Exemplo n.º 1
0
def evaluate(iterator):
    actor.eval()
    sum_nll, cnt_nll = 0, 0
    for batch, (src, tgt) in enumerate(iterator, start=1):

        # get hyp
        hyps, _ = actor.generate(src, k=args.beamsize, n=1)

        # get log_prob
        log_prob = actor(src, tgt)

        # compute masked nll
        tgt_flat = tgt[1:].view(-1, 1)
        masked_nll = -log_prob.view(-1, ntoken_tgt).gather(
            1, tgt_flat).masked_select(tgt_flat.ne(tgt_pad_idx))

        # accumulate nll
        sum_nll += masked_nll.data.sum()
        cnt_nll += masked_nll.size(0)

        # pytorch_bleu requires size [bsz x nref|nhyp x seqlen]
        ref, hyp = utils.prepare_for_bleu(tgt,
                                          hyps,
                                          eos_idx,
                                          tgt_pad_idx,
                                          tgt_unk_idx,
                                          exclude_unk=True)

        # add hyp & ref to corpus
        bleu_metric.add_to_corpus(hyp, ref)

    # sanity check
    vis_idx = np.random.randint(0, tgt.size(1))
    logging('===> [SRC]  {}'.format(vocab['src'].convert_to_sent(
        src[:, vis_idx].contiguous().data.cpu().view(-1),
        exclude=[src_pad_idx])))
    logging('===> [REF]  {}'.format(vocab['tgt'].convert_to_sent(
        tgt[1:, vis_idx].contiguous().data.cpu().view(-1),
        exclude=[tgt_pad_idx, eos_idx])))
    logging('===> [HYP]  {}'.format(vocab['tgt'].convert_to_sent(
        hyps[1:, vis_idx, 0].contiguous().data.cpu().view(-1),
        exclude=[tgt_pad_idx, eos_idx])))

    ppl = np.exp(sum_nll / cnt_nll)

    bleu4, precs, hyplen, reflen = bleu_metric.corpus_bleu()
    bleu = bleu4[0] * 100
    logging(
        'PPL {:.3f} | BLEU = {:.3f}, {:.1f}/{:.1f}/{:.1f}/{:.1f}, hyp_len={}, ref_len={}'
        .format(ppl, bleu, *[prec[0] * 100 for prec in precs], int(hyplen[0]),
                int(reflen[0])))

    return ppl, bleu
Exemplo n.º 2
0
def evaluate(iterator):
    # actor.eval()
    critic.eval()
    sum_res, cnt_tok = 0, 0
    for batch, (src, tgt) in enumerate(iterator, start=1):
        max_len = tgt.size(0) + 5
        seq, act_log_dist = actor.sample(src, k=1, max_len=max_len)
        seq = seq.view(seq.size(0), -1).detach()

        # non-padding mask
        mask = seq[1:].ne(tgt_pad_idx).float()

        # given a demonstration trajectory / real sequence tgt, get Q(y_{<t}, w) for all w in W_+
        Q_all = critic(tgt, seq, out_mode=models.LOGIT)

        # compute Q(y_{<t}, y_t)
        Q_mod = Q_all.gather(2, seq[1:].unsqueeze(2)).squeeze(
            2)  # [tgtlen-1 x bsz]

        # compute V_hat(y_{<t})
        act_log_dist = act_log_dist.data.clone()
        act_log_dist.masked_fill_(seq.data[1:].eq(tgt_pad_idx)[:, :, None], 0.)
        if critic.dec_tau > 0:
            V_hat = (act_log_dist.exp() *
                     (Q_all.data -
                      critic.dec_tau * act_log_dist)).sum(2) * mask.data
        else:
            V_hat = (act_log_dist.exp() * Q_all.data).sum(2) * mask.data

        # compute rewards
        ref, hyp = utils.prepare_for_bleu(tgt,
                                          seq,
                                          eos_idx=eos_idx,
                                          pad_idx=tgt_pad_idx,
                                          unk_idx=tgt_unk_idx)
        R = utils.get_rewards(bleu_metric, hyp, ref)

        # compute target value : `Q_hat(s, a) = r(s, a) + V_bar(s')`
        Q_hat = R.clone()
        Q_hat[:-1] += V_hat[1:]

        # compute TD error : `td_error = Q_hat - Q_mod`
        td_error = Variable(Q_hat) - Q_mod

        # accumulate nll for computing perplexity (this is not necessary though)
        cnt_tok += mask.data.sum()
        sum_res += (torch.abs(td_error.data) * mask.data).sum()

    res = sum_res / cnt_tok
    logging('Valid td error = {:.5f}'.format(res))

    return res
