コード例 #1
0
    def test_set_global_cache_dirs(self):
        utils.set_global_cache_dirs([])
        self.assertFalse(self.cached_task.cached)

        utils.set_global_cache_dirs([self.test_data_dir])
        self.cached_task._initialized = False
        self.assertTrue(self.cached_task.cached)
コード例 #2
0
    def setUp(self):
        super(FakeTaskTest, self).setUp()
        self.maxDiff = None  # pylint:disable=invalid-name

        # Mock TFDS
        # Note we don't use mock.Mock since they fail to pickle.
        fake_tfds_paths = {
            "train": ["train.tfrecord-%05d-of-00002" % i for i in range(2)],
            "validation": ["validation.tfrecord-00000-of-00001"],
        }

        def _load_shard(shard_path):
            if "train" in shard_path:
                if shard_path.endswith("00000-of-00002"):
                    return get_fake_dataset("train").take(2)
                else:
                    return get_fake_dataset("train").skip(2)
            else:
                return get_fake_dataset("validation")

        fake_tfds = FakeLazyTfds(name="fake:0.0.0",
                                 load=get_fake_dataset,
                                 load_shard=_load_shard,
                                 info=FakeTfdsInfo(splits={
                                     "train": None,
                                     "validation": None
                                 }),
                                 files=fake_tfds_paths.get,
                                 verify_split=lambda x: x,
                                 size=lambda x: 30 if x == "train" else 10)
        add_fake_tfds(fake_tfds)

        # Set up data directory.
        self.test_tmpdir = self.create_tempdir().full_path
        self.test_data_dir = os.path.join(self.test_tmpdir, "test_data")
        shutil.copytree(TEST_DATA_DIR, self.test_data_dir)
        for root, dirs, _ in os.walk(self.test_data_dir):
            for d in dirs + [""]:
                os.chmod(os.path.join(root, d), 0o777)

        # Register a cached test Task.
        dataset_utils.set_global_cache_dirs([self.test_data_dir])
        clear_tasks()
        add_tfds_task("cached_task")

        self.cached_task = TaskRegistry.get("cached_task")
        cached_task_dir = os.path.join(self.test_data_dir, "cached_task")
        _dump_fake_dataset(os.path.join(cached_task_dir, "train.tfrecord"),
                           "train", [2, 1])
        _dump_fake_dataset(
            os.path.join(cached_task_dir, "validation.tfrecord"), "validation",
            [2])

        # Register an uncached test Task.
        add_tfds_task("uncached_task")
        self.uncached_task = TaskRegistry.get("uncached_task")

        # Auto-verify any split by just retuning the split name
        dataset_utils.verify_tfds_split = absltest.mock.Mock(
            side_effect=lambda x, y: y)
