示例#1
0
    def forward(self, predicted_locs, predicted_scores, boxes, labels):
        """
        Forward propagation.

        :param fmap_dims:
        :param predicted_locs: predicted locations/boxes w.r.t the 22536 prior boxes, a tensor of dimensions (N, 22536, 4)
        :param predicted_scores: class scores for each of the encoded locations/boxes, a tensor of dimensions (N, 22536, n_classes)
        :param boxes: true  object bounding boxes in boundary coordinates, a list of N tensors
        :param labels: true object labels, a list of N tensors
        :return: multibox loss, a scalar
        """
        n_levels = len(self.priors_cxcy)
        batch_size = predicted_locs.size(0)

        decoded_locs = list()  # length is the batch size
        true_locs = list()
        true_classes = list()
        predicted_class_scores = list()

        # For each image
        for i in range(batch_size):
            image_bboxes = boxes[i]

            batch_split_predicted_locs = []
            batch_split_predicted_scores = []
            # split the predictions according to the feature pyramid dimension
            for s in range(len(self.priors_cxcy)):
                batch_split_predicted_locs.append(predicted_locs[i][
                    self.prior_split_points[s]:self.prior_split_points[s +
                                                                       1], :])
                batch_split_predicted_scores.append(predicted_scores[i][
                    self.prior_split_points[s]:self.prior_split_points[s +
                                                                       1], :])

            # candidates for positive samples, use to calculate the IOU threshold
            positive_samples_idx = list()
            positive_overlaps = list()
            overlap = list()  # for all
            for level in range(n_levels):
                distance = find_distance(
                    xy_to_cxcy(image_bboxes),
                    self.priors_cxcy[level])  # n_bboxes, n_priors

                _, top_idx_level = torch.topk(-1. * distance,
                                              min(self.n_candidates,
                                                  distance.size(1)),
                                              dim=1)
                positive_samples_idx.append(top_idx_level)
                overlap_level = find_jaccard_overlap(
                    image_bboxes,
                    self.priors_xy[level])  # overlap for each level
                positive_overlaps.append(
                    torch.gather(overlap_level, dim=1, index=top_idx_level))

                overlap.append(overlap_level)

            positive_overlaps_cat = torch.cat(positive_overlaps, dim=1)
            overlap_mean = torch.mean(positive_overlaps_cat, dim=1)
            overlap_std = torch.std(positive_overlaps_cat, dim=1)
            # print(overlap_mean, overlap_std)
            iou_threshold = overlap_mean + overlap_std  # n_bboxes, for each object, we have one threshold

            # one prior can only be associated to one gt object
            # For each prior, find the object that has the maximum overlap, return [value, indices]
            true_classes_level = list()
            true_locs_level = list()
            positive_priors = list()  # For all levels
            decoded_locs_level = list()
            for level in range(n_levels):
                positive_priors_per_level = torch.zeros(
                    (image_bboxes.size(0), self.priors_cxcy[level].size(0)),
                    dtype=torch.uint8).to(self.device)  # indexing, (n,)

                for ob in range(image_bboxes.size(0)):
                    for c in range(len(positive_samples_idx[level][ob])):
                        # print(ob, c, 'Range for c: ', len(positive_samples_idx[level][ob]))
                        current_iou = positive_overlaps[level][ob, c]
                        current_bbox = image_bboxes[ob, :]
                        current_prior = self.priors_cxcy[level][
                            positive_samples_idx[level][ob, c], :]

                        if current_iou > iou_threshold[ob]:
                            if current_bbox[0] < current_prior[0] < current_bbox[2] \
                                    and current_bbox[1] < current_prior[1] < current_bbox[3]:
                                positive_priors_per_level[
                                    ob, positive_samples_idx[level][ob, c]] = 1

                positive_priors.append(positive_priors_per_level)

            for level in range(
                    n_levels
            ):  # this is the loop for find the best object for each prior,
                # because one prior could match with more than one objects
                label_for_each_prior_per_level = torch.zeros(
                    (self.priors_cxcy[level].size(0)),
                    dtype=torch.long).to(self.device)
                true_locs_per_level = list(
                )  # only for positive candidates in the predictions
                decoded_locs_per_level = list()
                total_decoded_locs = cxcy_to_xy(
                    gcxgcy_to_cxcy(batch_split_predicted_locs[level],
                                   self.priors_cxcy[level]))

                for c in range(positive_samples_idx[level].size(1)):
                    # for c in range(self.priors_cxcy[level].size(0)):  # loop over each prior in each level
                    current_max_iou = 0.
                    current_max_iou_ob = -1  # index for rows: (n_ob, n_prior)
                    for ob in range(image_bboxes.size(0)):
                        if positive_priors[level][
                                ob, positive_samples_idx[level][ob, c]] == 1:
                            if overlap[level][ob, positive_samples_idx[level][
                                    ob, c]] > current_max_iou:
                                current_max_iou_ob = ob
                                current_max_iou = overlap[level][
                                    ob, positive_samples_idx[level][ob, c]]

                    if current_max_iou_ob > -1 and current_max_iou > 0.:
                        temp_true_locs = image_bboxes[
                            current_max_iou_ob, :].unsqueeze(0)  # (1, 4)
                        temp_decoded_locs = total_decoded_locs[
                            positive_samples_idx[level][current_max_iou_ob,
                                                        c], :].unsqueeze(
                                                            0)  # (1, 4)
                        label_for_each_prior_per_level[positive_samples_idx[
                            level][current_max_iou_ob,
                                   c]] = labels[i][current_max_iou_ob]
                        true_locs_per_level.append(temp_true_locs)
                        decoded_locs_per_level.append(temp_decoded_locs)

                if len(true_locs_per_level) > 0:
                    true_locs_level.append(
                        torch.cat(true_locs_per_level,
                                  dim=0).view(-1, 4))  # (1, n_l * 4)
                    decoded_locs_level.append(
                        torch.cat(decoded_locs_per_level, dim=0).view(-1, 4))

                true_classes_level.append(label_for_each_prior_per_level)

            # Store
            true_classes.append(torch.cat(true_classes_level,
                                          dim=0))  # batch_size, n_priors
            predicted_class_scores.append(
                torch.cat(batch_split_predicted_scores, dim=0))
            if len(true_locs_level) > 0:
                true_locs.append(torch.cat(true_locs_level,
                                           dim=0))  # batch_size, n_pos, 4
                decoded_locs.append(torch.cat(decoded_locs_level, dim=0))

        # assemble all samples from batches
        true_classes = torch.cat(true_classes, dim=0)
        positive_priors = true_classes > 0
        predicted_scores = torch.cat(predicted_class_scores, dim=0)
        true_locs = torch.cat(true_locs, dim=0)
        decoded_locs = torch.cat(decoded_locs, dim=0)

        # LOCALIZATION LOSS
        loc_loss = self.regression_loss(decoded_locs, true_locs)

        # CONFIDENCE LOSS
        n_positives = positive_priors.sum().float()

        # First, find the loss for all priors
        conf_loss = self.classification_loss(predicted_scores,
                                             true_classes) / n_positives
        # conf_loss = self.FocalLoss(predicted_scores, true_classes) / len(true_classes)

        # TOTAL LOSS
        return conf_loss + self.alpha * loc_loss
