def direct_mask_loss(self, pos_idx, idx_t, loc_data, mask_data, priors, masks): """ Crops the gt masks using the predicted bboxes, scales them down, and outputs the BCE loss. """ loss_m = 0 for idx in range(mask_data.shape[0]): with jt.no_grad(): cur_pos_idx = pos_idx[idx] cur_pos_idx_squeezed = cur_pos_idx[:, 1] # Shape: [num_priors, 4], decoded predicted bboxes pos_bboxes = decode(loc_data[idx], priors.data, cfg.use_yolo_regressors) pos_bboxes = pos_bboxes[cur_pos_idx].view(-1, 4).clamp(0, 1) pos_lookup = idx_t[idx, cur_pos_idx_squeezed] cur_masks = masks[idx] pos_masks = cur_masks[pos_lookup] # Convert bboxes to absolute coordinates num_pos, img_height, img_width = pos_masks.shape # Take care of all the bad behavior that can be caused by out of bounds coordinates x1, x2 = sanitize_coordinates(pos_bboxes[:, 0], pos_bboxes[:, 2], img_width) y1, y2 = sanitize_coordinates(pos_bboxes[:, 1], pos_bboxes[:, 3], img_height) # Crop each gt mask with the predicted bbox and rescale to the predicted mask size # Note that each bounding box crop is a different size so I don't think we can vectorize this scaled_masks = [] for jdx in range(num_pos): tmp_mask = pos_masks[jdx, y1[jdx]:y2[jdx], x1[jdx]:x2[jdx]] # Restore any dimensions we've left out because our bbox was 1px wide while tmp_mask.ndim < 2: tmp_mask = tmp_mask.unsqueeze(0) new_mask = nn.AdaptiveAvgPool2d(cfg.mask_size)( tmp_mask.unsqueeze(0)) scaled_masks.append(new_mask.view(1, -1)) mask_t = (jt.contrib.concat(scaled_masks, 0) > 0.5).float() # Threshold downsampled mask pos_mask_data = mask_data[idx, cur_pos_idx_squeezed, :] loss_m += nn.bce_loss(jt.clamp(pos_mask_data, 0, 1), mask_t, size_average=False) * cfg.mask_alpha return loss_m
def __call__(self, anchors, objectness, box_regression, targets): """ Arguments: anchors (list[list[BoxList]]) objectness (list[Tensor]) box_regression (list[Tensor]) targets (list[BoxList]) Returns: objectness_loss (Tensor) box_loss (Tensor) """ anchors = [ cat_boxlist(anchors_per_image) for anchors_per_image in anchors ] labels, regression_targets = self.prepare_targets(anchors, targets) sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels) sampled_pos_inds = jt.nonzero( jt.contrib.concat(sampled_pos_inds, dim=0)).squeeze(1) sampled_neg_inds = jt.nonzero( jt.contrib.concat(sampled_neg_inds, dim=0)).squeeze(1) sampled_inds = jt.contrib.concat([sampled_pos_inds, sampled_neg_inds], dim=0) objectness, box_regression = concat_box_prediction_layers( objectness, box_regression) objectness = objectness.squeeze(1) labels = jt.contrib.concat(labels, dim=0) regression_targets = jt.contrib.concat(regression_targets, dim=0) box_loss = _smooth_l1_loss(box_regression[sampled_pos_inds], regression_targets[sampled_pos_inds], sigma=3.) / (sampled_inds.numel()) # bce_loss_with_logits = nn.BCEWithLogitsLoss() # objectness_loss = bce_loss_with_logits( # objectness[sampled_inds], labels[sampled_inds] # ) objectness_loss = nn.bce_loss(objectness[sampled_inds].sigmoid(), labels[sampled_inds]) return objectness_loss, box_loss
def __call__(self, output, target): from jittor.nn import bce_loss return bce_loss(output, target)
def execute(self, net, predictions, targets, masks, num_crowds): """Multibox Loss Args: predictions (tuple): A tuple containing loc preds, conf preds, mask preds, and prior boxes from SSD net. loc shape: jt.size(batch_size,num_priors,4) conf shape: jt.size(batch_size,num_priors,num_classes) masks shape: jt.size(batch_size,num_priors,mask_dim) priors shape: jt.size(num_priors,4) proto* shape: jt.size(batch_size,mask_h,mask_w,mask_dim) targets (list<tensor>): Ground truth boxes and labels for a batch, shape: [batch_size][num_objs,5] (last idx is the label). masks (list<tensor>): Ground truth masks for each object in each image, shape: [batch_size][num_objs,im_height,im_width] num_crowds (list<int>): Number of crowd annotations per batch. The crowd annotations should be the last num_crowds elements of targets and masks. * Only if mask_type == lincomb """ loc_data = predictions['loc'] conf_data = predictions['conf'] mask_data = predictions['mask'] priors = predictions['priors'] if cfg.mask_type == mask_type.lincomb: proto_data = predictions['proto'] score_data = predictions['score'] if cfg.use_mask_scoring else None inst_data = predictions['inst'] if cfg.use_instance_coeff else None labels = [None] * len(targets) # Used in sem segm loss batch_size = loc_data.shape[0] num_priors = priors.shape[0] num_classes = self.num_classes # Match priors (default boxes) and ground truth boxes # These tensors will be created with the same device as loc_data loc_t = jt.empty((batch_size, num_priors, 4),dtype=loc_data.dtype) gt_box_t = jt.empty((batch_size, num_priors, 4),dtype=loc_data.dtype) conf_t = jt.empty((batch_size, num_priors)).int32() idx_t = jt.empty((batch_size, num_priors)).int32() if cfg.use_class_existence_loss: class_existence_t = jt.empty((batch_size, num_classes-1),dtype=loc_data.dtype) # jt.sync(list(predictions.values())) for idx in range(batch_size): truths = targets[idx][:, :-1] labels[idx] = targets[idx][:, -1].int32() if cfg.use_class_existence_loss: # Construct a one-hot vector for each object and collapse it into an existence vector with max # Also it's fine to include the crowd annotations here class_existence_t[idx,:] = jt.eye(num_classes-1)[labels[idx]].max(dim=0)[0] # Split the crowd annotations because they come bundled in cur_crowds = num_crowds[idx] if cur_crowds > 0: split = lambda x: (x[-cur_crowds:], x[:-cur_crowds]) crowd_boxes, truths = split(truths) # We don't use the crowd labels or masks _, labels[idx] = split(labels[idx]) _, masks[idx] = split(masks[idx]) else: crowd_boxes = None match(self.pos_threshold, self.neg_threshold, truths, priors, labels[idx], crowd_boxes, loc_t, conf_t, idx_t, idx, loc_data[idx]) gt_box_t[idx,:,:] = truths[idx_t[idx]] # wrap targets loc_t.stop_grad() conf_t.stop_grad() idx_t.stop_grad() pos = conf_t > 0 num_pos = pos.sum(dim=1, keepdims=True) # Shape: [batch,num_priors,4] pos_idx = pos.unsqueeze(pos.ndim).expand_as(loc_data) losses = {} # Localization Loss (Smooth L1) if cfg.train_boxes: loc_p = loc_data[pos_idx].view(-1, 4) loc_t = loc_t[pos_idx].view(-1, 4) # print(loc_t) losses['B'] = nn.smooth_l1_loss(loc_p, loc_t, reduction='sum') * cfg.bbox_alpha if cfg.train_masks: if cfg.mask_type == mask_type.direct: if cfg.use_gt_bboxes: pos_masks = [] for idx in range(batch_size): pos_masks.append(masks[idx][idx_t[idx, pos[idx]]]) masks_t = jt.contrib.concat(pos_masks, 0) masks_p = mask_data[pos, :].view(-1, cfg.mask_dim) losses['M'] = nn.bce_loss(jt.clamp(masks_p, 0, 1), masks_t, size_average=False) * cfg.mask_alpha else: losses['M'] = self.direct_mask_loss(pos_idx, idx_t, loc_data, mask_data, priors, masks) elif cfg.mask_type == mask_type.lincomb: ret = self.lincomb_mask_loss(pos, idx_t, loc_data, mask_data, priors, proto_data, masks, gt_box_t, score_data, inst_data, labels) if cfg.use_maskiou: loss, maskiou_targets = ret else: loss = ret losses.update(loss) if cfg.mask_proto_loss is not None: if cfg.mask_proto_loss == 'l1': losses['P'] = jt.mean(jt.abs(proto_data)) / self.l1_expected_area * self.l1_alpha elif cfg.mask_proto_loss == 'disj': losses['P'] = -jt.mean(jt.max(nn.log_softmax(proto_data, dim=-1), dim=-1)[0]) # Confidence loss if cfg.use_focal_loss: if cfg.use_sigmoid_focal_loss: losses['C'] = self.focal_conf_sigmoid_loss(conf_data, conf_t) elif cfg.use_objectness_score: losses['C'] = self.focal_conf_objectness_loss(conf_data, conf_t) else: losses['C'] = self.focal_conf_loss(conf_data, conf_t) else: if cfg.use_objectness_score: losses['C'] = self.conf_objectness_loss(conf_data, conf_t, batch_size, loc_p, loc_t, priors) else: losses['C'] = self.ohem_conf_loss(conf_data, conf_t, pos, batch_size) # Mask IoU Loss if cfg.use_maskiou and maskiou_targets is not None: losses['I'] = self.mask_iou_loss(net, maskiou_targets) # These losses also don't depend on anchors if cfg.use_class_existence_loss: losses['E'] = self.class_existence_loss(predictions['classes'], class_existence_t) if cfg.use_semantic_segmentation_loss: losses['S'] = self.semantic_segmentation_loss(predictions['segm'], masks, labels) # Divide all losses by the number of positives. # Don't do it for loss[P] because that doesn't depend on the anchors. total_num_pos = num_pos.sum().float() for k in losses: if k not in ('P', 'E', 'S'): losses[k] /= total_num_pos else: losses[k] /= batch_size # Loss Key: # - B: Box Localization Loss # - C: Class Confidence Loss # - M: Mask Loss # - P: Prototype Loss # - D: Coefficient Diversity Loss # - E: Class Existence Loss # - S: Semantic Segmentation Loss return losses
def lincomb_mask_loss(self, pos, idx_t, loc_data, mask_data, priors, proto_data, masks, gt_box_t, score_data, inst_data, labels, interpolation_mode='bilinear'): mask_h = proto_data.shape[1] mask_w = proto_data.shape[2] process_gt_bboxes = cfg.mask_proto_normalize_emulate_roi_pooling or cfg.mask_proto_crop if cfg.mask_proto_remove_empty_masks: # Make sure to store a copy of this because we edit it to get rid of all-zero masks pos = pos.clone() loss_m = 0 loss_d = 0 # Coefficient diversity loss maskiou_t_list = [] maskiou_net_input_list = [] label_t_list = [] for idx in range(mask_data.shape[0]): with jt.no_grad(): downsampled_masks = nn.interpolate(masks[idx].unsqueeze(0), (mask_h, mask_w), mode=interpolation_mode, align_corners=False).squeeze(0) downsampled_masks = downsampled_masks.permute(1, 2, 0) if cfg.mask_proto_binarize_downsampled_gt: downsampled_masks = (downsampled_masks>0.5).float() if cfg.mask_proto_remove_empty_masks: # Get rid of gt masks that are so small they get downsampled away very_small_masks = (downsampled_masks.sum(0).sum(0) <= 0.0001) for i in range(very_small_masks.shape[0]): if very_small_masks[i]: pos[idx, idx_t[idx] == i] = 0 if cfg.mask_proto_reweight_mask_loss: # Ensure that the gt is binary if not cfg.mask_proto_binarize_downsampled_gt: bin_gt = (downsampled_masks>0.5).float() else: bin_gt = downsampled_masks gt_foreground_norm = bin_gt / (jt.sum(bin_gt, dim=(0,1), keepdim=True) + 0.0001) gt_background_norm = (1-bin_gt) / (jt.sum(1-bin_gt, dim=(0,1), keepdim=True) + 0.0001) mask_reweighting = gt_foreground_norm * cfg.mask_proto_reweight_coeff + gt_background_norm mask_reweighting *= mask_h * mask_w cur_pos = pos[idx] cur_pos = jt.where(cur_pos)[0] pos_idx_t = idx_t[idx, cur_pos] if process_gt_bboxes: # Note: this is in point-form if cfg.mask_proto_crop_with_pred_box: pos_gt_box_t = decode(loc_data[idx, :, :], priors.data, cfg.use_yolo_regressors)[cur_pos] else: pos_gt_box_t = gt_box_t[idx, cur_pos] if pos_idx_t.shape[0] == 0: continue proto_masks = proto_data[idx] proto_coef = mask_data[idx, cur_pos, :] if cfg.use_mask_scoring: mask_scores = score_data[idx, cur_pos, :] if cfg.mask_proto_coeff_diversity_loss: if inst_data is not None: div_coeffs = inst_data[idx, cur_pos, :] else: div_coeffs = proto_coef loss_d += self.coeff_diversity_loss(div_coeffs, pos_idx_t) # If we have over the allowed number of masks, select a random sample old_num_pos = proto_coef.shape[0] if old_num_pos > cfg.masks_to_train: perm = jt.randperm(proto_coef.shape[0]) select = perm[:cfg.masks_to_train] proto_coef = proto_coef[select, :] pos_idx_t = pos_idx_t[select] if process_gt_bboxes: pos_gt_box_t = pos_gt_box_t[select, :] if cfg.use_mask_scoring: mask_scores = mask_scores[select, :] num_pos = proto_coef.shape[0] mask_t = downsampled_masks[:, :, pos_idx_t] label_t = labels[idx][pos_idx_t] # Size: [mask_h, mask_w, num_pos] pred_masks = proto_masks @ proto_coef.transpose(1,0) pred_masks = cfg.mask_proto_mask_activation(pred_masks) if cfg.mask_proto_double_loss: if cfg.mask_proto_mask_activation == activation_func.sigmoid: pre_loss = nn.bce_loss(jt.clamp(pred_masks, 0, 1), mask_t, size_average=False) else: pre_loss = nn.smooth_l1_loss(pred_masks, mask_t, reduction='sum') loss_m += cfg.mask_proto_double_loss_alpha * pre_loss if cfg.mask_proto_crop: pred_masks = crop(pred_masks, pos_gt_box_t) if cfg.mask_proto_mask_activation == activation_func.sigmoid: pre_loss = binary_cross_entropy(jt.clamp(pred_masks, 0, 1), mask_t) else: pre_loss = nn.smooth_l1_loss(pred_masks, mask_t, reduction='none') if cfg.mask_proto_normalize_mask_loss_by_sqrt_area: gt_area = jt.sum(mask_t, dim=(0, 1), keepdims=True) pre_loss = pre_loss / (jt.sqrt(gt_area) + 0.0001) if cfg.mask_proto_reweight_mask_loss: pre_loss = pre_loss * mask_reweighting[:, :, pos_idx_t] if cfg.mask_proto_normalize_emulate_roi_pooling: weight = mask_h * mask_w if cfg.mask_proto_crop else 1 pos_gt_csize = center_size(pos_gt_box_t) gt_box_width = pos_gt_csize[:, 2] * mask_w gt_box_height = pos_gt_csize[:, 3] * mask_h pre_loss = pre_loss.sum(0).sum(0) / gt_box_width / gt_box_height * weight # If the number of masks were limited scale the loss accordingly if old_num_pos > num_pos: pre_loss *= old_num_pos / num_pos loss_m += jt.sum(pre_loss) if cfg.use_maskiou: if cfg.discard_mask_area > 0: gt_mask_area = jt.sum(mask_t, dim=(0, 1)) select = gt_mask_area > cfg.discard_mask_area if jt.sum(select).item() < 1: continue pos_gt_box_t = pos_gt_box_t[select, :] pred_masks = pred_masks[:, :, select] mask_t = mask_t[:, :, select] label_t = label_t[select] maskiou_net_input = pred_masks.permute(2, 0, 1).unsqueeze(1) pred_masks = (pred_masks>0.5).float() maskiou_t = self._mask_iou(pred_masks, mask_t) maskiou_net_input_list.append(maskiou_net_input) maskiou_t_list.append(maskiou_t) label_t_list.append(label_t) losses = {'M': loss_m * cfg.mask_alpha / mask_h / mask_w} if cfg.mask_proto_coeff_diversity_loss: losses['D'] = loss_d if cfg.use_maskiou: # discard_mask_area discarded every mask in the batch, so nothing to do here if len(maskiou_t_list) == 0: return losses, None maskiou_t = jt.contrib.concat(maskiou_t_list) label_t = jt.contrib.concat(label_t_list) maskiou_net_input = jt.contrib.concat(maskiou_net_input_list) num_samples = maskiou_t.shape[0] if cfg.maskious_to_train > 0 and num_samples > cfg.maskious_to_train: perm = jt.randperm(num_samples) select = perm[:cfg.masks_to_train] maskiou_t = maskiou_t[select] label_t = label_t[select] maskiou_net_input = maskiou_net_input[select] return losses, [maskiou_net_input, maskiou_t, label_t] return losses