def compute_reward_old(pred_str, pred_sent_list, trg_str, trg_sent_list, reward_type, regularization_factor=0.0, regularization_type=0, entropy=None): if regularization_type == 1: raise ValueError("Not implemented.") elif regularization_type == 2: regularization = entropy else: regularization = 0.0 if reward_type == 0: tmp_reward = 0.3 * compute_rouge_n( pred_str, trg_str, n=1, mode='f') + 0.2 * compute_rouge_n( pred_str, trg_str, n=2, mode='f') + 0.5 * compute_rouge_l_summ( pred_sent_list, trg_sent_list, mode='f') elif reward_type == 1: tmp_reward = compute_rouge_l_summ(pred_sent_list, trg_sent_list, mode='f') else: raise ValueError # Add the regularization term to the reward only if regularization type != 0 if regularization_type == 0 or regularization_factor == 0: reward = tmp_reward else: reward = (1 - regularization_factor ) * tmp_reward + regularization_factor * regularization return reward
def get_extract_label(art_sents, abs_sents): """ greedily match summary sentences to article sentences""" extracted = [] scores = [] indices = list(range(len(art_sents))) for abst in abs_sents: rouges = list(map(compute_rouge_l(reference=abst, mode='f'), art_sents)) rouge1 = list( map(compute_rouge_n(reference=abst, n=1, mode='f'), art_sents)) rouge2 = list( map(compute_rouge_n(reference=abst, n=2, mode='f'), art_sents)) # ext = max(indices, key=lambda i: (rouges[i])) ext = max(indices, key=lambda i: (rouges[i] + rouge1[i] + rouge2[i]) / 3) indices.remove(ext) extracted.append(ext) scores.append(rouges[ext]) if not indices: break return extracted, scores
def xsum_mixed_rouge_reward(pred_str_list, pred_sent_2d_list, trg_str_list, trg_sent_2d_list, batch_size, device): #reward = np.zeros(batch_size) reward = [] for idx, (pred_str, pred_sent_list, trg_str, trg_sent_list) in enumerate( zip(pred_str_list, pred_sent_2d_list, trg_str_list, trg_sent_2d_list)): #reward[idx] = 0.2 * compute_rouge_n(pred_str, trg_str, n=1, mode='f') + 0.5 * compute_rouge_n(pred_str, trg_str, n=2, mode='f') + 0.3 * compute_rouge_l_summ(pred_sent_list, trg_sent_list, mode='f') reward.append( 0.2 * compute_rouge_n(pred_str, trg_str, n=1, mode='f') + 0.5 * compute_rouge_n(pred_str, trg_str, n=2, mode='f') + 0.3 * compute_rouge_l_summ(pred_sent_list, trg_sent_list, mode='f')) return torch.FloatTensor(reward).to(device) # tensor: [batch_size]
def train(args): if not exists(args.path): os.makedirs(args.path) # make net agent, agent_vocab, abstractor, net_args = configure_net( args.abs_dir, args.ext_dir, args.cuda) # configure training setting assert args.stop > 0 train_params = configure_training( 'adam', args.lr, args.clip, args.decay, args.batch, args.gamma, args.reward, args.stop, 'rouge-1' ) train_batcher, val_batcher = build_batchers(args.batch) # TODO different reward reward_fn = compute_rouge_l stop_reward_fn = compute_rouge_n(n=1) # save abstractor binary if args.abs_dir is not None: abs_ckpt = {} abs_ckpt['state_dict'] = load_best_ckpt(args.abs_dir) abs_vocab = pkl.load(open(join(args.abs_dir, 'vocab.pkl'), 'rb')) abs_dir = join(args.path, 'abstractor') os.makedirs(join(abs_dir, 'ckpt')) with open(join(abs_dir, 'meta.json'), 'w') as f: json.dump(net_args['abstractor'], f, indent=4) torch.save(abs_ckpt, join(abs_dir, 'ckpt/ckpt-0-0')) with open(join(abs_dir, 'vocab.pkl'), 'wb') as f: pkl.dump(abs_vocab, f) # save configuration meta = {} meta['net'] = 'rnn-ext_abs_rl' meta['net_args'] = net_args meta['train_params'] = train_params with open(join(args.path, 'meta.json'), 'w') as f: json.dump(meta, f, indent=4) with open(join(args.path, 'agent_vocab.pkl'), 'wb') as f: pkl.dump(agent_vocab, f) # prepare trainer grad_fn = get_grad_fn(agent, args.clip) optimizer = optim.Adam(agent.parameters(), **train_params['optimizer'][1]) scheduler = ReduceLROnPlateau(optimizer, 'max', verbose=True, factor=args.decay, min_lr=0, patience=args.lr_p) pipeline = A2CPipeline(meta['net'], agent, abstractor, train_batcher, val_batcher, optimizer, grad_fn, reward_fn, args.gamma, stop_reward_fn, args.stop) trainer = BasicTrainer(pipeline, args.path, args.ckpt_freq, args.patience, scheduler, val_mode='score') print('start training with the following hyper-parameters:') print(meta) trainer.train()
def main(args): # Define directories of train set and decoded summaries DATA_DIR = os.environ['DATA'] train_data_dir = os.path.join(DATA_DIR, 'train') dec_dir = args.decoded_data_dir # Get reference files list ref_files_df = pd.read_csv(os.path.join(DATA_DIR, args.bucket_file_path)) ref_files = ref_files_df['filename'] # Get rouge scores list between reference and generated texts rouge_scores_list = [] for ind, act_file_name in enumerate(ref_files): with open(join(train_data_dir, act_file_name)) as f: js = json.loads(f.read()) abstract = '\n'.join(js['abstract']) dec_file_name = str(ind) + '.dec' with open(join(dec_dir, dec_file_name)) as f: generation = f.read() rouge_score = compute_rouge_n(generation.split(), abstract.split()) rouge_scores_list.append(rouge_score) df = pd.DataFrame() df['filename'] = ref_files df['rouge_score'] = rouge_scores_list df.to_csv(join('./rouge_score_files', args.rouge_scores_file_name))
def a2c_validate(agent, abstractor, loader): agent.eval() start = time() print('start running validation...', end='') avg_reward = 0 i = 0 with torch.no_grad(): for art_batch, abs_batch in loader: ext_sents = [] ext_inds = [] for raw_arts in art_batch: indices = agent(raw_arts) ext_inds += [(len(ext_sents), len(indices)-1)] ext_sents += [raw_arts[idx.item()] for idx in indices if idx.item() < len(raw_arts)] all_summs = abstractor(ext_sents) for (j, n), abs_sents in zip(ext_inds, abs_batch): summs = all_summs[j:j+n] # python ROUGE-1 (not official evaluation) avg_reward += compute_rouge_n(list(concat(summs)), list(concat(abs_sents)), n=1) i += 1 avg_reward /= (i/100) print('finished in {}! avg reward: {:.2f}'.format( timedelta(seconds=int(time()-start)), avg_reward)) return {'reward': avg_reward}
def a2c_validate(agent, abstractor, loader): agent.eval() start = time() print('start running validation...', end='') avg_reward = 0 i = 0 with torch.no_grad(): for art_batch, topic_batch, abs_batch in loader: ext_sents = [] ext_inds = [] for raw_arts, topic in zip(art_batch, topic_batch): indices = agent(raw_arts, topic) ext_inds += [(len(ext_sents), len(indices) - 1)] ext_sents += [ raw_arts[idx.item()] for idx in indices if idx.item() < len(raw_arts) ] all_summs = abstractor(ext_sents) for (j, n), abs_sents in zip(ext_inds, abs_batch): summs = all_summs[j:j + n] # python ROUGE-1 (not official evaluation) avg_reward += compute_rouge_n(list(concat(summs)), list(concat(abs_sents)), n=1) i += 1 avg_reward /= (i / 100) print('finished in {}! avg reward: {:.2f}'.format( timedelta(seconds=int(time() - start)), avg_reward)) return {'reward': avg_reward}
def main(args): dec_dir = args.decoded_data_dir act_dir = args.actual_data_dir decoded_files = os.listdir(dec_dir) ref_files = os.listdir(act_dir) file_ids = [dec_file.split('.')[0] for dec_file in decoded_files] rouge_scores_list = [] for ind, file_id in enumerate(file_ids): act_file_name = '.'.join([file_id, 'json']) with open(join(act_dir, act_file_name)) as f: js = json.loads(f.read()) abstract = '\n'.join(js['abstract']) dec_file_name = decoded_files[ind] with open(join(dec_dir, dec_file_name)) as f: generation = f.read() rouge_score = compute_rouge_n(generation.split(), abstract.split()) rouge_scores_list.append(rouge_score) df = pd.DataFrame() df['file_id'] = file_ids df['rouge_score'] = rouge_scores_list df.to_csv(join('./rouge_score_files', args.rouge_scores_file_name))
def a2c_validate(agent, abstractor, loader): agent.eval() start = time() print('start running validation...', end='') avg_reward = 0 i = 0 with torch.no_grad(): for art_batch, abs_batch, extract in loader: greedy_inputs = [] for idx, raw_arts in enumerate(art_batch): greedy, sample, log_probs = agent(raw_arts, sample_time=1, validate=True) sample = sample[0] log_probs = log_probs[0] greedy_sents = [raw_arts[ind] for ind in greedy] greedy_sents = [word for sent in greedy_sents for word in sent] #print(greedy_sents) #greedy_sents = list(concat(greedy_sents)) greedy_sents = [] ext_sent = [] for ids in greedy: if ids < len(raw_arts): if ids == 0: if ext_sent: greedy_sents.append(ext_sent) ext_sent = [] else: ext_sent += raw_arts[ids] if greedy[-1] != 0 and ext_sent: greedy_sents.append(ext_sent) #print(greedy_sents) #exit() greedy_inputs.append(greedy_sents) greedy_abstracts = [] for abs_src in greedy_inputs: with torch.no_grad(): greedy_outputs = abstractor(abs_src) #greedy_abstract = [] #for greedy_sents in greedy_outputs: # greedy_sents = sent_tokenize(' '.join(greedy_sents)) # greedy_sents = [sent.strip().split(' ') for sent in greedy_sents] # greedy_abstract += greedy_sents greedy_abstract = list(concat(greedy_outputs)) greedy_abstracts.append(greedy_abstract) for idx, greedy_sents in enumerate(greedy_abstracts): abss = abs_batch[idx] bs = compute_rouge_n(greedy_sents, list(concat(abss))) avg_reward += bs i += 1 #print(i) #print(avg_reward) #exit() avg_reward /= (i / 100) print('finished in {}! avg reward: {:.2f}'.format( timedelta(seconds=int(time() - start)), avg_reward)) return {'reward': avg_reward}
def a2c_validate(agent, abstractor, loader): agent.eval() start = time() print('start running validation...', end='') avg_reward = 0 i = 0 with torch.no_grad(): for art_batch, abs_batch, sent_batch in loader: print(i) ext_sents = [] ext_inds = [] masks = [] dirty = [] for raw_arts, sent_labels in zip(art_batch, sent_batch): indices = agent(raw_arts, sent_labels) ext_inds += [(len(ext_sents), len(indices) - 1)] assert indices[-1][-1].item() == len(raw_arts) + 1 tmp_stop = indices[-1][-1].item() tmp_truncate = tmp_stop - 1 str_arts = list(map(lambda x: ' '.join(x), raw_arts)) for idx in indices: t, m = rl_edu_to_sentence(str_arts, idx) if t == []: assert len(idx) == 1 id = idx[0].item() if id == tmp_truncate: dirty.append(len(ext_sents)) ext_sents.append(label) masks.append(label_mask) else: if idx[-1].item() != tmp_stop: ext_sents.append(t) masks.append(m) all_summs = abstractor(ext_sents, masks) for d in dirty: all_summs[d] = [] for (j, n), abs_sents in zip(ext_inds, abs_batch): summs = all_summs[j:j + n] # python ROUGE-1 (not official evaluation) avg_reward += compute_rouge_n(list(concat(summs)), list(concat(abs_sents)), n=1) i += 1 if i % 100 == 1: print(avg_reward / i, i) ''' with open('./compare/rl/' + str(i - 1) + '.dec', 'w') as f: for s in summs: s = ' '.join(s) f.write(s + '\n') ''' #if i > 1000: # break avg_reward /= (i / 100) print('finished in {}! avg reward: {:.2f}'.format( timedelta(seconds=int(time() - start)), avg_reward)) return {'reward': avg_reward}
def a2c_validate(agent, abstractor, loader): agent.eval() start = time() print('start running validation...', end='') avg_reward = 0 i = 0 with torch.no_grad(): for art_batch, abs_batch, ext_batch in loader: ext_sents = [] ext_inds = [] sent_acts = [] for raw_arts in art_batch: (indices, _), actions = agent(raw_arts) ext_inds += [(len(ext_sents), len(indices) - 1)] ext_sents += [ raw_arts[idx.item()] for idx in indices if idx.item() < len(raw_arts) ] sent_acts += [ actions[j] for j, idx in enumerate(indices) if idx.item() < len(raw_arts) ] assert len(ext_sents) == len(sent_acts) all_summs = [] need_abs_sents = [ ext_sents[iters] for iters, act in enumerate(sent_acts) if act == 0 ] if len(need_abs_sents) > 0: turn_abs_sents = abstractor(need_abs_sents) for nums, action in enumerate(sent_acts): if action == 0: all_summs += turn_abs_sents.pop(0) else: all_summs += ext_sents[nums] for (j, n), abs_sents in zip(ext_inds, abs_batch): summs = all_summs[j:j + n] # python ROUGE-1 (not official evaluation) avg_reward += compute_rouge_n(list(concat(summs)), list(concat(abs_sents)), n=1) i += 1 avg_reward /= (i / 100) print('finished in {}! avg reward: {:.2f}'.format( timedelta(seconds=int(time() - start)), avg_reward)) return {'reward': avg_reward}
def get_extract_label(art_sents, abs_sents): """ greedily match summary sentences to article sentences""" extracted = [] scores = [] new_art_sents, composed = my_compose(art_sents) indices = list(range(len(new_art_sents))) for abst in abs_sents: rouges = list(map(compute_rouge_n(reference=abst, mode='f'), new_art_sents)) ext = max(indices, key=lambda i: rouges[i]) extracted.append(composed[ext]) scores.append(rouges[ext]) #print(extracted, scores) return extracted, scores
def a2c_validate(agent, abstractor, loader): agent.eval() start = time() print('start running validation...', end='') avg_reward = 0 i = 0 with torch.no_grad(): for art_batch, abs_batch, _ in loader: #print(art_batch, abs_batch, _) ext_sents = [] ext_inds = [] new_indices = [] for raw_arts in art_batch: indices = agent(raw_arts) extL = len(ext_sents) #ext_sents += [raw_arts[idx.item()] # for idx in indices if idx.item() < len(raw_arts)] inds_ = [] ext_sent = [] for idx in indices: if idx.item() < len(raw_arts): if idx.item() == 0: if ext_sent: ext_sents.append(ext_sent) ext_sent = [] inds_.append(1) else: ext_sent += raw_arts[idx.item()] if indices[-1].item() != 0 and ext_sent: ext_sents.append(ext_sent) inds_.append(1) new_indices.append(inds_) indxL = len(new_indices) - 1 ext_inds += [(extL, indxL)] all_summs = abstractor(ext_sents) for (j, n), abs_sents in zip(ext_inds, abs_batch): summs = all_summs[j:j + n] # python ROUGE-1 (not official evaluation) avg_reward += compute_rouge_n(list(concat(summs)), list(concat(abs_sents)), n=1) i += 1 avg_reward /= (i / 100) print('finished in {}! avg reward: {:.2f}'.format( timedelta(seconds=int(time() - start)), avg_reward)) return {'reward': avg_reward}
def sc_validate(agent, abstractor, loader, entity=False, bert=False): agent.eval() start = time() print('start running validation...', end='') avg_reward = 0 i = 0 with torch.no_grad(): for art_batch, abs_batch in loader: greedy_inputs = [] for idx, raw_arts in enumerate(art_batch): greedy, sample, log_probs = agent(raw_arts, sample_time=1, validate=True) if entity or bert: raw_arts = raw_arts[0] # sample = sample[0] # log_probs = log_probs[0] greedy_sents = [raw_arts[ind] for ind in greedy] #greedy_sents = list(concat(greedy_sents)) greedy_sents = [word for sent in greedy_sents for word in sent] greedy_inputs.append(greedy_sents) with torch.no_grad(): greedy_outputs = abstractor(greedy_inputs) greedy_abstracts = [] for greedy_sents in greedy_outputs: greedy_sents = sent_tokenize(' '.join(greedy_sents)) greedy_sents = [ sent.strip().split(' ') for sent in greedy_sents ] greedy_abstracts.append(greedy_sents) for idx, greedy_sents in enumerate(greedy_abstracts): abss = abs_batch[idx] bs = compute_rouge_n(list(concat(greedy_sents)), list(concat(abss))) avg_reward += bs i += 1 avg_reward /= (i / 100) print('finished in {}! avg reward: {:.2f}'.format( timedelta(seconds=int(time() - start)), avg_reward)) return {'reward': avg_reward}
def a2c_validate(agent, abstractor, loader): agent.eval() start = time() print('start running validation...', end='') avg_reward = 0 i = 0 # IF using BERTScore # reward_fn = compute_bertscore_wo_baseline_rescaling # reward_fn.metric = datasets.load_metric('bertscore') with torch.no_grad(): for art_batch, abs_batch in loader: ext_sents = [] ext_inds = [] for raw_arts in art_batch: indices = agent(raw_arts) ext_inds += [(len(ext_sents), len(indices) - 1)] ext_sents += [ raw_arts[idx.item()] for idx in indices if idx.item() < len(raw_arts) ] all_summs = abstractor(ext_sents) for (j, n), abs_sents in zip(ext_inds, abs_batch): summs = all_summs[j:j + n] # python ROUGE-1 (not official evaluation) avg_reward += compute_rouge_n(list(concat(summs)), list(concat(abs_sents)), n=1) # IF using BERT score # avg_reward += reward_fn(' '.join(concat(summs)), ' '.join(concat(abs_sents))) i += 1 avg_reward /= (i / 100) # IF using BERTScore # del reward_fn print('finished in {}! avg reward: {:.2f}'.format( timedelta(seconds=int(time() - start)), avg_reward)) return {'reward': avg_reward}
def coll(tokenizer, batch): blank = '[unused0]' questions, context, _ids, abstract = list( filter(bool, list(zip(*batch)))) system = context # print('q:', questions) # print('c:', context) if len(abstract) > 0: rouges = { 'RLr': compute_rouge_l_summ( [_c.lower().split(' ') for _c in context[0]], abstract[0], mode='r'), 'RLf': compute_rouge_l_summ( [_c.lower().split(' ') for _c in context[0]], abstract[0], mode='f'), 'R1': compute_rouge_n(' '.join(context[0]).split(' '), list(concat(abstract[0])), mode='f', n=1), 'R2': compute_rouge_n(' '.join(context[0]).split(' '), list(concat(abstract[0])), mode='f', n=2) } if len(questions[0]) != 0: questions = questions[0] context = context[0] context = ' '.join(context) choicess = [[ question['answer'], question['choice1'], question['choice2'], question['choice3'] ] for question in questions] questions = [ question['question'].replace('<\\blank>', blank) for question in questions ] questions = [[ tokenizer.tokenize(qp.lower()) for qp in question.split(blank) ] for question in questions] new_questions = [] for question in questions: new_q = ['[CLS]'] for q in question: new_q += q + [blank] new_q.pop() new_questions.append(new_q) questions = new_questions contexts = [['[SEP]'] + tokenizer.tokenize(context.lower()) for _ in range(len(questions))] choicess = [[[tokenizer.tokenize(c.lower()) for c in choice] for choice in choices] for choices in choicess] choicess = [[['[SEP]'] + choice[0] + ['[SEP]'] + choice[1] if len(choice) == 2 else ['[SEP]'] + choice[0] for choice in choices] for choices in choicess] _inputs = [[ tokenizer.convert_tokens_to_ids( (question + context + choice)[:MAX_LEN]) for choice in choices ] for question, context, choices in zip( questions, contexts, choicess)] _inputs = pad_batch_tensorize_3d(_inputs, pad=0, cuda=False) else: _inputs = [] return (_inputs, rouges, system, abstract)
def a2c_train_step(agent, target_agent, abstractor, loader, opt, grad_fn, gamma=0.99, reward_fn=compute_rouge_l, stop_reward_fn=compute_rouge_n(n=1), stop_coeff=1.0): opt.zero_grad() def length_penalty(l, eps=0.7): return (5 + len(l) - 1)**eps / (5 + 1)**eps indices, probs, baselines = [], [], [] target_indices, target_probs, target_baselines = [], [], [] act, act_probs, act_baselines = [], [], [] pure_act = [] ext_sents = [] # 所有句子 abstract = [] pres_or_rewr = [] act_precision = [] target_probs = [] true_indices, false_indices = [], [] recall, precision = [], [] art_batch, abs_batch, ext_indices = next(loader) for n, (raw_arts, raw_abs, ext_index) in enumerate(zip(art_batch, abs_batch, ext_indices)): with torch.no_grad(): (target_inds, target_ms), target_bs, ( target_act_inds, target_act_ms), target_act_bs = target_agent(raw_arts) (inds, ms), bs, (act_inds, act_ms), act_bs = agent(raw_arts, observations=target_inds) baselines.append(bs) indices.append(inds) probs.append(ms) target_probs.append(target_ms) act.append(act_inds) act_probs.append(act_ms) act_baselines.append(act_bs) pure_act += [a.detach() for a in act_inds[:-1]] def precision_recall(): """ sentence-level precision/recall """ true_positive_extracts = [] false_extracts = [] # acc_pre = [] for k, index in enumerate(inds[:-1]): if index.item() in ext_index: true_positive_extracts.append(index.detach().item()) else: false_extracts.append(index.detach().item()) each_recall = (len(true_positive_extracts)) / (len(ext_index)) each_precision = (len(true_positive_extracts)) / ( len(inds) - 1) if len(inds) > 1 else 0 recall.append(each_recall) precision.append(each_precision) true_indices.append(true_positive_extracts) false_indices.append(false_extracts) precision_recall() """ 存取原句 """ current_sentence = [ raw_arts[idx.item()] for idx in inds if idx.item() < len(raw_arts) ] ext_sents += current_sentence assert len(indices) == len(act) assert len(probs) == len(act_probs) assert len(baselines) == len(act_baselines) """ 呼叫改寫模型 """ with torch.no_grad(): rewrite_sentences = abstractor(ext_sents) """ 經過abs agent後,重組句子 """ results = [ rewrite_sentences[n] if bins.item() == 0 else ext_sents[n] for n, bins in enumerate(pure_act) ] """ 原本選句子該得到的 reward 計算 """ i = 0 rewards = [] avg_reward = 0 avg_reward_abs_rew = 0 avg_reward_abs_pre = 0 rewrite_rewards, preserve_rewards = [], [] for inds, act_inds, abss, labeled in zip(indices, act, abs_batch, ext_indices): """ 全部都做 總得分 """ """ 原先extractor agent取的原始句子 """ rs = ([ reward_fn(ext_sents[i + j], abss[j]) for j in range(min(len(inds) - 1, len(abss))) ] + [0 for _ in range(max(0, len(inds) - 1 - len(abss)))] + [ stop_coeff * stop_reward_fn(list(concat(ext_sents[i:i + len(inds) - 1])), list(concat(abss))) ]) # rs = ([reward_fn(sum(ext_sents[i:i+j+1],[]), sum(abss[:j+1],[])) for j in range(min(len(inds)-1, len(abss)))] # + [0 for _ in range(max(0, len(inds)-1-len(abss)))] # + [stop_coeff*stop_reward_fn( # list(concat(ext_sents[i:i+len(inds)-1])), # list(concat(abss)))]) """ 單句 """ def single_compared_MC_TD(choice): # before_rewrite = ([stop_reward_fn(ext_sents[i+j], abss[j]) # for j in range(min(len(inds)-1, len(abss)))] # + [0 for _ in range(max(0, len(inds)-1-len(abss)))] # + [0]) # after_rewrite = ([stop_reward_fn(rewrite_sentences[i+j], abss[j]) # for j in range(min(len(inds)-1, len(abss)))] # + [1 for _ in range(max(0, len(inds)-1-len(abss)))] # + [0]) remain_sents = min(len(inds) - 1, len(abss)) if choice == 'TD': previous_results = [[]] + results[i:i + remain_sents] """ abstractor agent 做完 rewrite 後 能得到的分數 """ rew_rs = ( [ reward_fn(results[i + j], abss[j]) - reward_fn(previous_results[j], abss[j]) if act_inds[j].item() == 0 else 0 for j in range(min(len(inds) - 1, len(abss))) ] # + [reward_fn(sum(results[i:i+remain_sents+j+1],[]), sum(abss[:-1],[])) - reward_fn(sum(results[i:i+remain_sents+j],[]), sum(abss[:-1],[])) # if act_inds[j].item()==0 else 0 for j in range(max(0, len(inds)-1-len(abss)))]) + [0 for _ in range(max(0, len(inds) - 1 - len(abss)))]) # + [stop_coeff*stop_reward_fn( # list(concat( # [results[x] for x in range(i, i+len(inds)-1) # if act_inds[x-i].item()==0])), # list(concat(abss)))]) rew_rs += [sum(rew_rs)] """ abstractor agent 做完 preserve 後 能得到的分數 """ pres_rs = ( [ reward_fn(results[i + j], abss[j]) - reward_fn(previous_results[j], abss[j]) if act_inds[j].item() == 1 else 0 for j in range(min(len(inds) - 1, len(abss))) ] # + [reward_fn(sum(results[i:i+remain_sents+j+1],[]), sum(abss[:-1],[])) - reward_fn(sum(results[i:i+remain_sents+j],[]), sum(abss[:-1],[])) # if act_inds[j].item()==1 else 0 for j in range(max(0, len(inds)-1-len(abss)))]) + [0 for _ in range(max(0, len(inds) - 1 - len(abss)))]) # + [stop_coeff*stop_reward_fn( # list(concat( # [results[x] for x in range(i, i+len(inds)-1) # if act_inds[x-i].item()==1])), # list(concat(abss)))]) pres_rs += [sum(pres_rs)] elif choice == 'MC': """ abstractor agent 做完 rewrite 後 能得到的分數 """ rew_rs = ( [ reward_fn(results[i + j], abss[j]) if act_inds[j].item() == 0 else 0 for j in range(min(len(inds) - 1, len(abss))) ] # + [reward_fn(results[i+remain_sents+j], sum(abss[:-1],[])) # if act_inds[j].item()==0 else 0 for j in range(max(0, len(inds)-1-len(abss)))]) + [0 for _ in range(max(0, len(inds) - 1 - len(abss)))]) # + [0 for _ in range(max(0, len(inds)-1-len(abss)))]) rew_rs += [sum(rew_rs)] """ abstractor agent 做完 preserve 後 能得到的分數 """ pres_rs = ( [ reward_fn(results[i + j], abss[j]) if act_inds[j].item() == 1 else 0 for j in range(min(len(inds) - 1, len(abss))) ] # + [reward_fn(results[i+remain_sents+j], sum(abss[:-1],[])) # if act_inds[j].item()==1 else 0 for j in range(max(0, len(inds)-1-len(abss)))]) + [0 for _ in range(max(0, len(inds) - 1 - len(abss)))]) pres_rs += [sum(pres_rs)] return rew_rs, pres_rs rew_rs, pres_rs = single_compared_MC_TD(choice='TD') """ 累計 """ def accumulated_compared_MC_TD(choice): remain_sents = min(len(inds) - 1, len(abss)) if choice == 'MC': """ abstractor agent 做完 rewrite 後 能得到的分數 """ rew_rs = ([ reward_fn(results[i + j], sum(abss[:j + 1], [])) if act_inds[j].item() == 0 else 0 for j in range(min(len(inds) - 1, len(abss))) ] + [ reward_fn(results[i + remain_sents + j], sum( abss[:-1], [])) if act_inds[j].item() == 0 else 0 for j in range(max(0, len(inds) - 1 - len(abss))) ]) # + [0 for _ in range(max(0, len(inds)-1-len(abss)))]) rew_rs += [sum(rew_rs)] """ abstractor agent 做完 preserve 後 能得到的分數 """ pres_rs = ([ reward_fn(results[i + j], sum(abss[:j + 1], [])) if act_inds[j].item() == 1 else 0 for j in range(min(len(inds) - 1, len(abss))) ] + [ reward_fn(results[i + remain_sents + j], sum( abss[:-1], [])) if act_inds[j].item() == 1 else 0 for j in range(max(0, len(inds) - 1 - len(abss))) ]) # + [0 for _ in range(max(0, len(inds)-1-len(abss)))]) pres_rs += [sum(pres_rs)] elif choice == 'TD': """ abstractor agent 做完 rewrite 後 能得到的分數 """ rew_rs = ([ reward_fn(sum(results[i:i + j + 1], []), sum(abss[:j + 1], [])) - reward_fn(sum(results[i:i + j], []), sum(abss[:j], [])) if act_inds[j].item() == 0 else 0 for j in range(min(len(inds) - 1, len(abss))) ] + [ reward_fn(sum(results[i:i + remain_sents + j + 1], []), sum(abss[:-1], [])) - reward_fn(sum(results[i:i + remain_sents + j], []), sum(abss[:-1], [])) if act_inds[j].item() == 0 else 0 for j in range(max(0, len(inds) - 1 - len(abss))) ]) # + [0 for _ in range(max(0, len(inds)-1-len(abss)))]) rew_rs += [sum(rew_rs)] """ abstractor agent 做完 preserve 後 能得到的分數 """ pres_rs = ([ reward_fn(sum(results[i:i + j + 1], []), sum(abss[:j + 1], [])) - reward_fn(sum(results[i:i + j], []), sum(abss[:j], [])) if act_inds[j].item() == 1 else 0 for j in range(min(len(inds) - 1, len(abss))) ] + [ reward_fn(sum(results[i:i + remain_sents + j + 1], []), sum(abss[:-1], [])) - reward_fn(sum(results[i:i + remain_sents + j], []), sum(abss[:-1], [])) if act_inds[j].item() == 1 else 0 for j in range(max(0, len(inds) - 1 - len(abss))) ]) pres_rs += [sum(pres_rs)] return rew_rs, pres_rs # rew_rs, pres_rs = accumulated_compared_MC_TD(choice='MC') assert len(rs) == len(inds) avg_reward += rs[-1] / stop_coeff avg_reward_abs_rew += rew_rs[-1] / stop_coeff avg_reward_abs_pre += pres_rs[-1] / stop_coeff i += len(inds) - 1 # compute discounted rewards R = 0 disc_rs = [] for r in rs[::-1]: R = r + gamma * R disc_rs.insert(0, R) rewards += disc_rs rewrite_rewards += rew_rs preserve_rewards += pres_rs indices = list(concat(indices)) probs = list(concat(probs)) baselines = list(concat(baselines)) target_probs = list(concat(target_probs)) act = list(concat(act)) act_probs = list(concat(act_probs)) act_baselines = list(concat(act_baselines)) # act_precision = list(concat(act_precision)) # act_precision = torch.Tensor(act_precision).to(act_baselines[0].device) """ 三個隱動作 reward 計算 """ rewrite_rewards = torch.Tensor(rewrite_rewards).to(act_baselines[0].device) rewrite_rewards = (rewrite_rewards - rewrite_rewards.mean()) / ( rewrite_rewards.std() + float(np.finfo(np.float32).eps)) preserve_rewards = torch.Tensor(preserve_rewards).to( act_baselines[0].device) preserve_rewards = (preserve_rewards - preserve_rewards.mean()) / ( preserve_rewards.std() + float(np.finfo(np.float32).eps)) complex_rewards = torch.stack([rewrite_rewards, preserve_rewards], dim=-1) # standardize rewards reward = torch.Tensor(rewards).to(baselines[0].device) reward = (reward - reward.mean()) / (reward.std() + float(np.finfo(np.float32).eps)) baseline = torch.cat(baselines).squeeze() act_baselines = torch.cat(act_baselines).squeeze().view(-1, 2) assert len(indices) == len(probs) == len(reward) == len(baseline) assert len(act) == len(act_probs) == len(act_baselines) avg_advantage = 0 losses = [] entropy_target = [] entropy_value = [] ratios = [] for action, p, tp, r, b in zip(indices, probs, target_probs, reward, baseline): # ratio = torch.exp(p.log_prob(action) - tp.log_prob(action)) # ratios.append(ratio) entropy_target.append(tp.entropy().detach().item()) entropy_value.append(p.entropy().detach().item()) advantage = r - b avg_advantage += advantage current_avg_advantage = advantage / len(indices) losses.append(-p.log_prob(action) * current_avg_advantage) abs_losses_rew = [] abs_losses_pre = [] for action, p, r, b in zip(act, act_probs, complex_rewards, act_baselines): advantage = r - b current_avg_advantage = advantage / len(act) # Rewrite abs_losses_rew.append(-p.log_prob(action) * current_avg_advantage[0]) # divide by T*B # Preserve abs_losses_pre.append(-p.log_prob(action) * current_avg_advantage[1]) # divide by T*B critic_loss = F.mse_loss(baseline, reward) critic_loss_re = F.mse_loss(act_baselines, complex_rewards) # backprop and update autograd.backward(tensors=[critic_loss] + losses, grad_tensors=[torch.ones(1).to(critic_loss.device)] * (1 + len(losses)), retain_graph=True) autograd.backward(tensors=[critic_loss_re] + abs_losses_rew + abs_losses_pre, grad_tensors=[torch.ones(1).to(critic_loss.device)] * (1 + len(abs_losses_rew) + len(abs_losses_pre))) opt.step() target_agent.load_state_dict(agent.state_dict()) ## clear cache torch.cuda.empty_cache() grad_log = grad_fn() log_dict = {} log_dict.update(grad_log) log_dict['entropy'] = sum(entropy_value) / len(entropy_value) log_dict['entropy_target'] = sum(entropy_target) / len(entropy_target) # log_dict['ratios'] = sum(ratios)/len(ratios) log_dict['reward'] = avg_reward / len(art_batch) log_dict['abs_reward_rew'] = avg_reward_abs_rew / len(art_batch) log_dict['abs_reward_pre'] = avg_reward_abs_pre / len(art_batch) log_dict['advantage'] = avg_advantage.item() / len(indices) log_dict['act_mse'] = critic_loss_re.item() log_dict['recall'] = sum(recall) / len(recall) log_dict['precision'] = sum(precision) / len(precision) assert not math.isnan(log_dict['grad_norm']) return log_dict
def a2c_train_step(agent, abstractor, loader, opt, grad_fn, gamma=0.99, reward_fn=compute_rouge_l, stop_reward_fn=compute_rouge_n(n=1), stop_coeff=1.0): opt.zero_grad() indices = [] probs = [] baselines = [] ext_sents = [] art_batch, abs_batch = next(loader) for raw_arts in art_batch: (inds, ms), bs = agent(raw_arts) baselines.append(bs) indices.append(inds) probs.append(ms) ext_sents += [raw_arts[idx.item()] for idx in inds if idx.item() < len(raw_arts)] with torch.no_grad(): summaries = abstractor(ext_sents) i = 0 rewards = [] avg_reward = 0 for inds, abss in zip(indices, abs_batch): rs = ([reward_fn(summaries[i+j], abss[j]) for j in range(min(len(inds)-1, len(abss)))] + [0 for _ in range(max(0, len(inds)-1-len(abss)))] + [stop_coeff*stop_reward_fn( list(concat(summaries[i:i+len(inds)-1])), list(concat(abss)))]) assert len(rs) == len(inds) avg_reward += rs[-1]/stop_coeff i += len(inds)-1 # compute discounted rewards R = 0 disc_rs = [] for r in rs[::-1]: R = r + gamma * R disc_rs.insert(0, R) rewards += disc_rs indices = list(concat(indices)) probs = list(concat(probs)) baselines = list(concat(baselines)) # standardize rewards reward = torch.Tensor(rewards).to(baselines[0].get_device()) reward = (reward - reward.mean()) / ( reward.std() + float(np.finfo(np.float32).eps)) baseline = torch.cat(baselines).squeeze() avg_advantage = 0 losses = [] for action, p, r, b in zip(indices, probs, reward, baseline): advantage = r - b avg_advantage += advantage losses.append(-p.log_prob(action) * (advantage/len(indices))) # divide by T*B critic_loss = F.mse_loss(baseline, reward) # backprop and update autograd.backward( [critic_loss] + losses, [torch.ones(1).to(critic_loss.get_device())]*(1+len(losses)) ) grad_log = grad_fn() opt.step() log_dict = {} log_dict.update(grad_log) log_dict['reward'] = avg_reward/len(art_batch) log_dict['advantage'] = avg_advantage.item()/len(indices) log_dict['mse'] = critic_loss.item() assert not math.isnan(log_dict['grad_norm']) return log_dict
def a2c_train_step(agent, abstractor, loader, opt, grad_fn, gamma=0.99, reward_fn=compute_rouge_l, stop_reward_fn=compute_rouge_n(n=1), stop_coeff=1.0): opt.zero_grad() indices = [] probs = [] baselines = [] ext_sents = [] masks = [] art_batch, abs_batch, sent_label_batch = next(loader) leng = [] avg_leng = [] dirty = [] time1 = time() for raw_arts, sent_labels in zip(art_batch, sent_label_batch): (inds, ms), bs = agent(raw_arts, sent_labels) assert inds[-1][-1].item() == len(raw_arts) + 1 baselines.append(bs) indices.append(inds) probs.append(ms) try: avg_leng.append( sum(list(map(lambda x: len(x) - 1, inds[:-1]))) / (len(inds) - 1)) except: pass leng.append(len(inds) - 1) tmp_stop = inds[-1][-1].item() tmp_truncate = tmp_stop - 1 str_arts = list(map(lambda x: ' '.join(x), raw_arts)) for idx in inds: t, m = rl_edu_to_sentence(str_arts, idx) assert len(t) == len(m) if t == []: assert len(idx) == 1 id = idx[0].item() if id == tmp_truncate: dirty.append(len(ext_sents)) ext_sents.append(label) masks.append(label_mask) else: if idx[-1].item() != tmp_stop: ext_sents.append(t) masks.append(m) print('长度:', leng) leng = list(map(lambda x: x[0] - len(x[1]), zip(leng, abs_batch))) avg_dis = sum(leng) / len(leng) print('平均差距:', avg_dis) print('平均edu数量:', sum(avg_leng) / len(avg_leng)) time2 = time() with torch.no_grad(): summaries = abstractor(ext_sents, masks) for d in dirty: summaries[d] = [] time3 = time() i = 0 rewards = [] avg_reward = 0 for inds, abss in zip(indices, abs_batch): rs_abs = [] rs_len = [] abs_num = min(len(inds) - 1, len(abss)) for j in range(abs_num): rs_abs.append(reward_fn(summaries[i + j], abss[j])) rs_len.append(len(inds[j])) rs_zero = [] for j in range(max(0, len(inds) - 1 - len(abss))): rs_zero += [0] * len(inds[j + abs_num]) rs_zero += [0] * (len(inds[-1]) - 1) rs_final = stop_coeff * stop_reward_fn( list(concat(summaries[i:i + len(inds) - 1])), list(concat(abss))) avg_reward += rs_final / stop_coeff i += len(inds) - 1 disc_rs = [rs_final] R = rs_final for _ in rs_zero: R = R * gamma disc_rs.append(R) for r, leng in zip(rs_abs, rs_len): R = r + R * gamma disc_rs += [R] * leng disc_rs = list(reversed(disc_rs)) assert len(disc_rs) == sum(list(map(lambda x: len(x), inds))) rewards += disc_rs indices = list(concat(list(concat(indices)))) probs = list(concat(probs)) baselines = list(concat(baselines)) # standardize rewards reward = torch.Tensor(rewards).to(baselines[0].device) assert len(reward) == len(probs) and len(baselines) == len(probs) and len( baselines) == len(indices) reward = (reward - reward.mean()) / (reward.std() + float(np.finfo(np.float32).eps)) baseline = torch.cat(baselines).squeeze() avg_advantage = 0 losses = [] for action, p, r, b in zip(indices, probs, reward, baseline): advantage = r - b avg_advantage += advantage losses.append(-(p.log_prob(action)) * (advantage / len(indices))) # divide by T*B critic_loss = F.mse_loss(baseline, reward) time4 = time() # backprop and update autograd.backward([critic_loss] + losses, [torch.ones(1).to(critic_loss.device)] * (1 + len(losses))) grad_log = grad_fn() opt.step() log_dict = {} log_dict.update(grad_log) log_dict['reward'] = avg_reward / len(art_batch) #print(avg_reward) log_dict['advantage'] = avg_advantage.item() / len(indices) log_dict['mse'] = critic_loss.item() assert not math.isnan(log_dict['grad_norm']) time5 = time() print(time2 - time1, time3 - time2, time4 - time3, time5 - time4) return log_dict
def decode(save_path, model_dir, split, batch_size, beam_size, diverse, max_len, cuda): start = time() # setup model with open(join(model_dir, 'meta.json')) as f: meta = json.loads(f.read()) if meta['net_args']['abstractor'] is None: # NOTE: if no abstractor is provided then # the whole model would be extractive summarization assert beam_size == 1 abstractor = identity else: if beam_size == 1: abstractor = Abstractor(join(model_dir, 'abstractor'), max_len, cuda) else: abstractor = BeamAbstractor(join(model_dir, 'abstractor'), max_len, cuda) extractor = RLExtractor(model_dir, cuda=cuda) # setup loader def coll(batch): articles, abstract, extracted = unzip(batch) articles = list(filter(bool, articles)) abstract = list(filter(bool, abstract)) extracted = list(filter(bool, extracted)) return articles, abstract, extracted dataset = DecodeDataset(split) n_data = len(dataset[0]) # article sentence loader = DataLoader( dataset, batch_size=batch_size, shuffle=False, num_workers=4, collate_fn=coll ) # prepare save paths and logs if os.path.exists(join(save_path, 'output')): pass else: os.makedirs(join(save_path, 'output')) dec_log = {} dec_log['abstractor'] = meta['net_args']['abstractor'] dec_log['extractor'] = meta['net_args']['extractor'] dec_log['rl'] = True dec_log['split'] = split dec_log['beam'] = beam_size dec_log['diverse'] = diverse with open(join(save_path, 'log.json'), 'w') as f: json.dump(dec_log, f, indent=4) file_path = os.path.join(save_path, 'Attention') act_path = os.path.join(save_path, 'Actions') header = "index,rouge_score1,rouge_score2,"+\ "rouge_scorel,dec_sent_nums,abs_sent_nums,doc_sent_nums,doc_words_nums,"+\ "ext_words_nums, abs_words_nums, diff,"+\ "recall, precision, less_rewrite, preserve_action, rewrite_action, each_actions,"+\ "top3AsAns, top3AsGold, any_top2AsAns, any_top2AsGold,true_rewrite,true_preserve\n" if not os.path.exists(file_path): print('create dir:{}'.format(file_path)) os.makedirs(file_path) if not os.path.exists(act_path): print('create dir:{}'.format(act_path)) os.makedirs(act_path) with open(join(save_path,'_statisticsDecode.log.csv'),'w') as w: w.write(header) # Decoding i = 0 with torch.no_grad(): for i_debug, (raw_article_batch, raw_abstract_batch, extracted_batch) in enumerate(loader): tokenized_article_batch = map(tokenize(None), raw_article_batch) tokenized_abstract_batch = map(tokenize(None), raw_abstract_batch) token_nums_batch = list(map(token_nums(None), raw_article_batch)) ext_nums = [] ext_arts = [] ext_inds = [] rewrite_less_rouge = [] dec_outs_act = [] ext_acts = [] abs_collections = [] ext_collections = [] # 抽句子 for ind, (raw_art_sents, abs_sents) in enumerate(zip(tokenized_article_batch ,tokenized_abstract_batch)): (ext, (state, act_dists)), act = extractor(raw_art_sents) # exclude EOE extracted_state = state[extracted_batch[ind]] attn = torch.softmax(state.mm(extracted_state.transpose(1,0)),dim=-1) # (_, abs_state), _ = extractor(abs_sents) # exclude EOE def plot_actDist(actons, nums): print('indiex: {} distribution ...'.format(nums)) # Write MDP State Attention weight matrix file_name = os.path.join(act_path, '{}.attention.pdf'.format(nums)) pdf_pages = PdfPages(file_name) plot_attention(actons.cpu().numpy(), name='{}-th artcle'.format(nums), X_label=list(range(len(raw_art_sents))), Y_label=list(range(len(ext))), dirpath=save_path, pdf_page=pdf_pages,action=True) pdf_pages.close() # plot_actDist(torch.stack(act_dists, dim=0), nums=ind+i) def plot_attn(): print('indiex: {} write_attention_pdf ...'.format(i + ind)) # Write MDP State Attention weight matrix file_name = os.path.join(file_path, '{}.attention.pdf'.format(i+ind)) pdf_pages = PdfPages(file_name) plot_attention(attn.cpu().numpy(), name='{}-th artcle'.format(i+ind), X_label=extracted_batch[ind],Y_label=list(range(len(raw_art_sents))), dirpath=save_path, pdf_page=pdf_pages) pdf_pages.close() # plot_attn() ext = ext[:-1] act = act[:-1] if not ext: # use top-5 if nothing is extracted # in some rare cases rnn-ext does not extract at all ext = list(range(5))[:len(raw_art_sents)] act = list([1]*5)[:len(raw_art_sents)] else: ext = [i.item() for i in ext] act = [i.item() for i in act] ext_nums.append(ext) ext_inds += [(len(ext_arts), len(ext))] # [(0,5),(5,7),(7,3),...] ext_arts += [raw_art_sents[k] for k in ext] ext_acts += [k for k in act] # 計算累計的句子 ext_collections += [sum(ext_arts[ext_inds[-1][0]:ext_inds[-1][0]+k+1],[]) for k in range(ext_inds[-1][1])] abs_collections += [sum(abs_sents[:k+1],[]) if k<len(abs_sents) else sum(abs_sents[0:len(abs_sents)],[]) for k in range(ext_inds[-1][1])] if beam_size > 1: # do n times abstract all_beams = abstractor(ext_arts, beam_size, diverse) dec_outs = rerank_mp(all_beams, ext_inds) dec_collections = [[sum(dec_outs[pos[0]:pos[0]+k+1],[]) for k in range(pos[1])] for pos in ext_inds] dec_collections = [x for sublist in dec_collections for x in sublist] for index, chooser in enumerate(ext_acts): if chooser == 0: dec_outs_act += [dec_outs[index]] else: dec_outs_act += [ext_arts[index]] assert len(ext_collections)==len(dec_collections)==len(abs_collections) for ext, dec, abss, act in zip(ext_collections, dec_collections, abs_collections, ext_acts): # for each sent in extracted digest # All abstract mapping rouge_before_rewriten = compute_rouge_n(ext, abss, n=1) rouge_after_rewriten = compute_rouge_n(dec, abss, n=1) diff_ins = rouge_before_rewriten - rouge_after_rewriten rewrite_less_rouge.append(diff_ins) else: # do 1st abstract dec_outs = abstractor(ext_arts) dec_collections = [[sum(dec_outs[pos[0]:pos[0]+k+1],[]) for k in range(pos[1])] for pos in ext_inds] dec_collections = [x for sublist in dec_collections for x in sublist] for index, chooser in enumerate(ext_acts): if chooser == 0: dec_outs_act += [dec_outs[index]] else: dec_outs_act += [ext_arts[index]] # dec_outs_act = dec_outs # dec_outs_act = ext_arts assert len(ext_collections)==len(dec_collections)==len(abs_collections) for ext, dec, abss, act in zip(ext_collections, dec_collections, abs_collections, ext_acts): # for each sent in extracted digest # All abstract mapping rouge_before_rewriten = compute_rouge_n(ext, abss, n=1) rouge_after_rewriten = compute_rouge_n(dec, abss, n=1) diff_ins = rouge_before_rewriten - rouge_after_rewriten rewrite_less_rouge.append(diff_ins) assert i == batch_size*i_debug for iters, (j, n) in enumerate(ext_inds): do_right_rewrite = sum([1 for rouge, action in zip(rewrite_less_rouge[j:j+n], ext_acts[j:j+n]) if rouge<0 and action==0]) do_right_preserve = sum([1 for rouge, action in zip(rewrite_less_rouge[j:j+n], ext_acts[j:j+n]) if rouge>=0 and action==1]) decoded_words_nums = [len(dec) for dec in dec_outs_act[j:j+n]] ext_words_nums = [token_nums_batch[iters][x] for x in range(len(token_nums_batch[iters])) if x in ext_nums[iters]] # 皆取extracted label # decoded_sents = [raw_article_batch[iters][x] for x in extracted_batch[iters]] # 統計數據 [START] decoded_sents = [' '.join(dec) for dec in dec_outs_act[j:j+n]] rouge_score1 = compute_rouge_n(' '.join(decoded_sents),' '.join(raw_abstract_batch[iters]),n=1) rouge_score2 = compute_rouge_n(' '.join(decoded_sents),' '.join(raw_abstract_batch[iters]),n=2) rouge_scorel = compute_rouge_l(' '.join(decoded_sents),' '.join(raw_abstract_batch[iters])) dec_sent_nums = len(decoded_sents) abs_sent_nums = len(raw_abstract_batch[iters]) doc_sent_nums = len(raw_article_batch[iters]) doc_words_nums = sum(token_nums_batch[iters]) ext_words_nums = sum(ext_words_nums) abs_words_nums = sum(decoded_words_nums) label_recall = len(set(ext_nums[iters]) & set(extracted_batch[iters])) / len(extracted_batch[iters]) label_precision = len(set(ext_nums[iters]) & set(extracted_batch[iters])) / len(ext_nums[iters]) less_rewrite = rewrite_less_rouge[j+n-1] dec_one_action_num = sum(ext_acts[j:j+n]) dec_zero_action_num = n - dec_one_action_num ext_indices = '_'.join([str(i) for i in ext_nums[iters]]) top3 = set([0,1,2]) <= set(ext_nums[iters]) top3_gold = set([0,1,2]) <= set(extracted_batch[iters]) # Any Top 2 top2 = set([0,1]) <= set(ext_nums[iters]) or set([1,2]) <= set(ext_nums[iters]) or set([0,2]) <= set(ext_nums[iters]) top2_gold = set([0,1]) <= set(extracted_batch[iters]) or set([1,2]) <= set(extracted_batch[iters]) or set([0,2]) <= set(extracted_batch[iters]) with open(join(save_path,'_statisticsDecode.log.csv'),'a') as w: w.write('{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{}\n'.format(i,rouge_score1, rouge_score2, rouge_scorel, dec_sent_nums, abs_sent_nums, doc_sent_nums, doc_words_nums, ext_words_nums,abs_words_nums,(ext_words_nums - abs_words_nums), label_recall, label_precision, less_rewrite, dec_one_action_num, dec_zero_action_num, ext_indices, top3, top3_gold, top2, top2_gold,do_right_rewrite,do_right_preserve)) # 統計數據 END with open(join(save_path, 'output/{}.dec'.format(i)), 'w') as f: decoded_sents = [i for i in decoded_sents if i!=''] if len(decoded_sents) > 0: f.write(make_html_safe('\n'.join(decoded_sents))) else: f.write('') i += 1 print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format( i, n_data, i/n_data*100, timedelta(seconds=int(time()-start)) ), end='') print()
def rl_validate(net, val_batches, reward_func=None, reward_coef=0.01, local_coh_func=None, local_coh_coef=0.005, bert=False): print('running validation ... ', end='') if bert: tokenizer = net._bert_model._tokenizer end = tokenizer.encoder[tokenizer._eos_token] unk = tokenizer.encoder[tokenizer._unk_token] def argmax(arr, keys): return arr[max(range(len(arr)), key=lambda i: keys[i].item())] def sum_id2word(raw_article_sents, decs, attns): if bert: dec_sents = [] for i, raw_words in enumerate(raw_article_sents): dec = [] for id_, attn in zip(decs, attns): if id_[i] == end: break elif id_[i] == unk: dec.append(argmax(raw_words, attn[i])) else: dec.append(id2word[id_[i].item()]) dec_sents.append(dec) else: dec_sents = [] for i, raw_words in enumerate(raw_article_sents): dec = [] for id_, attn in zip(decs, attns): if id_[i] == END: break elif id_[i] == UNK: dec.append(argmax(raw_words, attn[i])) else: dec.append(id2word[id_[i].item()]) dec_sents.append(dec) return dec_sents net.eval() start = time() i = 0 score = 0 score_reward = 0 score_local_coh = 0 score_r2 = 0 score_r1 = 0 bl_r2 = [] bl_r1 = [] with torch.no_grad(): for fw_args, bw_args in val_batches: raw_articles = bw_args[0] id2word = bw_args[1] raw_targets = bw_args[2] if reward_func is not None: questions = bw_args[3] greedies, greedy_attns = net.greedy(*fw_args) greedy_sents = sum_id2word(raw_articles, greedies, greedy_attns) bl_scores = [] if reward_func is not None: bl_coh_inputs = [] if local_coh_func is not None: bl_local_coh_scores = [] for baseline, target in zip(greedy_sents, raw_targets): if bert: text = ''.join(baseline) baseline = bytearray([ tokenizer.byte_decoder[c] for c in text ]).decode('utf-8', errors=tokenizer.errors) baseline = baseline.split(' ') text = ''.join(target) target = bytearray([ tokenizer.byte_decoder[c] for c in text ]).decode('utf-8', errors=tokenizer.errors) target = target.split(' ') bss = sent_tokenize(' '.join(baseline)) if reward_func is not None: bl_coh_inputs.append(bss) bss = [bs.split(' ') for bs in bss] tgs = sent_tokenize(' '.join(target)) tgs = [tg.split(' ') for tg in tgs] bl_score = compute_rouge_l_summ(bss, tgs) bl_r1.append( compute_rouge_n(list(concat(bss)), list(concat(tgs)), n=1)) bl_r2.append( compute_rouge_n(list(concat(bss)), list(concat(tgs)), n=2)) bl_scores.append(bl_score) bl_scores = torch.tensor(bl_scores, dtype=torch.float32, device=greedy_attns[0].device) if reward_func is not None: bl_reward_scores = reward_func.score(questions, bl_coh_inputs) bl_reward_scores = torch.tensor(bl_reward_scores, dtype=torch.float32, device=greedy_attns[0].device) score_reward += bl_reward_scores.mean().item() * 100 reward = bl_scores.mean().item() i += 1 score += reward * 100 score_r2 += torch.tensor( bl_r2, dtype=torch.float32, device=greedy_attns[0].device).mean().item() * 100 score_r1 += torch.tensor( bl_r1, dtype=torch.float32, device=greedy_attns[0].device).mean().item() * 100 val_score = score / i score_r2 = score_r2 / i score_r1 = score_r1 / i if reward_func is not None: val_reward_score = score_reward / i else: val_reward_score = 0 val_local_coh_score = 0 print( 'validation finished in {} '.format( timedelta(seconds=int(time() - start)))) print('validation reward: {:.4f} ... '.format(val_score)) print('validation r2: {:.4f} ... '.format(score_r2)) print('validation r1: {:.4f} ... '.format(score_r1)) if reward_func is not None: print('validation reward: {:.4f} ... '.format(val_reward_score)) if local_coh_func is not None: val_local_coh_score = score_local_coh / i print('validation {} reward: {:.4f} ... '.format( local_coh_func.__name__, val_local_coh_score)) print('n_data:', i) return {'score': val_score, 'score_reward:': val_reward_score}
def train(args): if not exists(args.path): os.makedirs(args.path) # make net if args.docgraph or args.paragraph: agent, agent_vocab, abstractor, net_args = configure_net_graph( args.abs_dir, args.ext_dir, args.cuda, args.docgraph, args.paragraph) else: agent, agent_vocab, abstractor, net_args = configure_net( args.abs_dir, args.ext_dir, args.cuda, True, False, args.rl_dir) if args.bert_stride > 0: assert args.bert_stride == agent._bert_stride # configure training setting assert args.stop > 0 train_params = configure_training('adam', args.lr, args.clip, args.decay, args.batch, args.gamma, args.reward, args.stop, 'rouge-1') if args.docgraph or args.paragraph: if args.bert: train_batcher, val_batcher = build_batchers_graph_bert( args.batch, args.key, args.adj_type, args.max_bert_word, args.docgraph, args.paragraph) else: train_batcher, val_batcher = build_batchers_graph( args.batch, args.key, args.adj_type, args.gold_key, args.docgraph, args.paragraph) elif args.bert: train_batcher, val_batcher = build_batchers_bert( args.batch, args.bert_sent, args.bert_stride, args.max_bert_word) else: train_batcher, val_batcher = build_batchers(args.batch) # TODO different reward if args.reward == 'rouge-l': reward_fn = compute_rouge_l elif args.reward == 'rouge-1': reward_fn = compute_rouge_n(n=1) elif args.reward == 'rouge-2': reward_fn = compute_rouge_n(n=2) elif args.reward == 'rouge-l-s': reward_fn = compute_rouge_l_summ else: raise Exception('Not prepared reward') stop_reward_fn = compute_rouge_n(n=1) # save abstractor binary if args.abs_dir is not None: abs_ckpt = {} abs_ckpt['state_dict'] = load_best_ckpt(args.abs_dir, reverse=True) abs_vocab = pkl.load(open(join(args.abs_dir, 'vocab.pkl'), 'rb')) abs_dir = join(args.path, 'abstractor') os.makedirs(join(abs_dir, 'ckpt')) with open(join(abs_dir, 'meta.json'), 'w') as f: json.dump(net_args['abstractor'], f, indent=4) torch.save(abs_ckpt, join(abs_dir, 'ckpt/ckpt-0-0')) with open(join(abs_dir, 'vocab.pkl'), 'wb') as f: pkl.dump(abs_vocab, f) # save configuration meta = {} meta['net'] = 'rnn-ext_abs_rl' meta['net_args'] = net_args meta['train_params'] = train_params with open(join(args.path, 'meta.json'), 'w') as f: json.dump(meta, f, indent=4) with open(join(args.path, 'agent_vocab.pkl'), 'wb') as f: pkl.dump(agent_vocab, f) # prepare trainer grad_fn = get_grad_fn(agent, args.clip) optimizer = optim.Adam(agent.parameters(), **train_params['optimizer'][1]) scheduler = ReduceLROnPlateau(optimizer, 'max', verbose=True, factor=args.decay, min_lr=1e-5, patience=args.lr_p) if args.docgraph or args.paragraph: entity = True else: entity = False pipeline = SCPipeline(meta['net'], agent, abstractor, train_batcher, val_batcher, optimizer, grad_fn, reward_fn, entity, args.bert) trainer = BasicTrainer(pipeline, args.path, args.ckpt_freq, args.patience, scheduler, val_mode='score') print('start training with the following hyper-parameters:') print(meta) trainer.train()
def a2c_train_step(agent, abstractor, loader, opt, grad_fn, gamma=0.99, reward_fn=compute_rouge_l, stop_reward_fn=compute_rouge_n(n=1), stop_coeff=1.0): opt.zero_grad() indices = [] probs = [] baselines = [] ext_sents, ext_sents_topic, ext_sents_probs, summ_indices = [], [], [], [] #art_batch, topic_batch, abs_batch, topic_label_batch = next(loader) art_batch, topic_batch, abs_batch = next(loader) for raw_arts, topic in zip(art_batch, topic_batch): (inds, ms), bs = agent(raw_arts, topic) baselines.append(bs) indices.append(inds) probs.append(ms) num_sents = len(raw_arts) ext_sents += [ raw_arts[idx.item()] for idx in inds if idx.item() < num_sents ] #ext_sents_topic += [topic[idx.item()] for idx in inds if idx.item()<num_sents] ext_sents_probs += [ m.probs[0][idx] for idx, m in zip(inds, ms) if idx.item() < num_sents ] summ_indices += [[ idx.item() for idx in inds if idx.item() < num_sents ]] if ext_sents == []: #print('Reach the end') return None with torch.no_grad(): summaries = abstractor(ext_sents) # Expand the topic dis of the reference #topic_label_expand = [] #for label, inds in zip(topic_label_batch, summ_indices): # num_ext = len(inds) # topic_label_expand += [label]*num_ext """ legal_ext_num = len(ext_sents) #topic_label_expand = torch.stack(reconstruct_topic_dis(topic_label_expand, legal_ext_num)).squeeze(1).cuda() ext_sents_topic = torch.stack(reconstruct_topic_dis_rl(ext_sents_topic, legal_ext_num)).squeeze(1).cuda() ext_sents_probs = torch.stack(ext_sents_probs).squeeze(1).cuda() #rs_topic = 1 - ext_sents_probs * abs(topic_label_expand - ext_sents_topic).sum(dim=1) # by L1 distance #rs_topic = 0.01*(ext_sents_probs * abs(topic_label_expand * ext_sents_topic).sum(dim=1)) # by inner product rs_topic = 0.05*ext_sents_probs * (ext_sents_topic).sum(dim=1) # by inner product if m2 rs_topic_avg = rs_topic.sum()/rs_topic.size(0) rs_topic_list = rs_topic.cpu().tolist() """ ## collect the generated headline summaries_collect = [None] * len(summ_indices) cnt = 0 i = 0 for summ_inds in summ_indices: try: if len(summ_inds) != 0: summaries_collect[cnt] = summaries[i] i = i + len(summ_inds) else: summaries_collect[cnt] = [" "] except: pdb.set_trace() cnt += 1 add_cls_loss = True if add_cls_loss == True: ## collect the gnerated headline summaries_collect = [None] * len(summ_indices) cnt = 0 i = 0 for summ_inds in summ_indices: try: if len(summ_inds) != 0: summaries_collect[cnt] = summaries[i] i = i + len(summ_inds) else: summaries_collect[cnt] = [" "] except: pdb.set_trace() cnt += 1 with open( '/home/yunzhu/Headline/FASum/PORLHG_v3/model/classifier/TEXT.Field', 'rb') as f: #with open('/home/yunzhu/Headline/FASum/PORLHG_v3/model/classifier/TEXT_abs.Field', 'rb') as f: TEXT = dill.load(f) vocab_size = len(TEXT.vocab) word_embeddings = TEXT.vocab.vectors prediction = get_gen_score(summaries_collect, TEXT, vocab_size, word_embeddings) cls_score = (prediction[:, 1].sum() / prediction.size(0)).item() else: TEXT = None vocab_size = None word_embeddings = None #prediction = get_gen_score(summaries_collect, TEXT, vocab_size, word_embeddings) #cls_score = prediction[:,1].sum().item() i, i_topic = 0, 0 rewards = [] avg_reward = 0 cnt = 0 for inds, abss in zip(indices, abs_batch): if add_cls_loss == True: try: #cls_r = prediction[cnt,1].item()*0.2 cls_r = prediction[cnt, 1].item() except: print('[Info] In rl.py code, cls_r=0') cls_r = 0 else: cls_r = 0 #rs_topic_ = (rs_topic[:len(inds)].sum()/len(inds)).item() rs_topic_ = 0 rs = ([ reward_fn(summaries[i + j], abss[j]) + cls_r + rs_topic_ for j in range(min(len(inds) - 1, len(abss))) ] + [0 for _ in range(max(0, len(inds) - 1 - len(abss)))] + [ stop_coeff * (stop_reward_fn(list(concat(summaries[i:i + len(inds) - 1])), list(concat(abss)))) ]) #list(concat(abss)))+cls_r)]) # + cls_r #try: #rs[:len(inds)-1] = np.add(rs[:len(inds)-1], rs_topic_list[i_topic:i_topic+len(inds)-1]) # rs[:-1] = np.add(rs[:-1], rs_topic_list[i:i+len(inds)-1]) #except: # pdb.set_trace() assert len(rs) == len(inds) avg_reward += rs[-1] / stop_coeff i += len(inds) - 1 i_topic += len(inds) # compute discounted rewards R = 0 disc_rs = [] for r in rs[::-1]: R = r + gamma * R disc_rs.insert(0, R) rewards += disc_rs cnt += 1 indices = list(concat(indices)) probs = list(concat(probs)) baselines = list(concat(baselines)) # standardize rewards reward = torch.Tensor(rewards).to(baselines[0].device) reward = (reward - reward.mean()) / (reward.std() + float(np.finfo(np.float32).eps)) baseline = torch.cat(baselines).squeeze() avg_advantage = 0 losses = [] for action, prob, r, b in zip(indices, probs, reward, baseline): advantage = r - b avg_advantage += advantage losses.append(-prob.log_prob(action) * (advantage / len(indices))) # divide by T*B critic_loss = F.mse_loss(baseline, reward) # backprop and update autograd.backward([critic_loss.unsqueeze(0)] + losses, [torch.ones(1).to('cuda')] * (1 + len(losses)) #[torch.ones(1).to(critic_loss.device)]*(1+len(losses)) ) grad_log = grad_fn() opt.step() log_dict = {} log_dict.update(grad_log) log_dict['reward'] = avg_reward / len(art_batch) log_dict['advantage'] = avg_advantage.item() / len(indices) log_dict['mse'] = critic_loss.item() #log_dict['rs_topic'] = rs_topic_avg.item() if add_cls_loss == True: log_dict['cls_score'] = cls_score assert not math.isnan(log_dict['grad_norm']) return log_dict
def a2c_train_step(agent, abstractor, loader, opt, grad_fn, gamma=0.99, reward_fn=compute_rouge_l, stop_reward_fn=compute_rouge_n(n=1), stop_coeff=1.0): opt.zero_grad() indices = [] probs = [] baselines = [] ext_sents = [] art_batch, abs_batch = next(loader) for raw_arts in art_batch: (inds, ms), bs = agent(raw_arts) baselines.append(bs) indices.append(inds) probs.append(ms) ext_sents += [ raw_arts[idx.item()] for idx in inds if idx.item() < len(raw_arts) ] with torch.no_grad(): summaries = abstractor(ext_sents) i = 0 rewards = [] avg_reward = 0 for inds, abss in zip(indices, abs_batch): rs = ([ reward_fn(summaries[i + j], abss[j]) for j in range(min(len(inds) - 1, len(abss))) ] + [0 for _ in range(max(0, len(inds) - 1 - len(abss)))] + [ stop_coeff * stop_reward_fn(list(concat(summaries[i:i + len(inds) - 1])), list(concat(abss))) ]) assert len(rs) == len(inds) avg_reward += rs[-1] / stop_coeff i += len(inds) - 1 # compute discounted rewards R = 0 disc_rs = [] for r in rs[::-1]: R = r + gamma * R disc_rs.insert(0, R) rewards += disc_rs indices = list(concat(indices)) probs = list(concat(probs)) baselines = list(concat(baselines)) # standardize rewards reward = torch.Tensor(rewards).to(baselines[0].device) reward = (reward - reward.mean()) / (reward.std() + float(np.finfo(np.float32).eps)) baseline = torch.cat(baselines).squeeze() avg_advantage = 0 losses = [] for action, p, r, b in zip(indices, probs, reward, baseline): advantage = r - b avg_advantage += advantage losses.append(-p.log_prob(action) * (advantage / len(indices))) # divide by T*B critic_loss = F.mse_loss(baseline, reward) # backprop and update autograd.backward([critic_loss] + losses, [torch.ones(1).to(critic_loss.device)] * (1 + len(losses))) grad_log = grad_fn() opt.step() log_dict = {} log_dict.update(grad_log) log_dict['reward'] = avg_reward / len(art_batch) log_dict['advantage'] = avg_advantage.item() / len(indices) log_dict['mse'] = critic_loss.item() assert not math.isnan(log_dict['grad_norm']) return log_dict
def a2c_train_step(agent, abstractor, loader, opt, grad_fn, gamma=0.99, reward_fn=compute_rouge_l, stop_reward_fn=compute_rouge_n(n=1), stop_coeff=1.0): sample_time = 1 time_variant = True gamma = 0.95 opt.zero_grad() art_batch, abs_batch, extracts = next(loader) all_loss = [] reward = 0 advantage = 0 i = 0 greedy_inputs = [] sample_inputs = [] sample_log_probs = [] for idx, raw_arts in enumerate(art_batch): #print(len(raw_arts)) #print(len(raw_arts[0])) greedy, samples, all_log_probs = agent(raw_arts, sample_time=sample_time) #print(greedy, samples, all_log_probs) #print(len(greedy), len(samples), len(all_log_probs)) #exit() if time_variant: bs = [] abss = abs_batch[idx] for _ind, gd in enumerate(greedy): #print('1',greedy_sents) greedy_sents = [] ext_sent = [] for ids in gd: if ids < len(raw_arts): if ids == 0: if ext_sent: greedy_sents.append(ext_sent) ext_sent = [] else: ext_sent += raw_arts[ids] if gd[-1] != 0 and ext_sent: greedy_sents.append(ext_sent) #print('2',greedy_sents) #greedy_sents = [raw_arts[ind] for ind in gd] baseline = 0 for i, sent in enumerate(greedy_sents): #greedy_sents = [[word for sent in greedy_sents for word in sent]] #greedy_sents = abstractor(greedy_sents) #print('3',greedy_sents) #exit() with torch.no_grad(): #print(sent) greedy_sent = abstractor([sent]) #print(greedy_sent) #print('section') #exit() greedy_sent = sent_tokenize(' '.join(greedy_sent[0])) greedy_sent = [ s.strip().split(' ') for s in greedy_sent ] #print('1', list(concat(greedy_sent))) #exit() if reward_fn.__name__ != 'compute_rouge_l_summ': if _ind != len(greedy) - 1: if i < len(abss): baseline += compute_rouge_with_marginal_increase( reward_fn, greedy_sent, abss[i], _ind, gamma=gamma) else: if i < len(abss): baseline += reward_fn( list(concat(greedy_sent)), list(concat(abss[i]))) #print('1', baseline) else: if _ind != len(greedy) - 1: if i < len(abss): baseline += compute_rouge_with_marginal_increase( reward_fn, greedy_sent, abss[i], _ind, gamma=gamma) else: if i < len(abss): baseline += reward_fn(greedy_sent, abss[i]) #print('2', baseline) #print(baseline) bs.append(baseline) #exit() #print(greedy, len(greedy), len(bs)) #print(len(greedy), len(bs)) #print(samples) #exit() #sample_sents = [raw_arts[ind] for ind in samples[0]] sample_sents = [] ext_sent = [] for ids in samples[0]: if ids < len(raw_arts): if ids == 0: if ext_sent: sample_sents.append(ext_sent) ext_sent = [] else: ext_sent += raw_arts[ids] if gd[-1] != 0 and ext_sent: sample_sents.append(ext_sent) all_rewards = [] for j, sent in enumerate(sample_sents): with torch.no_grad(): #sample_sents = [[word for sent in sample_sents for word in sent]] sample_sent = abstractor([sent]) sample_sent = sent_tokenize(' '.join(sample_sent[0])) sample_sent = [s.strip().split(' ') for s in sample_sent] #print('2', sample_sent) if reward_fn.__name__ != 'compute_rouge_l_summ': #print(sample_sent, abss[j]) #exit() #rewards = [reward_fn(list(concat(sample_sent[:i+1])), list(concat(abss[j]))) for i in range(len(sample_sent))] rewards = [] for i in range(len(sample_sent)): if j < len(abss): #print('3', sample_sent[:+1]) #exit() rewards.append( reward_fn(list(concat(sample_sent[:i + 1])), list(concat(abss[j])))) else: rewards.append(0) #print(rewards,len(rewards)) #exit() for index in range(len(rewards)): rwd = 0 for _index in range(len(rewards) - index): if _index != 0: rwd += (rewards[_index + index] - rewards[_index + index - 1]) * math.pow(gamma, _index) else: rwd += rewards[_index + index] all_rewards.append(rwd) if j < len(abss): all_rewards.append( compute_rouge_n(list(concat(sample_sent[:j])), list(concat(abss[:j])))) else: all_rewards.append( compute_rouge_n(list(concat(sample_sent[:j])), list(concat(abss)))) else: #rewards = [reward_fn(sample_sent[:i + 1], abss[j]) for i in # range(len(sample_sent))] rewards = [] for i in range(len(sample_sent)): if j < len(abss): rewards.append( reward_fn(list(concat(sample_sent[:i + 1])), list(concat(abss[j])))) else: rewards.append(0) for index in range(len(rewards)): rwd = 0 for _index in range(len(rewards) - index): if _index != 0: rwd += (rewards[_index + index] - rewards[_index + index - 1]) * math.pow(gamma, _index) else: rwd += rewards[_index + index] all_rewards.append(rwd) all_rewards.append( compute_rouge_n(list(concat(sample_sent)), list(concat(abss[j])))) # print('greedy:', greedy) # print('sample:', samples[0]) # print('baseline:', bs) # print('rewars:', all_rewards) reward += bs[-1] advantage += (all_rewards[-1] - bs[-1]) i += 1 advs = [ torch.tensor([_bs - rwd], dtype=torch.float).to(all_log_probs[0][0].device) for _bs, rwd in zip(bs, all_rewards) ] for log_prob, adv in zip(all_log_probs[0], advs): all_loss.append(log_prob * adv) else: greedy_sents = [raw_arts[ind] for ind in greedy] greedy_sents = [word for sent in greedy_sents for word in sent] greedy_inputs.append(greedy_sents) sample_sents = [raw_arts[ind] for ind in samples[0]] sample_sents = [word for sent in sample_sents for word in sent] sample_inputs.append(sample_sents) sample_log_probs.append(all_log_probs[0]) if not time_variant: with torch.no_grad(): greedy_outs = abstractor(greedy_inputs) sample_outs = abstractor(sample_inputs) for greedy_sents, sample_sents, log_probs, abss in zip( greedy_outs, sample_outs, sample_log_probs, abs_batch): greedy_sents = sent_tokenize(' '.join(greedy_sents)) greedy_sents = [sent.strip().split(' ') for sent in greedy_sents] if reward_fn.__name__ != 'compute_rouge_l_summ': bs = reward_fn(list(concat(greedy_sents)), list(concat(abss))) else: bs = reward_fn(greedy_sents, abss) sample_sents = sent_tokenize(' '.join(sample_sents)) sample_sents = [sent.strip().split(' ') for sent in sample_sents] if reward_fn.__name__ != 'compute_rouge_l_summ': rwd = reward_fn(list(concat(sample_sents)), list(concat(abss))) else: rwd = reward_fn(sample_sents, abss) reward += bs advantage += (rwd - bs) i += 1 adv = torch.tensor([bs - rwd], dtype=torch.float).to(log_probs[0].device) for log_prob in log_probs: all_loss.append(log_prob * adv) reward = reward / i advantage = advantage / i # backprop and update loss = torch.cat(all_loss, dim=0).mean() loss.backward() grad_log = grad_fn() opt.step() log_dict = {} log_dict.update(grad_log) log_dict['reward'] = reward log_dict['advantage'] = advantage log_dict['mse'] = 0 assert not math.isnan(log_dict['grad_norm']) return log_dict
def sc_train_step(agent, abstractor, loader, opt, grad_fn, reward_fn=compute_rouge_l, sample_time=1, entity=False): gamma = 0.95 opt.zero_grad() art_batch, abs_batch = next(loader) all_loss = [] reward = 0 advantage = 0 i = 0 greedy_inputs = [] sample_inputs = [] sample_log_probs = [] for idx, raw_arts in enumerate(art_batch): greedy, samples, all_log_probs = agent(raw_arts, sample_time=sample_time) if agent.time_variant: bs = [] abss = abs_batch[idx] for _ind, gd in enumerate(greedy): greedy_sents = [raw_arts[ind] for ind in gd] with torch.no_grad(): greedy_sents = [[ word for sent in greedy_sents for word in sent ]] greedy_sents = abstractor(greedy_sents) greedy_sents = sent_tokenize(' '.join(greedy_sents[0])) greedy_sents = [ sent.strip().split(' ') for sent in greedy_sents ] if reward_fn.__name__ != 'compute_rouge_l_summ': if _ind != len(greedy) - 1: baseline = compute_rouge_with_marginal_increase( reward_fn, greedy_sents, abss, _ind, gamma=gamma) else: baseline = reward_fn(list(concat(greedy_sents)), list(concat(abss))) else: if _ind != len(greedy) - 1: baseline = compute_rouge_with_marginal_increase( reward_fn, greedy_sents, abss, _ind, gamma=gamma) else: baseline = reward_fn(greedy_sents, abss) bs.append(baseline) sample_sents = [raw_arts[ind] for ind in samples[0]] with torch.no_grad(): sample_sents = [[ word for sent in sample_sents for word in sent ]] sample_sents = abstractor(sample_sents) sample_sents = sent_tokenize(' '.join(sample_sents[0])) sample_sents = [ sent.strip().split(' ') for sent in sample_sents ] if reward_fn.__name__ != 'compute_rouge_l_summ': rewards = [ reward_fn(list(concat(sample_sents[:i + 1])), list(concat(abss))) for i in range(len(sample_sents)) ] all_rewards = [] for index in range(len(rewards)): rwd = 0 for _index in range(len(rewards) - index): if _index != 0: rwd += (rewards[_index + index] - rewards[_index + index - 1]) * math.pow( gamma, _index) else: rwd += rewards[_index + index] all_rewards.append(rwd) all_rewards.append( compute_rouge_n(list(concat(sample_sents)), list(concat(abss)))) else: rewards = [ reward_fn(sample_sents[:i + 1], abss) for i in range(len(sample_sents)) ] all_rewards = [] for index in range(len(rewards)): rwd = 0 for _index in range(len(rewards) - index): if _index != 0: rwd += (rewards[_index + index] - rewards[_index + index - 1]) * math.pow( gamma, _index) else: rwd += rewards[_index + index] all_rewards.append(rwd) all_rewards.append( compute_rouge_n(list(concat(sample_sents)), list(concat(abss)))) # print('greedy:', greedy) # print('sample:', samples[0]) # print('baseline:', bs) # print('rewars:', all_rewards) reward += bs[-1] advantage += (all_rewards[-1] - bs[-1]) i += 1 advs = [ torch.tensor([_bs - rwd], dtype=torch.float).to(all_log_probs[0][0].device) for _bs, rwd in zip(bs, all_rewards) ] for log_prob, adv in zip(all_log_probs[0], advs): all_loss.append(log_prob * adv) else: if entity: raw_arts = raw_arts[0] greedy_sents = [raw_arts[ind] for ind in greedy] greedy_sents = [word for sent in greedy_sents for word in sent] greedy_inputs.append(greedy_sents) sample_sents = [raw_arts[ind] for ind in samples[0]] sample_sents = [word for sent in sample_sents for word in sent] sample_inputs.append(sample_sents) sample_log_probs.append(all_log_probs[0]) if not agent.time_variant: with torch.no_grad(): greedy_outs = abstractor(greedy_inputs) sample_outs = abstractor(sample_inputs) for greedy_sents, sample_sents, log_probs, abss in zip( greedy_outs, sample_outs, sample_log_probs, abs_batch): greedy_sents = sent_tokenize(' '.join(greedy_sents)) greedy_sents = [sent.strip().split(' ') for sent in greedy_sents] if reward_fn.__name__ != 'compute_rouge_l_summ': bs = reward_fn(list(concat(greedy_sents)), list(concat(abss))) else: bs = reward_fn(greedy_sents, abss) sample_sents = sent_tokenize(' '.join(sample_sents)) sample_sents = [sent.strip().split(' ') for sent in sample_sents] if reward_fn.__name__ != 'compute_rouge_l_summ': rwd = reward_fn(list(concat(sample_sents)), list(concat(abss))) else: rwd = reward_fn(sample_sents, abss) reward += bs advantage += (rwd - bs) i += 1 adv = torch.tensor([bs - rwd], dtype=torch.float).to(log_probs[0].device) for log_prob in log_probs: all_loss.append(log_prob * adv) reward = reward / i advantage = advantage / i # backprop and update loss = torch.cat(all_loss, dim=0).mean() loss.backward() grad_log = grad_fn() opt.step() log_dict = {} log_dict.update(grad_log) log_dict['reward'] = reward log_dict['advantage'] = advantage log_dict['mse'] = 0 assert not math.isnan(log_dict['grad_norm']) return log_dict
def a2c_train_step(agent, abstractor, loader, opt, grad_fn, gamma=0.99, reward_fn=compute_rouge_l, stop_reward_fn=compute_rouge_n(n=1), stop_coeff=1.0): #, wt_rge = wt_rge): opt.zero_grad() indices = [] probs = [] baselines = [] ext_sents = [] art_batch, abs_batch = next(loader) for raw_arts, raw_abs in zip(art_batch, abs_batch): # print(raw_abs) (inds, ms), bs, raw_arts = agent(raw_arts, raw_abs_sents=raw_abs, n_abs=None) #, n_abs=3 # print(inds, len(raw_arts), len(inds)) baselines.append(bs) indices.append(inds) probs.append(ms) ext_sents += [ raw_arts[idx.item()] for idx in inds if idx.item() < len(raw_arts) ] with torch.no_grad(): if not ext_sents: summaries = [] elif use_bart: summaries = get_bart_summaries(ext_sents, tokenizer, bart_model) else: summaries = abstractor(ext_sents) i = 0 rewards = [] avg_reward = 0 # print(indices) # print(abs_batch) cands = [] refs = [] F1s = None if reward_fn == compute_bert_score or reward_fn == compute_bleurt_score: # print("bertscore -reward") for inds, abss in zip(indices, abs_batch): for j in range(min(len(inds) - 1, len(abss))): cands.append(" ".join(summaries[i + j])) refs.append(" ".join(abss[j])) i += len(inds) - 1 # print(len(cands), len(refs)) #around 120 each F1s = reward_fn(cands, refs) # print(F1s) i = 0 t = 0 avg_len = 0 for inds, abss in zip(indices, abs_batch): # print(abss) # abss_tot = [x for y in abss for x in y] # print([j for j in range(min(len(inds)-1, len(abss)))]) compute_rouge_lcompute_rouge_n(n=2) #PLEASE NOTE HERE x = len(inds) - 1 if (reward_fn == compute_bert_score or reward_fn == compute_bleurt_score) and wt_rge >= 0.0001: rwd_lst = [(1 - wt_rge) * F1s[t + j] + wt_rge * compute_rouge_n(summaries[i + j], abss[j], n=2) for j in range(min(len(inds) - 1, len(abss)))] # print(rwd_lst) t += min(len(inds) - 1, len(abss)) elif (reward_fn == compute_bert_score or reward_fn == compute_bleurt_score): rwd_lst = [ F1s[t + j] for j in range(min(len(inds) - 1, len(abss))) ] # print(rwd_lst) t += min(len(inds) - 1, len(abss)) elif wt_rge >= 0.0001: # print("hi", wt_rge) rwd_lst = [(1 - wt_rge) * reward_fn(summaries[i + j], abss[j]) + wt_rge * compute_rouge_l(summaries[i + j], abss[j]) for j in range(min(len(inds) - 1, len(abss)))] elif reward_fn == presumm_reward or reward_fn == presumm_reward2 or reward_fn == presumm_reward3 or reward_fn == presumm_reward4: rwd_lst = reward_fn(summaries[i:i + len(inds) - 1], abss) x = 3 # print("hi", x) elif reward_fn == "summ-rouge-l": tmp = list(concat(abss)) rwd_lst = [ compute_rouge_l(summaries[i + j], tmp, mode='r') for j in range(min(len(inds) - 1, 3)) ] else: rwd_lst = [ reward_fn(summaries[i + j], abss[j]) for j in range(min(len(inds) - 1, len(abss))) ] #abss_tot, mode=r rs = ( rwd_lst + [0 for _ in range(max(0, len(inds) - 1 - len(rwd_lst))) ] #len(inds)-1-len(abss) + [ stop_coeff * stop_reward_fn(list(concat(summaries[i:i + x])), list(concat(abss))) ]) assert len(rs) == len(inds) avg_reward += rs[-1] / stop_coeff if reward_fn == presumm_reward or reward_fn == presumm_reward3 or reward_fn == presumm_reward4: rs[-1] = 0 # if reward_fn == presumm_reward3: # rs.pop() avg_len += (len(inds) - 1) / len(abs_batch) # print(avg_reward) i += len(inds) - 1 # compute discounted rewards # print(rs) R = 0 disc_rs = [] for r in rs[::-1]: R = r + gamma * R disc_rs.insert(0, R) rewards += disc_rs # print(avg_len) indices = list(concat(indices)) probs = list(concat(probs)) baselines = list(concat(baselines)) # standardize rewards reward = torch.Tensor(rewards).to(baselines[0].device) # print(reward.mean()) reward = (reward - reward.mean()) / (reward.std() + float(np.finfo(np.float32).eps)) baseline = torch.cat(baselines).squeeze() avg_advantage = 0 losses = [] avg_list = [] ##ARJUN for action, p, r, b in zip(indices, probs, reward, baseline): advantage = (r - b).item() ##ARJUN # (r-b) avg_advantage += advantage avg_list.append(advantage) ##ARJUN avg_list = torch.Tensor(avg_list).to(baselines[0].device) ##ARJUN # avg_list = (avg_list - avg_list.mean()) / ( avg_list.std() + float(np.finfo(np.float32).eps)) ##ARJUN # for action, p, advantage in zip(indices, probs, avg_list): ##ARJUN losses.append(-p.log_prob(action) * (advantage / len(indices))) # divide by T*B critic_loss = F.mse_loss(baseline, reward).unsqueeze(dim=0) # backprop and update autograd.backward([critic_loss] + losses, [torch.ones(1).to(critic_loss.device)] * (1 + len(losses))) grad_log = grad_fn() opt.step() log_dict = {} log_dict.update(grad_log) log_dict['reward'] = avg_reward / len(art_batch) log_dict['advantage'] = avg_advantage / len(indices) #removed .item() log_dict['mse'] = critic_loss.item() log_dict['avg_len'] = avg_len assert not math.isnan(log_dict['grad_norm']) return log_dict
def a2c_train_step(agent, abstractor, loader, opt, grad_fn, gamma=0.99, reward_fn=compute_rouge_l, stop_reward_fn=compute_rouge_n(n=1), stop_coeff=1.0): opt.zero_grad() indices = [] indicesL = [] new_indices = [] probs = [] new_probs = [] baselines = [] ext_sents = [] art_batch, abs_batch, ext_batch = next(loader) raw_adv = [] r = 0 a = 0 f = 0 zeros = [] for raw_arts in art_batch: (inds, ms), bs = agent(raw_arts) baselines.append(bs) indices.append(inds) indicesL.append(len(inds)) probs.append(ms) #ext_sents += [raw_arts[idx.item()] # for idx in inds if idx.item() < len(raw_arts)] #print(ext_sents) #exit(0) inds_ = [] ms_ = [] ext_sent = [] zero = [] count = 0 q = 0 f = 0 #print(inds) for idx in inds: #print('bx',bx) if idx.item() < len(raw_arts): if idx.item() == 0: if ext_sent: count += 1 ext_sents.append(ext_sent) zero.append(f) inds_.append(1) raw_adv.append((r, a)) a += 1 q = 1 #print('1',a) else: if q == 1: raw_adv.append((r, a - 1)) else: raw_adv.append((r, a)) #print('a') ext_sent = [] count = 0 else: ext_sent += raw_arts[idx.item()] #print('b') count += 1 raw_adv.append((r, a)) q = 0 else: #print('c') if q == 1: raw_adv.append((r, a - 1)) else: raw_adv.append((r, a)) f += 1 if inds[-1].item() != 0: if ext_sent: ext_sents.append(ext_sent) zero.append(f) inds_.append(1) a += 1 #print('2',a) #ext_sent = [] if not inds_: ext_sents.append(raw_arts[0]) inds_.append(1) zero.append(f) a += 1 #print('3',a) r += 1 zeros.append(zero) new_indices.append(inds_) #print(len(inds),len(bs),len(inds_)) #print(len(zero), len(inds_)) #print(ms[0].probs) #exit() with torch.no_grad(): summaries = abstractor(ext_sents) i = 0 rewards = [] avg_reward = 0 set_reward, summ_reward = [], [] for inds, abss in zip(new_indices, abs_batch): set_reward.append([ reward_fn(summaries[i + j], abss[j]) for j in range(min(len(inds) - 1, len(abss))) ]) summ_reward.append(stop_coeff * stop_reward_fn( list(concat(summaries[i:i + len(inds) - 1])), list(concat(abss)))) ''' rs = ([reward_fn(summaries[i+j], abss[j]) for j in range(min(len(inds)-1, len(abss)))] + [0 for _ in range(max(0, len(inds)-1-len(abss)))] + [stop_coeff*stop_reward_fn( list(concat(summaries[i:i+len(inds)-1])), list(concat(abss)))]) if len(rs) != len(inds): print(len(rs),len(inds)) print(indices) print(new_indices) assert len(rs) == len(inds) avg_reward += rs[-1]/stop_coeff i += len(inds)-1 # compute discounted rewards R = 0 disc_rs = [] for r in rs[::-1]: R = r + gamma * R disc_rs.insert(0, R) rewards += disc_rs ''' ext_rewards = [] avg_ext_reward = 0 for inds, exts, arts, setR, summR, zero in zip(indices, ext_batch, art_batch, set_reward, summ_reward, zeros): k = 0 rs1 = [] ext_zero = [] for e in range(len(exts)): if exts[e] == 0: ext_zero.append(e) e0 = 0 maxlen = 0 for j in range(min(len(inds) - 1, len(exts))): #print(set(inds) & set(exts)) maxlen += 1 if j in zero and k < len(setR) and k < len(ext_zero) - 1: rs1.append(setR[k]) k += 1 e0 = ext_zero[k] + 1 else: try: if inds[j].item() < len(arts): rs1.append( compute_rouge_n(arts[inds[j].item()], arts[exts[e0]], n=1)) else: rs1.append(0.0) e0 += 1 except: maxlen -= 1 break rs2 = [0 for _ in range(len(inds) - 1 - maxlen)] stop = metrics(gt=exts, pred=inds, metrics_map=['MAP'])[0] rs3 = [stop_coeff * summR] rs = rs1 + rs2 + rs3 assert len(rs) == len(inds) avg_ext_reward += rs[-1] / stop_coeff # compute discounted rewards R = 0 disc_rs = [] #print(rs) #print(rs[::-1]) #exit() for r in rs[::-1]: R = r + gamma * R disc_rs.insert(0, R) ext_rewards += disc_rs #print(len(ext_rewards)) #exit() indices = list(concat(indices)) #print(len(indices),len(raw_adv)) new_indices = list(concat(new_indices)) #print(len(new_indices)) probs = list(concat(probs)) baselines = list(concat(baselines)) #exit() # standardize rewards ext_reward = torch.Tensor(ext_rewards).to(baselines[0].device) ext_reward = (ext_reward - ext_reward.mean()) / ( ext_reward.std() + float(np.finfo(np.float32).eps)) baseline = torch.cat(baselines).squeeze() avg_advantage = 0 losses = [] for r, b, p, action in zip(ext_reward, baseline, probs, indices): advantage = r - b avg_advantage += advantage losses.append((-p.log_prob(action) * (advantage / len(indices)))) # divide by T*B #exit() critic_loss = F.mse_loss(baseline, ext_reward) # backprop and update #print("[DEBUG]") #print(critic_loss) critic_loss = critic_loss.view([1]) #print(critic_loss) autograd.backward(tensors=[critic_loss] + losses, grad_tensors=[torch.ones(1).to(critic_loss.device)] * (1 + len(losses))) #exit() grad_log = grad_fn() opt.step() log_dict = {} log_dict.update(grad_log) log_dict['reward'] = avg_ext_reward / len(art_batch) log_dict['advantage'] = avg_advantage.item() / len(indices) log_dict['mse'] = critic_loss.item() assert not math.isnan(log_dict['grad_norm']) return log_dict
def train_step(self, sample_time=1): def argmax(arr, keys): return arr[max(range(len(arr)), key=lambda i: keys[i].item())] def sum_id2word(raw_article_sents, decs, attns, id2word): dec_sents = [] for i, raw_words in enumerate(raw_article_sents): dec = [] for id_, attn in zip(decs, attns): if id_[i] == END: break elif id_[i] == UNK: dec.append(argmax(raw_words, attn[i])) else: dec.append(id2word[id_[i].item()]) dec_sents.append(dec) return dec_sents def pack_seq(seq_list): return torch.cat([_.unsqueeze(1) for _ in seq_list], 1) # forward pass of model self._net.train() #self._net.zero_grad() fw_args, bw_args = next(self._batches) raw_articles = bw_args[0] id2word = bw_args[1] raw_targets = bw_args[2] with torch.no_grad(): greedies, greedy_attns = self._net.greedy(*fw_args) greedy_sents = sum_id2word(raw_articles, greedies, greedy_attns, id2word) bl_scores = [] for baseline, target in zip(greedy_sents, raw_targets): bss = sent_tokenize(' '.join(baseline)) tgs = sent_tokenize(' '.join(target)) bss = [bs.split(' ') for bs in bss] tgs = [tg.split(' ') for tg in tgs] bss_bleu = list(concat(bss)) tgs_bleu = list(concat(tgs)) bss_bleu = ' '.join(bss_bleu) tgs_bleu = ' '.join(tgs_bleu) #bl_score = compute_rouge_l_summ(bss, tgs) if self._bleu: bleu_scores = bleu(bss_bleu, tgs_bleu) bleu_score = (bleu_scores[0] + bleu_scores[1] + bleu_scores[2] + bleu_scores[3]) bl_score = bleu_score elif self.f1: bl_score = compute_f1(bss_bleu, tgs_bleu) else: bl_score = (self._w8[2] * compute_rouge_l_summ(bss, tgs) + \ self._w8[0] * compute_rouge_n(list(concat(bss)), list(concat(tgs)), n=1) + \ self._w8[1] * compute_rouge_n(list(concat(bss)), list(concat(tgs)), n=2)) bl_scores.append(bl_score) bl_scores = torch.tensor(bl_scores, dtype=torch.float32, device=greedy_attns[0].device) samples, sample_attns, seqLogProbs = self._net.sample(*fw_args) sample_sents = sum_id2word(raw_articles, samples, sample_attns, id2word) sp_seqs = pack_seq(samples) _masks = (sp_seqs > PAD).float() sp_seqLogProb = pack_seq(seqLogProbs) #loss_nll = - sp_seqLogProb.squeeze(2) loss_nll = -sp_seqLogProb.squeeze(2) * _masks.detach().type_as( sp_seqLogProb) sp_scores = [] for sample, target in zip(sample_sents, raw_targets): sps = sent_tokenize(' '.join(sample)) tgs = sent_tokenize(' '.join(target)) sps = [sp.split(' ') for sp in sps] tgs = [tg.split(' ') for tg in tgs] #sp_score = compute_rouge_l_summ(sps, tgs) sps_bleu = list(concat(sps)) tgs_bleu = list(concat(tgs)) sps_bleu = ' '.join(sps_bleu) tgs_bleu = ' '.join(tgs_bleu) # bl_score = compute_rouge_l_summ(bss, tgs) if self._bleu: bleu_scores = bleu(sps_bleu, tgs_bleu) bleu_score = (bleu_scores[0] + bleu_scores[1] + bleu_scores[2] + bleu_scores[3]) sp_score = bleu_score elif self.f1: sp_score = compute_f1(sps_bleu, tgs_bleu) else: sp_score = (self._w8[2] * compute_rouge_l_summ(sps, tgs) + \ self._w8[0] * compute_rouge_n(list(concat(sps)), list(concat(tgs)), n=1) + \ self._w8[1]* compute_rouge_n(list(concat(sps)), list(concat(tgs)), n=2)) sp_scores.append(sp_score) sp_scores = torch.tensor(sp_scores, dtype=torch.float32, device=greedy_attns[0].device) reward = sp_scores.view(-1, 1) - bl_scores.view(-1, 1) reward.requires_grad_(False) loss = reward.contiguous().detach() * loss_nll loss = loss.sum() full_length = _masks.data.float().sum() loss = loss / full_length loss.backward() log_dict = {} log_dict['reward'] = bl_scores.mean().item() if self._grad_fn is not None: log_dict.update(self._grad_fn()) self._opt.step() self._net.zero_grad() #torch.cuda.empty_cache() return log_dict
def train_step(self, sample_time=1): def argmax(arr, keys): return arr[max(range(len(arr)), key=lambda i: keys[i].item())] def sum_id2word(raw_article_sents, decs, attns, id2word): dec_sents = [] for i, raw_words in enumerate(raw_article_sents): dec = [] for id_, attn in zip(decs, attns): if id_[i] == END: break elif id_[i] == UNK: dec.append(argmax(raw_words, attn[i])) else: dec.append(id2word[id_[i].item()]) dec_sents.append(dec) return dec_sents def pack_seq(seq_list): return torch.cat([_.unsqueeze(1) for _ in seq_list], 1) # forward pass of model self._net.train() #self._net.zero_grad() fw_args, bw_args = next(self._batches) raw_articles = bw_args[0] id2word = bw_args[1] raw_targets = bw_args[2] with torch.no_grad(): greedies, greedy_attns = self._net.greedy(*fw_args) greedy_sents = sum_id2word(raw_articles, greedies, greedy_attns, id2word) bl_scores = [] if self._coh_fn is not None: bl_coh_scores = [] bl_coh_inputs = [] if self._local_coh_fun is not None: bl_local_coh_scores = [] for baseline, target in zip(greedy_sents, raw_targets): bss = sent_tokenize(' '.join(baseline)) tgs = sent_tokenize(' '.join(target)) if self._coh_fn is not None: bl_coh_inputs.append(bss) # if len(bss) > 1: # input_args = (bss, ) + self._coh_fn # coh_score = coherence_infer(*input_args) / 2 # else: # coh_score = 0 # bl_coh_scores.append(coh_score) if self._local_coh_fun is not None: local_coh_score = self._local_coh_fun(bss) bl_local_coh_scores.append(local_coh_score) bss = [bs.split(' ') for bs in bss] tgs = [tg.split(' ') for tg in tgs] #bl_score = compute_rouge_l_summ(bss, tgs) bl_score = (self._weights[2] * compute_rouge_l_summ(bss, tgs) + \ self._weights[0] * compute_rouge_n(list(concat(bss)), list(concat(tgs)), n=1) + \ self._weights[1] * compute_rouge_n(list(concat(bss)), list(concat(tgs)), n=2)) bl_scores.append(bl_score) bl_scores = torch.tensor(bl_scores, dtype=torch.float32, device=greedy_attns[0].device) if self._coh_fn is not None: input_args = (bl_coh_inputs,) + self._coh_fn bl_coh_scores = batch_global_infer(*input_args) bl_coh_scores = torch.tensor(bl_coh_scores, dtype=torch.float32, device=greedy_attns[0].device) if self._local_coh_fun is not None: bl_local_coh_scores = torch.tensor(bl_local_coh_scores, dtype=torch.float32, device=greedy_attns[0].device) # print('bl:', bl_scores) # print('bl_coh:', bl_coh_scores) for _ in range(sample_time): samples, sample_attns, seqLogProbs = self._net.sample(*fw_args) sample_sents = sum_id2word(raw_articles, samples, sample_attns, id2word) sp_seqs = pack_seq(samples) _masks = (sp_seqs > PAD).float() sp_seqLogProb = pack_seq(seqLogProbs) #loss_nll = - sp_seqLogProb.squeeze(2) loss_nll = - sp_seqLogProb.squeeze(2) * _masks.detach().type_as(sp_seqLogProb) sp_scores = [] if self._coh_fn is not None: sp_coh_scores = [] sp_coh_inputs = [] if self._local_coh_fun is not None: sp_local_coh_scores = [] for sample, target in zip(sample_sents, raw_targets): sps = sent_tokenize(' '.join(sample)) tgs = sent_tokenize(' '.join(target)) if self._coh_fn is not None: sp_coh_inputs.append(sps) # if len(sps) > 1: # input_args = (sps,) + self._coh_fn # coh_score = coherence_infer(*input_args) / 2 # else: # coh_score = 0 # sp_coh_scores.append(coh_score) if self._local_coh_fun is not None: local_coh_score = self._local_coh_fun(sps) sp_local_coh_scores.append(local_coh_score) sps = [sp.split(' ') for sp in sps] tgs = [tg.split(' ') for tg in tgs] #sp_score = compute_rouge_l_summ(sps, tgs) sp_score = (self._weights[2] * compute_rouge_l_summ(sps, tgs) + \ self._weights[0] * compute_rouge_n(list(concat(sps)), list(concat(tgs)), n=1) + \ self._weights[1] * compute_rouge_n(list(concat(sps)), list(concat(tgs)), n=2)) sp_scores.append(sp_score) sp_scores = torch.tensor(sp_scores, dtype=torch.float32, device=greedy_attns[0].device) reward = sp_scores.view(-1, 1) - bl_scores.view(-1, 1) reward.requires_grad_(False) if self._coh_fn is not None: input_args = (sp_coh_inputs,) + self._coh_fn sp_coh_scores = batch_global_infer(*input_args) sp_coh_scores = torch.tensor(sp_coh_scores, dtype=torch.float32, device=greedy_attns[0].device) reward_coh = sp_coh_scores.view(-1, 1) - bl_coh_scores.view(-1, 1) reward_coh.requires_grad_(False) reward = reward + self._coh_cof * reward_coh if self._local_coh_fun is not None: sp_local_coh_scores = torch.tensor(sp_local_coh_scores, dtype=torch.float32, device=greedy_attns[0].device) reward_local = sp_local_coh_scores.view(-1, 1) - bl_local_coh_scores.view(-1, 1) reward_local.requires_grad_(False) reward = reward + self._local_co_coef * reward_local if _ == 0: loss = reward.contiguous().detach() * loss_nll loss = loss.sum() full_length = _masks.data.float().sum() else: loss += (reward.contiguous().detach() * loss_nll).sum() full_length += _masks.data.float().sum() # print('sp:', sp_scores) # print('sp_coh:', sp_coh_scores) loss = loss / full_length # backward and update ( and optional gradient monitoring ) loss.backward() log_dict = {} if self._coh_fn is not None: log_dict['reward'] = bl_scores.mean().item() + self._coh_cof * bl_coh_scores.mean().item() else: log_dict['reward'] = bl_scores.mean().item() if self._grad_fn is not None: log_dict.update(self._grad_fn()) self._opt.step() self._net.zero_grad() torch.cuda.empty_cache() return log_dict
def a2c_train_step(agent, abstractor, loader, opt, grad_fn, gamma=0.99, reward_fn=compute_rouge_l, stop_reward_fn=compute_rouge_n(n=1), stop_coeff=1.0): opt.zero_grad() indices = [] probs = [] baselines = [] ext_sents = [] summ_indices = [] art_batch, topic_batch, abs_batch = next(loader) for raw_arts, topic in zip(art_batch, topic_batch): (inds, ms), bs = agent(raw_arts, topic) baselines.append(bs) indices.append(inds) probs.append(ms) ext_sents += [ raw_arts[idx.item()] for idx in inds if idx.item() < len(raw_arts) ] summ_indices += [[ idx.item() for idx in inds if idx.item() < len(raw_arts) ]] with torch.no_grad(): summaries = abstractor(ext_sents) ## collect the generated headline summaries_collect = [None] * len(summ_indices) cnt = 0 i = 0 for summ_inds in summ_indices: try: if len(summ_inds) != 0: summaries_collect[cnt] = summaries[i] i = i + len(summ_inds) else: summaries_collect[cnt] = [" "] except: pdb.set_trace() cnt += 1 add_cls_loss = False if add_cls_loss == True: ## collect the gnerated headline summaries_collect = [None] * len(summ_indices) cnt = 0 i = 0 for summ_inds in summ_indices: try: if len(summ_inds) != 0: summaries_collect[cnt] = summaries[i] i = i + len(summ_inds) else: summaries_collect[cnt] = [" "] except: pdb.set_trace() cnt += 1 with open( '/home/yunzhu/Headline/FASum/FASRL/model/classifier/TEXT.Field', 'rb') as f: TEXT = dill.load(f) vocab_size = len(TEXT.vocab) word_embeddings = TEXT.vocab.vectors else: TEXT = None vocab_size = None word_embeddings = None #prediction = get_gen_score(summaries_collect, TEXT, vocab_size, word_embeddings) #cls_score = prediction[:,1].sum().item() i = 0 rewards = [] avg_reward = 0 cnt = 0 for inds, abss in zip(indices, abs_batch): cls_r = 0 rs = ([ reward_fn(summaries[i + j], abss[j]) + cls_r for j in range(min(len(inds) - 1, len(abss))) ] + [0 for _ in range(max(0, len(inds) - 1 - len(abss)))] + [ stop_coeff * (stop_reward_fn(list(concat(summaries[i:i + len(inds) - 1])), list(concat(abss))) + cls_r) ]) assert len(rs) == len(inds) avg_reward += rs[-1] / stop_coeff i += len(inds) - 1 # compute discounted rewards R = 0 disc_rs = [] for r in rs[::-1]: R = r + gamma * R disc_rs.insert(0, R) rewards += disc_rs cnt += 1 indices = list(concat(indices)) probs = list(concat(probs)) baselines = list(concat(baselines)) # standardize rewards reward = torch.Tensor(rewards).to(baselines[0].device) reward = (reward - reward.mean()) / (reward.std() + float(np.finfo(np.float32).eps)) baseline = torch.cat(baselines).squeeze() avg_advantage = 0 losses = [] for action, p, r, b in zip(indices, probs, reward, baseline): advantage = r - b avg_advantage += advantage losses.append(-p.log_prob(action) * (advantage / len(indices))) # divide by T*B critic_loss = F.mse_loss(baseline, reward) # backprop and update autograd.backward([critic_loss.unsqueeze(0)] + losses, [torch.ones(1).to('cuda')] * (1 + len(losses)) #[torch.ones(1).to(critic_loss.device)]*(1+len(losses)) ) grad_log = grad_fn() opt.step() log_dict = {} log_dict.update(grad_log) log_dict['reward'] = avg_reward / len(art_batch) log_dict['advantage'] = avg_advantage.item() / len(indices) log_dict['mse'] = critic_loss.item() assert not math.isnan(log_dict['grad_norm']) return log_dict
def train_step(self, sample_time=1): torch.autograd.set_detect_anomaly(True) def argmax(arr, keys): return arr[max(range(len(arr)), key=lambda i: keys[i].item())] def sum_id2word(raw_article_sents, decs, attns, id2word): if self._bert: dec_sents = [] for i, raw_words in enumerate(raw_article_sents): dec = [] for id_, attn in zip(decs, attns): if id_[i] == self._end: break elif id_[i] == self._unk: dec.append(argmax(raw_words, attn[i])) else: dec.append(id2word[id_[i].item()]) dec_sents.append(dec) else: dec_sents = [] for i, raw_words in enumerate(raw_article_sents): dec = [] for id_, attn in zip(decs, attns): if id_[i] == self._end: break elif id_[i] == self._unk: dec.append(argmax(raw_words, attn[i])) else: dec.append(id2word[id_[i].item()]) dec_sents.append(dec) return dec_sents def pack_seq(seq_list): return torch.cat([_.unsqueeze(1) for _ in seq_list], 1) # forward pass of model self._net.train() #self._net.zero_grad() total_loss = None for i in range(self._accumulate_g_step): fw_args, bw_args = next(self._batches) raw_articles = bw_args[0] id2word = bw_args[1] raw_targets = bw_args[2] if self._reward_fn is not None: questions = bw_args[3] targets = bw_args[4] # encode # attention, init_dec_states, nodes = self._net.encode_general(*fw_args) #fw_args += (attention, init_dec_states, nodes) # _init_dec_states = ((init_dec_states[0][0].clone(), init_dec_states[0][1].clone()), init_dec_states[1].clone()) with torch.no_grad(): # g_fw_args = fw_args + (attention, _init_dec_states, nodes, False) # greedies, greedy_attns = self._net.rl_step(*g_fw_args) greedies, greedy_attns = self._net.greedy(*fw_args) greedy_sents = sum_id2word(raw_articles, greedies, greedy_attns, id2word) bl_scores = [] if self._reward_fn is not None: bl_reward_scores = [] bl_reward_inputs = [] if self._local_coh_fun is not None: bl_local_coh_scores = [] for baseline, target in zip(greedy_sents, raw_targets): if self._bert: text = ''.join(baseline) baseline = bytearray([ self._tokenizer.byte_decoder[c] for c in text ]).decode('utf-8', errors=self._tokenizer.errors) baseline = baseline.strip().lower().split(' ') text = ''.join(target) target = bytearray([ self._tokenizer.byte_decoder[c] for c in text ]).decode('utf-8', errors=self._tokenizer.errors) target = target.strip().lower().split(' ') bss = sent_tokenize(' '.join(baseline)) tgs = sent_tokenize(' '.join(target)) if self._reward_fn is not None: bl_reward_inputs.append(bss) if self._local_coh_fun is not None: local_coh_score = self._local_coh_fun(bss) bl_local_coh_scores.append(local_coh_score) bss = [bs.split(' ') for bs in bss] tgs = [tg.split(' ') for tg in tgs] #bl_score = compute_rouge_l_summ(bss, tgs) bl_score = (self._w8[2] * compute_rouge_l_summ(bss, tgs) + \ self._w8[0] * compute_rouge_n(list(concat(bss)), list(concat(tgs)), n=1) + \ self._w8[1] * compute_rouge_n(list(concat(bss)), list(concat(tgs)), n=2)) bl_scores.append(bl_score) bl_scores = torch.tensor(bl_scores, dtype=torch.float32, device=greedy_attns[0].device) # sample # s_fw_args = fw_args + (attention, init_dec_states, nodes, True) # samples, sample_attns, seqLogProbs = self._net.rl_step(*s_fw_args) fw_args += (self._ml_loss, ) if self._ml_loss: samples, sample_attns, seqLogProbs, ml_logit = self._net.sample( *fw_args) else: samples, sample_attns, seqLogProbs = self._net.sample(*fw_args) sample_sents = sum_id2word(raw_articles, samples, sample_attns, id2word) sp_seqs = pack_seq(samples) _masks = (sp_seqs > PAD).float() sp_seqLogProb = pack_seq(seqLogProbs) #loss_nll = - sp_seqLogProb.squeeze(2) loss_nll = -sp_seqLogProb.squeeze(2) * _masks.detach().type_as( sp_seqLogProb) sp_scores = [] if self._reward_fn is not None: sp_reward_inputs = [] for sample, target in zip(sample_sents, raw_targets): if self._bert: text = ''.join(sample) sample = bytearray([ self._tokenizer.byte_decoder[c] for c in text ]).decode('utf-8', errors=self._tokenizer.errors) sample = sample.strip().lower().split(' ') text = ''.join(target) target = bytearray([ self._tokenizer.byte_decoder[c] for c in text ]).decode('utf-8', errors=self._tokenizer.errors) target = target.strip().lower().split(' ') sps = sent_tokenize(' '.join(sample)) tgs = sent_tokenize(' '.join(target)) if self._reward_fn is not None: sp_reward_inputs.append(sps) sps = [sp.split(' ') for sp in sps] tgs = [tg.split(' ') for tg in tgs] #sp_score = compute_rouge_l_summ(sps, tgs) sp_score = (self._w8[2] * compute_rouge_l_summ(sps, tgs) + \ self._w8[0] * compute_rouge_n(list(concat(sps)), list(concat(tgs)), n=1) + \ self._w8[1]* compute_rouge_n(list(concat(sps)), list(concat(tgs)), n=2)) sp_scores.append(sp_score) sp_scores = torch.tensor(sp_scores, dtype=torch.float32, device=greedy_attns[0].device) if self._reward_fn is not None: sp_reward_scores, bl_reward_scores = self._reward_fn.score_two_seqs( questions, sp_reward_inputs, bl_reward_inputs) sp_reward_scores = torch.tensor(sp_reward_scores, dtype=torch.float32, device=greedy_attns[0].device) bl_reward_scores = torch.tensor(bl_reward_scores, dtype=torch.float32, device=greedy_attns[0].device) reward = sp_scores.view(-1, 1) - bl_scores.view(-1, 1) if self._reward_fn is not None: reward += self._reward_w8 * (sp_reward_scores.view(-1, 1) - bl_reward_scores.view(-1, 1)) reward.requires_grad_(False) loss = reward.contiguous().detach() * loss_nll loss = loss.sum() full_length = _masks.data.float().sum() loss = loss / full_length if self._ml_loss: ml_loss = self._ml_criterion(ml_logit, targets) loss += self._ml_loss_w8 * ml_loss.mean() # if total_loss is None: # total_loss = loss # else: # total_loss += loss loss = loss / self._accumulate_g_step loss.backward() log_dict = {} if self._reward_fn is not None: log_dict['reward'] = bl_scores.mean().item() log_dict['question_reward'] = bl_reward_scores.mean().item() log_dict['sample_question_reward'] = sp_reward_scores.mean().item() log_dict['sample_reward'] = sp_scores.mean().item() else: log_dict['reward'] = bl_scores.mean().item() if self._grad_fn is not None: log_dict.update(self._grad_fn()) self._opt.step() self._net.zero_grad() #torch.cuda.empty_cache() return log_dict