コード例 #3
0
  def setUp(self):
    super().setUp()
    self.maxDiff = None  # pylint:disable=invalid-name

    # Mock TFDS
    # Note we don't use mock.Mock since they fail to pickle.
    fake_tfds_paths = {
        "train": [
            {  # pylint:disable=g-complex-comprehension
                "filename": "train.tfrecord-%05d-of-00002" % i,
                "skip": 0,
                "take": -1
            }
            for i in range(2)],
        "validation": [
            {
                "filename": "validation.tfrecord-00000-of-00001",
                "skip": 0,
                "take": -1
            }],
    }
    def _load_shard(shard_instruction):
      fname = shard_instruction["filename"]
      if "train" in fname:
        if fname.endswith("00000-of-00002"):
          return get_fake_dataset("train").take(2)
        else:
          return get_fake_dataset("train").skip(2)
      else:
        return get_fake_dataset("validation")

    fake_tfds = FakeLazyTfds(
        name="fake:0.0.0",
        load=get_fake_dataset,
        load_shard=_load_shard,
        info=FakeTfdsInfo(splits={"train": None, "validation": None}),
        files=fake_tfds_paths.get,
        size=lambda x: 30 if x == "train" else 10)
    self._tfds_patcher = mock.patch(
        "t5.data.utils.LazyTfdsLoader", new=mock.Mock(return_value=fake_tfds))
    self._tfds_patcher.start()

    # Set up data directory.
    self.test_tmpdir = self.get_tempdir()
    self.test_data_dir = os.path.join(self.test_tmpdir, "test_data")
    shutil.copytree(TEST_DATA_DIR, self.test_data_dir)
    for root, dirs, _ in os.walk(self.test_data_dir):
      for d in dirs + [""]:
        os.chmod(os.path.join(root, d), 0o777)

    # Register a cached test Task.
    dataset_utils.set_global_cache_dirs([self.test_data_dir])
    clear_tasks()
    add_tfds_task("cached_task")

    # Prepare cached task.
    self.cached_task = TaskRegistry.get("cached_task")
    cached_task_dir = os.path.join(self.test_data_dir, "cached_task")
    _dump_fake_dataset(
        os.path.join(cached_task_dir, "train.tfrecord"),
        _FAKE_TOKENIZED_DATASET["train"], [2, 1], _dump_examples_to_tfrecord)
    _dump_fake_dataset(
        os.path.join(cached_task_dir, "validation.tfrecord"),
        _FAKE_TOKENIZED_DATASET["validation"], [2], _dump_examples_to_tfrecord)

    # Prepare uncached TfdsTask.
    add_tfds_task("uncached_task")
    self.uncached_task = TaskRegistry.get("uncached_task")

    # Prepare uncached TextLineTask.
    _dump_fake_dataset(
        os.path.join(self.test_data_dir, "train.tsv"),
        _FAKE_DATASET["train"], [2, 1], _dump_examples_to_tsv)
    TaskRegistry.add(
        "text_line_task",
        dataset_utils.TextLineTask,
        split_to_filepattern={
            "train": os.path.join(self.test_data_dir, "train.tsv*"),
        },
        skip_header_lines=1,
        text_preprocessor=[_split_tsv_preprocessor, test_text_preprocessor],
        output_features=dataset_utils.Feature(sentencepiece_vocab()),
        metric_fns=[])
    self.text_line_task = TaskRegistry.get("text_line_task")
