コード例 #1
0
ファイル: io.py プロジェクト: FXuZ/mlmono
    def _gen_tf_dataset(self):
        # TODO(jdaaph) (URGENT): clean the method of commented code and put config
        #                        constants declared only in the ctor.

        list_files = tf.data.Dataset.list_files(self.filenames)
        if is_training(self.mode):
            list_files = list_files.shuffle(self.io_config.fn_shuffle_buffer)

        # Parallel_interleave is preferred as it's deterministic in ordering,
        # this ensures better reproducibility.
        dataset = list_files.apply(
            tf.contrib.data.parallel_interleave(
                tf.data.TFRecordDataset,
                cycle_length=self.interleave_cycle,
                block_length=self.interleave_block))
        dataset = dataset.map(lambda raw: self.parse_file(raw))

        # Always shuffle before batch to promote randomness in the training data.
        # The data_shuffle_buffer should be some value > rows in single data shard (record).
        if is_training(self.mode):
            dataset = dataset.shuffle(self.io_config.data_shuffle_buffer)
        dataset = dataset.repeat(self.global_config.trainer.num_steps)
        dataset = dataset.batch(batch_size=self.batch_size)

        dataset = dataset.map(map_func=self.parse_iter,
                              num_parallel_calls=self.num_parallel_parse)
        return dataset
コード例 #2
0
 def keras_load_data(mode):
     (x_train, y_train), (x_test,
                          y_test) = tf.keras.datasets.mnist.load_data()
     if is_training(mode):
         return x_train, y_train
     else:
         return x_test, y_test
コード例 #3
0
    def _gen_tf_dataset(self):
        list_files = tf.data.Dataset.list_files(self.filenames)
        if is_training(self.mode):
            list_files = list_files.shuffle(self.io_config.fn_shuffle_buffer)

        # Parallel_interleave is preferred as it's deterministic in ordering,
        # this ensures better reproducibility.
        dataset = list_files.apply(
            tf.contrib.data.parallel_interleave(
                lambda filename: self.parse_file(filename),
                cycle_length=self.interleave_cycle,
                block_length=self.interleave_block))

        # The data_shuffle_buffer should be some value > rows in single data shard (record).
        dataset = dataset.batch(batch_size=self.batch_size)
        if is_training(self.mode):
            dataset = dataset.shuffle(self.io_config.data_shuffle_buffer)
        dataset = dataset.repeat(num_epochs)
        return dataset
コード例 #4
0
ファイル: io.py プロジェクト: FXuZ/mlmono
    def _gen_tf_dataset(self):
        '''Construct a tf dataset using `tf.Dataset`. '''
        images, labels = self.keras_load_data(self.mode)
        dataset = tf.data.Dataset.from_tensor_slices((images, labels))

        if is_training(self.mode):
            dataset = dataset.shuffle(self.data_shuffle_buffer)
        dataset = dataset.map(self.parse_keras_tensor)
        dataset = dataset.batch(self.batch_size)
        dataset = dataset.repeat()
        dataset = dataset.prefetch(tf.contrib.data.AUTOTUNE)

        return dataset