示例#1
0
 def test_extra_ids(self):
   vocab = test_utils.sentencepiece_vocab(extra_ids=10)
   self.assertEqual(36, vocab.vocab_size)
   self.assertEqual("v", vocab.decode([25]))
   test_string = "<extra_id_0> <extra_id_1> v <extra_id_9>"
   test_tokens = (35, 34, 3, 25, 26)
   self.assertEqual(test_string, vocab.decode(test_tokens))
   self.assertEqual(test_string, _decode_tf(vocab, test_tokens))
   self.assertSequenceEqual(test_tokens, vocab.encode(test_string))
   self.assertSequenceEqual(
       test_tokens,
       tuple(vocab.encode_tf(test_string).numpy()))
 def test_no_eos(self):
     default_vocab = test_utils.sentencepiece_vocab()
     features = {
         "inputs":
         dataset_providers.Feature(add_eos=True, vocabulary=default_vocab),
         "targets":
         dataset_providers.Feature(add_eos=False, vocabulary=default_vocab),
     }
     self.add_task("task_no_eos",
                   self.function_source,
                   output_features=features)
     self.verify_task_matches_fake_datasets("task_no_eos", use_cached=False)
示例#3
0
    def test_feature_validation(self):
        default_vocab = test_utils.sentencepiece_vocab()
        features = {
            "inputs":
            dataset_providers.Feature(vocabulary=default_vocab,
                                      required=False),
            "targets":
            dataset_providers.Feature(vocabulary=default_vocab, required=True),
        }

        def _materialize(output):
            task = dataset_providers.Task(
                "feature_validation_task",
                self.function_source,
                output_features=features,
                preprocessors=(
                    lambda _: tf.data.Dataset.from_tensors(output), ),
                metric_fns=[],
            )
            list(
                task.get_dataset({
                    "inputs": 13,
                    "targets": 13
                },
                                 "train",
                                 use_cached=False).as_numpy_iterator())

        # Missing optional feature: OK
        _materialize({"targets": [0]})

        # Missing required feature.
        with self.assertRaisesWithLiteralMatch(
                ValueError,
                "Task dataset is missing expected output feature after preprocessing: "
                "targets"):
            _materialize({"inputs": [0]})

        # Wrong type.
        with self.assertRaisesWithLiteralMatch(
                ValueError,
                "Task dataset has incorrect type for feature 'targets' after "
                "preprocessing: Got string, expected int32"):
            _materialize({"targets": ["wrong type"]})

        # Wrong rank.
        with self.assertRaisesWithLiteralMatch(
                ValueError,
                "Task dataset has incorrect rank for feature 'targets' after "
                "preprocessing: Got 0, expected 1"):
            _materialize({"targets": 0})
