def train_model(opt, model, train_iter, valid_iter, fields, optim, lr_scheduler, start_epoch_at): train_loss = nmt.NMTLossCompute(model.generator, fields['tgt'].vocab) valid_loss = nmt.NMTLossCompute(model.generator, fields['tgt'].vocab) if use_cuda: train_loss = train_loss.cuda() valid_loss = valid_loss.cuda() shard_size = opt.train_shard_size trainer = nmt.Trainer(opt, model, train_iter, valid_iter, train_loss, valid_loss, optim, shard_size) num_train_epochs = opt.num_train_epochs print('start training...') for step_epoch in range(start_epoch_at + 1, num_train_epochs): if step_epoch >= opt.start_decay_at: lr_scheduler.step() # 1. Train for one epoch on the training set. train_stats = trainer.train(step_epoch, report_func) print('Train perplexity: %g' % train_stats.ppl()) #2. Validate on the validation set. valid_stats = trainer.validate() print('Validation perplexity: %g' % valid_stats.ppl()) trainer.epoch_step(step_epoch, out_dir=opt.out_dir) model.train()
def train_model(model, train_data, valid_data, fields, optim, lr_scheduler, start_epoch_at): train_iter = make_train_data_iter(train_data, opt) valid_iter = make_valid_data_iter(valid_data, opt) train_loss = nmt.NMTLossCompute(model.generator, fields['tgt'].vocab) valid_loss = nmt.NMTLossCompute(model.generator, fields['tgt'].vocab) if use_cuda: train_loss = train_loss.cuda() valid_loss = valid_loss.cuda() shard_size = opt.train_shard_size trainer = nmt.Trainer(opt, model, train_iter, valid_iter, train_loss, valid_loss, optim, lr_scheduler, shard_size) num_train_epochs = opt.num_train_epochs print('start training...') for step_epoch in range(start_epoch_at + 1, num_train_epochs): if step_epoch >= opt.start_decay_at: trainer.lr_scheduler.step() # 1. Train for one epoch on the training set. train_stats = trainer.train(step_epoch, report_func) print('Train perplexity: %g' % train_stats.ppl()) # 2. Validate on the validation set. valid_stats = trainer.validate() print('Validation perplexity: %g' % valid_stats.ppl()) trainer.epoch_step(step_epoch, out_dir=opt.out_dir) if opt.test_bleu: model.eval() valid_bleu = test_bleu(model, fields, step_epoch) model.train() train_stats.log("train", summery_writer, step_epoch, ppl=train_stats.ppl(), learning_rate=optim.lr, accuracy=train_stats.accuracy()) valid_stats.log("valid", summery_writer, step_epoch, ppl=valid_stats.ppl(), learning_rate=optim.lr, bleu=valid_bleu if opt.test_bleu else 0.0, accuracy=valid_stats.accuracy())
def train_model(opt, model, train_iter, valid_iter, fields, optimG, lr_schedulerG, optimD, lr_schedulerD, start_epoch_at): num_train_epochs = opt.num_train_epochs num_updates = 0 print('start training...') valid_loss = nmt.NMTLossCompute(model.generator.generator, fields['tgt'].vocab) if use_cuda: valid_loss = valid_loss.cuda() shard_size = opt.train_shard_size trainer = nmt.Trainer(opt, model.generator, train_iter, valid_iter, valid_loss, valid_loss, optimG, lr_schedulerG, shard_size, train_loss_b=None) for step_epoch in range(start_epoch_at + 1, num_train_epochs): for batch in train_iter: if num_updates % (opt.D_turns + 1) == -1 % (opt.D_turns + 1): G_turn(model, batch, optimG, opt) else: D_turn(model, batch, optimD, opt) if num_updates % (opt.show_sample_every) == -1 % ( opt.show_sample_every): D_turn(model, batch, optimD, opt, show_sample=True) num_updates += 1 sys.stdout.flush() valid_stats = trainer.validate() print('Validation perplexity: %g' % valid_stats.ppl()) sys.stdout.flush() if step_epoch >= opt.start_decay_at: lr_schedulerD.step() lr_schedulerG.step() save_per_epoch(model, step_epoch, opt) model.train()
def train_model(opt, model, critic, train_iter, valid_iter, fields, optimR, lr_schedulerR, optimT, lr_schedulerT, optimC, lr_schedulerC, start_epoch_at): train_loss = nmt.NMTLossCompute(model.generator, fields['tgt'].vocab) valid_loss = nmt.NMTLossCompute(model.generator, fields['tgt'].vocab) if use_cuda: train_loss = train_loss.cuda() valid_loss = valid_loss.cuda() shard_size = opt.train_shard_size trainer = nmt.Trainer(opt, model, train_iter, valid_iter, train_loss, valid_loss, optimR, shard_size) scorer = nmt.Scorer(model, fields['tgt'].vocab, fields['src'].vocab, train_loss, opt) num_train_epochs = opt.num_train_epochs print('start training...') global_step = 0 for step_epoch in range(start_epoch_at + 1, num_train_epochs): if step_epoch >= opt.start_decay_at: lr_schedulerR.step() if lr_schedulerT is not None: lr_schedulerT.step() if lr_schedulerC is not None: lr_schedulerC.step() total_stats = Statistics() report_stats = Statistics() for step_batch, batch in enumerate(train_iter): global_step += 1 if global_step % 6 == -1 % global_step: T_turn = False C_turn = False R_turn = True else: T_turn = False C_turn = False R_turn = True if C_turn: model.template_generator.eval() model.response_generator.eval() critic.train() optimC.optimizer.zero_grad() src_inputs, src_lengths = batch.src tgt_inputs, tgt_lengths = batch.tgt ref_src_inputs, ref_src_lengths = batch.ref_src ref_tgt_inputs, ref_tgt_lengths = batch.ref_tgt I_word, I_word_length = batch.I D_word, D_word_length = batch.D preds, ev = model.template_generator(I_word, I_word_length, D_word, D_word_length, ref_tgt_inputs, ref_tgt_lengths, return_ev=True) preds = preds.squeeze(2) template, template_lengths = model.template_generator.do_mask_and_clean( preds, ref_tgt_inputs, ref_tgt_lengths) #x = template.t().data.tolist() #vocab = fields['tgt'].vocab #for t in x: # print ("---", ' '.join([vocab.itos[tt] for tt in t])) (response, response_length), logp = sample(model.response_generator, src_inputs, None, template, src_lengths, None, template_lengths, max_len=20) enc_embedding = model.response_generator.enc_embedding dec_embedding = model.response_generator.dec_embedding inds = np.arange(len(tgt_lengths)) np.random.shuffle(inds) inds_tensor = Variable(torch.LongTensor(inds).cuda()) random_tgt = tgt_inputs.index_select(1, inds_tensor) random_tgt_len = [tgt_lengths[i] for i in inds] #vocab = fields['tgt'].vocab #vocab_src = fields['src'].vocab #w = src_inputs.t().data.tolist() #x = tgt_inputs.t().data.tolist() #y = response.t().data.tolist() #z = random_tgt.t().data.tolist() #for tw, tx, ty, tz in zip(w, x, y, z): # print (' '.join([vocab_src.itos[tt] for tt in tw]), '|||||', ' '.join([vocab.itos[tt] for tt in tx]), '|||||', ' '.join([vocab.itos[tt] for tt in ty]), '|||||',' '.join([vocab.itos[tt] for tt in tz])) x, y, z = critic(enc_embedding(src_inputs), src_lengths, dec_embedding(tgt_inputs), tgt_lengths, dec_embedding(response), response_length, dec_embedding(random_tgt), random_tgt_len) loss = torch.mean(-x) #print (loss.data[0]) loss.backward() optimC.step() stats = Statistics() elif T_turn: model.template_generator.train() model.response_generator.eval() critic.eval() stats = scorer.update(batch, optimT, 'T', sample, critic) elif R_turn: #I_word, I_word_length = batch.I #D_word, D_word_length = batch.D #print("R_TURN : I_word : {}, D_word: {}".format(I_word, D_word)) if not (model.__class__.__name__ == "jointTemplateResponseGenerator"): model.template_generator.eval() model.response_generator.train() critic.eval() if global_step % 2 == 0: stats = trainer.update(batch) else: stats = scorer.update(batch, optimR, 'R', sample, critic) else: stats = trainer.update(batch) report_stats.update(stats) total_stats.update(stats) report_func(opt, global_step, step_epoch, step_batch, len(train_iter), total_stats.start_time, optimR.lr, report_stats) if critic is not None: critic.save_checkpoint( step_epoch, opt, os.path.join(opt.out_dir, "checkpoint_epoch_critic%d.pkl" % step_epoch)) print("\nEpoch : {} ______________________________".format(step_epoch)) print('Train perplexity: %g' % total_stats.ppl()) #2. Validate on the validation set. valid_stats = trainer.validate() print('Validation perplexity: %g' % valid_stats.ppl()) trainer.epoch_step(step_epoch, out_dir=opt.out_dir) model.train()