Exemple #1
0
    def test_nvfuser_capability_context(self, device):
        # This test is to ensure that the torch calls are replaced with refs
        # based on the nvfuser+prims capability
        from torch.fx.experimental.proxy_tensor import make_fx
        from torch._prims.context import TorchRefsNvfuserCapabilityMode

        # It's assumed that digamma is not supported by nvfuser
        # If it's ever supported, this test will need to be updated
        self.assertTrue(getattr(torch.ops.nvprims, "digamma", None) is None)

        a = torch.randn(3, 3, device=device)

        def func(a):
            return torch.digamma(a)

        with TorchRefsNvfuserCapabilityMode():
            gm = make_fx(func)(a)

        # Check that the torch.digamma is not replaced with torch.ops.prims.digamma
        call_function_nodes = list(
            filter(lambda n: n.op == "call_function", gm.graph.nodes))
        includes_aten_digamma = any(
            torch.ops.aten.digamma.default == node.target
            for node in call_function_nodes)
        includes_prims_digamma = any(
            torch.ops.prims.digamma.default == node.target
            for node in call_function_nodes)
        self.assertTrue(includes_aten_digamma)
        self.assertFalse(includes_prims_digamma)

        # Check mixed case, sigmoid is replaced with refs, but digamma is not
        def func(a):
            return torch.sigmoid(torch.digamma(a))

        with TorchRefsNvfuserCapabilityMode():
            gm = make_fx(func)(a)

        call_function_nodes = list(
            filter(lambda n: n.op == "call_function", gm.graph.nodes))
        includes_aten_sigmoid = any(
            torch.ops.aten.sigmoid.default == node.target
            for node in call_function_nodes)
        includes_prims_digamma = any(
            torch.ops.prims.digamma.default == node.target
            for node in call_function_nodes)
        includes_nvprims_exp = any(torch.ops.nvprims.exp.default == node.target
                                   for node in call_function_nodes)
        self.assertFalse(includes_aten_sigmoid)
        self.assertFalse(includes_prims_digamma)
        self.assertTrue(includes_nvprims_exp)
Exemple #2
0
    def test_decomposition_interpreter(self):
        def fn(x):
            return torch.nn.functional.silu(x)

        x = torch.rand((4, 4))
        fx_module = make_fx(fn, decomposition_table=None)(x)

        found_silu = False
        for n in fx_module.graph.nodes:
            if n.target == torch.ops.aten.silu or n.target == torch.ops.aten.silu.default:
                found_silu = True

        self.assertTrue(found_silu)

        new_graph = torch.fx.Graph()
        silu_decomp_table = {torch.ops.aten.silu.default: decomposition_table[torch.ops.aten.silu.default]}
        DecompositionInterpreter(
            fx_module,
            new_graph=new_graph,
            decomposition_table=silu_decomp_table,
        ).run(x)

        decomposed_module = torch.fx.GraphModule(fx_module, new_graph)

        for n in decomposed_module.graph.nodes:
            self.assertTrue(n.target != torch.ops.aten.silu)
            self.assertTrue(n.target != torch.ops.aten.silu.default)

        self.assertEqual(fx_module(x), decomposed_module(x))
    def test_reinplace_scatter_twice_with_different_view_op_invalid2(self):
        def f(a_):
            a = a_.clone()
            b = a[:, 1]
            c = b[1]
            c_updated = c.add(1)
            bad_mirror_of_b = a.as_strided((4, ), (4, ), 0)
            # The first arg to select_scatter points to a different than c's base.
            # This makes it invalid to re-inplace.
            b_updated = torch.select_scatter(bad_mirror_of_b, c_updated, 0, 1)
            return b_updated

        inpt = torch.ones(4, 4)
        f2 = reinplace(make_fx(f)(inpt), inpt)
        expected_out = f(inpt)
        actual_out = f2(inpt)
        # self.assertEqual(actual_out, expected_out)
        self.assertExpectedInline(f2.code, """\



def forward(self, a__1):
    clone_default = torch.ops.aten.clone.default(a__1);  a__1 = None
    slice_tensor = torch.ops.aten.slice.Tensor(clone_default, 0, 0, 9223372036854775807)
    select_int = torch.ops.aten.select.int(slice_tensor, 1, 1);  slice_tensor = None
    select_int_1 = torch.ops.aten.select.int(select_int, 0, 1);  select_int = None
    add_tensor = torch.ops.aten.add.Tensor(select_int_1, 1);  select_int_1 = None
    as_strided_default = torch.ops.aten.as_strided.default(clone_default, [4], [4], 0);  clone_default = None
    select_int_2 = torch.ops.aten.select.int(as_strided_default, 0, 1)
    copy__default = torch.ops.aten.copy_.default(select_int_2, add_tensor);  select_int_2 = add_tensor = None
    return as_strided_default
    """)  # noqa: B950
