def test_optional_features(self):
    def _dummy_preprocessor(output):
      return lambda _: tf.data.Dataset.from_tensors(output)

    default_vocab = test_utils.sentencepiece_vocab()
    features = {
        "inputs":
            dataset_providers.Feature(vocabulary=default_vocab, required=False),
        "targets":
            dataset_providers.Feature(vocabulary=default_vocab, required=True),
    }

    test_utils.add_task(
        "text_missing_optional_feature",
        test_utils.get_fake_dataset,
        output_features=features,
        text_preprocessor=_dummy_preprocessor({"targets": "a"}))
    TaskRegistry.get_dataset(
        "text_missing_optional_feature", {"targets": 13},
        "train", use_cached=False)

    test_utils.add_task(
        "text_missing_required_feature",
        test_utils.get_fake_dataset,
        output_features=features,
        text_preprocessor=_dummy_preprocessor({"inputs": "a"}))
    with self.assertRaisesRegex(
        ValueError,
        "Task dataset is missing expected output feature after preprocessing: "
        "targets"):
      TaskRegistry.get_dataset(
          "text_missing_required_feature", {"inputs": 13},
          "train", use_cached=False)
 def test_vocab(self):
     vocab = test_utils.sentencepiece_vocab()
     self.assertEqual(26, vocab.vocab_size)
     self.assertSequenceEqual(_TEST_TOKENS, vocab.encode(_TEST_STRING))
     self.assertEqual(_TEST_STRING,
                      tf.compat.as_bytes(vocab.decode(_TEST_TOKENS)))
     self.assertSequenceEqual(_TEST_TOKENS,
                              tuple(vocab.encode_tf(_TEST_STRING).numpy()))
     self.assertEqual(_TEST_STRING, vocab.decode_tf(_TEST_TOKENS).numpy())
コード例 #3
0
 def test_vocab(self):
     vocab = test_utils.sentencepiece_vocab()
     self.assertEqual(26, vocab.vocab_size)
     self.assertSequenceEqual(self.TEST_TOKENS,
                              vocab.encode(self.TEST_STRING))
     self.assertEqual(self.TEST_STRING, vocab.decode(self.TEST_TOKENS))
     self.assertSequenceEqual(
         self.TEST_TOKENS, tuple(vocab.encode_tf(self.TEST_STRING).numpy()))
     self.assertEqual(self.TEST_STRING, _decode_tf(vocab, self.TEST_TOKENS))
コード例 #4
0
 def test_no_eos(self):
     default_vocab = test_utils.sentencepiece_vocab()
     features = {
         "inputs": utils.Feature(add_eos=True, vocabulary=default_vocab),
         "targets": utils.Feature(add_eos=False, vocabulary=default_vocab),
     }
     test_utils.add_task("task_no_eos",
                         test_utils.get_fake_dataset,
                         output_features=features)
     fn_task = TaskRegistry.get("task_no_eos")
     test_utils.verify_task_matches_fake_datasets(fn_task, use_cached=False)
コード例 #5
0
    def test_triviaqa_truncate_text(self):

        vocab = test_utils.sentencepiece_vocab()

        def tokenize_and_prepare_dataset(inputs, targets):
            tokenized_inputs = vocab.encode(inputs)
            tokenized_targets = vocab.encode(targets)

            dataset = tf.data.Dataset.from_tensors({
                'inputs':
                tokenized_inputs,
                'targets':
                tokenized_targets,
            })

            return dataset, tokenized_targets

        inputs = 'This is a very very long string which must contain the answer.'
        targets = 'long string'

        og_dataset, tokenized_targets = tokenize_and_prepare_dataset(
            inputs, targets)

        for _ in range(0, 10):
            dataset = prep.trivia_qa_truncate_inputs(
                og_dataset,
                output_features=None,
                sequence_length={'inputs': 20})

            for data in test_utils.dataset_as_text(dataset):
                self.assertLen(data['inputs'], 20)
                self.assertContainsSubset(tokenized_targets, data['inputs'])

        # Dummy input which exists in the vocab to be able to compare strings after
        # decoding.
        inputs = 'w h d n r t v'
        targets = 'h d'

        og_dataset, _ = tokenize_and_prepare_dataset(inputs, targets)

        for _ in range(0, 5):
            dataset = prep.trivia_qa_truncate_inputs(
                og_dataset,
                output_features=None,
                sequence_length={'inputs': 5})

            for data in test_utils.dataset_as_text(dataset):
                self.assertLen(data['inputs'], 5)
                truncated_inputs = vocab.decode(data['inputs'].tolist())
                new_targets = vocab.decode(data['targets'].tolist())
                self.assertRegex(truncated_inputs, '.*' + targets + '.*')
                self.assertEqual(targets, new_targets)
