def extract(cls, node):
     encoding_map = {0: 'corner', 1: 'center'}
     center_point_box = onnx_attr(node, 'center_point_box', 'i', default=0)
     NonMaxSuppression.update_node_stat(node, {'sort_result_descending': 0,
                                               'output_type': np.int64,
                                               'box_encoding': encoding_map[center_point_box]})
     return cls.enabled
 def extract(cls, node):
     attrs = {
         'sort_result_descending': 1,
         'box_encoding': 'corner',
         'output_type': np.int32
     }
     NonMaxSuppression.update_node_stat(node, attrs)
     return cls.enabled
    def test_nms_infer_opset1(self):
        nms_node = Node(self.graph, 'nms')
        nms_node['version'] = 'opset1'
        NonMaxSuppression.infer(nms_node)
        NonMaxSuppression.type_infer(nms_node)

        self.assertTrue(
            np.array_equal(nms_node.out_port(0).data.get_shape(), [100, 3]))
        self.assertTrue(nms_node.out_port(0).get_data_type() == np.int64)
    def test_nms_infer_i64_opset5_1_out(self):
        nms_node = Node(self.graph, 'nms')
        nms_node['version'] = 'opset5'
        nms_node['output_type'] = np.int64
        NonMaxSuppression.infer(nms_node)
        NonMaxSuppression.type_infer(nms_node)

        self.assertTrue(
            np.array_equal(
                nms_node.out_port(0).data.get_shape(),
                shape_array([dynamic_dimension_value, 3])))
        self.assertTrue(nms_node.out_port(0).get_data_type() == np.int64)
 def extract(cls, node):
     pad_to_max_output_size = node.pb.attr["pad_to_max_output_size:"].b
     if not pad_to_max_output_size:
         log.warning(
             'The attribute "pad_to_max_output_size" of node {} is equal to False which is not supported. '
             'Forcing it to be equal to True'.format(node.soft_get('name')))
     attrs = {
         'sort_result_descending': 1,
         'box_encoding': 'corner',
         'output_type': np.int32
     }
     NonMaxSuppression.update_node_stat(node, attrs)
     return cls.enabled