Beispiel #1
0
def StepPGLoss(G, D, src_seq):
    ''' Policy gradient training on G '''
    src_pos = G.get_position(src_seq.data)
    enc_output, *_ = G.encoder(src_seq, src_pos)
    dec_seq = autograd.Variable(torch.LongTensor(1, 1).fill_(Constants.BOS))
    if torch.cuda.is_available():
        dec_seq = dec_seq.cuda()

    rewards = None
    probs = None

    # decode
    for i in range(G.max_len):
        rollout_tokens, prob = G.step_rollout(src_seq,
                                              enc_output,
                                              dec_seq,
                                              n_rollout=6)
        rollout_tokens = rollout_tokens.transpose(1, 0)  # (n_rollout, 1)

        partial_seq = helper.stack(dec_seq.data, 6)  # (n_rollout, cur_len)
        partial_seq = torch.cat([partial_seq, rollout_tokens],
                                dim=1)  # (n_rollout, cur_len+1)
        if partial_seq.size(1) < D.min_len:
            partial_seq = helper.pad_seq(partial_seq, D.min_len, Constants.PAD)

        partial_seq = autograd.Variable(partial_seq)

        reward = D(partial_seq)

        top_i = reward.max(dim=0)[1].data
        next_token = rollout_tokens.squeeze(1)[top_i]  # 選reward最高的為下個token
        next_token = autograd.Variable(
            next_token.unsqueeze(1))  # 需轉為variable,torch.cat()才不會出錯

        dec_seq = torch.cat([dec_seq, next_token], dim=1)

        rewards = torch.cat([rewards, reward
                             ]) if rewards is not None else reward
        probs = torch.cat([probs, prob]) if probs is not None else reward
        # probs += list(prob.split(1))

        if next_token[0] == Constants.EOS:
            break

    # print(rewards)
    loss = -torch.mean(rewards * probs)

    return loss
 def check_input(self, x):
     if x.size(1) < self.min_len:
         x = helper.pad_seq(x.data, self.min_len, Constants.PAD)
     return x
