Beispiel #1
0
    def test_import_user_module_from_file(self):
        self.assertIsNone(registry.get_builder_class("always_one"))
        self.assertIsNone(registry.get_model_class("simple"))

        user_dir = self._get_user_dir()
        user_file = os.path.join(user_dir, "models", "simple.py")
        import_user_module(user_file)
        # Only model should be found and build should be none
        self.assertIsNone(registry.get_builder_class("always_one"))
        self.assertIsNotNone(registry.get_model_class("simple"))
        self.assertTrue("mmf_user_dir" in sys.modules)
        self.assertTrue(user_dir in get_mmf_env("user_dir"))
Beispiel #2
0
    def test_import_user_module_from_directory_absolute(self, abs_path=True):
        # Make sure the modules are not available first
        self.assertIsNone(registry.get_builder_class("always_one"))
        self.assertIsNone(registry.get_model_class("simple"))
        self.assertFalse("mmf_user_dir" in sys.modules)

        # Now, import and test
        user_dir = self._get_user_dir(abs_path)
        import_user_module(user_dir)
        self.assertIsNotNone(registry.get_builder_class("always_one"))
        self.assertIsNotNone(registry.get_model_class("simple"))
        self.assertTrue("mmf_user_dir" in sys.modules)
        self.assertTrue(user_dir in get_mmf_env("user_dir"))
Beispiel #3
0
    def _build_dataset_config(self, config):
        dataset = config.dataset
        datasets = config.datasets

        if dataset is None and datasets is None:
            raise KeyError("Required argument 'dataset|datasets' not passed")

        if datasets is None:
            config.datasets = dataset
            datasets = dataset.split(",")
        else:
            datasets = datasets.split(",")

        dataset_config = OmegaConf.create()

        for dataset in datasets:
            builder_cls = registry.get_builder_class(dataset)

            if builder_cls is None:
                warning = f"No dataset named '{dataset}' has been registered"
                warnings.warn(warning)
                continue
            default_dataset_config_path = builder_cls.config_path()

            if default_dataset_config_path is None:
                warning = (
                    "Dataset {}'s builder class has no default configuration "
                    + f"provided")
                warnings.warn(warning)
                continue

            dataset_config = OmegaConf.merge(
                dataset_config, load_yaml(default_dataset_config_path))

        return dataset_config
Beispiel #4
0
def build_datamodule(dataset_key) -> pl.LightningDataModule:
    dataset_builder = registry.get_builder_class(dataset_key)
    assert dataset_builder, (
        f"Key {dataset_key} doesn't have a registered " + "dataset builder"
    )
    builder_instance: pl.LightningDataModule = dataset_builder()
    return builder_instance
Beispiel #5
0
def build_dataset(dataset_key: str,
                  config=None,
                  dataset_type="train") -> torch.utils.data.Dataset:
    """Builder function for creating a dataset. If dataset_key is passed
    the dataset is created from default config of the dataset and thus is
    disable config even if it is passed. Otherwise, we use MultiDatasetLoader to
    build and return an instance of dataset based on the config

    Args:
        dataset_key (str): Key of dataset to build.
        config (DictConfig, optional): Configuration that will be used to create
            the dataset. If not passed, dataset's default config will be used.
            Defaults to {}.
        dataset_type (str, optional): Type of the dataset to build, train|val|test.
            Defaults to "train".

    Returns:
        (torch.utils.data.Dataset): A dataset instance of type torch Dataset
    """
    from mmf.datasets.base_dataset_builder import BaseDatasetBuilder
    from mmf.utils.configuration import load_yaml_with_defaults

    dataset_builder = registry.get_builder_class(dataset_key)
    assert dataset_builder, (f"Key {dataset_key} doesn't have a registered " +
                             "dataset builder")

    # If config is not provided, we take it from default one
    if not config:
        config_path = dataset_builder.config_path()
        if config_path is None:
            # If config path wasn't defined, send an empty config path
            # but don't force dataset to define a config
            warnings.warn(f"Config path not defined for {dataset_key}, " +
                          "continuing with empty config")
            config = OmegaConf.create()
        else:
            config = load_yaml_with_defaults(config_path)
            config = OmegaConf.select(config, f"dataset_config.{dataset_key}")
            if config is None:
                config = OmegaConf.create()
            OmegaConf.set_struct(config, True)

    builder_instance: BaseDatasetBuilder = dataset_builder()
    builder_instance.build_dataset(config, dataset_type)
    dataset = builder_instance.load_dataset(config, dataset_type)
    if hasattr(builder_instance, "update_registry_for_model"):
        builder_instance.update_registry_for_model(config)

    return dataset
Beispiel #6
0
def build_dataset(
    dataset_key: str, config=None, dataset_type="train"
) -> mmf_typings.DatasetType:
    """Builder function for creating a dataset. If dataset_key is passed
    the dataset is created from default config of the dataset and thus is
    disable config even if it is passed. Otherwise, we use MultiDatasetLoader to
    build and return an instance of dataset based on the config

    Args:
        dataset_key (str): Key of dataset to build.
        config (DictConfig, optional): Configuration that will be used to create
            the dataset. If not passed, dataset's default config will be used.
            Defaults to {}.
        dataset_type (str, optional): Type of the dataset to build, train|val|test.
            Defaults to "train".

    Returns:
        (DatasetType): A dataset instance of type BaseDataset
    """
    from mmf.utils.configuration import load_yaml_with_defaults

    dataset_builder = registry.get_builder_class(dataset_key)
    assert dataset_builder, (
        f"Key {dataset_key} doesn't have a registered " + "dataset builder"
    )

    # If config is not provided, we take it from default one
    if not config:
        config = load_yaml_with_defaults(dataset_builder.config_path())
        config = OmegaConf.select(config, f"dataset_config.{dataset_key}")
        OmegaConf.set_struct(config, True)

    builder_instance: mmf_typings.DatasetBuilderType = dataset_builder()
    builder_instance.build_dataset(config, dataset_type)
    dataset = builder_instance.load_dataset(config, dataset_type)
    if hasattr(builder_instance, "update_registry_for_model"):
        builder_instance.update_registry_for_model(config)

    return dataset
