예제 #1
0
    def __init__(self,
                 params,
                 tokenizer,
                 start_mention_id=None,
                 end_mention_id=None):
        super(CrossEncoderModule, self).__init__()
        model_path = params["bert_model"]
        if params.get("roberta"):
            encoder_model = RobertaModel.from_pretrained(model_path)
        else:
            encoder_model = BertModel.from_pretrained(model_path)
        encoder_model.resize_token_embeddings(len(tokenizer))
        self.pool_highlighted = params["pool_highlighted"]
        self.encoder = BertEncoder(encoder_model,
                                   params["out_dim"],
                                   layer_pulled=params["pull_from_layer"],
                                   add_linear=params["add_linear"]
                                   and not self.pool_highlighted,
                                   get_all_outputs=self.pool_highlighted)
        self.config = self.encoder.bert_model.config
        self.start_mention_id = start_mention_id
        self.end_mention_id = end_mention_id

        if self.pool_highlighted:
            bert_output_dim = encoder_model.embeddings.word_embeddings.weight.size(
                1)
            output_dim = params["out_dim"]
            self.additional_linear = nn.Linear(2 * bert_output_dim, output_dim)
            self.dropout = nn.Dropout(0.1)
예제 #2
0
 def __init__(self, params):
     super(BiEncoderModule, self).__init__()
     ctxt_bert = BertModel.from_pretrained(params["bert_model"])
     cand_bert = BertModel.from_pretrained(params['bert_model'])
     self.context_encoder = BertEncoder(
         ctxt_bert,
         params["out_dim"],
         layer_pulled=params["pull_from_layer"],
         add_linear=params["add_linear"],
     )
     self.cand_encoder = BertEncoder(
         cand_bert,
         params["out_dim"],
         layer_pulled=params["pull_from_layer"],
         add_linear=params["add_linear"],
     )
     self.config = ctxt_bert.config
예제 #3
0
 def __init__(self, params):
     super(BiEncoderModule, self).__init__()
     ctxt_bert = BertModel.from_pretrained(
         params["bert_model"]
     )  # Could be a path containing config.json and pytorch_model.bin; or could be an id shorthand for a model that is loaded in the library
     cand_bert = BertModel.from_pretrained(params["bert_model"])
     self.context_encoder = BertEncoder(
         ctxt_bert,
         params["out_dim"],
         layer_pulled=params["pull_from_layer"],
         add_linear=params["add_linear"],
     )
     self.cand_encoder = BertEncoder(
         cand_bert,
         params["out_dim"],
         layer_pulled=params["pull_from_layer"],
         add_linear=params["add_linear"],
     )
     self.config = ctxt_bert.config
예제 #4
0
 def __init__(self, params, tokenizer):
     super(CrossEncoderModule, self).__init__()
     model_path = params["bert_model"]
     if params.get("roberta"):
         encoder_model = RobertaModel.from_pretrained(model_path)
     else:
         encoder_model = BertModel.from_pretrained(model_path)
     encoder_model.resize_token_embeddings(len(tokenizer))
     self.encoder = BertEncoder(
         encoder_model,
         params["out_dim"],
         layer_pulled=params["pull_from_layer"],
         add_linear=params["add_linear"],
     )
     self.config = self.encoder.bert_model.config