def make_dataset(filename, mode):
     return tfrecord_batcher.tfrecord_dataset(filename,
                                              job['batch_size'],
                                              job['seq_length'],
                                              job.get('seq_depth', 4),
                                              job['num_targets'],
                                              job['target_length'],
                                              mode=mode)
Beispiel #2
0
 def make_dataset(pat, mode):
     return tfrecord_batcher.tfrecord_dataset(pat,
                                              job['batch_size'],
                                              job['seq_length'],
                                              job.get('seq_depth', 4),
                                              job['target_length'],
                                              job['num_targets'],
                                              mode=mode,
                                              repeat=False)
 def make_dataset(filename):
   return tfrecord_batcher.tfrecord_dataset(
       filename,
       job['batch_size'],
       job['seq_length'],
       job['seq_depth'],
       job['num_targets'],
       job['target_width'],
       shuffle=True,
       trim_eos=True)
Beispiel #4
0
 def make_dataset(filename, mode):
     return tfrecord_batcher.tfrecord_dataset(
         filename,
         job["batch_size"],
         job["seq_length"],
         job.get("seq_depth", 4),
         job["target_length"],
         job["num_targets"],
         mode=mode,
         repeat=False,
     )
    def make_dataset(loc, mode, is_dir=False):
        """
    Creates the tfrecord dataset.

    This function is now expected to take either some filename string OR a
    list of filename strings as the data source for tfrecord_dataset.
    """
        if is_dir:
            pattern = ''
            if mode == tf.estimator.ModeKeys.TRAIN:
                pattern = 'train-*.tfr'
            elif mode == tf.estimator.ModeKeys.EVAL:
                pattern = 'valid-*.tfr'
            elif mode == tf.estimator.ModeKeys.PREDICT:
                pattern = 'test-*.tfr'
            else:
                raise Exception('unrecognized tfrecord mode. Aborting.')
            pattern = path.join(loc, pattern)

            return tfrecord_batcher.tfrecord_dataset(pattern,
                                                     job['batch_size'],
                                                     job['seq_length'],
                                                     job['seq_depth'],
                                                     job['target_length'],
                                                     job['num_targets'],
                                                     mode=mode,
                                                     repeat=False)
        else:
            return tfrecord_batcher.tfrecord_dataset(loc,
                                                     job['batch_size'],
                                                     job['seq_length'],
                                                     job.get('seq_depth', 4),
                                                     job['target_length'],
                                                     job['num_targets'],
                                                     mode=mode,
                                                     repeat=False)