Exemple #1
0
def postprocess(det_output,
                w,
                h,
                batch_idx=0,
                interpolation_mode='bilinear',
                visualize_lincomb=False,
                crop_masks=True,
                score_threshold=0):
    """
    Postprocesses the output of Yolact on testing mode into a format that makes sense,
    accounting for all the possible configuration settings.

    Args:
        - det_output: The lost of dicts that Detect outputs.
        - w: The real with of the image.
        - h: The real height of the image.
        - batch_idx: If you have multiple images for this batch, the image's index in the batch.
        - interpolation_mode: Can be 'nearest' | 'area' | 'bilinear' (see torch.nn.functional.interpolate)

    Returns 4 torch Tensors (in the following order):
        - classes [num_det]: The class idx for each detection.
        - scores  [num_det]: The confidence score for each detection.
        - boxes   [num_det, 4]: The bounding box for each detection in absolute point form.
        - masks   [num_det, h, w]: Full image masks for each detection.
    """

    dets = det_output[batch_idx]
    net = dets['net']
    dets = dets['detection']

    if dets is None:
        return [torch.Tensor()
                ] * 4  # Warning, this is 4 copies of the same thing

    if score_threshold > 0:
        keep = dets['score'] > score_threshold

        for k in dets:
            if k != 'proto':
                dets[k] = dets[k][keep]

        if dets['score'].size(0) == 0:
            return [torch.Tensor()] * 4

    # Actually extract everything from dets now
    classes = dets['class']
    boxes = dets['box']
    scores = dets['score']
    masks = dets['mask']

    if cfg.mask_type == mask_type.lincomb and cfg.eval_mask_branch:
        # At this points masks is only the coefficients
        proto_data = dets['proto']

        # Test flag, do not upvote
        if cfg.mask_proto_debug:
            np.save('scripts/proto.npy', proto_data.cpu().numpy())

        if visualize_lincomb:
            display_lincomb(proto_data, masks)

        masks = proto_data @ masks.t()
        masks = cfg.mask_proto_mask_activation(masks)

        # Crop masks before upsampling because you know why
        if crop_masks:
            masks = crop(masks, boxes)

        # Permute into the correct output shape [num_dets, proto_h, proto_w]
        masks = masks.permute(2, 0, 1).contiguous()

        if cfg.use_maskiou:
            with timer.env('maskiou_net'):
                with torch.no_grad():
                    maskiou_p = net.maskiou_net(masks.unsqueeze(1))
                    maskiou_p = torch.gather(
                        maskiou_p, dim=1,
                        index=classes.unsqueeze(1)).squeeze(1)
                    if cfg.rescore_mask:
                        if cfg.rescore_bbox:
                            scores = scores * maskiou_p
                        else:
                            scores = [scores, scores * maskiou_p]

        # Scale masks up to the full image
        masks = F.interpolate(masks.unsqueeze(0), (h, w),
                              mode=interpolation_mode,
                              align_corners=False).squeeze(0)

        # Binarize the masks
        masks.gt_(0.5)

    boxes[:, 0], boxes[:, 2] = sanitize_coordinates(boxes[:, 0],
                                                    boxes[:, 2],
                                                    w,
                                                    cast=False)
    boxes[:, 1], boxes[:, 3] = sanitize_coordinates(boxes[:, 1],
                                                    boxes[:, 3],
                                                    h,
                                                    cast=False)
    boxes = boxes.long()

    if cfg.mask_type == mask_type.direct and cfg.eval_mask_branch:
        # Upscale masks
        full_masks = torch.zeros(masks.size(0), h, w)

        for jdx in range(masks.size(0)):
            x1, y1, x2, y2 = boxes[jdx, :]

            mask_w = x2 - x1
            mask_h = y2 - y1

            # Just in case
            if mask_w * mask_h <= 0 or mask_w < 0:
                continue

            mask = masks[jdx, :].view(1, 1, cfg.mask_size, cfg.mask_size)
            mask = F.interpolate(mask, (mask_h, mask_w),
                                 mode=interpolation_mode,
                                 align_corners=False)
            mask = mask.gt(0.5).float()
            full_masks[jdx, y1:y2, x1:x2] = mask

        masks = full_masks

    return classes, scores, boxes, masks
