Exemple #1
0
 def instances_losses(self, gt_insts_mask, proposals, dummy_feature):
     """
     Arguments:
         gt_insts_mask (Tensor):
             Shape (N, mask_h, mask_w), where N = the number of GT instances per batch
             segmentation ground truth
         proposals (Instances):
             A Instances class contains all sampled foreground information per batch,
             thus len(proposals) depends on select_instances function. Two terms are
             required for loss computation when len(proposals) > 0.
             "gt_inds" term, len(proposals) elements, stores mapping relation between
                 predicted instance and gt instance.
             "pred_global_logits" term, shape (len(proposals), 1, mask_h, mask_w),
                 stores predicted logits of foreground segmentation
         dummy_feature (Tensor): a tensor with "requires_grad" equal to True,
             only be used when len(proposals) == 0
     Returns:
         dict[str: Tensor]:
         mapping from a named loss to a scalar tensor
         storing the loss. Used during training only. The dict key is: "loss_mask"
     """
     if len(proposals):
         gt_inds = proposals.gt_inds
         pred_instances_mask = proposals.pred_global_logits.sigmoid()
         gt_insts_mask = gt_insts_mask[gt_inds]. \
             unsqueeze(dim=1).to(dtype=pred_instances_mask.dtype)
         loss_mask = dice_loss(pred_instances_mask, gt_insts_mask).mean()
     else:
         loss_mask = dummy_feature.sum() * 0.
     return {"loss_mask": loss_mask}
Exemple #2
0
    def projection_losses(self, gt_masks, proposals, dummy_feature):
        """
        Arguments:
            gt_masks (List[Tensor]):
                a list of N elements, where N = the number of GT instances per batch and
                shape of each elements in gt_masks is (1, mask_h, mask_w)
                segmentation ground truth where the value inside box is 1, outside box is 0
            proposals (Instances):
                A Instances class contains all sampled foreground information per batch,
                thus len(proposals) depends on select_instances function. Two terms are
                required for loss computation when len(proposals) > 0.
                "gt_inds" term, len(proposals) elements, stores mapping relation between
                    predicted instance and gt instance.
                "pred_global_logits" term, shape (len(proposals), 1, mask_h, mask_w),
                    stores predicted logits of foreground segmentation
            dummy_feature (Tensor): a tensor with "requires_grad" equal to True,
                only be used when len(proposals) == 0
        Returns:
            dict[str: Tensor]:
            mapping from a named loss to a scalar tensor
            storing the loss. Used during training only. The dict key is: "loss_proj"
        """
        if len(proposals):
            gt_inds = proposals.gt_inds
            pred_instances_mask = proposals.pred_global_logits.sigmoid()
            # gather gt_masks based on gt_inds
            # gt_masks shape: List(Tensor)
            gt_masks = torch.cat(gt_masks)[gt_inds].to(
                dtype=pred_instances_mask.dtype)
            gt_proj_x = gt_masks.max(dim=1)[0]
            gt_proj_y = gt_masks.max(dim=2)[0]
            # transform pred mask to compute loss
            # projections pred
            pred_x_proj = pred_instances_mask.squeeze(1).max(dim=1)[0]
            pred_y_proj = pred_instances_mask.squeeze(1).max(dim=2)[0]

            loss_proj = dice_loss(pred_x_proj, gt_proj_x) + \
                dice_loss(pred_y_proj, gt_proj_y)
            loss_proj = loss_proj.mean()
        else:
            loss_proj = dummy_feature.sum() * 0.
        return {"loss_proj": loss_proj}
Exemple #3
0
    def losses(self, ins_preds, cate_preds, ins_label_list, cate_label_list,
               ins_ind_label_list):
        """
        Compute losses:

            L = L_cate + λ * L_mask

        Args:
            ins_preds (list[Tensor]): each element in the list is mask prediction results
                of one level, and the shape of each element is [N, G*G, H, W], where:
                * N is the number of images per mini-batch
                * G is the side length of each level of the grids
                * H and W is the height and width of the predicted mask

            cate_preds (list[Tensor]): each element in the list is category prediction results
                of one level, and the shape of each element is [#N, #C, #G, #G], where:
                * C is the number of classes

            ins_label_list (list[list[Tensor]]): each element in the list is mask ground truth
                of one image, and each element is a list which contains mask tensors per level
                with shape [H, W], where:
                * H and W is the ground truth mask size per level (same as `ins_preds`)

            cate_label_list (list[list[Tensor]]): each element in the list is category ground truth
                of one image, and each element is a list which contains tensors with shape [G, G]
                per level.

            ins_ind_label_list (list[list[Tensor]]):  used to indicate which grids contain objects,
                these grids need to calculate mask loss. Each element in the list is indicator
                of one image, and each element is a list which contains tensors with shape [G*G]
                per level。

        Returns:
            dict[str -> Tensor]: losses.
        """
        # ins, per level
        ins_preds_valid = []
        ins_labels_valid = []
        cate_labels_valid = []
        num_images = len(ins_label_list)
        num_levels = len(ins_label_list[0])
        for level_idx in range(num_levels):
            ins_preds_per_level = []
            ins_labels_per_level = []
            cate_labels_per_level = []
            for img_idx in range(num_images):
                valid_ins_inds = ins_ind_label_list[img_idx][level_idx]
                ins_preds_per_level.append(
                    ins_preds[level_idx][img_idx][valid_ins_inds, ...])
                ins_labels_per_level.append(
                    ins_label_list[img_idx][level_idx][valid_ins_inds, ...])
                cate_labels_per_level.append(
                    cate_label_list[img_idx][level_idx].flatten())
            ins_preds_valid.append(torch.cat(ins_preds_per_level))
            ins_labels_valid.append(torch.cat(ins_labels_per_level))
            cate_labels_valid.append(torch.cat(cate_labels_per_level))

        # dice loss, per_level
        loss_ins = []
        for input, target in zip(ins_preds_valid, ins_labels_valid):
            if input.size()[0] == 0:
                continue
            input = torch.sigmoid(input)
            target = target.float() / 255.
            loss_ins.append(dice_loss(input, target))
        # loss_ins (list[Tensor]): each element's shape is [#Ins, #H*#W]
        loss_ins = torch.cat(loss_ins).mean()
        loss_ins = loss_ins * self.loss_ins_weight

        # cate
        cate_preds = [
            cate_pred.permute(0, 2, 3, 1).reshape(-1, self.num_classes)
            for cate_pred in cate_preds
        ]
        cate_preds = torch.cat(cate_preds)

        flatten_cate_labels = torch.cat(cate_labels_valid)
        foreground_idxs = flatten_cate_labels != self.num_classes
        cate_labels = torch.zeros_like(cate_preds)
        cate_labels[foreground_idxs, flatten_cate_labels[foreground_idxs]] = 1
        num_ins = foreground_idxs.sum()

        loss_cate = self.loss_cat_weight * sigmoid_focal_loss_jit(
            cate_preds,
            cate_labels,
            alpha=self.loss_cat_alpha,
            gamma=self.loss_cat_gamma,
            reduction="sum",
        ) / max(1, num_ins)
        return dict(loss_ins=loss_ins, loss_cate=loss_cate)
