def __call__(self, mask_logits): """ Arguments: mask_logits (Tensor) Return: mask_loss (Tensor): scalar tensor containing the loss If we use maskiou head, we will return extra feature for maskiou head. """ labels = [ proposals_per_img.get_field("labels") for proposals_per_img in self.positive_proposals ] mask_targets = [ proposals_per_img.get_field("mask_targets") for proposals_per_img in self.positive_proposals ] labels = cat(labels, dim=0) mask_targets = cat(mask_targets, dim=0) positive_inds = torch.nonzero(labels > 0).squeeze(1) labels_pos = labels[positive_inds] # torch.mean (in binary_cross_entropy_with_logits) doesn't # accept empty tensors, so handle it separately if mask_targets.numel() == 0: return mask_logits.sum() * 0 mask_loss = F.binary_cross_entropy_with_logits( mask_logits[positive_inds, labels_pos], mask_targets) mask_loss *= cfg.MRCNN.LOSS_WEIGHT return mask_loss
def __call__(self, keypoint_logits): heatmaps = [] valid = [] for proposals_per_image in self.positive_proposals: kp = proposals_per_image.get_field("keypoints_target") heatmaps_per_image, valid_per_image = project_keypoints_to_heatmap( kp, proposals_per_image, self.resolution) heatmaps.append(heatmaps_per_image.view(-1)) valid.append(valid_per_image.view(-1)) keypoint_targets = cat(heatmaps, dim=0) valid = cat(valid, dim=0).to(dtype=torch.uint8) valid = torch.nonzero(valid).squeeze(1) # torch.mean (in binary_cross_entropy_with_logits) does'nt # accept empty tensors, so handle it sepaartely if keypoint_targets.numel() == 0 or len(valid) == 0: return keypoint_logits.sum() * 0 N, K, H, W = keypoint_logits.shape keypoint_logits = keypoint_logits.view(N * K, H * W) keypoint_loss = F.cross_entropy(keypoint_logits[valid], keypoint_targets[valid]) keypoint_loss *= cfg.KRCNN.LOSS_WEIGHT return keypoint_loss
def convert_to_roi_format(self, boxes): concat_boxes = cat([b.bbox for b in boxes], dim=0) device, dtype = concat_boxes.device, concat_boxes.dtype ids = cat( [ torch.full((len(b), 1), i, dtype=dtype, device=device) for i, b in enumerate(boxes) ], dim=0, ) rois = torch.cat([ids, concat_boxes], dim=1) return rois
def __call__(self, class_logits, box_regression): """ Computes the loss for Faster R-CNN. This requires that the subsample method has been called beforehand. Arguments: class_logits (list[Tensor]) box_regression (list[Tensor]) Returns: classification_loss (Tensor) box_loss (Tensor) """ loss_dict = {} if not hasattr(self, "_proposals"): raise RuntimeError("subsample needs to be called before") proposals = self._proposals labels = cat([proposal.get_field("labels") for proposal in proposals], dim=0) assert class_logits[0] is not None or box_regression[0] is not None, 'Fast R-CNN should keep 1 branch at least' if class_logits[0] is not None: class_logits = cat(class_logits, dim=0) classification_loss = F.cross_entropy(class_logits, labels) loss_dict["loss_classifier"] = classification_loss if box_regression[0] is not None: box_regression = cat(box_regression, dim=0) device = box_regression.device regression_targets = cat([proposal.get_field("regression_targets") for proposal in proposals], dim=0) # get indices that correspond to the regression targets for # the corresponding ground truth labels, to be used with # advanced indexing sampled_pos_inds_subset = torch.nonzero(labels > 0).squeeze(1) labels_pos = labels[sampled_pos_inds_subset] if self.cls_agnostic_bbox_reg: map_inds = torch.tensor([4, 5, 6, 7], device=device) else: map_inds = 4 * labels_pos[:, None] + torch.tensor([0, 1, 2, 3], device=device) box_loss = smooth_l1_loss( box_regression[sampled_pos_inds_subset[:, None], map_inds], regression_targets[sampled_pos_inds_subset], size_average=False, beta=cfg.FAST_RCNN.SMOOTH_L1_BETA, ) box_loss = box_loss / labels.numel() loss_dict["loss_box_reg"] = box_loss return loss_dict
def __call__(self, parsing_logits): parsing_targets = [proposals_per_img.get_field("parsing_targets") for proposals_per_img in self.positive_proposals] parsing_targets = cat(parsing_targets, dim=0) if parsing_targets.numel() == 0: if not self.parsingiou_on: return parsing_logits.sum() * 0 else: return parsing_logits.sum() * 0, None if self.parsingiou_on: # TODO: use tensor for speeding up pred_parsings_np = parsing_logits.detach().argmax(dim=1).cpu().numpy() parsing_targets_np = parsing_targets.cpu().numpy() N = parsing_targets_np.shape[0] parsingiou_targets = np.zeros(N, dtype=np.float) for _ in range(N): parsing_iou = cal_one_mean_iou(parsing_targets_np[_], pred_parsings_np[_], cfg.PRCNN.NUM_PARSING) parsingiou_targets[_] = np.nanmean(parsing_iou) parsingiou_targets = torch.from_numpy(parsingiou_targets).to(parsing_targets.device, dtype=torch.float) parsing_loss = F.cross_entropy( parsing_logits, parsing_targets, reduction="mean" ) parsing_loss *= cfg.PRCNN.LOSS_WEIGHT if not self.parsingiou_on: return parsing_loss else: return parsing_loss, parsingiou_targets
def __call__(self, semantic_pred, targets): labels = self.semseg_batch_resize(targets) labels = cat([label for label in labels], dim=0).long() assert len(labels.shape) == 3 loss_semseg = F.cross_entropy(semantic_pred, labels, ignore_index=self.ignore_label) loss_semseg *= self.loss_weight return loss_semseg
def __call__(self, boxlists): """ Arguments: boxlists (list[BoxList]) """ # Compute level ids s = torch.sqrt(cat([boxlist.area() for boxlist in boxlists])) # Eqn.(1) in FPN paper target_lvls = torch.floor(self.lvl0 + torch.log2(s / self.s0 + self.eps)) target_lvls = torch.clamp(target_lvls, min=self.k_min, max=self.k_max) return target_lvls.to(torch.int64) - self.k_min
def __call__(self, parsing_logits): parsing_targets = [ proposals_per_img.get_field("parsing_targets") for proposals_per_img in self.positive_proposals ] parsing_targets = cat(parsing_targets, dim=0) if parsing_targets.numel() == 0: return parsing_logits.sum() * 0 parsing_loss = F.cross_entropy(parsing_logits, parsing_targets, reduction="mean") parsing_loss *= cfg.PRCNN.LOSS_WEIGHT return parsing_loss
def __call__(self, logits): targets = [ proposals_per_img.get_field("targets") for proposals_per_img in self.positive_proposals ] targets = cat(targets, dim=0).float() if targets.numel() == 0: return logits['fused'].sum() * 0 loss_fused = self.loss_weight * F.binary_cross_entropy_with_logits( logits['fused'], targets) loss_unfused = self.loss_weight * F.binary_cross_entropy_with_logits( logits['unfused'], targets) loss = loss_fused + loss_unfused return loss
def __call__(self, mask_logits): """ Arguments: mask_logits (Tensor) Return: mask_loss (Tensor): scalar tensor containing the loss If we use maskiou head, we will return extra feature for maskiou head. """ labels = [ proposals_per_img.get_field("labels") for proposals_per_img in self.positive_proposals ] mask_targets = [ proposals_per_img.get_field("mask_targets") for proposals_per_img in self.positive_proposals ] if self.maskiou_on: mask_ratios = [ proposals_per_img.get_field("mask_ratios") for proposals_per_img in self.positive_proposals ] labels = cat(labels, dim=0) mask_targets = cat(mask_targets, dim=0) positive_inds = torch.nonzero(labels > 0).squeeze(1) labels_pos = labels[positive_inds] # torch.mean (in binary_cross_entropy_with_logits) doesn't # accept empty tensors, so handle it separately if mask_targets.numel() == 0: if not self.maskiou_on: return mask_logits.sum() * 0 else: selected_index = torch.arange(mask_logits.shape[0], device=labels.device) selected_mask = mask_logits[selected_index, labels] mask_num, mask_h, mask_w = selected_mask.shape selected_mask = selected_mask.reshape(mask_num, 1, mask_h, mask_w) return mask_logits.sum() * 0, selected_mask, labels, None if self.maskiou_on: mask_ratios = cat(mask_ratios, dim=0) value_eps = 1e-10 * torch.ones(mask_targets.shape[0], device=labels.device) mask_ratios = torch.max(mask_ratios, value_eps) pred_masks = mask_logits[positive_inds, labels_pos] pred_masks[:] = pred_masks > 0.5 mask_targets_full_area = mask_targets.sum(dim=[1, 2]) / mask_ratios mask_ovr = pred_masks * mask_targets mask_ovr_area = mask_ovr.sum(dim=[1, 2]) mask_union_area = pred_masks.sum( dim=[1, 2]) + mask_targets_full_area - mask_ovr_area value_1 = torch.ones(pred_masks.shape[0], device=labels.device) value_0 = torch.zeros(pred_masks.shape[0], device=labels.device) mask_union_area = torch.max(mask_union_area, value_1) mask_ovr_area = torch.max(mask_ovr_area, value_0) maskiou_targets = mask_ovr_area / mask_union_area mask_loss = F.binary_cross_entropy_with_logits( mask_logits[positive_inds, labels_pos], mask_targets) mask_loss *= cfg.MRCNN.LOSS_WEIGHT if not self.maskiou_on: return mask_loss else: selected_index = torch.arange(mask_logits.shape[0], device=labels.device) selected_mask = mask_logits[selected_index, labels] mask_num, mask_h, mask_w = selected_mask.shape selected_mask = selected_mask.reshape(mask_num, 1, mask_h, mask_w) selected_mask = selected_mask.sigmoid() return mask_loss, selected_mask, labels, maskiou_targets