def test_inputs_using_generic_text_dataset_preprocess_fn(self): gin.bind_parameter( 'generic_text_dataset_preprocess_fn.spm_path', _spm_path()) gin.bind_parameter( 'generic_text_dataset_preprocess_fn.text_preprocess_fns', [lambda ds, training: t5_processors.squad(ds)]) # Just make sure this doesn't throw. def data_streams(): return tf_inputs.data_streams( 'squad', data_dir=_TESTDATA, input_name='inputs', target_name='targets', bare_preprocess_fn=tf_inputs.generic_text_dataset_preprocess_fn, shuffle_buffer_size=1) n_devices = 3 squad_inputs = inputs.batcher( data_streams=data_streams, max_eval_length=512, buckets=([513,], [n_devices, n_devices]) ) eval_stream = squad_inputs.eval_stream(n_devices) inps, tgts, _ = next(eval_stream) # We can only assert that the batch dim gets divided by n_devices. self.assertEqual(inps.shape[0] % n_devices, 0) self.assertEqual(tgts.shape[0] % n_devices, 0)
def test_inputs_using_generic_text_dataset_preprocess_fn(self): gin.bind_parameter('generic_text_dataset_preprocess_fn.spm_path', _spm_path()) gin.bind_parameter('generic_text_dataset_preprocess_fn.copy_plaintext', True) gin.bind_parameter( 'generic_text_dataset_preprocess_fn.text_preprocess_fn', t5_processors.squad) # Just make sure this doesn't throw. def data_streams(): return tf_inputs.data_streams('squad', data_dir=_TESTDATA, input_name='inputs', target_name='targets', bare_preprocess_fn=tf_inputs. generic_text_dataset_preprocess_fn) squad_inputs = inputs.batcher( data_streams=data_streams, batch_size_per_device=2, eval_batch_size=2, max_eval_length=50, ) n_devices = 3 train_stream = squad_inputs.train_stream(n_devices) inps, tgts = next(train_stream) # We can only assert that the batch dim gets divided by n_devices. self.assertEqual(inps.shape[0] % n_devices, 0) self.assertEqual(tgts.shape[0] % n_devices, 0)
def _mnist_dataset(): """Loads (and caches) the standard MNIST data set.""" streams = tf_inputs.data_streams('mnist') return inputs.batcher(streams, variable_shapes=False, batch_size_per_device=256, eval_batch_size=256)
def dataset_from_trax_tfds(dataset_name='mnist', variable_shapes=False, tfds_dir=None, batch_size_per_device=256, eval_batch_size=256, **kwargs): kwargs = {k: v for k, v in kwargs.items() if v is not None} if len(kwargs): logger.warn( 'dataset_from_trax_tfds: ignoring arguments {}'.format(kwargs)) streams = tf_inputs.data_streams(dataset_name, data_dir=tfds_dir) return inputs.batcher(streams, variable_shapes=variable_shapes, batch_size_per_device=batch_size_per_device, eval_batch_size=eval_batch_size)