def _match_sub_graph_for_scope(self, graph: Graph, scope_pattern: str):
        """
        :param graph: networkx graph to find sub-graph in.
        :param scope_pattern: regular expression specifying sub-graph scope.
        :return: an object describing matched sub-graph.
        """
        inputs_order = self.replacement_desc.get_inputs_description()
        outputs_order = self.replacement_desc.get_outputs_description()

        for list_nodes in inputs_order:
            for node_name_pattern, port in list_nodes:
                if len(
                        find_object_by_pattern(graph.nodes(),
                                               '.*' + node_name_pattern)) == 0:
                    log.info(
                        'Node "{} does not exist in the graph". Failed to match sub-graph by scope "{}".'
                        .format(node_name_pattern, self.replacement_desc.id))
                    return None

        matched_nodes = nodes_matching_name_pattern(graph, scope_pattern)
        if len(matched_nodes) == 0:
            log.info(
                'There are no instances of the sub-graph by scope "{}"'.format(
                    scope_pattern))
            return None

        return SubgraphMatch(graph, self.replacement_desc, matched_nodes,
                             inputs_order, outputs_order, scope_pattern)
def generate_pattern_for_node(graph: Graph, sub_graph_pattern: str,
                              node_name: str):
    if sub_graph_pattern == '':
        return node_name
    node_name_components = node_name.split("/")
    cur_name = ''
    matched_index = None  # index of the node name component to start new pattern from
    compiled_pattern = compile(sub_graph_pattern)
    for index in range(0, len(node_name_components)):
        cur_name += node_name_components[index] + "/"
        if match(compiled_pattern, cur_name):
            matched_index = index
            break
    if matched_index is None:
        raise RuntimeError('Node name "{}" does not match pattern "{}"'.format(
            node_name, sub_graph_pattern))

    if sub_graph_pattern == '' or sub_graph_pattern[-1] != '/':
        sub_graph_pattern += '/'

    sub_graph_nodes = nodes_matching_name_pattern(graph, sub_graph_pattern)
    name_suffix = '/'.join(node_name_components[matched_index + 1:]) + '$'
    if len([
            node for node in sub_graph_nodes
            if match(sub_graph_pattern + name_suffix, node)
    ]) == 1:
        return name_suffix

    raise RuntimeError(
        'The pattern that uniquely identifies node "{}" using sub-graph pattern "{}" has not been found'
        .format(node_name, sub_graph_pattern))
Example #3
0
def replace_subgraph_calls(graph: Graph, patterns_string: str):
    """
    The function replaces sub-graphs defined by the node names with single nodes that are executed using the TensorFlow.
    The patterns applied independently, so N patterns produce N TensorFlow call nodes.
    :param graph: networkX graph to operate on.
    :param patterns_string: comma separated list of node names patterns.
    """
    cycle_exist = False
    patterns = patterns_string.split(',')
    for pattern in patterns:
        log.info("Merging nodes using pattern '{}'".format(pattern))
        matched_nodes = nodes_matching_name_pattern(graph, pattern)
        if len(matched_nodes) != 0:
            merge_nodes(graph, matched_nodes)
            try:
                # the function 'find_cycle' raises exception if the cycle is not found
                nx.find_cycle(graph)
                cycle_exist = True
            except nx.exception.NetworkXNoCycle:
                cycle_exist = False
            if cycle_exist:
                log.warning("Graph contains a cycle after merging nodes using pattern '{}'".format(pattern))
    if cycle_exist:
        graph.dump_graph_for_graphviz()
        log.error('graph contains cycle after applying all merge node patterns')
    def update_custom_replacement_attributes(self, graph: Graph):
        if not self.has('instances') or len(self.instances) == 0:
            raise Error(
                "No instances are defined for replacement with id '{}'. ".
                format(self.replacement_id) + refer_to_faq_msg(68))

        pattern = self.instances[
            0]  # use the first instance pattern to find input/output nodes patterns
        # TODO verify that all instances will produce the same sub-graph
        matched_nodes = nodes_matching_name_pattern(graph, pattern)

        output_tensors = set()
        input_nodes_mapping = dict(
        )  # key is the input tensor name, value is the pair: (input_port, output_node_name)
        for src_node_name, dst_node_name, edge_attrs in graph.edges(data=True):
            dst_node = graph.node[dst_node_name]

            # edge outside sub-graph into sub-graph
            if (src_node_name not in matched_nodes) and (dst_node_name
                                                         in matched_nodes):
                tensor_name = src_node_name + ":" + str(edge_attrs['out'])
                if tensor_name not in input_nodes_mapping:
                    input_nodes_mapping[tensor_name] = list()
                input_nodes_mapping[tensor_name].append(
                    (generate_pattern_for_node(graph, pattern, dst_node_name),
                     edge_attrs['in']))

            # edge from inside sub-graph to outside sub-graph
            if (src_node_name in matched_nodes) and (dst_node_name
                                                     not in matched_nodes):
                output_tensors.add((generate_pattern_for_node(
                    graph, pattern, dst_node['pb'].input[edge_attrs['in']]),
                                    edge_attrs['out']))

        for node_name in graph.nodes():
            node = Node(graph, node_name)
            if node_name in matched_nodes and len(
                    node.out_nodes()) == 0 and node['pb'].op != 'Const':
                log.debug(
                    "Node {} doesn't have output edges. Consider it output".
                    format(node_name))
                output_tensors.add(
                    (generate_pattern_for_node(graph, pattern, node_name), 0))

        if not self.has('inputs'):
            self._replacement_desc['inputs'] = [[{
                'node': desc[0],
                'port': desc[1]
            } for desc in inp] for inp in sorted(input_nodes_mapping.values())]
            log.debug('Updated inputs of sub-graph for instance "{}"'.format(
                self.instances))

        if not self.has('outputs'):
            self._replacement_desc['outputs'] = [{
                'node': node,
                'port': port
            } for node, port in sorted(output_tensors)]
            log.debug('Updated outputs of sub-graph for instance "{}"'.format(
                self.instances))