Пример #1
0
    def test_value_errors(self):
        dataset_fn = (lambda split, shuffle_files: tf.data.Dataset.
                      from_tensors(["test"]))
        output_features = {
            "inputs":
            dataset_providers.Feature(test_utils.sentencepiece_vocab())
        }

        with self.assertRaisesWithLiteralMatch(
                ValueError,
                "`CacheDatasetPlaceholder` can appear at most once in the "
                "preprocessing pipeline. Found 2 in 'multiple_cache_placeholders'."
        ):
            dataset_providers.Task(
                "multiple_cache_placeholders",
                source=dataset_providers.FunctionDataSource(
                    dataset_fn=dataset_fn, splits=["train", "validation"]),
                preprocessors=[
                    test_utils.test_text_preprocessor, preprocessors.tokenize,
                    dataset_providers.CacheDatasetPlaceholder(),
                    test_utils.test_token_preprocessor,
                    dataset_providers.CacheDatasetPlaceholder()
                ],
                output_features=output_features,
                metric_fns=[])

        with self.assertRaisesWithLiteralMatch(
                ValueError,
                "'test_token_preprocessor' has a `sequence_length` argument but occurs "
                "before `CacheDatasetPlaceholder` in 'sequence_length_pre_cache'. This "
                "is not allowed since the sequence length is specified at run time."
        ):
            dataset_providers.Task(
                "sequence_length_pre_cache",
                dataset_providers.FunctionDataSource(
                    dataset_fn=dataset_fn,
                    splits=["train"],
                ),
                preprocessors=[
                    test_utils.test_text_preprocessor, preprocessors.tokenize,
                    test_utils.test_token_preprocessor,
                    dataset_providers.CacheDatasetPlaceholder()
                ],
                output_features=output_features,
                metric_fns=[])
Пример #2
0
    def __enter__(self):
        def ds_fn(split, shuffle_files):
            del shuffle_files
            data = self.per_split_data[split]
            ds = tf.data.Dataset.from_tensors(data)
            return ds

        mock_source = dataset_providers.FunctionDataSource(
            ds_fn, splits=self.per_split_data.keys())
        self._task._source = mock_source
        self._mock_source = mock_source
  def test_function_source_signature(self):
    # Good signatures.
    def good_fn(split, shuffle_files):
      del split
      del shuffle_files
    dataset_providers.FunctionDataSource(good_fn, splits=("train",))

    def default_good_fn(split, shuffle_files=False):
      del split
      del shuffle_files
    dataset_providers.FunctionDataSource(default_good_fn, splits=("train",))

    def seed_fn(split, shuffle_files=True, seed=0):
      del split
      del shuffle_files
      del seed
    dataset_providers.FunctionDataSource(seed_fn, splits=("train",))

    def extra_kwarg_good_fn(split, shuffle_files, unused_kwarg=True):
      del split
      del shuffle_files
    dataset_providers.FunctionDataSource(extra_kwarg_good_fn, splits=("train",))

    # Bad signatures.
    with self.assertRaisesWithLiteralMatch(
        ValueError,
        "'missing_shuff' must have positional args ('split', 'shuffle_files'), "
        "got: ('split',)"):
      def missing_shuff(split):
        del split
      dataset_providers.FunctionDataSource(missing_shuff, splits=("train",))

    with self.assertRaisesWithLiteralMatch(
        ValueError,
        "'missing_split' must have positional args ('split', 'shuffle_files'), "
        "got: ('shuffle_files',)"):
      def missing_split(shuffle_files):
        del shuffle_files
      dataset_providers.FunctionDataSource(missing_split, splits=("train",))

    with self.assertRaisesWithLiteralMatch(
        ValueError,
        "'extra_pos_arg' may only have positional args ('split', "
        "'shuffle_files'), got: ('split', 'shuffle_files', 'unused_arg')"):
      def extra_pos_arg(split, shuffle_files, unused_arg):
        del split
        del shuffle_files
      dataset_providers.FunctionDataSource(extra_pos_arg, splits=("train",))
