コード例 #1
0
class Detector(object):
    def __init__(self, weights_path):
        self.model = SSD('test')
        self.model.cuda().eval()

        state = torch.load(weights_path,
                           map_location=lambda storage, loc: storage)
        state = {key: value.float() for key, value in state.items()}
        self.model.load_state_dict(state)

        self.transform = GeneralizedRCNNTransform(DETECTOR_MIN_SIZE,
                                                  DETECTOR_MAX_SIZE,
                                                  DETECTOR_MEAN, DETECTOR_STD)
        self.transform.eval()

    def detect(self, images):
        images = torch.stack(
            [torch.from_numpy(image).cuda() for image in images])
        images = images.transpose(1, 3).transpose(2, 3).float()
        original_image_sizes = [img.shape[-2:] for img in images]
        images, _ = self.transform(images, None)
        with torch.no_grad():
            detections_batch = self.model(images.tensors).cpu().numpy()
        result = []
        for detections, image_size in zip(detections_batch,
                                          images.image_sizes):
            scores = detections[1, :, 0]
            keep_idxs = scores > DETECTOR_THRESHOLD
            detections = detections[1, keep_idxs, :]
            detections = detections[:, [1, 2, 3, 4, 0]]
            detections[:, 0] *= image_size[1]
            detections[:, 1] *= image_size[0]
            detections[:, 2] *= image_size[1]
            detections[:, 3] *= image_size[0]
            result.append({
                'scores': torch.from_numpy(detections[:, 4]),
                'boxes': torch.from_numpy(detections[:, :4])
            })

        result = self.transform.postprocess(result, images.image_sizes,
                                            original_image_sizes)
        return result
