def inference_single_image(self, box_cls, box_delta, box_center, box_param, shifts, image_size, fpn_levels, img_id): boxes_all = [] scores_all = [] class_idxs_all = [] box_params_all = [] shifts_all = [] fpn_levels_all = [] # Iterate over every feature level for box_cls_i, box_reg_i, box_ctr_i, box_param_i, shifts_i, fpn_level_i in zip( box_cls, box_delta, box_center, box_param, shifts, fpn_levels): box_cls_i = box_cls_i.flatten().sigmoid_() if self.thresh_with_centerness: box_ctr_i = box_ctr_i.expand( (-1, self.num_classes)).flatten().sigmoid() box_cls_i = box_cls_i * box_ctr_i # Keep top k top scoring indices only. num_topk = min(self.topk_candidates, box_reg_i.shape[0]) # torch.sort is actually faster than .topk (at least on GPUs) predicted_prob, topk_idxs = box_cls_i.sort(descending=True) predicted_prob = predicted_prob[:num_topk] topk_idxs = topk_idxs[:num_topk] # filter out the proposals with low confidence score keep_idxs = predicted_prob > self.score_threshold # after topk predicted_prob = predicted_prob[keep_idxs] topk_idxs = topk_idxs[keep_idxs] shift_idxs = topk_idxs // self.num_classes classes_idxs = topk_idxs % self.num_classes box_reg_i = box_reg_i[shift_idxs] shifts_i = shifts_i[shift_idxs] fpn_level_i = fpn_level_i[shift_idxs] # predict boxes predicted_boxes = self.shift2box_transform.apply_deltas( box_reg_i, shifts_i) if not self.thresh_with_centerness: box_ctr_i = box_ctr_i.flatten().sigmoid_()[shift_idxs] predicted_prob = predicted_prob * box_ctr_i # instances conv params for predicted boxes box_param = box_param_i[shift_idxs] boxes_all.append(predicted_boxes) scores_all.append(torch.sqrt(predicted_prob)) class_idxs_all.append(classes_idxs) box_params_all.append(box_param) shifts_all.append(shifts_i) fpn_levels_all.append(fpn_level_i) boxes_all, scores_all, class_idxs_all, box_params_all, shifts_all, fpn_levels_all = [ cat(x) for x in [ boxes_all, scores_all, class_idxs_all, box_params_all, shifts_all, fpn_levels_all ] ] keep = generalized_batched_nms(boxes_all, scores_all, class_idxs_all, self.nms_threshold, nms_type=self.nms_type) keep = keep[:self.max_detections_per_image] im_inds = scores_all.new_ones(len(scores_all), dtype=torch.long) * img_id proposals_i = Instances(image_size) proposals_i.pred_boxes = Boxes(boxes_all[keep]) proposals_i.scores = scores_all[keep] proposals_i.pred_classes = class_idxs_all[keep] proposals_i.inst_parmas = box_params_all[keep] proposals_i.fpn_levels = fpn_levels_all[keep] proposals_i.shifts = shifts_all[keep] proposals_i.im_inds = im_inds[keep] return proposals_i
def proposals_losses(self, gt_classes, gt_shifts_deltas, gt_centerness, gt_inds, im_inds, pred_class_logits, pred_shift_deltas, pred_centerness, pred_inst_params, fpn_levels, shifts): pred_class_logits, pred_shift_deltas, pred_centerness, pred_inst_params = \ permute_all_to_N_HWA_K_and_concat( pred_class_logits, pred_shift_deltas, pred_centerness, pred_inst_params, self.num_gen_params, self.num_classes ) # Shapes: (N x R, K) and (N x R, 4), respectively. gt_classes = gt_classes.flatten() gt_shifts_deltas = gt_shifts_deltas.reshape(-1, 4) gt_centerness = gt_centerness.reshape(-1, 1) fpn_levels = fpn_levels.flatten() im_inds = im_inds.flatten() gt_inds = gt_inds.flatten() valid_idxs = gt_classes >= 0 foreground_idxs = (gt_classes >= 0) & (gt_classes != self.num_classes) num_foreground = foreground_idxs.sum() gt_classes_target = torch.zeros_like(pred_class_logits) gt_classes_target[foreground_idxs, gt_classes[foreground_idxs]] = 1 num_foreground = comm.all_reduce(num_foreground) / float( comm.get_world_size()) num_foreground_centerness = gt_centerness[foreground_idxs].sum() num_targets = comm.all_reduce(num_foreground_centerness) / float( comm.get_world_size()) # logits loss loss_cls = sigmoid_focal_loss_jit( pred_class_logits[valid_idxs], gt_classes_target[valid_idxs], alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="sum", ) / max(1.0, num_foreground) # regression loss loss_box_reg = iou_loss(pred_shift_deltas[foreground_idxs], gt_shifts_deltas[foreground_idxs], gt_centerness[foreground_idxs], box_mode="ltrb", loss_type=self.iou_loss_type, reduction="sum", smooth=self.iou_smooth) / max( 1e-6, num_targets) # centerness loss loss_centerness = F.binary_cross_entropy_with_logits( pred_centerness[foreground_idxs], gt_centerness[foreground_idxs], reduction="sum", ) / max(1, num_foreground) proposals_losses = { "loss_cls": loss_cls, "loss_box_reg": loss_box_reg, "loss_centerness": loss_centerness } all_shifts = torch.cat([torch.cat(shift) for shift in shifts]) proposals = Instances((0, 0)) proposals.inst_parmas = pred_inst_params[foreground_idxs] proposals.fpn_levels = fpn_levels[foreground_idxs] proposals.shifts = all_shifts[foreground_idxs] proposals.gt_inds = gt_inds[foreground_idxs] proposals.im_inds = im_inds[foreground_idxs] # select_instances for saving memory if len(proposals): if self.topk_proposals_per_im != -1: proposals.gt_cls = gt_classes[foreground_idxs] proposals.pred_logits = pred_class_logits[foreground_idxs] proposals.pred_centerness = pred_centerness[foreground_idxs] proposals = self.select_instances(proposals) return proposals_losses, proposals