def test_dropout(self): def func(x): x = torch.nn.functional.dropout(x) return torch.nn.functional.relu(x) a = torch.randn(4, 4, dtype=torch.float, device='cuda', requires_grad=True) s = torch.jit.script(func) c = s(a) c = s(a) warmup_backward(c.sum()) # skip_check to skip extra bailout nodes in between graph = backward_graph(s, skip_check=True) self.assertAllFused(graph, except_for={'aten::div', 'prim::Constant'})
def test_fuser_deduplication(self): # See that fusion kernel outputs are deduplicated when removing _grad_sum_to_size in the fuser's compilation # see the discussion in PR #14957. def f(x, y): return torch.sigmoid(x + y) b = torch.randn(5, 5, requires_grad=True) a = torch.randn(5, 5, requires_grad=True) s = self.checkScript(f, (a, b)) self.assertAllFused(s.graph_for(a, b), except_for={'aten::size'}) c = s(a, b) ga, gb = torch.autograd.grad(c.sum(), [a, b]) graph = backward_graph(s) self.assertAllFused(graph) # check that a, b share storage, i.e. were generated as a single output in the fuser self.assertEqual(ga.data_ptr(), gb.data_ptr())
def test_lstm_cuda(self): inputs = get_lstm_inputs('cuda', training=True) module = self.checkScript(LSTMCellS, inputs) return forward_graph = module.graph_for(*inputs) self.assertGraphContainsExactly( forward_graph, FUSION_GROUP, 1, consider_subgraphs=True) self.assertTrue(len(strip_profiling_nodes(forward_graph.nodes())) == 2) # Everything is differentiable but TupleConstruct return FileCheck().check("DifferentiableGraph").check_next("TupleConstruct") \ .check_next("return").run(str(forward_graph)) with enable_profiling_mode_for_profiling_tests(True): hy, cy = module(*inputs) warmup_backward((hy + cy).sum()) backward = backward_graph(module) self.assertAllFused(backward, except_for=("aten::t", "aten::mm", "aten::_grad_sum_to_size"))
def test_lstm_cuda(self): inputs = get_lstm_inputs('cuda', training=True) module = self.checkScript(LSTMCellS, inputs) forward_graph = module.graph_for(*inputs) self.assertGraphContainsExactly(forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True) self.assertTrue(len(list(forward_graph.nodes())) == 2) # Everything is differentiable but TupleConstruct return FileCheck().check("DifferentiableGraph").check_next("TupleConstruct") \ .check_next("return").run(str(forward_graph)) hy, cy = module(*inputs) (hy + cy).sum().backward() backward = backward_graph(module) FileCheck().check("FusionGroup_0").check_next("FusionGroup_1") \ .check_not("FusionGroup_2").run(str(backward))
def test_fuser_iou(self): # This checks if most of Intersection over Union is fused. # In particular, the backward contains many _grad_sum_to_size. def iou(b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2): ltx = torch.max(b1x1, b2x1) # [N,M] lty = torch.max(b1y1, b2y1) rbx = torch.min(b1x2, b2x2) rby = torch.min(b1y2, b2y2) w = (rbx - ltx).clamp(min=0, max=float('inf')) # [N,M] h = (rby - lty).clamp(min=0, max=float('inf')) # [N,M] inter = w * h # [N,M] area1 = (b1x2 - b1x1) * (b1y2 - b1y2) # [N,1] area2 = (b2x2 - b2x1) * (b2y2 - b2y2) # [1,M] iou = inter / (area1 + area2 - inter) return iou box1 = torch.randn(5, 4, requires_grad=True) box2 = torch.randn(5, 4, requires_grad=True) # unsqueezing can currently not be fused b1x1 = box1[:, 0].unsqueeze(1) # [N,1] b1y1 = box1[:, 1].unsqueeze(1) b1x2 = box1[:, 2].unsqueeze(1) b1y2 = box1[:, 3].unsqueeze(1) b2x1 = box2[:, 0].unsqueeze(0) # [1,N] b2y1 = box2[:, 1].unsqueeze(0) b2x2 = box2[:, 2].unsqueeze(0) b2y2 = box2[:, 3].unsqueeze(0) s = self.checkScript(iou, (b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2)) self.assertAllFused(s.graph_for(b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2), except_for={'aten::size', 'prim::BroadcastSizes'}) c = s(b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2) torch.autograd.grad(c.sum(), [b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2]) graph = backward_graph(s) self.assertAllFused(graph, except_for={'aten::size', 'prim::BroadcastSizes'})
def test_clamp(self): def func2(a, b): return torch.clamp(a + b, min=0, max=2) def funcInf(a, b): return torch.clamp(a + b, min=0, max=float('inf')) def funcNegInf(a, b): return torch.clamp(a + b, min=float('-inf'), max=0) def funcOptMin(a, b): return torch.clamp(a + b, max=2) def funcOptMax(a, b): return torch.clamp(a + b, min=0) a = torch.randn(4, 4, dtype=torch.float, device='cuda', requires_grad=True) b = torch.randn(4, 4, dtype=torch.float, device='cuda') nan = torch.tensor(float('nan'), dtype=torch.float, device='cuda') funcs = (func2, funcInf, funcNegInf, funcOptMin, funcOptMax) for f, inputs in product(funcs, [[a, b], [a, nan]]): inp1, inp2 = inputs s = self.checkScript(f, (inp1, inp2), profiling=ProfilingMode.PROFILING) self.assertAllFused( s.graph_for(inp1, inp2), except_for={'aten::size', 'aten::_size_if_not_equal'}) c = s(inp1, inp2) with enable_profiling_mode(): warmup_backward(c.sum()) graph = backward_graph(s) self.assertAllFused( graph, except_for={'aten::Float', 'aten::_grad_sum_to_size'})