コード例 #2
0
ファイル: retinanet.py プロジェクト: IntelAI/models
class RetinaNet(nn.Module):
    """
    Implements RetinaNet.

    The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
    image, and should be in 0-1 range. Different images can have different sizes.

    The behavior of the model changes depending if it is in training or evaluation mode.

    During training, the model expects both the input tensors, as well as a targets (list of dictionary),
    containing:
        - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
        - labels (Int64Tensor[N]): the class label for each ground-truth box

    The model returns a Dict[Tensor] during training, containing the classification and regression
    losses.

    During inference, the model requires only the input tensors, and returns the post-processed
    predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
    follows:
        - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
        - labels (Int64Tensor[N]): the predicted labels for each image
        - scores (Tensor[N]): the scores for each prediction

    Args:
        backbone (nn.Module): the network used to compute the features for the model.
            It should contain an out_channels attribute, which indicates the number of output
            channels that each feature map has (and it should be the same for all feature maps).
            The backbone should return a single Tensor or an OrderedDict[Tensor].
        num_classes (int): number of output classes of the model (including the background).
        min_size (int): minimum size of the image to be rescaled before feeding it to the backbone
        max_size (int): maximum size of the image to be rescaled before feeding it to the backbone
        image_mean (Tuple[float, float, float]): mean values used for input normalization.
            They are generally the mean values of the dataset on which the backbone has been trained
            on
        image_std (Tuple[float, float, float]): std values used for input normalization.
            They are generally the std values of the dataset on which the backbone has been trained on
        anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
            maps.
        head (nn.Module): Module run on top of the feature pyramid.
            Defaults to a module containing a classification and regression module.
        score_thresh (float): Score threshold used for postprocessing the detections.
        nms_thresh (float): NMS threshold used for postprocessing the detections.
        detections_per_img (int): Number of best detections to keep after NMS.
        fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
            considered as positive during training.
        bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
            considered as negative during training.
        topk_candidates (int): Number of best detections to keep before NMS.

    Example:

        >>> import torch
        >>> import torchvision
        >>> from torchvision.models.detection import RetinaNet
        >>> from torchvision.models.detection.anchor_utils import AnchorGenerator
        >>> # load a pre-trained model for classification and return
        >>> # only the features
        >>> backbone = torchvision.models.mobilenet_v2(pretrained=True).features
        >>> # RetinaNet needs to know the number of
        >>> # output channels in a backbone. For mobilenet_v2, it's 1280
        >>> # so we need to add it here
        >>> backbone.out_channels = 1280
        >>>
        >>> # let's make the network generate 5 x 3 anchors per spatial
        >>> # location, with 5 different sizes and 3 different aspect
        >>> # ratios. We have a Tuple[Tuple[int]] because each feature
        >>> # map could potentially have different sizes and
        >>> # aspect ratios
        >>> anchor_generator = AnchorGenerator(
        >>>     sizes=((32, 64, 128, 256, 512),),
        >>>     aspect_ratios=((0.5, 1.0, 2.0),)
        >>> )
        >>>
        >>> # put the pieces together inside a RetinaNet model
        >>> model = RetinaNet(backbone,
        >>>                   num_classes=2,
        >>>                   anchor_generator=anchor_generator)
        >>> model.eval()
        >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
        >>> predictions = model(x)
    """
    __annotations__ = {
        'box_coder': det_utils.BoxCoder,
        'proposal_matcher': det_utils.Matcher,
    }

    def __init__(
            self,
            backbone,
            num_classes,
            # transform parameters
            min_size=800,
            max_size=1333,
            image_mean=None,
            image_std=None,
            # Anchor parameters
            anchor_generator=None,
            head=None,
            proposal_matcher=None,
            score_thresh=0.05,
            nms_thresh=0.5,
            detections_per_img=300,
            fg_iou_thresh=0.5,
            bg_iou_thresh=0.4,
            topk_candidates=1000):
        super().__init__()

        if not hasattr(backbone, "out_channels"):
            raise ValueError(
                "backbone should contain an attribute out_channels "
                "specifying the number of output channels (assumed to be the "
                "same for all the levels)")
        self.backbone = backbone

        assert isinstance(anchor_generator, (AnchorGenerator, type(None)))

        if anchor_generator is None:
            anchor_sizes = tuple(
                (x, int(x * 2**(1.0 / 3)), int(x * 2**(2.0 / 3)))
                for x in [32, 64, 128, 256, 512])
            aspect_ratios = ((0.5, 1.0, 2.0), ) * len(anchor_sizes)
            anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios)
        self.anchor_generator = anchor_generator

        if head is None:
            head = RetinaNetHead(
                backbone.out_channels,
                anchor_generator.num_anchors_per_location()[0], num_classes)
        self.head = head

        if proposal_matcher is None:
            proposal_matcher = det_utils.Matcher(
                fg_iou_thresh,
                bg_iou_thresh,
                allow_low_quality_matches=True,
            )
        self.proposal_matcher = proposal_matcher

        self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))

        if image_mean is None:
            image_mean = [0.485, 0.456, 0.406]
        if image_std is None:
            image_std = [0.229, 0.224, 0.225]
        self.transform = GeneralizedRCNNTransform(min_size, max_size,
                                                  image_mean, image_std)

        self.score_thresh = score_thresh
        self.nms_thresh = nms_thresh
        self.detections_per_img = detections_per_img
        self.topk_candidates = topk_candidates

        # used only on torchscript mode
        self._has_warned = False

    @torch.jit.unused
    def eager_outputs(self, losses, detections):
        # type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
        if self.training:
            return losses

        return detections

    def compute_loss(self, targets, head_outputs, anchors):
        # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor]) -> Dict[str, Tensor]
        matched_idxs = []
        for anchors_per_image, targets_per_image in zip(anchors, targets):
            if targets_per_image['boxes'].numel() == 0:
                matched_idxs.append(
                    torch.full((anchors_per_image.size(0), ),
                               -1,
                               dtype=torch.int64,
                               device=anchors_per_image.device))
                continue

            match_quality_matrix = box_ops.box_iou(targets_per_image['boxes'],
                                                   anchors_per_image)
            matched_idxs.append(self.proposal_matcher(match_quality_matrix))

        return self.head.compute_loss(targets, head_outputs, anchors,
                                      matched_idxs)

    def postprocess_detections(self, head_outputs, anchors, image_shapes):
        # type: (Dict[str, List[Tensor]], List[List[Tensor]], List[Tuple[int, int]]) -> List[Dict[str, Tensor]]
        class_logits = head_outputs['cls_logits']
        box_regression = head_outputs['bbox_regression']

        num_images = len(image_shapes)

        detections: List[Dict[str, Tensor]] = []

        for index in range(num_images):
            box_regression_per_image = [br[index] for br in box_regression]
            logits_per_image = [cl[index] for cl in class_logits]
            anchors_per_image, image_shape = anchors[index], image_shapes[
                index]

            image_boxes = []
            image_scores = []
            image_labels = []

            for box_regression_per_level, logits_per_level, anchors_per_level in \
                    zip(box_regression_per_image, logits_per_image, anchors_per_image):
                num_classes = logits_per_level.shape[-1]

                # remove low scoring boxes
                scores_per_level = torch.sigmoid(logits_per_level).flatten()
                keep_idxs = scores_per_level > self.score_thresh
                scores_per_level = scores_per_level[keep_idxs]
                topk_idxs = torch.where(keep_idxs)[0]

                # keep only topk scoring predictions
                num_topk = min(self.topk_candidates, topk_idxs.size(0))
                scores_per_level, idxs = scores_per_level.topk(num_topk)
                topk_idxs = topk_idxs[idxs]

                anchor_idxs = torch.div(topk_idxs,
                                        num_classes,
                                        rounding_mode='floor')
                labels_per_level = topk_idxs % num_classes

                boxes_per_level = self.box_coder.decode_single(
                    box_regression_per_level[anchor_idxs],
                    anchors_per_level[anchor_idxs])
                boxes_per_level = box_ops.clip_boxes_to_image(
                    boxes_per_level, image_shape)

                image_boxes.append(boxes_per_level)
                image_scores.append(scores_per_level)
                image_labels.append(labels_per_level)

            image_boxes = torch.cat(image_boxes, dim=0)
            image_scores = torch.cat(image_scores, dim=0)
            image_labels = torch.cat(image_labels, dim=0)

            # non-maximum suppression
            keep = box_ops.batched_nms(image_boxes, image_scores, image_labels,
                                       self.nms_thresh)
            keep = keep[:self.detections_per_img]

            detections.append({
                'boxes': image_boxes[keep],
                'scores': image_scores[keep],
                'labels': image_labels[keep],
            })

        return detections

    def forward(self, images, targets=None):
        # type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
        """
        Args:
            images (list[Tensor]): images to be processed
            targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional)

        Returns:
            result (list[BoxList] or dict[Tensor]): the output from the model.
                During training, it returns a dict[Tensor] which contains the losses.
                During testing, it returns list[BoxList] contains additional fields
                like `scores`, `labels` and `mask` (for Mask R-CNN models).

        """
        if self.training and targets is None:
            raise ValueError("In training mode, targets should be passed")

        if self.training:
            assert targets is not None
            for target in targets:
                boxes = target["boxes"]
                if isinstance(boxes, torch.Tensor):
                    if len(boxes.shape) != 2 or boxes.shape[-1] != 4:
                        raise ValueError("Expected target boxes to be a tensor"
                                         "of shape [N, 4], got {:}.".format(
                                             boxes.shape))
                else:
                    raise ValueError("Expected target boxes to be of type "
                                     "Tensor, got {:}.".format(type(boxes)))

        # get the original image sizes
        original_image_sizes: List[Tuple[int, int]] = []
        for img in images:
            val = img.shape[-2:]
            assert len(val) == 2
            original_image_sizes.append((val[0], val[1]))

        # transform the input
        images, targets = self.transform(images, targets)

        # Check for degenerate boxes
        # TODO: Move this to a function
        if targets is not None:
            for target_idx, target in enumerate(targets):
                boxes = target["boxes"]
                degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
                if degenerate_boxes.any():
                    # print the first degenerate box
                    bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
                    degen_bb: List[float] = boxes[bb_idx].tolist()
                    raise ValueError(
                        "All bounding boxes should have positive height and width."
                        " Found invalid box {} for target at index {}.".format(
                            degen_bb, target_idx))

        # get the features from the backbone
        features = self.backbone(images.tensors)
        if isinstance(features, torch.Tensor):
            features = OrderedDict([('0', features)])

        # TODO: Do we want a list or a dict?
        features = list(features.values())
        for idx in range(len(features)):
            features[idx] = features[idx].to(torch.float32)

        # compute the retinanet heads outputs using the features
        head_outputs = self.head(features)

        for key in head_outputs:
            head_outputs[key] = head_outputs[key].to(torch.float32)

        # create the set of anchors
        anchors = self.anchor_generator(images, features)

        losses = {}
        detections: List[Dict[str, Tensor]] = []
        if self.training:
            assert targets is not None

            # compute the losses
            losses = self.compute_loss(targets, head_outputs, anchors)
        else:
            # recover level sizes
            num_anchors_per_level = [x.size(2) * x.size(3) for x in features]
            HW = 0
            for v in num_anchors_per_level:
                HW += v
            HWA = head_outputs['cls_logits'].size(1)
            A = HWA // HW
            num_anchors_per_level = [hw * A for hw in num_anchors_per_level]

            # split outputs per level
            split_head_outputs: Dict[str, List[Tensor]] = {}
            for k in head_outputs:
                split_head_outputs[k] = list(head_outputs[k].split(
                    num_anchors_per_level, dim=1))
            split_anchors = [
                list(a.split(num_anchors_per_level)) for a in anchors
            ]

            # compute the detections
            detections = self.postprocess_detections(split_head_outputs,
                                                     split_anchors,
                                                     images.image_sizes)
            detections = self.transform.postprocess(detections,
                                                    images.image_sizes,
                                                    original_image_sizes)

        if torch.jit.is_scripting():
            if not self._has_warned:
                warnings.warn(
                    "RetinaNet always returns a (Losses, Detections) tuple in scripting"
                )
                self._has_warned = True
            return losses, detections
        return self.eager_outputs(losses, detections)
