コード例 #1
0
    def scale_back_batch(self, bboxes_in, scores_in):
        """
            Do scale and transform from xywh to ltrb
            suppose input Nx4xnum_bbox Nxlabel_numxnum_bbox
        """
        if bboxes_in.device == torch.device("cpu"):
            self.dboxes = self.dboxes.cpu()
            self.dboxes_xywh = self.dboxes_xywh.cpu()
        else:
            self.dboxes = self.dboxes.cuda()
            self.dboxes_xywh = self.dboxes_xywh.cuda()

        bboxes_in = bboxes_in.permute(0, 2, 1)
        scores_in = scores_in.permute(0, 2, 1)

        bboxes_in[:, :, :2] = self.scale_xy * bboxes_in[:, :, :2]
        bboxes_in[:, :, 2:] = self.scale_wh * bboxes_in[:, :, 2:]

        bboxes_in[:, :, :
                  2] = bboxes_in[:, :, :
                                 2] * self.dboxes_xywh[:, :,
                                                       2:] + self.dboxes_xywh[:, :, :
                                                                              2]
        bboxes_in[:, :,
                  2:] = bboxes_in[:, :, 2:].exp() * self.dboxes_xywh[:, :, 2:]
        bboxes_in = box_convert(bboxes_in, in_fmt="cxcywh", out_fmt="xyxy")

        return bboxes_in, F.softmax(scores_in, dim=-1)
コード例 #2
0
ファイル: modules.py プロジェクト: NNHieu/Learn_SSD
    def forward(self, b_pred_prior_offsets: Tensor,
                b_pred_prior_classes: Tensor, b_anchors, b_labels):
        """

        """
        # print("SSDLoss - prior device:", self.prior_anchor.device)
        # self.prior_anchor = self.prior_anchor.to(b_gbs[0].device)
        device = b_pred_prior_classes.device
        batch_size = b_pred_prior_offsets.size(0)
        true_locs = torch.ones(
            (batch_size, self.bencoder.nboxes, 4),
            dtype=torch.float,
            device=device) * self.bencoder.cxcywh_dboxes.unsqueeze(
                0)  # (N, 8732, 4)
        true_classes = torch.zeros((batch_size, self.bencoder.nboxes),
                                   dtype=torch.long,
                                   device=device)  # (N, 8732)

        for i, (anchors, labels) in enumerate(zip(b_anchors, b_labels)):
            positive_map, positive_set = self.bencoder.matching(
                anchors, threshold=self.threshold)
            true_classes[i, positive_map] = labels[positive_set]
            true_locs[i, positive_map] = box_convert(anchors[positive_set],
                                                     'xyxy', 'cxcywh')
            true_locs[i] = self.bencoder.cxcy_to_gcxgcy(true_locs[i])
        positive_map = true_classes > 0  # (N, 8732)
        loc_loss = self.sl1(b_pred_prior_offsets[positive_map],
                            true_locs[positive_map])

        nclasses = b_pred_prior_classes.size(-1)
        # Number of positive and hard-negative priors per image
        n_positives = positive_map.sum(dim=1)  # (N)
        n_hard_negatives = self.neg_pos_ratio * n_positives  # (N)

        conf_loss_all = self.crossent(b_pred_prior_classes.view(-1, nclasses),
                                      true_classes.view(-1))
        conf_loss_all = conf_loss_all.view(batch_size, -1)  # (N, 8732)

        conf_loss_pos = conf_loss_all[positive_map]  # (sum(n_positives))

        conf_loss_neg = conf_loss_all.clone()  # (N, 8732)
        conf_loss_neg[
            positive_map] = 0.  # (N, 8732), positive priors are ignored
        conf_loss_neg, _ = conf_loss_neg.sort(
            dim=1, descending=True)  # (N, 8732), sorted by decreasing hardness
        hardness_ranks = torch.arange(0,
                                      self.bencoder.nboxes,
                                      step=1,
                                      dtype=torch.int,
                                      device=device).unsqueeze(0).expand_as(
                                          conf_loss_neg)  # (N, 8732)
        hard_negatives = hardness_ranks < n_hard_negatives.unsqueeze(
            1)  # (N, 8732)
        conf_loss_hard_neg = conf_loss_neg[
            hard_negatives]  # (sum(n_hard_negatives))

        conf_loss = (conf_loss_hard_neg.sum() + conf_loss_pos.sum()
                     ) / n_positives.sum().float()  # (), scalar

        return conf_loss, loc_loss
