def postprocess(det_output, w, h, batch_idx=0, interpolation_mode='bilinear', visualize_lincomb=False, crop_masks=True, score_threshold=0): """ Postprocesses the output of Yolact on testing mode into a format that makes sense, accounting for all the possible configuration settings. Args: - det_output: The lost of dicts that Detect outputs. - w: The real with of the image. - h: The real height of the image. - batch_idx: If you have multiple images for this batch, the image's index in the batch. - interpolation_mode: Can be 'nearest' | 'area' | 'bilinear' (see torch.nn.functional.interpolate) Returns 4 torch Tensors (in the following order): - classes [num_det]: The class idx for each detection. - scores [num_det]: The confidence score for each detection. - boxes [num_det, 4]: The bounding box for each detection in absolute point form. - masks [num_det, h, w]: Full image masks for each detection. """ dets = det_output[batch_idx] net = dets['net'] dets = dets['detection'] if dets is None: return [torch.Tensor() ] * 4 # Warning, this is 4 copies of the same thing if score_threshold > 0: keep = dets['score'] > score_threshold for k in dets: if k != 'proto': dets[k] = dets[k][keep] if dets['score'].size(0) == 0: return [torch.Tensor()] * 4 # Actually extract everything from dets now classes = dets['class'] boxes = dets['box'] scores = dets['score'] masks = dets['mask'] if cfg.mask_type == mask_type.lincomb and cfg.eval_mask_branch: # At this points masks is only the coefficients proto_data = dets['proto'] # Test flag, do not upvote if cfg.mask_proto_debug: np.save('scripts/proto.npy', proto_data.cpu().numpy()) if visualize_lincomb: display_lincomb(proto_data, masks) masks = proto_data @ masks.t() masks = cfg.mask_proto_mask_activation(masks) # Crop masks before upsampling because you know why if crop_masks: masks = crop(masks, boxes) # Permute into the correct output shape [num_dets, proto_h, proto_w] masks = masks.permute(2, 0, 1).contiguous() if cfg.use_maskiou: with timer.env('maskiou_net'): with torch.no_grad(): maskiou_p = net.maskiou_net(masks.unsqueeze(1)) maskiou_p = torch.gather( maskiou_p, dim=1, index=classes.unsqueeze(1)).squeeze(1) if cfg.rescore_mask: if cfg.rescore_bbox: scores = scores * maskiou_p else: scores = [scores, scores * maskiou_p] # Scale masks up to the full image masks = F.interpolate(masks.unsqueeze(0), (h, w), mode=interpolation_mode, align_corners=False).squeeze(0) # Binarize the masks masks.gt_(0.5) boxes[:, 0], boxes[:, 2] = sanitize_coordinates(boxes[:, 0], boxes[:, 2], w, cast=False) boxes[:, 1], boxes[:, 3] = sanitize_coordinates(boxes[:, 1], boxes[:, 3], h, cast=False) boxes = boxes.long() if cfg.mask_type == mask_type.direct and cfg.eval_mask_branch: # Upscale masks full_masks = torch.zeros(masks.size(0), h, w) for jdx in range(masks.size(0)): x1, y1, x2, y2 = boxes[jdx, :] mask_w = x2 - x1 mask_h = y2 - y1 # Just in case if mask_w * mask_h <= 0 or mask_w < 0: continue mask = masks[jdx, :].view(1, 1, cfg.mask_size, cfg.mask_size) mask = F.interpolate(mask, (mask_h, mask_w), mode=interpolation_mode, align_corners=False) mask = mask.gt(0.5).float() full_masks[jdx, y1:y2, x1:x2] = mask masks = full_masks return classes, scores, boxes, masks
def lincomb_mask_loss(self, pos, idx_t, loc_data, mask_data, priors, proto_data, masks, gt_box_t, inst_data, interpolation_mode='bilinear'): mask_h = proto_data.size(1) mask_w = proto_data.size(2) process_gt_bboxes = cfg.mask_proto_normalize_emulate_roi_pooling or cfg.mask_proto_crop if cfg.mask_proto_remove_empty_masks: # Make sure to store a copy of this because we edit it to get rid of all-zero masks pos = pos.clone() loss_m = 0 loss_d = 0 # Coefficient diversity loss for idx in range(mask_data.size(0)): with torch.no_grad(): downsampled_masks = F.interpolate( masks[idx].unsqueeze(0), (mask_h, mask_w), mode=interpolation_mode, align_corners=False).squeeze(0) downsampled_masks = downsampled_masks.permute(1, 2, 0).contiguous() if cfg.mask_proto_binarize_downsampled_gt: downsampled_masks = downsampled_masks.gt(0.5).float() if cfg.mask_proto_remove_empty_masks: # Get rid of gt masks that are so small they get downsampled away very_small_masks = (downsampled_masks.sum(dim=(0, 1)) <= 0.0001) for i in range(very_small_masks.size(0)): if very_small_masks[i]: pos[idx, idx_t[idx] == i] = 0 if cfg.mask_proto_reweight_mask_loss: # Ensure that the gt is binary if not cfg.mask_proto_binarize_downsampled_gt: bin_gt = downsampled_masks.gt(0.5).float() else: bin_gt = downsampled_masks gt_foreground_norm = bin_gt / ( torch.sum(bin_gt, dim=(0, 1), keepdim=True) + 0.0001) gt_background_norm = (1 - bin_gt) / (torch.sum( 1 - bin_gt, dim=(0, 1), keepdim=True) + 0.0001) mask_reweighting = gt_foreground_norm * cfg.mask_proto_reweight_coeff + gt_background_norm mask_reweighting *= mask_h * mask_w cur_pos = pos[idx] pos_idx_t = idx_t[idx, cur_pos] if process_gt_bboxes: # Note: this is in point-form pos_gt_box_t = gt_box_t[idx, cur_pos] if pos_idx_t.size(0) == 0: continue proto_masks = proto_data[idx] proto_coef = mask_data[idx, cur_pos, :] if cfg.mask_proto_coeff_diversity_loss: if inst_data is not None: div_coeffs = inst_data[idx, cur_pos, :] else: div_coeffs = proto_coef loss_d += self.coeff_diversity_loss(div_coeffs, pos_idx_t) # If we have over the allowed number of masks, select a random sample old_num_pos = proto_coef.size(0) if old_num_pos > cfg.masks_to_train: perm = torch.randperm(proto_coef.size(0)) select = perm[:cfg.masks_to_train] proto_coef = proto_coef[select, :] pos_idx_t = pos_idx_t[select] if process_gt_bboxes: pos_gt_box_t = pos_gt_box_t[select, :] num_pos = proto_coef.size(0) mask_t = downsampled_masks[:, :, pos_idx_t] # Size: [mask_h, mask_w, num_pos] pred_masks = proto_masks @ proto_coef.t() pred_masks = cfg.mask_proto_mask_activation(pred_masks) if cfg.mask_proto_double_loss: if cfg.mask_proto_mask_activation == activation_func.sigmoid: pre_loss = F.binary_cross_entropy(torch.clamp( pred_masks, 0, 1), mask_t, reduction='sum') else: pre_loss = F.smooth_l1_loss(pred_masks, mask_t, reduction='sum') loss_m += cfg.mask_proto_double_loss_alpha * pre_loss if cfg.mask_proto_crop: pred_masks = crop(pred_masks, pos_gt_box_t) if cfg.mask_proto_mask_activation == activation_func.sigmoid: pre_loss = F.binary_cross_entropy(torch.clamp( pred_masks, 0, 1), mask_t, reduction='none') else: pre_loss = F.smooth_l1_loss(pred_masks, mask_t, reduction='none') if cfg.mask_proto_normalize_mask_loss_by_sqrt_area: gt_area = torch.sum(mask_t, dim=(0, 1), keepdim=True) pre_loss = pre_loss / (torch.sqrt(gt_area) + 0.0001) if cfg.mask_proto_reweight_mask_loss: pre_loss = pre_loss * mask_reweighting[:, :, pos_idx_t] if cfg.mask_proto_normalize_emulate_roi_pooling: weight = mask_h * mask_w if cfg.mask_proto_crop else 1 pos_get_csize = center_size(pos_gt_box_t) gt_box_width = pos_get_csize[:, 2] * mask_w gt_box_height = pos_get_csize[:, 3] * mask_h pre_loss = pre_loss.sum( dim=(0, 1)) / gt_box_width / gt_box_height * weight # If the number of masks were limited scale the loss accordingly if old_num_pos > num_pos: pre_loss *= old_num_pos / num_pos loss_m += torch.sum(pre_loss) losses = {'M': loss_m * cfg.mask_alpha / mask_h / mask_w} if cfg.mask_proto_coeff_diversity_loss: losses['D'] = loss_d return losses
def postprocess(det_output, w, h, batch_idx=0, interpolation_mode='bilinear', visualize_lincomb=False, crop_masks=True, score_threshold=0, mask_score=True): """ Postprocesses the output of EPSNet on testing mode into a format that makes sense, accounting for all the possible configuration settings. Args: - det_output: The lost of dicts that Detect outputs. - w: The real with of the image. - h: The real height of the image. - batch_idx: If you have multiple images for this batch, the image's index in the batch. - interpolation_mode: Can be 'nearest' | 'area' | 'bilinear' (see torch.nn.functional.interpolate) Returns 4 torch Tensors (in the following order): - classes [num_det]: The class idx for each detection. - scores [num_det]: The confidence score for each detection. - boxes [num_det, 4]: The bounding box for each detection in absolute point form. - masks [num_det, h, w]: Full image masks for each detection. """ dets = det_output[batch_idx] if not 'score' in dets: return [torch.Tensor() ] * 4 # Warning, this is 4 copies of the same thing if score_threshold > 0: keep = dets['score'] > score_threshold for k in dets: if k != 'proto' and k != 'segm': dets[k] = dets[k][keep] if dets['score'].size(0) == 0: return [torch.Tensor()] * 4 # im_w and im_h when it concerns bboxes. This is a workaround hack for preserve_aspect_ratio b_w, b_h = (w, h) # Undo the padding introduced with preserve_aspect_ratio if cfg.preserve_aspect_ratio: r_w, r_h = Resize.faster_rcnn_scale(w, h, cfg.min_size, cfg.max_size) # Get rid of any detections whose centers are outside the image boxes = dets['box'] boxes = center_size(boxes) s_w, s_h = (r_w / cfg.max_size, r_h / cfg.max_size) not_outside = ((boxes[:, 0] > s_w) + (boxes[:, 1] > s_h)) < 1 # not (a or b) for k in dets: if k != 'proto': dets[k] = dets[k][not_outside] # A hack to scale the bboxes to the right size b_w, b_h = (cfg.max_size / r_w * w, cfg.max_size / r_h * h) # Actually extract everything from dets now classes = dets['class'] boxes = dets['box'] scores = dets['score'] masks = dets['mask'] if cfg.mask_type == mask_type.lincomb and cfg.eval_mask_branch: # At this points masks is only the coefficients proto_data = dets['proto'] # Test flag, do not upvote if cfg.mask_proto_debug: np.save('scripts/proto.npy', proto_data.cpu().numpy()) if visualize_lincomb: display_lincomb(proto_data, masks) masks = torch.matmul(proto_data, masks.t()) if mask_score: masks = cfg.mask_proto_mask_activation(masks) # Crop masks before upsampling because you know why if crop_masks: masks = crop(masks, boxes) # Permute into the correct output shape [num_dets, proto_h, proto_w] masks = masks.permute(2, 0, 1).contiguous() # Scale masks up to the full image if cfg.preserve_aspect_ratio: # Undo padding masks = masks[:, :int(r_h / cfg.max_size * proto_data.size(1) ), :int(r_w / cfg.max_size * proto_data.size(2))] masks = F.interpolate(masks.unsqueeze(0), (h, w), mode=interpolation_mode, align_corners=False).squeeze(0) # Binarize the masks if mask_score: masks.gt_(0.5) if mask_score is True: boxes[:, 0], boxes[:, 2] = sanitize_coordinates(boxes[:, 0], boxes[:, 2], b_w, cast=False) boxes[:, 1], boxes[:, 3] = sanitize_coordinates(boxes[:, 1], boxes[:, 3], b_h, cast=False) boxes = boxes.long() if cfg.mask_type == mask_type.direct and cfg.eval_mask_branch: # Upscale masks full_masks = torch.zeros(masks.size(0), h, w) for jdx in range(masks.size(0)): x1, y1, x2, y2 = boxes[jdx, :] mask_w = x2 - x1 mask_h = y2 - y1 # Just in case if mask_w * mask_h <= 0 or mask_w < 0: continue mask = masks[jdx, :].view(1, 1, cfg.mask_size, cfg.mask_size) mask = F.interpolate(mask, (mask_h, mask_w), mode=interpolation_mode, align_corners=False) if mask_score: mask = mask.gt(0.5).float() full_masks[jdx, y1:y2, x1:x2] = mask masks = full_masks return classes, scores, boxes, masks
def lincomb_mask_loss(self, pos, idx_t, loc_data, mask_data, priors, proto_data, masks, gt_box_t, score_data, inst_data, labels, dist_maps, interpolation_mode='bilinear'): mask_h = proto_data.size(1) mask_w = proto_data.size(2) process_gt_bboxes = cfg.mask_proto_normalize_emulate_roi_pooling or cfg.mask_proto_crop if cfg.mask_proto_remove_empty_masks: # Make sure to store a copy of this because we edit it to get rid of all-zero masks pos = pos.clone() loss_m = 0 loss_ms = 0 loss_d = 0 # Coefficient diversity loss maskiou_t_list = [] maskiou_net_input_list = [] label_t_list = [] for idx in range(mask_data.size(0)): with torch.no_grad(): downsampled_masks = F.interpolate( masks[idx].unsqueeze(0), (mask_h, mask_w), mode=interpolation_mode, align_corners=False).squeeze(0) downsampled_masks = downsampled_masks.permute(1, 2, 0).contiguous() ### downsampled_dist_maps = F.interpolate( dist_maps[idx].unsqueeze(0), (mask_h, mask_w), mode=interpolation_mode, align_corners=False).squeeze(0) downsampled_dist_maps = downsampled_dist_maps.permute( 1, 2, 0).contiguous() if cfg.mask_proto_binarize_downsampled_gt: downsampled_masks = downsampled_masks.gt(0.5).float() if cfg.mask_proto_remove_empty_masks: # Get rid of gt masks that are so small they get downsampled away very_small_masks = (downsampled_masks.sum(dim=(0, 1)) <= 0.0001) for i in range(very_small_masks.size(0)): if very_small_masks[i]: pos[idx, idx_t[idx] == i] = 0 if cfg.mask_proto_reweight_mask_loss: # Ensure that the gt is binary if not cfg.mask_proto_binarize_downsampled_gt: bin_gt = downsampled_masks.gt(0.5).float() else: bin_gt = downsampled_masks gt_foreground_norm = bin_gt / ( torch.sum(bin_gt, dim=(0, 1), keepdim=True) + 0.0001) gt_background_norm = (1 - bin_gt) / (torch.sum( 1 - bin_gt, dim=(0, 1), keepdim=True) + 0.0001) mask_reweighting = gt_foreground_norm * cfg.mask_proto_reweight_coeff + gt_background_norm mask_reweighting *= mask_h * mask_w cur_pos = pos[idx] pos_idx_t = idx_t[idx, cur_pos] if process_gt_bboxes: # Note: this is in point-form if cfg.mask_proto_crop_with_pred_box: pos_gt_box_t = decode(loc_data[idx, :, :], priors.data, cfg.use_yolo_regressors)[cur_pos] else: pos_gt_box_t = gt_box_t[idx, cur_pos] if pos_idx_t.size(0) == 0: continue proto_masks = proto_data[idx] proto_coef = mask_data[idx, cur_pos, :] if cfg.use_mask_scoring: mask_scores = score_data[idx, cur_pos, :] if cfg.mask_proto_coeff_diversity_loss: if inst_data is not None: div_coeffs = inst_data[idx, cur_pos, :] else: div_coeffs = proto_coef loss_d += self.coeff_diversity_loss(div_coeffs, pos_idx_t) # If we have over the allowed number of masks, select a random sample old_num_pos = proto_coef.size(0) if old_num_pos > cfg.masks_to_train: perm = torch.randperm(proto_coef.size(0)) select = perm[:cfg.masks_to_train] proto_coef = proto_coef[select, :] pos_idx_t = pos_idx_t[select] if process_gt_bboxes: pos_gt_box_t = pos_gt_box_t[select, :] if cfg.use_mask_scoring: mask_scores = mask_scores[select, :] num_pos = proto_coef.size(0) mask_t = downsampled_masks[:, :, pos_idx_t] dist_maps_t = downsampled_dist_maps[:, :, pos_idx_t] label_t = labels[idx][pos_idx_t] # Size: [mask_h, mask_w, num_pos] pred_masks = proto_masks @ proto_coef.t() pred_masks = cfg.mask_proto_mask_activation(pred_masks) if cfg.mask_proto_double_loss: if cfg.mask_proto_mask_activation == activation_func.sigmoid: pre_loss = F.binary_cross_entropy(torch.clamp( pred_masks, 0, 1), mask_t, reduction='sum') else: pre_loss = F.smooth_l1_loss(pred_masks, mask_t, reduction='sum') loss_m += cfg.mask_proto_double_loss_alpha * pre_loss if cfg.mask_proto_crop: pred_masks = crop(pred_masks, pos_gt_box_t) if cfg.mask_proto_mask_activation == activation_func.sigmoid: pre_loss = F.binary_cross_entropy(torch.clamp( pred_masks, 0, 1), mask_t, reduction='none') else: pre_loss = F.smooth_l1_loss(pred_masks, mask_t, reduction='none') ### my ### adding the surface_loss (GAC_loss) if True: pc = pred_masks.unsqueeze(0).permute( 0, 3, 1, 2) #[:, self.idc, ...].type(torch.float32) dc = dist_maps_t.unsqueeze(0).permute( 0, 3, 1, 2) #[:, self.idc, ...].type(torch.float32) # print(pc.shape) # print(dc.shape) multipled = torch.einsum("bcwh,bcwh->bcwh", pc, dc) surface_loss = multipled.mean() loss_ms += 1 * surface_loss ### my ### if cfg.mask_proto_normalize_mask_loss_by_sqrt_area: gt_area = torch.sum(mask_t, dim=(0, 1), keepdim=True) pre_loss = pre_loss / (torch.sqrt(gt_area) + 0.0001) if cfg.mask_proto_reweight_mask_loss: pre_loss = pre_loss * mask_reweighting[:, :, pos_idx_t] if cfg.mask_proto_normalize_emulate_roi_pooling: weight = mask_h * mask_w if cfg.mask_proto_crop else 1 pos_gt_csize = center_size(pos_gt_box_t) gt_box_width = pos_gt_csize[:, 2] * mask_w gt_box_height = pos_gt_csize[:, 3] * mask_h pre_loss = pre_loss.sum( dim=(0, 1)) / gt_box_width / gt_box_height * weight # If the number of masks were limited scale the loss accordingly if old_num_pos > num_pos: pre_loss *= old_num_pos / num_pos loss_m += torch.sum(pre_loss) if cfg.use_maskiou: if cfg.discard_mask_area > 0: gt_mask_area = torch.sum(mask_t, dim=(0, 1)) select = gt_mask_area > cfg.discard_mask_area if torch.sum(select) < 1: continue pos_gt_box_t = pos_gt_box_t[select, :] pred_masks = pred_masks[:, :, select] mask_t = mask_t[:, :, select] label_t = label_t[select] maskiou_net_input = pred_masks.permute( 2, 0, 1).contiguous().unsqueeze(1) pred_masks = pred_masks.gt(0.5).float() maskiou_t = self._mask_iou(pred_masks, mask_t) maskiou_net_input_list.append(maskiou_net_input) maskiou_t_list.append(maskiou_t) label_t_list.append(label_t) losses = {'M': loss_m * cfg.mask_alpha / mask_h / mask_w} losses['MS'] = loss_ms if cfg.mask_proto_coeff_diversity_loss: losses['D'] = loss_d if cfg.use_maskiou: # discard_mask_area discarded every mask in the batch, so nothing to do here if len(maskiou_t_list) == 0: return losses, None maskiou_t = torch.cat(maskiou_t_list) label_t = torch.cat(label_t_list) maskiou_net_input = torch.cat(maskiou_net_input_list) num_samples = maskiou_t.size(0) if cfg.maskious_to_train > 0 and num_samples > cfg.maskious_to_train: perm = torch.randperm(num_samples) select = perm[:cfg.masks_to_train] maskiou_t = maskiou_t[select] label_t = label_t[select] maskiou_net_input = maskiou_net_input[select] return losses, [maskiou_net_input, maskiou_t, label_t] return losses
def lincomb_mask_loss(self, pos, idx_t, loc_data, mask_data, priors, proto_data, masks, gt_box_t, score_data, inst_data, labels, interpolation_mode='bilinear'): #cw: proto_data : (batch_size, mask_h, mask_w, mask_dim) mask_dim은 k값 말하는 것 같음. # mask_data : (batch_size, num_priors, mask_dim) mask_h = proto_data.size(1) mask_w = proto_data.size(2) #cw : 둘 다 True process_gt_bboxes = cfg.mask_proto_normalize_emulate_roi_pooling or cfg.mask_proto_crop #cw: yolact_plus -- False if cfg.mask_proto_remove_empty_masks: # Make sure to store a copy of this because we edit it to get rid of all-zero masks pos = pos.clone() loss_m = 0 loss_d = 0 # Coefficient diversity loss maskiou_t_list = [] maskiou_net_input_list = [] label_t_list = [] #cw: masks -- (num_objs,im_height,im_width) for idx in range(mask_data.size(0)): #cw : 1. mask를 재조정하는 과정 with torch.no_grad(): downsampled_masks = F.interpolate(masks[idx].unsqueeze(0), (mask_h, mask_w), mode=interpolation_mode, align_corners=False).squeeze(0) downsampled_masks = downsampled_masks.permute(1, 2, 0).contiguous() #cw : permute를 거침으로써 원소들의 메모리 배치가 contiguous하지 않아졌으므로 # contiguous() 를 통해 조정. # permute: 차원을 인덱스로 뒤섞어 줌 (h, w, 1) #cw : yolact_plus -- True #bilinear interpolate를 거친 GT mask를 다시 0, 1로 재조정합니다.(threshold == 0.5) if cfg.mask_proto_binarize_downsampled_gt: downsampled_masks = downsampled_masks.gt(0.5).float() #cw : NOT NEEDED if cfg.mask_proto_remove_empty_masks: # Get rid of gt masks that are so small they get downsampled away very_small_masks = (downsampled_masks.sum(dim=(0,1)) <= 0.0001) for i in range(very_small_masks.size(0)): if very_small_masks[i]: pos[idx, idx_t[idx] == i] = 0 #cw : NOT NEEDED if cfg.mask_proto_reweight_mask_loss: # Ensure that the gt is binary if not cfg.mask_proto_binarize_downsampled_gt: bin_gt = downsampled_masks.gt(0.5).float() else: bin_gt = downsampled_masks gt_foreground_norm = bin_gt / (torch.sum(bin_gt, dim=(0,1), keepdim=True) + 0.0001) gt_background_norm = (1-bin_gt) / (torch.sum(1-bin_gt, dim=(0,1), keepdim=True) + 0.0001) mask_reweighting = gt_foreground_norm * cfg.mask_proto_reweight_coeff + gt_background_norm mask_reweighting *= mask_h * mask_w #cw: 2. #idx번째 배치샘플의 prior box중 confidence통과한 애들의 index만 가지고 있음 cur_pos = pos[idx] # match를 통과한 idx_t에서 cur_pos에서 True나온 부분들을 빼냄. # 본래 idx_t에서 골라진 애들은 cur_pos와 동일한게 perfect한것. pos_idx_t = idx_t[idx, cur_pos] # 그런애들을 따로 빼냄 #cw : True if process_gt_bboxes: # Note: this is in point-form if cfg.mask_proto_crop_with_pred_box: #False pos_gt_box_t = decode(loc_data[idx, :, :], priors.data, cfg.use_yolo_regressors)[cur_pos] else: pos_gt_box_t = gt_box_t[idx, cur_pos] if pos_idx_t.size(0) == 0: continue proto_masks = proto_data[idx] proto_coef = mask_data[idx, cur_pos, :] if cfg.use_mask_scoring: mask_scores = score_data[idx, cur_pos, :] #cw : NOT NEEDED if cfg.mask_proto_coeff_diversity_loss: if inst_data is not None: div_coeffs = inst_data[idx, cur_pos, :] else: div_coeffs = proto_coef loss_d += self.coeff_diversity_loss(div_coeffs, pos_idx_t) # If we have over the allowed number of masks, select a random sample old_num_pos = proto_coef.size(0) if old_num_pos > cfg.masks_to_train: #cw yolact_plus -- 100개만큼의 mask에 대해서만 훈련. perm = torch.randperm(proto_coef.size(0)) select = perm[:cfg.masks_to_train] proto_coef = proto_coef[select, :] pos_idx_t = pos_idx_t[select] if process_gt_bboxes: #cw : True pos_gt_box_t = pos_gt_box_t[select, :] if cfg.use_mask_scoring: mask_scores = mask_scores[select, :] num_pos = proto_coef.size(0) mask_t = downsampled_masks[:, :, pos_idx_t] label_t = labels[idx][pos_idx_t] # Size: [mask_h, mask_w, num_pos] pred_masks = proto_masks @ proto_coef.t() pred_masks = cfg.mask_proto_mask_activation(pred_masks) if cfg.mask_proto_double_loss: if cfg.mask_proto_mask_activation == activation_func.sigmoid: pre_loss = F.binary_cross_entropy(torch.clamp(pred_masks, 0, 1), mask_t, reduction='sum') else: pre_loss = F.smooth_l1_loss(pred_masks, mask_t, reduction='sum') loss_m += cfg.mask_proto_double_loss_alpha * pre_loss if cfg.mask_proto_crop: pred_masks = crop(pred_masks, pos_gt_box_t) if cfg.mask_proto_mask_activation == activation_func.sigmoid: pre_loss = F.binary_cross_entropy(torch.clamp(pred_masks, 0, 1), mask_t, reduction='none') else: pre_loss = F.smooth_l1_loss(pred_masks, mask_t, reduction='none') if cfg.mask_proto_normalize_mask_loss_by_sqrt_area: gt_area = torch.sum(mask_t, dim=(0, 1), keepdim=True) pre_loss = pre_loss / (torch.sqrt(gt_area) + 0.0001) if cfg.mask_proto_reweight_mask_loss: pre_loss = pre_loss * mask_reweighting[:, :, pos_idx_t] if cfg.mask_proto_normalize_emulate_roi_pooling: weight = mask_h * mask_w if cfg.mask_proto_crop else 1 pos_gt_csize = center_size(pos_gt_box_t) gt_box_width = pos_gt_csize[:, 2] * mask_w gt_box_height = pos_gt_csize[:, 3] * mask_h pre_loss = pre_loss.sum(dim=(0, 1)) / gt_box_width / gt_box_height * weight # If the number of masks were limited scale the loss accordingly if old_num_pos > num_pos: pre_loss *= old_num_pos / num_pos loss_m += torch.sum(pre_loss) #cw : yolact_plus -- True if cfg.use_maskiou: #cw yolact_plus -- 'discard_mask_area': 5*5, if cfg.discard_mask_area > 0: gt_mask_area = torch.sum(mask_t, dim=(0, 1)) #cw 0과 1차원은 reduced 되고 나머지 차원에 맞게 sum진행. select = gt_mask_area > cfg.discard_mask_area if torch.sum(select) < 1: continue pos_gt_box_t = pos_gt_box_t[select, :] pred_masks = pred_masks[:, :, select] mask_t = mask_t[:, :, select] label_t = label_t[select] maskiou_net_input = pred_masks.permute(2, 0, 1).contiguous().unsqueeze(1) pred_masks = pred_masks.gt(0.5).float() maskiou_t = self._mask_iou(pred_masks, mask_t) maskiou_net_input_list.append(maskiou_net_input) maskiou_t_list.append(maskiou_t) label_t_list.append(label_t) losses = {'M': loss_m * cfg.mask_alpha / mask_h / mask_w} #cw : NOT NEEDED if cfg.mask_proto_coeff_diversity_loss: losses['D'] = loss_d #cw yolact_plus : True if cfg.use_maskiou: # discard_mask_area discarded every mask in the batch, so nothing to do here if len(maskiou_t_list) == 0: return losses, None maskiou_t = torch.cat(maskiou_t_list) label_t = torch.cat(label_t_list) maskiou_net_input = torch.cat(maskiou_net_input_list) num_samples = maskiou_t.size(0) if cfg.maskious_to_train > 0 and num_samples > cfg.maskious_to_train: perm = torch.randperm(num_samples) select = perm[:cfg.masks_to_train] maskiou_t = maskiou_t[select] label_t = label_t[select] maskiou_net_input = maskiou_net_input[select] return losses, [maskiou_net_input, maskiou_t, label_t] return losses
def lincomb_mask_loss(self, pos, idx_t, loc_data, mask_data, priors, wrapper, masks, gt_box_t, inst_data, labels, truths, interpolation_mode='bilinear'): """ :param pos: :param idx_t: :param loc_data: :param mask_data: :param priors: :param wrapper: :param masks: masks from GT :param gt_box_t: :param inst_data: :param labels: A Python-List!!! :param truths: GT BBOX!!! :param interpolation_mode: :return: """ # mask_h = proto_data.size(1) # mask_w = proto_data.size(2) # 推测 proto_data 的 shape 是 (batch_size, h, w) mask_h, mask_w = 138, 138 # 默认配置下 proto_data.size 为 torch.Size([1, 138, 138, 32]) # mask_data.size() torch.Size([1, 19248, 32]) # GT_BBOX # gt_box_t tensor([[[0.6272, 0.2766, 0.8167, 0.8574], # [0.6272, 0.2766, 0.8167, 0.8574], # [0.6272, 0.2766, 0.8167, 0.8574], # ..., # [0.7971, 0.2935, 0.9717, 0.9171], # [0.7971, 0.2935, 0.9717, 0.9171], # [0.7971, 0.2935, 0.9717, 0.9171]]]), gt_box_t.size() torch.Size([1, 19248, 4]) process_gt_bboxes = cfg.mask_proto_normalize_emulate_roi_pooling or cfg.mask_proto_crop if cfg.mask_proto_remove_empty_masks: # Make sure to store a copy of this because we edit it to get rid of all-zero masks pos = pos.clone() loss_m = 0 loss_d = 0 # Coefficient diversity loss for idx in range(mask_data.size(0)): # 也即是 batch size --- confirmed cur_pos = pos[idx] pos_idx_t = idx_t[idx, cur_pos] # pos tensor([[False, False, False, ..., False, False, False]]), # pos.size() torch.Size([1, 19248]), # cur_pos tensor([False, False, False, ..., False, False, False]), # pos_idx_t tensor([ 6, 9, 18, 22, 0, 0, 12, 20, 21, 19, 0, 0, 1, 1, 1, 7, 14, 1, # 14, 14, 11, 5, 13, 8, 10, 2, 4, 3, 15, 16, 0, 0, 0, 0, 1, 17, # 17]) # cur_pos.size torch.Size([19248]), idx_t.size torch.Size([1, 19248]), pos_idx_t.size torch.Size([43]) if process_gt_bboxes: # Note: this is in point-form pos_gt_box_t = gt_box_t[idx, cur_pos] if pos_idx_t.size(0) == 0: continue with torch.no_grad(): # 似乎作用是把 GT mask downsample 到网络输出大小 - 138 by 138 # masks[idx].size() torch.Size([5, 550, 550]) -> downsampled_masks.size() torch.Size([138, 138, 5]) # 5 应该是每张图片中 seg 的数量 # downsampled_masks = F.interpolate(masks[idx].unsqueeze(0), (mask_h, mask_w), # mode=interpolation_mode, align_corners=False).squeeze(0) # # ([5, 138, 138]) # downsampled_masks = downsampled_masks.permute(1, 2, 0).contiguous() # # ([138, 138, 5]) # downsampled_masks = masks[idx].permute(1, 2, 0).contiguous() # should be [550, 550, 5] # print(f"downsampled_masks {downsampled_masks.size()}") downsampled_masks = crop_gt_mask(masks[idx], pos_gt_box_t, 550) if cfg.mask_proto_binarize_downsampled_gt: downsampled_masks = downsampled_masks.gt(0.5).float() # proto_masks = proto_data[idx] proto_coef = mask_data[idx, cur_pos, :] if cfg.mask_proto_coeff_diversity_loss: if inst_data is not None: div_coeffs = inst_data[idx, cur_pos, :] else: div_coeffs = proto_coef loss_d += self.coeff_diversity_loss(div_coeffs, pos_idx_t) # If we have over the allowed number of masks, select a random sample old_num_pos = proto_coef.size(0) if old_num_pos > cfg.masks_to_train: perm = torch.randperm(proto_coef.size(0)) select = perm[:cfg.masks_to_train] proto_coef = proto_coef[select, :] pos_idx_t = pos_idx_t[select] if process_gt_bboxes: pos_gt_box_t = pos_gt_box_t[select, :] num_pos = proto_coef.size(0) mask_t = downsampled_masks[:, :, pos_idx_t] # Load the preselected bases gt_bases = np.zeros((truths.size(0), 64, 64, 50)) # print(f"labels {labels}, truths.size {truths.size()}") for truth_id in range(truths.size(0)): current_cat_id = labels[idx][truth_id].item() if current_cat_id == 0: # print('id 0 detected!') # 最好确定一下 0 是不是真的就是 bg continue gt_bases[truth_id] = wrapper.get_bases(current_cat_id) gt_bases = torch.from_numpy(gt_bases)[idx].float().to( proto_coef.device) # [idx] is needed due to the batch_size problem # gt_bases is of size [64, 64, 50] # print(f"gt_bases.size {gt_bases.size()}") # pos_idx_t tensor([0, 0, 0, 0, 1, 0]), pos_idx_t.size torch.Size([6]) # proto_masks.size torch.Size([138, 138, 50]) # proto_coef.size torch.Size([6, 50]) # proto_coef.t.size torch.Size([50, 6]) # Size: [mask_h, mask_w, num_pos] # pred_masks = proto_masks @ proto_coef.t() pred_masks = gt_bases @ proto_coef.t() # [脑子瓦特了写法...] # pred_masks = np.zeros((64, 64, len(pos_idx_t))) # for i in range(len(pos_idx_t)): # pred_masks[:, :, i] = gt_bases[truths[pos_idx_t[i]]] @ np.tile(proto_coef[i], # (len(pos_idx_t), 1)).transpose() pred_masks = cfg.mask_proto_mask_activation(pred_masks) # if cfg.mask_proto_crop: # pred_masks = crop(pred_masks, pos_gt_box_t) # 验证: mask_t 的 size 和 pred_masks 的 size # pred_masks before crop torch.Size([138, 138, 11]) # pred_masks after crop torch.Size([138, 138, 11]) # mask_t torch.Size([138, 138, 11]) # 所以,直接用 crop 去裁剪 mask_t # 错误!!! 此 crop 非彼 crop # mask_t_cropped = crop(mask_t, pos_gt_box_t, 1, 64) if cfg.mask_proto_mask_activation == activation_func.sigmoid: pre_loss = F.binary_cross_entropy(torch.clamp( pred_masks, 0, 1), mask_t, reduction='none') else: pre_loss = F.smooth_l1_loss(pred_masks, mask_t, reduction='none') if cfg.mask_proto_normalize_emulate_roi_pooling: weight = mask_h * mask_w if cfg.mask_proto_crop else 1 pos_get_csize = center_size(pos_gt_box_t) gt_box_width = pos_get_csize[:, 2] * mask_w gt_box_height = pos_get_csize[:, 3] * mask_h pre_loss = pre_loss.sum( dim=(0, 1)) / gt_box_width / gt_box_height * weight # If the number of masks were limited scale the loss accordingly if old_num_pos > num_pos: pre_loss *= old_num_pos / num_pos loss_m += torch.sum(pre_loss) losses = {'M': loss_m * cfg.mask_alpha / mask_h / mask_w} return losses
def lincomb_mask_loss(self, pos, idx_t, loc_data, mask_data, priors, proto_data, masks, gt_box_t, score_data, inst_data, labels, interpolation_mode='bilinear', pred_seg=False): # proto_data size # [batch, mask_h, mask_w, mask_dim] # torch.Size([1, 138, 138, 32] mask_h = proto_data.size(1) mask_w = proto_data.size(2) process_gt_bboxes = cfg.mask_proto_normalize_emulate_roi_pooling or cfg.mask_proto_crop if cfg.mask_proto_remove_empty_masks: # Make sure to store a copy of this because we edit it to get rid of all-zero masks # pos = conf_t > 0 pos = pos.clone() loss_m = 0 loss_d = 0 # Coefficient diversity loss maskiou_t_list = [] maskiou_net_input_list = [] label_t_list = [] pred_seg_list = [] for idx in range(mask_data.size(0)): with torch.no_grad(): downsampled_masks = F.interpolate( masks[idx].unsqueeze(0), (mask_h, mask_w), mode=interpolation_mode, align_corners=False).squeeze(0) downsampled_masks = downsampled_masks.permute(1, 2, 0).contiguous() if cfg.mask_proto_binarize_downsampled_gt: downsampled_masks = downsampled_masks.gt(0.5).float() if cfg.mask_proto_remove_empty_masks: # Get rid of gt masks that are so small they get downsampled away very_small_masks = (downsampled_masks.sum(dim=(0, 1)) <= 0.0001) for i in range(very_small_masks.size(0)): if very_small_masks[i]: pos[idx, idx_t[idx] == i] = 0 if cfg.mask_proto_reweight_mask_loss: # Ensure that the gt is binary if not cfg.mask_proto_binarize_downsampled_gt: bin_gt = downsampled_masks.gt(0.5).float() else: bin_gt = downsampled_masks gt_foreground_norm = bin_gt / ( torch.sum(bin_gt, dim=(0, 1), keepdim=True) + 0.0001) gt_background_norm = (1 - bin_gt) / (torch.sum( 1 - bin_gt, dim=(0, 1), keepdim=True) + 0.0001) mask_reweighting = gt_foreground_norm * cfg.mask_proto_reweight_coeff + gt_background_norm mask_reweighting *= mask_h * mask_w # pos = conf_t > 0 cur_pos = pos[idx] pos_idx_t = idx_t[idx, cur_pos] if process_gt_bboxes: # Note: this is in point-form if cfg.mask_proto_crop_with_pred_box: pos_gt_box_t = decode(loc_data[idx, :, :], priors.data, cfg.use_yolo_regressors)[cur_pos] else: pos_gt_box_t = gt_box_t[idx, cur_pos] if pos_idx_t.size(0) == 0: continue proto_masks = proto_data[idx] proto_coef = mask_data[idx, cur_pos, :] if cfg.use_mask_scoring: mask_scores = score_data[idx, cur_pos, :] if cfg.mask_proto_coeff_diversity_loss: if inst_data is not None: div_coeffs = inst_data[idx, cur_pos, :] else: div_coeffs = proto_coef loss_d += self.coeff_diversity_loss(div_coeffs, pos_idx_t) # If we have over the allowed number of masks, select a random sample old_num_pos = proto_coef.size(0) if old_num_pos > cfg.masks_to_train: perm = torch.randperm(proto_coef.size(0)) select = perm[:cfg.masks_to_train] proto_coef = proto_coef[select, :] pos_idx_t = pos_idx_t[select] if process_gt_bboxes: pos_gt_box_t = pos_gt_box_t[select, :] if cfg.use_mask_scoring: mask_scores = mask_scores[select, :] num_pos = proto_coef.size(0) mask_t = downsampled_masks[:, :, pos_idx_t] label_t = labels[idx][pos_idx_t] # Size: [mask_h, mask_w, num_pos] # NOTE: output the prediction segmentation pred_masks = proto_masks @ proto_coef.t() pred_masks = cfg.mask_proto_mask_activation(pred_masks) if pred_seg: # mask = mask_clas.squeeze().cpu().detach().numpy() # from PIL import Image # seg = Image.fromarray(seg, 'L') # mask = Image.fromarray(mask, 'L') # seg.save('seg.png') # mask.save('mask.png') # raise RuntimeError # pred_masks_gt05 = pred_masks.gt(0.5) pred_seg_list.append(pred_masks) label_t_list.append(label_t) if cfg.mask_proto_double_loss: if cfg.mask_proto_mask_activation == activation_func.sigmoid: pre_loss = F.binary_cross_entropy(torch.clamp( pred_masks, 0, 1), mask_t, reduction='sum') else: pre_loss = F.smooth_l1_loss(pred_masks, mask_t, reduction='sum') loss_m += cfg.mask_proto_double_loss_alpha * pre_loss if cfg.mask_proto_crop: pred_masks = crop(pred_masks, pos_gt_box_t) if cfg.mask_proto_mask_activation == activation_func.sigmoid: pre_loss = F.binary_cross_entropy(torch.clamp( pred_masks, 0, 1), mask_t, reduction='none') else: pre_loss = F.smooth_l1_loss(pred_masks, mask_t, reduction='none') if cfg.mask_proto_normalize_mask_loss_by_sqrt_area: gt_area = torch.sum(mask_t, dim=(0, 1), keepdim=True) pre_loss = pre_loss / (torch.sqrt(gt_area) + 0.0001) if cfg.mask_proto_reweight_mask_loss: pre_loss = pre_loss * mask_reweighting[:, :, pos_idx_t] if cfg.mask_proto_normalize_emulate_roi_pooling: weight = mask_h * mask_w if cfg.mask_proto_crop else 1 pos_gt_csize = center_size(pos_gt_box_t) gt_box_width = pos_gt_csize[:, 2] * mask_w gt_box_height = pos_gt_csize[:, 3] * mask_h pre_loss = pre_loss.sum( dim=(0, 1)) / gt_box_width / gt_box_height * weight # If the number of masks were limited scale the loss accordingly if old_num_pos > num_pos: pre_loss *= old_num_pos / num_pos loss_m += torch.sum(pre_loss) if cfg.use_maskiou: if cfg.discard_mask_area > 0: gt_mask_area = torch.sum(mask_t, dim=(0, 1)) select = gt_mask_area > cfg.discard_mask_area if torch.sum(select) < 1: continue pos_gt_box_t = pos_gt_box_t[select, :] pred_masks = pred_masks[:, :, select] mask_t = mask_t[:, :, select] label_t = label_t[select] maskiou_net_input = pred_masks.permute( 2, 0, 1).contiguous().unsqueeze(1) pred_masks = pred_masks.gt(0.5).float() maskiou_t = self._mask_iou(pred_masks, mask_t) maskiou_net_input_list.append(maskiou_net_input) maskiou_t_list.append(maskiou_t) label_t_list.append(label_t) losses = {'M': loss_m * cfg.mask_alpha / mask_h / mask_w} if cfg.mask_proto_coeff_diversity_loss: losses['D'] = loss_d if cfg.use_maskiou: # discard_mask_area discarded every mask in the batch, so nothing to do here if len(maskiou_t_list) == 0: return losses, None maskiou_t = torch.cat(maskiou_t_list) label_t = torch.cat(label_t_list) maskiou_net_input = torch.cat(maskiou_net_input_list) num_samples = maskiou_t.size(0) if cfg.maskious_to_train > 0 and num_samples > cfg.maskious_to_train: perm = torch.randperm(num_samples) select = perm[:cfg.masks_to_train] maskiou_t = maskiou_t[select] label_t = label_t[select] maskiou_net_input = maskiou_net_input[select] return losses, [maskiou_net_input, maskiou_t, label_t] # NOTE if pred_seg: return losses, pred_seg_list, label_t_list else: return losses