def test_cast_tensorflow_dtype(self): """If the tensor is a TensorFlow tensor, casting using a TensorFlow dtype will also work""" t = tf.Variable([1, 2, 3]) res = fn.cast(t, tf.complex128) assert isinstance(res, tf.Tensor) assert res.dtype is tf.complex128
def test_cast_torch_dtype(self): """If the tensor is a Torch tensor, casting using a Torch dtype will also work""" t = torch.tensor([1, 2, 3], dtype=torch.int64) res = fn.cast(t, torch.float64) assert isinstance(res, torch.Tensor) assert res.dtype is torch.float64
def test_cast_numpy_string(self, t): """Test that specifying a NumPy dtype via a string results in proper casting behaviour""" res = fn.cast(t, "float64") assert fn.get_interface(res) == fn.get_interface(t) if hasattr(res, "numpy"): # if tensorflow or pytorch, extract view of underlying data res = res.numpy() t = t.numpy() assert onp.issubdtype(onp.asarray(t).dtype, onp.integer) assert res.dtype.type is onp.float64