Example #1
0
def point_sample_fine_grained_features(features_list, feature_scales, boxes, point_coords):
    cat_boxes = Boxes.cat(boxes)
    num_boxes = [len(b) for b in boxes]

    point_coords_wrt_image = get_point_coords_wrt_image(cat_boxes.tensor, point_coords)
    split_point_coords_wrt_image = torch.split(point_coords_wrt_image, num_boxes)

    point_features = []
    for idx_img, point_coords_wrt_image_per_image in enumerate(split_point_coords_wrt_image):
        point_features_per_image = []
        for idx_feature, feature_map in enumerate(features_list):
            h, w = feature_map.shape[-2:]
            scale = torch.tensor([w, h], device=feature_map.device) / feature_scales[idx_feature]
            point_coords_scaled = point_coords_wrt_image_per_image / scale
            point_features_per_image.append(
                point_sample(
                    feature_map[idx_img].unsqueeze(0),
                    point_coords_scaled.unsqueeze(0),
                    align_corners=False,
                )
                .squeeze(0)
                .transpose(1, 0)
            )
        point_features.append(cat(point_features_per_image, dim=1))

    return cat(point_features, dim=0), point_coords_wrt_image
Example #2
0
def permute_all_cls_and_box_to_N_HWA_K_and_concat(box_cls,
                                                  box_delta,
                                                  num_classes=80):
    box_cls_flattened = [permute_to_N_HWA_K(x, num_classes) for x in box_cls]
    box_delta_flattened = [permute_to_N_HWA_K(x, 4) for x in box_delta]
    box_cls = cat(box_cls_flattened, dim=1).view(-1, num_classes)
    box_delta = cat(box_delta_flattened, dim=1).view(-1, 4)
    return box_cls, box_delta
Example #3
0
def roi_mask_point_loss(mask_logits, instances, points_coord):
    assert len(instances) == 0 or isinstance(instances[0].gt_masks, BitMasks), \
        "Point head works with GT in 'bitmask' format only. Set INPUT.MASK_FORMAT to 'bitmask'."
    with torch.no_grad():
        cls_agnostic_mask = mask_logits.size(1) == 1
        total_num_masks = mask_logits.size(0)

        gt_classes = []
        gt_mask_logits = []
        idx = 0
        for instances_per_image in instances:
            if not cls_agnostic_mask:
                gt_classes_per_image = instances_per_image.gt_classes.to(
                    dtype=torch.int64)
                gt_classes.append(gt_classes_per_image)

            gt_bit_masks = instances_per_image.gt_masks.tensor
            h, w = instances_per_image.gt_masks.image_size
            scale = torch.tensor([w, h],
                                 dtype=torch.float,
                                 device=gt_bit_masks.device)
            points_coord_grid_sample_format = (
                points_coord[idx:idx + len(instances_per_image)] / scale)
            idx += len(instances_per_image)
            gt_mask_logits.append(
                point_sample(
                    gt_bit_masks.to(torch.float32).unsqueeze(1),
                    points_coord_grid_sample_format,
                    align_corners=False,
                ).squeeze(1))
        gt_mask_logits = cat(gt_mask_logits)

    if gt_mask_logits.numel() == 0:
        return mask_logits.sum() * 0

    if cls_agnostic_mask:
        mask_logits = mask_logits[:, 0]
    else:
        indices = torch.arange(total_num_masks)
        gt_classes = cat(gt_classes, dim=0)
        mask_logits = mask_logits[indices, gt_classes]

    mask_accurate = (mask_logits > 0.0) == gt_mask_logits.to(dtype=torch.uint8)
    mask_accuracy = mask_accurate.nonzero().size(0) / mask_accurate.numel()
    get_event_storage().put_scalar("point_rend/accuracy", mask_accuracy)

    point_loss = F.binary_cross_entropy_with_logits(
        mask_logits, gt_mask_logits.to(dtype=torch.float32), reduction="mean")
    return point_loss
