Ejemplo n.º 1
0
    def test_import_user_module_from_file(self):
        self.assertIsNone(registry.get_builder_class("always_one"))
        self.assertIsNone(registry.get_model_class("simple"))

        user_dir = self._get_user_dir()
        user_file = os.path.join(user_dir, "models", "simple.py")
        import_user_module(user_file)
        # Only model should be found and build should be none
        self.assertIsNone(registry.get_builder_class("always_one"))
        self.assertIsNotNone(registry.get_model_class("simple"))
        self.assertTrue("mmf_user_dir" in sys.modules)
        self.assertTrue(user_dir in get_mmf_env("user_dir"))
Ejemplo n.º 2
0
    def test_import_user_module_from_directory_absolute(self, abs_path=True):
        # Make sure the modules are not available first
        self.assertIsNone(registry.get_builder_class("always_one"))
        self.assertIsNone(registry.get_model_class("simple"))
        self.assertFalse("mmf_user_dir" in sys.modules)

        # Now, import and test
        user_dir = self._get_user_dir(abs_path)
        import_user_module(user_dir)
        self.assertIsNotNone(registry.get_builder_class("always_one"))
        self.assertIsNotNone(registry.get_model_class("simple"))
        self.assertTrue("mmf_user_dir" in sys.modules)
        self.assertTrue(user_dir in get_mmf_env("user_dir"))
Ejemplo n.º 3
0
 def setUp(self):
     test_utils.setup_proxy()
     setup_imports()
     model_name = "vilbert"
     args = test_utils.dummy_args(model=model_name)
     configuration = Configuration(args)
     config = configuration.get_config()
     model_class = registry.get_model_class(model_name)
     self.vision_feature_size = 1024
     self.vision_target_size = 1279
     config.model_config[model_name]["training_head_type"] = "pretraining"
     config.model_config[model_name][
         "visual_embedding_dim"] = self.vision_feature_size
     config.model_config[model_name][
         "v_feature_size"] = self.vision_feature_size
     config.model_config[model_name][
         "v_target_size"] = self.vision_target_size
     config.model_config[model_name]["dynamic_attention"] = False
     self.pretrain_model = model_class(config.model_config[model_name])
     self.pretrain_model.build()
     config.model_config[model_name][
         "training_head_type"] = "classification"
     config.model_config[model_name]["num_labels"] = 2
     self.finetune_model = model_class(config.model_config[model_name])
     self.finetune_model.build()
Ejemplo n.º 4
0
    def from_pretrained(cls, model_name_or_path, *args, **kwargs):
        # Check if the path exists, if not it is pretrained, otherwise,
        # we will try to load the checkpoint from the path
        if not PathManager.exists(model_name_or_path):
            model_key = model_name_or_path.split(".")[0]
            model_cls = registry.get_model_class(model_key)
            assert (
                model_cls == cls
            ), f"Incorrect pretrained model key {model_name_or_path} "
            "for class {cls.__name__}"
        output = load_pretrained_model(model_name_or_path, *args, **kwargs)
        config, checkpoint, full_config = (
            output["config"],
            output["checkpoint"],
            output["full_config"],
        )

        # Save original config for state reset later
        config_temp_holder = registry.get("config")
        # Register full config from checkpoint when loading the model
        registry.register("config", full_config)

        # Some models need registry updates to be load pretrained model
        # If they have this method, call it so they can update accordingly
        if hasattr(cls, "update_registry_for_pretrained"):
            cls.update_registry_for_pretrained(config, checkpoint, output)

        instance = cls(config)
        instance.is_pretrained = True
        instance.build()
        incompatible_keys = instance.load_state_dict(checkpoint, strict=False)

        # The model has loaded, reset the state
        registry.register("config", config_temp_holder)

        if len(incompatible_keys.missing_keys) != 0:
            logger.warning(
                f"Missing keys {incompatible_keys.missing_keys} in the"
                + " checkpoint.\n"
                + "If this is not your checkpoint, please open up an "
                + "issue on MMF GitHub. \n"
                + f"Unexpected keys if any: {incompatible_keys.unexpected_keys}"
            )

        if len(incompatible_keys.unexpected_keys) != 0:
            logger.warning(
                "Unexpected keys in state dict: "
                + f"{incompatible_keys.unexpected_keys} \n"
                + "This is usually not a problem with pretrained models, but "
                + "if this is your own model, please double check. \n"
                + "If you think this is an issue, please open up a "
                + "bug at MMF GitHub."
            )

        instance.eval()

        return instance
Ejemplo n.º 5
0
 def setUp(self):
     setup_imports()
     self.model_name = "mmf_transformer"
     args = test_utils.dummy_args(model=self.model_name)
     configuration = Configuration(args)
     self.config = configuration.get_config()
     self.model_class = registry.get_model_class(self.model_name)
     self.finetune_model = self.model_class(
         self.config.model_config[self.model_name])
     self.finetune_model.build()
