def distributed_induction(graph: nx.MultiDiGraph, sample: nx.MultiDiGraph, partition_map: PartitionMap, ownership: Set[Vertex]): # Step 1: Get non-sampled edges non-owned nodes edge_queries = [[] for _ in range(mpi.size)] for edge in filter( lambda e: not sample.has_edge(*e) and sample.has_node(e[0]), graph.edges): owners = partition_map.get_owners(edge[1]) edge_queries[random.choice(owners)].append( edge) # Select only one of the owners randomly # Step 2: Resolve induction of owned nodes for edge in edge_queries[mpi.rank]: if edge[1] in ownership: sample.add_edge(*edge) edge_queries[mpi.rank].clear() # Step 3: Query each node's owner for query_inductions(sample, edge_queries, ownership)
def fuse_sequence_of_reshapes(graph: nx.MultiDiGraph): for node in list(graph.nodes()): node = Node(graph, node) if not graph.has_node(node.id): # data node can be already removed continue if (node.has_valid('type') and node.type == 'Reshape' and len(node.out_nodes()) == 1 and node.out_node().has_valid('kind') and node.out_node().kind == 'data' and len(node.out_node().out_nodes()) == 1): log.debug('First phase for Reshape: {}'.format(node.name)) next_op = node.out_node().out_node() log.debug('second node: {}'.format(next_op.graph.node[next_op.id])) if next_op.has_valid('type') and next_op.type == 'Reshape': # Detected Reshape1 --> data --> Reshape2 pattern without side edges # Remove Reshape1 log.debug('Second phase for Reshape: {}'.format(node.name)) remove_op_node_with_data_node(graph, node)
class NxGraph(BaseGraph): """ NxGraph is a wrapper that provides methods to interact with a networkx.MultiDiGraph. NxGraph extends kgx.graph.base_graph.BaseGraph and implements all the methods from BaseGraph. """ def __init__(self): super().__init__() self.graph = MultiDiGraph() self.name = None def add_node(self, node: str, **kwargs: Any) -> None: """ Add a node to the graph. Parameters ---------- node: str Node identifier **kwargs: Any Any additional node properties """ if "data" in kwargs: data = kwargs["data"] else: data = kwargs self.graph.add_node(node, **data) def add_edge(self, subject_node: str, object_node: str, edge_key: str = None, **kwargs: Any) -> None: """ Add an edge to the graph. Parameters ---------- subject_node: str The subject (source) node object_node: str The object (target) node edge_key: Optional[str] The edge key kwargs: Any Any additional edge properties """ if "data" in kwargs: data = kwargs["data"] else: data = kwargs return self.graph.add_edge(subject_node, object_node, key=edge_key, **data) def add_node_attribute(self, node: str, attr_key: str, attr_value: Any) -> None: """ Add an attribute to a given node. Parameters ---------- node: str The node identifier attr_key: str The key for an attribute attr_value: Any The value corresponding to the key """ self.graph.add_node(node, **{attr_key: attr_value}) def add_edge_attribute( self, subject_node: str, object_node: str, edge_key: Optional[str], attr_key: str, attr_value: Any, ) -> None: """ Add an attribute to a given edge. Parameters ---------- subject_node: str The subject (source) node object_node: str The object (target) node edge_key: Optional[str] The edge key attr_key: str The attribute key attr_value: Any The attribute value """ self.graph.add_edge(subject_node, object_node, key=edge_key, **{attr_key: attr_value}) def update_node_attribute(self, node: str, attr_key: str, attr_value: Any, preserve: bool = False) -> Dict: """ Update an attribute of a given node. Parameters ---------- node: str The node identifier attr_key: str The key for an attribute attr_value: Any The value corresponding to the key preserve: bool Whether or not to preserve existing values for the given attr_key Returns ------- Dict A dictionary corresponding to the updated node properties """ node_data = self.graph.nodes[node] updated = prepare_data_dict(node_data, {attr_key: attr_value}, preserve=preserve) self.graph.add_node(node, **updated) return updated def update_edge_attribute( self, subject_node: str, object_node: str, edge_key: Optional[str], attr_key: str, attr_value: Any, preserve: bool = False, ) -> Dict: """ Update an attribute of a given edge. Parameters ---------- subject_node: str The subject (source) node object_node: str The object (target) node edge_key: Optional[str] The edge key attr_key: str The attribute key attr_value: Any The attribute value preserve: bool Whether or not to preserve existing values for the given attr_key Returns ------- Dict A dictionary corresponding to the updated edge properties """ e = self.graph.edges((subject_node, object_node, edge_key), keys=True, data=True) edge_data = list(e)[0][3] updated = prepare_data_dict(edge_data, {attr_key: attr_value}, preserve) self.graph.add_edge(subject_node, object_node, key=edge_key, **updated) return updated def get_node(self, node: str) -> Dict: """ Get a node and its properties. Parameters ---------- node: str The node identifier Returns ------- Dict The node dictionary """ n = {} if self.graph.has_node(node): n = self.graph.nodes[node] return n def get_edge(self, subject_node: str, object_node: str, edge_key: Optional[str] = None) -> Dict: """ Get an edge and its properties. Parameters ---------- subject_node: str The subject (source) node object_node: str The object (target) node edge_key: Optional[str] The edge key Returns ------- Dict The edge dictionary """ e = {} if self.graph.has_edge(subject_node, object_node, edge_key): e = self.graph.get_edge_data(subject_node, object_node, edge_key) return e def nodes(self, data: bool = True) -> Dict: """ Get all nodes in a graph. Parameters ---------- data: bool Whether or not to fetch node properties Returns ------- Dict A dictionary of nodes """ return self.graph.nodes(data) def edges(self, keys: bool = False, data: bool = True) -> Dict: """ Get all edges in a graph. Parameters ---------- keys: bool Whether or not to include edge keys data: bool Whether or not to fetch node properties Returns ------- Dict A dictionary of edges """ return self.graph.edges(keys=keys, data=data) def in_edges(self, node: str, keys: bool = False, data: bool = False) -> List: """ Get all incoming edges for a given node. Parameters ---------- node: str The node identifier keys: bool Whether or not to include edge keys data: bool Whether or not to fetch node properties Returns ------- List A list of edges """ return self.graph.in_edges(node, keys=keys, data=data) def out_edges(self, node: str, keys: bool = False, data: bool = False) -> List: """ Get all outgoing edges for a given node. Parameters ---------- node: str The node identifier keys: bool Whether or not to include edge keys data: bool Whether or not to fetch node properties Returns ------- List A list of edges """ return self.graph.out_edges(node, keys=keys, data=data) def nodes_iter(self) -> Generator: """ Get an iterable to traverse through all the nodes in a graph. Returns ------- Generator A generator for nodes where each element is a Tuple that contains (node_id, node_data) """ for n in self.graph.nodes(data=True): yield n def edges_iter(self) -> Generator: """ Get an iterable to traverse through all the edges in a graph. Returns ------- Generator A generator for edges where each element is a 4-tuple that contains (subject, object, edge_key, edge_data) """ for u, v, k, data in self.graph.edges(keys=True, data=True): yield u, v, k, data def remove_node(self, node: str) -> None: """ Remove a given node from the graph. Parameters ---------- node: str The node identifier """ self.graph.remove_node(node) def remove_edge(self, subject_node: str, object_node: str, edge_key: Optional[str] = None) -> None: """ Remove a given edge from the graph. Parameters ---------- subject_node: str The subject (source) node object_node: str The object (target) node edge_key: Optional[str] The edge key """ self.graph.remove_edge(subject_node, object_node, edge_key) def has_node(self, node: str) -> bool: """ Check whether a given node exists in the graph. Parameters ---------- node: str The node identifier Returns ------- bool Whether or not the given node exists """ return self.graph.has_node(node) def has_edge(self, subject_node: str, object_node: str, edge_key: Optional[str] = None) -> bool: """ Check whether a given edge exists in the graph. Parameters ---------- subject_node: str The subject (source) node object_node: str The object (target) node edge_key: Optional[str] The edge key Returns ------- bool Whether or not the given edge exists """ return self.graph.has_edge(subject_node, object_node, key=edge_key) def number_of_nodes(self) -> int: """ Returns the number of nodes in a graph. Returns ------- int """ return self.graph.number_of_nodes() def number_of_edges(self) -> int: """ Returns the number of edges in a graph. Returns ------- int """ return self.graph.number_of_edges() def degree(self): """ Get the degree of all the nodes in a graph. """ return self.graph.degree() def clear(self) -> None: """ Remove all the nodes and edges in the graph. """ self.graph.clear() @staticmethod def set_node_attributes(graph: BaseGraph, attributes: Dict) -> None: """ Set nodes attributes from a dictionary of key-values. Parameters ---------- graph: kgx.graph.base_graph.BaseGraph The graph to modify attributes: Dict A dictionary of node identifier to key-value pairs """ return set_node_attributes(graph.graph, attributes) @staticmethod def set_edge_attributes(graph: BaseGraph, attributes: Dict) -> None: """ Set nodes attributes from a dictionary of key-values. Parameters ---------- graph: kgx.graph.base_graph.BaseGraph The graph to modify attributes: Dict A dictionary of node identifier to key-value pairs Returns ------- Any """ return set_edge_attributes(graph.graph, attributes) @staticmethod def get_node_attributes(graph: BaseGraph, attr_key: str) -> Dict: """ Get all nodes that have a value for the given attribute ``attr_key``. Parameters ---------- graph: kgx.graph.base_graph.BaseGraph The graph to modify attr_key: str The attribute key Returns ------- Dict A dictionary where nodes are the keys and the values are the attribute values for ``key`` """ return get_node_attributes(graph.graph, attr_key) @staticmethod def get_edge_attributes(graph: BaseGraph, attr_key: str) -> Dict: """ Get all edges that have a value for the given attribute ``attr_key``. Parameters ---------- graph: kgx.graph.base_graph.BaseGraph The graph to modify attr_key: str The attribute key Returns ------- Dict A dictionary where edges are the keys and the values are the attribute values for ``attr_key`` """ return get_edge_attributes(graph.graph, attr_key) @staticmethod def relabel_nodes(graph: BaseGraph, mapping: Dict) -> None: """ Relabel identifiers for a series of nodes based on mappings. Parameters ---------- graph: kgx.graph.base_graph.BaseGraph The graph to modify mapping: Dict A dictionary of mapping where the key is the old identifier and the value is the new identifier. """ relabel_nodes(graph.graph, mapping, copy=False)
class MossNet: def __init__(self, moss_results_dict): '''Create a ``MossNet`` object from a 3D dictionary of downloaded MOSS results Args: ``moss_results_dict`` (``dict``): A 3D dictionary of downloaded MOSS results Returns: ``MossNet``: A ``MossNet`` object ''' if isinstance(moss_results_dict, MultiDiGraph): self.graph = moss_results_dict; return if isinstance(moss_results_dict, str): try: if moss_results_dict.lower().endswith('.gz'): moss_results_dict = load(gopen(moss_results_dict)) else: moss_results_dict = load(open(moss_results_dict,'rb')) except: raise ValueError("Unable to load dictionary: %s" % moss_results_dict) if not isinstance(moss_results_dict, dict): raise TypeError("moss_results_dict must be a 3D dictionary of MOSS results") self.graph = MultiDiGraph() for u in moss_results_dict: u_edges = moss_results_dict[u] if not isinstance(u_edges, dict): raise TypeError("moss_results_dict must be a 3D dictionary of MOSS results") for v in u_edges: u_v_links = u_edges[v] if not isinstance(u_edges[v], dict): raise TypeError("moss_results_dict must be a 3D dictionary of MOSS results") for f in u_v_links: try: left, right = u_v_links[f] except: raise TypeError("moss_results_dict must be a 3D dictionary of MOSS results") self.graph.add_edge(u, v, attr_dict = {'files':f, 'left':left, 'right':right}) def save(self, outfile): '''Save this ``MossNet`` object as a 3D dictionary of MOSS results Args: ``outfile`` (``str``): The desired output file's path ''' out = dict() for u in self.graph.nodes: u_edges = dict(); out[u] = u_edges for v in self.graph.neighbors(u): u_v_links = dict(); u_edges[v] = u_v_links; u_v_edge_data = self.graph.get_edge_data(u,v) for k in u_v_edge_data: edge = u_v_edge_data[k]['attr_dict']; u_v_links[edge['files']] = (edge['left'], edge['right']) if outfile.lower().endswith('.gz'): f = gopen(outfile, mode='wb', compresslevel=9) else: f = open(outfile, 'wb') pkldump(out, f); f.close() def __add__(self, o): if not isinstance(o, MossNet): raise TypeError("unsupported operand type(s) for +: 'MossNet' and '%s'" % type(o).__name__) g = MultiDiGraph() g.add_edges_from(list(self.graph.edges(data=True)) + list(o.graph.edges(data=True))) g.add_nodes_from(list(self.graph.nodes(data=True)) + list(o.graph.nodes(data=True))) return MossNet(g) def get_networkx(self): '''Return a NetworkX ``MultiDiGraph`` equivalent to this ``MossNet`` object Returns: ``MultiDiGraph``: A NetworkX ``DiGraph`` equivalent to this ``MossNet`` object ''' return self.graph.copy() def get_nodes(self): '''Returns a ``set`` of node labels in this ``MossNet`` object Returns: ``set``: The node labels in this ``MossNet`` object ''' return set(self.graph.nodes) def get_pair(self, u, v, style='tuples'): '''Returns the links between nodes ``u`` and ``v`` Args: ``u`` (``str``): A node label ``v`` (``str``): A node label not equal to ``u`` ``style`` (``str``): The representation of a given link * ``"tuples"``: Links are ``((u_percent, u_html), (v_percent, v_html))`` tuples * ``"html"``: Links are HTML representation (one HTML for all links) * ``"htmls"``: Links are HTML representations (one HTML per link) Returns: ``dict``: The links between ``u`` and ``v`` (keys are filenames) ''' if style not in {'tuples', 'html', 'htmls'}: raise ValueError("Invalid link style: %s" % style) if u == v: raise ValueError("u and v cannot be equal: %s" % u) for node in [u,v]: if not self.graph.has_node(node): raise ValueError("Nonexistant node: %s" % node) links = self.graph.get_edge_data(u,v) out = dict() for k in sorted(links.keys(), key=lambda x: links[x]['attr_dict']['files']): d = links[k]['attr_dict'] u_fn, v_fn = d['files'] u_percent, u_html = d['left'] v_percent, v_html = d['right'] if style == 'tuples': out[(u_fn, v_fn)] = ((u_percent, u_html), (v_percent, v_html)) elif style in {'html', 'htmls'}: out[(u_fn, v_fn)] = '<html><table style="width:100%%" border="1"><tr><td colspan="2"><center><b>%s/%s --- %s/%s</b></center></td></tr><tr><td>%s (%d%%)</td><td>%s (%d%%)</td></tr><tr><td><pre>%s</pre></td><td><pre>%s</pre></td></tr></table></html>' % (u, u_fn, v, v_fn, u, u_percent, v, v_percent, u_html, v_html) if style == 'html': out = '<html>' + '<br>'.join(out[fns].replace('<html>','').replace('</html>','') for fns in sorted(out.keys())) + '</html>' return out def get_summary(self, style='html'): '''Returns a summary of this ``MossNet`` Args: ``style`` (``str``): The representation of this ``MossNet`` Returns: ``dict``: A summary of this ``MossNet``, where keys are filenames ''' if style not in {'html'}: raise ValueError("Invalid summary style: %s" % style) matches = list() # list of (u_path, u_percent, v_path, v_percent) tuples for u,v in self.traverse_pairs(order=None): links = self.graph.get_edge_data(u,v) for k in links: d = links[k]['attr_dict'] u_fn, v_fn = d['files'] u_percent, u_html = d['left'] v_percent, v_html = d['right'] matches.append(('%s/%s' % (u,u_fn), u_percent, '%s/%s' % (v,v_fn), v_percent)) matches.sort(reverse=True, key=lambda x: max(x[1],x[3])) return '<html><table style="width:100%%" border="1">%s</table></html>' % ''.join(('<tr><td>%s (%d%%)</td><td>%s (%d%%)</td></tr>' % tup) for tup in matches) def num_links(self, u, v): '''Returns the number of links between ``u`` and ``v`` Args: ``u`` (``str``): A node label ``v`` (``str``): A node label not equal to ``u`` Returns: ``int``: The number of links between ``u`` and ``v`` ''' for node in [u,v]: if not self.graph.has_node(node): raise ValueError("Nonexistant node: %s" % node) return len(self.graph.get_edge_data(u,v)) def num_nodes(self): '''Returns the number of nodes in this ``MossNet`` object Returns: ``int``: The number of nodes in this ``MossNet`` object ''' return self.graph.number_of_nodes() def num_edges(self): '''Returns the number of (undirected) edges in this ``MossNet`` object (including parallel edges) Returns: ``int``: The number of (undirected) edges in this ``MossNet`` object (including parallel edges) ''' return int(self.graph.number_of_edges()/2) def outlier_pairs(self): '''Predict which student pairs are outliers (i.e., too many problem similarities). The distribution of number of links between student pairs (i.e., histogram) is modeled as y = A/(B^x), where x = a number of links, and y = the number of student pairs with that many links Returns: ``list`` of ``tuple``: The student pairs expected to be outliers (in decreasing order of significance) ''' links = dict() # key = number of links; value = set of student pairs that have that number of links for u,v in self.traverse_pairs(): n = self.num_links(u,v) if n not in links: links[n] = set() links[n].add((u,v)) mult = list(); min_links = min(len(s) for s in links.values()); max_links = max(len(s) for s in links.values()) for i in range(min_links, max_links): if i not in links or i+1 not in links or len(links[i+1]) > len(links[i]): break mult.append(float(len(links[i]))/len(links[i+1])) B = sum(mult)/len(mult) A = len(links[min_links]) * (B**min_links) n_cutoff = log(A)/log(B) out = list() for n in sorted(links.keys(), reverse=True): if n < n_cutoff: break for u,v in links[n]: out.append((n,u,v)) return out def traverse_pairs(self, order='descending'): '''Iterate over student pairs Args: ``order`` (``str``): Order to sort pairs in iteration * ``None`` to not sort (may be faster for large/dense graphs) * ``"ascending"`` to sort in ascending order of number of links * ``"descending"`` to sort in descending order of number of links ''' if order not in {None, 'None', 'none', 'ascending', 'descending'}: raise ValueError("Invalid order: %s" % order) nodes = list(self.graph.nodes) pairs = [(u,v) for u in self.graph.nodes for v in self.graph.neighbors(u) if u < v] if order == 'ascending': pairs.sort(key=lambda x: len(self.graph.get_edge_data(x[0],x[1]))) elif order == 'descending': pairs.sort(key=lambda x: len(self.graph.get_edge_data(x[0],x[1])), reverse=True) for pair in pairs: yield pair def export(self, outpath, style='html', gte=0, verbose=False): '''Export the links in this ``MossNet`` in the specified style Args: ``outpath`` (``str``): Path to desired output folder/file ``style`` (``str``): Desired output style ``gte`` (``int``): The minimum number of links for an edge to be exported * ``"dot"`` to export as a GraphViz DOT file * ``"gexf"`` to export as a Graph Exchange XML Format (GEXF) file * ``"html"`` to export one HTML file per pair ``verbose`` (``bool``): ``True`` to show verbose messages, otherwise ``False`` ''' if style not in {'dot', 'gexf', 'html'}: raise ValueError("Invalid export style: %s" % style) if isdir(outpath) or isfile(outpath): raise ValueError("Output path exists: %s" % outpath) if not isinstance(gte, int): raise TypeError("'gte' must be an 'int', but you provided a '%s'" % type(gte).__name__) if gte < 0: raise ValueError("'gte' must be non-negative, but yours was %d" % gte) # export as folder of HTML files if style == 'html': summary = self.get_summary(style='html') pairs = list(self.traverse_pairs(order=None)) makedirs(outpath) f = open('%s/summary.html' % outpath, 'w'); f.write(summary); f.close() for i,pair in enumerate(pairs): if verbose: print("Exporting pair %d of %d..." % (i+1, len(pairs)), end='\r') u,v = pair if self.num_links(u,v) < gte: continue if style == 'html': f = open("%s/%d_%s_%s.html" % (outpath, self.num_links(u,v), u, v), 'w') f.write(self.get_pair(u, v, style='html')) f.close() if verbose: print("Successfully exported %d pairs" % len(pairs)) # export as GraphViz DOT or a GEXF file elif style in {'dot', 'gexf'}: if verbose: print("Computing colors...", end='') max_links = max(self.num_links(u,v) for u,v in self.traverse_pairs()) try: from seaborn import color_palette except: raise RuntimeError("Exporting as a DOT or GEXF file currently requires seaborn") pal = color_palette("Reds", max_links) if verbose: print(" done") print("Computing node information...", end='') nodes = list(self.get_nodes()) index = {u:i for i,u in enumerate(nodes)} if verbose: print(" done") print("Writing output file...", end='') outfile = open(outpath, 'w') if style == 'dot': pal = [str(c).upper() for c in pal.as_hex()] outfile.write("graph G {\n") for u in nodes: outfile.write(' node%d[label="%s"]\n' % (index[u], u)) for u,v in self.traverse_pairs(): curr_num_links = self.num_links(u,v) if curr_num_links < gte: continue outfile.write(' node%d -- node%d[color="%s"]\n' % (index[u], index[v], pal[curr_num_links-1])) outfile.write('}\n') elif style == 'gexf': from datetime import datetime pal = [(int(255*c[0]), int(255*c[1]), int(255*c[2])) for c in pal] outfile.write('<?xml version="1.0" encoding="UTF-8"?>\n') outfile.write('<gexf xmlns="http://www.gexf.net/1.3draft" xmlns:viz="http://www.gexf.net/1.3draft/viz">\n') outfile.write(' <meta lastmodifieddate="%s">\n' % datetime.today().strftime('%Y-%m-%d')) outfile.write(' <creator>MossNet</creator>\n') outfile.write(' <description>A MossNet network exported to GEXF</description>\n') outfile.write(' </meta>\n') outfile.write(' <graph mode="static" defaultedgetype="undirected">\n') outfile.write(' <nodes>\n') for u in nodes: outfile.write(' <node id="%d" label="%s"/>\n' % (index[u], u)) outfile.write(' </nodes>\n') outfile.write(' <edges>\n') for i,pair in enumerate(self.traverse_pairs()): u,v = pair curr_num_links = self.num_links(u,v) if curr_num_links == 0: continue color = pal[curr_num_links-1] outfile.write(' <edge id="%d" source="%d" target="%d">\n' % (i, index[u], index[v])) outfile.write(' <viz:color r="%d" g="%d" b="%d"/>\n' % (color[0], color[1], color[2])) outfile.write(' </edge>\n') outfile.write(' </edges>\n') outfile.write(' </graph>\n') outfile.write('</gexf>\n') outfile.close() if verbose: print(" done")
def merge_nodes(graph: nx.MultiDiGraph, nodes_to_merge_names: list, inputs_desc: list = None, outputs_desc: list = None): """ Merges nodes specified in the set 'nodes_to_merge_names' into one mega-node, creating new edges between mega-node and inputs/outputs nodes of the mega-node. The added edges contain name of input/output nodes which will be used for generation of placeholders and will be saved to the IR xml so IE plug-in know how to map input/output data for the layer. Also the function adds protobufs of the nodes of the sub-graph and 'Const' ops consumed by nodes in the sub-graph to the node's attribute 'pbs'. :param graph: the graph object to operate on. :param nodes_to_merge_names: list of nodes names that should be merged into a single node. :param inputs_desc: optional list describing input nodes order. :param outputs_desc: optional list describing output nodes order. """ if not is_connected_component(graph, nodes_to_merge_names): log.warning( "The following nodes do not form connected sub-graph: {}".format( nodes_to_merge_names)) dump_graph_for_graphviz(graph, nodes_to_dump=nodes_to_merge_names) new_node_name = unique_id(graph, "TFSubgraphCall_") log.info("Create new node with name '{}' for nodes '{}'".format( new_node_name, ', '.join(nodes_to_merge_names))) graph.add_node(new_node_name) new_node_attrs = graph.node[new_node_name] new_node_attrs['name'] = new_node_name set_tf_custom_call_node_attrs(new_node_attrs) new_node = Node(graph, new_node_name) added_input_tensors_names = set( ) # set of tensors that are were added as input to the sub-graph added_new_node_output_tensors = dict( ) # key - tensor name, value - out port for node_name in nodes_to_merge_names: node = Node(graph, node_name) add_node_pb_if_not_yet_added(node, new_node) for in_node_name, edge_attrs in get_inputs(graph, node_name): in_node = Node(graph, in_node_name) # internal edges between nodes of the sub-graph if in_node_name in nodes_to_merge_names: add_node_pb_if_not_yet_added(in_node, new_node) continue # edge outside of sub-graph into sub-graph if in_node_name not in nodes_to_merge_names: # we cannot use the 'in_node_name' as a protobuf operation name here # because the 'in_node_name' could be a sub-graph matched before. input_tensor_name = node.pb.input[edge_attrs['in']] if input_tensor_name not in added_input_tensors_names: graph.add_edge( in_node_name, new_node_name, **merge_edge_props( { 'in': find_input_port(new_node, inputs_desc, node_name, edge_attrs['in']), 'out': edge_attrs['out'], 'internal_input_node_name': input_tensor_name, 'original_dst_node_name': node_name, 'original_dst_port': edge_attrs['in'], 'in_attrs': [ 'in', 'internal_input_node_name', 'original_dst_node_name', 'original_dst_port', 'placeholder_name' ], 'out_attrs': ['out'] }, edge_attrs)) log.debug( "Creating edge from outside of sub-graph to inside sub-graph: {} -> {}" .format(in_node_name, new_node_name)) added_input_tensors_names.add(input_tensor_name) # edge from inside sub-graph to outside sub-graph for out_node_name, edge_attrs in get_outputs(graph, node_name): if out_node_name not in nodes_to_merge_names: log.debug( "Creating edge from inside of sub-graph to outside sub-graph: {} -> {}" .format(new_node_name, out_node_name)) out_name = internal_output_name_for_node( node_name, edge_attrs['out']) if out_name not in added_new_node_output_tensors.keys(): added_new_node_output_tensors[out_name] = find_output_port( new_node, outputs_desc, node_name, edge_attrs['out']) graph.add_edge( new_node_name, out_node_name, **merge_edge_props( { 'in': edge_attrs['in'], 'out': added_new_node_output_tensors[out_name], 'internal_output_node_name': out_name, 'in_attrs': ['in', 'internal_input_node_name'], 'out_attrs': ['out', 'internal_output_node_name'] }, edge_attrs)) new_node['output_tensors_names'] = [ val for val in {v: k for k, v in added_new_node_output_tensors.items()}.values() ] # add nodes using the same order as in initial GraphDef so we can dump them to IR in "correct" order new_node['nodes_order'] = [ node for node in graph.graph['initial_nodes_order'] if node in new_node['pbs'].keys() ] for n in nodes_to_merge_names: if graph.has_node( n): # check if not deleted by another (similar) pattern graph.remove_node(n) return Node(graph, new_node_name)