Пример #4
0
 def _task_from_tensor_slices(name, tensor_slices, label_classes):
   return dataset_providers.Task(
       name,
       dataset_providers.FunctionDataSource(
           lambda split, shuffle_files:
           tf.data.Dataset.from_tensor_slices(tensor_slices),
           splits=("validation")),
       preprocessors=[utils.map_over_dataset(lambda ex: {
           "inputs": tf.range(ex["inputs_lengths"]),
           "targets": tf.range(ex["targets_lengths"]),
           "targets_pretokenized": ex["targets_pretokenized"],
       })],
       postprocess_fn=functools.partial(
           _string_label_to_class_id_postprocessor,
           label_classes=label_classes),
       output_features={"inputs": dataset_providers.Feature(mock.Mock()),
                        "targets": dataset_providers.Feature(mock.Mock())}
   )
def register_dummy_task(
    task_name: str,
    dataset_fn: Callable[[str, str], tf.data.Dataset],
    output_feature_names: Sequence[str] = ("inputs", "targets")) -> None:
  """Register a dummy task for GetDatasetTest."""
  dataset_providers.TaskRegistry.add(
      task_name,
      source=dataset_providers.FunctionDataSource(
          dataset_fn=dataset_fn, splits=["train", "validation"]),
      preprocessors=[
          dataset_providers.CacheDatasetPlaceholder(),
          preprocessors.append_eos_after_trim,
      ],
      output_features={
          feat: dataset_providers.Feature(test_utils.sentencepiece_vocab())
          for feat in output_feature_names
      },
      metric_fns=[])
    def setUp(self):
        super().setUp()

        TaskRegistry.reset()
        MixtureRegistry.reset()

        self.fake_source = dataset_providers.FunctionDataSource(
            lambda split, shuffle_files: tf.data.Dataset.range(2), ['train'])

        self.vocabulary = vocabularies.PassThroughVocabulary(100)

        self.metrics_fns = [lambda targets, predictions: 0]

        def fake_preprocessor(ds):
            """Adds one and casts to int32."""
            return ds.map(lambda x: tf.cast(x + 1, tf.int32))

        def fake_preprocessor_of(ds, output_features):
            """Creates output feature dict from scalar input."""
            return ds.map(lambda x: {k: [x] for k in output_features})

        def fake_preprocessor_sl(ds, sequence_length):
            """Concatenates the sequence length to each feature."""
            return ds.map(
                lambda x: {  # pylint:disable=g-long-lambda
                    k: tf.concat([v, [sequence_length[k]]], 0)
                    for k, v in x.items()
                })

        def fake_preprocessor_sl_of(ds, sequence_length, output_features):
            """Adds the sequence length to each feature with `add_eos` enabled."""
            return ds.map(
                lambda x: {  # pylint:disable=g-long-lambda
                    k: tf.concat([v, [sequence_length[k]]], 0)
                    if output_features[k].add_eos else v
                    for k, v in x.items()
                })

        self.preprocessors = [
            fake_preprocessor,
            fake_preprocessor_of,
            fake_preprocessor_sl,
            fake_preprocessor_sl_of,
        ]
Пример #7
0
def register_dummy_task(task_name: str,
                        dataset_fn: Callable[[str, str], tf.data.Dataset],
                        output_feature_names: Sequence[str] = ("inputs",
                                                               "targets"),
                        postprocess_fn=None,
                        metrics_fn=None) -> None:
    """Register a dummy task for GetDatasetTest."""
    dataset_providers.TaskRegistry.add(
        task_name,
        source=dataset_providers.FunctionDataSource(
            dataset_fn=dataset_fn, splits=["train", "validation"]),
        preprocessors=[preprocessors.append_eos_after_trim],
        postprocess_fn=postprocess_fn,
        output_features={
            # Mock the sentencepiece vocabulary.
            feat: dataset_providers.Feature(mock.Mock(eos_id=True))
            for feat in output_feature_names
        },
        metric_fns=metrics_fn)
