def __predict(self, data, include_timing=False, eval_raw=False): """ Uses the trained model to make predictions of individual batches (i.e. documents). :return: predictions and time taken for the ED step. """ predictions = {items[0]["doc_name"]: [] for items in data} self.model.eval() timing = [] for batch in data: # each document is a minibatch start = time.time() token_ids = [ m["context"][0] + m["context"][1] if len(m["context"][0]) + len(m["context"][1]) > 0 else [self.embeddings["word_voca"].unk_id] for m in batch ] s_ltoken_ids = [m["snd_ctx"][0] for m in batch] s_rtoken_ids = [m["snd_ctx"][1] for m in batch] s_mtoken_ids = [m["snd_ment"] for m in batch] entity_ids = Variable( torch.LongTensor([m["selected_cands"]["cands"] for m in batch]).to(self.device)) p_e_m = Variable( torch.FloatTensor([ m["selected_cands"]["p_e_m"] for m in batch ]).to(self.device)) entity_mask = Variable( torch.FloatTensor([m["selected_cands"]["mask"] for m in batch]).to(self.device)) true_pos = Variable( torch.LongTensor([ m["selected_cands"]["true_pos"] for m in batch ]).to(self.device)) token_ids, token_mask = utils.make_equal_len( token_ids, self.embeddings["word_voca"].unk_id) s_ltoken_ids, s_ltoken_mask = utils.make_equal_len( s_ltoken_ids, self.embeddings["snd_voca"].unk_id, to_right=False) s_rtoken_ids, s_rtoken_mask = utils.make_equal_len( s_rtoken_ids, self.embeddings["snd_voca"].unk_id) s_rtoken_ids = [l[::-1] for l in s_rtoken_ids] s_rtoken_mask = [l[::-1] for l in s_rtoken_mask] s_mtoken_ids, s_mtoken_mask = utils.make_equal_len( s_mtoken_ids, self.embeddings["snd_voca"].unk_id) token_ids = Variable(torch.LongTensor(token_ids).to(self.device)) token_mask = Variable( torch.FloatTensor(token_mask).to(self.device)) self.model.s_ltoken_ids = Variable( torch.LongTensor(s_ltoken_ids).to(self.device)) self.model.s_ltoken_mask = Variable( torch.FloatTensor(s_ltoken_mask).to(self.device)) self.model.s_rtoken_ids = Variable( torch.LongTensor(s_rtoken_ids).to(self.device)) self.model.s_rtoken_mask = Variable( torch.FloatTensor(s_rtoken_mask).to(self.device)) self.model.s_mtoken_ids = Variable( torch.LongTensor(s_mtoken_ids).to(self.device)) self.model.s_mtoken_mask = Variable( torch.FloatTensor(s_mtoken_mask).to(self.device)) scores, ent_scores = self.model.forward( token_ids, token_mask, entity_ids, entity_mask, p_e_m, self.embeddings, gold=true_pos.view(-1, 1), ) pred_ids = torch.argmax(scores, axis=1) scores = scores.cpu().data.numpy() confidence_scores = self.__compute_confidence(scores, pred_ids) pred_ids = np.argmax(scores, axis=1) if not eval_raw: pred_entities = [ m["selected_cands"]["named_cands"][i] if m["selected_cands"]["mask"][i] == 1 else (m["selected_cands"]["named_cands"][0] if m["selected_cands"]["mask"][0] == 1 else "NIL") for (i, m) in zip(pred_ids, batch) ] doc_names = [m["doc_name"] for m in batch] for dname, entity in zip(doc_names, pred_entities): predictions[dname].append({"pred": (entity, 0.0)}) else: pred_entities = [[ m["selected_cands"]["named_cands"][i], m["raw"]["mention"], m["selected_cands"]["named_cands"], s, cs, m["selected_cands"]["mask"], ] if m["selected_cands"]["mask"][i] == 1 else ([ m["selected_cands"]["named_cands"][0], m["raw"]["mention"], m["selected_cands"]["named_cands"], s, cs, m["selected_cands"]["mask"], ] if m["selected_cands"]["mask"][0] == 1 else [ "NIL", m["raw"]["mention"], m["selected_cands"]["named_cands"], s, cs, m["selected_cands"]["mask"], ]) for (i, m, s, cs) in zip(pred_ids, batch, scores, confidence_scores)] doc_names = [m["doc_name"] for m in batch] for dname, entity in zip(doc_names, pred_entities): if entity[0] != "NIL": predictions[dname].append({ "mention": entity[1], "prediction": entity[0], "candidates": entity[2], "conf_ed": entity[4], "scores": list([str(x) for x in entity[3]]), }) else: predictions[dname].append({ "mention": entity[1], "prediction": entity[0], "candidates": entity[2], "scores": [], }) timing.append(time.time() - start) if include_timing: return predictions, timing else: return predictions
def train(self, org_train_dataset, org_dev_datasets): """ Responsible for training the ED model. :return: - """ train_dataset = self.get_data_items(org_train_dataset, "train", predict=False) dev_datasets = [] for dname, data in org_dev_datasets.items(): dev_datasets.append( (dname, self.get_data_items(data, dname, predict=True))) print("Creating optimizer") optimizer = optim.Adam( [p for p in self.model.parameters() if p.requires_grad], lr=self.config["learning_rate"], ) best_f1 = -1 not_better_count = 0 eval_after_n_epochs = self.config["eval_after_n_epochs"] for e in range(self.config["n_epochs"]): shuffle(train_dataset) total_loss = 0 for dc, batch in enumerate( train_dataset): # each document is a minibatch self.model.train() optimizer.zero_grad() # convert data items to pytorch inputs token_ids = [ m["context"][0] + m["context"][1] if len(m["context"][0]) + len(m["context"][1]) > 0 else [self.embeddings["word_voca"].unk_id] for m in batch ] s_ltoken_ids = [m["snd_ctx"][0] for m in batch] s_rtoken_ids = [m["snd_ctx"][1] for m in batch] s_mtoken_ids = [m["snd_ment"] for m in batch] entity_ids = Variable( torch.LongTensor([ m["selected_cands"]["cands"] for m in batch ]).to(self.device)) true_pos = Variable( torch.LongTensor([ m["selected_cands"]["true_pos"] for m in batch ]).to(self.device)) p_e_m = Variable( torch.FloatTensor([ m["selected_cands"]["p_e_m"] for m in batch ]).to(self.device)) entity_mask = Variable( torch.FloatTensor([ m["selected_cands"]["mask"] for m in batch ]).to(self.device)) token_ids, token_mask = utils.make_equal_len( token_ids, self.embeddings["word_voca"].unk_id) s_ltoken_ids, s_ltoken_mask = utils.make_equal_len( s_ltoken_ids, self.embeddings["snd_voca"].unk_id, to_right=False) s_rtoken_ids, s_rtoken_mask = utils.make_equal_len( s_rtoken_ids, self.embeddings["snd_voca"].unk_id) s_rtoken_ids = [l[::-1] for l in s_rtoken_ids] s_rtoken_mask = [l[::-1] for l in s_rtoken_mask] s_mtoken_ids, s_mtoken_mask = utils.make_equal_len( s_mtoken_ids, self.embeddings["snd_voca"].unk_id) token_ids = Variable( torch.LongTensor(token_ids).to(self.device)) token_mask = Variable( torch.FloatTensor(token_mask).to(self.device)) # too ugly but too lazy to fix it self.model.s_ltoken_ids = Variable( torch.LongTensor(s_ltoken_ids).to(self.device)) self.model.s_ltoken_mask = Variable( torch.FloatTensor(s_ltoken_mask).to(self.device)) self.model.s_rtoken_ids = Variable( torch.LongTensor(s_rtoken_ids).to(self.device)) self.model.s_rtoken_mask = Variable( torch.FloatTensor(s_rtoken_mask).to(self.device)) self.model.s_mtoken_ids = Variable( torch.LongTensor(s_mtoken_ids).to(self.device)) self.model.s_mtoken_mask = Variable( torch.FloatTensor(s_mtoken_mask).to(self.device)) scores, ent_scores = self.model.forward( token_ids, token_mask, entity_ids, entity_mask, p_e_m, self.embeddings, gold=true_pos.view(-1, 1), ) loss = self.model.loss(scores, true_pos) # loss = self.model.prob_loss(scores, true_pos) loss.backward() optimizer.step() self.model.regularize(max_norm=100) loss = loss.cpu().data.numpy() total_loss += loss print( "epoch", e, "%0.2f%%" % (dc / len(train_dataset) * 100), loss, end="\r", ) print("epoch", e, "total loss", total_loss, total_loss / len(train_dataset)) if (e + 1) % eval_after_n_epochs == 0: dev_f1 = 0 for dname, data in dev_datasets: predictions = self.__predict(data) f1, recall, precision, _ = self.__eval( org_dev_datasets[dname], predictions) print( dname, utils.tokgreen( "Micro F1: {}, Recall: {}, Precision: {}".format( f1, recall, precision)), ) if dname == "aida_testA": dev_f1 = f1 if (self.config["learning_rate"] == 1e-4 and dev_f1 >= self.config["dev_f1_change_lr"]): eval_after_n_epochs = 2 best_f1 = dev_f1 not_better_count = 0 self.config["learning_rate"] = 1e-5 print("change learning rate to", self.config["learning_rate"]) for param_group in optimizer.param_groups: param_group["lr"] = self.config["learning_rate"] if dev_f1 < best_f1: not_better_count += 1 print("Not improving", not_better_count) else: not_better_count = 0 best_f1 = dev_f1 print("save model to", self.config["model_path"]) self.__save(self.config["model_path"]) if not_better_count == self.config["n_not_inc"]: break