def test_reverse_arith_ops(self): shape_env = ShapeEnv() a = shape_env.create_symint("s1", 2) self.assertTrue(5 // a == 5 // 2) a = shape_env.create_symint("s1", 2) self.assertTrue(5 * a == 5 * 2)
def test_arith_ops(self): shape_env = ShapeEnv() symints = [] for i in range(5): symints.append((i, shape_env.create_symint(f"s{i}", i))) ops = [operator.add, operator.sub, operator.floordiv, operator.mul, operator.mod] for op in ops: for args in itertools.permutations(symints, 2): if not isinstance(args[0][1], int) and ((op != operator.mod or op != operator.floordiv) and args[1][0] != 0): self.assertTrue(op(args[0][1], args[1][1]) == op(args[0][0], args[1][0]))
def test_meta_symint(self): shape_env = ShapeEnv() a0 = shape_env.create_symint("a0", 2) r = torch.empty(a0, device='meta') self.assertIsInstance(r.shape[0], CPP_SYMINT_CLASS)