def encode_ground_truth(self, ground_truth, anchors):
        """
        Args:
            ground_truth: list(:len Batch) of torch.tensor(:shape [Boxes_i, 5])
            anchors: torch.tensor(:shape [AnchorBoxes, 4])
        Returns:
            target_classes: torch.tensor(:shape [Batch, AnchorBoxes])
            target_locs: torch.tensor(:shape [Batch, AnchorBoxes, 4])
        """
        batch_size = len(ground_truth)
        num_anchors = anchors.size(0)
        device = anchors.device

        ground_truth = [x.to(device, non_blocking=True) for x in ground_truth]
        corner_anchors = box_utils.to_corners(anchors)

        target = torch.zeros((batch_size, num_anchors, TARGET_SIZE),
                             dtype=torch.float32,
                             device=device)
        target[..., CLASS_INDEX] = torch.full_like(target[..., SCORE_INDEX],
                                                   NEGATIVE_CLASS)
        target[..., SCORE_INDEX] = torch.ones_like(target[..., SCORE_INDEX])

        for i, gt in enumerate(ground_truth):
            if not len(gt):
                continue

            gt_boxes = gt[:, det_ds.LOC_INDEX_START:det_ds.LOC_INDEX_END]
            weights = box_utils.iou(gt_boxes, corner_anchors)

            box_idx = matcher.match_per_prediction(weights,
                                                   self.matched_threshold,
                                                   self.unmatched_threshold)
            matched = box_idx.ne(matcher.NOT_MATCHED) & box_idx.ne(
                matcher.IGNORE)

            target[i, matched, LOC_INDEX_START:LOC_INDEX_END] = gt[
                box_idx[matched], det_ds.LOC_INDEX_START:det_ds.LOC_INDEX_END]
            target[i, matched, CLASS_INDEX] = gt[box_idx[matched],
                                                 det_ds.CLASS_INDEX]
            target[i, matched, SCORE_INDEX] = gt[box_idx[matched],
                                                 det_ds.SCORE_INDEX]

            ingored = box_idx.eq(IGNORE_CLASS)
            target[i, ingored, CLASS_INDEX] = IGNORE_CLASS
            target[i, ingored, SCORE_INDEX] = IGNORE_CLASS

        positive = target[..., CLASS_INDEX].ne(NEGATIVE_CLASS) & target[
            ..., CLASS_INDEX].ne(IGNORE_CLASS)
        assert not torch.isnan(
            target[..., LOC_INDEX_START:LOC_INDEX_END][positive]).any().item()

        return target
    def encode_ground_truth(self, ground_truth, priors):
        """
        Args:
            ground_truth: list(:len Batch) of torch.tensor(:shape [Boxes_i, 5])
            priors: torch.tensor(:shape [AnchorBoxes, 4])
        Returns:
            target_classes: torch.tensor(:shape [Batch, AnchorBoxes])
            target_locs: torch.tensor(:shape [Batch, AnchorBoxes, 4])
        """
        batch_size = len(ground_truth)
        num_priors = priors.size(0)
        device = priors.device

        ground_truth = [x.to(device, non_blocking=True) for x in ground_truth]
        corner_priors = box_utils.to_corners(priors)

        target_classes = torch.zeros((batch_size, num_priors),
                                     dtype=torch.long,
                                     device=device)
        target_locs = torch.zeros((batch_size, num_priors, 4),
                                  dtype=torch.float32,
                                  device=device)

        for i, gt in enumerate(ground_truth):
            if len(gt) == 0:
                continue

            weights = box_utils.jaccard(gt[:, :4], corner_priors)

            box_idx = match_per_prediction(
                weights,
                matched_threshold=self.matched_threshold,
                unmatched_threshold=self.unmatched_threshold)
            matched = box_idx >= 0
            target_classes[i, matched] = gt[box_idx[matched], 4].long()
            target_locs[i, matched] = gt[box_idx[matched], :4]

            ingored = box_idx == -1
            target_classes[i, ingored] = -1

        target_locs = box_utils.to_centroids(target_locs)
        target_locs = self.box_coder.encode_box(target_locs,
                                                priors,
                                                inplace=False)

        assert not torch.isnan(target_locs[target_classes.gt(0)]).any().item()

        return target_classes, target_locs
