Beispiel #1
0
    def __init__(self, anchor_ratios=(0.5, 1, 2), **kwargs):
        super().__init__(
            2,
            SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE,
                                          from_logits=True),
            L1Loss(reduction=tf.keras.losses.Reduction.NONE),
            multiples=len(anchor_ratios),
            kernel_initializer_classification_head=initializers.RandomNormal(stddev=0.01),
            kernel_initializer_box_prediction_head=initializers.RandomNormal(stddev=0.01),
            **kwargs)

        #Force each ground_truths to match to at least one anchor
        matcher = Matcher([0.3, 0.7], [0, -1, 1], allow_low_quality_matches=True)
        self.target_assigner = TargetAssigner(IoUSimilarity(),
                                              matcher,
                                              encode_boxes_faster_rcnn,
                                              dtype=self._compute_dtype)

        anchor_strides = (4, 8, 16, 32, 64)
        anchor_zises = (32, 64, 128, 256, 512)
        self._anchor_ratios = anchor_ratios

        # Precompute a deterministic grid of anchors for each layer of the pyramid.
        # We will extract a subpart of the anchors according to
        self._anchors = [
            Anchors(stride, size, self._anchor_ratios)
            for stride, size in zip(anchor_strides, anchor_zises)
        ]
Beispiel #2
0
    def __init__(self, num_classes, backbone, num_queries=300, **kwargs):
        super().__init__(**kwargs)

        self.num_classes = num_classes
        self.num_queries = num_queries
        self.hidden_dim = 256

        self.backbone = backbone
        self.input_proj = tf.keras.layers.Conv2D(self.hidden_dim, 1)
        self.pos_embed = PositionEmbeddingSine(output_dim=self.hidden_dim)
        num_heads = 8
        self.transformer_num_layers = 6
        self.transformer = Transformer(num_layers=self.transformer_num_layers,
                                       d_model=self.hidden_dim,
                                       num_heads=num_heads,
                                       dim_feedforward=2048)

        # MCMA layers
        self.dyn_weight_map = DynamicalWeightMaps()
        self.ref_points = SMCAReferencePoints(self.hidden_dim, num_heads)

        self.bbox_embed = tf.keras.models.Sequential([
            tf.keras.layers.Dense(self.hidden_dim, activation='relu'),
            tf.keras.layers.Dense(self.hidden_dim, activation='relu'),
            tf.keras.layers.Dense(4, dtype=tf.float32)  # (x1, y1, x2, y2)
        ])
        self.class_embed = tf.keras.layers.Dense(num_classes + 1, dtype=tf.float32)

        # Will create a learnable embedding matrix for all our queries
        # It is a matrix of [num_queries, self.hidden_dim]
        # The embedding layers
        self.query_embed = tf.keras.layers.Embedding(
            num_queries,
            self.hidden_dim,
            embeddings_initializer=tf.keras.initializers.RandomNormal(mean=0., stddev=1.))

        self.all_the_queries = tf.range(num_queries)

        # Loss computation
        self.weight_class, self.weight_l1, self.weight_giou = 2, 5, 2
        similarity_func = DetrSimilarity(self.weight_class, self.weight_l1, self.weight_giou)
        self.target_assigner = TargetAssigner(similarity_func,
                                              hungarian_matching,
                                              lambda gt, pred: gt,
                                              negative_class_weight=1.0)

        # Losses
        self.giou = tfa.losses.GIoULoss(reduction=tf.keras.losses.Reduction.NONE)
        self.l1 = L1Loss(reduction=tf.keras.losses.Reduction.NONE)
        self.focal_loss = tfa.losses.SigmoidFocalCrossEntropy(
            alpha=0.25, gamma=2, reduction=tf.keras.losses.Reduction.NONE, from_logits=True)

        # Metrics
        self.giou_metric = tf.keras.metrics.Mean(name="giou_last_layer")
        self.l1_metric = tf.keras.metrics.Mean(name="l1_last_layer")
        self.focal_loss_metric = tf.keras.metrics.Mean(name="focal_loss_last_layer")
        self.loss_metric = tf.keras.metrics.Mean(name="loss")
        self.precision_metric = tf.keras.metrics.SparseCategoricalAccuracy()
        # Object recall = foreground
        self.recall_metric = tf.keras.metrics.Mean(name="object_recall")
Beispiel #3
0
    def __init__(self, num_classes: int, backbone, num_queries=100, **kwargs):
        super().__init__(**kwargs)

        self.num_classes = num_classes
        self.num_queries = num_queries
        self.hidden_dim = 256

        self.backbone = backbone
        self.input_proj = tf.keras.layers.Conv2D(self.hidden_dim, 1)
        self.pos_embed = PositionEmbeddingSine(output_dim=self.hidden_dim)
        self.transformer_num_layers = 6
        self.transformer = Transformer(num_layers=self.transformer_num_layers,
                                       d_model=self.hidden_dim,
                                       num_heads=8,
                                       dim_feedforward=2048)

        self.bbox_embed = tf.keras.models.Sequential([
            tf.keras.layers.Dense(self.hidden_dim, activation='relu'),
            tf.keras.layers.Dense(self.hidden_dim, activation='relu'),
            tf.keras.layers.Dense(4, activation='sigmoid',
                                  dtype=tf.float32)  # (x1, y1, x2, y2)
        ])
        self.class_embed = tf.keras.layers.Dense(num_classes + 1,
                                                 dtype=tf.float32)

        # Will create a learnable embedding matrix for all our queries
        # It is a matrix of [num_queries, self.hidden_dim]
        # The embedding layers
        self.query_embed = tf.keras.layers.Embedding(num_queries,
                                                     self.hidden_dim)
        self.all_the_queries = tf.range(num_queries)

        # Loss computation
        self.weight_class, self.weight_l1, self.weight_giou = 1, 5, 2
        similarity_func = DetrSimilarity(self.weight_class, self.weight_l1,
                                         self.weight_giou)
        self.target_assigner = TargetAssigner(similarity_func,
                                              hungarian_matching,
                                              lambda gt, pred: gt,
                                              negative_class_weight=1.0)

        # Relative classification weight applied to the no-object category
        # It down-weight the log-probability term of a no-object
        # by a factor 10 to account for class imbalance
        self.non_object_weight = tf.constant(0.1, dtype=self.compute_dtype)

        # Losses
        self.giou = GIoULoss(reduction=tf.keras.losses.Reduction.NONE)
        self.l1 = L1Loss(reduction=tf.keras.losses.Reduction.NONE)
        self.scc = SparseCategoricalCrossentropy(
            reduction=tf.keras.losses.Reduction.NONE, from_logits=True)

        # Metrics
        self.giou_metric = tf.keras.metrics.Mean(name="giou_last_layer")
        self.l1_metric = tf.keras.metrics.Mean(name="l1_last_layer")
        self.scc_metric = tf.keras.metrics.Mean(name="scc_last_layer")
        self.loss_metric = tf.keras.metrics.Mean(name="loss")
        self.precision_metric = tf.keras.metrics.SparseCategoricalAccuracy()
        # Object recall = foreground
        self.recall_metric = tf.keras.metrics.Mean(name="object_recall")
