def init_model_from_weights_params_file(self, config: AttrDict, checkpoint: Dict[str, Any]): """ We initialize the weights from this checkpoint. However, we don't care about the other metadata like iteration number etc. So the method only reads the state_dict """ # TODO (Quentin) - support: different number of nodes + different checkpoint # formats + fine tuning # Special cases in which we want to evaluate a model trained with FSDP: # - we need to benchmark it in FSDP mode as well and with the same number of # workers # - we need to have it trained with VISSL (no support for other checkpoint # types for now) if isinstance(self.trunk, FeatureExtractorModel) and isinstance( self.trunk.base_model, FSDP): CheckpointLoader.init_fsdp_model_from_weights( self.trunk.base_model, checkpoint, weights_path=[ "classy_state_dict", "base_model", "model", "trunk" ], ) fsdp_recursive_reset_lazy_init(self.trunk.base_model) elif isinstance(self.trunk, FSDP): CheckpointLoader.init_fsdp_model_from_weights( self.trunk, checkpoint, weights_path=[ "classy_state_dict", "base_model", "model", "trunk" ], ) fsdp_recursive_reset_lazy_init(self.trunk) # General case: support for multiple format of checkpoint else: params_from_file = config["MODEL"]["WEIGHTS_INIT"] skip_layers = params_from_file.get("SKIP_LAYERS", []) replace_prefix = params_from_file.get("REMOVE_PREFIX", None) append_prefix = params_from_file.get("APPEND_PREFIX", None) state_dict_key_name = params_from_file.get("STATE_DICT_KEY_NAME", None) init_model_from_consolidated_weights( config, self, checkpoint, state_dict_key_name=state_dict_key_name, skip_layers=skip_layers, replace_prefix=replace_prefix, append_prefix=append_prefix, )
def _init_fsdp_model_heads_from_weights_params_file( self, checkpoint: Dict[str, Any]): for i, head in enumerate(self.heads): logging.info(f"Loading FSDP head {i}") if isinstance(head, FSDP): CheckpointLoader.init_fsdp_model_from_weights( head, checkpoint, weights_path=[ "classy_state_dict", "base_model", "model", "heads" ], strict=False, head_index=i, ) fsdp_recursive_reset_lazy_init(head)
def init_model_from_weights_params_file(self, config: AttrDict, checkpoint: Dict[str, Any], strict: bool = False): """ We initialize the weights from this checkpoint. However, we don't care about the other metadata like iteration number etc. So the method only reads the state_dict """ # Specific case for FSDP trunks: # - models have to be created with VISSL # - checkpoints have to be created with VISSL if isinstance(self.trunk, FeatureExtractorModel) and isinstance( self.trunk.base_model, FSDP): # Linear evaluation / extraction from FSDP models: # - load the trunk (complete strict load) # - load the head (optional and partial load supported) logging.info("Loading FSDP trunk in extraction mode") CheckpointLoader.init_fsdp_model_from_weights( self.trunk.base_model, checkpoint, weights_path=[ "classy_state_dict", "base_model", "model", "trunk" ], ) fsdp_recursive_reset_lazy_init(self.trunk.base_model) if should_init_head_weights(config.MODEL): self._init_fsdp_model_heads_from_weights_params_file( checkpoint) elif isinstance(self.trunk, FSDP): # Fine-tuning of FSDP models: # - load the trunk (complete strict load) # - load the head (optional and partial load supported) logging.info("Loading FSDP trunk") CheckpointLoader.init_fsdp_model_from_weights( self.trunk, checkpoint, weights_path=[ "classy_state_dict", "base_model", "model", "trunk" ], ) fsdp_recursive_reset_lazy_init(self.trunk) if should_init_head_weights(config.MODEL): self._init_fsdp_model_heads_from_weights_params_file( checkpoint) # General case: support for multiple format of checkpoint else: params_from_file = config["MODEL"]["WEIGHTS_INIT"] skip_layers = params_from_file.get("SKIP_LAYERS", []) replace_prefix = params_from_file.get("REMOVE_PREFIX", None) append_prefix = params_from_file.get("APPEND_PREFIX", None) state_dict_key_name = params_from_file.get("STATE_DICT_KEY_NAME", None) init_model_from_consolidated_weights( config, self, checkpoint, state_dict_key_name=state_dict_key_name, skip_layers=skip_layers, replace_prefix=replace_prefix, append_prefix=append_prefix, strict=strict, )