def test_nvfuser_capability_context(self, device): # This test is to ensure that the torch calls are replaced with refs # based on the nvfuser+prims capability from torch.fx.experimental.proxy_tensor import make_fx from torch._prims.context import TorchRefsNvfuserCapabilityMode # It's assumed that digamma is not supported by nvfuser # If it's ever supported, this test will need to be updated self.assertTrue(getattr(torch.ops.nvprims, "digamma", None) is None) a = torch.randn(3, 3, device=device) def func(a): return torch.digamma(a) with TorchRefsNvfuserCapabilityMode(): gm = make_fx(func)(a) # Check that the torch.digamma is not replaced with torch.ops.prims.digamma call_function_nodes = list( filter(lambda n: n.op == "call_function", gm.graph.nodes)) includes_aten_digamma = any( torch.ops.aten.digamma.default == node.target for node in call_function_nodes) includes_prims_digamma = any( torch.ops.prims.digamma.default == node.target for node in call_function_nodes) self.assertTrue(includes_aten_digamma) self.assertFalse(includes_prims_digamma) # Check mixed case, sigmoid is replaced with refs, but digamma is not def func(a): return torch.sigmoid(torch.digamma(a)) with TorchRefsNvfuserCapabilityMode(): gm = make_fx(func)(a) call_function_nodes = list( filter(lambda n: n.op == "call_function", gm.graph.nodes)) includes_aten_sigmoid = any( torch.ops.aten.sigmoid.default == node.target for node in call_function_nodes) includes_prims_digamma = any( torch.ops.prims.digamma.default == node.target for node in call_function_nodes) includes_nvprims_exp = any(torch.ops.nvprims.exp.default == node.target for node in call_function_nodes) self.assertFalse(includes_aten_sigmoid) self.assertFalse(includes_prims_digamma) self.assertTrue(includes_nvprims_exp)
def test_decomposition_interpreter(self): def fn(x): return torch.nn.functional.silu(x) x = torch.rand((4, 4)) fx_module = make_fx(fn, decomposition_table=None)(x) found_silu = False for n in fx_module.graph.nodes: if n.target == torch.ops.aten.silu or n.target == torch.ops.aten.silu.default: found_silu = True self.assertTrue(found_silu) new_graph = torch.fx.Graph() silu_decomp_table = {torch.ops.aten.silu.default: decomposition_table[torch.ops.aten.silu.default]} DecompositionInterpreter( fx_module, new_graph=new_graph, decomposition_table=silu_decomp_table, ).run(x) decomposed_module = torch.fx.GraphModule(fx_module, new_graph) for n in decomposed_module.graph.nodes: self.assertTrue(n.target != torch.ops.aten.silu) self.assertTrue(n.target != torch.ops.aten.silu.default) self.assertEqual(fx_module(x), decomposed_module(x))
def test_reinplace_scatter_twice_with_different_view_op_invalid2(self): def f(a_): a = a_.clone() b = a[:, 1] c = b[1] c_updated = c.add(1) bad_mirror_of_b = a.as_strided((4, ), (4, ), 0) # The first arg to select_scatter points to a different than c's base. # This makes it invalid to re-inplace. b_updated = torch.select_scatter(bad_mirror_of_b, c_updated, 0, 1) return b_updated inpt = torch.ones(4, 4) f2 = reinplace(make_fx(f)(inpt), inpt) expected_out = f(inpt) actual_out = f2(inpt) # self.assertEqual(actual_out, expected_out) self.assertExpectedInline(f2.code, """\ def forward(self, a__1): clone_default = torch.ops.aten.clone.default(a__1); a__1 = None slice_tensor = torch.ops.aten.slice.Tensor(clone_default, 0, 0, 9223372036854775807) select_int = torch.ops.aten.select.int(slice_tensor, 1, 1); slice_tensor = None select_int_1 = torch.ops.aten.select.int(select_int, 0, 1); select_int = None add_tensor = torch.ops.aten.add.Tensor(select_int_1, 1); select_int_1 = None as_strided_default = torch.ops.aten.as_strided.default(clone_default, [4], [4], 0); clone_default = None select_int_2 = torch.ops.aten.select.int(as_strided_default, 0, 1) copy__default = torch.ops.aten.copy_.default(select_int_2, add_tensor); select_int_2 = add_tensor = None return as_strided_default """) # noqa: B950
def test_nvfuser_executor_partitioned_no_partitions_error(self, device): # This test is to ensure that nvfuser partitioned executor works correctly # It's assumed that digamma is not supported by nvfuser # If it's ever supported, this test will need to be updated self.assertTrue(torch.ops.prims.digamma.default.impl_nvfuser is None) from torch.fx.experimental.proxy_tensor import make_fx from torch._prims.context import TorchRefsMode from torch._prims.executor import execute a = torch.randn(3, 4, device=device) def func(a): return torch.digamma(a) # not supported by nvfuser with TorchRefsMode.push(): gm = make_fx(func)(a) with catch_warnings(record=True) as w: # Trigger warning execute(gm, a, executor="nvfuser") # Check warning occurs self.assertEqual(len(w), 1) self.assertTrue( "is not supported by nvFuser" in str(w[-1].message))
def _traced(*args, executor="aten", **kwargs): # TODO: caching wrapped, all_args = wrapper_and_args_for_make_fx(fn, args, kwargs) with NvfuserPrimsMode(), TorchRefsMode(): gm = make_fx(wrapped)(all_args) return execute(gm, all_args, executor=executor)
def test_scalar_device(self, device): def f(a, b): return a + b inps = [torch.randn(3, device=device), torch.tensor(5)] fx_f = make_fx(f)(*inps) self.assertEqual(fx_f(*inps), f(*inps))
def test_aten_where(self, device, dtype): def fn(x): where = torch.ops.aten.where(x < 0, -x, x) a = where + 1 b = a + 1 return b inputs = torch.randn(4, device=device) traced = make_fx(fn)(inputs) nvfuser = NvFuserBackend() compiled_module = nvfuser.compile(copy.deepcopy(traced)) for node in compiled_module.graph.nodes: if node.op == "call_function": assert "fused" in str( node.target ), "the entire function should be fused into a single fusion group" eager_result = traced(inputs) nvfuser_result = compiled_module(inputs) torch.testing.assert_close(eager_result, nvfuser_result, rtol=1e-5, atol=1e-5)
def test_reinplace_with_view(self): def f(x): a = x.clone() a_view = a.view(-1) # We shouldn't re-inplace the first add(), because an alias of a is re-used later in the program b = a.add(1) # Second add() is fine to re-inplace c = a_view.add(1) return c inpt = torch.ones(2) f2 = reinplace(make_fx(f)(inpt), inpt) expected_out = f(inpt) actual_out = f2(inpt) self.assertEqual(actual_out, expected_out) self.assertExpectedInline( f2.code, """\ def forward(self, x_1): clone_default = torch.ops.aten.clone.default(x_1); x_1 = None view_default = torch.ops.aten.view.default(clone_default, [-1]) add_tensor = torch.ops.aten.add.Tensor(clone_default, 1); clone_default = None add_tensor_1 = torch.ops.aten.add_.Tensor(view_default, 1) return view_default """)
def test_reinplace_different_metadata(self): def f(a_): a = a_.clone() b = a + 1 # Naively, we shouldn't try to inplace the .ge() call, # because that would require resizing "b" (from a float to a bool tensor). c = torch.ge(b, a) return c inpt = torch.ones(4) f2 = reinplace(make_fx(f)(inpt), inpt) expected_out = f(inpt) actual_out = f2(inpt) self.assertEqual(actual_out, expected_out) # The .ge() should not be reinplaced. self.assertExpectedInline( f2.code, """\ def forward(self, a__1): clone_default = torch.ops.aten.clone.default(a__1); a__1 = None add_tensor = torch.ops.aten.add.Tensor(clone_default, 1) ge_tensor = torch.ops.aten.ge.Tensor(add_tensor, clone_default); add_tensor = clone_default = None return ge_tensor """)
def test_out_node_updated(self): def f(): x = torch.zeros(2, 2) y = x.diagonal() y_updated = y.add(1) z = torch.diagonal_scatter(x, y_updated) # reinplace needs to know to replace output [z] with [x] return [z] if not HAS_FUNCTIONALIZATION: return f2 = reinplace(make_fx(functionalize(f))()) expected_out = f() actual_out = f2() self.assertEqual(actual_out, expected_out) self.assertExpectedInline( f2.code, """\ def forward(self): zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False) diagonal_default = torch.ops.aten.diagonal.default(zeros) add_tensor = torch.ops.aten.add_.Tensor(diagonal_default, 1); diagonal_default = None return [zeros] """)
def test_reinplace_index_mutation(self): def f(): a = torch.zeros(4, 4, 4) a[:, 2:] = torch.ones(4, 2, 4) return a if not HAS_FUNCTIONALIZATION: return f2 = reinplace(make_fx(functionalize(f))()) expected_out = f() actual_out = f2() self.assertEqual(actual_out, expected_out) self.assertExpectedInline( f2.code, """\ def forward(self): zeros = torch.ops.aten.zeros.default([4, 4, 4], device = device(type='cpu'), pin_memory = False) ones = torch.ops.aten.ones.default([4, 2, 4], device = device(type='cpu'), pin_memory = False) slice_tensor = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807) slice_tensor_1 = torch.ops.aten.slice.Tensor(slice_tensor, 1, 2, 9223372036854775807); slice_tensor = None slice_tensor_2 = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807) slice_tensor_3 = torch.ops.aten.slice.Tensor(slice_tensor_2, 1, 2, 9223372036854775807); slice_tensor_2 = None copy__default = torch.ops.aten.copy_.default(slice_tensor_3, ones); slice_tensor_3 = ones = None return zeros """)
def test_nvfuser_executor_partitioned(self, device): # This test is to ensure that nvfuser partitioned executor works correctly # It's assumed that digamma is not supported by nvfuser # If it's ever supported, this test will need to be updated self.assertTrue(torch.ops.prims.digamma.default.impl_nvfuser is None) from torch.fx.experimental.proxy_tensor import make_fx from torch._prims.context import TorchRefsMode from torch._prims.executor import execute a = torch.randn(3, 4, device=device) b = torch.rand(3, 1, device=device) c = torch.rand(3, 4, device=device) def func(a, b, c): aa = torch.digamma(a) # not supported by nvfuser d = torch.add(b, c) dd = torch.sqrt(d) return torch.mul(aa, dd.digamma()) with TorchRefsMode.push(): gm = make_fx(func)(a, b, c) expected = execute(gm, a, b, c, executor="aten") actual = execute(gm, a, b, c, executor="nvfuser") self.assertEqual(expected, actual)
def test_reinplace_scatter_twice_with_different_view_op_invalid(self): def f(a_): a = a_.clone() b = a[:, 1] c = b[1] c_updated = c.add(1) good_mirror_of_b = a.as_strided((4, ), (4, ), 1) # The first arg to select_scatter is an equivalent view to b. # However, the select_scatter call below tries to put c_updated # into a different slice of "b" than what "c" currently occupies. # b_updated = torch.select_scatter(good_mirror_of_b, c_updated, 0, 0) return b_updated inpt = torch.ones(4, 4) f2 = reinplace(make_fx(f)(inpt), inpt) expected_out = f(inpt) actual_out = f2(inpt) self.assertEqual(actual_out, expected_out) self.assertExpectedInline(f2.code, """\ def forward(self, a__1): clone_default = torch.ops.aten.clone.default(a__1); a__1 = None slice_tensor = torch.ops.aten.slice.Tensor(clone_default, 0, 0, 9223372036854775807) select_int = torch.ops.aten.select.int(slice_tensor, 1, 1); slice_tensor = None select_int_1 = torch.ops.aten.select.int(select_int, 0, 1); select_int = None add_tensor = torch.ops.aten.add.Tensor(select_int_1, 1); select_int_1 = None as_strided_default = torch.ops.aten.as_strided.default(clone_default, [4], [4], 1); clone_default = None select_int_2 = torch.ops.aten.select.int(as_strided_default, 0, 0) copy__default = torch.ops.aten.copy_.default(select_int_2, add_tensor); select_int_2 = add_tensor = None return as_strided_default """) # noqa: B950
def test_reinplace_scatter_twice(self): def f(a_): # for now, don't test mutations to inputs a = a_.clone() b = a[:, 1] c = b[1] c.add_(1) return a if not HAS_FUNCTIONALIZATION: return inpt = torch.ones(4, 4) f2 = reinplace(make_fx(functionalize(f))(inpt), inpt) expected_out = f(inpt) actual_out = f2(inpt) self.assertEqual(actual_out, expected_out) self.assertExpectedInline( f2.code, """\ def forward(self, a__1): clone_default = torch.ops.aten.clone.default(a__1); a__1 = None slice_tensor = torch.ops.aten.slice.Tensor(clone_default, 0, 0, 9223372036854775807) select_int = torch.ops.aten.select.int(slice_tensor, 1, 1); slice_tensor = None select_int_1 = torch.ops.aten.select.int(select_int, 0, 1); select_int = None add_tensor = torch.ops.aten.add_.Tensor(select_int_1, 1); select_int_1 = None slice_tensor_1 = torch.ops.aten.slice.Tensor(clone_default, 0, 0, 9223372036854775807) select_int_2 = torch.ops.aten.select.int(slice_tensor_1, 1, 1); slice_tensor_1 = None return clone_default """)
def test_use_fake_and_tensor(self): def f(x, y): z = torch.tensor([2.0, 3.0]) return x + y + z g = make_fx(f, use_fake=True)(torch.randn(2), torch.randn(2)) x, y = torch.randn(2), torch.randn(2) self.assertEqual(g(x, y), f(x, y))
def test_make_fx_overloads(self): def f(x): return x.cos() + torch.randn(x.shape) traced = make_fx(f)(torch.randn(3)) self.assertTrue(all([isinstance(node.target, torch._ops.OpOverload) for node in traced.graph.nodes if node.op == 'call_function']))
def test_constant_proxy_tensor_mut(self): from torch.fx.experimental.proxy_tensor import make_fx def f(): val = torch.tensor(float(1)) val.add_(2) return torch.full((100, 100), val) g = make_fx(f)() self.assertEqual(g(), f()) # In case we mutated shared state in the g graph! self.assertEqual(g(), f()) g = make_fx(f, use_fake=True)() self.assertEqual(g(), f()) # In case we mutated shared state in the g graph! self.assertEqual(g(), f())
def check(self, f, t, delta, check_val=True, graph_input=False, P=None): """ check if the CSE modified graph of ``f`` 1) has delta less nodes, and 2) do not reduce the number of nodes further on a second pass, and 3) modified returned is true only if the number of nodes decreases. Args: f: function to be checked t: tensor to be passed to f delta: an integer >= -1. If delta = -1, it only checks if the new graph has less or equal number of nodes check_val: if True, check if the output of f is correct graph_input: True is f is type GraphModule P: the pass to use. If None, use P_default """ if graph_input: fx_g = f else: fx_g = make_fx(f)(t) if P is None: P = P_default res = P(fx_g) new_g = res.graph_module new_graph = new_g.graph modified = res.modified # the number of nodes decrease/ or stay the same old_num_nodes = len(fx_g.graph.nodes) new_num_nodes = len(new_graph.nodes) assert (new_num_nodes < old_num_nodes) == modified, "modified should be True if the number of nodes decrease" if delta == -1: self.assertTrue(old_num_nodes >= new_num_nodes, ( f"number of nodes increased {old_num_nodes}, {new_num_nodes}")) else: self.assertTrue(old_num_nodes == new_num_nodes + delta, ( f"number of nodes not the same {old_num_nodes - delta}, {new_num_nodes}\n {fx_g.graph} \n {new_graph}")) # a second pass should not reduce more nodes res = P(new_g) pass_2_graph = res.graph_module.graph pass_2_num_nodes = len(pass_2_graph.nodes) self.assertTrue(pass_2_num_nodes == new_num_nodes, ( f"second pass graph has less node {pass_2_num_nodes}, {new_num_nodes}\n {new_graph} \n {pass_2_graph}")) # check correctness if check_val: true_result = fx_g(t) our_result = new_g(t) if true_result is None: # both return None self.assertTrue(our_result is None, f"true result is None, CSE result is {our_result}") else: # results returned are the same self.assertTrue(torch.all(true_result == our_result), ( f"results are different {true_result}, {our_result}")) # check results are the same
def test_mode_tracing_factory_function(self): def f(x): return x + torch.randn(x.shape) # default behavior should trace factory functions traced = make_fx(f)(torch.randn(3)) self.assertTrue( any(node.target == torch.ops.aten.randn.default for node in traced.graph.nodes))
def test_make_fx(self, device): def f(x): return torch.sin(x) inp = torch.randn(3) fx_f = make_fx(f)(inp) new_inp = torch.randn(3) self.assertEqual(fx_f(new_inp), f(new_inp))
def test_mode_tracing_factory_function_no_factory_function(self): def f(x): return x + torch.randn(x.shape) # setting the flag to false should not trace factory functions traced = make_fx(f, trace_factory_functions=False)(torch.randn(3)) self.assertFalse( any(node.target == torch.ops.aten.randn.default for node in traced.graph.nodes))
def test_constant_proxy_tensor(self): from torch.fx.experimental.proxy_tensor import make_fx def f(): val = torch.tensor(float('inf')) return torch.full((100, 100), val) g = make_fx(f)() self.assertEqual(g(), f())
def test_mode_tracing_factory_function(self): def f(x): return x + torch.randn(x.shape) traced = make_fx(f, trace_factory_functions=True)(torch.randn(3)) self.assertTrue( any( isinstance(node.target, torch._ops.OpOverloadPacket) and node.target._qualified_op_name == 'aten::randn' for node in traced.graph.nodes))
def test_inplace_metadata(self): def f(x): x = x.clone() x.unsqueeze_(-1) assert x.shape[-1] == 1 return x inps = [torch.randn(5)] fx_f = make_fx(f)(*inps) self.assertEqual(fx_f(*inps), f(*inps))
def test_mode_tracing_factory_function_default_behavior(self): def f(x): return x + torch.randn(x.shape) traced = make_fx(f)(torch.randn( 3)) # default behavior should not trace factory functions self.assertFalse( any( isinstance(node.target, torch._ops.OpOverloadPacket) and node.target._qualified_op_name == 'aten::randn' for node in traced.graph.nodes))
def test_proxy_tensor(self): def f_grad(x): val = x.cos().cos().sum() return torch.autograd.grad(val, x) def f_backward(x): val = x.cos().cos().sum() val.backward() return x.grad for f in [f_grad, f_backward]: traced_graph = make_fx(f)(torch.randn(3, requires_grad=True)) inp = torch.randn(3, requires_grad=True) traced_graph_out = traced_graph(inp) assert inp.grad is None torch.testing.assert_close(traced_graph_out, f(inp))
def _traced(*args, executor="aten", **kwargs): # TODO: caching nargs = len(args) fn_kwargs = kwargs flat_fn_kwargs = list(fn_kwargs.values()) all_args = list(args) + flat_fn_kwargs def wrapped(args): fn_args = args[:nargs] kwargs_keys = list(fn_kwargs.keys()) kwargs = dict(zip(kwargs_keys, args[nargs:])) return fn(*fn_args, **kwargs) with TorchRefsMode.push(): gm = make_fx(wrapped)(all_args) return execute(gm, all_args, executor=executor)
def test_nvprims(self, device): # This test is to ensure that nvfuser specific prims are exposed # and can be traced with make_fx from torch.fx.experimental.proxy_tensor import make_fx def func(a): return torch.ops.nvprims.add(a, a) a = torch.randn(3, 4, device=device) gm = make_fx(func)(a) for node in gm.graph.nodes: if node.op == "call_function": self.assertTrue(node.name == "add_default") self.assertTrue(node.target == torch.ops.nvprims.add.default) self.assertFalse(node.target == torch.ops.prims.add.default) self.assertFalse(node.target == torch.ops.aten.add.default)
def test_reinplace_scatter_op(self): def f(a_): # for now, don't test mutations to inputs a = a_.clone() e = a.view(-1) b = a.view(-1) c = b[0] d = c.view(-1) d.add_(1) return a + e if not HAS_FUNCTIONALIZATION: return inpt = torch.ones(4) f2 = reinplace(make_fx(functionalize(f))(inpt), inpt) expected_out = f(inpt) actual_out = f2(inpt) self.assertEqual(actual_out, expected_out) # NOTE: one slight pessimization here is the fact that # there are a bunch of redundant views in the graph. # Technically, half of these views are duplicates that we could de-dup. # This shouldn't really hurt performance though, since creating an extra view # is effectively just moving some metadata around (and allocating a new TensorImpl). # We can/should update the pass in the future to clean this up. self.assertExpectedInline( f2.code, """\ def forward(self, a__1): clone_default = torch.ops.aten.clone.default(a__1); a__1 = None view_default = torch.ops.aten.view.default(clone_default, [-1]) view_default_1 = torch.ops.aten.view.default(clone_default, [-1]) select_int = torch.ops.aten.select.int(view_default_1, 0, 0); view_default_1 = None view_default_2 = torch.ops.aten.view.default(select_int, [-1]); select_int = None add_tensor = torch.ops.aten.add_.Tensor(view_default_2, 1) view_default_3 = torch.ops.aten.view.default(clone_default, [-1]); clone_default = None select_int_1 = torch.ops.aten.select.int(view_default_3, 0, 0) view_default_4 = torch.ops.aten.view.default(view_default_2, []); view_default_2 = None view_default_5 = torch.ops.aten.view.default(view_default_3, [4]); view_default_3 = None view_default_6 = torch.ops.aten.view.default(view_default_5, [-1]) add_tensor_1 = torch.ops.aten.add_.Tensor(view_default_5, view_default_6); view_default_6 = None return view_default_5 """)
def test_nvfuser_executor_cached_noncontiguous(self, device): # This test is to ensure that nvfuser computes correct results for noncontiguous tensors from torch.fx.experimental.proxy_tensor import make_fx from torch._prims.context import TorchRefsMode from torch._prims.executor import execute a = torch.randn(3, 3, device=device) def func(a): return torch.sigmoid(a) with TorchRefsMode.push(): gm = make_fx(func)(a) # First run to create the cache execute(gm, a, executor="nvfuser") # a.mT is noncontiguous, but it shouldn't affect correctness expected = execute(gm, a.mT, executor="aten") actual = execute(gm, a.mT, executor="nvfuser") self.assertEqual(expected, actual)