Beispiel #1
0
    def build(self, config, dataset_type):
        download_folder = os.path.join(get_mmf_root(), config.data_dir,
                                       config.data_folder)

        file_name = CLEVR_DOWNLOAD_URL.split("/")[-1]
        local_filename = os.path.join(download_folder, file_name)

        extraction_folder = os.path.join(download_folder,
                                         ".".join(file_name.split(".")[:-1]))
        self.data_folder = extraction_folder

        # Either if the zip file is already present or if there are some
        # files inside the folder we don't continue download process
        if os.path.exists(local_filename):
            logger.info("CLEVR dataset is already present. Skipping download.")
            return

        if (os.path.exists(extraction_folder)
                and len(os.listdir(extraction_folder)) != 0):
            return

        logger.info("Downloading the CLEVR dataset now")
        download(CLEVR_DOWNLOAD_URL, download_folder,
                 CLEVR_DOWNLOAD_URL.split("/")[-1])

        logger.info("Downloaded. Extracting now. This can take time.")
        with zipfile.ZipFile(local_filename, "r") as zip_ref:
            zip_ref.extractall(download_folder)
def resolve_cache_dir(env_variable="MMF_CACHE_DIR", default="mmf"):
    # Some of this follow what "transformers" does for there cache resolving
    try:
        from torch.hub import _get_torch_home

        torch_cache_home = _get_torch_home()
    except ImportError:
        torch_cache_home = os.path.expanduser(
            os.getenv(
                "TORCH_HOME",
                os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch"),
            )
        )
    default_cache_path = os.path.join(torch_cache_home, default)

    cache_path = os.getenv(env_variable, default_cache_path)

    if not PathManager.exists(cache_path):
        try:
            PathManager.mkdirs(cache_path)
        except PermissionError:
            cache_path = os.path.join(get_mmf_root(), ".mmf_cache")
            PathManager.mkdirs(cache_path)

    return cache_path
def get_mmf_cache_dir():
    config = get_global_config()
    cache_dir = config.env.cache_dir
    # If cache_dir path exists do not join to mmf root
    if not os.path.exists(cache_dir):
        cache_dir = os.path.join(get_mmf_root(), cache_dir)
    return cache_dir
def load_yaml(f):
    # Convert to absolute path for loading includes
    abs_f = get_absolute_path(f)

    try:
        mapping = OmegaConf.load(abs_f)
        f = abs_f
    except FileNotFoundError as e:
        # Check if this file might be relative to root?
        # TODO: Later test if this can be removed
        relative = os.path.abspath(os.path.join(get_mmf_root(), f))
        if not PathManager.isfile(relative):
            raise e
        else:
            f = relative
            mapping = OmegaConf.load(f)

    if mapping is None:
        mapping = OmegaConf.create()

    includes = mapping.get("includes", [])

    if not isinstance(includes, collections.abc.Sequence):
        raise AttributeError(
            "Includes must be a list, {} provided".format(type(includes))
        )

    include_mapping = OmegaConf.create()

    mmf_root_dir = get_mmf_root()

    for include in includes:
        original_include_path = include
        include = os.path.join(mmf_root_dir, include)

        # If path doesn't exist relative to MMF root, try relative to current file
        if not PathManager.exists(include):
            include = os.path.join(os.path.dirname(f), original_include_path)

        current_include_mapping = load_yaml(include)
        include_mapping = OmegaConf.merge(include_mapping, current_include_mapping)

    mapping.pop("includes", None)

    mapping = OmegaConf.merge(include_mapping, mapping)

    return mapping
