def test_ones_like_explicit_dtype(self, t): """Test that the ones like function creates the correct shape and type tensor.""" res = fn.ones_like(t, dtype=np.float16) if isinstance(t, (list, tuple)): t = onp.asarray(t) assert res.shape == t.shape assert fn.get_interface(res) == fn.get_interface(t) assert fn.allclose(res, np.ones(t.shape)) # if tensorflow or pytorch, extract view of underlying data if hasattr(res, "numpy"): res = res.numpy() t = t.numpy() assert onp.asarray(res).dtype.type is np.float16
def test_where(t): """Test that the where function works as expected""" res = fn.where(t < 0, 100 * fn.ones_like(t), t) expected = np.array([[[1, 2], [3, 4], [100, 1]], [[5, 6], [0, 100], [2, 1]]]) assert fn.allclose(res, expected)