def test_mark_output_unreachable_nodes(self): """ Checks that all nodes that are unreachable from output nodes are marked correspondingly. The graph doesn't contain data nodes yet. "node_4" is output. placeholder_1->node_1->node_2 \ -> node_3->node_4 :return: None """ graph = build_graph(nodes_attributes, [('placeholder_1', 'node_1'), ('node_1', 'node_2'), ('placeholder_1', 'node_3'), ('node_3', 'node_4'), ('node_4', 'op_output')], {'node_4': {}}, nodes_with_edges_only=True) mark_output_reachable_nodes(graph) self.assertListEqual( sorted(['placeholder_1', 'node_3', 'op_output', 'node_4']), sorted(graph.get_nodes_with_attributes(is_output_reachable=True))) self.assertListEqual( sorted(['node_1', 'node_2']), sorted(graph.get_nodes_with_attributes(is_output_reachable=False)))
def graph_clean_up(graph: Graph, undead_node_types: list = None): if undead_node_types is None: undead_node_types = [] if 'Shape' in undead_node_types: undead_node_types.remove('Shape') mark_output_reachable_nodes(graph) mark_undead_nodes(graph, undead_node_types) mark_const_producer_nodes(graph) eliminate_dead_nodes(graph) # Add Const op for constant data nodes add_constant_operations(graph) shape_inference(graph)
def test_mark_output_unreachable_nodes_behind_output(self): """ Checks case when unreachable node is 'behind' (i.e. is the child) of the output node. The graph doesn't contain data nodes yet. "node_2" is output. placeholder_1->node_1->node_2->node_3 :return: None """ graph = build_graph(nodes_attributes, [('placeholder_1', 'node_1'), ('node_1', 'node_2'), ('node_2', 'node_3'), ('node_2', 'op_output')], {'node_2': {}}, nodes_with_edges_only=True) mark_output_reachable_nodes(graph) self.assertListEqual( sorted(['node_1', 'node_2', 'op_output', 'placeholder_1']), sorted(graph.get_nodes_with_attributes(is_output_reachable=True))) self.assertFalse(graph.node['node_3']['is_output_reachable'])