示例#1
0
文件: load.py 项目: mbbessa/datasets
def builder_cls_from_module(
    module_name: str,) -> Type[dataset_builder.DatasetBuilder]:
  """Imports the module and extract the `tfds.core.DatasetBuilder`.

  Args:
    module_name: Dataset module to import containing the dataset definition
      (e.g. `tensorflow_datasets.image.mnist.mnist`)

  Returns:
    The extracted tfds.core.DatasetBuilder builder class.
  """
  if module_name not in sys.modules:  # Module already imported
    # Module can be created during execution, so call invalidate_caches() to
    # make sure the new module is noticed by the import system.
    importlib.invalidate_caches()

    # Executing the module will register the datasets in _MODULE_TO_DATASETS.
    with registered.skip_registration(),\
         huggingface_wrapper.mock_huggingface_import():
      importlib.import_module(module_name)
      # TODO(tfds): For community-installed modules, we should raise cleaner
      # error if there is additional missing dependency. E.g. Parsing all
      # import statements. Or wrap this `importlib.import_module` within a
      # `with lazy_imports():` context manager ?

  builder_classes = registered._MODULE_TO_DATASETS.get(module_name, [])  # pylint: disable=protected-access

  if len(builder_classes) != 1:
    raise ValueError(
        f'Could not load DatasetBuilder from: {module_name}. '
        'Make sure the module only contains a single `DatasetBuilder`.\n'
        'If no dataset is detected, make sure that all abstractmethods are '
        'implemented.\n'
        f'Detected builders: {builder_classes}')
  return builder_classes[0]
示例#2
0
文件: load.py 项目: zdz1130/datasets
def builder_cls_from_module(
    module_name: str,
) -> Type[dataset_builder.DatasetBuilder]:
  """Imports the module and extract the `tfds.core.DatasetBuilder`.

  Args:
    module_name: Dataset module to import containing the dataset definition
      (e.g. `tensorflow_datasets.image.mnist.mnist`)

  Returns:
    The extracted tfds.core.DatasetBuilder builder class.
  """
  if module_name not in sys.modules:  # Module already imported
    # Module can be created during execution, so call invalidate_caches() to
    # make sure the new module is noticed by the import system.
    importlib.invalidate_caches()

    # Executing the module will register the datasets in _MODULE_TO_DATASETS.
    with registered.skip_registration():
      importlib.import_module(module_name)

  builder_classes = registered._MODULE_TO_DATASETS.get(module_name, [])  # pylint: disable=protected-access

  if len(builder_classes) != 1:
    raise ValueError(
        f'Could not load DatasetBuilder from: {module_name}. '
        'Make sure the module only contains a single `DatasetBuilder`. '
        f'Detected builders: {builder_classes}'
    )
  return builder_classes[0]
示例#3
0
def test_skip_dataset_collection_registration():
    with registered.skip_registration():

        class SkipCollectionBuilder(registered.RegisteredDatasetCollection):  # pylint: disable=unused-variable
            pass

    assert ("skip_collection_builder"
            not in registered.list_imported_dataset_collections())
示例#4
0
def test_skip_regitration():
    """Test `skip_registration()`."""

    with registered.skip_registration():

        class SkipRegisteredDataset(registered.RegisteredDataset):
            pass

    name = "skip_registered_dataset"
    assert name == SkipRegisteredDataset.name
    assert name not in load.list_builders()
示例#5
0
    def test_skip_regitration(self):
        """Test `skip_registration()`."""

        with registered.skip_registration():

            @six.add_metaclass(registered.RegisteredDataset)
            class SkipRegisteredDataset(object):
                pass

        name = "skip_registered_dataset"
        self.assertEqual(name, SkipRegisteredDataset.name)
        self.assertNotIn(name, registered.list_builders())
示例#6
0
# coding=utf-8
# Copyright 2020 The TensorFlow Datasets Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Custom Datasets APIs."""

from tensorflow_datasets.core import registered

# Custom datasets cannot be instanciated through `tfds.load`
with registered.skip_registration():
    # pylint: disable=g-import-not-at-top
    from tensorflow_datasets.core.folder_dataset.image_folder import ImageFolder
    from tensorflow_datasets.core.folder_dataset.translate_folder import TranslateFolder
    # pylint: enable=g-import-not-at-top

__all__ = [
    "ImageFolder",
    "TranslateFolder",
]