Beispiel #4
0
    def __init__(self, num_classes, **kwargs):
        super().__init__(
            num_classes,
            SparseCategoricalCrossentropy(
                reduction=tf.keras.losses.Reduction.NONE, from_logits=True),
            L1Loss(reduction=tf.keras.losses.Reduction.NONE
                   ),  # like in tensorpack
            kernel_initializer_classification_head=initializers.RandomNormal(
                stddev=0.01),
            kernel_initializer_box_prediction_head=initializers.RandomNormal(
                stddev=0.001),
            **kwargs)

        matcher = Matcher([0.5], [0, 1])
        # The same scale_factors is used in decoding as well
        encode = functools.partial(encode_boxes_faster_rcnn,
                                   scale_factors=(10.0, 10.0, 5.0, 5.0))
        self.target_assigner = TargetAssigner(IoUSimilarity(),
                                              matcher,
                                              encode,
                                              dtype=self._compute_dtype)
Beispiel #5
0
class RegionProposalNetwork(AbstractDetectionHead):
    """It has been introduced in the [Faster R-CNN paper](https://arxiv.org/abs/1506.01497) and
    use the parameters from [Feature Pyramidal Networks for Object Detection](https://arxiv.org/abs/1612.03144).

    Arguments:

    - *anchor_ratios*: The ratios are the different shapes that you want to apply on your anchors.
            e.g: (0.5, 1, 2)
    """

    def __init__(self, anchor_ratios=(0.5, 1, 2), **kwargs):
        super().__init__(
            2,
            SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE,
                                          from_logits=True),
            L1Loss(reduction=tf.keras.losses.Reduction.NONE),
            multiples=len(anchor_ratios),
            kernel_initializer_classification_head=initializers.RandomNormal(stddev=0.01),
            kernel_initializer_box_prediction_head=initializers.RandomNormal(stddev=0.01),
            **kwargs)

        #Force each ground_truths to match to at least one anchor
        matcher = Matcher([0.3, 0.7], [0, -1, 1], allow_low_quality_matches=True)
        self.target_assigner = TargetAssigner(IoUSimilarity(),
                                              matcher,
                                              encode_boxes_faster_rcnn,
                                              dtype=self._compute_dtype)

        anchor_strides = (4, 8, 16, 32, 64)
        anchor_zises = (32, 64, 128, 256, 512)
        self._anchor_ratios = anchor_ratios

        # Precompute a deterministic grid of anchors for each layer of the pyramid.
        # We will extract a subpart of the anchors according to
        self._anchors = [
            Anchors(stride, size, self._anchor_ratios)
            for stride, size in zip(anchor_strides, anchor_zises)
        ]

    def build(self, input_shape):
        self.rpn_conv2d = KL.Conv2D(512, (3, 3),
                                    padding='same',
                                    kernel_initializer=self._kernel_initializer_classification_head,
                                    kernel_regularizer=self._kernel_regularizer)
        super().build(input_shape)

    def build_rpn_head(self, inputs):
        """Predictions for the classification and the regression

        Arguments:

        - *inputs*: A tensor of  shape [batch_size, width, height, channel]

        Returns:

        A tuple of tensors of shape ([batch_size, num_anchors, 2], [batch_size, num_anchors, 4])
        """

        batch_size = tf.shape(inputs)[0]
        rpn_conv2d = self.rpn_conv2d(inputs)
        classification_head, localization_head = self.build_detection_head(rpn_conv2d)
        classification_head = tf.reshape(classification_head, (batch_size, -1, 2))
        localization_head = tf.reshape(localization_head, (batch_size, -1, 4))
        return classification_head, localization_head

    def call(self, inputs: List[tf.Tensor]):
        """Create the computation graph for the rpn inference

        Argument:

        *inputs*: A List of tensors the output of the pyramid

        Returns:
        
        - *localization_pred*: A list of logits 3-D tensor of shape [batch_size, num_anchors, 4]
        - *classification_pred*: A lost of logits 3-D tensor of shape [batch_size, num_anchors, 2]
        - *anchors*: A list of tensors of shape [batch_size, num_anchors, (y_min, x_min, y_max, x_max)]
        """
        anchors = [anchors(tensor) for tensor, anchors in zip(inputs, self._anchors)]

        rpn_predictions = [self.build_rpn_head(tensor) for tensor in inputs]
        localization_pred = [prediction[1] for prediction in rpn_predictions]
        classification_pred = [prediction[0] for prediction in rpn_predictions]

        return localization_pred, classification_pred, anchors

    def compute_loss(self, localization_pred, classification_pred, anchors, ground_truths):
        """Compute the loss

        Arguments:

        - *localization_pred*: A list of tensors of shape [batch_size, num_anchors, 4].
        - *classification_pred*: A list of tensors of shape [batch_size, num_anchors, 2]
        - *anchors*: A list of tensors of shape [num_anchors, (y_min, x_min, y_max, x_max)]
        - *ground_truths*: A dict with BoxField as key and a tensor as value.

        ```python
        ground_truths = {
            BoxField.BOXES:
                tf.constant([[[0, 0, 1, 1], [0, 0, 2, 2]], [[0, 0, 3, 3], [0, 0, 0, 0]]], tf.float32),
            BoxField.LABELS:
                tf.constant([[1, 0], [1, 0]], tf.float32),
            BoxField.WEIGHTS:
                tf.constant([[1, 0], [1, 1]], tf.float32),
            BoxField.NUM_BOXES:
                tf.constant([[2], [1]], tf.int32)
        }
        ```

        where `NUM_BOXES` allows to remove the padding created by tf.Data.

        Returns:

        - *classification_loss*: A scalar in tf.float32
        - *localization_loss*: A scalar in tf.float32
        """
        localization_pred = tf.concat(localization_pred, 1)
        classification_pred = tf.concat(classification_pred, 1)
        anchors = tf.concat(anchors, 0)

        ground_truths = {
            # We add one because the background is not counted in ground_truths[BoxField.LABELS]
            BoxField.LABELS:
                ground_truths[BoxField.LABELS] + 1,
            BoxField.BOXES:
                ground_truths[BoxField.BOXES],
            BoxField.WEIGHTS:
                ground_truths[BoxField.WEIGHTS],
            BoxField.NUM_BOXES:
                ground_truths[BoxField.NUM_BOXES]
        }
        # anchors are deterministic duplicate them to create a batch
        anchors = tf.tile(anchors[None], (tf.shape(ground_truths[BoxField.BOXES])[0], 1, 1))
        y_true, weights = self.target_assigner.assign({BoxField.BOXES: anchors}, ground_truths)
        y_true[BoxField.LABELS] = tf.minimum(y_true[BoxField.LABELS], 1)

        ## Compute metrics
        recall = compute_rpn_metrics(y_true[BoxField.LABELS], classification_pred,
                                     weights[BoxField.LABELS])
        self.add_metric(recall, name='rpn_recall', aggregation='mean')

        # All the boxes which are not -1 can be sampled
        labels = y_true[BoxField.LABELS] > 0
        sample_idx = batch_sample_balanced_positive_negative(
            weights[BoxField.LABELS],
            SAMPLING_SIZE,
            labels,
            positive_fraction=SAMPLING_POSITIVE_RATIO,
            dtype=self._compute_dtype)

        weights[BoxField.LABELS] = sample_idx * weights[BoxField.LABELS]
        weights[BoxField.BOXES] = sample_idx * weights[BoxField.BOXES]

        y_pred = {BoxField.LABELS: classification_pred, BoxField.BOXES: localization_pred}

        return self.compute_losses(y_true, y_pred, weights)

    def get_config(self):
        base_config = super().get_config()
        base_config['anchor_ratios'] = self._anchor_ratios
        return base_config
