def extract(cls, node): proto_layer = node.pb param = proto_layer.argmax_param update_attrs = { 'out_max_val': int(param.out_max_val), 'top_k': param.top_k, 'axis': param.axis, } mapping_rule = merge_attrs(param, update_attrs) ArgMaxOp.update_node_stat(node, mapping_rule) # ArgMax must be converted to TopK but without the output with values ArgMaxOp.update_node_stat(node, {'remove_values_output': True}) return cls.enabled
def extract(cls, node): keepdims = onnx_attr(node, 'keepdims', 'i', default=1) axis = onnx_attr(node, 'axis', 'i', default=0) attrs = { 'axis': axis, # ONNX ArgMax always computes an index of one maximum value 'top_k': 1, 'out_max_val': 0, # Set attribute to trigger ArgMax replacer in case do not keep the dimension 'keepdims': keepdims } ArgMaxOp.update_node_stat(node, attrs) return cls.enabled
def extract(cls, node): ArgMaxOp.update_node_stat( node, { 'out_max_val': 0, 'top_k': 1, 'axis': None, 'dim_attrs': ['axis'], 'keepdims': 0, 'remove_values_output': True, 'output_type': tf_dtype_extractor(node.pb.attr['out_type'].type, np.int64), }) return cls.enabled
def test_caffe_argmax_no_shape(self): graph = build_graph(nodes_attributes, [('node_1', 'argmax'), ('argmax', 'node_3'), ('node_3', 'op_output') ], {'node_3': {'shape': None}, 'node_1': {'shape': None}, 'argmax': { 'out_max_val': False, 'top_k': 100 } }) argmax_node = Node(graph, 'argmax') ArgMaxOp.argmax_infer(argmax_node) res_shape = graph.node['node_3']['shape'] self.assertIsNone(res_shape)
def test_caffe_argmax_extend_shape(self): graph = build_graph(nodes_attributes, [('node_1', 'argmax'), ('argmax', 'node_3'), ('node_3', 'op_output') ], {'node_3': {'shape': None}, 'node_1': {'shape': np.array([1, 3])}, 'argmax': { 'out_max_val': True, 'top_k': 100 } }) argmax_node = Node(graph, 'argmax') ArgMaxOp.argmax_infer(argmax_node) exp_shape = np.array([1, 2, 100]) res_shape = graph.node['node_3']['shape'] for i in range(0, len(exp_shape)): self.assertEqual(exp_shape[i], res_shape[i])
def test_caffe_argmax_axis_negative(self): graph = build_graph( nodes_attributes, [('node_1', 'argmax'), ('argmax', 'node_3')], { 'node_3': { 'is_output': True, 'shape': None }, 'node_1': { 'shape': np.array([1, 3, 1025, 2049]) }, 'argmax': { 'out_max_val': True, 'top_k': 100, 'axis': -1 } }) argmax_node = Node(graph, 'argmax') ArgMaxOp.argmax_infer(argmax_node) exp_shape = np.array([1, 3, 1025, 100]) res_shape = graph.node['node_3']['shape'] self.assertEqual(argmax_node.axis, 3) for i in range(0, len(exp_shape)): self.assertEqual(exp_shape[i], res_shape[i])
def extract(cls, node): ArgMaxOp.update_node_stat(node, {'out_max_val': 0, 'top_k': 1, 'axis': None, 'dim_attrs': ['axis'], 'keepdims': 0, 'remove_values_output': True}) return cls.enabled