def _get_imitation_mask(self, x, gt_boxes, iou_factor=0.5):
        gt_box: (B, K, 4) [x_min, y_min, x_max, y_max]
        out_size = x.size(2)
        batch_size = x.size(0)

        center_anchors = make_center_anchors(anchors_wh=self.anchors, grid_size=out_size)
        anchors = center_to_corner(center_anchors).view(out_size * out_size * 5, 4)  # (N, 4)
        gt_boxes = gt_boxes

        mask_batch = torch.zeros([batch_size, out_size, out_size])

        for i in range(batch_size):
            num_obj = gt_boxes[i].size(0)
            if not num_obj:

            IOU_map = find_jaccard_overlap(anchors, gt_boxes[i] * float(out_size), 0).view(out_size, out_size, self.num_anchors, num_obj)
            max_iou, _ = IOU_map.view(-1, num_obj).max(dim=0)
            mask_img = torch.zeros([out_size, out_size], dtype=torch.int64, requires_grad=False).type_as(x)
            threshold: torch.Tensor = max_iou * iou_factor

            for k in range(num_obj):

                mask_per_gt = torch.sum(IOU_map[:, :, :, k] > threshold[k], dim=2)

                mask_img += mask_per_gt

                mask_img += mask_img
            mask_batch[i] = mask_img

        mask_batch = mask_batch.clamp(0, 1)
        return mask_batch  # (B, h, w)
