def test_lstm_traced_cuda(self):
     inputs = get_lstm_inputs('cuda')
     ge = self.checkTrace(LSTMCellF, inputs)
     graph = ge.graph_for(*inputs)
     fusion_groups = self.findFusionGroups(graph)
     self.assertEqual(len(fusion_groups), 1)
     FileCheck().check("Chunk").check("aten::sigmoid").check("aten::tanh").run(str(fusion_groups[0]))
Exemple #2
0
 def test_lstm_traced_cuda(self):
     inputs = get_lstm_inputs('cuda')
     ge = self.checkTrace(LSTMCellF, inputs)
     graph = ge.graph_for(*inputs)
     FileCheck().check_not("Chunk").check_not("aten::add").check_not("aten::sigmoid") \
         .check_not("aten::tanh").check("FusionGroup").check_next("TupleConstruct") \
         .check_next("return").check_not("FusionGroup_1").run(str(graph))
Exemple #3
0
 def test_lstm_traced_cuda(self):
     inputs = get_lstm_inputs('cuda')
     ge = self.checkTrace(LSTMCellF, inputs)
     graph = ge.graph_for(*inputs)
     # .check_not("aten::add") don't get pulled into FusionGroup because of BailOuts
     FileCheck().check_not("Chunk").check_not("aten::sigmoid") \
         .check_not("aten::tanh").check(FUSION_GROUP).check_next("TupleConstruct") \
         .check_next("return").check_not(FUSION_GROUP + "_2").run(str(graph))
Exemple #4
0
 def test_lstm_traced_cpu(self):
     inputs = get_lstm_inputs('cpu')
     try:
         ge = self.checkTrace(LSTMCellF, inputs)
         graph = ge.graph_for(*inputs)
         FileCheck.check("FusionGroup").run(str(graph))
     except RuntimeError as e:
         if 'Failed to compile' in e.args[0]:
             warnings.warn('CPU fuser test has failed! This is not a hard failure, '
                           'because the kernels sometimes trigger bugs in compilers '
                           '(most notably GCC 7.2).')
             raise unittest.SkipTest('Failed to compile')
         else:
             raise
Exemple #5
0
    def test_lstm_cuda(self):
        inputs = get_lstm_inputs('cuda', training=True)
        module = self.checkScript(LSTMCellS, inputs)
        forward_graph = module.graph_for(*inputs)
        self.assertGraphContainsExactly(
            forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True)
        self.assertTrue(len(list(forward_graph.nodes())) == 2)
        # Everything is differentiable but TupleConstruct return
        FileCheck().check("DifferentiableGraph").check_next("TupleConstruct") \
            .check_next("return").run(str(forward_graph))

        hy, cy = module(*inputs)
        (hy + cy).sum().backward()
        backward = backward_graph(module)
        self.assertAllFused(backward, except_for=("aten::t", "aten::mm",
                                                  "aten::_grad_sum_to_size"))
    def test_lstm_cuda(self):
        inputs = get_lstm_inputs('cuda', training=True)
        module = self.checkScript(LSTMCellS, inputs)
        return
        forward_graph = module.graph_for(*inputs)
        self.assertGraphContainsExactly(
            forward_graph, FUSION_GROUP, 1, consider_subgraphs=True)
        self.assertTrue(len(strip_profiling_nodes(forward_graph.nodes())) == 2)
        # Everything is differentiable but TupleConstruct return
        FileCheck().check("DifferentiableGraph").check_next("TupleConstruct") \
            .check_next("return").run(str(forward_graph))

        with enable_profiling_mode_for_profiling_tests(True):
            hy, cy = module(*inputs)
            warmup_backward((hy + cy).sum())
            backward = backward_graph(module)
        self.assertAllFused(backward, except_for=("aten::t", "aten::mm",
                                                  "aten::_grad_sum_to_size"))
Exemple #7
0
    def test_lstm_gates_permutations_cuda(self):
        # lstm has gates = x.mm(w_ih.t()) + hx.mm(w_hh.t()) + b_ih + b_hh.
        # Test that any permutation of this will still result in one FusionGroup.
        choices = ['x.mm(w_ih.t())', 'hx.mm(w_hh.t())', 'b_ih', 'b_hh']
        template = dedent('''
        def cell(x, hx, cx, w_ih, w_hh, b_ih, b_hh):
            gates = {} + {} + {} + {}
            ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
            return ingate * forgetgate * cellgate * outgate
        ''')
        for permutation in permutations(choices, len(choices)):
            code = template.format(*permutation)
            scope = {}
            exec(code, globals(), scope)
            cu = torch.jit.CompilationUnit(code)

            inputs = get_lstm_inputs('cuda', training=False)
            self.assertEqual(cu.cell(*inputs), scope['cell'](*inputs))
            forward_graph = cu.cell.graph_for(*inputs)
            self.assertGraphContainsExactly(forward_graph, FUSION_GROUP, 1)
Exemple #8
0
 def test_lstm_concat_cuda(self):
     inputs = get_lstm_inputs('cuda')
     ge = self.checkTrace(LSTMCellC, inputs)
     graph = ge.graph_for(*inputs)
     FileCheck().check("FusedConcat").check_next("return").run(str(graph))
Exemple #9
0
 def test_lstm_concat_cuda(self):
     inputs = get_lstm_inputs('cuda')
     ge = self.checkTrace(LSTMCellC, inputs)
     graph = ge.graph_for(*inputs)