def _mnist_dataset(): """Loads (and caches) the standard MNIST data set.""" streams = tf_inputs.data_streams('mnist') return inputs.batcher(streams, variable_shapes=False, batch_size_per_device=256, eval_batch_size=256)
def test_inputs_using_generic_text_dataset_preprocess_fn(self): gin.bind_parameter('generic_text_dataset_preprocess_fn.spm_path', _spm_path()) gin.bind_parameter('generic_text_dataset_preprocess_fn.text_preprocess_fns', [lambda ds, training: t5_processors.squad(ds)]) # Just make sure this doesn't throw. def data_streams(): return tf_inputs.data_streams( 'squad', data_dir=_TESTDATA, input_name='inputs', target_name='targets', bare_preprocess_fn=tf_inputs.generic_text_dataset_preprocess_fn, shuffle_buffer_size=1) n_devices = 3 squad_inputs = inputs.batcher( data_streams=data_streams, max_eval_length=512, buckets=([ 513, ], [n_devices, n_devices])) eval_stream = squad_inputs.eval_stream(n_devices) inps, tgts, _ = next(eval_stream) # We can only assert that the batch dim gets divided by n_devices. self.assertEqual(inps.shape[0] % n_devices, 0) self.assertEqual(tgts.shape[0] % n_devices, 0)
def _mnist_brightness_dataset(): """Loads (and caches) a MNIST mean brightness data set.""" def preprocess_stream(stream): def new_stream(): for (image, _) in stream(): yield (image, (image / 255).mean()[None]) return new_stream streams = tuple(map(preprocess_stream, tf_inputs.data_streams('mnist'))) return inputs.batcher(streams, variable_shapes=False, batch_size_per_device=256, eval_batch_size=256)