示例#2
0
def random_crop(image, boxes, labels):
    """
    Performs a random crop in the manner stated in the paper. Helps to learn to detect larger and partial objects.

    Note that some objects may be cut out entirely.

    Adapted from https://github.com/amdegroot/ssd.pytorch/blob/master/utils/augmentations.py

    :param image: image, a tensor of dimensions (3, original_h, original_w)
    :param boxes: bounding boxes in boundary coordinates, a tensor of dimensions (n_objects, 4)
    :param labels: labels of objects, a tensor of dimensions (n_objects)
    :param difficulties: difficulties of detection of these objects, a tensor of dimensions (n_objects)
    :return: cropped image, updated bounding box coordinates, updated labels, updated difficulties
    """
    original_h = image.size(1)
    original_w = image.size(2)
    # Keep choosing a minimum overlap until a successful crop is made
    while True:
        # Randomly draw the value for minimum overlap
        min_overlap = random.choice([0., .1, .3, .5, .7, .9,
                                     None])  # 'None' refers to no cropping

        # If not cropping
        if min_overlap is None:
            return image, boxes, labels

        # Try up to 50 times for this choice of minimum overlap
        # This isn't mentioned in the paper, of course, but 50 is chosen in paper authors' original Caffe repo
        max_trials = 50
        for _ in range(max_trials):
            # Crop dimensions must be in [0.3, 1] of original dimensions
            # Note - it's [0.1, 1] in the paper, but actually [0.3, 1] in the authors' repo
            min_scale = 0.3
            scale_h = random.uniform(min_scale, 1)
            scale_w = random.uniform(min_scale, 1)
            new_h = int(scale_h * original_h)
            new_w = int(scale_w * original_w)

            # Aspect ratio has to be in [0.5, 2]
            aspect_ratio = new_h / new_w
            if not 0.5 < aspect_ratio < 2:
                continue

            # Crop coordinates (origin at top-left of image)
            left = random.randint(0, original_w - new_w)
            right = left + new_w
            top = random.randint(0, original_h - new_h)
            bottom = top + new_h
            crop = torch.FloatTensor([left, top, right, bottom])  # (4)

            # Calculate Jaccard overlap between the crop and the bounding boxes
            overlap = find_jaccard_overlap(
                crop.unsqueeze(0), boxes
            )  # (1, n_objects), n_objects is the no. of objects in this image
            overlap = overlap.squeeze(0)  # (n_objects)

            # If not a single bounding box has a Jaccard overlap of greater than the minimum, try again
            if overlap.max().item() < min_overlap:
                continue

            # Crop image
            new_image = image[:, top:bottom, left:right]  # (3, new_h, new_w)

            # Find centers of original bounding boxes
            bb_centers = (boxes[:, :2] + boxes[:, 2:]) / 2.  # (n_objects, 2)

            # Find bounding boxes whose centers are in the crop
            centers_in_crop = (bb_centers[:, 0] > left) * (
                bb_centers[:, 0] < right
            ) * (bb_centers[:, 1] > top) * (
                bb_centers[:, 1] < bottom
            )  # (n_objects), a Torch uInt8/Byte tensor, can be used as a boolean index

            # If not a single bounding box has its center in the crop, try again
            if not centers_in_crop.any():
                continue

            # Discard bounding boxes that don't meet this criterion
            new_boxes = boxes[centers_in_crop, :]
            new_labels = labels[centers_in_crop]

            # Calculate bounding boxes' new coordinates in the crop
            new_boxes[:, :2] = torch.max(new_boxes[:, :2],
                                         crop[:2])  # crop[:2] is [left, top]
            new_boxes[:, :2] -= crop[:2]
            new_boxes[:,
                      2:] = torch.min(new_boxes[:, 2:],
                                      crop[2:])  # crop[2:] is [right, bottom]
            new_boxes[:, 2:] -= crop[:2]

            return new_image, new_boxes, new_labels
    def compute_odm_loss(self, arm_locs, arm_scores, odm_locs, odm_scores,
                         boxes, labels):
        """
        :param arm_locs: serve as "anchor boxes"
        :param arm_scores:
        :param odm_locs:
        :param odm_scores:
        :param boxes:
        :param labels:
        :return:
        """
        # print(arm_scores.size(), arm_locs.size(), odm_scores.size(), odm_locs.size())
        batch_size = odm_locs.size(0)
        n_priors = self.priors_cxcy.size(0)
        n_classes = odm_scores.size(2)

        # print(n_priors, predicted_locs.size(), predicted_scores.size())
        assert n_priors == odm_locs.size(1) == odm_scores.size(1)

        # Calculate ARM loss: offset smoothl1 + binary classification loss
        decoded_arm_locs = torch.zeros((batch_size, n_priors, 4),
                                       dtype=torch.float).to(self.device)
        # decoded_odm_locs = torch.zeros((batch_size, n_priors, 4), dtype=torch.float).to(self.device)
        true_locs_encoded = torch.zeros((batch_size, n_priors, 4),
                                        dtype=torch.float).to(self.device)
        true_classes = torch.zeros((batch_size, n_priors),
                                   dtype=torch.long).to(self.device)

        # For each image
        for i in range(batch_size):
            n_objects = boxes[i].size(0)

            decoded_arm_locs[i] = cxcy_to_xy(
                gcxgcy_to_cxcy(arm_locs[i], self.priors_cxcy))
            overlap = find_jaccard_overlap(boxes[i], decoded_arm_locs[i])

            # For each prior, find the object that has the maximum overlap, return [value, indices]
            overlap_for_each_prior, object_for_each_prior = overlap.max(
                dim=0)  # (22536)

            overlap_for_each_object, prior_for_each_object = overlap.max(
                dim=1)  # (N_o)
            prior_for_each_object = prior_for_each_object[
                overlap_for_each_object > 0]
            # Then, assign each object to the corresponding maximum-overlap-prior. (This fixes 1.)
            if len(prior_for_each_object) > 0:
                overlap_for_each_prior.index_fill_(0, prior_for_each_object,
                                                   1.0)

            for j in range(prior_for_each_object.size(0)):
                object_for_each_prior[prior_for_each_object[j]] = j

            # Labels for each prior
            label_for_each_prior = labels[i][object_for_each_prior]

            # Set priors whose overlaps with objects are less than the threshold to be background (no object)
            label_for_each_prior[overlap_for_each_prior < self.threshold] = 0

            # Store
            true_classes[i] = label_for_each_prior

            # Encode center-size object coordinates into the form we regressed predicted boxes to
            true_locs_encoded[i] = cxcy_to_gcxgcy(
                xy_to_cxcy(boxes[i][object_for_each_prior]),
                xy_to_cxcy(decoded_arm_locs[i]))

        # Identify priors that are positive (object/non-background)
        positive_priors = true_classes > 0
        # Eliminate easy background bboxes from ARM
        arm_scores_prob = F.softmax(arm_scores, dim=2)
        easy_negative_idx = arm_scores_prob[:, :, 1] < self.theta

        positive_priors = positive_priors & ~easy_negative_idx

        # LOCALIZATION LOSS
        loc_loss = self.odm_loss(
            odm_locs[positive_priors].view(-1, 4),
            true_locs_encoded[positive_priors].view(-1, 4))

        # CONFIDENCE LOSS
        n_positives = positive_priors.sum(dim=1)  # (N)
        n_hard_negatives = self.neg_pos_ratio * n_positives  # (N)

        # First, find the loss for all priors
        conf_loss_all = self.odm_cross_entropy(odm_scores.view(-1, n_classes),
                                               true_classes.view(-1))
        conf_loss_all = conf_loss_all.view(batch_size, -1)  # (N, 8732)

        # We already know which priors are positive
        conf_loss_pos = conf_loss_all[positive_priors]  # (sum(n_positives))

        # Next, find which priors are hard-negative
        # To do this, sort ONLY negative priors in each image in order of decreasing loss and take top n_hard_negatives
        conf_loss_neg = conf_loss_all.clone()
        conf_loss_neg[positive_priors] = 0.
        conf_loss_neg[easy_negative_idx] = 0.

        conf_loss_neg, _ = conf_loss_neg.sort(
            dim=-1,
            descending=True)  # (N, 8732), sorted by decreasing hardness
        hardness_ranks = torch.LongTensor(
            range(n_priors)).unsqueeze(0).expand_as(conf_loss_neg).to(
                self.device)
        hard_negatives = hardness_ranks < n_hard_negatives.unsqueeze(1)
        conf_loss_hard_neg = conf_loss_neg[hard_negatives]

        conf_loss = (conf_loss_hard_neg.sum() + conf_loss_pos.sum()
                     ) / n_positives.sum().float()  # (), scalar

        # TOTAL LOSS
        return conf_loss + self.alpha * loc_loss
    def compute_arm_loss(self, arm_locs, arm_scores, boxes, labels):
        """
        :param arm_locs: offset prediction from Anchor Refinement Modules
        :param arm_scores: binary classification scores from Anchor Refinement Modules
        :param boxes: gt bbox
        :param labels: gt labels
        :return:
        """
        batch_size = arm_locs.size(0)
        n_priors = self.priors_cxcy.size(0)
        n_classes = arm_scores.size(2)  # should be 2

        true_locs_encoded = torch.zeros((batch_size, n_priors, 4),
                                        dtype=torch.float).to(self.device)
        true_classes = torch.zeros((batch_size, n_priors),
                                   dtype=torch.long).to(self.device)

        # For each image
        for i in range(batch_size):

            overlap = find_jaccard_overlap(boxes[i],
                                           self.priors_xy)  # initial overlap

            # For each prior, find the object that has the maximum overlap, return [value, indices]
            overlap_for_each_prior, object_for_each_prior = overlap.max(dim=0)

            overlap_for_each_object, prior_for_each_object = overlap.max(
                dim=1)  # (N_o)
            prior_for_each_object = prior_for_each_object[
                overlap_for_each_object > 0]

            if len(prior_for_each_object) > 0:
                overlap_for_each_prior.index_fill_(0, prior_for_each_object,
                                                   1.0)

            for j in range(prior_for_each_object.size(0)):
                object_for_each_prior[prior_for_each_object[j]] = j

            # To ensure these priors qualify, artificially give them an overlap of greater than 0.5. (This fixes 2.)
            # overlap_for_each_prior[prior_for_each_object] = 1.

            # Labels for each prior
            label_for_each_prior = labels[i][object_for_each_prior]

            # Set priors whose overlaps with objects are less than the threshold to be background (no object)
            label_for_each_prior[overlap_for_each_prior < self.threshold] = 0

            # Converted labels to 0, 1's
            label_for_each_prior = (label_for_each_prior > 0).long()
            true_classes[i] = label_for_each_prior

            # Encode center-size object coordinates into the form we regressed predicted boxes to
            true_locs_encoded[i] = cxcy_to_gcxgcy(
                xy_to_cxcy(boxes[i][object_for_each_prior]), self.priors_cxcy)

        # Identify priors that are positive (non-background, binary)
        positive_priors = true_classes > 0
        n_positives = positive_priors.sum(dim=1)  # (N)
        # LOCALIZATION LOSS
        loc_loss = self.arm_loss(
            arm_locs[positive_priors].view(-1, 4),
            true_locs_encoded[positive_priors].view(-1, 4))

        # CONFIDENCE LOSS
        # Number of positive and hard-negative priors per image
        n_hard_negatives = self.neg_pos_ratio * n_positives  # (N)

        # First, find the loss for all priors
        conf_loss_all = self.arm_cross_entropy(arm_scores.view(-1, n_classes),
                                               true_classes.view(-1))
        conf_loss_all = conf_loss_all.view(batch_size, -1)

        # We already know which priors are positive
        conf_loss_pos = conf_loss_all[positive_priors]  # (sum(n_positives))

        # Next, find which priors are hard-negative
        # To do this, sort ONLY negative priors in each image in order of decreasing loss and take top n_hard_negatives
        conf_loss_neg = conf_loss_all.clone()
        conf_loss_neg[positive_priors] = 0.
        conf_loss_neg, _ = conf_loss_neg.sort(
            dim=-1,
            descending=True)  # (N, 8732), sorted by decreasing hardness
        hardness_ranks = torch.LongTensor(
            range(n_priors)).unsqueeze(0).expand_as(conf_loss_neg).to(
                self.device)
        hard_negatives = hardness_ranks < n_hard_negatives.unsqueeze(1)
        conf_loss_hard_neg = conf_loss_neg[hard_negatives]

        conf_loss = (conf_loss_hard_neg.sum() + conf_loss_pos.sum()
                     ) / n_positives.sum().float()  # (), scalar

        # TOTAL LOSS
        return conf_loss + self.alpha * loc_loss
