예제 #1
0
 def test_invalid_split_dataset(self):
     with testing.tmp_dir(self.get_temp_dir()) as tmp_dir:
         with self.assertRaisesWithPredicateMatch(
                 ValueError, "`all` is a reserved keyword"):
             # Raise error during .download_and_prepare()
             load.load(
                 name="invalid_split_dataset",
                 data_dir=tmp_dir,
             )
예제 #2
0
 def test_invalid_split_dataset(self):
     with testing.tmp_dir(self.get_temp_dir()) as tmp_dir:
         with self.assertRaisesWithPredicateMatch(ValueError,
                                                  "ALL is a special"):
             # Raise error during .download_and_prepare()
             registered.load(
                 name="invalid_split_dataset",
                 data_dir=tmp_dir,
             )
 def test_file_backed(self):
     with testing.tmp_dir(self.get_temp_dir()) as tmp_dir:
         vocab_fname = os.path.join(tmp_dir, 'vocab')
         encoder = text_encoder.TokenTextEncoder(
             vocab_list=['hi', 'bye', ZH_HELLO])
         encoder.save_to_file(vocab_fname)
         file_backed_encoder = text_encoder.TokenTextEncoder.load_from_file(
             vocab_fname)
         self.assertEqual(encoder.tokens, file_backed_encoder.tokens)
예제 #4
0
 def test_load(self):
     with testing.tmp_dir(self.get_temp_dir()) as tmp_dir:
         dataset = registered.load(name="dummy_dataset_shared_generator",
                                   data_dir=tmp_dir,
                                   download=True,
                                   split=splits_lib.Split.TRAIN)
         data = list(dataset_utils.as_numpy(dataset))
         self.assertEqual(20, len(data))
         self.assertLess(data[0]["x"], 30)
예제 #5
0
  def test_metadata(self):
    feature = feature_lib.Sequence(feature_lib.ClassLabel(num_classes=2))
    feature.feature.names = ['left', 'right']
    with testing.tmp_dir() as tmp_dir:
      feature.save_metadata(data_dir=tmp_dir, feature_name='test')

      feature2 = feature_lib.Sequence(feature_lib.ClassLabel(num_classes=2))
      feature2.load_metadata(data_dir=tmp_dir, feature_name='test')
    self.assertEqual(feature2.feature.names, ['left', 'right'])
예제 #6
0
 def test_dataset_download(self):
   with testing.mock_kaggle_api():
     with testing.tmp_dir() as tmp_dir:
       out_path = kaggle.download_kaggle_data('user/dataset', tmp_dir)
       self.assertIsInstance(out_path, os.PathLike)
       self.assertEqual(
           os.fspath(out_path), os.path.join(tmp_dir, 'user_dataset'))
       with tf.io.gfile.GFile(os.path.join(out_path, 'output.txt')) as f:
         self.assertEqual('user/dataset', f.read())
예제 #7
0
  def test_with_config(self):
    """Test that builder with configs are correctly generated."""
    with testing.tmp_dir() as tmp_dir:
      builder = DummyMnistConfigs(data_dir=tmp_dir)
      builder.download_and_prepare()
    doc_str = document_datasets.document_single_builder(builder)

    self.assertIn("Some manual instructions.", doc_str)
    self.assertIn("Mnist description.", doc_str)  # Shared description.
    self.assertIn("Config description.", doc_str)  # Config-specific description
예제 #8
0
 def test_feature_save_load_metadata_slashes(self):
     with testing.tmp_dir() as data_dir:
         fd = features_lib.FeaturesDict({
             'image/frame':
             features_lib.Image(shape=(32, 32, 3)),
             'image/label':
             features_lib.ClassLabel(num_classes=2),
         })
         fd.save_metadata(data_dir)
         fd.load_metadata(data_dir)
예제 #9
0
 def test_competition_download(self):
     with testing.mock_kaggle_api():
         with testing.tmp_dir() as tmp_dir:
             out_path = kaggle.download_kaggle_data('digit-recognizer',
                                                    tmp_dir)
             self.assertEqual(os.fspath(out_path),
                              os.path.join(tmp_dir, 'digit-recognizer'))
             with tf.io.gfile.GFile(os.path.join(out_path,
                                                 'output.txt')) as f:
                 self.assertEqual('digit-recognizer', f.read())
  def test_get_data_dir_with_config(self):
    with testing.tmp_dir(self.get_temp_dir()) as tmp_dir:
      config_name = "plus1"
      builder = DummyDatasetWithConfigs(config=config_name, data_dir=tmp_dir)

      builder_data_dir = os.path.join(tmp_dir, builder.name, config_name)
      version_data_dir = os.path.join(builder_data_dir, "0.0.1")

      tf.io.gfile.makedirs(version_data_dir)
      self.assertEqual(builder._build_data_dir(), version_data_dir)
    def test_metadata(self):
        with testing.tmp_dir(self.get_temp_dir()) as tmp_dir:
            builder = RandomShapedImageGenerator(data_dir=tmp_dir)
            builder.download_and_prepare()
            # Metadata should have been created
            self.assertEqual(builder.info.metadata, {"some_key": 123})

            # Metadata should have been restored
            builder2 = RandomShapedImageGenerator(data_dir=tmp_dir)
            self.assertEqual(builder2.info.metadata, {"some_key": 123})