Пример #8
0
  def test_data_injection(self):

    def ds_fn(split, shuffle_files):
      del shuffle_files
      data = {'train': {'data': b'not used'}}
      ds = tf.data.Dataset.from_tensors(data[split])
      return ds

    source = dataset_providers.FunctionDataSource(
        dataset_fn=ds_fn, splits=['train'])

    dataset_providers.TaskRegistry.add(
        'test_data_injection_task',
        source=source,
        preprocessors=[],
        output_features={},
        metric_fns=[])

    data = {'train': {'data': b'This data is not used.'}}
    with DataInjector('test_data_injection_task', data):
      pass

    task = dataset_providers.TaskRegistry.get('test_data_injection_task')
    self.assertIs(source, task._source)
    def test_fewshot_data_source(self):
        def fake_dataset_fn(split, shuffle_files):
            del shuffle_files
            return tf.data.Dataset.range(
                *((0, 2) if split == 'validation' else (3, 5)))

        # 0 shot
        src = experimental.FewshotDataSource(
            dataset_providers.FunctionDataSource(
                dataset_fn=fake_dataset_fn, splits=['train', 'validation']),
            num_shots=0)
        dataset = src.get_dataset('validation')
        assert_dataset(dataset, [{
            'eval': 0,
        }, {
            'eval': 1
        }])

        # 3 shot
        src = experimental.FewshotDataSource(
            dataset_providers.FunctionDataSource(
                dataset_fn=fake_dataset_fn, splits=['train', 'validation']),
            train_preprocessors=[
                utils.map_over_dataset(lambda x: {
                    'inputs': 0,
                    'targets': x
                })
            ],
            num_shots=3)
        dataset = src.get_dataset('validation')
        assert_dataset(dataset, [
            {
                'eval': 0,
                'train': {
                    'inputs': [0, 0, 0],
                    'targets': [3, 4, 3]
                }
            },
            {
                'eval': 1,
                'train': {
                    'inputs': [0, 0, 0],
                    'targets': [4, 3, 4]
                }
            },
        ])

        # 3-shot, sharded.
        assert_dataset(
            src.get_dataset('validation', shard_info=ShardInfo(0, 2)), [
                {
                    'eval': 0,
                    'train': {
                        'inputs': [0, 0, 0],
                        'targets': [3, 3, 3]
                    }
                },
            ])
        assert_dataset(
            src.get_dataset('validation', shard_info=ShardInfo(1, 2)), [
                {
                    'eval': 1,
                    'train': {
                        'inputs': [0, 0, 0],
                        'targets': [4, 4, 4]
                    }
                },
            ])

        # Missing train
        src = experimental.FewshotDataSource(
            dataset_providers.FunctionDataSource(dataset_fn=fake_dataset_fn,
                                                 splits=['validation']),
            num_shots=3)
        with self.assertRaisesRegex(
                ValueError,
                'Train split \'train\' is not one of the original source splits: '
                r'\(\'validation\',\)'):
            dataset = src.get_dataset('validation')
