def test_mark_ops_producing_constant_values(self):
        """
        Checks case when operation produces only constant tensors so it could be removed. If the node produces several
        tensors and at least one of them is not constant then we should not mark this node.
        The graph contains data nodes.
        "data_node_2" and "data_node_5" are output.
        "node_3" produces constant tensor "data_node_3" and non-constant tensor "data_node_3_2".
        "node_6" produces constant tensor "data_node_6".
        "node_4" could be eliminated since it gets constant input.

                             node_6->data_node_6->
                                                  \
        placeholder_1->placeholder_1_data_node->node_1->data_node_1->node_2->data_node_2
                                                  /
        node_3->data_node_3->node_4->data_node_4->
           \
            ->data_node_3_2->node_5->data_node_5

        :return: None
        """
        graph = build_graph(nodes_attributes,
                            [('placeholder_1', 'placeholder_1_data_node'),
                             ('placeholder_1_data_node', 'node_1'),
                             ('node_1', 'data_node_1'),
                             ('data_node_1', 'node_2'),
                             ('node_2', 'data_node_2'),
                             ('node_3', 'data_node_3'),
                             ('node_3', 'data_node_3_2'),
                             ('node_6', 'data_node_6'),
                             ('data_node_6', 'node_1'),
                             ('data_node_3_2', 'node_5'),
                             ('node_5', 'data_node_5'),
                             ('data_node_3', 'node_4'),
                             ('data_node_4', 'node_1'),
                             ('data_node_2', 'op_output'),
                             ('data_node_5', 'op_output_1')], {
                                 'data_node_2': {},
                                 'data_node_5': {},
                                 'data_node_3': {
                                     'value': np.array(1)
                                 },
                                 'data_node_6': {
                                     'value': np.array(1)
                                 }
                             },
                            nodes_with_edges_only=True)
        mark_const_producer_nodes(graph)
        self.assertTrue((graph.node['node_6']['is_const_producer']))
        self.assertListEqual(
            sorted(['node_1', 'node_2', 'node_3', 'node_5', 'placeholder_1']),
            sorted(
                graph.get_nodes_with_attributes(is_const_producer=False,
                                                kind='op')))

        graph.clean_up()
        self.assertTrue('node_3' in graph.nodes())
        self.assertTrue('node_4' not in graph.nodes())
        self.assertTrue('node_6' not in graph.nodes())
Exemple #2
0
    def graph_clean_up(graph: Graph, undead_node_types: list = None):
        if undead_node_types is None:
            undead_node_types = []

        if 'Shape' in undead_node_types:
            undead_node_types.remove('Shape')

        mark_output_reachable_nodes(graph)
        mark_undead_nodes(graph, undead_node_types)
        mark_const_producer_nodes(graph)
        eliminate_dead_nodes(graph)
        # Add Const op for constant data nodes
        add_constant_operations(graph)
        shape_inference(graph)