def minibatch2input(self, batch, predict=False, topk=None): if topk == None: topk = 10000 topk = min(topk, len(batch[0]['selected_cands']['cands'])) n_ments = len(batch) # only uisng negative samples when the document doesn't have any supervision (i.e. not CoNLL) tps = [m['selected_cands']['true_pos'] >= 0 for m in batch] if not predict and (self.args.multi_instance or self.args.semisup) and not np.any(tps): n_negs = self.args.n_negs else: n_negs = 0 # convert data items to pytorch inputs token_ids = [ m['context'][0] + m['context'][1] if len(m['context'][0]) + len(m['context'][1]) > 0 else [self.model.word_voca.unk_id] for m in batch ] s_ltoken_ids = [m['snd_ctx'][0] for m in batch] s_rtoken_ids = [m['snd_ctx'][1] for m in batch] s_mtoken_ids = [m['snd_ment'] for m in batch] entity_ids = torch.LongTensor( [m['selected_cands']['cands'][:topk] for m in batch]) p_e_m = torch.FloatTensor( [m['selected_cands']['p_e_m'][:topk] for m in batch]) entity_mask = torch.FloatTensor( [m['selected_cands']['mask'][:topk] for m in batch]) true_pos = torch.LongTensor([ m['selected_cands']['true_pos'] if m['selected_cands']['true_pos'] < topk else -1 for m in batch ]) p_e_ent_net = torch.FloatTensor([ m['selected_cands']['p_e_ent_net'][:topk] for m in batch ]) if len(batch) > 1 else torch.zeros(1, entity_ids.shape[1]) if n_negs > 0: # add n_negs negative samples at the beginning of lists def ent_neg_sample(neg_cands_p_e_m, exclusive): sample_ids = np.random.choice(len(neg_cands_p_e_m), n_negs * 10) all_samples = list( zip( np.array([s[0] for s in neg_cands_p_e_m ])[sample_ids].astype(int), np.array([s[1] for s in neg_cands_p_e_m])[sample_ids])) exclusive = set(exclusive) samples = [] for s in all_samples: if s[0] not in exclusive: samples.append(s) if len(samples) < n_negs: samples = samples + [(self.model.entity_voca.unk_id, 1e-3) ] * (n_negs - len(samples)) else: shuffle(samples) samples = samples[:n_negs] return np.array([s[0] for s in samples ]), np.array([s[1] for s in samples]) neg_cands_p_e_m = [list(zip(list(m['cands']), list(m['p_e_m']))) + \ (list(zip(list(m['neg_cands'], [1e-3] * len(m['neg_cands'])))) if len(m['cands']) <= topk else []) for m in batch] neg_cands_p_e_m = [ ent_neg_sample(si, entity_ids_i) for si, entity_ids_i in zip( neg_cands_p_e_m, entity_ids.numpy()) ] neg_entity_ids = torch.Tensor( [si[0].astype(float) for si in neg_cands_p_e_m]).long() neg_p_e_m = torch.Tensor( [si[1].astype(float) for si in neg_cands_p_e_m]) neg_entity_mask = torch.ones(n_ments, n_negs) entity_ids = torch.cat([neg_entity_ids, entity_ids], dim=1) entity_mask = torch.cat([neg_entity_mask, entity_mask], dim=1) p_e_m = torch.cat([neg_p_e_m, p_e_m], dim=1) true_pos = true_pos.add_(n_negs) entity_ids = Variable(entity_ids.cuda()) true_pos = Variable(true_pos.cuda()) p_e_m = Variable(p_e_m.cuda()) p_e_ent_net = Variable(p_e_ent_net.cuda()) entity_mask = Variable(entity_mask.cuda()) token_ids, token_mask = utils.make_equal_len( token_ids, self.model.word_voca.unk_id) s_ltoken_ids, s_ltoken_mask = utils.make_equal_len( s_ltoken_ids, self.model.snd_word_voca.unk_id, to_right=False) s_rtoken_ids, s_rtoken_mask = utils.make_equal_len( s_rtoken_ids, self.model.snd_word_voca.unk_id) s_rtoken_ids = [l[::-1] for l in s_rtoken_ids] s_rtoken_mask = [l[::-1] for l in s_rtoken_mask] s_mtoken_ids, s_mtoken_mask = utils.make_equal_len( s_mtoken_ids, self.model.snd_word_voca.unk_id) token_ids = Variable(torch.LongTensor(token_ids).cuda()) token_mask = Variable(torch.FloatTensor(token_mask).cuda()) s_ltoken_ids = Variable(torch.LongTensor(s_ltoken_ids).cuda()) s_ltoken_mask = Variable(torch.FloatTensor(s_ltoken_mask).cuda()) s_rtoken_ids = Variable(torch.LongTensor(s_rtoken_ids).cuda()) s_rtoken_mask = Variable(torch.FloatTensor(s_rtoken_mask).cuda()) s_mtoken_ids = Variable(torch.LongTensor(s_mtoken_ids).cuda()) s_mtoken_mask = Variable(torch.FloatTensor(s_mtoken_mask).cuda()) ret = { 'token_ids': token_ids, 'token_mask': token_mask, 'entity_ids': entity_ids, 'entity_mask': entity_mask, 'p_e_m': p_e_m, 'p_e_ent_net': p_e_ent_net, 'true_pos': true_pos, 's_ltoken_ids': s_ltoken_ids, 's_ltoken_mask': s_ltoken_mask, 's_rtoken_ids': s_rtoken_ids, 's_rtoken_mask': s_rtoken_mask, 's_mtoken_ids': s_mtoken_ids, 's_mtoken_mask': s_mtoken_mask, 'n_negs': n_negs } return ret
def predict(self, data): predictions = {items[0]['doc_name']: [] for items in data} self.model.eval() for batch in data: # each document is a minibatch token_ids = [ m['context'][0] + m['context'][1] if len(m['context'][0]) + len(m['context'][1]) > 0 else [self.model.word_voca.unk_id] for m in batch ] s_ltoken_ids = [m['snd_ctx'][0] for m in batch] s_rtoken_ids = [m['snd_ctx'][1] for m in batch] s_mtoken_ids = [m['snd_ment'] for m in batch] lctx_ids = s_ltoken_ids rctx_ids = s_rtoken_ids m_ids = s_mtoken_ids entity_ids = Variable( torch.LongTensor([m['selected_cands']['cands'] for m in batch]).cuda()) p_e_m = Variable( torch.FloatTensor( [m['selected_cands']['p_e_m'] for m in batch]).cuda()) entity_mask = Variable( torch.FloatTensor([m['selected_cands']['mask'] for m in batch]).cuda()) true_pos = Variable( torch.LongTensor( [m['selected_cands']['true_pos'] for m in batch]).cuda()) token_ids, token_mask = utils.make_equal_len( token_ids, self.model.word_voca.unk_id) s_ltoken_ids, s_ltoken_mask = utils.make_equal_len( s_ltoken_ids, self.model.snd_word_voca.unk_id, to_right=False) s_rtoken_ids, s_rtoken_mask = utils.make_equal_len( s_rtoken_ids, self.model.snd_word_voca.unk_id) s_rtoken_ids = [l[::-1] for l in s_rtoken_ids] s_rtoken_mask = [l[::-1] for l in s_rtoken_mask] s_mtoken_ids, s_mtoken_mask = utils.make_equal_len( s_mtoken_ids, self.model.snd_word_voca.unk_id) token_ids = Variable(torch.LongTensor(token_ids).cuda()) token_mask = Variable(torch.FloatTensor(token_mask).cuda()) # too ugly, but too lazy to fix it self.model.s_ltoken_ids = Variable( torch.LongTensor(s_ltoken_ids).cuda()) self.model.s_ltoken_mask = Variable( torch.FloatTensor(s_ltoken_mask).cuda()) self.model.s_rtoken_ids = Variable( torch.LongTensor(s_rtoken_ids).cuda()) self.model.s_rtoken_mask = Variable( torch.FloatTensor(s_rtoken_mask).cuda()) self.model.s_mtoken_ids = Variable( torch.LongTensor(s_mtoken_ids).cuda()) self.model.s_mtoken_mask = Variable( torch.FloatTensor(s_mtoken_mask).cuda()) scores = self.model.forward(token_ids, token_mask, entity_ids, entity_mask, p_e_m, gold=true_pos.view(-1, 1)) scores = scores.cpu().data.numpy() # print out relation weights if self.args.mode == 'eval' and self.args.print_rel: print('================================') weights = self.model._rel_ctx_ctx_weights.cpu().data.numpy() voca = self.model.snd_word_voca for i in range(len(batch)): print( ' '.join([voca.id2word[id] for id in lctx_ids[i]]), utils.tokgreen(' '.join( [voca.id2word[id] for id in m_ids[i]])), ' '.join([voca.id2word[id] for id in rctx_ids[i]])) for j in range(len(batch)): if i == j: continue np.set_printoptions(precision=2) print( '\t', weights[:, i, j], '\t', ' '.join([voca.id2word[id] for id in lctx_ids[j]]), utils.tokgreen(' '.join( [voca.id2word[id] for id in m_ids[j]])), ' '.join([voca.id2word[id] for id in rctx_ids[j]])) pred_ids = np.argmax(scores, axis=1) pred_entities = [ m['selected_cands']['named_cands'][i] if m['selected_cands']['mask'][i] == 1 else (m['selected_cands']['named_cands'][0] if m['selected_cands']['mask'][0] == 1 else 'NI') for (i, m) in zip(pred_ids, batch) ] # print(pred_entities) doc_names = [m['doc_name'] for m in batch] if self.args.mode == 'eval' and self.args.print_incorrect: gold = [ item['selected_cands']['named_cands'][ item['selected_cands']['true_pos']] if item['selected_cands']['true_pos'] >= 0 else 'UNKNOWN' for item in batch ] pred = pred_entities for i in range(len(gold)): if gold[i] != pred[i]: print('--------------------------------------------') # pprint(batch[i]['raw']) print(gold[i], pred[i]) print(pred_ids[i], scores[i]) for dname, entity in zip(doc_names, pred_entities): predictions[dname].append({'pred': (entity, 0.)}) return predictions
def predict(self, data): predictions = {items[0]['doc_name']: [] for items in data} self.model.eval() for batch in data: # each document is a minibatch token_ids = [ m['context'][0] + m['context'][1] if len(m['context'][0]) + len(m['context'][1]) > 0 else [self.model.word_voca.unk_id] for m in batch ] s_ltoken_ids = [m['snd_ctx'][0] for m in batch] s_rtoken_ids = [m['snd_ctx'][1] for m in batch] s_mtoken_ids = [m['snd_ment'] for m in batch] lctx_ids = s_ltoken_ids rctx_ids = s_rtoken_ids m_ids = s_mtoken_ids entity_ids = Variable( torch.LongTensor([m['selected_cands']['cands'] for m in batch]).to(device)) p_e_m = Variable( torch.FloatTensor( [m['selected_cands']['p_e_m'] for m in batch]).to(device)) entity_mask = Variable( torch.FloatTensor([m['selected_cands']['mask'] for m in batch]).to(device)) true_pos = Variable( torch.LongTensor([ m['selected_cands']['true_pos'] for m in batch ]).to(device)) token_ids, token_mask = utils.make_equal_len( token_ids, self.model.word_voca.unk_id) s_ltoken_ids, s_ltoken_mask = utils.make_equal_len( s_ltoken_ids, self.model.snd_word_voca.unk_id, to_right=False) s_rtoken_ids, s_rtoken_mask = utils.make_equal_len( s_rtoken_ids, self.model.snd_word_voca.unk_id) s_rtoken_ids = [l[::-1] for l in s_rtoken_ids] s_rtoken_mask = [l[::-1] for l in s_rtoken_mask] s_mtoken_ids, s_mtoken_mask = utils.make_equal_len( s_mtoken_ids, self.model.snd_word_voca.unk_id) token_ids = Variable(torch.LongTensor(token_ids).to(device)) token_mask = Variable(torch.FloatTensor(token_mask).to(device)) # too ugly, but too lazy to fix it self.model.s_ltoken_ids = Variable( torch.LongTensor(s_ltoken_ids).to(device)) self.model.s_ltoken_mask = Variable( torch.FloatTensor(s_ltoken_mask).to(device)) self.model.s_rtoken_ids = Variable( torch.LongTensor(s_rtoken_ids).to(device)) self.model.s_rtoken_mask = Variable( torch.FloatTensor(s_rtoken_mask).to(device)) self.model.s_mtoken_ids = Variable( torch.LongTensor(s_mtoken_ids).to(device)) self.model.s_mtoken_mask = Variable( torch.FloatTensor(s_mtoken_mask).to(device)) scores = self.model.forward(token_ids, token_mask, entity_ids, entity_mask, p_e_m, gold=true_pos.view(-1, 1)) scores = scores.cpu().data.numpy() # print out relation weights if self.args.mode == 'eval' and self.args.print_rel: print('================================') weights = self.model._rel_ctx_ctx_weights.cpu().data.numpy() voca = self.model.snd_word_voca for i in range(len(batch)): print( ' '.join([voca.id2word[id] for id in lctx_ids[i]]), utils.tokgreen(' '.join( [voca.id2word[id] for id in m_ids[i]])), ' '.join([voca.id2word[id] for id in rctx_ids[i]])) for j in range(len(batch)): if i == j: continue np.set_printoptions(precision=2) print( '\t', weights[:, i, j], '\t', ' '.join([voca.id2word[id] for id in lctx_ids[j]]), utils.tokgreen(' '.join( [voca.id2word[id] for id in m_ids[j]])), ' '.join([voca.id2word[id] for id in rctx_ids[j]])) pred_ids = np.argmax(scores, axis=1) processed_scores = [] for i, score in enumerate(scores): min_score = np.min(score) if all(list(map(lambda x: x == min_score, score))): processed_scores.append(scores[i].tolist()) continue count = 0 for j, s in enumerate(score): if s == min_score: scores[i][j] = np.NINF count += 1 non_inf_list = list( filter(lambda x: x != float('-inf'), scores[i])) mean = np.mean(non_inf_list) processed_scores.append( ((scores[i] - mean) * len(non_inf_list) * 10).tolist()) e_x = np.exp(processed_scores - np.amax(processed_scores, axis=1, keepdims=True)) softmax_scores = e_x / e_x.sum(axis=1, keepdims=True) pred_scores = np.amax(np.array(processed_scores), axis=1) pred_confidences = np.amax(softmax_scores, axis=1) pred_entities = [ m['selected_cands']['named_cands'][i] if m['selected_cands']['mask'][i] == 1 else (m['selected_cands']['named_cands'][0] if m['selected_cands']['mask'][0] == 1 else 'NIL') for (i, m) in zip(pred_ids, batch) ] doc_names = [m['doc_name'] for m in batch] if self.args.mode == 'eval' and self.args.print_incorrect: gold = [ item['selected_cands']['named_cands'][ item['selected_cands']['true_pos']] if item['selected_cands']['true_pos'] >= 0 else 'UNKNOWN' for item in batch ] pred = pred_entities for i in range(len(gold)): if gold[i] != pred[i]: print('--------------------------------------------') pprint(batch[i]['raw']) print(gold[i], pred[i]) for dname, entity, pred_score, pred_confidence in zip( doc_names, pred_entities, pred_scores, pred_confidences): predictions[dname].append({ 'pred': (entity, 0.), 'score': pred_score, 'confidence': pred_confidence }) return predictions
def train(self, org_train_dataset, org_dev_datasets, config): print('extracting training data') train_dataset = self.get_data_items(org_train_dataset, predict=False) print('#train docs', len(train_dataset)) dev_datasets = [] for dname, data in org_dev_datasets: dev_datasets.append((dname, self.get_data_items(data, predict=True))) print(dname, '#dev docs', len(dev_datasets[-1][1])) print('creating optimizer') optimizer = optim.Adam( [p for p in self.model.parameters() if p.requires_grad], lr=config['lr']) best_f1 = -1 not_better_count = 0 is_counting = False eval_after_n_epochs = self.args.eval_after_n_epochs for e in range(config['n_epochs']): shuffle(train_dataset) total_loss = 0 for dc, batch in enumerate( train_dataset): # each document is a minibatch self.model.train() optimizer.zero_grad() # convert data items to pytorch inputs token_ids = [ m['context'][0] + m['context'][1] if len(m['context'][0]) + len(m['context'][1]) > 0 else [self.model.word_voca.unk_id] for m in batch ] s_ltoken_ids = [m['snd_ctx'][0] for m in batch] s_rtoken_ids = [m['snd_ctx'][1] for m in batch] s_mtoken_ids = [m['snd_ment'] for m in batch] entity_ids = Variable( torch.LongTensor( [m['selected_cands']['cands'] for m in batch]).cuda()) true_pos = Variable( torch.LongTensor([ m['selected_cands']['true_pos'] for m in batch ]).cuda()) p_e_m = Variable( torch.FloatTensor( [m['selected_cands']['p_e_m'] for m in batch]).cuda()) entity_mask = Variable( torch.FloatTensor( [m['selected_cands']['mask'] for m in batch]).cuda()) token_ids, token_mask = utils.make_equal_len( token_ids, self.model.word_voca.unk_id) s_ltoken_ids, s_ltoken_mask = utils.make_equal_len( s_ltoken_ids, self.model.snd_word_voca.unk_id, to_right=False) s_rtoken_ids, s_rtoken_mask = utils.make_equal_len( s_rtoken_ids, self.model.snd_word_voca.unk_id) s_rtoken_ids = [l[::-1] for l in s_rtoken_ids] s_rtoken_mask = [l[::-1] for l in s_rtoken_mask] s_mtoken_ids, s_mtoken_mask = utils.make_equal_len( s_mtoken_ids, self.model.snd_word_voca.unk_id) token_ids = Variable(torch.LongTensor(token_ids).cuda()) token_mask = Variable(torch.FloatTensor(token_mask).cuda()) # too ugly but too lazy to fix it self.model.s_ltoken_ids = Variable( torch.LongTensor(s_ltoken_ids).cuda()) self.model.s_ltoken_mask = Variable( torch.FloatTensor(s_ltoken_mask).cuda()) self.model.s_rtoken_ids = Variable( torch.LongTensor(s_rtoken_ids).cuda()) self.model.s_rtoken_mask = Variable( torch.FloatTensor(s_rtoken_mask).cuda()) self.model.s_mtoken_ids = Variable( torch.LongTensor(s_mtoken_ids).cuda()) self.model.s_mtoken_mask = Variable( torch.FloatTensor(s_mtoken_mask).cuda()) scores = self.model.forward(token_ids, token_mask, entity_ids, entity_mask, p_e_m, gold=true_pos.view(-1, 1)) loss = self.model.loss(scores, true_pos) loss.backward() optimizer.step() # self.model.regularize(max_norm=100) loss = loss.cpu().data.numpy() total_loss += loss print('epoch', e, "%0.2f%%" % (dc / len(train_dataset) * 100), loss, end='\r') print('epoch', e, 'total loss', total_loss, total_loss / len(train_dataset)) if (e + 1) % eval_after_n_epochs == 0: dev_f1 = 0 for di, (dname, data) in enumerate(dev_datasets): predictions = self.predict(data) f1 = D.eval(org_dev_datasets[di][1], predictions) print(dname, utils.tokgreen('micro F1: ' + str(f1))) if dname == 'aida-A': dev_f1 = f1 if config[ 'lr'] == 1e-4 and dev_f1 >= self.args.dev_f1_change_lr: eval_after_n_epochs = 2 is_counting = True best_f1 = dev_f1 not_better_count = 0 config['lr'] = 1e-5 print('change learning rate to', config['lr']) if self.args.mulrel_type == 'rel-norm': optimizer = optim.Adam([ p for p in self.model.parameters() if p.requires_grad ], lr=config['lr']) elif self.args.mulrel_type == 'ment-norm': for param_group in optimizer.param_groups: param_group['lr'] = config['lr'] if is_counting: if dev_f1 < best_f1: not_better_count += 1 else: not_better_count = 0 best_f1 = dev_f1 print('save model to', self.args.model_path) self.model.save(self.args.model_path) if not_better_count == self.args.n_not_inc: break self.model.print_weight_norm()
def predict(self, data, topk=1): predictions = {items[0]['doc_name']: [] for items in data} self.model.eval() for batch in data: # each document is a minibatch token_ids = [m['context'][0] + m['context'][1] if len(m['context'][0]) + len(m['context'][1]) > 0 else [self.model.word_voca.unk_id] for m in batch] s_ltoken_ids = [m['snd_ctx'][0] for m in batch] s_rtoken_ids = [m['snd_ctx'][1] for m in batch] s_mtoken_ids = [m['snd_ment'] for m in batch] lctx_ids = s_ltoken_ids rctx_ids = s_rtoken_ids m_ids = s_mtoken_ids entity_ids = Variable(torch.LongTensor([m['selected_cands']['cands'] for m in batch]).cuda()) p_e_m = Variable(torch.FloatTensor([m['selected_cands']['p_e_m'] for m in batch]).cuda()) entity_mask = Variable(torch.FloatTensor([m['selected_cands']['mask'] for m in batch]).cuda()) true_pos = Variable(torch.LongTensor([m['selected_cands']['true_pos'] for m in batch]).cuda()) token_ids, token_mask = utils.make_equal_len(token_ids, self.model.word_voca.unk_id) s_ltoken_ids, s_ltoken_mask = utils.make_equal_len(s_ltoken_ids, self.model.snd_word_voca.unk_id, to_right=False) s_rtoken_ids, s_rtoken_mask = utils.make_equal_len(s_rtoken_ids, self.model.snd_word_voca.unk_id) s_rtoken_ids = [l[::-1] for l in s_rtoken_ids] s_rtoken_mask = [l[::-1] for l in s_rtoken_mask] s_mtoken_ids, s_mtoken_mask = utils.make_equal_len(s_mtoken_ids, self.model.snd_word_voca.unk_id) token_ids = Variable(torch.LongTensor(token_ids).cuda()) token_mask = Variable(torch.FloatTensor(token_mask).cuda()) # too ugly, but too lazy to fix it self.model.s_ltoken_ids = Variable(torch.LongTensor(s_ltoken_ids).cuda()) self.model.s_ltoken_mask = Variable(torch.FloatTensor(s_ltoken_mask).cuda()) self.model.s_rtoken_ids = Variable(torch.LongTensor(s_rtoken_ids).cuda()) self.model.s_rtoken_mask = Variable(torch.FloatTensor(s_rtoken_mask).cuda()) self.model.s_mtoken_ids = Variable(torch.LongTensor(s_mtoken_ids).cuda()) self.model.s_mtoken_mask = Variable(torch.FloatTensor(s_mtoken_mask).cuda()) scores = self.model.forward(token_ids, token_mask, entity_ids, entity_mask, p_e_m, gold=true_pos.view(-1, 1)) scores = scores.cpu().data.numpy() if (topk == 1): pred_ids = np.argmax(scores, axis=1) pred_entities = [m['selected_cands']['named_cands'][i] if m['selected_cands']['mask'][i] == 1 else (m['selected_cands']['named_cands'][0] if m['selected_cands']['mask'][0] == 1 else 'NIL') for (i, m) in zip(pred_ids, batch)] doc_names = [m['doc_name'] for m in batch] for dname, entity in zip(doc_names, pred_entities): predictions[dname].append({'pred': (entity, 0.)}) else: pred_ids = np.argsort(scores, axis=1)[:,::-1] pred_entities = [[m['selected_cands']['named_cands'][i] if m['selected_cands']['mask'][i] == 1 else 'NIL' for i in ids] for (ids, m) in zip(pred_ids, batch)] doc_names = [m['doc_name'] for m in batch] for dname, entities in zip(doc_names, pred_entities): while len(entities)>=2 and entities[-1]=='NIL' and entities[-2]=='NIL': del entities[-1] predictions[dname].append({'pred': (entities, 0.)}) return predictions