Beispiel #3
0
def main():
    parser = argparse.ArgumentParser()
    opt = options.train_options(parser)
    opt = parser.parse_args()

    opt.cuda = torch.cuda.is_available()
    opt.device = None if opt.cuda else -1

    # 快速變更設定
    opt.exp_dir = './experiment/transformer-reinforce/use_billion'
    opt.load_vocab_from = './experiment/transformer/lang8-cor2err/vocab.pt'
    opt.build_vocab_from = './data/billion/billion.30m.model.vocab'

    opt.load_D_from = opt.exp_dir
    # opt.load_D_from = None

    # dataset params
    opt.max_len = 20

    # G params
    # opt.load_G_a_from = './experiment/transformer/lang8-err2cor/'
    # opt.load_G_b_from = './experiment/transformer/lang8-cor2err/'
    opt.d_word_vec = 300
    opt.d_model = 300
    opt.d_inner_hid = 600
    opt.n_head = 6
    opt.n_layers = 3
    opt.embs_share_weight = False
    opt.beam_size = 1
    opt.max_token_seq_len = opt.max_len + 2  # 包含<BOS>, <EOS>
    opt.n_warmup_steps = 4000

    # D params
    opt.embed_dim = opt.d_model
    opt.num_kernel = 100
    opt.kernel_sizes = [3, 4, 5, 6, 7]
    opt.dropout_p = 0.25

    # train params
    opt.batch_size = 1
    opt.n_epoch = 10

    if not os.path.exists(opt.exp_dir):
        os.makedirs(opt.exp_dir)
    logging.basicConfig(filename=opt.exp_dir + '/.log',
                        format=LOG_FORMAT,
                        level=logging.DEBUG)
    logging.getLogger().addHandler(logging.StreamHandler())

    logging.info('Use CUDA? ' + str(opt.cuda))
    logging.info(opt)

    # ---------- prepare dataset ----------

    def len_filter(example):
        return len(example.src) <= opt.max_len and len(
            example.tgt) <= opt.max_len

    EN = SentencePieceField(init_token=Constants.BOS_WORD,
                            eos_token=Constants.EOS_WORD,
                            batch_first=True,
                            include_lengths=True)

    train = datasets.TranslationDataset(path='./data/dualgan/train',
                                        exts=('.billion.sp', '.use.sp'),
                                        fields=[('src', EN), ('tgt', EN)],
                                        filter_pred=len_filter)
    val = datasets.TranslationDataset(path='./data/dualgan/val',
                                      exts=('.billion.sp', '.use.sp'),
                                      fields=[('src', EN), ('tgt', EN)],
                                      filter_pred=len_filter)
    train_lang8, val_lang8 = Lang8.splits(exts=('.err.sp', '.cor.sp'),
                                          fields=[('src', EN), ('tgt', EN)],
                                          train='test',
                                          validation='test',
                                          test=None,
                                          filter_pred=len_filter)

    # 讀取 vocabulary(確保一致)
    try:
        logging.info('Load voab from %s' % opt.load_vocab_from)
        EN.load_vocab(opt.load_vocab_from)
    except FileNotFoundError:
        EN.build_vocab_from(opt.build_vocab_from)
        EN.save_vocab(opt.load_vocab_from)

    logging.info('Vocab len: %d' % len(EN.vocab))

    # 檢查Constants是否有誤
    assert EN.vocab.stoi[Constants.BOS_WORD] == Constants.BOS
    assert EN.vocab.stoi[Constants.EOS_WORD] == Constants.EOS
    assert EN.vocab.stoi[Constants.PAD_WORD] == Constants.PAD
    assert EN.vocab.stoi[Constants.UNK_WORD] == Constants.UNK

    # ---------- init model ----------

    # G = build_G(opt, EN, EN)
    hidden_size = 512
    bidirectional = True
    encoder = EncoderRNN(len(EN.vocab),
                         opt.max_len,
                         hidden_size,
                         n_layers=1,
                         bidirectional=bidirectional)
    decoder = DecoderRNN(len(EN.vocab),
                         opt.max_len,
                         hidden_size * 2 if bidirectional else 1,
                         n_layers=1,
                         dropout_p=0.2,
                         use_attention=True,
                         bidirectional=bidirectional,
                         eos_id=Constants.EOS,
                         sos_id=Constants.BOS)
    G = Seq2seq(encoder, decoder)
    for param in G.parameters():
        param.data.uniform_(-0.08, 0.08)

    # optim_G = ScheduledOptim(optim.Adam(
    #     G.get_trainable_parameters(),
    #     betas=(0.9, 0.98), eps=1e-09),
    #     opt.d_model, opt.n_warmup_steps)
    optim_G = optim.Adam(G.parameters(), lr=1e-4, betas=(0.9, 0.98), eps=1e-09)
    loss_G = NLLLoss(size_average=False)
    if torch.cuda.is_available():
        loss_G.cuda()

    # # 預先訓練D
    if opt.load_D_from:
        D = load_model(opt.load_D_from)
    else:
        D = build_D(opt, EN)
    optim_D = torch.optim.Adam(D.parameters(), lr=1e-4)

    def get_criterion(vocab_size):
        ''' With PAD token zero weight '''
        weight = torch.ones(vocab_size)
        weight[Constants.PAD] = 0
        return nn.CrossEntropyLoss(weight, size_average=False)

    crit_G = get_criterion(len(EN.vocab))
    crit_D = nn.BCELoss()

    if opt.cuda:
        G.cuda()
        D.cuda()
        crit_G.cuda()
        crit_D.cuda()

    # ---------- train ----------

    trainer_D = trainers.DiscriminatorTrainer()

    if not opt.load_D_from:
        for epoch in range(1):
            logging.info('[Pretrain D Epoch %d]' % epoch)

            pool = helper.DiscriminatorDataPool(opt.max_len, D.min_len,
                                                Constants.PAD)

            # 將資料塞進pool中
            train_iter = data.BucketIterator(dataset=train,
                                             batch_size=opt.batch_size,
                                             device=opt.device,
                                             sort_key=lambda x: len(x.src),
                                             repeat=False)
            pool.fill(train_iter)

            # train D
            trainer_D.train(D,
                            train_iter=pool.batch_gen(),
                            crit=crit_D,
                            optimizer=optim_D)
            pool.reset()

        Checkpoint(model=D,
                   optimizer=optim_D,
                   epoch=0,
                   step=0,
                   input_vocab=EN.vocab,
                   output_vocab=EN.vocab).save(opt.exp_dir)

    def eval_D():
        pool = helper.DiscriminatorDataPool(opt.max_len, D.min_len,
                                            Constants.PAD)
        val_iter = data.BucketIterator(dataset=val,
                                       batch_size=opt.batch_size,
                                       device=opt.device,
                                       sort_key=lambda x: len(x.src),
                                       repeat=False)
        pool.fill(val_iter)
        trainer_D.evaluate(D, val_iter=pool.batch_gen(), crit=crit_D)

        # eval_D()

    # Train G
    ALPHA = 0
    for epoch in range(100):
        logging.info('[Epoch %d]' % epoch)
        train_iter = data.BucketIterator(dataset=train,
                                         batch_size=1,
                                         device=opt.device,
                                         sort_within_batch=True,
                                         sort_key=lambda x: len(x.src),
                                         repeat=False)

        for step, batch in enumerate(train_iter):
            src_seq = batch.src[0]
            src_length = batch.src[1]
            tgt_seq = src_seq[0].clone()
            # gold = tgt_seq[:, 1:]

            optim_G.zero_grad()
            loss_G.reset()

            decoder_outputs, decoder_hidden, other = G.rollout(src_seq,
                                                               None,
                                                               None,
                                                               n_rollout=1)
            for i, step_output in enumerate(decoder_outputs):
                batch_size = tgt_seq.size(0)
                # print(step_output)

                # loss_G.eval_batch(step_output.contiguous().view(batch_size, -1), tgt_seq[:, i + 1])

            softmax_output = torch.exp(
                torch.cat([x for x in decoder_outputs], dim=0)).unsqueeze(0)
            softmax_output = helper.stack(softmax_output, 8)

            print(softmax_output)
            rollout = softmax_output.multinomial(1)
            print(rollout)

            tgt_seq = helper.pad_seq(tgt_seq.data,
                                     max_len=len(decoder_outputs) + 1,
                                     pad_value=Constants.PAD)
            tgt_seq = autograd.Variable(tgt_seq)
            for i, step_output in enumerate(decoder_outputs):
                batch_size = tgt_seq.size(0)
                loss_G.eval_batch(
                    step_output.contiguous().view(batch_size, -1),
                    tgt_seq[:, i + 1])
            G.zero_grad()
            loss_G.backward()
            optim_G.step()

            if step % 100 == 0:
                pred = torch.cat([x for x in other['sequence']], dim=1)
                print('[step %d] loss_rest %.4f' %
                      (epoch * len(train_iter) + step, loss_G.get_loss()))
                print('%s -> %s' %
                      (EN.reverse(tgt_seq.data)[0], EN.reverse(pred.data)[0]))

    # Reinforce Train G
    for p in D.parameters():
        p.requires_grad = False
