コード例 #1
0
ファイル: encoders.py プロジェクト: ytsheng/mmf
    def __init__(self, in_dim, weights_file, bias_file, model_data_dir, *args,
                 **kwargs):
        super().__init__()
        model_data_dir = get_absolute_path(model_data_dir)

        if not os.path.isabs(weights_file):
            weights_file = os.path.join(model_data_dir, weights_file)
        if not os.path.isabs(bias_file):
            bias_file = os.path.join(model_data_dir, bias_file)

        if not PathManager.exists(bias_file) or not PathManager.exists(
                weights_file):
            download_path = download_pretrained_model("detectron.vmb_weights")
            weights_file = get_absolute_path(
                os.path.join(download_path, "fc7_w.pkl"))
            bias_file = get_absolute_path(
                os.path.join(download_path, "fc7_b.pkl"))

        with PathManager.open(weights_file, "rb") as w:
            weights = pickle.load(w)
        with PathManager.open(bias_file, "rb") as b:
            bias = pickle.load(b)
        out_dim = bias.shape[0]

        self.lc = nn.Linear(in_dim, out_dim)
        self.lc.weight.data.copy_(torch.from_numpy(weights))
        self.lc.bias.data.copy_(torch.from_numpy(bias))
        self.out_dim = out_dim
コード例 #2
0
ファイル: env.py プロジェクト: hila-chefer/NLP_Final_Project
def import_user_module(user_dir: str):
    """Given a user dir, this function imports it as a module.

    This user_module is expected to have an __init__.py at its root.
    You can use import_files to import your python files easily in
    __init__.py

    Args:
        user_dir (str): directory which has to be imported
    """
    logger = logging.getLogger(__name__)
    if user_dir:
        user_dir = get_absolute_path(user_dir)
        module_parent, module_name = os.path.split(user_dir)

        if module_name in sys.modules:
            module_bak = sys.modules[module_name]
            del sys.modules[module_name]
        else:
            module_bak = None

        logger.info(f"Importing from {user_dir}")
        sys.path.insert(0, module_parent)
        importlib.import_module(module_name)

        sys.modules["mmf_user_dir"] = sys.modules[module_name]
        if module_bak is not None and module_name != "mmf_user_dir":
            sys.modules[module_name] = module_bak
コード例 #3
0
ファイル: annotation_database.py プロジェクト: n-zhang/mmf
 def __init__(self, config, path, *args, **kwargs):
     super().__init__()
     self.metadata = {}
     self.config = config
     self.start_idx = 0
     path = get_absolute_path(path)
     self.load_annotation_db(path)
コード例 #4
0
    def __init__(self,
                 config,
                 path,
                 annotation_db=None,
                 feature_key=None,
                 *args,
                 **kwargs):
        super().__init__(config, path, annotation_db, *args, **kwargs)
        self.feature_readers = []
        self.feature_dict = {}
        self.feature_key = config.get("feature_key", "feature_path")
        self.feature_key = feature_key if feature_key else self.feature_key
        self._fast_read = config.get("fast_read", False)

        path = path.split(",")

        for image_feature_dir in path:
            feature_reader = FeatureReader(
                base_path=get_absolute_path(image_feature_dir),
                depth_first=config.get("depth_first", False),
                max_features=config.get("max_features", 100),
            )
            self.feature_readers.append(feature_reader)

        self.paths = path
        self.annotation_db = annotation_db
        self._should_return_info = config.get("return_features_info", True)

        if self._fast_read:
            path = ", ".join(path)
            logger.info(f"Fast reading features from {path}")
            logger.info("Hold tight, this may take a while...")
            self._threaded_read()
コード例 #5
0
ファイル: text.py プロジェクト: slbinilkumar/mmf
    def __init__(self, vocab_file, data_dir=None):
        if not os.path.isabs(vocab_file) and data_dir is not None:
            vocab_file = get_absolute_path(os.path.join(data_dir, vocab_file))

        if not PathManager.exists(vocab_file):
            raise RuntimeError(f"Vocab file {vocab_file} for vocab dict doesn't exist")

        self.word_list = load_str_list(vocab_file)
        self._build()
