def test_valid_shape_and_dtype(self): shape = (10, 5) dtype = torch.float32 tensor = torch.zeros(shape, dtype=dtype) module = Ensure(shape=shape, dtype=dtype) module(tensor) ensure(tensor, shape, dtype)
def test_invalid_shape(self, shape): tensor = torch.zeros((1, 2, 3)) module = Ensure(shape=shape) with pytest.raises(ValueError, match='input shape is'): module(tensor) with pytest.raises(ValueError, match='input shape is'): ensure(tensor, shape)
def test_torchscript_module(self): shape = (10, 5) dtype = torch.float32 tensor = torch.zeros(shape, dtype=dtype) module = Ensure(shape=shape, dtype=dtype) jit_module = torch.jit.script(module) jit_module(tensor) shape = (5, 5) tensor = torch.zeros(shape, dtype=dtype) # torchscript changes the exception type with pytest.raises(torch.jit.Error, match='input shape is'): jit_module(tensor)
def test_jit_module(self): shape = (10, 5) dtype = torch.float32 tensor = torch.zeros(shape, dtype=dtype) module = Ensure(shape=shape, dtype=dtype) jit_module = torch.jit.trace(module, (tensor,)) jit_module(tensor) # An invalid tensor also passes, because checks are disabled shape = (5, 5) tensor = torch.zeros(shape, dtype=dtype) jit_module(tensor) # Tracing with a different shape fails during trace process with pytest.raises(ValueError, match='input shape is'): jit_module = torch.jit.trace(module, (tensor,))
def test_invalid_dtypes(self, dtype_t, dtype_c): tensor = torch.zeros(1, dtype=dtype_t) module = Ensure(shape=None, dtype=dtype_c) with pytest.raises(ValueError, match='input dtype'): module(tensor)
def test_valid_dtypes(self, dtype): tensor = torch.zeros(1, dtype=dtype) module = Ensure(shape=None, dtype=dtype) module(tensor) ensure(tensor, None, dtype)
def test_invalid_unknown_shape(self, shape_t, shape_c): tensor = torch.zeros(shape_t) module = Ensure(shape=shape_c) with pytest.raises(ValueError, match='non broadcastable'): module(tensor)
def test_wrong_initialization(self): with pytest.raises(ValueError, match='both arguments'): Ensure(shape=None, dtype=None)
def test_unknown_shape(self, shape_t, shape_c): tensor = torch.zeros(shape_t) module = Ensure(shape=shape_c) module(tensor)
def test_nonbroadcastable_shape(self, shape_t, shape_c): tensor = torch.zeros(shape_t) module = Ensure(shape=shape_c, broadcastable=True) with pytest.raises(ValueError, match='non broadcastable'): module(tensor)
def test_broadcastable_shape(self, shape_t, shape_c): tensor = torch.zeros(shape_t) module = Ensure(shape=shape_c, broadcastable=True) module(tensor)
def test_valid_shape(self, shape): tensor = torch.zeros(shape) module = Ensure(shape=shape) module(tensor) # Use the function version ensure(tensor, shape)
def test_invalid_dtypes_with_cast(self, dtype_t, dtype_c): tensor = torch.zeros(1, dtype=dtype_t) module = Ensure(shape=None, dtype=dtype_c, can_cast=True) with pytest.raises(ValueError, match='be casted to'): module(tensor)
def test_dtypes_with_cast(self, dtype_t, dtype_c): tensor = torch.zeros(1, dtype=dtype_t) module = Ensure(shape=None, dtype=dtype_c, can_cast=True) module(tensor)