Exemple #4
0
    def test_nvfuser_executor_partitioned_no_partitions_error(self, device):
        # This test is to ensure that nvfuser partitioned executor works correctly
        # It's assumed that digamma is not supported by nvfuser
        # If it's ever supported, this test will need to be updated
        self.assertTrue(torch.ops.prims.digamma.default.impl_nvfuser is None)

        from torch.fx.experimental.proxy_tensor import make_fx
        from torch._prims.context import TorchRefsMode
        from torch._prims.executor import execute

        a = torch.randn(3, 4, device=device)

        def func(a):
            return torch.digamma(a)  # not supported by nvfuser

        with TorchRefsMode.push():
            gm = make_fx(func)(a)

        with catch_warnings(record=True) as w:
            # Trigger warning
            execute(gm, a, executor="nvfuser")
            # Check warning occurs
            self.assertEqual(len(w), 1)
            self.assertTrue(
                "is not supported by nvFuser" in str(w[-1].message))
Exemple #5
0
    def _traced(*args, executor="aten", **kwargs):
        # TODO: caching
        wrapped, all_args = wrapper_and_args_for_make_fx(fn, args, kwargs)

        with NvfuserPrimsMode(), TorchRefsMode():
            gm = make_fx(wrapped)(all_args)
        return execute(gm, all_args, executor=executor)
Exemple #6
0
    def test_scalar_device(self, device):
        def f(a, b):
            return a + b

        inps = [torch.randn(3, device=device), torch.tensor(5)]
        fx_f = make_fx(f)(*inps)
        self.assertEqual(fx_f(*inps), f(*inps))
