Beispiel #1
0
 def test_default_multi_dir(self):
     # No data_dir is passed
     # Multiple data_dirs are registered
     # -> use default path
     constants.add_data_dir(self.other_data_dir)
     self.assertBuildDataDir(self.builder._build_data_dir(None),
                             self.default_data_dir)
 def test_default_multi_dir_duplicate(self):
   # If two data dirs contains the dataset, raise an error...
   constants.add_data_dir(self.other_data_dir)
   tf.io.gfile.makedirs(os.path.join(
       self.default_data_dir, "dummy_dataset_shared_generator", "1.0.0"))
   tf.io.gfile.makedirs(os.path.join(
       self.other_data_dir, "dummy_dataset_shared_generator", "1.0.0"))
   with self.assertRaisesRegex(ValueError, "found in more than one directory"):
     self.builder._build_data_dir(None)
 def test_expicit_multi_dir(self):
   # If two data dirs contains the same version
   # Data dir is explicitly passed
   constants.add_data_dir(self.other_data_dir)
   tf.io.gfile.makedirs(os.path.join(
       self.default_data_dir, "dummy_dataset_shared_generator", "1.0.0"))
   tf.io.gfile.makedirs(os.path.join(
       self.other_data_dir, "dummy_dataset_shared_generator", "1.0.0"))
   self.assertBuildDataDir(
       self.builder._build_data_dir(self.other_data_dir), self.other_data_dir)
 def test_default_multi_dir_version_exists(self):
   # No data_dir is passed
   # Multiple data_dirs are registered
   # Data found
   # -> Re-load existing data
   constants.add_data_dir(self.other_data_dir)
   tf.io.gfile.makedirs(os.path.join(
       self.other_data_dir, "dummy_dataset_shared_generator", "1.0.0"))
   self.assertBuildDataDir(
       self.builder._build_data_dir(None), self.other_data_dir)
 def test_default_multi_dir_old_version_exists(self):
   # No data_dir is passed
   # Multiple data_dirs are registered
   # Data dir contains old versions
   # -> use default path
   constants.add_data_dir(self.other_data_dir)
   tf.io.gfile.makedirs(os.path.join(
       self.other_data_dir, "dummy_dataset_shared_generator", "0.1.0"))
   tf.io.gfile.makedirs(os.path.join(
       self.other_data_dir, "dummy_dataset_shared_generator", "0.2.0"))
   self.assertBuildDataDir(
       self.builder._build_data_dir(None), self.default_data_dir)
Beispiel #6
0
    def test_load_data_dir(self):
        """Ensure that `tfds.load` also supports multiple data_dir."""
        constants.add_data_dir(self.other_data_dir)

        class MultiDirDataset(DummyDatasetSharedGenerator):  # pylint: disable=unused-variable
            VERSION = utils.Version("1.2.0")

        data_dir = os.path.join(self.other_data_dir, "multi_dir_dataset",
                                "1.2.0")
        tf.io.gfile.makedirs(data_dir)

        with mock.patch.object(dataset_info.DatasetInfo,
                               "read_from_directory"):
            _, info = load.load("multi_dir_dataset", split=[], with_info=True)
        self.assertEqual(info.data_dir, data_dir)