Exemple #2
0
    def lincomb_mask_loss(self,
                          pos,
                          idx_t,
                          loc_data,
                          mask_data,
                          priors,
                          proto_data,
                          masks,
                          gt_box_t,
                          inst_data,
                          interpolation_mode='bilinear'):
        mask_h = proto_data.size(1)
        mask_w = proto_data.size(2)

        process_gt_bboxes = cfg.mask_proto_normalize_emulate_roi_pooling or cfg.mask_proto_crop

        if cfg.mask_proto_remove_empty_masks:
            # Make sure to store a copy of this because we edit it to get rid of all-zero masks
            pos = pos.clone()

        loss_m = 0
        loss_d = 0  # Coefficient diversity loss

        for idx in range(mask_data.size(0)):
            with torch.no_grad():
                downsampled_masks = F.interpolate(
                    masks[idx].unsqueeze(0), (mask_h, mask_w),
                    mode=interpolation_mode,
                    align_corners=False).squeeze(0)
                downsampled_masks = downsampled_masks.permute(1, 2,
                                                              0).contiguous()

                if cfg.mask_proto_binarize_downsampled_gt:
                    downsampled_masks = downsampled_masks.gt(0.5).float()

                if cfg.mask_proto_remove_empty_masks:
                    # Get rid of gt masks that are so small they get downsampled away
                    very_small_masks = (downsampled_masks.sum(dim=(0, 1)) <=
                                        0.0001)
                    for i in range(very_small_masks.size(0)):
                        if very_small_masks[i]:
                            pos[idx, idx_t[idx] == i] = 0

                if cfg.mask_proto_reweight_mask_loss:
                    # Ensure that the gt is binary
                    if not cfg.mask_proto_binarize_downsampled_gt:
                        bin_gt = downsampled_masks.gt(0.5).float()
                    else:
                        bin_gt = downsampled_masks

                    gt_foreground_norm = bin_gt / (
                        torch.sum(bin_gt, dim=(0, 1), keepdim=True) + 0.0001)
                    gt_background_norm = (1 - bin_gt) / (torch.sum(
                        1 - bin_gt, dim=(0, 1), keepdim=True) + 0.0001)

                    mask_reweighting = gt_foreground_norm * cfg.mask_proto_reweight_coeff + gt_background_norm
                    mask_reweighting *= mask_h * mask_w

            cur_pos = pos[idx]
            pos_idx_t = idx_t[idx, cur_pos]

            if process_gt_bboxes:
                # Note: this is in point-form
                pos_gt_box_t = gt_box_t[idx, cur_pos]

            if pos_idx_t.size(0) == 0:
                continue

            proto_masks = proto_data[idx]
            proto_coef = mask_data[idx, cur_pos, :]

            if cfg.mask_proto_coeff_diversity_loss:
                if inst_data is not None:
                    div_coeffs = inst_data[idx, cur_pos, :]
                else:
                    div_coeffs = proto_coef

                loss_d += self.coeff_diversity_loss(div_coeffs, pos_idx_t)

            # If we have over the allowed number of masks, select a random sample
            old_num_pos = proto_coef.size(0)
            if old_num_pos > cfg.masks_to_train:
                perm = torch.randperm(proto_coef.size(0))
                select = perm[:cfg.masks_to_train]

                proto_coef = proto_coef[select, :]
                pos_idx_t = pos_idx_t[select]

                if process_gt_bboxes:
                    pos_gt_box_t = pos_gt_box_t[select, :]

            num_pos = proto_coef.size(0)
            mask_t = downsampled_masks[:, :, pos_idx_t]

            # Size: [mask_h, mask_w, num_pos]
            pred_masks = proto_masks @ proto_coef.t()
            pred_masks = cfg.mask_proto_mask_activation(pred_masks)

            if cfg.mask_proto_double_loss:
                if cfg.mask_proto_mask_activation == activation_func.sigmoid:
                    pre_loss = F.binary_cross_entropy(torch.clamp(
                        pred_masks, 0, 1),
                                                      mask_t,
                                                      reduction='sum')
                else:
                    pre_loss = F.smooth_l1_loss(pred_masks,
                                                mask_t,
                                                reduction='sum')

                loss_m += cfg.mask_proto_double_loss_alpha * pre_loss

            if cfg.mask_proto_crop:
                pred_masks = crop(pred_masks, pos_gt_box_t)

            if cfg.mask_proto_mask_activation == activation_func.sigmoid:
                pre_loss = F.binary_cross_entropy(torch.clamp(
                    pred_masks, 0, 1),
                                                  mask_t,
                                                  reduction='none')
            else:
                pre_loss = F.smooth_l1_loss(pred_masks,
                                            mask_t,
                                            reduction='none')

            if cfg.mask_proto_normalize_mask_loss_by_sqrt_area:
                gt_area = torch.sum(mask_t, dim=(0, 1), keepdim=True)
                pre_loss = pre_loss / (torch.sqrt(gt_area) + 0.0001)

            if cfg.mask_proto_reweight_mask_loss:
                pre_loss = pre_loss * mask_reweighting[:, :, pos_idx_t]

            if cfg.mask_proto_normalize_emulate_roi_pooling:
                weight = mask_h * mask_w if cfg.mask_proto_crop else 1
                pos_get_csize = center_size(pos_gt_box_t)
                gt_box_width = pos_get_csize[:, 2] * mask_w
                gt_box_height = pos_get_csize[:, 3] * mask_h
                pre_loss = pre_loss.sum(
                    dim=(0, 1)) / gt_box_width / gt_box_height * weight

            # If the number of masks were limited scale the loss accordingly
            if old_num_pos > num_pos:
                pre_loss *= old_num_pos / num_pos

            loss_m += torch.sum(pre_loss)

        losses = {'M': loss_m * cfg.mask_alpha / mask_h / mask_w}

        if cfg.mask_proto_coeff_diversity_loss:
            losses['D'] = loss_d

        return losses
