Beispiel #1
0
    def test_nvfuser_executor_partitioned(self, device):
        # This test is to ensure that nvfuser partitioned executor works correctly
        # It's assumed that digamma is not supported by nvfuser
        # If it's ever supported, this test will need to be updated
        self.assertTrue(torch.ops.prims.digamma.default.impl_nvfuser is None)

        from torch.fx.experimental.proxy_tensor import make_fx
        from torch._prims.context import TorchRefsMode
        from torch._prims.executor import execute

        a = torch.randn(3, 4, device=device)
        b = torch.rand(3, 1, device=device)
        c = torch.rand(3, 4, device=device)

        def func(a, b, c):
            aa = torch.digamma(a)  # not supported by nvfuser
            d = torch.add(b, c)
            dd = torch.sqrt(d)
            return torch.mul(aa, dd.digamma())

        with TorchRefsMode.push():
            gm = make_fx(func)(a, b, c)

        expected = execute(gm, a, b, c, executor="aten")
        actual = execute(gm, a, b, c, executor="nvfuser")
        self.assertEqual(expected, actual)
Beispiel #2
0
    def test_nvfuser_executor_partitioned_no_partitions_error(self, device):
        # This test is to ensure that nvfuser partitioned executor works correctly
        # It's assumed that digamma is not supported by nvfuser
        # If it's ever supported, this test will need to be updated
        self.assertTrue(torch.ops.prims.digamma.default.impl_nvfuser is None)

        from torch.fx.experimental.proxy_tensor import make_fx
        from torch._prims.context import TorchRefsMode
        from torch._prims.executor import execute

        a = torch.randn(3, 4, device=device)

        def func(a):
            return torch.digamma(a)  # not supported by nvfuser

        with TorchRefsMode.push():
            gm = make_fx(func)(a)

        with catch_warnings(record=True) as w:
            # Trigger warning
            execute(gm, a, executor="nvfuser")
            # Check warning occurs
            self.assertEqual(len(w), 1)
            self.assertTrue(
                "is not supported by nvFuser" in str(w[-1].message))
Beispiel #3
0
    def test_nvfuser_executor_cached_noncontiguous(self, device):
        # This test is to ensure that nvfuser computes correct results for noncontiguous tensors
        from torch.fx.experimental.proxy_tensor import make_fx
        from torch._prims.context import TorchRefsMode
        from torch._prims.executor import execute

        a = torch.randn(3, 3, device=device)

        def func(a):
            return torch.sigmoid(a)

        with TorchRefsMode.push():
            gm = make_fx(func)(a)

        # First run to create the cache
        execute(gm, a, executor="nvfuser")

        # a.mT is noncontiguous, but it shouldn't affect correctness
        expected = execute(gm, a.mT, executor="aten")
        actual = execute(gm, a.mT, executor="nvfuser")
        self.assertEqual(expected, actual)
Beispiel #4
0
    def lower_to_prims_and_execute(self, graph_module: GraphModule, *args,
                                   **kwargs):
        # `graph_module` is an Aten-Fx graph
        # "lowering to prims" and "trace execution" are grouped into this function, as they are both input dependent

        if graph_module in self.prim_decomp_cache:
            logging.debug("prim_decomp_cache hit!")
            prim_module = self.prim_decomp_cache[graph_module]
        else:
            prim_graph = torch.fx.Graph()
            DecompositionInterpreter(graph_module,
                                     prim_graph,
                                     decomposition_table=aten2prim_decomp).run(
                                         *args, **kwargs)
            prim_module = torch.fx.GraphModule(graph_module, prim_graph)
            self.prim_decomp_cache[graph_module] = prim_module

            logging.debug("Lower to prims graph: ", prim_module.code)

        # invokes trace executor for running the prim graph
        return execute(prim_module, *args, executor="nvfuser")