def test_grad_sum_to_size_elimination(self): def my_broadcasted_cell(a, b, c): return (a + b) + c s1 = torch.randn(5, 1, requires_grad=True, device='cuda') s2 = torch.randn(5, 5, requires_grad=True, device='cuda') module = self.checkScript(my_broadcasted_cell, (s1, s1, s1), profiling=ProfilingMode.PROFILING) forward_graph = module.graph_for(s1, s1, s1) self.assertAllFused(forward_graph, except_for=("aten::size", "prim::BroadcastSizes", "aten::_size_if_not_equal")) old_plans = set() for i in range(3): # if we have s2, then the s1 are _grad_sum_to_size'd args = s2 if i < 1 else s1, s2 if i < 2 else s1, s2 args = [a.detach_().requires_grad_() for a in args] # recompile, so we don't trigger bailouts module = self.checkScript(my_broadcasted_cell, args, profiling=ProfilingMode.PROFILING) res = module(s2 if i < 1 else s1, s2 if i < 2 else s1, s2) warmup_backward(res.sum(), args) grads = torch.autograd.grad(res.sum(), args) for inp, gr in zip(args, grads): self.assertEqual(inp.shape, gr.shape) backward = None # this is a workaround for the backward graphs not being # in order for Python 2 for g in all_backward_graphs(module): if str(g) not in old_plans: assert backward is None backward = g old_plans.add(str(backward)) num_grads = 1 if i > 0 else 0 self.assertEqual(len([n for n in backward.nodes() if n.kind() == 'aten::_grad_sum_to_size']), num_grads)
def test_clamp(self): def func2(a, b): return torch.clamp(a + b, min=0, max=2) def funcInf(a, b): return torch.clamp(a + b, min=0, max=float('inf')) def funcOptMin(a, b): return torch.clamp(a + b, max=2) def funcOptMax(a, b): return torch.clamp(a + b, min=0) a = torch.randn(4, 4, dtype=torch.float, device='cuda', requires_grad=True) b = torch.randn(4, 4, dtype=torch.float, device='cuda') nan = torch.tensor(float('nan'), dtype=torch.float, device='cuda') funcs = (func2, funcInf, funcOptMin, funcOptMax) for f, inputs in product(funcs, [[a, b], [a, nan]]): f.__disable_jit_function_caching__ = True inp1, inp2 = inputs s = self.checkScript(f, (inp1, inp2), profiling=ProfilingMode.PROFILING) self.assertAllFused( s.graph_for(inp1, inp2), except_for={'aten::size', 'aten::_size_if_not_equal'}) c = s(inp1, inp2) with enable_profiling_mode_for_profiling_tests(): warmup_backward(c.sum()) graph = backward_graph(s) self.assertAllFused( graph, except_for={'aten::Float', 'aten::_grad_sum_to_size'})
def test_fuser_deduplication(self): # See that fusion kernel outputs are deduplicated when removing _grad_sum_to_size in the fuser's compilation # see the discussion in PR #14957. def f(x, y): return torch.sigmoid(x + y) b = torch.randn(5, 5, requires_grad=True) a = torch.randn(5, 5, requires_grad=True) s = self.checkScript(f, (a, b)) self.assertAllFused(s.graph_for(a, b), except_for={ 'aten::size', 'aten::_size_if_not_equal', 'prim::BroadcastSizes'}) c = s(a, b) results = warmup_backward(c.sum(), [a, b]) ga2, gb2 = results.pop() graph = backward_graph(s) self.assertAllFused(graph) # check that a, b share storage, i.e. were generated as a single output in the fuser self.assertEqual(ga2.data_ptr(), gb2.data_ptr())