Exemple #3
0
def postprocess(det_output,
                w,
                h,
                batch_idx=0,
                interpolation_mode='bilinear',
                visualize_lincomb=False,
                crop_masks=True,
                score_threshold=0,
                mask_score=True):
    """
    Postprocesses the output of EPSNet on testing mode into a format that makes sense,
    accounting for all the possible configuration settings.

    Args:
        - det_output: The lost of dicts that Detect outputs.
        - w: The real with of the image.
        - h: The real height of the image.
        - batch_idx: If you have multiple images for this batch, the image's index in the batch.
        - interpolation_mode: Can be 'nearest' | 'area' | 'bilinear' (see torch.nn.functional.interpolate)

    Returns 4 torch Tensors (in the following order):
        - classes [num_det]: The class idx for each detection.
        - scores  [num_det]: The confidence score for each detection.
        - boxes   [num_det, 4]: The bounding box for each detection in absolute point form.
        - masks   [num_det, h, w]: Full image masks for each detection.
    """

    dets = det_output[batch_idx]

    if not 'score' in dets:
        return [torch.Tensor()
                ] * 4  # Warning, this is 4 copies of the same thing

    if score_threshold > 0:
        keep = dets['score'] > score_threshold

        for k in dets:
            if k != 'proto' and k != 'segm':
                dets[k] = dets[k][keep]

        if dets['score'].size(0) == 0:
            return [torch.Tensor()] * 4

    # im_w and im_h when it concerns bboxes. This is a workaround hack for preserve_aspect_ratio
    b_w, b_h = (w, h)

    # Undo the padding introduced with preserve_aspect_ratio
    if cfg.preserve_aspect_ratio:
        r_w, r_h = Resize.faster_rcnn_scale(w, h, cfg.min_size, cfg.max_size)

        # Get rid of any detections whose centers are outside the image
        boxes = dets['box']
        boxes = center_size(boxes)
        s_w, s_h = (r_w / cfg.max_size, r_h / cfg.max_size)

        not_outside = ((boxes[:, 0] > s_w) +
                       (boxes[:, 1] > s_h)) < 1  # not (a or b)
        for k in dets:
            if k != 'proto':
                dets[k] = dets[k][not_outside]

        # A hack to scale the bboxes to the right size
        b_w, b_h = (cfg.max_size / r_w * w, cfg.max_size / r_h * h)

    # Actually extract everything from dets now
    classes = dets['class']
    boxes = dets['box']
    scores = dets['score']
    masks = dets['mask']

    if cfg.mask_type == mask_type.lincomb and cfg.eval_mask_branch:
        # At this points masks is only the coefficients
        proto_data = dets['proto']

        # Test flag, do not upvote
        if cfg.mask_proto_debug:
            np.save('scripts/proto.npy', proto_data.cpu().numpy())

        if visualize_lincomb:
            display_lincomb(proto_data, masks)
        masks = torch.matmul(proto_data, masks.t())
        if mask_score:
            masks = cfg.mask_proto_mask_activation(masks)

        # Crop masks before upsampling because you know why
        if crop_masks:
            masks = crop(masks, boxes)

        # Permute into the correct output shape [num_dets, proto_h, proto_w]
        masks = masks.permute(2, 0, 1).contiguous()

        # Scale masks up to the full image
        if cfg.preserve_aspect_ratio:
            # Undo padding
            masks = masks[:, :int(r_h / cfg.max_size * proto_data.size(1)
                                  ), :int(r_w / cfg.max_size *
                                          proto_data.size(2))]

        masks = F.interpolate(masks.unsqueeze(0), (h, w),
                              mode=interpolation_mode,
                              align_corners=False).squeeze(0)
        # Binarize the masks
        if mask_score:
            masks.gt_(0.5)

    if mask_score is True:
        boxes[:, 0], boxes[:, 2] = sanitize_coordinates(boxes[:, 0],
                                                        boxes[:, 2],
                                                        b_w,
                                                        cast=False)
        boxes[:, 1], boxes[:, 3] = sanitize_coordinates(boxes[:, 1],
                                                        boxes[:, 3],
                                                        b_h,
                                                        cast=False)
        boxes = boxes.long()

    if cfg.mask_type == mask_type.direct and cfg.eval_mask_branch:
        # Upscale masks
        full_masks = torch.zeros(masks.size(0), h, w)

        for jdx in range(masks.size(0)):
            x1, y1, x2, y2 = boxes[jdx, :]

            mask_w = x2 - x1
            mask_h = y2 - y1

            # Just in case
            if mask_w * mask_h <= 0 or mask_w < 0:
                continue

            mask = masks[jdx, :].view(1, 1, cfg.mask_size, cfg.mask_size)
            mask = F.interpolate(mask, (mask_h, mask_w),
                                 mode=interpolation_mode,
                                 align_corners=False)
            if mask_score:
                mask = mask.gt(0.5).float()
            full_masks[jdx, y1:y2, x1:x2] = mask

        masks = full_masks

    return classes, scores, boxes, masks
    def lincomb_mask_loss(self,
                          pos,
                          idx_t,
                          loc_data,
                          mask_data,
                          priors,
                          proto_data,
                          masks,
                          gt_box_t,
                          score_data,
                          inst_data,
                          labels,
                          dist_maps,
                          interpolation_mode='bilinear'):
        mask_h = proto_data.size(1)
        mask_w = proto_data.size(2)

        process_gt_bboxes = cfg.mask_proto_normalize_emulate_roi_pooling or cfg.mask_proto_crop

        if cfg.mask_proto_remove_empty_masks:
            # Make sure to store a copy of this because we edit it to get rid of all-zero masks
            pos = pos.clone()

        loss_m = 0
        loss_ms = 0
        loss_d = 0  # Coefficient diversity loss

        maskiou_t_list = []
        maskiou_net_input_list = []
        label_t_list = []

        for idx in range(mask_data.size(0)):
            with torch.no_grad():
                downsampled_masks = F.interpolate(
                    masks[idx].unsqueeze(0), (mask_h, mask_w),
                    mode=interpolation_mode,
                    align_corners=False).squeeze(0)
                downsampled_masks = downsampled_masks.permute(1, 2,
                                                              0).contiguous()
                ###
                downsampled_dist_maps = F.interpolate(
                    dist_maps[idx].unsqueeze(0), (mask_h, mask_w),
                    mode=interpolation_mode,
                    align_corners=False).squeeze(0)
                downsampled_dist_maps = downsampled_dist_maps.permute(
                    1, 2, 0).contiguous()

                if cfg.mask_proto_binarize_downsampled_gt:
                    downsampled_masks = downsampled_masks.gt(0.5).float()

                if cfg.mask_proto_remove_empty_masks:
                    # Get rid of gt masks that are so small they get downsampled away
                    very_small_masks = (downsampled_masks.sum(dim=(0, 1)) <=
                                        0.0001)
                    for i in range(very_small_masks.size(0)):
                        if very_small_masks[i]:
                            pos[idx, idx_t[idx] == i] = 0

                if cfg.mask_proto_reweight_mask_loss:
                    # Ensure that the gt is binary
                    if not cfg.mask_proto_binarize_downsampled_gt:
                        bin_gt = downsampled_masks.gt(0.5).float()
                    else:
                        bin_gt = downsampled_masks

                    gt_foreground_norm = bin_gt / (
                        torch.sum(bin_gt, dim=(0, 1), keepdim=True) + 0.0001)
                    gt_background_norm = (1 - bin_gt) / (torch.sum(
                        1 - bin_gt, dim=(0, 1), keepdim=True) + 0.0001)

                    mask_reweighting = gt_foreground_norm * cfg.mask_proto_reweight_coeff + gt_background_norm
                    mask_reweighting *= mask_h * mask_w

            cur_pos = pos[idx]
            pos_idx_t = idx_t[idx, cur_pos]

            if process_gt_bboxes:
                # Note: this is in point-form
                if cfg.mask_proto_crop_with_pred_box:
                    pos_gt_box_t = decode(loc_data[idx, :, :], priors.data,
                                          cfg.use_yolo_regressors)[cur_pos]
                else:
                    pos_gt_box_t = gt_box_t[idx, cur_pos]

            if pos_idx_t.size(0) == 0:
                continue

            proto_masks = proto_data[idx]
            proto_coef = mask_data[idx, cur_pos, :]
            if cfg.use_mask_scoring:
                mask_scores = score_data[idx, cur_pos, :]

            if cfg.mask_proto_coeff_diversity_loss:
                if inst_data is not None:
                    div_coeffs = inst_data[idx, cur_pos, :]
                else:
                    div_coeffs = proto_coef

                loss_d += self.coeff_diversity_loss(div_coeffs, pos_idx_t)

            # If we have over the allowed number of masks, select a random sample
            old_num_pos = proto_coef.size(0)
            if old_num_pos > cfg.masks_to_train:
                perm = torch.randperm(proto_coef.size(0))
                select = perm[:cfg.masks_to_train]

                proto_coef = proto_coef[select, :]
                pos_idx_t = pos_idx_t[select]

                if process_gt_bboxes:
                    pos_gt_box_t = pos_gt_box_t[select, :]
                if cfg.use_mask_scoring:
                    mask_scores = mask_scores[select, :]

            num_pos = proto_coef.size(0)
            mask_t = downsampled_masks[:, :, pos_idx_t]
            dist_maps_t = downsampled_dist_maps[:, :, pos_idx_t]
            label_t = labels[idx][pos_idx_t]

            # Size: [mask_h, mask_w, num_pos]
            pred_masks = proto_masks @ proto_coef.t()
            pred_masks = cfg.mask_proto_mask_activation(pred_masks)

            if cfg.mask_proto_double_loss:
                if cfg.mask_proto_mask_activation == activation_func.sigmoid:
                    pre_loss = F.binary_cross_entropy(torch.clamp(
                        pred_masks, 0, 1),
                                                      mask_t,
                                                      reduction='sum')
                else:
                    pre_loss = F.smooth_l1_loss(pred_masks,
                                                mask_t,
                                                reduction='sum')

                loss_m += cfg.mask_proto_double_loss_alpha * pre_loss

            if cfg.mask_proto_crop:
                pred_masks = crop(pred_masks, pos_gt_box_t)

            if cfg.mask_proto_mask_activation == activation_func.sigmoid:
                pre_loss = F.binary_cross_entropy(torch.clamp(
                    pred_masks, 0, 1),
                                                  mask_t,
                                                  reduction='none')
            else:
                pre_loss = F.smooth_l1_loss(pred_masks,
                                            mask_t,
                                            reduction='none')

            ### my ###  adding the surface_loss (GAC_loss)
            if True:
                pc = pred_masks.unsqueeze(0).permute(
                    0, 3, 1, 2)  #[:, self.idc, ...].type(torch.float32)
                dc = dist_maps_t.unsqueeze(0).permute(
                    0, 3, 1, 2)  #[:, self.idc, ...].type(torch.float32)
                # print(pc.shape)
                # print(dc.shape)
                multipled = torch.einsum("bcwh,bcwh->bcwh", pc, dc)
                surface_loss = multipled.mean()
            loss_ms += 1 * surface_loss
            ### my ###

            if cfg.mask_proto_normalize_mask_loss_by_sqrt_area:
                gt_area = torch.sum(mask_t, dim=(0, 1), keepdim=True)
                pre_loss = pre_loss / (torch.sqrt(gt_area) + 0.0001)

            if cfg.mask_proto_reweight_mask_loss:
                pre_loss = pre_loss * mask_reweighting[:, :, pos_idx_t]

            if cfg.mask_proto_normalize_emulate_roi_pooling:
                weight = mask_h * mask_w if cfg.mask_proto_crop else 1
                pos_gt_csize = center_size(pos_gt_box_t)
                gt_box_width = pos_gt_csize[:, 2] * mask_w
                gt_box_height = pos_gt_csize[:, 3] * mask_h
                pre_loss = pre_loss.sum(
                    dim=(0, 1)) / gt_box_width / gt_box_height * weight

            # If the number of masks were limited scale the loss accordingly
            if old_num_pos > num_pos:
                pre_loss *= old_num_pos / num_pos

            loss_m += torch.sum(pre_loss)

            if cfg.use_maskiou:
                if cfg.discard_mask_area > 0:
                    gt_mask_area = torch.sum(mask_t, dim=(0, 1))
                    select = gt_mask_area > cfg.discard_mask_area

                    if torch.sum(select) < 1:
                        continue

                    pos_gt_box_t = pos_gt_box_t[select, :]
                    pred_masks = pred_masks[:, :, select]
                    mask_t = mask_t[:, :, select]
                    label_t = label_t[select]

                maskiou_net_input = pred_masks.permute(
                    2, 0, 1).contiguous().unsqueeze(1)
                pred_masks = pred_masks.gt(0.5).float()
                maskiou_t = self._mask_iou(pred_masks, mask_t)

                maskiou_net_input_list.append(maskiou_net_input)
                maskiou_t_list.append(maskiou_t)
                label_t_list.append(label_t)

        losses = {'M': loss_m * cfg.mask_alpha / mask_h / mask_w}
        losses['MS'] = loss_ms

        if cfg.mask_proto_coeff_diversity_loss:
            losses['D'] = loss_d

        if cfg.use_maskiou:
            # discard_mask_area discarded every mask in the batch, so nothing to do here
            if len(maskiou_t_list) == 0:
                return losses, None

            maskiou_t = torch.cat(maskiou_t_list)
            label_t = torch.cat(label_t_list)
            maskiou_net_input = torch.cat(maskiou_net_input_list)

            num_samples = maskiou_t.size(0)
            if cfg.maskious_to_train > 0 and num_samples > cfg.maskious_to_train:
                perm = torch.randperm(num_samples)
                select = perm[:cfg.masks_to_train]
                maskiou_t = maskiou_t[select]
                label_t = label_t[select]
                maskiou_net_input = maskiou_net_input[select]

            return losses, [maskiou_net_input, maskiou_t, label_t]

        return losses
    def lincomb_mask_loss(self, pos, idx_t, loc_data, mask_data, priors, proto_data, masks, gt_box_t, score_data, inst_data, labels, interpolation_mode='bilinear'):
        #cw: proto_data : (batch_size, mask_h, mask_w, mask_dim) mask_dim은 k값 말하는 것 같음.
        # mask_data : (batch_size, num_priors, mask_dim)
        mask_h = proto_data.size(1)
        mask_w = proto_data.size(2)
        
        #cw : 둘 다 True
        process_gt_bboxes = cfg.mask_proto_normalize_emulate_roi_pooling or cfg.mask_proto_crop
        
        #cw: yolact_plus -- False
        if cfg.mask_proto_remove_empty_masks: 
            # Make sure to store a copy of this because we edit it to get rid of all-zero masks
            pos = pos.clone()

        loss_m = 0
        loss_d = 0 # Coefficient diversity loss

        maskiou_t_list = []
        maskiou_net_input_list = []
        label_t_list = []

        #cw: masks -- (num_objs,im_height,im_width)

        for idx in range(mask_data.size(0)):    
            #cw : 1. mask를 재조정하는 과정
            with torch.no_grad():
                downsampled_masks = F.interpolate(masks[idx].unsqueeze(0), (mask_h, mask_w),
                                                  mode=interpolation_mode, align_corners=False).squeeze(0)
                downsampled_masks = downsampled_masks.permute(1, 2, 0).contiguous() 
                #cw : permute를 거침으로써 원소들의 메모리 배치가 contiguous하지 않아졌으므로
                # contiguous() 를 통해 조정.
                # permute: 차원을 인덱스로 뒤섞어 줌                (h, w, 1)

                #cw : yolact_plus -- True
                #bilinear interpolate를 거친 GT mask를 다시 0, 1로 재조정합니다.(threshold == 0.5)
                if cfg.mask_proto_binarize_downsampled_gt:
                    downsampled_masks = downsampled_masks.gt(0.5).float()

                #cw : NOT NEEDED
                if cfg.mask_proto_remove_empty_masks:
                    # Get rid of gt masks that are so small they get downsampled away
                    very_small_masks = (downsampled_masks.sum(dim=(0,1)) <= 0.0001)
                    for i in range(very_small_masks.size(0)):
                        if very_small_masks[i]:
                            pos[idx, idx_t[idx] == i] = 0

                #cw : NOT NEEDED
                if cfg.mask_proto_reweight_mask_loss:
                    # Ensure that the gt is binary
                    if not cfg.mask_proto_binarize_downsampled_gt:
                        bin_gt = downsampled_masks.gt(0.5).float()
                    else:
                        bin_gt = downsampled_masks

                    gt_foreground_norm = bin_gt     / (torch.sum(bin_gt,   dim=(0,1), keepdim=True) + 0.0001)
                    gt_background_norm = (1-bin_gt) / (torch.sum(1-bin_gt, dim=(0,1), keepdim=True) + 0.0001)

                    mask_reweighting   = gt_foreground_norm * cfg.mask_proto_reweight_coeff + gt_background_norm
                    mask_reweighting  *= mask_h * mask_w

            #cw: 2. 
            
            #idx번째 배치샘플의 prior box중 confidence통과한 애들의 index만 가지고 있음
            cur_pos = pos[idx] 
            # match를 통과한 idx_t에서 cur_pos에서 True나온 부분들을 빼냄.
            # 본래 idx_t에서 골라진 애들은 cur_pos와 동일한게 perfect한것.
            pos_idx_t = idx_t[idx, cur_pos] # 그런애들을 따로 빼냄
            
            #cw : True
            if process_gt_bboxes:
                # Note: this is in point-form
                if cfg.mask_proto_crop_with_pred_box: #False
                    pos_gt_box_t = decode(loc_data[idx, :, :], priors.data, cfg.use_yolo_regressors)[cur_pos]
                else:
                    pos_gt_box_t = gt_box_t[idx, cur_pos]

            if pos_idx_t.size(0) == 0:
                continue

            proto_masks = proto_data[idx]
            proto_coef  = mask_data[idx, cur_pos, :]
            if cfg.use_mask_scoring:
                mask_scores = score_data[idx, cur_pos, :]

            #cw : NOT NEEDED
            if cfg.mask_proto_coeff_diversity_loss:
                if inst_data is not None:
                    div_coeffs = inst_data[idx, cur_pos, :]
                else:
                    div_coeffs = proto_coef

                loss_d += self.coeff_diversity_loss(div_coeffs, pos_idx_t)
            
            # If we have over the allowed number of masks, select a random sample
            old_num_pos = proto_coef.size(0)
            if old_num_pos > cfg.masks_to_train:    #cw yolact_plus -- 100개만큼의 mask에 대해서만 훈련.
                perm = torch.randperm(proto_coef.size(0))
                select = perm[:cfg.masks_to_train]

                proto_coef = proto_coef[select, :]
                pos_idx_t  = pos_idx_t[select]
                
                if process_gt_bboxes: #cw : True
                    pos_gt_box_t = pos_gt_box_t[select, :]
                if cfg.use_mask_scoring:
                    mask_scores = mask_scores[select, :]

            num_pos = proto_coef.size(0)
            mask_t = downsampled_masks[:, :, pos_idx_t]     
            label_t = labels[idx][pos_idx_t]     

            # Size: [mask_h, mask_w, num_pos]
            pred_masks = proto_masks @ proto_coef.t()
            pred_masks = cfg.mask_proto_mask_activation(pred_masks)

            if cfg.mask_proto_double_loss:
                if cfg.mask_proto_mask_activation == activation_func.sigmoid:
                    pre_loss = F.binary_cross_entropy(torch.clamp(pred_masks, 0, 1), mask_t, reduction='sum')
                else:
                    pre_loss = F.smooth_l1_loss(pred_masks, mask_t, reduction='sum')
                
                loss_m += cfg.mask_proto_double_loss_alpha * pre_loss

            if cfg.mask_proto_crop:
                pred_masks = crop(pred_masks, pos_gt_box_t)
            
            if cfg.mask_proto_mask_activation == activation_func.sigmoid:
                pre_loss = F.binary_cross_entropy(torch.clamp(pred_masks, 0, 1), mask_t, reduction='none')
            else:
                pre_loss = F.smooth_l1_loss(pred_masks, mask_t, reduction='none')

            if cfg.mask_proto_normalize_mask_loss_by_sqrt_area:
                gt_area  = torch.sum(mask_t, dim=(0, 1), keepdim=True)
                pre_loss = pre_loss / (torch.sqrt(gt_area) + 0.0001)
            
            if cfg.mask_proto_reweight_mask_loss:
                pre_loss = pre_loss * mask_reweighting[:, :, pos_idx_t]
                
            if cfg.mask_proto_normalize_emulate_roi_pooling:
                weight = mask_h * mask_w if cfg.mask_proto_crop else 1
                pos_gt_csize = center_size(pos_gt_box_t)
                gt_box_width  = pos_gt_csize[:, 2] * mask_w
                gt_box_height = pos_gt_csize[:, 3] * mask_h
                pre_loss = pre_loss.sum(dim=(0, 1)) / gt_box_width / gt_box_height * weight

            # If the number of masks were limited scale the loss accordingly
            if old_num_pos > num_pos:
                pre_loss *= old_num_pos / num_pos

            loss_m += torch.sum(pre_loss)

            #cw : yolact_plus -- True
            if cfg.use_maskiou:
                #cw yolact_plus -- 'discard_mask_area': 5*5,
                if cfg.discard_mask_area > 0:
                    gt_mask_area = torch.sum(mask_t, dim=(0, 1)) #cw 0과 1차원은 reduced 되고 나머지 차원에 맞게 sum진행.
                    select = gt_mask_area > cfg.discard_mask_area

                    if torch.sum(select) < 1:
                        continue

                    pos_gt_box_t = pos_gt_box_t[select, :]
                    pred_masks = pred_masks[:, :, select]
                    mask_t = mask_t[:, :, select]
                    label_t = label_t[select]

                maskiou_net_input = pred_masks.permute(2, 0, 1).contiguous().unsqueeze(1)
                pred_masks = pred_masks.gt(0.5).float()                
                maskiou_t = self._mask_iou(pred_masks, mask_t)
                
                maskiou_net_input_list.append(maskiou_net_input)
                maskiou_t_list.append(maskiou_t)
                label_t_list.append(label_t)
        
        losses = {'M': loss_m * cfg.mask_alpha / mask_h / mask_w}
        
        #cw : NOT NEEDED
        if cfg.mask_proto_coeff_diversity_loss:
            losses['D'] = loss_d

        #cw yolact_plus : True
        if cfg.use_maskiou:
            # discard_mask_area discarded every mask in the batch, so nothing to do here
            if len(maskiou_t_list) == 0:
                return losses, None

            maskiou_t = torch.cat(maskiou_t_list)
            label_t = torch.cat(label_t_list)
            maskiou_net_input = torch.cat(maskiou_net_input_list)

            num_samples = maskiou_t.size(0)
            if cfg.maskious_to_train > 0 and num_samples > cfg.maskious_to_train:
                perm = torch.randperm(num_samples)
                select = perm[:cfg.masks_to_train]
                maskiou_t = maskiou_t[select]
                label_t = label_t[select]
                maskiou_net_input = maskiou_net_input[select]

            return losses, [maskiou_net_input, maskiou_t, label_t]

        return losses
    def lincomb_mask_loss(self,
                          pos,
                          idx_t,
                          loc_data,
                          mask_data,
                          priors,
                          wrapper,
                          masks,
                          gt_box_t,
                          inst_data,
                          labels,
                          truths,
                          interpolation_mode='bilinear'):
        """

        :param pos:
        :param idx_t:
        :param loc_data:
        :param mask_data:
        :param priors:
        :param wrapper:
        :param masks: masks from GT
        :param gt_box_t:
        :param inst_data:
        :param labels: A Python-List!!!
        :param truths: GT BBOX!!!
        :param interpolation_mode:
        :return:
        """
        # mask_h = proto_data.size(1)
        # mask_w = proto_data.size(2)
        # 推测 proto_data 的 shape 是 (batch_size, h, w)
        mask_h, mask_w = 138, 138
        # 默认配置下 proto_data.size 为 torch.Size([1, 138, 138, 32])
        # mask_data.size() torch.Size([1, 19248, 32])

        # GT_BBOX
        # gt_box_t tensor([[[0.6272, 0.2766, 0.8167, 0.8574],
        #          [0.6272, 0.2766, 0.8167, 0.8574],
        #          [0.6272, 0.2766, 0.8167, 0.8574],
        #          ...,
        #          [0.7971, 0.2935, 0.9717, 0.9171],
        #          [0.7971, 0.2935, 0.9717, 0.9171],
        #          [0.7971, 0.2935, 0.9717, 0.9171]]]), gt_box_t.size() torch.Size([1, 19248, 4])

        process_gt_bboxes = cfg.mask_proto_normalize_emulate_roi_pooling or cfg.mask_proto_crop

        if cfg.mask_proto_remove_empty_masks:
            # Make sure to store a copy of this because we edit it to get rid of all-zero masks
            pos = pos.clone()

        loss_m = 0
        loss_d = 0  # Coefficient diversity loss

        for idx in range(mask_data.size(0)):  # 也即是 batch size --- confirmed
            cur_pos = pos[idx]
            pos_idx_t = idx_t[idx, cur_pos]
            # pos tensor([[False, False, False,  ..., False, False, False]]),
            # pos.size() torch.Size([1, 19248]),
            # cur_pos tensor([False, False, False,  ..., False, False, False]),
            # pos_idx_t tensor([ 6,  9, 18, 22,  0,  0, 12, 20, 21, 19,  0,  0,  1,  1,  1,  7, 14,  1,
            #         14, 14, 11,  5, 13,  8, 10,  2,  4,  3, 15, 16,  0,  0,  0,  0,  1, 17,
            #         17])
            # cur_pos.size torch.Size([19248]), idx_t.size torch.Size([1, 19248]), pos_idx_t.size torch.Size([43])

            if process_gt_bboxes:
                # Note: this is in point-form
                pos_gt_box_t = gt_box_t[idx, cur_pos]

            if pos_idx_t.size(0) == 0:
                continue

            with torch.no_grad():
                # 似乎作用是把 GT mask downsample 到网络输出大小 - 138 by 138
                # masks[idx].size() torch.Size([5, 550, 550]) -> downsampled_masks.size() torch.Size([138, 138, 5])
                # 5 应该是每张图片中 seg 的数量
                # downsampled_masks = F.interpolate(masks[idx].unsqueeze(0), (mask_h, mask_w),
                #                                   mode=interpolation_mode, align_corners=False).squeeze(0)
                # # ([5, 138, 138])
                # downsampled_masks = downsampled_masks.permute(1, 2, 0).contiguous()
                # # ([138, 138, 5])

                # downsampled_masks = masks[idx].permute(1, 2, 0).contiguous()
                # should be [550, 550, 5]
                # print(f"downsampled_masks {downsampled_masks.size()}")
                downsampled_masks = crop_gt_mask(masks[idx], pos_gt_box_t, 550)

                if cfg.mask_proto_binarize_downsampled_gt:
                    downsampled_masks = downsampled_masks.gt(0.5).float()

            # proto_masks = proto_data[idx]
            proto_coef = mask_data[idx, cur_pos, :]

            if cfg.mask_proto_coeff_diversity_loss:
                if inst_data is not None:
                    div_coeffs = inst_data[idx, cur_pos, :]
                else:
                    div_coeffs = proto_coef

                loss_d += self.coeff_diversity_loss(div_coeffs, pos_idx_t)

            # If we have over the allowed number of masks, select a random sample
            old_num_pos = proto_coef.size(0)
            if old_num_pos > cfg.masks_to_train:
                perm = torch.randperm(proto_coef.size(0))
                select = perm[:cfg.masks_to_train]

                proto_coef = proto_coef[select, :]
                pos_idx_t = pos_idx_t[select]

                if process_gt_bboxes:
                    pos_gt_box_t = pos_gt_box_t[select, :]

            num_pos = proto_coef.size(0)
            mask_t = downsampled_masks[:, :, pos_idx_t]

            # Load the preselected bases
            gt_bases = np.zeros((truths.size(0), 64, 64, 50))
            # print(f"labels {labels}, truths.size {truths.size()}")
            for truth_id in range(truths.size(0)):
                current_cat_id = labels[idx][truth_id].item()
                if current_cat_id == 0:
                    # print('id 0 detected!')
                    # 最好确定一下 0 是不是真的就是 bg
                    continue
                gt_bases[truth_id] = wrapper.get_bases(current_cat_id)
            gt_bases = torch.from_numpy(gt_bases)[idx].float().to(
                proto_coef.device)
            # [idx] is needed due to the batch_size problem
            # gt_bases is of size [64, 64, 50]
            # print(f"gt_bases.size {gt_bases.size()}")

            # pos_idx_t tensor([0, 0, 0, 0, 1, 0]), pos_idx_t.size torch.Size([6])
            # proto_masks.size torch.Size([138, 138, 50])
            # proto_coef.size torch.Size([6, 50])
            # proto_coef.t.size torch.Size([50, 6])

            # Size: [mask_h, mask_w, num_pos]
            # pred_masks = proto_masks @ proto_coef.t()

            pred_masks = gt_bases @ proto_coef.t()

            # [脑子瓦特了写法...]
            # pred_masks = np.zeros((64, 64, len(pos_idx_t)))
            # for i in range(len(pos_idx_t)):
            #     pred_masks[:, :, i] = gt_bases[truths[pos_idx_t[i]]] @ np.tile(proto_coef[i],
            #                                                                    (len(pos_idx_t), 1)).transpose()

            pred_masks = cfg.mask_proto_mask_activation(pred_masks)

            # if cfg.mask_proto_crop:
            #     pred_masks = crop(pred_masks, pos_gt_box_t)

            # 验证: mask_t 的 size 和 pred_masks 的 size
            # pred_masks before crop torch.Size([138, 138, 11])
            # pred_masks after  crop torch.Size([138, 138, 11])
            # mask_t                 torch.Size([138, 138, 11])
            # 所以,直接用 crop 去裁剪 mask_t
            # 错误!!! 此 crop 非彼 crop
            # mask_t_cropped = crop(mask_t, pos_gt_box_t, 1, 64)

            if cfg.mask_proto_mask_activation == activation_func.sigmoid:
                pre_loss = F.binary_cross_entropy(torch.clamp(
                    pred_masks, 0, 1),
                                                  mask_t,
                                                  reduction='none')
            else:
                pre_loss = F.smooth_l1_loss(pred_masks,
                                            mask_t,
                                            reduction='none')

            if cfg.mask_proto_normalize_emulate_roi_pooling:
                weight = mask_h * mask_w if cfg.mask_proto_crop else 1
                pos_get_csize = center_size(pos_gt_box_t)
                gt_box_width = pos_get_csize[:, 2] * mask_w
                gt_box_height = pos_get_csize[:, 3] * mask_h
                pre_loss = pre_loss.sum(
                    dim=(0, 1)) / gt_box_width / gt_box_height * weight

            # If the number of masks were limited scale the loss accordingly
            if old_num_pos > num_pos:
                pre_loss *= old_num_pos / num_pos

            loss_m += torch.sum(pre_loss)

        losses = {'M': loss_m * cfg.mask_alpha / mask_h / mask_w}

        return losses
