Пример #1
0
    def find_and_replace_pattern(self, graph: Graph):
        add_output_ops(graph,
                       graph.graph['packed_outputs'],
                       inputs=graph.graph['user_shapes'])

        # For keeping tensor names information for output nodes fake outputs are added
        # to graph during the model loading. In the following code fake outputs are removed
        # and tensor names information is moved to output->Result edge.
        for node in graph.get_op_nodes(needs_removal=True):
            fw_info = None
            in_node = None
            for in_port_idx in node.in_edges():
                node_idx = node.in_edge(in_port_idx)['in']
                if node_idx in node.in_nodes():
                    in_node = node.in_node(node_idx)
                    fw_info_value = get_edge_attribute_between_nodes(
                        in_node, node, 'fw_tensor_debug_info')
                    if fw_info_value:
                        fw_info = fw_info_value
                        break
            graph.erase_node(node)

            if fw_info is not None and in_node is not None:
                for out_idx in in_node.out_nodes():
                    set_edge_attribute_between_nodes(in_node,
                                                     in_node.out_node(out_idx),
                                                     'fw_tensor_debug_info',
                                                     fw_info)
Пример #2
0
 def test_output_port_cut(self, output):
     nodes = {'A': {'op': 'Parameter', 'kind': 'op'},
              'B': {'op': 'Parameter', 'kind': 'op'},
              'C': {'type': 'Identity', 'kind': 'op', 'op': 'Identity'},
              'D': {'type': 'Identity', 'kind': 'op', 'op': 'Identity'},
              'E': {'type': 'Identity', 'kind': 'op', 'op': 'Identity'},
              }
     edges = [
         ('A', 'C', {'in': 0, 'out': 0}),
         ('B', 'C', {'in': 1, 'out': 0}),
         ('C', 'D', {'in': 0, 'out': 0}),
         ('C', 'E', {'in': 0, 'out': 1})
     ]
     graph = build_graph_with_edge_attrs(nodes, edges)
     sinks = add_output_ops(graph, output)
     graph.clean_up()
     self.assertEqual(len(graph.nodes()), 2)