示例#1
0
        def training_data_fn():

            if self.run_hparams.use_dali and self.run_hparams.data_idx_dir is not None:
                if hvd.rank() == 0:
                    print("Using DALI input... ")

                return data_utils.get_dali_input_fn(
                    filenames=filenames,
                    idx_filenames=idx_filenames,
                    batch_size=batch_size,
                    height=self.run_hparams.height,
                    width=self.run_hparams.width,
                    training=True,
                    distort_color=self.run_hparams.distort_colors,
                    num_threads=self.run_hparams.num_preprocessing_threads,
                    deterministic=False
                    if self.run_hparams.seed is None else True)

            elif self.run_hparams.data_dir is not None:

                return data_utils.get_tfrecords_input_fn(
                    filenames=filenames,
                    batch_size=batch_size,
                    height=self.run_hparams.height,
                    width=self.run_hparams.width,
                    training=True,
                    distort_color=self.run_hparams.distort_colors,
                    num_threads=self.run_hparams.num_preprocessing_threads,
                    deterministic=False
                    if self.run_hparams.seed is None else True)

            else:
                if hvd.rank() == 0:
                    print("Using Synthetic Data ...")
                return data_utils.get_synth_input_fn(
                    batch_size=batch_size,
                    height=self.run_hparams.height,
                    width=self.run_hparams.width,
                    num_channels=self.run_hparams.n_channels,
                    data_format=self.run_hparams.input_format,
                    num_classes=self.run_hparams.n_classes,
                    dtype=self.run_hparams.dtype,
                )
 def train_data_fn(self, params):
     data_dir = '/data'
     data_index_dir = '/data/index_files'
     if self.is_training:
         mode = 'train'
     else:
         mode = 'validation'
     if data_dir is not None:
         #if hvd.rank() == 0:
         tf.logging.info("Using DALI input... ")
         filenames, num_samples = list_filenames_in_dataset(
             data_dir=data_dir, mode=mode)
         idx_filenames = parse_dali_idx_dataset(data_idx_dir=data_index_dir,
                                                mode=mode)
         return data_utils.get_dali_input_fn(
             filenames=filenames,
             idx_filenames=idx_filenames,
             batch_size=params['batch_size'],
             height=self.image_size,
             width=self.image_size,
             training=self.is_training,
             distort_color=False,
             num_threads=4,
             deterministic=False)