def test_kernel_shape_prop(self): device, size = 'cpu', (4, 4) x = torch.rand(size, device=device) y = torch.rand(size, device=device) graph_str = """ graph(%a : Tensor, %b : Tensor): %c : Tensor = aten::mul(%a, %b) return (%c) """ graph = torch._C.parse_ir(graph_str) exception_thrown = False try: kernel = te.TensorExprKernel(graph) except RuntimeError: # Graph doesn't have shape info for inputs => compilation should # fail exception_thrown = True pass assert exception_thrown # Inject shape info and try compiling again example_inputs = [torch.rand(4, 4), torch.rand(4, 4)] torch._C._te.annotate_input_shapes(graph, example_inputs) torch._C._jit_pass_propagate_shapes_on_graph(graph) # Now compilation should pass kernel = te.TensorExprKernel(graph) res = kernel.run((x, y)) correct = torch.mul(x, y) np.testing.assert_allclose(res.numpy(), correct.numpy(), atol=1e-5)
def test_kernel_with_tensor_inputs(self): def f(a, b, c): return a + b + c device, size = 'cpu', (4, 4) x = torch.rand(size, device=device) y = torch.rand(size, device=device) z = torch.rand(size, device=device) graph_str = """ graph(%a.1 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu), %b.1 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu), %c.1 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu)): %6 : int = prim::Constant[value=1]() %7 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::add(%a.1, %b.1, %6) %3 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::add(%7, %c.1, %6) return (%3) """ graph = torch._C.parse_ir(graph_str) kernel = te.TensorExprKernel(graph) res1 = kernel.run((x, y, z)) res2 = kernel.fallback((x, y, z)) correct = f(x, y, z) np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3) np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3)
def test_kernel_with_scalar_inputs(self): def f(a, b, c): return a + b + c x = torch.tensor(0.1, dtype=torch.float, device='cpu') y = torch.tensor(0.6, dtype=torch.float, device='cpu') z = torch.tensor(0.7, dtype=torch.float, device='cpu') graph_str = """ graph(%a.1 : Float(requires_grad=0, device=cpu), %b.1 : Float(requires_grad=0, device=cpu), %c.1 : Float(requires_grad=0, device=cpu)): %3 : int = prim::Constant[value=1]() %6 : Float(requires_grad=0, device=cpu) = aten::add(%a.1, %b.1, %3) %9 : Float(requires_grad=0, device=cpu) = aten::add(%6, %c.1, %3) return (%9) """ graph = torch._C.parse_ir(graph_str) kernel = te.TensorExprKernel(graph) res1 = kernel.run((x, y, z)) res2 = kernel.fallback((x, y, z)) correct = f(x, y, z) np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3) np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3)
def test_kernel_with_custom_lowering(self): def f(a): return a.nan_to_num() device = "cpu" x = torch.ones((2, 2), device=device) x[0, 0] = x[1, 1] = torch.nan graph_str = """ graph(%x : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu)): %none : NoneType = prim::Constant() %y : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::nan_to_num(%x, %none, %none, %none) return (%y) """ graph = torch._C.parse_ir(graph_str) def my_custom_lowering(inputs, out_shape, out_type, device): def compute(idxs): load = inputs[0].as_buf().load(idxs) return te.ifThenElse(te.ExprHandle.isnan(load), te.ExprHandle.float(0.0), load) return te.Compute2("custom_nan_to_num", out_shape, compute) kernel = te.TensorExprKernel(graph, {"aten::nan_to_num": my_custom_lowering}) res1 = kernel.run((x, )) res2 = kernel.fallback((x, )) correct = f(x) np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3) np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3)
def test_kernel_shape_prop_module(self): class TestModule(torch.nn.Module): def forward(self, x, y): return x * x + y graph = torch.jit.script(TestModule()).graph # Try compiling the graph as-is. It should fail because it doesn't have # shape info. exception_thrown = False try: kernel = te.TensorExprKernel(graph) except RuntimeError: exception_thrown = True pass assert exception_thrown # Try injecting shape info for graph inputs example_inputs = [torch.rand(4, 4), torch.rand(4, 4)] exception_thrown = False try: torch._C._te.annotate_input_shapes(graph, example_inputs) except RuntimeError: # Graph has a 'self' argument for which we can't set shapes exception_thrown = True pass assert exception_thrown # Remove 'self' argument and try annotating shapes one more time torch._C._te.remove_unused_self_argument(graph) # Inject shape info and try compiling again torch._C._te.annotate_input_shapes(graph, example_inputs) torch._C._jit_pass_propagate_shapes_on_graph(graph) # Now compilation should pass kernel = te.TensorExprKernel(graph) device, size = 'cpu', (4, 4) x = torch.rand(size, device=device) y = torch.rand(size, device=device) res = kernel.run((x, y)) correct = TestModule().forward(x, y) np.testing.assert_allclose(res.numpy(), correct.numpy(), atol=1e-5)
def test_kernel_with_t(self): def f(a): return a.t() device, size = 'cpu', (3, 4) x = torch.rand(size, device=device) graph_str = """ graph(%a.1 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)): %3 : Float(4, 3, strides=[4, 1], requires_grad=0, device=cpu) = aten::t(%a.1) return (%3) """ graph = torch._C.parse_ir(graph_str) kernel = te.TensorExprKernel(graph) res1 = kernel.run((x, )) res2 = kernel.fallback((x, )) correct = f(x) np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3) np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3)
def test_kernel_with_transpose(self): def f(a): return a.transpose(-1, -2) device, size = 'cpu', (3, 4) x = torch.rand(size, device=device) graph_str = """ graph(%a.1 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)): %2 : int = prim::Constant[value=-1]() %3 : int = prim::Constant[value=-2]() %4 : Float(4, 3, strides=[4, 1], requires_grad=0, device=cpu) = aten::transpose(%a.1, %2, %3) return (%4) """ graph = torch._C.parse_ir(graph_str) kernel = te.TensorExprKernel(graph) res1 = kernel.run((x, )) res2 = kernel.fallback((x, )) correct = f(x) np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3) np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3)
def test_kernel_with_expand(self): def f(a): return a.expand((2, 3, 4)) device = 'cpu' x = torch.rand((1, 3, 1), device=device) graph_str = """ graph(%a : Float(1, 3, 1, strides=[3, 1, 1], requires_grad=0, device=cpu)): %1 : int = prim::Constant[value=2]() %2 : int = prim::Constant[value=3]() %3 : int = prim::Constant[value=4]() %4 : int[] = prim::ListConstruct(%1, %2, %3) %5 : bool = prim::Constant[value=0]() %6 : Float(2, 3, 4, strides=[12, 4, 0], requires_grad=0, device=cpu) = aten::expand(%a, %4, %5) return (%6) """ graph = torch._C.parse_ir(graph_str) kernel = te.TensorExprKernel(graph) res1 = kernel.run((x, )) res2 = kernel.fallback((x, )) correct = f(x) np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3) np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3)
def test_kernel_with_permute(self): def f(a): return a.permute([2, 1, 0]) device, size = 'cpu', (3, 4, 5) x = torch.rand(size, device=device) graph_str = """ graph(%a.1 : Float(3, 4, 5, strides=[20, 5, 1], requires_grad=0, device=cpu)): %1 : int = prim::Constant[value=2]() %2 : int = prim::Constant[value=1]() %3 : int = prim::Constant[value=0]() %4 : int[] = prim::ListConstruct(%1, %2, %3) %5 : Float(5, 4, 3, strides=[12, 3, 1], requires_grad=0, device=cpu) = aten::permute(%a.1, %4) return (%5) """ graph = torch._C.parse_ir(graph_str) kernel = te.TensorExprKernel(graph) res1 = kernel.run((x, )) res2 = kernel.fallback((x, )) correct = f(x) np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3) np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3)