Example #1
0
 def test_sub_graph_between_nodes_branches_included(self):
     """
     Check that the function works correctly for tree like structures.
     1 -> 2 -> 3 -> 4
          \
          5 -> 6
         / \
     9 ->   -> 7 -> 8
     """
     graph = Graph()
     node_names = list(range(1, 10))
     graph.add_nodes_from(node_names)
     graph.add_edges_from([(1, 2), (2, 3), (3, 4), (2, 5), (5, 6), (5, 7),
                           (7, 8), (9, 5)])
     self.assertListEqual(sorted(sub_graph_between_nodes(graph, [1], [4])),
                          node_names)
     self.assertListEqual(sorted(sub_graph_between_nodes(graph, [1], [6])),
                          node_names)
     self.assertListEqual(sorted(sub_graph_between_nodes(graph, [1], [8])),
                          node_names)
     # all nodes except 4 because it is a child of end node
     self.assertListEqual(sorted(sub_graph_between_nodes(graph, [1], [3])),
                          [n for n in node_names if n != 4])
     # all nodes except 1 because it is a parent node child of start node. The nodes 3 and 4 must be added because
     # after merging node 2 into sub-graph the node 2 will be removed and it is not known how to calculate the tensor
     # between node 2 and 3.
     self.assertListEqual(sorted(sub_graph_between_nodes(graph, [2], [8])),
                          [n for n in node_names if n != 1])
Example #2
0
    def test_sub_graph_between_nodes_include_incoming_edges_for_internal_nodes(
            self):
        """
        Check that the function adds input nodes for the internal nodes of the graph. For example, we need to add node 5
        and 6 in the case below if we find match from node 1 till node 4.
        6 -> 5 ->
                 \
            1 -> 2 -> 3 -> 4
        :return:
        """
        graph = Graph()
        graph.add_nodes_from(list(range(1, 7)))
        graph.add_edges_from([(1, 2), (2, 3), (3, 4), (5, 2), (6, 5)])
        sub_graph_nodes = sub_graph_between_nodes(graph, [1], [4])
        self.assertIsNotNone(sub_graph_nodes)
        self.assertListEqual(sorted(sub_graph_nodes), list(range(1, 7)))

        sub_graph_nodes = sub_graph_between_nodes(graph, [1], [2])
        self.assertIsNotNone(sub_graph_nodes)
        self.assertListEqual(sorted(sub_graph_nodes), [1, 2, 5, 6])
def get_body(graph, inputs, outputs):
    if len(inputs) == 0:
        nodes, extra_inputs = invert_sub_graph_between_nodes(
            graph, outputs, inputs,
            lambda node: node.soft_get('op') == 'TensorIteratorInput')
    else:
        nodes, extra_inputs = sub_graph_between_nodes(
            graph, inputs, outputs,
            lambda node: node.soft_get('op') == 'TensorIteratorInput')
    nodes = list(set(nodes) - set(inputs) - set(outputs) - set(extra_inputs))
    return nodes, extra_inputs
Example #4
0
 def test_sub_graph_between_nodes_multiple_inputs(self):
     """
     Check that the function works correctly when multiple inputs specified.
       5->
          \
     1 -> 2 -> 3 -> 4
     """
     graph = Graph()
     graph.add_nodes_from(list(range(1, 6)))
     graph.add_edges_from([(1, 2), (2, 3), (3, 4), (5, 2)])
     sub_graph_nodes = sub_graph_between_nodes(graph, [2, 5], [4])
     self.assertIsNotNone(sub_graph_nodes)
     self.assertListEqual(sorted(sub_graph_nodes), sorted([2, 3, 4, 5]))
Example #5
0
 def test_sub_graph_between_nodes_placeholder_excluded(self):
     """
     Check that the function do not check that node is Placeholders for the nodes not included into the sub-graph.
     For example, node 5 is Placeholder but it is not included into the sub-graph, so this attribute is ignored.
       5->
          \
     1 -> 2 -> 3 -> 4
     """
     graph = Graph()
     graph.add_nodes_from(list(range(1, 6)))
     graph.node[5]['op'] = 'Parameter'
     graph.add_edges_from([(1, 2), (2, 3), (3, 4), (5, 2)])
     sub_graph_nodes = sub_graph_between_nodes(graph, [2], [4])
     self.assertIsNotNone(sub_graph_nodes)
     self.assertListEqual(sorted(sub_graph_nodes), [2, 3, 4])
Example #6
0
 def test_sub_graph_between_nodes_do_not_include_incoming_edges_for_input_nodes(
         self):
     """
     Check that the function doesn't add input nodes for the start nodes of the sub-graph. For example, we do not
     need to add node 5 in the case below if we find match from node 1 till node 4.
       5->
          \
     1 -> 2 -> 3 -> 4
     """
     graph = Graph()
     graph.add_nodes_from(list(range(1, 6)))
     graph.add_edges_from([(1, 2), (2, 3), (3, 4), (5, 2)])
     sub_graph_nodes = sub_graph_between_nodes(graph, [2], [4])
     self.assertIsNotNone(sub_graph_nodes)
     self.assertListEqual(sorted(sub_graph_nodes), [2, 3, 4])