コード例 #3
0
ファイル: utils.py プロジェクト: uvipen/SSD-pytorch
    def __init__(self, fig_size, feat_size, steps, scales, aspect_ratios, scale_xy=0.1, scale_wh=0.2):

        self.feat_size = feat_size
        self.fig_size = fig_size

        self.scale_xy = scale_xy
        self.scale_wh = scale_wh

        self.steps = steps
        self.scales = scales

        fk = fig_size / np.array(steps)
        self.aspect_ratios = aspect_ratios

        self.default_boxes = []
        for idx, sfeat in enumerate(self.feat_size):

            sk1 = scales[idx] / fig_size
            sk2 = scales[idx + 1] / fig_size
            sk3 = sqrt(sk1 * sk2)
            all_sizes = [(sk1, sk1), (sk3, sk3)]

            for alpha in aspect_ratios[idx]:
                w, h = sk1 * sqrt(alpha), sk1 / sqrt(alpha)
                all_sizes.append((w, h))
                all_sizes.append((h, w))
            for w, h in all_sizes:
                for i, j in itertools.product(range(sfeat), repeat=2):
                    cx, cy = (j + 0.5) / fk[idx], (i + 0.5) / fk[idx]
                    self.default_boxes.append((cx, cy, w, h))

        self.dboxes = torch.tensor(self.default_boxes, dtype=torch.float)
        self.dboxes.clamp_(min=0, max=1)
        self.dboxes_ltrb = box_convert(self.dboxes, in_fmt="cxcywh", out_fmt="xyxy")
コード例 #4
0
ファイル: modules.py プロジェクト: NNHieu/Learn_SSD
    def __init__(self):
        super(BoxEncoder, self).__init__()
        interested_k = self.fmap_dims.keys()
        configs = [(self.fmap_dims[k], self.obj_scales[k],
                    self.aspect_ratios[k]) for k in interested_k]
        prior_boxes = []

        for k, (fmap_dim, scale, ratios) in enumerate(configs):
            # if fmap_dim != 5: continue
            for i in range(fmap_dim):
                for j in range(fmap_dim):
                    cx = (j + 0.5) / fmap_dim
                    cy = (i + 0.5) / fmap_dim
                    for ratio in ratios:
                        prior_boxes.append(
                            [cx, cy, scale * sqrt(ratio), scale / sqrt(ratio)])
                        if ratio == 1:
                            try:
                                next_scale = configs[k + 1][1]
                                additional_scale = sqrt(scale * next_scale)
                            except IndexError:
                                additional_scale = 1.
                            prior_boxes.append(
                                [cx, cy, additional_scale, additional_scale])
        prior_boxes = torch.FloatTensor(prior_boxes)
        prior_boxes.clamp_(0, 1)  # (8732, 4)

        self.cxcywh_dboxes = nn.Parameter(prior_boxes, requires_grad=False)
        self.xyxy_dboxes = nn.Parameter(box_convert(prior_boxes, 'cxcywh',
                                                    'xyxy'),
                                        requires_grad=False)
        self.nboxes = self.cxcywh_dboxes.size(0)
