Ejemplo n.º 1
0
    def test_remove_noop_nodes_front(self):
        graph = build_graph(
            {
                'noop': {
                    'type': 'NoOp',
                    'value': None,
                    'kind': 'op'
                },
                'output': {
                    'type': 'Identity',
                    'value': None,
                    'kind': 'op'
                }
            }, [('noop', 'output')])

        self.assertEqual(len(graph.nodes()), 2)
        self.assertEqual(len(graph.edges()), 1)
        self.assertListEqual(list(graph.out_edges('noop')),
                             [('noop', 'output')])

        erase_node(Node(graph, 'noop'))

        self.assertEqual(len(graph.nodes()), 1)
        self.assertEqual(len(graph.edges()), 0)
        self.assertEqual(len(graph.in_edges('output')), 0)
Ejemplo n.º 2
0
    def test_remove_noop_nodes_middle(self):
        graph = build_graph(
            {
                'input': {
                    'type': 'Placeholder',
                    'value': None,
                    'kind': 'op'
                },
                'noop': {
                    'type': 'NoOp',
                    'value': None,
                    'kind': 'op'
                },
                'output': {
                    'type': 'Identity',
                    'value': None,
                    'kind': 'op'
                },
            }, [('input', 'noop'), ('noop', 'output')])

        self.assertEqual(len(graph.nodes()), 3)
        self.assertEqual(len(graph.edges()), 2)
        self.assertListEqual(list(graph.out_edges('input')),
                             [('input', 'noop')])

        erase_node(Node(graph, 'noop'))

        self.assertEqual(len(graph.nodes()), 2)
        self.assertEqual(len(graph.edges()), 1)
        self.assertListEqual(list(graph.out_edges('input')),
                             [('input', 'output')])
Ejemplo n.º 3
0
 def replace_sub_graph(graph: nx.MultiDiGraph, match: dict):
     if not len(match['const'].in_edges()) and len(
             match['const'].out_edges()) == 1:
         erase_node(match['const'])
         erase_node(match['output'])
         log.info("Standalone Const node \"{}\" was removed from the graph".
                  format(match['const'].id))
Ejemplo n.º 4
0
    def test_remove_noop_nodes_noop_only(self):
        import networkx as nx
        graph = nx.MultiDiGraph()
        graph.add_node('noop', **{'type': 'NoOp', 'value': None, 'kind': 'op'})

        self.assertEqual(len(graph.nodes()), 1)
        self.assertEqual(len(graph.edges()), 0)

        erase_node(Node(graph, 'noop'))

        self.assertEqual(len(graph.nodes()), 0)
        self.assertEqual(len(graph.edges()), 0)
Ejemplo n.º 5
0
    def replace_sub_graph(self, graph: nx.MultiDiGraph, match: dict):
        """
        Need to find each occurrence of pattern: _contrib_MultiBoxPrior -> Flatten
        remove Flatten layer - IE does not expect outputs to be flattened

        Parameters
        ----------
        graph : nx.MultiDiGraph
           Graph with loaded model.
         match : dict
           Patterns which were found in graph structure.
        """
        erase_node(match['flatten'])
Ejemplo n.º 6
0
 def replace_sub_graph(self, graph: nx.MultiDiGraph, match: dict):
     ph = match['placeholder']
     if ph.name in self.replacement_dict:
         name = ph.name
         if ph.has_and_set('data_type'):
             data_type = ph.data_type
         else:
             data_type = SUPPORTED_DATA_TYPES[
                 graph.graph['cmd_params'].data_type][0]
         string_value = self.replacement_dict[name]
         try:
             if data_type != np.bool:
                 value = np.array(string_value, dtype=data_type)
             elif data_type == np.bool and graph.graph['fw'] == 'tf':
                 from mo.front.tf.common import tf_data_type_cast
                 if isinstance(string_value, list):
                     casted_list = list()
                     for v in np.array(string_value):
                         casted_list.append(
                             tf_data_type_cast[ph.data_type](v))
                     value = np.array(string_value, dtype=data_type)
                 else:
                     value = tf_data_type_cast[ph.data_type](string_value)
             else:
                 raise Error("Can not cast value {} to {} data_type".format(
                     string_value, data_type))
         except:
             raise Error("Can not cast value {} to {} data_type".format(
                 string_value, data_type))
         try:
             value = np.reshape(a=value, newshape=ph.shape)
         except:
             raise Error("Can not reshape value {} to shape {}".format(
                 value, ph.shape))
         out_edges = list(graph.out_edges(ph.id, data=True))
         new_node = Const(graph).create_node(
             attrs={
                 'value': value,
                 'data_type': type(value),
                 'name': name + '/const_placeholder',
                 'shape': ph.shape
             })
         erase_node(ph)
         graph.add_edges_from([(new_node.id, v, attrs)
                               for u, v, attrs in out_edges])
         log.info(
             "Placeholder node \"{}\" was replaced with Const node \"{}\" with value \"{}\""
             .format(name, new_node.name, value))
Ejemplo n.º 7
0
    def test_remove_node_from_graph(self):
        """
        Checks case when remove node from graph.
        The graph doesn't contain removed node yet.
        "node_2" should be removed.

        placeholder_1->node_1->node_2->node_3

        :return: None
        """
        graph = build_graph(nodes_attributes, [('placeholder_1', 'node_1'),
                                               ('node_1', 'node_2'),
                                               ('node_2', 'node_3')],
                            nodes_with_edges_only=True)
        erase_node(Node(graph, 'node_2'))

        self.assertListEqual(sorted(['placeholder_1', 'node_1', 'node_3']),
                             sorted(graph.nodes()))
