def _mnist_dataset(): """Loads (and caches) the standard MNIST data set.""" streams = tf_inputs.data_streams('mnist') return inputs.batcher(streams, variable_shapes=False, batch_size_per_device=256, eval_batch_size=256)
def data_streams(): return tf_inputs.data_streams( 'squad', data_dir=_TESTDATA, input_name='inputs', target_name='targets', bare_preprocess_fn=tf_inputs.generic_text_dataset_preprocess_fn, shuffle_buffer_size=1)
def test_c4(self): gin.bind_parameter('c4_preprocess.max_target_length', 2048) gin.bind_parameter('c4_preprocess.tokenization', 'spc') gin.bind_parameter('c4_preprocess.spm_path', _spm_path()) # Just make sure this doesn't throw. _ = tf_inputs.data_streams( 'c4', data_dir=_TESTDATA, input_name='targets', target_name='text', preprocess_fn=tf_inputs.c4_preprocess)
def _mnist_brightness_dataset(): """Loads (and caches) a MNIST mean brightness data set.""" def preprocess_stream(stream): def new_stream(): for (image, _) in stream(): yield (image, (image / 255).mean()[None]) return new_stream streams = tuple(map(preprocess_stream, tf_inputs.data_streams('mnist'))) return inputs.batcher(streams, variable_shapes=False, batch_size_per_device=256, eval_batch_size=256)
def test_c4_pretrain(self): _t5_gin_config() gin.bind_parameter('c4_bare_preprocess_fn.spm_path', _spm_path()) gin.bind_parameter('batcher.batch_size_per_device', 8) gin.bind_parameter('batcher.eval_batch_size', 8) gin.bind_parameter('batcher.max_eval_length', 50) gin.bind_parameter('batcher.buckets', ([51], [8, 1])) # Just make sure this doesn't throw. _ = tf_inputs.data_streams( 'c4', data_dir=_TESTDATA, input_name='inputs', target_name='targets', bare_preprocess_fn=tf_inputs.c4_bare_preprocess_fn)