def training_data_fn(): if self.run_hparams.use_dali and self.run_hparams.data_idx_dir is not None: if hvd.rank() == 0: print("Using DALI input... ") return data_utils.get_dali_input_fn( filenames=filenames, idx_filenames=idx_filenames, batch_size=batch_size, height=self.run_hparams.height, width=self.run_hparams.width, training=True, distort_color=self.run_hparams.distort_colors, num_threads=self.run_hparams.num_preprocessing_threads, deterministic=False if self.run_hparams.seed is None else True) elif self.run_hparams.data_dir is not None: return data_utils.get_tfrecords_input_fn( filenames=filenames, batch_size=batch_size, height=self.run_hparams.height, width=self.run_hparams.width, training=True, distort_color=self.run_hparams.distort_colors, num_threads=self.run_hparams.num_preprocessing_threads, deterministic=False if self.run_hparams.seed is None else True) else: if hvd.rank() == 0: print("Using Synthetic Data ...") return data_utils.get_synth_input_fn( batch_size=batch_size, height=self.run_hparams.height, width=self.run_hparams.width, num_channels=self.run_hparams.n_channels, data_format=self.run_hparams.input_format, num_classes=self.run_hparams.n_classes, dtype=self.run_hparams.dtype, )
def train_data_fn(self, params): data_dir = '/data' data_index_dir = '/data/index_files' if self.is_training: mode = 'train' else: mode = 'validation' if data_dir is not None: #if hvd.rank() == 0: tf.logging.info("Using DALI input... ") filenames, num_samples = list_filenames_in_dataset( data_dir=data_dir, mode=mode) idx_filenames = parse_dali_idx_dataset(data_idx_dir=data_index_dir, mode=mode) return data_utils.get_dali_input_fn( filenames=filenames, idx_filenames=idx_filenames, batch_size=params['batch_size'], height=self.image_size, width=self.image_size, training=self.is_training, distort_color=False, num_threads=4, deterministic=False)