Beispiel #1
0
    def test_caffe_argmax_axis_negative(self):
        graph = build_graph(
            nodes_attributes, [('op_input', 'node_1'), ('node_1', 'argmax'),
                               ('argmax', 'node_3'), ('node_3', 'op_output')],
            {
                'node_3': {
                    'shape': None
                },
                'node_1': {
                    'shape': np.array([1, 3, 1025, 2049])
                },
                'argmax': {
                    'out_max_val': True,
                    'top_k': 100,
                    'axis': -1
                }
            })

        argmax_node = Node(graph, 'argmax')
        arg_ops_infer(argmax_node)
        exp_shape = np.array([1, 3, 1025, 100])
        res_shape = graph.node['node_3']['shape']
        self.assertEqual(argmax_node.axis, 3)
        for i in range(0, len(exp_shape)):
            self.assertEqual(exp_shape[i], res_shape[i])
Beispiel #2
0
    def test_caffe_argmax_out_max_val_false(self):
        graph = build_graph(
            nodes_attributes, [('op_input', 'node_1'), ('node_1', 'argmax'),
                               ('argmax', 'node_3'), ('node_3', 'op_output')],
            {
                'node_3': {
                    'shape': None
                },
                'node_1': {
                    'shape': np.array([1, 3])
                },
                'argmax': {
                    'out_max_val': False,
                    'top_k': 100
                }
            })

        argmax_node = Node(graph, 'argmax')
        arg_ops_infer(argmax_node)
        exp_shape = np.array([1, 1, 100])
        res_shape = graph.node['node_3']['shape']
        for i in range(0, len(exp_shape)):
            self.assertEqual(exp_shape[i], res_shape[i])