Beispiel #1
0
    def inference_single_image(self, box_cls, box_delta, box_center, box_param,
                               shifts, image_size, fpn_levels, img_id):
        boxes_all = []
        scores_all = []
        class_idxs_all = []
        box_params_all = []
        shifts_all = []
        fpn_levels_all = []
        # Iterate over every feature level
        for box_cls_i, box_reg_i, box_ctr_i, box_param_i, shifts_i, fpn_level_i in zip(
                box_cls, box_delta, box_center, box_param, shifts, fpn_levels):

            box_cls_i = box_cls_i.flatten().sigmoid_()
            if self.thresh_with_centerness:
                box_ctr_i = box_ctr_i.expand(
                    (-1, self.num_classes)).flatten().sigmoid()
                box_cls_i = box_cls_i * box_ctr_i

            # Keep top k top scoring indices only.
            num_topk = min(self.topk_candidates, box_reg_i.shape[0])
            # torch.sort is actually faster than .topk (at least on GPUs)
            predicted_prob, topk_idxs = box_cls_i.sort(descending=True)
            predicted_prob = predicted_prob[:num_topk]
            topk_idxs = topk_idxs[:num_topk]

            # filter out the proposals with low confidence score
            keep_idxs = predicted_prob > self.score_threshold  # after topk
            predicted_prob = predicted_prob[keep_idxs]
            topk_idxs = topk_idxs[keep_idxs]

            shift_idxs = topk_idxs // self.num_classes
            classes_idxs = topk_idxs % self.num_classes

            box_reg_i = box_reg_i[shift_idxs]
            shifts_i = shifts_i[shift_idxs]
            fpn_level_i = fpn_level_i[shift_idxs]
            # predict boxes
            predicted_boxes = self.shift2box_transform.apply_deltas(
                box_reg_i, shifts_i)

            if not self.thresh_with_centerness:
                box_ctr_i = box_ctr_i.flatten().sigmoid_()[shift_idxs]
                predicted_prob = predicted_prob * box_ctr_i

            # instances conv params for predicted boxes
            box_param = box_param_i[shift_idxs]

            boxes_all.append(predicted_boxes)
            scores_all.append(torch.sqrt(predicted_prob))
            class_idxs_all.append(classes_idxs)
            box_params_all.append(box_param)
            shifts_all.append(shifts_i)
            fpn_levels_all.append(fpn_level_i)

        boxes_all, scores_all, class_idxs_all, box_params_all, shifts_all, fpn_levels_all = [
            cat(x) for x in [
                boxes_all, scores_all, class_idxs_all, box_params_all,
                shifts_all, fpn_levels_all
            ]
        ]

        keep = generalized_batched_nms(boxes_all,
                                       scores_all,
                                       class_idxs_all,
                                       self.nms_threshold,
                                       nms_type=self.nms_type)
        keep = keep[:self.max_detections_per_image]

        im_inds = scores_all.new_ones(len(scores_all),
                                      dtype=torch.long) * img_id
        proposals_i = Instances(image_size)
        proposals_i.pred_boxes = Boxes(boxes_all[keep])
        proposals_i.scores = scores_all[keep]
        proposals_i.pred_classes = class_idxs_all[keep]

        proposals_i.inst_parmas = box_params_all[keep]
        proposals_i.fpn_levels = fpn_levels_all[keep]
        proposals_i.shifts = shifts_all[keep]
        proposals_i.im_inds = im_inds[keep]

        return proposals_i
Beispiel #2
0
    def proposals_losses(self, gt_classes, gt_shifts_deltas, gt_centerness,
                         gt_inds, im_inds, pred_class_logits,
                         pred_shift_deltas, pred_centerness, pred_inst_params,
                         fpn_levels, shifts):

        pred_class_logits, pred_shift_deltas, pred_centerness, pred_inst_params = \
            permute_all_to_N_HWA_K_and_concat(
                pred_class_logits, pred_shift_deltas, pred_centerness,
                pred_inst_params, self.num_gen_params, self.num_classes
            )  # Shapes: (N x R, K) and (N x R, 4), respectively.

        gt_classes = gt_classes.flatten()
        gt_shifts_deltas = gt_shifts_deltas.reshape(-1, 4)
        gt_centerness = gt_centerness.reshape(-1, 1)
        fpn_levels = fpn_levels.flatten()
        im_inds = im_inds.flatten()
        gt_inds = gt_inds.flatten()

        valid_idxs = gt_classes >= 0
        foreground_idxs = (gt_classes >= 0) & (gt_classes != self.num_classes)
        num_foreground = foreground_idxs.sum()

        gt_classes_target = torch.zeros_like(pred_class_logits)
        gt_classes_target[foreground_idxs, gt_classes[foreground_idxs]] = 1

        num_foreground = comm.all_reduce(num_foreground) / float(
            comm.get_world_size())
        num_foreground_centerness = gt_centerness[foreground_idxs].sum()
        num_targets = comm.all_reduce(num_foreground_centerness) / float(
            comm.get_world_size())

        # logits loss
        loss_cls = sigmoid_focal_loss_jit(
            pred_class_logits[valid_idxs],
            gt_classes_target[valid_idxs],
            alpha=self.focal_loss_alpha,
            gamma=self.focal_loss_gamma,
            reduction="sum",
        ) / max(1.0, num_foreground)

        # regression loss
        loss_box_reg = iou_loss(pred_shift_deltas[foreground_idxs],
                                gt_shifts_deltas[foreground_idxs],
                                gt_centerness[foreground_idxs],
                                box_mode="ltrb",
                                loss_type=self.iou_loss_type,
                                reduction="sum",
                                smooth=self.iou_smooth) / max(
                                    1e-6, num_targets)

        # centerness loss
        loss_centerness = F.binary_cross_entropy_with_logits(
            pred_centerness[foreground_idxs],
            gt_centerness[foreground_idxs],
            reduction="sum",
        ) / max(1, num_foreground)

        proposals_losses = {
            "loss_cls": loss_cls,
            "loss_box_reg": loss_box_reg,
            "loss_centerness": loss_centerness
        }

        all_shifts = torch.cat([torch.cat(shift) for shift in shifts])
        proposals = Instances((0, 0))
        proposals.inst_parmas = pred_inst_params[foreground_idxs]
        proposals.fpn_levels = fpn_levels[foreground_idxs]
        proposals.shifts = all_shifts[foreground_idxs]
        proposals.gt_inds = gt_inds[foreground_idxs]
        proposals.im_inds = im_inds[foreground_idxs]

        # select_instances for saving memory
        if len(proposals):
            if self.topk_proposals_per_im != -1:
                proposals.gt_cls = gt_classes[foreground_idxs]
                proposals.pred_logits = pred_class_logits[foreground_idxs]
                proposals.pred_centerness = pred_centerness[foreground_idxs]
            proposals = self.select_instances(proposals)

        return proposals_losses, proposals