def _eval_epoch(self, subset: Text, batch_size: int):
        """Evaluates an epoch."""
        num_samples = 0.
        summed_scalars = None

        params = helpers.get_first(self._byol_state.online_params)
        state = helpers.get_first(self._byol_state.online_state)
        split = dataset.Split.from_string(subset)

        dataset_iterator = dataset.load(
            split,
            preprocess_mode=dataset.PreprocessMode.EVAL,
            transpose=self._should_transpose_images(),
            batch_dims=[batch_size])

        for inputs in dataset_iterator:
            num_samples += inputs['labels'].shape[0]
            scalars = self.eval_batch_jit(params, state, inputs)

            # Accumulate the sum of scalars for each step.
            scalars = jax.tree_map(lambda x: jnp.sum(x, axis=0), scalars)
            if summed_scalars is None:
                summed_scalars = scalars
            else:
                summed_scalars = jax.tree_multimap(jnp.add, summed_scalars,
                                                   scalars)

        mean_scalars = jax.tree_map(lambda x: x / num_samples, summed_scalars)
        return mean_scalars
    def evaluate(self, global_step, **unused_args):
        """Thin wrapper around _eval_epoch."""

        global_step = np.array(helpers.get_first(global_step))
        scalars = jax.device_get(self._eval_epoch(**self._evaluation_config))

        logging.info('[Step %d] Eval scalars: %s', global_step, scalars)
        return scalars
    def step(self, *, global_step, rng):
        """Performs a single training step."""

        if self._train_input is None:
            self._initialize_train(rng)

        inputs = next(self._train_input)
        self._experiment_state, scalars = self.update_pmap(
            self._experiment_state, global_step, inputs)

        scalars = helpers.get_first(scalars)
        return scalars
    def step(self, *, global_step, rng):
        """Performs a single training step."""
        if self._train_input is None:
            self._initialize_train()

        inputs = next(self._train_input)

        self._byol_state, scalars = self.update_pmap(
            self._byol_state,
            global_step=global_step,
            rng=rng,
            inputs=inputs,
        )

        return helpers.get_first(scalars)