def load_encoder_from_transformers_weights( encoder: nn.Module, weights_dict: dict, return_remainder=False ): """Find encoder weights in weights dict, load them into encoder, return any remaining weights. TODO: clarify how we know the encoder weights will be prefixed by transformer name. Args: encoder (PreTrainedModel): Transformer w/o heads (embedding layer + self-attention layer). weights_dict (Dict): model weights. return_remainder (bool): If True, return any leftover weights. Returns: Dict containing any leftover weights. """ remainder_weights_dict = {} load_weights_dict = {} model_arch = ModelArchitectures.from_encoder(encoder=encoder) encoder_prefix = MODEL_PREFIX[model_arch] + "." # Encoder for k, v in weights_dict.items(): if k.startswith(encoder_prefix): load_weights_dict[strings.remove_prefix(k, encoder_prefix)] = v else: remainder_weights_dict[k] = v encoder.load_state_dict(load_weights_dict) if return_remainder: return remainder_weights_dict
def output_hidden_states_context(encoder): model_arch = ModelArchitectures.from_encoder(encoder) if model_arch in ( ModelArchitectures.BERT, ModelArchitectures.ROBERTA, ModelArchitectures.ALBERT, ModelArchitectures.XLM_ROBERTA, ModelArchitectures.ELECTRA, ): if hasattr(encoder.encoder, "output_hidden_states"): # Transformers < v2 modified_obj = encoder.encoder elif hasattr(encoder.encoder.config, "output_hidden_states"): # Transformers >= v3 modified_obj = encoder.encoder.config else: raise RuntimeError( f"Failed to convert model {type(encoder)} to output hidden states" ) old_value = modified_obj.output_hidden_states modified_obj.output_hidden_states = True yield modified_obj.output_hidden_states = old_value elif model_arch in (ModelArchitectures.BART, ModelArchitectures.MBART): yield return else: raise KeyError(model_arch)
def get_output_from_encoder(encoder, input_ids, segment_ids, input_mask) -> EncoderOutput: """Pass inputs to encoder, return encoder output. Args: encoder: bare model outputting raw hidden-states without any specific head. input_ids: token indices (see huggingface.co/transformers/glossary.html#input-ids). segment_ids: token type ids (see huggingface.co/transformers/glossary.html#token-type-ids). input_mask: attention mask (see huggingface.co/transformers/glossary.html#attention-mask). Raises: RuntimeError if encoder output contains less than 2 elements. Returns: EncoderOutput containing pooled and unpooled model outputs as well as any other outputs. """ model_arch = ModelArchitectures.from_encoder(encoder) if model_arch in [ ModelArchitectures.BERT, ModelArchitectures.ROBERTA, ModelArchitectures.ALBERT, ModelArchitectures.XLM_ROBERTA, ]: pooled, unpooled, other = get_output_from_standard_transformer_models( encoder=encoder, input_ids=input_ids, segment_ids=segment_ids, input_mask=input_mask, ) elif model_arch == ModelArchitectures.ELECTRA: pooled, unpooled, other = get_output_from_electra( encoder=encoder, input_ids=input_ids, segment_ids=segment_ids, input_mask=input_mask, ) elif model_arch in [ ModelArchitectures.BART, ModelArchitectures.MBART, ]: pooled, unpooled, other = get_output_from_bart_models( encoder=encoder, input_ids=input_ids, input_mask=input_mask, ) else: raise KeyError(model_arch) # Extend later with attention, hidden_acts, etc if other: return EncoderOutput(pooled=pooled, unpooled=unpooled, other=other) else: return EncoderOutput(pooled=pooled, unpooled=unpooled)
def get_model_arch_from_jiant_model(jiant_model: nn.Module) -> ModelArchitectures: return ModelArchitectures.from_encoder(encoder=jiant_model.encoder)