コード例 #6
0
 def test_dtype(self):
     default_vocab = test_utils.sentencepiece_vocab()
     features = {
         "inputs":
         # defaults to int32
         dataset_providers.Feature(vocabulary=default_vocab),
         "targets":
         dataset_providers.Feature(dtype=tf.int64,
                                   vocabulary=default_vocab),
     }
     test_utils.add_task("task_dtypes",
                         test_utils.get_fake_dataset,
                         output_features=features)
     dtype_task = TaskRegistry.get("task_dtypes")
     test_utils.verify_task_matches_fake_datasets(dtype_task,
                                                  use_cached=False)
コード例 #7
0
def register_dummy_task(
    task_name: str,
    dataset_fn: Callable[[str, str], tf.data.Dataset],
    output_feature_names: Sequence[str] = ("inputs", "targets")) -> None:
  """Register a dummy task for GetDatasetTest."""
  dataset_providers.TaskRegistry.add(
      task_name,
      dataset_providers.TaskV3,
      source=dataset_providers.FunctionSource(
          dataset_fn=dataset_fn, splits=["train", "validation"]),
      preprocessors=[dataset_providers.CacheDatasetPlaceholder()],
      output_features={
          feat: dataset_providers.Feature(test_utils.sentencepiece_vocab())
          for feat in output_feature_names
      },
      metric_fns=[])
コード例 #8
0
    def test_v3_value_errors(self):
        dataset_fn = (lambda split, shuffle_files: tf.data.Dataset.
                      from_tensors(["test"]))
        output_features = {
            "inputs":
            dataset_providers.Feature(test_utils.sentencepiece_vocab())
        }

        with self.assertRaisesRegex(
                ValueError,
                "`CacheDatasetPlaceholder` can appear at most once in the "
                "preprocessing pipeline. Found 2 in 'multiple_cache_placeholders'."
        ):
            dataset_providers.TaskV3(
                "multiple_cache_placeholders",
                source=dataset_providers.FunctionSource(
                    dataset_fn=dataset_fn, splits=["train", "validation"]),
                preprocessors=[
                    test_utils.test_text_preprocessor, preprocessors.tokenize,
                    dataset_providers.CacheDatasetPlaceholder(),
                    test_utils.test_token_preprocessor,
                    dataset_providers.CacheDatasetPlaceholder()
                ],
                output_features=output_features,
                metric_fns=[])

        with self.assertRaisesRegex(
                ValueError,
                "'test_token_preprocessor' has a `sequence_length` argument but occurs "
                "before `CacheDatasetPlaceholder` in 'sequence_length_pre_cache'. This "
                "is not allowed since the sequence length is specified at run time."
        ):
            dataset_providers.TaskV3(
                "sequence_length_pre_cache",
                dataset_providers.FunctionSource(
                    dataset_fn=dataset_fn,
                    splits=["train"],
                ),
                preprocessors=[
                    test_utils.test_text_preprocessor, preprocessors.tokenize,
                    test_utils.test_token_preprocessor,
                    dataset_providers.CacheDatasetPlaceholder()
                ],
                output_features=output_features,
                metric_fns=[])
