def test_sub_graph_between_nodes_branches_included(self): """ Check that the function works correctly for tree like structures. 1 -> 2 -> 3 -> 4 \ 5 -> 6 / \ 9 -> -> 7 -> 8 """ graph = Graph() node_names = list(range(1, 10)) graph.add_nodes_from(node_names) graph.add_edges_from([(1, 2), (2, 3), (3, 4), (2, 5), (5, 6), (5, 7), (7, 8), (9, 5)]) self.assertListEqual(sorted(sub_graph_between_nodes(graph, [1], [4])), node_names) self.assertListEqual(sorted(sub_graph_between_nodes(graph, [1], [6])), node_names) self.assertListEqual(sorted(sub_graph_between_nodes(graph, [1], [8])), node_names) # all nodes except 4 because it is a child of end node self.assertListEqual(sorted(sub_graph_between_nodes(graph, [1], [3])), [n for n in node_names if n != 4]) # all nodes except 1 because it is a parent node child of start node. The nodes 3 and 4 must be added because # after merging node 2 into sub-graph the node 2 will be removed and it is not known how to calculate the tensor # between node 2 and 3. self.assertListEqual(sorted(sub_graph_between_nodes(graph, [2], [8])), [n for n in node_names if n != 1])
def test_sub_graph_between_nodes_include_incoming_edges_for_internal_nodes( self): """ Check that the function adds input nodes for the internal nodes of the graph. For example, we need to add node 5 and 6 in the case below if we find match from node 1 till node 4. 6 -> 5 -> \ 1 -> 2 -> 3 -> 4 :return: """ graph = Graph() graph.add_nodes_from(list(range(1, 7))) graph.add_edges_from([(1, 2), (2, 3), (3, 4), (5, 2), (6, 5)]) sub_graph_nodes = sub_graph_between_nodes(graph, [1], [4]) self.assertIsNotNone(sub_graph_nodes) self.assertListEqual(sorted(sub_graph_nodes), list(range(1, 7))) sub_graph_nodes = sub_graph_between_nodes(graph, [1], [2]) self.assertIsNotNone(sub_graph_nodes) self.assertListEqual(sorted(sub_graph_nodes), [1, 2, 5, 6])
def get_body(graph, inputs, outputs): if len(inputs) == 0: nodes, extra_inputs = invert_sub_graph_between_nodes( graph, outputs, inputs, lambda node: node.soft_get('op') == 'TensorIteratorInput') else: nodes, extra_inputs = sub_graph_between_nodes( graph, inputs, outputs, lambda node: node.soft_get('op') == 'TensorIteratorInput') nodes = list(set(nodes) - set(inputs) - set(outputs) - set(extra_inputs)) return nodes, extra_inputs
def test_sub_graph_between_nodes_multiple_inputs(self): """ Check that the function works correctly when multiple inputs specified. 5-> \ 1 -> 2 -> 3 -> 4 """ graph = Graph() graph.add_nodes_from(list(range(1, 6))) graph.add_edges_from([(1, 2), (2, 3), (3, 4), (5, 2)]) sub_graph_nodes = sub_graph_between_nodes(graph, [2, 5], [4]) self.assertIsNotNone(sub_graph_nodes) self.assertListEqual(sorted(sub_graph_nodes), sorted([2, 3, 4, 5]))
def test_sub_graph_between_nodes_placeholder_excluded(self): """ Check that the function do not check that node is Placeholders for the nodes not included into the sub-graph. For example, node 5 is Placeholder but it is not included into the sub-graph, so this attribute is ignored. 5-> \ 1 -> 2 -> 3 -> 4 """ graph = Graph() graph.add_nodes_from(list(range(1, 6))) graph.node[5]['op'] = 'Parameter' graph.add_edges_from([(1, 2), (2, 3), (3, 4), (5, 2)]) sub_graph_nodes = sub_graph_between_nodes(graph, [2], [4]) self.assertIsNotNone(sub_graph_nodes) self.assertListEqual(sorted(sub_graph_nodes), [2, 3, 4])
def test_sub_graph_between_nodes_do_not_include_incoming_edges_for_input_nodes( self): """ Check that the function doesn't add input nodes for the start nodes of the sub-graph. For example, we do not need to add node 5 in the case below if we find match from node 1 till node 4. 5-> \ 1 -> 2 -> 3 -> 4 """ graph = Graph() graph.add_nodes_from(list(range(1, 6))) graph.add_edges_from([(1, 2), (2, 3), (3, 4), (5, 2)]) sub_graph_nodes = sub_graph_between_nodes(graph, [2], [4]) self.assertIsNotNone(sub_graph_nodes) self.assertListEqual(sorted(sub_graph_nodes), [2, 3, 4])
def test_sub_graph_between_nodes_control_flow_not_included_forward(self): """ Check that the function works correctly for case when control flow edges should not be traversed (edge 3 -> 5). 1 -> 2 -> 3 -> 4 \ -> 5 -> 6 """ graph = Graph() graph.add_nodes_from(list(range(1, 7))) graph.add_edges_from([(1, 2), (2, 3), (3, 4), (3, 5, { 'control_flow_edge': True }), (5, 6)]) sub_graph_nodes = sub_graph_between_nodes(graph, [1], [4], include_control_flow=False) self.assertIsNotNone(sub_graph_nodes) self.assertListEqual(sorted(sub_graph_nodes), sorted([1, 2, 3, 4]))
def test_sub_graph_between_nodes_control_flow_included(self): """ Check that the function works correctly for case when control flow edges must be traversed (edge 5 -> 2). 6 -> 5-> \ 1 -> 2 -> 3 -> 4 """ graph = Graph() graph.add_nodes_from(list(range(1, 7))) graph.add_edges_from([(1, 2), (2, 3), (3, 4), (5, 2, { 'control_flow_edge': True }), (6, 5)]) sub_graph_nodes = sub_graph_between_nodes(graph, [1], [4], include_control_flow=True) self.assertIsNotNone(sub_graph_nodes) self.assertListEqual(sorted(sub_graph_nodes), sorted([1, 2, 3, 4, 5, 6]))
def _match_sub_graph_for_points(self, graph: Graph): """ :param graph: networkx graph to find sub-graph in. :return: an object describing matched sub-graph. """ start_points = self.replacement_desc.get_internal_input_nodes(graph) end_points = self.replacement_desc.get_internal_output_nodes(graph) # check that start and end points exist in the graph for node_name in start_points + end_points: if node_name not in graph.nodes(): log.info('Node "{}" does not exist in the graph. Failed to match sub-graph by points "{}".'.format( node_name, self.replacement_desc.id)) return None matched_nodes = sub_graph_between_nodes(graph, start_points, end_points, include_control_flow=False) return SubgraphMatch(graph, self.replacement_desc, matched_nodes, self.replacement_desc.get_inputs_description(), self.replacement_desc.get_outputs_description(), '')
def update_custom_replacement_attributes(self, graph: Graph): if not self.has('instances'): raise Error("No instance(s) is(are) defined for the custom replacement '{}'. ".format(self.replacement_id) + refer_to_faq_msg(66)) if not isinstance(self.instances, dict): raise Error("The instance must be a single dictionary for the custom replacement with id '{}'. ".format( self.replacement_id) + refer_to_faq_msg(67)) start_points = self.get_internal_input_nodes(graph) end_points = self.get_internal_output_nodes(graph) matched_nodes = sub_graph_between_nodes(graph, start_points, end_points, include_control_flow=False) output_tensors = set() input_nodes_mapping = dict() # key is the input tensor name, value is the pair: (input_port, output_node_name) for src_node_name, dst_node_name, edge_attrs in graph.edges(data=True): dst_node = graph.node[dst_node_name] # edge outside sub-graph into sub-graph if (src_node_name not in matched_nodes) and (dst_node_name in matched_nodes): tensor_name = src_node_name + ":" + str(edge_attrs['out']) if tensor_name not in input_nodes_mapping: input_nodes_mapping[tensor_name] = list() input_nodes_mapping[tensor_name].append(('^' + dst_node_name + '$', edge_attrs['in'])) # edge from inside sub-graph to outside sub-graph if (src_node_name in matched_nodes) and (dst_node_name not in matched_nodes): output_tensors.add(('^' + dst_node['pb'].input[edge_attrs['in']] + '$', edge_attrs['out'])) for node_name in graph.nodes(): node = Node(graph, node_name) if node_name in matched_nodes and len(node.out_nodes()) == 0 and node['pb'].op != 'Const': log.debug("Node {} doesn't have output edges. Consider it output".format(node_name)) output_tensors.add(('^' + node_name + '$', 0)) if not self.has('inputs'): self._replacement_desc['inputs'] = [[{'node': desc[0], 'port': desc[1]} for desc in inp] for inp in sorted(input_nodes_mapping.values())] log.debug('Updated inputs of sub-graph for instance "{}"'.format(self.instances)) if not self.has('outputs'): self._replacement_desc['outputs'] = [{'node': node, 'port': port} for node, port in sorted(output_tensors)] log.debug('Updated outputs of sub-graph for instance "{}"'.format(self.instances))