def test_lstm_traced_cuda(self): inputs = get_lstm_inputs('cuda') ge = self.checkTrace(LSTMCellF, inputs) graph = ge.graph_for(*inputs) fusion_groups = self.findFusionGroups(graph) self.assertEqual(len(fusion_groups), 1) FileCheck().check("Chunk").check("aten::sigmoid").check("aten::tanh").run(str(fusion_groups[0]))
def test_lstm_traced_cuda(self): inputs = get_lstm_inputs('cuda') ge = self.checkTrace(LSTMCellF, inputs) graph = ge.graph_for(*inputs) FileCheck().check_not("Chunk").check_not("aten::add").check_not("aten::sigmoid") \ .check_not("aten::tanh").check("FusionGroup").check_next("TupleConstruct") \ .check_next("return").check_not("FusionGroup_1").run(str(graph))
def test_lstm_traced_cuda(self): inputs = get_lstm_inputs('cuda') ge = self.checkTrace(LSTMCellF, inputs) graph = ge.graph_for(*inputs) # .check_not("aten::add") don't get pulled into FusionGroup because of BailOuts FileCheck().check_not("Chunk").check_not("aten::sigmoid") \ .check_not("aten::tanh").check(FUSION_GROUP).check_next("TupleConstruct") \ .check_next("return").check_not(FUSION_GROUP + "_2").run(str(graph))
def test_lstm_traced_cpu(self): inputs = get_lstm_inputs('cpu') try: ge = self.checkTrace(LSTMCellF, inputs) graph = ge.graph_for(*inputs) FileCheck.check("FusionGroup").run(str(graph)) except RuntimeError as e: if 'Failed to compile' in e.args[0]: warnings.warn('CPU fuser test has failed! This is not a hard failure, ' 'because the kernels sometimes trigger bugs in compilers ' '(most notably GCC 7.2).') raise unittest.SkipTest('Failed to compile') else: raise
def test_lstm_cuda(self): inputs = get_lstm_inputs('cuda', training=True) module = self.checkScript(LSTMCellS, inputs) forward_graph = module.graph_for(*inputs) self.assertGraphContainsExactly( forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True) self.assertTrue(len(list(forward_graph.nodes())) == 2) # Everything is differentiable but TupleConstruct return FileCheck().check("DifferentiableGraph").check_next("TupleConstruct") \ .check_next("return").run(str(forward_graph)) hy, cy = module(*inputs) (hy + cy).sum().backward() backward = backward_graph(module) self.assertAllFused(backward, except_for=("aten::t", "aten::mm", "aten::_grad_sum_to_size"))
def test_lstm_cuda(self): inputs = get_lstm_inputs('cuda', training=True) module = self.checkScript(LSTMCellS, inputs) return forward_graph = module.graph_for(*inputs) self.assertGraphContainsExactly( forward_graph, FUSION_GROUP, 1, consider_subgraphs=True) self.assertTrue(len(strip_profiling_nodes(forward_graph.nodes())) == 2) # Everything is differentiable but TupleConstruct return FileCheck().check("DifferentiableGraph").check_next("TupleConstruct") \ .check_next("return").run(str(forward_graph)) with enable_profiling_mode_for_profiling_tests(True): hy, cy = module(*inputs) warmup_backward((hy + cy).sum()) backward = backward_graph(module) self.assertAllFused(backward, except_for=("aten::t", "aten::mm", "aten::_grad_sum_to_size"))
def test_lstm_gates_permutations_cuda(self): # lstm has gates = x.mm(w_ih.t()) + hx.mm(w_hh.t()) + b_ih + b_hh. # Test that any permutation of this will still result in one FusionGroup. choices = ['x.mm(w_ih.t())', 'hx.mm(w_hh.t())', 'b_ih', 'b_hh'] template = dedent(''' def cell(x, hx, cx, w_ih, w_hh, b_ih, b_hh): gates = {} + {} + {} + {} ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) return ingate * forgetgate * cellgate * outgate ''') for permutation in permutations(choices, len(choices)): code = template.format(*permutation) scope = {} exec(code, globals(), scope) cu = torch.jit.CompilationUnit(code) inputs = get_lstm_inputs('cuda', training=False) self.assertEqual(cu.cell(*inputs), scope['cell'](*inputs)) forward_graph = cu.cell.graph_for(*inputs) self.assertGraphContainsExactly(forward_graph, FUSION_GROUP, 1)
def test_lstm_concat_cuda(self): inputs = get_lstm_inputs('cuda') ge = self.checkTrace(LSTMCellC, inputs) graph = ge.graph_for(*inputs) FileCheck().check("FusedConcat").check_next("return").run(str(graph))
def test_lstm_concat_cuda(self): inputs = get_lstm_inputs('cuda') ge = self.checkTrace(LSTMCellC, inputs) graph = ge.graph_for(*inputs)