def test_inputs_using_generic_text_dataset_preprocess_fn(self): gin.bind_parameter('generic_text_dataset_preprocess_fn.spm_path', _spm_path()) gin.bind_parameter('generic_text_dataset_preprocess_fn.text_preprocess_fns', [lambda ds, training: t5_processors.squad(ds)]) # Just make sure this doesn't throw. def data_streams(): return tf_inputs.data_streams( 'squad', data_dir=_TESTDATA, input_name='inputs', target_name='targets', bare_preprocess_fn=tf_inputs.generic_text_dataset_preprocess_fn, shuffle_buffer_size=1) n_devices = 3 squad_inputs = inputs.batcher( data_streams=data_streams, max_eval_length=512, buckets=([ 513, ], [n_devices, n_devices])) eval_stream = squad_inputs.eval_stream(n_devices) inps, tgts, _ = next(eval_stream) # We can only assert that the batch dim gets divided by n_devices. self.assertEqual(inps.shape[0] % n_devices, 0) self.assertEqual(tgts.shape[0] % n_devices, 0)
def test_squad(self): og_dataset = tf.data.Dataset.from_tensors({ 'id': 'testid', 'context': 'Some context.', 'question': 'A question?', 'answers': { 'text': ['The answer.', 'Another answer.'], } }) dataset = prep.squad(og_dataset) assert_dataset( dataset, { 'id': 'testid', 'inputs': 'question: A question ? context: Some context . ', 'targets': 'The answer . ', 'context': 'Some context . ', 'question': 'A question ? ', 'answers': ['The answer . ', 'Another answer . '], })
def test_generic_text_dataset_preprocess_fn(self): dataset = _load_dataset('squad') example, = tfds.as_numpy(dataset.take(1)) self.assertNotIn('inputs', example) self.assertNotIn('targets', example) proc_dataset = tf_inputs.generic_text_dataset_preprocess_fn( dataset, spm_path=_spm_path(), text_preprocess_fns=[lambda ds, training: t5_processors.squad(ds)], copy_plaintext=True, debug_print_examples=True, debug_print_examples_rate=1.0) proc_example, = tfds.as_numpy(proc_dataset.take(1)) self.assertIn('inputs', proc_example) self.assertIn('targets', proc_example) self.assertEqual(proc_example['inputs'].dtype, np.int64) self.assertEqual(proc_example['targets'].dtype, np.int64)