예제 #12
0
 def test_competition_download(self):
     filenames = ["a", "b"]
     with testing.mock_kaggle_api(filenames):
         downloader = kaggle.KaggleCompetitionDownloader("digit-recognizer")
         self.assertEqual(downloader.competition_files, ["a", "b"])
         with testing.tmp_dir() as tmp_dir:
             for fname in downloader.competition_files:
                 out_path = downloader.download_file(fname, tmp_dir)
                 self.assertEqual(out_path, os.path.join(tmp_dir, fname))
                 with tf.io.gfile.GFile(out_path) as f:
                     self.assertEqual(fname, f.read())
예제 #13
0
  def test_file_backed(self, additional_tokens):
    encoder = text_encoder.ByteTextEncoder(additional_tokens=additional_tokens)
    with testing.tmp_dir(self.get_temp_dir()) as tmp_dir:
      vocab_fname = os.path.join(tmp_dir, 'vocab')
      encoder.save_to_file(vocab_fname)

      file_backed_encoder = text_encoder.ByteTextEncoder.load_from_file(
          vocab_fname)
      self.assertEqual(encoder.vocab_size, file_backed_encoder.vocab_size)
      self.assertEqual(encoder.additional_tokens,
                       file_backed_encoder.additional_tokens)
예제 #14
0
    def test_reuse_cache_if_exists(self):
        with testing.tmp_dir(self.get_temp_dir()) as tmp_dir:
            builder = testing.DummyMnist(data_dir=tmp_dir)
            dl_config = download.DownloadConfig(max_examples_per_split=3)
            builder.download_and_prepare(download_config=dl_config)

            dl_config = download.DownloadConfig(
                download_mode=download.GenerateMode.REUSE_CACHE_IF_EXISTS,
                max_examples_per_split=5)
            builder.download_and_prepare(download_config=dl_config)
            self.assertEqual(builder.info.splits["train"].num_examples, 5)
예제 #15
0
    def test_stats_restored_from_gcs(self):
        with testing.tmp_dir(self.get_temp_dir()) as tmp_dir:
            builder = testing.DummyMnist(data_dir=tmp_dir)
            self.assertEqual(builder.info.splits.total_num_examples, 70000)
            self.assertFalse(self.compute_dynamic_property.called)

            builder.download_and_prepare()

            # Statistics shouldn't have been recomputed
            self.assertEqual(builder.info.splits.total_num_examples, 70000)
            self.assertFalse(self.compute_dynamic_property.called)
 def test_config_construction(self):
   with testing.tmp_dir(self.get_temp_dir()) as tmp_dir:
     self.assertSetEqual(
         set(["plus1", "plus2"]),
         set(DummyDatasetWithConfigs.builder_configs.keys()))
     plus1_config = DummyDatasetWithConfigs.builder_configs["plus1"]
     builder = DummyDatasetWithConfigs(config="plus1", data_dir=tmp_dir)
     self.assertIs(plus1_config, builder.builder_config)
     builder = DummyDatasetWithConfigs(config=plus1_config, data_dir=tmp_dir)
     self.assertIs(plus1_config, builder.builder_config)
     self.assertIs(builder.builder_config,
                   DummyDatasetWithConfigs.BUILDER_CONFIGS[0])
예제 #17
0
 def test_s3_raise(self):
     dl_config = self._get_dl_config_if_need_to_run()
     if not dl_config:
         return
     dl_config.compute_stats = download.ComputeStatsMode.SKIP
     with testing.tmp_dir(self.get_temp_dir()) as tmp_dir:
         builder = FaultyS3DummyBeamDataset(data_dir=tmp_dir)
         builder.download_and_prepare(download_config=dl_config)
         with self.assertRaisesWithPredicateMatch(
                 AssertionError,
                 "`DatasetInfo.SplitInfo.num_shards` is empty"):
             builder.as_dataset()
    def test_statistics_generation_variable_sizes(self):
        with testing.tmp_dir(self.get_temp_dir()) as tmp_dir:
            builder = RandomShapedImageGenerator(data_dir=tmp_dir)
            builder.download_and_prepare()

            # Get the expected type of the feature.
            schema_feature = builder.info.as_proto.schema.feature[0]
            self.assertEqual("im", schema_feature.name)

            self.assertEqual(-1, schema_feature.shape.dim[0].size)
            self.assertEqual(-1, schema_feature.shape.dim[1].size)
            self.assertEqual(3, schema_feature.shape.dim[2].size)