コード例 #4
0
    def setUp(self):
        super().setUp()
        self.maxDiff = None  # pylint:disable=invalid-name

        # Mock TFDS
        # Note we don't use mock.Mock since they fail to pickle.
        fake_tfds_paths = {
            "train": [
                {  # pylint:disable=g-complex-comprehension
                    "filename": "train.tfrecord-%05d-of-00002" % i,
                    "skip": 0,
                    "take": -1
                } for i in range(2)
            ],
            "validation": [{
                "filename": "validation.tfrecord-00000-of-00001",
                "skip": 0,
                "take": -1
            }],
        }

        def _load_shard(shard_instruction, shuffle_files, seed):
            del shuffle_files
            del seed
            fname = shard_instruction["filename"]
            if "train" in fname:
                if fname.endswith("00000-of-00002"):
                    return get_fake_dataset("train").take(2)
                else:
                    return get_fake_dataset("train").skip(2)
            else:
                return get_fake_dataset("validation")

        fake_tfds = FakeLazyTfds(name="fake:0.0.0",
                                 load=get_fake_dataset,
                                 load_shard=_load_shard,
                                 info=FakeTfdsInfo(splits={
                                     "train": None,
                                     "validation": None
                                 }),
                                 files=fake_tfds_paths.get,
                                 size=lambda x: 30 if x == "train" else 10)
        self._tfds_patcher = mock.patch("t5.data.utils.LazyTfdsLoader",
                                        new=mock.Mock(return_value=fake_tfds))
        self._tfds_patcher.start()

        # Set up data directory.
        self.test_tmpdir = self.get_tempdir()
        self.test_data_dir = os.path.join(self.test_tmpdir, "test_data")
        shutil.copytree(TEST_DATA_DIR, self.test_data_dir)
        for root, dirs, _ in os.walk(self.test_data_dir):
            for d in dirs + [""]:
                os.chmod(os.path.join(root, d), 0o777)

        # Register a cached test Task.
        dataset_utils.set_global_cache_dirs([self.test_data_dir])
        clear_tasks()
        add_tfds_task("cached_task",
                      token_preprocessor=test_token_preprocessor)
        add_tfds_task("cached_task_no_token_prep")

        # Prepare cached tasks.
        self.cached_task = TaskRegistry.get("cached_task")
        cached_task_dir = os.path.join(self.test_data_dir, "cached_task")
        _dump_fake_dataset(os.path.join(cached_task_dir, "train.tfrecord"),
                           _FAKE_TOKENIZED_DATASET["train"], [2, 1],
                           _dump_examples_to_tfrecord)
        _dump_fake_dataset(
            os.path.join(cached_task_dir, "validation.tfrecord"),
            _FAKE_TOKENIZED_DATASET["validation"], [2],
            _dump_examples_to_tfrecord)
        shutil.copytree(
            cached_task_dir,
            os.path.join(self.test_data_dir, "cached_task_no_token_prep"))

        # Prepare uncached TfdsTask.
        add_tfds_task("uncached_task",
                      token_preprocessor=test_token_preprocessor)
        add_tfds_task("uncached_task_no_token_prep")
        self.uncached_task = TaskRegistry.get("uncached_task")
        # Prepare uncached, random TfdsTask
        add_tfds_task("uncached_random_task",
                      token_preprocessor=random_token_preprocessor)
        self.uncached_random_task = TaskRegistry.get("uncached_random_task")

        # Prepare uncached TextLineTask.
        _dump_fake_dataset(os.path.join(self.test_data_dir,
                                        "train.tsv"), _FAKE_DATASET["train"],
                           [2, 1], _dump_examples_to_tsv)
        TaskRegistry.add("text_line_task",
                         dataset_providers.TextLineTask,
                         split_to_filepattern={
                             "train":
                             os.path.join(self.test_data_dir, "train.tsv*"),
                         },
                         skip_header_lines=1,
                         text_preprocessor=[
                             _split_tsv_preprocessor, test_text_preprocessor
                         ],
                         output_features=dataset_providers.Feature(
                             sentencepiece_vocab()),
                         metric_fns=[])
        self.text_line_task = TaskRegistry.get("text_line_task")

        # Prepare uncached TFExampleTask
        _dump_fake_dataset(os.path.join(self.test_data_dir, "train.tfrecord"),
                           _FAKE_DATASET["train"], [2, 1],
                           _dump_examples_to_tfrecord)
        TaskRegistry.add(
            "tf_example_task",
            dataset_providers.TFExampleTask,
            split_to_filepattern={
                "train": os.path.join(self.test_data_dir, "train.tfrecord*"),
            },
            feature_description={
                "prefix": tf.io.FixedLenFeature([], tf.string),
                "suffix": tf.io.FixedLenFeature([], tf.string),
            },
            text_preprocessor=[test_text_preprocessor],
            output_features=dataset_providers.Feature(sentencepiece_vocab()),
            metric_fns=[])
        self.tf_example_task = TaskRegistry.get("tf_example_task")

        # Prepare uncached Task.
        def _dataset_fn(split,
                        shuffle_files,
                        filepattern=os.path.join(self.test_data_dir,
                                                 "train.tsv*")):
            del split
            files = tf.data.Dataset.list_files(filepattern,
                                               shuffle=shuffle_files)
            return files.interleave(
                lambda f: tf.data.TextLineDataset(f).skip(1),
                num_parallel_calls=tf.data.experimental.AUTOTUNE)

        TaskRegistry.add("general_task",
                         dataset_providers.Task,
                         dataset_fn=_dataset_fn,
                         splits=["train"],
                         text_preprocessor=[
                             _split_tsv_preprocessor, test_text_preprocessor
                         ],
                         output_features=dataset_providers.Feature(
                             sentencepiece_vocab()),
                         metric_fns=[])
        self.general_task = TaskRegistry.get("general_task")

        # Prepare uncached TaskV3.
        TaskRegistry.add("task_v3",
                         dataset_providers.TaskV3,
                         source=dataset_providers.FunctionSource(
                             dataset_fn=get_fake_dataset,
                             splits=["train", "validation"]),
                         preprocessors=[
                             test_text_preprocessor,
                             preprocessors.tokenize,
                             token_preprocessor_no_sequence_length,
                             dataset_providers.CacheDatasetPlaceholder(),
                         ],
                         output_features={
                             "inputs":
                             dataset_providers.Feature(sentencepiece_vocab()),
                             "targets":
                             dataset_providers.Feature(sentencepiece_vocab()),
                         },
                         metric_fns=[])
        self.task_v3 = TaskRegistry.get("task_v3")

        # Prepare uncached TaskV3 with no caching before tokenization.
        TaskRegistry.add("task_v3_tokenized_postcache",
                         dataset_providers.TaskV3,
                         source=dataset_providers.FunctionSource(
                             dataset_fn=get_fake_dataset,
                             splits=["train", "validation"]),
                         preprocessors=[
                             test_text_preprocessor,
                             dataset_providers.CacheDatasetPlaceholder(),
                             preprocessors.tokenize,
                             token_preprocessor_no_sequence_length,
                         ],
                         output_features={
                             "inputs":
                             dataset_providers.Feature(sentencepiece_vocab()),
                             "targets":
                             dataset_providers.Feature(sentencepiece_vocab()),
                         },
                         metric_fns=[])