Example #4
0
def get_uncertain_point_coords_with_randomness(
    coarse_logits,
    uncertainty_func,
    num_points,
    oversample_ratio,
    importance_sample_ratio
):
    assert oversample_ratio >= 1
    assert importance_sample_ratio <= 1 and importance_sample_ratio >= 0

    num_boxes = coarse_logits.shape[0]
    num_sampled = int(num_points * oversample_ratio)
    point_coords = torch.rand(num_boxes, num_sampled, 2, device=coarse_logits.device)
    point_logits = point_sample(coarse_logits, point_coords, align_corners=False)
    point_uncertainties = uncertainty_func(point_logits)
    num_uncertain_points = int(importance_sample_ratio * num_points)
    num_random_points = num_points - num_uncertain_points
    idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]
    shift = num_sampled * torch.arange(num_boxes, dtype=torch.long, device=coarse_logits.device)
    idx += shift[:, None]
    point_coords = (
        point_coords.view(-1, 2)[idx.view(-1), :].view(num_boxes, num_uncertain_points, 2)
    )
    if num_random_points > 0:
        point_coords = cat(
            [
                point_coords,
                torch.rand(num_boxes, num_random_points, 2, device=coarse_logits.device),
            ],
            dim=1,
        )
    return point_coords
Example #5
0
    def __init__(
        self,
        box2box_transform,
        pred_class_logits,
        pred_proposal_deltas,
        proposals,
        smooth_l1_beta=0,
    ):
        self.box2box_transform = box2box_transform
        self.num_preds_per_image = [len(p) for p in proposals]
        self.pred_class_logits = pred_class_logits
        self.pred_proposal_deltas = pred_proposal_deltas
        self.smooth_l1_beta = smooth_l1_beta
        self.image_shapes = [x.image_size for x in proposals]

        if len(proposals):
            box_type = type(proposals[0].proposal_boxes)
            self.proposals = box_type.cat(
                [p.proposal_boxes for p in proposals])
            assert not self.proposals.tensor.requires_grad, \
                "Proposals should not require gradients!"

            if proposals[0].has("gt_boxes"):
                self.gt_boxes = box_type.cat([p.gt_boxes for p in proposals])
                assert proposals[0].has("gt_classes")

                self.gt_classes = cat([p.gt_classes for p in proposals], dim=0)
        else:
            self.proposals = Boxes(
                torch.zeros(0, 4, device=self.pred_proposal_deltas.device))
        self._no_instances = len(proposals) == 0
Example #6
0
 def forward(self, fine_grained_features, coarse_features):
     x = torch.cat((fine_grained_features, coarse_features), dim=1)
     for layer in self.fc_layers:
         x = F.relu(layer(x))
         if self.coarse_pred_each_layer:
             x = cat((x, coarse_features), dim=1)
     return self.predictor(x)
Example #7
0
    def prepare_targets(self, points, targets):
        object_sizes_of_interest = [
            [-1, 64],
            [64, 128],
            [128, 256],
            [256, 512],
            [512, INF],
        ]
        expanded_object_sizes_of_interest = []
        for level, points_per_level in enumerate(points):
            object_sizes_of_interest_per_level = points_per_level.new_tensor(
                object_sizes_of_interest[level]
            )
            expanded_object_sizes_of_interest.append(
                object_sizes_of_interest_per_level[None].expand(len(points_per_level), -1)
            )

        expanded_object_sizes_of_interest = cat(expanded_object_sizes_of_interest, 0)
        num_points_per_level = [len(points_per_level) for points_per_level in points]
        self.num_points_per_level = num_points_per_level
        points_all_level = cat(points, dim=0)
        labels, reg_targets = self.compute_targets_for_locations(
            points_all_level,
            targets,
            expanded_object_sizes_of_interest
        )

        for i in range(len(labels)):
            labels[i] = torch.split(labels[i], num_points_per_level, dim=0)
            reg_targets[i] = torch.split(reg_targets[i], num_points_per_level, dim=0)

        labels_level_first = []
        reg_targets_level_first = []
        for level in range(len(points)):
            labels_level_first.append(cat([labels_per_im[level] for labels_per_im in labels], 0))

            reg_targets_per_level = cat(
                [reg_targets_per_im[level] for reg_targets_per_im in reg_targets],
                dim=0
            )

            if self.norm_reg_targets:
                reg_targets_per_level = reg_targets_per_level / self.neck_strides[level]
            reg_targets_level_first.append(reg_targets_per_level)

        return labels_level_first, reg_targets_level_first
