def test_get_dataset_cached(self):
    test_utils.verify_task_matches_fake_datasets(
        self.cached_task, use_cached=True)

    # Test with token preprocessor.
    self.cached_task._token_preprocessor = test_utils.test_token_preprocessor
    test_utils.verify_task_matches_fake_datasets(
        self.cached_task, use_cached=False, token_preprocessed=True)
  def test_get_dataset_cached(self):
    test_utils.verify_task_matches_fake_datasets(
        self.cached_task, use_cached=True, token_preprocessed=True)

    # Test without token preprocessor.
    test_utils.verify_task_matches_fake_datasets(
        TaskRegistry.get("cached_task_no_token_prep"),
        use_cached=True,
        token_preprocessed=False)
  def test_tasks(self):
    test_utils.add_task("task1", test_utils.get_fake_dataset)
    test_utils.add_task("task2", test_utils.get_fake_dataset)
    MixtureRegistry.add("test_mix1", [("task1", 1), ("task2", 1)])
    mix = MixtureRegistry.get("test_mix1")
    self.assertEqual(len(mix.tasks), 2)

    for task in mix.tasks:
      test_utils.verify_task_matches_fake_datasets(task, use_cached=False)
      self.assertEqual(mix.get_rate(task), 1)
Beispiel #4
0
 def test_no_eos(self):
     features = {
         "inputs": utils.Feature(add_eos=True),
         "targets": utils.Feature(add_eos=False),
     }
     test_utils.add_task("task_no_eos",
                         test_utils.get_fake_dataset,
                         output_features=features)
     fn_task = TaskRegistry.get("task_no_eos")
     test_utils.verify_task_matches_fake_datasets(fn_task, use_cached=False)
Beispiel #5
0
 def test_no_eos(self):
     default_vocab = test_utils.sentencepiece_vocab()
     features = {
         "inputs": utils.Feature(add_eos=True, vocabulary=default_vocab),
         "targets": utils.Feature(add_eos=False, vocabulary=default_vocab),
     }
     test_utils.add_task("task_no_eos",
                         test_utils.get_fake_dataset,
                         output_features=features)
     fn_task = TaskRegistry.get("task_no_eos")
     test_utils.verify_task_matches_fake_datasets(fn_task, use_cached=False)
Beispiel #6
0
 def test_sharding(self):
     for i in range(3):
         test_utils.verify_task_matches_fake_datasets(
             self.cached_task,
             use_cached=False,
             num_shards=i,
             token_preprocessed=True)
         test_utils.verify_task_matches_fake_datasets(
             self.cached_task,
             use_cached=True,
             num_shards=i,
             token_preprocessed=True)
  def test_get_dataset_onthefly(self):
    test_utils.verify_task_matches_fake_datasets(
        self.uncached_task, use_cached=False)

    # Test with token preprocessor.
    self.uncached_task._token_preprocessor = test_utils.test_token_preprocessor
    test_utils.verify_task_matches_fake_datasets(
        self.uncached_task, use_cached=False, token_preprocessed=True)

    # Override mock to get more examples.
    def fake_load(s, shuffle_files=False):
      del shuffle_files  # Unused, to mimic TFDS API
      return test_utils.get_fake_dataset(s).repeat().take(20)
    test_utils.add_fake_tfds(
        utils.LazyTfdsLoader("fake:0.0.0")._replace(load=fake_load))
  def test_get_dataset_onthefly(self):
    test_utils.verify_task_matches_fake_datasets(
        self.uncached_task, use_cached=False, token_preprocessed=True)

    # Test without token preprocessor.
    test_utils.verify_task_matches_fake_datasets(
        TaskRegistry.get("uncached_task_no_token_prep"),
        use_cached=False,
        token_preprocessed=False)

    # Override mock to get more examples.
    def fake_load(s, shuffle_files=False):
      del shuffle_files  # Unused, to mimic TFDS API
      return test_utils.get_fake_dataset(s).repeat().take(20)
    self._tfds_patcher.new.return_value = (
        self._tfds_patcher.new.return_value._replace(load=fake_load))
Beispiel #9
0
 def test_dtype(self):
     default_vocab = test_utils.sentencepiece_vocab()
     features = {
         "inputs":
         # defaults to int32
         dataset_providers.Feature(vocabulary=default_vocab),
         "targets":
         dataset_providers.Feature(dtype=tf.int64,
                                   vocabulary=default_vocab),
     }
     test_utils.add_task("task_dtypes",
                         test_utils.get_fake_dataset,
                         output_features=features)
     dtype_task = TaskRegistry.get("task_dtypes")
     test_utils.verify_task_matches_fake_datasets(dtype_task,
                                                  use_cached=False)
 def test_dataset_fn(self):
   test_utils.add_task("fn_task", test_utils.get_fake_dataset)
   fn_task = TaskRegistry.get("fn_task")
   test_utils.verify_task_matches_fake_datasets(fn_task, use_cached=False)
 def test_get_dataset_v3(self):
   test_utils.verify_task_matches_fake_datasets(
       self.task_v3, use_cached=False, token_preprocessed=True)
 def test_tf_example_task(self):
   test_utils.verify_task_matches_fake_datasets(
       self.tf_example_task, use_cached=False, splits=["train"])