def construct_qa_transformer(options: KaggleEvaluationOptions) -> Reranker: # We load a sequence classification model first -- again, as a workaround. Refactor. try: model = AutoModelForSequenceClassification.from_pretrained(options.model_name) except OSError: model = AutoModelForSequenceClassification.from_pretrained(options.model_name, from_tf=True) fixed_model = BertForQuestionAnswering(model.config) fixed_model.qa_outputs = model.classifier fixed_model.bert = model.bert device = torch.device(options.device) model = fixed_model.to(device).eval() tokenizer = AutoTokenizer.from_pretrained(options.tokenizer_name, do_lower_case=options.do_lower_case) return QuestionAnsweringTransformerReranker(model, tokenizer)
def demo4(): from transformers import BertTokenizer, BertForQuestionAnswering import torch MODEL_PATH = r"D:\transformr_files\bert-base-uncased/" # 实例化tokenizer tokenizer = BertTokenizer.from_pretrained( r"D:\transformr_files\bert-base-uncased\bert-base-uncased-vocab.txt") # 导入bert的model_config model_config = transformers.BertConfig.from_pretrained(MODEL_PATH) # 首先新建bert_model bert_model = transformers.BertModel.from_pretrained(MODEL_PATH, config=model_config) # 最终有两个输出,初始位置和结束位置(下面有解释) model_config.num_labels = 2 # 同样根据bert的model_config新建BertForQuestionAnswering model = BertForQuestionAnswering(model_config) model.bert = bert_model # 设定模式 model.eval() question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet" # 获取input_ids编码 input_ids = tokenizer.encode(question, text) # 手动进行token_type_ids编码,可用encode_plus代替 # input_ids = tokenizer.encode_plus("i like you", "but not him") token_type_ids = [ 0 if i <= input_ids.index(102) else 1 for i in range(len(input_ids)) ] # 得到评分, start_scores, end_scores = model(torch.tensor([input_ids]), token_type_ids=torch.tensor( [token_type_ids])) # 进行逆编码,得到原始的token all_tokens = tokenizer.convert_ids_to_tokens(input_ids) # ['[CLS]', 'who', 'was', 'jim', 'henson', '?', '[SEP]', 'jim', 'henson', 'was', 'a', 'nice', 'puppet', '[SEP]'] # 对输出的答案进行解码的过程 answer = ' '.join( all_tokens[torch.argmax(start_scores):torch.argmax(end_scores) + 1]) # assert answer == "a nice puppet" # 这里因为没有经过微调,所以效果不是很好,输出结果不佳。 print(answer)