コード例 #6
0
    def __init__(self,
                 vocab_file,
                 embedding_file,
                 data_dir=None,
                 *args,
                 **kwargs):
        """Use this vocab class when you have a custom vocab as well as a
        custom embeddings file.

        This will inherit vocab class, so you will get predefined tokens with
        this one.

        IMPORTANT: To init your embedding, get your vectors from this class's
        object by calling `get_vectors` function

        Parameters
        ----------
        vocab_file : str
            Path of custom vocabulary
        embedding_file : str
            Path to custom embedding inititalization file
        data_dir : str
            Path to data directory if embedding file is not an absolute path.
            Default: None
        """
        super().__init__(vocab_file)
        self.type = "custom"

        if not os.path.isabs(embedding_file) and data_dir is not None:
            embedding_file = os.path.join(data_dir, embedding_file)
            embedding_file = get_absolute_path(embedding_file)

        if not PathManager.exists(embedding_file):
            from mmf.common.registry import registry

            writer = registry.get("writer")
            error = "Embedding file path %s doesn't exist" % embedding_file
            if writer is not None:
                writer.write(error, "error")
            raise RuntimeError(error)

        embedding_vectors = torch.from_numpy(np.load(embedding_file))

        self.vectors = torch.FloatTensor(self.get_size(),
                                         len(embedding_vectors[0]))

        for i in range(0, 4):
            self.vectors[i] = torch.ones_like(self.vectors[i]) * 0.1 * i

        for i in range(4, self.get_size()):
            self.vectors[i] = embedding_vectors[i - 4]
コード例 #7
0
 def build(self, config, *args, **kwargs):
     # First, check whether manual downloads have been performed
     data_dir = get_mmf_env(key="data_dir")
     test_path = get_absolute_path(
         os.path.join(
             data_dir,
             "annotations",
             "train.jsonl",
         ))
     # NOTE: This doesn't check for files, but that is a fine assumption for now
     assert PathManager.exists(test_path), (
         "Hateful Memes Dataset doesn't do automatic downloads; please " +
         "follow instructions at https://fb.me/hm_prerequisites")
     super().build(config, *args, **kwargs)
コード例 #8
0
ファイル: download.py プロジェクト: zeta1999/mmf
def download_pretrained_model(model_name, *args, **kwargs):
    import omegaconf
    from omegaconf import OmegaConf

    from mmf.utils.configuration import load_yaml, get_mmf_env

    model_zoo = load_yaml(get_mmf_env(key="model_zoo"))
    OmegaConf.set_struct(model_zoo, True)
    OmegaConf.set_readonly(model_zoo, True)

    data_dir = get_absolute_path(get_mmf_env("data_dir"))
    model_data_dir = os.path.join(data_dir, "models")
    download_path = os.path.join(model_data_dir, model_name)

    try:
        model_config = OmegaConf.select(model_zoo, model_name)
    except omegaconf.errors.OmegaConfBaseException as e:
        print(f"No such model name {model_name} defined in mmf zoo")
        raise e

    if "version" not in model_config or "resources" not in model_config:
        # Version and Resources are not present time to try the defaults
        try:
            model_config = model_config.defaults
            download_path = os.path.join(model_data_dir, model_name + ".defaults")
        except omegaconf.errors.OmegaConfBaseException as e:
            print(
                f"Model name {model_name} doesn't specify 'resources' and 'version' "
                "while no defaults have been provided"
            )
            raise e

    # Download requirements if any specified by "zoo_requirements" field
    # This can either be a list or a string
    if "zoo_requirements" in model_config:
        requirements = model_config.zoo_requirements
        if isinstance(requirements, str):
            requirements = [requirements]
        for item in requirements:
            download_pretrained_model(item, *args, **kwargs)

    version = model_config.version
    resources = model_config.resources

    if is_master():
        download_resources(resources, download_path, version)
    synchronize()

    return download_path