示例#4
0
    def test_value_errors(self):
        dataset_fn = (lambda split, shuffle_files: tf.data.Dataset.
                      from_tensors(["test"]))
        output_features = {
            "inputs":
            dataset_providers.Feature(test_utils.sentencepiece_vocab())
        }

        with self.assertRaisesWithLiteralMatch(
                ValueError,
                "`CacheDatasetPlaceholder` can appear at most once in the "
                "preprocessing pipeline. Found 2 in 'multiple_cache_placeholders'."
        ):
            dataset_providers.Task(
                "multiple_cache_placeholders",
                source=dataset_providers.FunctionDataSource(
                    dataset_fn=dataset_fn, splits=["train", "validation"]),
                preprocessors=[
                    test_utils.test_text_preprocessor, preprocessors.tokenize,
                    dataset_providers.CacheDatasetPlaceholder(),
                    test_utils.test_token_preprocessor,
                    dataset_providers.CacheDatasetPlaceholder()
                ],
                output_features=output_features,
                metric_fns=[])

        with self.assertRaisesWithLiteralMatch(
                ValueError,
                "'test_token_preprocessor' has a `sequence_length` argument but occurs "
                "before `CacheDatasetPlaceholder` in 'sequence_length_pre_cache'. This "
                "is not allowed since the sequence length is specified at run time."
        ):
            dataset_providers.Task(
                "sequence_length_pre_cache",
                dataset_providers.FunctionDataSource(
                    dataset_fn=dataset_fn,
                    splits=["train"],
                ),
                preprocessors=[
                    test_utils.test_text_preprocessor, preprocessors.tokenize,
                    test_utils.test_token_preprocessor,
                    dataset_providers.CacheDatasetPlaceholder()
                ],
                output_features=output_features,
                metric_fns=[])
    def test_append_eos(self):
        og_dataset = tf.data.Dataset.from_tensors({
            'inputs': [1, 2, 3],
            'targets': [4, 5, 6, 7],
            'arrows': [8, 9, 10, 11],
            'bows': [12, 13],
        })
        vocab = test_utils.sentencepiece_vocab()
        output_features = {
            'inputs': Feature(vocab, add_eos=False),
            'targets': Feature(vocab, add_eos=True),
            'arrows': Feature(vocab, add_eos=True),
        }
        sequence_length = {'inputs': 4, 'targets': 3, 'arrows': 5, 'bows': 1}

        # Add eos only.
        assert_dataset(
            preprocessors.append_eos(og_dataset, output_features), {
                'inputs': [1, 2, 3],
                'targets': [4, 5, 6, 7, 1],
                'arrows': [8, 9, 10, 11, 1],
                'bows': [12, 13],
            })

        # Trim to sequence lengths.
        assert_dataset(
            preprocessors.append_eos_after_trim(
                og_dataset,
                output_features=output_features,
                sequence_length=sequence_length), {
                    'inputs': [1, 2, 3],
                    'targets': [4, 5, 1],
                    'arrows': [8, 9, 10, 11, 1],
                    'bows': [12, 13],
                })

        # Don't trim to sequence lengths.
        assert_dataset(
            preprocessors.append_eos_after_trim(
                og_dataset, output_features=output_features), {
                    'inputs': [1, 2, 3],
                    'targets': [4, 5, 6, 7, 1],
                    'arrows': [8, 9, 10, 11, 1],
                    'bows': [12, 13],
                })
def register_dummy_task(
    task_name: str,
    dataset_fn: Callable[[str, str], tf.data.Dataset],
    output_feature_names: Sequence[str] = ("inputs", "targets")) -> None:
  """Register a dummy task for GetDatasetTest."""
  dataset_providers.TaskRegistry.add(
      task_name,
      source=dataset_providers.FunctionDataSource(
          dataset_fn=dataset_fn, splits=["train", "validation"]),
      preprocessors=[
          dataset_providers.CacheDatasetPlaceholder(),
          preprocessors.append_eos_after_trim,
      ],
      output_features={
          feat: dataset_providers.Feature(test_utils.sentencepiece_vocab())
          for feat in output_feature_names
      },
      metric_fns=[])
示例#7
0
    def test_dtype(self):
        default_vocab = test_utils.sentencepiece_vocab()
        features = {
            "inputs":
            # defaults to int32
            dataset_providers.Feature(vocabulary=default_vocab),
            "targets":
            dataset_providers.Feature(dtype=tf.int64,
                                      vocabulary=default_vocab),
        }

        self.add_task(
            "task_dtypes",
            self.function_source,
            preprocessors=self.DEFAULT_PREPROCESSORS + (
                utils.map_over_dataset(
                    lambda x: {
                        k: tf.cast(v, tf.int64) if k == "targets" else v  # pylint:disable=g-long-lambda
                        for k, v in x.items()
                    }), ),
            output_features=features)
        self.verify_task_matches_fake_datasets("task_dtypes", use_cached=False)
示例#8
0
 def test_not_equal(self):
     vocab1 = test_utils.sentencepiece_vocab()
     vocab2 = test_utils.sentencepiece_vocab(10)
     self.assertNotEqual(vocab1, vocab2)