def label(save_path,file): name = os.path.basename(file) name, _ = os.path.splitext(name) save_name = os.path.join(save_path,"%s.json" % name) paper = json.load(open(file)) abstract = paper['abstract'] article = paper['article'] tokenize = compose(list, _split_words) article = tokenize(article) abstract = tokenize(abstract) extracted = [] scores = [] indices = list(range(len(article))) for abst in abstract: rouges = list(map(compute_rouge_l(reference=abst, mode='r'), article)) ext = max(indices, key=lambda i: rouges[i]) indices.remove(ext) extracted.append(ext) scores.append(rouges[ext]) if not indices: break paper['extracted'] = extracted paper['score'] = scores json.dump(paper,open(save_name,'w'),indent=4)
def label(save_path, file): name = os.path.basename(file) name, _ = os.path.splitext(name) save_name = os.path.join(save_path, "%s.json" % name) paper = json.load(open(file)) abstract = paper['abstract'] article = paper['article'] tokenize = compose(list, _split_words) article = tokenize(article) abstract = tokenize(abstract) km_matrix = [] for i in range(len(abstract)): rouges = list( map(compute_rouge_l(reference=abstract[i], mode='r'), article)) km_matrix.append([-i for i in rouges]) km_matrix = np.array(km_matrix) row_ind, col_ind = linear_sum_assignment(km_matrix) paper['extracted'] = [int(i) for i in col_ind] paper['score'] = [-i for i in km_matrix[row_ind, col_ind]] for i in range(len(paper['score']) - 1, -1, -1): if paper['score'][i] <= 0.05: paper['score'].pop(i) paper['extracted'].pop(i) json.dump(paper, open(save_name, 'w'), indent=4)
def get_extract_label(art_sents, abs_sents): """ greedily match summary sentences to article sentences""" extracted = [] scores = [] indices = list(range(len(art_sents))) for abst in abs_sents: rouges = list(map(compute_rouge_l(reference=abst, mode='f'), art_sents)) rouge1 = list( map(compute_rouge_n(reference=abst, n=1, mode='f'), art_sents)) rouge2 = list( map(compute_rouge_n(reference=abst, n=2, mode='f'), art_sents)) # ext = max(indices, key=lambda i: (rouges[i])) ext = max(indices, key=lambda i: (rouges[i] + rouge1[i] + rouge2[i]) / 3) indices.remove(ext) extracted.append(ext) scores.append(rouges[ext]) if not indices: break return extracted, scores
def get_extract_label(art_sents, abs_sents): """ greedily match summary sentences to article sentences""" extracted = [] scores = [] indices = list(range(len(art_sents))) for abst in abs_sents: rouges = list(map(compute_rouge_l(reference=abst, mode='r'), art_sents)) ext = max(indices, key=lambda i: rouges[i]) indices.remove(ext) extracted.append(ext) scores.append(rouges[ext]) if not indices: break return extracted, scores
def process_one(doc_id, dataset='080910'): """ :param doc_id: an id in [0, n_sets), find n_sets in data_info.py :param dataset: can be '010203', '04', '11', '080910' :return: """ sim = defaultdict(dict) ext = [] data_info = DataInfo(dataset) path = f"{data_info.doc_path}/{doc_id}.json" with open(path) as f: data = json.loads(f.read()) name = data['id'] for ct, sent in enumerate(data['article']): if sent[0] == '###': continue if not is_quote(sent.split()): ext.append(sent) for i in tqdm(range(len(ext))): for j in range(i + 1, len(ext)): sim[name][j, i] = compute_rouge_l(ext[j], ext[i], mode='p') sim[name][i, j] = compute_rouge_l(ext[i], ext[j], mode='p') pickle.dump(sim, open(f'sim_{dataset}_new/sim{doc_id}_{dataset}.pkl', 'wb'))
def a2c_train_step(agent, abstractor, loader, opt, grad_fn, gamma=0.99, reward_fn=compute_rouge_l, stop_reward_fn=compute_rouge_n(n=1), stop_coeff=1.0): #, wt_rge = wt_rge): opt.zero_grad() indices = [] probs = [] baselines = [] ext_sents = [] art_batch, abs_batch = next(loader) for raw_arts, raw_abs in zip(art_batch, abs_batch): # print(raw_abs) (inds, ms), bs, raw_arts = agent(raw_arts, raw_abs_sents=raw_abs, n_abs=None) #, n_abs=3 # print(inds, len(raw_arts), len(inds)) baselines.append(bs) indices.append(inds) probs.append(ms) ext_sents += [ raw_arts[idx.item()] for idx in inds if idx.item() < len(raw_arts) ] with torch.no_grad(): if not ext_sents: summaries = [] elif use_bart: summaries = get_bart_summaries(ext_sents, tokenizer, bart_model) else: summaries = abstractor(ext_sents) i = 0 rewards = [] avg_reward = 0 # print(indices) # print(abs_batch) cands = [] refs = [] F1s = None if reward_fn == compute_bert_score or reward_fn == compute_bleurt_score: # print("bertscore -reward") for inds, abss in zip(indices, abs_batch): for j in range(min(len(inds) - 1, len(abss))): cands.append(" ".join(summaries[i + j])) refs.append(" ".join(abss[j])) i += len(inds) - 1 # print(len(cands), len(refs)) #around 120 each F1s = reward_fn(cands, refs) # print(F1s) i = 0 t = 0 avg_len = 0 for inds, abss in zip(indices, abs_batch): # print(abss) # abss_tot = [x for y in abss for x in y] # print([j for j in range(min(len(inds)-1, len(abss)))]) compute_rouge_lcompute_rouge_n(n=2) #PLEASE NOTE HERE x = len(inds) - 1 if (reward_fn == compute_bert_score or reward_fn == compute_bleurt_score) and wt_rge >= 0.0001: rwd_lst = [(1 - wt_rge) * F1s[t + j] + wt_rge * compute_rouge_n(summaries[i + j], abss[j], n=2) for j in range(min(len(inds) - 1, len(abss)))] # print(rwd_lst) t += min(len(inds) - 1, len(abss)) elif (reward_fn == compute_bert_score or reward_fn == compute_bleurt_score): rwd_lst = [ F1s[t + j] for j in range(min(len(inds) - 1, len(abss))) ] # print(rwd_lst) t += min(len(inds) - 1, len(abss)) elif wt_rge >= 0.0001: # print("hi", wt_rge) rwd_lst = [(1 - wt_rge) * reward_fn(summaries[i + j], abss[j]) + wt_rge * compute_rouge_l(summaries[i + j], abss[j]) for j in range(min(len(inds) - 1, len(abss)))] elif reward_fn == presumm_reward or reward_fn == presumm_reward2 or reward_fn == presumm_reward3 or reward_fn == presumm_reward4: rwd_lst = reward_fn(summaries[i:i + len(inds) - 1], abss) x = 3 # print("hi", x) elif reward_fn == "summ-rouge-l": tmp = list(concat(abss)) rwd_lst = [ compute_rouge_l(summaries[i + j], tmp, mode='r') for j in range(min(len(inds) - 1, 3)) ] else: rwd_lst = [ reward_fn(summaries[i + j], abss[j]) for j in range(min(len(inds) - 1, len(abss))) ] #abss_tot, mode=r rs = ( rwd_lst + [0 for _ in range(max(0, len(inds) - 1 - len(rwd_lst))) ] #len(inds)-1-len(abss) + [ stop_coeff * stop_reward_fn(list(concat(summaries[i:i + x])), list(concat(abss))) ]) assert len(rs) == len(inds) avg_reward += rs[-1] / stop_coeff if reward_fn == presumm_reward or reward_fn == presumm_reward3 or reward_fn == presumm_reward4: rs[-1] = 0 # if reward_fn == presumm_reward3: # rs.pop() avg_len += (len(inds) - 1) / len(abs_batch) # print(avg_reward) i += len(inds) - 1 # compute discounted rewards # print(rs) R = 0 disc_rs = [] for r in rs[::-1]: R = r + gamma * R disc_rs.insert(0, R) rewards += disc_rs # print(avg_len) indices = list(concat(indices)) probs = list(concat(probs)) baselines = list(concat(baselines)) # standardize rewards reward = torch.Tensor(rewards).to(baselines[0].device) # print(reward.mean()) reward = (reward - reward.mean()) / (reward.std() + float(np.finfo(np.float32).eps)) baseline = torch.cat(baselines).squeeze() avg_advantage = 0 losses = [] avg_list = [] ##ARJUN for action, p, r, b in zip(indices, probs, reward, baseline): advantage = (r - b).item() ##ARJUN # (r-b) avg_advantage += advantage avg_list.append(advantage) ##ARJUN avg_list = torch.Tensor(avg_list).to(baselines[0].device) ##ARJUN # avg_list = (avg_list - avg_list.mean()) / ( avg_list.std() + float(np.finfo(np.float32).eps)) ##ARJUN # for action, p, advantage in zip(indices, probs, avg_list): ##ARJUN losses.append(-p.log_prob(action) * (advantage / len(indices))) # divide by T*B critic_loss = F.mse_loss(baseline, reward).unsqueeze(dim=0) # backprop and update autograd.backward([critic_loss] + losses, [torch.ones(1).to(critic_loss.device)] * (1 + len(losses))) grad_log = grad_fn() opt.step() log_dict = {} log_dict.update(grad_log) log_dict['reward'] = avg_reward / len(art_batch) log_dict['advantage'] = avg_advantage / len(indices) #removed .item() log_dict['mse'] = critic_loss.item() log_dict['avg_len'] = avg_len assert not math.isnan(log_dict['grad_norm']) return log_dict
def decode(save_path, model_dir, split, batch_size, beam_size, diverse, max_len, cuda): start = time() # setup model with open(join(model_dir, 'meta.json')) as f: meta = json.loads(f.read()) if meta['net_args']['abstractor'] is None: # NOTE: if no abstractor is provided then # the whole model would be extractive summarization assert beam_size == 1 abstractor = identity else: if beam_size == 1: abstractor = Abstractor(join(model_dir, 'abstractor'), max_len, cuda) else: abstractor = BeamAbstractor(join(model_dir, 'abstractor'), max_len, cuda) extractor = RLExtractor(model_dir, cuda=cuda) # setup loader def coll(batch): articles, abstract, extracted = unzip(batch) articles = list(filter(bool, articles)) abstract = list(filter(bool, abstract)) extracted = list(filter(bool, extracted)) return articles, abstract, extracted dataset = DecodeDataset(split) n_data = len(dataset[0]) # article sentence loader = DataLoader( dataset, batch_size=batch_size, shuffle=False, num_workers=4, collate_fn=coll ) # prepare save paths and logs if os.path.exists(join(save_path, 'output')): pass else: os.makedirs(join(save_path, 'output')) dec_log = {} dec_log['abstractor'] = meta['net_args']['abstractor'] dec_log['extractor'] = meta['net_args']['extractor'] dec_log['rl'] = True dec_log['split'] = split dec_log['beam'] = beam_size dec_log['diverse'] = diverse with open(join(save_path, 'log.json'), 'w') as f: json.dump(dec_log, f, indent=4) file_path = os.path.join(save_path, 'Attention') act_path = os.path.join(save_path, 'Actions') header = "index,rouge_score1,rouge_score2,"+\ "rouge_scorel,dec_sent_nums,abs_sent_nums,doc_sent_nums,doc_words_nums,"+\ "ext_words_nums, abs_words_nums, diff,"+\ "recall, precision, less_rewrite, preserve_action, rewrite_action, each_actions,"+\ "top3AsAns, top3AsGold, any_top2AsAns, any_top2AsGold,true_rewrite,true_preserve\n" if not os.path.exists(file_path): print('create dir:{}'.format(file_path)) os.makedirs(file_path) if not os.path.exists(act_path): print('create dir:{}'.format(act_path)) os.makedirs(act_path) with open(join(save_path,'_statisticsDecode.log.csv'),'w') as w: w.write(header) # Decoding i = 0 with torch.no_grad(): for i_debug, (raw_article_batch, raw_abstract_batch, extracted_batch) in enumerate(loader): tokenized_article_batch = map(tokenize(None), raw_article_batch) tokenized_abstract_batch = map(tokenize(None), raw_abstract_batch) token_nums_batch = list(map(token_nums(None), raw_article_batch)) ext_nums = [] ext_arts = [] ext_inds = [] rewrite_less_rouge = [] dec_outs_act = [] ext_acts = [] abs_collections = [] ext_collections = [] # 抽句子 for ind, (raw_art_sents, abs_sents) in enumerate(zip(tokenized_article_batch ,tokenized_abstract_batch)): (ext, (state, act_dists)), act = extractor(raw_art_sents) # exclude EOE extracted_state = state[extracted_batch[ind]] attn = torch.softmax(state.mm(extracted_state.transpose(1,0)),dim=-1) # (_, abs_state), _ = extractor(abs_sents) # exclude EOE def plot_actDist(actons, nums): print('indiex: {} distribution ...'.format(nums)) # Write MDP State Attention weight matrix file_name = os.path.join(act_path, '{}.attention.pdf'.format(nums)) pdf_pages = PdfPages(file_name) plot_attention(actons.cpu().numpy(), name='{}-th artcle'.format(nums), X_label=list(range(len(raw_art_sents))), Y_label=list(range(len(ext))), dirpath=save_path, pdf_page=pdf_pages,action=True) pdf_pages.close() # plot_actDist(torch.stack(act_dists, dim=0), nums=ind+i) def plot_attn(): print('indiex: {} write_attention_pdf ...'.format(i + ind)) # Write MDP State Attention weight matrix file_name = os.path.join(file_path, '{}.attention.pdf'.format(i+ind)) pdf_pages = PdfPages(file_name) plot_attention(attn.cpu().numpy(), name='{}-th artcle'.format(i+ind), X_label=extracted_batch[ind],Y_label=list(range(len(raw_art_sents))), dirpath=save_path, pdf_page=pdf_pages) pdf_pages.close() # plot_attn() ext = ext[:-1] act = act[:-1] if not ext: # use top-5 if nothing is extracted # in some rare cases rnn-ext does not extract at all ext = list(range(5))[:len(raw_art_sents)] act = list([1]*5)[:len(raw_art_sents)] else: ext = [i.item() for i in ext] act = [i.item() for i in act] ext_nums.append(ext) ext_inds += [(len(ext_arts), len(ext))] # [(0,5),(5,7),(7,3),...] ext_arts += [raw_art_sents[k] for k in ext] ext_acts += [k for k in act] # 計算累計的句子 ext_collections += [sum(ext_arts[ext_inds[-1][0]:ext_inds[-1][0]+k+1],[]) for k in range(ext_inds[-1][1])] abs_collections += [sum(abs_sents[:k+1],[]) if k<len(abs_sents) else sum(abs_sents[0:len(abs_sents)],[]) for k in range(ext_inds[-1][1])] if beam_size > 1: # do n times abstract all_beams = abstractor(ext_arts, beam_size, diverse) dec_outs = rerank_mp(all_beams, ext_inds) dec_collections = [[sum(dec_outs[pos[0]:pos[0]+k+1],[]) for k in range(pos[1])] for pos in ext_inds] dec_collections = [x for sublist in dec_collections for x in sublist] for index, chooser in enumerate(ext_acts): if chooser == 0: dec_outs_act += [dec_outs[index]] else: dec_outs_act += [ext_arts[index]] assert len(ext_collections)==len(dec_collections)==len(abs_collections) for ext, dec, abss, act in zip(ext_collections, dec_collections, abs_collections, ext_acts): # for each sent in extracted digest # All abstract mapping rouge_before_rewriten = compute_rouge_n(ext, abss, n=1) rouge_after_rewriten = compute_rouge_n(dec, abss, n=1) diff_ins = rouge_before_rewriten - rouge_after_rewriten rewrite_less_rouge.append(diff_ins) else: # do 1st abstract dec_outs = abstractor(ext_arts) dec_collections = [[sum(dec_outs[pos[0]:pos[0]+k+1],[]) for k in range(pos[1])] for pos in ext_inds] dec_collections = [x for sublist in dec_collections for x in sublist] for index, chooser in enumerate(ext_acts): if chooser == 0: dec_outs_act += [dec_outs[index]] else: dec_outs_act += [ext_arts[index]] # dec_outs_act = dec_outs # dec_outs_act = ext_arts assert len(ext_collections)==len(dec_collections)==len(abs_collections) for ext, dec, abss, act in zip(ext_collections, dec_collections, abs_collections, ext_acts): # for each sent in extracted digest # All abstract mapping rouge_before_rewriten = compute_rouge_n(ext, abss, n=1) rouge_after_rewriten = compute_rouge_n(dec, abss, n=1) diff_ins = rouge_before_rewriten - rouge_after_rewriten rewrite_less_rouge.append(diff_ins) assert i == batch_size*i_debug for iters, (j, n) in enumerate(ext_inds): do_right_rewrite = sum([1 for rouge, action in zip(rewrite_less_rouge[j:j+n], ext_acts[j:j+n]) if rouge<0 and action==0]) do_right_preserve = sum([1 for rouge, action in zip(rewrite_less_rouge[j:j+n], ext_acts[j:j+n]) if rouge>=0 and action==1]) decoded_words_nums = [len(dec) for dec in dec_outs_act[j:j+n]] ext_words_nums = [token_nums_batch[iters][x] for x in range(len(token_nums_batch[iters])) if x in ext_nums[iters]] # 皆取extracted label # decoded_sents = [raw_article_batch[iters][x] for x in extracted_batch[iters]] # 統計數據 [START] decoded_sents = [' '.join(dec) for dec in dec_outs_act[j:j+n]] rouge_score1 = compute_rouge_n(' '.join(decoded_sents),' '.join(raw_abstract_batch[iters]),n=1) rouge_score2 = compute_rouge_n(' '.join(decoded_sents),' '.join(raw_abstract_batch[iters]),n=2) rouge_scorel = compute_rouge_l(' '.join(decoded_sents),' '.join(raw_abstract_batch[iters])) dec_sent_nums = len(decoded_sents) abs_sent_nums = len(raw_abstract_batch[iters]) doc_sent_nums = len(raw_article_batch[iters]) doc_words_nums = sum(token_nums_batch[iters]) ext_words_nums = sum(ext_words_nums) abs_words_nums = sum(decoded_words_nums) label_recall = len(set(ext_nums[iters]) & set(extracted_batch[iters])) / len(extracted_batch[iters]) label_precision = len(set(ext_nums[iters]) & set(extracted_batch[iters])) / len(ext_nums[iters]) less_rewrite = rewrite_less_rouge[j+n-1] dec_one_action_num = sum(ext_acts[j:j+n]) dec_zero_action_num = n - dec_one_action_num ext_indices = '_'.join([str(i) for i in ext_nums[iters]]) top3 = set([0,1,2]) <= set(ext_nums[iters]) top3_gold = set([0,1,2]) <= set(extracted_batch[iters]) # Any Top 2 top2 = set([0,1]) <= set(ext_nums[iters]) or set([1,2]) <= set(ext_nums[iters]) or set([0,2]) <= set(ext_nums[iters]) top2_gold = set([0,1]) <= set(extracted_batch[iters]) or set([1,2]) <= set(extracted_batch[iters]) or set([0,2]) <= set(extracted_batch[iters]) with open(join(save_path,'_statisticsDecode.log.csv'),'a') as w: w.write('{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{}\n'.format(i,rouge_score1, rouge_score2, rouge_scorel, dec_sent_nums, abs_sent_nums, doc_sent_nums, doc_words_nums, ext_words_nums,abs_words_nums,(ext_words_nums - abs_words_nums), label_recall, label_precision, less_rewrite, dec_one_action_num, dec_zero_action_num, ext_indices, top3, top3_gold, top2, top2_gold,do_right_rewrite,do_right_preserve)) # 統計數據 END with open(join(save_path, 'output/{}.dec'.format(i)), 'w') as f: decoded_sents = [i for i in decoded_sents if i!=''] if len(decoded_sents) > 0: f.write(make_html_safe('\n'.join(decoded_sents))) else: f.write('') i += 1 print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format( i, n_data, i/n_data*100, timedelta(seconds=int(time()-start)) ), end='') print()
def get_extract_label(art_sents, abs_sents): """ greedily match summary sentences to article sentences""" extracted = [] scores = [] if ''.join(art_sents[0]) != '<E>': art_sents.insert(0, '<E>') indices = list(range(len(art_sents))) new_abs_sents = [] for j in range(len(abs_sents)): rouges = list( map(compute_rouge_l(reference=abs_sents[j], mode='r'), art_sents[1:])) rouges.insert(0, 0) ext = max(indices, key=lambda i: rouges[i]) max_scores = rouges[ext] max_exts = collections.Counter(rouges)[max_scores] if max_exts != 1: max_inds = [] rouge_f = [] for idx, score in enumerate(rouges): if idx in indices: if score == max_scores: max_inds.append(idx) rouge_f.append( compute_rouge_l_summ(art_sents[idx], abs_sents[j], mode='f')) maxrouge = max(list(range(len(max_inds))), key=lambda i: rouge_f[i]) ext = max_inds[maxrouge] if ext == 0: ext = 1 new_art_sents = [] new_art_sents.append('<E>') for i in range(1, len(art_sents)): #print(art_sents[i]) if i < ext: new_art_sents.append(art_sents[i] + art_sents[ext]) elif i > ext: new_art_sents.append(art_sents[ext] + art_sents[i]) else: new_art_sents.append(art_sents[ext]) new_rouges = list( map(compute_rouge_l_summ(refs=abs_sents[j], mode='fr'), new_art_sents[1:])) new_rouges_f = [fr[0] for fr in new_rouges] new_rouges_r = [fr[1] for fr in new_rouges] new_rouges_f.insert(0, 0) new_rouges_r.insert(0, 0) new_ext = max(indices, key=lambda i: new_rouges_f[i]) if new_ext == 0: new_ext = 1 #if ext == new_ext or rouges[ext] >= new_rouges[new_ext]: if ext == new_ext or rouges[ext] >= new_rouges_r[new_ext]: extracted.append(ext) extracted.append(0) scores.append(new_rouges_f[ext]) elif ext < new_ext: extracted.append(ext) extracted.append(new_ext) extracted.append(0) scores.append(new_rouges_f[new_ext]) else: extracted.append(new_ext) extracted.append(ext) extracted.append(0) scores.append(new_rouges_f[new_ext]) #reduce duplication: ab->A bc->B, abc->AB new_abs_sents.append(abs_sents[j]) index = findindex(extracted, 0) #dic = collections.Counter(extracted) while (len(index) >= 2): #print('in') if len(index) == 2: overlap = list( set(extracted[:index[-2]]) & set(extracted[index[-2] + 1:index[-1]])) l = len(overlap) if l > 0: new = list( set(extracted[:index[-2]]).union( set(extracted[index[-2] + 1:index[-1]]))) new.sort() del extracted[:index[-1] + 1] extracted = extracted + new extracted.append(0) new_sent = new_abs_sents[-2] + new_abs_sents[-1] del new_abs_sents[-2:] new_abs_sents.append(new_sent) index = findindex(extracted, 0) else: break else: overlap = list( set(extracted[index[-3] + 1:index[-2]]) & set(extracted[index[-2] + 1:index[-1]])) l = len(overlap) if l > 0: new = list( set(extracted[index[-3] + 1:index[-2]]).union( set(extracted[index[-2] + 1:index[-1]]))) new.sort() del extracted[index[-3] + 1:index[-1] + 1] extracted = extracted + new extracted.append(0) new_sent = new_abs_sents[-2] + new_abs_sents[-1] del new_abs_sents[-2:] new_abs_sents.append(new_sent) index = findindex(extracted, 0) else: break if len(index) >= 2: if len(index) == 2: for idx in extracted[:index[-2]]: try: indices.remove(idx) except: continue if len(index) > 2: for idx in extracted[index[-3]:index[-2]]: try: indices.remove(idx) except: continue if not indices: break length = len(new_abs_sents) for i in range(length): new_abs_sents.insert(i + (i + 1), ['<E>']) return extracted, scores, new_abs_sents, art_sents