コード例 #1
0
    def test_reverse_arith_ops(self):
        shape_env = ShapeEnv()

        a = shape_env.create_symint("s1", 2)
        self.assertTrue(5 // a == 5 // 2)

        a = shape_env.create_symint("s1", 2)
        self.assertTrue(5 * a == 5 * 2)
コード例 #2
0
    def test_arith_ops(self):
        shape_env = ShapeEnv()
        symints = []
        for i in range(5):
            symints.append((i, shape_env.create_symint(f"s{i}", i)))

        ops = [operator.add, operator.sub, operator.floordiv, operator.mul, operator.mod]

        for op in ops:
            for args in itertools.permutations(symints, 2):
                if not isinstance(args[0][1], int) and ((op != operator.mod or op != operator.floordiv) and args[1][0] != 0):
                    self.assertTrue(op(args[0][1], args[1][1]) == op(args[0][0], args[1][0]))
コード例 #3
0
 def test_meta_symint(self):
     shape_env = ShapeEnv()
     a0 = shape_env.create_symint("a0", 2)
     r = torch.empty(a0, device='meta')
     self.assertIsInstance(r.shape[0], CPP_SYMINT_CLASS)