Exemple #7
0
    def lincomb_mask_loss(self,
                          pos,
                          idx_t,
                          loc_data,
                          mask_data,
                          priors,
                          proto_data,
                          masks,
                          gt_box_t,
                          score_data,
                          inst_data,
                          labels,
                          interpolation_mode='bilinear',
                          pred_seg=False):
        # proto_data size
        # [batch, mask_h, mask_w, mask_dim]
        # torch.Size([1, 138, 138, 32]
        mask_h = proto_data.size(1)
        mask_w = proto_data.size(2)

        process_gt_bboxes = cfg.mask_proto_normalize_emulate_roi_pooling or cfg.mask_proto_crop

        if cfg.mask_proto_remove_empty_masks:
            # Make sure to store a copy of this because we edit it to get rid of all-zero masks
            # pos = conf_t > 0
            pos = pos.clone()

        loss_m = 0
        loss_d = 0  # Coefficient diversity loss

        maskiou_t_list = []
        maskiou_net_input_list = []
        label_t_list = []

        pred_seg_list = []
        for idx in range(mask_data.size(0)):
            with torch.no_grad():
                downsampled_masks = F.interpolate(
                    masks[idx].unsqueeze(0), (mask_h, mask_w),
                    mode=interpolation_mode,
                    align_corners=False).squeeze(0)
                downsampled_masks = downsampled_masks.permute(1, 2,
                                                              0).contiguous()

                if cfg.mask_proto_binarize_downsampled_gt:
                    downsampled_masks = downsampled_masks.gt(0.5).float()

                if cfg.mask_proto_remove_empty_masks:
                    # Get rid of gt masks that are so small they get downsampled away
                    very_small_masks = (downsampled_masks.sum(dim=(0, 1)) <=
                                        0.0001)
                    for i in range(very_small_masks.size(0)):
                        if very_small_masks[i]:
                            pos[idx, idx_t[idx] == i] = 0

                if cfg.mask_proto_reweight_mask_loss:
                    # Ensure that the gt is binary
                    if not cfg.mask_proto_binarize_downsampled_gt:
                        bin_gt = downsampled_masks.gt(0.5).float()
                    else:
                        bin_gt = downsampled_masks

                    gt_foreground_norm = bin_gt / (
                        torch.sum(bin_gt, dim=(0, 1), keepdim=True) + 0.0001)
                    gt_background_norm = (1 - bin_gt) / (torch.sum(
                        1 - bin_gt, dim=(0, 1), keepdim=True) + 0.0001)

                    mask_reweighting = gt_foreground_norm * cfg.mask_proto_reweight_coeff + gt_background_norm
                    mask_reweighting *= mask_h * mask_w

            # pos = conf_t > 0
            cur_pos = pos[idx]
            pos_idx_t = idx_t[idx, cur_pos]

            if process_gt_bboxes:
                # Note: this is in point-form
                if cfg.mask_proto_crop_with_pred_box:
                    pos_gt_box_t = decode(loc_data[idx, :, :], priors.data,
                                          cfg.use_yolo_regressors)[cur_pos]
                else:
                    pos_gt_box_t = gt_box_t[idx, cur_pos]

            if pos_idx_t.size(0) == 0:
                continue

            proto_masks = proto_data[idx]
            proto_coef = mask_data[idx, cur_pos, :]
            if cfg.use_mask_scoring:
                mask_scores = score_data[idx, cur_pos, :]

            if cfg.mask_proto_coeff_diversity_loss:
                if inst_data is not None:
                    div_coeffs = inst_data[idx, cur_pos, :]
                else:
                    div_coeffs = proto_coef

                loss_d += self.coeff_diversity_loss(div_coeffs, pos_idx_t)

            # If we have over the allowed number of masks, select a random sample
            old_num_pos = proto_coef.size(0)
            if old_num_pos > cfg.masks_to_train:
                perm = torch.randperm(proto_coef.size(0))
                select = perm[:cfg.masks_to_train]

                proto_coef = proto_coef[select, :]
                pos_idx_t = pos_idx_t[select]

                if process_gt_bboxes:
                    pos_gt_box_t = pos_gt_box_t[select, :]
                if cfg.use_mask_scoring:
                    mask_scores = mask_scores[select, :]

            num_pos = proto_coef.size(0)
            mask_t = downsampled_masks[:, :, pos_idx_t]
            label_t = labels[idx][pos_idx_t]

            # Size: [mask_h, mask_w, num_pos]
            # NOTE: output the prediction segmentation
            pred_masks = proto_masks @ proto_coef.t()
            pred_masks = cfg.mask_proto_mask_activation(pred_masks)
            if pred_seg:
                # mask = mask_clas.squeeze().cpu().detach().numpy()
                # from PIL import Image
                # seg = Image.fromarray(seg, 'L')
                # mask = Image.fromarray(mask, 'L')
                # seg.save('seg.png')
                # mask.save('mask.png')
                # raise RuntimeError
                # pred_masks_gt05 = pred_masks.gt(0.5)

                pred_seg_list.append(pred_masks)
                label_t_list.append(label_t)

            if cfg.mask_proto_double_loss:
                if cfg.mask_proto_mask_activation == activation_func.sigmoid:
                    pre_loss = F.binary_cross_entropy(torch.clamp(
                        pred_masks, 0, 1),
                                                      mask_t,
                                                      reduction='sum')
                else:
                    pre_loss = F.smooth_l1_loss(pred_masks,
                                                mask_t,
                                                reduction='sum')

                loss_m += cfg.mask_proto_double_loss_alpha * pre_loss

            if cfg.mask_proto_crop:
                pred_masks = crop(pred_masks, pos_gt_box_t)

            if cfg.mask_proto_mask_activation == activation_func.sigmoid:
                pre_loss = F.binary_cross_entropy(torch.clamp(
                    pred_masks, 0, 1),
                                                  mask_t,
                                                  reduction='none')
            else:
                pre_loss = F.smooth_l1_loss(pred_masks,
                                            mask_t,
                                            reduction='none')

            if cfg.mask_proto_normalize_mask_loss_by_sqrt_area:
                gt_area = torch.sum(mask_t, dim=(0, 1), keepdim=True)
                pre_loss = pre_loss / (torch.sqrt(gt_area) + 0.0001)

            if cfg.mask_proto_reweight_mask_loss:
                pre_loss = pre_loss * mask_reweighting[:, :, pos_idx_t]

            if cfg.mask_proto_normalize_emulate_roi_pooling:
                weight = mask_h * mask_w if cfg.mask_proto_crop else 1
                pos_gt_csize = center_size(pos_gt_box_t)
                gt_box_width = pos_gt_csize[:, 2] * mask_w
                gt_box_height = pos_gt_csize[:, 3] * mask_h
                pre_loss = pre_loss.sum(
                    dim=(0, 1)) / gt_box_width / gt_box_height * weight

            # If the number of masks were limited scale the loss accordingly
            if old_num_pos > num_pos:
                pre_loss *= old_num_pos / num_pos

            loss_m += torch.sum(pre_loss)

            if cfg.use_maskiou:
                if cfg.discard_mask_area > 0:
                    gt_mask_area = torch.sum(mask_t, dim=(0, 1))
                    select = gt_mask_area > cfg.discard_mask_area

                    if torch.sum(select) < 1:
                        continue

                    pos_gt_box_t = pos_gt_box_t[select, :]
                    pred_masks = pred_masks[:, :, select]
                    mask_t = mask_t[:, :, select]
                    label_t = label_t[select]

                maskiou_net_input = pred_masks.permute(
                    2, 0, 1).contiguous().unsqueeze(1)
                pred_masks = pred_masks.gt(0.5).float()
                maskiou_t = self._mask_iou(pred_masks, mask_t)

                maskiou_net_input_list.append(maskiou_net_input)
                maskiou_t_list.append(maskiou_t)
                label_t_list.append(label_t)

        losses = {'M': loss_m * cfg.mask_alpha / mask_h / mask_w}

        if cfg.mask_proto_coeff_diversity_loss:
            losses['D'] = loss_d

        if cfg.use_maskiou:
            # discard_mask_area discarded every mask in the batch, so nothing to do here
            if len(maskiou_t_list) == 0:
                return losses, None

            maskiou_t = torch.cat(maskiou_t_list)
            label_t = torch.cat(label_t_list)
            maskiou_net_input = torch.cat(maskiou_net_input_list)

            num_samples = maskiou_t.size(0)
            if cfg.maskious_to_train > 0 and num_samples > cfg.maskious_to_train:
                perm = torch.randperm(num_samples)
                select = perm[:cfg.masks_to_train]
                maskiou_t = maskiou_t[select]
                label_t = label_t[select]
                maskiou_net_input = maskiou_net_input[select]

            return losses, [maskiou_net_input, maskiou_t, label_t]

        # NOTE
        if pred_seg:
            return losses, pred_seg_list, label_t_list
        else:
            return losses