Example #7
0
 def test_sub_graph_between_nodes_control_flow_not_included_forward(self):
     """
     Check that the function works correctly for case when control flow edges should not be traversed (edge 3 -> 5).
        1 -> 2 -> 3 -> 4
                   \
                    -> 5 -> 6
     """
     graph = Graph()
     graph.add_nodes_from(list(range(1, 7)))
     graph.add_edges_from([(1, 2), (2, 3), (3, 4),
                           (3, 5, {
                               'control_flow_edge': True
                           }), (5, 6)])
     sub_graph_nodes = sub_graph_between_nodes(graph, [1], [4],
                                               include_control_flow=False)
     self.assertIsNotNone(sub_graph_nodes)
     self.assertListEqual(sorted(sub_graph_nodes), sorted([1, 2, 3, 4]))
Example #8
0
 def test_sub_graph_between_nodes_control_flow_included(self):
     """
     Check that the function works correctly for case when control flow edges must be traversed (edge 5 -> 2).
     6 -> 5->
             \
        1 -> 2 -> 3 -> 4
     """
     graph = Graph()
     graph.add_nodes_from(list(range(1, 7)))
     graph.add_edges_from([(1, 2), (2, 3), (3, 4),
                           (5, 2, {
                               'control_flow_edge': True
                           }), (6, 5)])
     sub_graph_nodes = sub_graph_between_nodes(graph, [1], [4],
                                               include_control_flow=True)
     self.assertIsNotNone(sub_graph_nodes)
     self.assertListEqual(sorted(sub_graph_nodes),
                          sorted([1, 2, 3, 4, 5, 6]))
Example #9
0
    def _match_sub_graph_for_points(self, graph: Graph):
        """
        :param graph: networkx graph to find sub-graph in.
        :return: an object describing matched sub-graph.
        """
        start_points = self.replacement_desc.get_internal_input_nodes(graph)
        end_points = self.replacement_desc.get_internal_output_nodes(graph)
        # check that start and end points exist in the graph
        for node_name in start_points + end_points:
            if node_name not in graph.nodes():
                log.info('Node "{}" does not exist in the graph. Failed to match sub-graph by points "{}".'.format(
                    node_name, self.replacement_desc.id))
                return None

        matched_nodes = sub_graph_between_nodes(graph, start_points, end_points, include_control_flow=False)
        return SubgraphMatch(graph, self.replacement_desc, matched_nodes,
                             self.replacement_desc.get_inputs_description(),
                             self.replacement_desc.get_outputs_description(), '')
Example #10
0
    def update_custom_replacement_attributes(self, graph: Graph):
        if not self.has('instances'):
            raise Error("No instance(s) is(are) defined for the custom replacement '{}'. ".format(self.replacement_id) +
                        refer_to_faq_msg(66))
        if not isinstance(self.instances, dict):
            raise Error("The instance must be a single dictionary for the custom replacement with id '{}'. ".format(
                self.replacement_id) +
                        refer_to_faq_msg(67))

        start_points = self.get_internal_input_nodes(graph)
        end_points = self.get_internal_output_nodes(graph)

        matched_nodes = sub_graph_between_nodes(graph, start_points, end_points, include_control_flow=False)
        output_tensors = set()
        input_nodes_mapping = dict()  # key is the input tensor name, value is the pair: (input_port, output_node_name)
        for src_node_name, dst_node_name, edge_attrs in graph.edges(data=True):
            dst_node = graph.node[dst_node_name]

            # edge outside sub-graph into sub-graph
            if (src_node_name not in matched_nodes) and (dst_node_name in matched_nodes):
                tensor_name = src_node_name + ":" + str(edge_attrs['out'])
                if tensor_name not in input_nodes_mapping:
                    input_nodes_mapping[tensor_name] = list()
                input_nodes_mapping[tensor_name].append(('^' + dst_node_name + '$', edge_attrs['in']))

            # edge from inside sub-graph to outside sub-graph
            if (src_node_name in matched_nodes) and (dst_node_name not in matched_nodes):
                output_tensors.add(('^' + dst_node['pb'].input[edge_attrs['in']] + '$', edge_attrs['out']))

        for node_name in graph.nodes():
            node = Node(graph, node_name)
            if node_name in matched_nodes and len(node.out_nodes()) == 0 and node['pb'].op != 'Const':
                log.debug("Node {} doesn't have output edges. Consider it output".format(node_name))
                output_tensors.add(('^' + node_name + '$', 0))

        if not self.has('inputs'):
            self._replacement_desc['inputs'] = [[{'node': desc[0], 'port': desc[1]} for desc in inp]
                                                for inp in sorted(input_nodes_mapping.values())]
            log.debug('Updated inputs of sub-graph for instance "{}"'.format(self.instances))

        if not self.has('outputs'):
            self._replacement_desc['outputs'] = [{'node': node, 'port': port} for node, port in sorted(output_tensors)]
            log.debug('Updated outputs of sub-graph for instance "{}"'.format(self.instances))