Beispiel #6
0
class DeTr(tf.keras.Model):
    """Build a DeTr model according to the paper
    [End-to-End Object Detection with Transformers](https://arxiv.org/abs/2005.12872)

    You can use it as follow:

    ```python
    model = DeTrResnet50Pytorch(80)
    base_lr = 0.1
    optimizer = tf.keras.optimizers.SGD(learning_rate=base_lr)
    model.compile(optimizer=optimizer, loss=None)
    model.fit(ds_train, validation_data=ds_test, epochs=11,)
    ```

    Arguments:
        num_classes: The number of classes of your dataset
            (**do not include the background class** it is handle for you)
        backbone: A vision model like ResNet50.
        num_queries: number of object queries, ie detection slot.
            This is the maximal number of objects
            DETR can detect in a single image. For COCO, we recommend 100 queries.

    Call arguments:
        inputs: Tuple
            1. images: A 4-D tensor of float32 and shape [batch_size, None, None, 3]
            2. image_informations: A 1D tensor of float32 and shape [(height, width),].
                It contains the shape of the image without any padding.
            3. images_padding_mask: A 3D tensor of int8 and shape [batch_size, None, None]
                composed of 0 and 1 which allows to know where a padding has been applied.
        training: Is automatically set to `True` in train mode

    Call returns:
        Tuple:
            - `logits`: A Tensor of shape [batch_size, h, num_classes + 1] class logits
            - `boxes`: A Tensor of shape [batch_size, h, 4]
            where h is num_queries * transformer_decoder.transformer_num_layers if
            training is true and num_queries otherwise.
    """
    def __init__(self, num_classes: int, backbone, num_queries=100, **kwargs):
        super().__init__(**kwargs)

        self.num_classes = num_classes
        self.num_queries = num_queries
        self.hidden_dim = 256

        self.backbone = backbone
        self.input_proj = tf.keras.layers.Conv2D(self.hidden_dim, 1)
        self.pos_embed = PositionEmbeddingSine(output_dim=self.hidden_dim)
        self.transformer_num_layers = 6
        self.transformer = Transformer(num_layers=self.transformer_num_layers,
                                       d_model=self.hidden_dim,
                                       num_heads=8,
                                       dim_feedforward=2048)

        self.bbox_embed = tf.keras.models.Sequential([
            tf.keras.layers.Dense(self.hidden_dim, activation='relu'),
            tf.keras.layers.Dense(self.hidden_dim, activation='relu'),
            tf.keras.layers.Dense(4, activation='sigmoid',
                                  dtype=tf.float32)  # (x1, y1, x2, y2)
        ])
        self.class_embed = tf.keras.layers.Dense(num_classes + 1,
                                                 dtype=tf.float32)

        # Will create a learnable embedding matrix for all our queries
        # It is a matrix of [num_queries, self.hidden_dim]
        # The embedding layers
        self.query_embed = tf.keras.layers.Embedding(
            num_queries,
            self.hidden_dim,
            embeddings_initializer=tf.keras.initializers.RandomNormal(
                mean=0., stddev=1.))
        self.all_the_queries = tf.range(num_queries)

        # Loss computation
        self.weight_class, self.weight_l1, self.weight_giou = 1, 5, 2
        similarity_func = DetrSimilarity(self.weight_class, self.weight_l1,
                                         self.weight_giou)
        self.target_assigner = TargetAssigner(similarity_func,
                                              hungarian_matching,
                                              lambda gt, pred: gt,
                                              negative_class_weight=1.0)

        # Relative classification weight applied to the no-object category
        # It down-weight the log-probability term of a no-object
        # by a factor 10 to account for class imbalance
        self.non_object_weight = tf.constant(0.1, dtype=self.compute_dtype)

        # Losses
        self.giou = GIoULoss(reduction=tf.keras.losses.Reduction.NONE)
        self.l1 = L1Loss(reduction=tf.keras.losses.Reduction.NONE)
        self.scc = SparseCategoricalCrossentropy(
            reduction=tf.keras.losses.Reduction.NONE, from_logits=True)

        # Metrics
        self.giou_metric = tf.keras.metrics.Mean(name="giou_last_layer")
        self.l1_metric = tf.keras.metrics.Mean(name="l1_last_layer")
        self.scc_metric = tf.keras.metrics.Mean(name="scc_last_layer")
        self.loss_metric = tf.keras.metrics.Mean(name="loss")
        self.precision_metric = tf.keras.metrics.SparseCategoricalAccuracy()
        # Object recall = foreground
        self.recall_metric = tf.keras.metrics.Mean(name="object_recall")

    @property
    def metrics(self):
        return [
            self.loss_metric, self.giou_metric, self.l1_metric,
            self.scc_metric, self.precision_metric, self.recall_metric
        ]

    def call(self, inputs, training=None):
        """Perform an inference in training.

        Arguments:
            inputs: Tuple
                1. images: A 4-D tensor of float32 and shape [batch_size, None, None, 3]
                2. image_informations: A 1D tensor of float32 and shape [(height, width),]. It contains the shape
                of the image without any padding.
                3. images_padding_mask: A 3D tensor of int8 and shape
                    [batch_size, None, None] composed of 0 and 1 which
                    allows to know where a padding has been applied.
            training: Is automatically set to `True` in train mode

        Returns:
            Tuple:
                - `logits`: A Tensor of shape [batch_size, h, num_classes + 1] class logits
                - `boxes`: A Tensor of shape [batch_size, h, 4]
                where h is num_queries * transformer_decoder.transformer_num_layers if
                training is true and num_queries otherwise.
        """
        images = inputs[DatasetField.IMAGES]
        images_padding_masks = inputs[DatasetField.IMAGES_PMASK]
        batch_size = tf.shape(images)[0]
        # The preprocessing dedicated to the backbone is done inside the model.
        x = self.backbone(images)[-1]
        features_mask = tf.image.resize(
            tf.cast(images_padding_masks[..., None], tf.float32),
            tf.shape(x)[1:3],
            method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
        features_mask = tf.cast(features_mask, tf.bool)
        # Positional_encoding for the backbone
        pos_embed = self.pos_embed(features_mask)
        # [batch_size, num_queries, self.hidden_dim]
        all_the_queries = tf.tile(self.all_the_queries[None], (batch_size, 1))
        # [batch_size, num_queries, self.hidden_dim]
        query_embed = self.query_embed(all_the_queries)
        # add positional_encoding to x [batch_size, h, w, self.hidden_dim]
        x = self.input_proj(x)

        # Flatten the position embedding and the spatial tensor
        # to allow the preprocessing by the Transformer
        # [batch_size, h * w,  self.hidden_dim]
        x = tf.reshape(x, (batch_size, -1, self.hidden_dim))
        pos_embed = tf.reshape(pos_embed, (batch_size, -1, self.hidden_dim))
        # Flatten the padding masks
        features_mask = tf.reshape(features_mask, (batch_size, -1))

        decoder_out, _ = self.transformer(x,
                                          pos_embed,
                                          query_embed,
                                          key_padding_mask=features_mask,
                                          training=training)
        boxes = self.bbox_embed(decoder_out)
        logits = self.class_embed(decoder_out)

        return {
            BoxField.SCORES: logits,
            BoxField.BOXES: boxes,
        }

    def compute_loss(
        self,
        ground_truths: Dict[str, tf.Tensor],
        y_pred: Dict[str, tf.Tensor],
        input_shape: tf.Tensor,
    ) -> int:
        """Apply the GIoU, L1 and SCC to each layers of the transformer decoder

        Arguments:
            ground_truths: see output kerod.dataset.preprocessing for the doc
            y_pred: A dict
                - *scores: A Tensor of shape [batch_size, num_queries, num_classes + 1] class logits
                - *bbox*: A Tensor of shape [batch_size, num_queries, 4]
            input_shape: [height, width] of the input tensor.
                It is the shape of the images will all the padding included.
                It is used to normalize the ground_truths boxes.
        """
        normalized_boxes = ground_truths[BoxField.BOXES] / tf.tile(
            input_shape[None], [1, 2])
        centered_normalized_boxes = convert_to_center_coordinates(
            normalized_boxes)
        ground_truths = {
            # We add one because the background is not counted in ground_truths [BoxField.LABELS]
            BoxField.LABELS:
            ground_truths[BoxField.LABELS] + 1,
            BoxField.BOXES:
            centered_normalized_boxes,
            BoxField.WEIGHTS:
            ground_truths[BoxField.WEIGHTS],
            BoxField.NUM_BOXES:
            ground_truths[BoxField.NUM_BOXES]
        }
        boxes_per_lvl = tf.split(y_pred[BoxField.BOXES],
                                 self.transformer_num_layers,
                                 axis=1)
        logits_per_lvl = tf.split(y_pred[BoxField.SCORES],
                                  self.transformer_num_layers,
                                  axis=1)

        y_pred_per_lvl = [{
            BoxField.BOXES: boxes,
            BoxField.SCORES: logits
        } for boxes, logits in zip(boxes_per_lvl, logits_per_lvl)]

        num_boxes = tf.cast(tf.reduce_sum(ground_truths[BoxField.NUM_BOXES]),
                            tf.float32)
        loss = 0
        # Compute the Giou, L1 and SCC at each layers of the transformer decoder
        for i, y_pred in enumerate(y_pred_per_lvl):
            # Logs the metrics for the last layer of the decoder
            compute_metrics = i == self.transformer_num_layers - 1
            loss += self._compute_loss(y_pred,
                                       ground_truths,
                                       num_boxes,
                                       compute_metrics=compute_metrics)
        return loss

    def _compute_loss(
        self,
        y_pred: Dict[str, tf.Tensor],
        ground_truths: Dict[str, tf.Tensor],
        num_boxes: int,
        compute_metrics=False,
    ):
        y_true, weights = self.target_assigner.assign(y_pred, ground_truths)

        # Reduce the class imbalanced by applying to the weights
        # self.non_object_weight for the non object (pos 0)
        weights[BoxField.LABELS] = item_assignment(
            weights[BoxField.LABELS], y_true[BoxField.LABELS] == 0,
            self.non_object_weight)
        # Caveats GIoU is buggy and if the batch_size is 1 and the sample_weight
        # is provided will raise an error
        giou = self.giou(convert_to_xyxy_coordinates(y_true[BoxField.BOXES]),
                         convert_to_xyxy_coordinates(y_pred[BoxField.BOXES]),
                         sample_weight=weights[BoxField.BOXES])

        l1 = self.l1(y_true[BoxField.BOXES],
                     y_pred[BoxField.BOXES],
                     sample_weight=weights[BoxField.BOXES])

        # SparseCategoricalCrossentropy
        scc = self.scc(y_true[BoxField.LABELS],
                       y_pred[BoxField.SCORES],
                       sample_weight=weights[BoxField.LABELS])

        giou = self.weight_giou * tf.reduce_sum(giou) / num_boxes

        l1 = self.weight_l1 * tf.reduce_sum(l1) / num_boxes

        scc = self.weight_class * tf.reduce_sum(scc) / tf.reduce_sum(
            weights[BoxField.LABELS])

        if compute_metrics:
            self.giou_metric.update_state(giou)
            self.l1_metric.update_state(l1)
            self.scc_metric.update_state(scc)
            self.precision_metric.update_state(
                y_true[BoxField.LABELS],
                y_pred[BoxField.SCORES],
                sample_weight=weights[BoxField.LABELS])

            recall = compute_detr_metrics(y_true[BoxField.LABELS],
                                          y_pred[BoxField.SCORES])
            self.recall_metric.update_state(recall)
        return giou + l1 + scc

    def train_step(self, data):
        data = data_adapter.expand_1d(data)
        x, ground_truths, _ = data_adapter.unpack_x_y_sample_weight(data)

        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)
            input_shape = tf.cast(
                tf.shape(x[DatasetField.IMAGES])[1:3], self.compute_dtype)
            loss = self.compute_loss(ground_truths, y_pred, input_shape)

            loss += self.compiled_loss(None,
                                       y_pred,
                                       None,
                                       regularization_losses=self.losses)

        self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
        self.loss_metric.update_state(loss)
        return {m.name: m.result() for m in self.metrics}

    def test_step(self, data):
        data = data_adapter.expand_1d(data)
        x, ground_truths, _ = data_adapter.unpack_x_y_sample_weight(data)

        # To compute the loss we need to get the results of each decoder layer
        # Setting training to True will provide it
        y_pred = self(x, training=True)
        input_shape = tf.cast(
            tf.shape(x[DatasetField.IMAGES])[1:3], self.compute_dtype)
        loss = self.compute_loss(ground_truths, y_pred, input_shape)
        loss += self.compiled_loss(None,
                                   y_pred,
                                   None,
                                   regularization_losses=self.losses)
        self.loss_metric.update_state(loss)
        return {m.name: m.result() for m in self.metrics}

    def predict_step(self, data):
        """Perform an inference and returns the boxes, scores and labels associated.
        Background is discarded the max and argmax operation are performed.
        It means that if background was predicted the second maximum score would
        be outputed.

        Example: background + 3 classes
        [0.54, 0.40, 0.03, 0.03] => score = 0.40, label = 0 (1 - 1)


        "To optimize for AP, we override the prediction of these slots
        with the second highest scoring class, using the corresponding confidence"
        Part 4. Experiments of Object Detection with Transformers

        Returns:
            boxes: A Tensor of shape [batch_size, self.num_queries, (y1,x1,y2,x2)]
                containing the boxes with the coordinates between 0 and 1.
            scores: A Tensor of shape [batch_size, self.num_queries] containing
                the score of the boxes.
            classes: A Tensor of shape [batch_size, self.num_queries]
                containing the class of the boxes [0, num_classes).
        """
        data = data_adapter.expand_1d(data)
        x, _, _ = data_adapter.unpack_x_y_sample_weight(data)
        y_pred = self(x, training=False)
        boxes_without_padding, scores, labels = detr_postprocessing(
            y_pred[BoxField.BOXES],
            y_pred[BoxField.SCORES],
            x[DatasetField.IMAGES_INFO],
            tf.shape(x[DatasetField.IMAGES])[1:3],
        )
        return boxes_without_padding, scores, labels