コード例 #9
0
 def test_denoise_nested_decorators(self):
     """Test whether gin and utils.map_over_dataset decorators are compatible."""
     bindings = """
   preprocessors.unsupervised.preprocessors = [@preprocessors.denoise]
   preprocessors.denoise.noise_density = 0.15
   preprocessors.denoise.noise_mask_fn = @preprocessors.iid_noise_mask
   preprocessors.denoise.inputs_fn = @noise_token_to_sentinel
 """
     gin.parse_config(bindings)
     og_dataset = tf.data.Dataset.from_tensor_slices({'targets': [1, 2, 3]})
     output_features = {
         'targets': Feature(test_utils.sentencepiece_vocab())
     }
     # Test denoise function when it is used as a gin-configurable of another
     # gin-configurable, prep.unsupervised.
     dataset = prep.unsupervised(og_dataset,
                                 output_features=output_features)
     self.assertIsInstance(dataset, tf.data.Dataset)
  def test_denoise(self):
    tf.set_random_seed(55)

    vocab = test_utils.sentencepiece_vocab()
    target_tokens = vocab.encode('The quick brown fox.')

    # This is what it encodes to.
    self.assertEqual(
        target_tokens,
        [3, 2, 20, 4, 3, 2, 8, 13, 2, 3, 2, 23, 7, 19, 22, 3, 2, 7, 2])

    og_dataset = tf.data.Dataset.from_tensor_slices({
        'targets': [target_tokens],
    })

    output_features = {
        'targets': utils.Feature(vocab),
    }

    # These are the parameters of denoise in the operative config of 'base'.
    # Except noise_density, bumped up from 0.15 to 0.3 in order to demonstrate
    # multiple corrupted spans.
    denoised_dataset = prep.denoise(
        og_dataset,
        output_features,
        noise_density=0.3,
        noise_mask_fn=prep.random_spans_noise_mask,
        inputs_fn=prep.noise_span_to_unique_sentinel,
        targets_fn=prep.nonnoise_span_to_unique_sentinel)

    # Two spans corrupted, [2] and [22, 3, 2, 7, 2], replaced by unique
    # sentinels 25 and 24 respectively.
    assert_dataset(denoised_dataset, [
        {
            'inputs': [
                3, 25, 20, 4, 3, 2, 8, 13, 2, 3, 2, 23, 7, 19, 24
            ],
            'targets': [
                25, 2, 24, 22, 3, 2, 7, 2
            ],
        },
    ])
コード例 #11
0
 def test_prefix_lm(self):
     vocab = test_utils.sentencepiece_vocab()
     inp = list(range(1, 101))
     og_dataset = tf.data.Dataset.from_tensor_slices({'targets': [inp]})
     og_dataset = og_dataset.repeat(100)
     output_features = {'targets': Feature(vocab)}
     output_dataset = prep.prefix_lm(
         og_dataset,
         {
             'inputs': 100,
             'targets': 100
         },
         output_features,
     )
     input_lengths = set()
     for ex in output_dataset.as_numpy_iterator():
         self.assertListEqual(
             ex['inputs'].tolist() + ex['targets'].tolist(), inp)
         input_lengths.add(len(ex['inputs']))
     self.assertGreater(len(input_lengths), 1)
コード例 #12
0
 def test_not_equal(self):
     vocab1 = test_utils.sentencepiece_vocab()
     vocab2 = test_utils.sentencepiece_vocab(10)
     self.assertNotEqual(vocab1, vocab2)
コード例 #13
0
 def test_extra_ids(self):
     vocab = test_utils.sentencepiece_vocab(extra_ids=10)
     self.assertEqual(36, vocab.vocab_size)
     self.assertEqual("v", vocab.decode([25]))
     self.assertEqual(_UNK_STRING, tf.compat.as_bytes(vocab.decode([35])))
     self.assertEqual(_UNK_STRING, vocab.decode_tf([35]).numpy())
コード例 #14
0
 def test_extra_ids(self):
     vocab = test_utils.sentencepiece_vocab(extra_ids=10)
     self.assertEqual(36, vocab.vocab_size)
     self.assertEqual("v", vocab.decode([25]))
     self.assertEqual(self.UNK_STRING, vocab.decode([35]))
     self.assertEqual(self.UNK_STRING, _decode_tf(vocab, [35]))
コード例 #15
0
 def test_get_sentencepiece_model_path(self):
   self.assertEqual(
       test_utils.sentencepiece_vocab().sentencepiece_model_file,
       mesh_transformer.get_sentencepiece_model_path("cached_mixture")
   )