예제 #1
    def single_fmps_process(self, location, pred_cls, pred_box,
                            pred_centerness, image_sizes):
        B, H, W, C = pred_cls.shape

        pred_cls = pred_cls.view(B, -1, C).sigmoid()
        pred_box = pred_box.view(B, -1, 4)
        pred_centerness = pred_centerness.view(B, -1).sigmoid()

        # multiply the classification scores with centerness scores
        pred_cls = pred_cls * pred_centerness[:, :, None]

        cls_mask = pred_cls > self.conf_thresh
        cls_mask_top_n = cls_mask.view(B, -1).sum(1)
        cls_mask_top_n = cls_mask_top_n.clamp(max=self.nms_thresh_topN)

        res = []
        for b in range(B):
            per_cls = pred_cls[b]
            per_cls_mask = cls_mask[b]
            per_cls = per_cls[per_cls_mask]

            per_cls_mask_nonzeros = per_cls_mask.nonzero()

            per_box_loc = per_cls_mask_nonzeros[:, 0]
            per_box_cls = per_cls_mask_nonzeros[:, 1] + 1  # class index

            per_box = pred_box[b]
            per_box = per_box[per_box_loc]
            per_location = location[per_box_loc]
            per_cls_mask_top_n = cls_mask_top_n[b]

            if per_cls_mask.sum().item() > per_cls_mask_top_n.item():
                per_cls, top_k_idx = per_cls.topk(per_cls_mask_top_n,
                per_box_cls = per_box_cls[top_k_idx]
                per_box = per_box[top_k_idx]
                per_location = per_location[top_k_idx]

            detections = torch.stack([
                per_location[:, 0] - per_box[:, 0],
                per_location[:, 1] - per_box[:, 1],
                per_location[:, 0] + per_box[:, 2],
                per_location[:, 1] + per_box[:, 3],

            h, w = image_sizes[0]
            box_list = BoxList(detections, (w, h), mode='xyxy')
            box_list.add_field('labels', per_box_cls)
            box_list.add_field('scores', per_cls)
            box_list = box_list.clip_to_image(remove_empty=False)

        return res
def select_layers(loc, cls, box, centerness, img_size):

    batch, channel, height, width = cls.shape
    cls = cls.permute(0, 2, 3, 1)
    cls = cls.reshape(batch, -1, channel).sigmoid()
    box = box.permute(0, 2, 3, 1)
    box = box.reshape(batch, -1, 4)
    centerness = centerness.permute(0, 2, 3, 1)
    centerness = centerness.reshape(batch, -1).sigmoid()

    select_id = cls > PRE_NMS_THRESH
    num_id = select_id.view(batch, -1).sum(1)
    num_id = num_id.clamp(max=PRE_NMS_TOK_K)

    cls = cls * centerness[:, :, None]

    res = []

    for i in range(batch):
        cls_i = cls[i]
        select_id_i = select_id[i]
        cls_i = cls_i[select_id_i]
        per_candidate_nonzeros = select_id_i.nonzero()
        loc_i = per_candidate_nonzeros[:, 0]
        per_class = per_candidate_nonzeros[:, 1] + 1

        box_i = box[i]
        box_i = box_i[loc_i]
        per_locs = loc[loc_i]

        tok_k = num_id[i]

        if select_id_i.sum().item() > tok_k.item():
            cls_i, top_k_index = cls_i.topk(tok_k, sorted=False)
            per_class = per_class[top_k_index]
            box_i = box_i[top_k_index]
            per_locs = per_locs[top_k_index]

        l_ = per_locs[:, 0] - box_i[:, 0]
        t_ = per_locs[:, 1] - box_i[:, 1]
        r_ = per_locs[:, 0] + box_i[:, 2]
        b_ = per_locs[:, 1] + box_i[:, 3]

        regs = torch.stack([l_, t_, r_, b_], dim=1)
        h, w = img_size[i]
        boxlist = BoxList(regs, (int(w), int(h)), mode='xyxy')
        boxlist.add_field("labels", per_class)
        boxlist.add_field("scores", cls_i)
        boxlist = boxlist.clip_to_image(remove_empty=False)
        # boxlist = remove_small_boxes(boxlist, self.min_size)

    return res