def test_default_multi_dir(self): # No data_dir is passed # Multiple data_dirs are registered # -> use default path constants.add_data_dir(self.other_data_dir) self.assertBuildDataDir(self.builder._build_data_dir(None), self.default_data_dir)
def test_default_multi_dir_duplicate(self): # If two data dirs contains the dataset, raise an error... constants.add_data_dir(self.other_data_dir) tf.io.gfile.makedirs(os.path.join( self.default_data_dir, "dummy_dataset_shared_generator", "1.0.0")) tf.io.gfile.makedirs(os.path.join( self.other_data_dir, "dummy_dataset_shared_generator", "1.0.0")) with self.assertRaisesRegex(ValueError, "found in more than one directory"): self.builder._build_data_dir(None)
def test_expicit_multi_dir(self): # If two data dirs contains the same version # Data dir is explicitly passed constants.add_data_dir(self.other_data_dir) tf.io.gfile.makedirs(os.path.join( self.default_data_dir, "dummy_dataset_shared_generator", "1.0.0")) tf.io.gfile.makedirs(os.path.join( self.other_data_dir, "dummy_dataset_shared_generator", "1.0.0")) self.assertBuildDataDir( self.builder._build_data_dir(self.other_data_dir), self.other_data_dir)
def test_default_multi_dir_version_exists(self): # No data_dir is passed # Multiple data_dirs are registered # Data found # -> Re-load existing data constants.add_data_dir(self.other_data_dir) tf.io.gfile.makedirs(os.path.join( self.other_data_dir, "dummy_dataset_shared_generator", "1.0.0")) self.assertBuildDataDir( self.builder._build_data_dir(None), self.other_data_dir)
def test_default_multi_dir_old_version_exists(self): # No data_dir is passed # Multiple data_dirs are registered # Data dir contains old versions # -> use default path constants.add_data_dir(self.other_data_dir) tf.io.gfile.makedirs(os.path.join( self.other_data_dir, "dummy_dataset_shared_generator", "0.1.0")) tf.io.gfile.makedirs(os.path.join( self.other_data_dir, "dummy_dataset_shared_generator", "0.2.0")) self.assertBuildDataDir( self.builder._build_data_dir(None), self.default_data_dir)
def test_load_data_dir(self): """Ensure that `tfds.load` also supports multiple data_dir.""" constants.add_data_dir(self.other_data_dir) class MultiDirDataset(DummyDatasetSharedGenerator): # pylint: disable=unused-variable VERSION = utils.Version("1.2.0") data_dir = os.path.join(self.other_data_dir, "multi_dir_dataset", "1.2.0") tf.io.gfile.makedirs(data_dir) with mock.patch.object(dataset_info.DatasetInfo, "read_from_directory"): _, info = load.load("multi_dir_dataset", split=[], with_info=True) self.assertEqual(info.data_dir, data_dir)