def test_grad_sum_to_size_elimination(self): def my_broadcasted_cell(a, b, c): return (a + b) + c s1 = torch.randn(5, 1, requires_grad=True, device='cuda') s2 = torch.randn(5, 5, requires_grad=True, device='cuda') module = self.checkScript(my_broadcasted_cell, (s1, s1, s1)) forward_graph = module.graph_for(s1, s1, s1) self.assertAllFused(forward_graph, except_for=("aten::size", "prim::BroadcastSizes", "aten::_size_if_not_equal")) old_plans = set() for i in range(3): # if we have s2, then the s1 are _grad_sum_to_size'd args = s2 if i < 1 else s1, s2 if i < 2 else s1, s2 args = [a.detach_().requires_grad_() for a in args] res = module(s2 if i < 1 else s1, s2 if i < 2 else s1, s2) grads = torch.autograd.grad(res.sum(), args) for inp, gr in zip(args, grads): self.assertEqual(inp.shape, gr.shape) backward = None # this is a workaround for the backward graphs not being # in order for Python 2 for g in all_backward_graphs(module): if str(g) not in old_plans: assert backward is None backward = g old_plans.add(str(backward)) self.assertEqual(len([1 for o in backward.outputs() if o.node().kind() == "aten::_grad_sum_to_size"]), i) self.assertEqual(len([1 for o in backward.outputs() if o.node().kind() == "prim::Param"]), 3 - i)
def test_grad_sum_to_size_elimination(self): def my_broadcasted_cell(a, b, c): return (a + b) + c s1 = torch.randn(5, 1, requires_grad=True, device='cuda') s2 = torch.randn(5, 5, requires_grad=True, device='cuda') module = self.checkScript(my_broadcasted_cell, (s1, s1, s1), profiling=ProfilingMode.PROFILING) forward_graph = module.graph_for(s1, s1, s1) self.assertAllFused(forward_graph, except_for=("aten::size", "prim::BroadcastSizes", "aten::_size_if_not_equal")) old_plans = set() for i in range(3): # if we have s2, then the s1 are _grad_sum_to_size'd args = s2 if i < 1 else s1, s2 if i < 2 else s1, s2 args = [a.detach_().requires_grad_() for a in args] # recompile, so we don't trigger bailouts module = self.checkScript(my_broadcasted_cell, args, profiling=ProfilingMode.PROFILING) res = module(s2 if i < 1 else s1, s2 if i < 2 else s1, s2) warmup_backward(res.sum(), args) grads = torch.autograd.grad(res.sum(), args) for inp, gr in zip(args, grads): self.assertEqual(inp.shape, gr.shape) backward = None # this is a workaround for the backward graphs not being # in order for Python 2 for g in all_backward_graphs(module): if str(g) not in old_plans: assert backward is None backward = g old_plans.add(str(backward)) num_grads = 1 if i > 0 else 0 self.assertEqual( len([ n for n in backward.nodes() if n.kind() == 'aten::_grad_sum_to_size' ]), num_grads)