def StepPGLoss(G, D, src_seq): ''' Policy gradient training on G ''' src_pos = G.get_position(src_seq.data) enc_output, *_ = G.encoder(src_seq, src_pos) dec_seq = autograd.Variable(torch.LongTensor(1, 1).fill_(Constants.BOS)) if torch.cuda.is_available(): dec_seq = dec_seq.cuda() rewards = None probs = None # decode for i in range(G.max_len): rollout_tokens, prob = G.step_rollout(src_seq, enc_output, dec_seq, n_rollout=6) rollout_tokens = rollout_tokens.transpose(1, 0) # (n_rollout, 1) partial_seq = helper.stack(dec_seq.data, 6) # (n_rollout, cur_len) partial_seq = torch.cat([partial_seq, rollout_tokens], dim=1) # (n_rollout, cur_len+1) if partial_seq.size(1) < D.min_len: partial_seq = helper.pad_seq(partial_seq, D.min_len, Constants.PAD) partial_seq = autograd.Variable(partial_seq) reward = D(partial_seq) top_i = reward.max(dim=0)[1].data next_token = rollout_tokens.squeeze(1)[top_i] # 選reward最高的為下個token next_token = autograd.Variable( next_token.unsqueeze(1)) # 需轉為variable,torch.cat()才不會出錯 dec_seq = torch.cat([dec_seq, next_token], dim=1) rewards = torch.cat([rewards, reward ]) if rewards is not None else reward probs = torch.cat([probs, prob]) if probs is not None else reward # probs += list(prob.split(1)) if next_token[0] == Constants.EOS: break # print(rewards) loss = -torch.mean(rewards * probs) return loss
def check_input(self, x): if x.size(1) < self.min_len: x = helper.pad_seq(x.data, self.min_len, Constants.PAD) return x
def main(): parser = argparse.ArgumentParser() opt = options.train_options(parser) opt = parser.parse_args() opt.cuda = torch.cuda.is_available() opt.device = None if opt.cuda else -1 # 快速變更設定 opt.exp_dir = './experiment/transformer-reinforce/use_billion' opt.load_vocab_from = './experiment/transformer/lang8-cor2err/vocab.pt' opt.build_vocab_from = './data/billion/billion.30m.model.vocab' opt.load_D_from = opt.exp_dir # opt.load_D_from = None # dataset params opt.max_len = 20 # G params # opt.load_G_a_from = './experiment/transformer/lang8-err2cor/' # opt.load_G_b_from = './experiment/transformer/lang8-cor2err/' opt.d_word_vec = 300 opt.d_model = 300 opt.d_inner_hid = 600 opt.n_head = 6 opt.n_layers = 3 opt.embs_share_weight = False opt.beam_size = 1 opt.max_token_seq_len = opt.max_len + 2 # 包含<BOS>, <EOS> opt.n_warmup_steps = 4000 # D params opt.embed_dim = opt.d_model opt.num_kernel = 100 opt.kernel_sizes = [3, 4, 5, 6, 7] opt.dropout_p = 0.25 # train params opt.batch_size = 1 opt.n_epoch = 10 if not os.path.exists(opt.exp_dir): os.makedirs(opt.exp_dir) logging.basicConfig(filename=opt.exp_dir + '/.log', format=LOG_FORMAT, level=logging.DEBUG) logging.getLogger().addHandler(logging.StreamHandler()) logging.info('Use CUDA? ' + str(opt.cuda)) logging.info(opt) # ---------- prepare dataset ---------- def len_filter(example): return len(example.src) <= opt.max_len and len( example.tgt) <= opt.max_len EN = SentencePieceField(init_token=Constants.BOS_WORD, eos_token=Constants.EOS_WORD, batch_first=True, include_lengths=True) train = datasets.TranslationDataset(path='./data/dualgan/train', exts=('.billion.sp', '.use.sp'), fields=[('src', EN), ('tgt', EN)], filter_pred=len_filter) val = datasets.TranslationDataset(path='./data/dualgan/val', exts=('.billion.sp', '.use.sp'), fields=[('src', EN), ('tgt', EN)], filter_pred=len_filter) train_lang8, val_lang8 = Lang8.splits(exts=('.err.sp', '.cor.sp'), fields=[('src', EN), ('tgt', EN)], train='test', validation='test', test=None, filter_pred=len_filter) # 讀取 vocabulary(確保一致) try: logging.info('Load voab from %s' % opt.load_vocab_from) EN.load_vocab(opt.load_vocab_from) except FileNotFoundError: EN.build_vocab_from(opt.build_vocab_from) EN.save_vocab(opt.load_vocab_from) logging.info('Vocab len: %d' % len(EN.vocab)) # 檢查Constants是否有誤 assert EN.vocab.stoi[Constants.BOS_WORD] == Constants.BOS assert EN.vocab.stoi[Constants.EOS_WORD] == Constants.EOS assert EN.vocab.stoi[Constants.PAD_WORD] == Constants.PAD assert EN.vocab.stoi[Constants.UNK_WORD] == Constants.UNK # ---------- init model ---------- # G = build_G(opt, EN, EN) hidden_size = 512 bidirectional = True encoder = EncoderRNN(len(EN.vocab), opt.max_len, hidden_size, n_layers=1, bidirectional=bidirectional) decoder = DecoderRNN(len(EN.vocab), opt.max_len, hidden_size * 2 if bidirectional else 1, n_layers=1, dropout_p=0.2, use_attention=True, bidirectional=bidirectional, eos_id=Constants.EOS, sos_id=Constants.BOS) G = Seq2seq(encoder, decoder) for param in G.parameters(): param.data.uniform_(-0.08, 0.08) # optim_G = ScheduledOptim(optim.Adam( # G.get_trainable_parameters(), # betas=(0.9, 0.98), eps=1e-09), # opt.d_model, opt.n_warmup_steps) optim_G = optim.Adam(G.parameters(), lr=1e-4, betas=(0.9, 0.98), eps=1e-09) loss_G = NLLLoss(size_average=False) if torch.cuda.is_available(): loss_G.cuda() # # 預先訓練D if opt.load_D_from: D = load_model(opt.load_D_from) else: D = build_D(opt, EN) optim_D = torch.optim.Adam(D.parameters(), lr=1e-4) def get_criterion(vocab_size): ''' With PAD token zero weight ''' weight = torch.ones(vocab_size) weight[Constants.PAD] = 0 return nn.CrossEntropyLoss(weight, size_average=False) crit_G = get_criterion(len(EN.vocab)) crit_D = nn.BCELoss() if opt.cuda: G.cuda() D.cuda() crit_G.cuda() crit_D.cuda() # ---------- train ---------- trainer_D = trainers.DiscriminatorTrainer() if not opt.load_D_from: for epoch in range(1): logging.info('[Pretrain D Epoch %d]' % epoch) pool = helper.DiscriminatorDataPool(opt.max_len, D.min_len, Constants.PAD) # 將資料塞進pool中 train_iter = data.BucketIterator(dataset=train, batch_size=opt.batch_size, device=opt.device, sort_key=lambda x: len(x.src), repeat=False) pool.fill(train_iter) # train D trainer_D.train(D, train_iter=pool.batch_gen(), crit=crit_D, optimizer=optim_D) pool.reset() Checkpoint(model=D, optimizer=optim_D, epoch=0, step=0, input_vocab=EN.vocab, output_vocab=EN.vocab).save(opt.exp_dir) def eval_D(): pool = helper.DiscriminatorDataPool(opt.max_len, D.min_len, Constants.PAD) val_iter = data.BucketIterator(dataset=val, batch_size=opt.batch_size, device=opt.device, sort_key=lambda x: len(x.src), repeat=False) pool.fill(val_iter) trainer_D.evaluate(D, val_iter=pool.batch_gen(), crit=crit_D) # eval_D() # Train G ALPHA = 0 for epoch in range(100): logging.info('[Epoch %d]' % epoch) train_iter = data.BucketIterator(dataset=train, batch_size=1, device=opt.device, sort_within_batch=True, sort_key=lambda x: len(x.src), repeat=False) for step, batch in enumerate(train_iter): src_seq = batch.src[0] src_length = batch.src[1] tgt_seq = src_seq[0].clone() # gold = tgt_seq[:, 1:] optim_G.zero_grad() loss_G.reset() decoder_outputs, decoder_hidden, other = G.rollout(src_seq, None, None, n_rollout=1) for i, step_output in enumerate(decoder_outputs): batch_size = tgt_seq.size(0) # print(step_output) # loss_G.eval_batch(step_output.contiguous().view(batch_size, -1), tgt_seq[:, i + 1]) softmax_output = torch.exp( torch.cat([x for x in decoder_outputs], dim=0)).unsqueeze(0) softmax_output = helper.stack(softmax_output, 8) print(softmax_output) rollout = softmax_output.multinomial(1) print(rollout) tgt_seq = helper.pad_seq(tgt_seq.data, max_len=len(decoder_outputs) + 1, pad_value=Constants.PAD) tgt_seq = autograd.Variable(tgt_seq) for i, step_output in enumerate(decoder_outputs): batch_size = tgt_seq.size(0) loss_G.eval_batch( step_output.contiguous().view(batch_size, -1), tgt_seq[:, i + 1]) G.zero_grad() loss_G.backward() optim_G.step() if step % 100 == 0: pred = torch.cat([x for x in other['sequence']], dim=1) print('[step %d] loss_rest %.4f' % (epoch * len(train_iter) + step, loss_G.get_loss())) print('%s -> %s' % (EN.reverse(tgt_seq.data)[0], EN.reverse(pred.data)[0])) # Reinforce Train G for p in D.parameters(): p.requires_grad = False
def train_G_PG(self, G, D, optim_G, src_seq): ''' Policy gradient training on G with beam ''' batch_size = src_seq.size(0) for p in D.parameters(): p.requires_grad = False # intermediate D reward # Dual training有將還原度一併作為reward,這邊暫時不考慮 optim_G.zero_grad() # encode src_pos = G.get_position(src_seq.data) enc_output, *_ = G.encoder(src_seq, src_pos) # init rollout variable # enc_output = helper.stack(enc_output, n_rollout, dim=0) # src_seq = helper.stack(src_seq, n_rollout, dim=0) cur_seq = autograd.Variable( torch.LongTensor(self.top_k, 1).fill_(Constants.BOS)) if torch.cuda.is_available(): cur_seq = cur_seq.cuda() rewards, probs = [], [] final_seqs = [] candidates = [] # decode for i in range(G.max_len): rollouts, sofmax_outs = [], [] for s in cur_seq.chunk(self.top_k, dim=0): rollout_tokens, sofmax_out = G.step_rollout( src_seq, enc_output, s, n_rollout=self.n_rollout) rollout_tokens = rollout_tokens.transpose(1, 0) # (batch * k, 1) rollouts.append(rollout_tokens) sofmax_outs.append(sofmax_out) rollouts = torch.cat(rollouts, dim=0) softmax_outs = torch.cat(sofmax_outs, dim=0) # 將目前的seq複製成n個,以便與rollout的token(1個seq有n個rollout)結合 cur_seq = cur_seq.data cur_seq = helper.inflate(cur_seq, self.n_rollout, 0) # (k * n, cur_len) cur_seq = torch.cat([cur_seq, rollouts], dim=1) # (batch * k, cur_len+1) _cur_seq = cur_seq.clone() if _cur_seq.size(1) < D.min_len: _cur_seq = helper.pad_seq(_cur_seq, D.min_len, Constants.PAD) reward = D(_cur_seq) # (batch * k) # 儲存rewards, probs,用於計算loss # rewards = torch.cat([rewards, reward]) if rewards is not None else reward # probs = torch.cat([probs, softmax_outs]) if probs is not None else softmax_outs # 從cur_seqs中選出topK的seq sorted, indices = reward.sort(dim=0, descending=True) candidates = [] for i in indices.data.split(1): seq = cur_seq[i] # seq是否存在candidates中? 若沒有則加入candidates if not any(torch.equal(seq, x) for x in candidates): # 若candidate的最新一個token為EOS,則加入final_seqs if seq[:, -1][0] == Constants.EOS: final_seqs.append(seq) else: candidates.append(seq) # 儲存被選上的rewards, probs rewards.append(reward[i]) probs.append(softmax_outs[i]) if len(candidates) == (self.top_k - len(final_seqs)): break # 判斷beams皆已完成? if len(candidates) == 0: break else: cur_seq = autograd.Variable(torch.cat(candidates, dim=0)) final_seqs += candidates rewards = torch.cat(rewards) probs = torch.cat(probs) # print(rewards) # print(probs) # back propagation loss = -torch.mean(rewards * probs) loss.backward() nn.utils.clip_grad_norm(G.get_trainable_parameters(), 40) # 避免grad爆炸 optim_G.step() return final_seqs[0], rewards, probs, loss
def _train_G_PG(self, G, D, optim_G, src_seq): ''' Policy gradient training on G ''' # TODO: add beam? for p in D.parameters(): p.requires_grad = False # intermediate D reward # Dual training有將還原度一併作為reward,這邊暫時不考慮 optim_G.zero_grad() src_pos = G.get_position(src_seq.data) enc_output, *_ = G.encoder(src_seq, src_pos) dec_seq = autograd.Variable( torch.LongTensor(1, 1).fill_(Constants.BOS)) if torch.cuda.is_available(): dec_seq = dec_seq.cuda() rewards = None probs = None # decode for i in range(G.max_len): rollout_tokens, prob = G.step_rollout(src_seq, enc_output, dec_seq, n_rollout=self.n_rollout) rollout_tokens = rollout_tokens.transpose(1, 0) # (n_rollout, 1) partial_seq = helper.stack_seq( dec_seq.data, self.n_rollout) # (n_rollout, cur_len) partial_seq = torch.cat([partial_seq, rollout_tokens], dim=1) # (n_rollout, cur_len+1) if partial_seq.size(1) < D.min_len: partial_seq = helper.pad_seq(partial_seq, D.min_len, Constants.PAD) reward = D(partial_seq) top_i = reward.max(dim=0)[1].data next_token = rollout_tokens.squeeze(1)[top_i] # 選reward最高的為下個token next_token = autograd.Variable( next_token.unsqueeze(1)) # 需轉為variable,torch.cat()才不會出錯 dec_seq = torch.cat([dec_seq, next_token], dim=1) rewards = torch.cat([rewards, reward ]) if rewards is not None else reward probs = torch.cat([probs, prob]) if probs is not None else reward # probs += list(prob.split(1)) if next_token[0] == Constants.EOS: break # back propagation # print(probs) # print(rewards) loss = -torch.mean(rewards * probs) # print(loss) loss.backward() nn.utils.clip_grad_norm(G.get_trainable_parameters(), 40) # 避免grad爆炸 optim_G.step() return dec_seq, rewards, probs, loss