Пример #1
0
    def init_model_from_weights_params_file(self, config: AttrDict,
                                            checkpoint: Dict[str, Any]):
        """
        We initialize the weights from this checkpoint. However, we don't care
        about the other metadata like iteration number etc.
        So the method only reads the state_dict
        """

        # TODO (Quentin) - support: different number of nodes + different checkpoint
        # formats + fine tuning
        # Special cases in which we want to evaluate a model trained with FSDP:
        # - we need to benchmark it in FSDP mode as well and with the same number of
        #   workers
        # - we need to have it trained with VISSL (no support for other checkpoint
        #   types for now)
        if isinstance(self.trunk, FeatureExtractorModel) and isinstance(
                self.trunk.base_model, FSDP):
            CheckpointLoader.init_fsdp_model_from_weights(
                self.trunk.base_model,
                checkpoint,
                weights_path=[
                    "classy_state_dict", "base_model", "model", "trunk"
                ],
            )
            fsdp_recursive_reset_lazy_init(self.trunk.base_model)
        elif isinstance(self.trunk, FSDP):
            CheckpointLoader.init_fsdp_model_from_weights(
                self.trunk,
                checkpoint,
                weights_path=[
                    "classy_state_dict", "base_model", "model", "trunk"
                ],
            )
            fsdp_recursive_reset_lazy_init(self.trunk)

        # General case: support for multiple format of checkpoint
        else:
            params_from_file = config["MODEL"]["WEIGHTS_INIT"]
            skip_layers = params_from_file.get("SKIP_LAYERS", [])
            replace_prefix = params_from_file.get("REMOVE_PREFIX", None)
            append_prefix = params_from_file.get("APPEND_PREFIX", None)
            state_dict_key_name = params_from_file.get("STATE_DICT_KEY_NAME",
                                                       None)
            init_model_from_consolidated_weights(
                config,
                self,
                checkpoint,
                state_dict_key_name=state_dict_key_name,
                skip_layers=skip_layers,
                replace_prefix=replace_prefix,
                append_prefix=append_prefix,
            )
Пример #2
0
 def _init_fsdp_model_heads_from_weights_params_file(
         self, checkpoint: Dict[str, Any]):
     for i, head in enumerate(self.heads):
         logging.info(f"Loading FSDP head {i}")
         if isinstance(head, FSDP):
             CheckpointLoader.init_fsdp_model_from_weights(
                 head,
                 checkpoint,
                 weights_path=[
                     "classy_state_dict", "base_model", "model", "heads"
                 ],
                 strict=False,
                 head_index=i,
             )
             fsdp_recursive_reset_lazy_init(head)
Пример #3
0
    def init_model_from_weights_params_file(self,
                                            config: AttrDict,
                                            checkpoint: Dict[str, Any],
                                            strict: bool = False):
        """
        We initialize the weights from this checkpoint. However, we don't care
        about the other metadata like iteration number etc.
        So the method only reads the state_dict
        """

        # Specific case for FSDP trunks:
        # - models have to be created with VISSL
        # - checkpoints have to be created with VISSL
        if isinstance(self.trunk, FeatureExtractorModel) and isinstance(
                self.trunk.base_model, FSDP):
            # Linear evaluation / extraction from FSDP models:
            # - load the trunk (complete strict load)
            # - load the head (optional and partial load supported)
            logging.info("Loading FSDP trunk in extraction mode")
            CheckpointLoader.init_fsdp_model_from_weights(
                self.trunk.base_model,
                checkpoint,
                weights_path=[
                    "classy_state_dict", "base_model", "model", "trunk"
                ],
            )
            fsdp_recursive_reset_lazy_init(self.trunk.base_model)
            if should_init_head_weights(config.MODEL):
                self._init_fsdp_model_heads_from_weights_params_file(
                    checkpoint)

        elif isinstance(self.trunk, FSDP):
            # Fine-tuning of FSDP models:
            # - load the trunk (complete strict load)
            # - load the head (optional and partial load supported)
            logging.info("Loading FSDP trunk")
            CheckpointLoader.init_fsdp_model_from_weights(
                self.trunk,
                checkpoint,
                weights_path=[
                    "classy_state_dict", "base_model", "model", "trunk"
                ],
            )
            fsdp_recursive_reset_lazy_init(self.trunk)
            if should_init_head_weights(config.MODEL):
                self._init_fsdp_model_heads_from_weights_params_file(
                    checkpoint)

        # General case: support for multiple format of checkpoint
        else:
            params_from_file = config["MODEL"]["WEIGHTS_INIT"]
            skip_layers = params_from_file.get("SKIP_LAYERS", [])
            replace_prefix = params_from_file.get("REMOVE_PREFIX", None)
            append_prefix = params_from_file.get("APPEND_PREFIX", None)
            state_dict_key_name = params_from_file.get("STATE_DICT_KEY_NAME",
                                                       None)
            init_model_from_consolidated_weights(
                config,
                self,
                checkpoint,
                state_dict_key_name=state_dict_key_name,
                skip_layers=skip_layers,
                replace_prefix=replace_prefix,
                append_prefix=append_prefix,
                strict=strict,
            )