Exemple #7
0
    def test_aten_where(self, device, dtype):
        def fn(x):
            where = torch.ops.aten.where(x < 0, -x, x)
            a = where + 1
            b = a + 1
            return b

        inputs = torch.randn(4, device=device)
        traced = make_fx(fn)(inputs)

        nvfuser = NvFuserBackend()
        compiled_module = nvfuser.compile(copy.deepcopy(traced))

        for node in compiled_module.graph.nodes:
            if node.op == "call_function":
                assert "fused" in str(
                    node.target
                ), "the entire function should be fused into a single fusion group"

        eager_result = traced(inputs)
        nvfuser_result = compiled_module(inputs)
        torch.testing.assert_close(eager_result,
                                   nvfuser_result,
                                   rtol=1e-5,
                                   atol=1e-5)
    def test_reinplace_with_view(self):
        def f(x):
            a = x.clone()
            a_view = a.view(-1)
            # We shouldn't re-inplace the first add(), because an alias of a is re-used later in the program
            b = a.add(1)
            # Second add() is fine to re-inplace
            c = a_view.add(1)
            return c

        inpt = torch.ones(2)
        f2 = reinplace(make_fx(f)(inpt), inpt)
        expected_out = f(inpt)
        actual_out = f2(inpt)
        self.assertEqual(actual_out, expected_out)
        self.assertExpectedInline(
            f2.code, """\



def forward(self, x_1):
    clone_default = torch.ops.aten.clone.default(x_1);  x_1 = None
    view_default = torch.ops.aten.view.default(clone_default, [-1])
    add_tensor = torch.ops.aten.add.Tensor(clone_default, 1);  clone_default = None
    add_tensor_1 = torch.ops.aten.add_.Tensor(view_default, 1)
    return view_default
    """)
    def test_reinplace_different_metadata(self):
        def f(a_):
            a = a_.clone()
            b = a + 1
            # Naively, we shouldn't try to inplace the .ge() call,
            # because that would require resizing "b" (from a float to a bool tensor).
            c = torch.ge(b, a)
            return c

        inpt = torch.ones(4)
        f2 = reinplace(make_fx(f)(inpt), inpt)
        expected_out = f(inpt)
        actual_out = f2(inpt)
        self.assertEqual(actual_out, expected_out)
        # The .ge() should not be reinplaced.
        self.assertExpectedInline(
            f2.code, """\



def forward(self, a__1):
    clone_default = torch.ops.aten.clone.default(a__1);  a__1 = None
    add_tensor = torch.ops.aten.add.Tensor(clone_default, 1)
    ge_tensor = torch.ops.aten.ge.Tensor(add_tensor, clone_default);  add_tensor = clone_default = None
    return ge_tensor
    """)
    def test_out_node_updated(self):
        def f():
            x = torch.zeros(2, 2)
            y = x.diagonal()
            y_updated = y.add(1)
            z = torch.diagonal_scatter(x, y_updated)
            # reinplace needs to know to replace output [z] with [x]
            return [z]

        if not HAS_FUNCTIONALIZATION:
            return
        f2 = reinplace(make_fx(functionalize(f))())
        expected_out = f()
        actual_out = f2()
        self.assertEqual(actual_out, expected_out)
        self.assertExpectedInline(
            f2.code, """\



def forward(self):
    zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
    diagonal_default = torch.ops.aten.diagonal.default(zeros)
    add_tensor = torch.ops.aten.add_.Tensor(diagonal_default, 1);  diagonal_default = None
    return [zeros]
    """)
    def test_reinplace_index_mutation(self):
        def f():
            a = torch.zeros(4, 4, 4)
            a[:, 2:] = torch.ones(4, 2, 4)
            return a

        if not HAS_FUNCTIONALIZATION:
            return
        f2 = reinplace(make_fx(functionalize(f))())
        expected_out = f()
        actual_out = f2()
        self.assertEqual(actual_out, expected_out)
        self.assertExpectedInline(
            f2.code, """\



def forward(self):
    zeros = torch.ops.aten.zeros.default([4, 4, 4], device = device(type='cpu'), pin_memory = False)
    ones = torch.ops.aten.ones.default([4, 2, 4], device = device(type='cpu'), pin_memory = False)
    slice_tensor = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807)
    slice_tensor_1 = torch.ops.aten.slice.Tensor(slice_tensor, 1, 2, 9223372036854775807);  slice_tensor = None
    slice_tensor_2 = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807)
    slice_tensor_3 = torch.ops.aten.slice.Tensor(slice_tensor_2, 1, 2, 9223372036854775807);  slice_tensor_2 = None
    copy__default = torch.ops.aten.copy_.default(slice_tensor_3, ones);  slice_tensor_3 = ones = None
    return zeros
    """)
