def generate_pattern_for_node(graph: Graph, sub_graph_pattern: str, node_name: str): if sub_graph_pattern == '': return node_name node_name_components = node_name.split("/") cur_name = '' matched_index = None # index of the node name component to start new pattern from compiled_pattern = compile(sub_graph_pattern) for index in range(0, len(node_name_components)): cur_name += node_name_components[index] + "/" if match(compiled_pattern, cur_name): matched_index = index break if matched_index is None: raise RuntimeError('Node name "{}" does not match pattern "{}"'.format(node_name, sub_graph_pattern)) if sub_graph_pattern == '' or sub_graph_pattern[-1] != '/': sub_graph_pattern += '/' sub_graph_nodes = nodes_matching_name_pattern(graph, sub_graph_pattern) name_suffix = '/'.join(node_name_components[matched_index + 1:]) + '$' if len([node for node in sub_graph_nodes if match(sub_graph_pattern + name_suffix, node)]) == 1: return name_suffix raise RuntimeError('The pattern that uniquely identifies node "{}" using sub-graph pattern "{}" has not been found'. format(node_name, sub_graph_pattern))
def update_custom_replacement_attributes(self, graph: Graph): if not self.has('instances') or len(self.instances) == 0: raise Error("No instances are defined for replacement with id '{}'. ".format(self.replacement_id) + refer_to_faq_msg(68)) pattern = self.instances[0] # use the first instance pattern to find input/output nodes patterns # TODO verify that all instances will produce the same sub-graph matched_nodes = nodes_matching_name_pattern(graph, pattern) output_tensors = set() input_nodes_mapping = dict() # key is the input tensor name, value is the pair: (input_port, output_node_name) for src_node_name, dst_node_name, edge_attrs in graph.edges(data=True): dst_node = graph.node[dst_node_name] # edge outside sub-graph into sub-graph if (src_node_name not in matched_nodes) and (dst_node_name in matched_nodes): tensor_name = src_node_name + ":" + str(edge_attrs['out']) if tensor_name not in input_nodes_mapping: input_nodes_mapping[tensor_name] = list() input_nodes_mapping[tensor_name].append((generate_pattern_for_node(graph, pattern, dst_node_name), edge_attrs['in'])) # edge from inside sub-graph to outside sub-graph if (src_node_name in matched_nodes) and (dst_node_name not in matched_nodes): output_tensors.add( (generate_pattern_for_node(graph, pattern, dst_node['pb'].input[edge_attrs['in']]), edge_attrs['out'])) for node_name in graph.nodes(): node = Node(graph, node_name) if node_name in matched_nodes and len(node.out_nodes()) == 0 and node['pb'].op != 'Const': log.debug("Node {} doesn't have output edges. Consider it output".format(node_name)) output_tensors.add((generate_pattern_for_node(graph, pattern, node_name), 0)) if not self.has('inputs') or len(self._replacement_desc['inputs']) == 0: self._replacement_desc['inputs'] = [[{'node': desc[0], 'port': desc[1]} for desc in inp] for inp in sorted(input_nodes_mapping.values())] log.debug('Updated inputs of sub-graph for instance "{}"'.format(self.instances)) if not self.has('outputs') or len(self._replacement_desc['outputs']) == 0: self._replacement_desc['outputs'] = [{'node': node, 'port': port} for node, port in sorted(output_tensors)] log.debug('Updated outputs of sub-graph for instance "{}"'.format(self.instances))
def _match_sub_graph_for_scope(self, graph: Graph, scope_pattern: str): """ :param graph: networkx graph to find sub-graph in. :param scope_pattern: regular expression specifying sub-graph scope. :return: an object describing matched sub-graph. """ inputs_order = self.replacement_desc.get_inputs_description() outputs_order = self.replacement_desc.get_outputs_description() for list_nodes in inputs_order: for node_name_pattern, port in list_nodes: if len(find_object_by_pattern(graph.nodes(), '.*' + node_name_pattern)) == 0: log.info('Node "{} does not exist in the graph". Failed to match sub-graph by scope "{}".'.format( node_name_pattern, self.replacement_desc.id)) return None matched_nodes = nodes_matching_name_pattern(graph, scope_pattern) if len(matched_nodes) == 0: log.info('There are no instances of the sub-graph by scope "{}"'.format(scope_pattern)) return None return SubgraphMatch(graph, self.replacement_desc, matched_nodes, inputs_order, outputs_order, scope_pattern)