def _eval_epoch(self, subset: Text, batch_size: int): """Evaluates an epoch.""" num_samples = 0. summed_scalars = None params = helpers.get_first(self._byol_state.online_params) state = helpers.get_first(self._byol_state.online_state) split = dataset.Split.from_string(subset) dataset_iterator = dataset.load( split, preprocess_mode=dataset.PreprocessMode.EVAL, transpose=self._should_transpose_images(), batch_dims=[batch_size]) for inputs in dataset_iterator: num_samples += inputs['labels'].shape[0] scalars = self.eval_batch_jit(params, state, inputs) # Accumulate the sum of scalars for each step. scalars = jax.tree_map(lambda x: jnp.sum(x, axis=0), scalars) if summed_scalars is None: summed_scalars = scalars else: summed_scalars = jax.tree_multimap(jnp.add, summed_scalars, scalars) mean_scalars = jax.tree_map(lambda x: x / num_samples, summed_scalars) return mean_scalars
def _build_train_input(self) -> Generator[dataset.Batch, None, None]: """Loads the (infinitely looping) dataset iterator.""" num_devices = jax.device_count() global_batch_size = self._batch_size per_device_batch_size, ragged = divmod(global_batch_size, num_devices) if ragged: raise ValueError( f'Global batch size {global_batch_size} must be divisible by ' f'num devices {num_devices}') return dataset.load( dataset.Split.TRAIN_AND_VALID, preprocess_mode=dataset.PreprocessMode.PRETRAIN, transpose=self._should_transpose_images(), batch_dims=[jax.local_device_count(), per_device_batch_size])
def _build_train_input(self): """See base class.""" num_devices = jax.device_count() global_batch_size = self._batch_size per_device_batch_size, ragged = divmod(global_batch_size, num_devices) if ragged: raise ValueError( f'Global batch size {global_batch_size} must be divisible by ' f'num devices {num_devices}') return dataset.load( dataset.Split.TRAIN_AND_VALID, preprocess_mode=dataset.PreprocessMode.LINEAR_TRAIN, transpose=self._should_transpose_images(), batch_dims=[jax.local_device_count(), per_device_batch_size])