Example #8
0
def keypoint_rcnn_inference(pred_keypoint_logits, pred_instances):
    bboxes_flat = cat([b.pred_boxes.tensor for b in pred_instances], dim=0)

    keypoint_results = heatmaps_to_keypoints(pred_keypoint_logits.detach(), bboxes_flat.detach())
    num_instances_per_image = [len(i) for i in pred_instances]
    keypoint_results = keypoint_results[:, :, [0, 1, 3]].split(num_instances_per_image, dim=0)

    for keypoint_results_per_image, instances_per_image in zip(keypoint_results, pred_instances):
        instances_per_image.pred_keypoints = keypoint_results_per_image
Example #9
0
def keypoint_rcnn_loss(pred_keypoint_logits, instances, normalizer):
    heatmaps = []
    valid = []

    keypoint_side_len = pred_keypoint_logits.shape[2]
    for instances_per_image in instances:
        if len(instances_per_image) == 0:
            continue
        keypoints = instances_per_image.gt_keypoints
        heatmaps_per_image, valid_per_image = keypoints.to_heatmap(
            instances_per_image.proposal_boxes.tensor,
            keypoint_side_len
        )
        heatmaps.append(heatmaps_per_image.view(-1))
        valid.append(valid_per_image.view(-1))

    if len(heatmaps):
        keypoint_targets = cat(heatmaps, dim=0)
        valid = cat(valid, dim=0).to(dtype=torch.uint8)
        valid = torch.nonzero(valid).squeeze(1)

    if len(heatmaps) == 0 or valid.numel() == 0:
        global _TOTAL_SKIPPED
        _TOTAL_SKIPPED += 1
        storage = get_event_storage()
        storage.put_scalar("kpts_num_skipped_batches", _TOTAL_SKIPPED, smoothing_hint=False)
        return pred_keypoint_logits.sum() * 0

    N, K, H, W = pred_keypoint_logits.shape
    pred_keypoint_logits = pred_keypoint_logits.view(N * K, H * W)

    keypoint_loss = F.cross_entropy(
        pred_keypoint_logits[valid],
        keypoint_targets[valid],
        reduction="sum"
    )

    if normalizer is None:
        normalizer = valid.numel()
    keypoint_loss /= normalizer

    return keypoint_loss
Example #10
0
def assign_boxes_to_levels(box_lists, min_level: int, max_level: int,
                           canonical_box_size: int, canonical_level: int):
    eps = sys.float_info.epsilon
    box_sizes = torch.sqrt(cat([boxes.area() for boxes in box_lists]))
    level_assignments = torch.floor(canonical_level +
                                    torch.log2(box_sizes / canonical_box_size +
                                               eps))
    level_assignments = torch.clamp(level_assignments,
                                    min=min_level,
                                    max=max_level)
    return level_assignments.to(torch.int64) - min_level
Example #11
0
    def losses(
        self,
        anchors,
        pred_objectness_logits: List[torch.Tensor],
        gt_labels: List[torch.Tensor],
        pred_anchor_deltas: List[torch.Tensor],
        gt_boxes,
    ):
        num_images = len(gt_labels)
        gt_labels = torch.stack(gt_labels)
        anchors = type(anchors[0]).cat(anchors).tensor
        gt_anchor_deltas = [
            self.box2box_transform.get_deltas(anchors, k) for k in gt_boxes
        ]
        gt_anchor_deltas = torch.stack(gt_anchor_deltas)

        pos_mask = gt_labels == 1
        num_pos_anchors = pos_mask.sum().item()
        num_neg_anchors = (gt_labels == 0).sum().item()
        storage = get_event_storage()
        storage.put_scalar("rpn/num_pos_anchors", num_pos_anchors / num_images)
        storage.put_scalar("rpn/num_neg_anchors", num_neg_anchors / num_images)

        localization_loss = smooth_l1_loss(
            cat(pred_anchor_deltas, dim=1)[pos_mask],
            gt_anchor_deltas[pos_mask],
            self.smooth_l1_beta,
            reduction="sum",
        )
        valid_mask = gt_labels >= 0
        objectness_loss = F.binary_cross_entropy_with_logits(
            cat(pred_objectness_logits, dim=1)[valid_mask],
            gt_labels[valid_mask].to(torch.float32),
            reduction="sum",
        )
        normalizer = self.batch_size_per_image * num_images
        return {
            "loss_rpn_cls": objectness_loss / normalizer,
            "loss_rpn_loc": localization_loss / normalizer,
        }
