def test_reverse_arith_ops(self): shape_env = ShapeEnv() a = shape_env.create_symint("s1", 2) self.assertTrue(5 // a == 5 // 2) a = shape_env.create_symint("s1", 2) self.assertTrue(5 * a == 5 * 2)
def test_aten_ops(self): shape_env = ShapeEnv() x = create_symbolic_tensor("x", torch.randn(5), shape_env) torch.ops.aten.narrow_copy.SymInt(x, 0, 0, x.shape[0]) shape_env = ShapeEnv() x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env) torch.ops.aten.expand.SymInt(x, [x.shape[0], x.shape[1], x.shape[2]])
def test_arith_ops(self): shape_env = ShapeEnv() symints = [] for i in range(5): symints.append((i, shape_env.create_symint(f"s{i}", i))) ops = [operator.add, operator.sub, operator.floordiv, operator.mul, operator.mod] for op in ops: for args in itertools.permutations(symints, 2): if not isinstance(args[0][1], int) and ((op != operator.mod or op != operator.floordiv) and args[1][0] != 0): self.assertTrue(op(args[0][1], args[1][1]) == op(args[0][0], args[1][0]))
def test_size_expressions(self): shape_env = ShapeEnv() x = create_symbolic_tensor("x", torch.randn(5), shape_env) expand_x = x.expand(x.shape[0], x.shape[0]) if expand_x.shape[0] > 3: result = expand_x + expand_x else: result = expand_x + expand_x gt_op = shape_env.guards[0][0] self.assertTrue(isinstance(gt_op, sympy.core.relational.StrictGreaterThan)) self.assertTrue(str(x.shape[0]), str(gt_op.args[0])) self.assertTrue(str(expand_x.shape[1]), str(x.shape[0])) self.assertTrue(str(expand_x.shape[1]), str(result.shape[0]))
def test_symint_args(self): shape_env = ShapeEnv() x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env) y = create_symbolic_tensor("y", torch.randn(5, 4, 1), shape_env) LAST_DIM = 2 z = x.narrow_copy(LAST_DIM, 0, y.shape[LAST_DIM]) self.assertTrue(z.shape[2] == int(y.shape[2])) # arithmetic expr with two symints z = x.narrow_copy(LAST_DIM, 0, x.shape[LAST_DIM] - y.shape[LAST_DIM]) self.assertTrue(z.shape[2] == 2) # arithmetic expr with a symint and python int z = x.narrow_copy(LAST_DIM, 0, x.shape[LAST_DIM] - 1) self.assertTrue(z.shape[2] == 2)
def test_binary(self): shape_env = ShapeEnv() x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env) y = create_symbolic_tensor("y", torch.randn(5, 4, 3), shape_env) z = x + y self.assertTrue(z.shape[0] == 5) self.assertTrue(z.shape[1] == 4) self.assertTrue(z.shape[2] == 3) # broadcasting y = create_symbolic_tensor("y", torch.randn(1, 4, 1), shape_env) z = x + y self.assertTrue(z.shape[0] == 5) self.assertTrue(z.shape[1] == 4) self.assertTrue(z.shape[2] == 3)
def test_symint_vargs(self): shape_env = ShapeEnv() x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env) y = create_symbolic_tensor("y", torch.randn(1, 4, 1), shape_env) # varargs z = y.expand(x.shape[0], y.shape[1], x.shape[2]) self.assertTrue(z.shape[0] == 5) self.assertTrue(z.shape[1] == 4) self.assertTrue(z.shape[2] == 3) # shape list z = y.expand((x.shape[0], y.shape[1], x.shape[2])) self.assertTrue(z.shape[0] == 5) self.assertTrue(z.shape[1] == 4) self.assertTrue(z.shape[2] == 3) # mixed python symints and ints z = y.expand(x.shape[0], y.shape[1], 3) self.assertTrue(z.shape[0] == 5) self.assertTrue(z.shape[1] == 4) self.assertTrue(z.shape[2] == 3) # mixed python symints and ints in a list z = y.expand((x.shape[0], y.shape[1], 3)) self.assertTrue(z.shape[0] == 5) self.assertTrue(z.shape[1] == 4) self.assertTrue(z.shape[2] == 3) # mixed python symints and ints z = y.expand(5, y.shape[1], x.shape[2]) self.assertTrue(z.shape[0] == 5) self.assertTrue(z.shape[1] == 4) self.assertTrue(z.shape[2] == 3) # mixed python ints and symints in a list z = y.expand((5, y.shape[1], x.shape[2])) self.assertTrue(z.shape[0] == 5) self.assertTrue(z.shape[1] == 4) self.assertTrue(z.shape[2] == 3) z = y.expand((y.shape[1],)) z = y.expand(y.shape[1])
def test_roundtrip(self): shape_env = ShapeEnv() x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env) self.assertTrue(not isinstance(x.shape[0], PySymInt)) self.assertTrue(isinstance(x.shape[0], CPP_SYMINT_CLASS)) self.assertTrue(x.shape[0] == 5) self.assertTrue(x.shape[1] == 4) self.assertTrue(x.shape[2], 3) self.assertTrue(x.size()[0], 5) self.assertTrue(x.size()[1], 4) self.assertTrue(isinstance(x.size()[1], CPP_SYMINT_CLASS)) self.assertTrue(x.size()[2] == 3) self.assertTrue(x.size(0) == 5) self.assertTrue(x.size(1) == 4) self.assertTrue(x.size(2) == 3) self.assertTrue(isinstance(x.size(2), CPP_SYMINT_CLASS))
def test_meta_symint(self): shape_env = ShapeEnv() a0 = shape_env.create_symint("a0", 2) r = torch.empty(a0, device='meta') self.assertIsInstance(r.shape[0], CPP_SYMINT_CLASS)