def conf_objectness_loss(self, conf_data, conf_t, batch_size, loc_p, loc_t, priors): """ Instead of using softmax, use class[0] to be p(obj) * p(IoU) as in YOLO. Then for the rest of the classes, softmax them and apply CE for only the positive examples. """ conf_t = conf_t.view(-1) # [batch_size*num_priors] conf_data = conf_data.view( -1, conf_data.size(-1)) # [batch_size*num_priors, num_classes] pos_mask = (conf_t > 0) neg_mask = (conf_t == 0) obj_data = conf_data[:, 0] obj_data_pos = obj_data[pos_mask] obj_data_neg = obj_data[neg_mask] # Don't be confused, this is just binary cross entropy similified obj_neg_loss = -F.logsigmoid(-obj_data_neg).sum() with torch.no_grad(): pos_priors = priors.unsqueeze(0).expand(batch_size, -1, -1).reshape(-1, 4)[pos_mask, :] boxes_pred = decode(loc_p, pos_priors, cfg.use_yolo_regressors) boxes_targ = decode(loc_t, pos_priors, cfg.use_yolo_regressors) iou_targets = elemwise_box_iou(boxes_pred, boxes_targ) obj_pos_loss = -iou_targets * F.logsigmoid(obj_data_pos) - ( 1 - iou_targets) * F.logsigmoid(-obj_data_pos) obj_pos_loss = obj_pos_loss.sum() # All that was the objectiveness loss--now time for the class confidence loss conf_data_pos = ( conf_data[:, 1:])[pos_mask] # Now this has just 80 classes conf_t_pos = conf_t[pos_mask] - 1 # So subtract 1 here class_loss = F.cross_entropy(conf_data_pos, conf_t_pos, reduction='sum') return cfg.conf_alpha * (class_loss + obj_pos_loss + obj_neg_loss)
def direct_mask_loss(self, pos_idx, idx_t, loc_data, mask_data, priors, masks): """Crops the gt masks using the predicted bboxes, scales them down, and outputs the BCE loss.""" loss_m = 0 for idx in range(mask_data.size(0)): with torch.no_grad(): cur_pos_idx = pos_idx[idx, :, :] cur_pos_idx_squeezed = cur_pos_idx[:, 1] # Shape: [num_priors, 4], decoded predicted bboxes pos_bboxes = decode( loc_data[idx, :, :], priors.data, self.cfg.use_yolo_regressors ) pos_bboxes = pos_bboxes[cur_pos_idx].view(-1, 4).clamp(0, 1) pos_lookup = idx_t[idx, cur_pos_idx_squeezed] cur_masks = masks[idx] pos_masks = cur_masks[pos_lookup, :, :] # Convert bboxes to absolute coordinates num_pos, img_height, img_width = pos_masks.size() # Take care of all the bad behavior that can be caused by out of bounds coordinates x1, x2 = sanitize_coordinates( pos_bboxes[:, 0], pos_bboxes[:, 2], img_width ) y1, y2 = sanitize_coordinates( pos_bboxes[:, 1], pos_bboxes[:, 3], img_height ) # Crop each gt mask with the predicted bbox and rescale to the predicted mask size # Note that each bounding box crop is a different size so I don't think we can vectorize this scaled_masks = [] for jdx in range(num_pos): tmp_mask = pos_masks[jdx, y1[jdx] : y2[jdx], x1[jdx] : x2[jdx]] # Restore any dimensions we've left out because our bbox was 1px wide while tmp_mask.dim() < 2: tmp_mask = tmp_mask.unsqueeze(0) new_mask = F.adaptive_avg_pool2d( tmp_mask.unsqueeze(0), self.cfg.mask_size ) scaled_masks.append(new_mask.view(1, -1)) mask_t = ( torch.cat(scaled_masks, 0).gt(0.5).float() ) # Threshold downsampled mask pos_mask_data = mask_data[idx, cur_pos_idx_squeezed, :] loss_m += ( F.binary_cross_entropy( torch.clamp(pos_mask_data, 0, 1), mask_t, reduction="sum" ) * self.cfg.mask_alpha ) return loss_m
def __call__(self, predictions, net): """ Args: loc_data: (tensor) Loc preds from loc layers Shape: [batch, num_priors, 4] conf_data: (tensor) Shape: Conf preds from conf layers Shape: [batch, num_priors, num_classes] mask_data: (tensor) Mask preds from mask layers Shape: [batch, num_priors, mask_dim] prior_data: (tensor) Prior boxes and variances from priorbox layers Shape: [num_priors, 4] proto_data: (tensor) If using mask_type.lincomb, the prototype masks Shape: [batch, mask_h, mask_w, mask_dim] Returns: output of shape (batch_size, top_k, 1 + 1 + 4 + mask_dim) These outputs are in the order: class idx, confidence, bbox coords, and mask. Note that the outputs are sorted only if cross_class_nms is False """ loc_data = predictions['loc'] conf_data = predictions['conf'] mask_data = predictions['mask'] prior_data = predictions['priors'] proto_data = predictions['proto'] if 'proto' in predictions else None inst_data = predictions['inst'] if 'inst' in predictions else None out = [] with timer.env('Detect'): batch_size = loc_data.size(0) num_priors = prior_data.size(0) conf_preds = conf_data.view(batch_size, num_priors, self.num_classes).transpose( 2, 1).contiguous() for batch_idx in range(batch_size): decoded_boxes = decode(loc_data[batch_idx], prior_data) result = self.detect(batch_idx, conf_preds, decoded_boxes, mask_data, inst_data) if result is not None and proto_data is not None: result['proto'] = proto_data[batch_idx] out.append({'detection': result, 'net': net}) return out
def __call__(self, predictions, net): """ Args: loc_data: (tensor) Loc preds from loc layers Shape: [batch, num_priors, 4] conf_data: (tensor) Shape: Conf preds from conf layers Shape: [batch, num_priors, num_classes] mask_data: (tensor) Mask preds from mask layers Shape: [batch, num_priors, mask_dim] prior_data: (tensor) Prior boxes and variances from priorbox layers Shape: [num_priors, 4] proto_data: (tensor) If using MaskType.LINCOMB, the prototype masks Shape: [batch, mask_h, mask_w, mask_dim] Returns: output of shape (batch_size, top_k, 1 + 1 + 4 + mask_dim) These outputs are in the order: class idx, confidence, bbox coords, and mask. Note that the outputs are sorted only if cross_class_nms is False """ loc_data = predictions["loc"] conf_data = predictions["conf"] mask_data = predictions["mask"] prior_data = predictions["priors"] proto_data = predictions["proto"] if "proto" in predictions else None inst_data = predictions["inst"] if "inst" in predictions else None out = [] with timer.env("Detect"): batch_size = loc_data.size(0) num_priors = prior_data.size(0) conf_preds = (conf_data.view(batch_size, num_priors, self.num_classes).transpose( 2, 1).contiguous()) for batch_idx in range(batch_size): decoded_boxes = decode(loc_data[batch_idx], prior_data) result = self.detect(batch_idx, conf_preds, decoded_boxes, mask_data, inst_data) if result is not None and proto_data is not None: result["proto"] = proto_data[batch_idx] out.append({"detection": result, "net": net}) return out
def lincomb_mask_loss(self, pos, idx_t, loc_data, mask_data, priors, proto_data, masks, gt_box_t, score_data, inst_data, labels, interpolation_mode='bilinear'): mask_h = proto_data.size(1) mask_w = proto_data.size(2) process_gt_bboxes = cfg.mask_proto_normalize_emulate_roi_pooling or cfg.mask_proto_crop if cfg.mask_proto_remove_empty_masks: # Make sure to store a copy of this because we edit it to get rid of all-zero masks pos = pos.clone() loss_m = 0 loss_d = 0 # Coefficient diversity loss maskiou_t_list = [] maskiou_net_input_list = [] label_t_list = [] for idx in range(mask_data.size(0)): with torch.no_grad(): downsampled_masks = F.interpolate( masks[idx].unsqueeze(0), (mask_h, mask_w), mode=interpolation_mode, align_corners=False).squeeze(0) downsampled_masks = downsampled_masks.permute(1, 2, 0).contiguous() if cfg.mask_proto_binarize_downsampled_gt: downsampled_masks = downsampled_masks.gt(0.5).float() if cfg.mask_proto_remove_empty_masks: # Get rid of gt masks that are so small they get downsampled away very_small_masks = (downsampled_masks.sum(dim=(0, 1)) <= 0.0001) for i in range(very_small_masks.size(0)): if very_small_masks[i]: pos[idx, idx_t[idx] == i] = 0 if cfg.mask_proto_reweight_mask_loss: # Ensure that the gt is binary if not cfg.mask_proto_binarize_downsampled_gt: bin_gt = downsampled_masks.gt(0.5).float() else: bin_gt = downsampled_masks gt_foreground_norm = bin_gt / ( torch.sum(bin_gt, dim=(0, 1), keepdim=True) + 0.0001) gt_background_norm = (1 - bin_gt) / (torch.sum( 1 - bin_gt, dim=(0, 1), keepdim=True) + 0.0001) mask_reweighting = gt_foreground_norm * cfg.mask_proto_reweight_coeff + gt_background_norm mask_reweighting *= mask_h * mask_w cur_pos = pos[idx] pos_idx_t = idx_t[idx, cur_pos] if process_gt_bboxes: # Note: this is in point-form if cfg.mask_proto_crop_with_pred_box: pos_gt_box_t = decode(loc_data[idx, :, :], priors.data, cfg.use_yolo_regressors)[cur_pos] else: pos_gt_box_t = gt_box_t[idx, cur_pos] if pos_idx_t.size(0) == 0: continue proto_masks = proto_data[idx] proto_coef = mask_data[idx, cur_pos, :] if cfg.use_mask_scoring: mask_scores = score_data[idx, cur_pos, :] if cfg.mask_proto_coeff_diversity_loss: if inst_data is not None: div_coeffs = inst_data[idx, cur_pos, :] else: div_coeffs = proto_coef loss_d += self.coeff_diversity_loss(div_coeffs, pos_idx_t) # If we have over the allowed number of masks, select a random sample old_num_pos = proto_coef.size(0) if old_num_pos > cfg.masks_to_train: perm = torch.randperm(proto_coef.size(0)) select = perm[:cfg.masks_to_train] proto_coef = proto_coef[select, :] pos_idx_t = pos_idx_t[select] if process_gt_bboxes: pos_gt_box_t = pos_gt_box_t[select, :] if cfg.use_mask_scoring: mask_scores = mask_scores[select, :] num_pos = proto_coef.size(0) mask_t = downsampled_masks[:, :, pos_idx_t] label_t = labels[idx][pos_idx_t] # Size: [mask_h, mask_w, num_pos] pred_masks = proto_masks @ proto_coef.t() pred_masks = cfg.mask_proto_mask_activation(pred_masks) if cfg.mask_proto_double_loss: if cfg.mask_proto_mask_activation == activation_func.sigmoid: pre_loss = F.binary_cross_entropy(torch.clamp( pred_masks, 0, 1), mask_t, reduction='sum') else: pre_loss = F.smooth_l1_loss(pred_masks, mask_t, reduction='sum') loss_m += cfg.mask_proto_double_loss_alpha * pre_loss if cfg.mask_proto_crop: pred_masks = crop(pred_masks, pos_gt_box_t) if cfg.mask_proto_mask_activation == activation_func.sigmoid: pre_loss = F.binary_cross_entropy(torch.clamp( pred_masks, 0, 1), mask_t, reduction='none') else: pre_loss = F.smooth_l1_loss(pred_masks, mask_t, reduction='none') if cfg.mask_proto_normalize_mask_loss_by_sqrt_area: gt_area = torch.sum(mask_t, dim=(0, 1), keepdim=True) pre_loss = pre_loss / (torch.sqrt(gt_area) + 0.0001) if cfg.mask_proto_reweight_mask_loss: pre_loss = pre_loss * mask_reweighting[:, :, pos_idx_t] if cfg.mask_proto_normalize_emulate_roi_pooling: weight = mask_h * mask_w if cfg.mask_proto_crop else 1 pos_gt_csize = center_size(pos_gt_box_t) gt_box_width = pos_gt_csize[:, 2] * mask_w gt_box_height = pos_gt_csize[:, 3] * mask_h pre_loss = pre_loss.sum( dim=(0, 1)) / gt_box_width / gt_box_height * weight # If the number of masks were limited scale the loss accordingly if old_num_pos > num_pos: pre_loss *= old_num_pos / num_pos loss_m += torch.sum(pre_loss) if cfg.use_maskiou: if cfg.discard_mask_area > 0: gt_mask_area = torch.sum(mask_t, dim=(0, 1)) select = gt_mask_area > cfg.discard_mask_area if torch.sum(select) < 1: continue pos_gt_box_t = pos_gt_box_t[select, :] pred_masks = pred_masks[:, :, select] mask_t = mask_t[:, :, select] label_t = label_t[select] maskiou_net_input = pred_masks.permute( 2, 0, 1).contiguous().unsqueeze(1) pred_masks = pred_masks.gt(0.5).float() maskiou_t = self._mask_iou(pred_masks, mask_t) maskiou_net_input_list.append(maskiou_net_input) maskiou_t_list.append(maskiou_t) label_t_list.append(label_t) losses = {'M': loss_m * cfg.mask_alpha / mask_h / mask_w} if cfg.mask_proto_coeff_diversity_loss: losses['D'] = loss_d if cfg.use_maskiou: # discard_mask_area discarded every mask in the batch, so nothing to do here if len(maskiou_t_list) == 0: return losses, None maskiou_t = torch.cat(maskiou_t_list) label_t = torch.cat(label_t_list) maskiou_net_input = torch.cat(maskiou_net_input_list) num_samples = maskiou_t.size(0) if cfg.maskious_to_train > 0 and num_samples > cfg.maskious_to_train: perm = torch.randperm(num_samples) select = perm[:cfg.masks_to_train] maskiou_t = maskiou_t[select] label_t = label_t[select] maskiou_net_input = maskiou_net_input[select] return losses, [maskiou_net_input, maskiou_t, label_t] return losses