示例#5
0
    def forward(self, predicted_locs, predicted_scores, boxes, labels):
        """
        Forward propagation.

        :param predicted_locs: predicted locations/boxes w.r.t the 22536 prior boxes, a tensor of dimensions (N, 22536, 4)
        :param predicted_scores: class scores for each of the encoded locations/boxes, a tensor of dimensions (N, 22536, n_classes)
        :param boxes: true  object bounding boxes in boundary coordinates, a list of N tensors
        :param labels: true object labels, a list of N tensors
        :return: multibox loss, a scalar
        """
        batch_size = predicted_locs.size(0)
        n_priors = self.priors_cxcy.size(0)
        n_classes = predicted_scores.size(2)

        # print(n_priors, predicted_locs.size(), predicted_scores.size())
        assert n_priors == predicted_locs.size(1) == predicted_scores.size(1)

        decoded_locs = torch.zeros((batch_size, n_priors, 4), dtype=torch.float).to(self.device)
        true_locs = torch.zeros((batch_size, n_priors, 4), dtype=torch.float).to(self.device)
        true_locs_encoded = torch.zeros((batch_size, n_priors, 4), dtype=torch.float).to(self.device)
        true_classes = torch.zeros((batch_size, n_priors), dtype=torch.long).to(self.device)
        true_neg_classes = torch.zeros((batch_size, n_priors), dtype=torch.long).to(self.device)

        # For each image
        for i in range(batch_size):
            n_objects = boxes[i].size(0)

            overlap = find_jaccard_overlap(boxes[i], self.priors_xy)  # (n_objects, 22536)

            # For each prior, find the object that has the maximum overlap, return [value, indices]
            overlap_for_each_prior, object_for_each_prior = overlap.max(dim=0)  # (22536)

            # We don't want a situation where an object is not represented in our positive (non-background) priors -
            # 1. An object might not be the best object for all priors, and is therefore not in object_for_each_prior.
            # 2. All priors with the object may be assigned as background based on the threshold (0.5).

            # To remedy this -
            # First, find the prior that has the maximum overlap for each object.
            overlap_for_each_object, prior_for_each_object = overlap.max(dim=1)  # (N_o)
            prior_for_each_object = prior_for_each_object[overlap_for_each_object > 0]
            # Then, assign each object to the corresponding maximum-overlap-prior. (This fixes 1.)
            if len(prior_for_each_object) > 0:
                overlap_for_each_prior.index_fill_(0, prior_for_each_object, 1.0)

            for j in range(prior_for_each_object.size(0)):
                object_for_each_prior[prior_for_each_object[j]] = j

            # Labels for each prior
            label_for_each_prior = labels[i][object_for_each_prior]
            label_for_each_prior_neg_used = labels[i][object_for_each_prior]

            # Set priors whose overlaps with objects are less than the threshold to be background (no object)
            # label_for_each_prior[overlap_for_each_prior < self.threshold] = -1  # label in 0.4-0.5 is not used
            # label_for_each_prior[overlap_for_each_prior < self.threshold - 0.1] = 0
            label_for_each_prior[overlap_for_each_prior < self.threshold] = 0
            label_for_each_prior_neg_used[overlap_for_each_prior < (self.threshold - 0.1)] = -1

            # Store
            true_classes[i] = label_for_each_prior
            true_neg_classes[i] = label_for_each_prior_neg_used

            # Encode center-size object coordinates into the form we regressed predicted boxes to
            true_locs_encoded[i] = cxcy_to_gcxgcy(xy_to_cxcy(boxes[i][object_for_each_prior]), self.priors_cxcy)
            true_locs[i] = boxes[i][object_for_each_prior]
            decoded_locs[i] = cxcy_to_xy(gcxgcy_to_cxcy(predicted_locs[i], self.priors_cxcy))

        # Identify priors that are positive (object/non-background)
        positive_priors = true_classes > 0
        negative_priors = true_neg_classes == -1
        n_positives = positive_priors.sum(dim=1)  # (N)

        # LOCALIZATION LOSS
        if self.config.reg_loss.upper() == 'DIOU':
            loc_loss = self.Diou_loss(decoded_locs[positive_priors].view(-1, 4),
                                      true_locs[positive_priors].view(-1, 4))
        else:
            loc_loss = self.smooth_l1(predicted_locs[positive_priors].view(-1, 4),
                                      true_locs_encoded[positive_priors].view(-1, 4))

        # CONFIDENCE LOSS
        if self.config.cls_loss.upper() == 'FOCAL':
            predicted_objects = torch.cat([predicted_scores[positive_priors],
                                           predicted_scores[negative_priors]], dim=0)
            target_class = torch.cat([true_classes[positive_priors],
                                      true_classes[negative_priors]], dim=0)

            conf_loss = self.Focal_loss(predicted_objects.view(-1, n_classes),
                                        target_class.view(-1), device=self.config.device) / n_positives.sum().float()
            # conf_loss = self.Focal_loss(predicted_objects.view(-1, n_classes),
            #                             target_class.view(-1), device=self.config.device) / n_positives.sum().float()
        else:
            # Number of positive and hard-negative priors per image
            # print('Classes:', self.n_classes, predicted_scores.size(), true_classes.size())
            n_hard_negatives = self.neg_pos_ratio * n_positives  # (N)

            # First, find the loss for all priors
            conf_loss_all = self.cross_entropy(predicted_scores.view(-1, n_classes),
                                               true_classes.view(-1))  # (N * 8732)
            conf_loss_all = conf_loss_all.view(batch_size, -1)  # (N, 8732)

            # We already know which priors are positive
            conf_loss_pos = conf_loss_all[positive_priors]  # (sum(n_positives))

            # Next, find which priors are hard-negative
            # To do this, sort ONLY negative priors in each image in order of decreasing loss and take top n_hard_negatives
            conf_loss_neg = conf_loss_all.clone()  # (N, 8732)
            # conf_loss_neg = conf_loss_all[negative_priors]
            # print(positive_priors.size(), negative_priors.size(), conf_loss_pos.size(), conf_loss_neg.size())
            conf_loss_neg[~negative_priors] = 0.
            # conf_loss_neg[positive_priors] = 0.  # (N, 8732), positive priors are ignored (never in top n_hard_negatives)
            conf_loss_neg, _ = conf_loss_neg.sort(dim=-1, descending=True)  # (N, 8732), sorted by decreasing hardness
            hardness_ranks = torch.LongTensor(range(n_priors)).unsqueeze(0).expand_as(conf_loss_neg).to(
                self.device)  # (N, 8732)
            hard_negatives = hardness_ranks < n_hard_negatives.unsqueeze(1)  # (N, 8732)
            conf_loss_hard_neg = conf_loss_neg[hard_negatives]  # (sum(n_hard_negatives))
            # conf_loss_hard_neg = conf_loss_neg[:n_hard_negatives.sum().long()]

            # As in the paper, averaged over positive priors only, although computed over both positive and hard-negative priors
            conf_loss = (conf_loss_hard_neg.sum() + conf_loss_pos.sum()) / n_positives.sum().float()  # (), scalar

        # TOTAL LOSS
        return conf_loss + self.alpha * loc_loss
    def forward(self, odm_locs, odm_scores, boxes, labels):
        """
        :param odm_locs: predicted bboxes
        :param odm_scores: predicted scores for each bbox
        :param boxes: gt
        :param labels: gt
        :return:
        """
        batch_size = odm_locs.size(0)
        n_priors = self.priors_cxcy.size(0)
        n_classes = odm_scores.size(2)

        assert n_priors == odm_locs.size(1) == odm_scores.size(1)

        decoded_locs = torch.zeros((batch_size, n_priors, 4),
                                   dtype=torch.float).to(self.device)
        true_locs = torch.zeros((batch_size, n_priors, 4),
                                dtype=torch.float).to(self.device)
        true_classes = torch.zeros((batch_size, n_priors),
                                   dtype=torch.long).to(self.device)

        # For each image
        for i in range(batch_size):
            overlap = find_jaccard_overlap(boxes[i],
                                           self.priors_xy)  # initial overlap

            # For each prior, find the object that has the maximum overlap, return [value, indices]
            overlap_for_each_prior, object_for_each_prior = overlap.max(
                dim=0)  # (22536)

            # We don't want a situation where an object is not represented in our positive (non-background) priors -
            # 1. An object might not be the best object for all priors, and is therefore not in object_for_each_prior.
            # 2. All priors with the object may be assigned as background based on the threshold (0.5).

            # To remedy this -
            # First, find the prior that has the maximum overlap for each object.
            overlap_for_each_object, prior_for_each_object = overlap.max(
                dim=1)  # (N_o)
            prior_for_each_object = prior_for_each_object[
                overlap_for_each_object > 0]
            # Then, assign each object to the corresponding maximum-overlap-prior. (This fixes 1.)
            if len(prior_for_each_object) > 0:
                overlap_for_each_prior.index_fill_(0, prior_for_each_object,
                                                   1.0)

            for j in range(prior_for_each_object.size(0)):
                object_for_each_prior[prior_for_each_object[j]] = j

            # Labels for each prior
            label_for_each_prior = labels[i][object_for_each_prior]

            # Set priors whose overlaps with objects are less than the threshold to be background (no object)
            label_for_each_prior[overlap_for_each_prior < self.threshold] = -1
            label_for_each_prior[overlap_for_each_prior < self.threshold -
                                 0.1] = 0
            # anchors whose IOU in 0.4 - 0.5 are not used

            # Store
            true_classes[i] = label_for_each_prior

            # Encode center-size object coordinates into the form we regressed predicted boxes to
            true_locs[i] = boxes[i][object_for_each_prior]
            decoded_locs[i] = cxcy_to_xy(
                gcxgcy_to_cxcy(odm_locs[i], self.priors_cxcy))

        # Identify priors that are positive (object/non-background)
        positive_priors = true_classes > 0

        # LOCALIZATION LOSS
        loc_loss = self.regression_loss(
            decoded_locs[positive_priors].view(-1, 4),
            true_locs[positive_priors].view(-1, 4))

        # CONFIDENCE LOSS
        # Number of positive and hard-negative priors per image
        n_positives = positive_priors.sum().float()  # (N)
        conf_loss = self.classification_loss(odm_scores.view(
            -1, n_classes), true_classes.view(-1)) / n_positives

        # TOTAL LOSS
        return conf_loss + self.alpha * loc_loss
    def forward(self, predicted_locs, predicted_scores, boxes, labels):
        """
        Forward propagation.

        :param predicted_locs: predicted locations/boxes w.r.t the 22536 prior boxes, a tensor of dimensions (N, 22536, 4)
        :param predicted_scores: class scores for each of the encoded locations/boxes, a tensor of dimensions (N, 22536, n_classes)
        :param boxes: true  object bounding boxes in boundary coordinates, a list of N tensors
        :param labels: true object labels, a list of N tensors
        :return: multibox loss, a scalar
        """
        batch_size = predicted_locs.size(0)
        n_priors = self.priors_cxcy.size(0)
        n_classes = predicted_scores.size(2)

        # print(n_priors, predicted_locs.size(), predicted_scores.size())
        assert n_priors == predicted_locs.size(1) == predicted_scores.size(1)

        # true_locs = torch.zeros((batch_size, n_priors, 4), dtype=torch.float).to(self.device)
        true_locs_encoded = torch.zeros((batch_size, n_priors, 4),
                                        dtype=torch.float).to(self.device)
        true_classes = torch.zeros((batch_size, n_priors),
                                   dtype=torch.long).to(self.device)

        # For each image
        for i in range(batch_size):

            overlap = find_jaccard_overlap(
                boxes[i], self.priors_xy)  # (n_objects, 22536)

            # For each prior, find the object that has the maximum overlap, return [value, indices]
            overlap_for_each_prior, object_for_each_prior = overlap.max(dim=0)

            overlap_for_each_object, prior_for_each_object = overlap.max(dim=1)
            prior_for_each_object = prior_for_each_object[
                overlap_for_each_object > 0]
            # Then, assign each object to the corresponding maximum-overlap-prior. (This fixes 1.)
            if len(prior_for_each_object) > 0:
                overlap_for_each_prior.index_fill_(0, prior_for_each_object,
                                                   1.0)

            for j in range(prior_for_each_object.size(0)):
                object_for_each_prior[prior_for_each_object[j]] = j

            # Labels for each prior
            label_for_each_prior = labels[i][object_for_each_prior]

            # Set priors whose overlaps with objects are less than the threshold to be background (no object)
            label_for_each_prior[
                overlap_for_each_prior <
                self.threshold] = -1  # label in 0.4-0.5 is not used
            label_for_each_prior[overlap_for_each_prior < self.threshold -
                                 0.1] = 0

            # Store
            true_classes[i] = label_for_each_prior

            # Encode center-size object coordinates into the form we regressed predicted boxes to
            true_locs_encoded[i] = cxcy_to_gcxgcy(
                xy_to_cxcy(boxes[i][object_for_each_prior]), self.priors_cxcy)

        # Identify priors that are positive (object/non-background)
        positive_priors = true_classes > 0
        n_positives = positive_priors.sum()  # (N)

        # LOCALIZATION LOSS
        loc_loss = self.smooth_l1(
            predicted_locs[positive_priors].view(-1, 4),
            true_locs_encoded[positive_priors].view(-1, 4))

        # CONFIDENCE LOSS
        conf_loss = self.Focal_loss(predicted_scores.view(-1, n_classes),
                                    true_classes.view(-1)) / n_positives

        # TOTAL LOSS
        return conf_loss + self.alpha * loc_loss