示例#1
0
    def test_case2(self):
        graph = build_graph(
            nodes, [('Parameter1', 'Op1'),
                    ('Op1', 'FakeOutput1', {
                        'in': 1,
                        'out': 1,
                        'fw_tensor_debug_info': [('Op1', 0, 'Op1_tensor_name')]
                    }), ('Parameter1', 'Op2'),
                    ('Op2', 'FakeOutput2', {
                        'in': 2,
                        'out': 3,
                        'fw_tensor_debug_info': [('Op2', 0, 'Op2_tensor_name')]
                    })])
        graph.graph['packed_outputs'] = None
        graph.graph['user_shapes'] = None

        graph.stage = 'front'
        OutputCut().find_and_replace_pattern(graph)

        op1 = Node(graph, 'Op1')
        op2 = Node(graph, 'Op2')
        self.assertTrue(op1.out_node(1)['type'] == 'Result')
        self.assertTrue(op2.out_node(3)['type'] == 'Result')
        self.assertTrue(
            op1.out_edge(1)['fw_tensor_debug_info'] == [('Op1', 0,
                                                         'Op1_tensor_name')])
        self.assertTrue(
            op2.out_edge(3)['fw_tensor_debug_info'] == [('Op2', 0,
                                                         'Op2_tensor_name')])
        self.assertTrue(graph.get_op_nodes(name='FakeOutput1') == [])
        self.assertTrue(graph.get_op_nodes(name='FakeOutput2') == [])
示例#2
0
    def test_case3(self):
        graph = build_graph(nodes, [])
        graph.graph['packed_outputs'] = None
        graph.graph['user_shapes'] = None

        graph.stage = 'front'
        OutputCut().find_and_replace_pattern(graph)

        self.assertTrue(graph.get_op_nodes(name='FakeOutput1') == [])
        self.assertTrue(graph.get_op_nodes(name='FakeOutput2') == [])
示例#3
0
    def test_case1(self):
        graph = build_graph(nodes, [('Parameter1', 'FakeOutput1', {
            'in':
            0,
            'out':
            0,
            'fw_tensor_debug_info': [('Parameter1', 'Parameter1_tensor_name')]
        })])
        graph.graph['packed_outputs'] = None
        graph.graph['user_shapes'] = None

        graph.stage = 'front'
        OutputCut().find_and_replace_pattern(graph)

        param1 = Node(graph, 'Parameter1')
        self.assertTrue(param1.out_node()['type'] == 'Result')
        self.assertTrue(param1.out_edge()['fw_tensor_debug_info'] == [(
            'Parameter1', 'Parameter1_tensor_name')])
        self.assertTrue(graph.get_op_nodes(name='FakeOutput1') == [])