def _eval_batch(
        self,
        params: hk.Params,
        state: hk.State,
        batch: dataset.Batch,
    ) -> Mapping[Text, jnp.ndarray]:
        """Evaluates a batch.

    Args:
      params: Parameters of the model to evaluate. Typically Byol's online
        parameters.
      state: State of the model to evaluate. Typically Byol's online state.
      batch: Batch of data to evaluate (must contain keys images and labels).

    Returns:
      Unreduced evaluation loss and top1 accuracy on the batch.
    """
        if self._should_transpose_images():
            batch = dataset.transpose_images(batch)

        outputs, _ = self.forward.apply(params,
                                        state,
                                        batch,
                                        is_training=False)
        logits = outputs['logits']
        labels = hk.one_hot(batch['labels'], self._num_classes)
        loss = helpers.softmax_cross_entropy(logits, labels, reduction=None)
        top1_correct = helpers.topk_accuracy(logits, batch['labels'], topk=1)
        top5_correct = helpers.topk_accuracy(logits, batch['labels'], topk=5)
        # NOTE: Returned values will be summed and finally divided by num_samples.
        return {
            'eval_loss': loss,
            'top1_accuracy': top1_correct,
            'top5_accuracy': top5_correct,
        }
Example #2
0
    def _backbone_fn(
        self,
        inputs: dataset.Batch,
        encoder_class: Text,
        encoder_config: Mapping[Text, Any],
        bn_decay_rate: float,
        is_training: bool,
    ) -> jnp.ndarray:
        """Forward of the encoder (backbone)."""
        bn_config = {'decay_rate': bn_decay_rate}
        encoder = getattr(networks, encoder_class)
        model = encoder(None, bn_config=bn_config, **encoder_config)

        if self._should_transpose_images():
            inputs = dataset.transpose_images(inputs)
        images = dataset.normalize_images(inputs['images'])
        return model(images, is_training=is_training)
    def _backbone_fn(
        self,
        inputs,
        encoder_class,
        encoder_config,
        bn_decay_rate,
        is_training,
    ):
        """Forward of the encoder (backbone)."""
        bn_config = {'decay_rate': bn_decay_rate}
        encoder = getattr(networks, encoder_class)
        model = encoder(None, bn_config=bn_config, **encoder_config)

        if self._should_transpose_images():
            inputs = dataset.transpose_images(inputs)
        images = dataset.normalize_images(inputs['images'])
        return model(images, is_training=is_training)
    def _make_initial_state(
        self,
        rng: jnp.ndarray,
        dummy_input: dataset.Batch,
    ) -> _ByolExperimentState:
        """BYOL's _ByolExperimentState initialization.

    Args:
      rng: random number generator used to initialize parameters. If working in
        a multi device setup, this need to be a ShardedArray.
      dummy_input: a dummy image, used to compute intermediate outputs shapes.

    Returns:
      Initial Byol state.
    """
        rng_online, rng_target = jax.random.split(rng)

        if self._should_transpose_images():
            dummy_input = dataset.transpose_images(dummy_input)

        # Online and target parameters are initialized using different rngs,
        # in our experiments we did not notice a significant different with using
        # the same rng for both.
        online_params, online_state = self.forward.init(
            rng_online,
            dummy_input,
            is_training=True,
        )
        target_params, target_state = self.forward.init(
            rng_target,
            dummy_input,
            is_training=True,
        )
        opt_state = self._optimizer(0).init(online_params)
        return _ByolExperimentState(
            online_params=online_params,
            target_params=target_params,
            opt_state=opt_state,
            online_state=online_state,
            target_state=target_state,
        )
    def loss_fn(
        self,
        online_params: hk.Params,
        target_params: hk.Params,
        online_state: hk.State,
        target_state: hk.Params,
        rng: jnp.ndarray,
        inputs: dataset.Batch,
    ) -> Tuple[jnp.ndarray, Tuple[Mapping[Text, hk.State], LogsDict]]:
        """Compute BYOL's loss function.

    Args:
      online_params: parameters of the online network (the loss is later
        differentiated with respect to the online parameters).
      target_params: parameters of the target network.
      online_state: internal state of online network.
      target_state: internal state of target network.
      rng: random number generator state.
      inputs: inputs, containing two batches of crops from the same images,
        view1 and view2 and labels

    Returns:
      BYOL's loss, a mapping containing the online and target networks updated
      states after processing inputs, and various logs.
    """
        if self._should_transpose_images():
            inputs = dataset.transpose_images(inputs)
        inputs = augmentations.postprocess(inputs, rng)
        labels = inputs['labels']

        online_network_out, online_state = self.forward.apply(
            params=online_params,
            state=online_state,
            inputs=inputs,
            is_training=True)
        target_network_out, target_state = self.forward.apply(
            params=target_params,
            state=target_state,
            inputs=inputs,
            is_training=True)

        # Representation loss

        # The stop_gradient is not necessary as we explicitly take the gradient with
        # respect to online parameters only in `optax.apply_updates`. We leave it to
        # indicate that gradients are not backpropagated through the target network.
        repr_loss = helpers.regression_loss(
            online_network_out['prediction_view1'],
            jax.lax.stop_gradient(target_network_out['projection_view2']))
        repr_loss = repr_loss + helpers.regression_loss(
            online_network_out['prediction_view2'],
            jax.lax.stop_gradient(target_network_out['projection_view1']))

        repr_loss = jnp.mean(repr_loss)

        # Classification loss (with gradient flows stopped from flowing into the
        # ResNet). This is used to provide an evaluation of the representation
        # quality during training.

        classif_loss = helpers.softmax_cross_entropy(
            logits=online_network_out['logits_view1'],
            labels=jax.nn.one_hot(labels, self._num_classes))

        top1_correct = helpers.topk_accuracy(
            online_network_out['logits_view1'],
            inputs['labels'],
            topk=1,
        )

        top5_correct = helpers.topk_accuracy(
            online_network_out['logits_view1'],
            inputs['labels'],
            topk=5,
        )

        top1_acc = jnp.mean(top1_correct)
        top5_acc = jnp.mean(top5_correct)

        classif_loss = jnp.mean(classif_loss)
        loss = repr_loss + classif_loss
        logs = dict(
            loss=loss,
            repr_loss=repr_loss,
            classif_loss=classif_loss,
            top1_accuracy=top1_acc,
            top5_accuracy=top5_acc,
        )

        return loss, (dict(online_state=online_state,
                           target_state=target_state), logs)