def test_nvfuser_executor_partitioned(self, device): # This test is to ensure that nvfuser partitioned executor works correctly # It's assumed that digamma is not supported by nvfuser # If it's ever supported, this test will need to be updated self.assertTrue(torch.ops.prims.digamma.default.impl_nvfuser is None) from torch.fx.experimental.proxy_tensor import make_fx from torch._prims.context import TorchRefsMode from torch._prims.executor import execute a = torch.randn(3, 4, device=device) b = torch.rand(3, 1, device=device) c = torch.rand(3, 4, device=device) def func(a, b, c): aa = torch.digamma(a) # not supported by nvfuser d = torch.add(b, c) dd = torch.sqrt(d) return torch.mul(aa, dd.digamma()) with TorchRefsMode.push(): gm = make_fx(func)(a, b, c) expected = execute(gm, a, b, c, executor="aten") actual = execute(gm, a, b, c, executor="nvfuser") self.assertEqual(expected, actual)
def test_nvfuser_executor_partitioned_no_partitions_error(self, device): # This test is to ensure that nvfuser partitioned executor works correctly # It's assumed that digamma is not supported by nvfuser # If it's ever supported, this test will need to be updated self.assertTrue(torch.ops.prims.digamma.default.impl_nvfuser is None) from torch.fx.experimental.proxy_tensor import make_fx from torch._prims.context import TorchRefsMode from torch._prims.executor import execute a = torch.randn(3, 4, device=device) def func(a): return torch.digamma(a) # not supported by nvfuser with TorchRefsMode.push(): gm = make_fx(func)(a) with catch_warnings(record=True) as w: # Trigger warning execute(gm, a, executor="nvfuser") # Check warning occurs self.assertEqual(len(w), 1) self.assertTrue( "is not supported by nvFuser" in str(w[-1].message))
def test_nvfuser_executor_cached_noncontiguous(self, device): # This test is to ensure that nvfuser computes correct results for noncontiguous tensors from torch.fx.experimental.proxy_tensor import make_fx from torch._prims.context import TorchRefsMode from torch._prims.executor import execute a = torch.randn(3, 3, device=device) def func(a): return torch.sigmoid(a) with TorchRefsMode.push(): gm = make_fx(func)(a) # First run to create the cache execute(gm, a, executor="nvfuser") # a.mT is noncontiguous, but it shouldn't affect correctness expected = execute(gm, a.mT, executor="aten") actual = execute(gm, a.mT, executor="nvfuser") self.assertEqual(expected, actual)
def lower_to_prims_and_execute(self, graph_module: GraphModule, *args, **kwargs): # `graph_module` is an Aten-Fx graph # "lowering to prims" and "trace execution" are grouped into this function, as they are both input dependent if graph_module in self.prim_decomp_cache: logging.debug("prim_decomp_cache hit!") prim_module = self.prim_decomp_cache[graph_module] else: prim_graph = torch.fx.Graph() DecompositionInterpreter(graph_module, prim_graph, decomposition_table=aten2prim_decomp).run( *args, **kwargs) prim_module = torch.fx.GraphModule(graph_module, prim_graph) self.prim_decomp_cache[graph_module] = prim_module logging.debug("Lower to prims graph: ", prim_module.code) # invokes trace executor for running the prim graph return execute(prim_module, *args, executor="nvfuser")