def test_tensor_type_not_determined_by_inputs(self): @torch.jit.script def scalar_type_input(x, y, z): return x + y + 4 + z.item() x = torch.tensor([2, 2]) scalar_type_input(x, x, torch.tensor(1)) scalar_type_input(x, x, torch.tensor(1)) scalar_type_input(x, x, torch.tensor(1.0)) g = torch.jit.last_executed_optimized_graph() # item & add should not get pulled into the fusion group - # we expect to see Fusion Group (item / add) Fusion Group in ir dump FileCheck().check("TensorExpr").check( "Scalar = aten::item").check_next("Tensor = aten::add").check( "TensorExpr").run(g) @torch.jit.script def non_const_dtype(x, y, cond: bool): dtype = torch.int16 if cond else torch.int32 return (x + y + 3).sum(dtype=dtype) non_const_dtype(x, x, True) non_const_dtype(x, x, True) g = torch.jit.last_executed_optimized_graph() # because dtype is non-const, sum should not get pulled into the Fusion Group FileCheck().check("TensorExpr").check("TensorExpr").check_not( "aten::sum").run(g)
def test_specialized_types(self): @torch.jit.script def test_fuse(a, b): c = a * b d = c * b return d x = torch.tensor([.5]) for _ in range(3): test_fuse(x, x) g = torch.jit.last_executed_optimized_graph() # Types should remain specialized for typecheck outputs & fusion outputs FileCheck().check("Double(").check_same("prim::TypeCheck").check_same("\n").check("Double").check_same("TensorExpr").run(g) # other outputs should not be specialized FileCheck().check("Tensor = prim::If").run(g)
def test_not_fusing_scalar_ops(self): @torch.jit.script def foo(x: int, y: int): return x + y + 2 + 4 + 5 + 6 foo(1, 2) foo(2, 3) g = torch.jit.last_executed_optimized_graph() FileCheck().check_not("TensorExpr").run(g)
def test_not_optimizing_property(self): @torch.jit.script def foo(x, y): return x + y + 1 + 2 + 3, x.size() x = torch.ones(1) foo(x, x) foo(x, x) g = torch.jit.last_executed_optimized_graph() FileCheck().check("aten::size").run(g) x = torch.ones([2, 3, 5]) self.assertEqual(foo(x, x), (x + x + 1 + 2 + 3, x.size()))
def test_tensor_constant(self): def foo(a, b): return a + b + torch.tensor([2]) x = torch.ones(1, requires_grad=False) foo_script = torch.jit.script(foo) foo_script(x, x) foo_script(x, x) self.assertEqual(foo_script(x, x), foo(x, x)) g = torch.jit.last_executed_optimized_graph() FileCheck().check_count("aten::add", 2, exactly=True).run(g)
def test_fallback_graph_not_specialized(self): @torch.jit.script def foo(a, b): c = a * b d = c * b e = d * b return d + e x = torch.ones(1) y = torch.ones(1) foo(x, y) foo(x, y) g = torch.jit.last_executed_optimized_graph() FileCheck().check("CallFunction").check_next("Tensor = prim::TupleUnpack").run(g)
def test_aliasing_merge(self): @torch.jit.script def foo(a, b): c = a * b d = c * b d.add_(b) e = d * b return d + e x = torch.ones(1) y = torch.ones(1) foo(x, y) b = foo(x, y) g = torch.jit.last_executed_optimized_graph() self.assertEqual(len(list(g.findAllNodes("prim::TypeCheck"))), 2) FileCheck().check("TensorExpr").check("aten::add_").check("TensorExpr").run(g)
def test_local_fusion_strategy(self): @torch.jit.script def foo(x): return x + x + x torch.jit.set_fusion_strategy([("STATIC", 1)]) for _ in range(3): foo(torch.rand([10])) torch.jit.set_fusion_strategy([("STATIC", 10)]) for i in range(10): foo(torch.rand([i])) foo(torch.rand([i])) g = torch.jit.last_executed_optimized_graph() FileCheck().check_count(":TensorExprGroup", 2, exactly=True).run(g)
def test_autograd_fallback_graph(self): @torch.jit.script def foo(a, b): c = a * b d = c * b e = d * b return d + e x = torch.ones(1, requires_grad=True) y = torch.ones(1, requires_grad=True) foo(x, y) b = foo(x, y) b.backward(torch.ones([1], dtype=torch.float), retain_graph=True) b.backward(torch.ones([1], dtype=torch.float)) g = torch.jit.last_executed_optimized_graph() FileCheck().check("fallback_function").check_next("CallFunction").run(g)
def test_use_not_profiled(self): def foo(t1, t2, t3, t4, t: float): h = t1 + t2 + t3 + t4 if t > 0.5: # Putting a use of t1 in a never-executed conditional prevents return t1 + 1 return h t = torch.rand(8, dtype=torch.float) foo_script = torch.jit.script(foo) for _ in range(torch._C._jit_get_num_profiled_runs() + 1): foo_script(t, t, t, t, 0.1) self.assertEqual(foo(t, t, t, t, 0.1), foo_script(t, t, t, t, 0.1)) g = torch.jit.last_executed_optimized_graph() # all adds fused FileCheck().check("graph").check_not("aten::add").check("prim::If").run(g)
def test_specialize_backward(self): def test_fuse(a, b): c = a * b d = c * b return d test_fuse.__disable_jit_function_caching__ = True scripted_f = torch.jit.script(test_fuse) x = torch.ones(1, requires_grad=True) y = torch.ones(1, requires_grad=True) scripted_f(x, y) b = scripted_f(x, y) warmup_backward(b) g = torch.jit.last_executed_optimized_graph() # Backward has an if node guarding specializations, # within the if node true block there is only one if node # that guards a tensorexpr group optimized_block = next(g.findNode("prim::If").blocks()) if_nodes = list(optimized_block.findAllNodes("prim::If")) self.assertEqual(len(if_nodes), 1) FileCheck().check("Group[Subgraph").run(str(if_nodes[0])) # no broadcasts occurred, sum_to_size have been specialized out self.assertIsNone(optimized_block.findNode("aten::_grad_sum_to_size")) broadcast_f = torch.jit.script(test_fuse) x = torch.ones([2, 2], requires_grad=True) y = torch.ones([1], requires_grad=True) broadcast_f(x, y) b = broadcast_f(x, y) b.backward(torch.ones([2, 2], dtype=torch.float), retain_graph=True) b.backward(torch.ones([2, 2], dtype=torch.float)) # warmup_backward(b, torch.ones([2, 2], dtype=torch.float)) g = torch.jit.last_executed_optimized_graph() optimized_block = next(g.findNode("prim::If").blocks()) # broadcasts occurred, currently expect to see aten::_grad_sum_to_size self.assertIsNotNone( optimized_block.findNode("aten::_grad_sum_to_size"))