def forward_for_single_feature_map(self, anchors, objectness,
                                       box_regression):
        """
        Arguments:
            anchors: list[BoxList]
            objectness: tensor of size N, A, H, W
            box_regression: tensor of size N, A * 4, H, W
        """
        device = objectness.device
        N, A, H, W = objectness.shape

        num_anchors = A * H * W

        # If inputs are on GPU, use a faster path
        use_fast_cuda_path = (objectness.is_cuda and box_regression.is_cuda)
        # Encompasses box decode, clip_to_image and remove_small_boxes calls
        if use_fast_cuda_path:
            objectness = objectness.reshape(N, -1)  # Now [N, AHW]
            objectness = objectness.sigmoid()

            pre_nms_top_n = min(self.pre_nms_top_n, num_anchors)
            objectness, topk_idx = objectness.topk(pre_nms_top_n,
                                                   dim=1,
                                                   sorted=True)

            # Get all image shapes, and cat them together
            image_shapes = [box.size for box in anchors]
            image_shapes_cat = torch.tensor([box.size for box in anchors],
                                            device=objectness.device).float()

            # Get a single tensor for all anchors
            concat_anchors = torch.cat([a.bbox for a in anchors], dim=0)

            # Note: Take all anchors, we'll index accordingly inside the kernel
            # only take the anchors corresponding to the topk boxes
            concat_anchors = concat_anchors.reshape(N, -1,
                                                    4)  # [batch_idx, topk_idx]

            # Return pre-nms boxes, associated scores and keep flag
            # Encompasses:
            # 1. Box decode
            # 2. Box clipping
            # 3. Box filtering
            # At the end we need to keep only the proposals & scores flagged
            # Note: topk_idx, objectness are sorted => proposals, objectness, keep are also
            # sorted -- this is important later
            proposals, objectness, keep = C.GeneratePreNMSUprightBoxes(
                N,
                A,
                H,
                W,
                topk_idx,
                objectness.float(
                ),  # Need to cast these as kernel doesn't support fp16
                box_regression.float(),
                concat_anchors,
                image_shapes_cat,
                pre_nms_top_n,
                self.min_size,
                self.box_coder.bbox_xform_clip,
                True)

            # view as [N, pre_nms_top_n, 4]
            proposals = proposals.view(N, -1, 4)
            objectness = objectness.view(N, -1)
        else:
            # reverse the reshape from before ready for permutation
            objectness = objectness.reshape(N, A, H, W)
            objectness = objectness.permute(0, 2, 3, 1).reshape(N, -1)
            objectness = objectness.sigmoid()

            pre_nms_top_n = min(self.pre_nms_top_n, num_anchors)
            objectness, topk_idx = objectness.topk(pre_nms_top_n,
                                                   dim=1,
                                                   sorted=True)

            # put in the same format as anchors
            box_regression = box_regression.view(N, -1, 4, H,
                                                 W).permute(0, 3, 4, 1, 2)
            box_regression = box_regression.reshape(N, -1, 4)

            batch_idx = torch.arange(N, device=device)[:, None]
            box_regression = box_regression[batch_idx, topk_idx]

            image_shapes = [box.size for box in anchors]
            concat_anchors = torch.cat([a.bbox for a in anchors], dim=0)
            concat_anchors = concat_anchors.reshape(N, -1, 4)[batch_idx,
                                                              topk_idx]

            proposals = self.box_coder.decode(box_regression.view(-1, 4),
                                              concat_anchors.view(-1, 4))

            proposals = proposals.view(N, -1, 4)

        # handle non-fast path without changing the loop
        if not use_fast_cuda_path:
            keep = [None for _ in range(N)]

        result = []
        for proposal, score, im_shape, k in zip(proposals, objectness,
                                                image_shapes, keep):
            if use_fast_cuda_path:
                # Note: Want k to be applied per-image instead of all-at-once in batched code earlier
                #       clip_to_image and remove_small_boxes already done in single kernel
                p = proposal.masked_select(k[:, None]).view(-1, 4)
                score = score.masked_select(k)
                boxlist = BoxList(p, im_shape, mode="xyxy")
            else:
                boxlist = BoxList(proposal, im_shape, mode="xyxy")
                boxlist = boxlist.clip_to_image(remove_empty=False)
                boxlist = remove_small_boxes(boxlist, self.min_size)
            boxlist.add_field("objectness", score)
            boxlist = boxlist_nms(
                boxlist,
                self.nms_thresh,
                max_proposals=self.post_nms_top_n,
                score_field="objectness",
            )
            result.append(boxlist)
        return result
    def forward_for_single_feature_map(self, anchors, objectness,
                                       box_regression):
        """
        Arguments:
            anchors: list of BoxList
            objectness: tensor of size N, A, H, W
            box_regression: tensor of size N, A * 4, H, W
        """
        device = objectness.device
        N, A, H, W = objectness.shape

        num_anchors = A * H * W
        objectness = objectness.reshape(N, -1)  # Now [N, AHW]
        objectness = objectness.sigmoid()

        pre_nms_top_n = min(self.pre_nms_top_n, num_anchors)
        objectness, topk_idx = objectness.topk(pre_nms_top_n,
                                               dim=1,
                                               sorted=True)

        use_fast_cuda_path = objectness.is_cuda
        if use_fast_cuda_path:
            # New code
            batch_idx = torch.arange(N, device=device)[:, None]

            # Get all image shapes, and cat them together
            image_shapes = [box.size[::-1] for box in anchors]
            image_shapes_cat = torch.cat([
                torch.tensor(box.size[::-1], device=objectness.device).float()
                for box in anchors
            ])

            # Get a single tensor for all anchors
            concat_anchors = torch.cat([a.bbox for a in anchors], dim=0)

            # Note: Take all anchors, we'll index accordingly inside the kernel
            # only take the anchors corresponding to the topk boxes
            concat_anchors = concat_anchors.reshape(N, -1,
                                                    4)  # [batch_idx, topk_idx]

            # Return pre-nms boxes, associated scores and keep flag
            # Encompasses:
            # 1. Box decode
            # 2. Box clipping
            # 3. Box filtering
            # At the end we need to keep only the proposals & scores flagged
            # Note: topk_idx, objectness are sorted => proposals, objectness, keep are also
            # sorted -- this is important later
            proposals, objectness, keep = C.GeneratePreNMSUprightBoxes(
                N,
                A,
                H,
                W,
                topk_idx,
                objectness.float(
                ),  # Need to cast these as kernel doesn't support fp16
                box_regression.float(),
                concat_anchors,
                image_shapes_cat,
                pre_nms_top_n,
                0,  # feature_stride
                self.min_size,
                self.box_coder.bbox_xform_clip,
                True)

            # view as [N, pre_nms_top_n, 4]
            proposals = proposals.view(N, -1, 4)
            objectness = objectness.view(N, -1)
        else:
            # put in the same format as anchors
            objectness = objectness.permute(0, 2, 3, 1).reshape(N, -1)
            objectness = objectness.sigmoid()
            box_regression = box_regression.view(N, -1, 4, H,
                                                 W).permute(0, 3, 4, 1, 2)
            box_regression = box_regression.reshape(N, -1, 4)

            num_anchors = A * H * W

            pre_nms_top_n = min(self.pre_nms_top_n, num_anchors)
            objectness, topk_idx = objectness.topk(pre_nms_top_n,
                                                   dim=1,
                                                   sorted=True)

            # TODO check if this batch_idx is really needed
            batch_idx = torch.arange(N, device=device)[:, None]
            box_regression = box_regression[batch_idx, topk_idx]

            image_shapes = [box.size[::-1] for box in anchors]
            concat_anchors = torch.cat([a.bbox for a in anchors], dim=0)
            concat_anchors = concat_anchors.reshape(N, -1, 4)[batch_idx,
                                                              topk_idx]

            proposals = self.box_coder.decode(box_regression.view(-1, 4),
                                              concat_anchors.view(-1, 4))

            proposals = proposals.view(N, -1, 4)

        # handle non-optimized path without changing loop
        if not use_fast_cuda_path:
            keep = [None for _ in range(num_images)]

        # TODO optimize / make batch friendly
        sampled_bboxes = []
        for proposal, score, im_shape, k in zip(proposals, objectness,
                                                image_shapes, keep):
            height, width = im_shape

            if proposal.dim() == 0:
                # TODO check what to do here
                # sampled_proposals.append(proposal.new())
                # sampled_scores.append(score.new())
                print("skipping")
                continue

            if False:  # currently slower
                # TODO: Don't do this, generate k directly in bytes
                k = k.byte()
                proposal = proposal[k, :]
                score = score[k]

                # perform NMS - returns index mask of kept boxes
                if self.nms_thresh > 0:
                    keep_mask = C.nms_gpu_upright(proposal, pre_nms_top_n,
                                                  self.nms_thresh)

                # keep map should still be ordered by score - keep only the post_nms_top_n entries
                if self.post_nms_top_n > 0:
                    keep_mask = keep_mask[:self.post_nms_top_n]

                # keep only selected boxes & scores
                keep_mask = keep_mask.long()
                p = proposal[keep_mask, :]
                score = score[keep_mask]
            else:
                if use_fast_cuda_path:
                    k = k.byte()
                    p = proposal.masked_select(k[:, None]).view(-1, 4)
                    score = score.masked_select(k)
                if self.nms_thresh > 0:
                    keep = box_nms(p, score, self.nms_thresh)
                    if self.post_nms_top_n > 0:
                        keep = keep[:self.post_nms_top_n]
                    p = p.index_select(0, keep)
                    score = score.index_select(0, keep)

            # Common code path
            sampled_bbox = BoxList(p, (width, height), mode="xyxy")
            sampled_bbox.add_field("objectness", score)
            sampled_bboxes.append(sampled_bbox)
            # TODO maybe also copy the other fields that were originally present?

        return sampled_bboxes