def extract(cls, node): sort = 'value' if node.pb.attr['sorted'] else 'none' TopK.update_node_stat(node, { 'mode': 'max', 'axis': -1, 'sort': sort, 'k': node.pb.attr['k'].i }) return cls.enabled
def test_topk_infer_opset1(self): topk_node = Node(self.graph, 'topk') topk_node['version'] = 'opset1' TopK.infer(topk_node) TopK.type_infer(topk_node) self.assertTrue(np.array_equal(topk_node.out_port(0).data.get_shape(), int64_array([20, 10, 4]))) self.assertTrue(np.array_equal(topk_node.out_port(1).data.get_shape(), int64_array([20, 10, 4]))) self.assertTrue(topk_node.out_port(0).get_data_type() == np.float32) self.assertTrue(topk_node.out_port(1).get_data_type() == np.int32)
def extract(cls, node): sort = 'value' if node.pb.attr['sorted'] else 'none' TopK.update_node_stat( node, { 'mode': 'max', 'axis': -1, 'sort': sort, 'index_element_type': np.int32 }) return cls.enabled
def test_topk_infer_v10_i32_opset3(self): self.graph.graph['cmd_params'] = FakeAttr(ir_version=10) topk_node = Node(self.graph, 'topk') topk_node['version'] = 'opset3' topk_node['index_element_type'] = np.int32 TopK.infer(topk_node) TopK.type_infer(topk_node) self.assertTrue(np.array_equal(topk_node.out_port(0).data.get_shape(), int64_array([20, 10, 4]))) self.assertTrue(np.array_equal(topk_node.out_port(1).data.get_shape(), int64_array([20, 10, 4]))) self.assertTrue(topk_node.out_port(0).get_data_type() == np.float32) self.assertTrue(topk_node.out_port(1).get_data_type() == np.int32)
def extract(cls, node): """ TopK-1 (k as attribute, required) TopK-10 (k as input, no sorting manipulations) TopK-11 (k as input, sorting manipulations through `sorted` and `largest` attrs) """ attrs = {'axis': onnx_attr(node, 'axis', 'i', default=-1)} if onnx_node_has_attr(node, 'k'): attrs['k'] = onnx_attr(node, 'k', 'i') attrs['sort'] = 'value' if onnx_attr(node, 'sorted', 'i', default=1) else 'none' attrs['mode'] = 'max' if onnx_attr(node, 'largest', 'i', default=1) else 'min' TopK.update_node_stat(node, attrs) return cls.enabled
def replace_pattern(self, graph: Graph, match: dict): node = match['node'] node_name = node.soft_get('name', node.id) connected_ports = [ port for port in node.in_ports().values() if not port.disconnected() ] if len(connected_ports) == 2: axis = node.in_port(1).data.get_value() else: axis = node.axis assert axis is not None, 'The "axis" should be defined for node "{}"'.format( node_name) assert node.has_and_set( 'output_type'), 'The data type is not set for node "{}"'.format( node_name) topk_mode = 'max' if node.op == 'ArgMax' else 'min' topk_node = TopK( graph, { 'axis': axis, 'mode': topk_mode, 'sort': 'index', 'remove_values_output': node.has_and_set('remove_values_output'), 'index_element_type': node.output_type }).create_node() node.in_port(0).get_connection().set_destination(topk_node.in_port(0)) if node.has_and_set( 'out_max_val' ): # in this mode the ArgMax produces tuples (max_ind, max_value) concat_node = Concat(graph, { 'axis': 1, 'name': node.name + '/Concat' }).create_node() concat_node.add_input_port(0, skip_if_exist=True) concat_node.add_input_port(1, skip_if_exist=True) topk_node.out_port(0).connect(concat_node.in_port(1)) # indices topk_node.out_port(1).connect(concat_node.in_port(0)) # values if not node.out_port(0).disconnected(): node.out_port(0).get_connection().set_source( concat_node.out_port(0)) else: if not node.out_port(0).disconnected(): node.out_port(0).get_connection().set_source( topk_node.out_port(1)) topk_node.in_port(1).connect( Const(graph, { 'name': node.soft_get('name') + '/TopK', 'value': node.top_k }).create_node().out_port(0)) graph.remove_nodes_from([node.id, node.out_node(0).id])
def test_topk_infer_v10_opset1(self): self.graph.graph['cmd_params'] = FakeAttr( generate_experimental_IR_V10=True, ir_version=10) topk_node = Node(self.graph, 'topk') topk_node['version'] = 'opset1' TopK.infer(topk_node) TopK.type_infer(topk_node) self.assertTrue( np.array_equal( topk_node.out_port(0).data.get_shape(), int64_array([20, 10, 4]))) self.assertTrue( np.array_equal( topk_node.out_port(1).data.get_shape(), int64_array([20, 10, 4]))) self.assertTrue(topk_node.out_port(0).get_data_type() == np.float32) self.assertTrue(topk_node.out_port(1).get_data_type() == np.int32)
def extract(node): axis = onnx_attr(node, 'axis', 'i', default=-1) TopK.update_node_stat(node, {'axis': axis, 'sort': 'value'}) return __class__.enabled
def extract(node): sort = 'value' if node.pb.attr['sorted'] else 'none' TopK.update_node_stat(node, {'mode': 'max', 'axis': -1, 'sort': sort}) return __class__.enabled