Beispiel #1
0
def add_enc_adapters(bert_model: BertModel,
                     config: AdapterConfig) -> BertModel:

    # Replace specific layer with adapter-added layer
    bert_encoder = bert_model.encoder
    for i in range(len(bert_model.encoder.layer)):
        bert_encoder.layer[i].attention.output = adapt_bert_self_output(
            config)(bert_encoder.layer[i].attention.output)
        bert_encoder.layer[i].output = adapt_bert_output(config)(
            bert_encoder.layer[i].output)

    # Freeze all parameters
    for param in bert_model.parameters():
        param.requires_grad = False
    # Unfreeze trainable parts — layer norms and adapters
    for name, sub_module in bert_model.named_modules():
        if isinstance(sub_module, (Adapter_func, BertLayerNorm)):
            for param_name, param in sub_module.named_parameters():
                param.requires_grad = True
    return bert_model