def train_rl_1(one2many_batch, model, optimizer, generator, opt, reward_cache): src_list, src_len, trg_list, _, trg_copy_target_list, src_oov_map_list, oov_list = one2many_batch if torch.cuda.is_available(): src_list = src_list.cuda() src_oov_map_list = src_oov_map_list.cuda() # Sample number_batch sequences sampled_seqs_list = generator.sample(src_list, src_len, src_oov_map_list, oov_list, opt.word2id, k=5, is_greedy=False) policy_loss = [] policy_rewards = [] # Compute their rewards and losses for seq_i, (src, trg, trg_copy, sampled_seqs, oov) in enumerate(zip(src_list, trg_list, trg_copy_target_list, sampled_seqs_list, oov_list)): # convert to string sequences sampled_str_seqs = [[opt.id2word[x] if x < opt.vocab_size else oov[x - opt.vocab_size] for x in to_cpu_list(seq.sentence)] for seq in sampled_seqs] sampled_str_seqs = [seq[:seq.index(pykp.io.EOS_WORD) + 1] if pykp.io.EOS_WORD in seq else seq for seq in sampled_str_seqs] # pad trg seqs with EOS to the same length trg_seqs = [[opt.id2word[x] if x < opt.vocab_size else oov[x - opt.vocab_size] for x in seq] for seq in trg_copy] # trg_seqs = [seq + [pykp.IO.EOS_WORD] * (opt.max_sent_length - len(seq)) for seq in trg_seqs] # local rewards (bleu) bleu_samples = get_match_result(true_seqs=trg_seqs, pred_seqs=sampled_str_seqs, type='bleu') # global rewards match_samples = get_match_result(true_seqs=trg_seqs, pred_seqs=sampled_str_seqs, type='exact') _, _, fscore_samples = evaluate.evaluate(match_samples, sampled_str_seqs, trg_seqs, topk=5) # compute the final rewards alpha = 0.0 rewards = alpha * np.asarray(bleu_samples) + (1.0 - alpha) * fscore_samples baseline = reward_cache.get_average() for reward in rewards: reward_cache.push(float(reward)) [policy_loss.append(-torch.stack(seq.logprobs, dim=0).sum() * float(reward - baseline)) for seq, reward in zip(sampled_seqs, rewards)] [policy_rewards.append(reward) for reward in rewards] optimizer.zero_grad() policy_loss = torch.stack(policy_loss).mean() * (1 - opt.loss_scale) policy_loss.backward() if opt.max_grad_norm > 0: pre_norm = torch.nn.utils.clip_grad_norm(model.parameters(), opt.max_grad_norm) after_norm = (sum([p.grad.data.norm(2) ** 2 for p in model.parameters() if p.grad is not None])) ** (1.0 / 2) # logging.info('clip grad (%f -> %f)' % (pre_norm, after_norm)) optimizer.step() return np.average(policy_rewards)
def train_rl(one2many_batch, model, optimizer, generator, opt): src_list, src_len, trg_list, _, trg_copy_target_list, src_oov_map_list, oov_list = one2many_batch if torch.cuda.is_available(): src_list = src_list.cuda() src_oov_map_list = src_oov_map_list.cuda() # Baseline sequences for self-critic baseline_seqs_list = generator.sample(src_list, src_len, src_oov_map_list, oov_list, opt.word2id, k=5, is_greedy=True) # Sample number_batch*beam_size sequences sampled_seqs_list = generator.sample(src_list, src_len, src_oov_map_list, oov_list, opt.word2id, k=5, is_greedy=False) policy_loss = [] policy_rewards = [] # Compute their rewards and losses for seq_i, (src, trg, trg_copy, sampled_seqs, baseline_seqs, oov) in enumerate(zip(src_list, trg_list, trg_copy_target_list, sampled_seqs_list, baseline_seqs_list, oov_list)): # convert to string sequences baseline_str_seqs = [[opt.id2word[x] if x < opt.vocab_size else oov[x - opt.vocab_size] for x in seq.sentence] for seq in baseline_seqs] baseline_str_seqs = [seq[:seq.index(pykp.io.EOS_WORD) + 1] if pykp.io.EOS_WORD in seq else seq for seq in baseline_str_seqs] sampled_str_seqs = [[opt.id2word[x] if x < opt.vocab_size else oov[x - opt.vocab_size] for x in seq.sentence] for seq in sampled_seqs] sampled_str_seqs = [seq[:seq.index(pykp.io.EOS_WORD) + 1] if pykp.io.EOS_WORD in seq else seq for seq in sampled_str_seqs] # pad trg seqs with EOS to the same length trg_seqs = [[opt.id2word[x] if x < opt.vocab_size else oov[x - opt.vocab_size] for x in seq] for seq in trg_copy] # trg_seqs = [seq + [pykp.IO.EOS_WORD] * (opt.max_sent_length - len(seq)) for seq in trg_seqs] # local rewards (bleu) bleu_baselines = get_match_result(true_seqs=trg_seqs, pred_seqs=baseline_str_seqs, type='bleu') bleu_samples = get_match_result(true_seqs=trg_seqs, pred_seqs=sampled_str_seqs, type='bleu') # global rewards match_baselines = get_match_result(true_seqs=trg_seqs, pred_seqs=baseline_str_seqs, type='exact') match_samples = get_match_result(true_seqs=trg_seqs, pred_seqs=sampled_str_seqs, type='exact') _, _, fscore_baselines = evaluate.evaluate(match_baselines, baseline_str_seqs, trg_seqs, topk=5) _, _, fscore_samples = evaluate.evaluate(match_samples, sampled_str_seqs, trg_seqs, topk=5) # compute the final rewards alpha = 0.0 baseline = alpha * np.average(bleu_baselines) + (1.0 - alpha) * fscore_baselines rewards = alpha * np.asarray(bleu_samples) + (1.0 - alpha) * fscore_samples """ print('*' * 20 + ' ' + str(seq_i) + ' ' + '*' * 20) print('Target Sequences:\n\t\t %s' % str(trg_seqs)) print('Baseline Sequences:') for pred_seq, reward in zip(baseline_str_seqs, baselines): print('\t\t[%f] %s' % (reward, ' '.join(pred_seq))) print('Predict Sequences:') for pred_seq, reward in zip(sampled_str_seqs, rewards): print('\t\t[%f] %s' % (reward, ' '.join(pred_seq))) """ [policy_loss.append(-torch.cat(seq.logprobs, dim=0) * float(reward - baseline)) for seq, reward in zip(sampled_seqs, rewards)] [policy_rewards.append(reward) for reward in rewards] optimizer.zero_grad() policy_loss = torch.cat(policy_loss).sum() * (1 - opt.loss_scale) policy_loss.backward() if opt.max_grad_norm > 0: pre_norm = torch.nn.utils.clip_grad_norm(model.parameters(), opt.max_grad_norm) after_norm = (sum([p.grad.data.norm(2) ** 2 for p in model.parameters() if p.grad is not None])) ** (1.0 / 2) logging.info('clip grad (%f -> %f)' % (pre_norm, after_norm)) optimizer.step() return np.average(policy_rewards)