Exemplo n.º 3
0
def train_erac(src, tgt):
    ##### Policy execution (actor)
    # sample sequence from the actor
    # max_len = min(tgt.size(0) + 10, 50)
    max_len = min(tgt.size(0) + 5, 50)
    seq, act_log_dist = actor.sample(src, k=args.nsample, max_len=max_len)

    seq = seq.view(seq.size(0), -1)
    mask = seq[1:].ne(tgt_pad_idx).float()
    act_dist = act_log_dist.exp()

    # compute rewards
    ref, hyp = utils.prepare_for_bleu(tgt, seq, eos_idx=eos_idx, pad_idx=tgt_pad_idx, unk_idx=tgt_unk_idx)
    bleu_R, bleu = utils.get_rewards(bleu_metric, hyp, ref, return_bleu=True)

    if args.use_unsuper_reward:
        R = utils.get_unsuper_rewards(GPTLM, tokenizer, XLM, bpe, dico, params, cos_sim, vocab, src, hyp,
                                      inc_adequacy=args.include_adequacy, mu=args.mu, device=device)
    else:
        R = bleu_R

    ##### Policy evaluation (critic)
    # compute Q value estimated by the critic
    Q_all = critic(tgt, seq, out_mode=models.LOGIT)
    # compute Q(y_{<t}, y_t)
    Q_mod = Q_all.gather(2, seq[1:].unsqueeze(2).to(torch.int64)).squeeze(2)
    # compute V_bar(y_{<t})
    act_log_dist.data.masked_fill_(seq.data[1:].eq(tgt_pad_idx)[:,:,None], 0.)
    if args.use_tgtnet:
        # tgt_volatile = tgt.data.clone().detach().requires_grad_(True)
        # seq_volatile = seq.data.clone().detach().requires_grad_(True)
        tgt_volatile = Variable(tgt.data, requires_grad=True)
        seq_volatile = Variable(seq.data, requires_grad=True)
        Q_all_bar = tgt_critic(tgt_volatile, seq_volatile, out_mode=models.LOGIT)

        V_bar = (act_dist.data * (Q_all_bar.data - critic.dec_tau * act_log_dist.data)).sum(2) * mask.data
    else:
        V_bar = (act_dist.data * (Q_all.data - critic.dec_tau * act_log_dist.data)).sum(2) * mask.data

    # compute target value : `Q_hat(s, a) = r(s, a) + V_bar(s')`
    Q_hat = R.clone()
    Q_hat[:-1] += V_bar[1:]

    # compute TD error : `td_error = Q_hat - Q_mod`
    td_error = Variable(Q_hat - Q_mod.data)

    # critic loss
    loss_crt = -td_error * Q_mod
    if args.smooth_coeff > 0:
        loss_crt += args.smooth_coeff * Q_all.var(2)
    loss_crt = loss_crt.sum(0).mean()

    # actor loss
    pg_signal = Q_all.data
    if args.tau > 0:
        # normalize to avoid unstability
        pg_signal -= args.tau * act_log_dist.data / (1e-8 + act_log_dist.data.norm(p=2, dim=2, keepdim=True))

    loss_act = -(Variable(pg_signal) * act_log_dist.clone().detach().requires_grad_(True).exp()).sum(2) * mask
    # loss_act = act_log_dist.exp()[:,:,:5].sum(2) * mask
    loss_act = loss_act.sum(0).mean()

    return loss_crt, loss_act, mask, td_error, R, bleu