Example #12
0
    def losses(self, anchors, pred_logits, gt_labels, pred_anchor_deltas,
               gt_boxes):
        num_images = len(gt_labels)
        gt_labels = torch.stack(gt_labels)
        anchors = type(anchors[0]).cat(anchors).tensor
        gt_anchor_deltas = [
            self.box2box_transform.get_deltas(anchors, k) for k in gt_boxes
        ]
        gt_anchor_deltas = torch.stack(gt_anchor_deltas)

        valid_mask = gt_labels >= 0
        pos_mask = (gt_labels >= 0) & (gt_labels != self.num_classes)
        num_pos_anchors = pos_mask.sum().item()
        get_event_storage().put_scalar("num_pos_anchors",
                                       num_pos_anchors / num_images)
        self.loss_normalizer = self.loss_normalizer_momentum * self.loss_normalizer + (
            1 - self.loss_normalizer_momentum) * max(num_pos_anchors, 1)

        gt_labels_target = F.one_hot(gt_labels[valid_mask],
                                     num_classes=self.num_classes + 1)[:, :-1]
        loss_cls = sigmoid_focal_loss_jit(
            cat(pred_logits, dim=1)[valid_mask],
            gt_labels_target.to(pred_logits[0].dtype),
            alpha=self.focal_loss_alpha,
            gamma=self.focal_loss_gamma,
            reduction="sum",
        )

        loss_box_reg = smooth_l1_loss(
            cat(pred_anchor_deltas, dim=1)[pos_mask],
            gt_anchor_deltas[pos_mask],
            beta=self.smooth_l1_loss_beta,
            reduction="sum",
        )

        return {
            "loss_cls": loss_cls / self.loss_normalizer,
            "loss_box_reg": loss_box_reg / self.loss_normalizer,
        }
Example #13
0
def convert_boxes_to_pooler_format(box_lists):
    def fmt_box_list(box_tensor, batch_index):
        repeated_index = torch.full((len(box_tensor), 1),
                                    batch_index,
                                    dtype=box_tensor.dtype,
                                    device=box_tensor.device)
        return cat((repeated_index, box_tensor), dim=1)

    pooler_fmt_boxes = cat([
        fmt_box_list(box_list.tensor, i)
        for i, box_list in enumerate(box_lists)
    ],
                           dim=0)

    return pooler_fmt_boxes
Example #14
0
def mask_rcnn_inference(pred_mask_logits, pred_instances):
    cls_agnostic_mask = pred_mask_logits.size(1) == 1

    if cls_agnostic_mask:
        mask_probs_pred = pred_mask_logits.sigmoid()
    else:
        num_masks = pred_mask_logits.shape[0]
        class_pred = cat([i.pred_classes for i in pred_instances])
        indices = torch.arange(num_masks, device=class_pred.device)
        mask_probs_pred = pred_mask_logits[indices,
                                           class_pred][:, None].sigmoid()

    num_boxes_per_image = [len(i) for i in pred_instances]
    mask_probs_pred = mask_probs_pred.split(num_boxes_per_image, dim=0)

    for prob, instances in zip(mask_probs_pred, pred_instances):
        instances.pred_masks = prob
