Esempio n. 1
0
def build_retrieval_model(cfg):
    """
    Builds the model on 1-gpu and initializes from the weight.
    """
    logging.info("Building model....")
    model = build_model(cfg.MODEL, cfg.OPTIMIZER)
    if PathManager.exists(cfg.MODEL.WEIGHTS_INIT.PARAMS_FILE):
        init_weights_path = cfg.MODEL.WEIGHTS_INIT.PARAMS_FILE
        logging.info(f"Initializing model from: {init_weights_path}")
        weights = torch.load(init_weights_path,
                             map_location=torch.device("cuda"))
        skip_layers = cfg.MODEL.WEIGHTS_INIT.get("SKIP_LAYERS", [])
        replace_prefix = cfg.MODEL.WEIGHTS_INIT.get("REMOVE_PREFIX", None)
        append_prefix = cfg.MODEL.WEIGHTS_INIT.get("APPEND_PREFIX", None)
        state_dict_key_name = cfg.MODEL.WEIGHTS_INIT.get(
            "STATE_DICT_KEY_NAME", None)

        init_model_from_consolidated_weights(
            cfg,
            model,
            weights,
            state_dict_key_name=state_dict_key_name,
            skip_layers=skip_layers,
            replace_prefix=replace_prefix,
            append_prefix=append_prefix,
        )
    else:
        # We only throw the warning if not weights file is provided. We want to
        # benchmark the random initialization model too and hence support that.
        logging.warning("Model is randomly initialized....")
    logging.info(f"Model is:\n {model}")
    return model
Esempio n. 2
0
    def init_model_from_weights_params_file(self, config: AttrDict,
                                            checkpoint: Dict[str, Any]):
        """
        We initialize the weights from this checkpoint. However, we don't care
        about the other metadata like iteration number etc.
        So the method only reads the state_dict
        """

        # TODO (Quentin) - support: different number of nodes + different checkpoint
        # formats + fine tuning
        # Special cases in which we want to evaluate a model trained with FSDP:
        # - we need to benchmark it in FSDP mode as well and with the same number of
        #   workers
        # - we need to have it trained with VISSL (no support for other checkpoint
        #   types for now)
        if isinstance(self.trunk, FeatureExtractorModel) and isinstance(
                self.trunk.base_model, FSDP):
            if checkpoint["type"] == CheckpointType.consolidated.name:
                self.trunk.base_model.load_state_dict(
                    checkpoint["classy_state_dict"]["base_model"]["model"]
                    ["trunk"])
            else:
                self.trunk.base_model.load_local_state_dict(
                    checkpoint["classy_state_dict"]["base_model"]["model"]
                    ["trunk"])
            fsdp_recursive_reset_lazy_init(self.trunk.base_model)
        elif isinstance(self.trunk, FSDP):
            if checkpoint["type"] == CheckpointType.consolidated.name:
                self.trunk.load_state_dict(checkpoint["classy_state_dict"]
                                           ["base_model"]["model"]["trunk"])
            else:
                self.trunk.load_local_state_dict(
                    checkpoint["classy_state_dict"]["base_model"]["model"]
                    ["trunk"])
            fsdp_recursive_reset_lazy_init(self.trunk)

        # General case: support for multiple format of checkpoint
        else:
            params_from_file = config["MODEL"]["WEIGHTS_INIT"]
            skip_layers = params_from_file.get("SKIP_LAYERS", [])
            replace_prefix = params_from_file.get("REMOVE_PREFIX", None)
            append_prefix = params_from_file.get("APPEND_PREFIX", None)
            state_dict_key_name = params_from_file.get("STATE_DICT_KEY_NAME",
                                                       None)
            init_model_from_consolidated_weights(
                config,
                self,
                checkpoint,
                state_dict_key_name=state_dict_key_name,
                skip_layers=skip_layers,
                replace_prefix=replace_prefix,
                append_prefix=append_prefix,
            )
Esempio n. 3
0
    def init_model_from_weights_params_file(self,
                                            config: AttrDict,
                                            checkpoint: Dict[str, Any],
                                            strict: bool = False):
        """
        We initialize the weights from this checkpoint. However, we don't care
        about the other metadata like iteration number etc.
        So the method only reads the state_dict
        """

        # Specific case for FSDP trunks:
        # - models have to be created with VISSL
        # - checkpoints have to be created with VISSL
        if isinstance(self.trunk, FeatureExtractorModel) and isinstance(
                self.trunk.base_model, FSDP):
            # Linear evaluation / extraction from FSDP models:
            # - load the trunk (complete strict load)
            # - load the head (optional and partial load supported)
            logging.info("Loading FSDP trunk in extraction mode")
            CheckpointLoader.init_fsdp_model_from_weights(
                self.trunk.base_model,
                checkpoint,
                weights_path=[
                    "classy_state_dict", "base_model", "model", "trunk"
                ],
            )
            fsdp_recursive_reset_lazy_init(self.trunk.base_model)
            if should_init_head_weights(config.MODEL):
                self._init_fsdp_model_heads_from_weights_params_file(
                    checkpoint)

        elif isinstance(self.trunk, FSDP):
            # Fine-tuning of FSDP models:
            # - load the trunk (complete strict load)
            # - load the head (optional and partial load supported)
            logging.info("Loading FSDP trunk")
            CheckpointLoader.init_fsdp_model_from_weights(
                self.trunk,
                checkpoint,
                weights_path=[
                    "classy_state_dict", "base_model", "model", "trunk"
                ],
            )
            fsdp_recursive_reset_lazy_init(self.trunk)
            if should_init_head_weights(config.MODEL):
                self._init_fsdp_model_heads_from_weights_params_file(
                    checkpoint)

        # General case: support for multiple format of checkpoint
        else:
            params_from_file = config["MODEL"]["WEIGHTS_INIT"]
            skip_layers = params_from_file.get("SKIP_LAYERS", [])
            replace_prefix = params_from_file.get("REMOVE_PREFIX", None)
            append_prefix = params_from_file.get("APPEND_PREFIX", None)
            state_dict_key_name = params_from_file.get("STATE_DICT_KEY_NAME",
                                                       None)
            init_model_from_consolidated_weights(
                config,
                self,
                checkpoint,
                state_dict_key_name=state_dict_key_name,
                skip_layers=skip_layers,
                replace_prefix=replace_prefix,
                append_prefix=append_prefix,
                strict=strict,
            )