예제 #19
0
    def test_stats_not_restored_gcs_overwritten(self):
        with testing.tmp_dir(self.get_temp_dir()) as tmp_dir:
            # If split are different that the one restored, stats should be recomputed
            builder = testing.DummyMnist(data_dir=tmp_dir, num_shards=5)
            self.assertEqual(builder.info.splits.total_num_examples, 70000)
            self.assertFalse(self.compute_dynamic_property.called)

            builder.download_and_prepare()

            # Statistics should have been recomputed (split different from the
            # restored ones)
            self.assertTrue(self.compute_dynamic_property.called)
예제 #20
0
    def test_disable_tqdm(self):
        tqdm_utils.disable_progress_bar()

        with testing.tmp_dir(self.get_temp_dir()) as tmp_dir:
            builder = testing.DummyMnist(data_dir=tmp_dir)
            builder.download_and_prepare()

            # Check the data has been generated
            train_ds, test_ds = builder.as_dataset(split=['train', 'test'])
            train_ds, test_ds = dataset_utils.as_numpy((train_ds, test_ds))
            self.assertEqual(20, len(list(train_ds)))
            self.assertEqual(20, len(list(test_ds)))
예제 #21
0
  def test_stats_not_restored_gcs_overwritten(self):
    with testing.tmp_dir(self.get_temp_dir()) as tmp_dir:
      # If split are different that the one restored, stats should be recomputed
      builder = testing.DummyMnist(data_dir=tmp_dir)
      self.assertEqual(builder.info.splits["train"].statistics.num_examples, 20)
      self.assertFalse(self.compute_dynamic_property.called)

      dl_config = download.DownloadConfig(max_examples_per_split=5)
      builder.download_and_prepare(download_config=dl_config)

      # Statistics should have been recomputed (split different from the
      # restored ones)
      self.assertTrue(self.compute_dynamic_property.called)
    def test_statistics_generation(self):
        with testing.tmp_dir(self.get_temp_dir()) as tmp_dir:
            builder = DummyDatasetSharedGenerator(data_dir=tmp_dir)
            builder.download_and_prepare()

            # Overall
            self.assertEqual(30, builder.info.splits.total_num_examples)

            # Per split.
            test_split = builder.info.splits["test"].get_proto()
            train_split = builder.info.splits["train"].get_proto()
            self.assertEqual(10, test_split.statistics.num_examples)
            self.assertEqual(20, train_split.statistics.num_examples)
  def test_multi_split(self):
    with testing.tmp_dir(self.get_temp_dir()) as tmp_dir:
      ds_train, ds_test = registered.load(
          name="dummy_dataset_shared_generator",
          data_dir=tmp_dir,
          split=["train", "test"],
          shuffle_files=False)

      data = list(dataset_utils.as_numpy(ds_train))
      self.assertEqual(20, len(data))

      data = list(dataset_utils.as_numpy(ds_test))
      self.assertEqual(10, len(data))
예제 #24
0
    def test_multi_split(self):
        with testing.tmp_dir(self.get_temp_dir()) as tmp_dir:
            ds_train, ds_test = registered.load(
                name="dummy_dataset_shared_generator",
                data_dir=tmp_dir,
                split=[splits_lib.Split.TRAIN, splits_lib.Split.TEST],
                as_dataset_kwargs=dict(shuffle_files=False))

            data = list(dataset_utils.as_numpy(ds_train))
            self.assertEqual(20, len(data))

            data = list(dataset_utils.as_numpy(ds_test))
            self.assertEqual(10, len(data))
예제 #25
0
    def test_save_load_metadata(self):
        text_f = features.Text(encoder=text_encoder.ByteTextEncoder(
            additional_tokens=['HI']))
        text = u'HI 你好'
        ids = text_f.str2ints(text)
        self.assertEqual(1, ids[0])

        with testing.tmp_dir(self.get_temp_dir()) as data_dir:
            feature_name = 'dummy'
            text_f.save_metadata(data_dir, feature_name)

            new_f = features.Text()
            new_f.load_metadata(data_dir, feature_name)
            self.assertEqual(ids, text_f.str2ints(text))