Example #15
0
    def inference_single_image(self, box_cls, box_delta, anchors, image_size):
        boxes_all = []
        scores_all = []
        class_idxs_all = []

        for box_cls_i, box_reg_i, anchors_i in zip(box_cls, box_delta,
                                                   anchors):
            box_cls_i = F.softmax(box_cls_i, dim=-1)[:, :-1].flatten()

            num_topk = box_reg_i.size(0)
            predicted_prob, topk_idxs = box_cls_i.sort(descending=True)
            predicted_prob = predicted_prob[:num_topk]
            topk_idxs = topk_idxs[:num_topk]

            keep_idxs = predicted_prob > self.score_threshold
            predicted_prob = predicted_prob[keep_idxs]
            topk_idxs = topk_idxs[keep_idxs]

            anchor_idxs = topk_idxs // self.num_classes
            classes_idxs = topk_idxs % self.num_classes

            box_reg_i = box_reg_i[anchor_idxs]
            anchors_i = anchors_i[anchor_idxs]
            predicted_boxes = self.box2box_transform.apply_deltas(
                box_reg_i, anchors_i.tensor)

            boxes_all.append(predicted_boxes)
            scores_all.append(predicted_prob)
            class_idxs_all.append(classes_idxs)

        boxes_all, scores_all, class_idxs_all = [
            cat(x) for x in [boxes_all, scores_all, class_idxs_all]
        ]
        keep = batched_nms(boxes_all, scores_all, class_idxs_all,
                           self.nms_threshold)
        keep = keep[:self.max_detections_per_image]

        result = Instances(image_size)
        result.pred_boxes = Boxes(boxes_all[keep])
        result.scores = scores_all[keep]
        result.pred_classes = class_idxs_all[keep]
        return result
Example #16
0
    def losses(self, locations, box_cls, box_regression, centerness, targets):
        N = box_cls[0].size(0)
        num_classes = box_cls[0].size(1)

        labels, reg_targets = self.prepare_targets(locations, targets)

        box_cls_flatten = []
        box_regression_flatten = []
        centerness_flatten = []
        labels_flatten = []
        reg_targets_flatten = []

        for l in range(len(labels)):
            box_cls_flatten.append(box_cls[l].permute(0, 2, 3, 1).reshape(-1, num_classes))
            box_regression_flatten.append(box_regression[l].permute(0, 2, 3, 1).reshape(-1, 4))
            labels_flatten.append(labels[l].reshape(-1))
            reg_targets_flatten.append(reg_targets[l].reshape(-1, 4))
            centerness_flatten.append(centerness[l].reshape(-1))

        box_cls_flatten = cat(box_cls_flatten, dim=0)
        box_regression_flatten = cat(box_regression_flatten, dim=0)
        centerness_flatten = cat(centerness_flatten, dim=0)
        labels_flatten = cat(labels_flatten, dim=0)
        reg_targets_flatten = cat(reg_targets_flatten, dim=0)

        pos_inds = torch.nonzero(
            (labels_flatten >= 0) & (labels_flatten < self.num_classes)
        ).squeeze(1)

        box_regression_flatten = box_regression_flatten[pos_inds]
        reg_targets_flatten = reg_targets_flatten[pos_inds]
        centerness_flatten = centerness_flatten[pos_inds]

        num_foreground = pos_inds.numel()
        gt_labels = torch.zeros_like(box_cls_flatten)
        gt_labels[pos_inds, labels_flatten[pos_inds]] = 1

        cls_loss = sigmoid_focal_loss_jit(
            box_cls_flatten,
            gt_labels,
            alpha=0.25,
            gamma=2,
            reduction="sum"
        ) / num_foreground

        if num_foreground > 0:
            centerness_targets = compute_centerness_targets(reg_targets_flatten)
            reg_loss = giou_loss(
                box_regression_flatten,
                reg_targets_flatten,
                centerness_flatten
            ) / centerness_targets.sum().item()
            centerness_loss = F.binary_cross_entropy_with_logits(
                centerness_flatten,
                centerness_targets,
                reduction="sum"
            ) / num_foreground
        else:
            reg_loss = box_cls_flatten.sum()
            centerness_loss = centerness_flatten.sum()
        return {
            "loss_cls": cls_loss,
            "loss_box_reg": reg_loss,
            "loss_centerness": centerness_loss
        }