Exemple #12
0
    def test_nvfuser_executor_partitioned(self, device):
        # This test is to ensure that nvfuser partitioned executor works correctly
        # It's assumed that digamma is not supported by nvfuser
        # If it's ever supported, this test will need to be updated
        self.assertTrue(torch.ops.prims.digamma.default.impl_nvfuser is None)

        from torch.fx.experimental.proxy_tensor import make_fx
        from torch._prims.context import TorchRefsMode
        from torch._prims.executor import execute

        a = torch.randn(3, 4, device=device)
        b = torch.rand(3, 1, device=device)
        c = torch.rand(3, 4, device=device)

        def func(a, b, c):
            aa = torch.digamma(a)  # not supported by nvfuser
            d = torch.add(b, c)
            dd = torch.sqrt(d)
            return torch.mul(aa, dd.digamma())

        with TorchRefsMode.push():
            gm = make_fx(func)(a, b, c)

        expected = execute(gm, a, b, c, executor="aten")
        actual = execute(gm, a, b, c, executor="nvfuser")
        self.assertEqual(expected, actual)
    def test_reinplace_scatter_twice_with_different_view_op_invalid(self):
        def f(a_):
            a = a_.clone()
            b = a[:, 1]
            c = b[1]
            c_updated = c.add(1)
            good_mirror_of_b = a.as_strided((4, ), (4, ), 1)
            # The first arg to select_scatter is an equivalent view to b.
            # However, the select_scatter call below tries to put c_updated
            # into a different slice of "b" than what "c" currently occupies.
            #
            b_updated = torch.select_scatter(good_mirror_of_b, c_updated, 0, 0)
            return b_updated

        inpt = torch.ones(4, 4)
        f2 = reinplace(make_fx(f)(inpt), inpt)
        expected_out = f(inpt)
        actual_out = f2(inpt)
        self.assertEqual(actual_out, expected_out)
        self.assertExpectedInline(f2.code, """\



def forward(self, a__1):
    clone_default = torch.ops.aten.clone.default(a__1);  a__1 = None
    slice_tensor = torch.ops.aten.slice.Tensor(clone_default, 0, 0, 9223372036854775807)
    select_int = torch.ops.aten.select.int(slice_tensor, 1, 1);  slice_tensor = None
    select_int_1 = torch.ops.aten.select.int(select_int, 0, 1);  select_int = None
    add_tensor = torch.ops.aten.add.Tensor(select_int_1, 1);  select_int_1 = None
    as_strided_default = torch.ops.aten.as_strided.default(clone_default, [4], [4], 1);  clone_default = None
    select_int_2 = torch.ops.aten.select.int(as_strided_default, 0, 0)
    copy__default = torch.ops.aten.copy_.default(select_int_2, add_tensor);  select_int_2 = add_tensor = None
    return as_strided_default
    """)  # noqa: B950
    def test_reinplace_scatter_twice(self):
        def f(a_):
            # for now, don't test mutations to inputs
            a = a_.clone()
            b = a[:, 1]
            c = b[1]
            c.add_(1)
            return a

        if not HAS_FUNCTIONALIZATION:
            return

        inpt = torch.ones(4, 4)
        f2 = reinplace(make_fx(functionalize(f))(inpt), inpt)
        expected_out = f(inpt)
        actual_out = f2(inpt)
        self.assertEqual(actual_out, expected_out)
        self.assertExpectedInline(
            f2.code, """\



def forward(self, a__1):
    clone_default = torch.ops.aten.clone.default(a__1);  a__1 = None
    slice_tensor = torch.ops.aten.slice.Tensor(clone_default, 0, 0, 9223372036854775807)
    select_int = torch.ops.aten.select.int(slice_tensor, 1, 1);  slice_tensor = None
    select_int_1 = torch.ops.aten.select.int(select_int, 0, 1);  select_int = None
    add_tensor = torch.ops.aten.add_.Tensor(select_int_1, 1);  select_int_1 = None
    slice_tensor_1 = torch.ops.aten.slice.Tensor(clone_default, 0, 0, 9223372036854775807)
    select_int_2 = torch.ops.aten.select.int(slice_tensor_1, 1, 1);  slice_tensor_1 = None
    return clone_default
    """)
Exemple #15
0
    def test_use_fake_and_tensor(self):
        def f(x, y):
            z = torch.tensor([2.0, 3.0])
            return x + y + z

        g = make_fx(f, use_fake=True)(torch.randn(2), torch.randn(2))
        x, y = torch.randn(2), torch.randn(2)
        self.assertEqual(g(x, y), f(x, y))
Exemple #16
0
    def test_make_fx_overloads(self):
        def f(x):
            return x.cos() + torch.randn(x.shape)

        traced = make_fx(f)(torch.randn(3))

        self.assertTrue(all([isinstance(node.target, torch._ops.OpOverload)
                             for node in traced.graph.nodes if node.op == 'call_function']))
Exemple #17
0
    def test_constant_proxy_tensor_mut(self):
        from torch.fx.experimental.proxy_tensor import make_fx

        def f():
            val = torch.tensor(float(1))
            val.add_(2)
            return torch.full((100, 100), val)

        g = make_fx(f)()
        self.assertEqual(g(), f())
        # In case we mutated shared state in the g graph!
        self.assertEqual(g(), f())

        g = make_fx(f, use_fake=True)()
        self.assertEqual(g(), f())
        # In case we mutated shared state in the g graph!
        self.assertEqual(g(), f())
