def test_invalid_split_dataset(self): with testing.tmp_dir(self.get_temp_dir()) as tmp_dir: with self.assertRaisesWithPredicateMatch( ValueError, "`all` is a reserved keyword"): # Raise error during .download_and_prepare() load.load( name="invalid_split_dataset", data_dir=tmp_dir, )
def test_invalid_split_dataset(self): with testing.tmp_dir(self.get_temp_dir()) as tmp_dir: with self.assertRaisesWithPredicateMatch(ValueError, "ALL is a special"): # Raise error during .download_and_prepare() registered.load( name="invalid_split_dataset", data_dir=tmp_dir, )
def test_file_backed(self): with testing.tmp_dir(self.get_temp_dir()) as tmp_dir: vocab_fname = os.path.join(tmp_dir, 'vocab') encoder = text_encoder.TokenTextEncoder( vocab_list=['hi', 'bye', ZH_HELLO]) encoder.save_to_file(vocab_fname) file_backed_encoder = text_encoder.TokenTextEncoder.load_from_file( vocab_fname) self.assertEqual(encoder.tokens, file_backed_encoder.tokens)
def test_load(self): with testing.tmp_dir(self.get_temp_dir()) as tmp_dir: dataset = registered.load(name="dummy_dataset_shared_generator", data_dir=tmp_dir, download=True, split=splits_lib.Split.TRAIN) data = list(dataset_utils.as_numpy(dataset)) self.assertEqual(20, len(data)) self.assertLess(data[0]["x"], 30)
def test_metadata(self): feature = feature_lib.Sequence(feature_lib.ClassLabel(num_classes=2)) feature.feature.names = ['left', 'right'] with testing.tmp_dir() as tmp_dir: feature.save_metadata(data_dir=tmp_dir, feature_name='test') feature2 = feature_lib.Sequence(feature_lib.ClassLabel(num_classes=2)) feature2.load_metadata(data_dir=tmp_dir, feature_name='test') self.assertEqual(feature2.feature.names, ['left', 'right'])
def test_dataset_download(self): with testing.mock_kaggle_api(): with testing.tmp_dir() as tmp_dir: out_path = kaggle.download_kaggle_data('user/dataset', tmp_dir) self.assertIsInstance(out_path, os.PathLike) self.assertEqual( os.fspath(out_path), os.path.join(tmp_dir, 'user_dataset')) with tf.io.gfile.GFile(os.path.join(out_path, 'output.txt')) as f: self.assertEqual('user/dataset', f.read())
def test_with_config(self): """Test that builder with configs are correctly generated.""" with testing.tmp_dir() as tmp_dir: builder = DummyMnistConfigs(data_dir=tmp_dir) builder.download_and_prepare() doc_str = document_datasets.document_single_builder(builder) self.assertIn("Some manual instructions.", doc_str) self.assertIn("Mnist description.", doc_str) # Shared description. self.assertIn("Config description.", doc_str) # Config-specific description
def test_feature_save_load_metadata_slashes(self): with testing.tmp_dir() as data_dir: fd = features_lib.FeaturesDict({ 'image/frame': features_lib.Image(shape=(32, 32, 3)), 'image/label': features_lib.ClassLabel(num_classes=2), }) fd.save_metadata(data_dir) fd.load_metadata(data_dir)
def test_competition_download(self): with testing.mock_kaggle_api(): with testing.tmp_dir() as tmp_dir: out_path = kaggle.download_kaggle_data('digit-recognizer', tmp_dir) self.assertEqual(os.fspath(out_path), os.path.join(tmp_dir, 'digit-recognizer')) with tf.io.gfile.GFile(os.path.join(out_path, 'output.txt')) as f: self.assertEqual('digit-recognizer', f.read())
def test_get_data_dir_with_config(self): with testing.tmp_dir(self.get_temp_dir()) as tmp_dir: config_name = "plus1" builder = DummyDatasetWithConfigs(config=config_name, data_dir=tmp_dir) builder_data_dir = os.path.join(tmp_dir, builder.name, config_name) version_data_dir = os.path.join(builder_data_dir, "0.0.1") tf.io.gfile.makedirs(version_data_dir) self.assertEqual(builder._build_data_dir(), version_data_dir)
def test_metadata(self): with testing.tmp_dir(self.get_temp_dir()) as tmp_dir: builder = RandomShapedImageGenerator(data_dir=tmp_dir) builder.download_and_prepare() # Metadata should have been created self.assertEqual(builder.info.metadata, {"some_key": 123}) # Metadata should have been restored builder2 = RandomShapedImageGenerator(data_dir=tmp_dir) self.assertEqual(builder2.info.metadata, {"some_key": 123})
def test_competition_download(self): filenames = ["a", "b"] with testing.mock_kaggle_api(filenames): downloader = kaggle.KaggleCompetitionDownloader("digit-recognizer") self.assertEqual(downloader.competition_files, ["a", "b"]) with testing.tmp_dir() as tmp_dir: for fname in downloader.competition_files: out_path = downloader.download_file(fname, tmp_dir) self.assertEqual(out_path, os.path.join(tmp_dir, fname)) with tf.io.gfile.GFile(out_path) as f: self.assertEqual(fname, f.read())
def test_file_backed(self, additional_tokens): encoder = text_encoder.ByteTextEncoder(additional_tokens=additional_tokens) with testing.tmp_dir(self.get_temp_dir()) as tmp_dir: vocab_fname = os.path.join(tmp_dir, 'vocab') encoder.save_to_file(vocab_fname) file_backed_encoder = text_encoder.ByteTextEncoder.load_from_file( vocab_fname) self.assertEqual(encoder.vocab_size, file_backed_encoder.vocab_size) self.assertEqual(encoder.additional_tokens, file_backed_encoder.additional_tokens)
def test_reuse_cache_if_exists(self): with testing.tmp_dir(self.get_temp_dir()) as tmp_dir: builder = testing.DummyMnist(data_dir=tmp_dir) dl_config = download.DownloadConfig(max_examples_per_split=3) builder.download_and_prepare(download_config=dl_config) dl_config = download.DownloadConfig( download_mode=download.GenerateMode.REUSE_CACHE_IF_EXISTS, max_examples_per_split=5) builder.download_and_prepare(download_config=dl_config) self.assertEqual(builder.info.splits["train"].num_examples, 5)
def test_stats_restored_from_gcs(self): with testing.tmp_dir(self.get_temp_dir()) as tmp_dir: builder = testing.DummyMnist(data_dir=tmp_dir) self.assertEqual(builder.info.splits.total_num_examples, 70000) self.assertFalse(self.compute_dynamic_property.called) builder.download_and_prepare() # Statistics shouldn't have been recomputed self.assertEqual(builder.info.splits.total_num_examples, 70000) self.assertFalse(self.compute_dynamic_property.called)
def test_config_construction(self): with testing.tmp_dir(self.get_temp_dir()) as tmp_dir: self.assertSetEqual( set(["plus1", "plus2"]), set(DummyDatasetWithConfigs.builder_configs.keys())) plus1_config = DummyDatasetWithConfigs.builder_configs["plus1"] builder = DummyDatasetWithConfigs(config="plus1", data_dir=tmp_dir) self.assertIs(plus1_config, builder.builder_config) builder = DummyDatasetWithConfigs(config=plus1_config, data_dir=tmp_dir) self.assertIs(plus1_config, builder.builder_config) self.assertIs(builder.builder_config, DummyDatasetWithConfigs.BUILDER_CONFIGS[0])
def test_s3_raise(self): dl_config = self._get_dl_config_if_need_to_run() if not dl_config: return dl_config.compute_stats = download.ComputeStatsMode.SKIP with testing.tmp_dir(self.get_temp_dir()) as tmp_dir: builder = FaultyS3DummyBeamDataset(data_dir=tmp_dir) builder.download_and_prepare(download_config=dl_config) with self.assertRaisesWithPredicateMatch( AssertionError, "`DatasetInfo.SplitInfo.num_shards` is empty"): builder.as_dataset()
def test_statistics_generation_variable_sizes(self): with testing.tmp_dir(self.get_temp_dir()) as tmp_dir: builder = RandomShapedImageGenerator(data_dir=tmp_dir) builder.download_and_prepare() # Get the expected type of the feature. schema_feature = builder.info.as_proto.schema.feature[0] self.assertEqual("im", schema_feature.name) self.assertEqual(-1, schema_feature.shape.dim[0].size) self.assertEqual(-1, schema_feature.shape.dim[1].size) self.assertEqual(3, schema_feature.shape.dim[2].size)
def test_stats_not_restored_gcs_overwritten(self): with testing.tmp_dir(self.get_temp_dir()) as tmp_dir: # If split are different that the one restored, stats should be recomputed builder = testing.DummyMnist(data_dir=tmp_dir, num_shards=5) self.assertEqual(builder.info.splits.total_num_examples, 70000) self.assertFalse(self.compute_dynamic_property.called) builder.download_and_prepare() # Statistics should have been recomputed (split different from the # restored ones) self.assertTrue(self.compute_dynamic_property.called)
def test_disable_tqdm(self): tqdm_utils.disable_progress_bar() with testing.tmp_dir(self.get_temp_dir()) as tmp_dir: builder = testing.DummyMnist(data_dir=tmp_dir) builder.download_and_prepare() # Check the data has been generated train_ds, test_ds = builder.as_dataset(split=['train', 'test']) train_ds, test_ds = dataset_utils.as_numpy((train_ds, test_ds)) self.assertEqual(20, len(list(train_ds))) self.assertEqual(20, len(list(test_ds)))
def test_stats_not_restored_gcs_overwritten(self): with testing.tmp_dir(self.get_temp_dir()) as tmp_dir: # If split are different that the one restored, stats should be recomputed builder = testing.DummyMnist(data_dir=tmp_dir) self.assertEqual(builder.info.splits["train"].statistics.num_examples, 20) self.assertFalse(self.compute_dynamic_property.called) dl_config = download.DownloadConfig(max_examples_per_split=5) builder.download_and_prepare(download_config=dl_config) # Statistics should have been recomputed (split different from the # restored ones) self.assertTrue(self.compute_dynamic_property.called)
def test_statistics_generation(self): with testing.tmp_dir(self.get_temp_dir()) as tmp_dir: builder = DummyDatasetSharedGenerator(data_dir=tmp_dir) builder.download_and_prepare() # Overall self.assertEqual(30, builder.info.splits.total_num_examples) # Per split. test_split = builder.info.splits["test"].get_proto() train_split = builder.info.splits["train"].get_proto() self.assertEqual(10, test_split.statistics.num_examples) self.assertEqual(20, train_split.statistics.num_examples)
def test_multi_split(self): with testing.tmp_dir(self.get_temp_dir()) as tmp_dir: ds_train, ds_test = registered.load( name="dummy_dataset_shared_generator", data_dir=tmp_dir, split=["train", "test"], shuffle_files=False) data = list(dataset_utils.as_numpy(ds_train)) self.assertEqual(20, len(data)) data = list(dataset_utils.as_numpy(ds_test)) self.assertEqual(10, len(data))
def test_multi_split(self): with testing.tmp_dir(self.get_temp_dir()) as tmp_dir: ds_train, ds_test = registered.load( name="dummy_dataset_shared_generator", data_dir=tmp_dir, split=[splits_lib.Split.TRAIN, splits_lib.Split.TEST], as_dataset_kwargs=dict(shuffle_files=False)) data = list(dataset_utils.as_numpy(ds_train)) self.assertEqual(20, len(data)) data = list(dataset_utils.as_numpy(ds_test)) self.assertEqual(10, len(data))
def test_save_load_metadata(self): text_f = features.Text(encoder=text_encoder.ByteTextEncoder( additional_tokens=['HI'])) text = u'HI 你好' ids = text_f.str2ints(text) self.assertEqual(1, ids[0]) with testing.tmp_dir(self.get_temp_dir()) as data_dir: feature_name = 'dummy' text_f.save_metadata(data_dir, feature_name) new_f = features.Text() new_f.load_metadata(data_dir, feature_name) self.assertEqual(ids, text_f.str2ints(text))
def test_statistics_generation(self): with testing.tmp_dir(self.get_temp_dir()) as tmp_dir: builder = DummyDatasetSharedGenerator(data_dir=tmp_dir) builder.download_and_prepare( download_config=download.DownloadConfig( compute_stats=download.ComputeStatsMode.AUTO, ), ) # Overall self.assertEqual(30, builder.info.splits.total_num_examples) # Per split. test_split = builder.info.splits["test"].get_proto() train_split = builder.info.splits["train"].get_proto() expected_schema = text_format.Parse(""" feature { name: "x" type: INT presence { min_fraction: 1.0 min_count: 1 } shape { dim { size: 1 } } }""", schema_pb2.Schema()) self.assertEqual(train_split.statistics.num_examples, 20) self.assertLen(train_split.statistics.features, 1) self.assertEqual( train_split.statistics.features[0].path.step[0], "x") self.assertLen( train_split.statistics.features[0].num_stats.common_stats. num_values_histogram.buckets, 10) self.assertLen( train_split.statistics.features[0].num_stats.histograms, 2) self.assertEqual(test_split.statistics.num_examples, 10) self.assertLen(test_split.statistics.features, 1) self.assertEqual( test_split.statistics.features[0].path.step[0], "x") self.assertLen( test_split.statistics.features[0].num_stats.common_stats. num_values_histogram.buckets, 10) self.assertLen( test_split.statistics.features[0].num_stats.histograms, 2) self.assertEqual(builder.info.as_proto.schema, expected_schema)
def test_force_stats(self): # Test when stats already exists but compute_stats='force' with testing.tmp_dir(self.get_temp_dir()) as tmp_dir: # No dataset_info restored, so stats are empty builder = testing.DummyMnist(data_dir=tmp_dir) self.assertEqual(builder.info.splits.total_num_examples, 40) self.assertFalse(self.compute_dynamic_property.called) download_config = download.DownloadConfig( compute_stats=download.ComputeStatsMode.FORCE, ) builder.download_and_prepare(download_config=download_config) # Statistics computation should have been recomputed self.assertTrue(self.compute_dynamic_property.called)
def test_gcs_not_exists(self): # By disabling the patch, and because DummyMnist is not on GCS, we can # simulate a new dataset starting from scratch self.patch_gcs.stop() with testing.tmp_dir(self.get_temp_dir()) as tmp_dir: builder = testing.DummyMnist(data_dir=tmp_dir) # No dataset_info restored, so stats are empty self.assertEqual(builder.info.splits.total_num_examples, 0) self.assertFalse(self.compute_dynamic_property.called) builder.download_and_prepare() # Statistics should have been recomputed self.assertTrue(self.compute_dynamic_property.called) self.patch_gcs.start()
def test_gcs_not_exists(self): # By disabling the patch, and because DummyMnist is not on GCS, we can # simulate a new dataset starting from scratch self.patch_gcs.stop() with testing.tmp_dir(self.get_temp_dir()) as tmp_dir: builder = testing.DummyMnist(data_dir=tmp_dir) # No dataset_info restored, so stats are empty self.assertEqual(builder.info.splits.total_num_examples, 0) dl_config = download.DownloadConfig() builder.download_and_prepare(download_config=dl_config) # Statistics should have been recomputed self.assertEqual(builder.info.splits["train"].num_examples, 20) self.patch_gcs.start()
def test_build_data_dir(self): with testing.tmp_dir(self.get_temp_dir()) as tmp_dir: builder = DummyDatasetSharedGenerator(data_dir=tmp_dir) self.assertEqual(str(builder.info.version), "1.0.0") builder_data_dir = os.path.join(tmp_dir, builder.name) version_dir = os.path.join(builder_data_dir, "1.0.0") # The dataset folder contains multiple other versions tf.io.gfile.makedirs(os.path.join(builder_data_dir, "14.0.0.invalid")) tf.io.gfile.makedirs(os.path.join(builder_data_dir, "10.0.0")) tf.io.gfile.makedirs(os.path.join(builder_data_dir, "9.0.0")) tf.io.gfile.makedirs(os.path.join(builder_data_dir, "0.1.0")) # The builder's version dir is chosen self.assertEqual(builder._build_data_dir(tmp_dir)[1], version_dir)