def test_diff_graph_inline_threshold(self): with enable_profiling_mode_for_profiling_tests(): NUM_RUNS = 1 with num_profiled_runs(NUM_RUNS): @torch.jit.script def foo(x): # two nodes should be fused # see https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/runtime/graph_executor_impl.h#L49 return torch.sigmoid(torch.sigmoid(x)) @torch.jit.script def bar(x): # two nodes should NOT be fused return torch.sigmoid(x) input = torch.rand([4, 4], requires_grad=True) foo(input) foo(input) bar(input) bar(input) print(foo.graph_for(input)) self.assertGraphContainsExactly(foo.graph_for(input), 'prim::DifferentiableGraph', 1) self.assertGraphContainsExactly(bar.graph_for(input), 'prim::DifferentiableGraph', 0)
def test_dynamic_shape(self): with num_profiled_runs(2): @torch.jit.script def test(x, y, z): return x * y * z cuda = CudaCodeGenCreated() x, y, z = [torch.rand(4, 8).cuda() for _ in range(3)] ref = test(x, y, z) _ = test(*[torch.rand(6, 8).cuda() for _ in range(3)]) res = test(x, y, z) np.testing.assert_allclose(ref.cpu().numpy(), res.cpu().numpy()) assert cuda.elapsed_value() == 1 # A wild broadcast appears. x = torch.rand(4, 8).cuda() y = torch.rand(1, 8).cuda() z = torch.rand(4, 1).cuda() res = test(x, y, z) xn, yn, zn = [t.cpu().numpy() for t in (x, y, z)] np.testing.assert_allclose(res.cpu().numpy(), xn * yn * zn) assert cuda.elapsed_value() == 1 # Mismatched shapes shouldn't reach codegen. x = torch.rand(4, 8).cuda() y = torch.rand(4, 8).cuda() z = torch.rand(5, 8).cuda() try: res = test(x, y, z) except RuntimeError as e: assert "The size of tensor a (4) must match" in e.args[0] assert cuda.elapsed_value() == 1
def test_prune_grad(self): @torch.jit.script def t(input, bias): return torch.nn.functional.relu(input + bias) input = torch.randn(2, 8, requires_grad=True) bias = torch.randn(8, requires_grad=False) # bias does NOT require grad NUM_PROFILED_RUNS = 1 with num_profiled_runs(NUM_PROFILED_RUNS): WARMUP = 3 # 2 runs to reach backward + 1 to optimize it for x in range(WARMUP): o = t(input, bias) o.sum().backward() fwd_plan = list(t.get_debug_state().execution_plans.values())[0] bwd_graph = list(fwd_plan.code.grad_executor_states()[0].execution_plans.values())[0].graph tup = next(bwd_graph.outputs()) self.assertEqual(len(list(tup.node().inputs())), 1)