예제 #1
0
    def test_fails_return_label_mismatch(self):
        def f(t1: [5, 'a'], t2: ['a', 5]) -> ['a']:
            return (t1.transpose(0, 1) * t2).sum(dim=0)

        t1 = torch.randn(3, 5)
        t2 = torch.randn(5, 3)

        with self.assertRaises(ShapeError):
            dimchecked(f)(t1, t2)
예제 #2
0
    def test_fails_backward_just_ellipsis(self):
        def f(t1: [..., 2], t2: [..., 2]):
            pass

        t1 = torch.randn(3, 3, 3, 2)
        t2 = torch.randn(5, 3, 1, 5)

        msg = "Size mismatch on dimension 3 of argument `t2` (found 5, expected 2)"
        with self.assertRaises(ShapeError) as ex:
            dimchecked(f)(t1, t2)
        self.assertEqual(str(ex.exception), msg)
예제 #3
0
    def test_fails_wrong_parameter(self):
        def f(t1: [3, 3], t2: [5, 3]) -> [3]:
            return (t1.transpose(0, 1) * t2).sum(dim=0)

        t1 = torch.randn(3, 5)
        t2 = torch.randn(5, 3)

        msg = "Size mismatch on dimension 1 of argument `t1` (found 5, expected 3)"
        with self.assertRaises(ShapeError) as ex:
            dimchecked(f)(t1, t2)
        self.assertEqual(str(ex.exception), msg)
예제 #4
0
    def test_fails_wrong_return(self):
        def f(t1: [3, 5], t2: [5, 3]) -> [5]:
            return (t1.transpose(0, 1) * t2).sum(dim=0)

        t1 = torch.randn(3, 5)
        t2 = torch.randn(5, 3)

        msg = ("Size mismatch on dimension 0 of argument "
               "`<return value>` (found 3, expected 5)")
        with self.assertRaises(ShapeError) as ex:
            dimchecked(f)(t1, t2)
        self.assertEqual(str(ex.exception), msg)
예제 #5
0
    def test_fails_backward_ellipsis_wildcard(self):
        def f(t1: [3, ..., 'a'], t2: [5, ..., 'a']):
            pass

        t1 = torch.randn(3, 3, 5)
        t2 = torch.randn(5, 3, 3)

        msg = ("Label `a` already had dimension 5 bound to it "
               "(based on tensor t1 of shape (3, 3, 5)), but it "
               "appears with dimension 3 in tensor t2")
        with self.assertRaises(ShapeError) as ex:
            dimchecked(f)(t1, t2)
        self.assertEqual(str(ex.exception), msg)
예제 #6
0
    def test_wrap_no_anno(self):
        def f(t1, t2):  # t1: [3, 5], t2: [5, 3] -> [3]
            return (t1.transpose(0, 1) * t2).sum(dim=0)

        t1 = torch.randn(3, 5)
        t2 = torch.randn(5, 3)

        self.assertTrue((f(t1, t2) == dimchecked(f)(t1, t2)).all())
예제 #7
0
    def test_succeeds_ellipsis(self):
        def f(t1: [5, ..., 'a'], t2: ['a', ..., 5]) -> ['a']:
            return (t1.transpose(0, 3) * t2).sum(dim=(1, 2, 3))

        t1 = torch.randn(5, 1, 2, 3)
        t2 = torch.randn(3, 1, 2, 5)

        self.assertTrue((f(t1, t2) == dimchecked(f)(t1, t2)).all())