Ejemplo n.º 6
0
 def setUp(self):
     setup_imports()
     model_name = "mmbt"
     args = test_utils.dummy_args(model=model_name)
     configuration = Configuration(args)
     config = configuration.get_config()
     model_class = registry.get_model_class(model_name)
     config.model_config[model_name]["training_head_type"] = "classification"
     config.model_config[model_name]["num_labels"] = 2
     self.finetune_model = model_class(config.model_config[model_name])
     self.finetune_model.build()
Ejemplo n.º 7
0
    def _load_from_zoo(self, file):
        ckpt_config = self.trainer.config.checkpoint
        zoo_ckpt = load_pretrained_model(file)

        # If zoo_config_override, load the model directly using `from_pretrained`
        if ckpt_config.zoo_config_override:
            model_cls = registry.get_model_class(self.trainer.config.model)
            self.trainer.model = model_cls.from_pretrained(ckpt_config.resume_zoo)
            self.trainer.config.model_config = zoo_ckpt["full_config"].model_config
            return None, False
        else:
            return self.upgrade_state_dict(zoo_ckpt["checkpoint"]), True
Ejemplo n.º 8
0
Archivo: build.py Proyecto: zpppy/mmf
def build_model(config):
    model_name = config.model

    model_class = registry.get_model_class(model_name)

    if model_class is None:
        registry.get("writer").write("No model registered for name: %s" % model_name)
    model = model_class(config)

    if hasattr(model, "build"):
        model.load_requirements()
        model.build()
        model.init_losses()

    return model
Ejemplo n.º 9
0
    def __init__(self):
        self._init_processors()
        self.visual_bert = registry.get_model_class(
            "visual_bert").from_pretrained("visual_bert.pretrained.coco")

        # Add this option so that it only output hidden states
        self.visual_bert.model.output_hidden_states = True

        self.visual_bert.model.to("cuda")
        self.visual_bert.model.eval()

        # Add this option so that losses are not pushed into output
        self.visual_bert.training_head_type = "finetuning"

        self.detection_model = self._build_detection_model()
Ejemplo n.º 10
0
def build_model(config):
    model_name = config.model

    model_class = registry.get_model_class(model_name)

    if model_class is None:
        raise RuntimeError(f"No model registered for name: {model_name}")
    model = model_class(config)

    if hasattr(model, "build"):
        model.load_requirements()
        model.build()
        model.init_losses()

    return model
Ejemplo n.º 11
0
    def _build_model_config(self, config):
        model = config.model
        if model is None:
            raise KeyError("Required argument 'model' not passed")
        model_cls = registry.get_model_class(model)

        if model_cls is None:
            warning = f"No model named '{model}' has been registered"
            warnings.warn(warning)
            return OmegaConf.create()

        default_model_config_path = model_cls.config_path()

        if default_model_config_path is None:
            warning = "Model {}'s class has no default configuration provided".format(
                model)
            warnings.warn(warning)
            return OmegaConf.create()

        return load_yaml(default_model_config_path)
Ejemplo n.º 12
0
    def from_pretrained(cls, model_name, *args, **kwargs):
        model_key = model_name.split(".")[0]
        model_cls = registry.get_model_class(model_key)
        assert (
            model_cls == cls
        ), f"Incorrect pretrained model key {model_name} for class {cls.__name__}"
        output = load_pretrained_model(model_name, *args, **kwargs)
        config, checkpoint = output["config"], output["checkpoint"]

        # Some models need registry updates to be load pretrained model
        # If they have this method, call it so they can update accordingly
        if hasattr(cls, "update_registry_for_pretrained"):
            cls.update_registry_for_pretrained(config, checkpoint, output)

        instance = cls(config)
        instance.is_pretrained = True
        instance.build()
        incompatible_keys = instance.load_state_dict(checkpoint, strict=False)

        if len(incompatible_keys.missing_keys) != 0:
            logger.warning(
                f"Missing keys {incompatible_keys.missing_keys} in the"
                + " checkpoint.\n"
                + "If this is not your checkpoint, please open up an "
                + "issue on MMF GitHub. \n"
                + f"Unexpected keys if any: {incompatible_keys.unexpected_keys}"
            )

        if len(incompatible_keys.unexpected_keys) != 0:
            logger.warning(
                "Unexpected keys in state dict: "
                + f"{incompatible_keys.unexpected_keys} \n"
                + "This is usually not a problem with pretrained models, but "
                + "if this is your own model, please double check. \n"
                + "If you think this is an issue, please open up a "
                + "bug at MMF GitHub."
            )

        instance.eval()

        return instance
