def _eval_epoch(self, subset: Text, batch_size: int):
        """Evaluates an epoch."""
        num_samples = 0.
        summed_scalars = None

        params = helpers.get_first(self._byol_state.online_params)
        state = helpers.get_first(self._byol_state.online_state)
        split = dataset.Split.from_string(subset)

        dataset_iterator = dataset.load(
            split,
            preprocess_mode=dataset.PreprocessMode.EVAL,
            transpose=self._should_transpose_images(),
            batch_dims=[batch_size])

        for inputs in dataset_iterator:
            num_samples += inputs['labels'].shape[0]
            scalars = self.eval_batch_jit(params, state, inputs)

            # Accumulate the sum of scalars for each step.
            scalars = jax.tree_map(lambda x: jnp.sum(x, axis=0), scalars)
            if summed_scalars is None:
                summed_scalars = scalars
            else:
                summed_scalars = jax.tree_multimap(jnp.add, summed_scalars,
                                                   scalars)

        mean_scalars = jax.tree_map(lambda x: x / num_samples, summed_scalars)
        return mean_scalars
    def _build_train_input(self) -> Generator[dataset.Batch, None, None]:
        """Loads the (infinitely looping) dataset iterator."""
        num_devices = jax.device_count()
        global_batch_size = self._batch_size
        per_device_batch_size, ragged = divmod(global_batch_size, num_devices)

        if ragged:
            raise ValueError(
                f'Global batch size {global_batch_size} must be divisible by '
                f'num devices {num_devices}')

        return dataset.load(
            dataset.Split.TRAIN_AND_VALID,
            preprocess_mode=dataset.PreprocessMode.PRETRAIN,
            transpose=self._should_transpose_images(),
            batch_dims=[jax.local_device_count(), per_device_batch_size])
    def _build_train_input(self):
        """See base class."""
        num_devices = jax.device_count()
        global_batch_size = self._batch_size
        per_device_batch_size, ragged = divmod(global_batch_size, num_devices)

        if ragged:
            raise ValueError(
                f'Global batch size {global_batch_size} must be divisible by '
                f'num devices {num_devices}')

        return dataset.load(
            dataset.Split.TRAIN_AND_VALID,
            preprocess_mode=dataset.PreprocessMode.LINEAR_TRAIN,
            transpose=self._should_transpose_images(),
            batch_dims=[jax.local_device_count(), per_device_batch_size])