Beispiel #7
0
    def load(self, config):
        self.config = config
        self._process_datasets()

        self._datasets = []
        self._builders = []
        self._loaders = []
        self._samplers = []
        self._iterators = []

        self._total_length = 0
        self._per_dataset_lengths = []
        self._num_datasets = 0
        self._finished_iterators = {}
        self._used_once = {}

        for dataset in self._given_datasets:
            builder_class = registry.get_builder_class(dataset)

            if builder_class is None:
                print("No builder class found for %s." % dataset)
                continue
            builder_instance = builder_class()

            if dataset in self.config.dataset_config:
                dataset_config = self.config.dataset_config[dataset]
            else:
                self.writer.write(
                    "Dataset %s is missing from "
                    "dataset_config in config." % dataset,
                    "error",
                )
                sys.exit(1)

            builder_instance.build_dataset(dataset_config, self._dataset_type)
            dataset_instance = builder_instance.load_dataset(
                dataset_config, self._dataset_type)

            if dataset_instance is None:
                continue

            loader_instance, sampler_instance = self.build_dataloader(
                self.config, dataset_instance)

            self._builders.append(builder_instance)
            self._datasets.append(dataset_instance)
            self._loaders.append(loader_instance)
            self._samplers.append(sampler_instance)

            self._per_dataset_lengths.append(len(dataset_instance))
            self._total_length += len(dataset_instance)

        self._num_datasets = len(self._datasets)
        self._dataset_probablities = [
            1 / self._num_datasets for _ in range(self._num_datasets)
        ]

        training = self._global_config.training
        self._proportional_sampling = training.dataset_size_proportional_sampling

        if self._dataset_type != "train":
            # If it is val or test, it needs to be all datasets need to be
            # fully iterated as metrics will be calculated in eval mode
            # over complete datasets
            self._proportional_sampling = True

        if self._proportional_sampling is True:
            self._dataset_probablities = self._per_dataset_lengths[:]
            self._dataset_probablities = [
                prob / self._total_length
                for prob in self._dataset_probablities
            ]

        self._loader_index = 0
        self._chosen_dataset = self._datasets[self._loader_index]
        self._chosen_loader = self._loaders[self._loader_index]
Beispiel #8
0
    def load(self, config, dataset_type, *args, **kwargs):
        """The VinVL dataset is a dataset that augments an existing
        dataset within MMF. VinVL requires unique inputs for
        finetuning and pretraining unsupported by general datasets.
        To enable this functionality on arbitrary datasets,
        the VinVL dataset contains a base dataset,
        and returns an augmented version of samples from the
        base dataset.
        For more details, read the VinVL dataset docstring.

        The Builder:
        This class is a builder for the VinVL dataset.
        As the VinVL dataset must be constructed with an instance to
        a base dataset, configured by the client in the VinVL configs
        yaml. This builder class instantiates 2 datasets, then
        passes the base dataset to the VinVL dataset instance.

        The VinVL config is expected to have the following stucture,
        ```yaml
        dataset_config:
            vinvl:
                base_dataset_name: vqa2
                label_map: <path to label map>
                base_dataset: ${dataset_config.vqa2}
                processors:
                    text_processor:
                        type: vinvl_text_tokenizer
                        params:
                            ...
        ```
        Where base_dataset is the yaml config for the base dataset
        in this example vqa2.
        And base_dataset_name is vqa2.

        Returns:
            VinVLDataset: Instance of the VinVLDataset class which contains
            an base dataset instance.
        """
        base_dataset_name = config.get("base_dataset_name", "vqa2")
        base_dataset_config = config.get("base_dataset", config)
        # instantiate base dataset
        # instantiate base dataser builder
        base_dataset_builder_class = registry.get_builder_class(
            base_dataset_name)
        base_dataset_builder_instance = base_dataset_builder_class()
        # build base dataset instance
        base_dataset_builder_instance.build_dataset(base_dataset_config)
        base_dataset = base_dataset_builder_instance.load_dataset(
            base_dataset_config, dataset_type)
        if hasattr(base_dataset_builder_instance, "update_registry_for_model"):
            base_dataset_builder_instance.update_registry_for_model(
                base_dataset_config)

        # instantiate vinvl dataset
        vinvl_text_processor = config["processors"]["text_processor"]
        with open_dict(base_dataset_config):
            base_dataset_config["processors"][
                "text_processor"] = vinvl_text_processor
            base_dataset_config["label_map"] = config["label_map"]

        vinvl_dataset = super().load(base_dataset_config, dataset_type, *args,
                                     **kwargs)
        vinvl_dataset.set_base_dataset(base_dataset)
        return vinvl_dataset