def compute_reward_old(pred_str, pred_sent_list, trg_str, trg_sent_list, reward_type, regularization_factor=0.0, regularization_type=0, entropy=None): if regularization_type == 1: raise ValueError("Not implemented.") elif regularization_type == 2: regularization = entropy else: regularization = 0.0 if reward_type == 0: tmp_reward = 0.3 * compute_rouge_n( pred_str, trg_str, n=1, mode='f') + 0.2 * compute_rouge_n( pred_str, trg_str, n=2, mode='f') + 0.5 * compute_rouge_l_summ( pred_sent_list, trg_sent_list, mode='f') elif reward_type == 1: tmp_reward = compute_rouge_l_summ(pred_sent_list, trg_sent_list, mode='f') else: raise ValueError # Add the regularization term to the reward only if regularization type != 0 if regularization_type == 0 or regularization_factor == 0: reward = tmp_reward else: reward = (1 - regularization_factor ) * tmp_reward + regularization_factor * regularization return reward
def rouge_l_reward(pred_str_list, pred_sent_2d_list, trg_str_list, trg_sent_2d_list, batch_size, device): #reward = np.zeros(batch_size) reward = [] for idx, (pred_str, pred_sent_list, trg_str, trg_sent_list) in enumerate( zip(pred_str_list, pred_sent_2d_list, trg_str_list, trg_sent_2d_list)): #reward[idx] = compute_rouge_l_summ(pred_sent_list, trg_sent_list, mode='f') reward.append( compute_rouge_l_summ(pred_sent_list, trg_sent_list, mode='f')) return torch.FloatTensor(reward).to(device) # tensor: [batch_size]
def rl_validate(net, val_batches, coherence_func=None, coh_coef = 0.01, local_coh_func=None, local_coh_coef=0.005): print('running validation ... ', end='') def argmax(arr, keys): return arr[max(range(len(arr)), key=lambda i: keys[i].item())] def sum_id2word(raw_article_sents, decs, attns): dec_sents = [] for i, raw_words in enumerate(raw_article_sents): dec = [] for id_, attn in zip(decs, attns): if id_[i] == END: break elif id_[i] == UNK: dec.append(argmax(raw_words, attn[i])) else: dec.append(id2word[id_[i].item()]) dec_sents.append(dec) return dec_sents net.eval() start = time() i = 0 score = 0 score_coh = 0 score_local_coh = 0 with torch.no_grad(): for fw_args, bw_args in val_batches: raw_articles = bw_args[0] id2word = bw_args[1] raw_targets = bw_args[2] greedies, greedy_attns = net.greedy(*fw_args) greedy_sents = sum_id2word(raw_articles, greedies, greedy_attns) bl_scores = [] if coherence_func is not None: bl_coh_scores = [] bl_coh_inputs = [] if local_coh_func is not None: bl_local_coh_scores = [] for baseline, target in zip(greedy_sents, raw_targets): bss = sent_tokenize(' '.join(baseline)) if coherence_func is not None: bl_coh_inputs.append(bss) # if len(bss) > 1: # input_args = (bss,) + coherence_func # coh_score = coherence_infer(*input_args) / 2 # else: # coh_score = 0 # bl_coh_scores.append(coh_score) if local_coh_func is not None: local_coh_score = local_coh_func(bss) bl_local_coh_scores.append(local_coh_score) bss = [bs.split(' ') for bs in bss] tgs = sent_tokenize(' '.join(target)) tgs = [tg.split(' ') for tg in tgs] bl_score = compute_rouge_l_summ(bss, tgs) bl_scores.append(bl_score) # print('blscore:', bl_score) # print('baseline:', bss) # print('target:', tgs) bl_scores = torch.tensor(bl_scores, dtype=torch.float32, device=greedy_attns[0].device) if coherence_func is not None: input_args = (bl_coh_inputs,) + coherence_func bl_coh_scores = batch_global_infer(*input_args) bl_coh_scores = torch.tensor(bl_coh_scores, dtype=torch.float32, device=greedy_attns[0].device) score_coh += bl_coh_scores.mean().item() * 100 * coh_coef if local_coh_func is not None: bl_local_coh_scores = torch.tensor(bl_local_coh_scores, dtype=torch.float32, device=greedy_attns[0].device) score_local_coh += bl_local_coh_scores.mean().item() * local_coh_coef * 100 reward = bl_scores.mean().item() i += 1 score += reward * 100 val_score = score / i if coherence_func is not None: val_coh_score = score_coh / i else: val_coh_score = 0 val_local_coh_score = 0 print( 'validation finished in {} '.format( timedelta(seconds=int(time()-start))) ) print('validation reward: {:.4f} ... '.format(val_score)) if coherence_func is not None: print('validation reward: {:.4f} ... '.format(val_coh_score)) if local_coh_func is not None: val_local_coh_score = score_local_coh / i print('validation {} reward: {:.4f} ... '.format(local_coh_func.__name__, val_local_coh_score)) print('n_data:', i) return {'score': val_score + val_coh_score + val_local_coh_score}
def train_step(self, sample_time=1): def argmax(arr, keys): return arr[max(range(len(arr)), key=lambda i: keys[i].item())] def sum_id2word(raw_article_sents, decs, attns, id2word): dec_sents = [] for i, raw_words in enumerate(raw_article_sents): dec = [] for id_, attn in zip(decs, attns): if id_[i] == END: break elif id_[i] == UNK: dec.append(argmax(raw_words, attn[i])) else: dec.append(id2word[id_[i].item()]) dec_sents.append(dec) return dec_sents def pack_seq(seq_list): return torch.cat([_.unsqueeze(1) for _ in seq_list], 1) # forward pass of model self._net.train() #self._net.zero_grad() fw_args, bw_args = next(self._batches) raw_articles = bw_args[0] id2word = bw_args[1] raw_targets = bw_args[2] with torch.no_grad(): greedies, greedy_attns = self._net.greedy(*fw_args) greedy_sents = sum_id2word(raw_articles, greedies, greedy_attns, id2word) bl_scores = [] if self._coh_fn is not None: bl_coh_scores = [] bl_coh_inputs = [] if self._local_coh_fun is not None: bl_local_coh_scores = [] for baseline, target in zip(greedy_sents, raw_targets): bss = sent_tokenize(' '.join(baseline)) tgs = sent_tokenize(' '.join(target)) if self._coh_fn is not None: bl_coh_inputs.append(bss) # if len(bss) > 1: # input_args = (bss, ) + self._coh_fn # coh_score = coherence_infer(*input_args) / 2 # else: # coh_score = 0 # bl_coh_scores.append(coh_score) if self._local_coh_fun is not None: local_coh_score = self._local_coh_fun(bss) bl_local_coh_scores.append(local_coh_score) bss = [bs.split(' ') for bs in bss] tgs = [tg.split(' ') for tg in tgs] #bl_score = compute_rouge_l_summ(bss, tgs) bl_score = (self._weights[2] * compute_rouge_l_summ(bss, tgs) + \ self._weights[0] * compute_rouge_n(list(concat(bss)), list(concat(tgs)), n=1) + \ self._weights[1] * compute_rouge_n(list(concat(bss)), list(concat(tgs)), n=2)) bl_scores.append(bl_score) bl_scores = torch.tensor(bl_scores, dtype=torch.float32, device=greedy_attns[0].device) if self._coh_fn is not None: input_args = (bl_coh_inputs,) + self._coh_fn bl_coh_scores = batch_global_infer(*input_args) bl_coh_scores = torch.tensor(bl_coh_scores, dtype=torch.float32, device=greedy_attns[0].device) if self._local_coh_fun is not None: bl_local_coh_scores = torch.tensor(bl_local_coh_scores, dtype=torch.float32, device=greedy_attns[0].device) # print('bl:', bl_scores) # print('bl_coh:', bl_coh_scores) for _ in range(sample_time): samples, sample_attns, seqLogProbs = self._net.sample(*fw_args) sample_sents = sum_id2word(raw_articles, samples, sample_attns, id2word) sp_seqs = pack_seq(samples) _masks = (sp_seqs > PAD).float() sp_seqLogProb = pack_seq(seqLogProbs) #loss_nll = - sp_seqLogProb.squeeze(2) loss_nll = - sp_seqLogProb.squeeze(2) * _masks.detach().type_as(sp_seqLogProb) sp_scores = [] if self._coh_fn is not None: sp_coh_scores = [] sp_coh_inputs = [] if self._local_coh_fun is not None: sp_local_coh_scores = [] for sample, target in zip(sample_sents, raw_targets): sps = sent_tokenize(' '.join(sample)) tgs = sent_tokenize(' '.join(target)) if self._coh_fn is not None: sp_coh_inputs.append(sps) # if len(sps) > 1: # input_args = (sps,) + self._coh_fn # coh_score = coherence_infer(*input_args) / 2 # else: # coh_score = 0 # sp_coh_scores.append(coh_score) if self._local_coh_fun is not None: local_coh_score = self._local_coh_fun(sps) sp_local_coh_scores.append(local_coh_score) sps = [sp.split(' ') for sp in sps] tgs = [tg.split(' ') for tg in tgs] #sp_score = compute_rouge_l_summ(sps, tgs) sp_score = (self._weights[2] * compute_rouge_l_summ(sps, tgs) + \ self._weights[0] * compute_rouge_n(list(concat(sps)), list(concat(tgs)), n=1) + \ self._weights[1] * compute_rouge_n(list(concat(sps)), list(concat(tgs)), n=2)) sp_scores.append(sp_score) sp_scores = torch.tensor(sp_scores, dtype=torch.float32, device=greedy_attns[0].device) reward = sp_scores.view(-1, 1) - bl_scores.view(-1, 1) reward.requires_grad_(False) if self._coh_fn is not None: input_args = (sp_coh_inputs,) + self._coh_fn sp_coh_scores = batch_global_infer(*input_args) sp_coh_scores = torch.tensor(sp_coh_scores, dtype=torch.float32, device=greedy_attns[0].device) reward_coh = sp_coh_scores.view(-1, 1) - bl_coh_scores.view(-1, 1) reward_coh.requires_grad_(False) reward = reward + self._coh_cof * reward_coh if self._local_coh_fun is not None: sp_local_coh_scores = torch.tensor(sp_local_coh_scores, dtype=torch.float32, device=greedy_attns[0].device) reward_local = sp_local_coh_scores.view(-1, 1) - bl_local_coh_scores.view(-1, 1) reward_local.requires_grad_(False) reward = reward + self._local_co_coef * reward_local if _ == 0: loss = reward.contiguous().detach() * loss_nll loss = loss.sum() full_length = _masks.data.float().sum() else: loss += (reward.contiguous().detach() * loss_nll).sum() full_length += _masks.data.float().sum() # print('sp:', sp_scores) # print('sp_coh:', sp_coh_scores) loss = loss / full_length # backward and update ( and optional gradient monitoring ) loss.backward() log_dict = {} if self._coh_fn is not None: log_dict['reward'] = bl_scores.mean().item() + self._coh_cof * bl_coh_scores.mean().item() else: log_dict['reward'] = bl_scores.mean().item() if self._grad_fn is not None: log_dict.update(self._grad_fn()) self._opt.step() self._net.zero_grad() torch.cuda.empty_cache() return log_dict
def train_step(self, sample_time=1): torch.autograd.set_detect_anomaly(True) def argmax(arr, keys): return arr[max(range(len(arr)), key=lambda i: keys[i].item())] def sum_id2word(raw_article_sents, decs, attns, id2word): if self._bert: dec_sents = [] for i, raw_words in enumerate(raw_article_sents): dec = [] for id_, attn in zip(decs, attns): if id_[i] == self._end: break elif id_[i] == self._unk: dec.append(argmax(raw_words, attn[i])) else: dec.append(id2word[id_[i].item()]) dec_sents.append(dec) else: dec_sents = [] for i, raw_words in enumerate(raw_article_sents): dec = [] for id_, attn in zip(decs, attns): if id_[i] == self._end: break elif id_[i] == self._unk: dec.append(argmax(raw_words, attn[i])) else: dec.append(id2word[id_[i].item()]) dec_sents.append(dec) return dec_sents def pack_seq(seq_list): return torch.cat([_.unsqueeze(1) for _ in seq_list], 1) # forward pass of model self._net.train() #self._net.zero_grad() total_loss = None for i in range(self._accumulate_g_step): fw_args, bw_args = next(self._batches) raw_articles = bw_args[0] id2word = bw_args[1] raw_targets = bw_args[2] if self._reward_fn is not None: questions = bw_args[3] targets = bw_args[4] # encode # attention, init_dec_states, nodes = self._net.encode_general(*fw_args) #fw_args += (attention, init_dec_states, nodes) # _init_dec_states = ((init_dec_states[0][0].clone(), init_dec_states[0][1].clone()), init_dec_states[1].clone()) with torch.no_grad(): # g_fw_args = fw_args + (attention, _init_dec_states, nodes, False) # greedies, greedy_attns = self._net.rl_step(*g_fw_args) greedies, greedy_attns = self._net.greedy(*fw_args) greedy_sents = sum_id2word(raw_articles, greedies, greedy_attns, id2word) bl_scores = [] if self._reward_fn is not None: bl_reward_scores = [] bl_reward_inputs = [] if self._local_coh_fun is not None: bl_local_coh_scores = [] for baseline, target in zip(greedy_sents, raw_targets): if self._bert: text = ''.join(baseline) baseline = bytearray([ self._tokenizer.byte_decoder[c] for c in text ]).decode('utf-8', errors=self._tokenizer.errors) baseline = baseline.strip().lower().split(' ') text = ''.join(target) target = bytearray([ self._tokenizer.byte_decoder[c] for c in text ]).decode('utf-8', errors=self._tokenizer.errors) target = target.strip().lower().split(' ') bss = sent_tokenize(' '.join(baseline)) tgs = sent_tokenize(' '.join(target)) if self._reward_fn is not None: bl_reward_inputs.append(bss) if self._local_coh_fun is not None: local_coh_score = self._local_coh_fun(bss) bl_local_coh_scores.append(local_coh_score) bss = [bs.split(' ') for bs in bss] tgs = [tg.split(' ') for tg in tgs] #bl_score = compute_rouge_l_summ(bss, tgs) bl_score = (self._w8[2] * compute_rouge_l_summ(bss, tgs) + \ self._w8[0] * compute_rouge_n(list(concat(bss)), list(concat(tgs)), n=1) + \ self._w8[1] * compute_rouge_n(list(concat(bss)), list(concat(tgs)), n=2)) bl_scores.append(bl_score) bl_scores = torch.tensor(bl_scores, dtype=torch.float32, device=greedy_attns[0].device) # sample # s_fw_args = fw_args + (attention, init_dec_states, nodes, True) # samples, sample_attns, seqLogProbs = self._net.rl_step(*s_fw_args) fw_args += (self._ml_loss, ) if self._ml_loss: samples, sample_attns, seqLogProbs, ml_logit = self._net.sample( *fw_args) else: samples, sample_attns, seqLogProbs = self._net.sample(*fw_args) sample_sents = sum_id2word(raw_articles, samples, sample_attns, id2word) sp_seqs = pack_seq(samples) _masks = (sp_seqs > PAD).float() sp_seqLogProb = pack_seq(seqLogProbs) #loss_nll = - sp_seqLogProb.squeeze(2) loss_nll = -sp_seqLogProb.squeeze(2) * _masks.detach().type_as( sp_seqLogProb) sp_scores = [] if self._reward_fn is not None: sp_reward_inputs = [] for sample, target in zip(sample_sents, raw_targets): if self._bert: text = ''.join(sample) sample = bytearray([ self._tokenizer.byte_decoder[c] for c in text ]).decode('utf-8', errors=self._tokenizer.errors) sample = sample.strip().lower().split(' ') text = ''.join(target) target = bytearray([ self._tokenizer.byte_decoder[c] for c in text ]).decode('utf-8', errors=self._tokenizer.errors) target = target.strip().lower().split(' ') sps = sent_tokenize(' '.join(sample)) tgs = sent_tokenize(' '.join(target)) if self._reward_fn is not None: sp_reward_inputs.append(sps) sps = [sp.split(' ') for sp in sps] tgs = [tg.split(' ') for tg in tgs] #sp_score = compute_rouge_l_summ(sps, tgs) sp_score = (self._w8[2] * compute_rouge_l_summ(sps, tgs) + \ self._w8[0] * compute_rouge_n(list(concat(sps)), list(concat(tgs)), n=1) + \ self._w8[1]* compute_rouge_n(list(concat(sps)), list(concat(tgs)), n=2)) sp_scores.append(sp_score) sp_scores = torch.tensor(sp_scores, dtype=torch.float32, device=greedy_attns[0].device) if self._reward_fn is not None: sp_reward_scores, bl_reward_scores = self._reward_fn.score_two_seqs( questions, sp_reward_inputs, bl_reward_inputs) sp_reward_scores = torch.tensor(sp_reward_scores, dtype=torch.float32, device=greedy_attns[0].device) bl_reward_scores = torch.tensor(bl_reward_scores, dtype=torch.float32, device=greedy_attns[0].device) reward = sp_scores.view(-1, 1) - bl_scores.view(-1, 1) if self._reward_fn is not None: reward += self._reward_w8 * (sp_reward_scores.view(-1, 1) - bl_reward_scores.view(-1, 1)) reward.requires_grad_(False) loss = reward.contiguous().detach() * loss_nll loss = loss.sum() full_length = _masks.data.float().sum() loss = loss / full_length if self._ml_loss: ml_loss = self._ml_criterion(ml_logit, targets) loss += self._ml_loss_w8 * ml_loss.mean() # if total_loss is None: # total_loss = loss # else: # total_loss += loss loss = loss / self._accumulate_g_step loss.backward() log_dict = {} if self._reward_fn is not None: log_dict['reward'] = bl_scores.mean().item() log_dict['question_reward'] = bl_reward_scores.mean().item() log_dict['sample_question_reward'] = sp_reward_scores.mean().item() log_dict['sample_reward'] = sp_scores.mean().item() else: log_dict['reward'] = bl_scores.mean().item() if self._grad_fn is not None: log_dict.update(self._grad_fn()) self._opt.step() self._net.zero_grad() #torch.cuda.empty_cache() return log_dict
def coll(tokenizer, batch): blank = '[unused0]' questions, context, _ids, abstract = list( filter(bool, list(zip(*batch)))) system = context # print('q:', questions) # print('c:', context) if len(abstract) > 0: rouges = { 'RLr': compute_rouge_l_summ( [_c.lower().split(' ') for _c in context[0]], abstract[0], mode='r'), 'RLf': compute_rouge_l_summ( [_c.lower().split(' ') for _c in context[0]], abstract[0], mode='f'), 'R1': compute_rouge_n(' '.join(context[0]).split(' '), list(concat(abstract[0])), mode='f', n=1), 'R2': compute_rouge_n(' '.join(context[0]).split(' '), list(concat(abstract[0])), mode='f', n=2) } if len(questions[0]) != 0: questions = questions[0] context = context[0] context = ' '.join(context) choicess = [[ question['answer'], question['choice1'], question['choice2'], question['choice3'] ] for question in questions] questions = [ question['question'].replace('<\\blank>', blank) for question in questions ] questions = [[ tokenizer.tokenize(qp.lower()) for qp in question.split(blank) ] for question in questions] new_questions = [] for question in questions: new_q = ['[CLS]'] for q in question: new_q += q + [blank] new_q.pop() new_questions.append(new_q) questions = new_questions contexts = [['[SEP]'] + tokenizer.tokenize(context.lower()) for _ in range(len(questions))] choicess = [[[tokenizer.tokenize(c.lower()) for c in choice] for choice in choices] for choices in choicess] choicess = [[['[SEP]'] + choice[0] + ['[SEP]'] + choice[1] if len(choice) == 2 else ['[SEP]'] + choice[0] for choice in choices] for choices in choicess] _inputs = [[ tokenizer.convert_tokens_to_ids( (question + context + choice)[:MAX_LEN]) for choice in choices ] for question, context, choices in zip( questions, contexts, choicess)] _inputs = pad_batch_tensorize_3d(_inputs, pad=0, cuda=False) else: _inputs = [] return (_inputs, rouges, system, abstract)
def rl_validate(net, val_batches, reward_func=None, reward_coef=0.01, local_coh_func=None, local_coh_coef=0.005, bert=False): print('running validation ... ', end='') if bert: tokenizer = net._bert_model._tokenizer end = tokenizer.encoder[tokenizer._eos_token] unk = tokenizer.encoder[tokenizer._unk_token] def argmax(arr, keys): return arr[max(range(len(arr)), key=lambda i: keys[i].item())] def sum_id2word(raw_article_sents, decs, attns): if bert: dec_sents = [] for i, raw_words in enumerate(raw_article_sents): dec = [] for id_, attn in zip(decs, attns): if id_[i] == end: break elif id_[i] == unk: dec.append(argmax(raw_words, attn[i])) else: dec.append(id2word[id_[i].item()]) dec_sents.append(dec) else: dec_sents = [] for i, raw_words in enumerate(raw_article_sents): dec = [] for id_, attn in zip(decs, attns): if id_[i] == END: break elif id_[i] == UNK: dec.append(argmax(raw_words, attn[i])) else: dec.append(id2word[id_[i].item()]) dec_sents.append(dec) return dec_sents net.eval() start = time() i = 0 score = 0 score_reward = 0 score_local_coh = 0 score_r2 = 0 score_r1 = 0 bl_r2 = [] bl_r1 = [] with torch.no_grad(): for fw_args, bw_args in val_batches: raw_articles = bw_args[0] id2word = bw_args[1] raw_targets = bw_args[2] if reward_func is not None: questions = bw_args[3] greedies, greedy_attns = net.greedy(*fw_args) greedy_sents = sum_id2word(raw_articles, greedies, greedy_attns) bl_scores = [] if reward_func is not None: bl_coh_inputs = [] if local_coh_func is not None: bl_local_coh_scores = [] for baseline, target in zip(greedy_sents, raw_targets): if bert: text = ''.join(baseline) baseline = bytearray([ tokenizer.byte_decoder[c] for c in text ]).decode('utf-8', errors=tokenizer.errors) baseline = baseline.split(' ') text = ''.join(target) target = bytearray([ tokenizer.byte_decoder[c] for c in text ]).decode('utf-8', errors=tokenizer.errors) target = target.split(' ') bss = sent_tokenize(' '.join(baseline)) if reward_func is not None: bl_coh_inputs.append(bss) bss = [bs.split(' ') for bs in bss] tgs = sent_tokenize(' '.join(target)) tgs = [tg.split(' ') for tg in tgs] bl_score = compute_rouge_l_summ(bss, tgs) bl_r1.append( compute_rouge_n(list(concat(bss)), list(concat(tgs)), n=1)) bl_r2.append( compute_rouge_n(list(concat(bss)), list(concat(tgs)), n=2)) bl_scores.append(bl_score) bl_scores = torch.tensor(bl_scores, dtype=torch.float32, device=greedy_attns[0].device) if reward_func is not None: bl_reward_scores = reward_func.score(questions, bl_coh_inputs) bl_reward_scores = torch.tensor(bl_reward_scores, dtype=torch.float32, device=greedy_attns[0].device) score_reward += bl_reward_scores.mean().item() * 100 reward = bl_scores.mean().item() i += 1 score += reward * 100 score_r2 += torch.tensor( bl_r2, dtype=torch.float32, device=greedy_attns[0].device).mean().item() * 100 score_r1 += torch.tensor( bl_r1, dtype=torch.float32, device=greedy_attns[0].device).mean().item() * 100 val_score = score / i score_r2 = score_r2 / i score_r1 = score_r1 / i if reward_func is not None: val_reward_score = score_reward / i else: val_reward_score = 0 val_local_coh_score = 0 print( 'validation finished in {} '.format( timedelta(seconds=int(time() - start)))) print('validation reward: {:.4f} ... '.format(val_score)) print('validation r2: {:.4f} ... '.format(score_r2)) print('validation r1: {:.4f} ... '.format(score_r1)) if reward_func is not None: print('validation reward: {:.4f} ... '.format(val_reward_score)) if local_coh_func is not None: val_local_coh_score = score_local_coh / i print('validation {} reward: {:.4f} ... '.format( local_coh_func.__name__, val_local_coh_score)) print('n_data:', i) return {'score': val_score, 'score_reward:': val_reward_score}
def train_step(self, sample_time=1): def argmax(arr, keys): return arr[max(range(len(arr)), key=lambda i: keys[i].item())] def sum_id2word(raw_article_sents, decs, attns, id2word): dec_sents = [] for i, raw_words in enumerate(raw_article_sents): dec = [] for id_, attn in zip(decs, attns): if id_[i] == END: break elif id_[i] == UNK: dec.append(argmax(raw_words, attn[i])) else: dec.append(id2word[id_[i].item()]) dec_sents.append(dec) return dec_sents def pack_seq(seq_list): return torch.cat([_.unsqueeze(1) for _ in seq_list], 1) # forward pass of model self._net.train() #self._net.zero_grad() fw_args, bw_args = next(self._batches) raw_articles = bw_args[0] id2word = bw_args[1] raw_targets = bw_args[2] with torch.no_grad(): greedies, greedy_attns = self._net.greedy(*fw_args) greedy_sents = sum_id2word(raw_articles, greedies, greedy_attns, id2word) bl_scores = [] for baseline, target in zip(greedy_sents, raw_targets): bss = sent_tokenize(' '.join(baseline)) tgs = sent_tokenize(' '.join(target)) bss = [bs.split(' ') for bs in bss] tgs = [tg.split(' ') for tg in tgs] bss_bleu = list(concat(bss)) tgs_bleu = list(concat(tgs)) bss_bleu = ' '.join(bss_bleu) tgs_bleu = ' '.join(tgs_bleu) #bl_score = compute_rouge_l_summ(bss, tgs) if self._bleu: bleu_scores = bleu(bss_bleu, tgs_bleu) bleu_score = (bleu_scores[0] + bleu_scores[1] + bleu_scores[2] + bleu_scores[3]) bl_score = bleu_score elif self.f1: bl_score = compute_f1(bss_bleu, tgs_bleu) else: bl_score = (self._w8[2] * compute_rouge_l_summ(bss, tgs) + \ self._w8[0] * compute_rouge_n(list(concat(bss)), list(concat(tgs)), n=1) + \ self._w8[1] * compute_rouge_n(list(concat(bss)), list(concat(tgs)), n=2)) bl_scores.append(bl_score) bl_scores = torch.tensor(bl_scores, dtype=torch.float32, device=greedy_attns[0].device) samples, sample_attns, seqLogProbs = self._net.sample(*fw_args) sample_sents = sum_id2word(raw_articles, samples, sample_attns, id2word) sp_seqs = pack_seq(samples) _masks = (sp_seqs > PAD).float() sp_seqLogProb = pack_seq(seqLogProbs) #loss_nll = - sp_seqLogProb.squeeze(2) loss_nll = -sp_seqLogProb.squeeze(2) * _masks.detach().type_as( sp_seqLogProb) sp_scores = [] for sample, target in zip(sample_sents, raw_targets): sps = sent_tokenize(' '.join(sample)) tgs = sent_tokenize(' '.join(target)) sps = [sp.split(' ') for sp in sps] tgs = [tg.split(' ') for tg in tgs] #sp_score = compute_rouge_l_summ(sps, tgs) sps_bleu = list(concat(sps)) tgs_bleu = list(concat(tgs)) sps_bleu = ' '.join(sps_bleu) tgs_bleu = ' '.join(tgs_bleu) # bl_score = compute_rouge_l_summ(bss, tgs) if self._bleu: bleu_scores = bleu(sps_bleu, tgs_bleu) bleu_score = (bleu_scores[0] + bleu_scores[1] + bleu_scores[2] + bleu_scores[3]) sp_score = bleu_score elif self.f1: sp_score = compute_f1(sps_bleu, tgs_bleu) else: sp_score = (self._w8[2] * compute_rouge_l_summ(sps, tgs) + \ self._w8[0] * compute_rouge_n(list(concat(sps)), list(concat(tgs)), n=1) + \ self._w8[1]* compute_rouge_n(list(concat(sps)), list(concat(tgs)), n=2)) sp_scores.append(sp_score) sp_scores = torch.tensor(sp_scores, dtype=torch.float32, device=greedy_attns[0].device) reward = sp_scores.view(-1, 1) - bl_scores.view(-1, 1) reward.requires_grad_(False) loss = reward.contiguous().detach() * loss_nll loss = loss.sum() full_length = _masks.data.float().sum() loss = loss / full_length loss.backward() log_dict = {} log_dict['reward'] = bl_scores.mean().item() if self._grad_fn is not None: log_dict.update(self._grad_fn()) self._opt.step() self._net.zero_grad() #torch.cuda.empty_cache() return log_dict
def multi_reward(x, y): return (compute_rouge_l_summ(x, y) + compute_rouge_n(list_cat(x), list_cat(y), n=2) + 1 / 3 * compute_rouge_n(list_cat(x), list_cat(y), n=1))
def sc_validate(agent, extractor, loader): agent.eval() start = time() print('start running validation...') ave_reward = 0 ave_reward_1 = 0 ave_reward_2 = 0 batch_num = 0 for art_batch, abs_batch, ext_batch in loader: ext_sents = [] sum_sents = [] indices = [] for raw_arts in art_batch: inds = extractor(raw_arts) inds = [i.item() for i in inds] indices.append(inds) ext_sents += [raw_arts[idx] for idx in inds if idx < len(raw_arts)] sum_sents += [ raw_arts[idx - len(raw_arts)] for idx in inds if len(raw_arts) <= idx < len(raw_arts) * 2 ] #--------batch decode-------- if sum_sents: dec_greedy = agent(sum_sents) count_ext = 0 count_sum = 0 dec_greedy_mix = [] for indice in indices: max_step = indice[-1] / 2 for ind in indice: if ind < max_step: dec_greedy_mix.append(ext_sents[count_ext]) count_ext += 1 elif max_step <= ind < max_step * 2: dec_greedy_mix.append(dec_greedy[count_sum]) count_sum += 1 else: dec_greedy_mix = ext_sents i = 0 for ex, abss in zip(indices, abs_batch): ex = ex[:-1] ave_reward += compute_rouge_l_summ(dec_greedy_mix[i:i + len(ex)], abss) ave_reward_1 += compute_rouge_n(list_cat(dec_greedy_mix[i:i + len(ex)]), list(concat(abss)), n=1) ave_reward_2 += compute_rouge_n(list_cat(dec_greedy_mix[i:i + len(ex)]), list(concat(abss)), n=2) i += len(ex) batch_num += 1 assert i == len(dec_greedy_mix) ave_reward /= (batch_num / 100) ave_reward_1 /= (batch_num / 100) ave_reward_2 /= (batch_num / 100) print('finished in {}! avg reward: {:.2f} rouge-1: {:.2f} rouge-2: {:.2f}'. format(timedelta(seconds=int(time() - start)), ave_reward, ave_reward_1, ave_reward_2)) return {'reward': ave_reward}
def get_extract_label(art_sents, abs_sents): """ greedily match summary sentences to article sentences""" extracted = [] scores = [] if ''.join(art_sents[0]) != '<E>': art_sents.insert(0, '<E>') indices = list(range(len(art_sents))) new_abs_sents = [] for j in range(len(abs_sents)): rouges = list( map(compute_rouge_l(reference=abs_sents[j], mode='r'), art_sents[1:])) rouges.insert(0, 0) ext = max(indices, key=lambda i: rouges[i]) max_scores = rouges[ext] max_exts = collections.Counter(rouges)[max_scores] if max_exts != 1: max_inds = [] rouge_f = [] for idx, score in enumerate(rouges): if idx in indices: if score == max_scores: max_inds.append(idx) rouge_f.append( compute_rouge_l_summ(art_sents[idx], abs_sents[j], mode='f')) maxrouge = max(list(range(len(max_inds))), key=lambda i: rouge_f[i]) ext = max_inds[maxrouge] if ext == 0: ext = 1 new_art_sents = [] new_art_sents.append('<E>') for i in range(1, len(art_sents)): #print(art_sents[i]) if i < ext: new_art_sents.append(art_sents[i] + art_sents[ext]) elif i > ext: new_art_sents.append(art_sents[ext] + art_sents[i]) else: new_art_sents.append(art_sents[ext]) new_rouges = list( map(compute_rouge_l_summ(refs=abs_sents[j], mode='fr'), new_art_sents[1:])) new_rouges_f = [fr[0] for fr in new_rouges] new_rouges_r = [fr[1] for fr in new_rouges] new_rouges_f.insert(0, 0) new_rouges_r.insert(0, 0) new_ext = max(indices, key=lambda i: new_rouges_f[i]) if new_ext == 0: new_ext = 1 #if ext == new_ext or rouges[ext] >= new_rouges[new_ext]: if ext == new_ext or rouges[ext] >= new_rouges_r[new_ext]: extracted.append(ext) extracted.append(0) scores.append(new_rouges_f[ext]) elif ext < new_ext: extracted.append(ext) extracted.append(new_ext) extracted.append(0) scores.append(new_rouges_f[new_ext]) else: extracted.append(new_ext) extracted.append(ext) extracted.append(0) scores.append(new_rouges_f[new_ext]) #reduce duplication: ab->A bc->B, abc->AB new_abs_sents.append(abs_sents[j]) index = findindex(extracted, 0) #dic = collections.Counter(extracted) while (len(index) >= 2): #print('in') if len(index) == 2: overlap = list( set(extracted[:index[-2]]) & set(extracted[index[-2] + 1:index[-1]])) l = len(overlap) if l > 0: new = list( set(extracted[:index[-2]]).union( set(extracted[index[-2] + 1:index[-1]]))) new.sort() del extracted[:index[-1] + 1] extracted = extracted + new extracted.append(0) new_sent = new_abs_sents[-2] + new_abs_sents[-1] del new_abs_sents[-2:] new_abs_sents.append(new_sent) index = findindex(extracted, 0) else: break else: overlap = list( set(extracted[index[-3] + 1:index[-2]]) & set(extracted[index[-2] + 1:index[-1]])) l = len(overlap) if l > 0: new = list( set(extracted[index[-3] + 1:index[-2]]).union( set(extracted[index[-2] + 1:index[-1]]))) new.sort() del extracted[index[-3] + 1:index[-1] + 1] extracted = extracted + new extracted.append(0) new_sent = new_abs_sents[-2] + new_abs_sents[-1] del new_abs_sents[-2:] new_abs_sents.append(new_sent) index = findindex(extracted, 0) else: break if len(index) >= 2: if len(index) == 2: for idx in extracted[:index[-2]]: try: indices.remove(idx) except: continue if len(index) > 2: for idx in extracted[index[-3]:index[-2]]: try: indices.remove(idx) except: continue if not indices: break length = len(new_abs_sents) for i in range(length): new_abs_sents.insert(i + (i + 1), ['<E>']) return extracted, scores, new_abs_sents, art_sents
def a2c_train_step(agent, abstractor, loader, opt, grad_fn, gamma=0.99, reward_fn=compute_rouge_l, stop_reward_fn=compute_rouge_n(n=1), stop_coeff=1.0, step=0, reward_plan=1): opt.zero_grad() indices = [] probs = [] baselines = [] ext_sents = [] sum_sents = [] art_batch, abs_batch = next(loader) for raw_arts in art_batch: (inds, ms), bs = agent(raw_arts, step=step) baselines.append(bs) indices.append(inds) probs.append(ms) ext_sents += [ raw_arts[idx.item()] for idx in inds if idx.item() < len(raw_arts) ] sum_sents += [ raw_arts[idx.item() - len(raw_arts)] for idx in inds if len(raw_arts) <= idx.item() < len(raw_arts) * 2 ] if sum_sents: with torch.no_grad(): sum_sents = abstractor(sum_sents) # mix the exts and sums count_ext = 0 count_sum = 0 summaries = [] for indice in indices: max_step = indice[-1].item() / 2 for ind in indice[:-1]: if ind.item() < max_step: summaries.append(ext_sents[count_ext]) count_ext += 1 elif max_step <= ind.item() < max_step * 2: summaries.append(sum_sents[count_sum]) count_sum += 1 else: summaries = ext_sents assert len(summaries) == len(ext_sents) + len(sum_sents) copy_rate = len(ext_sents) / len(summaries) i = 0 rewards = [] avg_reward = 0 for inds, abss in zip(indices, abs_batch): # plain if reward_plan == 0: reward_save = [ compute_rouge_l_summ(summaries[i:i + j + 1], abss) for j in range(min(len(inds) - 1, len(abss))) ] rs = ([(reward_save[j] - reward_save[j - 1]) if j > 0 else reward_save[j] for j in range(min(len(inds) - 1, len(abss)))] + [0 for _ in range(max(0, len(inds) - 1 - len(abss)))] + [ stop_coeff * stop_reward_fn( list(concat(summaries[i:i + len(inds) - 1])), list(concat(abss))) ]) # margin elif reward_plan == 1: reward_save = [ compute_rouge_l_summ(summaries[i:i + j + 1], abss) for j in range(len(inds) - 1) ] rs = ([(reward_save[j] - reward_save[j - 1]) if j > 0 else reward_save[j] for j in range(len(inds) - 1)] + [ stop_coeff * stop_reward_fn( list(concat(summaries[i:i + len(inds) - 1])), list(concat(abss))) ]) assert len(rs) == len(inds) avg_reward += rs[-1] / stop_coeff i += len(inds) - 1 # compute discounted rewards R = 0 disc_rs = [] for r in rs[::-1]: R = r + gamma * R disc_rs.insert(0, R) rewards += disc_rs indices = list(concat(indices)) probs = list(concat(probs)) baselines = list(concat(baselines)) # standardize rewards reward = torch.Tensor(rewards).to(baselines[0].device) reward = (reward - reward.mean()) / (reward.std() + float(np.finfo(np.float32).eps)) baseline = torch.cat(baselines).squeeze() avg_advantage = 0 losses = [] for action, p, r, b in zip(indices, probs, reward, baseline): advantage = r - b avg_advantage += advantage losses.append( -p.log_prob(action) * (advantage / len(indices)) ) # divide by T*B + 1e-2*torch.exp(p.log_prob(action))*p.log_prob(action) critic_loss = F.mse_loss(baseline, reward) # backprop and update autograd.backward([critic_loss] + losses, [torch.ones(1).to(critic_loss.device)] * (1 + len(losses))) grad_log = grad_fn() opt.step() log_dict = {} log_dict.update(grad_log) log_dict['reward'] = avg_reward / len(art_batch) log_dict['advantage'] = avg_advantage.item() / len(indices) log_dict['mse'] = critic_loss.item() log_dict['copy_rate'] = copy_rate assert not math.isnan(log_dict['grad_norm']) return log_dict