コード例 #3
0
class YOLO(nn.Module):
    def __init__(
        self,
        backbone: nn.Module,
        num_classes: int,
        anchor_grids: List[List[int]],
        # transform parameters
        min_size: int = 320,
        max_size: int = 416,
        image_mean: Optional[List[float]] = None,
        image_std: Optional[List[float]] = None,
        # Anchor parameters
        anchor_generator: Optional[nn.Module] = None,
        head: Optional[nn.Module] = None,
        # Training parameter
        compute_loss: Optional[nn.Module] = None,
        fg_iou_thresh: float = 0.5,
        bg_iou_thresh: float = 0.4,
        # Post Process parameter
        postprocess_detections: Optional[nn.Module] = None,
        score_thresh: float = 0.05,
        nms_thresh: float = 0.5,
        detections_per_img: int = 300,
    ):
        super().__init__()
        if not hasattr(backbone, "out_channels"):
            raise ValueError(
                "backbone should contain an attribute out_channels "
                "specifying the number of output channels (assumed to be the "
                "same for all the levels)")
        self.backbone = backbone

        if anchor_generator is None:
            strides: List[int] = [8, 16, 32]
            anchor_generator = AnchorGenerator(strides, anchor_grids)
        self.anchor_generator = anchor_generator

        if compute_loss is None:
            compute_loss = SetCriterion(
                weights=(1.0, 1.0, 1.0, 1.0),
                fg_iou_thresh=fg_iou_thresh,
                bg_iou_thresh=bg_iou_thresh,
            )
        self.compute_loss = compute_loss

        if head is None:
            head = YoloHead(
                backbone.out_channels,
                anchor_generator.num_anchors,
                num_classes,
            )
        self.head = head

        if image_mean is None:
            image_mean = [0., 0., 0.]
        if image_std is None:
            image_std = [1., 1., 1.]

        self.transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std)

        if postprocess_detections is None:
            postprocess_detections = PostProcess(score_thresh, nms_thresh, detections_per_img)
        self.postprocess_detections = postprocess_detections

        # used only on torchscript mode
        self._has_warned = False

    @torch.jit.unused
    def eager_outputs(
        self,
        losses: Dict[str, Tensor],
        detections: List[Dict[str, Tensor]],
    ) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]:
        if self.training:
            return losses

        return detections

    def forward(
        self,
        images: List[Tensor],
        targets: Optional[List[Dict[str, Tensor]]] = None,
    ) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]:
        """
        Arguments:
            images (list[Tensor]): images to be processed
            targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional)

        Returns:
            result (list[BoxList] or dict[Tensor]): the output from the model.
                During Training, it returns a dict[Tensor] which contains the losses
                TODO, currently this repo doesn't support training.
                During Testing, it returns list[BoxList] contains additional fields
                like `scores` and `labels`.
        """
        # get the original image sizes
        original_image_sizes = torch.jit.annotate(List[Tuple[int, int]], [])
        for img in images:
            val = img.shape[-2:]
            assert len(val) == 2
            original_image_sizes.append((val[0], val[1]))

        # transform the input
        images, targets = self.transform(images, targets)

        # get the features from the backbone
        features = self.backbone(images.tensors)

        # compute the yolo heads outputs using the features
        head_outputs = self.head(features)

        # create the set of anchors
        anchors_tuple = self.anchor_generator(features)
        losses = {}
        detections = torch.jit.annotate(List[Dict[str, Tensor]], [])

        if self.training:
            assert targets is not None

            # compute the losses
            losses = self.compute_loss(targets, head_outputs, anchors_tuple[0])
        else:
            # compute the detections
            detections = self.postprocess_detections(head_outputs, anchors_tuple, images.image_sizes)
            detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)

        if torch.jit.is_scripting():
            if not self._has_warned:
                warnings.warn("YOLO always returns a (Losses, Detections) tuple in scripting")
                self._has_warned = True
            return losses, detections
        else:
            return self.eager_outputs(losses, detections)
