def train_rl_1(one2many_batch, model, optimizer, generator, opt, reward_cache):
    src_list, src_len, trg_list, _, trg_copy_target_list, src_oov_map_list, oov_list = one2many_batch

    if torch.cuda.is_available():
        src_list = src_list.cuda()
        src_oov_map_list = src_oov_map_list.cuda()

    # Sample number_batch sequences
    sampled_seqs_list = generator.sample(src_list, src_len, src_oov_map_list, oov_list, opt.word2id, k=5, is_greedy=False)

    policy_loss = []
    policy_rewards = []
    # Compute their rewards and losses
    for seq_i, (src, trg, trg_copy, sampled_seqs, oov) in enumerate(zip(src_list, trg_list, trg_copy_target_list, sampled_seqs_list, oov_list)):
        # convert to string sequences
        sampled_str_seqs = [[opt.id2word[x] if x < opt.vocab_size else oov[x - opt.vocab_size] for x in to_cpu_list(seq.sentence)] for seq in sampled_seqs]
        sampled_str_seqs = [seq[:seq.index(pykp.io.EOS_WORD) + 1] if pykp.io.EOS_WORD in seq else seq for seq in sampled_str_seqs]

        # pad trg seqs with EOS to the same length
        trg_seqs = [[opt.id2word[x] if x < opt.vocab_size else oov[x - opt.vocab_size] for x in seq] for seq in trg_copy]
        # trg_seqs            =  [seq + [pykp.IO.EOS_WORD] * (opt.max_sent_length - len(seq)) for seq in trg_seqs]

        # local rewards (bleu)
        bleu_samples = get_match_result(true_seqs=trg_seqs, pred_seqs=sampled_str_seqs, type='bleu')

        # global rewards
        match_samples = get_match_result(true_seqs=trg_seqs, pred_seqs=sampled_str_seqs, type='exact')

        _, _, fscore_samples = evaluate.evaluate(match_samples, sampled_str_seqs, trg_seqs, topk=5)

        # compute the final rewards
        alpha = 0.0
        rewards = alpha * np.asarray(bleu_samples) + (1.0 - alpha) * fscore_samples
        baseline = reward_cache.get_average()
        for reward in rewards:
            reward_cache.push(float(reward))

        [policy_loss.append(-torch.stack(seq.logprobs, dim=0).sum() * float(reward - baseline)) for seq, reward in zip(sampled_seqs, rewards)]
        [policy_rewards.append(reward) for reward in rewards]

    optimizer.zero_grad()
    policy_loss = torch.stack(policy_loss).mean() * (1 - opt.loss_scale)
    policy_loss.backward()

    if opt.max_grad_norm > 0:
        pre_norm = torch.nn.utils.clip_grad_norm(model.parameters(), opt.max_grad_norm)
        after_norm = (sum([p.grad.data.norm(2) ** 2 for p in model.parameters() if p.grad is not None])) ** (1.0 / 2)
        # logging.info('clip grad (%f -> %f)' % (pre_norm, after_norm))

    optimizer.step()
    return np.average(policy_rewards)
示例#2
0
def train_rl(one2many_batch, model, optimizer, generator, opt):
    src_list, src_len, trg_list, _, trg_copy_target_list, src_oov_map_list, oov_list = one2many_batch

    if torch.cuda.is_available():
        src_list = src_list.cuda()
        src_oov_map_list = src_oov_map_list.cuda()

    # Baseline sequences for self-critic
    baseline_seqs_list = generator.sample(src_list, src_len, src_oov_map_list, oov_list, opt.word2id, k=5, is_greedy=True)

    # Sample number_batch*beam_size sequences
    sampled_seqs_list  = generator.sample(src_list, src_len, src_oov_map_list, oov_list, opt.word2id, k=5, is_greedy=False)

    policy_loss = []
    policy_rewards = []
    # Compute their rewards and losses
    for seq_i, (src, trg, trg_copy, sampled_seqs, baseline_seqs, oov) in enumerate(zip(src_list, trg_list, trg_copy_target_list, sampled_seqs_list, baseline_seqs_list, oov_list)):
        # convert to string sequences
        baseline_str_seqs   =  [[opt.id2word[x] if x < opt.vocab_size else oov[x - opt.vocab_size] for x in seq.sentence] for seq in baseline_seqs]
        baseline_str_seqs   =  [seq[:seq.index(pykp.io.EOS_WORD) + 1] if pykp.io.EOS_WORD in seq else seq for seq in baseline_str_seqs]
        sampled_str_seqs    =  [[opt.id2word[x] if x < opt.vocab_size else oov[x - opt.vocab_size] for x in seq.sentence] for seq in sampled_seqs]
        sampled_str_seqs    =  [seq[:seq.index(pykp.io.EOS_WORD) + 1] if pykp.io.EOS_WORD in seq else seq for seq in sampled_str_seqs]

        # pad trg seqs with EOS to the same length
        trg_seqs            =  [[opt.id2word[x] if x < opt.vocab_size else oov[x - opt.vocab_size] for x in seq] for seq in trg_copy]
        # trg_seqs            =  [seq + [pykp.IO.EOS_WORD] * (opt.max_sent_length - len(seq)) for seq in trg_seqs]

        # local rewards (bleu)
        bleu_baselines           =  get_match_result(true_seqs=trg_seqs, pred_seqs=baseline_str_seqs, type='bleu')
        bleu_samples             =  get_match_result(true_seqs=trg_seqs, pred_seqs=sampled_str_seqs, type='bleu')

        # global rewards
        match_baselines          =  get_match_result(true_seqs=trg_seqs, pred_seqs=baseline_str_seqs, type='exact')
        match_samples            =  get_match_result(true_seqs=trg_seqs, pred_seqs=sampled_str_seqs, type='exact')

        _, _, fscore_baselines   =  evaluate.evaluate(match_baselines, baseline_str_seqs, trg_seqs, topk=5)
        _, _, fscore_samples     =  evaluate.evaluate(match_samples, sampled_str_seqs, trg_seqs, topk=5)

        # compute the final rewards
        alpha                    = 0.0
        baseline                 = alpha * np.average(bleu_baselines) + (1.0 - alpha) * fscore_baselines
        rewards                  = alpha * np.asarray(bleu_samples)   + (1.0 - alpha) * fscore_samples

        """
        print('*' * 20 + '  ' + str(seq_i) + '  ' + '*' * 20)
        print('Target Sequences:\n\t\t %s' % str(trg_seqs))
        print('Baseline Sequences:')
        for pred_seq, reward in zip(baseline_str_seqs, baselines):
            print('\t\t[%f] %s' % (reward, ' '.join(pred_seq)))
        print('Predict Sequences:')
        for pred_seq, reward in zip(sampled_str_seqs, rewards):
            print('\t\t[%f] %s' % (reward, ' '.join(pred_seq)))
        """

        [policy_loss.append(-torch.cat(seq.logprobs, dim=0) * float(reward - baseline)) for seq, reward in zip(sampled_seqs, rewards)]
        [policy_rewards.append(reward) for reward in rewards]

    optimizer.zero_grad()
    policy_loss = torch.cat(policy_loss).sum() * (1 - opt.loss_scale)
    policy_loss.backward()

    if opt.max_grad_norm > 0:
        pre_norm = torch.nn.utils.clip_grad_norm(model.parameters(), opt.max_grad_norm)
        after_norm = (sum([p.grad.data.norm(2) ** 2 for p in model.parameters() if p.grad is not None])) ** (1.0 / 2)
        logging.info('clip grad (%f -> %f)' % (pre_norm, after_norm))

    optimizer.step()

    return np.average(policy_rewards)