示例#1
0
def generate_pattern_for_node(graph: Graph, sub_graph_pattern: str, node_name: str):
    if sub_graph_pattern == '':
        return node_name
    node_name_components = node_name.split("/")
    cur_name = ''
    matched_index = None  # index of the node name component to start new pattern from
    compiled_pattern = compile(sub_graph_pattern)
    for index in range(0, len(node_name_components)):
        cur_name += node_name_components[index] + "/"
        if match(compiled_pattern, cur_name):
            matched_index = index
            break
    if matched_index is None:
        raise RuntimeError('Node name "{}" does not match pattern "{}"'.format(node_name, sub_graph_pattern))

    if sub_graph_pattern == '' or sub_graph_pattern[-1] != '/':
        sub_graph_pattern += '/'

    sub_graph_nodes = nodes_matching_name_pattern(graph, sub_graph_pattern)
    name_suffix = '/'.join(node_name_components[matched_index + 1:]) + '$'
    if len([node for node in sub_graph_nodes if match(sub_graph_pattern + name_suffix, node)]) == 1:
        return name_suffix

    raise RuntimeError('The pattern that uniquely identifies node "{}" using sub-graph pattern "{}" has not been found'.
                       format(node_name, sub_graph_pattern))
示例#2
0
    def update_custom_replacement_attributes(self, graph: Graph):
        if not self.has('instances') or len(self.instances) == 0:
            raise Error("No instances are defined for replacement with id '{}'. ".format(self.replacement_id) +
                        refer_to_faq_msg(68))

        pattern = self.instances[0]  # use the first instance pattern to find input/output nodes patterns
        # TODO verify that all instances will produce the same sub-graph
        matched_nodes = nodes_matching_name_pattern(graph, pattern)

        output_tensors = set()
        input_nodes_mapping = dict()  # key is the input tensor name, value is the pair: (input_port, output_node_name)
        for src_node_name, dst_node_name, edge_attrs in graph.edges(data=True):
            dst_node = graph.node[dst_node_name]

            # edge outside sub-graph into sub-graph
            if (src_node_name not in matched_nodes) and (dst_node_name in matched_nodes):
                tensor_name = src_node_name + ":" + str(edge_attrs['out'])
                if tensor_name not in input_nodes_mapping:
                    input_nodes_mapping[tensor_name] = list()
                input_nodes_mapping[tensor_name].append((generate_pattern_for_node(graph, pattern, dst_node_name),
                                                         edge_attrs['in']))

            # edge from inside sub-graph to outside sub-graph
            if (src_node_name in matched_nodes) and (dst_node_name not in matched_nodes):
                output_tensors.add(
                    (generate_pattern_for_node(graph, pattern, dst_node['pb'].input[edge_attrs['in']]),
                     edge_attrs['out']))

        for node_name in graph.nodes():
            node = Node(graph, node_name)
            if node_name in matched_nodes and len(node.out_nodes()) == 0 and node['pb'].op != 'Const':
                log.debug("Node {} doesn't have output edges. Consider it output".format(node_name))
                output_tensors.add((generate_pattern_for_node(graph, pattern, node_name), 0))

        if not self.has('inputs') or len(self._replacement_desc['inputs']) == 0:
            self._replacement_desc['inputs'] = [[{'node': desc[0], 'port': desc[1]} for desc in inp]
                                                for inp in sorted(input_nodes_mapping.values())]
            log.debug('Updated inputs of sub-graph for instance "{}"'.format(self.instances))

        if not self.has('outputs') or len(self._replacement_desc['outputs']) == 0:
            self._replacement_desc['outputs'] = [{'node': node, 'port': port} for node, port in sorted(output_tensors)]
            log.debug('Updated outputs of sub-graph for instance "{}"'.format(self.instances))
示例#3
0
    def _match_sub_graph_for_scope(self, graph: Graph, scope_pattern: str):
        """
        :param graph: networkx graph to find sub-graph in.
        :param scope_pattern: regular expression specifying sub-graph scope.
        :return: an object describing matched sub-graph.
        """
        inputs_order = self.replacement_desc.get_inputs_description()
        outputs_order = self.replacement_desc.get_outputs_description()

        for list_nodes in inputs_order:
            for node_name_pattern, port in list_nodes:
                if len(find_object_by_pattern(graph.nodes(), '.*' + node_name_pattern)) == 0:
                    log.info('Node "{} does not exist in the graph". Failed to match sub-graph by scope "{}".'.format(
                        node_name_pattern, self.replacement_desc.id))
                    return None

        matched_nodes = nodes_matching_name_pattern(graph, scope_pattern)
        if len(matched_nodes) == 0:
            log.info('There are no instances of the sub-graph by scope "{}"'.format(scope_pattern))
            return None

        return SubgraphMatch(graph, self.replacement_desc, matched_nodes, inputs_order, outputs_order, scope_pattern)