def aux_losses(self, gt_classes, pred_class_logits): pred_class_logits = cat([ permute_to_N_HWA_K(x, self.num_classes) for x in pred_class_logits ], dim=1).view(-1, self.num_classes) gt_classes = gt_classes.flatten() valid_idxs = gt_classes >= 0 foreground_idxs = (gt_classes >= 0) & (gt_classes != self.num_classes) num_foreground = foreground_idxs.sum() gt_classes_target = torch.zeros_like(pred_class_logits) gt_classes_target[foreground_idxs, gt_classes[foreground_idxs]] = 1 num_foreground = comm.all_reduce(num_foreground) / float( comm.get_world_size()) # logits loss loss_cls_aux = sigmoid_focal_loss_jit( pred_class_logits[valid_idxs], gt_classes_target[valid_idxs], alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="sum", ) / max(1.0, num_foreground) return {"loss_cls_aux": loss_cls_aux}
def losses(self, gt_classes, gt_shifts_deltas, pred_class_logits, pred_shift_deltas): """ Args: For `gt_classes` and `gt_shifts_deltas` parameters, see :meth:`FCOS.get_ground_truth`. Their shapes are (N, R) and (N, R, 4), respectively, where R is the total number of shifts across levels, i.e. sum(Hi x Wi) For `pred_class_logits` and `pred_shift_deltas`, see :meth:`FCOSHead.forward`. Returns: dict[str: Tensor]: mapping from a named loss to a scalar tensor storing the loss. Used during training only. The dict keys are: "loss_cls" and "loss_box_reg" """ pred_class_logits, pred_shift_deltas = \ permute_all_cls_and_box_to_N_HWA_K_and_concat( pred_class_logits, pred_shift_deltas, self.num_classes ) # Shapes: (N x R, K) and (N x R, 4), respectively. gt_classes = gt_classes.flatten() gt_shifts_deltas = gt_shifts_deltas.view(-1, 4) valid_idxs = gt_classes >= 0 foreground_idxs = (gt_classes >= 0) & (gt_classes != self.num_classes) num_foreground = foreground_idxs.sum() gt_classes_target = torch.zeros_like(pred_class_logits) gt_classes_target[foreground_idxs, gt_classes[foreground_idxs]] = 1 num_foreground = comm.all_reduce(num_foreground) / float( comm.get_world_size()) # logits loss loss_cls = sigmoid_focal_loss_jit( pred_class_logits[valid_idxs], gt_classes_target[valid_idxs], alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="sum", ) / max(1.0, num_foreground) # regression loss loss_box_reg = iou_loss( pred_shift_deltas[foreground_idxs], gt_shifts_deltas[foreground_idxs], box_mode="ltrb", loss_type=self.iou_loss_type, reduction="sum", ) / max(1.0, num_foreground) * self.reg_weight return { "loss_cls": loss_cls, "loss_box_reg": loss_box_reg, }
def losses(self, gt_classes, gt_anchors_deltas, pred_class_logits, pred_anchor_deltas): """ Args: For `gt_classes` and `gt_anchors_deltas` parameters, see :meth:`RetinaNet.get_ground_truth`. Their shapes are (N, R) and (N, R, 4), respectively, where R is the total number of anchors across levels, i.e. sum(Hi x Wi x A) For `pred_class_logits` and `pred_anchor_deltas`, see :meth:`RetinaNetHead.forward`. Returns: dict[str: Tensor]: mapping from a named loss to a scalar tensor storing the loss. Used during training only. The dict keys are: "loss_cls" and "loss_box_reg" """ pred_class_logits, pred_anchor_deltas = permute_all_cls_and_box_to_N_HWA_K_and_concat( pred_class_logits, pred_anchor_deltas, self.num_classes ) # Shapes: (N x R, K) and (N x R, 4), respectively. gt_classes = gt_classes.flatten() gt_anchors_deltas = gt_anchors_deltas.view(-1, 4) valid_idxs = gt_classes >= 0 foreground_idxs = (gt_classes >= 0) & (gt_classes != self.num_classes) num_foreground = foreground_idxs.sum() gt_classes_target = torch.zeros_like(pred_class_logits) gt_classes_target[foreground_idxs, gt_classes[foreground_idxs]] = 1 self.loss_normalizer = self.loss_normalizer_momentum * self.loss_normalizer + ( 1 - self.loss_normalizer_momentum) * max(num_foreground.item(), 1) # logits loss loss_cls = sigmoid_focal_loss_jit( pred_class_logits[valid_idxs], gt_classes_target[valid_idxs], alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="sum", ) / self.loss_normalizer # regression loss loss_box_reg = smooth_l1_loss( pred_anchor_deltas[foreground_idxs], gt_anchors_deltas[foreground_idxs], beta=self.smooth_l1_loss_beta, reduction="sum", ) / self.loss_normalizer return {"loss_cls": loss_cls, "loss_box_reg": loss_box_reg}
def losses(self, gt_classes, gt_anchors_deltas, pred_class_logits, pred_anchor_deltas): """ Args: For `gt_classes` and `gt_anchors_deltas` parameters, see :meth:`EfficientDet.get_ground_truth`. Their shapes are (N, R) and (N, R, 4), respectively, where R is the total number of anchors across levels, i.e. sum(Hi x Wi x A) For `pred_class_logits` and `pred_anchor_deltas`, see :meth:`EfficientDetHead.forward`. Returns: dict[str: Tensor]: mapping from a named loss to a scalar tensor storing the loss. Used during training only. The dict keys are: "loss_cls" and "loss_box_reg" """ pred_class_logits, pred_anchor_deltas = permute_all_cls_and_box_to_N_HWA_K_and_concat( pred_class_logits, pred_anchor_deltas, self.num_classes ) # Shapes: (N x R, K) and (N x R, 4), respectively. gt_classes = gt_classes.flatten() gt_anchors_deltas = gt_anchors_deltas.view(-1, 4) valid_idxs = gt_classes >= 0 foreground_idxs = (gt_classes >= 0) & (gt_classes != self.num_classes) num_foreground = foreground_idxs.sum() gt_classes_target = torch.zeros_like(pred_class_logits) gt_classes_target[foreground_idxs, gt_classes[foreground_idxs]] = 1 # Classification loss loss_cls = sigmoid_focal_loss_jit( pred_class_logits[valid_idxs], gt_classes_target[valid_idxs], alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="sum", ) / max(1, num_foreground) # Regression loss, refer to the official released code. # See: https://github.com/google/automl/blob/master/efficientdet/det_model_fn.py loss_box_reg = self.box_loss_weight * self.smooth_l1_loss_beta * smooth_l1_loss( pred_anchor_deltas[foreground_idxs], gt_anchors_deltas[foreground_idxs], beta=self.smooth_l1_loss_beta, reduction="sum", ) / max(1, num_foreground * self.regress_norm) return {"loss_cls": loss_cls, "loss_box_reg": loss_box_reg}
def loss_labels(self, outputs, targets, indices, num_boxes, log=True): """Classification loss (NLL) targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] """ assert "pred_logits" in outputs src_logits = outputs["pred_logits"] idx = self._get_src_permutation_idx(indices) target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) target_classes = torch.full( src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device ) target_classes[idx] = target_classes_o if self.use_focal: flat_src_logits = src_logits.flatten(0, 1) target_classes = target_classes.flatten(0, 1) pos_inds = torch.nonzero(target_classes != self.num_classes, as_tuple=True)[0] labels = torch.zeros_like(flat_src_logits) labels[pos_inds, target_classes[pos_inds]] = 1 # comp focal loss. class_loss = sigmoid_focal_loss_jit( flat_src_logits, labels, alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="sum", ) / num_boxes losses = {'loss_ce': class_loss} else: loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight) losses = {"loss_ce": loss_ce} if log: losses["class_error"] = 100 - accuracy(src_logits[idx], target_classes_o)[0] return losses
def losses( self, gt_classes, gt_shifts_deltas, gt_centerness, gt_classes_border, gt_deltas_border, pred_class_logits, pred_shift_deltas, pred_centerness, border_box_cls, border_bbox_reg, ): """ Args: For `gt_classes`, `gt_shifts_deltas` and `gt_centerness` parameters, see :meth:`BorderDet.get_ground_truth`. Their shapes are (N, R) and (N, R, 4), respectively, where R is the total number of shifts across levels, i.e. sum(Hi x Wi) For `pred_class_logits`, `pred_shift_deltas` and `pred_centerness`, see :meth:`BorderHead.forward`. Returns: dict[str: Tensor]: mapping from a named loss to a scalar tensor storing the loss. Used during training only. The dict keys are: "loss_cls" and "loss_box_reg" """ ( pred_class_logits, pred_shift_deltas, pred_centerness, border_class_logits, border_shift_deltas, ) = permute_all_cls_and_box_to_N_HWA_K_and_concat( pred_class_logits, pred_shift_deltas, pred_centerness, border_box_cls, border_bbox_reg, self.num_classes ) # Shapes: (N x R, K) and (N x R, 4), respectively. # fcos gt_classes = gt_classes.flatten() gt_shifts_deltas = gt_shifts_deltas.view(-1, 4) gt_centerness = gt_centerness.view(-1, 1) valid_idxs = gt_classes >= 0 foreground_idxs = (gt_classes >= 0) & (gt_classes != self.num_classes) num_foreground = foreground_idxs.sum() acc_centerness_num = gt_centerness[foreground_idxs].sum() gt_classes_target = torch.zeros_like(pred_class_logits) gt_classes_target[foreground_idxs, gt_classes[foreground_idxs]] = 1 dist.all_reduce(num_foreground) num_foreground /= dist.get_world_size() dist.all_reduce(acc_centerness_num) acc_centerness_num /= dist.get_world_size() # logits loss loss_cls = sigmoid_focal_loss_jit( pred_class_logits[valid_idxs], gt_classes_target[valid_idxs], alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="sum", ) / max(1, num_foreground) # regression loss loss_box_reg = iou_loss( pred_shift_deltas[foreground_idxs], gt_shifts_deltas[foreground_idxs], gt_centerness[foreground_idxs], box_mode="ltrb", loss_type=self.iou_loss_type, reduction="sum", ) / max(1, acc_centerness_num) # centerness loss loss_centerness = F.binary_cross_entropy_with_logits( pred_centerness[foreground_idxs], gt_centerness[foreground_idxs], reduction="sum", ) / max(1, num_foreground) # borderdet gt_classes_border = gt_classes_border.flatten() gt_deltas_border = gt_deltas_border.view(-1, 4) valid_idxs_border = gt_classes_border >= 0 foreground_idxs_border = (gt_classes_border >= 0) & (gt_classes_border != self.num_classes) num_foreground_border = foreground_idxs_border.sum() gt_classes_border_target = torch.zeros_like(border_class_logits) gt_classes_border_target[foreground_idxs_border, gt_classes_border[foreground_idxs_border]] = 1 dist.all_reduce(num_foreground_border) num_foreground_border /= dist.get_world_size() num_foreground_border = max(num_foreground_border, 1.0) loss_border_cls = sigmoid_focal_loss_jit( border_class_logits[valid_idxs_border], gt_classes_border_target[valid_idxs_border], alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="sum", ) / num_foreground_border if foreground_idxs_border.numel() > 0: loss_border_reg = ( smooth_l1_loss(border_shift_deltas[foreground_idxs_border], gt_deltas_border[foreground_idxs_border], beta=0, reduction="sum") / num_foreground_border) else: loss_border_reg = border_shift_deltas.sum() return { "loss_cls": loss_cls, "loss_box_reg": loss_box_reg, "loss_centerness": loss_centerness, "loss_border_cls": loss_border_cls, "loss_border_reg": loss_border_reg, }
def losses(self, ins_preds, cate_preds, ins_label_list, cate_label_list, ins_ind_label_list): """ Compute losses: L = L_cate + λ * L_mask Args: ins_preds (list[Tensor]): each element in the list is mask prediction results of one level, and the shape of each element is [N, G*G, H, W], where: * N is the number of images per mini-batch * G is the side length of each level of the grids * H and W is the height and width of the predicted mask cate_preds (list[Tensor]): each element in the list is category prediction results of one level, and the shape of each element is [#N, #C, #G, #G], where: * C is the number of classes ins_label_list (list[list[Tensor]]): each element in the list is mask ground truth of one image, and each element is a list which contains mask tensors per level with shape [H, W], where: * H and W is the ground truth mask size per level (same as `ins_preds`) cate_label_list (list[list[Tensor]]): each element in the list is category ground truth of one image, and each element is a list which contains tensors with shape [G, G] per level. ins_ind_label_list (list[list[Tensor]]): used to indicate which grids contain objects, these grids need to calculate mask loss. Each element in the list is indicator of one image, and each element is a list which contains tensors with shape [G*G] per level。 Returns: dict[str -> Tensor]: losses. """ # ins, per level ins_preds_valid = [] ins_labels_valid = [] cate_labels_valid = [] num_images = len(ins_label_list) num_levels = len(ins_label_list[0]) for level_idx in range(num_levels): ins_preds_per_level = [] ins_labels_per_level = [] cate_labels_per_level = [] for img_idx in range(num_images): valid_ins_inds = ins_ind_label_list[img_idx][level_idx] ins_preds_per_level.append( ins_preds[level_idx][img_idx][valid_ins_inds, ...]) ins_labels_per_level.append( ins_label_list[img_idx][level_idx][valid_ins_inds, ...]) cate_labels_per_level.append( cate_label_list[img_idx][level_idx].flatten()) ins_preds_valid.append(torch.cat(ins_preds_per_level)) ins_labels_valid.append(torch.cat(ins_labels_per_level)) cate_labels_valid.append(torch.cat(cate_labels_per_level)) # dice loss, per_level loss_ins = [] for input, target in zip(ins_preds_valid, ins_labels_valid): if input.size()[0] == 0: continue input = torch.sigmoid(input) target = target.float() / 255. loss_ins.append(dice_loss(input, target)) # loss_ins (list[Tensor]): each element's shape is [#Ins, #H*#W] loss_ins = torch.cat(loss_ins).mean() loss_ins = loss_ins * self.loss_ins_weight # cate cate_preds = [ cate_pred.permute(0, 2, 3, 1).reshape(-1, self.num_classes) for cate_pred in cate_preds ] cate_preds = torch.cat(cate_preds) flatten_cate_labels = torch.cat(cate_labels_valid) foreground_idxs = flatten_cate_labels != self.num_classes cate_labels = torch.zeros_like(cate_preds) cate_labels[foreground_idxs, flatten_cate_labels[foreground_idxs]] = 1 num_ins = foreground_idxs.sum() loss_cate = self.loss_cat_weight * sigmoid_focal_loss_jit( cate_preds, cate_labels, alpha=self.loss_cat_alpha, gamma=self.loss_cat_gamma, reduction="sum", ) / max(1, num_ins) return dict(loss_ins=loss_ins, loss_cate=loss_cate)
def losses(self, gt_classes, gt_shifts_deltas, gt_centerness, pred_class_logits, pred_shift_deltas, pred_centerness): """ Args: For `gt_classes`, `gt_shifts_deltas` and `gt_centerness` parameters, see :meth:`FCOS.get_ground_truth`. Their shapes are (N, R) and (N, R, 4), respectively, where R is the total number of shifts across levels, i.e. sum(Hi x Wi) For `pred_class_logits`, `pred_shift_deltas` and `pred_centerness`, see :meth:`FCOSHead.forward`. Returns: dict[str: Tensor]: mapping from a named loss to a scalar tensor storing the loss. Used during training only. The dict keys are: "loss_cls" and "loss_box_reg" """ pred_class_logits, pred_shift_deltas, pred_centerness = \ permute_all_cls_and_box_to_N_HWA_K_and_concat( pred_class_logits, pred_shift_deltas, pred_centerness, self.num_classes ) # Shapes: (N x R, K) and (N x R, 4), respectively. gt_classes = gt_classes.flatten() gt_shifts_deltas = gt_shifts_deltas.view(-1, 4) gt_centerness = gt_centerness.view(-1, 1) valid_idxs = gt_classes >= 0 foreground_idxs = (gt_classes >= 0) & (gt_classes != self.num_classes) num_foreground = foreground_idxs.sum() gt_classes_target = torch.zeros_like(pred_class_logits) gt_classes_target[foreground_idxs, gt_classes[foreground_idxs]] = 1 num_foreground = comm.all_reduce(num_foreground) / float(comm.get_world_size()) num_foreground_centerness = gt_centerness[foreground_idxs].sum() num_targets = comm.all_reduce(num_foreground_centerness) / float(comm.get_world_size()) # logits loss loss_cls = sigmoid_focal_loss_jit( pred_class_logits[valid_idxs], gt_classes_target[valid_idxs], alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="sum", ) / max(1.0, num_foreground) # regression loss loss_box_reg = iou_loss( pred_shift_deltas[foreground_idxs], gt_shifts_deltas[foreground_idxs], gt_centerness[foreground_idxs], box_mode="ltrb", loss_type=self.iou_loss_type, reduction="sum", ) / max(1.0, num_targets) # centerness loss loss_centerness = F.binary_cross_entropy_with_logits( pred_centerness[foreground_idxs], gt_centerness[foreground_idxs], reduction="sum", ) / max(1, num_foreground) loss = { "loss_cls": loss_cls, "loss_box_reg": loss_box_reg, "loss_centerness": loss_centerness, } # budget loss if self.is_dynamic_head and self.budget_loss_lambda != 0: soft_cost, used_cost, full_cost = get_module_running_cost(self) loss_budget = (soft_cost / full_cost).mean() * self.budget_loss_lambda storage = get_event_storage() storage.put_scalar("complxity_ratio", (used_cost / full_cost).mean()) loss.update({"loss_budget": loss_budget}) return loss
def losses(self, indices, gt_instances, anchors, pred_class_logits, pred_anchor_deltas): pred_class_logits = cat(pred_class_logits, dim=1).view(-1, self.num_classes) pred_anchor_deltas = cat(pred_anchor_deltas, dim=1).view(-1, 4) anchors = [Boxes.cat(anchors_i) for anchors_i in anchors] N = len(anchors) # list[Tensor(R, 4)], one for each image all_anchors = Boxes.cat(anchors).tensor # Boxes(Tensor(N*R, 4)) predicted_boxes = self.box2box_transform.apply_deltas( pred_anchor_deltas, all_anchors) predicted_boxes = predicted_boxes.reshape(N, -1, 4) ious = [] pos_ious = [] for i in range(N): src_idx, tgt_idx = indices[i] iou, _ = box_iou(predicted_boxes[i, ...], gt_instances[i].gt_boxes.tensor) if iou.numel() == 0: max_iou = iou.new_full((iou.size(0), ), 0) else: max_iou = iou.max(dim=1)[0] a_iou, _ = box_iou(anchors[i].tensor, gt_instances[i].gt_boxes.tensor) if a_iou.numel() == 0: pos_iou = a_iou.new_full((0, ), 0) else: pos_iou = a_iou[src_idx, tgt_idx] ious.append(max_iou) pos_ious.append(pos_iou) ious = torch.cat(ious) ignore_idx = ious > self.neg_ignore_thresh pos_ious = torch.cat(pos_ious) pos_ignore_idx = pos_ious < self.pos_ignore_thresh src_idx = torch.cat([ src + idx * anchors[0].tensor.shape[0] for idx, (src, _) in enumerate(indices) ]) gt_classes = torch.full(pred_class_logits.shape[:1], self.num_classes, dtype=torch.int64, device=pred_class_logits.device) gt_classes[ignore_idx] = -1 target_classes_o = torch.cat( [t.gt_classes[J] for t, (_, J) in zip(gt_instances, indices)]) target_classes_o[pos_ignore_idx] = -1 gt_classes[src_idx] = target_classes_o valid_idxs = gt_classes >= 0 foreground_idxs = (gt_classes >= 0) & (gt_classes != self.num_classes) num_foreground = foreground_idxs.sum() gt_classes_target = torch.zeros_like(pred_class_logits) gt_classes_target[foreground_idxs, gt_classes[foreground_idxs]] = 1 if comm.get_world_size() > 1: dist.all_reduce(num_foreground) num_foreground = num_foreground * 1.0 / comm.get_world_size() # cls loss loss_cls = sigmoid_focal_loss_jit( pred_class_logits[valid_idxs], gt_classes_target[valid_idxs], alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="sum", ) # reg loss target_boxes = torch.cat( [t.gt_boxes.tensor[i] for t, (_, i) in zip(gt_instances, indices)], dim=0) target_boxes = target_boxes[~pos_ignore_idx] matched_predicted_boxes = predicted_boxes.reshape( -1, 4)[src_idx[~pos_ignore_idx]] loss_box_reg = (1 - torch.diag( generalized_box_iou(matched_predicted_boxes, target_boxes))).sum() return { "loss_cls": loss_cls / max(1, num_foreground), "loss_box_reg": loss_box_reg / max(1, num_foreground), }
def losses(self, center_pts, cls_outs, pts_outs_init, pts_outs_refine, targets): """ Args: center_pts: (list[list[Tensor]]): a list of N=#image elements. Each is a list of #feature level tensors. The tensors contains shifts of this image on the specific feature level. cls_outs: List[Tensor], each item in list with shape:[N, num_classes, H, W] pts_outs_init: List[Tensor], each item in list with shape:[N, num_points*2, H, W] pts_outs_refine: List[Tensor], each item in list with shape:[N, num_points*2, H, W] targets: (list[Instances]): a list of N `Instances`s. The i-th `Instances` contains the ground-truth per-instance annotations for the i-th input image. Specify `targets` during training only. Returns: dict[str:Tensor]: mapping from a named loss to scalar tensor """ featmap_sizes = [featmap.size()[-2:] for featmap in cls_outs] assert len(featmap_sizes) == len(center_pts[0]) pts_dim = 2 * self.num_points cls_outs = [ cls_out.permute(0, 2, 3, 1).reshape(cls_out.size(0), -1, self.num_classes) for cls_out in cls_outs ] pts_outs_init = [ pts_out_init.permute(0, 2, 3, 1).reshape(pts_out_init.size(0), -1, pts_dim) for pts_out_init in pts_outs_init ] pts_outs_refine = [ pts_out_refine.permute(0, 2, 3, 1).reshape(pts_out_refine.size(0), -1, pts_dim) for pts_out_refine in pts_outs_refine ] cls_outs = torch.cat(cls_outs, dim=1) pts_outs_init = torch.cat(pts_outs_init, dim=1) pts_outs_refine = torch.cat(pts_outs_refine, dim=1) pts_strides = [] for i, s in enumerate(center_pts[0]): pts_strides.append( cls_outs.new_full((s.size(0), ), self.fpn_strides[i])) pts_strides = torch.cat(pts_strides, dim=0) center_pts = [ torch.cat(c_pts, dim=0).to(self.device) for c_pts in center_pts ] pred_cls = [] pred_init = [] pred_refine = [] target_cls = [] target_init = [] target_refine = [] num_pos_init = 0 num_pos_refine = 0 for img, (per_center_pts, cls_prob, pts_init, pts_refine, per_targets) in enumerate( zip(center_pts, cls_outs, pts_outs_init, pts_outs_refine, targets)): assert per_center_pts.shape[:-1] == cls_prob.shape[:-1] gt_bboxes = per_targets.gt_boxes.to(cls_prob.device) gt_labels = per_targets.gt_classes.to(cls_prob.device) pts_init_bbox_targets, pts_init_labels_targets = \ self.point_targets(per_center_pts, pts_strides, gt_bboxes.tensor, gt_labels) # per_center_pts, shape:[N, 18] per_center_pts_repeat = per_center_pts.repeat(1, self.num_points) normalize_term = self.point_base_scale * pts_strides normalize_term = normalize_term.reshape(-1, 1) # bbox_center = torch.cat([per_center_pts, per_center_pts], dim=1) per_pts_strides = pts_strides.reshape(-1, 1) pts_init_coordinate = pts_init * per_pts_strides + \ per_center_pts_repeat init_bbox_pred = self.pts_to_bbox(pts_init_coordinate) foreground_idxs = (pts_init_labels_targets >= 0) & \ (pts_init_labels_targets != self.num_classes) pred_init.append(init_bbox_pred[foreground_idxs] / normalize_term[foreground_idxs]) target_init.append(pts_init_bbox_targets[foreground_idxs] / normalize_term[foreground_idxs]) num_pos_init += foreground_idxs.sum() # A another way to convert predicted offset to bbox # bbox_pred_init = self.pts_to_bbox(pts_init.detach()) * \ # per_pts_strides # init_bbox_pred = bbox_center + bbox_pred_init pts_refine_bbox_targets, pts_refine_labels_targets = \ self.bbox_targets(init_bbox_pred, gt_bboxes, gt_labels) pts_refine_coordinate = pts_refine * per_pts_strides + \ per_center_pts_repeat refine_bbox_pred = self.pts_to_bbox(pts_refine_coordinate) # bbox_pred_refine = self.pts_to_bbox(pts_refine) * per_pts_strides # refine_bbox_pred = bbox_center + bbox_pred_refine foreground_idxs = (pts_refine_labels_targets >= 0) & \ (pts_refine_labels_targets != self.num_classes) pred_refine.append(refine_bbox_pred[foreground_idxs] / normalize_term[foreground_idxs]) target_refine.append(pts_refine_bbox_targets[foreground_idxs] / normalize_term[foreground_idxs]) num_pos_refine += foreground_idxs.sum() gt_classes_target = torch.zeros_like(cls_prob) gt_classes_target[foreground_idxs, pts_refine_labels_targets[foreground_idxs]] = 1 pred_cls.append(cls_prob) target_cls.append(gt_classes_target) pred_cls = torch.cat(pred_cls, dim=0) pred_init = torch.cat(pred_init, dim=0) pred_refine = torch.cat(pred_refine, dim=0) target_cls = torch.cat(target_cls, dim=0) target_init = torch.cat(target_init, dim=0) target_refine = torch.cat(target_refine, dim=0) loss_cls = sigmoid_focal_loss_jit( pred_cls, target_cls, alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="sum") / max( 1, num_pos_refine.item()) * self.loss_cls_weight loss_pts_init = smooth_l1_loss( pred_init, target_init, beta=0.11, reduction='sum') / max( 1, num_pos_init.item()) * self.loss_bbox_init_weight loss_pts_refine = smooth_l1_loss( pred_refine, target_refine, beta=0.11, reduction='sum') / max( 1, num_pos_refine.item()) * self.loss_bbox_refine_weight return { "loss_cls": loss_cls, "loss_pts_init": loss_pts_init, "loss_pts_refine": loss_pts_refine }
def losses(self, ins_preds_x, ins_preds_y, cate_preds, ins_label_list, cate_label_list, ins_ind_label_list, ins_ind_label_list_xy): # ins, per level ins_labels = [] # per level for ins_labels_level, ins_ind_labels_level in \ zip(zip(*ins_label_list), zip(*ins_ind_label_list)): ins_labels_per_level = [] for ins_labels_level_img, ins_ind_labels_level_img in \ zip(ins_labels_level, ins_ind_labels_level): ins_labels_per_level.append( ins_labels_level_img[ins_ind_labels_level_img, ...]) ins_labels.append(torch.cat(ins_labels_per_level)) ins_preds_x_final = [] for ins_preds_level_x, ins_ind_labels_level in \ zip(ins_preds_x, zip(*ins_ind_label_list_xy)): ins_preds_x_final_per_level = [] for ins_preds_level_img_x, ins_ind_labels_level_img in \ zip(ins_preds_level_x, ins_ind_labels_level): ins_preds_x_final_per_level.append( ins_preds_level_img_x[ins_ind_labels_level_img[:, 1], ...]) ins_preds_x_final.append(torch.cat(ins_preds_x_final_per_level)) ins_preds_y_final = [] for ins_preds_level_y, ins_ind_labels_level in \ zip(ins_preds_y, zip(*ins_ind_label_list_xy)): ins_preds_y_final_per_level = [] for ins_preds_level_img_y, ins_ind_labels_level_img in \ zip(ins_preds_level_y, ins_ind_labels_level): ins_preds_y_final_per_level.append( ins_preds_level_img_y[ins_ind_labels_level_img[:, 0], ...]) ins_preds_y_final.append(torch.cat(ins_preds_y_final_per_level)) num_ins = 0. # dice loss, per_level loss_ins = [] for input_x, input_y, target in zip(ins_preds_x_final, ins_preds_y_final, ins_labels): mask_n = input_x.size(0) if mask_n == 0: continue num_ins += mask_n input = (input_x.sigmoid()) * (input_y.sigmoid()) target = target.float() / 255. loss_ins.append(dice_loss(input, target)) loss_ins = torch.cat(loss_ins).mean() loss_ins = loss_ins * self.loss_ins_weight # cate cate_preds = [ cate_pred.permute(0, 2, 3, 1).reshape(-1, self.num_classes) for cate_pred in cate_preds ] cate_preds = torch.cat(cate_preds) cate_labels = [] for cate_labels_level in zip(*cate_label_list): cate_labels_per_level = [] for cate_labels_level_img in cate_labels_level: cate_labels_per_level.append(cate_labels_level_img.flatten()) cate_labels.append(torch.cat(cate_labels_per_level)) flatten_cate_labels = torch.cat(cate_labels) foreground_idxs = flatten_cate_labels != self.num_classes cate_labels = torch.zeros_like(cate_preds) cate_labels[foreground_idxs, flatten_cate_labels[foreground_idxs]] = 1 loss_cate = self.loss_cat_weight * sigmoid_focal_loss_jit( cate_preds, cate_labels, alpha=self.loss_cat_alpha, gamma=self.loss_cat_gamma, reduction="sum", ) / max(1, num_ins) return dict(loss_ins=loss_ins, loss_cate=loss_cate)
def proposals_losses(self, gt_classes, gt_shifts_deltas, gt_centerness, gt_inds, im_inds, pred_class_logits, pred_shift_deltas, pred_centerness, pred_inst_params, fpn_levels, shifts): pred_class_logits, pred_shift_deltas, pred_centerness, pred_inst_params = \ permute_all_to_N_HWA_K_and_concat( pred_class_logits, pred_shift_deltas, pred_centerness, pred_inst_params, self.num_gen_params, self.num_classes ) # Shapes: (N x R, K) and (N x R, 4), respectively. gt_classes = gt_classes.flatten() gt_shifts_deltas = gt_shifts_deltas.reshape(-1, 4) gt_centerness = gt_centerness.reshape(-1, 1) fpn_levels = fpn_levels.flatten() im_inds = im_inds.flatten() gt_inds = gt_inds.flatten() valid_idxs = gt_classes >= 0 foreground_idxs = (gt_classes >= 0) & (gt_classes != self.num_classes) num_foreground = foreground_idxs.sum() gt_classes_target = torch.zeros_like(pred_class_logits) gt_classes_target[foreground_idxs, gt_classes[foreground_idxs]] = 1 num_foreground = comm.all_reduce(num_foreground) / float( comm.get_world_size()) num_foreground_centerness = gt_centerness[foreground_idxs].sum() num_targets = comm.all_reduce(num_foreground_centerness) / float( comm.get_world_size()) # logits loss loss_cls = sigmoid_focal_loss_jit( pred_class_logits[valid_idxs], gt_classes_target[valid_idxs], alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="sum", ) / max(1.0, num_foreground) # regression loss loss_box_reg = iou_loss(pred_shift_deltas[foreground_idxs], gt_shifts_deltas[foreground_idxs], gt_centerness[foreground_idxs], box_mode="ltrb", loss_type=self.iou_loss_type, reduction="sum", smooth=self.iou_smooth) / max( 1e-6, num_targets) # centerness loss loss_centerness = F.binary_cross_entropy_with_logits( pred_centerness[foreground_idxs], gt_centerness[foreground_idxs], reduction="sum", ) / max(1, num_foreground) proposals_losses = { "loss_cls": loss_cls, "loss_box_reg": loss_box_reg, "loss_centerness": loss_centerness } all_shifts = torch.cat([torch.cat(shift) for shift in shifts]) proposals = Instances((0, 0)) proposals.inst_parmas = pred_inst_params[foreground_idxs] proposals.fpn_levels = fpn_levels[foreground_idxs] proposals.shifts = all_shifts[foreground_idxs] proposals.gt_inds = gt_inds[foreground_idxs] proposals.im_inds = im_inds[foreground_idxs] # select_instances for saving memory if len(proposals): if self.topk_proposals_per_im != -1: proposals.gt_cls = gt_classes[foreground_idxs] proposals.pred_logits = pred_class_logits[foreground_idxs] proposals.pred_centerness = pred_centerness[foreground_idxs] proposals = self.select_instances(proposals) return proposals_losses, proposals
def get_lla_assignments_and_losses(self, shifts, targets, box_cls, box_delta, box_iou): gt_classes = [] box_cls = [permute_to_N_HWA_K(x, self.num_classes) for x in box_cls] box_delta = [permute_to_N_HWA_K(x, 4) for x in box_delta] box_iou = [permute_to_N_HWA_K(x, 1) for x in box_iou] box_cls = torch.cat(box_cls, dim=1) box_delta = torch.cat(box_delta, dim=1) box_iou = torch.cat(box_iou, dim=1) losses_cls = [] losses_box_reg = [] losses_iou = [] num_fg = 0 for shifts_per_image, targets_per_image, box_cls_per_image, \ box_delta_per_image, box_iou_per_image in zip( shifts, targets, box_cls, box_delta, box_iou): shifts_over_all = torch.cat(shifts_per_image, dim=0) gt_boxes = targets_per_image.gt_boxes gt_classes = targets_per_image.gt_classes deltas = self.shift2box_transform.get_deltas( shifts_over_all, gt_boxes.tensor.unsqueeze(1)) is_in_boxes = deltas.min(dim=-1).values > 0.01 shape = (len(targets_per_image), len(shifts_over_all), -1) box_cls_per_image_unexpanded = box_cls_per_image box_delta_per_image_unexpanded = box_delta_per_image box_cls_per_image = box_cls_per_image.unsqueeze(0).expand(shape) gt_cls_per_image = F.one_hot( torch.max(gt_classes, torch.zeros_like(gt_classes)), self.num_classes).float().unsqueeze(1).expand(shape) with torch.no_grad(): loss_cls = sigmoid_focal_loss_jit( box_cls_per_image, gt_cls_per_image, alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma).sum(dim=-1) loss_cls_bg = sigmoid_focal_loss_jit( box_cls_per_image_unexpanded, torch.zeros_like(box_cls_per_image_unexpanded), alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma).sum(dim=-1) box_delta_per_image = box_delta_per_image.unsqueeze(0).expand( shape) gt_delta_per_image = self.shift2box_transform.get_deltas( shifts_over_all, gt_boxes.tensor.unsqueeze(1)) loss_delta = iou_loss(box_delta_per_image, gt_delta_per_image, box_mode="ltrb", loss_type='iou') ious = get_ious(box_delta_per_image, gt_delta_per_image, box_mode="ltrb", loss_type='iou') loss = loss_cls + self.reg_cost * loss_delta + 1e3 * ( 1 - is_in_boxes.float()) loss = torch.cat([loss, loss_cls_bg.unsqueeze(0)], dim=0) num_gt = loss.shape[0] - 1 num_anchor = loss.shape[1] # Topk matching_matrix = torch.zeros_like(loss) _, topk_idx = torch.topk(loss[:-1], k=self.topk, dim=1, largest=False) matching_matrix[torch.arange(num_gt).unsqueeze(1). repeat(1, self.topk).view(-1), topk_idx.view(-1)] = 1. # make sure one anchor with one gt anchor_matched_gt = matching_matrix.sum(0) if (anchor_matched_gt > 1).sum() > 0: loss_min, loss_argmin = torch.min( loss[:-1, anchor_matched_gt > 1], dim=0) matching_matrix[:, anchor_matched_gt > 1] *= 0. matching_matrix[loss_argmin, anchor_matched_gt > 1] = 1. anchor_matched_gt = matching_matrix.sum(0) num_fg += matching_matrix.sum() matching_matrix[ -1] = 1. - anchor_matched_gt # assignment for Background assigned_gt_inds = torch.argmax(matching_matrix, dim=0) gt_cls_per_image_bg = gt_cls_per_image.new_zeros( (gt_cls_per_image.size(1), gt_cls_per_image.size(2))).unsqueeze(0) gt_cls_per_image_with_bg = torch.cat( [gt_cls_per_image, gt_cls_per_image_bg], dim=0) cls_target_per_image = gt_cls_per_image_with_bg[ assigned_gt_inds, torch.arange(num_anchor)] # Dealing with Crowdhuman ignore label gt_classes_ = torch.cat([gt_classes, gt_classes.new_zeros(1)]) anchor_cls_labels = gt_classes_[assigned_gt_inds] valid_flag = anchor_cls_labels >= 0 pos_mask = assigned_gt_inds != len( targets_per_image) # get foreground mask valid_fg = pos_mask & valid_flag assigned_fg_inds = assigned_gt_inds[valid_fg] range_fg = torch.arange(num_anchor)[valid_fg] ious_fg = ious[assigned_fg_inds, range_fg] anchor_loss_cls = sigmoid_focal_loss_jit( box_cls_per_image_unexpanded[valid_flag], cls_target_per_image[valid_flag], alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma).sum(dim=-1) delta_target = gt_delta_per_image[assigned_fg_inds, range_fg] anchor_loss_delta = 2. * iou_loss( box_delta_per_image_unexpanded[valid_fg], delta_target, box_mode="ltrb", loss_type=self.iou_loss_type) anchor_loss_iou = 0.5 * F.binary_cross_entropy_with_logits( box_iou_per_image.squeeze(1)[valid_fg], ious_fg, reduction='none') losses_cls.append(anchor_loss_cls.sum()) losses_box_reg.append(anchor_loss_delta.sum()) losses_iou.append(anchor_loss_iou.sum()) if self.norm_sync: dist.all_reduce(num_fg) num_fg = num_fg.float() / dist.get_world_size() return { 'loss_cls': torch.stack(losses_cls).sum() / num_fg, 'loss_box_reg': torch.stack(losses_box_reg).sum() / num_fg, 'loss_iou': torch.stack(losses_iou).sum() / num_fg }
def get_ground_truth(self, shifts, targets, box_cls, box_delta): """ Args: shifts (list[list[Tensor]]): a list of N=#image elements. Each is a list of #feature level tensors. The tensors contains shifts of this image on the specific feature level. targets (list[Instances]): a list of N `Instances`s. The i-th `Instances` contains the ground-truth per-instance annotations for the i-th input image. Specify `targets` during training only. Returns: gt_classes (Tensor): An integer tensor of shape (N, R) storing ground-truth labels for each shift. R is the total number of shifts, i.e. the sum of Hi x Wi for all levels. Shifts in the valid boxes are assigned their corresponding label in the [0, K-1] range. Shifts in the background are assigned the label "K". Shifts in the ignore areas are assigned a label "-1", i.e. ignore. gt_shifts_deltas (Tensor): Shape (N, R, 4). The last dimension represents ground-truth shift2box transform targets (dl, dt, dr, db) that map each shift to its matched ground-truth box. The values in the tensor are meaningful only when the corresponding shift is labeled as foreground. """ gt_classes = [] gt_shifts_deltas = [] box_cls = torch.cat( [permute_to_N_HWA_K(x, self.num_classes) for x in box_cls], dim=1) box_delta = torch.cat([permute_to_N_HWA_K(x, 4) for x in box_delta], dim=1) num_fg = 0 num_gt = 0 for shifts_per_image, targets_per_image, box_cls_per_image, box_delta_per_image in zip( shifts, targets, box_cls, box_delta): shifts_over_all_feature_maps = torch.cat(shifts_per_image, dim=0) gt_boxes = targets_per_image.gt_boxes shape = (len(targets_per_image), len(shifts_over_all_feature_maps), -1) gt_cls_per_image = F.one_hot(targets_per_image.gt_classes, self.num_classes).float() loss_cls = sigmoid_focal_loss_jit( box_cls_per_image.unsqueeze(0).expand(shape), gt_cls_per_image.unsqueeze(1).expand(shape), alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, ).sum(dim=2) gt_delta_per_image = self.shift2box_transform.get_deltas( shifts_over_all_feature_maps, gt_boxes.tensor.unsqueeze(1)) loss_delta = iou_loss( box_delta_per_image.unsqueeze(0).expand(shape), gt_delta_per_image, box_mode="ltrb", loss_type=self.iou_loss_type, ) * self.reg_weight loss = loss_cls + loss_delta INF = 1e8 deltas = self.shift2box_transform.get_deltas( shifts_over_all_feature_maps, gt_boxes.tensor.unsqueeze(1)) if self.center_sampling_radius > 0: centers = gt_boxes.get_centers() is_in_boxes = [] for stride, shifts_i in zip(self.fpn_strides, shifts_per_image): radius = stride * self.center_sampling_radius center_boxes = torch.cat(( torch.max(centers - radius, gt_boxes.tensor[:, :2]), torch.min(centers + radius, gt_boxes.tensor[:, 2:]), ), dim=-1) center_deltas = self.shift2box_transform.get_deltas( shifts_i, center_boxes.unsqueeze(1)) is_in_boxes.append(center_deltas.min(dim=-1).values > 0) is_in_boxes = torch.cat(is_in_boxes, dim=1) else: # no center sampling, it will use all the locations within a ground-truth box is_in_boxes = deltas.min(dim=-1).values > 0 loss[~is_in_boxes] = INF gt_idxs, shift_idxs = linear_sum_assignment(loss.cpu().numpy()) num_fg += len(shift_idxs) num_gt += len(targets_per_image) gt_classes_i = shifts_over_all_feature_maps.new_full( (len(shifts_over_all_feature_maps), ), self.num_classes, dtype=torch.long) gt_shifts_reg_deltas_i = shifts_over_all_feature_maps.new_zeros( len(shifts_over_all_feature_maps), 4) if len(targets_per_image) > 0: # ground truth classes gt_classes_i[shift_idxs] = targets_per_image.gt_classes[ gt_idxs] # ground truth box regression gt_shifts_reg_deltas_i[ shift_idxs] = self.shift2box_transform.get_deltas( shifts_over_all_feature_maps[shift_idxs], gt_boxes[gt_idxs].tensor) gt_classes.append(gt_classes_i) gt_shifts_deltas.append(gt_shifts_reg_deltas_i) get_event_storage().put_scalar("num_fg_per_gt", num_fg / num_gt) return torch.stack(gt_classes), torch.stack(gt_shifts_deltas)
def get_ground_truth(self, shifts, targets, box_cls, box_delta, box_iou): gt_classes = [] gt_shifts_deltas = [] gt_ious = [] assigned_units = [] box_cls = [permute_to_N_HWA_K(x, self.num_classes) for x in box_cls] box_delta = [permute_to_N_HWA_K(x, 4) for x in box_delta] box_iou = [permute_to_N_HWA_K(x, 1) for x in box_iou] box_cls = torch.cat(box_cls, dim=1) box_delta = torch.cat(box_delta, dim=1) box_iou = torch.cat(box_iou, dim=1) for shifts_per_image, targets_per_image, box_cls_per_image, \ box_delta_per_image, box_iou_per_image in zip( shifts, targets, box_cls, box_delta, box_iou): shifts_over_all = torch.cat(shifts_per_image, dim=0) gt_boxes = targets_per_image.gt_boxes # In gt box and center. deltas = self.shift2box_transform.get_deltas( shifts_over_all, gt_boxes.tensor.unsqueeze(1)) is_in_boxes = deltas.min(dim=-1).values > 0.01 center_sampling_radius = 2.5 centers = gt_boxes.get_centers() is_in_centers = [] for stride, shifts_i in zip(self.fpn_strides, shifts_per_image): radius = stride * center_sampling_radius center_boxes = torch.cat(( torch.max(centers - radius, gt_boxes.tensor[:, :2]), torch.min(centers + radius, gt_boxes.tensor[:, 2:]), ), dim=-1) center_deltas = self.shift2box_transform.get_deltas( shifts_i, center_boxes.unsqueeze(1)) is_in_centers.append(center_deltas.min(dim=-1).values > 0) is_in_centers = torch.cat(is_in_centers, dim=1) del centers, center_boxes, deltas, center_deltas is_in_boxes = (is_in_boxes & is_in_centers) num_gt = len(targets_per_image) num_anchor = len(shifts_over_all) shape = (num_gt, num_anchor, -1) gt_cls_per_image = F.one_hot(targets_per_image.gt_classes, self.num_classes).float() with torch.no_grad(): loss_cls = sigmoid_focal_loss_jit( box_cls_per_image.unsqueeze(0).expand(shape), gt_cls_per_image.unsqueeze(1).expand(shape), alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, ).sum(dim=-1) loss_cls_bg = sigmoid_focal_loss_jit( box_cls_per_image, torch.zeros_like(box_cls_per_image), alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, ).sum(dim=-1) gt_delta_per_image = self.shift2box_transform.get_deltas( shifts_over_all, gt_boxes.tensor.unsqueeze(1)) ious, loss_delta = get_ious_and_iou_loss( box_delta_per_image.unsqueeze(0).expand(shape), gt_delta_per_image, box_mode="ltrb", loss_type='iou') loss = loss_cls + self.reg_weight * loss_delta + 1e6 * ( 1 - is_in_boxes.float()) # Performing Dynamic k Estimation topk_ious, _ = torch.topk(ious * is_in_boxes.float(), self.top_candidates, dim=1) mu = ious.new_ones(num_gt + 1) mu[:-1] = torch.clamp(topk_ious.sum(1).int(), min=1).float() mu[-1] = num_anchor - mu[:-1].sum() nu = ious.new_ones(num_anchor) loss = torch.cat([loss, loss_cls_bg.unsqueeze(0)], dim=0) # Solving Optimal-Transportation-Plan pi via Sinkhorn-Iteration. _, pi = self.sinkhorn(mu, nu, loss) # Rescale pi so that the max pi for each gt equals to 1. rescale_factor, _ = pi.max(dim=1) pi = pi / rescale_factor.unsqueeze(1) max_assigned_units, matched_gt_inds = torch.max(pi, dim=0) gt_classes_i = targets_per_image.gt_classes.new_ones( num_anchor) * self.num_classes fg_mask = matched_gt_inds != num_gt gt_classes_i[fg_mask] = targets_per_image.gt_classes[ matched_gt_inds[fg_mask]] gt_classes.append(gt_classes_i) assigned_units.append(max_assigned_units) box_target_per_image = gt_delta_per_image.new_zeros( (num_anchor, 4)) box_target_per_image[fg_mask] = \ gt_delta_per_image[matched_gt_inds[fg_mask], torch.arange(num_anchor)[fg_mask]] gt_shifts_deltas.append(box_target_per_image) gt_ious_per_image = ious.new_zeros((num_anchor, 1)) gt_ious_per_image[fg_mask] = ious[ matched_gt_inds[fg_mask], torch.arange(num_anchor)[fg_mask]].unsqueeze(1) gt_ious.append(gt_ious_per_image) return torch.cat(gt_classes), torch.cat(gt_shifts_deltas), torch.cat( gt_ious)