コード例 #9
0
ファイル: builder.py プロジェクト: lilyli2004/mmf
 def build(self, config, *args, **kwargs):
     # First, check whether manual downloads have been performed
     data_dir = get_mmf_env(key="data_dir")
     test_path = get_absolute_path(
         os.path.join(
             data_dir,
             "datasets",
             self.dataset_name,
             "defaults",
             "annotations",
             "train.jsonl",
         ))
     # NOTE: This doesn't check for files, but that is a fine assumption for now
     assert PathManager.exists(test_path)
     super().build(config, *args, **kwargs)
コード例 #10
0
def load_yaml(f):
    # Convert to absolute path for loading includes
    abs_f = get_absolute_path(f)

    try:
        mapping = OmegaConf.load(abs_f)
        f = abs_f
    except FileNotFoundError as e:
        # Check if this file might be relative to root?
        # TODO: Later test if this can be removed
        relative = os.path.abspath(os.path.join(get_mmf_root(), f))
        if not PathManager.isfile(relative):
            raise e
        else:
            f = relative
            mapping = OmegaConf.load(f)

    if mapping is None:
        mapping = OmegaConf.create()

    includes = mapping.get("includes", [])

    if not isinstance(includes, collections.abc.Sequence):
        raise AttributeError("Includes must be a list, {} provided".format(
            type(includes)))

    include_mapping = OmegaConf.create()

    mmf_root_dir = get_mmf_root()

    for include in includes:
        original_include_path = include
        include = os.path.join(mmf_root_dir, include)

        # If path doesn't exist relative to MMF root, try relative to current file
        if not PathManager.exists(include):
            include = os.path.join(os.path.dirname(f), original_include_path)

        current_include_mapping = load_yaml(include)
        include_mapping = OmegaConf.merge(include_mapping,
                                          current_include_mapping)

    mapping.pop("includes", None)

    mapping = OmegaConf.merge(include_mapping, mapping)

    return mapping
コード例 #11
0
    def _download_requirement(self,
                              config,
                              requirement_key,
                              requirement_variation="defaults"):
        version, resources = get_zoo_config(requirement_key,
                                            requirement_variation,
                                            self.zoo_config_path,
                                            self.zoo_type)

        if resources is None:
            return

        requirement_split = requirement_key.split(".")
        dataset_name = requirement_split[0]

        # The dataset variation has been directly passed in the key so use it instead
        if len(requirement_split) >= 2:
            dataset_variation = requirement_split[1]
        else:
            dataset_variation = requirement_variation

        # We want to use root env data_dir so that we don't mix up our download
        # root dir with the dataset ones
        download_path = os.path.join(get_mmf_env("data_dir"), "datasets",
                                     dataset_name, dataset_variation)
        download_path = get_absolute_path(download_path)

        if not isinstance(resources, collections.abc.Mapping):
            self._download_resources(resources, download_path, version)
        else:
            use_features = config.get("use_features", False)
            use_images = config.get("use_images", False)

            if use_features:
                self._download_based_on_attribute(resources, download_path,
                                                  version, "features")

            if use_images:
                self._download_based_on_attribute(resources, download_path,
                                                  version, "images")

            self._download_based_on_attribute(resources, download_path,
                                              version, "annotations")
            self._download_resources(resources.get("extras", []),
                                     download_path, version)
コード例 #12
0
ファイル: env.py プロジェクト: EXYNOS-999/DeepMeMes
def import_user_module(user_dir: str, no_print: bool = False):
    """Given a user dir, this function imports it as a module.

    This user_module is expected to have an __init__.py at its root.
    You can use import_files to import your python files easily in
    __init__.py

    Args:
        user_dir (str): directory which has to be imported
        no_print (bool): This function won't print anything if set to true
    """
    if user_dir:
        user_dir = get_absolute_path(user_dir)
        module_parent, module_name = os.path.split(user_dir)

        if module_name not in sys.modules:
            sys.path.insert(0, module_parent)
            if not no_print:
                print(f"Importing user_dir from {user_dir}")
            importlib.import_module(module_name)
            sys.path.pop(0)