예제 #26
0
  def test_statistics_generation(self):
    with testing.tmp_dir(self.get_temp_dir()) as tmp_dir:
      builder = DummyDatasetSharedGenerator(data_dir=tmp_dir)
      builder.download_and_prepare(
          download_config=download.DownloadConfig(
              compute_stats=download.ComputeStatsMode.AUTO,
          ),
      )

      # Overall
      self.assertEqual(30, builder.info.splits.total_num_examples)

      # Per split.
      test_split = builder.info.splits["test"].get_proto()
      train_split = builder.info.splits["train"].get_proto()
      expected_schema = text_format.Parse("""
      feature {
        name: "x"
        type: INT
        presence {
          min_fraction: 1.0
          min_count: 1
        }
        shape {
          dim {
            size: 1
          }
        }
      }""", schema_pb2.Schema())
      self.assertEqual(train_split.statistics.num_examples, 20)
      self.assertLen(train_split.statistics.features, 1)
      self.assertEqual(
          train_split.statistics.features[0].path.step[0], "x")
      self.assertLen(
          train_split.statistics.features[0].num_stats.common_stats.
          num_values_histogram.buckets, 10)
      self.assertLen(
          train_split.statistics.features[0].num_stats.histograms, 2)

      self.assertEqual(test_split.statistics.num_examples, 10)
      self.assertLen(test_split.statistics.features, 1)
      self.assertEqual(
          test_split.statistics.features[0].path.step[0], "x")
      self.assertLen(
          test_split.statistics.features[0].num_stats.common_stats.
          num_values_histogram.buckets, 10)
      self.assertLen(
          test_split.statistics.features[0].num_stats.histograms, 2)
      self.assertEqual(builder.info.as_proto.schema, expected_schema)
예제 #27
0
    def test_force_stats(self):
        # Test when stats already exists but compute_stats='force'

        with testing.tmp_dir(self.get_temp_dir()) as tmp_dir:
            # No dataset_info restored, so stats are empty
            builder = testing.DummyMnist(data_dir=tmp_dir)
            self.assertEqual(builder.info.splits.total_num_examples, 40)
            self.assertFalse(self.compute_dynamic_property.called)

            download_config = download.DownloadConfig(
                compute_stats=download.ComputeStatsMode.FORCE, )
            builder.download_and_prepare(download_config=download_config)

            # Statistics computation should have been recomputed
            self.assertTrue(self.compute_dynamic_property.called)
예제 #28
0
    def test_gcs_not_exists(self):
        # By disabling the patch, and because DummyMnist is not on GCS, we can
        # simulate a new dataset starting from scratch
        self.patch_gcs.stop()
        with testing.tmp_dir(self.get_temp_dir()) as tmp_dir:
            builder = testing.DummyMnist(data_dir=tmp_dir)
            # No dataset_info restored, so stats are empty
            self.assertEqual(builder.info.splits.total_num_examples, 0)
            self.assertFalse(self.compute_dynamic_property.called)

            builder.download_and_prepare()

            # Statistics should have been recomputed
            self.assertTrue(self.compute_dynamic_property.called)
        self.patch_gcs.start()
예제 #29
0
    def test_gcs_not_exists(self):
        # By disabling the patch, and because DummyMnist is not on GCS, we can
        # simulate a new dataset starting from scratch
        self.patch_gcs.stop()
        with testing.tmp_dir(self.get_temp_dir()) as tmp_dir:
            builder = testing.DummyMnist(data_dir=tmp_dir)
            # No dataset_info restored, so stats are empty
            self.assertEqual(builder.info.splits.total_num_examples, 0)

            dl_config = download.DownloadConfig()
            builder.download_and_prepare(download_config=dl_config)

            # Statistics should have been recomputed
            self.assertEqual(builder.info.splits["train"].num_examples, 20)
        self.patch_gcs.start()
예제 #30
0
  def test_build_data_dir(self):
    with testing.tmp_dir(self.get_temp_dir()) as tmp_dir:
      builder = DummyDatasetSharedGenerator(data_dir=tmp_dir)
      self.assertEqual(str(builder.info.version), "1.0.0")
      builder_data_dir = os.path.join(tmp_dir, builder.name)
      version_dir = os.path.join(builder_data_dir, "1.0.0")

      # The dataset folder contains multiple other versions
      tf.io.gfile.makedirs(os.path.join(builder_data_dir, "14.0.0.invalid"))
      tf.io.gfile.makedirs(os.path.join(builder_data_dir, "10.0.0"))
      tf.io.gfile.makedirs(os.path.join(builder_data_dir, "9.0.0"))
      tf.io.gfile.makedirs(os.path.join(builder_data_dir, "0.1.0"))

      # The builder's version dir is chosen
      self.assertEqual(builder._build_data_dir(tmp_dir)[1], version_dir)