Example #17
0
    def _forward_mask_point(self, features, mask_coarse_logits, instances):
        if not self.mask_point_on:
            return {} if self.training else mask_coarse_logits

        mask_features_list = [features[k] for k in self.mask_point_in_features]
        features_scales = [
            self._feature_scales[k] for k in self.mask_point_in_features
        ]

        if self.training:
            proposal_boxes = [x.proposal_boxes for x in instances]
            gt_classes = cat([x.gt_classes for x in instances])
            with torch.no_grad():
                point_coords = get_uncertain_point_coords_with_randomness(
                    mask_coarse_logits,
                    lambda logits: calculate_uncertainty(logits, gt_classes),
                    self.mask_point_train_num_points,
                    self.mask_point_oversample_ratio,
                    self.mask_point_importance_sample_ratio,
                )

            fine_grained_features, point_coords_wrt_image = point_sample_fine_grained_features(
                mask_features_list, features_scales, proposal_boxes,
                point_coords)
            coarse_features = point_sample(mask_coarse_logits,
                                           point_coords,
                                           align_corners=False)
            point_logits = self.mask_point_head(fine_grained_features,
                                                coarse_features)
            return {
                "loss_mask_point":
                roi_mask_point_loss(point_logits, instances,
                                    point_coords_wrt_image)
            }
        else:
            pred_boxes = [x.pred_boxes for x in instances]
            pred_classes = cat([x.pred_classes for x in instances])
            if len(pred_classes) == 0:
                return mask_coarse_logits

            mask_logits = mask_coarse_logits.clone()
            for subdivions_step in range(self.mask_point_subdivision_steps):
                mask_logits = interpolate(mask_logits,
                                          scale_factor=2,
                                          mode="bilinear",
                                          align_corners=False)
                H, W = mask_logits.shape[-2:]
                if (self.mask_point_subdivision_num_points >= 4 * H * W
                        and subdivions_step <
                        self.mask_point_subdivision_steps - 1):
                    continue
                uncertainty_map = calculate_uncertainty(
                    mask_logits, pred_classes)
                point_indices, point_coords = get_uncertain_point_coords_on_grid(
                    uncertainty_map, self.mask_point_subdivision_num_points)
                fine_grained_features, _ = point_sample_fine_grained_features(
                    mask_features_list, features_scales, pred_boxes,
                    point_coords)
                coarse_features = point_sample(mask_coarse_logits,
                                               point_coords,
                                               align_corners=False)
                point_logits = self.mask_point_head(fine_grained_features,
                                                    coarse_features)

                R, C, H, W = mask_logits.shape
                point_indices = point_indices.unsqueeze(1).expand(-1, C, -1)
                mask_logits = (mask_logits.reshape(R, C, H * W).scatter_(
                    2, point_indices, point_logits).view(R, C, H, W))
            return mask_logits
Example #18
0
 def fmt_box_list(box_tensor, batch_index):
     repeated_index = torch.full((len(box_tensor), 1),
                                 batch_index,
                                 dtype=box_tensor.dtype,
                                 device=box_tensor.device)
     return cat((repeated_index, box_tensor), dim=1)
Example #19
0
def find_top_rpn_proposals(
    proposals: List[torch.Tensor],
    pred_objectness_logits: List[torch.Tensor],
    image_sizes: List[Tuple[int, int]],
    nms_thresh: float,
    pre_nms_topk: int,
    post_nms_topk: int,
    min_box_size: int,
    training: bool,
):
    num_images = len(image_sizes)
    device = proposals[0].device

    topk_scores = []
    topk_proposals = []
    level_ids = []
    batch_idx = torch.arange(num_images, device=device)
    for level_id, proposals_i, logits_i in zip(
        itertools.count(),
        proposals,
        pred_objectness_logits
    ):
        Hi_Wi_A = logits_i.shape[1]
        num_proposals_i = min(pre_nms_topk, Hi_Wi_A)

        logits_i, idx = logits_i.sort(descending=True, dim=1)
        topk_scores_i = logits_i[batch_idx, :num_proposals_i]
        topk_idx = idx[batch_idx, :num_proposals_i]

        topk_proposals_i = proposals_i[batch_idx[:, None], topk_idx]

        topk_proposals.append(topk_proposals_i)
        topk_scores.append(topk_scores_i)
        level_ids.append(
            torch.full((num_proposals_i,), level_id, dtype=torch.int64, device=device)
        )

    topk_scores = cat(topk_scores, dim=1)
    topk_proposals = cat(topk_proposals, dim=1)
    level_ids = cat(level_ids, dim=0)

    results = []
    for n, image_size in enumerate(image_sizes):
        boxes = Boxes(topk_proposals[n])
        scores_per_img = topk_scores[n]
        lvl = level_ids

        valid_mask = torch.isfinite(boxes.tensor).all(dim=1) & torch.isfinite(scores_per_img)
        if not valid_mask.all():
            if training:
                raise FloatingPointError(
                    "Predicted boxes or scores contain Inf/NaN. Training has diverged."
                )
            boxes = boxes[valid_mask]
            scores_per_img = scores_per_img[valid_mask]
            lvl = lvl[valid_mask]
        boxes.clip(image_size)

        keep = boxes.nonempty(threshold=min_box_size)
        if keep.sum().item() != len(boxes):
            boxes, scores_per_img, lvl = boxes[keep], scores_per_img[keep], lvl[keep]

        keep = batched_nms(boxes.tensor, scores_per_img, lvl, nms_thresh)
        keep = keep[:post_nms_topk]

        res = Instances(image_size)
        res.proposal_boxes = boxes[keep]
        res.objectness_logits = scores_per_img[keep]
        results.append(res)
    return results