コード例 #5
0
    def encode(self, bboxes_in, labels_in, criteria=0.5):
        # 1x8732 tensor where each value is iou of bbox with dbox
        ious = box_iou(bboxes_in, self.dboxes)
        # 1x8732 (best iou in each column), 1x8732 (index of best iou in each column (0))
        best_dbox_ious, best_dbox_idx = ious.max(dim=0)
        # 1x1 (best iou in each row), 1x1 (index of best iou in each row (0))
        _, best_bbox_idx = ious.max(dim=1)

        # sets best iou 2.0
        best_dbox_ious.index_fill_(0, best_bbox_idx, 2.0)

        # tensor([0])
        idx = torch.arange(0, best_bbox_idx.size(0), dtype=torch.int64)
        best_dbox_idx[best_bbox_idx[idx]] = idx

        # filter IoU > 0.5
        masks = best_dbox_ious > criteria
        labels_out = torch.zeros(self.nboxes, dtype=torch.long)  # 1x8732
        # put class id on boxes with IoU > 0.5
        labels_out[masks] = labels_in[best_dbox_idx[masks]]

        bboxes_out = self.dboxes.clone()  # 8732x4
        bboxes_out[masks, :] = bboxes_in[best_dbox_idx[masks], :]
        bboxes_out = box_convert(bboxes_out, in_fmt="xyxy", out_fmt="cxcywh")
        return bboxes_out, labels_out
コード例 #6
0
def crop(img: Image, target: Dict[str, Any],
         region: Tuple[int]) -> Tuple[Image, Dict[str, Any]]:
    """
    Args:
        region: [Top, Left, H, W]
    """
    # crop image
    src_w, src_h = img.size
    img = TF.crop(img, *region)

    target = deepcopy(target)
    top, left, h, w = region

    # set new image size
    if "size" in target.keys():
        target["size"] = (h, w)

    fields: List[str] = list()
    for k, v in target.items():
        if isinstance(v, Tensor):
            fields.append(k)

    # crop bounding boxes
    if "boxes" in target:
        boxes = target["boxes"]
        boxes[:, [0, 2]] *= src_w
        boxes[:, [1, 3]] *= src_h
        boxes = box_op.box_convert(boxes, "cxcywh", "xyxy")
        boxes -= torch.tensor([left, top, left, top])
        boxes = box_op.clip_boxes_to_image(boxes, (h, w))
        keep = box_op.remove_small_boxes(boxes, 1)
        boxes[:, [0, 2]] /= w
        boxes[:, [1, 3]] /= h
        boxes = box_op.box_convert(boxes, "xyxy", "cxcywh")
        target["boxes"] = boxes
        for field in fields:
            target[field] = target[field][keep]

    if "masks" in target:
        target['masks'] = target['masks'][:, top:top + h, left:left + w]
        keep = target['masks'].flatten(1).any(1)
        for field in fields:
            target[field] = target[field][keep]

    return img, target
コード例 #7
0
def pad_bottom_right(img: Image, target: Dict[str, Any],
                     padding: Tuple[int, int]) -> Tuple[Image, Dict[str, Any]]:
    # assumes that we only pad on the bottom right corners
    w, h = img.size
    img = TF.pad(img, (0, 0, padding[0], padding[1]))
    target = deepcopy(target)
    target["size"] = (img.size[1], img.size[0])

    if "boxes" in target:
        bboxes = box_op.box_convert(target["boxes"], "cxcywh", "xyxy")
        x_ratio = w / (w + padding[0])
        y_ratio = h / (h + padding[1])
        bboxes *= torch.tensor([x_ratio, y_ratio, x_ratio, y_ratio])
        target["boxes"] = box_op.box_convert(bboxes, "xyxy", "cxcywh")
    if "masks" in target:
        target['masks'] = TF.pad(target['masks'],
                                 (0, 0, padding[0], padding[1]))

    return img, target
コード例 #8
0
    def __call__(image, target=None):

        if target is None:
            return image, None
        target = target.copy()
        h, w = image.shape[-2:]
        if "boxes" in target:
            boxes = target["boxes"]
            boxes = box_convert(boxes, in_fmt='xyxy', out_fmt='cxcywh')
            boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32)
            target["boxes"] = boxes
        return image, target
