Esempio n. 1
0
    def test_caffe_argmax_no_shape(self):
        graph = build_graph(nodes_attributes,
                            [('node_1', 'argmax'),
                             ('argmax', 'node_3'),
                             ('node_3', 'op_output')
                             ],
                            {'node_3': {'shape': None},
                             'node_1': {'shape': None},
                             'argmax': {
                                 'out_max_val': False,
                                 'top_k': 100
                             }
                             })

        argmax_node = Node(graph, 'argmax')
        ArgMaxOp.argmax_infer(argmax_node)
        res_shape = graph.node['node_3']['shape']
        self.assertIsNone(res_shape)
Esempio n. 2
0
    def test_caffe_argmax_extend_shape(self):
        graph = build_graph(nodes_attributes,
                            [('node_1', 'argmax'),
                             ('argmax', 'node_3'),
                             ('node_3', 'op_output')
                             ],
                            {'node_3': {'shape': None},
                             'node_1': {'shape': np.array([1, 3])},
                             'argmax': {
                                 'out_max_val': True,
                                 'top_k': 100
                             }
                             })

        argmax_node = Node(graph, 'argmax')
        ArgMaxOp.argmax_infer(argmax_node)
        exp_shape = np.array([1, 2, 100])
        res_shape = graph.node['node_3']['shape']
        for i in range(0, len(exp_shape)):
            self.assertEqual(exp_shape[i], res_shape[i])
Esempio n. 3
0
    def test_caffe_argmax_axis_negative(self):
        graph = build_graph(
            nodes_attributes, [('node_1', 'argmax'), ('argmax', 'node_3')], {
                'node_3': {
                    'is_output': True,
                    'shape': None
                },
                'node_1': {
                    'shape': np.array([1, 3, 1025, 2049])
                },
                'argmax': {
                    'out_max_val': True,
                    'top_k': 100,
                    'axis': -1
                }
            })

        argmax_node = Node(graph, 'argmax')
        ArgMaxOp.argmax_infer(argmax_node)
        exp_shape = np.array([1, 3, 1025, 100])
        res_shape = graph.node['node_3']['shape']
        self.assertEqual(argmax_node.axis, 3)
        for i in range(0, len(exp_shape)):
            self.assertEqual(exp_shape[i], res_shape[i])