Exemple #18
0
def check(self, f, t, delta, check_val=True, graph_input=False, P=None):
    """
    check if the CSE modified graph of ``f``
    1) has delta less nodes, and
    2) do not reduce the number of nodes further on a second pass, and
    3) modified returned is true only if the number of nodes decreases.

    Args:
        f: function to be checked
        t: tensor to be passed to f
        delta: an integer >= -1.
               If delta = -1, it only checks if the new graph has less or equal number of nodes
        check_val: if True, check if the output of f is correct
        graph_input: True is f is type GraphModule
        P: the pass to use. If None, use P_default
    """
    if graph_input:
        fx_g = f
    else:
        fx_g = make_fx(f)(t)

    if P is None:
        P = P_default

    res = P(fx_g)
    new_g = res.graph_module
    new_graph = new_g.graph
    modified = res.modified

    # the number of nodes decrease/ or stay the same
    old_num_nodes = len(fx_g.graph.nodes)
    new_num_nodes = len(new_graph.nodes)

    assert (new_num_nodes < old_num_nodes) == modified, "modified should be True if the number of nodes decrease"

    if delta == -1:
        self.assertTrue(old_num_nodes >= new_num_nodes, (
            f"number of nodes increased {old_num_nodes}, {new_num_nodes}"))
    else:
        self.assertTrue(old_num_nodes == new_num_nodes + delta, (
            f"number of nodes not the same {old_num_nodes - delta}, {new_num_nodes}\n {fx_g.graph} \n {new_graph}"))

    # a second pass should not reduce more nodes
    res = P(new_g)
    pass_2_graph = res.graph_module.graph
    pass_2_num_nodes = len(pass_2_graph.nodes)
    self.assertTrue(pass_2_num_nodes == new_num_nodes, (
        f"second pass graph has less node {pass_2_num_nodes}, {new_num_nodes}\n {new_graph} \n {pass_2_graph}"))

    # check correctness
    if check_val:
        true_result = fx_g(t)
        our_result = new_g(t)
        if true_result is None:  # both return None
            self.assertTrue(our_result is None, f"true result is None, CSE result is {our_result}")
        else:  # results returned are the same
            self.assertTrue(torch.all(true_result == our_result), (
                f"results are different {true_result}, {our_result}"))  # check results are the same
Exemple #19
0
    def test_mode_tracing_factory_function(self):
        def f(x):
            return x + torch.randn(x.shape)

        # default behavior should trace factory functions
        traced = make_fx(f)(torch.randn(3))
        self.assertTrue(
            any(node.target == torch.ops.aten.randn.default
                for node in traced.graph.nodes))
Exemple #20
0
    def test_make_fx(self, device):
        def f(x):
            return torch.sin(x)

        inp = torch.randn(3)
        fx_f = make_fx(f)(inp)

        new_inp = torch.randn(3)
        self.assertEqual(fx_f(new_inp), f(new_inp))
Exemple #21
0
    def test_mode_tracing_factory_function_no_factory_function(self):
        def f(x):
            return x + torch.randn(x.shape)

        # setting the flag to false should not trace factory functions
        traced = make_fx(f, trace_factory_functions=False)(torch.randn(3))
        self.assertFalse(
            any(node.target == torch.ops.aten.randn.default
                for node in traced.graph.nodes))
Exemple #22
0
    def test_constant_proxy_tensor(self):
        from torch.fx.experimental.proxy_tensor import make_fx

        def f():
            val = torch.tensor(float('inf'))
            return torch.full((100, 100), val)

        g = make_fx(f)()
        self.assertEqual(g(), f())
Exemple #23
0
    def test_mode_tracing_factory_function(self):
        def f(x):
            return x + torch.randn(x.shape)

        traced = make_fx(f, trace_factory_functions=True)(torch.randn(3))
        self.assertTrue(
            any(
                isinstance(node.target, torch._ops.OpOverloadPacket)
                and node.target._qualified_op_name == 'aten::randn'
                for node in traced.graph.nodes))
Exemple #24
0
    def test_inplace_metadata(self):
        def f(x):
            x = x.clone()
            x.unsqueeze_(-1)
            assert x.shape[-1] == 1
            return x

        inps = [torch.randn(5)]
        fx_f = make_fx(f)(*inps)
        self.assertEqual(fx_f(*inps), f(*inps))
Exemple #25
0
    def test_mode_tracing_factory_function_default_behavior(self):
        def f(x):
            return x + torch.randn(x.shape)

        traced = make_fx(f)(torch.randn(
            3))  # default behavior should not trace factory functions
        self.assertFalse(
            any(
                isinstance(node.target, torch._ops.OpOverloadPacket)
                and node.target._qualified_op_name == 'aten::randn'
                for node in traced.graph.nodes))
Exemple #26
0
    def test_proxy_tensor(self):
        def f_grad(x):
            val = x.cos().cos().sum()
            return torch.autograd.grad(val, x)

        def f_backward(x):
            val = x.cos().cos().sum()
            val.backward()
            return x.grad

        for f in [f_grad, f_backward]:
            traced_graph = make_fx(f)(torch.randn(3, requires_grad=True))
            inp = torch.randn(3, requires_grad=True)
            traced_graph_out = traced_graph(inp)
            assert inp.grad is None
            torch.testing.assert_close(traced_graph_out, f(inp))
