def update_input_edges_attrs(graph: Graph, node: Node, added_nodes: list): """ Copy edge attributes from 'old' input edges of node 'node' to new input sub-graph edges. :param graph: graph to operate on :param node: Node object that was replaced. :param added_nodes: list of nodes names added. :return: None """ for old_u, old_v, old_edge_attrs in graph.in_edges(node.id, data=True): for new_u, new_v, new_edge_attrs in graph.in_edges(added_nodes, data=True): if new_u not in added_nodes: # external input to the sub-graph if old_u == new_u and old_edge_attrs['out'] == new_edge_attrs['out']: merge_edge_props(new_edge_attrs, old_edge_attrs) # copy old edge attributes
def merge_nodes(graph: Graph, nodes_to_merge_names: list, inputs_desc: list = None, outputs_desc: list = None): """ Merges nodes specified in the set 'nodes_to_merge_names' into one mega-node, creating new edges between mega-node and inputs/outputs nodes of the mega-node. The added edges contain name of input/output nodes which will be used for generation of placeholders and will be saved to the IR xml so IE plug-in know how to map input/output data for the layer. Also the function adds protobufs of the nodes of the sub-graph and 'Const' ops consumed by nodes in the sub-graph to the node's attribute 'pbs'. :param graph: the graph object to operate on. :param nodes_to_merge_names: list of nodes names that should be merged into a single node. :param inputs_desc: optional list describing input nodes order. :param outputs_desc: optional list describing output nodes order. """ if not is_connected_component(graph, nodes_to_merge_names): log.warning("The following nodes do not form connected sub-graph: {}".format(nodes_to_merge_names)) graph.dump_graph_for_graphviz(nodes_to_dump=nodes_to_merge_names) new_node_name = graph.unique_id("TFSubgraphCall_") log.info("Create new node with name '{}' for nodes '{}'".format(new_node_name, ', '.join(nodes_to_merge_names))) graph.add_node(new_node_name) new_node_attrs = graph.node[new_node_name] new_node_attrs['name'] = new_node_name set_tf_custom_call_node_attrs(new_node_attrs) new_node = Node(graph, new_node_name) added_input_tensors_names = set() # set of tensors that are were added as input to the sub-graph added_new_node_output_tensors = dict() # key - tensor name, value - out port for node_name in nodes_to_merge_names: node = Node(graph, node_name) add_node_pb_if_not_yet_added(node, new_node) # TODO: any improvements? for in_node_name, edge_attrs in Node(graph, node_name).get_inputs(): in_node = Node(graph, in_node_name) # internal edges between nodes of the sub-graph if in_node_name in nodes_to_merge_names: add_node_pb_if_not_yet_added(in_node, new_node) continue # edge outside of sub-graph into sub-graph if in_node_name not in nodes_to_merge_names: # we cannot use the 'in_node_name' as a protobuf operation name here # because the 'in_node_name' could be a sub-graph matched before. input_tensor_name = node.pb.input[edge_attrs['in']] if input_tensor_name not in added_input_tensors_names: graph.add_edge(in_node_name, new_node_name, **merge_edge_props( {'in': find_input_port(new_node, inputs_desc, node_name, edge_attrs['in']), 'out': edge_attrs['out'], 'internal_input_node_name': input_tensor_name, 'original_dst_node_name': node_name, 'original_dst_port': edge_attrs['in'], 'in_attrs': ['in', 'internal_input_node_name', 'original_dst_node_name', 'original_dst_port', 'placeholder_name'], 'out_attrs': ['out']}, edge_attrs) ) log.debug("Creating edge from outside of sub-graph to inside sub-graph: {} -> {}".format( in_node_name, new_node_name)) added_input_tensors_names.add(input_tensor_name) # edge from inside sub-graph to outside sub-graph for out_node_name, edge_attrs in Node(graph, node_name).get_outputs(): if out_node_name not in nodes_to_merge_names: log.debug("Creating edge from inside of sub-graph to outside sub-graph: {} -> {}".format( new_node_name, out_node_name)) out_name = internal_output_name_for_node(node_name, edge_attrs['out']) if out_name not in added_new_node_output_tensors.keys(): added_new_node_output_tensors[out_name] = find_output_port(new_node, outputs_desc, node_name, edge_attrs['out']) graph.add_edge(new_node_name, out_node_name, **merge_edge_props( {'in': edge_attrs['in'], 'out': added_new_node_output_tensors[out_name], 'internal_output_node_name': out_name, 'in_attrs': ['in', 'internal_input_node_name'], 'out_attrs': ['out', 'internal_output_node_name']}, edge_attrs) ) new_node['output_tensors_names'] = [val for val in {v: k for k, v in added_new_node_output_tensors.items()}.values()] # add nodes using the same order as in initial GraphDef so we can dump them to IR in "correct" order new_node['nodes_order'] = [node for node in graph.graph['initial_nodes_order'] if node in new_node['pbs'].keys()] for n in nodes_to_merge_names: if graph.has_node(n): # check if not deleted by another (similar) pattern graph.remove_node(n) return Node(graph, new_node_name)