tf.bitcast equivalent in pytorch?

Issue

This question is different from tf.cast equivalent in pytorch?.

bitcast do bitwise reinterpretation(like reinterpret_cast in C++) instead of "safe" type conversion.

This operation is useful when you want to store bfloat16 tensor with numpy.

x = torch.ones(224, 224, 3, dtype=torch.bfloat16
x_np = bitcast(x, torch.uint8).numpy()

Currently numpy doesn’t natively support bfloat16, so x.numpy() will raise TypeError: Got unsupported ScalarType BFloat16

Solution

Use the 2nd overload torch.Tensor.view.

Its semantic is closely similar to numpy.ndarray.view.

Answered By – YouJiacheng

This Answer collected from stackoverflow, is licensed under cc by-sa 2.5 , cc by-sa 3.0 and cc by-sa 4.0

Leave a Reply

(*) Required, Your email will not be published