Example #1
0
def load_encoder_from_transformers_weights(
    encoder: nn.Module, weights_dict: dict, return_remainder=False
):
    """Find encoder weights in weights dict, load them into encoder, return any remaining weights.

    TODO: clarify how we know the encoder weights will be prefixed by transformer name.

    Args:
        encoder (PreTrainedModel): Transformer w/o heads (embedding layer + self-attention layer).
        weights_dict (Dict): model weights.
        return_remainder (bool): If True, return any leftover weights.

    Returns:
        Dict containing any leftover weights.

    """
    remainder_weights_dict = {}
    load_weights_dict = {}
    model_arch = ModelArchitectures.from_encoder(encoder=encoder)
    encoder_prefix = MODEL_PREFIX[model_arch] + "."
    # Encoder
    for k, v in weights_dict.items():
        if k.startswith(encoder_prefix):
            load_weights_dict[strings.remove_prefix(k, encoder_prefix)] = v
        else:
            remainder_weights_dict[k] = v
    encoder.load_state_dict(load_weights_dict)
    if return_remainder:
        return remainder_weights_dict
Example #2
0
def output_hidden_states_context(encoder):
    model_arch = ModelArchitectures.from_encoder(encoder)
    if model_arch in (
            ModelArchitectures.BERT,
            ModelArchitectures.ROBERTA,
            ModelArchitectures.ALBERT,
            ModelArchitectures.XLM_ROBERTA,
            ModelArchitectures.ELECTRA,
    ):
        if hasattr(encoder.encoder, "output_hidden_states"):
            # Transformers < v2
            modified_obj = encoder.encoder
        elif hasattr(encoder.encoder.config, "output_hidden_states"):
            # Transformers >= v3
            modified_obj = encoder.encoder.config
        else:
            raise RuntimeError(
                f"Failed to convert model {type(encoder)} to output hidden states"
            )
        old_value = modified_obj.output_hidden_states
        modified_obj.output_hidden_states = True
        yield
        modified_obj.output_hidden_states = old_value
    elif model_arch in (ModelArchitectures.BART, ModelArchitectures.MBART):
        yield
        return
    else:
        raise KeyError(model_arch)
Example #3
0
def get_output_from_encoder(encoder, input_ids, segment_ids,
                            input_mask) -> EncoderOutput:
    """Pass inputs to encoder, return encoder output.

    Args:
        encoder: bare model outputting raw hidden-states without any specific head.
        input_ids: token indices (see huggingface.co/transformers/glossary.html#input-ids).
        segment_ids: token type ids (see huggingface.co/transformers/glossary.html#token-type-ids).
        input_mask: attention mask (see huggingface.co/transformers/glossary.html#attention-mask).

    Raises:
        RuntimeError if encoder output contains less than 2 elements.

    Returns:
        EncoderOutput containing pooled and unpooled model outputs as well as any other outputs.

    """
    model_arch = ModelArchitectures.from_encoder(encoder)
    if model_arch in [
            ModelArchitectures.BERT,
            ModelArchitectures.ROBERTA,
            ModelArchitectures.ALBERT,
            ModelArchitectures.XLM_ROBERTA,
    ]:
        pooled, unpooled, other = get_output_from_standard_transformer_models(
            encoder=encoder,
            input_ids=input_ids,
            segment_ids=segment_ids,
            input_mask=input_mask,
        )
    elif model_arch == ModelArchitectures.ELECTRA:
        pooled, unpooled, other = get_output_from_electra(
            encoder=encoder,
            input_ids=input_ids,
            segment_ids=segment_ids,
            input_mask=input_mask,
        )
    elif model_arch in [
            ModelArchitectures.BART,
            ModelArchitectures.MBART,
    ]:
        pooled, unpooled, other = get_output_from_bart_models(
            encoder=encoder,
            input_ids=input_ids,
            input_mask=input_mask,
        )
    else:
        raise KeyError(model_arch)

    # Extend later with attention, hidden_acts, etc
    if other:
        return EncoderOutput(pooled=pooled, unpooled=unpooled, other=other)
    else:
        return EncoderOutput(pooled=pooled, unpooled=unpooled)
Example #4
0
def get_model_arch_from_jiant_model(jiant_model: nn.Module) -> ModelArchitectures:
    return ModelArchitectures.from_encoder(encoder=jiant_model.encoder)