def update_input_edges_attrs(graph: Graph, node: Node, added_nodes: list):
     """
     Copy edge attributes from 'old' input edges of node 'node' to new input sub-graph edges.
     :param graph: graph to operate on
     :param node: Node object that was replaced.
     :param added_nodes: list of nodes names added.
     :return: None
     """
     for old_u, old_v, old_edge_attrs in graph.in_edges(node.id, data=True):
         for new_u, new_v, new_edge_attrs in graph.in_edges(added_nodes, data=True):
             if new_u not in added_nodes:  # external input to the sub-graph
                 if old_u == new_u and old_edge_attrs['out'] == new_edge_attrs['out']:
                     merge_edge_props(new_edge_attrs, old_edge_attrs)  # copy old edge attributes
Esempio n. 2
0
def merge_nodes(graph: Graph, nodes_to_merge_names: list, inputs_desc: list = None,
                outputs_desc: list = None):
    """
    Merges nodes specified in the set 'nodes_to_merge_names' into one mega-node, creating new edges between mega-node
    and inputs/outputs nodes of the mega-node. The added edges contain name of input/output nodes which will be used for
    generation of placeholders and will be saved to the IR xml so IE plug-in know how to map input/output data for the
    layer. Also the function adds protobufs of the nodes of the sub-graph and 'Const' ops consumed by nodes in the
    sub-graph to the node's attribute 'pbs'.
    :param graph: the graph object to operate on.
    :param nodes_to_merge_names: list of nodes names that should be merged into a single node.
    :param inputs_desc: optional list describing input nodes order.
    :param outputs_desc: optional list describing output nodes order.
    """
    if not is_connected_component(graph, nodes_to_merge_names):
        log.warning("The following nodes do not form connected sub-graph: {}".format(nodes_to_merge_names))
        graph.dump_graph_for_graphviz(nodes_to_dump=nodes_to_merge_names)

    new_node_name = graph.unique_id("TFSubgraphCall_")
    log.info("Create new node with name '{}' for nodes '{}'".format(new_node_name, ', '.join(nodes_to_merge_names)))
    graph.add_node(new_node_name)
    new_node_attrs = graph.node[new_node_name]

    new_node_attrs['name'] = new_node_name
    set_tf_custom_call_node_attrs(new_node_attrs)
    new_node = Node(graph, new_node_name)

    added_input_tensors_names = set()  # set of tensors that are were added as input to the sub-graph
    added_new_node_output_tensors = dict()  # key - tensor name, value - out port

    for node_name in nodes_to_merge_names:
        node = Node(graph, node_name)
        add_node_pb_if_not_yet_added(node, new_node)
        # TODO: any improvements?
        for in_node_name, edge_attrs in Node(graph, node_name).get_inputs():
            in_node = Node(graph, in_node_name)

            # internal edges between nodes of the sub-graph
            if in_node_name in nodes_to_merge_names:
                add_node_pb_if_not_yet_added(in_node, new_node)
                continue

            # edge outside of sub-graph into sub-graph
            if in_node_name not in nodes_to_merge_names:
                # we cannot use the 'in_node_name' as a protobuf operation name here
                # because the 'in_node_name' could be a sub-graph matched before.
                input_tensor_name = node.pb.input[edge_attrs['in']]
                if input_tensor_name not in added_input_tensors_names:
                    graph.add_edge(in_node_name, new_node_name,
                                   **merge_edge_props(
                                       {'in': find_input_port(new_node, inputs_desc, node_name, edge_attrs['in']),
                                        'out': edge_attrs['out'],
                                        'internal_input_node_name': input_tensor_name,
                                        'original_dst_node_name': node_name,
                                        'original_dst_port': edge_attrs['in'],
                                        'in_attrs': ['in', 'internal_input_node_name', 'original_dst_node_name',
                                                     'original_dst_port', 'placeholder_name'],
                                        'out_attrs': ['out']},
                                       edge_attrs)
                                   )
                    log.debug("Creating edge from outside of sub-graph to inside sub-graph: {} -> {}".format(
                        in_node_name, new_node_name))
                    added_input_tensors_names.add(input_tensor_name)

        # edge from inside sub-graph to outside sub-graph
        for out_node_name, edge_attrs in Node(graph, node_name).get_outputs():
            if out_node_name not in nodes_to_merge_names:
                log.debug("Creating edge from inside of sub-graph to outside sub-graph: {} -> {}".format(
                    new_node_name, out_node_name))
                out_name = internal_output_name_for_node(node_name, edge_attrs['out'])
                if out_name not in added_new_node_output_tensors.keys():
                    added_new_node_output_tensors[out_name] = find_output_port(new_node, outputs_desc, node_name,
                                                                               edge_attrs['out'])
                graph.add_edge(new_node_name, out_node_name,
                               **merge_edge_props(
                                   {'in': edge_attrs['in'],
                                    'out': added_new_node_output_tensors[out_name],
                                    'internal_output_node_name': out_name,
                                    'in_attrs': ['in', 'internal_input_node_name'],
                                    'out_attrs': ['out', 'internal_output_node_name']},
                                   edge_attrs)
                               )
        new_node['output_tensors_names'] = [val for val in
                                            {v: k for k, v in added_new_node_output_tensors.items()}.values()]

    # add nodes using the same order as in initial GraphDef so we can dump them to IR in "correct" order
    new_node['nodes_order'] = [node for node in graph.graph['initial_nodes_order'] if node in new_node['pbs'].keys()]

    for n in nodes_to_merge_names:
        if graph.has_node(n):  # check if not deleted by another (similar) pattern
            graph.remove_node(n)
    return Node(graph, new_node_name)