def data_streams(): return tf_inputs.data_streams('squad', data_dir=_TESTDATA, input_name='inputs', target_name='targets', bare_preprocess_fn=tf_inputs. generic_text_dataset_preprocess_fn)
def _mnist_dataset(): """Loads (and caches) the standard MNIST data set.""" streams = tf_inputs.data_streams('mnist') return inputs.batcher(streams, variable_shapes=False, batch_size_per_device=256, eval_batch_size=256)
def test_c4(self): gin.bind_parameter('c4_preprocess.max_target_length', 2048) gin.bind_parameter('c4_preprocess.tokenization', 'spc') gin.bind_parameter('c4_preprocess.spm_path', _spm_path()) # Just make sure this doesn't throw. _ = tf_inputs.data_streams( 'c4', data_dir=_TESTDATA, input_name='targets', target_name='text', preprocess_fn=tf_inputs.c4_preprocess)
def test_c4_pretrain(self): _t5_gin_config() gin.bind_parameter('c4_bare_preprocess_fn.spm_path', _spm_path()) gin.bind_parameter('batcher.batch_size_per_device', 8) gin.bind_parameter('batcher.eval_batch_size', 8) gin.bind_parameter('batcher.max_eval_length', 50) gin.bind_parameter('batcher.buckets', ([51], [8, 1])) # Just make sure this doesn't throw. _ = tf_inputs.data_streams( 'c4', data_dir=_TESTDATA, input_name='inputs', target_name='targets', bare_preprocess_fn=tf_inputs.c4_bare_preprocess_fn)
def dataset_from_trax_tfds(dataset_name='mnist', variable_shapes=False, tfds_dir=None, batch_size_per_device=256, eval_batch_size=256, **kwargs): kwargs = {k: v for k, v in kwargs.items() if v is not None} if len(kwargs): logger.warn( 'dataset_from_trax_tfds: ignoring arguments {}'.format(kwargs)) streams = tf_inputs.data_streams(dataset_name, data_dir=tfds_dir) return inputs.batcher(streams, variable_shapes=variable_shapes, batch_size_per_device=batch_size_per_device, eval_batch_size=eval_batch_size)