def test_builder_memoization(self, mock_tfds_builder): mock_tfds_builder.side_effect = ( lambda name, data_dir: ",".join([name, data_dir or ""])) ds1 = utils.LazyTfdsLoader("ds1") self.assertEqual("ds1,", ds1.builder) self.assertEqual(1, tfds.builder.call_count) # Builder should be cached with same name. self.assertEqual("ds1,", ds1.builder) self.assertEqual(1, tfds.builder.call_count) # Same name but different data dir is a cache miss. ds1_dir1 = utils.LazyTfdsLoader("ds1", "dir1") self.assertEqual("ds1,dir1", ds1_dir1.builder) self.assertEqual(2, tfds.builder.call_count) # Same name and data dir is a cache hit. self.assertEqual("ds1,dir1", ds1_dir1.builder) self.assertEqual(2, tfds.builder.call_count) # Different name is a cache miss. ds2 = utils.LazyTfdsLoader("ds2") self.assertEqual("ds2,", ds2.builder) self.assertEqual(3, tfds.builder.call_count) # Different split map name is a cache hit. ds2 = utils.LazyTfdsLoader("ds2", split_map={"train": "validation"}) self.assertEqual("ds2,", ds2.builder) self.assertEqual(3, tfds.builder.call_count) # Try calling everything again, order shouldn't matter. self.assertEqual("ds1,", ds1.builder) self.assertEqual("ds1,dir1", ds1_dir1.builder) self.assertEqual("ds2,", ds2.builder) self.assertEqual(3, tfds.builder.call_count)
def test_split_map(self, mock_tfds_load): utils.LazyTfdsLoader._MEMOIZED_BUILDERS[("ds/c1", None)] = mock.Mock( info=mock.Mock(splits={ "validation": mock.Mock( num_examples=420, file_instructions=["f1", "f2"]), "test": mock.Mock( num_examples=42, file_instructions=["f3"]), })) ds = utils.LazyTfdsLoader( "ds/c1", split_map={"train": "validation", "validation": "test"}) # test .load() ds.load("train", shuffle_files=False) mock_tfds_load.assert_called_once_with( "ds/c1", split="validation", data_dir=None, shuffle_files=False, download=True, try_gcs=True) # test .size() self.assertEqual(420, ds.size(split="train")) self.assertEqual(42, ds.size(split="validation")) with self.assertRaises(KeyError): ds.size(split="test") # test .files() self.assertListEqual(["f1", "f2"], ds.files(split="train")) self.assertListEqual(["f3"], ds.files(split="validation")) with self.assertRaises(KeyError): ds.files(split="test")
def __init__( self, tfds_name: str, tfds_data_dir: Optional[str] = None, splits: Optional[Union[Iterable[str], Mapping[str, str]]] = None ): """TfdsTask constructor. Args: tfds_name: string, the name and version number of a TFDS dataset, optionally with a config. tfds_data_dir: string, an optional path to a specific TFDS data directory to use. splits: an iterable of allowable string split names, a dict mapping allowable canonical splits (e.g., 'validation') to TFDS splits or slices (e.g., 'train[':1%']), or None. The default, None, uses all available splits from the TFDS dataset info. """ if ":" not in tfds_name: raise ValueError("TFDS name must contain a version number, got: %s" % tfds_name) self._tfds_dataset = utils.LazyTfdsLoader( tfds_name, data_dir=tfds_data_dir, split_map=splits if isinstance(splits, dict) else None) # If splits are not provided, we pass an empty tuple and use the lazy # lookup in the `splits` property. super().__init__(splits=splits or ())
def __init__( self, name, tfds_name, text_preprocessor, metric_fns, tfds_data_dir=None, splits=None, **task_kwargs): """TfdsTask constructor. Args: name: string, a unique name for the Task. A ValueError will be raised if another task with this name is already registered. tfds_name: string, the name and version number of a TFDS dataset, optionally with a config. text_preprocessor: a function (or list of functions) that (each) takes in a tf.data.Dataset of string features and returns a tf.data.Dataset of string features. Can be set to None as a no-op. If a list is given, they will be executed sequentially. metric_fns: list(callable), list of metric functions with the signature metric_fn(targets, predictions) to use during evaluation. tfds_data_dir: string, an optional path to a specific TFDS data directory to use. splits: a list(string) of allowable splits to load, a dict mapping allowable canonical splits (e.g., 'validation') to TFDS splits or slices (e.g., 'train[':1%']), or None. The default, None, uses all available splits from the TFDS dataset info. **task_kwargs: dict, additional keyword arguments for the parent `Task` class. """ if ":" not in tfds_name: raise ValueError( "TFDS name must contain a version number, got: %s" % tfds_name) self._tfds_dataset = utils.LazyTfdsLoader( tfds_name, data_dir=tfds_data_dir, split_map=splits if isinstance(splits, dict) else None) def dataset_fn(split, shuffle_files, seed=None): return self._tfds_dataset.load( split, shuffle_files, seed=seed) super().__init__( name, dataset_fn=dataset_fn, splits=list(splits) if splits else None, text_preprocessor=text_preprocessor, metric_fns=metric_fns, **task_kwargs)
def test_get_dataset_onthefly(self): test_utils.verify_task_matches_fake_datasets( self.uncached_task, use_cached=False) # Test with token preprocessor. self.uncached_task._token_preprocessor = test_utils.test_token_preprocessor test_utils.verify_task_matches_fake_datasets( self.uncached_task, use_cached=False, token_preprocessed=True) # Override mock to get more examples. def fake_load(s, shuffle_files=False): del shuffle_files # Unused, to mimic TFDS API return test_utils.get_fake_dataset(s).repeat().take(20) test_utils.add_fake_tfds( utils.LazyTfdsLoader("fake:0.0.0")._replace(load=fake_load))