コード例 #13
0
    def __init__(
        self,
        config,
        path,
        annotation_db=None,
        transform=None,
        loader=default_loader,
        is_valid_file=None,
        image_key=None,
        *args,
        **kwargs
    ):
        """Initialize an instance of ImageDatabase

        Args:
            torch ([type]): [description]
            config (DictConfig): Config object from dataset_config
            path (str): Path to images folder
            annotation_db (AnnotationDB, optional): Annotation DB to be used
                to be figure out image paths. Defaults to None.
            transform (callable, optional): Transform to be called upon loaded image.
                Defaults to None.
            loader (callable, optional): Custom loader for image which given a path
                returns a PIL Image. Defaults to torchvision's default loader.
            is_valid_file (callable, optional): Custom callable to filter out invalid
                files. If image is invalid, {"images": []} will returned which you can
                filter out in your dataset. Defaults to None.
            image_key (str, optional): Key that points to image path in annotation db.
                If not specified, ImageDatabase will make some intelligent guesses
                about the possible key. Defaults to None.
        """
        super().__init__()
        self.config = config
        self.base_path = get_absolute_path(path)
        self.transform = transform
        self.annotation_db = annotation_db
        self.loader = loader
        self.image_key = config.get("image_key", None)
        self.image_key = image_key if image_key else self.image_key
        self.is_valid_file = is_valid_file
コード例 #14
0
    def __init__(self, config, dataset_type, imdb_file_index, *args, **kwargs):
        super().__init__(config,
                         dataset_type,
                         imdb_file_index,
                         dataset_name="visual_genome",
                         *args,
                         **kwargs)

        self._return_scene_graph = config.return_scene_graph
        self._return_objects = config.return_objects
        self._return_relationships = config.return_relationships
        self._return_region_descriptions = config.return_region_descriptions
        self._no_unk = config.get("no_unk", False)
        self.scene_graph_db = None
        self.region_descriptions_db = None
        self.image_metadata_db = None
        self._max_feature = config.max_features

        build_scene_graph_db = (self._return_scene_graph
                                or self._return_objects
                                or self._return_relationships)
        # print("config", config)
        if self._return_region_descriptions:
            print("use_region_descriptions_true")
            self.region_descriptions_db = self.build_region_descriptions_db()
            self.image_metadata_db = self.build_image_metadata_db()

        if build_scene_graph_db:
            scene_graph_file = config.scene_graph_files[dataset_type][
                imdb_file_index]
            print("scene_graph_file", scene_graph_file)
            # scene_graph_file = self._get_absolute_path(scene_graph_file)
            scene_graph_file = get_absolute_path(
                get_mmf_env("data_dir") + "/" + scene_graph_file)
            print("scene_graph_file", scene_graph_file)
            self.scene_graph_db = SceneGraphDatabase(config, scene_graph_file)
            print("use_scene_graph_true")
            self.scene_graph_db = self.build_scene_graph_db()