Example #20
0
def mask_rcnn_loss(pred_mask_logits, instances, vis_period=0):
    cls_agnostic_mask = pred_mask_logits.size(1) == 1
    total_num_masks = pred_mask_logits.size(0)
    mask_side_len = pred_mask_logits.size(2)
    assert pred_mask_logits.size(2) == pred_mask_logits.size(
        3), "Mask prediction must be square!"

    gt_classes = []
    gt_masks = []
    for instances_per_image in instances:
        if len(instances_per_image) == 0:
            continue
        if not cls_agnostic_mask:
            gt_classes_per_image = instances_per_image.gt_classes.to(
                dtype=torch.int64)
            gt_classes.append(gt_classes_per_image)

        gt_masks_per_image = instances_per_image.gt_masks.crop_and_resize(
            instances_per_image.proposal_boxes.tensor,
            mask_side_len).to(device=pred_mask_logits.device)
        gt_masks.append(gt_masks_per_image)

    if len(gt_masks) == 0:
        return pred_mask_logits.sum() * 0

    gt_masks = cat(gt_masks, dim=0)

    if cls_agnostic_mask:
        pred_mask_logits = pred_mask_logits[:, 0]
    else:
        indices = torch.arange(total_num_masks)
        gt_classes = cat(gt_classes, dim=0)
        pred_mask_logits = pred_mask_logits[indices, gt_classes]

    if gt_masks.dtype == torch.bool:
        gt_masks_bool = gt_masks
    else:
        gt_masks_bool = gt_masks > 0.5
    gt_masks = gt_masks.to(dtype=torch.float32)

    mask_incorrect = (pred_mask_logits > 0.0) != gt_masks_bool
    mask_accuracy = 1 - (mask_incorrect.sum().item() /
                         max(mask_incorrect.numel(), 1.0))
    num_positive = gt_masks_bool.sum().item()
    false_positive = (mask_incorrect & ~gt_masks_bool).sum().item() / max(
        gt_masks_bool.numel() - num_positive, 1.0)
    false_negative = (mask_incorrect & gt_masks_bool).sum().item() / max(
        num_positive, 1.0)

    storage = get_event_storage()
    storage.put_scalar("mask_rcnn/accuracy", mask_accuracy)
    storage.put_scalar("mask_rcnn/false_positive", false_positive)
    storage.put_scalar("mask_rcnn/false_negative", false_negative)
    if vis_period > 0 and storage.iter % vis_period == 0:
        pred_masks = pred_mask_logits.sigmoid()
        vis_masks = torch.cat([pred_masks, gt_masks], axis=2)
        name = "Left: mask prediction;   Right: mask GT"
        for idx, vis_mask in enumerate(vis_masks):
            vis_mask = torch.stack([vis_mask] * 3, axis=0)
            storage.put_image(name + f" ({idx})", vis_mask)

    mask_loss = F.binary_cross_entropy_with_logits(pred_mask_logits,
                                                   gt_masks,
                                                   reduction="mean")
    return mask_loss