Beispiel #1
0
 def data_streams():
     return tf_inputs.data_streams('squad',
                                   data_dir=_TESTDATA,
                                   input_name='inputs',
                                   target_name='targets',
                                   bare_preprocess_fn=tf_inputs.
                                   generic_text_dataset_preprocess_fn)
Beispiel #2
0
def _mnist_dataset():
    """Loads (and caches) the standard MNIST data set."""
    streams = tf_inputs.data_streams('mnist')
    return inputs.batcher(streams,
                          variable_shapes=False,
                          batch_size_per_device=256,
                          eval_batch_size=256)
Beispiel #3
0
  def test_c4(self):
    gin.bind_parameter('c4_preprocess.max_target_length', 2048)
    gin.bind_parameter('c4_preprocess.tokenization', 'spc')
    gin.bind_parameter('c4_preprocess.spm_path', _spm_path())

    # Just make sure this doesn't throw.
    _ = tf_inputs.data_streams(
        'c4', data_dir=_TESTDATA, input_name='targets', target_name='text',
        preprocess_fn=tf_inputs.c4_preprocess)
Beispiel #4
0
  def test_c4_pretrain(self):
    _t5_gin_config()

    gin.bind_parameter('c4_bare_preprocess_fn.spm_path', _spm_path())

    gin.bind_parameter('batcher.batch_size_per_device', 8)
    gin.bind_parameter('batcher.eval_batch_size', 8)
    gin.bind_parameter('batcher.max_eval_length', 50)
    gin.bind_parameter('batcher.buckets', ([51], [8, 1]))

    # Just make sure this doesn't throw.
    _ = tf_inputs.data_streams(
        'c4', data_dir=_TESTDATA, input_name='inputs', target_name='targets',
        bare_preprocess_fn=tf_inputs.c4_bare_preprocess_fn)
Beispiel #5
0
def dataset_from_trax_tfds(dataset_name='mnist',
                           variable_shapes=False,
                           tfds_dir=None,
                           batch_size_per_device=256,
                           eval_batch_size=256,
                           **kwargs):
    kwargs = {k: v for k, v in kwargs.items() if v is not None}
    if len(kwargs):
        logger.warn(
            'dataset_from_trax_tfds: ignoring arguments {}'.format(kwargs))
    streams = tf_inputs.data_streams(dataset_name, data_dir=tfds_dir)
    return inputs.batcher(streams,
                          variable_shapes=variable_shapes,
                          batch_size_per_device=batch_size_per_device,
                          eval_batch_size=eval_batch_size)