Beispiel #1
0
def load_lm_heads_from_transformers_weights(jiant_model, weights_dict):
    model_arch = get_model_arch_from_jiant_model(jiant_model=jiant_model)
    if model_arch == ModelArchitectures.BERT:
        mlm_weights_map = {
            "bias": "cls.predictions.bias",
            "dense.weight": "cls.predictions.transform.dense.weight",
            "dense.bias": "cls.predictions.transform.dense.bias",
            "LayerNorm.weight": "cls.predictions.transform.LayerNorm.weight",
            "LayerNorm.bias": "cls.predictions.transform.LayerNorm.bias",
            "decoder.weight": "cls.predictions.decoder.weight",
            "decoder.bias": "cls.predictions.bias",  # <-- linked directly to bias
        }
        mlm_weights_dict = {new_k: weights_dict[old_k] for new_k, old_k in mlm_weights_map.items()}
    elif model_arch in (ModelArchitectures.ROBERTA, ModelArchitectures.XLM_ROBERTA):
        mlm_weights_dict = {
            strings.remove_prefix(k, "lm_head."): v for k, v in weights_dict.items()
        }
        mlm_weights_dict["decoder.bias"] = mlm_weights_dict["bias"]
    elif model_arch == ModelArchitectures.ALBERT:
        mlm_weights_dict = {
            strings.remove_prefix(k, "predictions."): v for k, v in weights_dict.items()
        }
    else:
        raise KeyError(model_arch)
    missed = set()
    for taskmodel_name, taskmodel in jiant_model.taskmodels_dict.items():
        if not isinstance(taskmodel, taskmodels.MLMModel):
            continue
        mismatch = taskmodel.mlm_head.load_state_dict(mlm_weights_dict)
        assert not mismatch.missing_keys
        missed.update(mismatch.unexpected_keys)
        taskmodel.mlm_head.decoder.weight = jiant_model.encoder.embeddings.word_embeddings.weight
    return list(missed)
Beispiel #2
0
def load_encoder_from_transformers_weights(
    encoder: nn.Module, weights_dict: dict, return_remainder=False
):
    """Find encoder weights in weights dict, load them into encoder, return any remaining weights.

    TODO: clarify how we know the encoder weights will be prefixed by transformer name.

    Args:
        encoder (PreTrainedModel): Transformer w/o heads (embedding layer + self-attention layer).
        weights_dict (Dict): model weights.
        return_remainder (bool): If True, return any leftover weights.

    Returns:
        Dict containing any leftover weights.

    """
    remainder_weights_dict = {}
    load_weights_dict = {}
    model_arch = ModelArchitectures.from_encoder(encoder=encoder)
    encoder_prefix = MODEL_PREFIX[model_arch] + "."
    # Encoder
    for k, v in weights_dict.items():
        if k.startswith(encoder_prefix):
            load_weights_dict[strings.remove_prefix(k, encoder_prefix)] = v
        else:
            remainder_weights_dict[k] = v
    encoder.load_state_dict(load_weights_dict)
    if return_remainder:
        return remainder_weights_dict
Beispiel #3
0
 def get_mlm_weights_dict(self, weights_dict):
     mlm_weights_dict = {
         strings.remove_prefix(k, "lm_head."): v
         for k, v in weights_dict.items()
     }
     mlm_weights_dict["decoder.bias"] = mlm_weights_dict["bias"]
     return mlm_weights_dict
Beispiel #4
0
def load_encoder_from_transformers_weights(encoder: nn.Module,
                                           weights_dict: dict,
                                           return_remainder=False):
    """Find encoder weights in weights dict, load them into encoder, return any remaining weights.

    TODO: clarify how we know the encoder weights will be prefixed by transformer name.

    Args:
        encoder (PreTrainedModel): Transformer w/o heads (embedding layer + self-attention layer).
        weights_dict (Dict): model weights.
        return_remainder (bool): If True, return any leftover weights.

    Returns:
        Dict containing any leftover weights.

    """
    remainder_weights_dict = {}
    load_weights_dict = {}
    model_arch = ModelArchitectures.from_model_type(
        model_type=encoder.config.model_type)
    encoder_prefix = model_arch.value + "."
    # Encoder
    for k, v in weights_dict.items():
        if k.startswith(encoder_prefix):
            load_weights_dict[strings.remove_prefix(k, encoder_prefix)] = v
        elif k.startswith(encoder_prefix.split("-")[0]):
            # workaround for deberta-v2
            # remove "-v2" suffix. weight names are prefixed with "deberta" and not "deberta-v2"
            load_weights_dict[strings.remove_prefix(
                k,
                encoder_prefix.split("-")[0] + ".")] = v
        else:
            remainder_weights_dict[k] = v
    encoder.load_state_dict(load_weights_dict, strict=False)
    if remainder_weights_dict:
        warnings.warn("The following weights were not loaded: {}".format(
            remainder_weights_dict.keys()))
    if return_remainder:
        return remainder_weights_dict
Beispiel #5
0
 def get_mlm_weights_dict(self, weights_dict):
     mlm_weights_dict = {
         strings.remove_prefix(k, "predictions."): v
         for k, v in weights_dict.items()
     }
     return mlm_weights_dict