Exemple #4
0
    def losses(self, ins_preds_x, ins_preds_y, cate_preds, ins_label_list,
               cate_label_list, ins_ind_label_list, ins_ind_label_list_xy):
        # ins, per level
        ins_labels = []  # per level
        for ins_labels_level, ins_ind_labels_level in \
                zip(zip(*ins_label_list), zip(*ins_ind_label_list)):
            ins_labels_per_level = []
            for ins_labels_level_img, ins_ind_labels_level_img in \
                    zip(ins_labels_level, ins_ind_labels_level):
                ins_labels_per_level.append(
                    ins_labels_level_img[ins_ind_labels_level_img, ...])
            ins_labels.append(torch.cat(ins_labels_per_level))

        ins_preds_x_final = []
        for ins_preds_level_x, ins_ind_labels_level in \
                zip(ins_preds_x, zip(*ins_ind_label_list_xy)):
            ins_preds_x_final_per_level = []
            for ins_preds_level_img_x, ins_ind_labels_level_img in \
                    zip(ins_preds_level_x, ins_ind_labels_level):
                ins_preds_x_final_per_level.append(
                    ins_preds_level_img_x[ins_ind_labels_level_img[:, 1], ...])
            ins_preds_x_final.append(torch.cat(ins_preds_x_final_per_level))

        ins_preds_y_final = []
        for ins_preds_level_y, ins_ind_labels_level in \
                zip(ins_preds_y, zip(*ins_ind_label_list_xy)):
            ins_preds_y_final_per_level = []
            for ins_preds_level_img_y, ins_ind_labels_level_img in \
                    zip(ins_preds_level_y, ins_ind_labels_level):
                ins_preds_y_final_per_level.append(
                    ins_preds_level_img_y[ins_ind_labels_level_img[:, 0], ...])
            ins_preds_y_final.append(torch.cat(ins_preds_y_final_per_level))

        num_ins = 0.
        # dice loss, per_level
        loss_ins = []
        for input_x, input_y, target in zip(ins_preds_x_final,
                                            ins_preds_y_final, ins_labels):
            mask_n = input_x.size(0)
            if mask_n == 0:
                continue
            num_ins += mask_n
            input = (input_x.sigmoid()) * (input_y.sigmoid())
            target = target.float() / 255.
            loss_ins.append(dice_loss(input, target))

        loss_ins = torch.cat(loss_ins).mean()
        loss_ins = loss_ins * self.loss_ins_weight

        # cate
        cate_preds = [
            cate_pred.permute(0, 2, 3, 1).reshape(-1, self.num_classes)
            for cate_pred in cate_preds
        ]
        cate_preds = torch.cat(cate_preds)

        cate_labels = []
        for cate_labels_level in zip(*cate_label_list):
            cate_labels_per_level = []
            for cate_labels_level_img in cate_labels_level:
                cate_labels_per_level.append(cate_labels_level_img.flatten())
            cate_labels.append(torch.cat(cate_labels_per_level))
        flatten_cate_labels = torch.cat(cate_labels)
        foreground_idxs = flatten_cate_labels != self.num_classes
        cate_labels = torch.zeros_like(cate_preds)
        cate_labels[foreground_idxs, flatten_cate_labels[foreground_idxs]] = 1

        loss_cate = self.loss_cat_weight * sigmoid_focal_loss_jit(
            cate_preds,
            cate_labels,
            alpha=self.loss_cat_alpha,
            gamma=self.loss_cat_gamma,
            reduction="sum",
        ) / max(1, num_ins)
        return dict(loss_ins=loss_ins, loss_cate=loss_cate)