def prerank(self, dataset, predict=False): new_dataset = [] has_gold = 0 total = 0 correct = 0 larger_than_x = 0 larger_than_x_correct = 0 total_cands = 0 print('preranking...') for count, content in enumerate(dataset): if count % 1000 == 0: print(count, end='\r') items = [] if self.args.keep_ctx_ent > 0: # rank the candidates by ntee scores lctx_ids = [ m['context'][0][max( len(m['context'][0]) - self.args.prerank_ctx_window // 2, 0):] for m in content ] rctx_ids = [ m['context'][1] [:min(len(m['context'][1]), self.args.prerank_ctx_window // 2)] for m in content ] ment_ids = [[] for m in content] token_ids = [ l + m + r if len(l) + len(r) > 0 else [self.prerank_model.word_voca.unk_id] for l, m, r in zip(lctx_ids, ment_ids, rctx_ids) ] token_ids_len = [len(a) for a in token_ids] entity_ids = [m['cands'] for m in content] entity_ids = Variable(torch.LongTensor(entity_ids).cuda()) entity_mask = [m['mask'] for m in content] entity_mask = Variable(torch.FloatTensor(entity_mask).cuda()) token_ids, token_offsets = utils.flatten_list_of_lists( token_ids) token_offsets = Variable( torch.LongTensor(token_offsets).cuda()) token_ids = Variable(torch.LongTensor(token_ids).cuda()) scores, sent_vecs = self.prerank_model.forward( token_ids, token_offsets, entity_ids, use_sum=True, return_sent_vecs=True) scores = (scores * entity_mask).add_( (entity_mask - 1).mul_(1e10)) if self.args.keep_ctx_ent > 0: top_scores, top_pos = torch.topk(scores, dim=1, k=self.args.keep_ctx_ent) top_scores = top_scores.data.cpu().numpy() / np.array( token_ids_len).reshape(-1, 1) top_pos = top_pos.data.cpu().numpy() else: top_scores = None top_pos = [[]] * len(content) # compute distribution for sampling negatives probs = F.softmax(torch.matmul( sent_vecs, self.prerank_model.entity_embeddings.weight.t()), dim=1) _, neg_cands = torch.topk(probs, dim=1, k=1000) neg_cands = neg_cands.data.cpu().numpy() else: top_scores = None top_pos = [[]] * len(content) # select candidats: mix between keep_ctx_ent best candidates (ntee scores) with # keep_p_e_m best candidates (p_e_m scores) for i, m in enumerate(content): sm = { 'cands': [], 'named_cands': [], 'p_e_m': [], 'mask': [], 'true_pos': -1 } m['selected_cands'] = sm m['neg_cands'] = neg_cands[i, :] selected = set(top_pos[i]) idx = 0 while len(selected ) < self.args.keep_ctx_ent + self.args.keep_p_e_m: if idx not in selected: selected.add(idx) idx += 1 selected = sorted(list(selected)) for idx in selected: sm['cands'].append(m['cands'][idx]) sm['named_cands'].append(m['named_cands'][idx]) sm['p_e_m'].append(m['p_e_m'][idx]) sm['mask'].append(m['mask'][idx]) if idx == m['true_pos']: sm['true_pos'] = len(sm['cands']) - 1 if not predict and not (self.args.multi_instance or self.args.semisup): if sm['true_pos'] == -1: continue # this insertion only makes the performance worse (why???) # sm['true_pos'] = 0 # sm['cands'][0] = m['cands'][m['true_pos']] # sm['named_cands'][0] = m['named_cands'][m['true_pos']] # sm['p_e_m'][0] = m['p_e_m'][m['true_pos']] # sm['mask'][0] = m['mask'][m['true_pos']] items.append(m) if sm['true_pos'] >= 0: has_gold += 1 total += 1 # if predict: # only for oracle model, not used for eval # if sm['true_pos'] == -1: # sm['true_pos'] = 0 # a fake gold, happens only 2%, but avoid the non-gold if len(items) > 0: if len(items) > 1: c, l, lc, tc = self.get_p_e_ent_net(items) correct += c larger_than_x += l larger_than_x_correct += lc total_cands += tc if (not predict) and (not self.args.multi_instance) and ( not self.args.semisup): filtered_items = [] for m in items: if m['selected_cands']['true_pos'] >= 0: filtered_items.append(m) else: filtered_items = items new_dataset.append(filtered_items) try: print('recall', has_gold / total) except: pass if True: # not predict: try: print('correct', correct, correct / total) print(larger_than_x, larger_than_x_correct, larger_than_x_correct / larger_than_x) print(total_cands, total_cands / total) except: pass print('------------------------------------------') return new_dataset
def prerank(self, dataset, predict=False): new_dataset = [] has_gold = 0 total = 0 for content in dataset: items = [] if self.args.keep_ctx_ent > 0: # rank the candidates by ntee scores lctx_ids = [ m['context'][0][max( len(m['context'][0]) - self.args.prerank_ctx_window // 2, 0):] for m in content ] rctx_ids = [ m['context'][1] [:min(len(m['context'][1]), self.args.prerank_ctx_window // 2)] for m in content ] ment_ids = [[] for m in content] token_ids = [ l + m + r if len(l) + len(r) > 0 else [self.prerank_model.word_voca.unk_id] for l, m, r in zip(lctx_ids, ment_ids, rctx_ids) ] entity_ids = [m['cands'] for m in content] entity_ids = Variable(torch.LongTensor(entity_ids).cuda()) entity_mask = [m['mask'] for m in content] entity_mask = Variable(torch.FloatTensor(entity_mask).cuda()) token_ids, token_offsets = utils.flatten_list_of_lists( token_ids) token_offsets = Variable( torch.LongTensor(token_offsets).cuda()) token_ids = Variable(torch.LongTensor(token_ids).cuda()) log_probs = self.prerank_model.forward(token_ids, token_offsets, entity_ids, use_sum=True) log_probs = (log_probs * entity_mask).add_( (entity_mask - 1).mul_(1e10)) _, top_pos = torch.topk(log_probs, dim=1, k=self.args.keep_ctx_ent) top_pos = top_pos.data.cpu().numpy() else: top_pos = [[]] * len(content) # select candidats: mix between keep_ctx_ent best candidates (ntee scores) with # keep_p_e_m best candidates (p_e_m scores) for i, m in enumerate(content): sm = { 'cands': [], 'named_cands': [], 'p_e_m': [], 'mask': [], 'true_pos': -1 } m['selected_cands'] = sm selected = set(top_pos[i]) idx = 0 while len(selected ) < self.args.keep_ctx_ent + self.args.keep_p_e_m: if idx not in selected: selected.add(idx) idx += 1 selected = sorted(list(selected)) for idx in selected: sm['cands'].append(m['cands'][idx]) sm['named_cands'].append(m['named_cands'][idx]) sm['p_e_m'].append(m['p_e_m'][idx]) sm['mask'].append(m['mask'][idx]) if idx == m['true_pos']: sm['true_pos'] = len(sm['cands']) - 1 if not predict: if sm['true_pos'] == -1: continue # this insertion only makes the performance worse (why???) # sm['true_pos'] = 0 # sm['cands'][0] = m['cands'][m['true_pos']] # sm['named_cands'][0] = m['named_cands'][m['true_pos']] # sm['p_e_m'][0] = m['p_e_m'][m['true_pos']] # sm['mask'][0] = m['mask'][m['true_pos']] items.append(m) if sm['true_pos'] >= 0: has_gold += 1 total += 1 if predict: # only for oracle model, not used for eval pass # if sm['true_pos'] == -1: # sm['true_pos'] = 0 # a fake gold, happens only 2%, but avoid the non-gold if len(items) > 0: new_dataset.append(items) print('recall', has_gold / total) return new_dataset