コード例 #9
0
def overlay_boxes(detections, path, time_consume, args):

    img = cv2.imread(path) if args.save_img else None

    for i, pred in enumerate(detections):  # detections per image
        det_logs = ""
        save_path = Path(args.output_dir).joinpath(Path(path).name)
        txt_path = Path(args.output_dir).joinpath(Path(path).stem)

        if pred is not None and len(pred) > 0:
            # Rescale boxes from img_size to im0 size
            boxes, scores, labels = (
                pred["boxes"].round(),
                pred["scores"],
                pred["labels"],
            )

            # Print results
            for c in labels.unique():
                n = (labels == c).sum()  # detections per class
                det_logs += "%g %ss, " % (n, args.names[int(c)]
                                          )  # add to string

            # Write results
            for xyxy, conf, cls_name in zip(boxes, scores, labels):
                if args.save_txt:  # Write to file
                    # normalized cxcywh
                    cxcywh = box_convert(xyxy, in_fmt="xyxy",
                                         out_fmt="cxcywh").tolist()
                    with open(f"{txt_path}.txt", "a") as f:
                        f.write(("%g " * 5 + "\n") %
                                (cls_name, *cxcywh))  # label format

                if args.save_img:  # Add bbox to image
                    label = "%s %.2f" % (args.names[int(cls_name)], conf)
                    plot_one_box(
                        xyxy,
                        img,
                        label=label,
                        color=args.colors[int(cls_name) % len(args.colors)],
                        line_thickness=3,
                    )

        # Print inference time
        logger.info("%sDone. (%.3fs)" % (det_logs, time_consume))

        # Save results (image with detections)
        if args.save_img:
            cv2.imwrite(str(save_path), img)

    return (boxes.tolist(), scores.tolist(), labels.tolist())
コード例 #10
0
    def __getitem__(self, index: int) -> Tuple[Tensor, Dict[str, Any]]:
        """
        Return image and target where target is a dictionary e.g.
            target: {
                image_id: str or int,
                orig_size: original image size (h, w)
                size: image size after transformation (h, w)
                boxes: relative bounding box for each object in the image (cx, cy, w, h)
                    normalized to [0, 1]
                labels: label for each bounding box
                *OTHER_INFO*: other information
            }

        Warning: after transformation, the number of bounding box of one image could be ZERO
        """
        img = pil_loader(self.images[index])
        img_w, img_h = img.size

        annotation = self.get_annotation(index)

        target: Dict[str, Any] = {
            "image_id": self.get_img_id(index),
            "orig_size": (img_h, img_w),
            "size": (img_h, img_w)
        }
        target.update(annotation)

        bbox_labels = target["labels"]
        bboxes: Tensor = target["boxes"]
        assert bboxes.shape[1] == 4 and bboxes.ndim == 2, "bound box must have shape: [n, 4]"
        # convert (xyxy)
        bboxes = box_ops.box_convert(bboxes, "xywh", "cxcywh")
        # normalize
        bboxes[:, (0, 2)] /= img_w
        bboxes[:, (1, 3)] /= img_h
        # bbox must not larger than image
        bboxes.clamp_(0, 1)

        target["boxes"] = bboxes
        target["labels"] = bbox_labels

        if self.augmentations is not None:
            img, target = self.augmentations(img, target)

        if self.resize is not None:
            img = TF.resize(img, self.resize)
            target["size"] = self.resize
        img = TF.to_tensor(img)
        img = TF.normalize(img, self.dataset_mean, self.dataset_std, inplace=True)
        return img, target
