def build_retrieval_model(cfg): """ Builds the model on 1-gpu and initializes from the weight. """ logging.info("Building model....") model = build_model(cfg.MODEL, cfg.OPTIMIZER) if PathManager.exists(cfg.MODEL.WEIGHTS_INIT.PARAMS_FILE): init_weights_path = cfg.MODEL.WEIGHTS_INIT.PARAMS_FILE logging.info(f"Initializing model from: {init_weights_path}") weights = torch.load(init_weights_path, map_location=torch.device("cuda")) skip_layers = cfg.MODEL.WEIGHTS_INIT.get("SKIP_LAYERS", []) replace_prefix = cfg.MODEL.WEIGHTS_INIT.get("REMOVE_PREFIX", None) append_prefix = cfg.MODEL.WEIGHTS_INIT.get("APPEND_PREFIX", None) state_dict_key_name = cfg.MODEL.WEIGHTS_INIT.get( "STATE_DICT_KEY_NAME", None) init_model_from_consolidated_weights( cfg, model, weights, state_dict_key_name=state_dict_key_name, skip_layers=skip_layers, replace_prefix=replace_prefix, append_prefix=append_prefix, ) else: # We only throw the warning if not weights file is provided. We want to # benchmark the random initialization model too and hence support that. logging.warning("Model is randomly initialized....") logging.info(f"Model is:\n {model}") return model
def init_model_from_weights_params_file(self, config: AttrDict, checkpoint: Dict[str, Any]): """ We initialize the weights from this checkpoint. However, we don't care about the other metadata like iteration number etc. So the method only reads the state_dict """ # TODO (Quentin) - support: different number of nodes + different checkpoint # formats + fine tuning # Special cases in which we want to evaluate a model trained with FSDP: # - we need to benchmark it in FSDP mode as well and with the same number of # workers # - we need to have it trained with VISSL (no support for other checkpoint # types for now) if isinstance(self.trunk, FeatureExtractorModel) and isinstance( self.trunk.base_model, FSDP): if checkpoint["type"] == CheckpointType.consolidated.name: self.trunk.base_model.load_state_dict( checkpoint["classy_state_dict"]["base_model"]["model"] ["trunk"]) else: self.trunk.base_model.load_local_state_dict( checkpoint["classy_state_dict"]["base_model"]["model"] ["trunk"]) fsdp_recursive_reset_lazy_init(self.trunk.base_model) elif isinstance(self.trunk, FSDP): if checkpoint["type"] == CheckpointType.consolidated.name: self.trunk.load_state_dict(checkpoint["classy_state_dict"] ["base_model"]["model"]["trunk"]) else: self.trunk.load_local_state_dict( checkpoint["classy_state_dict"]["base_model"]["model"] ["trunk"]) fsdp_recursive_reset_lazy_init(self.trunk) # General case: support for multiple format of checkpoint else: params_from_file = config["MODEL"]["WEIGHTS_INIT"] skip_layers = params_from_file.get("SKIP_LAYERS", []) replace_prefix = params_from_file.get("REMOVE_PREFIX", None) append_prefix = params_from_file.get("APPEND_PREFIX", None) state_dict_key_name = params_from_file.get("STATE_DICT_KEY_NAME", None) init_model_from_consolidated_weights( config, self, checkpoint, state_dict_key_name=state_dict_key_name, skip_layers=skip_layers, replace_prefix=replace_prefix, append_prefix=append_prefix, )
def init_model_from_weights_params_file(self, config: AttrDict, checkpoint: Dict[str, Any], strict: bool = False): """ We initialize the weights from this checkpoint. However, we don't care about the other metadata like iteration number etc. So the method only reads the state_dict """ # Specific case for FSDP trunks: # - models have to be created with VISSL # - checkpoints have to be created with VISSL if isinstance(self.trunk, FeatureExtractorModel) and isinstance( self.trunk.base_model, FSDP): # Linear evaluation / extraction from FSDP models: # - load the trunk (complete strict load) # - load the head (optional and partial load supported) logging.info("Loading FSDP trunk in extraction mode") CheckpointLoader.init_fsdp_model_from_weights( self.trunk.base_model, checkpoint, weights_path=[ "classy_state_dict", "base_model", "model", "trunk" ], ) fsdp_recursive_reset_lazy_init(self.trunk.base_model) if should_init_head_weights(config.MODEL): self._init_fsdp_model_heads_from_weights_params_file( checkpoint) elif isinstance(self.trunk, FSDP): # Fine-tuning of FSDP models: # - load the trunk (complete strict load) # - load the head (optional and partial load supported) logging.info("Loading FSDP trunk") CheckpointLoader.init_fsdp_model_from_weights( self.trunk, checkpoint, weights_path=[ "classy_state_dict", "base_model", "model", "trunk" ], ) fsdp_recursive_reset_lazy_init(self.trunk) if should_init_head_weights(config.MODEL): self._init_fsdp_model_heads_from_weights_params_file( checkpoint) # General case: support for multiple format of checkpoint else: params_from_file = config["MODEL"]["WEIGHTS_INIT"] skip_layers = params_from_file.get("SKIP_LAYERS", []) replace_prefix = params_from_file.get("REMOVE_PREFIX", None) append_prefix = params_from_file.get("APPEND_PREFIX", None) state_dict_key_name = params_from_file.get("STATE_DICT_KEY_NAME", None) init_model_from_consolidated_weights( config, self, checkpoint, state_dict_key_name=state_dict_key_name, skip_layers=skip_layers, replace_prefix=replace_prefix, append_prefix=append_prefix, strict=strict, )