class Classifier(torch.nn.Module): def __init__(self, hidden_size=768, linear_out=2, batch_first=True): super(Classifier, self).__init__() self.output_model_file = "lm/pytorch_model.bin" self.output_config_file = "lm/config.json" self.tokenizer = BertTokenizer.from_pretrained("lm", do_lower_case=False) self.config = BertConfig.from_json_file(self.output_config_file) self.model = BertForMaskedLM(self.config) device = torch.device( 'cuda') if torch.cuda.is_available() else torch.device('cpu') self.state_dict = torch.load(self.output_model_file, map_location=device) self.model.load_state_dict(self.state_dict) self.lstm = torch.nn.LSTM(hidden_size, 300) self.linear = torch.nn.Linear(300, linear_out) def get_embeddings(self, x_instance): indexed_tokens = x_instance.tolist() break_sentence = indexed_tokens.index(102) tokens_tensor = torch.tensor([indexed_tokens]) segments_ids = [0] * (break_sentence + 1) segments_ids += [1] * (len(indexed_tokens) - break_sentence - 1) segments_tensors = torch.tensor([segments_ids]) self.model.eval() with torch.no_grad(): encoded_layers, _ = self.model.bert(tokens_tensor.to(device), segments_tensors.to(device)) token_embeddings = torch.stack(encoded_layers, dim=0) token_embeddings = torch.squeeze(token_embeddings, dim=1) token_embeddings = token_embeddings.permute(1, 0, 2) token_vecs_cat = [] for token in token_embeddings: cat_vec = torch.stack((token[-1], token[-2], token[-3], token[-4])) mean_vec = torch.mean(cat_vec, 0) token_vecs_cat.append(mean_vec) token_vecs_cat = torch.stack(token_vecs_cat, dim=0) return token_vecs_cat def embed_data(self, x): entries = [] for entry in x: emb = self.get_embeddings(entry.to(device)).to(device) entries.append(emb) return torch.stack(entries) def forward(self, x): h = self.embed_data(x) h = h.permute(1, 0, 2) output, _ = self.lstm(h) pred = self.linear(output) pred = pred.permute(1, 0, 2) return pred
def create_bert_for_masked_lm(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): model = BertForMaskedLM(config=config) model.eval() loss = model(input_ids, token_type_ids, input_mask, token_labels) prediction_scores = model(input_ids, token_type_ids, input_mask) outputs = { "loss": loss, "prediction_scores": prediction_scores, } return outputs
def get_words_for_blank_slow_decode(text: str, model: BertForMaskedLM, tokenizer: BertTokenizer): random.seed(42) np.random.seed(42) torch.manual_seed(42) mask_positions = [] tokenized_text = tokenizer.tokenize(text) top_words_all = [] for i in range(len(tokenized_text)): if tokenized_text[i] == '_': tokenized_text[i] = '[MASK]' mask_positions.append(i) while mask_positions: top_words = [] # Convert tokens to vocab indices token_ids = tokenizer.convert_tokens_to_ids(tokenized_text) tokens_tensor = torch.tensor([token_ids]) # Call BERT to calculate unnormalized probabilities for all pos model.eval() predictions = model(tokens_tensor) # get predictions mask_preds = predictions[0, mask_positions, :] candidates = [] #(word, prob) for mask_pos in mask_positions: mask_preds = predictions[0, mask_pos, :] top_idxs = mask_preds.detach().numpy().argsort()[::-1] top_idx = top_idxs[0] top_prob = mask_preds[top_idx] top_word = tokenizer.ids_to_tokens[top_idx] candidates.append((top_word, top_prob.detach().item())) top_words_pos = [] for i in top_idxs[:20]: top_words_pos.append((tokenizer.ids_to_tokens[i], mask_preds[i].detach().item())) top_words.append(top_words_pos) best_candidate = max(candidates, key = lambda x: x[1]) best_pos = mask_positions[candidates.index(best_candidate)] tokenized_text[best_pos] = best_candidate[0] mask_positions = [i for i in mask_positions if i != best_pos] top_words_all.append(top_words[candidates.index(best_candidate)]) pred_sent = ' '.join(tokenized_text).replace(' ##', '') return (pred_sent, top_words_all)
def predict_word(text: str, model: BertForMaskedLM, tokenizer: BertTokenizer, tgt_word: str, tgt_pos: int): # print('Template sentence: ', text) mask_positions = [] # insert mask tokens tokenized_text = tokenizer.tokenize(text) for i in range(len(tokenized_text)): if tokenized_text[i] == '_': tokenized_text[i] = '[MASK]' mask_positions.append(i) # Convert tokens to vocab indices token_ids = tokenizer.convert_tokens_to_ids(tokenized_text) tokens_tensor = torch.tensor([token_ids]) # Call BERT to calculate unnormalized probabilities for all pos model.eval() predictions = model(tokens_tensor) # normalize by softmax predictions = F.softmax(predictions, dim=2) # For the target word position, get probabilities for each word of interest normalized = predictions[0, tgt_pos, :] out_prob = normalized[tokenizer.vocab[tgt_word]].item() # Also, fill in all blanks by max prob, and print for inspection for mask_pos in mask_positions: predicted_index = torch.argmax(predictions[0, mask_pos, :]).item() predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0] tokenized_text[mask_pos] = predicted_token for mask_pos in mask_positions: tokenized_text[mask_pos] = "_" + tokenized_text[mask_pos] + "_" pred_sent = ' '.join(tokenized_text).replace(' ##', '') # print(pred_sent) return out_prob, pred_sent