Esempio n. 1
0
    def _eval_batch(
        self,
        params: hk.Params,
        state: hk.State,
        batch: dataset.Batch,
    ) -> Mapping[Text, jnp.ndarray]:
        """Evaluates a batch.

    Args:
      params: Parameters of the model to evaluate. Typically Byol's online
        parameters.
      state: State of the model to evaluate. Typically Byol's online state.
      batch: Batch of data to evaluate (must contain keys images and labels).

    Returns:
      Unreduced evaluation loss and top1 accuracy on the batch.
    """
        if self._should_transpose_images():
            batch = dataset.transpose_images(batch)

        outputs, _ = self.forward.apply(params,
                                        state,
                                        batch,
                                        is_training=False)
        logits = outputs['logits']
        labels = hk.one_hot(batch['labels'], self._num_classes)
        loss = helpers.softmax_cross_entropy(logits, labels, reduction=None)
        top1_correct = helpers.topk_accuracy(logits, batch['labels'], topk=1)
        top5_correct = helpers.topk_accuracy(logits, batch['labels'], topk=5)
        # NOTE: Returned values will be summed and finally divided by num_samples.
        return {
            'eval_loss': loss,
            'top1_accuracy': top1_correct,
            'top5_accuracy': top5_correct,
        }
Esempio n. 2
0
    def _loss_fn(
        self,
        backbone_params: hk.Params,
        classif_params: hk.Params,
        backbone_state: hk.State,
        inputs: dataset.Batch,
    ) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, hk.State]]:
        """Compute the classification loss function.

    Args:
      backbone_params: parameters of the encoder network.
      classif_params: parameters of the linear classifier.
      backbone_state: internal state of encoder network.
      inputs: inputs, containing `images` and `labels`.

    Returns:
      The classification loss and various logs.
    """
        embeddings, backbone_state = self.forward_backbone.apply(
            backbone_params,
            backbone_state,
            inputs,
            is_training=not self._freeze_backbone)

        logits = self.forward_classif.apply(classif_params, embeddings)
        labels = hk.one_hot(inputs['labels'], self._num_classes)
        loss = helpers.softmax_cross_entropy(logits, labels, reduction='mean')
        scaled_loss = loss / jax.device_count()

        return scaled_loss, (loss, backbone_state)
Esempio n. 3
0
 def _eval_batch(
     self,
     backbone_params: hk.Params,
     classif_params: hk.Params,
     backbone_state: hk.State,
     inputs: dataset.Batch,
 ) -> LogsDict:
     """Evaluates a batch."""
     embeddings, backbone_state = self.forward_backbone.apply(
         backbone_params, backbone_state, inputs, is_training=False)
     logits = self.forward_classif.apply(classif_params, embeddings)
     labels = hk.one_hot(inputs['labels'], self._num_classes)
     loss = helpers.softmax_cross_entropy(logits, labels, reduction=None)
     top1_correct = helpers.topk_accuracy(logits, inputs['labels'], topk=1)
     top5_correct = helpers.topk_accuracy(logits, inputs['labels'], topk=5)
     # NOTE: Returned values will be summed and finally divided by num_samples.
     return {
         'eval_loss': loss,
         'top1_accuracy': top1_correct,
         'top5_accuracy': top5_correct
     }
Esempio n. 4
0
    def loss_fn(
        self,
        online_params: hk.Params,
        target_params: hk.Params,
        online_state: hk.State,
        target_state: hk.Params,
        rng: jnp.ndarray,
        inputs: dataset.Batch,
    ) -> Tuple[jnp.ndarray, Tuple[Mapping[Text, hk.State], LogsDict]]:
        """Compute BYOL's loss function.

    Args:
      online_params: parameters of the online network (the loss is later
        differentiated with respect to the online parameters).
      target_params: parameters of the target network.
      online_state: internal state of online network.
      target_state: internal state of target network.
      rng: random number generator state.
      inputs: inputs, containing two batches of crops from the same images,
        view1 and view2 and labels

    Returns:
      BYOL's loss, a mapping containing the online and target networks updated
      states after processing inputs, and various logs.
    """
        if self._should_transpose_images():
            inputs = dataset.transpose_images(inputs)
        inputs = augmentations.postprocess(inputs, rng)
        labels = inputs['labels']

        online_network_out, online_state = self.forward.apply(
            params=online_params,
            state=online_state,
            inputs=inputs,
            is_training=True)
        target_network_out, target_state = self.forward.apply(
            params=target_params,
            state=target_state,
            inputs=inputs,
            is_training=True)

        # Representation loss

        # The stop_gradient is not necessary as we explicitly take the gradient with
        # respect to online parameters only in `optax.apply_updates`. We leave it to
        # indicate that gradients are not backpropagated through the target network.
        repr_loss = helpers.regression_loss(
            online_network_out['prediction_view1'],
            jax.lax.stop_gradient(target_network_out['projection_view2']))
        repr_loss = repr_loss + helpers.regression_loss(
            online_network_out['prediction_view2'],
            jax.lax.stop_gradient(target_network_out['projection_view1']))

        repr_loss = jnp.mean(repr_loss)

        # Classification loss (with gradient flows stopped from flowing into the
        # ResNet). This is used to provide an evaluation of the representation
        # quality during training.

        classif_loss = helpers.softmax_cross_entropy(
            logits=online_network_out['logits_view1'],
            labels=jax.nn.one_hot(labels, self._num_classes))

        top1_correct = helpers.topk_accuracy(
            online_network_out['logits_view1'],
            inputs['labels'],
            topk=1,
        )

        top5_correct = helpers.topk_accuracy(
            online_network_out['logits_view1'],
            inputs['labels'],
            topk=5,
        )

        top1_acc = jnp.mean(top1_correct)
        top5_acc = jnp.mean(top5_correct)

        classif_loss = jnp.mean(classif_loss)
        loss = repr_loss + classif_loss
        logs = dict(
            loss=loss,
            repr_loss=repr_loss,
            classif_loss=classif_loss,
            top1_accuracy=top1_acc,
            top5_accuracy=top5_acc,
        )

        return loss, (dict(online_state=online_state,
                           target_state=target_state), logs)