예제 #1
0
    def _restore_model_weights(self, model):
        """
        If using a weights file to initialize the model, we load the weights
        and initialize the model. Since the weights file specified
        by user might not be VISSL trained weights, we expose several config
        options like APPEND_PREFIX, etc to allow successful loading of the weights.
        See MODEL.WEIGHTS_INIT description in vissl/config/defaults.yaml for details.
        """
        params_from_file = self.config["MODEL"]["WEIGHTS_INIT"]
        init_weights_path = params_from_file["PARAMS_FILE"]
        assert init_weights_path, "Shouldn't call this when init_weight_path is empty"
        logging.info(f"Initializing model from: {init_weights_path}")

        if PathManager.exists(init_weights_path):
            weights = load_and_broadcast_checkpoint(init_weights_path,
                                                    device=torch.device("cpu"))
            skip_layers = params_from_file.get("SKIP_LAYERS", [])
            replace_prefix = params_from_file.get("REMOVE_PREFIX", None)
            append_prefix = params_from_file.get("APPEND_PREFIX", None)
            state_dict_key_name = params_from_file.get("STATE_DICT_KEY_NAME",
                                                       None)

            # we initialize the weights from this checkpoint. However, we
            # don't care about the other metadata like iteration number etc.
            # So the method only reads the state_dict
            init_model_from_weights(
                self.config,
                model,
                weights,
                state_dict_key_name=state_dict_key_name,
                skip_layers=skip_layers,
                replace_prefix=replace_prefix,
                append_prefix=append_prefix,
            )
        return model
예제 #2
0
def build_retrieval_model(cfg):
    """
    Builds the model on 1-gpu and initializes from the weight.
    """
    logging.info("Building model....")
    model = build_model(cfg.MODEL, cfg.OPTIMIZER)
    if PathManager.exists(cfg.MODEL.WEIGHTS_INIT.PARAMS_FILE):
        init_weights_path = cfg.MODEL.WEIGHTS_INIT.PARAMS_FILE
        logging.info(f"Initializing model from: {init_weights_path}")
        weights = torch.load(init_weights_path, map_location=torch.device("cuda"))
        skip_layers = cfg.MODEL.WEIGHTS_INIT.get("SKIP_LAYERS", [])
        replace_prefix = cfg.MODEL.WEIGHTS_INIT.get("REMOVE_PREFIX", None)
        append_prefix = cfg.MODEL.WEIGHTS_INIT.get("APPEND_PREFIX", None)
        state_dict_key_name = cfg.MODEL.WEIGHTS_INIT.get("STATE_DICT_KEY_NAME", None)

        init_model_from_weights(
            cfg,
            model,
            weights,
            state_dict_key_name=state_dict_key_name,
            skip_layers=skip_layers,
            replace_prefix=replace_prefix,
            append_prefix=append_prefix,
        )
    else:
        # We only throw the warning if not weights file is provided. We want to
        # benchmark the random initialization model too and hence support that.
        logging.warning("Model is randomly initialized....")
    logging.info(f"Model is:\n {model}")
    return model