def test_mark_output_unreachable_nodes(self):
        """
        Checks that all nodes that are unreachable from output nodes are marked correspondingly.
        The graph doesn't contain data nodes yet.
        "node_4" is output.

        placeholder_1->node_1->node_2
              \
               -> node_3->node_4

        :return: None
        """
        graph = build_graph(nodes_attributes, [('placeholder_1', 'node_1'),
                                               ('node_1', 'node_2'),
                                               ('placeholder_1', 'node_3'),
                                               ('node_3', 'node_4'),
                                               ('node_4', 'op_output')],
                            {'node_4': {}},
                            nodes_with_edges_only=True)
        mark_output_reachable_nodes(graph)

        self.assertListEqual(
            sorted(['placeholder_1', 'node_3', 'op_output', 'node_4']),
            sorted(graph.get_nodes_with_attributes(is_output_reachable=True)))
        self.assertListEqual(
            sorted(['node_1', 'node_2']),
            sorted(graph.get_nodes_with_attributes(is_output_reachable=False)))
Exemple #2
0
    def graph_clean_up(graph: Graph, undead_node_types: list = None):
        if undead_node_types is None:
            undead_node_types = []

        if 'Shape' in undead_node_types:
            undead_node_types.remove('Shape')

        mark_output_reachable_nodes(graph)
        mark_undead_nodes(graph, undead_node_types)
        mark_const_producer_nodes(graph)
        eliminate_dead_nodes(graph)
        # Add Const op for constant data nodes
        add_constant_operations(graph)
        shape_inference(graph)
    def test_mark_output_unreachable_nodes_behind_output(self):
        """
        Checks case when unreachable node is 'behind' (i.e. is the child) of the output node.
        The graph doesn't contain data nodes yet.
        "node_2" is output.

        placeholder_1->node_1->node_2->node_3

        :return: None
        """
        graph = build_graph(nodes_attributes, [('placeholder_1', 'node_1'),
                                               ('node_1', 'node_2'),
                                               ('node_2', 'node_3'),
                                               ('node_2', 'op_output')],
                            {'node_2': {}},
                            nodes_with_edges_only=True)
        mark_output_reachable_nodes(graph)

        self.assertListEqual(
            sorted(['node_1', 'node_2', 'op_output', 'placeholder_1']),
            sorted(graph.get_nodes_with_attributes(is_output_reachable=True)))
        self.assertFalse(graph.node['node_3']['is_output_reachable'])