def test_import_user_module_from_file(self): self.assertIsNone(registry.get_builder_class("always_one")) self.assertIsNone(registry.get_model_class("simple")) user_dir = self._get_user_dir() user_file = os.path.join(user_dir, "models", "simple.py") import_user_module(user_file) # Only model should be found and build should be none self.assertIsNone(registry.get_builder_class("always_one")) self.assertIsNotNone(registry.get_model_class("simple")) self.assertTrue("mmf_user_dir" in sys.modules) self.assertTrue(user_dir in get_mmf_env("user_dir"))
def test_import_user_module_from_directory_absolute(self, abs_path=True): # Make sure the modules are not available first self.assertIsNone(registry.get_builder_class("always_one")) self.assertIsNone(registry.get_model_class("simple")) self.assertFalse("mmf_user_dir" in sys.modules) # Now, import and test user_dir = self._get_user_dir(abs_path) import_user_module(user_dir) self.assertIsNotNone(registry.get_builder_class("always_one")) self.assertIsNotNone(registry.get_model_class("simple")) self.assertTrue("mmf_user_dir" in sys.modules) self.assertTrue(user_dir in get_mmf_env("user_dir"))
def setUp(self): test_utils.setup_proxy() setup_imports() model_name = "vilbert" args = test_utils.dummy_args(model=model_name) configuration = Configuration(args) config = configuration.get_config() model_class = registry.get_model_class(model_name) self.vision_feature_size = 1024 self.vision_target_size = 1279 config.model_config[model_name]["training_head_type"] = "pretraining" config.model_config[model_name][ "visual_embedding_dim"] = self.vision_feature_size config.model_config[model_name][ "v_feature_size"] = self.vision_feature_size config.model_config[model_name][ "v_target_size"] = self.vision_target_size config.model_config[model_name]["dynamic_attention"] = False self.pretrain_model = model_class(config.model_config[model_name]) self.pretrain_model.build() config.model_config[model_name][ "training_head_type"] = "classification" config.model_config[model_name]["num_labels"] = 2 self.finetune_model = model_class(config.model_config[model_name]) self.finetune_model.build()
def from_pretrained(cls, model_name_or_path, *args, **kwargs): # Check if the path exists, if not it is pretrained, otherwise, # we will try to load the checkpoint from the path if not PathManager.exists(model_name_or_path): model_key = model_name_or_path.split(".")[0] model_cls = registry.get_model_class(model_key) assert ( model_cls == cls ), f"Incorrect pretrained model key {model_name_or_path} " "for class {cls.__name__}" output = load_pretrained_model(model_name_or_path, *args, **kwargs) config, checkpoint, full_config = ( output["config"], output["checkpoint"], output["full_config"], ) # Save original config for state reset later config_temp_holder = registry.get("config") # Register full config from checkpoint when loading the model registry.register("config", full_config) # Some models need registry updates to be load pretrained model # If they have this method, call it so they can update accordingly if hasattr(cls, "update_registry_for_pretrained"): cls.update_registry_for_pretrained(config, checkpoint, output) instance = cls(config) instance.is_pretrained = True instance.build() incompatible_keys = instance.load_state_dict(checkpoint, strict=False) # The model has loaded, reset the state registry.register("config", config_temp_holder) if len(incompatible_keys.missing_keys) != 0: logger.warning( f"Missing keys {incompatible_keys.missing_keys} in the" + " checkpoint.\n" + "If this is not your checkpoint, please open up an " + "issue on MMF GitHub. \n" + f"Unexpected keys if any: {incompatible_keys.unexpected_keys}" ) if len(incompatible_keys.unexpected_keys) != 0: logger.warning( "Unexpected keys in state dict: " + f"{incompatible_keys.unexpected_keys} \n" + "This is usually not a problem with pretrained models, but " + "if this is your own model, please double check. \n" + "If you think this is an issue, please open up a " + "bug at MMF GitHub." ) instance.eval() return instance
def setUp(self): setup_imports() self.model_name = "mmf_transformer" args = test_utils.dummy_args(model=self.model_name) configuration = Configuration(args) self.config = configuration.get_config() self.model_class = registry.get_model_class(self.model_name) self.finetune_model = self.model_class( self.config.model_config[self.model_name]) self.finetune_model.build()
def setUp(self): setup_imports() model_name = "mmbt" args = test_utils.dummy_args(model=model_name) configuration = Configuration(args) config = configuration.get_config() model_class = registry.get_model_class(model_name) config.model_config[model_name]["training_head_type"] = "classification" config.model_config[model_name]["num_labels"] = 2 self.finetune_model = model_class(config.model_config[model_name]) self.finetune_model.build()
def _load_from_zoo(self, file): ckpt_config = self.trainer.config.checkpoint zoo_ckpt = load_pretrained_model(file) # If zoo_config_override, load the model directly using `from_pretrained` if ckpt_config.zoo_config_override: model_cls = registry.get_model_class(self.trainer.config.model) self.trainer.model = model_cls.from_pretrained(ckpt_config.resume_zoo) self.trainer.config.model_config = zoo_ckpt["full_config"].model_config return None, False else: return self.upgrade_state_dict(zoo_ckpt["checkpoint"]), True
def build_model(config): model_name = config.model model_class = registry.get_model_class(model_name) if model_class is None: registry.get("writer").write("No model registered for name: %s" % model_name) model = model_class(config) if hasattr(model, "build"): model.load_requirements() model.build() model.init_losses() return model
def __init__(self): self._init_processors() self.visual_bert = registry.get_model_class( "visual_bert").from_pretrained("visual_bert.pretrained.coco") # Add this option so that it only output hidden states self.visual_bert.model.output_hidden_states = True self.visual_bert.model.to("cuda") self.visual_bert.model.eval() # Add this option so that losses are not pushed into output self.visual_bert.training_head_type = "finetuning" self.detection_model = self._build_detection_model()
def build_model(config): model_name = config.model model_class = registry.get_model_class(model_name) if model_class is None: raise RuntimeError(f"No model registered for name: {model_name}") model = model_class(config) if hasattr(model, "build"): model.load_requirements() model.build() model.init_losses() return model
def _build_model_config(self, config): model = config.model if model is None: raise KeyError("Required argument 'model' not passed") model_cls = registry.get_model_class(model) if model_cls is None: warning = f"No model named '{model}' has been registered" warnings.warn(warning) return OmegaConf.create() default_model_config_path = model_cls.config_path() if default_model_config_path is None: warning = "Model {}'s class has no default configuration provided".format( model) warnings.warn(warning) return OmegaConf.create() return load_yaml(default_model_config_path)
def from_pretrained(cls, model_name, *args, **kwargs): model_key = model_name.split(".")[0] model_cls = registry.get_model_class(model_key) assert ( model_cls == cls ), f"Incorrect pretrained model key {model_name} for class {cls.__name__}" output = load_pretrained_model(model_name, *args, **kwargs) config, checkpoint = output["config"], output["checkpoint"] # Some models need registry updates to be load pretrained model # If they have this method, call it so they can update accordingly if hasattr(cls, "update_registry_for_pretrained"): cls.update_registry_for_pretrained(config, checkpoint, output) instance = cls(config) instance.is_pretrained = True instance.build() incompatible_keys = instance.load_state_dict(checkpoint, strict=False) if len(incompatible_keys.missing_keys) != 0: logger.warning( f"Missing keys {incompatible_keys.missing_keys} in the" + " checkpoint.\n" + "If this is not your checkpoint, please open up an " + "issue on MMF GitHub. \n" + f"Unexpected keys if any: {incompatible_keys.unexpected_keys}" ) if len(incompatible_keys.unexpected_keys) != 0: logger.warning( "Unexpected keys in state dict: " + f"{incompatible_keys.unexpected_keys} \n" + "This is usually not a problem with pretrained models, but " + "if this is your own model, please double check. \n" + "If you think this is an issue, please open up a " + "bug at MMF GitHub." ) instance.eval() return instance
def from_pretrained(cls, model_name, *args, **kwargs): model_key = model_name.split(".")[0] model_cls = registry.get_model_class(model_key) assert ( model_cls == cls ), f"Incorrect pretrained model key {model_name} for class {cls.__name__}" output = load_pretrained_model(model_name, *args, **kwargs) config, checkpoint = output["config"], output["checkpoint"] # Some models need registry updates to be load pretrained model # If they have this method, call it so they can update accordingly if hasattr(cls, "update_registry_for_pretrained"): cls.update_registry_for_pretrained(config, checkpoint, output) instance = cls(config) instance.is_pretrained = True instance.build() instance.load_state_dict(checkpoint) instance.eval() return instance
def build_lightning_model( config: Union[DictConfig, "mmf.models.base_model.BaseModel.Config"], checkpoint_path: str = None, ) -> "mmf.models.base_model.BaseModel": from mmf.models.base_model import BaseModel if not checkpoint_path: model = build_model(config) model.is_pl_enabled = True return model # If it is not an OmegaConf object, create the object if not isinstance(config, DictConfig) and isinstance(config, BaseModel.Config): config = OmegaConf.structured(config) model_name = config.model model_class = registry.get_model_class(model_name) if model_class is None: raise RuntimeError(f"No model registered for name: {model_name}") """ model.build is called inside on_load_checkpoint as suggested here: https://github.com/PyTorchLightning/pytorch-lightning/issues/5410 """ if is_main(): model_class.load_requirements(model_class, config=config) model = model_class.load_from_checkpoint( checkpoint_path, config=config, strict=False ) synchronize() else: synchronize() model = model_class.load_from_checkpoint( checkpoint_path, config=config, strict=False ) model.init_losses() model.is_pl_enabled = True return model
def build_model( config: Union[DictConfig, "mmf.models.base_model.BaseModel.Config"] ) -> "mmf.models.base_model.BaseModel": from mmf.models.base_model import BaseModel # If it is not an OmegaConf object, create the object if not isinstance(config, DictConfig) and isinstance(config, BaseModel.Config): config = OmegaConf.structured(config) model_name = config.model model_class = registry.get_model_class(model_name) if model_class is None: raise RuntimeError(f"No model registered for name: {model_name}") model = model_class(config) if hasattr(model, "build"): model.load_requirements() model.build() model.init_losses() return model
def build_model( config: Union[DictConfig, "mmf.models.base_model.BaseModel.Config"] ) -> "mmf.models.base_model.BaseModel": from mmf.models.base_model import BaseModel # If it is not an OmegaConf object, create the object if not isinstance(config, DictConfig) and isinstance( config, BaseModel.Config): config = OmegaConf.structured(config) model_name = config.model model_class = registry.get_model_class(model_name) if model_class is None: raise RuntimeError(f"No model registered for name: {model_name}") model = model_class(config) if hasattr(model, "build"): model.load_requirements() """ Model build involves checkpoint loading If the checkpoint is not available the underlying methods try to download it. Let master build the model (download the checkpoints) while other ranks wait for the sync message Once the master has downloaded the checkpoint and built the model it sends the sync message, completing the synchronization now other cores can proceed to build the model using already downloaded checkpoint. """ if is_master(): model.build() synchronize() else: synchronize() model.build() model.init_losses() return model
import torch.nn as nn import torch.nn.functional as F import torch from mmf.common.registry import registry from transformers import RobertaModel model_cls = registry.get_model_class("visual_bert") class MyVisualBert(nn.Module): def __init__(self): super(MyVisualBert, self).__init__() self.visual_bert = model_cls.from_pretrained( "visual_bert.pretrained.coco").model # self.image_features = nn.Sequential(nn.Linear(2048, 2048), nn.ReLU(), nn.LayerNorm(2048)) self.classifier = nn.Sequential(nn.Linear(768, 768), nn.ReLU(), nn.Dropout(.1), nn.LayerNorm(768), nn.Linear(768, 1)) def forward(self, input_ids, attention_mask, visual_embeddings): device = input_ids.device # visual_embeddings = self.image_features(visual_embeddings) embs, pooled, _ = self.visual_bert.bert( input_ids=input_ids, # tokens attention_mask=attention_mask, # attention mask phrase + mot visual_embeddings=visual_embeddings, # 2048 visual_embeddings_type=torch.ones( (input_ids.shape[0], visual_embeddings.shape[1]), dtype=torch.long).to(device))