コード例 #4
0
def evaluate_yolo_2017(model, data_loader, device):
    n_threads = torch.get_num_threads()
    # FIXME remove this and make paste_masks_in_image run on the GPU
    torch.set_num_threads(1)
    cpu_device = torch.device("cpu")
    model.eval()
    metric_logger = utils.MetricLogger(delimiter="  ")
    header = 'Test:'

    coco = get_coco_api_from_dataset(data_loader.dataset)
    iou_types = _get_iou_types(model)
    coco_evaluator = CocoEvaluator(coco, iou_types)
    transform = GeneralizedRCNNTransform(416, 416, [0, 0, 0], [1, 1, 1])
    transform.eval()
    for image, targets in metric_logger.log_every(data_loader, 100, header):
        image = list(img.to(device) for img in image)

        original_image_sizes = [img.shape[-2:] for img in image]

        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        torch.cuda.synchronize()
        model_time = time.time()
        transformed_img = transform(image)
        transformed_shape = transformed_img[0].tensors.shape[-2:]
        inf_out, _ = model(transformed_img[0].tensors)
        # Run NMS
        output = non_max_suppression(inf_out, conf_thres=0.001, iou_thres=0.6)

        # Statistics per image
        predictions = []
        for si, pred in enumerate(output):
            prediction = {'boxes': [], 'labels': [], 'scores': []}
            if pred is None:
                continue
            # Append to text file
            # with open('test.txt', 'a') as file:
            #    [file.write('%11.5g' * 7 % tuple(x) + '\n') for x in pred]

            # Clip boxes to image bounds
            clip_coords(pred, transformed_shape)
            # Append to pycocotools JSON dictionary
            # [{"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}, ...
            image_id = int(targets[si]['image_id'])
            box = pred[:, :4].clone()  # xyxy
            # scale_coords(transformed_shape, box, shapes[si][0], shapes[si][1])  # to original shape
            # box = xyxy2xywh(box)  # xywh
            # box[:, :2] -= box[:, 2:] / 2  # xy center to top-left corner
            for di, d in enumerate(pred):
                box_T = [floatn(x, 3) for x in box[di]]
                label = coco91class[int(d[5])]
                score = floatn(d[4], 5)
                prediction['boxes'].append(box_T)
                prediction['labels'].append(label)
                prediction['scores'].append(score)
            prediction['boxes'] = torch.tensor(prediction['boxes'])
            prediction['labels'] = torch.tensor(prediction['labels'])
            prediction['scores'] = torch.tensor(prediction['scores'])
            predictions.append(prediction)

        outputs = transform.postprocess(predictions,
                                        transformed_img[0].image_sizes,
                                        original_image_sizes)

        outputs = [{k: v.to(cpu_device)
                    for k, v in t.items()} for t in predictions]
        model_time = time.time() - model_time

        res = {
            target["image_id"].item(): output
            for target, output in zip(targets, outputs)
        }
        evaluator_time = time.time()
        coco_evaluator.update(res)
        evaluator_time = time.time() - evaluator_time
        metric_logger.update(model_time=model_time,
                             evaluator_time=evaluator_time)

    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    coco_evaluator.synchronize_between_processes()

    # accumulate predictions from all images
    coco_evaluator.accumulate()
    coco_evaluator.summarize()
    torch.set_num_threads(n_threads)
    return coco_evaluator