コード例 #11
0
def vis_bbox(img: Tensor,
             bboxes: Tensor,
             bbox_labels: LongTensor,
             img_size: Tuple[int],
             dataset_mean: Tuple[int],
             dataset_std: Tuple[int],
             bbox_fmt: str = "cxcywh",
             save_fp: Optional[str] = None) -> np.ndarray:
    """
    Return: opencv type image
    """
    un_normalize = UnNormalize(dataset_mean, dataset_std)
    img = un_normalize(img) * 255
    bboxes = box_ops.box_convert(bboxes, bbox_fmt, "xyxy")
    img = draw_img_preds(img, bboxes, bbox_labels, img_size)
    if save_fp is not None:
        cv2.imwrite(save_fp, img)
    return img
コード例 #12
0
ファイル: pil_draw.py プロジェクト: NNHieu/Learn_SSD
def draw_boxes(boxes,
               image=None,
               draw=None,
               thickness=4,
               color="#00ff00",
               boxes_format='xyxy'):
    if draw is None:
        draw = ImageDraw.Draw(image)
    im_width, im_height = image.size
    if boxes_format != 'xyxy':
        boxes = box_convert(boxes, boxes_format, 'xyxy')
    for box in boxes:
        (left, top, right, bottom) = (box[0] * im_width, box[2] * im_width,
                                      box[1] * im_height, box[3] * im_height)
        draw.line([(left, top), (left, bottom), (right, bottom), (right, top),
                   (left, top)],
                  width=thickness,
                  fill=color)
    return image
コード例 #13
0
    def encode(self, bboxes_in, labels_in, criteria=0.5):
        ious = box_iou(bboxes_in, self.dboxes)
        best_dbox_ious, best_dbox_idx = ious.max(dim=0)
        best_bbox_ious, best_bbox_idx = ious.max(dim=1)

        # set best ious 2.0
        best_dbox_ious.index_fill_(0, best_bbox_idx, 2.0)

        idx = torch.arange(0, best_bbox_idx.size(0), dtype=torch.int64)
        best_dbox_idx[best_bbox_idx[idx]] = idx

        # filter IoU > 0.5
        masks = best_dbox_ious > criteria
        labels_out = torch.zeros(self.nboxes, dtype=torch.long)
        labels_out[masks] = labels_in[best_dbox_idx[masks]]
        bboxes_out = self.dboxes.clone()
        bboxes_out[masks, :] = bboxes_in[best_dbox_idx[masks], :]
        bboxes_out = box_convert(bboxes_out, in_fmt="xyxy", out_fmt="cxcywh")
        return bboxes_out, labels_out
コード例 #14
0
ファイル: pil_draw.py プロジェクト: NNHieu/Learn_SSD
 def draw_preds(self,
                image,
                boxes,
                labels,
                conf_scores=None,
                boxes_format='xyxy'):
     """Overlay labeled boxes on an image with formatted scores and label names."""
     if boxes_format != 'xyxy':
         boxes = box_convert(boxes, boxes_format, 'xyxy')
     for i in range(len(boxes)):
         class_name = self.CLASSES[labels[i]]
         display_str = "{}".format(
             class_name)  # class_names[i].decode("ascii"),
         if conf_scores is not None:
             display_str += ": {:.2f}%".format(conf_scores[i] * 100)
         color = self.colors[hash(class_name) % len(self.colors)]
         draw_bounding_box_on_image(image,
                                    tuple(boxes[i]),
                                    color,
                                    self.font,
                                    display_str_list=[display_str])
     return image
