def point_sample_fine_grained_features(features_list, feature_scales, boxes, point_coords): cat_boxes = Boxes.cat(boxes) num_boxes = [len(b) for b in boxes] point_coords_wrt_image = get_point_coords_wrt_image(cat_boxes.tensor, point_coords) split_point_coords_wrt_image = torch.split(point_coords_wrt_image, num_boxes) point_features = [] for idx_img, point_coords_wrt_image_per_image in enumerate(split_point_coords_wrt_image): point_features_per_image = [] for idx_feature, feature_map in enumerate(features_list): h, w = feature_map.shape[-2:] scale = torch.tensor([w, h], device=feature_map.device) / feature_scales[idx_feature] point_coords_scaled = point_coords_wrt_image_per_image / scale point_features_per_image.append( point_sample( feature_map[idx_img].unsqueeze(0), point_coords_scaled.unsqueeze(0), align_corners=False, ) .squeeze(0) .transpose(1, 0) ) point_features.append(cat(point_features_per_image, dim=1)) return cat(point_features, dim=0), point_coords_wrt_image
def permute_all_cls_and_box_to_N_HWA_K_and_concat(box_cls, box_delta, num_classes=80): box_cls_flattened = [permute_to_N_HWA_K(x, num_classes) for x in box_cls] box_delta_flattened = [permute_to_N_HWA_K(x, 4) for x in box_delta] box_cls = cat(box_cls_flattened, dim=1).view(-1, num_classes) box_delta = cat(box_delta_flattened, dim=1).view(-1, 4) return box_cls, box_delta
def roi_mask_point_loss(mask_logits, instances, points_coord): assert len(instances) == 0 or isinstance(instances[0].gt_masks, BitMasks), \ "Point head works with GT in 'bitmask' format only. Set INPUT.MASK_FORMAT to 'bitmask'." with torch.no_grad(): cls_agnostic_mask = mask_logits.size(1) == 1 total_num_masks = mask_logits.size(0) gt_classes = [] gt_mask_logits = [] idx = 0 for instances_per_image in instances: if not cls_agnostic_mask: gt_classes_per_image = instances_per_image.gt_classes.to( dtype=torch.int64) gt_classes.append(gt_classes_per_image) gt_bit_masks = instances_per_image.gt_masks.tensor h, w = instances_per_image.gt_masks.image_size scale = torch.tensor([w, h], dtype=torch.float, device=gt_bit_masks.device) points_coord_grid_sample_format = ( points_coord[idx:idx + len(instances_per_image)] / scale) idx += len(instances_per_image) gt_mask_logits.append( point_sample( gt_bit_masks.to(torch.float32).unsqueeze(1), points_coord_grid_sample_format, align_corners=False, ).squeeze(1)) gt_mask_logits = cat(gt_mask_logits) if gt_mask_logits.numel() == 0: return mask_logits.sum() * 0 if cls_agnostic_mask: mask_logits = mask_logits[:, 0] else: indices = torch.arange(total_num_masks) gt_classes = cat(gt_classes, dim=0) mask_logits = mask_logits[indices, gt_classes] mask_accurate = (mask_logits > 0.0) == gt_mask_logits.to(dtype=torch.uint8) mask_accuracy = mask_accurate.nonzero().size(0) / mask_accurate.numel() get_event_storage().put_scalar("point_rend/accuracy", mask_accuracy) point_loss = F.binary_cross_entropy_with_logits( mask_logits, gt_mask_logits.to(dtype=torch.float32), reduction="mean") return point_loss
def get_uncertain_point_coords_with_randomness( coarse_logits, uncertainty_func, num_points, oversample_ratio, importance_sample_ratio ): assert oversample_ratio >= 1 assert importance_sample_ratio <= 1 and importance_sample_ratio >= 0 num_boxes = coarse_logits.shape[0] num_sampled = int(num_points * oversample_ratio) point_coords = torch.rand(num_boxes, num_sampled, 2, device=coarse_logits.device) point_logits = point_sample(coarse_logits, point_coords, align_corners=False) point_uncertainties = uncertainty_func(point_logits) num_uncertain_points = int(importance_sample_ratio * num_points) num_random_points = num_points - num_uncertain_points idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1] shift = num_sampled * torch.arange(num_boxes, dtype=torch.long, device=coarse_logits.device) idx += shift[:, None] point_coords = ( point_coords.view(-1, 2)[idx.view(-1), :].view(num_boxes, num_uncertain_points, 2) ) if num_random_points > 0: point_coords = cat( [ point_coords, torch.rand(num_boxes, num_random_points, 2, device=coarse_logits.device), ], dim=1, ) return point_coords
def __init__( self, box2box_transform, pred_class_logits, pred_proposal_deltas, proposals, smooth_l1_beta=0, ): self.box2box_transform = box2box_transform self.num_preds_per_image = [len(p) for p in proposals] self.pred_class_logits = pred_class_logits self.pred_proposal_deltas = pred_proposal_deltas self.smooth_l1_beta = smooth_l1_beta self.image_shapes = [x.image_size for x in proposals] if len(proposals): box_type = type(proposals[0].proposal_boxes) self.proposals = box_type.cat( [p.proposal_boxes for p in proposals]) assert not self.proposals.tensor.requires_grad, \ "Proposals should not require gradients!" if proposals[0].has("gt_boxes"): self.gt_boxes = box_type.cat([p.gt_boxes for p in proposals]) assert proposals[0].has("gt_classes") self.gt_classes = cat([p.gt_classes for p in proposals], dim=0) else: self.proposals = Boxes( torch.zeros(0, 4, device=self.pred_proposal_deltas.device)) self._no_instances = len(proposals) == 0
def forward(self, fine_grained_features, coarse_features): x = torch.cat((fine_grained_features, coarse_features), dim=1) for layer in self.fc_layers: x = F.relu(layer(x)) if self.coarse_pred_each_layer: x = cat((x, coarse_features), dim=1) return self.predictor(x)
def prepare_targets(self, points, targets): object_sizes_of_interest = [ [-1, 64], [64, 128], [128, 256], [256, 512], [512, INF], ] expanded_object_sizes_of_interest = [] for level, points_per_level in enumerate(points): object_sizes_of_interest_per_level = points_per_level.new_tensor( object_sizes_of_interest[level] ) expanded_object_sizes_of_interest.append( object_sizes_of_interest_per_level[None].expand(len(points_per_level), -1) ) expanded_object_sizes_of_interest = cat(expanded_object_sizes_of_interest, 0) num_points_per_level = [len(points_per_level) for points_per_level in points] self.num_points_per_level = num_points_per_level points_all_level = cat(points, dim=0) labels, reg_targets = self.compute_targets_for_locations( points_all_level, targets, expanded_object_sizes_of_interest ) for i in range(len(labels)): labels[i] = torch.split(labels[i], num_points_per_level, dim=0) reg_targets[i] = torch.split(reg_targets[i], num_points_per_level, dim=0) labels_level_first = [] reg_targets_level_first = [] for level in range(len(points)): labels_level_first.append(cat([labels_per_im[level] for labels_per_im in labels], 0)) reg_targets_per_level = cat( [reg_targets_per_im[level] for reg_targets_per_im in reg_targets], dim=0 ) if self.norm_reg_targets: reg_targets_per_level = reg_targets_per_level / self.neck_strides[level] reg_targets_level_first.append(reg_targets_per_level) return labels_level_first, reg_targets_level_first
def keypoint_rcnn_inference(pred_keypoint_logits, pred_instances): bboxes_flat = cat([b.pred_boxes.tensor for b in pred_instances], dim=0) keypoint_results = heatmaps_to_keypoints(pred_keypoint_logits.detach(), bboxes_flat.detach()) num_instances_per_image = [len(i) for i in pred_instances] keypoint_results = keypoint_results[:, :, [0, 1, 3]].split(num_instances_per_image, dim=0) for keypoint_results_per_image, instances_per_image in zip(keypoint_results, pred_instances): instances_per_image.pred_keypoints = keypoint_results_per_image
def keypoint_rcnn_loss(pred_keypoint_logits, instances, normalizer): heatmaps = [] valid = [] keypoint_side_len = pred_keypoint_logits.shape[2] for instances_per_image in instances: if len(instances_per_image) == 0: continue keypoints = instances_per_image.gt_keypoints heatmaps_per_image, valid_per_image = keypoints.to_heatmap( instances_per_image.proposal_boxes.tensor, keypoint_side_len ) heatmaps.append(heatmaps_per_image.view(-1)) valid.append(valid_per_image.view(-1)) if len(heatmaps): keypoint_targets = cat(heatmaps, dim=0) valid = cat(valid, dim=0).to(dtype=torch.uint8) valid = torch.nonzero(valid).squeeze(1) if len(heatmaps) == 0 or valid.numel() == 0: global _TOTAL_SKIPPED _TOTAL_SKIPPED += 1 storage = get_event_storage() storage.put_scalar("kpts_num_skipped_batches", _TOTAL_SKIPPED, smoothing_hint=False) return pred_keypoint_logits.sum() * 0 N, K, H, W = pred_keypoint_logits.shape pred_keypoint_logits = pred_keypoint_logits.view(N * K, H * W) keypoint_loss = F.cross_entropy( pred_keypoint_logits[valid], keypoint_targets[valid], reduction="sum" ) if normalizer is None: normalizer = valid.numel() keypoint_loss /= normalizer return keypoint_loss
def assign_boxes_to_levels(box_lists, min_level: int, max_level: int, canonical_box_size: int, canonical_level: int): eps = sys.float_info.epsilon box_sizes = torch.sqrt(cat([boxes.area() for boxes in box_lists])) level_assignments = torch.floor(canonical_level + torch.log2(box_sizes / canonical_box_size + eps)) level_assignments = torch.clamp(level_assignments, min=min_level, max=max_level) return level_assignments.to(torch.int64) - min_level
def losses( self, anchors, pred_objectness_logits: List[torch.Tensor], gt_labels: List[torch.Tensor], pred_anchor_deltas: List[torch.Tensor], gt_boxes, ): num_images = len(gt_labels) gt_labels = torch.stack(gt_labels) anchors = type(anchors[0]).cat(anchors).tensor gt_anchor_deltas = [ self.box2box_transform.get_deltas(anchors, k) for k in gt_boxes ] gt_anchor_deltas = torch.stack(gt_anchor_deltas) pos_mask = gt_labels == 1 num_pos_anchors = pos_mask.sum().item() num_neg_anchors = (gt_labels == 0).sum().item() storage = get_event_storage() storage.put_scalar("rpn/num_pos_anchors", num_pos_anchors / num_images) storage.put_scalar("rpn/num_neg_anchors", num_neg_anchors / num_images) localization_loss = smooth_l1_loss( cat(pred_anchor_deltas, dim=1)[pos_mask], gt_anchor_deltas[pos_mask], self.smooth_l1_beta, reduction="sum", ) valid_mask = gt_labels >= 0 objectness_loss = F.binary_cross_entropy_with_logits( cat(pred_objectness_logits, dim=1)[valid_mask], gt_labels[valid_mask].to(torch.float32), reduction="sum", ) normalizer = self.batch_size_per_image * num_images return { "loss_rpn_cls": objectness_loss / normalizer, "loss_rpn_loc": localization_loss / normalizer, }
def losses(self, anchors, pred_logits, gt_labels, pred_anchor_deltas, gt_boxes): num_images = len(gt_labels) gt_labels = torch.stack(gt_labels) anchors = type(anchors[0]).cat(anchors).tensor gt_anchor_deltas = [ self.box2box_transform.get_deltas(anchors, k) for k in gt_boxes ] gt_anchor_deltas = torch.stack(gt_anchor_deltas) valid_mask = gt_labels >= 0 pos_mask = (gt_labels >= 0) & (gt_labels != self.num_classes) num_pos_anchors = pos_mask.sum().item() get_event_storage().put_scalar("num_pos_anchors", num_pos_anchors / num_images) self.loss_normalizer = self.loss_normalizer_momentum * self.loss_normalizer + ( 1 - self.loss_normalizer_momentum) * max(num_pos_anchors, 1) gt_labels_target = F.one_hot(gt_labels[valid_mask], num_classes=self.num_classes + 1)[:, :-1] loss_cls = sigmoid_focal_loss_jit( cat(pred_logits, dim=1)[valid_mask], gt_labels_target.to(pred_logits[0].dtype), alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="sum", ) loss_box_reg = smooth_l1_loss( cat(pred_anchor_deltas, dim=1)[pos_mask], gt_anchor_deltas[pos_mask], beta=self.smooth_l1_loss_beta, reduction="sum", ) return { "loss_cls": loss_cls / self.loss_normalizer, "loss_box_reg": loss_box_reg / self.loss_normalizer, }
def convert_boxes_to_pooler_format(box_lists): def fmt_box_list(box_tensor, batch_index): repeated_index = torch.full((len(box_tensor), 1), batch_index, dtype=box_tensor.dtype, device=box_tensor.device) return cat((repeated_index, box_tensor), dim=1) pooler_fmt_boxes = cat([ fmt_box_list(box_list.tensor, i) for i, box_list in enumerate(box_lists) ], dim=0) return pooler_fmt_boxes
def mask_rcnn_inference(pred_mask_logits, pred_instances): cls_agnostic_mask = pred_mask_logits.size(1) == 1 if cls_agnostic_mask: mask_probs_pred = pred_mask_logits.sigmoid() else: num_masks = pred_mask_logits.shape[0] class_pred = cat([i.pred_classes for i in pred_instances]) indices = torch.arange(num_masks, device=class_pred.device) mask_probs_pred = pred_mask_logits[indices, class_pred][:, None].sigmoid() num_boxes_per_image = [len(i) for i in pred_instances] mask_probs_pred = mask_probs_pred.split(num_boxes_per_image, dim=0) for prob, instances in zip(mask_probs_pred, pred_instances): instances.pred_masks = prob
def inference_single_image(self, box_cls, box_delta, anchors, image_size): boxes_all = [] scores_all = [] class_idxs_all = [] for box_cls_i, box_reg_i, anchors_i in zip(box_cls, box_delta, anchors): box_cls_i = F.softmax(box_cls_i, dim=-1)[:, :-1].flatten() num_topk = box_reg_i.size(0) predicted_prob, topk_idxs = box_cls_i.sort(descending=True) predicted_prob = predicted_prob[:num_topk] topk_idxs = topk_idxs[:num_topk] keep_idxs = predicted_prob > self.score_threshold predicted_prob = predicted_prob[keep_idxs] topk_idxs = topk_idxs[keep_idxs] anchor_idxs = topk_idxs // self.num_classes classes_idxs = topk_idxs % self.num_classes box_reg_i = box_reg_i[anchor_idxs] anchors_i = anchors_i[anchor_idxs] predicted_boxes = self.box2box_transform.apply_deltas( box_reg_i, anchors_i.tensor) boxes_all.append(predicted_boxes) scores_all.append(predicted_prob) class_idxs_all.append(classes_idxs) boxes_all, scores_all, class_idxs_all = [ cat(x) for x in [boxes_all, scores_all, class_idxs_all] ] keep = batched_nms(boxes_all, scores_all, class_idxs_all, self.nms_threshold) keep = keep[:self.max_detections_per_image] result = Instances(image_size) result.pred_boxes = Boxes(boxes_all[keep]) result.scores = scores_all[keep] result.pred_classes = class_idxs_all[keep] return result
def losses(self, locations, box_cls, box_regression, centerness, targets): N = box_cls[0].size(0) num_classes = box_cls[0].size(1) labels, reg_targets = self.prepare_targets(locations, targets) box_cls_flatten = [] box_regression_flatten = [] centerness_flatten = [] labels_flatten = [] reg_targets_flatten = [] for l in range(len(labels)): box_cls_flatten.append(box_cls[l].permute(0, 2, 3, 1).reshape(-1, num_classes)) box_regression_flatten.append(box_regression[l].permute(0, 2, 3, 1).reshape(-1, 4)) labels_flatten.append(labels[l].reshape(-1)) reg_targets_flatten.append(reg_targets[l].reshape(-1, 4)) centerness_flatten.append(centerness[l].reshape(-1)) box_cls_flatten = cat(box_cls_flatten, dim=0) box_regression_flatten = cat(box_regression_flatten, dim=0) centerness_flatten = cat(centerness_flatten, dim=0) labels_flatten = cat(labels_flatten, dim=0) reg_targets_flatten = cat(reg_targets_flatten, dim=0) pos_inds = torch.nonzero( (labels_flatten >= 0) & (labels_flatten < self.num_classes) ).squeeze(1) box_regression_flatten = box_regression_flatten[pos_inds] reg_targets_flatten = reg_targets_flatten[pos_inds] centerness_flatten = centerness_flatten[pos_inds] num_foreground = pos_inds.numel() gt_labels = torch.zeros_like(box_cls_flatten) gt_labels[pos_inds, labels_flatten[pos_inds]] = 1 cls_loss = sigmoid_focal_loss_jit( box_cls_flatten, gt_labels, alpha=0.25, gamma=2, reduction="sum" ) / num_foreground if num_foreground > 0: centerness_targets = compute_centerness_targets(reg_targets_flatten) reg_loss = giou_loss( box_regression_flatten, reg_targets_flatten, centerness_flatten ) / centerness_targets.sum().item() centerness_loss = F.binary_cross_entropy_with_logits( centerness_flatten, centerness_targets, reduction="sum" ) / num_foreground else: reg_loss = box_cls_flatten.sum() centerness_loss = centerness_flatten.sum() return { "loss_cls": cls_loss, "loss_box_reg": reg_loss, "loss_centerness": centerness_loss }
def _forward_mask_point(self, features, mask_coarse_logits, instances): if not self.mask_point_on: return {} if self.training else mask_coarse_logits mask_features_list = [features[k] for k in self.mask_point_in_features] features_scales = [ self._feature_scales[k] for k in self.mask_point_in_features ] if self.training: proposal_boxes = [x.proposal_boxes for x in instances] gt_classes = cat([x.gt_classes for x in instances]) with torch.no_grad(): point_coords = get_uncertain_point_coords_with_randomness( mask_coarse_logits, lambda logits: calculate_uncertainty(logits, gt_classes), self.mask_point_train_num_points, self.mask_point_oversample_ratio, self.mask_point_importance_sample_ratio, ) fine_grained_features, point_coords_wrt_image = point_sample_fine_grained_features( mask_features_list, features_scales, proposal_boxes, point_coords) coarse_features = point_sample(mask_coarse_logits, point_coords, align_corners=False) point_logits = self.mask_point_head(fine_grained_features, coarse_features) return { "loss_mask_point": roi_mask_point_loss(point_logits, instances, point_coords_wrt_image) } else: pred_boxes = [x.pred_boxes for x in instances] pred_classes = cat([x.pred_classes for x in instances]) if len(pred_classes) == 0: return mask_coarse_logits mask_logits = mask_coarse_logits.clone() for subdivions_step in range(self.mask_point_subdivision_steps): mask_logits = interpolate(mask_logits, scale_factor=2, mode="bilinear", align_corners=False) H, W = mask_logits.shape[-2:] if (self.mask_point_subdivision_num_points >= 4 * H * W and subdivions_step < self.mask_point_subdivision_steps - 1): continue uncertainty_map = calculate_uncertainty( mask_logits, pred_classes) point_indices, point_coords = get_uncertain_point_coords_on_grid( uncertainty_map, self.mask_point_subdivision_num_points) fine_grained_features, _ = point_sample_fine_grained_features( mask_features_list, features_scales, pred_boxes, point_coords) coarse_features = point_sample(mask_coarse_logits, point_coords, align_corners=False) point_logits = self.mask_point_head(fine_grained_features, coarse_features) R, C, H, W = mask_logits.shape point_indices = point_indices.unsqueeze(1).expand(-1, C, -1) mask_logits = (mask_logits.reshape(R, C, H * W).scatter_( 2, point_indices, point_logits).view(R, C, H, W)) return mask_logits
def fmt_box_list(box_tensor, batch_index): repeated_index = torch.full((len(box_tensor), 1), batch_index, dtype=box_tensor.dtype, device=box_tensor.device) return cat((repeated_index, box_tensor), dim=1)
def find_top_rpn_proposals( proposals: List[torch.Tensor], pred_objectness_logits: List[torch.Tensor], image_sizes: List[Tuple[int, int]], nms_thresh: float, pre_nms_topk: int, post_nms_topk: int, min_box_size: int, training: bool, ): num_images = len(image_sizes) device = proposals[0].device topk_scores = [] topk_proposals = [] level_ids = [] batch_idx = torch.arange(num_images, device=device) for level_id, proposals_i, logits_i in zip( itertools.count(), proposals, pred_objectness_logits ): Hi_Wi_A = logits_i.shape[1] num_proposals_i = min(pre_nms_topk, Hi_Wi_A) logits_i, idx = logits_i.sort(descending=True, dim=1) topk_scores_i = logits_i[batch_idx, :num_proposals_i] topk_idx = idx[batch_idx, :num_proposals_i] topk_proposals_i = proposals_i[batch_idx[:, None], topk_idx] topk_proposals.append(topk_proposals_i) topk_scores.append(topk_scores_i) level_ids.append( torch.full((num_proposals_i,), level_id, dtype=torch.int64, device=device) ) topk_scores = cat(topk_scores, dim=1) topk_proposals = cat(topk_proposals, dim=1) level_ids = cat(level_ids, dim=0) results = [] for n, image_size in enumerate(image_sizes): boxes = Boxes(topk_proposals[n]) scores_per_img = topk_scores[n] lvl = level_ids valid_mask = torch.isfinite(boxes.tensor).all(dim=1) & torch.isfinite(scores_per_img) if not valid_mask.all(): if training: raise FloatingPointError( "Predicted boxes or scores contain Inf/NaN. Training has diverged." ) boxes = boxes[valid_mask] scores_per_img = scores_per_img[valid_mask] lvl = lvl[valid_mask] boxes.clip(image_size) keep = boxes.nonempty(threshold=min_box_size) if keep.sum().item() != len(boxes): boxes, scores_per_img, lvl = boxes[keep], scores_per_img[keep], lvl[keep] keep = batched_nms(boxes.tensor, scores_per_img, lvl, nms_thresh) keep = keep[:post_nms_topk] res = Instances(image_size) res.proposal_boxes = boxes[keep] res.objectness_logits = scores_per_img[keep] results.append(res) return results
def mask_rcnn_loss(pred_mask_logits, instances, vis_period=0): cls_agnostic_mask = pred_mask_logits.size(1) == 1 total_num_masks = pred_mask_logits.size(0) mask_side_len = pred_mask_logits.size(2) assert pred_mask_logits.size(2) == pred_mask_logits.size( 3), "Mask prediction must be square!" gt_classes = [] gt_masks = [] for instances_per_image in instances: if len(instances_per_image) == 0: continue if not cls_agnostic_mask: gt_classes_per_image = instances_per_image.gt_classes.to( dtype=torch.int64) gt_classes.append(gt_classes_per_image) gt_masks_per_image = instances_per_image.gt_masks.crop_and_resize( instances_per_image.proposal_boxes.tensor, mask_side_len).to(device=pred_mask_logits.device) gt_masks.append(gt_masks_per_image) if len(gt_masks) == 0: return pred_mask_logits.sum() * 0 gt_masks = cat(gt_masks, dim=0) if cls_agnostic_mask: pred_mask_logits = pred_mask_logits[:, 0] else: indices = torch.arange(total_num_masks) gt_classes = cat(gt_classes, dim=0) pred_mask_logits = pred_mask_logits[indices, gt_classes] if gt_masks.dtype == torch.bool: gt_masks_bool = gt_masks else: gt_masks_bool = gt_masks > 0.5 gt_masks = gt_masks.to(dtype=torch.float32) mask_incorrect = (pred_mask_logits > 0.0) != gt_masks_bool mask_accuracy = 1 - (mask_incorrect.sum().item() / max(mask_incorrect.numel(), 1.0)) num_positive = gt_masks_bool.sum().item() false_positive = (mask_incorrect & ~gt_masks_bool).sum().item() / max( gt_masks_bool.numel() - num_positive, 1.0) false_negative = (mask_incorrect & gt_masks_bool).sum().item() / max( num_positive, 1.0) storage = get_event_storage() storage.put_scalar("mask_rcnn/accuracy", mask_accuracy) storage.put_scalar("mask_rcnn/false_positive", false_positive) storage.put_scalar("mask_rcnn/false_negative", false_negative) if vis_period > 0 and storage.iter % vis_period == 0: pred_masks = pred_mask_logits.sigmoid() vis_masks = torch.cat([pred_masks, gt_masks], axis=2) name = "Left: mask prediction; Right: mask GT" for idx, vis_mask in enumerate(vis_masks): vis_mask = torch.stack([vis_mask] * 3, axis=0) storage.put_image(name + f" ({idx})", vis_mask) mask_loss = F.binary_cross_entropy_with_logits(pred_mask_logits, gt_masks, reduction="mean") return mask_loss