Пример #10
0
    def _prepare_sources_and_tasks(self):
        clear_tasks()
        clear_mixtures()
        # Prepare TfdsSource
        # Note we don't use mock.Mock since they fail to pickle.
        fake_tfds_paths = {
            "train": [
                {  # pylint:disable=g-complex-comprehension
                    "filename": "train.tfrecord-%05d-of-00002" % i,
                    "skip": 0,
                    "take": -1
                } for i in range(2)
            ],
            "validation": [{
                "filename": "validation.tfrecord-00000-of-00001",
                "skip": 0,
                "take": -1
            }],
        }

        def _load_shard(shard_instruction, shuffle_files, seed):
            del shuffle_files
            del seed
            fname = shard_instruction["filename"]
            if "train" in fname:
                ds = get_fake_dataset("train")
                if fname.endswith("00000-of-00002"):
                    return ds.take(2)
                else:
                    return ds.skip(2)
            else:
                return get_fake_dataset("validation")

        fake_tfds = FakeLazyTfds(name="fake:0.0.0",
                                 load=get_fake_dataset,
                                 load_shard=_load_shard,
                                 info=FakeTfdsInfo(splits={
                                     "train": None,
                                     "validation": None
                                 }),
                                 files=fake_tfds_paths.get,
                                 size=lambda x: 30 if x == "train" else 10)
        self._tfds_patcher = mock.patch("t5.seqio.utils.LazyTfdsLoader",
                                        new=mock.Mock(return_value=fake_tfds))
        self._tfds_patcher.start()

        # Set up data directory.
        self.test_tmpdir = self.get_tempdir()
        self.test_data_dir = os.path.join(self.test_tmpdir, "test_data")
        shutil.copytree(TEST_DATA_DIR, self.test_data_dir)
        for root, dirs, _ in os.walk(self.test_data_dir):
            for d in dirs + [""]:
                os.chmod(os.path.join(root, d), 0o777)

        # Prepare uncached TextLineTask.
        self.tfds_source = dataset_providers.TfdsDataSource(
            tfds_name="fake:0.0.0", splits=("train", "validation"))
        self.add_task("tfds_task", source=self.tfds_source)

        # Prepare TextLineSource.
        _dump_fake_dataset(os.path.join(self.test_data_dir,
                                        "train.tsv"), _FAKE_DATASET["train"],
                           [2, 1], _dump_examples_to_tsv)
        self.text_line_source = dataset_providers.TextLineDataSource(
            split_to_filepattern={
                "train": os.path.join(self.test_data_dir, "train.tsv*"),
            },
            skip_header_lines=1,
        )
        self.add_task("text_line_task",
                      source=self.text_line_source,
                      preprocessors=(split_tsv_preprocessor, ) +
                      self.DEFAULT_PREPROCESSORS)

        # Prepare TFExampleSource.
        _dump_fake_dataset(os.path.join(self.test_data_dir, "train.tfrecord"),
                           _FAKE_DATASET["train"], [2, 1],
                           _dump_examples_to_tfrecord)
        self.tf_example_source = dataset_providers.TFExampleDataSource(
            split_to_filepattern={
                "train": os.path.join(self.test_data_dir, "train.tfrecord*"),
            },
            feature_description={
                "prefix": tf.io.FixedLenFeature([], tf.string),
                "suffix": tf.io.FixedLenFeature([], tf.string),
            })
        self.add_task("tf_example_task", source=self.tf_example_source)

        # Prepare FunctionDataSource
        self.function_source = dataset_providers.FunctionDataSource(
            dataset_fn=get_fake_dataset, splits=["train", "validation"])
        self.add_task("function_task", source=self.function_source)

        # Prepare Task that is tokenized and preprocessed before caching.
        self.add_task("fully_processed_precache",
                      source=self.function_source,
                      preprocessors=(
                          test_text_preprocessor,
                          preprocessors.tokenize,
                          token_preprocessor_no_sequence_length,
                          dataset_providers.CacheDatasetPlaceholder(),
                      ))

        # Prepare Task that is tokenized after caching.
        self.add_task("tokenized_postcache",
                      source=self.function_source,
                      preprocessors=(
                          test_text_preprocessor,
                          dataset_providers.CacheDatasetPlaceholder(),
                          preprocessors.tokenize,
                          token_preprocessor_no_sequence_length,
                      ))

        # Prepare Task with randomization.
        self.random_task = self.add_task(
            "random_task",
            source=self.function_source,
            preprocessors=(
                test_text_preprocessor,
                dataset_providers.CacheDatasetPlaceholder(),
                preprocessors.tokenize,
                random_token_preprocessor,
            ))

        self.uncached_task = self.add_task("uncached_task",
                                           source=self.tfds_source)

        # Prepare cached task.
        dataset_utils.set_global_cache_dirs([self.test_data_dir])
        self.cached_task_dir = os.path.join(self.test_data_dir, "cached_task")
        _dump_fake_dataset(
            os.path.join(self.cached_task_dir,
                         "train.tfrecord"), _FAKE_TOKENIZED_DATASET["train"],
            [2, 1], _dump_examples_to_tfrecord)
        _dump_fake_dataset(
            os.path.join(self.cached_task_dir, "validation.tfrecord"),
            _FAKE_TOKENIZED_DATASET["validation"], [2],
            _dump_examples_to_tfrecord)
        self.cached_task = self.add_task("cached_task",
                                         source=self.tfds_source)

        # Prepare cached plaintext task.
        _dump_fake_dataset(
            os.path.join(self.test_data_dir, "cached_plaintext_task",
                         "train.tfrecord"),
            _FAKE_PLAINTEXT_TOKENIZED_DATASET["train"], [2, 1],
            _dump_examples_to_tfrecord)
        self.cached_plaintext_task = self.add_task(
            "cached_plaintext_task",
            source=self.tfds_source,
            preprocessors=self.DEFAULT_PREPROCESSORS +
            (test_token_preprocessor, ))