def test_simplernms_infer_no_shape(self): graph = build_graph(nodes_attributes, [('SimplerNMS_1', 'node_1')], {'node_1': {'is_output': True, 'shape': None}, 'SimplerNMS_1': {'feat_stride': 12, 'post_nms_topn': 150, 'scale': [1, 2, 3]} }) simplernms_node = Node(graph, 'SimplerNMS_1') SimplerNMSOp.simplernms_infer(simplernms_node) self.assertIsNone(graph.node['node_1']['shape'])
def test_simplernms_infer_ideal(self): graph = build_graph(nodes_attributes, [('SimplerNMS_1', 'node_1')], {'node_1': {'is_output': True, 'shape': None}, 'SimplerNMS_1': {'feat_stride': 16, 'post_nms_topn': 150, 'scale': [1, 2, 3]} }) simplernms_node = Node(graph, 'SimplerNMS_1') SimplerNMSOp.simplernms_infer(simplernms_node) exp_shape = np.array([150, 5]) res_shape = graph.node['node_1']['shape'] for i in range(0, len(exp_shape)): self.assertEqual(exp_shape[i], res_shape[i]) self.assertEqual(simplernms_node.scale, ['1', '2', '3'])
def extract(cls, node): proto_layer = node.pb param = proto_layer.simpler_nms_param update_attrs = { 'cls_threshold': param.cls_threshold, 'max_num_proposals': param.max_num_proposals, 'iou_threshold': param.iou_threshold, 'min_bbox_size': param.min_bbox_size, 'feat_stride': param.feat_stride, 'pre_nms_topn': param.pre_nms_topn, 'post_nms_topn': param.post_nms_topn, 'scale': param.scale, } mapping_rule = merge_attrs(param, update_attrs) # update the attributes of the node SimplerNMSOp.update_node_stat(node, mapping_rule) return cls.enabled