示例#1
0
def _mnist_dataset():
    """Loads (and caches) the standard MNIST data set."""
    streams = tf_inputs.data_streams('mnist')
    return inputs.batcher(streams,
                          variable_shapes=False,
                          batch_size_per_device=256,
                          eval_batch_size=256)
示例#2
0
 def data_streams():
   return tf_inputs.data_streams(
       'squad',
       data_dir=_TESTDATA,
       input_name='inputs',
       target_name='targets',
       bare_preprocess_fn=tf_inputs.generic_text_dataset_preprocess_fn,
       shuffle_buffer_size=1)
示例#3
0
  def test_c4(self):
    gin.bind_parameter('c4_preprocess.max_target_length', 2048)
    gin.bind_parameter('c4_preprocess.tokenization', 'spc')
    gin.bind_parameter('c4_preprocess.spm_path', _spm_path())

    # Just make sure this doesn't throw.
    _ = tf_inputs.data_streams(
        'c4', data_dir=_TESTDATA, input_name='targets', target_name='text',
        preprocess_fn=tf_inputs.c4_preprocess)
示例#4
0
def _mnist_brightness_dataset():
  """Loads (and caches) a MNIST mean brightness data set."""
  def preprocess_stream(stream):
    def new_stream():
      for (image, _) in stream():
        yield (image, (image / 255).mean()[None])
    return new_stream

  streams = tuple(map(preprocess_stream, tf_inputs.data_streams('mnist')))
  return inputs.batcher(streams, variable_shapes=False,
                        batch_size_per_device=256,
                        eval_batch_size=256)
示例#5
0
  def test_c4_pretrain(self):
    _t5_gin_config()

    gin.bind_parameter('c4_bare_preprocess_fn.spm_path', _spm_path())

    gin.bind_parameter('batcher.batch_size_per_device', 8)
    gin.bind_parameter('batcher.eval_batch_size', 8)
    gin.bind_parameter('batcher.max_eval_length', 50)
    gin.bind_parameter('batcher.buckets', ([51], [8, 1]))

    # Just make sure this doesn't throw.
    _ = tf_inputs.data_streams(
        'c4', data_dir=_TESTDATA, input_name='inputs', target_name='targets',
        bare_preprocess_fn=tf_inputs.c4_bare_preprocess_fn)