def train(self): wlog('Start training ... ') assert wargs.sample_size < wargs.batch_size, 'Batch size < sample count' # [low, high) batch_count = len(self.train_data) if wargs.fine_tune is False else 0 batch_count_domain = len(self.train_data_domain) batch_start_sample = tc.randperm(batch_count + batch_count_domain)[0] wlog('Randomly select {} samples in the {}th/{} batch'.format( wargs.sample_size, batch_start_sample, batch_count + batch_count_domain)) bidx, eval_cnt = 0, [0] wlog('Self-normalization alpha -> {}'.format(wargs.self_norm_alpha)) train_start = time.time() wlog('') wlog('#' * 120) wlog('#' * 30, False) wlog(' Start Training ', False) wlog('#' * 30) wlog('#' * 120) #DANN loss_domain = tc.nn.NLLLoss() for epoch in range(wargs.start_epoch, wargs.max_epochs + 1): epoch_start = time.time() # train for one epoch on the training data wlog('') wlog('$' * 30, False) wlog(' Epoch [{}/{}] '.format(epoch, wargs.max_epochs), False) wlog('$' * 30) if wargs.epoch_shuffle and epoch > wargs.epoch_shuffle_minibatch: self.train_data.shuffle() self.train_data_domain.shuffle() # shuffle the original batch shuffled_batch_idx = tc.randperm(batch_count + batch_count_domain) sample_size = wargs.sample_size epoch_loss, epoch_trg_words, epoch_num_correct = 0, 0, 0 show_loss, show_src_words, show_trg_words, show_correct_num = 0, 0, 0, 0 domain_loss = 0 sample_spend, eval_spend, epoch_bidx = 0, 0, 0 show_start = time.time() for name, par in self.model.named_parameters(): if name.split('.')[0] != "domain_discriminator": par.requires_grad = True else: par.requires_grad = False for k in range(batch_count + batch_count_domain): bidx += 1 epoch_bidx = k + 1 batch_idx = shuffled_batch_idx[ k] if epoch >= wargs.epoch_shuffle_minibatch else k #p = float(k + epoch * (batch_count+batch_count_domain)) / wargs.max_epochs / (batch_count+batch_count_domain) #alpha = 2. / (1. + np.exp(-10 * p)) - 1 if batch_idx < batch_count: # (max_slen_batch, batch_size) _, srcs, trgs, slens, srcs_m, trgs_m = self.train_data[ batch_idx] self.model.zero_grad() # (max_tlen_batch - 1, batch_size, out_size) #forward compute DANN out of domain decoder_outputs, domain_outputs = self.model( srcs, trgs[:-1], srcs_m, trgs_m[:-1], 'OUT', alpha=wargs.alpha) if len(decoder_outputs) == 2: (decoder_outputs, _checks) = decoder_outputs this_bnum = decoder_outputs.size(1) #batch_loss, grad_output, batch_correct_num = memory_efficient( # outputs, trgs[1:], trgs_m[1:], self.model.classifier) #backward compute, now we have the grad batch_loss, batch_correct_num, batch_log_norm = self.model.classifier.snip_back_prop( decoder_outputs, trgs[1:], trgs_m[1:], wargs.snip_size) #self.model.zero_grad() #domain_outputs = self.model(srcs, None, srcs_m, None, 'OUT', alpha=alpha) #if wargs.cross_entropy is True: # domain_label = tc.zeros(len(domain_outputs)) # domain_label = domain_label.long() # domain_label = domain_label.cuda() #domainv_label = Variable(domain_label, volatile=False) # domain_loss_a = loss_domain(tc.log(domain_outputs), domainv_label) # domain_loss_a.backward(retain_graph = True if wargs.max_entropy is True else False) if wargs.max_entropy is True: #try to max this entropy lam = wargs.alpha if epoch > 1 else 0 domain_loss_b = lam * tc.dot(tc.log(domain_outputs), domain_outputs) domain_loss_b.backward() batch_src_words = srcs.data.ne(PAD).sum() assert batch_src_words == slens.data.sum() batch_trg_words = trgs[1:].data.ne(PAD).sum() elif batch_idx >= batch_count: #domain_loss_out.backward(retain_graph=True) #DANN in domain compute _, srcs_domain, trgs_domain, slens_domain, srcs_m_domain, trgs_m_domain = self.train_data_domain[ batch_idx - batch_count] self.model.zero_grad() decoder_outputs, domain_outputs = self.model( srcs_domain, trgs_domain[:-1], srcs_m_domain, trgs_m_domain[:-1], 'IN', alpha=wargs.alpha) if len(decoder_outputs) == 2: (decoder_outputs, _checks) = decoder_outputs this_bnum = decoder_outputs.size(1) batch_loss, batch_correct_num, batch_log_norm = self.model.classifier.snip_back_prop( decoder_outputs, trgs_domain[1:], trgs_m_domain[1:], wargs.snip_size) #domain_outputs = self.model(srcs_domain, None, srcs_m_domain, None, 'IN', alpha=alpha) #if wargs.cross_entropy is True: # domain_label = tc.ones(len(domain_outputs)) # domain_label = domain_label.long() # domain_label = domain_label.cuda() #domainv_label = Variable(domain_label, volatile=False) # domain_loss_a = loss_domain(tc.log(domain_outputs), domainv_label) # domain_loss_a.backward(retain_graph = True if wargs.max_entropy is True else False) if wargs.max_entropy is True: lam = wargs.alpha if epoch > 1 else 0 domain_loss_b = lam * tc.dot(tc.log(domain_outputs), domain_outputs) domain_loss_b.backward() batch_src_words = srcs_domain.data.ne(PAD).sum() assert batch_src_words == slens_domain.data.sum() batch_trg_words = trgs_domain[1:].data.ne(PAD).sum() _grad_nan = False for n, p in self.model.named_parameters(): if p.grad is None: debug('grad None | {}'.format(n)) continue tmp_grad = p.grad.data.cpu().numpy() if numpy.isnan(tmp_grad).any( ): # we check gradient here for vanishing Gradient wlog("grad contains 'nan' | {}".format(n)) #wlog("gradient\n{}".format(tmp_grad)) _grad_nan = True if n == 'decoder.l_f1_0.weight' or n == 's_init.weight' or n=='decoder.l_f1_1.weight' \ or n == 'decoder.l_conv.0.weight' or n == 'decoder.l_f2.weight': debug('grad zeros |{:5} {}'.format( str(not np.any(tmp_grad)), n)) if _grad_nan is True and wargs.dynamic_cyk_decoding is True: for _i, items in enumerate(_checks): wlog('step {} Variable----------------:'.format(_i)) #for item in items: wlog(item.cpu().data.numpy()) wlog('wen _check_tanh_sa ---------------') wlog(items[0].cpu().data.numpy()) wlog('wen _check_a1_weight ---------------') wlog(items[1].cpu().data.numpy()) wlog('wen _check_a1 ---------------') wlog(items[2].cpu().data.numpy()) wlog('wen alpha_ij---------------') wlog(items[3].cpu().data.numpy()) wlog('wen before_mask---------------') wlog(items[4].cpu().data.numpy()) wlog('wen after_mask---------------') wlog(items[5].cpu().data.numpy()) #outputs.backward(grad_output) self.optim.step() #del outputs, grad_output show_loss += batch_loss #domain_loss += domain_loss_a.data.clone()[0] #domain_loss += domain_loss_out show_correct_num += batch_correct_num epoch_loss += batch_loss epoch_num_correct += batch_correct_num show_src_words += batch_src_words show_trg_words += batch_trg_words epoch_trg_words += batch_trg_words batch_log_norm = tc.mean(tc.abs(batch_log_norm)) if epoch_bidx % wargs.display_freq == 0: #print show_correct_num, show_loss, show_trg_words, show_loss/show_trg_words ud = time.time() - show_start - sample_spend - eval_spend wlog( 'Epo:{:>2}/{:>2} |[{:^5} {:^5} {:^5}k] |acc:{:5.2f}% |ppl:{:4.2f} ' '| |logZ|:{:.2f} ' '|stok/s:{:>4}/{:>2}={:>2} |ttok/s:{:>2} ' '|stok/sec:{:6.2f} |ttok/sec:{:6.2f} |elapsed:{:4.2f}/{:4.2f}m' .format(epoch, wargs.max_epochs, epoch_bidx, batch_idx, bidx / 1000, (show_correct_num / show_trg_words) * 100, math.exp(show_loss / show_trg_words), batch_log_norm, batch_src_words, this_bnum, int(batch_src_words / this_bnum), int(batch_trg_words / this_bnum), show_src_words / ud, show_trg_words / ud, ud, (time.time() - train_start) / 60.)) show_loss, show_src_words, show_trg_words, show_correct_num = 0, 0, 0, 0 sample_spend, eval_spend = 0, 0 show_start = time.time() if epoch_bidx % wargs.sampling_freq == 0: sample_start = time.time() self.model.eval() #self.model.classifier.eval() tor = Translator(self.model, self.sv, self.tv) if batch_idx < batch_count: # (max_len_batch, batch_size) sample_src_tensor = srcs.t()[:sample_size] sample_trg_tensor = trgs.t()[:sample_size] tor.trans_samples(sample_src_tensor, sample_trg_tensor, "OUT") elif batch_idx >= batch_count: sample_src_tensor = srcs_domain.t()[:sample_size] sample_trg_tensor = trgs_domain.t()[:sample_size] tor.trans_samples(sample_src_tensor, sample_trg_tensor, "IN") wlog('') sample_spend = time.time() - sample_start self.model.train() # Just watch the translation of some source sentences in training data if wargs.if_fixed_sampling and bidx == batch_start_sample: # randomly select sample_size sample from current batch rand_rows = np.random.choice(this_bnum, sample_size, replace=False) sample_src_tensor = tc.Tensor(sample_size, srcs.size(0)).long() sample_src_tensor.fill_(PAD) sample_trg_tensor = tc.Tensor(sample_size, trgs.size(0)).long() sample_trg_tensor.fill_(PAD) for id in xrange(sample_size): sample_src_tensor[id, :] = srcs.t()[rand_rows[id], :] sample_trg_tensor[id, :] = trgs.t()[rand_rows[id], :] if wargs.epoch_eval is not True and bidx > wargs.eval_valid_from and \ bidx % wargs.eval_valid_freq == 0: eval_start = time.time() eval_cnt[0] += 1 wlog('\nAmong epoch, batch [{}], [{}] eval save model ...'. format(epoch_bidx, eval_cnt[0])) self.mt_eval(epoch, epoch_bidx, "IN") eval_spend = time.time() - eval_start for name, par in self.model.named_parameters(): if name.split('.')[0] != "domain_discriminator": par.requires_grad = False else: par.requires_grad = True for k in range(batch_count + batch_count_domain): epoch_bidx = k + 1 batch_idx = shuffled_batch_idx[ k] if epoch >= wargs.epoch_shuffle_minibatch else k if batch_idx < batch_count: _, srcs, trgs, slens, srcs_m, trgs_m = self.train_data[ batch_idx] self.model.zero_grad() # (max_tlen_batch - 1, batch_size, out_size) #forward compute DANN out of domain domain_outputs = self.model(srcs, trgs[:-1], srcs_m, trgs_m[:-1], 'OUT', alpha=wargs.alpha, adv=True) #if len(decoder_outputs) == 2: (decoder_outputs, _checks) = decoder_outputs #this_bnum = decoder_outputs.size(1) #batch_loss, grad_output, batch_correct_num = memory_efficient( # outputs, trgs[1:], trgs_m[1:], self.model.classifier) #backward compute, now we have the grad #batch_loss, batch_correct_num, batch_log_norm = self.model.classifier.snip_back_prop( # decoder_outputs, trgs[1:], trgs_m[1:], wargs.snip_size) #self.model.zero_grad() #domain_outputs = self.model(srcs, None, srcs_m, None, 'OUT', alpha=alpha) if wargs.cross_entropy is True: domain_label = tc.zeros(len(domain_outputs)) domain_label = domain_label.long() domain_label = domain_label.cuda() domainv_label = Variable(domain_label, volatile=False) domain_loss_a = loss_domain(tc.log(domain_outputs), domainv_label) domain_loss_a.backward() #if wargs.max_entropy is True: #try to max this entropy # domain_loss_b = -wargs.alpha*tc.dot(tc.log(domain_outputs), domain_outputs) # domain_loss_b.backward() batch_src_words = srcs.data.ne(PAD).sum() assert batch_src_words == slens.data.sum() batch_trg_words = trgs[1:].data.ne(PAD).sum() elif batch_idx >= batch_count: #domain_loss_out.backward(retain_graph=True) #DANN in domain compute _, srcs_domain, trgs_domain, slens_domain, srcs_m_domain, trgs_m_domain = self.train_data_domain[ batch_idx - batch_count] self.model.zero_grad() domain_outputs = self.model(srcs_domain, trgs_domain[:-1], srcs_m_domain, trgs_m_domain[:-1], 'IN', alpha=wargs.alpha, adv=True) #if len(decoder_outputs) == 2: (decoder_outputs, _checks) = decoder_outputs #this_bnum = decoder_outputs.size(1) #batch_loss, batch_correct_num, batch_log_norm = self.model.classifier.snip_back_prop( # decoder_outputs, trgs_domain[1:], trgs_m_domain[1:], wargs.snip_size) #domain_outputs = self.model(srcs_domain, None, srcs_m_domain, None, 'IN', alpha=alpha) if wargs.cross_entropy is True: domain_label = tc.ones(len(domain_outputs)) domain_label = domain_label.long() domain_label = domain_label.cuda() domainv_label = Variable(domain_label, volatile=False) domain_loss_a = loss_domain(tc.log(domain_outputs), domainv_label) domain_loss_a.backward( retain_graph=True if wargs.max_entropy is True else False) #if wargs.max_entropy is True: # domain_loss_b = -wargs.alpha*tc.dot(tc.log(domain_outputs), domain_outputs) # domain_loss_b.backward() batch_src_words = srcs_domain.data.ne(PAD).sum() assert batch_src_words == slens_domain.data.sum() batch_trg_words = trgs_domain[1:].data.ne(PAD).sum() self.optim.step() show_loss += batch_loss #domain_loss += domain_loss_a.data.clone()[0] #domain_loss += domain_loss_out show_correct_num += batch_correct_num epoch_loss += batch_loss epoch_num_correct += batch_correct_num show_src_words += batch_src_words show_trg_words += batch_trg_words epoch_trg_words += batch_trg_words #batch_log_norm = tc.mean(tc.abs(batch_log_norm)) if epoch_bidx % wargs.display_freq == 0: #print show_correct_num, show_loss, show_trg_words, show_loss/show_trg_words ud = time.time() - show_start - sample_spend - eval_spend wlog( 'Epo:{:>2}/{:>2} |[{:^5} {:^5} {:^5}k] |acc:{:5.2f}% |ppl:{:4.2f} ' '| |logZ|:{:.2f} ' '|stok/s:{:>4}/{:>2}={:>2} |ttok/s:{:>2} ' '|stok/sec:{:6.2f} |ttok/sec:{:6.2f} |elapsed:{:4.2f}/{:4.2f}m' .format(epoch, wargs.max_epochs, epoch_bidx, batch_idx, bidx / 1000, (show_correct_num / show_trg_words) * 100, math.exp(show_loss / show_trg_words), 0, batch_src_words, this_bnum, int(batch_src_words / this_bnum), int(batch_trg_words / this_bnum), show_src_words / ud, show_trg_words / ud, ud, (time.time() - train_start) / 60.)) show_loss, show_src_words, show_trg_words, show_correct_num = 0, 0, 0, 0 sample_spend, eval_spend = 0, 0 show_start = time.time() if epoch_bidx % wargs.sampling_freq == 0: sample_start = time.time() self.model.eval() #self.model.classifier.eval() tor = Translator(self.model, self.sv, self.tv) if batch_idx < batch_count: # (max_len_batch, batch_size) sample_src_tensor = srcs.t()[:sample_size] sample_trg_tensor = trgs.t()[:sample_size] tor.trans_samples(sample_src_tensor, sample_trg_tensor, "OUT") elif batch_idx >= batch_count: sample_src_tensor = srcs_domain.t()[:sample_size] sample_trg_tensor = trgs_domain.t()[:sample_size] tor.trans_samples(sample_src_tensor, sample_trg_tensor, "IN") wlog('') sample_spend = time.time() - sample_start self.model.train() # Just watch the translation of some source sentences in training data if wargs.if_fixed_sampling and bidx == batch_start_sample: # randomly select sample_size sample from current batch rand_rows = np.random.choice(this_bnum, sample_size, replace=False) sample_src_tensor = tc.Tensor(sample_size, srcs.size(0)).long() sample_src_tensor.fill_(PAD) sample_trg_tensor = tc.Tensor(sample_size, trgs.size(0)).long() sample_trg_tensor.fill_(PAD) for id in xrange(sample_size): sample_src_tensor[id, :] = srcs.t()[rand_rows[id], :] sample_trg_tensor[id, :] = trgs.t()[rand_rows[id], :] avg_epoch_loss = epoch_loss / epoch_trg_words avg_epoch_acc = epoch_num_correct / epoch_trg_words wlog('\nEnd epoch [{}]'.format(epoch)) wlog('Train accuracy {:4.2f}%'.format(avg_epoch_acc * 100)) wlog('Average loss {:4.2f}'.format(avg_epoch_loss)) wlog('Train perplexity: {0:4.2f}'.format(math.exp(avg_epoch_loss))) #wlog('Epoch domain loss is {:4.2f}'.format(float(domain_loss))) wlog('End epoch, batch [{}], [{}] eval save model ...'.format( epoch_bidx, eval_cnt[0])) mteval_bleu = self.mt_eval(epoch, epoch_bidx, "IN") self.optim.update_learning_rate(mteval_bleu, epoch) epoch_time_consume = time.time() - epoch_start wlog('Consuming: {:4.2f}s'.format(epoch_time_consume)) wlog('Train finished, comsuming {:6.2f} hours'.format( (time.time() - train_start) / 3600)) wlog('Congratulations!')
# "问题 交换 了 意见 。" #t = "( beijing , syndicated news ) the sino - us relation that was heated momentarily " \ # "by the us president bush 's visit to china is cooling down rapidly . china " \ # "confirmed yesterday that it has called off its naval fleet visit to the us " \ # "ports this year and refused to confirm whether the country 's vice president " \ # "hu jintao will visit the united states as planned ." s = '当 林肯 去 新奥尔良 时 , 我 听到 密西 西比 河 的 歌声 。' t = "When Lincoln goes to New Orleans, I hear Mississippi river's singing sound" #s = '新奥尔良 是 爵士 音乐 的 发源 地 。' #s = '新奥尔良 以 其 美食 而 闻名 。' # = '休斯顿 是 仅 次于 新奥尔良 和 纽约 的 美国 第三 大 港 。' s = [[src_vocab.key2idx[x] if x in src_vocab.key2idx else UNK for x in s.split(' ')]] t = [[trg_vocab.key2idx[x] if x in trg_vocab.key2idx else UNK for x in t.split(' ')]] tor.trans_samples(s, t) sys.exit(0) input_file = '{}{}.{}'.format(wargs.val_tst_dir, args.input_file, wargs.val_src_suffix) input_abspath = os.path.realpath(input_file) wlog('Translating test file {} ... '.format(input_abspath)) ref_file = '{}{}.{}'.format(wargs.val_tst_dir, args.input_file, wargs.val_ref_suffix) test_src_tlst, _ = wrap_tst_data(input_abspath, src_vocab, char=wargs.src_char) test_input_data = Input(test_src_tlst, None, batch_size=wargs.test_batch_size, batch_sort=False, gpu_ids=args.gpu_ids) batch_tst_data = None if os.path.exists(ref_file) and False: wlog('With force decoding test file {} ... to get alignments'.format(input_file)) wlog('\t\tRef file {}'.format(ref_file)) from tools.inputs_handler import wrap_data
def train(self): wlog('Start training ... ') assert wargs.sample_size < wargs.batch_size, 'Batch size < sample count' # [low, high) batch_count = len(self.train_data) batch_start_sample = tc.randperm(batch_count)[-1] wlog('Randomly select {} samples in the {}th/{} batch'.format( wargs.sample_size, batch_start_sample, batch_count)) bidx, eval_cnt = 0, [0] wlog('Self-normalization alpha -> {}'.format(wargs.self_norm_alpha)) tor_hook = Translator(self.model, self.sv, self.tv) train_start = time.time() wlog('\n' + '#' * 120 + '\n' + '#' * 30 + ' Start Training ' + '#' * 30 + '\n' + '#' * 120) batch_oracles, _checks = None, None for epoch in range(wargs.start_epoch, wargs.max_epochs + 1): epoch_start = time.time() # train for one epoch on the training data wlog('\n' + '$' * 30, 0) wlog(' Epoch [{}/{}] '.format(epoch, wargs.max_epochs) + '$' * 30) if wargs.epoch_shuffle and epoch > wargs.epoch_shuffle_minibatch: self.train_data.shuffle() # shuffle the original batch shuffled_batch_idx = tc.randperm(batch_count) sample_size = wargs.sample_size epoch_loss, epoch_trg_words, epoch_num_correct, \ epoch_batch_logZ, epoch_n_sents = 0, 0, 0, 0, 0 show_loss, show_src_words, show_trg_words, show_correct_num, \ show_batch_logZ, show_n_sents = 0, 0, 0, 0, 0, 0 sample_spend, eval_spend, epoch_bidx = 0, 0, 0 show_start = time.time() for k in range(batch_count): bidx += 1 epoch_bidx = k + 1 batch_idx = shuffled_batch_idx[ k] if epoch >= wargs.epoch_shuffle_minibatch else k # (max_slen_batch, batch_size) #_, srcs, ttrgs_for_files, slens, srcs_m, trg_mask_for_files = self.train_data[batch_idx] _, srcs, spos, ttrgs_for_files, ttpos_for_files, slens, srcs_m, trg_mask_for_files = self.train_data[ batch_idx] trgs, tpos, trgs_m = ttrgs_for_files[0], ttpos_for_files[ 0], trg_mask_for_files[0] #self.model.zero_grad() self.optim.zero_grad() # (max_tlen_batch - 1, batch_size, out_size) if wargs.model == 8: gold, gold_mask = trgs[1:], trgs_m[1:] # (L, B) -> (B, L) T_srcs, T_spos, T_trgs, T_tpos = srcs.t(), spos.t( ), trgs.t(), tpos.t() src, trg = (T_srcs, T_spos), (T_trgs, T_tpos) gold, gold_mask = gold.t(), gold_mask.t() #gold, gold_mask = trg[0][:, 1:], trgs_m.t()[:, 1:] if gold.is_contiguous() is False: gold = gold.contiguous() if gold_mask.is_contiguous() is False: gold_mask = gold_mask.contiguous() outputs = self.model(src, trg) else: gold, gold_mask = trgs[1:], trgs_m[1:] outputs = self.model(srcs, trgs[:-1], srcs_m, trgs_m[:-1]) if len(outputs) == 2: (outputs, _checks) = outputs if len(outputs) == 2: (outputs, attends) = outputs this_bnum = outputs.size(1) epoch_n_sents += this_bnum show_n_sents += this_bnum #batch_loss, batch_correct_num, batch_log_norm = self.classifier(outputs, trgs[1:], trgs_m[1:]) #batch_loss.div(this_bnum).backward() #batch_loss = batch_loss.data[0] #batch_correct_num = batch_correct_num.data[0] batch_loss, batch_correct_num, batch_Z = self.classifier.snip_back_prop( outputs, gold, gold_mask, wargs.snip_size) self.optim.step() grad_checker(self.model, _checks) batch_src_words = srcs.data.ne(PAD).sum() assert batch_src_words == slens.data.sum() batch_trg_words = trgs[1:].data.ne(PAD).sum() show_loss += batch_loss show_correct_num += batch_correct_num epoch_loss += batch_loss epoch_num_correct += batch_correct_num show_src_words += batch_src_words show_trg_words += batch_trg_words epoch_trg_words += batch_trg_words show_batch_logZ += batch_Z epoch_batch_logZ += batch_Z if epoch_bidx % wargs.display_freq == 0: #print show_correct_num, show_loss, show_trg_words, show_loss/show_trg_words ud = time.time() - show_start - sample_spend - eval_spend wlog( 'Epo:{:>2}/{:>2} |[{:^5} {:^5} {:^5}k] |acc:{:5.2f}% |ppl:{:4.2f} ' '||w-logZ|:{:.2f} ||s-logZ|:{:.2f} ' '|stok/s:{:>4}/{:>2}={:>2} |ttok/s:{:>2} ' '|stok/sec:{:6.2f} |ttok/sec:{:6.2f} |elapsed:{:4.2f}/{:4.2f}m' .format(epoch, wargs.max_epochs, epoch_bidx, batch_idx, bidx / 1000, (show_correct_num / show_trg_words) * 100, math.exp(show_loss / show_trg_words), show_batch_logZ / show_trg_words, show_batch_logZ / show_n_sents, batch_src_words, this_bnum, int(batch_src_words / this_bnum), int(batch_trg_words / this_bnum), show_src_words / ud, show_trg_words / ud, ud, (time.time() - train_start) / 60.)) show_loss, show_src_words, show_trg_words, show_correct_num, \ show_batch_logZ, show_n_sents = 0, 0, 0, 0, 0, 0 sample_spend, eval_spend = 0, 0 show_start = time.time() if epoch_bidx % wargs.sampling_freq == 0: sample_start = time.time() self.model.eval() # (max_len_batch, batch_size) sample_src_tensor = srcs.t()[:sample_size] sample_trg_tensor = trgs.t()[:sample_size] sample_src_tensor_pos = T_spos[: sample_size] if wargs.model == 8 else None tor_hook.trans_samples(sample_src_tensor, sample_trg_tensor, sample_src_tensor_pos) wlog('') sample_spend = time.time() - sample_start self.model.train() # Just watch the translation of some source sentences in training data if wargs.if_fixed_sampling and bidx == batch_start_sample: # randomly select sample_size sample from current batch rand_rows = np.random.choice(this_bnum, sample_size, replace=False) sample_src_tensor = tc.Tensor(sample_size, srcs.size(0)).long() sample_src_tensor.fill_(PAD) sample_trg_tensor = tc.Tensor(sample_size, trgs.size(0)).long() sample_trg_tensor.fill_(PAD) for id in xrange(sample_size): sample_src_tensor[id, :] = srcs.t()[rand_rows[id], :] sample_trg_tensor[id, :] = trgs.t()[rand_rows[id], :] if wargs.epoch_eval is not True and bidx > wargs.eval_valid_from and \ bidx % wargs.eval_valid_freq == 0: eval_start = time.time() eval_cnt[0] += 1 wlog('\nAmong epoch, batch [{}], [{}] eval save model ...'. format(epoch_bidx, eval_cnt[0])) self.mt_eval(epoch, epoch_bidx) eval_spend = time.time() - eval_start avg_epoch_loss = epoch_loss / epoch_trg_words avg_epoch_acc = epoch_num_correct / epoch_trg_words wlog('\nEnd epoch [{}]'.format(epoch)) wlog('Train accuracy {:4.2f}%'.format(avg_epoch_acc * 100)) wlog('Average loss {:4.2f}'.format(avg_epoch_loss)) wlog('Train perplexity: {0:4.2f}'.format(math.exp(avg_epoch_loss))) wlog('Train average |w-logZ|: {}/{}={} |s-logZ|: {}/{}={}'.format( epoch_batch_logZ, epoch_trg_words, epoch_batch_logZ / epoch_trg_words, epoch_batch_logZ, epoch_n_sents, epoch_batch_logZ / epoch_n_sents)) wlog('End epoch, batch [{}], [{}] eval save model ...'.format( epoch_bidx, eval_cnt[0])) mteval_bleu = self.mt_eval(epoch, epoch_bidx) self.optim.update_learning_rate(mteval_bleu, epoch) epoch_time_consume = time.time() - epoch_start wlog('Consuming: {:4.2f}s'.format(epoch_time_consume)) wlog('Finish training, comsuming {:6.2f} hours'.format( (time.time() - train_start) / 3600)) wlog('Congratulations!')
class Trainer(object): def __init__(self, model_par, train_data, vocab_data, optim, lossCompute, model, valid_data=None, tests_data=None): self.model_par, self.model, self.lossCompute, self.optim = model_par, model, lossCompute, optim self.sv, self.tv = vocab_data['src'].idx2key, vocab_data['trg'].idx2key self.train_data, self.valid_data, self.tests_data = train_data, valid_data, tests_data self.max_epochs, self.start_epoch = wargs.max_epochs, wargs.start_epoch self.n_look = wargs.n_look assert self.n_look <= wargs.batch_size, 'eyeball count > batch size' self.n_batches = len(train_data) # [low, high) self.look_xs, self.look_ys = None, None if wargs.fix_looking is True: rand_idxs = random.sample(range(train_data.n_sent), self.n_look) wlog( 'randomly look {} samples frow the whole training data'.format( self.n_look)) self.look_xs = [train_data.x_list[i][0] for i in rand_idxs] self.look_ys = [train_data.y_list_files[i][0] for i in rand_idxs] self.tor = Translator(model, self.sv, self.tv, gpu_ids=wargs.gpu_ids) self.n_eval = 1 self.grad_accum_count = wargs.grad_accum_count self.epoch_shuffle_train = wargs.epoch_shuffle_train self.epoch_shuffle_batch = wargs.epoch_shuffle_batch self.ss_cur_prob = wargs.ss_prob_begin if wargs.ss_type is not None: wlog( 'word-level optimizing bias between training and decoding ...') if wargs.bleu_sampling is True: wlog('sentence-level optimizing ...') wlog('schedule sampling value {}'.format(self.ss_cur_prob)) if self.ss_cur_prob < 1. and wargs.bleu_sampling is True: self.sampler = Nbs(self.model, self.tv, k=3, noise=wargs.bleu_gumbel_noise, batch_sample=True) if self.grad_accum_count > 1: assert ( wargs.chunk_size == 0 ), 'to accumulate grads, disable target sequence truncating' def accum_matrics(self, batch_size, xtoks, ytoks, nll, ok_ytoks, bow_loss): self.look_sents += batch_size self.e_sents += batch_size self.look_nll += nll self.look_bow_loss += bow_loss self.look_ok_ytoks += ok_ytoks self.e_nll += nll self.e_ok_ytoks += ok_ytoks self.look_xtoks += xtoks self.look_ytoks += ytoks self.e_ytoks += ytoks def grad_accumulate(self, real_batches, e_idx, n_upds): #if self.grad_accum_count > 1: # self.model_par.zero_grad() for batch in real_batches: # (batch_size, max_slen_batch) _, xs, y_for_files, bows, x_lens, xs_mask, y_mask_for_files, bows_mask = batch _batch_size = xs.size(0) ys, ys_mask = y_for_files[0], y_mask_for_files[0] #wlog('x: {}, x_mask: {}, y: {}, y_mask: {}'.format( # xs.size(), xs_mask.size(), ys.size(), ys_mask.size())) if bows is not None: bows, bows_mask = bows[0], bows_mask[0] #wlog('bows: {}, bows_mask: {})'.format(bows.size(), bows_mask.size())) _xtoks = xs.data.ne(PAD).sum().item() assert _xtoks == x_lens.data.sum().item() _ytoks = ys[:, 1:].data.ne(PAD).sum().item() #if self.grad_accum_count == 1: self.model_par.zero_grad() # exclude last target word from inputs results = self.model_par(xs, ys[:, :-1], xs_mask, ys_mask[:, :-1], self.ss_cur_prob) logits, alphas, contexts = results['logit'], results[ 'attend'], results['context'] # (batch_size, y_Lm1, out_size) gold, gold_mask = ys[:, 1:].contiguous(), ys_mask[:, 1:].contiguous() # 3. Compute loss in shards for memory efficiency. _nll, _ok_ytoks, _bow_loss = self.lossCompute( logits, e_idx, n_upds, gold, gold_mask, None, bows, bows_mask, contexts) self.accum_matrics(_batch_size, _xtoks, _ytoks, _nll, _ok_ytoks, _bow_loss) # 3. Update the parameters and statistics. self.optim.step() self.optim.optimizer.zero_grad() #tc.cuda.empty_cache() def look_samples(self, n_steps): if n_steps % wargs.look_freq == 0: look_start = time.time() self.model_par.eval() # affect the dropout !!! self.model.eval() if self.look_xs is not None and self.look_ys is not None: _xs, _ys = self.look_xs, self.look_ys else: rand_idxs = random.sample(range(self.train_data.n_sent), self.n_look) wlog('randomly look {} samples frow the whole training data'. format(self.n_look)) _xs = [self.train_data.x_list[i][0] for i in rand_idxs] _ys = [self.train_data.y_list_files[i][0] for i in rand_idxs] self.tor.trans_samples(_xs, _ys) wlog('') self.look_spend = time.time() - look_start self.model_par.train() self.model.train() def try_valid(self, e_idx, e_bidx, n_steps): if n_steps > 150000: wargs.eval_valid_freq = 1000 if wargs.epoch_eval is not True and n_steps > wargs.eval_valid_from and \ n_steps % wargs.eval_valid_freq == 0: eval_start = time.time() wlog('\nAmong epoch, e_batch:{}, n_steps:{}, {}-th validation ...'. format(e_bidx, n_steps, self.n_eval)) self.mt_eval(e_idx, e_bidx, n_steps) self.eval_spend = time.time() - eval_start def mt_eval(self, e_idx, e_bidx, n_steps): state_dict = { 'model': self.model.state_dict(), 'epoch': e_idx, 'batch': e_bidx, 'steps': n_steps, 'optim': self.optim } if wargs.save_one_model: model_file = '{}.pt'.format(wargs.model_prefix) else: model_file = '{}_e{}_upd{}.pt'.format(wargs.model_prefix, e_idx, n_steps) tc.save(state_dict, model_file) wlog('Saving temporary model in {}'.format(model_file)) self.model_par.eval() self.model.eval() self.tor.trans_eval(self.valid_data, e_idx, e_bidx, n_steps, model_file, self.tests_data) self.model_par.train() self.model.train() self.n_eval += 1 def train(self): wlog('start training ... ') train_start = time.time() wlog('\n' + '#' * 120 + '\n' + '#' * 30 + ' Start Training ' + '#' * 30 + '\n' + '#' * 120) batch_oracles, _checks, accum_batches, real_batches = None, None, 0, [] current_steps = self.optim.n_current_steps self.model_par.train() self.model.train() show_start = time.time() self.look_nll, self.look_ytoks, self.look_ok_ytoks, self.look_sents, self.look_bow_loss = 0, 0, 0, 0, 0 for e_idx in range(self.start_epoch, self.max_epochs + 1): wlog('\n{} Epoch [{}/{}] {}'.format('$' * 30, e_idx, self.max_epochs, '$' * 30)) if wargs.bow_loss is True: wlog('bow: {}'.format(schedule_bow_lambda(e_idx, 5, 0.5, 0.2))) # shuffle the training data for each epoch if self.epoch_shuffle_train: self.train_data.shuffle() self.e_nll, self.e_ytoks, self.e_ok_ytoks, self.e_sents = 0, 0, 0, 0 self.look_xtoks, self.look_spend, b_counter, self.eval_spend = 0, 0, 0, 0 epo_start = time.time() if self.epoch_shuffle_batch: shuffled_bidx = tc.randperm(self.n_batches) #for bidx in range(self.n_batches): bidx = 0 cond = True if wargs.lr_update_way != 'invsqrt' else self.optim.learning_rate > wargs.min_lr while cond: if self.train_data.eos() is True: break if current_steps >= wargs.max_update: wlog('Touch the max update {}'.format(wargs.max_update)) sys.exit(0) b_counter += 1 e_bidx = shuffled_bidx[ bidx] if self.epoch_shuffle_batch else bidx if wargs.ss_type is not None and self.ss_cur_prob < 1. and wargs.bleu_sampling: batch_beam_trgs = self.sampler.beam_search_trans( xs, xs_mask, ys_mask) batch_beam_trgs = [ list(zip(*b)[0]) for b in batch_beam_trgs ] #wlog(batch_beam_trgs) batch_oracles = batch_search_oracle( batch_beam_trgs, ys[1:], ys_mask[1:]) #wlog(batch_oracles) batch_oracles = batch_oracles[:-1].cuda() batch_oracles = self.model.decoder.trg_lookup_table( batch_oracles) batch = self.train_data[e_bidx] real_batches.append(batch) accum_batches += 1 if accum_batches == self.grad_accum_count: self.grad_accumulate(real_batches, e_idx, current_steps) current_steps = self.optim.n_current_steps del real_batches accum_batches, real_batches = 0, [] tc.cuda.empty_cache() #grad_checker(self.model, _checks) if current_steps % wargs.display_freq == 0: #wlog('look_ok_ytoks:{}, look_nll:{}, look_ytoks:{}'.format(self.look_ok_ytoks, self.look_nll, self.look_ytoks)) ud = time.time( ) - show_start - self.look_spend - self.eval_spend wlog( 'Epo:{:>2}/{:>2} |[{:^5}/{} {:^5}] |acc:{:5.2f}% |{:4.2f}/{:4.2f}=nll:{:4.2f} |bow:{:4.2f}' ' |w-ppl:{:4.2f} |x(y)/s:{:>4}({:>4})/{}={}({}) |x(y)/sec:{}({}) |lr:{:7.6f}' ' |{:4.2f}s/{:4.2f}m'.format( e_idx, self.max_epochs, b_counter, len(self.train_data), current_steps, (self.look_ok_ytoks / self.look_ytoks) * 100, self.look_nll, self.look_ytoks, self.look_nll / self.look_ytoks, self.look_bow_loss / self.look_ytoks, math.exp(self.look_nll / self.look_ytoks), self.look_xtoks, self.look_ytoks, self.look_sents, int(round(self.look_xtoks / self.look_sents)), int(round(self.look_ytoks / self.look_sents)), int(round(self.look_xtoks / ud)), int(round(self.look_ytoks / ud)), self.optim.learning_rate, ud, (time.time() - train_start) / 60.)) self.look_nll, self.look_xtoks, self.look_ytoks, self.look_ok_ytoks, self.look_sents, self.look_bow_loss = 0, 0, 0, 0, 0, 0 self.look_spend, self.eval_spend = 0, 0 show_start = time.time() self.look_samples(current_steps) self.try_valid(e_idx, e_bidx, current_steps) bidx += 1 avg_epo_acc, avg_epo_nll = self.e_ok_ytoks / self.e_ytoks, self.e_nll / self.e_ytoks wlog('\nEnd epoch [{}]'.format(e_idx)) wlog('avg. w-acc: {:4.2f}%, w-nll: {:4.2f}, w-ppl: {:4.2f}'.format( avg_epo_acc * 100, avg_epo_nll, math.exp(avg_epo_nll))) if wargs.epoch_eval is True: wlog( '\nEnd epoch, e_batch:{}, n_steps:{}, {}-th validation ...' .format(e_bidx, n_steps, self.n_eval)) self.mt_eval(e_idx, e_bidx, self.optim.n_current_steps) # decay the probability value epslion of scheduled sampling per batch if wargs.ss_type is not None: self.ss_cur_prob = ss_prob_decay(e_idx) # start from 1. epo_time_consume = time.time() - epo_start wlog('Consuming: {:4.2f}s'.format(epo_time_consume)) wlog('Finish training, comsuming {:6.2f} hours'.format( (time.time() - train_start) / 3600)) wlog('Congratulations!')
def train(self): wlog('Start training ... ') assert wargs.sample_size < wargs.batch_size, 'Batch size < sample count' # [low, high) batch_count = len(self.train_data) batch_start_sample = tc.randperm(batch_count)[0] wlog('Randomly select {} samples in the {}th/{} batch'.format(wargs.sample_size, batch_start_sample, batch_count)) bidx, eval_cnt, ss_eps_cur = 0, [0], wargs.ss_eps_begin wlog('Self-normalization alpha -> {}'.format(wargs.self_norm_alpha)) train_start = time.time() wlog('') wlog('#' * 120) wlog('#' * 30, False) wlog(' Start Training ', False) wlog('#' * 30) wlog('#' * 120) for epoch in range(wargs.start_epoch, wargs.max_epochs + 1): epoch_start = time.time() # train for one epoch on the training data wlog('') wlog('$' * 30, False) wlog(' Epoch [{}/{}] '.format(epoch, wargs.max_epochs), False) wlog('$' * 30) wlog('Schedule sampling value {}'.format(ss_eps_cur)) if wargs.epoch_shuffle and epoch > wargs.epoch_shuffle_minibatch: self.train_data.shuffle() # shuffle the original batch shuffled_batch_idx = tc.randperm(batch_count) sample_size = wargs.sample_size epoch_loss, epoch_trg_words, epoch_num_correct = 0, 0, 0 show_loss, show_src_words, show_trg_words, show_correct_num = 0, 0, 0, 0 sample_spend, eval_spend, epoch_bidx = 0, 0, 0 show_start = time.time() for k in range(batch_count): bidx += 1 epoch_bidx = k + 1 batch_idx = shuffled_batch_idx[k] if epoch >= wargs.epoch_shuffle_minibatch else k # (max_slen_batch, batch_size) _, srcs, trgs, slens, srcs_m, trgs_m = self.train_data[batch_idx] self.model.zero_grad() # (max_tlen_batch - 1, batch_size, out_size) outputs = self.model(srcs, trgs[:-1], srcs_m, trgs_m[:-1], ss_eps=ss_eps_cur) if len(outputs) == 2: (outputs, _checks) = outputs this_bnum = outputs.size(1) #batch_loss, grad_output, batch_correct_num = memory_efficient( # outputs, trgs[1:], trgs_m[1:], self.model.classifier) batch_loss, batch_correct_num, batch_log_norm = self.model.decoder.classifier.snip_back_prop( outputs, trgs[1:], trgs_m[1:], wargs.snip_size) _grad_nan = False for n, p in self.model.named_parameters(): if p.grad is None: debug('grad None | {}'.format(n)) continue tmp_grad = p.grad.data.cpu().numpy() if numpy.isnan(tmp_grad).any(): # we check gradient here for vanishing Gradient wlog("grad contains 'nan' | {}".format(n)) #wlog("gradient\n{}".format(tmp_grad)) _grad_nan = True if n == 'decoder.l_f1_0.weight' or n == 's_init.weight' or n=='decoder.l_f1_1.weight' \ or n == 'decoder.l_conv.0.weight' or n == 'decoder.l_f2.weight': debug('grad zeros |{:5} {}'.format(str(not np.any(tmp_grad)), n)) if _grad_nan is True and wargs.dynamic_cyk_decoding is True: for _i, items in enumerate(_checks): wlog('step {} Variable----------------:'.format(_i)) #for item in items: wlog(item.cpu().data.numpy()) wlog('wen _check_tanh_sa ---------------') wlog(items[0].cpu().data.numpy()) wlog('wen _check_a1_weight ---------------') wlog(items[1].cpu().data.numpy()) wlog('wen _check_a1 ---------------') wlog(items[2].cpu().data.numpy()) wlog('wen alpha_ij---------------') wlog(items[3].cpu().data.numpy()) wlog('wen before_mask---------------') wlog(items[4].cpu().data.numpy()) wlog('wen after_mask---------------') wlog(items[5].cpu().data.numpy()) #outputs.backward(grad_output) self.optim.step() #del outputs, grad_output batch_src_words = srcs.data.ne(PAD).sum() assert batch_src_words == slens.data.sum() batch_trg_words = trgs[1:].data.ne(PAD).sum() show_loss += batch_loss show_correct_num += batch_correct_num epoch_loss += batch_loss epoch_num_correct += batch_correct_num show_src_words += batch_src_words show_trg_words += batch_trg_words epoch_trg_words += batch_trg_words batch_log_norm = tc.mean(tc.abs(batch_log_norm)) if epoch_bidx % wargs.display_freq == 0: #print show_correct_num, show_loss, show_trg_words, show_loss/show_trg_words ud = time.time() - show_start - sample_spend - eval_spend wlog( 'Epo:{:>2}/{:>2} |[{:^5} {:^5} {:^5}k] |acc:{:5.2f}% |ppl:{:4.2f} ' '| |logZ|:{:.2f} ' '|stok/s:{:>4}/{:>2}={:>2} |ttok/s:{:>2} ' '|stok/sec:{:6.2f} |ttok/sec:{:6.2f} |elapsed:{:4.2f}/{:4.2f}m'.format( epoch, wargs.max_epochs, epoch_bidx, batch_idx, bidx/1000, (show_correct_num / show_trg_words) * 100, math.exp(show_loss / show_trg_words), batch_log_norm, batch_src_words, this_bnum, int(batch_src_words / this_bnum), int(batch_trg_words / this_bnum), show_src_words / ud, show_trg_words / ud, ud, (time.time() - train_start) / 60.) ) show_loss, show_src_words, show_trg_words, show_correct_num = 0, 0, 0, 0 sample_spend, eval_spend = 0, 0 show_start = time.time() if epoch_bidx % wargs.sampling_freq == 0: sample_start = time.time() self.model.eval() #self.model.classifier.eval() tor = Translator(self.model, self.sv, self.tv) # (max_len_batch, batch_size) sample_src_tensor = srcs.t()[:sample_size] sample_trg_tensor = trgs.t()[:sample_size] tor.trans_samples(sample_src_tensor, sample_trg_tensor) wlog('') sample_spend = time.time() - sample_start self.model.train() # Just watch the translation of some source sentences in training data if wargs.if_fixed_sampling and bidx == batch_start_sample: # randomly select sample_size sample from current batch rand_rows = np.random.choice(this_bnum, sample_size, replace=False) sample_src_tensor = tc.Tensor(sample_size, srcs.size(0)).long() sample_src_tensor.fill_(PAD) sample_trg_tensor = tc.Tensor(sample_size, trgs.size(0)).long() sample_trg_tensor.fill_(PAD) for id in xrange(sample_size): sample_src_tensor[id, :] = srcs.t()[rand_rows[id], :] sample_trg_tensor[id, :] = trgs.t()[rand_rows[id], :] if wargs.epoch_eval is not True and bidx > wargs.eval_valid_from and \ bidx % wargs.eval_valid_freq == 0: eval_start = time.time() eval_cnt[0] += 1 wlog('\nAmong epoch, batch [{}], [{}] eval save model ...'.format( epoch_bidx, eval_cnt[0])) self.mt_eval(epoch, epoch_bidx) eval_spend = time.time() - eval_start avg_epoch_loss = epoch_loss / epoch_trg_words avg_epoch_acc = epoch_num_correct / epoch_trg_words wlog('\nEnd epoch [{}]'.format(epoch)) wlog('Train accuracy {:4.2f}%'.format(avg_epoch_acc * 100)) wlog('Average loss {:4.2f}'.format(avg_epoch_loss)) wlog('Train perplexity: {0:4.2f}'.format(math.exp(avg_epoch_loss))) wlog('End epoch, batch [{}], [{}] eval save model ...'.format(epoch_bidx, eval_cnt[0])) mteval_bleu = self.mt_eval(epoch, epoch_bidx) self.optim.update_learning_rate(mteval_bleu, epoch) # decay the probability value epslion of scheduled sampling per batch ss_eps_cur = schedule_sample_eps_decay(epoch) # start from 1 epoch_time_consume = time.time() - epoch_start wlog('Consuming: {:4.2f}s'.format(epoch_time_consume)) wlog('Finish training, comsuming {:6.2f} hours'.format((time.time() - train_start) / 3600)) wlog('Congratulations!')