def extract(cls, node):
        proto_layer = node.pb
        param = proto_layer.argmax_param

        update_attrs = {
            'out_max_val': int(param.out_max_val),
            'top_k': param.top_k,
            'axis': param.axis,
        }

        mapping_rule = merge_attrs(param, update_attrs)

        ArgMaxOp.update_node_stat(node, mapping_rule)
        # ArgMax must be converted to TopK but without the output with values
        ArgMaxOp.update_node_stat(node, {'remove_values_output': True})
        return cls.enabled
示例#2
0
    def extract(cls, node):
        keepdims = onnx_attr(node, 'keepdims', 'i', default=1)
        axis = onnx_attr(node, 'axis', 'i', default=0)

        attrs = {
            'axis': axis,

            # ONNX ArgMax always computes an index of one maximum value
            'top_k': 1,
            'out_max_val': 0,

            # Set attribute to trigger ArgMax replacer in case do not keep the dimension
            'keepdims': keepdims
        }

        ArgMaxOp.update_node_stat(node, attrs)
        return cls.enabled
示例#3
0
 def extract(cls, node):
     ArgMaxOp.update_node_stat(
         node, {
             'out_max_val':
             0,
             'top_k':
             1,
             'axis':
             None,
             'dim_attrs': ['axis'],
             'keepdims':
             0,
             'remove_values_output':
             True,
             'output_type':
             tf_dtype_extractor(node.pb.attr['out_type'].type, np.int64),
         })
     return cls.enabled
示例#4
0
    def test_caffe_argmax_no_shape(self):
        graph = build_graph(nodes_attributes,
                            [('node_1', 'argmax'),
                             ('argmax', 'node_3'),
                             ('node_3', 'op_output')
                             ],
                            {'node_3': {'shape': None},
                             'node_1': {'shape': None},
                             'argmax': {
                                 'out_max_val': False,
                                 'top_k': 100
                             }
                             })

        argmax_node = Node(graph, 'argmax')
        ArgMaxOp.argmax_infer(argmax_node)
        res_shape = graph.node['node_3']['shape']
        self.assertIsNone(res_shape)
示例#5
0
    def test_caffe_argmax_extend_shape(self):
        graph = build_graph(nodes_attributes,
                            [('node_1', 'argmax'),
                             ('argmax', 'node_3'),
                             ('node_3', 'op_output')
                             ],
                            {'node_3': {'shape': None},
                             'node_1': {'shape': np.array([1, 3])},
                             'argmax': {
                                 'out_max_val': True,
                                 'top_k': 100
                             }
                             })

        argmax_node = Node(graph, 'argmax')
        ArgMaxOp.argmax_infer(argmax_node)
        exp_shape = np.array([1, 2, 100])
        res_shape = graph.node['node_3']['shape']
        for i in range(0, len(exp_shape)):
            self.assertEqual(exp_shape[i], res_shape[i])
示例#6
0
    def test_caffe_argmax_axis_negative(self):
        graph = build_graph(
            nodes_attributes, [('node_1', 'argmax'), ('argmax', 'node_3')], {
                'node_3': {
                    'is_output': True,
                    'shape': None
                },
                'node_1': {
                    'shape': np.array([1, 3, 1025, 2049])
                },
                'argmax': {
                    'out_max_val': True,
                    'top_k': 100,
                    'axis': -1
                }
            })

        argmax_node = Node(graph, 'argmax')
        ArgMaxOp.argmax_infer(argmax_node)
        exp_shape = np.array([1, 3, 1025, 100])
        res_shape = graph.node['node_3']['shape']
        self.assertEqual(argmax_node.axis, 3)
        for i in range(0, len(exp_shape)):
            self.assertEqual(exp_shape[i], res_shape[i])
示例#7
0
 def extract(cls, node):
     ArgMaxOp.update_node_stat(node, {'out_max_val': 0, 'top_k': 1, 'axis': None,
                                      'dim_attrs': ['axis'], 'keepdims': 0, 'remove_values_output': True})
     return cls.enabled