def _match_sub_graph_for_scope(self, graph: Graph, scope_pattern: str): """ :param graph: networkx graph to find sub-graph in. :param scope_pattern: regular expression specifying sub-graph scope. :return: an object describing matched sub-graph. """ inputs_order = self.replacement_desc.get_inputs_description() outputs_order = self.replacement_desc.get_outputs_description() for list_nodes in inputs_order: for node_name_pattern, port in list_nodes: if len( find_object_by_pattern(graph.nodes(), '.*' + node_name_pattern)) == 0: log.info( 'Node "{} does not exist in the graph". Failed to match sub-graph by scope "{}".' .format(node_name_pattern, self.replacement_desc.id)) return None matched_nodes = nodes_matching_name_pattern(graph, scope_pattern) if len(matched_nodes) == 0: log.info( 'There are no instances of the sub-graph by scope "{}"'.format( scope_pattern)) return None return SubgraphMatch(graph, self.replacement_desc, matched_nodes, inputs_order, outputs_order, scope_pattern)
def generate_pattern_for_node(graph: Graph, sub_graph_pattern: str, node_name: str): if sub_graph_pattern == '': return node_name node_name_components = node_name.split("/") cur_name = '' matched_index = None # index of the node name component to start new pattern from compiled_pattern = compile(sub_graph_pattern) for index in range(0, len(node_name_components)): cur_name += node_name_components[index] + "/" if match(compiled_pattern, cur_name): matched_index = index break if matched_index is None: raise RuntimeError('Node name "{}" does not match pattern "{}"'.format( node_name, sub_graph_pattern)) if sub_graph_pattern == '' or sub_graph_pattern[-1] != '/': sub_graph_pattern += '/' sub_graph_nodes = nodes_matching_name_pattern(graph, sub_graph_pattern) name_suffix = '/'.join(node_name_components[matched_index + 1:]) + '$' if len([ node for node in sub_graph_nodes if match(sub_graph_pattern + name_suffix, node) ]) == 1: return name_suffix raise RuntimeError( 'The pattern that uniquely identifies node "{}" using sub-graph pattern "{}" has not been found' .format(node_name, sub_graph_pattern))
def replace_subgraph_calls(graph: Graph, patterns_string: str): """ The function replaces sub-graphs defined by the node names with single nodes that are executed using the TensorFlow. The patterns applied independently, so N patterns produce N TensorFlow call nodes. :param graph: networkX graph to operate on. :param patterns_string: comma separated list of node names patterns. """ cycle_exist = False patterns = patterns_string.split(',') for pattern in patterns: log.info("Merging nodes using pattern '{}'".format(pattern)) matched_nodes = nodes_matching_name_pattern(graph, pattern) if len(matched_nodes) != 0: merge_nodes(graph, matched_nodes) try: # the function 'find_cycle' raises exception if the cycle is not found nx.find_cycle(graph) cycle_exist = True except nx.exception.NetworkXNoCycle: cycle_exist = False if cycle_exist: log.warning("Graph contains a cycle after merging nodes using pattern '{}'".format(pattern)) if cycle_exist: graph.dump_graph_for_graphviz() log.error('graph contains cycle after applying all merge node patterns')
def update_custom_replacement_attributes(self, graph: Graph): if not self.has('instances') or len(self.instances) == 0: raise Error( "No instances are defined for replacement with id '{}'. ". format(self.replacement_id) + refer_to_faq_msg(68)) pattern = self.instances[ 0] # use the first instance pattern to find input/output nodes patterns # TODO verify that all instances will produce the same sub-graph matched_nodes = nodes_matching_name_pattern(graph, pattern) output_tensors = set() input_nodes_mapping = dict( ) # key is the input tensor name, value is the pair: (input_port, output_node_name) for src_node_name, dst_node_name, edge_attrs in graph.edges(data=True): dst_node = graph.node[dst_node_name] # edge outside sub-graph into sub-graph if (src_node_name not in matched_nodes) and (dst_node_name in matched_nodes): tensor_name = src_node_name + ":" + str(edge_attrs['out']) if tensor_name not in input_nodes_mapping: input_nodes_mapping[tensor_name] = list() input_nodes_mapping[tensor_name].append( (generate_pattern_for_node(graph, pattern, dst_node_name), edge_attrs['in'])) # edge from inside sub-graph to outside sub-graph if (src_node_name in matched_nodes) and (dst_node_name not in matched_nodes): output_tensors.add((generate_pattern_for_node( graph, pattern, dst_node['pb'].input[edge_attrs['in']]), edge_attrs['out'])) for node_name in graph.nodes(): node = Node(graph, node_name) if node_name in matched_nodes and len( node.out_nodes()) == 0 and node['pb'].op != 'Const': log.debug( "Node {} doesn't have output edges. Consider it output". format(node_name)) output_tensors.add( (generate_pattern_for_node(graph, pattern, node_name), 0)) if not self.has('inputs'): self._replacement_desc['inputs'] = [[{ 'node': desc[0], 'port': desc[1] } for desc in inp] for inp in sorted(input_nodes_mapping.values())] log.debug('Updated inputs of sub-graph for instance "{}"'.format( self.instances)) if not self.has('outputs'): self._replacement_desc['outputs'] = [{ 'node': node, 'port': port } for node, port in sorted(output_tensors)] log.debug('Updated outputs of sub-graph for instance "{}"'.format( self.instances))