コード例 #1
0
 def test_valid_shape_and_dtype(self):
     shape = (10, 5)
     dtype = torch.float32
     tensor = torch.zeros(shape, dtype=dtype)
     module = Ensure(shape=shape, dtype=dtype)
     module(tensor)
     ensure(tensor, shape, dtype)
コード例 #2
0
 def test_invalid_shape(self, shape):
     tensor = torch.zeros((1, 2, 3))
     module = Ensure(shape=shape)
     with pytest.raises(ValueError, match='input shape is'):
         module(tensor)
     with pytest.raises(ValueError, match='input shape is'):
         ensure(tensor, shape)
コード例 #3
0
    def test_torchscript_module(self):
        shape = (10, 5)
        dtype = torch.float32
        tensor = torch.zeros(shape, dtype=dtype)
        module = Ensure(shape=shape, dtype=dtype)
        jit_module = torch.jit.script(module)
        jit_module(tensor)

        shape = (5, 5)
        tensor = torch.zeros(shape, dtype=dtype)
        # torchscript changes the exception type
        with pytest.raises(torch.jit.Error, match='input shape is'):
            jit_module(tensor)
コード例 #4
0
    def test_jit_module(self):
        shape = (10, 5)
        dtype = torch.float32
        tensor = torch.zeros(shape, dtype=dtype)
        module = Ensure(shape=shape, dtype=dtype)
        jit_module = torch.jit.trace(module, (tensor,))
        jit_module(tensor)

        # An invalid tensor also passes, because checks are disabled
        shape = (5, 5)
        tensor = torch.zeros(shape, dtype=dtype)
        jit_module(tensor)

        # Tracing with a different shape fails during trace process
        with pytest.raises(ValueError, match='input shape is'):
            jit_module = torch.jit.trace(module, (tensor,))
コード例 #5
0
 def test_invalid_dtypes(self, dtype_t, dtype_c):
     tensor = torch.zeros(1, dtype=dtype_t)
     module = Ensure(shape=None, dtype=dtype_c)
     with pytest.raises(ValueError, match='input dtype'):
         module(tensor)
コード例 #6
0
 def test_valid_dtypes(self, dtype):
     tensor = torch.zeros(1, dtype=dtype)
     module = Ensure(shape=None, dtype=dtype)
     module(tensor)
     ensure(tensor, None, dtype)
コード例 #7
0
 def test_invalid_unknown_shape(self, shape_t, shape_c):
     tensor = torch.zeros(shape_t)
     module = Ensure(shape=shape_c)
     with pytest.raises(ValueError, match='non broadcastable'):
         module(tensor)
コード例 #8
0
 def test_wrong_initialization(self):
     with pytest.raises(ValueError, match='both arguments'):
         Ensure(shape=None, dtype=None)
コード例 #9
0
 def test_unknown_shape(self, shape_t, shape_c):
     tensor = torch.zeros(shape_t)
     module = Ensure(shape=shape_c)
     module(tensor)
コード例 #10
0
 def test_nonbroadcastable_shape(self, shape_t, shape_c):
     tensor = torch.zeros(shape_t)
     module = Ensure(shape=shape_c, broadcastable=True)
     with pytest.raises(ValueError, match='non broadcastable'):
         module(tensor)
コード例 #11
0
 def test_broadcastable_shape(self, shape_t, shape_c):
     tensor = torch.zeros(shape_t)
     module = Ensure(shape=shape_c, broadcastable=True)
     module(tensor)
コード例 #12
0
 def test_valid_shape(self, shape):
     tensor = torch.zeros(shape)
     module = Ensure(shape=shape)
     module(tensor)
     # Use the function version
     ensure(tensor, shape)
コード例 #13
0
 def test_invalid_dtypes_with_cast(self, dtype_t, dtype_c):
     tensor = torch.zeros(1, dtype=dtype_t)
     module = Ensure(shape=None, dtype=dtype_c, can_cast=True)
     with pytest.raises(ValueError, match='be casted to'):
         module(tensor)
コード例 #14
0
 def test_dtypes_with_cast(self, dtype_t, dtype_c):
     tensor = torch.zeros(1, dtype=dtype_t)
     module = Ensure(shape=None, dtype=dtype_c, can_cast=True)
     module(tensor)