Beispiel #1
0
  def test_inputs_using_generic_text_dataset_preprocess_fn(self):
    gin.bind_parameter('generic_text_dataset_preprocess_fn.spm_path',
                       _spm_path())
    gin.bind_parameter('generic_text_dataset_preprocess_fn.text_preprocess_fns',
                       [lambda ds, training: t5_processors.squad(ds)])

    # Just make sure this doesn't throw.
    def data_streams():
      return tf_inputs.data_streams(
          'squad',
          data_dir=_TESTDATA,
          input_name='inputs',
          target_name='targets',
          bare_preprocess_fn=tf_inputs.generic_text_dataset_preprocess_fn,
          shuffle_buffer_size=1)

    n_devices = 3

    squad_inputs = inputs.batcher(
        data_streams=data_streams,
        max_eval_length=512,
        buckets=([
            513,
        ], [n_devices, n_devices]))

    eval_stream = squad_inputs.eval_stream(n_devices)
    inps, tgts, _ = next(eval_stream)

    # We can only assert that the batch dim gets divided by n_devices.
    self.assertEqual(inps.shape[0] % n_devices, 0)
    self.assertEqual(tgts.shape[0] % n_devices, 0)
Beispiel #2
0
    def test_squad(self):
        og_dataset = tf.data.Dataset.from_tensors({
            'id': 'testid',
            'context': 'Some context.',
            'question': 'A question?',
            'answers': {
                'text': ['The answer.', 'Another answer.'],
            }
        })

        dataset = prep.squad(og_dataset)
        assert_dataset(
            dataset, {
                'id': 'testid',
                'inputs': 'question: A question ? context: Some context . ',
                'targets': 'The answer . ',
                'context': 'Some context . ',
                'question': 'A question ? ',
                'answers': ['The answer . ', 'Another answer . '],
            })
Beispiel #3
0
  def test_generic_text_dataset_preprocess_fn(self):
    dataset = _load_dataset('squad')

    example, = tfds.as_numpy(dataset.take(1))

    self.assertNotIn('inputs', example)
    self.assertNotIn('targets', example)

    proc_dataset = tf_inputs.generic_text_dataset_preprocess_fn(
        dataset, spm_path=_spm_path(),
        text_preprocess_fns=[lambda ds, training: t5_processors.squad(ds)],
        copy_plaintext=True,
        debug_print_examples=True,
        debug_print_examples_rate=1.0)

    proc_example, = tfds.as_numpy(proc_dataset.take(1))

    self.assertIn('inputs', proc_example)
    self.assertIn('targets', proc_example)

    self.assertEqual(proc_example['inputs'].dtype, np.int64)
    self.assertEqual(proc_example['targets'].dtype, np.int64)