Exemplo n.º 2
    def match_gt_priors(self, boxes, labels):
        ''' Given gt boxes, labels and (8732) priors, match them into the most suited priors
        N: batch size
            boxes: true object bounding boxes in boundary coordinates, (xy), a list of N tensors: N(n_objects, 4)
            labels: true object labels, a list of N tensors: N(n_objects,)
            truth_offsets: tensor (N, 8732, 4)
            truth_classes: tensor (N, 8732,)
        N = len(boxes) #batch size
        n_priors = self.priors_cxcy.size(0)
        # print(n_priors)
        truth_offsets = torch.zeros((N, n_priors, 4), dtype=torch.float).to(device)
        truth_classes = torch.zeros((N, n_priors), dtype=torch.long).to(device)
        # for each image
        for i in range(N):
            n_objects = labels[i].shape[0]

            overlap = find_jaccard_overlap(self.priors_xy, boxes[i]) #(n_priors, n_boxes)
            # print(overlap, overlap.shape)
            # for each prior, find the max iou and the coresponding object id
            prior_iou, prior_obj = overlap.max(dim=1) #(n_priors)
            # print(prior_iou, prior_obj)
            # for each object, find the most suited prior id
            _, object_prior = overlap.max(dim=0) #(n_objects)
            # print(_, object_prior)
            # for each object, assign its most suited prior with object id 
            for j in range(n_objects): prior_obj[object_prior[j]] = j
            # for each object, assign its most suited prior with hight iou to ensure it qualifies the thresholding 
            prior_iou[object_prior] = 1.
            # match bbox coordinates
            boxes_xy = boxes[i][prior_obj] # (8732, 4)
            # print(boxes[0].shape, prior_obj, boxes_xy.shape)
            # match prior class
            prior_class = labels[i][prior_obj]  # (8732)
            # thresholding: assign prior with iou < threshold to the class 0: background
            prior_class[prior_iou < self.threshold] = 0
            # save into the truth tensors
            truth_offsets[i,:,:] = cxcy_to_gcxgcy(xy_to_cxcy(boxes_xy), self.priors_cxcy)
            truth_classes[i,:] = prior_class
        return truth_offsets, truth_classes
Exemplo n.º 3
    def forward(self, predicted_locs, predicted_scores, boxes, labels, device):
        Forward propagation.

        :param predicted_locs: predicted locations/boxes w.r.t the 8732 prior boxes, a tensor of dimensions (N, 8732, 4)
        :param predicted_scores: class scores for each of the encoded locations/boxes, a tensor of dimensions (N, 8732, n_classes)
        :param boxes: true  object bounding boxes in boundary coordinates, a list of N tensors
        :param labels: true object labels, a list of N tensors
        :return: multibox loss, a scalar
        batch_size = predicted_locs.size(0)
        n_priors = self.priors_cxcy.size(0)
        n_classes = predicted_scores.size(2)

        assert n_priors == predicted_locs.size(1) == predicted_scores.size(1)

        true_locs = torch.zeros((batch_size, n_priors, 4),
                                dtype=torch.float).to(device)  # (N, 8732, 4)
        true_classes = torch.zeros((batch_size, n_priors),
                                   dtype=torch.long).to(device)  # (N, 8732)

        # For each image
        for i in range(batch_size):
            n_objects = boxes[i].size(0)

            overlap = find_jaccard_overlap(boxes[i],
                                           self.priors_xy)  # (n_objects, 8732)

            # For each prior, find the object that has the maximum overlap
            overlap_for_each_prior, object_for_each_prior = overlap.max(
                dim=0)  # (8732)

            # We don't want a situation where an object is not represented in our positive (non-background) priors -
            # 1. An object might not be the best object for all priors, and is therefore not in object_for_each_prior.
            # 2. All priors with the object may be assigned as background based on the threshold (0.5).

            # To remedy this -
            # First, find the prior that has the maximum overlap for each object.
            _, prior_for_each_object = overlap.max(dim=1)  # (N_o)

            # Then, assign each object to the corresponding maximum-overlap-prior. (This fixes 1.)
            object_for_each_prior[prior_for_each_object] = torch.LongTensor(

            # To ensure these priors qualify, artificially give them an overlap of greater than 0.5. (This fixes 2.)
            overlap_for_each_prior[prior_for_each_object] = 1.

            # Labels for each prior
            label_for_each_prior = labels[i][object_for_each_prior]  # (8732)
            # Set priors whose overlaps with objects are less than the threshold to be background (no object)
                overlap_for_each_prior < self.threshold] = 0  # (8732)

            # Store
            true_classes[i] = label_for_each_prior

            # Encode center-size object coordinates into the form we regressed predicted boxes to
            true_locs[i] = cxcy_to_gcxgcy(
                self.priors_cxcy)  # (8732, 4)

        # Identify priors that are positive (object/non-background)
        positive_priors = true_classes != 0  # (N, 8732)


        # Localization loss is computed only over positive (non-background) priors
        loc_loss = self.smooth_l1(predicted_locs[positive_priors],
                                  true_locs[positive_priors])  # (), scalar

        # Note: indexing with a torch.uint8 (byte) tensor flattens the tensor when indexing is across multiple dimensions (N & 8732)
        # So, if predicted_locs has the shape (N, 8732, 4), predicted_locs[positive_priors] will have (total positives, 4)


        # Confidence loss is computed over positive priors and the most difficult (hardest) negative priors in each image
        # That is, FOR EACH IMAGE,
        # we will take the hardest (neg_pos_ratio * n_positives) negative priors, i.e where there is maximum loss
        # This is called Hard Negative Mining - it concentrates on hardest negatives in each image, and also minimizes pos/neg imbalance

        # Number of positive and hard-negative priors per image
        n_positives = positive_priors.sum(dim=1)  # (N)
        n_hard_negatives = self.neg_pos_ratio * n_positives  # (N)

        # First, find the loss for all priors
        conf_loss_all = self.cross_entropy(predicted_scores.view(
            -1, n_classes), true_classes.view(-1))  # (N * 8732)
        conf_loss_all = conf_loss_all.view(batch_size, n_priors)  # (N, 8732)

        # We already know which priors are positive
        conf_loss_pos = conf_loss_all[positive_priors]  # (sum(n_positives))

        # Next, find which priors are hard-negative
        # To do this, sort ONLY negative priors in each image in order of decreasing loss and take top n_hard_negatives
        conf_loss_neg = conf_loss_all.clone()  # (N, 8732)
            positive_priors] = 0.  # (N, 8732), positive priors are ignored (never in top n_hard_negatives)
        conf_loss_neg, _ = conf_loss_neg.sort(
            dim=1, descending=True)  # (N, 8732), sorted by decreasing hardness
        hardness_ranks = torch.LongTensor(
                device)  # (N, 8732)
        hard_negatives = hardness_ranks < n_hard_negatives.unsqueeze(
            1)  # (N, 8732)
        conf_loss_hard_neg = conf_loss_neg[
            hard_negatives]  # (sum(n_hard_negatives))

        # As in the paper, averaged over positive priors only, although computed over both positive and hard-negative priors
        conf_loss = (conf_loss_hard_neg.sum() + conf_loss_pos.sum()
                     ) / n_positives.sum().float()  # (), scalar

        # TOTAL LOSS

        return conf_loss + self.alpha * loc_loss
Exemplo n.º 4
    def detect_objects(self, predicted_locs, predicted_scores, min_score,
                       max_overlap, top_k, device):
        Decipher the 8732 locations and class scores (output of ths SSD300) to detect objects.

        For each class, perform Non-Maximum Suppression (NMS) on boxes that are above a minimum threshold.

        :param predicted_locs: predicted locations/boxes w.r.t the 8732 prior boxes, a tensor of dimensions (N, 8732, 4)
        :param predicted_scores: class scores for each of the encoded locations/boxes, a tensor of dimensions (N, 8732, n_classes)
        :param min_score: minimum threshold for a box to be considered a match for a certain class
        :param max_overlap: maximum overlap two boxes can have so that the one with the lower score is not suppressed via NMS
        :param top_k: if there are a lot of resulting detection across all classes, keep only the top 'k'
        :return: detections (boxes, labels, and scores), lists of length batch_size
        batch_size = predicted_locs.size(0)
        n_priors = self.priors_cxcy.size(0)
        predicted_scores = F.softmax(predicted_scores,
                                     dim=2)  # (N, 8732, n_classes)

        # Lists to store final predicted boxes, labels, and scores for all images
        all_images_boxes = list()
        all_images_labels = list()
        all_images_scores = list()

        assert n_priors == predicted_locs.size(1) == predicted_scores.size(1)

        for i in range(batch_size):
            # Decode object coordinates from the form we regressed predicted boxes to
            decoded_locs = cxcy_to_xy(
                gcxgcy_to_cxcy(predicted_locs[i], self.priors_cxcy)
            )  # (8732, 4), these are fractional pt. coordinates

            # Lists to store boxes and scores for this image
            image_boxes = list()
            image_labels = list()
            image_scores = list()

            max_scores, best_label = predicted_scores[i].max(dim=1)  # (8732)

            # Check for each class
            for c in range(1, self.n_classes):
                # Keep only predicted boxes and scores where scores for this class are above the minimum score
                class_scores = predicted_scores[i][:, c]  # (8732)
                score_above_min_score = class_scores > min_score  # torch.uint8 (byte) tensor, for indexing
                n_above_min_score = score_above_min_score.sum().item()
                if n_above_min_score == 0:
                class_scores = class_scores[
                    score_above_min_score]  # (n_qualified), n_min_score <= 8732
                class_decoded_locs = decoded_locs[
                    score_above_min_score]  # (n_qualified, 4)

                # Sort predicted boxes and scores by scores
                class_scores, sort_ind = class_scores.sort(
                    dim=0, descending=True)  # (n_qualified), (n_min_score)
                class_decoded_locs = class_decoded_locs[
                    sort_ind]  # (n_min_score, 4)

                # Find the overlap between predicted boxes
                overlap = find_jaccard_overlap(
                    class_decoded_locs)  # (n_qualified, n_min_score)

                # Non-Maximum Suppression (NMS)

                # A torch.uint8 (byte) tensor to keep track of which predicted boxes to suppress
                # 1 implies suppress, 0 implies don't suppress
                suppress = torch.zeros(
                    dtype=torch.uint8).to(device)  # (n_qualified)

                # Consider each box in order of decreasing scores
                for box in range(class_decoded_locs.size(0)):
                    # If this box is already marked for suppression
                    if suppress[box] == 1:

                    # Suppress boxes whose overlaps (with this box) are greater than maximum overlap
                    # Find such boxes and update suppress indices
                    suppress = torch.max(suppress, overlap[box] > max_overlap)
                    # The max operation retains previously suppressed boxes, like an 'OR' operation

                    # Don't suppress this box, even though it has an overlap of 1 with itself
                    suppress[box] = 0

                # Store only unsuppressed boxes for this class
                image_boxes.append(class_decoded_locs[1 - suppress])
                        (1 - suppress).sum().item() * [c]).to(device))
                image_scores.append(class_scores[1 - suppress])

            # If no object in any class is found, store a placeholder for 'background'
            if len(image_boxes) == 0:
                    torch.FloatTensor([[0., 0., 1., 1.]]).to(device))

            # Concatenate into single tensors
            image_boxes =, dim=0)  # (n_objects, 4)
            image_labels =, dim=0)  # (n_objects)
            image_scores =, dim=0)  # (n_objects)
            n_objects = image_scores.size(0)

            # Keep only the top k objects
            if n_objects > top_k:
                image_scores, sort_ind = image_scores.sort(dim=0,
                image_scores = image_scores[:top_k]  # (top_k)
                image_boxes = image_boxes[sort_ind][:top_k]  # (top_k, 4)
                image_labels = image_labels[sort_ind][:top_k]  # (top_k)

            # Append to lists that store predicted boxes and scores for all images

        return all_images_boxes, all_images_labels, all_images_scores  # lists of length batch_size
Exemplo n.º 5
def random_crop(image, boxes, labels):
    Performs a random crop in the manner stated in the paper. Helps to learn to detect larger and partial objects.

    Note that some objects may be cut out entirely.

    Adapted from

    :param image: image, a tensor of dimensions (3, original_h, original_w)
    :param boxes: bounding boxes in boundary coordinates, a tensor of dimensions (n_objects, 4)
    :param labels: labels of objects, a tensor of dimensions (n_objects)
    :return: cropped image, updated bounding box coordinates, updated labels, updated difficulties
    original_h = image.size(1)
    original_w = image.size(2)
    # Keep choosing a minimum overlap until a successful crop is made
    while True:
        # Randomly draw the value for minimum overlap
        min_overlap = random.choice([0., .1, .3, .5, .7, .9,
                                     None])  # 'None' refers to no cropping

        # If not cropping
        if min_overlap is None:
            return image, boxes, labels

        # Try up to 50 times for this choice of minimum overlap
        # This isn't mentioned in the paper, of course, but 50 is chosen in paper authors' original Caffe repo
        max_trials = 50
        for _ in range(max_trials):
            # Crop dimensions must be in [0.3, 1] of original dimensions
            # Note - it's [0.1, 1] in the paper, but actually [0.3, 1] in the authors' repo
            min_scale = 0.3
            scale_h = random.uniform(min_scale, 1)
            scale_w = random.uniform(min_scale, 1)
            new_h = int(scale_h * original_h)
            new_w = int(scale_w * original_w)

            # Aspect ratio has to be in [0.5, 2]
            aspect_ratio = new_h / new_w
            if not 0.5 < aspect_ratio < 2:

            # Crop coordinates (origin at top-left of image)
            left = random.randint(0, original_w - new_w)
            right = left + new_w
            top = random.randint(0, original_h - new_h)
            bottom = top + new_h
            crop = torch.FloatTensor([left, top, right, bottom])  # (4)

            # Calculate Jaccard overlap between the crop and the bounding boxes
            overlap = find_jaccard_overlap(
                crop.unsqueeze(0), boxes
            )  # (1, n_objects), n_objects is the no. of objects in this image
            overlap = overlap.squeeze(0)  # (n_objects)

            # If not a single bounding box has a Jaccard overlap of greater than the minimum, try again
            if overlap.max().item() < min_overlap:

            # Crop image
            new_image = image[:, top:bottom, left:right]  # (3, new_h, new_w)

            # Find centers of original bounding boxes
            bb_centers = (boxes[:, :2] + boxes[:, 2:]) / 2.  # (n_objects, 2)

            # Find bounding boxes whose centers are in the crop
            centers_in_crop = (bb_centers[:, 0] > left) * (
                bb_centers[:, 0] < right
            ) * (bb_centers[:, 1] > top) * (
                bb_centers[:, 1] < bottom
            )  # (n_objects), a Torch uInt8/Byte tensor, can be used as a boolean index

            # If not a single bounding box has its center in the crop, try again
            if not centers_in_crop.any():

            # Discard bounding boxes that don't meet this criterion
            new_boxes = boxes[centers_in_crop, :]
            new_labels = labels[centers_in_crop]

            # Calculate bounding boxes' new coordinates in the crop
            new_boxes[:, :2] = torch.max(new_boxes[:, :2],
                                         crop[:2])  # crop[:2] is [left, top]
            new_boxes[:, :2] -= crop[:2]
                      2:] = torch.min(new_boxes[:, 2:],
                                      crop[2:])  # crop[2:] is [right, bottom]
            new_boxes[:, 2:] -= crop[:2]

            return new_image, new_boxes, new_labels
Exemplo n.º 6
    def make_target(self, gt_boxes, gt_labels, pred_xy, pred_wh):
        :param gt_bboxes:
        :param gt_labels:
        :param pred_xy:
        :param pred_wh:
        :param anchors:
        out_size = pred_xy.size(2)
        batch_size = pred_xy.size(0)
        resp_mask = torch.zeros([batch_size, out_size, out_size,
                                 5])  # y, x, anchor, ~
        gt_xy = torch.zeros([batch_size, out_size, out_size, 5, 2])
        gt_wh = torch.zeros([batch_size, out_size, out_size, 5, 2])
        gt_conf = torch.zeros([batch_size, out_size, out_size, 5])
        gt_cls = torch.zeros(
            [batch_size, out_size, out_size, 5, self.num_classes])

        center_anchors = make_center_anchors(anchors_wh=self.anchors,
        corner_anchors = center_to_corner(center_anchors).view(
            out_size * out_size * 5, 4)

        # 1. make resp_mask
        for b in range(batch_size):

            label = gt_labels[b]
            corner_gt_box = gt_boxes[b]
            corner_gt_box_13 = corner_gt_box * float(out_size)

            center_gt_box = corner_to_center(corner_gt_box)
            center_gt_box_13 = center_gt_box * float(out_size)

            bxby = center_gt_box_13[..., :2]  # [# obj, 2]
            x_y_ = bxby - bxby.floor()  # [# obj, 2], 0~1 scale
            bwbh = center_gt_box_13[..., 2:]

            iou_anchors_gt = find_jaccard_overlap(
                corner_anchors, corner_gt_box_13)  # [845, # obj]
            iou_anchors_gt = iou_anchors_gt.view(out_size, out_size, 5, -1)

            num_obj = corner_gt_box.size(0)

            for n_obj in range(num_obj):
                cx, cy = bxby[n_obj]
                cx = int(cx)
                cy = int(cy)

                _, max_idx = iou_anchors_gt[cy, cx, :, n_obj].max(
                    0)  # which anchor has maximum iou?
                j = max_idx  # j is idx.
                # # j-th anchor
                resp_mask[b, cy, cx, j] = 1
                gt_xy[b, cy, cx, j, :] = x_y_[n_obj]
                w_h_ = bwbh[n_obj] / torch.FloatTensor(self.anchors[j]).to(
                    device)  # ratio
                gt_wh[b, cy, cx, j, :] = w_h_
                gt_cls[b, cy, cx, j, int(label[n_obj].item())] = 1

            pred_xy_ = pred_xy[b]
            pred_wh_ = pred_wh[b]
            center_pred_xy = center_anchors[
                ..., :2].floor() + pred_xy_  # [845, 2] fix floor error
            center_pred_wh = center_anchors[..., 2:] * pred_wh_  # [845, 2]
            center_pred_bbox =[center_pred_xy, center_pred_wh],
            corner_pred_bbox = center_to_corner(center_pred_bbox).view(
                -1, 4)  # [845, 4]

            iou_pred_gt = find_jaccard_overlap(
                corner_pred_bbox, corner_gt_box_13)  # [845, # obj]
            iou_pred_gt = iou_pred_gt.view(out_size, out_size, 5, -1)

            gt_conf[b] = iou_pred_gt.max(-1)[
                0]  # each obj, maximum preds          # [13, 13, 5]

        return resp_mask, gt_xy, gt_wh, gt_conf, gt_cls
    def compute_odm_loss(self, arm_locs, arm_scores, odm_locs, odm_scores,
                         boxes, labels):
        :param arm_locs: serve as "anchor boxes"
        :param arm_scores:
        :param odm_locs:
        :param odm_scores:
        :param boxes:
        :param labels:
        # print(arm_scores.size(), arm_locs.size(), odm_scores.size(), odm_locs.size())
        batch_size = odm_locs.size(0)
        n_priors = self.priors_cxcy.size(0)
        n_classes = odm_scores.size(2)

        # print(n_priors, predicted_locs.size(), predicted_scores.size())
        assert n_priors == odm_locs.size(1) == odm_scores.size(1)

        decoded_arm_locs = torch.zeros((batch_size, n_priors, 4),
        # decoded_odm_locs = torch.zeros((batch_size, n_priors, 4), dtype=torch.float).to(self.device)
        true_locs_encoded = torch.zeros((batch_size, n_priors, 4),
        true_classes = torch.zeros((batch_size, n_priors),

        # For each image
        for i in range(batch_size):
            n_objects = boxes[i].size(0)

            decoded_arm_locs[i] = cxcy_to_xy(
                gcxgcy_to_cxcy(arm_locs[i], self.priors_cxcy))
            overlap = find_jaccard_overlap(boxes[i], decoded_arm_locs[i])

            # For each prior, find the object that has the maximum overlap, return [value, indices]
            overlap_for_each_prior, object_for_each_prior = overlap.max(
                dim=0)  # (22536)

            overlap_for_each_object, prior_for_each_object = overlap.max(
                dim=1)  # (N_o)
            prior_for_each_object = prior_for_each_object[
                overlap_for_each_object > 0]
            # Then, assign each object to the corresponding maximum-overlap-prior. (This fixes 1.)
            if len(prior_for_each_object) > 0:
                overlap_for_each_prior.index_fill_(0, prior_for_each_object,

            for j in range(prior_for_each_object.size(0)):
                object_for_each_prior[prior_for_each_object[j]] = j

            # Labels for each prior
            label_for_each_prior = labels[i][object_for_each_prior]

            # Set priors whose overlaps with objects are less than the threshold to be background (no object)
            label_for_each_prior[overlap_for_each_prior < self.threshold] = 0

            # Store
            true_classes[i] = label_for_each_prior

            # Encode center-size object coordinates into the form we regressed predicted boxes to
            true_locs_encoded[i] = cxcy_to_gcxgcy(

        # Identify priors that are positive (object/non-background)
        positive_priors = true_classes > 0
        # Eliminate easy background bboxes from ARM
        arm_scores_prob = F.softmax(arm_scores, dim=2)
        easy_negative_idx = arm_scores_prob[:, :, 1] < self.theta

        positive_priors = positive_priors & ~easy_negative_idx

        loc_loss = self.odm_loss(
            odm_locs[positive_priors].view(-1, 4),
            true_locs_encoded[positive_priors].view(-1, 4))

        n_positives = positive_priors.sum(dim=1)  # (N)
        n_hard_negatives = self.neg_pos_ratio * n_positives  # (N)

        # First, find the loss for all priors
        conf_loss_all = self.odm_cross_entropy(odm_scores.view(-1, n_classes),
        conf_loss_all = conf_loss_all.view(batch_size, -1)  # (N, 8732)

        # We already know which priors are positive
        conf_loss_pos = conf_loss_all[positive_priors]  # (sum(n_positives))

        conf_loss_neg = conf_loss_all.clone()
        conf_loss_neg[positive_priors] = 0.
        conf_loss_neg[easy_negative_idx] = 0.

        conf_loss_neg, _ = conf_loss_neg.sort(
            descending=True)  # (N, 8732), sorted by decreasing hardness
        hardness_ranks = torch.LongTensor(
        hard_negatives = hardness_ranks < n_hard_negatives.unsqueeze(1)
        conf_loss_hard_neg = conf_loss_neg[hard_negatives]

        conf_loss = (conf_loss_hard_neg.sum() + conf_loss_pos.sum()
                     ) / n_positives.sum().float()  # (), scalar

        # TOTAL LOSS
        return conf_loss + self.alpha * loc_loss
    def compute_arm_loss(self, arm_locs, arm_scores, boxes, labels):
        :param arm_locs: offset prediction from Anchor Refinement Modules
        :param arm_scores: binary classification scores from Anchor Refinement Modules
        :param boxes: gt bbox
        :param labels: gt labels
        batch_size = arm_locs.size(0)
        n_priors = self.priors_cxcy.size(0)
        n_classes = arm_scores.size(2)  # should be 2

        true_locs_encoded = torch.zeros((batch_size, n_priors, 4),
        true_classes = torch.zeros((batch_size, n_priors),

        # For each image
        for i in range(batch_size):

            overlap = find_jaccard_overlap(boxes[i],
                                           self.priors_xy)  # initial overlap

            # For each prior, find the object that has the maximum overlap, return [value, indices]
            overlap_for_each_prior, object_for_each_prior = overlap.max(dim=0)

            overlap_for_each_object, prior_for_each_object = overlap.max(
                dim=1)  # (N_o)
            prior_for_each_object = prior_for_each_object[
                overlap_for_each_object > 0]

            if len(prior_for_each_object) > 0:
                overlap_for_each_prior.index_fill_(0, prior_for_each_object,

            for j in range(prior_for_each_object.size(0)):
                object_for_each_prior[prior_for_each_object[j]] = j

            # Labels for each prior
            label_for_each_prior = labels[i][object_for_each_prior]

            # Set priors whose overlaps with objects are less than the threshold to be background (no object)
            label_for_each_prior[overlap_for_each_prior < self.threshold] = 0

            # Converted labels to 0, 1's
            label_for_each_prior = (label_for_each_prior > 0).long()
            true_classes[i] = label_for_each_prior

            # Encode center-size object coordinates into the form we regressed predicted boxes to
            true_locs_encoded[i] = cxcy_to_gcxgcy(
                xy_to_cxcy(boxes[i][object_for_each_prior]), self.priors_cxcy)

        # Identify priors that are positive (non-background, binary)
        positive_priors = true_classes > 0
        n_positives = positive_priors.sum(dim=1)  # (N)
        loc_loss = self.arm_loss(
            arm_locs[positive_priors].view(-1, 4),
            true_locs_encoded[positive_priors].view(-1, 4))

        n_hard_negatives = self.neg_pos_ratio * n_positives

        # First, find the loss for all priors
        conf_loss_all = self.arm_cross_entropy(arm_scores.view(-1, n_classes),
        conf_loss_all = conf_loss_all.view(batch_size, -1)

        # We already know which priors are positive
        conf_loss_pos = conf_loss_all[positive_priors]

        # Next, find which priors are hard-negative
        # To do this, sort ONLY negative priors in each image in order of decreasing loss and take top n_hard_negatives
        conf_loss_neg = conf_loss_all.clone()
        conf_loss_neg[positive_priors] = 0.
        conf_loss_neg, _ = conf_loss_neg.sort(
            descending=True)  # (N, 8732), sorted by decreasing hardness
        hardness_ranks = torch.LongTensor(
        hard_negatives = hardness_ranks < n_hard_negatives.unsqueeze(1)
        conf_loss_hard_neg = conf_loss_neg[hard_negatives]

        conf_loss = (conf_loss_hard_neg.sum() + conf_loss_pos.sum()
                     ) / n_positives.sum().float()  # (), scalar

        # TOTAL LOSS
        return conf_loss + self.alpha * loc_loss
Exemplo n.º 9
    def forward(self, predicted_locs, predicted_scores, boxes, labels):
        Forward propagation.
        :param predicted_locs: predicted locations/boxes w.r.t the 8732 prior boxes, a tensor of dimensions (N, 8732, 4)
        :param predicted_scores: class scores for each of the encoded locations/boxes, a tensor of dimensions (N, 8732, n_classes)
        :param boxes: true  object bounding boxes in boundary coordinates, a list of N tensors
        :param labels: true object labels, a list of N tensors
        :return: multibox loss, a scalar
        batch_size = predicted_locs.size(0)
        n_priors = self.priors_cxcy.size(0)
        n_classes = predicted_scores.size(2)

        assert n_priors == predicted_locs.size(1) == predicted_scores.size(1)

        true_locs = torch.zeros((batch_size, n_priors, 4), dtype=torch.float).to(device)  # (N, 8732, 4)
        true_classes = torch.zeros((batch_size, n_priors), dtype=torch.long).to(device)  # (N, 8732)

        # For each image
        for i in range(batch_size):
            n_objects = boxes[i].size(0)

            overlap = find_jaccard_overlap(boxes[i],
                                           self.priors_xy)  # (n_objects, 8732)

            # For each prior, find the object that has the maximum overlap
            overlap_for_each_prior, object_for_each_prior = overlap.max(dim=0)  # (8732)

            # We don't want a situation where an object is not represented in our positive (non-background) priors -
            # 1. An object might not be the best object for all priors, and is therefore not in object_for_each_prior.
            # 2. All priors with the object may be assigned as background based on the threshold (0.5).

            # To remedy this -
            # First, find the prior that has the maximum overlap for each object.
            _, prior_for_each_object = overlap.max(dim=1)  # (N_o)

            # Then, assign each object to the corresponding maximum-overlap-prior. (This fixes 1.)
            object_for_each_prior[prior_for_each_object] = torch.LongTensor(range(n_objects)).to(device)

            # To ensure these priors qualify, artificially give them an overlap of greater than 0.5. (This fixes 2.)
            overlap_for_each_prior[prior_for_each_object] = 1.

            # Labels for each prior
            label_for_each_prior = labels[i][object_for_each_prior]  # (8732)
            # Set priors whose overlaps with objects are less than the threshold to be background (no object)
            label_for_each_prior[overlap_for_each_prior < self.threshold] = 0  # (8732)

            # Store
            true_classes[i] = label_for_each_prior

            # Encode center-size object coordinates into the form we regressed predicted boxes to
            true_locs[i] = cxcy_to_gcxgcy(xy_to_cxcy(boxes[i][object_for_each_prior]), self.priors_cxcy)  # (8732, 4)

        positive_priors = true_classes != 0  # (N, 8732)

        loc_loss = self.smooth_l1(predicted_locs[positive_priors], true_locs[positive_priors])  # (), scalar

        n_positives = positive_priors.sum(dim=1)  # (N)
        n_hard_negatives = self.neg_pos_ratio * n_positives  # (N)

        # First, find the loss for all priors
        conf_loss_all = self.cross_entropy(predicted_scores.view(-1, n_classes), true_classes.view(-1))  # (N * 8732)
        conf_loss_all = conf_loss_all.view(batch_size, n_priors)  # (N, 8732)

        # We already know which priors are positive
        conf_loss_pos = conf_loss_all[positive_priors]  # (sum(n_positives))

        # Next, find which priors are hard-negative
        # To do this, sort ONLY negative priors in each image in order of decreasing loss and take top n_hard_negatives
        conf_loss_neg = conf_loss_all.clone()  # (N, 8732)
        conf_loss_neg[positive_priors] = 0.  # (N, 8732), positive priors are ignored (never in top n_hard_negatives)
        conf_loss_neg, _ = conf_loss_neg.sort(dim=1, descending=True)  # (N, 8732), sorted by decreasing hardness
        hardness_ranks = torch.LongTensor(range(n_priors)).unsqueeze(0).expand_as(conf_loss_neg).to(device)  # (N, 8732)
        hard_negatives = hardness_ranks < n_hard_negatives.unsqueeze(1)  # (N, 8732)
        conf_loss_hard_neg = conf_loss_neg[hard_negatives]  # (sum(n_hard_negatives))

        conf_loss = (conf_loss_hard_neg.sum() + conf_loss_pos.sum()) / n_positives.sum().float()  # (), scalar

        return conf_loss + self.alpha * loc_loss
Exemplo n.º 10
def meanAP(det_boxes, det_labels, det_scores, true_boxes, true_labels, true_difficulties):
    Calculate the Mean Average Precision (mAP) of detected objects.
    See for an explanation
    :param det_boxes: list of tensors, one tensor for each image containing detected objects' bounding boxes
    :param det_labels: list of tensors, one tensor for each image containing detected objects' labels
    :param det_scores: list of tensors, one tensor for each image containing detected objects' labels' scores
    :param true_boxes: list of tensors, one tensor for each image containing actual objects' bounding boxes
    :param true_labels: list of tensors, one tensor for each image containing actual objects' labels
    :param true_difficulties: list of tensors, one tensor for each image containing actual objects' difficulty (0 or 1)
    :return: list of average precisions for all classes, mean average precision (mAP)
    assert len(det_boxes) == len(det_labels) == len(det_scores) == len(true_boxes) == len(
        true_labels) # these are all lists of tensors of the same length, i.e. number of images
    n_classes = len(VOC_ENCODING)

    # Store all (true) objects in a single continuous tensor while keeping track of the image it is from
    true_images = list()
    for i in range(len(true_labels)):
        true_images.extend([i] * true_labels[i].size(0))
    true_images = torch.LongTensor(true_images).to(DEVICE)  # (n_objects), n_objects is the total no. of objects across all images
    true_boxes =, dim=0)  # (n_objects, 4)
    true_labels =, dim=0)  # (n_objects)
    true_difficulties =, dim=0)  # (n_objects)

    assert true_images.size(0) == true_boxes.size(0) == true_labels.size(0)

    # Store all detections in a single continuous tensor while keeping track of the image it is from
    det_images = list()
    for i in range(len(det_labels)):
        det_images.extend([i] * det_labels[i].size(0))
    det_images = torch.LongTensor(det_images).to(DEVICE)  # (n_detections)
    det_boxes =, dim=0)  # (n_detections, 4)
    det_labels =, dim=0)  # (n_detections)
    det_scores =, dim=0)  # (n_detections)

    assert det_images.size(0) == det_boxes.size(0) == det_labels.size(0) == det_scores.size(0)

    # Calculate APs for each class (except background)
    average_precisions = torch.zeros((n_classes - 1), dtype=torch.float)  # (n_classes - 1)
    for c in range(1, n_classes):
        # Extract only objects with this class
        true_class_images = true_images[true_labels == c]  # (n_class_objects)
        true_class_boxes = true_boxes[true_labels == c]  # (n_class_objects, 4)
        true_class_difficulties = true_difficulties[true_labels == c]  # (n_class_objects)
        n_easy_class_objects = (1 - true_class_difficulties).sum().item()  # ignore difficult objects

        # Keep track of which true objects with this class have already been 'detected'
        # So far, none
        true_class_boxes_detected = torch.zeros((true_class_difficulties.size(0)), dtype=torch.uint8).to(DEVICE)  # (n_class_objects)

        # Extract only detections with this class
        det_class_images = det_images[det_labels == c]  # (n_class_detections)
        det_class_boxes = det_boxes[det_labels == c]  # (n_class_detections, 4)
        det_class_scores = det_scores[det_labels == c]  # (n_class_detections)
        n_class_detections = det_class_boxes.size(0)
        if n_class_detections == 0:

        # Sort detections in decreasing order of confidence/scores
        det_class_scores, sort_ind = torch.sort(det_class_scores, dim=0, descending=True)  # (n_class_detections)
        det_class_images = det_class_images[sort_ind]  # (n_class_detections)
        det_class_boxes = det_class_boxes[sort_ind]  # (n_class_detections, 4)

        # In the order of decreasing scores, check if true or false positive
        true_positives = torch.zeros((n_class_detections), dtype=torch.float).to(DEVICE)  # (n_class_detections)
        false_positives = torch.zeros((n_class_detections), dtype=torch.float).to(DEVICE)  # (n_class_detections)
        for d in range(n_class_detections):
            this_detection_box = det_class_boxes[d].unsqueeze(0)  # (1, 4)
            this_image = det_class_images[d]  # (), scalar

            # Find objects in the same image with this class, their difficulties, and whether they have been detected before
            object_boxes = true_class_boxes[true_class_images == this_image]  # (n_class_objects_in_img)
            object_difficulties = true_class_difficulties[true_class_images == this_image]  # (n_class_objects_in_img)
            # If no such object in this image, then the detection is a false positive
            if object_boxes.size(0) == 0:
                false_positives[d] = 1

            # Find maximum overlap of this detection with objects in this image of this class
            overlaps = find_jaccard_overlap(this_detection_box, object_boxes)  # (1, n_class_objects_in_img)
            max_overlap, ind = torch.max(overlaps.squeeze(0), dim=0)  # (), () - scalars

            # 'ind' is the index of the object in these image-level tensors 'object_boxes', 'object_difficulties'
            # In the original class-level tensors 'true_class_boxes', etc., 'ind' corresponds to object with index...
            original_ind = torch.LongTensor(range(true_class_boxes.size(0)))[true_class_images == this_image][ind]
            # We need 'original_ind' to update 'true_class_boxes_detected'

            # If the maximum overlap is greater than the threshold of 0.5, it's a match
            if max_overlap.item() > 0.5:
                # If the object it matched with is 'difficult', ignore it
                if object_difficulties[ind] == 0:
                    # If this object has already not been detected, it's a true positive
                    if true_class_boxes_detected[original_ind] == 0:
                        true_positives[d] = 1
                        true_class_boxes_detected[original_ind] = 1  # this object has now been detected/accounted for
                    # Otherwise, it's a false positive (since this object is already accounted for)
                        false_positives[d] = 1
            # Otherwise, the detection occurs in a different location than the actual object, and is a false positive
                false_positives[d] = 1

        # Compute cumulative precision and recall at each detection in the order of decreasing scores
        cumul_true_positives = torch.cumsum(true_positives, dim=0)  # (n_class_detections)
        cumul_false_positives = torch.cumsum(false_positives, dim=0)  # (n_class_detections)
        cumul_precision = cumul_true_positives / (
                cumul_true_positives + cumul_false_positives + 1e-10)  # (n_class_detections)
        cumul_recall = cumul_true_positives / n_easy_class_objects  # (n_class_detections)

        # Find the mean of the maximum of the precisions corresponding to recalls above the threshold 't'
        recall_thresholds = torch.arange(start=0, end=1.1, step=.1).tolist()  # (11)
        precisions = torch.zeros((len(recall_thresholds)), dtype=torch.float).to(DEVICE)  # (11)
        for i, t in enumerate(recall_thresholds):
            recalls_above_t = cumul_recall >= t
            if recalls_above_t.any():
                precisions[i] = cumul_precision[recalls_above_t].max()
                precisions[i] = 0.
        average_precisions[c - 1] = precisions.mean()  # c is in [1, n_classes - 1]

    # Calculate Mean Average Precision (mAP)
    mean_average_precision = average_precisions.mean().item()

    # Keep class-wise average precisions in a dictionary
    average_precisions = {VOC_DECODING[c + 1]: v for c, v in enumerate(average_precisions.tolist())}

    return average_precisions, mean_average_precision
Exemplo n.º 11
    def forward(self, predicted_locs, predicted_scores, boxes, labels):
        Forward propagation.
        :param predicted_locs: predicted locations/boxes w.r.t the prior boxes, a tensor of dimensions (N, n_priors, 4)
        :param predicted_scores: class scores for each of the encoded locations/boxes, a tensor of dimensions (N, n_priors, n_classes)
        :param boxes: true  object bounding boxes in boundary coordinates, a list of N tensors
        :param labels: true object labels, a list of N tensors
        :return: multibox loss, a scalar
        batch_size = predicted_locs.size(0)
        n_priors = self.priors_cxcy.size(0)
        n_classes = predicted_scores.size(2)
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        assert n_priors == predicted_locs.size(1) == predicted_scores.size(1)

        true_locs = torch.zeros((batch_size, n_priors, 4),
        true_classes = torch.zeros((batch_size, n_priors),

        # For each image
        for i in range(batch_size):
            n_objects = boxes[i].size(0)

            overlap = find_jaccard_overlap(
                boxes[i], self.priors_xy)  # (n_objects, n_priors)

            # For each prior, find the object that has the maximum overlap
            overlap_for_each_prior, object_for_each_prior = overlap.max(
                dim=0)  # (n_priors)

            _, prior_for_each_object = overlap.max(dim=1)  # (N_o)

            # Then, assign each object to the corresponding maximum-overlap-prior. (This fixes 1.)
            object_for_each_prior[prior_for_each_object] = torch.LongTensor(

            # To ensure these priors qualify, artificially give them an overlap of greater than 0.5. (This fixes 2.)
            overlap_for_each_prior[prior_for_each_object] = 1.

            # Labels for each prior
            label_for_each_prior = labels[i][
                object_for_each_prior]  # (n_priors)
            # Set priors whose overlaps with objects are less than the threshold to be background (no object)
                overlap_for_each_prior < self.threshold] = 0  # (n_priors)

            # Store
            true_classes[i] = label_for_each_prior

            # Encode center-size object coordinates into the form we regressed predicted boxes to
            true_locs[i] = cxcy_to_gcxgcy(
                self.priors_cxcy)  # (n_priors, 4)

        # Identify priors that are positive (object/non-background)
        positive_priors = true_classes != 0  # (N, n_priors)

        # Localization loss is computed only over positive (non-background) priors
        loc_loss = self.smooth_l1(predicted_locs[positive_priors],
                                  true_locs[positive_priors])  # (), scalar

        # Number of positive and hard-negative priors per image
        n_positives = positive_priors.sum(dim=1)  # (N)
        n_hard_negatives = self.neg_pos_ratio * n_positives  # (N)

        # First, find the loss for all priors
        conf_loss_all = self.cross_entropy(predicted_scores.view(
            -1, n_classes), true_classes.view(-1))  # (N * n_priors)
        conf_loss_all = conf_loss_all.view(batch_size,
                                           n_priors)  # (N, n_priors)

        # We already know which priors are positive
        conf_loss_pos = conf_loss_all[positive_priors]  # (sum(n_positives))

        # Next, find which priors are hard-negative
        # To do this, sort ONLY negative priors in each image in order of decreasing loss and take top n_hard_negatives
        conf_loss_neg = conf_loss_all.clone()  # (N, n_priors)
            positive_priors] = 0.  # (N, n_priors), positive priors are ignored (never in top n_hard_negatives)
        conf_loss_neg, _ = conf_loss_neg.sort(
            descending=True)  # (N, n_priors), sorted by decreasing hardness
        hardness_ranks = torch.LongTensor(
                device)  # (N, n_priors)
        hard_negatives = hardness_ranks < n_hard_negatives.unsqueeze(
            1)  # (N, n_priors)
        conf_loss_hard_neg = conf_loss_neg[
            hard_negatives]  # (sum(n_hard_negatives))

        # As in the paper, averaged over positive priors only, although computed over both positive and hard-negative priors
        conf_loss = (conf_loss_hard_neg.sum() + conf_loss_pos.sum()
                     ) / n_positives.sum().float()  # (), scalar

        # TOTAL LOSS

        return conf_loss + self.alpha * loc_loss
Exemplo n.º 12
    def detect_objects(self, predicted_locs, predicted_scores, min_score,
                       max_overlap, top_k):
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        batch_size = predicted_locs.size(0)
        n_priors = self.priors_cxcy.size(0)
        predicted_scores = F.softmax(predicted_scores, dim=2)

        all_images_boxes = list()
        all_images_labels = list()
        all_images_scores = list()

        assert n_priors == predicted_locs.size(1) == predicted_scores.size(1)

        for i in range(batch_size):
            # Decode object coordinates from the form we regressed predicted boxes to
            decoded_locs = cxcy_to_xy(
                gcxgcy_to_cxcy(predicted_locs[i], self.priors_cxcy)
            )  # (n_priors, 4), these are fractional pt. coordinates

            image_boxes = list()
            image_labels = list()
            image_scores = list()

            max_scores, best_label = predicted_scores[i].max(dim=1)  # (8732)

            for c in range(1, self.n_classes):
                # Keep only predicted boxes and scores where scores for this class are above the minimum score
                class_scores = predicted_scores[i][:, c]
                score_above_min_score = class_scores > min_score  # torch.uint8 (byte) tensor, for indexing
                n_above_min_score = score_above_min_score.sum().item()
                if n_above_min_score == 0:
                class_scores = class_scores[score_above_min_score]
                class_decoded_locs = decoded_locs[score_above_min_score]

                # Sort predicted boxes and scores by scores
                class_scores, sort_ind = class_scores.sort(dim=0,
                class_decoded_locs = class_decoded_locs[sort_ind]
                print('class_scores.shape', class_scores.shape,
                      'class_decoded_locs.shape', class_decoded_locs.shape)
                # Find the overlap between predicted boxes
                overlap = find_jaccard_overlap(class_decoded_locs,

                # Non-Maximum Suppression (NMS)

                # A torch.uint8 (byte) tensor to keep track of which predicted boxes to suppress
                # 1 implies suppress, 0 implies don't suppress
                suppress = torch.max(
                    suppress, (overlap[box] > max_overlap).type(
                        torch.cuda.ByteTensor))  # (n_qualified)

                # Consider each box in order of decreasing scores
                for box in range(class_decoded_locs.size(0)):
                    # If this box is already marked for suppression
                    if suppress[box] == 1:

                    # Suppress boxes whose overlaps (with this box) are greater than maximum overlap
                    # Find such boxes and update suppress indices
                    suppress = torch.max(suppress, overlap[box] > max_overlap)
                    # The max operation retains previously suppressed boxes, like an 'OR' operation

                    # Don't suppress this box, even though it has an overlap of 1 with itself
                    suppress[box] = 0

                # Store only unsuppressed boxes for this class
                image_boxes.append(class_decoded_locs[1 - suppress])
                        (1 - suppress).sum().item() * [c]).to(device))
                image_scores.append(class_scores[1 - suppress])

            # If no object in any class is found, store a placeholder for 'background'
            if len(image_boxes) == 0:
                    torch.FloatTensor([[0., 0., 1., 1.]]).to(device))

            # Concatenate into single tensors
            image_boxes =, dim=0)  # (n_objects, 4)
            image_labels =, dim=0)  # (n_objects)
            image_scores =, dim=0)  # (n_objects)
            n_objects = image_scores.size(0)

            # Keep only the top k objects
            if n_objects > top_k:
                image_scores, sort_ind = image_scores.sort(dim=0,
                image_scores = image_scores[:top_k]  # (top_k)
                image_boxes = image_boxes[sort_ind][:top_k]  # (top_k, 4)
                image_labels = image_labels[sort_ind][:top_k]  # (top_k)

            # Append to lists that store predicted boxes and scores for all images

        return all_images_boxes, all_images_labels, all_images_scores  # lists of length batch_size
Exemplo n.º 13
    def forward(self, output_boxes, output_scores, true_boxes, true_labels):

        batch_size = output_boxes.size(0)
        n_classes = output_boxes.size(2)
        n_priors = self.default_cxcy.size(0)

        gt_locs = torch.Tensor(batch_size, n_priors, 4).to(self.device)
        gt_class = torch.LongTensor(batch_size, n_priors).to(self.device)

        for im in range(batch_size):
            n_objects = true_boxes[im].size(0)

            # compute IoU for each ground truth box with default boxes
            # (n_objects, 8732)
            overlaps = find_jaccard_overlap(true_boxes[im], self.default_xy)

            # find highest-overlap object for each default, and then highest-
            # overlap default for each object
            overlap_per_default, object_per_default = overlaps.max(dim=0)
            overlap_per_object, default_per_object = overlaps.max(dim=1)

            # assign object to default box with highest overlap
            object_per_default[default_per_object] = torch.LongTensor(

            # give these default boxes an overlap of 1 (ensure positive)
            overlap_per_default[default_per_object] = 1.

            # assign labels to the default boxes according to the best overlap
            default_labels = true_labels[im][object_per_default]
            default_labels[overlap_per_default < self.threshold] = 0

            gt_class[im] = default_labels
            gt_locs[im] = cxcy_to_gcxgcy(

        positive_defaults = (gt_class > 0)

        # localization loss
        L_loc = self.smooth_l1(output_boxes[positive_defaults],

        # confidence loss
        n_positives = positive_defaults.sum(dim=1)  # (N)
        n_hard_negatives = self.hard_neg_scale * n_positives

        conf_all = output_scores.view(-1, n_classes)
        L_conf_all = self.cross_entropy(conf_all, gt_class.view(-1))
        L_conf_all = L_conf_all.view(batch_size, n_priors)  # (N, 8732)

        # We already know which priors are positive
        L_conf_pos = L_conf_all[positive_defaults]  # (sum(n_positives))

        # Next, find which priors are hard-negative
        # To do this, sort ONLY negative priors in each image in order of decreasing loss and take top n_hard_negatives
        L_conf_neg = L_con_all.clone()  # (N, 8732)
            positive_defaults] = 0.  # (N, 8732), positive priors are ignored (never in top n_hard_negatives)
        L_conf_neg, _ = L_conf_neg.sort(
            dim=1, descending=True)  # (N, 8732), sorted by decreasing hardness
        hardness_ranks = torch.LongTensor(
                device)  # (N, 8732)
        hard_negatives = hardness_ranks < n_hard_negatives.unsqueeze(1)
        L_conf_hard_neg = L_conf_neg[hard_negatives]

        # As in the paper, averaged over positive priors only, although computed over both positive and hard-negative priors
        L_conf = (L_conf_hard_neg.sum() +
                  L_conf_pos.sum()) / n_positives.sum().float()  # (), scalar

        loss = L_conf + self.alpha * L_loc
        return loss
Exemplo n.º 14
def random_crop(image, bboxes, labels, difficulties):
    Performs a random crop to help detect larger and partial objects.

    :param image: a tensor with shape (C, H, W)
    :param bboxes: a tensor with shape (num_objects, 4)
    :param labels: a tensor with shape (num_objects)
    :param difficulties: a tensor with shape (num_objects)
    org_h = image.shape[1]
    org_w = image.shape[2]

    # keep choosing a minimum overlap until a successful crop is made
    while True:
        min_overlap = random.choice([0.0, 0.1, 0.3, 0.5, 0.7, 0.9, None])
        if min_overlap is None:
            return image, bboxes, labels, difficulties

        # try up to 50 times for this choice of minimum overlap
        max_trials = 50
        for _ in range(max_trials):
            scale_h = random.uniform(0.3, 1)
            scale_w = random.uniform(0.3, 1)
            new_h = int(scale_h * org_h)
            new_w = int(scale_w * org_w)

            # aspect ratio must be in [0.5, 2]
            aspect_ratio = new_h / new_w
            if not 0.5 < aspect_ratio < 2:

            # calculate crop coordinates
            ymin = random.randint(0, org_h - new_h)
            ymax = ymin + new_h
            xmin = random.randint(0, org_w - new_w)
            xmax = xmin + new_w
            crop = torch.FloatTensor([xmin, ymin, xmax, ymax])
            overlap = find_jaccard_overlap(crop.unsqueeze(0), bboxes).squeeze(0)  # (num_objects)
            if overlap.max().item() < min_overlap:

            # crop image
            new_image = image[:, ymin:ymax, xmin:xmax]  # (3, new_h, new_w)

            # find centers of original bounding bboxes
            centers = (bboxes[:, :2] + bboxes[:, 2:]) / 2  # (num_objects, 2)
            centers_in_crop = (centers[:, 0] > xmin) * (centers[:, 0] < xmax) * \
                              (centers[:, 1] > ymin) * (centers[:, 1] < ymax)
            if not centers_in_crop.any():

            # filter bounding bboxes and labels
            new_bboxes = bboxes[centers_in_crop, :]
            new_labels = labels[centers_in_crop]
            new_difficulties = difficulties[centers_in_crop]

            # calculate bounding bboxes' new coordinates in the crop
            new_bboxes[:, :2] = torch.max(new_bboxes[:, :2], crop[:2])
            new_bboxes[:, :2] -= crop[:2]
            new_bboxes[:, 2:] = torch.min(new_bboxes[:, 2:], crop[2:])
            new_bboxes[:, 2:] -= crop[:2]

            return new_image, new_bboxes, new_labels, new_difficulties