Exemple #27
0
    def _traced(*args, executor="aten", **kwargs):
        # TODO: caching
        nargs = len(args)
        fn_kwargs = kwargs
        flat_fn_kwargs = list(fn_kwargs.values())
        all_args = list(args) + flat_fn_kwargs

        def wrapped(args):
            fn_args = args[:nargs]
            kwargs_keys = list(fn_kwargs.keys())
            kwargs = dict(zip(kwargs_keys, args[nargs:]))
            return fn(*fn_args, **kwargs)

        with TorchRefsMode.push():
            gm = make_fx(wrapped)(all_args)
        return execute(gm, all_args, executor=executor)
Exemple #28
0
    def test_nvprims(self, device):
        # This test is to ensure that nvfuser specific prims are exposed
        # and can be traced with make_fx
        from torch.fx.experimental.proxy_tensor import make_fx

        def func(a):
            return torch.ops.nvprims.add(a, a)

        a = torch.randn(3, 4, device=device)
        gm = make_fx(func)(a)

        for node in gm.graph.nodes:
            if node.op == "call_function":
                self.assertTrue(node.name == "add_default")
                self.assertTrue(node.target == torch.ops.nvprims.add.default)
                self.assertFalse(node.target == torch.ops.prims.add.default)
                self.assertFalse(node.target == torch.ops.aten.add.default)
    def test_reinplace_scatter_op(self):
        def f(a_):
            # for now, don't test mutations to inputs
            a = a_.clone()
            e = a.view(-1)
            b = a.view(-1)
            c = b[0]
            d = c.view(-1)
            d.add_(1)
            return a + e

        if not HAS_FUNCTIONALIZATION:
            return
        inpt = torch.ones(4)
        f2 = reinplace(make_fx(functionalize(f))(inpt), inpt)
        expected_out = f(inpt)
        actual_out = f2(inpt)
        self.assertEqual(actual_out, expected_out)
        # NOTE: one slight pessimization here is the fact that
        # there are a bunch of redundant views in the graph.
        # Technically, half of these views are duplicates that we could de-dup.
        # This shouldn't really hurt performance though, since creating an extra view
        # is effectively just moving some metadata around (and allocating a new TensorImpl).
        # We can/should update the pass in the future to clean this up.
        self.assertExpectedInline(
            f2.code, """\



def forward(self, a__1):
    clone_default = torch.ops.aten.clone.default(a__1);  a__1 = None
    view_default = torch.ops.aten.view.default(clone_default, [-1])
    view_default_1 = torch.ops.aten.view.default(clone_default, [-1])
    select_int = torch.ops.aten.select.int(view_default_1, 0, 0);  view_default_1 = None
    view_default_2 = torch.ops.aten.view.default(select_int, [-1]);  select_int = None
    add_tensor = torch.ops.aten.add_.Tensor(view_default_2, 1)
    view_default_3 = torch.ops.aten.view.default(clone_default, [-1]);  clone_default = None
    select_int_1 = torch.ops.aten.select.int(view_default_3, 0, 0)
    view_default_4 = torch.ops.aten.view.default(view_default_2, []);  view_default_2 = None
    view_default_5 = torch.ops.aten.view.default(view_default_3, [4]);  view_default_3 = None
    view_default_6 = torch.ops.aten.view.default(view_default_5, [-1])
    add_tensor_1 = torch.ops.aten.add_.Tensor(view_default_5, view_default_6);  view_default_6 = None
    return view_default_5
    """)
Exemple #30
0
    def test_nvfuser_executor_cached_noncontiguous(self, device):
        # This test is to ensure that nvfuser computes correct results for noncontiguous tensors
        from torch.fx.experimental.proxy_tensor import make_fx
        from torch._prims.context import TorchRefsMode
        from torch._prims.executor import execute

        a = torch.randn(3, 3, device=device)

        def func(a):
            return torch.sigmoid(a)

        with TorchRefsMode.push():
            gm = make_fx(func)(a)

        # First run to create the cache
        execute(gm, a, executor="nvfuser")

        # a.mT is noncontiguous, but it shouldn't affect correctness
        expected = execute(gm, a.mT, executor="aten")
        actual = execute(gm, a.mT, executor="nvfuser")
        self.assertEqual(expected, actual)