Beispiel #4
0
    def train_G_PG(self, G, D, optim_G, src_seq):
        ''' Policy gradient training on G with beam
        '''
        batch_size = src_seq.size(0)
        for p in D.parameters():
            p.requires_grad = False

        # intermediate D reward
        # Dual training有將還原度一併作為reward,這邊暫時不考慮
        optim_G.zero_grad()

        # encode
        src_pos = G.get_position(src_seq.data)
        enc_output, *_ = G.encoder(src_seq, src_pos)

        # init rollout variable
        # enc_output = helper.stack(enc_output, n_rollout, dim=0)
        # src_seq = helper.stack(src_seq, n_rollout, dim=0)
        cur_seq = autograd.Variable(
            torch.LongTensor(self.top_k, 1).fill_(Constants.BOS))
        if torch.cuda.is_available():
            cur_seq = cur_seq.cuda()

        rewards, probs = [], []
        final_seqs = []
        candidates = []

        # decode
        for i in range(G.max_len):
            rollouts, sofmax_outs = [], []
            for s in cur_seq.chunk(self.top_k, dim=0):
                rollout_tokens, sofmax_out = G.step_rollout(
                    src_seq, enc_output, s, n_rollout=self.n_rollout)
                rollout_tokens = rollout_tokens.transpose(1,
                                                          0)  # (batch * k, 1)
                rollouts.append(rollout_tokens)
                sofmax_outs.append(sofmax_out)

            rollouts = torch.cat(rollouts, dim=0)
            softmax_outs = torch.cat(sofmax_outs, dim=0)

            # 將目前的seq複製成n個,以便與rollout的token(1個seq有n個rollout)結合
            cur_seq = cur_seq.data
            cur_seq = helper.inflate(cur_seq, self.n_rollout,
                                     0)  # (k * n, cur_len)
            cur_seq = torch.cat([cur_seq, rollouts],
                                dim=1)  # (batch * k, cur_len+1)

            _cur_seq = cur_seq.clone()
            if _cur_seq.size(1) < D.min_len:
                _cur_seq = helper.pad_seq(_cur_seq, D.min_len, Constants.PAD)
            reward = D(_cur_seq)  # (batch * k)

            # 儲存rewards, probs,用於計算loss
            # rewards = torch.cat([rewards, reward]) if rewards is not None else reward
            # probs = torch.cat([probs, softmax_outs]) if probs is not None else softmax_outs

            # 從cur_seqs中選出topK的seq
            sorted, indices = reward.sort(dim=0, descending=True)
            candidates = []

            for i in indices.data.split(1):
                seq = cur_seq[i]

                # seq是否存在candidates中? 若沒有則加入candidates
                if not any(torch.equal(seq, x) for x in candidates):
                    # 若candidate的最新一個token為EOS,則加入final_seqs
                    if seq[:, -1][0] == Constants.EOS:
                        final_seqs.append(seq)
                    else:
                        candidates.append(seq)

                    # 儲存被選上的rewards, probs
                    rewards.append(reward[i])
                    probs.append(softmax_outs[i])

                    if len(candidates) == (self.top_k - len(final_seqs)):
                        break

            # 判斷beams皆已完成?
            if len(candidates) == 0:
                break
            else:
                cur_seq = autograd.Variable(torch.cat(candidates, dim=0))

        final_seqs += candidates
        rewards = torch.cat(rewards)
        probs = torch.cat(probs)

        # print(rewards)
        # print(probs)

        # back propagation
        loss = -torch.mean(rewards * probs)
        loss.backward()
        nn.utils.clip_grad_norm(G.get_trainable_parameters(), 40)  # 避免grad爆炸
        optim_G.step()

        return final_seqs[0], rewards, probs, loss