コード例 #5
0
class Retinanet(nn.Module):
    """
    Implement RetinaNet in :paper:`RetinaNet`.

    The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
    image, and should be in 0-1 range. Different images can have different sizes.

    The behavior of the model changes depending if it is in training or evaluation mode.

    During training, the model expects both the input tensors, as well as a targets (list of dictionary),
    containing:
        - boxes (FloatTensor[N, 4]): the ground-truth boxes in [x1, y1, x2, y2] format, with values of x
          between 0 and W and values of y between 0 and H
        - labels (Int64Tensor[N]): the class label for each ground-truth box

    The model returns a Dict[Tensor] during training, containing the `classification` and `regression` losses for
    the `RetinaNet` `classSubnet` & `BoxSubnet` repectively.

    For infererence, use `.predict` 
    During inference, the model requires only the input tensors, and returns the post-processed
    predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
    follows:
        - boxes (FloatTensor[N, 4]): the predicted boxes in [x1, y1, x2, y2] format, with values of x
          between 0 and W and values of y between 0 and H
        - labels (Int64Tensor[N]): the predicted labels for each image
        - scores (Tensor[N]): the scores or each prediction

    Arguments:
        - num_classes   (int): number of output classes of the model (excluding the background).

        - backbone_kind (str): the network used to compute the features for the model.
                               currently support only `Resnet` networks.
        - prior       (float): Prior prob for rare case (i.e. foreground) at the beginning of training.
        - pretrianed   (bool): Wether the backbone should be `pretrained` or not.
        - nms_thres   (float): Overlap threshold used for non-maximum suppression
                               (suppress boxes with IoU >= this threshold).
        - score_thres (float): Minimum score threshold (assuming scores in a [0, 1] range.
        - max_detections_per_images(int): Number of proposals to keep after applying NMS.
        - freeze_bn   (bool) : Wether to freeze the `BatchNorm` layers of the `BackBone` network.
        - anchor_generator(AnchorGenertor): Must be an instance of `AnchorGenerator`.
                                            If None the default AnchorGenerator is used.
                                            see `config.py`
        - min_size (int)     : `minimum size` of the image to be rescaled before
                               feeding it to the backbone.
        - max_size (int)     : `maximum size` of the image to be rescaled before
                               feeding it to the backbone.
        - image_mean (List[float]): mean values used for input normalization.
        - image_std (List[float]) : std values used for input normalization.

    >>> For default values see `config.py`
    """

    def __init__(
        self,
        num_classes: Optional[int] = None,
        backbone_kind: Optional[str] = None,
        prior: Optional[float] = None,
        pretrained: Optional[bool] = None,
        nms_thres: Optional[float] = None,
        score_thres: Optional[float] = None,
        max_detections_per_images: Optional[int] = None,
        freeze_bn: Optional[bool] = None,
        min_size: Optional[int] = None,
        max_size: Optional[int] = None,
        image_mean: Optional[List[float]] = None,
        image_std: Optional[List[float]] = None,
        anchor_generator: Optional[AnchorGenerator] = None,
        logger=None,
    ) -> None:

        super(Retinanet, self).__init__()

        # Set Parameters
        num_classes = ifnone(num_classes, NUM_CLASSES)
        backbone_kind = ifnone(backbone_kind, BACKBONE)
        prior = ifnone(prior, PRIOR)
        pretrained = ifnone(pretrained, PRETRAINED_BACKBONE)
        nms_thres = ifnone(nms_thres, NMS_THRES)
        score_thres = ifnone(score_thres, SCORE_THRES)
        max_detections_per_images = ifnone(max_detections_per_images, MAX_DETECTIONS_PER_IMAGE)
        freeze_bn = ifnone(freeze_bn, FREEZE_BN)
        min_size = ifnone(min_size, MIN_IMAGE_SIZE)
        max_size = ifnone(max_size, MAX_IMAGE_SIZE)
        image_mean = ifnone(image_mean, MEAN)
        image_std = ifnone(image_std, STD)
        anchor_generator = ifnone(anchor_generator, AnchorGenerator())
        logger = ifnone(logger, logging.getLogger(__name__))
        logger.name = __name__

        if backbone_kind not in __small__ + __big__:
            _prompt = f"Expected `backbone_kind` to be one of {__small__+__big__} got {backbone_kind}"
            raise ValueError(_prompt)

        # Instantiate modules for RetinaNet
        self.backbone_kind = backbone_kind
        self.transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std)
        self.backbone = get_backbone(backbone_kind, pretrained, freeze_bn=freeze_bn)
        fpn_szs = self._get_backbone_ouputs()
        self.fpn = FeaturePyramid(fpn_szs[0], fpn_szs[1], fpn_szs[2], 256)
        self.anchor_generator = anchor_generator
        num_anchors = self.anchor_generator.num_cell_anchors[0]
        self.retinanet_head = RetinaNetHead(256, 256, num_anchors, num_classes, prior)

        # Parameters for detection
        self.score_thres        = score_thres
        self.nms_thres          = nms_thres
        self.detections_per_img = max_detections_per_images
        self.num_classes        = num_classes

        # Log some information
        logger.info(f"BACKBONE     : {backbone_kind}")
        logger.info(f"INPUT_PARAMS : MAX_SIZE={max_size}, MIN_SIZE={min_size}")
        logger.info(f"NUM_CLASSES  : {self.num_classes}")

    def _get_backbone_ouputs(self) -> List:
        if self.backbone_kind in __small__:
            fpn_szs = [
                self.backbone.backbone.layer2[1].conv2.out_channels,
                self.backbone.backbone.layer3[1].conv2.out_channels,
                self.backbone.backbone.layer4[1].conv2.out_channels,
            ]
            return fpn_szs

        elif self.backbone_kind in __big__:
            fpn_szs = [
                self.backbone.backbone.layer2[2].conv3.out_channels,
                self.backbone.backbone.layer3[2].conv3.out_channels,
                self.backbone.backbone.layer4[2].conv3.out_channels,
            ]
            return fpn_szs

    def compute_loss(
        self,
        targets: List[Dict[str, Tensor]],
        outputs: Dict[str, Tensor],
        anchors: List[Tensor],
    ) -> Dict[str, Tensor]:
        return self.retinanet_head.compute_loss(targets, outputs, anchors)

    def process_detections(
        self,
        outputs: Dict[str, Tensor],
        anchors: List[Tensor],
        im_szs: List[Tuple[int, int]],
    ) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]:
        " Process `outputs` and return the predicted bboxes, score, clas_labels above `detect_thres` "

        class_logits = outputs.pop("cls_preds")
        bboxes = outputs.pop("bbox_preds")
        scores = torch.sigmoid(class_logits)

        device = class_logits.device
        num_classes = class_logits.shape[-1]

        # create labels for each score
        labels = torch.arange(num_classes, device=device)
        labels = labels.view(1, -1).expand_as(scores)

        detections = torch.jit.annotate(List[Dict[str, Tensor]], [])

        for bb_per_im, sc_per_im, ancs_per_im, im_sz, lbl_per_im in zip(bboxes, scores, anchors, im_szs, labels):
            
            all_boxes = []
            all_scores = []
            all_labels = []
            # convert the activation i.e, outputs of the model to bounding boxes
            bb_per_im = activ_2_bbox(bb_per_im, ancs_per_im)
            # clip the bounding boxes to the image size
            bb_per_im = ops.clip_boxes_to_image(bb_per_im, im_sz)

            # Iterate over each `cls_idx` in `num_classes` and do nms
            # to each class individually
            for cls_idx in range(num_classes):
                # remove low predicitons with scores < score_thres
                #  and grab the predictions corresponding to the cls_idx
                inds = torch.gt(sc_per_im[:, cls_idx], self.score_thres)
                bb_per_cls, sc_per_cls, lbl_per_cls = (
                    bb_per_im[inds],
                    sc_per_im[inds, cls_idx],
                    lbl_per_im[inds, cls_idx],
                )
                # remove boxes that are too small ~(1-02)
                keep = ops.remove_small_boxes(bb_per_cls, min_size=1e-2)
                bb_per_cls, sc_per_cls, lbl_per_cls = (
                    bb_per_cls[keep],
                    sc_per_cls[keep],
                    lbl_per_cls[keep],
                )
                # compute non max supression to supress overlapping boxes
                keep = ops.nms(bb_per_cls, sc_per_cls, self.nms_thres)
                bb_per_cls, sc_per_cls, lbl_per_cls = (
                    bb_per_cls[keep],
                    sc_per_cls[keep],
                    lbl_per_cls[keep],
                )

                all_boxes.append(bb_per_cls)
                all_scores.append(sc_per_cls)
                all_labels.append(lbl_per_cls)

            # Convert to tensors
            all_boxes = torch.cat(all_boxes, dim=0)
            all_scores = torch.cat(all_scores, dim=0)
            all_labels = torch.cat(all_labels, dim=0)

            # model is going to predict classes which are going to be in the range of [0, num_classes]
            # 0 is reserved for the background class for which no loss is calculate , so
            # we will add 1 to all the class_predictions to shift the predicitons range from
            # [0, num_classes) -> [1, num_classes]
            all_labels = all_labels + 1

            # Sort by scores and
            # Grab the idxs from the corresponding to the topk predictions
            _, topk_idxs = all_scores.sort(descending=True)
            topk_idxs = topk_idxs[: self.detections_per_img]
            all_boxes, all_scores, all_labels = (
                all_boxes[topk_idxs],
                all_scores[topk_idxs],
                all_labels[topk_idxs],
            )

            detections.append({"boxes": all_boxes, "scores": all_scores, "labels": all_labels,})
        return detections

    def predict(self, images: List[Tensor]) -> List[Dict[str, Tensor]]:
        """
        Computes predictions for the given model
        """
        #set model to eval
        if self.training :
            self.training = False
        
        targets = None
        # get the original image sizes
        original_image_sizes = []
        for img in images:
            val = img.shape[-2:]
            assert len(val) == 2
            original_image_sizes.append((val[0], val[1]))
        
        # Foward pass of the Model
        images, targets = self.transform(images, targets)
        feature_maps    = self.backbone(images.tensors)
        feature_maps    = self.fpn(feature_maps)
        outputs         = self.retinanet_head(feature_maps)
        anchors         = self.anchor_generator(images, feature_maps)
        
        detections       = torch.jit.annotate(List[Dict[str, Tensor]], [])
        #compute the detections
        detections       = self.process_detections(outputs, anchors, images.image_sizes)
        final_detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)
        return final_detections

    def forward(self, images: List[Tensor], targets: Optional[List[Dict[str, Tensor]]]) -> Dict[str, Tensor]:
        """
        Computes the loss of the model
        """
        # Foward pass of the Model
        images, targets = self.transform(images, targets)
        feature_maps    = self.backbone(images.tensors)
        feature_maps    = self.fpn(feature_maps)
        outputs         = self.retinanet_head(feature_maps)
        # Generate anchors for the images
        anchors         = self.anchor_generator(images, feature_maps)
        # store losses
        losses = {}
        losses = self.compute_loss(targets, outputs, anchors)
        return losses
