def _build_eval_input(self): bs_per_device = (self.config.eval_batch_size // jax.local_device_count()) split = input_pipeline.Split.from_string(self.config.eval_subset) eval_preproc = self.config.get('eval_preproc', 'crop_resize') return input_pipeline.load( split, is_training=False, batch_dims=[jax.local_device_count(), bs_per_device], transpose=self.config.get('transpose', False), image_size=(self.test_imsize, self.test_imsize), name=self.config.which_dataset, eval_preproc=eval_preproc, fake_data=False)
def build_eval_input(data_dir, batch_size, img_size): bs_per_device = (batch_size // jax.local_device_count()) split = input_pipeline.Split.TEST eval_preproc = 'crop_resize' return input_pipeline.load( split, data_dir=data_dir, is_training=False, batch_dims=[jax.local_device_count(), bs_per_device], transpose=True, image_size=(img_size, img_size), name='imagenet', eval_preproc=eval_preproc, fake_data=False)
def _build_train_input(self): num_devices = jax.device_count() global_batch_size = self.config.train_batch_size bs_per_device, ragged = divmod(global_batch_size, num_devices) if ragged: raise ValueError( f'Global batch size {global_batch_size} must be divisible by ' f'num devices {num_devices}') return input_pipeline.load( input_pipeline.Split.TRAIN_AND_VALID, is_training=True, batch_dims=[jax.local_device_count(), bs_per_device], transpose=self.config.get('transpose', False), image_size=(self.train_imsize, self.train_imsize), augment_name=self.config.augment_name, augment_before_mix=self.config.get('augment_before_mix', True), name=self.config.which_dataset, fake_data=False)
def build_train_input(data_dir, batch_size, img_size, augmentation): num_devices = jax.device_count() bs_per_device, ragged = divmod(batch_size, num_devices) if ragged: raise ValueError( f'Batch size {batch_size} must be divisible by num devices {num_devices}' ) return input_pipeline.load( input_pipeline.Split.TRAIN_AND_VALID, data_dir=data_dir, is_training=True, batch_dims=[jax.local_device_count(), bs_per_device], transpose=True, image_size=(img_size, img_size), augment_name=augmentation, augment_before_mix=True, name='imagenet', fake_data=False)