def instances_losses(self, gt_insts_mask, proposals, dummy_feature): """ Arguments: gt_insts_mask (Tensor): Shape (N, mask_h, mask_w), where N = the number of GT instances per batch segmentation ground truth proposals (Instances): A Instances class contains all sampled foreground information per batch, thus len(proposals) depends on select_instances function. Two terms are required for loss computation when len(proposals) > 0. "gt_inds" term, len(proposals) elements, stores mapping relation between predicted instance and gt instance. "pred_global_logits" term, shape (len(proposals), 1, mask_h, mask_w), stores predicted logits of foreground segmentation dummy_feature (Tensor): a tensor with "requires_grad" equal to True, only be used when len(proposals) == 0 Returns: dict[str: Tensor]: mapping from a named loss to a scalar tensor storing the loss. Used during training only. The dict key is: "loss_mask" """ if len(proposals): gt_inds = proposals.gt_inds pred_instances_mask = proposals.pred_global_logits.sigmoid() gt_insts_mask = gt_insts_mask[gt_inds]. \ unsqueeze(dim=1).to(dtype=pred_instances_mask.dtype) loss_mask = dice_loss(pred_instances_mask, gt_insts_mask).mean() else: loss_mask = dummy_feature.sum() * 0. return {"loss_mask": loss_mask}
def projection_losses(self, gt_masks, proposals, dummy_feature): """ Arguments: gt_masks (List[Tensor]): a list of N elements, where N = the number of GT instances per batch and shape of each elements in gt_masks is (1, mask_h, mask_w) segmentation ground truth where the value inside box is 1, outside box is 0 proposals (Instances): A Instances class contains all sampled foreground information per batch, thus len(proposals) depends on select_instances function. Two terms are required for loss computation when len(proposals) > 0. "gt_inds" term, len(proposals) elements, stores mapping relation between predicted instance and gt instance. "pred_global_logits" term, shape (len(proposals), 1, mask_h, mask_w), stores predicted logits of foreground segmentation dummy_feature (Tensor): a tensor with "requires_grad" equal to True, only be used when len(proposals) == 0 Returns: dict[str: Tensor]: mapping from a named loss to a scalar tensor storing the loss. Used during training only. The dict key is: "loss_proj" """ if len(proposals): gt_inds = proposals.gt_inds pred_instances_mask = proposals.pred_global_logits.sigmoid() # gather gt_masks based on gt_inds # gt_masks shape: List(Tensor) gt_masks = torch.cat(gt_masks)[gt_inds].to( dtype=pred_instances_mask.dtype) gt_proj_x = gt_masks.max(dim=1)[0] gt_proj_y = gt_masks.max(dim=2)[0] # transform pred mask to compute loss # projections pred pred_x_proj = pred_instances_mask.squeeze(1).max(dim=1)[0] pred_y_proj = pred_instances_mask.squeeze(1).max(dim=2)[0] loss_proj = dice_loss(pred_x_proj, gt_proj_x) + \ dice_loss(pred_y_proj, gt_proj_y) loss_proj = loss_proj.mean() else: loss_proj = dummy_feature.sum() * 0. return {"loss_proj": loss_proj}
def losses(self, ins_preds, cate_preds, ins_label_list, cate_label_list, ins_ind_label_list): """ Compute losses: L = L_cate + λ * L_mask Args: ins_preds (list[Tensor]): each element in the list is mask prediction results of one level, and the shape of each element is [N, G*G, H, W], where: * N is the number of images per mini-batch * G is the side length of each level of the grids * H and W is the height and width of the predicted mask cate_preds (list[Tensor]): each element in the list is category prediction results of one level, and the shape of each element is [#N, #C, #G, #G], where: * C is the number of classes ins_label_list (list[list[Tensor]]): each element in the list is mask ground truth of one image, and each element is a list which contains mask tensors per level with shape [H, W], where: * H and W is the ground truth mask size per level (same as `ins_preds`) cate_label_list (list[list[Tensor]]): each element in the list is category ground truth of one image, and each element is a list which contains tensors with shape [G, G] per level. ins_ind_label_list (list[list[Tensor]]): used to indicate which grids contain objects, these grids need to calculate mask loss. Each element in the list is indicator of one image, and each element is a list which contains tensors with shape [G*G] per level。 Returns: dict[str -> Tensor]: losses. """ # ins, per level ins_preds_valid = [] ins_labels_valid = [] cate_labels_valid = [] num_images = len(ins_label_list) num_levels = len(ins_label_list[0]) for level_idx in range(num_levels): ins_preds_per_level = [] ins_labels_per_level = [] cate_labels_per_level = [] for img_idx in range(num_images): valid_ins_inds = ins_ind_label_list[img_idx][level_idx] ins_preds_per_level.append( ins_preds[level_idx][img_idx][valid_ins_inds, ...]) ins_labels_per_level.append( ins_label_list[img_idx][level_idx][valid_ins_inds, ...]) cate_labels_per_level.append( cate_label_list[img_idx][level_idx].flatten()) ins_preds_valid.append(torch.cat(ins_preds_per_level)) ins_labels_valid.append(torch.cat(ins_labels_per_level)) cate_labels_valid.append(torch.cat(cate_labels_per_level)) # dice loss, per_level loss_ins = [] for input, target in zip(ins_preds_valid, ins_labels_valid): if input.size()[0] == 0: continue input = torch.sigmoid(input) target = target.float() / 255. loss_ins.append(dice_loss(input, target)) # loss_ins (list[Tensor]): each element's shape is [#Ins, #H*#W] loss_ins = torch.cat(loss_ins).mean() loss_ins = loss_ins * self.loss_ins_weight # cate cate_preds = [ cate_pred.permute(0, 2, 3, 1).reshape(-1, self.num_classes) for cate_pred in cate_preds ] cate_preds = torch.cat(cate_preds) flatten_cate_labels = torch.cat(cate_labels_valid) foreground_idxs = flatten_cate_labels != self.num_classes cate_labels = torch.zeros_like(cate_preds) cate_labels[foreground_idxs, flatten_cate_labels[foreground_idxs]] = 1 num_ins = foreground_idxs.sum() loss_cate = self.loss_cat_weight * sigmoid_focal_loss_jit( cate_preds, cate_labels, alpha=self.loss_cat_alpha, gamma=self.loss_cat_gamma, reduction="sum", ) / max(1, num_ins) return dict(loss_ins=loss_ins, loss_cate=loss_cate)
def losses(self, ins_preds_x, ins_preds_y, cate_preds, ins_label_list, cate_label_list, ins_ind_label_list, ins_ind_label_list_xy): # ins, per level ins_labels = [] # per level for ins_labels_level, ins_ind_labels_level in \ zip(zip(*ins_label_list), zip(*ins_ind_label_list)): ins_labels_per_level = [] for ins_labels_level_img, ins_ind_labels_level_img in \ zip(ins_labels_level, ins_ind_labels_level): ins_labels_per_level.append( ins_labels_level_img[ins_ind_labels_level_img, ...]) ins_labels.append(torch.cat(ins_labels_per_level)) ins_preds_x_final = [] for ins_preds_level_x, ins_ind_labels_level in \ zip(ins_preds_x, zip(*ins_ind_label_list_xy)): ins_preds_x_final_per_level = [] for ins_preds_level_img_x, ins_ind_labels_level_img in \ zip(ins_preds_level_x, ins_ind_labels_level): ins_preds_x_final_per_level.append( ins_preds_level_img_x[ins_ind_labels_level_img[:, 1], ...]) ins_preds_x_final.append(torch.cat(ins_preds_x_final_per_level)) ins_preds_y_final = [] for ins_preds_level_y, ins_ind_labels_level in \ zip(ins_preds_y, zip(*ins_ind_label_list_xy)): ins_preds_y_final_per_level = [] for ins_preds_level_img_y, ins_ind_labels_level_img in \ zip(ins_preds_level_y, ins_ind_labels_level): ins_preds_y_final_per_level.append( ins_preds_level_img_y[ins_ind_labels_level_img[:, 0], ...]) ins_preds_y_final.append(torch.cat(ins_preds_y_final_per_level)) num_ins = 0. # dice loss, per_level loss_ins = [] for input_x, input_y, target in zip(ins_preds_x_final, ins_preds_y_final, ins_labels): mask_n = input_x.size(0) if mask_n == 0: continue num_ins += mask_n input = (input_x.sigmoid()) * (input_y.sigmoid()) target = target.float() / 255. loss_ins.append(dice_loss(input, target)) loss_ins = torch.cat(loss_ins).mean() loss_ins = loss_ins * self.loss_ins_weight # cate cate_preds = [ cate_pred.permute(0, 2, 3, 1).reshape(-1, self.num_classes) for cate_pred in cate_preds ] cate_preds = torch.cat(cate_preds) cate_labels = [] for cate_labels_level in zip(*cate_label_list): cate_labels_per_level = [] for cate_labels_level_img in cate_labels_level: cate_labels_per_level.append(cate_labels_level_img.flatten()) cate_labels.append(torch.cat(cate_labels_per_level)) flatten_cate_labels = torch.cat(cate_labels) foreground_idxs = flatten_cate_labels != self.num_classes cate_labels = torch.zeros_like(cate_preds) cate_labels[foreground_idxs, flatten_cate_labels[foreground_idxs]] = 1 loss_cate = self.loss_cat_weight * sigmoid_focal_loss_jit( cate_preds, cate_labels, alpha=self.loss_cat_alpha, gamma=self.loss_cat_gamma, reduction="sum", ) / max(1, num_ins) return dict(loss_ins=loss_ins, loss_cate=loss_cate)