Beispiel #7
0
class FastRCNN(AbstractDetectionHead):
    """Build the Fast-RCNN on top of the FPN. The parameters used
    are from [Feature Pyramidal Networks for Object Detection](https://arxiv.org/abs/1612.03144).

    Arguments:
        num_classes: The number of classes that predict the classification head (N+1) where N
            is the number of classes of your dataset and 1 is the background.

    Call arguments:
        inputs: A Tuple
            1. `pyramid`: A List of tensors the output of the pyramid
            2. `anchors`: A tensor of shape [batch_size, num_boxes, (y_min, x_min, y_max, x_max)]

    Call returns:
        Tuple:
            `classification_pred`: A logit Tensor of shape [batch_size, num_boxes, num_classes]
            `localization_pred`: A Tensor of shape [batch_size, num_boxes, 4 * (num_classes - 1)]
            `anchors`: A Tensor of shape [batch_size, num_boxes, 4]
    """
    def __init__(self, num_classes, **kwargs):
        super().__init__(
            num_classes,
            SparseCategoricalCrossentropy(
                reduction=tf.keras.losses.Reduction.NONE, from_logits=True),
            L1Loss(reduction=tf.keras.losses.Reduction.NONE
                   ),  # like in tensorpack
            kernel_initializer_classification_head=initializers.RandomNormal(
                stddev=0.01),
            kernel_initializer_box_prediction_head=initializers.RandomNormal(
                stddev=0.001),
            **kwargs)

        matcher = Matcher([0.5], [0, 1])
        # The same scale_factors is used in decoding as well
        encode = functools.partial(encode_boxes_faster_rcnn,
                                   scale_factors=(10.0, 10.0, 5.0, 5.0))
        self.target_assigner = TargetAssigner(IoUSimilarity(),
                                              matcher,
                                              encode,
                                              dtype=self._compute_dtype)

    def build(self, input_shape):
        self.denses = [
            KL.Dense(1024,
                     kernel_initializer=initializers.VarianceScaling(),
                     kernel_regularizer=self._kernel_regularizer,
                     activation='relu') for _ in range(2)
        ]
        super().build(input_shape)

    def call(self, inputs):
        """Build the computational graph of the fast RCNN HEAD.

        It performs a raw prediction of the FastRCNN head you can post_process them using:

        ```python
        from kerod.layers.post_processing import post_process_fast_rcnn_boxes

        outputs = post_process_fast_rcnn_boxes(classification_pred, localization_pred, anchors,
                                    images_information, num_classes)
        ```

        where `images_information` is provided as input of your model and `num_classes` includes
        the background.


        Arguments:
            inputs: A Tuple
                1. `pyramid`: A List of tensors the output of the pyramid
                2. `anchors`: A tensor of shape [batch_size, num_boxes, (y_min, x_min, y_max, x_max)]

        Returns:
            Tuple:
                `classification_pred`: A logit Tensor of shape [batch_size, num_boxes, num_classes]
                `localization_pred`: A Tensor of shape [batch_size, num_boxes, 4 * (num_classes - 1)]
                `anchors`: A Tensor of shape [batch_size, num_boxes, 4]
        """
        # Remove P6
        pyramid = inputs[0][:-1]
        anchors = inputs[1]

        # We can compute the original image shape regarding
        # TODO compute it more automatically without knowing that the last layer is stride 32
        image_shape = tf.cast(tf.shape(pyramid[-1])[1:3] * 32,
                              dtype=self._compute_dtype)
        boxe_tensors = multilevel_roi_align(pyramid,
                                            anchors,
                                            image_shape,
                                            crop_size=7)
        l = KL.Flatten()(boxe_tensors)
        for dense in self.denses:
            l = dense(l)

        classification_pred, localization_pred = self.build_detection_head(
            tf.reshape(l, (-1, 1, 1, 1024)))
        batch_size = tf.shape(anchors)[0]
        classification_pred = tf.reshape(classification_pred,
                                         (batch_size, -1, self._num_classes))
        localization_pred = tf.reshape(localization_pred,
                                       (batch_size, -1,
                                        (self._num_classes - 1) * 4))
        return classification_pred, localization_pred

    def sample_boxes(self,
                     anchors: tf.Tensor,
                     ground_truths: Dict[str, tf.Tensor],
                     sampling_size: int = 512,
                     sampling_positive_ratio: float = 0.25):
        """Perform the sampling of the target anchors.

        During the training a set of RoIs is detected by the RPN.
        However, you do not want to analyse all the set. You only want
        to analyse the anchors that you sampled with this method.

        Arguments:
            anchors: A tensor of shape [batch_size, num_boxes, (y_min, x_min, y_max, x_max)]
            ground_truths: A dict
                - `BoxField.LABELS`: A 3-D tensor of shape [batch_size, num_gt, num_classes],
                - `BoxField.BOXES`: A 3-D tensor of shape [batch_size, num_gt, (y1, x1, y2, x2)]
                - `BoxField.LABELS`: A 3-D tensor of int32 and shape [batch_size, num_gt]
                - `BoxField.WEIGHTS`: A 3-D tensor of float and shape [batch_size, num_gt]
                - `BoxField.NUM_BOXES`: A 2-D tensor of int32 and shape [batch_size, 1]
                    which allows to remove the padding created by tf.Data.
                    Example: if batch_size=2 and this field equal tf.constant([[2], [1]], tf.int32)
                    then my second box has a padding of 1
            sampling_size: Desired sampling size. If None, keeps all positive samples and
                randomly selects negative samples so that the positive sample fraction
                matches positive_fraction.
            sampling_positive_ratio: Desired fraction of positive examples (scalar in [0,1])
                in the batch.

        Returns:
            Tuple:
                1. y_true: A dict with :
                    - `BoxField.LABELS`: A 3-D tensor of shape [batch_size, num_anchors,
                        num_classes],
                    - `BoxField.BOXES`: A 3-D tensor of shape [batch_size, num_anchors,
                        box_code_dimension]

                2. weights: A dict with:
                    - `BoxField.LABELS`: A 2-D tensor of shape [batch_size, num_anchors],
                    - `BoxField.BOXES`: A 2-D tensor of shape [batch_size, num_anchors]

        Raises:
            ValueError: If the batch_size is None.
            ValueError: If the batch_size between your ground_truths and the anchors does not match.
        """

        ground_truths = {
            # We add one because the background is not counted in ground_truths[BoxField.LABELS]
            BoxField.LABELS:
            ground_truths[BoxField.LABELS] + 1,
            BoxField.BOXES:
            ground_truths[BoxField.BOXES],
            BoxField.WEIGHTS:
            ground_truths[BoxField.WEIGHTS],
            BoxField.NUM_BOXES:
            ground_truths[BoxField.NUM_BOXES]
        }
        y_true, weights = self.target_assigner.assign(
            {BoxField.BOXES: anchors}, ground_truths)

        labels = y_true[BoxField.LABELS] > 0
        sample_idx = batch_sample_balanced_positive_negative(
            weights[BoxField.LABELS],
            sampling_size,
            labels,
            positive_fraction=sampling_positive_ratio,
            dtype=self._compute_dtype)

        weights[BoxField.LABELS] = sample_idx * weights[BoxField.LABELS]
        weights[BoxField.BOXES] = sample_idx * weights[BoxField.BOXES]

        selected_boxes_idx = tf.where(sample_idx == 1)

        batch_size = tf.shape(sample_idx)[0]

        # Extract the selected anchors corresponding anchors
        # tf.gather_nd collaps the batch_together so we reshape with the proper batch_size
        anchors = tf.reshape(tf.gather_nd(anchors, selected_boxes_idx),
                             (batch_size, -1, 4))

        y_true[BoxField.BOXES] = tf.reshape(
            tf.gather_nd(y_true[BoxField.BOXES], selected_boxes_idx),
            (batch_size, -1, 4))

        y_true[BoxField.LABELS] = tf.reshape(
            tf.gather_nd(y_true[BoxField.LABELS], selected_boxes_idx),
            (batch_size, -1))

        for key in y_true.keys():
            weights[key] = tf.reshape(
                tf.gather_nd(weights[key], selected_boxes_idx),
                (batch_size, -1))
            weights[key] = tf.stop_gradient(weights[key])
            y_true[key] = tf.stop_gradient(y_true[key])
        return y_true, weights, anchors

    def compute_loss(self, y_true: dict, weights: dict,
                     classification_pred: tf.Tensor,
                     localization_pred: tf.Tensor):
        """Compute the loss of the FastRCNN

        Arguments:
            y_true: A dict with :
                - `BoxField.LABELS`: A 3-D tensor of shape [batch_size, num_anchors, num_classes]
                - `BoxField.BOXES`: A 3-D tensor of shape [batch_size, num_anchors, 4]
            weights: A dict with:
                - `BoxField.LABELS`: A 3-D tensor of shape [batch_size, num_anchors, num_classes]
                - `BoxField.BOXES`: A 2-D tensor of shape [batch_size, num_anchors]
            classification_pred: A 3-D tensor of float and shape
                [batch_size, num_anchors, num_classes]
            localization_pred: A  3-D tensor of float and shape
                [batch_size, num_anchors, (num_classes - 1) * 4]

        Returns:
            Tuple:
                - `classification_loss`: A scalar
                - `localization_loss`: A scalar
        """
        y_true_classification = tf.cast(y_true[BoxField.LABELS], tf.int32)
        accuracy, fg_accuracy, false_negative = compute_fast_rcnn_metrics(
            y_true_classification, classification_pred)
        self.add_metric(accuracy, name='accuracy', aggregation='mean')
        self.add_metric(fg_accuracy, name='fg_accuracy', aggregation='mean')
        self.add_metric(false_negative,
                        name='false_negative',
                        aggregation='mean')

        # y_true[BoxField.LABELS] is just 1 and 0 we are using it as mask to extract
        # the corresponding target anchors
        batch_size = tf.shape(classification_pred)[0]
        # We create a boolean mask to extract the desired localization prediction to compute
        # the loss
        one_hot_targets = tf.one_hot(y_true_classification,
                                     self._num_classes,
                                     dtype=tf.int8)
        one_hot_targets = tf.reshape(one_hot_targets, [-1])

        # We need to insert a fake background classes at the position 0
        localization_pred = tf.pad(localization_pred, [[0, 0], [0, 0], [4, 0]])
        localization_pred = tf.reshape(localization_pred, [-1, 4])

        extracted_localization_pred = tf.boolean_mask(localization_pred,
                                                      one_hot_targets > 0)
        extracted_localization_pred = tf.reshape(extracted_localization_pred,
                                                 (batch_size, -1, 4))

        y_pred = {
            BoxField.LABELS: classification_pred,
            BoxField.BOXES: extracted_localization_pred
        }

        return self.compute_losses(y_true, y_pred, weights)