コード例 #15
0
    def __getitem__(self, idx):
        data = self._get_img_data_by_idx(idx)
        img_path = os.path.join(self.img_dir, data.iloc[0]['file_name'])
        img_id = data.iloc[0]['image_id']
        img = Image.open(img_path).convert("RGB")

        target = {
            'boxes':
            box_convert(
                torch.from_numpy(np.array(data['bbox'].values.tolist())),
                'xywh', 'xyxy'),
            'labels':
            torch.from_numpy(data['category_id'].values),
            'image_id':
            torch.tensor([img_id]),
            'area':
            torch.from_numpy(data['area'].values),
            'iscrowd':
            torch.from_numpy(data['iscrowd'].values)
        }

        if self.img_size:
            h, w = data.iloc[0]['height'], data.iloc[0]['width']
            img = F.resize(img, self.img_size)
            # scaling x bbox
            target['boxes'][:, (0, 2)] = torch.round(
                target['boxes'][:, (0, 2)] * self.img_size[1] / w)
            # scaling y bbox
            target['boxes'][:, (1, 3)] = torch.round(
                target['boxes'][:, (1, 3)] * self.img_size[0] / h)
            # adjust area
            target['area'] = (target['boxes'][:, 2] -
                              target['boxes'][:, 0]) * (target['boxes'][:, 3] -
                                                        target['boxes'][:, 1])

        if self.transforms:
            img = self.transforms(img)

        return img, target
コード例 #16
0
ファイル: pil_draw.py プロジェクト: NNHieu/Learn_SSD
def draw_preds(image,
               boxes,
               labels,
               conf_scores,
               class_list,
               boxes_format='xyxy'):
    """Overlay labeled boxes on an image with formatted scores and label names."""
    colors = list(ImageColor.colormap.values())
    try:
        if platform.system() == 'Darwin':
            font = ImageFont.truetype("/System/Library/Fonts/NewYork.ttf", 17)
        else:
            font = ImageFont.truetype(
                "/usr/share/fonts/truetype/liberation/LiberationSansNarrow-Regular.ttf",
                18)
    except IOError:
        print("Font not found, using default font.")
        font = ImageFont.load_default()
    if boxes_format != 'xyxy':
        boxes = box_convert(boxes, boxes_format, 'xyxy')
    for i in range(len(boxes)):
        if labels[i] == 0: continue
        #   image_pil = Image.fromarray(np.uint8(image)).convert("RGB")
        class_name = class_list[labels[i] - 1]
        display_str = "{}: {:.2f}%".format(
            # class_names[i].decode("ascii"),
            class_name,
            conf_scores[i] * 100)
        color = colors[hash(class_name) % len(colors)]
        draw_bounding_box_on_image(image,
                                   tuple(boxes[i]),
                                   color,
                                   font,
                                   display_str_list=[display_str])
        #   np.copyto(image, np.array(image_pil))
    return image
コード例 #17
0
def parse_single_target(target):
    boxes = box_convert(target['boxes'], in_fmt="cxcywh", out_fmt="xyxy")
    boxes = to_numpy(boxes)
    sizes = np.tile(to_numpy(target['size'])[1::-1], 2)
    boxes = boxes * sizes
    return boxes
