def builder_cls_from_module( module_name: str,) -> Type[dataset_builder.DatasetBuilder]: """Imports the module and extract the `tfds.core.DatasetBuilder`. Args: module_name: Dataset module to import containing the dataset definition (e.g. `tensorflow_datasets.image.mnist.mnist`) Returns: The extracted tfds.core.DatasetBuilder builder class. """ if module_name not in sys.modules: # Module already imported # Module can be created during execution, so call invalidate_caches() to # make sure the new module is noticed by the import system. importlib.invalidate_caches() # Executing the module will register the datasets in _MODULE_TO_DATASETS. with registered.skip_registration(),\ huggingface_wrapper.mock_huggingface_import(): importlib.import_module(module_name) # TODO(tfds): For community-installed modules, we should raise cleaner # error if there is additional missing dependency. E.g. Parsing all # import statements. Or wrap this `importlib.import_module` within a # `with lazy_imports():` context manager ? builder_classes = registered._MODULE_TO_DATASETS.get(module_name, []) # pylint: disable=protected-access if len(builder_classes) != 1: raise ValueError( f'Could not load DatasetBuilder from: {module_name}. ' 'Make sure the module only contains a single `DatasetBuilder`.\n' 'If no dataset is detected, make sure that all abstractmethods are ' 'implemented.\n' f'Detected builders: {builder_classes}') return builder_classes[0]
def builder_cls_from_module( module_name: str, ) -> Type[dataset_builder.DatasetBuilder]: """Imports the module and extract the `tfds.core.DatasetBuilder`. Args: module_name: Dataset module to import containing the dataset definition (e.g. `tensorflow_datasets.image.mnist.mnist`) Returns: The extracted tfds.core.DatasetBuilder builder class. """ if module_name not in sys.modules: # Module already imported # Module can be created during execution, so call invalidate_caches() to # make sure the new module is noticed by the import system. importlib.invalidate_caches() # Executing the module will register the datasets in _MODULE_TO_DATASETS. with registered.skip_registration(): importlib.import_module(module_name) builder_classes = registered._MODULE_TO_DATASETS.get(module_name, []) # pylint: disable=protected-access if len(builder_classes) != 1: raise ValueError( f'Could not load DatasetBuilder from: {module_name}. ' 'Make sure the module only contains a single `DatasetBuilder`. ' f'Detected builders: {builder_classes}' ) return builder_classes[0]
def test_skip_dataset_collection_registration(): with registered.skip_registration(): class SkipCollectionBuilder(registered.RegisteredDatasetCollection): # pylint: disable=unused-variable pass assert ("skip_collection_builder" not in registered.list_imported_dataset_collections())
def test_skip_regitration(): """Test `skip_registration()`.""" with registered.skip_registration(): class SkipRegisteredDataset(registered.RegisteredDataset): pass name = "skip_registered_dataset" assert name == SkipRegisteredDataset.name assert name not in load.list_builders()
def test_skip_regitration(self): """Test `skip_registration()`.""" with registered.skip_registration(): @six.add_metaclass(registered.RegisteredDataset) class SkipRegisteredDataset(object): pass name = "skip_registered_dataset" self.assertEqual(name, SkipRegisteredDataset.name) self.assertNotIn(name, registered.list_builders())
# coding=utf-8 # Copyright 2020 The TensorFlow Datasets Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Custom Datasets APIs.""" from tensorflow_datasets.core import registered # Custom datasets cannot be instanciated through `tfds.load` with registered.skip_registration(): # pylint: disable=g-import-not-at-top from tensorflow_datasets.core.folder_dataset.image_folder import ImageFolder from tensorflow_datasets.core.folder_dataset.translate_folder import TranslateFolder # pylint: enable=g-import-not-at-top __all__ = [ "ImageFolder", "TranslateFolder", ]