def test_fixed_shape_basic(): t1 = TensorType("float64", (1, 1)) assert t1.shape == (1, 1) assert t1.broadcastable == (True, True) t1 = TensorType("float64", (0,)) assert t1.shape == (0,) assert t1.broadcastable == (False,) t1 = TensorType("float64", (False, False)) assert t1.shape == (None, None) assert t1.broadcastable == (False, False) t1 = TensorType("float64", (2, 3)) assert t1.shape == (2, 3) assert t1.broadcastable == (False, False) assert t1.value_zeros(t1.shape).shape == t1.shape assert str(t1) == "TensorType(float64, (2, 3))" t1 = TensorType("float64", (1,)) assert t1.shape == (1,) assert t1.broadcastable == (True,) t2 = t1.clone() assert t1 is not t2 assert t1 == t2 t2 = t1.clone(dtype="float32", shape=(2, 4)) assert t2.dtype == "float32" assert t2.shape == (2, 4)
def test_fixed_shape_clone(): t1 = TensorType("float64", (1,)) t2 = t1.clone(dtype="float32", shape=(2, 4)) assert t2.shape == (2, 4) t2 = t1.clone(dtype="float32", shape=(False, False)) assert t2.shape == (None, None)
def test_deprecated_kwargs(): with pytest.warns(DeprecationWarning, match=".*broadcastable.*"): res = TensorType("float64", broadcastable=(True, False)) assert res.shape == (1, None) with pytest.warns(DeprecationWarning, match=".*broadcastable.*"): new_res = res.clone(broadcastable=(False, True)) assert new_res.shape == (None, 1)