Ejemplo n.º 13
0
    def from_pretrained(cls, model_name, *args, **kwargs):
        model_key = model_name.split(".")[0]
        model_cls = registry.get_model_class(model_key)
        assert (
            model_cls == cls
        ), f"Incorrect pretrained model key {model_name} for class {cls.__name__}"
        output = load_pretrained_model(model_name, *args, **kwargs)
        config, checkpoint = output["config"], output["checkpoint"]

        # Some models need registry updates to be load pretrained model
        # If they have this method, call it so they can update accordingly
        if hasattr(cls, "update_registry_for_pretrained"):
            cls.update_registry_for_pretrained(config, checkpoint, output)

        instance = cls(config)
        instance.is_pretrained = True
        instance.build()
        instance.load_state_dict(checkpoint)
        instance.eval()

        return instance
Ejemplo n.º 14
0
def build_lightning_model(
    config: Union[DictConfig, "mmf.models.base_model.BaseModel.Config"],
    checkpoint_path: str = None,
) -> "mmf.models.base_model.BaseModel":
    from mmf.models.base_model import BaseModel

    if not checkpoint_path:
        model = build_model(config)
        model.is_pl_enabled = True
        return model

    # If it is not an OmegaConf object, create the object
    if not isinstance(config, DictConfig) and isinstance(config, BaseModel.Config):
        config = OmegaConf.structured(config)

    model_name = config.model
    model_class = registry.get_model_class(model_name)

    if model_class is None:
        raise RuntimeError(f"No model registered for name: {model_name}")

    """ model.build is called inside on_load_checkpoint as suggested here:
    https://github.com/PyTorchLightning/pytorch-lightning/issues/5410
    """

    if is_main():
        model_class.load_requirements(model_class, config=config)
        model = model_class.load_from_checkpoint(
            checkpoint_path, config=config, strict=False
        )
        synchronize()
    else:
        synchronize()
        model = model_class.load_from_checkpoint(
            checkpoint_path, config=config, strict=False
        )

    model.init_losses()
    model.is_pl_enabled = True
    return model
Ejemplo n.º 15
0
def build_model(
    config: Union[DictConfig, "mmf.models.base_model.BaseModel.Config"]
) -> "mmf.models.base_model.BaseModel":
    from mmf.models.base_model import BaseModel

    # If it is not an OmegaConf object, create the object
    if not isinstance(config, DictConfig) and isinstance(config, BaseModel.Config):
        config = OmegaConf.structured(config)

    model_name = config.model
    model_class = registry.get_model_class(model_name)

    if model_class is None:
        raise RuntimeError(f"No model registered for name: {model_name}")
    model = model_class(config)

    if hasattr(model, "build"):
        model.load_requirements()
        model.build()
        model.init_losses()

    return model
Ejemplo n.º 16
0
def build_model(
    config: Union[DictConfig, "mmf.models.base_model.BaseModel.Config"]
) -> "mmf.models.base_model.BaseModel":
    from mmf.models.base_model import BaseModel

    # If it is not an OmegaConf object, create the object
    if not isinstance(config, DictConfig) and isinstance(
            config, BaseModel.Config):
        config = OmegaConf.structured(config)

    model_name = config.model
    model_class = registry.get_model_class(model_name)

    if model_class is None:
        raise RuntimeError(f"No model registered for name: {model_name}")
    model = model_class(config)

    if hasattr(model, "build"):
        model.load_requirements()
        """ Model build involves checkpoint loading
        If the checkpoint is not available the underlying
        methods try to download it.
        Let master build the model (download the checkpoints) while
        other ranks wait for the sync message
        Once the master has downloaded the checkpoint and built the
        model it sends the sync message, completing the synchronization
        now other cores can proceed to build the model
        using already downloaded checkpoint.
        """
        if is_master():
            model.build()
            synchronize()
        else:
            synchronize()
            model.build()
        model.init_losses()

    return model
Ejemplo n.º 17
0
import torch.nn as nn
import torch.nn.functional as F
import torch
from mmf.common.registry import registry
from transformers import RobertaModel

model_cls = registry.get_model_class("visual_bert")


class MyVisualBert(nn.Module):
    def __init__(self):
        super(MyVisualBert, self).__init__()
        self.visual_bert = model_cls.from_pretrained(
            "visual_bert.pretrained.coco").model
        # self.image_features = nn.Sequential(nn.Linear(2048, 2048), nn.ReLU(), nn.LayerNorm(2048))
        self.classifier = nn.Sequential(nn.Linear(768, 768), nn.ReLU(),
                                        nn.Dropout(.1), nn.LayerNorm(768),
                                        nn.Linear(768, 1))

    def forward(self, input_ids, attention_mask, visual_embeddings):

        device = input_ids.device
        # visual_embeddings = self.image_features(visual_embeddings)
        embs, pooled, _ = self.visual_bert.bert(
            input_ids=input_ids,  # tokens
            attention_mask=attention_mask,  # attention mask phrase + mot 
            visual_embeddings=visual_embeddings,  # 2048
            visual_embeddings_type=torch.ones(
                (input_ids.shape[0], visual_embeddings.shape[1]),
                dtype=torch.long).to(device))