Example #1
0
    def forward(self, input):

        # Algorithm:
        #
        # for each (L, H, W) location i
        #   generate A anchor boxes centered on cell i
        #   apply predicted bbox deltas at cell i to each of the A anchors
        # clip predicted boxes to image
        # remove predicted boxes with either height or width < threshold
        # sort all (proposal, score) pairs by score from highest to lowest
        # take top pre_nms_topN proposals before NMS
        # apply NMS with threshold 0.7 to remaining proposals
        # take after_nms_topN proposals after NMS
        # return the top proposals (-> RoIs top, scores top)


        # the first set of _num_anchors channels are bg probs
        # the second set are the fg probs
        scores = input[0][:, :, 1]  # batch_size x num_rois x 1
        bbox_deltas = input[1]      # batch_size x num_rois x 6
        im_info = input[2]
        cfg_key = input[3]
        feat_shapes = input[4]        

        pre_nms_topN  = cfg[cfg_key].RPN_PRE_NMS_TOP_N
        post_nms_topN = cfg[cfg_key].RPN_POST_NMS_TOP_N
        nms_thresh    = cfg[cfg_key].RPN_NMS_THRESH
        min_size      = cfg[cfg_key].RPN_MIN_SIZE

        batch_size = bbox_deltas.size(0)

        anchors = torch.from_numpy(generate_anchors_all_pyramids(self._fpn_scales, self._anchor_ratios,
                l_ratios, feat_shapes, 
                self._fpn_feature_strides, self._fpn_anchor_stride)).type_as(scores)
        num_anchors = anchors.size(0)

        anchors = anchors.view(1, num_anchors, 6).expand(batch_size, num_anchors, 6)

        # Convert anchors into proposals via bbox transformations
        proposals = bbox_transform_inv(anchors, bbox_deltas, batch_size)

        # 2. clip predicted boxes to image
        proposals = clip_boxes(proposals, im_info, batch_size)
        # keep_idx = self._filter_boxes(proposals, min_size).squeeze().long().nonzero().squeeze()
                
        scores_keep = scores
        proposals_keep = proposals

        _, order = torch.sort(scores_keep, 1, True)

        output = scores.new(batch_size, post_nms_topN, 7).zero_()
        for i in range(batch_size):
            # # 3. remove predicted boxes with either height or width < threshold
            # # (NOTE: convert min_size to input image scale stored in im_info[2])
            proposals_single = proposals_keep[i]
            scores_single = scores_keep[i]

            # # 4. sort all (proposal, score) pairs by score from highest to lowest
            # # 5. take top pre_nms_topN (e.g. 6000)
            order_single = order[i]

            if pre_nms_topN > 0 and pre_nms_topN < scores_keep.numel():
                order_single = order_single[:pre_nms_topN]

            proposals_single = proposals_single[order_single, :]
            scores_single = scores_single[order_single].view(-1,1)

            # 6. apply nms (e.g. threshold = 0.7)
            # 7. take after_nms_topN (e.g. 300)
            # 8. return the top proposals (-> RoIs top)

            keep_idx_i = nms(torch.cat((proposals_single, scores_single), 1), nms_thresh)
            keep_idx_i = keep_idx_i.long().view(-1)

            if post_nms_topN > 0:
                keep_idx_i = keep_idx_i[:post_nms_topN]
            proposals_single = proposals_single[keep_idx_i, :]
            scores_single = scores_single[keep_idx_i, :]

            # padding 0 at the end.
            num_proposal = proposals_single.size(0)
            output[i,:,0] = i
            output[i,:num_proposal,1:] = proposals_single

        return output
    def forward(self, input):
        # Algorithm:
        #
        # for each (H, W) location i
        #   generate 9 anchor boxes centered on cell i
        #   apply predicted bbox deltas at cell i to each of the 9 anchors
        # filter out-of-image anchors

        scores = input[0]
        gt_boxes = input[1]
        im_info = input[2]
        num_boxes = input[3]
        feat_shapes = input[4]

        # NOTE: need to change
        # height, width = scores.size(2), scores.size(3)
        height, width = 0, 0

        batch_size = gt_boxes.size(0)

        anchors = torch.from_numpy(
            generate_anchors_all_pyramids(
                self._fpn_scales, self._anchor_ratios, feat_shapes,
                self._fpn_feature_strides,
                self._fpn_anchor_stride)).type_as(scores)
        total_anchors = anchors.size(0)

        keep = ((anchors[:, 0] >= -self._allowed_border) &
                (anchors[:, 1] >= -self._allowed_border) &
                (anchors[:, 2] < long(im_info[0][1]) + self._allowed_border) &
                (anchors[:, 3] < long(im_info[0][0]) + self._allowed_border))

        inds_inside = torch.nonzero(keep).view(-1)

        # keep only inside anchors
        anchors = anchors[inds_inside, :]

        # label: 1 is positive, 0 is negative, -1 is dont care
        labels = gt_boxes.new(batch_size, inds_inside.size(0)).fill_(-1)
        bbox_inside_weights = gt_boxes.new(batch_size,
                                           inds_inside.size(0)).zero_()
        bbox_outside_weights = gt_boxes.new(batch_size,
                                            inds_inside.size(0)).zero_()

        overlaps = bbox_overlaps_batch(anchors, gt_boxes)

        max_overlaps, argmax_overlaps = torch.max(overlaps, 2)
        gt_max_overlaps, _ = torch.max(overlaps, 1)

        if not cfg.TRAIN.RPN_CLOBBER_POSITIVES:
            labels[max_overlaps < cfg.TRAIN.RPN_NEGATIVE_OVERLAP] = 0

        gt_max_overlaps[gt_max_overlaps == 0] = 1e-5
        keep = torch.sum(
            overlaps.eq(
                gt_max_overlaps.view(batch_size, 1, -1).expand_as(overlaps)),
            2)

        if torch.sum(keep) > 0:
            labels[keep > 0] = 1

        # fg label: above threshold IOU
        labels[max_overlaps >= cfg.TRAIN.RPN_POSITIVE_OVERLAP] = 1

        if cfg.TRAIN.RPN_CLOBBER_POSITIVES:
            labels[max_overlaps < cfg.TRAIN.RPN_NEGATIVE_OVERLAP] = 0

        num_fg = int(cfg.TRAIN.RPN_FG_FRACTION * cfg.TRAIN.RPN_BATCHSIZE)

        sum_fg = torch.sum((labels == 1).int(), 1)
        sum_bg = torch.sum((labels == 0).int(), 1)

        for i in range(batch_size):
            # subsample positive labels if we have too many
            if sum_fg[i] > num_fg:
                fg_inds = torch.nonzero(labels[i] == 1).view(-1)
                # torch.randperm seems has a bug on multi-gpu setting that cause the segfault.
                # See https://github.com/pytorch/pytorch/issues/1868 for more details.
                # use numpy instead.
                #rand_num = torch.randperm(fg_inds.size(0)).type_as(gt_boxes).long()
                rand_num = torch.from_numpy(
                    np.random.permutation(
                        fg_inds.size(0))).type_as(gt_boxes).long()
                disable_inds = fg_inds[rand_num[:fg_inds.size(0) - num_fg]]
                labels[i][disable_inds] = -1

            num_bg = cfg.TRAIN.RPN_BATCHSIZE - sum_fg[i]

            # subsample negative labels if we have too many
            if sum_bg[i] > num_bg:
                bg_inds = torch.nonzero(labels[i] == 0).view(-1)
                #rand_num = torch.randperm(bg_inds.size(0)).type_as(gt_boxes).long()

                rand_num = torch.from_numpy(
                    np.random.permutation(
                        bg_inds.size(0))).type_as(gt_boxes).long()
                disable_inds = bg_inds[rand_num[:bg_inds.size(0) - num_bg]]
                labels[i][disable_inds] = -1

        offset = torch.arange(0, batch_size) * gt_boxes.size(1)

        argmax_overlaps = argmax_overlaps + offset.view(
            batch_size, 1).type_as(argmax_overlaps)
        bbox_targets = _compute_targets_batch(
            anchors,
            gt_boxes.view(-1, 5)[argmax_overlaps.view(-1), :].view(
                batch_size, -1, 5))

        # use a single value instead of 4 values for easy index.
        bbox_inside_weights[labels == 1] = cfg.TRAIN.RPN_BBOX_INSIDE_WEIGHTS[0]

        if cfg.TRAIN.RPN_POSITIVE_WEIGHT < 0:
            num_examples = torch.sum(labels[i] >= 0)
            positive_weights = 1.0 / num_examples.item()
            #positive_weights = 1.0 / num_examples
            #negative_weights = 1.0 / num_examples
            negative_weights = 1.0 / num_examples.item()
        else:
            assert ((cfg.TRAIN.RPN_POSITIVE_WEIGHT > 0) &
                    (cfg.TRAIN.RPN_POSITIVE_WEIGHT < 1))

        bbox_outside_weights[labels == 1] = positive_weights
        bbox_outside_weights[labels == 0] = negative_weights

        labels = _unmap(labels,
                        total_anchors,
                        inds_inside,
                        batch_size,
                        fill=-1)
        bbox_targets = _unmap(bbox_targets,
                              total_anchors,
                              inds_inside,
                              batch_size,
                              fill=0)
        bbox_inside_weights = _unmap(bbox_inside_weights,
                                     total_anchors,
                                     inds_inside,
                                     batch_size,
                                     fill=0)
        bbox_outside_weights = _unmap(bbox_outside_weights,
                                      total_anchors,
                                      inds_inside,
                                      batch_size,
                                      fill=0)

        outputs = []

        # labels = labels.view(batch_size, height, width, A).permute(0,3,1,2).contiguous()
        # labels = labels.view(batch_size, 1, A * height, width)
        outputs.append(labels)
        # bbox_targets = bbox_targets.view(batch_size, height, width, A*4).permute(0,3,1,2).contiguous()
        outputs.append(bbox_targets)

        # anchors_count = bbox_inside_weights.size(1)
        # bbox_inside_weights = bbox_inside_weights.view(batch_size,anchors_count,1).expand(batch_size, anchors_count, 4)
        # bbox_inside_weights = bbox_inside_weights.contiguous().view(batch_size, height, width, 4*A)\
        # .permute(0,3,1,2).contiguous()

        outputs.append(bbox_inside_weights)

        # bbox_outside_weights = bbox_outside_weights.view(batch_size,anchors_count,1).expand(batch_size, anchors_count, 4)
        # bbox_outside_weights = bbox_outside_weights.contiguous().view(batch_size, height, width, 4*A)\
        # .permute(0,3,1,2).contiguous()
        outputs.append(bbox_outside_weights)

        return outputs