def entity_linking_plain(text): jpype.attachThreadToJVM() processed_text = make_text_into_conll(text) print('load conll at', datadir) conll = D.TestDataset(testdir, person_path) dev_datasets = [('tta', conll.tta)] if args.mode == 'test': org_dev_datasets = dev_datasets # + [('aida-train', conll.train)] dev_datasets = [] for dname, data in org_dev_datasets: dev_datasets.append((dname, ranker.get_data_items(data, predict=True))) print(dname, '#dev docs', len(dev_datasets[-1][1])) vecs = ranker.model.rel_embs.cpu().data.numpy() for di, (dname, data) in enumerate(dev_datasets): if di == 1: break ranker.model._coh_ctx_vecs = [] predictions = ranker.predict(data) print(dname, utils.tokgreen('micro F1: ' + str(D.eval_for_api(org_dev_datasets[di][1], predictions)))) with open("test_result_marking.txt", encoding="UTF8") as result_file: merged = merge_item(processed_text, result_file) return json.dumps(postprocess(merged), indent=4, sort_keys=True, ensure_ascii=False)
def print_attention(self, gold_pos): token_ids = self._token_ids.data.cpu().numpy() entity_ids = self._entity_ids.data.cpu().numpy() att_probs = self._att_probs.data.cpu().numpy() top_tok_att_ids = self._top_tok_att_ids.data.cpu().numpy() gold_pos = gold_pos.data.cpu().numpy() scores = self._scores.data.cpu().numpy() print('===========================================') for tids, eids, ap, aids, gpos, ss in zip(token_ids, entity_ids, att_probs, top_tok_att_ids, gold_pos, scores): selected_tids = tids[aids] print('-------------------------------') print(utils.tokgreen(repr([(self.entity_voca.id2word[e], s) for e, s in zip(eids, ss)])), utils.tokblue(repr(self.entity_voca.id2word[eids[gpos]] if gpos > -1 else 'UNKNOWN'))) print([(self.word_voca.id2word[t], a[0]) for t, a in zip(selected_tids, ap)])
dev_datasets = [('aida-A', conll.testA), ('aida-B', conll.testB), ('msnbc', conll.msnbc), ('aquaint', conll.aquaint), ('ace2004', conll.ace2004), ('clueweb', conll.clueweb), ('wikipedia', conll.wikipedia)] if args.mode == 'train': print('training...') config = {'lr': args.learning_rate, 'n_epochs': args.n_epochs} pprint(config) ranker.train(conll.train, dev_datasets, config) elif args.mode == 'eval': org_dev_datasets = dev_datasets # + [('aida-train', conll.train)] dev_datasets = [] for dname, data in org_dev_datasets: dev_datasets.append( (dname, ranker.get_data_items(data, predict=True))) print(dname, '#dev docs', len(dev_datasets[-1][1])) vecs = ranker.model.rel_embs.cpu().data.numpy() for di, (dname, data) in enumerate(dev_datasets): ranker.model._coh_ctx_vecs = [] # predict each dataset one by one predictions = ranker.predict(data) print( dname, utils.tokgreen( 'micro F1: ' + str(D.eval(org_dev_datasets[di][1], predictions))))
def predict(self, data, n_best=1, with_oracle=False, inference=None): """ n_best = 1: normal prediction n_best = -1: print out all candidates sorted by their scores n_best > 1: using oracle """ predictions = {items[0]['doc_name']: [] for items in data} self.model.eval() for batch in data: # each document is a minibatch if self.args.multi_instance: inputs = self.minibatch2input(batch, predict=True, topk=n_best) else: inputs = self.minibatch2input(batch, predict=True) lctx_ids = inputs['s_ltoken_ids'] rctx_ids = inputs['s_rtoken_ids'] m_ids = inputs['s_mtoken_ids'] scores = self.model.forward(inputs, gold=inputs['true_pos'].view(-1, 1), inference=inference) scores = scores.cpu().data.numpy() # print out relation weights if self.args.mode == 'eval' and self.args.print_rel: lctx__ids = [m['snd_ctx'][0] for m in batch] rctx_ids = [m['snd_ctx'][1] for m in batch] m_ids = [m['snd_ment'] for m in batch] print('================================') weights = self.model._rel_ctx_ctx_weights.cpu().data.numpy() voca = self.model.snd_word_voca for i in range(len(batch)): print( ' '.join([voca.id2word[id] for id in lctx_ids[i]]), utils.tokgreen(' '.join( [voca.id2word[id] for id in m_ids[i]])), ' '.join([voca.id2word[id] for id in rctx_ids[i]])) for j in range(len(batch)): if i == j: continue np.set_printoptions(precision=2) print( '\t', weights[:, i, j], '\t', ' '.join([voca.id2word[id] for id in lctx_ids[j]]), utils.tokgreen(' '.join( [voca.id2word[id] for id in m_ids[j]])), ' '.join([voca.id2word[id] for id in rctx_ids[j]])) doc_names = [m['doc_name'] for m in batch] if n_best >= 1: if n_best == 1: pred_ids = np.argmax(scores, axis=1) else: pred_scores, pred_ids = torch.topk(torch.Tensor(scores), k=min( n_best, scores.shape[1])) pred_scores = pred_scores.numpy() pred_ids = pred_ids.numpy() true_pos = inputs['true_pos'].cpu().data.numpy() if with_oracle: pred_ids = [ t if t in ids else ids[0] for (ids, t) in zip(pred_ids, true_pos) ] else: best_ids = [] p_e_m = inputs['p_e_m'].cpu().data.numpy() p_e_ent_net = inputs['p_e_ent_net'].cpu().data.numpy() for i, m in enumerate(batch): names = [ m['selected_cands']['named_cands'][j] if m['selected_cands']['mask'][j] == 1 else 'UNNOWN' for j in pred_ids[i, :] ] mention = m['raw']['mention'] strsim = [ Levenshtein.ratio( n.replace('_', ' ').lower(), mention.lower()) for n in names ] final_score = pred_scores[i] + np.array( strsim) * alpha + p_e_m[i][ pred_ids[i]] * beta + p_e_ent_net[i][ pred_ids[i]] * gamma #final_score = np.array(strsim) * alpha + p_e_m[i][pred_ids[i]] * beta + p_e_ent_net[i][pred_ids[i]] * gamma best_ids.append(pred_ids[i, np.argmax(final_score)]) if self.args.print_incorrect: if true_pos[i] in pred_ids[ i] and true_pos[i] != best_ids[-1]: print( '-----------------------------------------------------' ) pprint(m['raw']) print('true pos', true_pos[i]) names = [ m['selected_cands']['named_cands'][j] if m['selected_cands']['mask'][j] == 1 else 'UNKNOWN' for j in pred_ids[i, :] ] pprint( list( zip(pred_ids[i], pred_scores[i], names))) print(strsim, p_e_m[i][pred_ids[i]], pred_scores[i]) print(final_score) pred_ids = best_ids pred_entities = [ m['selected_cands']['named_cands'][i] if m['selected_cands']['mask'][i] == 1 else (m['selected_cands']['named_cands'][0] if m['selected_cands']['mask'][0] == 1 else 'NIL') for (i, m) in zip(pred_ids, batch) ] if self.args.mode == 'eval' and self.args.print_incorrect: gold = [ item['selected_cands']['named_cands'][ item['selected_cands']['true_pos']] if item['selected_cands']['true_pos'] >= 0 else 'UNKNOWN' for item in batch ] pred = pred_entities for i in range(len(gold)): if gold[i] != pred[i]: print( '--------------------------------------------') pprint(batch[i]['raw']) print(gold[i], pred[i]) for dname, entity in zip(doc_names, pred_entities): predictions[dname].append({'pred': (entity, 0.)}) elif n_best == -1: for i, dname in enumerate(doc_names): m = batch[i] pred = [ (m['selected_cands']['named_cands'][j], scores[i, j]) for j in range(len(m['selected_cands']['named_cands'])) if m['selected_cands']['mask'][j] == 1 ] if len(pred) == 0: pred = [('NIL', 1)] predictions[dname].append( sorted(pred, key=lambda x: x[1])[::-1]) return predictions
def train(self, org_train_dataset, org_dev_datasets, config, preranked_train=None, preranked_dev=None): print('extracting training data') if preranked_train is None: train_dataset = self.get_data_items(org_train_dataset, predict=False) else: train_dataset = preranked_train print('#train docs', len(train_dataset)) if preranked_dev is None: dev_datasets = [] for dname, data in org_dev_datasets: dev_datasets.append( (dname, self.get_data_items(data, predict=True))) print(dname, '#dev docs', len(dev_datasets[-1][1])) else: dev_datasets = preranked_dev print('creating optimizer') optimizer = optim.Adam( [p for p in self.model.parameters() if p.requires_grad], lr=config['lr']) best_f1 = -1 not_better_count = 0 is_counting = False stop = False eval_after_n_epochs = self.args.eval_after_n_epochs final_result_str = '' print('total training items', len(train_dataset)) n_updates = 0 if config['multi_instance']: n_updates_to_eval = 1000 n_updates_to_stop = 60000 f1_threshold = 0.875 f1_start_couting = 0.87 elif config['semisup']: n_updates_to_eval = 5000 n_update_to_stop = 1e10 f1_threshold = 0.86 f1_start_couting = 0.86 else: # for supervised learning n_updates_to_eval = 1000 n_updates_to_stop = 1000 * self.args.n_epochs f1_threshold = 0.95 f1_start_couting = 0.95 for e in range(config['n_epochs']): shuffle(train_dataset) total_loss = 0 total = 0 for dc, batch in enumerate( train_dataset): # each document is a minibatch self.model.train() optimizer.zero_grad() tps = [m['selected_cands']['true_pos'] >= 0 for m in batch] any_true = np.any(tps) if any_true: inputs = self.minibatch2input(batch) else: inputs = self.minibatch2input(batch, topk=2) if config['semisup']: if any_true: # from supervision (i.e. CoNLL) scores = self.model.forward( inputs, gold=inputs['true_pos'].view(-1, 1), inference='LBP') else: scores = self.model.forward( inputs, gold=inputs['true_pos'].view(-1, 1), inference='star') else: scores = self.model.forward(inputs, gold=inputs['true_pos'].view( -1, 1)) if any_true: loss = self.model.loss(scores, inputs['true_pos']) else: loss = self.model.multi_instance_loss(scores, inputs) loss.backward() optimizer.step() loss = loss.cpu().data.item() total_loss += loss if dc % 100 == 0: print('epoch', e, "%0.2f%%" % (dc / len(train_dataset) * 100), loss, end='\r') n_updates += 1 if n_updates % n_updates_to_eval == 0: # only continue if the best f1 is larger than if n_updates >= n_updates_to_stop and best_f1 < f1_threshold: stop = True print( 'this initialization is not good. Run another one... STOP' ) break print('\n--------------------') dev_f1 = 0 results = '' for di, (dname, data) in enumerate(dev_datasets): #if dname == 'aida-A': #dname != '': if dname != '': #a = 0.1 #b = 1. #c = 0.95 #global alpha #global beta #global gamma #alpha = a #beta = b #gamma = c predictions = self.predict(data, n_best=n_best) cats = None # **YD** only ignore .conll for reddit data # if'cat' in data[0][0]['raw']['conll_m']: if 'reddit' not in dname and 'cat' in data[0][0][ 'raw']['conll_m']: cats = [] for doc in data: cats += [ m['raw']['conll_m']['cat'] for m in doc ] # **YD** change output of D.eval to include prec, rec and f1 """ if cats is None: f1 = D.eval(org_dev_datasets[di][1], predictions) else: f1 = D.eval(org_dev_datasets[di][1], predictions, cats) #print(alpha, beta, gamma, dname, utils.tokgreen('micro F1: ' + str(f1))) print(dname, utils.tokgreen('micro F1: ' + str(f1))) results += dname + '\t' + utils.tokgreen('micro F1: ' + str(f1)) + '\n' """ if cats is None: f1, out_s = D.eval(org_dev_datasets[di][1], predictions) else: f1, out_s = D.eval(org_dev_datasets[di][1], predictions, cats) print(dname, utils.tokgreen(out_s)) results = dname + '\t' + utils.tokgreen( out_s) + '\n' if dname == 'aida-A': dev_f1 = f1 continue if config['multi_instance']: predictions = self.predict(data, n_best=n_best) else: # including semisup predictions = self.predict(data, n_best=1) cats = None if 'cat' in data[0][0]['raw']['conll_m']: cats = [] for doc in data: cats += [ m['raw']['conll_m']['cat'] for m in doc ] if cats is None: f1 = D.eval(org_dev_datasets[di][1], predictions) else: f1, tab = D.eval(org_dev_datasets[di][1], predictions, cats) pprint(tab) print(dname, utils.tokgreen('micro F1: ' + str(f1))) if dev_f1 >= best_f1 and dev_f1 >= f1_start_couting: # 0.82 (for weak supervised learning alone) is_counting = True not_better_count = 0 if is_counting: if dev_f1 < best_f1: not_better_count += 1 print('not dev f1 inc after', not_better_count) else: final_result_str = results not_better_count = 0 best_f1 = dev_f1 if self.args.model_path is not None: print('save model to', self.args.model_path) self.model.save(self.args.model_path) if not_better_count == self.args.n_not_inc: print('dev f1 not inc after', not_better_count, '... STOP') stop = True break print('epoch', e, 'total loss', total_loss, total_loss / len(train_dataset)) if stop: print('**********************************************') print('best results (f1 on aida-A):') print(final_result_str) break
def predict(self, data): predictions = {items[0]['doc_name']: [] for items in data} self.model.eval() for batch in data: # each document is a minibatch token_ids = [ m['context'][0] + m['context'][1] if len(m['context'][0]) + len(m['context'][1]) > 0 else [self.model.word_voca.unk_id] for m in batch ] s_ltoken_ids = [m['snd_ctx'][0] for m in batch] s_rtoken_ids = [m['snd_ctx'][1] for m in batch] s_mtoken_ids = [m['snd_ment'] for m in batch] lctx_ids = s_ltoken_ids rctx_ids = s_rtoken_ids m_ids = s_mtoken_ids entity_ids = Variable( torch.LongTensor([m['selected_cands']['cands'] for m in batch]).cuda()) p_e_m = Variable( torch.FloatTensor( [m['selected_cands']['p_e_m'] for m in batch]).cuda()) entity_mask = Variable( torch.FloatTensor([m['selected_cands']['mask'] for m in batch]).cuda()) true_pos = Variable( torch.LongTensor( [m['selected_cands']['true_pos'] for m in batch]).cuda()) token_ids, token_mask = utils.make_equal_len( token_ids, self.model.word_voca.unk_id) s_ltoken_ids, s_ltoken_mask = utils.make_equal_len( s_ltoken_ids, self.model.snd_word_voca.unk_id, to_right=False) s_rtoken_ids, s_rtoken_mask = utils.make_equal_len( s_rtoken_ids, self.model.snd_word_voca.unk_id) s_rtoken_ids = [l[::-1] for l in s_rtoken_ids] s_rtoken_mask = [l[::-1] for l in s_rtoken_mask] s_mtoken_ids, s_mtoken_mask = utils.make_equal_len( s_mtoken_ids, self.model.snd_word_voca.unk_id) token_ids = Variable(torch.LongTensor(token_ids).cuda()) token_mask = Variable(torch.FloatTensor(token_mask).cuda()) # too ugly, but too lazy to fix it self.model.s_ltoken_ids = Variable( torch.LongTensor(s_ltoken_ids).cuda()) self.model.s_ltoken_mask = Variable( torch.FloatTensor(s_ltoken_mask).cuda()) self.model.s_rtoken_ids = Variable( torch.LongTensor(s_rtoken_ids).cuda()) self.model.s_rtoken_mask = Variable( torch.FloatTensor(s_rtoken_mask).cuda()) self.model.s_mtoken_ids = Variable( torch.LongTensor(s_mtoken_ids).cuda()) self.model.s_mtoken_mask = Variable( torch.FloatTensor(s_mtoken_mask).cuda()) scores = self.model.forward(token_ids, token_mask, entity_ids, entity_mask, p_e_m, gold=true_pos.view(-1, 1)) scores = scores.cpu().data.numpy() # print out relation weights if self.args.mode == 'eval' and self.args.print_rel: print('================================') weights = self.model._rel_ctx_ctx_weights.cpu().data.numpy() voca = self.model.snd_word_voca for i in range(len(batch)): print( ' '.join([voca.id2word[id] for id in lctx_ids[i]]), utils.tokgreen(' '.join( [voca.id2word[id] for id in m_ids[i]])), ' '.join([voca.id2word[id] for id in rctx_ids[i]])) for j in range(len(batch)): if i == j: continue np.set_printoptions(precision=2) print( '\t', weights[:, i, j], '\t', ' '.join([voca.id2word[id] for id in lctx_ids[j]]), utils.tokgreen(' '.join( [voca.id2word[id] for id in m_ids[j]])), ' '.join([voca.id2word[id] for id in rctx_ids[j]])) pred_ids = np.argmax(scores, axis=1) pred_entities = [ m['selected_cands']['named_cands'][i] if m['selected_cands']['mask'][i] == 1 else (m['selected_cands']['named_cands'][0] if m['selected_cands']['mask'][0] == 1 else 'NI') for (i, m) in zip(pred_ids, batch) ] # print(pred_entities) doc_names = [m['doc_name'] for m in batch] if self.args.mode == 'eval' and self.args.print_incorrect: gold = [ item['selected_cands']['named_cands'][ item['selected_cands']['true_pos']] if item['selected_cands']['true_pos'] >= 0 else 'UNKNOWN' for item in batch ] pred = pred_entities for i in range(len(gold)): if gold[i] != pred[i]: print('--------------------------------------------') # pprint(batch[i]['raw']) print(gold[i], pred[i]) print(pred_ids[i], scores[i]) for dname, entity in zip(doc_names, pred_entities): predictions[dname].append({'pred': (entity, 0.)}) return predictions
def train(self, org_train_dataset, org_dev_datasets, config): print('extracting training data') train_dataset = self.get_data_items(org_train_dataset, predict=False) print('#train docs', len(train_dataset)) dev_datasets = [] for dname, data in org_dev_datasets: dev_datasets.append((dname, self.get_data_items(data, predict=True))) print(dname, '#dev docs', len(dev_datasets[-1][1])) print('creating optimizer') optimizer = optim.Adam( [p for p in self.model.parameters() if p.requires_grad], lr=config['lr']) best_f1 = -1 not_better_count = 0 is_counting = False eval_after_n_epochs = self.args.eval_after_n_epochs for e in range(config['n_epochs']): shuffle(train_dataset) total_loss = 0 for dc, batch in enumerate( train_dataset): # each document is a minibatch self.model.train() optimizer.zero_grad() # convert data items to pytorch inputs token_ids = [ m['context'][0] + m['context'][1] if len(m['context'][0]) + len(m['context'][1]) > 0 else [self.model.word_voca.unk_id] for m in batch ] s_ltoken_ids = [m['snd_ctx'][0] for m in batch] s_rtoken_ids = [m['snd_ctx'][1] for m in batch] s_mtoken_ids = [m['snd_ment'] for m in batch] entity_ids = Variable( torch.LongTensor( [m['selected_cands']['cands'] for m in batch]).cuda()) true_pos = Variable( torch.LongTensor([ m['selected_cands']['true_pos'] for m in batch ]).cuda()) p_e_m = Variable( torch.FloatTensor( [m['selected_cands']['p_e_m'] for m in batch]).cuda()) entity_mask = Variable( torch.FloatTensor( [m['selected_cands']['mask'] for m in batch]).cuda()) token_ids, token_mask = utils.make_equal_len( token_ids, self.model.word_voca.unk_id) s_ltoken_ids, s_ltoken_mask = utils.make_equal_len( s_ltoken_ids, self.model.snd_word_voca.unk_id, to_right=False) s_rtoken_ids, s_rtoken_mask = utils.make_equal_len( s_rtoken_ids, self.model.snd_word_voca.unk_id) s_rtoken_ids = [l[::-1] for l in s_rtoken_ids] s_rtoken_mask = [l[::-1] for l in s_rtoken_mask] s_mtoken_ids, s_mtoken_mask = utils.make_equal_len( s_mtoken_ids, self.model.snd_word_voca.unk_id) token_ids = Variable(torch.LongTensor(token_ids).cuda()) token_mask = Variable(torch.FloatTensor(token_mask).cuda()) # too ugly but too lazy to fix it self.model.s_ltoken_ids = Variable( torch.LongTensor(s_ltoken_ids).cuda()) self.model.s_ltoken_mask = Variable( torch.FloatTensor(s_ltoken_mask).cuda()) self.model.s_rtoken_ids = Variable( torch.LongTensor(s_rtoken_ids).cuda()) self.model.s_rtoken_mask = Variable( torch.FloatTensor(s_rtoken_mask).cuda()) self.model.s_mtoken_ids = Variable( torch.LongTensor(s_mtoken_ids).cuda()) self.model.s_mtoken_mask = Variable( torch.FloatTensor(s_mtoken_mask).cuda()) scores = self.model.forward(token_ids, token_mask, entity_ids, entity_mask, p_e_m, gold=true_pos.view(-1, 1)) loss = self.model.loss(scores, true_pos) loss.backward() optimizer.step() # self.model.regularize(max_norm=100) loss = loss.cpu().data.numpy() total_loss += loss print('epoch', e, "%0.2f%%" % (dc / len(train_dataset) * 100), loss, end='\r') print('epoch', e, 'total loss', total_loss, total_loss / len(train_dataset)) if (e + 1) % eval_after_n_epochs == 0: dev_f1 = 0 for di, (dname, data) in enumerate(dev_datasets): predictions = self.predict(data) f1 = D.eval(org_dev_datasets[di][1], predictions) print(dname, utils.tokgreen('micro F1: ' + str(f1))) if dname == 'aida-A': dev_f1 = f1 if config[ 'lr'] == 1e-4 and dev_f1 >= self.args.dev_f1_change_lr: eval_after_n_epochs = 2 is_counting = True best_f1 = dev_f1 not_better_count = 0 config['lr'] = 1e-5 print('change learning rate to', config['lr']) if self.args.mulrel_type == 'rel-norm': optimizer = optim.Adam([ p for p in self.model.parameters() if p.requires_grad ], lr=config['lr']) elif self.args.mulrel_type == 'ment-norm': for param_group in optimizer.param_groups: param_group['lr'] = config['lr'] if is_counting: if dev_f1 < best_f1: not_better_count += 1 else: not_better_count = 0 best_f1 = dev_f1 print('save model to', self.args.model_path) self.model.save(self.args.model_path) if not_better_count == self.args.n_not_inc: break self.model.print_weight_norm()
preranked_train = preranked_train[:min(args.n_docs, len(preranked_train))] org_dev_datasets = [(all_datasets[i][0], all_datasets[i][1]) for i in range(1, len(all_datasets))] dev_datasets = [(all_datasets[i][0], all_datasets[i][2]) for i in range(1, len(all_datasets))] dev_datasets = [(all_datasets[i][0], all_datasets[i][1], all_datasets[2]) for i in range(1, len(all_datasets))] for di, (dname, data, preranked) in enumerate(dev_datasets): ranker.model._coh_ctx_vecs = [] predictions = ranker.predict(preranked) print( dname, utils.tokgreen('micro F1: ' + str(D.eval(data, predictions)))) elif args.mode == 'ed': with open(args.filelist, 'r') as flist: for fname in flist: fname = fname.strip() print('load file from', fname) conll_path = fname cands_path = conll_path + '.csv' data = D.CoNLLDataset.load_file(conll_path, cands_path, person_path) data = ranker.get_data_items(data, predict=True) print('#docs', len(data)) continue
('reddit2020silver', conll.reddit2020silver), ('reddit2020g_s', conll.reddit2020g_s), ] if args.mode == 'train': print('training...') config = {'lr': args.learning_rate, 'n_epochs': args.n_epochs} pprint(config) ranker.train(conll.train, dev_datasets, config) elif args.mode == 'eval': org_dev_datasets = dev_datasets # + [('aida-train', conll.train)] dev_datasets = [] for dname, data in org_dev_datasets: dev_datasets.append( (dname, ranker.get_data_items(data, predict=True))) print(dname, '#dev docs', len(dev_datasets[-1][1])) vecs = ranker.model.rel_embs.cpu().data.numpy() for di, (dname, data) in enumerate(dev_datasets): ranker.model._coh_ctx_vecs = [] predictions = ranker.predict(data) # **YD** change output of D.eval to include prec, rec and f1 """ print(dname, utils.tokgreen('micro F1: ' + str(D.eval(org_dev_datasets[di][1], predictions)))) """ f1, out_s = D.eval(data, predictions) print(dname, utils.tokgreen(out_s))
def predict(self, data): predictions = {items[0]['doc_name']: [] for items in data} self.model.eval() for batch in data: # each document is a minibatch token_ids = [ m['context'][0] + m['context'][1] if len(m['context'][0]) + len(m['context'][1]) > 0 else [self.model.word_voca.unk_id] for m in batch ] s_ltoken_ids = [m['snd_ctx'][0] for m in batch] s_rtoken_ids = [m['snd_ctx'][1] for m in batch] s_mtoken_ids = [m['snd_ment'] for m in batch] lctx_ids = s_ltoken_ids rctx_ids = s_rtoken_ids m_ids = s_mtoken_ids entity_ids = Variable( torch.LongTensor([m['selected_cands']['cands'] for m in batch]).to(device)) p_e_m = Variable( torch.FloatTensor( [m['selected_cands']['p_e_m'] for m in batch]).to(device)) entity_mask = Variable( torch.FloatTensor([m['selected_cands']['mask'] for m in batch]).to(device)) true_pos = Variable( torch.LongTensor([ m['selected_cands']['true_pos'] for m in batch ]).to(device)) token_ids, token_mask = utils.make_equal_len( token_ids, self.model.word_voca.unk_id) s_ltoken_ids, s_ltoken_mask = utils.make_equal_len( s_ltoken_ids, self.model.snd_word_voca.unk_id, to_right=False) s_rtoken_ids, s_rtoken_mask = utils.make_equal_len( s_rtoken_ids, self.model.snd_word_voca.unk_id) s_rtoken_ids = [l[::-1] for l in s_rtoken_ids] s_rtoken_mask = [l[::-1] for l in s_rtoken_mask] s_mtoken_ids, s_mtoken_mask = utils.make_equal_len( s_mtoken_ids, self.model.snd_word_voca.unk_id) token_ids = Variable(torch.LongTensor(token_ids).to(device)) token_mask = Variable(torch.FloatTensor(token_mask).to(device)) # too ugly, but too lazy to fix it self.model.s_ltoken_ids = Variable( torch.LongTensor(s_ltoken_ids).to(device)) self.model.s_ltoken_mask = Variable( torch.FloatTensor(s_ltoken_mask).to(device)) self.model.s_rtoken_ids = Variable( torch.LongTensor(s_rtoken_ids).to(device)) self.model.s_rtoken_mask = Variable( torch.FloatTensor(s_rtoken_mask).to(device)) self.model.s_mtoken_ids = Variable( torch.LongTensor(s_mtoken_ids).to(device)) self.model.s_mtoken_mask = Variable( torch.FloatTensor(s_mtoken_mask).to(device)) scores = self.model.forward(token_ids, token_mask, entity_ids, entity_mask, p_e_m, gold=true_pos.view(-1, 1)) scores = scores.cpu().data.numpy() # print out relation weights if self.args.mode == 'eval' and self.args.print_rel: print('================================') weights = self.model._rel_ctx_ctx_weights.cpu().data.numpy() voca = self.model.snd_word_voca for i in range(len(batch)): print( ' '.join([voca.id2word[id] for id in lctx_ids[i]]), utils.tokgreen(' '.join( [voca.id2word[id] for id in m_ids[i]])), ' '.join([voca.id2word[id] for id in rctx_ids[i]])) for j in range(len(batch)): if i == j: continue np.set_printoptions(precision=2) print( '\t', weights[:, i, j], '\t', ' '.join([voca.id2word[id] for id in lctx_ids[j]]), utils.tokgreen(' '.join( [voca.id2word[id] for id in m_ids[j]])), ' '.join([voca.id2word[id] for id in rctx_ids[j]])) pred_ids = np.argmax(scores, axis=1) processed_scores = [] for i, score in enumerate(scores): min_score = np.min(score) if all(list(map(lambda x: x == min_score, score))): processed_scores.append(scores[i].tolist()) continue count = 0 for j, s in enumerate(score): if s == min_score: scores[i][j] = np.NINF count += 1 non_inf_list = list( filter(lambda x: x != float('-inf'), scores[i])) mean = np.mean(non_inf_list) processed_scores.append( ((scores[i] - mean) * len(non_inf_list) * 10).tolist()) e_x = np.exp(processed_scores - np.amax(processed_scores, axis=1, keepdims=True)) softmax_scores = e_x / e_x.sum(axis=1, keepdims=True) pred_scores = np.amax(np.array(processed_scores), axis=1) pred_confidences = np.amax(softmax_scores, axis=1) pred_entities = [ m['selected_cands']['named_cands'][i] if m['selected_cands']['mask'][i] == 1 else (m['selected_cands']['named_cands'][0] if m['selected_cands']['mask'][0] == 1 else 'NIL') for (i, m) in zip(pred_ids, batch) ] doc_names = [m['doc_name'] for m in batch] if self.args.mode == 'eval' and self.args.print_incorrect: gold = [ item['selected_cands']['named_cands'][ item['selected_cands']['true_pos']] if item['selected_cands']['true_pos'] >= 0 else 'UNKNOWN' for item in batch ] pred = pred_entities for i in range(len(gold)): if gold[i] != pred[i]: print('--------------------------------------------') pprint(batch[i]['raw']) print(gold[i], pred[i]) for dname, entity, pred_score, pred_confidence in zip( doc_names, pred_entities, pred_scores, pred_confidences): predictions[dname].append({ 'pred': (entity, 0.), 'score': pred_score, 'confidence': pred_confidence }) return predictions