Exemplo n.º 1
0
    def _download_requirement(self,
                              config,
                              requirement_key,
                              requirement_variation="defaults"):
        version, resources = get_zoo_config(requirement_key,
                                            requirement_variation,
                                            self.zoo_config_path,
                                            self.zoo_type)

        if resources is None:
            return

        requirement_split = requirement_key.split(".")
        dataset_name = requirement_split[0]

        # The dataset variation has been directly passed in the key so use it instead
        if len(requirement_split) >= 2:
            dataset_variation = requirement_split[1]
        else:
            dataset_variation = requirement_variation

        # We want to use root env data_dir so that we don't mix up our download
        # root dir with the dataset ones
        download_path = os.path.join(get_mmf_env("data_dir"), "datasets",
                                     dataset_name, dataset_variation)
        download_path = get_absolute_path(download_path)

        if not isinstance(resources, collections.abc.Mapping):
            self._download_resources(resources, download_path, version)
        else:
            use_features = config.get("use_features", False)
            use_images = config.get("use_images", False)

            if use_features:
                self._download_based_on_attribute(resources, download_path,
                                                  version, "features")

            if use_images:
                self._download_based_on_attribute(resources, download_path,
                                                  version, "images")

            self._download_based_on_attribute(resources, download_path,
                                              version, "annotations")
            self._download_resources(resources.get("extras", []),
                                     download_path, version)
Exemplo n.º 2
0
    def test_get_zoo_config(self):
        # Test direct key
        version, resources = configuration.get_zoo_config("textvqa.ocr_en")
        self.assertIsNotNone(version)
        self.assertIsNotNone(resources)

        # Test default variation
        version, resources = configuration.get_zoo_config("textvqa")
        self.assertIsNotNone(version)
        self.assertIsNotNone(resources)

        # Test non-default variation
        version, resources = configuration.get_zoo_config("textvqa",
                                                          variation="ocr_en")
        self.assertIsNotNone(version)
        self.assertIsNotNone(resources)

        # Test random key
        version, resources = configuration.get_zoo_config("some_random")
        self.assertIsNone(version)
        self.assertIsNone(resources)

        # Test non-existent variation
        self.assertRaises(
            AssertionError,
            configuration.get_zoo_config,
            "textvqa",
            variation="some_random",
        )

        # Test different zoo_type
        version, resources = configuration.get_zoo_config(
            "visual_bert.pretrained", zoo_type="models")
        self.assertIsNotNone(version)
        self.assertIsNotNone(resources)

        # Test direct config
        version, resources = configuration.get_zoo_config(
            "visual_bert.pretrained",
            zoo_config_path=os.path.join("configs", "zoo", "models.yaml"),
        )
        self.assertIsNotNone(version)
        self.assertIsNotNone(resources)