def _eval_batch( self, params: hk.Params, state: hk.State, batch: dataset.Batch, ) -> Mapping[Text, jnp.ndarray]: """Evaluates a batch. Args: params: Parameters of the model to evaluate. Typically Byol's online parameters. state: State of the model to evaluate. Typically Byol's online state. batch: Batch of data to evaluate (must contain keys images and labels). Returns: Unreduced evaluation loss and top1 accuracy on the batch. """ if self._should_transpose_images(): batch = dataset.transpose_images(batch) outputs, _ = self.forward.apply(params, state, batch, is_training=False) logits = outputs['logits'] labels = hk.one_hot(batch['labels'], self._num_classes) loss = helpers.softmax_cross_entropy(logits, labels, reduction=None) top1_correct = helpers.topk_accuracy(logits, batch['labels'], topk=1) top5_correct = helpers.topk_accuracy(logits, batch['labels'], topk=5) # NOTE: Returned values will be summed and finally divided by num_samples. return { 'eval_loss': loss, 'top1_accuracy': top1_correct, 'top5_accuracy': top5_correct, }
def _backbone_fn( self, inputs: dataset.Batch, encoder_class: Text, encoder_config: Mapping[Text, Any], bn_decay_rate: float, is_training: bool, ) -> jnp.ndarray: """Forward of the encoder (backbone).""" bn_config = {'decay_rate': bn_decay_rate} encoder = getattr(networks, encoder_class) model = encoder(None, bn_config=bn_config, **encoder_config) if self._should_transpose_images(): inputs = dataset.transpose_images(inputs) images = dataset.normalize_images(inputs['images']) return model(images, is_training=is_training)
def _backbone_fn( self, inputs, encoder_class, encoder_config, bn_decay_rate, is_training, ): """Forward of the encoder (backbone).""" bn_config = {'decay_rate': bn_decay_rate} encoder = getattr(networks, encoder_class) model = encoder(None, bn_config=bn_config, **encoder_config) if self._should_transpose_images(): inputs = dataset.transpose_images(inputs) images = dataset.normalize_images(inputs['images']) return model(images, is_training=is_training)
def _make_initial_state( self, rng: jnp.ndarray, dummy_input: dataset.Batch, ) -> _ByolExperimentState: """BYOL's _ByolExperimentState initialization. Args: rng: random number generator used to initialize parameters. If working in a multi device setup, this need to be a ShardedArray. dummy_input: a dummy image, used to compute intermediate outputs shapes. Returns: Initial Byol state. """ rng_online, rng_target = jax.random.split(rng) if self._should_transpose_images(): dummy_input = dataset.transpose_images(dummy_input) # Online and target parameters are initialized using different rngs, # in our experiments we did not notice a significant different with using # the same rng for both. online_params, online_state = self.forward.init( rng_online, dummy_input, is_training=True, ) target_params, target_state = self.forward.init( rng_target, dummy_input, is_training=True, ) opt_state = self._optimizer(0).init(online_params) return _ByolExperimentState( online_params=online_params, target_params=target_params, opt_state=opt_state, online_state=online_state, target_state=target_state, )
def loss_fn( self, online_params: hk.Params, target_params: hk.Params, online_state: hk.State, target_state: hk.Params, rng: jnp.ndarray, inputs: dataset.Batch, ) -> Tuple[jnp.ndarray, Tuple[Mapping[Text, hk.State], LogsDict]]: """Compute BYOL's loss function. Args: online_params: parameters of the online network (the loss is later differentiated with respect to the online parameters). target_params: parameters of the target network. online_state: internal state of online network. target_state: internal state of target network. rng: random number generator state. inputs: inputs, containing two batches of crops from the same images, view1 and view2 and labels Returns: BYOL's loss, a mapping containing the online and target networks updated states after processing inputs, and various logs. """ if self._should_transpose_images(): inputs = dataset.transpose_images(inputs) inputs = augmentations.postprocess(inputs, rng) labels = inputs['labels'] online_network_out, online_state = self.forward.apply( params=online_params, state=online_state, inputs=inputs, is_training=True) target_network_out, target_state = self.forward.apply( params=target_params, state=target_state, inputs=inputs, is_training=True) # Representation loss # The stop_gradient is not necessary as we explicitly take the gradient with # respect to online parameters only in `optax.apply_updates`. We leave it to # indicate that gradients are not backpropagated through the target network. repr_loss = helpers.regression_loss( online_network_out['prediction_view1'], jax.lax.stop_gradient(target_network_out['projection_view2'])) repr_loss = repr_loss + helpers.regression_loss( online_network_out['prediction_view2'], jax.lax.stop_gradient(target_network_out['projection_view1'])) repr_loss = jnp.mean(repr_loss) # Classification loss (with gradient flows stopped from flowing into the # ResNet). This is used to provide an evaluation of the representation # quality during training. classif_loss = helpers.softmax_cross_entropy( logits=online_network_out['logits_view1'], labels=jax.nn.one_hot(labels, self._num_classes)) top1_correct = helpers.topk_accuracy( online_network_out['logits_view1'], inputs['labels'], topk=1, ) top5_correct = helpers.topk_accuracy( online_network_out['logits_view1'], inputs['labels'], topk=5, ) top1_acc = jnp.mean(top1_correct) top5_acc = jnp.mean(top5_correct) classif_loss = jnp.mean(classif_loss) loss = repr_loss + classif_loss logs = dict( loss=loss, repr_loss=repr_loss, classif_loss=classif_loss, top1_accuracy=top1_acc, top5_accuracy=top5_acc, ) return loss, (dict(online_state=online_state, target_state=target_state), logs)