def load_lm_heads_from_transformers_weights(jiant_model, weights_dict): model_arch = get_model_arch_from_jiant_model(jiant_model=jiant_model) if model_arch == ModelArchitectures.BERT: mlm_weights_map = { "bias": "cls.predictions.bias", "dense.weight": "cls.predictions.transform.dense.weight", "dense.bias": "cls.predictions.transform.dense.bias", "LayerNorm.weight": "cls.predictions.transform.LayerNorm.weight", "LayerNorm.bias": "cls.predictions.transform.LayerNorm.bias", "decoder.weight": "cls.predictions.decoder.weight", "decoder.bias": "cls.predictions.bias", # <-- linked directly to bias } mlm_weights_dict = {new_k: weights_dict[old_k] for new_k, old_k in mlm_weights_map.items()} elif model_arch in (ModelArchitectures.ROBERTA, ModelArchitectures.XLM_ROBERTA): mlm_weights_dict = { strings.remove_prefix(k, "lm_head."): v for k, v in weights_dict.items() } mlm_weights_dict["decoder.bias"] = mlm_weights_dict["bias"] elif model_arch == ModelArchitectures.ALBERT: mlm_weights_dict = { strings.remove_prefix(k, "predictions."): v for k, v in weights_dict.items() } else: raise KeyError(model_arch) missed = set() for taskmodel_name, taskmodel in jiant_model.taskmodels_dict.items(): if not isinstance(taskmodel, taskmodels.MLMModel): continue mismatch = taskmodel.mlm_head.load_state_dict(mlm_weights_dict) assert not mismatch.missing_keys missed.update(mismatch.unexpected_keys) taskmodel.mlm_head.decoder.weight = jiant_model.encoder.embeddings.word_embeddings.weight return list(missed)
def load_encoder_from_transformers_weights( encoder: nn.Module, weights_dict: dict, return_remainder=False ): """Find encoder weights in weights dict, load them into encoder, return any remaining weights. TODO: clarify how we know the encoder weights will be prefixed by transformer name. Args: encoder (PreTrainedModel): Transformer w/o heads (embedding layer + self-attention layer). weights_dict (Dict): model weights. return_remainder (bool): If True, return any leftover weights. Returns: Dict containing any leftover weights. """ remainder_weights_dict = {} load_weights_dict = {} model_arch = ModelArchitectures.from_encoder(encoder=encoder) encoder_prefix = MODEL_PREFIX[model_arch] + "." # Encoder for k, v in weights_dict.items(): if k.startswith(encoder_prefix): load_weights_dict[strings.remove_prefix(k, encoder_prefix)] = v else: remainder_weights_dict[k] = v encoder.load_state_dict(load_weights_dict) if return_remainder: return remainder_weights_dict
def get_mlm_weights_dict(self, weights_dict): mlm_weights_dict = { strings.remove_prefix(k, "lm_head."): v for k, v in weights_dict.items() } mlm_weights_dict["decoder.bias"] = mlm_weights_dict["bias"] return mlm_weights_dict
def load_encoder_from_transformers_weights(encoder: nn.Module, weights_dict: dict, return_remainder=False): """Find encoder weights in weights dict, load them into encoder, return any remaining weights. TODO: clarify how we know the encoder weights will be prefixed by transformer name. Args: encoder (PreTrainedModel): Transformer w/o heads (embedding layer + self-attention layer). weights_dict (Dict): model weights. return_remainder (bool): If True, return any leftover weights. Returns: Dict containing any leftover weights. """ remainder_weights_dict = {} load_weights_dict = {} model_arch = ModelArchitectures.from_model_type( model_type=encoder.config.model_type) encoder_prefix = model_arch.value + "." # Encoder for k, v in weights_dict.items(): if k.startswith(encoder_prefix): load_weights_dict[strings.remove_prefix(k, encoder_prefix)] = v elif k.startswith(encoder_prefix.split("-")[0]): # workaround for deberta-v2 # remove "-v2" suffix. weight names are prefixed with "deberta" and not "deberta-v2" load_weights_dict[strings.remove_prefix( k, encoder_prefix.split("-")[0] + ".")] = v else: remainder_weights_dict[k] = v encoder.load_state_dict(load_weights_dict, strict=False) if remainder_weights_dict: warnings.warn("The following weights were not loaded: {}".format( remainder_weights_dict.keys())) if return_remainder: return remainder_weights_dict
def get_mlm_weights_dict(self, weights_dict): mlm_weights_dict = { strings.remove_prefix(k, "predictions."): v for k, v in weights_dict.items() } return mlm_weights_dict