Beispiel #5
0
    def build(self, config, dataset_type):
        self._dataset_type = dataset_type
        self._config = config
        data_folder = os.path.join(get_mmf_root(), self._config.data_dir)

        # Since the imdb tar file contains all of the sets, we won't download them
        # except in case of train
        if self._dataset_type != "train":
            return

        self._download_and_extract_imdb(data_folder)
        self._download_and_extract_features(data_folder)
    def build(self, config, dataset_type):
        self._dataset_type = dataset_type
        self._config = config
        data_folder = os.path.join(get_mmf_root(), self._config.data_dir)

        self._download_and_extract_imdb(data_folder)

        if self._dataset_type != "train":
            return

        self._download_and_extract("vocabs", VISUAL_DIALOG_CONSTS["vocabs"],
                                   data_folder)
        self._download_and_extract_features(data_folder)
 def setUp(self):
     setup_imports()
     torch.manual_seed(1234)
     config_path = os.path.join(
         get_mmf_root(),
         "..",
         "projects",
         "butd",
         "configs",
         "coco",
         "beam_search.yaml",
     )
     config_path = os.path.abspath(config_path)
     args = dummy_args(model="butd", dataset="coco")
     args.opts.append(f"config={config_path}")
     configuration = Configuration(args)
     configuration.config.datasets = "coco"
     configuration.freeze()
     self.config = configuration.config
     registry.register("config", self.config)
    def __init__(self, config, dataset_type, data_folder=None, *args, **kwargs):
        super().__init__(_CONSTANTS["dataset_key"], config, dataset_type)
        self._data_folder = data_folder
        self._data_dir = os.path.join(get_mmf_root(), config.data_dir)

        if not self._data_folder:
            self._data_folder = os.path.join(self._data_dir, config.data_folder)

        if not os.path.exists(self._data_folder):
            raise RuntimeError(
                _TEMPLATES["data_folder_missing_error"].format(self._data_folder)
            )

        # Check if the folder was actually extracted in the subfolder
        if config.data_folder in os.listdir(self._data_folder):
            self._data_folder = os.path.join(self._data_folder, config.data_folder)

        if len(os.listdir(self._data_folder)) == 0:
            raise FileNotFoundError(_CONSTANTS["empty_folder_error"])

        self.load()
 def setUp(self):
     torch.manual_seed(1234)
     registry.register("clevr_text_vocab_size", 80)
     registry.register("clevr_num_final_outputs", 32)
     config_path = os.path.join(
         get_mmf_root(),
         "..",
         "projects",
         "others",
         "cnn_lstm",
         "clevr",
         "defaults.yaml",
     )
     config_path = os.path.abspath(config_path)
     args = dummy_args(model="cnn_lstm", dataset="clevr")
     args.opts.append(f"config={config_path}")
     configuration = Configuration(args)
     configuration.config.datasets = "clevr"
     configuration.freeze()
     self.config = configuration.config
     registry.register("config", self.config)
 def setUp(self):
     setup_imports()
     torch.manual_seed(1234)
     config_path = os.path.join(
         get_mmf_root(),
         "..",
         "projects",
         "butd",
         "configs",
         "coco",
         "nucleus_sampling.yaml",
     )
     config_path = os.path.abspath(config_path)
     args = dummy_args(model="butd", dataset="coco")
     args.opts.append(f"config={config_path}")
     configuration = Configuration(args)
     configuration.config.datasets = "coco"
     configuration.config.model_config.butd.inference.params.sum_threshold = 0.5
     configuration.freeze()
     self.config = configuration.config
     registry.register("config", self.config)
Beispiel #11
0
 def test_config_overrides(self):
     config_path = os.path.join(
         get_mmf_root(),
         "..",
         "projects",
         "m4c",
         "configs",
         "textvqa",
         "defaults.yaml",
     )
     config_path = os.path.abspath(config_path)
     args = dummy_args(model="m4c", dataset="textvqa")
     args.opts += [
         f"config={config_path}",
         "training.lr_steps[1]=10000",
         'dataset_config.textvqa.zoo_requirements[0]="test"',
     ]
     configuration = Configuration(args)
     configuration.freeze()
     config = configuration.get_config()
     self.assertEqual(config.training.lr_steps[1], 10000)
     self.assertEqual(config.dataset_config.textvqa.zoo_requirements[0],
                      "test")
Beispiel #12
0
 def _test_quality_check(self, fn):
     fn(get_mmf_root())
     fn(os.path.join(get_mmf_root(), "..", "mmf_cli"))
     fn(os.path.join(get_mmf_root(), "..", "tests"))