Exemplo n.º 4
0
def train(epoch):
    actor.train()
    critic.train()
    start_time = time.time()
    sum_res, cnt_tok = 0, 0
    sum_score, cnt_sent = 0, 0
    for batch, (src, tgt) in enumerate(tr_iter, start=1):
        # get trajectory
        max_len = min(tgt.size(0) + 5, 50)
        seq, act_log_dist = actor.sample(src, k=args.nsample, max_len=max_len)

        seq = seq.view(seq.size(0), -1).detach()
        mask = seq[1:].ne(tgt_pad_idx).float()

        # compute rewards
        ref, hyp = utils.prepare_for_bleu(tgt,
                                          seq,
                                          eos_idx=eos_idx,
                                          pad_idx=tgt_pad_idx,
                                          unk_idx=tgt_unk_idx)
        R, score = utils.get_rewards(bleu_metric, hyp, ref, return_bleu=True)

        # given a demonstration trajectory / real sequence tgt, get Q(y_{<t}, w) for all w in W_+
        Q_all = critic(tgt, seq, out_mode=models.LOGIT)

        # compute Q(y_{<t}, y_t)
        Q_mod = Q_all.gather(2, seq[1:].unsqueeze(2)).squeeze(2)

        # compute V_bar(y_{<t})
        act_log_dist = act_log_dist.data.clone()
        act_log_dist.masked_fill_(seq.data[1:].eq(tgt_pad_idx)[:, :, None], 0.)
        act_dist = act_log_dist.exp()

        if args.use_tgtnet:
            tgt_volatile = Variable(tgt.data, volatile=True)
            seq_volatile = Variable(seq.data, volatile=True)
            Q_all_bar = tgt_critic(tgt_volatile,
                                   seq_volatile,
                                   out_mode=models.LOGIT)

            if critic.dec_tau > 0:
                V_bar = (act_dist *
                         (Q_all_bar.data -
                          critic.dec_tau * act_log_dist)).sum(2) * mask.data
            else:
                V_bar = (act_dist * Q_all_bar.data).sum(2) * mask.data
        else:
            if critic.dec_tau > 0:
                V_bar = (act_dist *
                         (Q_all.data -
                          critic.dec_tau * act_log_dist)).sum(2) * mask.data
            else:
                V_bar = (act_dist * Q_all.data).sum(2) * mask.data

        # compute target value : `Q_hat(s, a) = r(s, a) + V_bar(s')`
        Q_hat = R.clone()
        Q_hat[:-1] += V_bar[1:]

        # compute TD error : `td_error = Q_hat - Q_mod`
        td_error = Variable(Q_hat - Q_mod.data)

        # construct loss function
        loss = -td_error * Q_mod * mask
        if args.smooth_coeff > 0:
            loss += args.smooth_coeff * Q_all.var(2)
        loss = loss.sum(0).mean()

        # accumulate nll for computing perplexity (this is not necessary though)
        cnt_tok += mask.data.sum()
        sum_res += (torch.abs(td_error.data) * mask.data).sum()
        cnt_sent += seq.size(1)
        sum_score += score.sum()

        # optimization
        optimizer.zero_grad()
        loss.backward()
        gnorm = nn.utils.clip_grad_norm(critic.parameters(), args.grad_clip)
        optimizer.step()

        if args.use_tgtnet:
            utils.slow_update(critic, tgt_critic, args.tgt_speed)

        # logging
        if batch % args.log_interval == 0:
            elapsed = time.time() - start_time
            logging(
                '| epoch {:3d} | {:4d}/{:4d} batches | lr {:.6f} | ms/batch {:5.1f} | '
                'td error {:5.3f} | score {:5.3f}'.format(
                    epoch, batch, tr_iter.num_batch(),
                    optimizer.param_groups[0]['lr'],
                    elapsed * 1000 / args.log_interval, sum_res / cnt_tok,
                    sum_score / cnt_sent))
            start_time = time.time()
            sum_res, cnt_tok = 0, 0
            sum_score, cnt_sent = 0, 0