コード例 #15
0
def import_user_module(user_dir: str):
    """Given a user dir, this function imports it as a module.

    This user_module is expected to have an __init__.py at its root.
    You can use import_files to import your python files easily in
    __init__.py

    Args:
        user_dir (str): directory which has to be imported
    """
    from mmf.common.registry import registry
    from mmf.utils.general import get_absolute_path  # noqa

    logger = logging.getLogger(__name__)
    if user_dir:
        if registry.get("__mmf_user_dir_imported__", no_warning=True):
            logger.info(f"User dir {user_dir} already imported. Skipping.")
            return

        # Allow loading of files as user source
        if user_dir.endswith(".py"):
            user_dir = user_dir[:-3]

        dot_path = ".".join(user_dir.split(os.path.sep))
        # In case of abspath which start from "/" the first char
        # will be "." which turns it into relative module which
        # find_spec doesn't like
        if os.path.isabs(user_dir):
            dot_path = dot_path[1:]

        try:
            dot_spec = importlib.util.find_spec(dot_path)
        except ModuleNotFoundError:
            dot_spec = None
        abs_user_dir = get_absolute_path(user_dir)
        module_parent, module_name = os.path.split(abs_user_dir)

        # If dot path is found in sys.modules, or path can be directly
        # be imported, we don't need to play jugglery with actual path
        if dot_path in sys.modules or dot_spec is not None:
            module_name = dot_path
        else:
            user_dir = abs_user_dir

        logger.info(f"Importing from {user_dir}")
        if module_name != dot_path:
            # Since dot path hasn't been found or can't be imported,
            # we can try importing the module by changing sys path
            # to the parent
            sys.path.insert(0, module_parent)

        importlib.import_module(module_name)
        sys.modules["mmf_user_dir"] = sys.modules[module_name]

        # Register config for user's model and dataset config
        # relative path resolution
        config = registry.get("config")
        if config is None:
            registry.register(
                "config", OmegaConf.create({"env": {"user_dir": user_dir}})
            )
        else:
            with open_dict(config):
                config.env.user_dir = user_dir

        registry.register("__mmf_user_dir_imported__", True)
コード例 #16
0
ファイル: vocab.py プロジェクト: facebookresearch/mmf
    def __init__(self,
                 vocab_file=None,
                 embedding_dim=300,
                 data_dir=None,
                 *args,
                 **kwargs):
        """Vocab class to be used when you want to train word embeddings from
        scratch based on a custom vocab. This will initialize the random
        vectors for the vocabulary you pass. Get the vectors using
        `get_vectors` function. This will also create random embeddings for
        some predefined words like PAD - <pad>, SOS - <s>, EOS - </s>,
        UNK - <unk>.

        Parameters
        ----------
        vocab_file : str
            Path of the vocabulary file containing one word per line
        embedding_dim : int
            Size of the embedding

        """
        self.type = "base"
        self.word_dict = {}
        self.itos = {}

        self.itos[self.PAD_INDEX] = self.PAD_TOKEN
        self.itos[self.SOS_INDEX] = self.SOS_TOKEN
        self.itos[self.EOS_INDEX] = self.EOS_TOKEN
        self.itos[self.UNK_INDEX] = self.UNK_TOKEN

        self.word_dict[self.SOS_TOKEN] = self.SOS_INDEX
        self.word_dict[self.EOS_TOKEN] = self.EOS_INDEX
        self.word_dict[self.PAD_TOKEN] = self.PAD_INDEX
        self.word_dict[self.UNK_TOKEN] = self.UNK_INDEX

        index = len(self.itos.keys())

        self.total_predefined = len(self.itos.keys())

        if vocab_file is not None:
            if not os.path.isabs(vocab_file) and data_dir is not None:
                vocab_file = os.path.join(data_dir, vocab_file)
                vocab_file = get_absolute_path(vocab_file)

            if not PathManager.exists(vocab_file):
                raise RuntimeError("Vocab not found at " + vocab_file)

            with PathManager.open(vocab_file, "r") as f:
                for line in f:
                    self.itos[index] = line.strip()
                    self.word_dict[line.strip()] = index
                    index += 1

        self.word_dict[self.SOS_TOKEN] = self.SOS_INDEX
        self.word_dict[self.EOS_TOKEN] = self.EOS_INDEX
        self.word_dict[self.PAD_TOKEN] = self.PAD_INDEX
        self.word_dict[self.UNK_TOKEN] = self.UNK_INDEX
        # Return unk index by default
        self.stoi = defaultdict(self.get_unk_index)
        self.stoi.update(self.word_dict)

        self.vectors = torch.FloatTensor(self.get_size(), embedding_dim)
コード例 #17
0
ファイル: dataset.py プロジェクト: sonuagarwal1008/mmf
 def _get_absolute_path(self, scene_graph_file):
     data_dir = get_mmf_env(key="data_dir")
     absolute_scene_graph_file = get_absolute_path(
         os.path.join(data_dir, scene_graph_file))
     return absolute_scene_graph_file