def validate_pipeline(self,
                          task_name,
                          expected_task_dir="cached_task",
                          token_preprocessed=False,
                          num_shards=2):
        self.assertTrue(TaskRegistry.get("cached_task").cache_dir)
        task = TaskRegistry.get(task_name)
        self.assertFalse(task.cache_dir)

        with TestPipeline() as p:
            output_dirs = cache_tasks_main.run_pipeline(
                p, ["cached_task", task_name], cache_dir=self.test_data_dir)

        actual_task_dir = os.path.join(self.test_data_dir,
                                       seqio.get_task_dir_from_name(task_name))
        expected_task_dir = os.path.join(test_utils.TEST_DATA_DIR,
                                         expected_task_dir)
        expected_tfrecord_files = [
            "train.tfrecord-%05d-of-%05d" % (i, num_shards)
            for i in range(num_shards)
        ]
        expected_auxiliary_files = ["stats.train.json", "info.train.json"]

        if "validation" in task.splits:
            expected_tfrecord_files.append(
                "validation.tfrecord-00000-of-00001")
            expected_auxiliary_files.extend(
                ["stats.validation.json", "info.validation.json"])
        self.assertEqual([actual_task_dir], output_dirs)
        self.assertCountEqual(
            expected_tfrecord_files + expected_auxiliary_files,
            tf.io.gfile.listdir(actual_task_dir))

        for fname in expected_auxiliary_files:
            actual_content = tf.io.gfile.GFile(
                os.path.join(actual_task_dir, fname)).read()
            expected_content = tf.io.gfile.GFile(
                os.path.join(expected_task_dir, fname)).read()

            # Accept minor formatting difference.
            actual_content = actual_content.replace(", ", ",")
            # Replace with actual number of shards.
            expected_content = expected_content.replace(
                '"num_shards": 2', f'"num_shards": {num_shards}')
            # Replace with actual version.
            version = seqio.__version__
            expected_content = expected_content.replace(
                '"seqio_version": "0.0.0"', f'"seqio_version": "{version}"')
            self.assertEqual(expected_content, actual_content)

        # Add COMPLETED file so that we can load `uncached_task`.
        mark_completed(self.test_data_dir,
                       seqio.get_task_dir_from_name(task_name))

        # Check datasets.
        self.verify_task_matches_fake_datasets(
            task_name,
            use_cached=True,
            splits=task.splits,
            token_preprocessed=token_preprocessed)
Exemple #2
0
  def test_overwrite(self):
    with TestPipeline() as p:
      _ = cache_tasks_main.run_pipeline(
          p, ["uncached_task"], cache_dir=self.test_data_dir, overwrite=True)

    actual_task_dir = os.path.join(self.test_data_dir, "uncached_task")
    stat_old = tf.io.gfile.stat(
        os.path.join(actual_task_dir, "train.tfrecord-00000-of-00002"))

    with TestPipeline() as p:
      _ = cache_tasks_main.run_pipeline(
          p, ["uncached_task"], cache_dir=self.test_data_dir, overwrite=True)

    stat_new = tf.io.gfile.stat(
        os.path.join(actual_task_dir, "train.tfrecord-00000-of-00002"))

    self.assertGreater(stat_new.mtime_nsec, stat_old.mtime_nsec)