Beispiel #8
0
class SMCA(tf.keras.Model):
    """Build a single scale SCMA model according to the paper
    [Fast Convergence of DETR with Spatially Modulated Co-Attention](https://arxiv.org/pdf/2101.07448.pdf).

    In what is it different from DETR ?

    Just imagine that your object queries are learned anchors.
    Those learned "anchors" will modulate the attention map during
    the coattention stage of the decoder. They will help to target
    faster some sweet spots which leads to a speed up by 10
    of the training. It maintains the same performance than DETR.

    You can use it as follow:

    ```python
    model = SMCAR50(80)
    base_lr = 0.1
    optimizer = tf.keras.optimizers.SGD(learning_rate=base_lr)
    model.compile(optimizer=optimizer, loss=None)
    model.fit(ds_train, validation_data=ds_test, epochs=11,)
    ```

    Arguments:
        num_classes: The number of classes of your dataset
            (**do not include the background class** it is handle for you)
        backbone: A vision model like ResNet50.
        num_queries: number of object queries, ie detection slot.
            This is the maximal number of objects
            SCMA can detect in a single image. For COCO, we recommend 300 queries.

    Call arguments:
        inputs: Tuple
            1. images: A 4-D tensor of float32 and shape [batch_size, None, None, 3]
            2. image_informations: A 1D tensor of float32 and shape [(height, width),].
                It contains the shape of the image without any padding.
            3. images_padding_mask: A 3D tensor of int8 and shape [batch_size, None, None]
                composed of 0 and 1 which allows to know where a padding has been applied.
        training: Is automatically set to `True` in train mode

    Call returns:
        logits: A Tensor of shape [batch_size, h, num_classes + 1] class logits
        boxes: A Tensor of shape [batch_size, h, 4]

    where h is num_queries * transformer_decoder.num_layers if
    training is true and num_queries otherwise.
    """
    def __init__(self, num_classes, backbone, num_queries=300, **kwargs):
        super().__init__(**kwargs)

        self.num_classes = num_classes
        self.num_queries = num_queries
        self.hidden_dim = 256

        self.backbone = backbone
        self.input_proj = tf.keras.layers.Conv2D(self.hidden_dim, 1)
        self.pos_embed = PositionEmbeddingSine(output_dim=self.hidden_dim)
        num_heads = 8
        self.transformer_num_layers = 6
        self.transformer = Transformer(num_layers=self.transformer_num_layers,
                                       d_model=self.hidden_dim,
                                       num_heads=num_heads,
                                       dim_feedforward=2048)

        # MCMA layers
        self.dyn_weight_map = DynamicalWeightMaps()
        self.ref_points = SMCAReferencePoints(self.hidden_dim, num_heads)

        self.bbox_embed = tf.keras.models.Sequential([
            tf.keras.layers.Dense(self.hidden_dim, activation='relu'),
            tf.keras.layers.Dense(self.hidden_dim, activation='relu'),
            tf.keras.layers.Dense(4, dtype=tf.float32)  # (x1, y1, x2, y2)
        ])
        self.class_embed = tf.keras.layers.Dense(num_classes + 1,
                                                 dtype=tf.float32)

        # Will create a learnable embedding matrix for all our queries
        # It is a matrix of [num_queries, self.hidden_dim]
        # The embedding layers
        self.query_embed = tf.keras.layers.Embedding(num_queries,
                                                     self.hidden_dim)
        self.all_the_queries = tf.range(num_queries)

        # Loss computation
        self.weight_class, self.weight_l1, self.weight_giou = 2, 5, 2
        similarity_func = DetrSimilarity(self.weight_class, self.weight_l1,
                                         self.weight_giou)
        self.target_assigner = TargetAssigner(similarity_func,
                                              hungarian_matching,
                                              lambda gt, pred: gt,
                                              negative_class_weight=1.0)

        # Losses
        self.giou = tfa.losses.GIoULoss(
            reduction=tf.keras.losses.Reduction.NONE)
        self.l1 = L1Loss(reduction=tf.keras.losses.Reduction.NONE)
        self.focal_loss = tfa.losses.SigmoidFocalCrossEntropy(
            alpha=0.25,
            gamma=2,
            reduction=tf.keras.losses.Reduction.NONE,
            from_logits=True)

        # Metrics
        self.giou_metric = tf.keras.metrics.Mean(name="giou_last_layer")
        self.l1_metric = tf.keras.metrics.Mean(name="l1_last_layer")
        self.focal_loss_metric = tf.keras.metrics.Mean(
            name="focal_loss_last_layer")
        self.loss_metric = tf.keras.metrics.Mean(name="loss")
        self.precision_metric = tf.keras.metrics.SparseCategoricalAccuracy()
        # Object recall = foreground
        self.recall_metric = tf.keras.metrics.Mean(name="object_recall")

    @property
    def metrics(self):
        return [
            self.loss_metric, self.giou_metric, self.l1_metric,
            self.focal_loss_metric, self.precision_metric, self.recall_metric
        ]

    def call(self, inputs, training=None):
        """Perform an inference in training.

        Arguments:

        - *inputs*: Tuple
            1. images: A 4-D tensor of float32 and shape [batch_size, None, None, 3]
            2. image_informations: A 1D tensor of float32 and shape [(height, width),]. It contains the shape
            of the image without any padding.
            3. images_padding_mask: A 3D tensor of int8 and shape [batch_size, None, None] composed of 0 and 1 which allows to know where a padding has been applied.


        - *training*: Is automatically set to `True` in train mode

        Returns:

        - *logits*: A Tensor of shape [batch_size, num_queries, num_classes + 1] class logits
        - *boxes*: A Tensor of shape [batch_size, num_queries, 4]

        where h is num_queries * transformer_decoder.num_layers if
        training is true and num_queries otherwise.
        """
        images = inputs[DatasetField.IMAGES]
        images_padding_masks = inputs[DatasetField.IMAGES_PMASK]
        batch_size = tf.shape(images)[0]
        # The preprocessing dedicated to the backbone is done inside the model.
        x = self.backbone(images)[-1]
        features_mask = tf.image.resize(
            tf.cast(images_padding_masks[..., None], tf.float32),
            tf.shape(x)[1:3],
            method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
        features_mask = tf.cast(features_mask, tf.bool)
        # Positional_encoding for the backbone
        pos_embed = self.pos_embed(features_mask)
        # [batch_size, num_queries, self.hidden_dim]
        all_the_queries = tf.tile(self.all_the_queries[None], (batch_size, 1))
        # [batch_size, num_queries, self.hidden_dim]
        query_embed = self.query_embed(all_the_queries)
        h_backbone_out, w_backbone_out = tf.shape(x)[1], tf.shape(x)[2]
        x = self.input_proj(x)

        # Flatten the position embedding and the spatial tensor
        # to allow the preprocessing by the Transformer
        # [batch_size, h * w,  self.hidden_dim]
        x = tf.reshape(x, (batch_size, -1, self.hidden_dim))
        pos_embed = tf.reshape(pos_embed, (batch_size, -1, self.hidden_dim))
        # Flatten the padding masks
        features_mask = tf.reshape(features_mask, (batch_size, -1))

        ref_points, ref_points_presigmoid = self.ref_points(query_embed)

        # dyn_weight_map_per_head = G in the paper
        dyn_weight_map_per_head = self.dyn_weight_map(h_backbone_out,
                                                      w_backbone_out,
                                                      ref_points)
        dyn_weight_map_per_head = tf.math.log(dyn_weight_map_per_head +
                                              10e-4)  # log G

        decoder_out, _ = self.transformer(x,
                                          pos_embed,
                                          query_embed,
                                          key_padding_mask=features_mask,
                                          coattn_mask=dyn_weight_map_per_head,
                                          training=training)

        logits = self.class_embed(decoder_out)
        boxes = self.bbox_embed(decoder_out)

        if training:
            # In training all the outputs of the decoders are stacked together.
            # We tile the reference_points to match those outputs
            ref_points_presigmoid = tf.tile(
                ref_points_presigmoid, (1, self.transformer_num_layers, 1))

        # Add initial center to constrain  the bounding boxes predictions
        offset = tf.concat([
            ref_points_presigmoid,
            tf.zeros((batch_size, tf.shape(ref_points_presigmoid)[1], 2))
        ],
                           axis=-1)
        boxes = tf.nn.sigmoid(boxes + offset)

        return {
            BoxField.SCORES: logits,
            BoxField.BOXES: boxes,
        }

    def compute_loss(
        self,
        ground_truths: Dict[str, tf.Tensor],
        y_pred: Dict[str, tf.Tensor],
        input_shape: tf.Tensor,
    ) -> int:
        """Apply the GIoU, L1 and SCC to each layers of the transformer decoder

        Args:
            ground_truths: see output kerod.dataset.preprocessing for the doc
            y_pred: A dict
                - *scores: A Tensor of shape [batch_size, num_queries, num_classes + 1] class logits
                - *bbox*: A Tensor of shape [batch_size, num_queries, 4]
            input_shape: [height, width] of the input tensor.
                It is the shape of the images will all the padding included.
                It is used to normalize the ground_truths boxes.
        """
        normalized_boxes = ground_truths[BoxField.BOXES] / tf.tile(
            input_shape[None], [1, 2])
        centered_normalized_boxes = convert_to_center_coordinates(
            normalized_boxes)
        ground_truths = {
            # We add one because the background is not counted in ground_truths [BoxField.LABELS]
            BoxField.LABELS:
            ground_truths[BoxField.LABELS] + 1,
            BoxField.BOXES:
            centered_normalized_boxes,
            BoxField.WEIGHTS:
            ground_truths[BoxField.WEIGHTS],
            BoxField.NUM_BOXES:
            ground_truths[BoxField.NUM_BOXES]
        }
        boxes_per_lvl = tf.split(y_pred[BoxField.BOXES],
                                 self.transformer_num_layers,
                                 axis=1)
        logits_per_lvl = tf.split(y_pred[BoxField.SCORES],
                                  self.transformer_num_layers,
                                  axis=1)

        y_pred_per_lvl = [{
            BoxField.BOXES: boxes,
            BoxField.SCORES: logits
        } for boxes, logits in zip(boxes_per_lvl, logits_per_lvl)]

        num_boxes = tf.cast(tf.reduce_sum(ground_truths[BoxField.NUM_BOXES]),
                            tf.float32)
        loss = 0
        # Compute the Giou, L1 and SCC at each layers of the transformer decoder
        for i, y_pred in enumerate(y_pred_per_lvl):
            # Logs the metrics for the last layer of the decoder
            compute_metrics = i == self.transformer_num_layers - 1
            loss += self._compute_loss(y_pred,
                                       ground_truths,
                                       num_boxes,
                                       compute_metrics=compute_metrics)
        return loss

    def _compute_loss(
        self,
        y_pred: Dict[str, tf.Tensor],
        ground_truths: Dict[str, tf.Tensor],
        num_boxes: int,
        compute_metrics=False,
    ):
        y_true, weights = self.target_assigner.assign(y_pred, ground_truths)

        # Caveats GIoU is buggy and if the batch_size is 1 and the sample_weight
        # is provided will raise an error
        giou = self.giou(convert_to_xyxy_coordinates(y_true[BoxField.BOXES]),
                         convert_to_xyxy_coordinates(y_pred[BoxField.BOXES]),
                         sample_weight=weights[BoxField.BOXES])

        l1 = self.l1(y_true[BoxField.BOXES],
                     y_pred[BoxField.BOXES],
                     sample_weight=weights[BoxField.BOXES])

        cls_labels = tf.one_hot(
            y_true[BoxField.LABELS],
            depth=self.num_classes + 1,
            dtype=tf.float32,
        )
        focal_loss = self.focal_loss(cls_labels,
                                     y_pred[BoxField.SCORES],
                                     sample_weight=weights[BoxField.LABELS])

        giou = self.weight_giou * tf.reduce_sum(giou) / num_boxes

        l1 = self.weight_l1 * tf.reduce_sum(l1) / num_boxes

        focal_loss = self.weight_class * tf.reduce_sum(
            focal_loss) / tf.reduce_sum(weights[BoxField.LABELS])

        if compute_metrics:
            self.giou_metric.update_state(giou)
            self.l1_metric.update_state(l1)
            self.focal_loss_metric.update_state(focal_loss)
            self.precision_metric.update_state(
                y_true[BoxField.LABELS],
                y_pred[BoxField.SCORES],
                sample_weight=weights[BoxField.LABELS])

            recall = compute_detr_metrics(y_true[BoxField.LABELS],
                                          y_pred[BoxField.SCORES])
            self.recall_metric.update_state(recall)
        return giou + l1 + focal_loss

    def train_step(self, data):
        data = data_adapter.expand_1d(data)
        x, ground_truths, _ = data_adapter.unpack_x_y_sample_weight(data)

        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)
            input_shape = tf.cast(
                tf.shape(x[DatasetField.IMAGES])[1:3], self.compute_dtype)
            loss = self.compute_loss(ground_truths, y_pred, input_shape)

            loss += self.compiled_loss(None,
                                       y_pred,
                                       None,
                                       regularization_losses=self.losses)

        self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
        self.loss_metric.update_state(loss)
        return {m.name: m.result() for m in self.metrics}

    def test_step(self, data):
        data = data_adapter.expand_1d(data)
        x, ground_truths, _ = data_adapter.unpack_x_y_sample_weight(data)

        # To compute the loss we need to get the results of each decoder layer
        # Setting training to True will provide it
        y_pred = self(x, training=True)
        input_shape = tf.cast(
            tf.shape(x[DatasetField.IMAGES])[1:3], self.compute_dtype)
        loss = self.compute_loss(ground_truths, y_pred, input_shape)
        loss += self.compiled_loss(None,
                                   y_pred,
                                   None,
                                   regularization_losses=self.losses)
        self.loss_metric.update_state(loss)
        return {m.name: m.result() for m in self.metrics}

    def predict_step(self, data):
        """Perform an inference and returns the boxes, scores and labels associated.
        Background is discarded the max and argmax operation are performed.
        It means that if background was predicted the second maximum score would
        be outputed.

        Example: background + 3 classes
        [0.54, 0.40, 0.03, 0.03] => score = 0.40, label = 0 (1 - 1)


        "To optimize for AP, we override the prediction of these slots
        with the second highest scoring class, using the corresponding confidence"
        Part 4. Experiments of Object Detection with Transformers

        Returns:
            boxes: A Tensor of shape [batch_size, self.num_queries, (y1,x1,y2,x2)]
                containing the boxes with the coordinates between 0 and 1.
            scores: A Tensor of shape [batch_size, self.num_queries] containing
                the score of the boxes.
            classes: A Tensor of shape [batch_size, self.num_queries]
                containing the class of the boxes [0, num_classes).
        """
        data = data_adapter.expand_1d(data)
        x, _, _ = data_adapter.unpack_x_y_sample_weight(data)
        y_pred = self(x, training=False)
        boxes_without_padding, scores, labels = detr_postprocessing(
            y_pred[BoxField.BOXES],
            y_pred[BoxField.SCORES],
            x[DatasetField.IMAGES_INFO],
            tf.shape(x[DatasetField.IMAGES])[1:3],
        )
        return boxes_without_padding, scores, labels