Ejemplo n.º 8
0
    def replace_sub_graph(self, graph: nx.MultiDiGraph, match: dict):
        """
        Need to find each occurrence of pattern: _contrib_MultiBoxPrior(s) -> Concat -> Reshape
        remove Reshape layer - IE does not expect outputs from concatenation of _contrib_MultiBoxPrior to be reshaped

        Parameters
        ----------
        graph : nx.MultiDiGraph
           Graph with loaded model.
         match : dict
           Patterns which were found in graph structure.
        """
        erase_node(match['reshape'])

        # concat should be performed for the third axis
        concat_node = match['concat']
        attr = get_json_layer_attrs(
            concat_node.graph.node[concat_node.id]['symbol_dict'])
        if 'dim' in attr:
            attr['dim'] = 2
            concat_node['axis'] = 2
Ejemplo n.º 9
0
    def replace_sub_graph(graph: nx.MultiDiGraph, match: dict, **kwargs):
        """
        Usually graph looks like:

          main_graph
            ...             OpOutput
             |                 |
        image_batch      label_batch
                \        /
                batch_join
                /        \
        placeholder      fifo_queue

        Replacer works for both cases (that's why we have loop - 68 line):
            label_batch was marked as output
            there is no label_batch node
        """
        true_placeholder_shape = match['placeholder'].shape
        placeholder_shape = match['fifo_queue'].shapes[0]
        assert true_placeholder_shape.ndim <= 1
        if true_placeholder_shape.ndim == 1 and len(true_placeholder_shape) > 1:
            log.warning(
                'Placeholder \'{}\' got non 0-dimensional shape {} in FIFOQueue pattern. Placeholder will have the '
                'same shape after folding the pattern instead of {} shape which is original for the network.'
                ''.format(match['placeholder'].id, true_placeholder_shape, placeholder_shape))
            placeholder_shape = true_placeholder_shape
        placeholder_name = match['fifo_queue'].name
        erase_node(match['fifo_queue'])
        erase_node(match['placeholder'])
        for _, out in match['batch_join'].out_nodes().items():
            if out.id != match['image_batch'].id:
                if out.out_node().op == 'OpOutput':
                    erase_node(out.out_node())
                erase_node(out)
        erase_node(match['batch_join'])
        placeholder = Input(graph, {'name': placeholder_name, 'shape': placeholder_shape}).create_node()
        create_edge(placeholder, match['image_batch'])
        log.info("FIFOQueueV2 pattern was detected. New shape of placeholder {} is {}. Use -b to set batch size if "
                 "needed".format(placeholder.id, placeholder['shape']))
Ejemplo n.º 10
0
    def test_remove_noop_nodes_back(self):
        graph = build_graph(
            {
                'input': {
                    'type': 'Placeholder',
                    'value': None,
                    'kind': 'op'
                },
                'noop': {
                    'type': 'NoOp',
                    'value': None,
                    'kind': 'op'
                }
            }, [('input', 'noop')])

        self.assertEqual(len(graph.nodes()), 2)
        self.assertEqual(len(graph.edges()), 1)
        self.assertListEqual(list(graph.in_edges('noop')), [('input', 'noop')])

        erase_node(Node(graph, 'noop'))

        self.assertEqual(len(graph.nodes()), 1)
        self.assertEqual(len(graph.edges()), 0)
        self.assertEqual(len(graph.in_edges('input')), 0)
Ejemplo n.º 11
0
    def test_remove_noop_nodes_check_out_port(self):
        graph = build_graph(
            {
                'input': {
                    'type': 'Placeholder',
                    'value': None,
                    'kind': 'op'
                },
                'noop': {
                    'type': 'NoOp',
                    'value': None,
                    'kind': 'op'
                },
                'output_1': {
                    'type': 'Identity',
                    'value': None,
                    'kind': 'op'
                },
                'output_2': {
                    'type': 'Identity',
                    'value': None,
                    'kind': 'op'
                },
                'output_3': {
                    'type': 'Identity',
                    'value': None,
                    'kind': 'op'
                },
            }, [('input', 'noop'), ('noop', 'output_1', {
                'in': 4,
                'out': 1
            }), ('noop', 'output_2', {
                'in': 2,
                'out': 1
            }), ('noop', 'output_3', {
                'in': 10,
                'out': 1
            })])

        ref_graph = build_graph(
            {
                'input': {
                    'type': 'Placeholder',
                    'value': None,
                    'kind': 'op'
                },
                'output_1': {
                    'type': 'Identity',
                    'value': None,
                    'kind': 'op'
                },
                'output_2': {
                    'type': 'Identity',
                    'value': None,
                    'kind': 'op'
                },
                'output_3': {
                    'type': 'Identity',
                    'value': None,
                    'kind': 'op'
                },
            }, [('input', 'output_1', {
                'in': 4,
                'out': 0
            }), ('input', 'output_2', {
                'in': 2,
                'out': 0
            }), ('input', 'output_3', {
                'in': 10,
                'out': 0
            })],
            nodes_with_edges_only=True)

        erase_node(Node(graph, 'noop'))

        compare_graphs(graph, ref_graph, 'output_1')
Ejemplo n.º 12
0
 def replace_sub_graph(graph: nx.MultiDiGraph, match: dict):
     erase_node(match['output'])
     erase_node(match['noop'])
     log.info("NoOp node \"{}\" was removed from the graph".format(
         match['noop'].id))