def postprocess(det_output, w, h, batch_idx=0, interpolation_mode='bilinear', visualize_lincomb=False, crop_masks=True, score_threshold=0): """ Postprocesses the output of Yolact on testing mode into a format that makes sense, accounting for all the possible configuration settings. Args: - det_output: The lost of dicts that Detect outputs. - w: The real with of the image. - h: The real height of the image. - batch_idx: If you have multiple images for this batch, the image's index in the batch. - interpolation_mode: Can be 'nearest' | 'area' | 'bilinear' (see torch.nn.functional.interpolate) Returns 4 torch Tensors (in the following order): - classes [num_det]: The class idx for each detection. - scores [num_det]: The confidence score for each detection. - boxes [num_det, 4]: The bounding box for each detection in absolute point form. - masks [num_det, h, w]: Full image masks for each detection. """ dets = det_output[batch_idx] if dets is None: return [torch.Tensor() ] * 4 # Warning, this is 4 copies of the same thing if score_threshold > 0: keep = dets['score'] > score_threshold for k in dets: if k != 'proto': dets[k] = dets[k][keep] if dets['score'].size(0) == 0: return [torch.Tensor()] * 4 # im_w and im_h when it concerns bboxes. This is a workaround hack for preserve_aspect_ratio b_w, b_h = (w, h) # Undo the padding introduced with preserve_aspect_ratio if cfg.preserve_aspect_ratio: r_w, r_h = Resize.faster_rcnn_scale(w, h, cfg.min_size, cfg.max_size) # Get rid of any detections whose centers are outside the image boxes = dets['box'] boxes = center_size(boxes) s_w, s_h = (r_w / cfg.max_size, r_h / cfg.max_size) not_outside = ((boxes[:, 0] > s_w) + (boxes[:, 1] > s_h)) < 1 # not (a or b) for k in dets: if k != 'proto': dets[k] = dets[k][not_outside] # A hack to scale the bboxes to the right size b_w, b_h = (cfg.max_size / r_w * w, cfg.max_size / r_h * h) # Actually extract everything from dets now classes = dets['class'] boxes = dets['box'] scores = dets['score'] masks = dets['mask'] if cfg.mask_type == mask_type.lincomb and cfg.eval_mask_branch: # At this points masks is only the coefficients proto_data = dets['proto'] # Test flag, do not upvote if cfg.mask_proto_debug: np.save('scripts/proto.npy', proto_data.cpu().numpy()) if visualize_lincomb: display_lincomb(proto_data, masks) masks = torch.matmul(proto_data, masks.t()) masks = cfg.mask_proto_mask_activation(masks) # Crop masks before upsampling because you know why if crop_masks: masks = crop(masks, boxes) # Permute into the correct output shape [num_dets, proto_h, proto_w] masks = masks.permute(2, 0, 1).contiguous() # Scale masks up to the full image if cfg.preserve_aspect_ratio: # Undo padding masks = masks[:, :int(r_h / cfg.max_size * proto_data.size(1) ), :int(r_w / cfg.max_size * proto_data.size(2))] masks = F.interpolate(masks.unsqueeze(0), (h, w), mode=interpolation_mode, align_corners=False).squeeze(0) # Binarize the masks masks.gt_(0.5) boxes[:, 0], boxes[:, 2] = sanitize_coordinates(boxes[:, 0], boxes[:, 2], b_w, cast=False) boxes[:, 1], boxes[:, 3] = sanitize_coordinates(boxes[:, 1], boxes[:, 3], b_h, cast=False) boxes = boxes.long() if cfg.mask_type == mask_type.direct and cfg.eval_mask_branch: # Upscale masks full_masks = torch.zeros(masks.size(0), h, w) for jdx in range(masks.size(0)): x1, y1, x2, y2 = boxes[jdx, :] mask_w = x2 - x1 mask_h = y2 - y1 # Just in case if mask_w * mask_h <= 0 or mask_w < 0: continue mask = masks[jdx, :].view(1, 1, cfg.mask_size, cfg.mask_size) mask = F.interpolate(mask, (mask_h, mask_w), mode=interpolation_mode, align_corners=False) mask = mask.gt(0.5).float() full_masks[jdx, y1:y2, x1:x2] = mask masks = full_masks return classes, scores, boxes, masks
def lincomb_mask_loss(self, pos, idx_t, loc_data, mask_data, priors, proto_data, masks, gt_box_t, score_data, inst_data, labels, interpolation_mode='bilinear'): mask_h = proto_data.size(1) mask_w = proto_data.size(2) process_gt_bboxes = cfg.mask_proto_normalize_emulate_roi_pooling or cfg.mask_proto_crop if cfg.mask_proto_remove_empty_masks: # Make sure to store a copy of this because we edit it to get rid of all-zero masks pos = pos.clone() loss_m = 0 loss_d = 0 # Coefficient diversity loss maskiou_t_list = [] maskiou_net_input_list = [] label_t_list = [] for idx in range(mask_data.size(0)): with torch.no_grad(): downsampled_masks = F.interpolate( masks[idx].unsqueeze(0), (mask_h, mask_w), mode=interpolation_mode, align_corners=False).squeeze(0) downsampled_masks = downsampled_masks.permute(1, 2, 0).contiguous() if cfg.mask_proto_binarize_downsampled_gt: downsampled_masks = downsampled_masks.gt(0.5).float() if cfg.mask_proto_remove_empty_masks: # Get rid of gt masks that are so small they get downsampled away very_small_masks = (downsampled_masks.sum(dim=(0, 1)) <= 0.0001) for i in range(very_small_masks.size(0)): if very_small_masks[i]: pos[idx, idx_t[idx] == i] = 0 if cfg.mask_proto_reweight_mask_loss: # Ensure that the gt is binary if not cfg.mask_proto_binarize_downsampled_gt: bin_gt = downsampled_masks.gt(0.5).float() else: bin_gt = downsampled_masks gt_foreground_norm = bin_gt / ( torch.sum(bin_gt, dim=(0, 1), keepdim=True) + 0.0001) gt_background_norm = (1 - bin_gt) / (torch.sum( 1 - bin_gt, dim=(0, 1), keepdim=True) + 0.0001) mask_reweighting = gt_foreground_norm * cfg.mask_proto_reweight_coeff + gt_background_norm mask_reweighting *= mask_h * mask_w cur_pos = pos[idx] pos_idx_t = idx_t[idx, cur_pos] if process_gt_bboxes: # Note: this is in point-form if cfg.mask_proto_crop_with_pred_box: pos_gt_box_t = decode(loc_data[idx, :, :], priors.data, cfg.use_yolo_regressors)[cur_pos] else: pos_gt_box_t = gt_box_t[idx, cur_pos] if pos_idx_t.size(0) == 0: continue proto_masks = proto_data[idx] proto_coef = mask_data[idx, cur_pos, :] if cfg.use_mask_scoring: mask_scores = score_data[idx, cur_pos, :] if cfg.mask_proto_coeff_diversity_loss: if inst_data is not None: div_coeffs = inst_data[idx, cur_pos, :] else: div_coeffs = proto_coef loss_d += self.coeff_diversity_loss(div_coeffs, pos_idx_t) # If we have over the allowed number of masks, select a random sample old_num_pos = proto_coef.size(0) if old_num_pos > cfg.masks_to_train: perm = torch.randperm(proto_coef.size(0)) select = perm[:cfg.masks_to_train] proto_coef = proto_coef[select, :] pos_idx_t = pos_idx_t[select] if process_gt_bboxes: pos_gt_box_t = pos_gt_box_t[select, :] if cfg.use_mask_scoring: mask_scores = mask_scores[select, :] num_pos = proto_coef.size(0) mask_t = downsampled_masks[:, :, pos_idx_t] label_t = labels[idx][pos_idx_t] # Size: [mask_h, mask_w, num_pos] pred_masks = proto_masks @ proto_coef.t() pred_masks = cfg.mask_proto_mask_activation(pred_masks) if cfg.mask_proto_double_loss: if cfg.mask_proto_mask_activation == activation_func.sigmoid: pre_loss = F.binary_cross_entropy(torch.clamp( pred_masks, 0, 1), mask_t, reduction='sum') else: pre_loss = F.smooth_l1_loss(pred_masks, mask_t, reduction='sum') loss_m += cfg.mask_proto_double_loss_alpha * pre_loss if cfg.mask_proto_crop: pred_masks = crop(pred_masks, pos_gt_box_t) if cfg.mask_proto_mask_activation == activation_func.sigmoid: pre_loss = F.binary_cross_entropy(torch.clamp( pred_masks, 0, 1), mask_t, reduction='none') else: pre_loss = F.smooth_l1_loss(pred_masks, mask_t, reduction='none') if cfg.mask_proto_normalize_mask_loss_by_sqrt_area: gt_area = torch.sum(mask_t, dim=(0, 1), keepdim=True) pre_loss = pre_loss / (torch.sqrt(gt_area) + 0.0001) if cfg.mask_proto_reweight_mask_loss: pre_loss = pre_loss * mask_reweighting[:, :, pos_idx_t] if cfg.mask_proto_normalize_emulate_roi_pooling: weight = mask_h * mask_w if cfg.mask_proto_crop else 1 pos_gt_csize = center_size(pos_gt_box_t) gt_box_width = pos_gt_csize[:, 2] * mask_w gt_box_height = pos_gt_csize[:, 3] * mask_h pre_loss = pre_loss.sum( dim=(0, 1)) / gt_box_width / gt_box_height * weight # If the number of masks were limited scale the loss accordingly if old_num_pos > num_pos: pre_loss *= old_num_pos / num_pos loss_m += torch.sum(pre_loss) if cfg.use_maskiou: if cfg.discard_mask_area > 0: gt_mask_area = torch.sum(mask_t, dim=(0, 1)) select = gt_mask_area > cfg.discard_mask_area if torch.sum(select) < 1: continue pos_gt_box_t = pos_gt_box_t[select, :] pred_masks = pred_masks[:, :, select] mask_t = mask_t[:, :, select] label_t = label_t[select] maskiou_net_input = pred_masks.permute( 2, 0, 1).contiguous().unsqueeze(1) pred_masks = pred_masks.gt(0.5).float() maskiou_t = self._mask_iou(pred_masks, mask_t) maskiou_net_input_list.append(maskiou_net_input) maskiou_t_list.append(maskiou_t) label_t_list.append(label_t) losses = {'M': loss_m * cfg.mask_alpha / mask_h / mask_w} if cfg.mask_proto_coeff_diversity_loss: losses['D'] = loss_d if cfg.use_maskiou: # discard_mask_area discarded every mask in the batch, so nothing to do here if len(maskiou_t_list) == 0: return losses, None maskiou_t = torch.cat(maskiou_t_list) label_t = torch.cat(label_t_list) maskiou_net_input = torch.cat(maskiou_net_input_list) num_samples = maskiou_t.size(0) if cfg.maskious_to_train > 0 and num_samples > cfg.maskious_to_train: perm = torch.randperm(num_samples) select = perm[:cfg.masks_to_train] maskiou_t = maskiou_t[select] label_t = label_t[select] maskiou_net_input = maskiou_net_input[select] return losses, [maskiou_net_input, maskiou_t, label_t] return losses
def postprocess(det_output, w, h, batch_idx=0, interpolation_mode='bilinear', visualize_lincomb=False, crop_masks=True, score_threshold=0): """ Postprocesses the output of Yolact on testing mode into a format that makes sense, accounting for all the possible configuration settings. Args: - det_output: The list of dicts that Detect outputs. - w: The real width of the image. - h: The real height of the image. - batch_idx: If you have multiple images for this batch, the image's index in the batch. - interpolation_mode: Can be 'nearest' | 'area' | 'bilinear' (see torch.nn.functional.interpolate) Returns 4 torch Tensors (in the following order): - classes [num_det]: The class idx for each detection. - scores [num_det]: The confidence score for each detection. - boxes [num_det, 4]: The bounding box for each detection in absolute point form. - masks [num_det, h, w]: Full image masks for each detection. """ dets = det_output[batch_idx] net = dets['net'] dets = dets['detection'] if dets is None: return [torch.Tensor() ] * 4 # Warning, this is 4 copies of the same thing if score_threshold > 0: keep = dets['score'] > score_threshold for k in dets: if k != 'proto': dets[k] = dets[k][keep] if dets['score'].size(0) == 0: return [torch.Tensor()] * 4 # Actually extract everything from dets now classes = dets['class'] boxes = dets['box'] scores = dets['score'] masks = dets['mask'] if cfg.mask_type == mask_type.lincomb and cfg.eval_mask_branch: # At this points masks is only the coefficients proto_data = dets['proto'] # Test flag, do not upvote if cfg.mask_proto_debug: np.save('scripts/proto.npy', proto_data.cpu().numpy()) if visualize_lincomb: display_lincomb(proto_data, masks) masks = proto_data @ masks.t() masks = cfg.mask_proto_mask_activation(masks) # Crop masks before upsampling because you know why if crop_masks: masks = crop(masks, boxes) # Permute into the correct output shape [num_dets, proto_h, proto_w] masks = masks.permute(2, 0, 1).contiguous() if cfg.use_maskiou: with timer.env('maskiou_net'): with torch.no_grad(): maskiou_p = net.maskiou_net(masks.unsqueeze(1)) maskiou_p = torch.gather( maskiou_p, dim=1, index=classes.unsqueeze(1)).squeeze(1) if cfg.rescore_mask: if cfg.rescore_bbox: scores = scores * maskiou_p else: scores = [scores, scores * maskiou_p] # Scale masks up to the full image masks = F.interpolate(masks.unsqueeze(0), (h, w), mode=interpolation_mode, align_corners=False).squeeze(0) # Binarize the masks masks.gt_(0.5) boxes[:, 0], boxes[:, 2] = sanitize_coordinates(boxes[:, 0], boxes[:, 2], w, cast=False) boxes[:, 1], boxes[:, 3] = sanitize_coordinates(boxes[:, 1], boxes[:, 3], h, cast=False) boxes = boxes.long() if cfg.mask_type == mask_type.direct and cfg.eval_mask_branch: # Upscale masks full_masks = torch.zeros(masks.size(0), h, w) for jdx in range(masks.size(0)): x1, y1, x2, y2 = boxes[jdx, :] mask_w = x2 - x1 mask_h = y2 - y1 # Just in case if mask_w * mask_h <= 0 or mask_w < 0: continue mask = masks[jdx, :].view(1, 1, cfg.mask_size, cfg.mask_size) mask = F.interpolate(mask, (mask_h, mask_w), mode=interpolation_mode, align_corners=False) mask = mask.gt(0.5).float() full_masks[jdx, y1:y2, x1:x2] = mask masks = full_masks return classes, scores, boxes, masks