コード例 #18
0
ファイル: modules.py プロジェクト: NNHieu/Learn_SSD
    def detect_objects(self, predicted_locs, predicted_scores, min_score,
                       max_overlap, top_k):
        """
        Decipher the 8732 locations and class scores (output of ths SSD300) to detect objects.

        For each class, perform Non-Maximum Suppression (NMS) on boxes that are above a minimum threshold.

        :param predicted_locs: predicted locations/boxes w.r.t the 8732 prior boxes, a tensor of dimensions (N, 8732, 4)
        :param predicted_scores: class scores for each of the encoded locations/boxes, a tensor of dimensions (N, 8732, n_classes)
        :param min_score: minimum threshold for a box to be considered a match for a certain class
        :param max_overlap: maximum overlap two boxes can have so that the one with the lower score is not suppressed via NMS
        :param top_k: if there are a lot of resulting detection across all classes, keep only the top 'k'
        :return: detections (boxes, labels, and scores), lists of length batch_size
        """
        batch_size = predicted_locs.size(0)
        predicted_scores = F.softmax(predicted_scores,
                                     dim=2)  # (N, 8732, n_classes)
        n_classes = predicted_scores.size(-1)
        device = predicted_locs.device

        # Lists to store final predicted boxes, labels, and scores for all images
        all_images_boxes = list()
        all_images_labels = list()
        all_images_scores = list()

        for i in range(batch_size):
            # print('detect ', i)
            # Decode object coordinates from the form we regressed predicted boxes to
            decoded_locs = box_convert(
                self.gcxgcy_to_cxcy(predicted_locs[i]), 'cxcywh',
                'xyxy')  # (8732, 4), these are fractional pt. coordinates

            # Lists to store boxes and scores for this image
            image_boxes = list()
            image_labels = list()
            image_scores = list()

            # Check for each class
            for c in range(1, n_classes):
                # print('class', c)
                # Keep only predicted boxes and scores where scores for this class are above the minimum score
                class_scores = predicted_scores[i][:, c]  # (8732)
                score_above_min_score = class_scores > min_score  # torch.uint8 (byte) tensor, for indexing
                n_above_min_score = score_above_min_score.sum().item()
                if n_above_min_score == 0:
                    continue
                class_scores = class_scores[
                    score_above_min_score]  # (n_qualified), n_min_score <= 8732
                class_decoded_locs = decoded_locs[
                    score_above_min_score]  # (n_qualified, 4)

                # Sort predicted boxes and scores by scores
                class_scores, sort_ind = class_scores.sort(
                    dim=0, descending=True)  # (n_qualified), (n_min_score)
                class_decoded_locs = class_decoded_locs[
                    sort_ind]  # (n_min_score, 4)

                # Find the overlap between predicted boxes
                overlap = box_iou(
                    class_decoded_locs,
                    class_decoded_locs)  # (n_qualified, n_min_score)

                # Non-Maximum Suppression (NMS)

                # A torch.uint8 (byte) tensor to keep track of which predicted boxes to suppress
                # 1 implies suppress, 0 implies don't suppress
                suppress = torch.zeros((n_above_min_score),
                                       dtype=torch.bool,
                                       device=device)  # (n_qualified)

                # Consider each box in order of decreasing scores
                for box in range(class_decoded_locs.size(0)):
                    # If this box is already marked for suppression
                    if suppress[box]:
                        continue

                    # Suppress boxes whose overlaps (with this box) are greater than maximum overlap
                    # Find such boxes and update suppress indices
                    suppress |= (overlap[box] > max_overlap)
                    # The max operation retains previously suppressed boxes, like an 'OR' operation

                    # Don't suppress this box, even though it has an overlap of 1 with itself
                    suppress[box] = False

                # Store only unsuppressed boxes for this class
                non_suppress = ~suppress
                image_boxes.append(class_decoded_locs[non_suppress])
                image_labels.append(
                    torch.LongTensor(
                        (non_suppress).sum().item() * [c]).to(device))
                image_scores.append(class_scores[non_suppress])

            # If no object in any class is found, store a placeholder for 'background'
            if len(image_boxes) == 0:
                image_boxes.append(
                    torch.FloatTensor([[0., 0., 1., 1.]]).to(device))
                image_labels.append(torch.LongTensor([0]).to(device))
                image_scores.append(torch.FloatTensor([0.]).to(device))

            # Concatenate into single tensors
            image_boxes = torch.cat(image_boxes, dim=0)  # (n_objects, 4)
            image_labels = torch.cat(image_labels, dim=0)  # (n_objects)
            image_scores = torch.cat(image_scores, dim=0)  # (n_objects)
            n_objects = image_scores.size(0)

            # Keep only the top k objects
            if n_objects > top_k:
                image_scores, sort_ind = image_scores.sort(dim=0,
                                                           descending=True)
                image_scores = image_scores[:top_k]  # (top_k)
                image_boxes = image_boxes[sort_ind][:top_k]  # (top_k, 4)
                image_labels = image_labels[sort_ind][:top_k]  # (top_k)

            # Append to lists that store predicted boxes and scores for all images
            all_images_boxes.append(image_boxes)
            all_images_labels.append(image_labels)
            all_images_scores.append(image_scores)

        return all_images_boxes, all_images_labels, all_images_scores  # lists of length batch_size