class Trainer: def __init__(self, model, sv, tv, optim, trg_dict_size, valid_data=None, tests_data=None, n_critic=1): self.lamda = 5 self.eps = 1e-20 #self.beta_KL = 0.005 self.beta_KL = 0. self.beta_RLGen = 0.2 self.clip_rate = 0. self.beta_RLBatch = 0. self.model = model self.decoder = model.decoder self.classifier = self.decoder.classifier self.sv, self.tv = sv, tv self.trg_dict_size = trg_dict_size self.n_critic = 1 self.translator_sample = Translator(self.model, sv, tv, k=1, noise=False) #self.translator = Translator(model, sv, tv, k=10) if isinstance(optim, list): self.optim_G, self.optim_D = optim[0], optim[1] self.optim_G.init_optimizer(self.model.parameters()) self.optim_D.init_optimizer(self.model.parameters()) else: self.optim_G = Optim( 'adam', 10e-05, wargs.max_grad_norm, learning_rate_decay=wargs.learning_rate_decay, start_decay_from=wargs.start_decay_from, last_valid_bleu=wargs.last_valid_bleu ) self.optim_G.init_optimizer(self.model.parameters()) self.optim_D = optim self.optim_D.init_optimizer(self.model.parameters()) self.optim = [self.optim_G, self.optim_D] ''' self.optim_RL = Optim( 'adadelta', 1.0, wargs.max_grad_norm, learning_rate_decay=wargs.learning_rate_decay, start_decay_from=wargs.start_decay_from, last_valid_bleu=wargs.last_valid_bleu ) self.optim_RL.init_optimizer(self.model.parameters()) ''' self.maskSoftmax = MaskSoftmax() self.valid_data = valid_data self.tests_data = tests_data def mt_eval(self, eid, bid, optim=None): if optim: self.optim = optim state_dict = { 'model': self.model.state_dict(), 'epoch': eid, 'batch': bid, 'optim': self.optim } if wargs.save_one_model: model_file = '{}.pt'.format(wargs.model_prefix) else: model_file = '{}_e{}_upd{}.pt'.format(wargs.model_prefix, eid, bid) tc.save(state_dict, model_file) wlog('Saving temporary model in {}'.format(model_file)) self.model.eval() tor0 = Translator(self.model, self.sv, self.tv, print_att=wargs.print_att) BLEU = tor0.trans_eval(self.valid_data, eid, bid, model_file, self.tests_data) self.model.train() return BLEU # p1: (max_tlen_batch, batch_size, vocab_size) def distance(self, P, Q, y_masks, type='JS', y_gold=None): B = y_masks.size(1) hypo_N = y_masks.data.sum() if Q.size(0) > P.size(0): Q = Q[:(P.size(0) + 1)] if type == 'JS': #D_kl = tc.mean(tc.sum((tc.log(p1) - tc.log(p2)) * p1, dim=-1).squeeze(), dim=0) M = (P + Q) / 2. D_kl1 = tc.sum((tc.log(P) - tc.log(M)) * P, dim=-1).squeeze() D_kl2 = tc.sum((tc.log(Q) - tc.log(M)) * Q, dim=-1).squeeze() Js = 0.5 * D_kl1 + 0.5 * D_kl2 sent_batch_dist = tc.sum(Js * y_masks) / B Js = Js / y_masks.sum(0)[None, :] word_level_dist = tc.sum(Js * y_masks) / B del M, D_kl1, D_kl2, Js elif type == 'KL': KL = tc.sum(P * (tc.log(P + self.eps) - tc.log(Q + self.eps)), dim=-1) # (L, B, V) -> (L, B) sent_batch_dist = tc.sum(KL * y_masks) / B word_level_dist0 = tc.sum(KL * y_masks) / hypo_N KL = KL / y_masks.sum(0)[None, :] #print W_KL.data word_level_dist1 = tc.sum(KL * y_masks) / B #print W_dist.data[0], y_masks.size(1) del KL elif type == 'KL-sent': #print p1[0] #print p2[0] #print '-----------------------------' p1 = tc.gather(p1, 2, y_gold[:, :, None])[:, :, 0] p2 = tc.gather(p2, 2, y_gold[:, :, None])[:, :, 0] # p1 (max_tlen_batch, batch_size) #print (p2 < 1) == False KL = (y_masks * (tc.log(p1) - tc.log(p2))) * p1 sent_batch_dist = tc.sum(KL) / B KL = KL / y_masks.sum(0)[None, :] word_level_dist = tc.sum(KL * y_masks) / B # KL: (1, batch_size) del p1, p2, KL return sent_batch_dist, word_level_dist0, word_level_dist1 def hyps_padding_dist(self, oracle, hyps_L, y_gold_maxL, p_y_hyp): #hyps_dist = [None] * B B, hyps_dist, hyps = oracle.size(1), [], [] # oracle, w/o bos assert (B == len(hyps_L)) and (oracle.size(0) == p_y_hyp.size(0)) for bidx in range(B): hyp_L = hyps_L[bidx] - 1 # remove bos if hyp_L < y_gold_maxL: padding = tc.ones(y_gold_maxL - hyp_L) / self.trg_dict_size padding = padding[:, None].expand(padding.size(0), self.trg_dict_size) #pad = pad[:, None].expand((pad.size(0), one_p_y_hyp.size(-1))) padding = Variable(padding, requires_grad=False) if wargs.gpu_id and not padding.is_cuda: padding = padding.cuda() #print one_p_y_hyp.size(0), pad.size(0) #print tc.cat((p_y_hyp[:hyp_L, bidx, :], padding), dim=0).size() hyps_dist.append(tc.cat((p_y_hyp[:hyp_L, bidx, :], padding), dim=0)) hyps.append(tc.cat((oracle[:hyp_L, bidx], Variable(PAD * tc.ones(y_gold_maxL - hyp_L).long()).cuda()), dim=0)) else: hyps_dist.append(p_y_hyp[:y_gold_maxL, bidx, :]) hyps.append(oracle[:y_gold_maxL, bidx]) #hyps_dist[bidx] = one_p_y_hyp hyps_dist = tc.stack(hyps_dist, dim=1) hyps = tc.stack(hyps, dim=1) return hyps_dist, hyps def gumbel_sampling(self, B, y_maxL, feed_gold_out, noise=False): # feed_gold_out (L * B, V) logit = self.classifier.pred_map(feed_gold_out, noise=noise) if logit.is_cuda: logit = logit.cpu() hyps = tc.max(logit, 1)[1] # hyps (L*B, 1) hyps = hyps.view(y_maxL, B) hyps[0] = BOS * tc.ones(B).long() # first words are <s> # hyps (L, B) c1 = tc.clamp((hyps.data - EOS), min=0, max=self.trg_dict_size) c2 = tc.clamp((EOS - hyps.data), min=0, max=self.trg_dict_size) _hyps = c1 + c2 _hyps = tc.cat([_hyps, tc.zeros(B).long().unsqueeze(0)], 0) _hyps = tc.min(_hyps, 0)[1] #_hyps = tc.max(0 - _hyps, 0)[1] # idx: (1, B) hyps_L = _hyps.view(-1).tolist() hyps_mask = tc.zeros(y_maxL, B) for bid in range(B): hyps_mask[:, bid][:hyps_L[bid]] = 1. hyps_mask = Variable(hyps_mask, requires_grad=False) if wargs.gpu_id and not hyps_mask.is_cuda: hyps_mask = hyps_mask.cuda() if wargs.gpu_id and not hyps.is_cuda: hyps = hyps.cuda() return hyps, hyps_mask, hyps_L def try_trans(self, srcs, ref): # (len, 1) #src = sent_filter(list(srcs[:, bid].data)) x_filter = sent_filter(list(srcs)) y_filter = sent_filter(list(ref)) #wlog('\n[{:3}] {}'.format('Src', idx2sent(x_filter, self.sv))) #wlog('[{:3}] {}'.format('Ref', idx2sent(y_filter, self.tv))) onebest, onebest_ids, _ = self.translator_sample.trans_onesent(x_filter) #wlog('[{:3}] {}'.format('Out', onebest)) # no EOS and BOS return onebest_ids def beamsearch_sampling(self, srcs, trgs, eos=True): # y_masks: (trg_max_len, batch_size) B = srcs.size(1) oracles, oracles_L = [None] * B, [None] * B for bidx in range(B): onebest_ids = self.try_trans(srcs[:, bidx].data, trgs[:, bidx].data) if len(onebest_ids) == 0 or onebest_ids[0] != BOS: onebest_ids = [BOS] + onebest_ids if eos is True and onebest_ids[-1] != EOS: onebest_ids = onebest_ids + [EOS] oracles_L[bidx] = len(onebest_ids) oracles[bidx] = onebest_ids maxL = max(oracles_L) for bidx in range(B): cur_L, oracle = oracles_L[bidx], oracles[bidx] if cur_L < maxL: oracles[bidx] = oracle + [PAD] * (maxL - cur_L) oracles = Variable(tc.Tensor(oracles).long().t(), requires_grad=False) # -> (L, B) if wargs.gpu_id and not oracles.is_cuda: oracles = oracles.cuda() oracles_mask = oracles.ne(PAD).float() return oracles, oracles_mask, oracles_L def train(self, dh, dev_input, k, merge=False, name='default', percentage=0.1): #if (k + 1) % 1 == 0 and self.valid_data and self.tests_data: # wlog('Evaluation on dev ... ') # mt_eval(valid_data, self.model, self.sv, self.tv, # 0, 0, [self.optim, self.optim_RL, self.optim_G], self.tests_data) batch_count = len(dev_input) self.model.train() self.sampler = Nbs(self.model, self.tv, k=3, noise=False, print_att=False, batch_sample=True) for eid in range(wargs.start_epoch, wargs.max_epochs + 1): #self.optim_G.init_optimizer(self.model.parameters()) #self.optim_RL.init_optimizer(self.model.parameters()) size = int(percentage * batch_count) shuffled_batch_idx = tc.randperm(batch_count) wlog('{} NEW Epoch {}'.format('-' * 50, '-' * 50)) wlog('{}, Epo:{:>2}/{:>2} start, random {}/{}({:.2%}) calc BLEU ... '.format( name, eid, wargs.max_epochs, size, batch_count, percentage), False) param_1, param_2, param_3, param_4, param_5, param_6 = [], [], [], [], [], [] for k in range(size): bid, half_size = shuffled_batch_idx[k], wargs.batch_size # srcs: (max_sLen_batch, batch_size, emb), trgs: (max_tLen_batch, batch_size, emb) if merge is False: _, srcs, _, trgs, _, slens, srcs_m, trgs_m = dev_input[bid] else: _, srcs, _, trgs, _, slens, srcs_m, trgs_m = dh.merge_batch(dev_input[bid])[0] trgs, trgs_m = trgs[0], trgs_m[0] # we only use the first dev reference if wargs.sampling == 'gumbeling': oracles, oracles_mask, oracles_L = self.gumbel_sampling(B, y_gold_maxL, feed_gold_out, True) elif wargs.sampling == 'truncation': oracles, oracles_mask, oracles_L = self.beamsearch_sampling(srcs, trgs) elif wargs.sampling == 'length_limit': batch_beam_trgs = self.sampler.beam_search_trans(srcs, srcs_m, trgs_m) hyps = [list(zip(*b)[0]) for b in batch_beam_trgs] oracles = batch_search_oracle(hyps, trgs[1:], trgs_m[1:]) if wargs.gpu_id and not oracles.is_cuda: oracles = oracles.cuda() oracles_mask = oracles.ne(0).float() oracles_L = oracles_mask.sum(0).data.int().tolist() # oracles same with trgs, with bos and eos,(L, B) param_1.append(BLToStrList(oracles[1:-1].t(), [l-2 for l in oracles_L])) param_2.append(BLToStrList(trgs[1:-1].t(), trgs_m[1:-1].sum(0).data.int().tolist())) param_3.append(BLToStrList(oracles[1:-1, :half_size].t(), [l-2 for l in oracles_L[:half_size]])) param_4.append(BLToStrList(trgs[1:-1, :half_size].t(), trgs_m[1:-1, :half_size].sum(0).data.int().tolist())) param_5.append(BLToStrList(oracles[1:-1, half_size:].t(), [l-2 for l in oracles_L[half_size:]])) param_6.append(BLToStrList(trgs[1:-1, half_size:].t(), trgs_m[1:-1, half_size:].sum(0).data.int().tolist())) start_bat_bleu_hist = bleu('\n'.join(param_3), ['\n'.join(param_4)], logfun=debug) start_bat_bleu_new = bleu('\n'.join(param_5), ['\n'.join(param_6)], logfun=debug) start_bat_bleu = bleu('\n'.join(param_1), ['\n'.join(param_2)], logfun=debug) wlog('Random BLEU on history {}, new {}, mix {}'.format( start_bat_bleu_hist, start_bat_bleu_new, start_bat_bleu)) wlog('Model selection and testing ... ') self.mt_eval(eid, 0, [self.optim_G, self.optim_D]) if start_bat_bleu > 0.9: wlog('Better BLEU ... go to next data history ...') return s_kl_seen, w_kl_seen0, w_kl_seen1, rl_gen_seen, rl_rho_seen, rl_bat_seen, w_mle_seen, \ s_mle_seen, ppl_seen = 0., 0., 0., 0., 0., 0., 0., 0., 0. for bid in range(batch_count): if merge is False: _, srcs, _, trgs, _, slens, srcs_m, trgs_m = dev_input[bid] else: _, srcs, _, trgs, _, slens, srcs_m, trgs_m = dh.merge_batch(dev_input[bid], True)[0] trgs, trgs_m = trgs[0], trgs_m[0] gold_feed, gold_feed_mask = trgs[:-1], trgs_m[:-1] gold, gold_mask = trgs[1:], trgs_m[1:] B, y_gold_maxL = srcs.size(1), gold_feed.size(0) N = gold.data.ne(PAD).sum() debug('B:{}, gold_feed_ymaxL:{}, N:{}'.format(B, y_gold_maxL, N)) ################################################################################### debug('Optimizing KL distance ................................ {}'.format(name)) #self.model.zero_grad() self.optim_G.zero_grad() feed_gold_out, _ = self.model(srcs, gold_feed, srcs_m, gold_feed_mask) p_y_gold = self.classifier.logit_to_prob(feed_gold_out) # p_y_gold: (gold_max_len - 1, B, trg_dict_size) if wargs.sampling == 'gumbeling': oracles, oracles_mask, oracles_L = self.gumbel_sampling(B, y_gold_maxL, feed_gold_out, True) elif wargs.sampling == 'truncation': oracles, oracles_mask, oracles_L = self.beamsearch_sampling(srcs, trgs) elif wargs.sampling == 'length_limit': # w/o eos batch_beam_trgs = self.sampler.beam_search_trans(srcs, srcs_m, trgs_m) hyps = [list(zip(*b)[0]) for b in batch_beam_trgs] oracles = batch_search_oracle(hyps, trgs[1:], trgs_m[1:]) if wargs.gpu_id and not oracles.is_cuda: oracles = oracles.cuda() oracles_mask = oracles.ne(0).float() oracles_L = oracles_mask.sum(0).data.int().tolist() oracle_feed, oracle_feed_mask = oracles[:-1], oracles_mask[:-1] oracle, oracle_mask = oracles[1:], oracles_mask[1:] # oracles same with trgs, with bos and eos,(L, B) feed_oracle_out, _ = self.model(srcs, oracle_feed, srcs_m, oracle_feed_mask) p_y_hyp = self.classifier.logit_to_prob(feed_oracle_out) p_y_hyp_pad, oracle = self.hyps_padding_dist(oracle, oracles_L, y_gold_maxL, p_y_hyp) #wlog('feed oracle dist: {}, feed gold dist: {}, oracle: {}'.format(p_y_hyp_pad.size(), p_y_gold.size(), oracle.size())) #B_KL_loss = self.distance(p_y_gold, p_y_hyp_pad, hyps_mask[1:], type='KL', y_gold=gold) S_KL_loss, W_KL_loss0, W_KL_loss1 = self.distance( p_y_gold, p_y_hyp_pad, gold_mask, type='KL', y_gold=gold) debug('KL: Sent-level {}, Word0-level {}, Word1-level {}'.format( S_KL_loss.data[0], W_KL_loss0.data[0], W_KL_loss1.data[0])) s_kl_seen += S_KL_loss.data[0] w_kl_seen0 += W_KL_loss0.data[0] w_kl_seen1 += W_KL_loss1.data[0] del p_y_hyp, feed_oracle_out ################################################################################### debug('Optimizing RL(Gen) .......... {}'.format(name)) hyps_list = BLToStrList(oracle[:-1].t(), [l-2 for l in oracles_L], True) trgs_list = BLToStrList(trgs[1:-1].t(), trgs_m[1:-1].sum(0).data.int().tolist(), True) bleus_sampling = [] for hyp, ref in zip(hyps_list, trgs_list): bleus_sampling.append(bleu(hyp, [ref], logfun=debug)) bleus_sampling = toVar(bleus_sampling, wargs.gpu_id) oracle_mask = oracle.ne(0).float() p_y_ahyp = p_y_hyp_pad.gather(2, oracle[:, :, None])[:, :, 0] p_y_ahyp = ((p_y_ahyp + self.eps).log() * oracle_mask).sum(0) / oracle_mask.sum(0) p_y_agold = p_y_gold.gather(2, gold[:, :, None])[:, :, 0] p_y_agold = ((p_y_agold + self.eps).log() * gold_mask).sum(0) / gold_mask.sum(0) r_theta = p_y_ahyp / p_y_agold A = 1. - bleus_sampling RL_Gen_loss = tc.min(r_theta * A, clip(r_theta, self.clip_rate) * A).sum() RL_Gen_loss = (RL_Gen_loss).div(B) debug('...... RL(Gen) cliped loss {}'.format(RL_Gen_loss.data[0])) rl_gen_seen += RL_Gen_loss.data[0] del p_y_agold ################################################################################### debug('Optimizing RL(Batch) -> Gap of MLE and BLEU ... rho ... feed onebest .... ') param_1 = BLToStrList(oracles[1:-1].t(), [l-2 for l in oracles_L]) param_2 = BLToStrList(trgs[1:-1].t(), trgs_m[1:-1].sum(0).data.int().tolist()) rl_bat_bleu = bleu(param_1, [param_2], logfun=debug) rl_avg_bleu = tc.mean(bleus_sampling).data[0] rl_rho = cor_coef(p_y_ahyp, bleus_sampling, eps=self.eps) rl_rho_seen += rl_rho.data[0] # must use data, accumulating Variable needs more memory #p_y_hyp = p_y_hyp.exp() #p_y_hyp = (p_y_hyp * self.lamda / 3).exp() #p_y_hyp = self.maskSoftmax(p_y_hyp) p_y_ahyp = p_y_ahyp[None, :] p_y_ahyp_T = p_y_ahyp.t().expand(B, B) p_y_ahyp = p_y_ahyp.expand(B, B) p_y_ahyp_sum = p_y_ahyp_T + p_y_ahyp + self.eps #bleus_sampling = bleus_sampling[None, :].exp() bleus_sampling = self.maskSoftmax(self.lamda * bleus_sampling[None, :]) bleus_T = bleus_sampling.t().expand(B, B) bleus = bleus_sampling.expand(B, B) bleus_sum = bleus_T + bleus + self.eps #print 'p_y_hyp_sum......................' #print p_y_hyp_sum.data RL_Batch_loss = p_y_ahyp / p_y_ahyp_sum * tc.log(bleus_T / bleus_sum) + \ p_y_ahyp_T / p_y_ahyp_sum * tc.log(bleus / bleus_sum) #RL_Batch_loss = tc.sum(-RL_Batch_loss * toVar(1 - tc.eye(B))).div(B) RL_Batch_loss = tc.sum(-RL_Batch_loss * toVar(1 - tc.eye(B), wargs.gpu_id)) debug('RL(Batch) Mean BLEU: {}, rl_batch_loss: {}, rl_rho: {}, Bat BLEU: {}'.format( rl_avg_bleu, RL_Batch_loss.data[0], rl_rho.data[0], rl_bat_bleu)) rl_bat_seen += RL_Batch_loss.data[0] del oracles, oracles_mask, oracle_feed, oracle_feed_mask, oracle, oracle_mask,\ p_y_ahyp, bleus_sampling, bleus, p_y_ahyp_T, p_y_ahyp_sum, bleus_T, bleus_sum ''' (self.beta_KL * S_KL_loss + self.beta_RLGen * RL_Gen_loss + \ self.beta_RLBatch * RL_Batch_loss).backward(retain_graph=True) mle_loss, grad_output, _ = memory_efficient( feed_gold_out, gold, gold_mask, self.model.classifier) feed_gold_out.backward(grad_output) ''' (self.beta_KL * W_KL_loss0 + self.beta_RLGen * RL_Gen_loss + \ self.beta_RLBatch * RL_Batch_loss).backward(retain_graph=True) self.optim_G.step() ###################################################### discrimitor #mle_loss, _, _ = self.classifier(feed_gold_out, gold, gold_mask) #mle_loss = mle_loss.div(B) #mle_loss = mle_loss.data[0] self.optim_D.zero_grad() mle_loss, _, _ = self.classifier.snip_back_prop(feed_gold_out, gold, gold_mask) self.optim_D.step() w_mle_seen += ( mle_loss / N ) s_mle_seen += ( mle_loss / B ) ppl_seen += math.exp(mle_loss/N) wlog('Epo:{:>2}/{:>2}, Bat:[{}/{}], W0-KL {:4.2f}, W1-KL {:4.2f}, ' 'S-RLGen {:4.2f}, B-rho {:4.2f}, B-RLBat {:4.2f}, W-MLE:{:4.2f}, ' 'S-MLE:{:4.2f}, W-ppl:{:4.2f}, B-bleu:{:4.2f}, A-bleu:{:4.2f}'.format( eid, wargs.max_epochs, bid, batch_count, W_KL_loss0.data[0], W_KL_loss1.data[0], RL_Gen_loss.data[0], rl_rho.data[0], RL_Batch_loss.data[0], mle_loss/N, mle_loss/B, math.exp(mle_loss/N), rl_bat_bleu, rl_avg_bleu)) #wlog('=' * 100) del S_KL_loss, W_KL_loss0, W_KL_loss1, RL_Gen_loss, RL_Batch_loss, feed_gold_out wlog('End epoch: S-KL {:4.2f}, W0-KL {:4.2f}, W1-KL {:4.2f}, S-RLGen {:4.2f}, B-rho ' '{:4.2f}, B-RLBat {:4.2f}, W-MLE {:4.2f}, S-MLE {:4.2f}, W-ppl {:4.2f}'.format( s_kl_seen/batch_count, w_kl_seen0/batch_count, w_kl_seen1/batch_count, rl_gen_seen/batch_count, rl_rho_seen/batch_count, rl_bat_seen/batch_count, w_mle_seen/batch_count, s_mle_seen/batch_count, ppl_seen/batch_count))
def main(): init_dir(wargs.dir_model) init_dir(wargs.dir_valid) vocab_data = {} train_srcD_file = wargs.src_vocab_from wlog('\nPreparing out of domain source vocabulary from {} ... '.format( train_srcD_file)) src_vocab = extract_vocab(train_srcD_file, wargs.src_dict, wargs.src_dict_size) #DANN train_srcD_file_domain = wargs.src_domain_vocab_from wlog('\nPreparing in domain source vocabulary from {} ...'.format( train_srcD_file_domain)) src_vocab = updata_vocab(train_srcD_file_domain, src_vocab, wargs.src_dict, wargs.src_dict_size) vocab_data['src'] = src_vocab train_trgD_file = wargs.trg_vocab_from wlog('\nPreparing out of domain target vocabulary from {} ... '.format( train_trgD_file)) trg_vocab = extract_vocab(train_trgD_file, wargs.trg_dict, wargs.trg_dict_size) #DANN train_trgD_file_domain = wargs.trg_domain_vocab_from wlog('\nPreparing in domain target vocabulary from {} ... '.format( train_trgD_file_domain)) trg_vocab = updata_vocab(train_trgD_file_domain, trg_vocab, wargs.trg_dict, wargs.trg_dict_size) vocab_data['trg'] = trg_vocab train_src_file = wargs.train_src train_trg_file = wargs.train_trg if wargs.fine_tune is False: wlog('\nPreparing out of domain training set from {} and {} ... '. format(train_src_file, train_trg_file)) train_src_tlst, train_trg_tlst = wrap_data( train_src_file, train_trg_file, vocab_data['src'], vocab_data['trg'], max_seq_len=wargs.max_seq_len) else: wlog('\nNo out of domain trainin set ...') #DANN train_src_file_domain = wargs.train_src_domain train_trg_file_domain = wargs.train_trg_domain wlog('\nPreparing in domain training set from {} and {}...'.format( train_src_file_domain, train_trg_file_domain)) train_src_tlst_domain, train_trg_tlst_domain = wrap_data( train_src_file_domain, train_trg_file_domain, vocab_data['src'], vocab_data['trg'], max_seq_len=wargs.max_seq_len) ''' list [torch.LongTensor (sentence), torch.LongTensor, torch.LongTensor, ...] no padding ''' valid_file = '{}{}.{}'.format(wargs.val_tst_dir, wargs.val_prefix, wargs.val_src_suffix) wlog('\nPreparing validation set from {} ... '.format(valid_file)) valid_src_tlst, valid_src_lens = val_wrap_data(valid_file, src_vocab) if wargs.fine_tune is False: wlog('Out of domain Sentence-pairs count in training data: {}'.format( len(train_src_tlst))) wlog('In domain Sentence-pairs count in training data: {}'.format( len(train_src_tlst_domain))) src_vocab_size, trg_vocab_size = vocab_data['src'].size( ), vocab_data['trg'].size() wlog('Vocabulary size: |source|={}, |target|={}'.format( src_vocab_size, trg_vocab_size)) if wargs.fine_tune is False: batch_train = Input(train_src_tlst, train_trg_tlst, wargs.batch_size) else: batch_train = None batch_valid = Input(valid_src_tlst, None, 1, volatile=True) #DANN batch_train_domain = Input(train_src_tlst_domain, train_trg_tlst_domain, wargs.batch_size) tests_data = None if wargs.tests_prefix is not None: init_dir(wargs.dir_tests) tests_data = {} for prefix in wargs.tests_prefix: init_dir(wargs.dir_tests + '/' + prefix) test_file = '{}{}.{}'.format(wargs.val_tst_dir, prefix, wargs.val_src_suffix) wlog('Preparing test set from {} ... '.format(test_file)) test_src_tlst, _ = val_wrap_data(test_file, src_vocab) tests_data[prefix] = Input(test_src_tlst, None, 1, volatile=True) sv = vocab_data['src'].idx2key tv = vocab_data['trg'].idx2key nmtModel = NMT(src_vocab_size, trg_vocab_size) if wargs.pre_train is not None: assert os.path.exists(wargs.pre_train), 'Requires pre-trained model' _dict = _load_model(wargs.pre_train) # initializing parameters of interactive attention model class_dict = None if len(_dict) == 4: model_dict, eid, bid, optim = _dict elif len(_dict) == 5: model_dict, class_dict, eid, bid, optim = _dict for name, param in nmtModel.named_parameters(): if name in model_dict: param.requires_grad = not wargs.fix_pre_params param.data.copy_(model_dict[name]) wlog('{:7} -> grad {}\t{}'.format('Model', param.requires_grad, name)) elif name.endswith('map_vocab.weight'): if class_dict is not None: param.requires_grad = not wargs.fix_pre_params param.data.copy_(class_dict['map_vocab.weight']) wlog('{:7} -> grad {}\t{}'.format('Model', param.requires_grad, name)) elif name.endswith('map_vocab.bias'): if class_dict is not None: param.requires_grad = not wargs.fix_pre_params param.data.copy_(class_dict['map_vocab.bias']) wlog('{:7} -> grad {}\t{}'.format('Model', param.requires_grad, name)) else: init_params(param, name, True) wargs.start_epoch = eid + 1 else: for n, p in nmtModel.named_parameters(): init_params(p, n, True) optim = Optim(wargs.opt_mode, wargs.learning_rate, wargs.max_grad_norm, learning_rate_decay=wargs.learning_rate_decay, start_decay_from=wargs.start_decay_from, last_valid_bleu=wargs.last_valid_bleu) if wargs.gpu_id: nmtModel.cuda() wlog('Push model onto GPU[{}] ... '.format(wargs.gpu_id[0])) else: nmtModel.cpu() wlog('Push model onto CPU ... ') wlog(nmtModel) wlog(optim) pcnt1 = len([p for p in nmtModel.parameters()]) pcnt2 = sum([p.nelement() for p in nmtModel.parameters()]) wlog('Parameters number: {}/{}'.format(pcnt1, pcnt2)) optim.init_optimizer(nmtModel.parameters()) trainer = Trainer(nmtModel, batch_train, batch_train_domain, vocab_data, optim, batch_valid, tests_data) trainer.train()
def main(): #if wargs.ss_type is not None: assert wargs.model == 1, 'Only rnnsearch support schedule sample' init_dir(wargs.dir_model) init_dir(wargs.dir_valid) src = os.path.join(wargs.dir_data, '{}.{}'.format(wargs.train_prefix, wargs.train_src_suffix)) trg = os.path.join(wargs.dir_data, '{}.{}'.format(wargs.train_prefix, wargs.train_trg_suffix)) vocabs = {} wlog('\nPreparing source vocabulary from {} ... '.format(src)) src_vocab = extract_vocab(src, wargs.src_vcb, wargs.n_src_vcb_plan, wargs.max_seq_len, char=wargs.src_char) wlog('\nPreparing target vocabulary from {} ... '.format(trg)) trg_vocab = extract_vocab(trg, wargs.trg_vcb, wargs.n_trg_vcb_plan, wargs.max_seq_len) n_src_vcb, n_trg_vcb = src_vocab.size(), trg_vocab.size() wlog('Vocabulary size: |source|={}, |target|={}'.format(n_src_vcb, n_trg_vcb)) vocabs['src'], vocabs['trg'] = src_vocab, trg_vocab wlog('\nPreparing training set from {} and {} ... '.format(src, trg)) trains = {} train_src_tlst, train_trg_tlst = wrap_data(wargs.dir_data, wargs.train_prefix, wargs.train_src_suffix, wargs.train_trg_suffix, src_vocab, trg_vocab, shuffle=True, sort_k_batches=wargs.sort_k_batches, max_seq_len=wargs.max_seq_len, char=wargs.src_char) ''' list [torch.LongTensor (sentence), torch.LongTensor, torch.LongTensor, ...] no padding ''' batch_train = Input(train_src_tlst, train_trg_tlst, wargs.batch_size, batch_type=wargs.batch_type, bow=wargs.trg_bow, batch_sort=False) wlog('Sentence-pairs count in training data: {}'.format(len(train_src_tlst))) batch_valid = None if wargs.val_prefix is not None: val_src_file = os.path.join(wargs.val_tst_dir, '{}.{}'.format(wargs.val_prefix, wargs.val_src_suffix)) val_trg_file = os.path.join(wargs.val_tst_dir, '{}.{}'.format(wargs.val_prefix, wargs.val_ref_suffix)) wlog('\nPreparing validation set from {} and {} ... '.format(val_src_file, val_trg_file)) valid_src_tlst, valid_trg_tlst = wrap_data(wargs.val_tst_dir, wargs.val_prefix, wargs.val_src_suffix, wargs.val_ref_suffix, src_vocab, trg_vocab, shuffle=False, max_seq_len=wargs.dev_max_seq_len, char=wargs.src_char) batch_valid = Input(valid_src_tlst, valid_trg_tlst, 1, batch_sort=False) batch_tests = None if wargs.tests_prefix is not None: assert isinstance(wargs.tests_prefix, list), 'Test files should be list.' init_dir(wargs.dir_tests) batch_tests = {} for prefix in wargs.tests_prefix: init_dir(wargs.dir_tests + '/' + prefix) test_file = '{}{}.{}'.format(wargs.val_tst_dir, prefix, wargs.val_src_suffix) wlog('\nPreparing test set from {} ... '.format(test_file)) test_src_tlst, _ = wrap_tst_data(test_file, src_vocab, char=wargs.src_char) batch_tests[prefix] = Input(test_src_tlst, None, 1, batch_sort=False) wlog('\n## Finish to Prepare Dataset ! ##\n') src_emb = WordEmbedding(n_src_vcb, wargs.d_src_emb, wargs.input_dropout, wargs.position_encoding, prefix='Src') trg_emb = WordEmbedding(n_trg_vcb, wargs.d_trg_emb, wargs.input_dropout, wargs.position_encoding, prefix='Trg') # share the embedding matrix - preprocess with share_vocab required. if wargs.embs_share_weight: if n_src_vcb != n_trg_vcb: raise AssertionError('The `-share_vocab` should be set during ' 'preprocess if you use share_embeddings!') src_emb.we.weight = trg_emb.we.weight nmtModel = build_NMT(src_emb, trg_emb) if not wargs.copy_attn: classifier = Classifier(wargs.d_model if wargs.decoder_type == 'att' else 2 * wargs.d_enc_hid, n_trg_vcb, trg_emb, loss_norm=wargs.loss_norm, label_smoothing=wargs.label_smoothing, emb_loss=wargs.emb_loss, bow_loss=wargs.bow_loss) nmtModel.decoder.classifier = classifier if wargs.gpu_id is not None: wlog('push model onto GPU {} ... '.format(wargs.gpu_id), 0) #nmtModel = nn.DataParallel(nmtModel, device_ids=wargs.gpu_id) nmtModel.to(tc.device('cuda')) else: wlog('push model onto CPU ... ', 0) nmtModel.to(tc.device('cpu')) wlog('done.') if wargs.pre_train is not None: assert os.path.exists(wargs.pre_train) from tools.utils import load_model _dict = load_model(wargs.pre_train) # initializing parameters of interactive attention model class_dict = None if len(_dict) == 5: model_dict, class_dict, eid, bid, optim = _dict elif len(_dict) == 4: model_dict, eid, bid, optim = _dict for name, param in nmtModel.named_parameters(): if name in model_dict: param.requires_grad = not wargs.fix_pre_params param.data.copy_(model_dict[name]) wlog('{:7} -> grad {}\t{}'.format('Model', param.requires_grad, name)) elif name.endswith('map_vocab.weight'): if class_dict is not None: param.requires_grad = not wargs.fix_pre_params param.data.copy_(class_dict['map_vocab.weight']) wlog('{:7} -> grad {}\t{}'.format('Model', param.requires_grad, name)) elif name.endswith('map_vocab.bias'): if class_dict is not None: param.requires_grad = not wargs.fix_pre_params param.data.copy_(class_dict['map_vocab.bias']) wlog('{:7} -> grad {}\t{}'.format('Model', param.requires_grad, name)) else: init_params(param, name, init_D=wargs.param_init_D, a=float(wargs.u_gain)) wargs.start_epoch = eid + 1 else: optim = Optim(wargs.opt_mode, wargs.learning_rate, wargs.max_grad_norm) #for n, p in nmtModel.named_parameters(): # bias can not be initialized uniformly #if wargs.encoder_type != 'att' and wargs.decoder_type != 'att': # init_params(p, n, init_D=wargs.param_init_D, a=float(wargs.u_gain)) wlog(nmtModel) wlog(optim) pcnt1 = len([p for p in nmtModel.parameters()]) pcnt2 = sum([p.nelement() for p in nmtModel.parameters()]) wlog('parameters number: {}/{}'.format(pcnt1, pcnt2)) wlog('\n' + '*' * 30 + ' trainable parameters ' + '*' * 30) for n, p in nmtModel.named_parameters(): if p.requires_grad: wlog('{:60} : {}'.format(n, p.size())) optim.init_optimizer(nmtModel.parameters()) trainer = Trainer(nmtModel, batch_train, vocabs, optim, batch_valid, batch_tests) trainer.train()
def main(): # if wargs.ss_type is not None: assert wargs.model == 1, 'Only rnnsearch support schedule sample' init_dir(wargs.dir_model) init_dir(wargs.dir_valid) src = os.path.join( wargs.dir_data, '{}.{}'.format(wargs.train_prefix, wargs.train_src_suffix)) trg = os.path.join( wargs.dir_data, '{}.{}'.format(wargs.train_prefix, wargs.train_trg_suffix)) src, trg = os.path.abspath(src), os.path.abspath(trg) vocabs = {} if wargs.share_vocab is False: wlog('\nPreparing source vocabulary from {} ... '.format(src)) src_vocab = extract_vocab(src, wargs.src_vcb, wargs.n_src_vcb_plan, wargs.max_seq_len, char=wargs.src_char) wlog('\nPreparing target vocabulary from {} ... '.format(trg)) trg_vocab = extract_vocab(trg, wargs.trg_vcb, wargs.n_trg_vcb_plan, wargs.max_seq_len) n_src_vcb, n_trg_vcb = src_vocab.size(), trg_vocab.size() wlog('Vocabulary size: |source|={}, |target|={}'.format( n_src_vcb, n_trg_vcb)) else: wlog('\nPreparing the shared vocabulary from \n\t{}\n\t{}'.format( src, trg)) trg_vocab = src_vocab = extract_vocab(src, wargs.src_vcb, wargs.n_src_vcb_plan, wargs.max_seq_len, share_vocab=True, trg_file=trg) n_src_vcb, n_trg_vcb = src_vocab.size(), trg_vocab.size() wlog('Shared vocabulary size: |vocab|={}'.format(src_vocab.size())) vocabs['src'], vocabs['trg'] = src_vocab, trg_vocab wlog('\nPreparing training set from {} and {} ... '.format(src, trg)) trains = {} train_src_tlst, train_trg_tlst = wrap_data( wargs.dir_data, wargs.train_prefix, wargs.train_src_suffix, wargs.train_trg_suffix, src_vocab, trg_vocab, shuffle=True, sort_k_batches=wargs.sort_k_batches, max_seq_len=wargs.max_seq_len, char=wargs.src_char) ''' list [torch.LongTensor (sentence), torch.LongTensor, torch.LongTensor, ...] no padding ''' batch_train = Input(train_src_tlst, train_trg_tlst, wargs.batch_size, batch_type=wargs.batch_type, bow=wargs.trg_bow, batch_sort=False, gpu_ids=device_ids) wlog('Sentence-pairs count in training data: {}'.format( len(train_src_tlst))) batch_valid = None if wargs.val_prefix is not None: val_src_file = os.path.join( wargs.val_tst_dir, '{}.{}'.format(wargs.val_prefix, wargs.val_src_suffix)) val_trg_file = os.path.join( wargs.val_tst_dir, '{}.{}'.format(wargs.val_prefix, wargs.val_ref_suffix)) val_src_file, val_trg_file = os.path.abspath( val_src_file), os.path.abspath(val_trg_file) wlog('\nPreparing validation set from {} and {} ... '.format( val_src_file, val_trg_file)) valid_src_tlst, valid_trg_tlst = wrap_data( wargs.val_tst_dir, wargs.val_prefix, wargs.val_src_suffix, wargs.val_ref_suffix, src_vocab, trg_vocab, shuffle=False, max_seq_len=wargs.dev_max_seq_len, char=wargs.src_char) batch_valid = Input(valid_src_tlst, valid_trg_tlst, batch_size=wargs.valid_batch_size, batch_sort=False, gpu_ids=device_ids) batch_tests = None if wargs.tests_prefix is not None: assert isinstance(wargs.tests_prefix, list), 'Test files should be list.' init_dir(wargs.dir_tests) batch_tests = {} for prefix in wargs.tests_prefix: init_dir(wargs.dir_tests + '/' + prefix) test_file = '{}{}.{}'.format(wargs.val_tst_dir, prefix, wargs.val_src_suffix) test_file = os.path.abspath(test_file) wlog('\nPreparing test set from {} ... '.format(test_file)) test_src_tlst, _ = wrap_tst_data(test_file, src_vocab, char=wargs.src_char) batch_tests[prefix] = Input(test_src_tlst, None, batch_size=wargs.test_batch_size, batch_sort=False, gpu_ids=device_ids) wlog('\n## Finish to Prepare Dataset ! ##\n') src_emb = WordEmbedding(n_src_vcb, wargs.d_src_emb, wargs.input_dropout, wargs.position_encoding, prefix='Src') trg_emb = WordEmbedding(n_trg_vcb, wargs.d_trg_emb, wargs.input_dropout, wargs.position_encoding, prefix='Trg') # share the embedding matrix between the source and target if wargs.share_vocab is True: src_emb.we.weight = trg_emb.we.weight nmtModel = build_NMT(src_emb, trg_emb) if device_ids is not None: wlog('push model onto GPU {} ... '.format(device_ids[0]), 0) nmtModel_par = nn.DataParallel(nmtModel, device_ids=device_ids) nmtModel_par.to(device) else: wlog('push model onto CPU ... ', 0) nmtModel.to(tc.device('cpu')) wlog('done.') if wargs.pre_train is not None: wlog(wargs.pre_train) assert os.path.exists(wargs.pre_train) from tools.utils import load_model _dict = load_model(wargs.pre_train) # initializing parameters of interactive attention model class_dict = None if len(_dict) == 5: # model_dict, e_idx, e_bidx, n_steps, optim = _dict['model'], _dict['epoch'], _dict['batch'], _dict['steps'], _dict['optim'] model_dict, e_idx, e_bidx, n_steps, optim = _dict elif len(_dict) == 4: model_dict, e_idx, e_bidx, optim = _dict for name, param in nmtModel.named_parameters(): if name in model_dict: param.requires_grad = not wargs.fix_pre_params param.data.copy_(model_dict[name]) # wlog('{:7} -> grad {}\t{}'.format('Model', param.requires_grad, name)) elif name.endswith('map_vocab.weight'): if class_dict is not None: param.requires_grad = not wargs.fix_pre_params param.data.copy_(class_dict['map_vocab.weight']) # wlog('{:7} -> grad {}\t{}'.format('Model', param.requires_grad, name)) elif name.endswith('map_vocab.bias'): if class_dict is not None: param.requires_grad = not wargs.fix_pre_params param.data.copy_(class_dict['map_vocab.bias']) # wlog('{:7} -> grad {}\t{}'.format('Model', param.requires_grad, name)) else: init_params(param, name, init_D=wargs.param_init_D, a=float(wargs.u_gain)) # wargs.start_epoch = e_idx + 1 # # 不重新开始 # optim.n_current_steps = 0 else: optim = Optim(wargs.opt_mode, wargs.learning_rate, wargs.max_grad_norm) for n, p in nmtModel.named_parameters(): # bias can not be initialized uniformly if 'norm' in n: wlog('ignore layer norm init ...') continue if 'emb' in n: wlog('ignore word embedding weight init ...') continue if 'vcb_proj' in n: wlog('ignore vcb_proj weight init ...') continue init_params(p, n, init_D=wargs.param_init_D, a=float(wargs.u_gain)) # if wargs.encoder_type != 'att' and wargs.decoder_type != 'att': # init_params(p, n, init_D=wargs.param_init_D, a=float(wargs.u_gain)) # wlog(nmtModel) wlog(optim) pcnt1 = len([p for p in nmtModel.parameters()]) pcnt2 = sum([p.nelement() for p in nmtModel.parameters()]) wlog('parameters number: {}/{}'.format(pcnt1, pcnt2)) # wlog('\n' + '*' * 30 + ' trainable parameters ' + '*' * 30) # for n, p in nmtModel.named_parameters(): # if p.requires_grad: wlog('{:60} : {}'.format(n, p.size())) opt_state = None if wargs.pre_train: opt_state = optim.optimizer.state_dict() if wargs.use_reinfore_ce is False: criterion = LabelSmoothingCriterion( trg_emb.n_vocab, label_smoothing=wargs.label_smoothing) else: word2vec = tc.load(wargs.word2vec_weight)['w2v'] # criterion = Word2VecDistanceCriterion(word2vec) criterion = CosineDistance(word2vec) if device_ids is not None: wlog('push criterion onto GPU {} ... '.format(device_ids[0]), 0) criterion = criterion.to(device) wlog('done.') # if wargs.reinfore_type == 0 or wargs.reinfore_type == 1: # param = list(nmtModel.parameters()) # else: # param = list(nmtModel.parameters()) + list(criterion.parameters()) param = list(nmtModel.parameters()) optim.init_optimizer(param) lossCompute = MultiGPULossCompute( nmtModel.generator, criterion, wargs.d_model if wargs.decoder_type == 'att' else 2 * wargs.d_enc_hid, n_trg_vcb, trg_emb, nmtModel.bowMapper, loss_norm=wargs.loss_norm, chunk_size=wargs.chunk_size, device_ids=device_ids) trainer = Trainer(nmtModel_par, batch_train, vocabs, optim, lossCompute, nmtModel, batch_valid, batch_tests, writer) trainer.train() writer.close()
def main(): #if wargs.ss_type is not None: assert wargs.model == 1, 'Only rnnsearch support schedule sample' init_dir(wargs.dir_model) init_dir(wargs.dir_valid) src = os.path.join( wargs.dir_data, '{}.{}'.format(wargs.train_prefix, wargs.train_src_suffix)) trg = os.path.join( wargs.dir_data, '{}.{}'.format(wargs.train_prefix, wargs.train_trg_suffix)) vocabs = {} wlog('\n[o/Subword] Preparing source vocabulary from {} ... '.format(src)) src_vocab = extract_vocab(src, wargs.src_dict, wargs.src_dict_size, wargs.max_seq_len, char=wargs.src_char) wlog('\n[o/Subword] Preparing target vocabulary from {} ... '.format(trg)) trg_vocab = extract_vocab(trg, wargs.trg_dict, wargs.trg_dict_size, wargs.max_seq_len) src_vocab_size, trg_vocab_size = src_vocab.size(), trg_vocab.size() wlog('Vocabulary size: |source|={}, |target|={}'.format( src_vocab_size, trg_vocab_size)) vocabs['src'], vocabs['trg'] = src_vocab, trg_vocab wlog('\nPreparing training set from {} and {} ... '.format(src, trg)) trains = {} train_src_tlst, train_trg_tlst = wrap_data(wargs.dir_data, wargs.train_prefix, wargs.train_src_suffix, wargs.train_trg_suffix, src_vocab, trg_vocab, max_seq_len=wargs.max_seq_len, char=wargs.src_char) ''' list [torch.LongTensor (sentence), torch.LongTensor, torch.LongTensor, ...] no padding ''' batch_train = Input(train_src_tlst, train_trg_tlst, wargs.batch_size, batch_sort=True) wlog('Sentence-pairs count in training data: {}'.format( len(train_src_tlst))) batch_valid = None if wargs.val_prefix is not None: val_src_file = '{}{}.{}'.format(wargs.val_tst_dir, wargs.val_prefix, wargs.val_src_suffix) val_trg_file = '{}{}.{}'.format(wargs.val_tst_dir, wargs.val_prefix, wargs.val_ref_suffix) wlog('\nPreparing validation set from {} and {} ... '.format( val_src_file, val_trg_file)) valid_src_tlst, valid_trg_tlst = wrap_data( wargs.val_tst_dir, wargs.val_prefix, wargs.val_src_suffix, wargs.val_ref_suffix, src_vocab, trg_vocab, shuffle=False, sort_data=False, max_seq_len=wargs.dev_max_seq_len, char=wargs.src_char) batch_valid = Input(valid_src_tlst, valid_trg_tlst, 1, volatile=True, batch_sort=False) batch_tests = None if wargs.tests_prefix is not None: assert isinstance(wargs.tests_prefix, list), 'Test files should be list.' init_dir(wargs.dir_tests) batch_tests = {} for prefix in wargs.tests_prefix: init_dir(wargs.dir_tests + '/' + prefix) test_file = '{}{}.{}'.format(wargs.val_tst_dir, prefix, wargs.val_src_suffix) wlog('\nPreparing test set from {} ... '.format(test_file)) test_src_tlst, _ = wrap_tst_data(test_file, src_vocab, char=wargs.src_char) batch_tests[prefix] = Input(test_src_tlst, None, 1, volatile=True, batch_sort=False) wlog('\n## Finish to Prepare Dataset ! ##\n') nmtModel = NMT(src_vocab_size, trg_vocab_size) if wargs.pre_train is not None: assert os.path.exists(wargs.pre_train) _dict = _load_model(wargs.pre_train) # initializing parameters of interactive attention model class_dict = None if len(_dict) == 4: model_dict, eid, bid, optim = _dict elif len(_dict) == 5: model_dict, class_dict, eid, bid, optim = _dict for name, param in nmtModel.named_parameters(): if name in model_dict: param.requires_grad = not wargs.fix_pre_params param.data.copy_(model_dict[name]) wlog('{:7} -> grad {}\t{}'.format('Model', param.requires_grad, name)) elif name.endswith('map_vocab.weight'): if class_dict is not None: param.requires_grad = not wargs.fix_pre_params param.data.copy_(class_dict['map_vocab.weight']) wlog('{:7} -> grad {}\t{}'.format('Model', param.requires_grad, name)) elif name.endswith('map_vocab.bias'): if class_dict is not None: param.requires_grad = not wargs.fix_pre_params param.data.copy_(class_dict['map_vocab.bias']) wlog('{:7} -> grad {}\t{}'.format('Model', param.requires_grad, name)) else: init_params(param, name, True) wargs.start_epoch = eid + 1 else: for n, p in nmtModel.named_parameters(): init_params(p, n, True) optim = Optim(wargs.opt_mode, wargs.learning_rate, wargs.max_grad_norm, learning_rate_decay=wargs.learning_rate_decay, start_decay_from=wargs.start_decay_from, last_valid_bleu=wargs.last_valid_bleu, model=wargs.model) if wargs.gpu_id is not None: wlog('Push model onto GPU {} ... '.format(wargs.gpu_id), 0) nmtModel.cuda() else: wlog('Push model onto CPU ... ', 0) nmtModel.cpu() wlog('done.') wlog(nmtModel) wlog(optim) pcnt1 = len([p for p in nmtModel.parameters()]) pcnt2 = sum([p.nelement() for p in nmtModel.parameters()]) wlog('Parameters number: {}/{}'.format(pcnt1, pcnt2)) optim.init_optimizer(nmtModel.parameters()) trainer = Trainer(nmtModel, batch_train, vocabs, optim, batch_valid, batch_tests) trainer.train()
def main(): # Check if CUDA is available if cuda.is_available(): wlog('CUDA is available, specify device by gpu_id argument (i.e. gpu_id=[3])') else: wlog('Warning: CUDA is not available, try CPU') if wargs.gpu_id: cuda.set_device(wargs.gpu_id[0]) wlog('Using GPU {}'.format(wargs.gpu_id[0])) init_dir(wargs.dir_model) init_dir(wargs.dir_valid) init_dir(wargs.dir_tests) for prefix in wargs.tests_prefix: if not prefix == wargs.val_prefix: init_dir(wargs.dir_tests + '/' + prefix) wlog('Preparing data ... ', 0) train_srcD_file = wargs.dir_data + 'train.10k.zh5' wlog('\nPreparing source vocabulary from {} ... '.format(train_srcD_file)) src_vocab = extract_vocab(train_srcD_file, wargs.src_dict, wargs.src_dict_size) train_trgD_file = wargs.dir_data + 'train.10k.en5' wlog('\nPreparing target vocabulary from {} ... '.format(train_trgD_file)) trg_vocab = extract_vocab(train_trgD_file, wargs.trg_dict, wargs.trg_dict_size) train_src_file = wargs.dir_data + 'train.10k.zh0' train_trg_file = wargs.dir_data + 'train.10k.en0' wlog('\nPreparing training set from {} and {} ... '.format(train_src_file, train_trg_file)) train_src_tlst, train_trg_tlst = wrap_data(train_src_file, train_trg_file, src_vocab, trg_vocab) #list [torch.LongTensor (sentence), torch.LongTensor, torch.LongTensor, ...], no padding wlog('Sentence-pairs count in training data: {}'.format(len(train_src_tlst))) src_vocab_size, trg_vocab_size = src_vocab.size(), trg_vocab.size() wlog('Vocabulary size: |source|={}, |target|={}'.format(src_vocab_size, trg_vocab_size)) batch_train = Input(train_src_tlst, train_trg_tlst, wargs.batch_size) tests_data = None if wargs.tests_prefix is not None: tests_data = {} for prefix in wargs.tests_prefix: test_file = wargs.val_tst_dir + prefix + '.src' test_src_tlst, _ = val_wrap_data(test_file, src_vocab) # we select best model by nist03 testing data if prefix == wargs.val_prefix: wlog('\nPreparing model-select set from {} ... '.format(test_file)) batch_valid = Input(test_src_tlst, None, 1, volatile=True, prefix=prefix) else: wlog('\nPreparing test set from {} ... '.format(test_file)) tests_data[prefix] = Input(test_src_tlst, None, 1, volatile=True) nmtModel = NMT() classifier = Classifier(wargs.out_size, trg_vocab_size) if wargs.pre_train: model_dict, class_dict, eid, bid, optim = load_pytorch_model(wargs.pre_train) if isinstance(optim, list): _, _, optim = optim # initializing parameters of interactive attention model for p in nmtModel.named_parameters(): p[1].data = model_dict[p[0]] for p in classifier.named_parameters(): p[1].data = class_dict[p[0]] #wargs.start_epoch = eid + 1 else: for p in nmtModel.parameters(): init_params(p, uniform=True) for p in classifier.parameters(): init_params(p, uniform=True) optim = Optim( wargs.opt_mode, wargs.learning_rate, wargs.max_grad_norm, learning_rate_decay=wargs.learning_rate_decay, start_decay_from=wargs.start_decay_from, last_valid_bleu=wargs.last_valid_bleu ) if wargs.gpu_id: wlog('Push model onto GPU ... ') nmtModel.cuda() classifier.cuda() else: wlog('Push model onto CPU ... ') nmtModel.cpu() classifier.cuda() nmtModel.classifier = classifier wlog(nmtModel) pcnt1 = len([p for p in nmtModel.parameters()]) pcnt2 = sum([p.nelement() for p in nmtModel.parameters()]) wlog('Parameters number: {}/{}'.format(pcnt1, pcnt2)) optim.init_optimizer(nmtModel.parameters()) #tor = Translator(nmtModel, src_vocab.idx2key, trg_vocab.idx2key) #tor.trans_tests(tests_data, pre_dict['epoch'], pre_dict['batch']) trainer = Trainer(nmtModel, src_vocab.idx2key, trg_vocab.idx2key, optim, trg_vocab_size) dev_src0 = wargs.dir_data + 'dev.1k.zh0' dev_trg0 = wargs.dir_data + 'dev.1k.en0' wlog('\nPreparing dev set for tuning from {} and {} ... '.format(dev_src0, dev_trg0)) dev_src0, dev_trg0 = wrap_data(dev_src0, dev_trg0, src_vocab, trg_vocab) wlog(len(train_src_tlst)) # add 1000 to train train_all_chunks = (train_src_tlst, train_trg_tlst) dh = DataHisto(train_all_chunks) dev_src1 = wargs.dir_data + 'dev.1k.zh1' dev_trg1 = wargs.dir_data + 'dev.1k.en1' wlog('\nPreparing dev set for tuning from {} and {} ... '.format(dev_src1, dev_trg1)) dev_src1, dev_trg1 = wrap_data(dev_src1, dev_trg1, src_vocab, trg_vocab) dev_src2 = wargs.dir_data + 'dev.1k.zh2' dev_trg2 = wargs.dir_data + 'dev.1k.en2' wlog('\nPreparing dev set for tuning from {} and {} ... '.format(dev_src2, dev_trg2)) dev_src2, dev_trg2 = wrap_data(dev_src2, dev_trg2, src_vocab, trg_vocab) dev_src3 = wargs.dir_data + 'dev.1k.zh3' dev_trg3 = wargs.dir_data + 'dev.1k.en3' wlog('\nPreparing dev set for tuning from {} and {} ... '.format(dev_src3, dev_trg3)) dev_src3, dev_trg3 = wrap_data(dev_src3, dev_trg3, src_vocab, trg_vocab) dev_src4 = wargs.dir_data + 'dev.1k.zh4' dev_trg4 = wargs.dir_data + 'dev.1k.en4' wlog('\nPreparing dev set for tuning from {} and {} ... '.format(dev_src4, dev_trg4)) dev_src4, dev_trg4 = wrap_data(dev_src4, dev_trg4, src_vocab, trg_vocab) wlog(len(dev_src4+dev_src3+dev_src2+dev_src1+dev_src0)) dev_input = Input(dev_src4+dev_src3+dev_src2+dev_src1+dev_src0, dev_trg4+dev_trg3+dev_trg2+dev_trg1+dev_trg0, wargs.batch_size) trainer.train(dh, dev_input, 0, batch_valid, tests_data, merge=True, name='DH_{}'.format('dev')) ''' chunk_size = 1000 rand_ids = tc.randperm(len(train_src_tlst))[:chunk_size * 1000] rand_ids = rand_ids.split(chunk_size) #train_chunks = [(dev_src, dev_trg)] train_chunks = [] for k in range(len(rand_ids)): rand_id = rand_ids[k] chunk_src_tlst = [train_src_tlst[i] for i in rand_id] chunk_trg_tlst = [train_trg_tlst[i] for i in rand_id] #wlog('Sentence-pairs count in training data: {}'.format(len(src_samples_train))) #batch_train = Input(train_src_tlst, train_trg_tlst, wargs.batch_size) #batch_train = Input(src_samples_train, trg_samples_train, wargs.batch_size) train_chunks.append((chunk_src_tlst, chunk_trg_tlst)) chunk_D0 = train_chunks[0] dh = DataHisto(chunk_D0) c0_input = Input(chunk_D0[0], chunk_D0[1], wargs.batch_size) trainer.train(dh, c0_input, 0, batch_valid, tests_data, merge=False, name='DH_{}'.format(0)) for k in range(1, len(train_chunks)): wlog('*' * 30, False) wlog(' Next Data {} '.format(k), False) wlog('*' * 30) chunk_Dk = train_chunks[k] ck_input = Input(chunk_Dk[0], chunk_Dk[1], wargs.batch_size) trainer.train(dh, ck_input, k, batch_valid, tests_data, merge=True, name='DH_{}'.format(k)) dh.add_batch_data(chunk_Dk) ''' if tests_data and wargs.final_test: bestModel = NMT() classifier = Classifier(wargs.out_size, trg_vocab_size) assert os.path.exists(wargs.best_model) model_dict = tc.load(wargs.best_model) best_model_dict = model_dict['model'] best_model_dict = {k: v for k, v in best_model_dict.items() if 'classifier' not in k} bestModel.load_state_dict(best_model_dict) classifier.load_state_dict(model_dict['class']) if wargs.gpu_id: wlog('Push NMT model onto GPU ... ') bestModel.cuda() classifier.cuda() else: wlog('Push NMT model onto CPU ... ') bestModel.cpu() classifier.cpu() bestModel.classifier = classifier tor = Translator(bestModel, src_vocab.idx2key, trg_vocab.idx2key) tor.trans_tests(tests_data, model_dict['epoch'], model_dict['batch'])
def main(): init_dir(wargs.dir_model) init_dir(wargs.dir_valid) vocab_data = {} train_srcD_file = wargs.src_vocab_from wlog('\nPreparing source vocabulary from {} ... '.format(train_srcD_file)) src_vocab = extract_vocab(train_srcD_file, wargs.src_dict, wargs.src_dict_size) vocab_data['src'] = src_vocab train_trgD_file = wargs.trg_vocab_from wlog('\nPreparing target vocabulary from {} ... '.format(train_trgD_file)) trg_vocab = extract_vocab(train_trgD_file, wargs.trg_dict, wargs.trg_dict_size) vocab_data['trg'] = trg_vocab train_src_file = wargs.train_src train_trg_file = wargs.train_trg wlog('\nPreparing training set from {} and {} ... '.format( train_src_file, train_trg_file)) train_src_tlst, train_trg_tlst = wrap_data(train_src_file, train_trg_file, src_vocab, trg_vocab, max_seq_len=wargs.max_seq_len) ''' list [torch.LongTensor (sentence), torch.LongTensor, torch.LongTensor, ...] no padding ''' ''' devs = {} dev_src = wargs.val_tst_dir + wargs.val_prefix + '.src' dev_trg = wargs.val_tst_dir + wargs.val_prefix + '.ref0' wlog('\nPreparing dev set for tuning from {} and {} ... '.format(dev_src, dev_trg)) dev_src, dev_trg = wrap_data(dev_src, dev_trg, src_vocab, trg_vocab) devs['src'], devs['trg'] = dev_src, dev_trg ''' valid_file = '{}{}.{}'.format(wargs.val_tst_dir, wargs.val_prefix, wargs.val_src_suffix) wlog('\nPreparing validation set from {} ... '.format(valid_file)) valid_src_tlst, valid_src_lens = val_wrap_data(valid_file, src_vocab) wlog('Sentence-pairs count in training data: {}'.format( len(train_src_tlst))) src_vocab_size, trg_vocab_size = vocab_data['src'].size( ), vocab_data['trg'].size() wlog('Vocabulary size: |source|={}, |target|={}'.format( src_vocab_size, trg_vocab_size)) batch_train = Input(train_src_tlst, train_trg_tlst, wargs.batch_size) batch_valid = Input(valid_src_tlst, None, 1, volatile=True) tests_data = None if wargs.tests_prefix is not None: init_dir(wargs.dir_tests) tests_data = {} for prefix in wargs.tests_prefix: init_dir(wargs.dir_tests + '/' + prefix) test_file = '{}{}.{}'.format(wargs.val_tst_dir, prefix, wargs.val_src_suffix) wlog('Preparing test set from {} ... '.format(test_file)) test_src_tlst, _ = val_wrap_data(test_file, src_vocab) tests_data[prefix] = Input(test_src_tlst, None, 1, volatile=True) ''' # lookup_table on cpu to save memory src_lookup_table = nn.Embedding(wargs.src_dict_size + 4, wargs.src_wemb_size, padding_idx=utils.PAD).cpu() trg_lookup_table = nn.Embedding(wargs.trg_dict_size + 4, wargs.trg_wemb_size, padding_idx=utils.PAD).cpu() wlog('Lookup table on CPU ... ') wlog(src_lookup_table) wlog(trg_lookup_table) ''' sv = vocab_data['src'].idx2key tv = vocab_data['trg'].idx2key nmtModel = NMT(src_vocab_size, trg_vocab_size) #classifier = Classifier(wargs.out_size, trg_vocab_size, # nmtModel.decoder.trg_lookup_table if wargs.copy_trg_emb is True else None) if wargs.pre_train: assert os.path.exists(wargs.pre_train) _dict = _load_model(wargs.pre_train) # initializing parameters of interactive attention model class_dict = None if len(_dict) == 4: model_dict, eid, bid, optim = _dict elif len(_dict) == 5: model_dict, class_dict, eid, bid, optim = _dict for name, param in nmtModel.named_parameters(): if name in model_dict: param.requires_grad = not wargs.fix_pre_params param.data.copy_(model_dict[name]) wlog('{:7} -> grad {}\t{}'.format('Model', param.requires_grad, name)) elif name.endswith('map_vocab.weight'): if class_dict is not None: param.requires_grad = not wargs.fix_pre_params param.data.copy_(class_dict['map_vocab.weight']) wlog('{:7} -> grad {}\t{}'.format('Model', param.requires_grad, name)) elif name.endswith('map_vocab.bias'): if class_dict is not None: param.requires_grad = not wargs.fix_pre_params param.data.copy_(class_dict['map_vocab.bias']) wlog('{:7} -> grad {}\t{}'.format('Model', param.requires_grad, name)) else: init_params(param, name, True) wargs.start_epoch = eid + 1 #tor = Translator(nmtModel, sv, tv) #tor.trans_tests(tests_data, eid, bid) else: for n, p in nmtModel.named_parameters(): init_params(p, n, True) #for n, p in classifier.named_parameters(): init_params(p, n, True) optim = Optim(wargs.opt_mode, wargs.learning_rate, wargs.max_grad_norm, learning_rate_decay=wargs.learning_rate_decay, start_decay_from=wargs.start_decay_from, last_valid_bleu=wargs.last_valid_bleu) if wargs.gpu_id: nmtModel.cuda() #classifier.cuda() wlog('Push model onto GPU[{}] ... '.format(wargs.gpu_id[0])) else: nmtModel.cpu() #classifier.cpu() wlog('Push model onto CPU ... ') #nmtModel.classifier = classifier #nmtModel.decoder.map_vocab = classifier.map_vocab ''' nmtModel.src_lookup_table = src_lookup_table nmtModel.trg_lookup_table = trg_lookup_table print nmtModel.src_lookup_table.weight.data.is_cuda nmtModel.classifier.init_weights(nmtModel.trg_lookup_table) ''' wlog(nmtModel) wlog(optim) pcnt1 = len([p for p in nmtModel.parameters()]) pcnt2 = sum([p.nelement() for p in nmtModel.parameters()]) wlog('Parameters number: {}/{}'.format(pcnt1, pcnt2)) optim.init_optimizer(nmtModel.parameters()) #tor = Translator(nmtModel, sv, tv, wargs.search_mode) #tor.trans_tests(tests_data, pre_dict['epoch'], pre_dict['batch']) trainer = Trainer(nmtModel, batch_train, vocab_data, optim, batch_valid, tests_data) trainer.train()
def main(): # Check if CUDA is available if cuda.is_available(): wlog( 'CUDA is available, specify device by gpu_id argument (i.e. gpu_id=[3])' ) else: wlog('Warning: CUDA is not available, try CPU') if wargs.gpu_id: cuda.set_device(wargs.gpu_id[0]) wlog('Using GPU {}'.format(wargs.gpu_id[0])) init_dir(wargs.dir_model) init_dir(wargs.dir_valid) ''' train_srcD_file = wargs.dir_data + 'train.10k.zh5' wlog('\nPreparing source vocabulary from {} ... '.format(train_srcD_file)) src_vocab = extract_vocab(train_srcD_file, wargs.src_dict, wargs.src_dict_size) train_trgD_file = wargs.dir_data + 'train.10k.en5' wlog('\nPreparing target vocabulary from {} ... '.format(train_trgD_file)) trg_vocab = extract_vocab(train_trgD_file, wargs.trg_dict, wargs.trg_dict_size) train_src_file = wargs.dir_data + 'train.10k.zh0' train_trg_file = wargs.dir_data + 'train.10k.en0' wlog('\nPreparing training set from {} and {} ... '.format(train_src_file, train_trg_file)) train_src_tlst, train_trg_tlst = wrap_data(train_src_file, train_trg_file, src_vocab, trg_vocab) #list [torch.LongTensor (sentence), torch.LongTensor, torch.LongTensor, ...], no padding wlog('Sentence-pairs count in training data: {}'.format(len(train_src_tlst))) src_vocab_size, trg_vocab_size = src_vocab.size(), trg_vocab.size() wlog('Vocabulary size: |source|={}, |target|={}'.format(src_vocab_size, trg_vocab_size)) batch_train = Input(train_src_tlst, train_trg_tlst, wargs.batch_size) ''' src = os.path.join( wargs.dir_data, '{}.{}'.format(wargs.train_prefix, wargs.train_src_suffix)) trg = os.path.join( wargs.dir_data, '{}.{}'.format(wargs.train_prefix, wargs.train_trg_suffix)) vocabs = {} wlog('\nPreparing source vocabulary from {} ... '.format(src)) src_vocab = extract_vocab(src, wargs.src_dict, wargs.src_dict_size) wlog('\nPreparing target vocabulary from {} ... '.format(trg)) trg_vocab = extract_vocab(trg, wargs.trg_dict, wargs.trg_dict_size) src_vocab_size, trg_vocab_size = src_vocab.size(), trg_vocab.size() wlog('Vocabulary size: |source|={}, |target|={}'.format( src_vocab_size, trg_vocab_size)) vocabs['src'], vocabs['trg'] = src_vocab, trg_vocab wlog('\nPreparing training set from {} and {} ... '.format(src, trg)) trains = {} train_src_tlst, train_trg_tlst = wrap_data(wargs.dir_data, wargs.train_prefix, wargs.train_src_suffix, wargs.train_trg_suffix, src_vocab, trg_vocab, max_seq_len=wargs.max_seq_len) ''' list [torch.LongTensor (sentence), torch.LongTensor, torch.LongTensor, ...] no padding ''' batch_train = Input(train_src_tlst, train_trg_tlst, wargs.batch_size, batch_sort=True) wlog('Sentence-pairs count in training data: {}'.format( len(train_src_tlst))) batch_valid = None if wargs.val_prefix is not None: val_src_file = '{}{}.{}'.format(wargs.val_tst_dir, wargs.val_prefix, wargs.val_src_suffix) val_trg_file = '{}{}.{}'.format(wargs.val_tst_dir, wargs.val_prefix, wargs.val_ref_suffix) wlog('\nPreparing validation set from {} and {} ... '.format( val_src_file, val_trg_file)) valid_src_tlst, valid_trg_tlst = wrap_data( wargs.val_tst_dir, wargs.val_prefix, wargs.val_src_suffix, wargs.val_ref_suffix, src_vocab, trg_vocab, shuffle=False, sort_data=False, max_seq_len=wargs.dev_max_seq_len) batch_valid = Input(valid_src_tlst, valid_trg_tlst, 1, volatile=True, batch_sort=False) batch_tests = None if wargs.tests_prefix is not None: assert isinstance(wargs.tests_prefix, list), 'Test files should be list.' init_dir(wargs.dir_tests) batch_tests = {} for prefix in wargs.tests_prefix: init_dir(wargs.dir_tests + '/' + prefix) test_file = '{}{}.{}'.format(wargs.val_tst_dir, prefix, wargs.val_src_suffix) wlog('\nPreparing test set from {} ... '.format(test_file)) test_src_tlst, _ = wrap_tst_data(test_file, src_vocab) batch_tests[prefix] = Input(test_src_tlst, None, 1, volatile=True, batch_sort=False) wlog('\n## Finish to Prepare Dataset ! ##\n') nmtModel = NMT(src_vocab_size, trg_vocab_size) if wargs.pre_train is not None: assert os.path.exists(wargs.pre_train) _dict = _load_model(wargs.pre_train) # initializing parameters of interactive attention model class_dict = None if len(_dict) == 4: model_dict, eid, bid, optim = _dict elif len(_dict) == 5: model_dict, class_dict, eid, bid, optim = _dict for name, param in nmtModel.named_parameters(): if name in model_dict: param.requires_grad = not wargs.fix_pre_params param.data.copy_(model_dict[name]) wlog('{:7} -> grad {}\t{}'.format('Model', param.requires_grad, name)) elif name.endswith('map_vocab.weight'): if class_dict is not None: param.requires_grad = not wargs.fix_pre_params param.data.copy_(class_dict['map_vocab.weight']) wlog('{:7} -> grad {}\t{}'.format('Model', param.requires_grad, name)) elif name.endswith('map_vocab.bias'): if class_dict is not None: param.requires_grad = not wargs.fix_pre_params param.data.copy_(class_dict['map_vocab.bias']) wlog('{:7} -> grad {}\t{}'.format('Model', param.requires_grad, name)) else: init_params(param, name, True) wargs.start_epoch = eid + 1 else: for n, p in nmtModel.named_parameters(): init_params(p, n, True) optim = Optim(wargs.opt_mode, wargs.learning_rate, wargs.max_grad_norm, learning_rate_decay=wargs.learning_rate_decay, start_decay_from=wargs.start_decay_from, last_valid_bleu=wargs.last_valid_bleu) optim.init_optimizer(nmtModel.parameters()) if wargs.gpu_id: wlog('Push model onto GPU {} ... '.format(wargs.gpu_id), 0) nmtModel.cuda() else: wlog('Push model onto CPU ... ', 0) nmtModel.cpu() wlog('done.') wlog(nmtModel) wlog(optim) pcnt1 = len([p for p in nmtModel.parameters()]) pcnt2 = sum([p.nelement() for p in nmtModel.parameters()]) wlog('Parameters number: {}/{}'.format(pcnt1, pcnt2)) trainer = Trainer(nmtModel, src_vocab.idx2key, trg_vocab.idx2key, optim, trg_vocab_size, valid_data=batch_valid, tests_data=batch_tests) # add 1000 to train train_all_chunks = (train_src_tlst, train_trg_tlst) dh = DataHisto(train_all_chunks) ''' dev_src0 = wargs.dir_data + 'dev.1k.zh0' dev_trg0 = wargs.dir_data + 'dev.1k.en0' wlog('\nPreparing dev set for tuning from {} and {} ... '.format(dev_src0, dev_trg0)) dev_src0, dev_trg0 = wrap_data(dev_src0, dev_trg0, src_vocab, trg_vocab) wlog(len(train_src_tlst)) dev_src1 = wargs.dir_data + 'dev.1k.zh1' dev_trg1 = wargs.dir_data + 'dev.1k.en1' wlog('\nPreparing dev set for tuning from {} and {} ... '.format(dev_src1, dev_trg1)) dev_src1, dev_trg1 = wrap_data(dev_src1, dev_trg1, src_vocab, trg_vocab) dev_src2 = wargs.dir_data + 'dev.1k.zh2' dev_trg2 = wargs.dir_data + 'dev.1k.en2' wlog('\nPreparing dev set for tuning from {} and {} ... '.format(dev_src2, dev_trg2)) dev_src2, dev_trg2 = wrap_data(dev_src2, dev_trg2, src_vocab, trg_vocab) dev_src3 = wargs.dir_data + 'dev.1k.zh3' dev_trg3 = wargs.dir_data + 'dev.1k.en3' wlog('\nPreparing dev set for tuning from {} and {} ... '.format(dev_src3, dev_trg3)) dev_src3, dev_trg3 = wrap_data(dev_src3, dev_trg3, src_vocab, trg_vocab) dev_src4 = wargs.dir_data + 'dev.1k.zh4' dev_trg4 = wargs.dir_data + 'dev.1k.en4' wlog('\nPreparing dev set for tuning from {} and {} ... '.format(dev_src4, dev_trg4)) dev_src4, dev_trg4 = wrap_data(dev_src4, dev_trg4, src_vocab, trg_vocab) wlog(len(dev_src4+dev_src3+dev_src2+dev_src1+dev_src0)) batch_dev = Input(dev_src4+dev_src3+dev_src2+dev_src1+dev_src0, dev_trg4+dev_trg3+dev_trg2+dev_trg1+dev_trg0, wargs.batch_size) ''' batch_dev = None assert wargs.dev_prefix is not None, 'Requires development to tuning.' dev_src_file = '{}{}.{}'.format(wargs.val_tst_dir, wargs.dev_prefix, wargs.val_src_suffix) dev_trg_file = '{}{}.{}'.format(wargs.val_tst_dir, wargs.dev_prefix, wargs.val_ref_suffix) wlog('\nPreparing dev set from {} and {} ... '.format( dev_src_file, dev_trg_file)) valid_src_tlst, valid_trg_tlst = wrap_data( wargs.val_tst_dir, wargs.dev_prefix, wargs.val_src_suffix, wargs.val_ref_suffix, src_vocab, trg_vocab, shuffle=True, sort_data=True, max_seq_len=wargs.dev_max_seq_len) batch_dev = Input(valid_src_tlst, valid_trg_tlst, wargs.batch_size, batch_sort=True) trainer.train(dh, batch_dev, 0, merge=True, name='DH_{}'.format('dev')) '''