def test_nvfuser_executor_partitioned_no_partitions_error(self, device): # This test is to ensure that nvfuser partitioned executor works correctly # It's assumed that digamma is not supported by nvfuser # If it's ever supported, this test will need to be updated self.assertTrue(torch.ops.prims.digamma.default.impl_nvfuser is None) from torch.fx.experimental.proxy_tensor import make_fx from torch._prims.context import TorchRefsMode from torch._prims.executor import execute a = torch.randn(3, 4, device=device) def func(a): return torch.digamma(a) # not supported by nvfuser with TorchRefsMode.push(): gm = make_fx(func)(a) with catch_warnings(record=True) as w: # Trigger warning execute(gm, a, executor="nvfuser") # Check warning occurs self.assertEqual(len(w), 1) self.assertTrue( "is not supported by nvFuser" in str(w[-1].message))
def test_nvfuser_executor_partitioned(self, device): # This test is to ensure that nvfuser partitioned executor works correctly # It's assumed that digamma is not supported by nvfuser # If it's ever supported, this test will need to be updated self.assertTrue(torch.ops.prims.digamma.default.impl_nvfuser is None) from torch.fx.experimental.proxy_tensor import make_fx from torch._prims.context import TorchRefsMode from torch._prims.executor import execute a = torch.randn(3, 4, device=device) b = torch.rand(3, 1, device=device) c = torch.rand(3, 4, device=device) def func(a, b, c): aa = torch.digamma(a) # not supported by nvfuser d = torch.add(b, c) dd = torch.sqrt(d) return torch.mul(aa, dd.digamma()) with TorchRefsMode.push(): gm = make_fx(func)(a, b, c) expected = execute(gm, a, b, c, executor="aten") actual = execute(gm, a, b, c, executor="nvfuser") self.assertEqual(expected, actual)
def _traced(*args, executor="aten", **kwargs): # TODO: caching wrapped, all_args = wrapper_and_args_for_make_fx(fn, args, kwargs) with NvfuserPrimsMode(), TorchRefsMode(): gm = make_fx(wrapped)(all_args) return execute(gm, all_args, executor=executor)
def _traced(*args, executor="aten", **kwargs): # TODO: caching nargs = len(args) fn_kwargs = kwargs flat_fn_kwargs = list(fn_kwargs.values()) all_args = list(args) + flat_fn_kwargs def wrapped(args): fn_args = args[:nargs] kwargs_keys = list(fn_kwargs.keys()) kwargs = dict(zip(kwargs_keys, args[nargs:])) return fn(*fn_args, **kwargs) with TorchRefsMode.push(): gm = make_fx(wrapped)(all_args) return execute(gm, all_args, executor=executor)
def test_nvfuser_executor_cached_noncontiguous(self, device): # This test is to ensure that nvfuser computes correct results for noncontiguous tensors from torch.fx.experimental.proxy_tensor import make_fx from torch._prims.context import TorchRefsMode from torch._prims.executor import execute a = torch.randn(3, 3, device=device) def func(a): return torch.sigmoid(a) with TorchRefsMode.push(): gm = make_fx(func)(a) # First run to create the cache execute(gm, a, executor="nvfuser") # a.mT is noncontiguous, but it shouldn't affect correctness expected = execute(gm, a.mT, executor="aten") actual = execute(gm, a.mT, executor="nvfuser") self.assertEqual(expected, actual)
def test_aten_overload_to_prims(self, device): # This test is to ensure that the torch.ops.aten calls are replaced with refs from torch.fx.experimental.proxy_tensor import make_fx from torch._prims.context import TorchRefsMode a = torch.randn(3, 3, device=device) def func(a): return torch.ops.aten.sigmoid.default( torch.ops.aten.digamma.default(a)) with TorchRefsMode(): gm = make_fx(func)(a) # Check that all call_function nodes are prims call_function_nodes = list( filter(lambda n: n.op == "call_function", gm.graph.nodes)) all_prims_namespace = all( node.target.name.startswith("prims") for node in call_function_nodes) self.assertTrue(all_prims_namespace)
def _traced(*args, executor="aten"): # TODO: caching with TorchRefsMode.push(): gm = make_fx(fn)(*args) return execute(gm, *args, executor=executor)