Ejemplo n.º 1
0
    def test_simplernms_infer_no_shape(self):
        graph = build_graph(nodes_attributes,
                            [('SimplerNMS_1', 'node_1')],
                            {'node_1': {'is_output': True, 'shape': None},
                             'SimplerNMS_1': {'feat_stride': 12, 'post_nms_topn': 150, 'scale': [1, 2, 3]}
                             })

        simplernms_node = Node(graph, 'SimplerNMS_1')

        SimplerNMSOp.simplernms_infer(simplernms_node)
        self.assertIsNone(graph.node['node_1']['shape'])
Ejemplo n.º 2
0
    def test_simplernms_infer_ideal(self):
        graph = build_graph(nodes_attributes,
                            [('SimplerNMS_1', 'node_1')],
                            {'node_1': {'is_output': True, 'shape': None},
                             'SimplerNMS_1': {'feat_stride': 16, 'post_nms_topn': 150, 'scale': [1, 2, 3]}
                             })

        simplernms_node = Node(graph, 'SimplerNMS_1')

        SimplerNMSOp.simplernms_infer(simplernms_node)
        exp_shape = np.array([150, 5])
        res_shape = graph.node['node_1']['shape']
        for i in range(0, len(exp_shape)):
            self.assertEqual(exp_shape[i], res_shape[i])
        self.assertEqual(simplernms_node.scale, ['1', '2', '3'])
Ejemplo n.º 3
0
    def extract(cls, node):
        proto_layer = node.pb
        param = proto_layer.simpler_nms_param
        update_attrs = {
            'cls_threshold': param.cls_threshold,
            'max_num_proposals': param.max_num_proposals,
            'iou_threshold': param.iou_threshold,
            'min_bbox_size': param.min_bbox_size,
            'feat_stride': param.feat_stride,
            'pre_nms_topn': param.pre_nms_topn,
            'post_nms_topn': param.post_nms_topn,
            'scale': param.scale,
        }

        mapping_rule = merge_attrs(param, update_attrs)

        # update the attributes of the node
        SimplerNMSOp.update_node_stat(node, mapping_rule)
        return cls.enabled