Exemple #1
0
 def _build_eval_input(self):
     bs_per_device = (self.config.eval_batch_size //
                      jax.local_device_count())
     split = input_pipeline.Split.from_string(self.config.eval_subset)
     eval_preproc = self.config.get('eval_preproc', 'crop_resize')
     return input_pipeline.load(
         split,
         is_training=False,
         batch_dims=[jax.local_device_count(), bs_per_device],
         transpose=self.config.get('transpose', False),
         image_size=(self.test_imsize, self.test_imsize),
         name=self.config.which_dataset,
         eval_preproc=eval_preproc,
         fake_data=False)
def build_eval_input(data_dir, batch_size, img_size):
    bs_per_device = (batch_size // jax.local_device_count())
    split = input_pipeline.Split.TEST
    eval_preproc = 'crop_resize'
    return input_pipeline.load(
        split,
        data_dir=data_dir,
        is_training=False,
        batch_dims=[jax.local_device_count(), bs_per_device],
        transpose=True,
        image_size=(img_size, img_size),
        name='imagenet',
        eval_preproc=eval_preproc,
        fake_data=False)
Exemple #3
0
 def _build_train_input(self):
     num_devices = jax.device_count()
     global_batch_size = self.config.train_batch_size
     bs_per_device, ragged = divmod(global_batch_size, num_devices)
     if ragged:
         raise ValueError(
             f'Global batch size {global_batch_size} must be divisible by '
             f'num devices {num_devices}')
     return input_pipeline.load(
         input_pipeline.Split.TRAIN_AND_VALID,
         is_training=True,
         batch_dims=[jax.local_device_count(), bs_per_device],
         transpose=self.config.get('transpose', False),
         image_size=(self.train_imsize, self.train_imsize),
         augment_name=self.config.augment_name,
         augment_before_mix=self.config.get('augment_before_mix', True),
         name=self.config.which_dataset,
         fake_data=False)
def build_train_input(data_dir, batch_size, img_size, augmentation):
    num_devices = jax.device_count()
    bs_per_device, ragged = divmod(batch_size, num_devices)
    if ragged:
        raise ValueError(
            f'Batch size {batch_size} must be divisible by num devices {num_devices}'
        )
    return input_pipeline.load(
        input_pipeline.Split.TRAIN_AND_VALID,
        data_dir=data_dir,
        is_training=True,
        batch_dims=[jax.local_device_count(), bs_per_device],
        transpose=True,
        image_size=(img_size, img_size),
        augment_name=augmentation,
        augment_before_mix=True,
        name='imagenet',
        fake_data=False)