Exemplo n.º 5
0
def train(epoch):
    actor.train()
    start_time = time.time()
    sum_nll, cnt_nll = 0, 0
    sum_cet, cnt_cet = 0, 0

    for batch, (src, tgt) in enumerate(tr_iter, start=1):
        tgt_volatile = Variable(tgt.data.clone(), volatile=True)

        # sample corrputed sequence
        seq = Variable(
            utils.ngram_sample(tgt.data,
                               args.nsample,
                               low=4,
                               high=ntoken_tgt,
                               pad_idx=tgt_pad_idx))
        seq = seq.view(seq.size(0), -1)

        # compute importance weight based on sentence bleu
        ref, hyp = utils.prepare_for_bleu(tgt, seq, eos_idx, tgt_pad_idx,
                                          tgt_unk_idx)
        sent_bleu = Variable(bleu_metric.sent_bleu(hyp, ref))

        weight = nn.functional.softmax(sent_bleu, 1)
        weight = weight.view(-1)

        # actor log distribution on seq
        log_act_dist = actor(src, seq)

        # non-padding mask
        mask_seq = seq[1:].ne(tgt_pad_idx).float()

        if args.vaml_coeff != 1:
            # negative loss likelihood estimated by the actor
            nll = -log_act_dist.gather(2, seq[1:].unsqueeze(2)).squeeze(2)

            nll_tgt = nll.data.view(nll.size(0), tgt.size(1), -1)[:, :, 0]
            mask_tgt = tgt[1:].data.ne(tgt_pad_idx).float()

            sum_nll += (nll_tgt * mask_tgt).sum()
            cnt_nll += mask_tgt.sum()

        if args.vaml_coeff != 0:
            # cross entropy based on the critic
            crt_dist = critic(tgt_volatile, seq, out_mode=models.PROB)
            cet = -(Variable(crt_dist.data) * log_act_dist).sum(-1)

            sum_cet += cet.data.sum()
            cnt_cet += mask_seq.data.sum()

        if 0 < args.vaml_coeff < 1:
            # # 1 in mask_nll|mask_cet means the corresponding token will NOT be trained by the loss_nll|loss_cet
            # mask_nll = utils.random_mask(seq[1:].data, rate=args.vaml_coeff)
            # mask_cet = (1 - mask_nll)
            # mask_pad = seq[1:].data.eq(tgt_pad_idx)

            # mask_nll = mask_nll | mask_pad
            # mask_cet = mask_cet | mask_pad

            # # apply the masking
            # nll.data.masked_fill_(mask_nll, 0)
            # cet.data.masked_fill_(mask_cet, 0)

            # loss_nll = nll.sum(0)
            # loss_cet = cet.sum(0)

            # 1 in mask_nll|mask_cet means the corresponding token will be trained by the loss_nll|loss_cet
            mask_cet = utils.random_mask(seq[1:].data, rate=args.vaml_coeff)
            mask_nll = (1 - mask_cet)
            mask_nonpad = seq[1:].data.ne(tgt_pad_idx)
            mask_nll = Variable((mask_nll & mask_nonpad).float())
            mask_cet = Variable((mask_cet & mask_nonpad).float())

            loss_nll = (nll * mask_nll).sum(0)
            loss_cet = (cet * mask_cet).sum(0)

            loss = ((loss_nll + loss_cet) * weight).view(tgt.size(1),
                                                         -1).sum(1).mean(0)

        elif args.vaml_coeff == 0:
            loss = ((nll * mask_seq).sum(0) * weight).view(tgt.size(1),
                                                           -1).sum(1).mean(0)

        elif args.vaml_coeff == 1:
            loss = ((cet * mask_seq).sum(0) * weight).view(tgt.size(1),
                                                           -1).sum(1).mean(0)

        # optimization
        optimizer.zero_grad()
        loss.backward()
        gnorm = nn.utils.clip_grad_norm(actor.parameters(), args.grad_clip)
        optimizer.step()

        # logging
        if batch % args.log_interval == 0:
            cur_loss = sum_nll / cnt_nll
            elapsed = time.time() - start_time
            logging(
                '| epoch {:3d} | {:4d}/{:4d} batches | lr {:.6f} | ms/batch {:5.1f} | '
                'loss nll {:5.2f} | loss cet {:5.2f} | ppl {:8.2f} '.format(
                    epoch, batch, tr_iter.num_batch(),
                    optimizer.param_groups[0]['lr'],
                    elapsed * 1000 / args.log_interval,
                    cur_loss, sum_cet / cnt_cet if cnt_cet > 0 else 0.,
                    np.exp(cur_loss)))
            start_time = time.time()
            sum_nll, cnt_nll = 0, 0
            sum_cet, cnt_cet = 0, 0