Beispiel #5
0
    def _train_G_PG(self, G, D, optim_G, src_seq):
        ''' Policy gradient training on G '''
        # TODO: add beam?
        for p in D.parameters():
            p.requires_grad = False

        # intermediate D reward
        # Dual training有將還原度一併作為reward,這邊暫時不考慮
        optim_G.zero_grad()

        src_pos = G.get_position(src_seq.data)
        enc_output, *_ = G.encoder(src_seq, src_pos)
        dec_seq = autograd.Variable(
            torch.LongTensor(1, 1).fill_(Constants.BOS))
        if torch.cuda.is_available():
            dec_seq = dec_seq.cuda()

        rewards = None
        probs = None

        # decode
        for i in range(G.max_len):
            rollout_tokens, prob = G.step_rollout(src_seq,
                                                  enc_output,
                                                  dec_seq,
                                                  n_rollout=self.n_rollout)
            rollout_tokens = rollout_tokens.transpose(1, 0)  # (n_rollout, 1)

            partial_seq = helper.stack_seq(
                dec_seq.data, self.n_rollout)  # (n_rollout, cur_len)
            partial_seq = torch.cat([partial_seq, rollout_tokens],
                                    dim=1)  # (n_rollout, cur_len+1)
            if partial_seq.size(1) < D.min_len:
                partial_seq = helper.pad_seq(partial_seq, D.min_len,
                                             Constants.PAD)

            reward = D(partial_seq)

            top_i = reward.max(dim=0)[1].data
            next_token = rollout_tokens.squeeze(1)[top_i]  # 選reward最高的為下個token
            next_token = autograd.Variable(
                next_token.unsqueeze(1))  # 需轉為variable,torch.cat()才不會出錯

            dec_seq = torch.cat([dec_seq, next_token], dim=1)

            rewards = torch.cat([rewards, reward
                                 ]) if rewards is not None else reward
            probs = torch.cat([probs, prob]) if probs is not None else reward
            # probs += list(prob.split(1))

            if next_token[0] == Constants.EOS:
                break

        # back propagation
        # print(probs)
        # print(rewards)

        loss = -torch.mean(rewards * probs)

        # print(loss)

        loss.backward()
        nn.utils.clip_grad_norm(G.get_trainable_parameters(), 40)  # 避免grad爆炸
        optim_G.step()

        return dec_seq, rewards, probs, loss