Beispiel #1
0
class RetrievalModel(torch.nn.Module):
    def __init__(self, bert_path_or_config, fc_path=None):
        super(RetrievalModel, self).__init__()
        if isinstance(bert_path_or_config, str):
            self.bert = BertModel.from_pretrained(
                pretrained_model_name_or_path=bert_path_or_config)
            self.tokenizer = BertTokenizer.from_pretrained(bert_path_or_config)
        elif isinstance(bert_path_or_config, BertConfig):
            self.bert = BertModel(bert_path_or_config)
            self.tokenizer = None

        self.fc = torch.nn.Linear(in_features=self.bert.config.hidden_size,
                                  out_features=1,
                                  bias=False)
        if fc_path:
            self.fc.load_state_dict(torch.load(fc_path))

    def forward(self,
                input_ids=None,
                attention_mask=None,
                token_type_ids=None):
        _, pooler_output = self.bert(input_ids=input_ids,
                                     attention_mask=attention_mask,
                                     token_type_ids=token_type_ids)
        scores = self.fc(pooler_output)
        return scores

    def save(self, save_dir):
        self.bert.save_pretrained(save_dir)
        if self.tokenizer:
            self.tokenizer.save_pretrained(save_dir)
        torch.save(self.fc.state_dict(), os.path.join(save_dir,
                                                      "fc_weight.bin"))
Beispiel #2
0
def run(pretrained_model, out_dir, num_layers=3):
    os.makedirs(out_dir, exist_ok=True)

    tokenizer = AutoTokenizer.from_pretrained(pretrained_model)
    model = BertModel.from_pretrained(pretrained_model, return_dict=True)

    small_config = copy.deepcopy(model.config)
    small_config.num_hidden_layers = num_layers
    small_model = BertModel(small_config)
    small_model.load_state_dict(model.state_dict(), strict=False)

    tokenizer.save_pretrained(out_dir)
    small_model.save_pretrained(out_dir)
Beispiel #3
0
    def test_export_custom_bert_model(self):
        from transformers import BertModel

        vocab = ["[UNK]", "[SEP]", "[CLS]", "[PAD]", "[MASK]", "some", "other", "words"]
        with NamedTemporaryFile(mode="w+t") as vocab_file:
            vocab_file.write("\n".join(vocab))
            vocab_file.flush()
            tokenizer = BertTokenizerFast(vocab_file.name)

        with TemporaryDirectory() as bert_save_dir:
            model = BertModel(BertConfig(vocab_size=len(vocab)))
            model.save_pretrained(bert_save_dir)
            self._test_export(bert_save_dir, "pt", 12, tokenizer)