示例#3
0
    def postprocess(self, prediction, priors):
        """
        Args:
            prediction: tuple of
                torch.tensor(:shape [Batch, AnchorBoxes * Classes])
                torch.tensor(:shape [Batch, AnchorBoxes * 4])
            priors: torch.tensor(:shape [AnchorBoxes, 4]
        Returns:
            processed: list(:len Batch) of torch.tensor(:shape [Boxes_i, 6] ~ {[0-3] - box, [4] - class, [5] - score})
        """
        b_scores, b_boxes = prediction

        batch_size = b_scores.size(0)
        num_priors = priors.size(0)

        b_scores = b_scores.view(batch_size, num_priors, -1)
        b_scores = self.score_converter_fn(b_scores)
        num_classes = b_scores.size(-1)

        if self.score_converter != 'SIGMOID':
            num_classes = num_classes - 1
            b_scores = b_scores[..., 1:]

        b_scores = b_scores.cpu()

        b_boxes = b_boxes.view(batch_size, num_priors, 4)
        b_boxes = b_boxes.to(priors.device)
        b_boxes = self.box_coder.decode_box(b_boxes, priors, inplace=False)
        b_boxes = box_utils.to_corners(b_boxes)
        b_boxes = b_boxes.cpu()

        processed = []
        for scores, boxes in zip(b_scores, b_boxes):
            picked = []

            for class_index in range(0, num_classes):
                class_scores = scores[:, class_index]
                mask = class_scores > self.score_threshold

                (boxes_picked,
                 scores_picked), _ = self.nms(boxes[mask], class_scores[mask])
                classes_picked = torch.full_like(scores_picked.unsqueeze_(1),
                                                 class_index + 1,
                                                 dtype=torch.float)

                picked.append(
                    torch.cat([boxes_picked, classes_picked, scores_picked],
                              dim=-1))

            picked = torch.cat(picked, dim=0)

            if self.max_total is not None and self.max_total < picked.size(0):
                _, indexes = torch.topk(picked[:, 5],
                                        self.max_total,
                                        sorted=True,
                                        largest=True)
                picked = picked[indexes]

            processed.append(picked)

        return processed
    def forward(self, pred, anchors, target):
        """
        Args:
            pred: tuple of
                torch.tensor(:shape [Batch, AnchorBoxes * Classes])
                torch.tensor(:shape [Batch, AnchorBoxes * 4])
            target: torch.tensor(:shape [Batch, AnchorBoxes, 6])
        Returns:
            losses: tuple(float, float)
        """
        scores, locs = pred

        target_locs = target[..., LOC_INDEX_START:LOC_INDEX_END]
        target_classes = target[..., CLASS_INDEX].long()
        target_scores = target[..., SCORE_INDEX]

        batch_size = target.size(0)
        num_priors = target.size(1)

        scores = scores.view(batch_size, num_priors, -1)
        locs = locs.view(batch_size, num_priors, 4)

        positive_mask = target_classes.ne(NEGATIVE_CLASS) & target_classes.ne(
            IGNORE_CLASS)
        sampled_mask = self.sampler(scores, target_classes)

        scores = scores[sampled_mask]
        target_classes = target_classes[sampled_mask]
        target_scores = target_scores[sampled_mask]

        if self.multiclass:
            class_target = torch.zeros_like(scores)
            mask = target_classes.ne(NEGATIVE_CLASS) & target_classes.ne(
                IGNORE_CLASS)
            class_target[mask, target_classes[mask] - 1] = target_scores[mask]
        elif self.soft_target:
            class_target = torch.zeros_like(scores)
            mask = target_classes.ne(IGNORE_CLASS)
            class_target[mask, target_classes[mask]] = target_scores[mask]
        else:
            class_target = target_classes.view(-1)

        class_loss = self.classification_loss(scores, class_target)

        if self.iou_loss:
            locs = self.box_coder.decode_box(locs, anchors)
            locs = box_utils.to_corners(locs)
        else:
            box_utils.to_centroids(target_locs, inplace=True)
            self.box_coder.encode_box(target_locs, anchors, inplace=True)

        positive_locs = locs[positive_mask].view(-1, 4)
        positive_target_locs = target_locs[positive_mask].view(-1, 4)
        loc_loss = self.localization_loss(positive_locs, positive_target_locs)

        divider = positive_mask.sum().clamp_(min=1).float()
        loc_loss.mul_(self.localization_weight).div_(divider)
        class_loss.mul_(self.classification_weight).div_(divider)

        loss = class_loss + loc_loss

        return loss, class_loss, loc_loss