コード例 #5
0
ファイル: test_utils.py プロジェクト: yiwen92/my-project
    def setUp(self):
        super(FakeTaskTest, self).setUp()
        self.maxDiff = None  # pylint:disable=invalid-name

        # Mock TFDS
        # Note we don't use mock.Mock since they fail to pickle.
        fake_tfds_paths = {
            "train": ["train.tfrecord-%05d-of-00002" % i for i in range(2)],
            "validation": ["validation.tfrecord-00000-of-00001"],
        }

        def _load_shard(shard_path):
            if "train" in shard_path:
                if shard_path.endswith("00000-of-00002"):
                    return get_fake_dataset("train").take(2)
                else:
                    return get_fake_dataset("train").skip(2)
            else:
                return get_fake_dataset("validation")

        fake_tfds = FakeLazyTfds(name="fake:0.0.0",
                                 load=get_fake_dataset,
                                 load_shard=_load_shard,
                                 info=FakeTfdsInfo(splits={
                                     "train": None,
                                     "validation": None
                                 }),
                                 files=fake_tfds_paths.get,
                                 verify_split=lambda x: x,
                                 size=lambda x: 30 if x == "train" else 10)
        add_fake_tfds(fake_tfds)

        # Set up data directory.
        self.test_tmpdir = self.create_tempdir().full_path
        self.test_data_dir = os.path.join(self.test_tmpdir, "test_data")
        shutil.copytree(TEST_DATA_DIR, self.test_data_dir)
        for root, dirs, _ in os.walk(self.test_data_dir):
            for d in dirs + [""]:
                os.chmod(os.path.join(root, d), 0o777)

        # Register a cached test Task.
        dataset_utils.set_global_cache_dirs([self.test_data_dir])
        clear_tasks()
        add_tfds_task("cached_task")

        # Prepare cached task.
        self.cached_task = TaskRegistry.get("cached_task")
        cached_task_dir = os.path.join(self.test_data_dir, "cached_task")
        _dump_fake_dataset(os.path.join(cached_task_dir, "train.tfrecord"),
                           _FAKE_CACHED_DATASET["train"], [2, 1],
                           _dump_examples_to_tfrecord)
        _dump_fake_dataset(
            os.path.join(cached_task_dir, "validation.tfrecord"),
            _FAKE_CACHED_DATASET["validation"], [2],
            _dump_examples_to_tfrecord)

        # Prepare uncached TfdsTask.
        add_tfds_task("uncached_task")
        self.uncached_task = TaskRegistry.get("uncached_task")

        # Prepare uncached TextLineTask.
        _dump_fake_dataset(os.path.join(self.test_data_dir,
                                        "train.tsv"), _FAKE_DATASET["train"],
                           [2, 1], _dump_examples_to_tsv)
        TaskRegistry.add("text_line_task",
                         dataset_utils.TextLineTask,
                         split_to_filepattern={
                             "train":
                             os.path.join(self.test_data_dir, "train.tsv*"),
                         },
                         skip_header_lines=1,
                         text_preprocessor=[
                             _split_tsv_preprocessor, test_text_preprocessor
                         ],
                         sentencepiece_model_path=os.path.join(
                             TEST_DATA_DIR, "sentencepiece",
                             "sentencepiece.model"),
                         metric_fns=[])
        self.text_line_task = TaskRegistry.get("text_line_task")

        # Auto-verify any split by just retuning the split name
        dataset_utils.verify_tfds_split = absltest.mock.Mock(
            side_effect=lambda x, y: y)