def test_caffe_argmax_no_shape(self): graph = build_graph(nodes_attributes, [('node_1', 'argmax'), ('argmax', 'node_3'), ('node_3', 'op_output') ], {'node_3': {'shape': None}, 'node_1': {'shape': None}, 'argmax': { 'out_max_val': False, 'top_k': 100 } }) argmax_node = Node(graph, 'argmax') ArgMaxOp.argmax_infer(argmax_node) res_shape = graph.node['node_3']['shape'] self.assertIsNone(res_shape)
def test_caffe_argmax_extend_shape(self): graph = build_graph(nodes_attributes, [('node_1', 'argmax'), ('argmax', 'node_3'), ('node_3', 'op_output') ], {'node_3': {'shape': None}, 'node_1': {'shape': np.array([1, 3])}, 'argmax': { 'out_max_val': True, 'top_k': 100 } }) argmax_node = Node(graph, 'argmax') ArgMaxOp.argmax_infer(argmax_node) exp_shape = np.array([1, 2, 100]) res_shape = graph.node['node_3']['shape'] for i in range(0, len(exp_shape)): self.assertEqual(exp_shape[i], res_shape[i])
def test_caffe_argmax_axis_negative(self): graph = build_graph( nodes_attributes, [('node_1', 'argmax'), ('argmax', 'node_3')], { 'node_3': { 'is_output': True, 'shape': None }, 'node_1': { 'shape': np.array([1, 3, 1025, 2049]) }, 'argmax': { 'out_max_val': True, 'top_k': 100, 'axis': -1 } }) argmax_node = Node(graph, 'argmax') ArgMaxOp.argmax_infer(argmax_node) exp_shape = np.array([1, 3, 1025, 100]) res_shape = graph.node['node_3']['shape'] self.assertEqual(argmax_node.axis, 3) for i in range(0, len(exp_shape)): self.assertEqual(exp_shape[i], res_shape[i])