コード例 #6
0
ファイル: retina_ssm.py プロジェクト: wolfworld6/CALD
class RetinaNet(nn.Module):
    """
    Implements RetinaNet.
    The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
    image, and should be in 0-1 range. Different images can have different sizes.
    The behavior of the model changes depending if it is in training or evaluation mode.
    During training, the model expects both the input tensors, as well as a targets (list of dictionary),
    containing:
        - boxes (FloatTensor[N, 4]): the ground-truth boxes in [x1, y1, x2, y2] format, with values
          between 0 and H and 0 and W
        - labels (Int64Tensor[N]): the class label for each ground-truth box
    The model returns a Dict[Tensor] during training, containing the classification and regression
    losses.
    During inference, the model requires only the input tensors, and returns the post-processed
    predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
    follows:
        - boxes (FloatTensor[N, 4]): the predicted boxes in [x1, y1, x2, y2] format, with values between
          0 and H and 0 and W
        - labels (Int64Tensor[N]): the predicted labels for each image
        - scores (Tensor[N]): the scores for each prediction
    Arguments:
        backbone (nn.Module): the network used to compute the features for the model.
            It should contain an out_channels attribute, which indicates the number of output
            channels that each feature map has (and it should be the same for all feature maps).
            The backbone should return a single Tensor or an OrderedDict[Tensor].
        num_classes (int): number of output classes of the model (excluding the background).
        min_size (int): minimum size of the image to be rescaled before feeding it to the backbone
        max_size (int): maximum size of the image to be rescaled before feeding it to the backbone
        image_mean (Tuple[float, float, float]): mean values used for input normalization.
            They are generally the mean values of the dataset on which the backbone has been trained
            on
        image_std (Tuple[float, float, float]): std values used for input normalization.
            They are generally the std values of the dataset on which the backbone has been trained on
        anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
            maps.
        head (nn.Module): Module run on top of the feature pyramid.
            Defaults to a module containing a classification and regression module.
        score_thresh (float): Score threshold used for postprocessing the detections.
        nms_thresh (float): NMS threshold used for postprocessing the detections.
        detections_per_img (int): Number of best detections to keep after NMS.
        fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
            considered as positive during training.
        bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
            considered as negative during training.
    Example:
        >>> import torch
        >>> import torchvision
        >>> from torchvision.models.detection import RetinaNet
        >>> from torchvision.models.detection.anchor_utils import AnchorGenerator
        >>> # load a pre-trained model for classification and return
        >>> # only the features
        >>> backbone = torchvision.models.mobilenet_v2(pretrained=True).features
        >>> # RetinaNet needs to know the number of
        >>> # output channels in a backbone. For mobilenet_v2, it's 1280
        >>> # so we need to add it here
        >>> backbone.out_channels = 1280
        >>>
        >>> # let's make the network generate 5 x 3 anchors per spatial
        >>> # location, with 5 different sizes and 3 different aspect
        >>> # ratios. We have a Tuple[Tuple[int]] because each feature
        >>> # map could potentially have different sizes and
        >>> # aspect ratios
        >>> anchor_generator = AnchorGenerator(
        >>>     sizes=((32, 64, 128, 256, 512),),
        >>>     aspect_ratios=((0.5, 1.0, 2.0),)
        >>> )
        >>>
        >>> # put the pieces together inside a RetinaNet model
        >>> model = RetinaNet(backbone,
        >>>                   num_classes=2,
        >>>                   anchor_generator=anchor_generator)
        >>> model.eval()
        >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
        >>> predictions = model(x)
    """
    __annotations__ = {
        'box_coder': det_utils.BoxCoder,
        'proposal_matcher': det_utils.Matcher,
    }

    def __init__(
            self,
            backbone,
            num_classes,
            # transform parameters
            min_size=800,
            max_size=1333,
            image_mean=None,
            image_std=None,
            # Anchor parameters
            anchor_generator=None,
            head=None,
            proposal_matcher=None,
            score_thresh=0.05,
            nms_thresh=0.5,
            detections_per_img=50,
            fg_iou_thresh=0.5,
            bg_iou_thresh=0.4):
        super().__init__()

        if not hasattr(backbone, "out_channels"):
            raise ValueError(
                "backbone should contain an attribute out_channels "
                "specifying the number of output channels (assumed to be the "
                "same for all the levels)")
        self.backbone = backbone

        assert isinstance(anchor_generator, (AnchorGenerator, type(None)))

        if anchor_generator is None:
            anchor_sizes = tuple(
                (x, int(x * 2**(1.0 / 3)), int(x * 2**(2.0 / 3)))
                for x in [32, 64, 128, 256, 512])
            aspect_ratios = ((0.5, 1.0, 2.0), ) * len(anchor_sizes)
            anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios)
        self.anchor_generator = anchor_generator

        if head is None:
            head = RetinaNetHead(
                backbone.out_channels,
                anchor_generator.num_anchors_per_location()[0], num_classes)
        self.head = head

        if proposal_matcher is None:
            proposal_matcher = det_utils.Matcher(
                fg_iou_thresh,
                bg_iou_thresh,
                allow_low_quality_matches=True,
            )
        self.proposal_matcher = proposal_matcher

        self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))

        if image_mean is None:
            image_mean = [0.485, 0.456, 0.406]
        if image_std is None:
            image_std = [0.229, 0.224, 0.225]
        self.transform = GeneralizedRCNNTransform(min_size, max_size,
                                                  image_mean, image_std)

        self.score_thresh = score_thresh
        self.nms_thresh = nms_thresh
        self.detections_per_img = detections_per_img
        self.ssm = False
        # used only on torchscript mode
        self._has_warned = False

    @torch.jit.unused
    def eager_outputs(self, losses, detections):
        # type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
        if self.training:
            return losses

        return detections

    def ssm_mode(self, ssm):
        self.ssm = ssm

    def compute_loss(self, targets, head_outputs, anchors):
        # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor]) -> Dict[str, Tensor]
        matched_idxs = []
        for anchors_per_image, targets_per_image in zip(anchors, targets):
            if targets_per_image['boxes'].numel() == 0:
                matched_idxs.append(torch.empty((0, ), dtype=torch.int32))
                continue

            match_quality_matrix = box_ops.box_iou(targets_per_image['boxes'],
                                                   anchors_per_image)
            matched_idxs.append(self.proposal_matcher(match_quality_matrix))

        return self.head.compute_loss(targets, head_outputs, anchors,
                                      matched_idxs)

    def postprocess_detections(self, head_outputs, anchors, image_shapes):
        # type: (Dict[str, Tensor], List[Tensor], List[Tuple[int, int]]) -> List[Dict[str, Tensor]]
        # TODO: Merge this with roi_heads.RoIHeads.postprocess_detections ?

        class_logits = head_outputs.pop('cls_logits')
        box_regression = head_outputs.pop('bbox_regression')
        other_outputs = head_outputs

        device = class_logits.device
        num_classes = class_logits.shape[-1]

        scores = torch.sigmoid(class_logits)

        # create labels for each score
        labels = torch.arange(num_classes, device=device)
        labels = labels.view(1, -1).expand_as(scores)

        detections = torch.jit.annotate(List[Dict[str, Tensor]], [])

        for index, (box_regression_per_image, scores_per_image, labels_per_image, anchors_per_image, image_shape) in \
                enumerate(zip(box_regression, scores, labels, anchors, image_shapes)):

            boxes_per_image = self.box_coder.decode_single(
                box_regression_per_image, anchors_per_image)
            boxes_per_image = box_ops.clip_boxes_to_image(
                boxes_per_image, image_shape)

            other_outputs_per_image = [(k, v[index])
                                       for k, v in other_outputs.items()]

            image_boxes = []
            image_scores = []
            image_labels = []
            image_other_outputs = torch.jit.annotate(Dict[str, List[Tensor]],
                                                     {})
            for class_index in range(num_classes):
                # remove low scoring boxes
                inds = torch.gt(scores_per_image[:, class_index],
                                self.score_thresh)
                boxes_per_class, scores_per_class, labels_per_class = \
                    boxes_per_image[inds], scores_per_image[inds, class_index], labels_per_image[inds, class_index]
                other_outputs_per_class = [(k, v[inds])
                                           for k, v in other_outputs_per_image]

                # remove empty boxes
                keep = box_ops.remove_small_boxes(boxes_per_class,
                                                  min_size=1e-2)
                boxes_per_class, scores_per_class, labels_per_class = \
                    boxes_per_class[keep], scores_per_class[keep], labels_per_class[keep]
                other_outputs_per_class = [(k, v[keep])
                                           for k, v in other_outputs_per_class]

                # non-maximum suppression, independently done per class
                keep = box_ops.nms(boxes_per_class, scores_per_class,
                                   self.nms_thresh)

                # keep only topk scoring predictions
                keep = keep[:self.detections_per_img]
                boxes_per_class, scores_per_class, labels_per_class = \
                    boxes_per_class[keep], scores_per_class[keep], labels_per_class[keep]
                other_outputs_per_class = [(k, v[keep])
                                           for k, v in other_outputs_per_class]

                image_boxes.append(boxes_per_class)
                image_scores.append(scores_per_class)
                image_labels.append(labels_per_class)

                for k, v in other_outputs_per_class:
                    if k not in image_other_outputs:
                        image_other_outputs[k] = []
                    image_other_outputs[k].append(v)

            detections.append({
                'boxes': torch.cat(image_boxes, dim=0),
                'scores': torch.cat(image_scores, dim=0),
                'labels': torch.cat(image_labels, dim=0),
            })

            for k, v in image_other_outputs.items():
                detections[-1].update({k: torch.cat(v, dim=0)})

        return detections

    def ssm_postprocess_detections(self, head_outputs, anchors, image_shapes):
        # type: (Dict[str, Tensor], List[Tensor], List[Tuple[int, int]]) -> List[Dict[str, Tensor]]
        # TODO: Merge this with roi_heads.RoIHeads.postprocess_detections ?

        class_logits = head_outputs.pop('cls_logits')
        box_regression = head_outputs.pop('bbox_regression')
        other_outputs = head_outputs

        device = class_logits.device
        num_classes = class_logits.shape[-1]

        scores = torch.sigmoid(class_logits)

        # create labels for each score
        labels = torch.arange(num_classes, device=device)
        labels = labels.view(1, -1).expand_as(scores)

        detections = torch.jit.annotate(List[Dict[str, Tensor]], [])
        al_idx = 0
        all_boxes = torch.empty([0, 4]).cuda()
        all_scores = torch.tensor([]).cuda()
        all_labels = []
        CONF_THRESH = 0.5  # bigger leads more active learning samples
        for index, (box_regression_per_image, scores_per_image, labels_per_image, anchors_per_image, image_shape) in \
                enumerate(zip(box_regression, scores, labels, anchors, image_shapes)):
            if torch.max(scores_per_image) < CONF_THRESH:
                # print(scores)
                al_idx = 1
                detections.append({
                    "boxes": all_boxes,
                    "labels": all_labels,
                    "scores": all_scores,
                    'al': al_idx,
                })
                continue
            boxes_per_image = self.box_coder.decode_single(
                box_regression_per_image, anchors_per_image)
            boxes_per_image = box_ops.clip_boxes_to_image(
                boxes_per_image, image_shape)

            other_outputs_per_image = [(k, v[index])
                                       for k, v in other_outputs.items()]

            image_boxes = []
            image_scores = []
            image_labels = []
            image_other_outputs = torch.jit.annotate(Dict[str, List[Tensor]],
                                                     {})

            for class_index in range(num_classes):
                # remove low scoring boxes
                inds = torch.gt(scores_per_image[:, class_index],
                                self.score_thresh)
                boxes_per_class, scores_per_class, scores_all_class, labels_per_class = \
                    boxes_per_image[inds], scores_per_image[inds, class_index], scores_per_image[inds], \
                    labels_per_image[inds, class_index]
                other_outputs_per_class = [(k, v[inds])
                                           for k, v in other_outputs_per_image]

                keep = [i for i in range(len(boxes_per_class))]
                random.shuffle(keep)
                keep = keep[:500]
                boxes_per_class, scores_per_class, scores_all_class, labels_per_class = \
                    boxes_per_class[keep], scores_per_class[keep], scores_all_class[keep], labels_per_class[keep]
                other_outputs_per_class = [(k, v[keep])
                                           for k, v in other_outputs_per_class]

                # remove empty boxes
                keep = box_ops.remove_small_boxes(boxes_per_class,
                                                  min_size=1e-2)
                boxes_per_class, scores_per_class, scores_all_class, labels_per_class = \
                    boxes_per_class[keep], scores_per_class[keep], scores_all_class[keep], labels_per_class[keep]
                other_outputs_per_class = [(k, v[keep])
                                           for k, v in other_outputs_per_class]

                # non-maximum suppression, independently done per class
                keep = box_ops.nms(boxes_per_class, scores_per_class,
                                   self.nms_thresh)

                # keep only topk scoring predictions
                keep = keep[:self.detections_per_img]
                boxes_per_class, scores_per_class, scores_all_class, labels_per_class = \
                    boxes_per_class[keep], scores_per_class[keep], scores_all_class[keep], labels_per_class[keep]
                other_outputs_per_class = [(k, v[keep])
                                           for k, v in other_outputs_per_class]

                image_boxes.append(boxes_per_class)
                image_scores.append(scores_per_class)
                image_labels.append(labels_per_class)

                for k, v in other_outputs_per_class:
                    if k not in image_other_outputs:
                        image_other_outputs[k] = []
                    image_other_outputs[k].append(v)

                for i in range(len(boxes_per_class)):
                    all_boxes = torch.cat(
                        (all_boxes, boxes_per_class[i].unsqueeze(0)), 0)
                    all_scores = torch.cat(
                        (all_scores, scores_per_class[i].unsqueeze(0)), 0)
                    all_labels.append(judge_y(scores_all_class[i][1:]))
            detections.append({
                "boxes": all_boxes,
                "labels": all_labels,
                "scores": all_scores,
                'al': al_idx,
            })
            for k, v in image_other_outputs.items():
                detections[-1].update({k: torch.cat(v, dim=0)})

        return detections

    def forward(self, images, targets=None):
        # type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
        """
        Arguments:
            images (list[Tensor]): images to be processed
            targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional)
        Returns:
            result (list[BoxList] or dict[Tensor]): the output from the model.
                During training, it returns a dict[Tensor] which contains the losses.
                During testing, it returns list[BoxList] contains additional fields
                like `scores`, `labels` and `mask` (for Mask R-CNN models).
        """
        if self.training and targets is None:
            raise ValueError("In training mode, targets should be passed")

        if self.training:
            assert targets is not None
            for target in targets:
                boxes = target["boxes"]
                if isinstance(boxes, torch.Tensor):
                    if len(boxes.shape) != 2 or boxes.shape[-1] != 4:
                        raise ValueError("Expected target boxes to be a tensor"
                                         "of shape [N, 4], got {:}.".format(
                                             boxes.shape))
                else:
                    raise ValueError("Expected target boxes to be of type "
                                     "Tensor, got {:}.".format(type(boxes)))

        # get the original image sizes
        original_image_sizes = torch.jit.annotate(List[Tuple[int, int]], [])
        for img in images:
            val = img.shape[-2:]
            assert len(val) == 2
            original_image_sizes.append((val[0], val[1]))

        # transform the input
        images, targets = self.transform(images, targets)

        # Check for degenerate boxes
        # TODO: Move this to a function
        if targets is not None:
            for target_idx, target in enumerate(targets):
                boxes = target["boxes"]
                degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
                if degenerate_boxes.any():
                    # print the first degenerate box
                    bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
                    degen_bb: List[float] = boxes[bb_idx].tolist()
                    raise ValueError(
                        "All bounding boxes should have positive height and width."
                        " Found invalid box {} for target at index {}.".format(
                            degen_bb, target_idx))

        # get the features from the backbone
        features = self.backbone(images.tensors)
        if isinstance(features, torch.Tensor):
            features = OrderedDict([('0', features)])

        # TODO: Do we want a list or a dict?
        features = list(features.values())

        # compute the retinanet heads outputs using the features
        head_outputs = self.head(features)

        # create the set of anchors
        anchors = self.anchor_generator(images, features)

        losses = {}
        detections = torch.jit.annotate(List[Dict[str, Tensor]], [])
        if self.training:
            assert targets is not None

            # compute the losses
            losses = self.compute_loss(targets, head_outputs, anchors)
        else:
            # compute the detections
            # print(self.ssm)
            if self.ssm:
                detections = self.ssm_postprocess_detections(
                    head_outputs, anchors, images.image_sizes)
            else:
                detections = self.postprocess_detections(
                    head_outputs, anchors, images.image_sizes)
            detections = self.transform.postprocess(detections,
                                                    images.image_sizes,
                                                    original_image_sizes)

        if torch.jit.is_scripting():
            if not self._has_warned:
                warnings.warn(
                    "RetinaNet always returns a (Losses, Detections) tuple in scripting"
                )
                self._has_warned = True
            return (losses, detections)
        return self.eager_outputs(losses, detections)