예제 #1
0
    def __call__(self, feats, rois, spatial_scale):
        roi, rois_num = rois
        cur_l = 0
        if self.start_level == self.end_level:
            rois_feat = ops.roi_align(feats[self.start_level],
                                      roi,
                                      self.resolution,
                                      spatial_scale,
                                      rois_num=rois_num)
            return rois_feat
        offset = 2
        k_min = self.start_level + offset
        k_max = self.end_level + offset
        rois_dist, restore_index, rois_num_dist = ops.distribute_fpn_proposals(
            roi,
            k_min,
            k_max,
            self.canconical_level,
            self.canonical_size,
            rois_num=rois_num)

        rois_feat_list = []
        for lvl in range(self.start_level, self.end_level + 1):
            roi_feat = ops.roi_align(feats[lvl],
                                     rois_dist[lvl],
                                     self.resolution,
                                     spatial_scale[lvl],
                                     sampling_ratio=self.sampling_ratio,
                                     rois_num=rois_num_dist[lvl])
            rois_feat_list.append(roi_feat)
        rois_feat_shuffle = paddle.concat(rois_feat_list)
        rois_feat = paddle.gather(rois_feat_shuffle, restore_index)

        return rois_feat
예제 #2
0
    def __call__(self, feats, roi, rois_num):
        roi = paddle.concat(roi) if len(roi) > 1 else roi[0]
        if len(feats) == 1:
            rois_feat = ops.roi_align(feats[self.start_level],
                                      roi,
                                      self.resolution,
                                      self.spatial_scale[0],
                                      rois_num=rois_num,
                                      aligned=self.aligned)
        else:
            offset = 2
            k_min = self.start_level + offset
            k_max = self.end_level + offset
            rois_dist, restore_index, rois_num_dist = ops.distribute_fpn_proposals(
                roi,
                k_min,
                k_max,
                self.canconical_level,
                self.canonical_size,
                rois_num=rois_num)
            rois_feat_list = []
            for lvl in range(self.start_level, self.end_level + 1):
                roi_feat = ops.roi_align(feats[lvl],
                                         rois_dist[lvl],
                                         self.resolution,
                                         self.spatial_scale[lvl],
                                         sampling_ratio=self.sampling_ratio,
                                         rois_num=rois_num_dist[lvl],
                                         aligned=self.aligned)
                if roi_feat.shape[0] > 0:
                    rois_feat_list.append(roi_feat)
            rois_feat_shuffle = paddle.concat(rois_feat_list)
            rois_feat = paddle.gather(rois_feat_shuffle, restore_index)

        return rois_feat
예제 #3
0
    def test_roi_align(self):
        b, c, h, w = 2, 12, 20, 20
        inputs_np = np.random.rand(b, c, h, w).astype('float32')
        rois_num = [4, 6]
        output_size = (7, 7)
        rois_np = make_rois(h, w, rois_num, output_size)
        rois_num_np = np.array(rois_num).astype('int32')
        with self.static_graph():
            inputs = paddle.static.data(name='inputs',
                                        shape=[b, c, h, w],
                                        dtype='float32')
            rois = paddle.static.data(name='rois',
                                      shape=[10, 4],
                                      dtype='float32')
            rois_num = paddle.static.data(name='rois_num',
                                          shape=[None],
                                          dtype='int32')

            output = ops.roi_align(input=inputs,
                                   rois=rois,
                                   output_size=output_size,
                                   rois_num=rois_num)
            output_np, = self.get_static_graph_result(feed={
                'inputs': inputs_np,
                'rois': rois_np,
                'rois_num': rois_num_np
            },
                                                      fetch_list=output,
                                                      with_lod=False)

        with self.dynamic_graph():
            inputs_dy = base.to_variable(inputs_np)
            rois_dy = base.to_variable(rois_np)
            rois_num_dy = base.to_variable(rois_num_np)

            output_dy = ops.roi_align(input=inputs_dy,
                                      rois=rois_dy,
                                      output_size=output_size,
                                      rois_num=rois_num_dy)
            output_dy_np = output_dy.numpy()

        self.assertTrue(np.array_equal(output_np, output_dy_np))