def train_model(model, fields, optim, data_type, model_opt, train_part): train_loss = make_loss_compute(model, fields["tgt"].vocab, opt) valid_loss = make_loss_compute(model, fields["tgt"].vocab, opt, train=False) trunc_size = opt.truncated_decoder # Badly named... shard_size = opt.max_generator_batches norm_method = opt.normalization grad_accum_count = opt.accum_count trainer = onmt.Trainer(model, train_loss, valid_loss, optim, trunc_size, shard_size, data_type, norm_method, grad_accum_count) print('\nStart training...') print(' * number of epochs: %d, starting from Epoch %d' % (opt.epochs + 1 - opt.start_epoch, opt.start_epoch)) print(' * batch size: %d' % opt.batch_size) #for epoch in range(opt.start_epoch, opt.epochs + 1): for epoch in range(1, opt.epochs + 1): print(f"Start to train on {epoch}/{opt.epochs}") # 1. Train for one epoch on the training set. train_iter = None train_iter = make_dataset_iter(lazily_load_dataset("train"), fields, opt) train_stats = trainer.train(train_iter, epoch, report_func, train_part, model_opt, fields) print('Train perplexity: %g' % train_stats.ppl()) print('Train accuracy: %g' % train_stats.accuracy()) # 2. Validate on the validation set. train_iter = None valid_iter = None valid_iter = make_dataset_iter(lazily_load_dataset("valid"), fields, opt, is_train=False) valid_stats = trainer.validate(valid_iter, train_part) print(f"Epoch: {epoch}") print('Validation perplexity: %g' % valid_stats.ppl()) print('Validation accuracy: %g' % valid_stats.accuracy()) # 3. Log to remote server. if opt.exp_host: train_stats.log("train", experiment, optim.lr) valid_stats.log("valid", experiment, optim.lr) if opt.tensorboard: train_stats.log_tensorboard("train", writer, optim.lr, epoch) train_stats.log_tensorboard("valid", writer, optim.lr, epoch) # 4. Update the learning rate trainer.epoch_step(valid_stats.ppl(), epoch) # 5. Drop a checkpoint if needed. if epoch >= opt.start_checkpoint_at: trainer.drop_checkpoint(model_opt, epoch, fields, valid_stats)
def train_model(model, train_data, valid_data, fields, optim, opt, num_runs): train_iter = make_train_data_iter(train_data, opt) valid_iter = make_valid_data_iter(valid_data, opt) train_loss = make_loss_compute(model, fields["tgt"].vocab, train_data, opt) valid_loss = make_loss_compute(model, fields["tgt"].vocab, valid_data, opt) trunc_size = opt.truncated_decoder # Badly named... shard_size = opt.max_generator_batches trainer = onmt.Trainer(model, train_iter, valid_iter, train_loss, valid_loss, optim, trunc_size, shard_size) # pdb.set_trace() for epoch in range(1): print('') # 1. Train for one epoch on the training set. train_stats = trainer.train(epoch, report_func) print('Train perplexity: %g' % train_stats.ppl()) print('Train accuracy: %g' % train_stats.accuracy()) logger.scalar_summary('train_%s_ppl' % num_runs, train_stats.ppl(), num_runs*(opt.epochs+1)+epoch) logger.scalar_summary('train_%s_acc' % num_runs, train_stats.accuracy(), num_runs*(opt.epochs+1)+epoch)
def train_model(model, train_data, valid_data, fields, optim): min_ppl, max_accuracy = float('inf'), -1 train_iter = make_train_data_iter(train_data, opt) valid_iter = make_valid_data_iter(valid_data, opt) train_loss = make_loss_compute(model, fields["tgt"].vocab, train_data, opt) valid_loss = make_loss_compute(model, fields["tgt"].vocab, valid_data, opt) trunc_size = opt.truncated_decoder # Badly named... shard_size = opt.max_generator_batches trainer = onmt.Trainer(model, train_iter, valid_iter, train_loss, valid_loss, optim, trunc_size, shard_size) for epoch in range(opt.start_epoch, opt.epochs + 1): print('') # 1. Train for one epoch on the training set. train_stats = trainer.train(epoch, report_func) print('Train perplexity: %g' % train_stats.ppl()) print('Train accuracy: %g' % train_stats.accuracy()) # 2. Validate on the validation set. valid_stats = trainer.validate() print('Validation perplexity: %g' % valid_stats.ppl()) print('Validation accuracy: %g' % valid_stats.accuracy()) # 3. Log to remote server. if opt.exp_host: train_stats.log("train", experiment, optim.lr) valid_stats.log("valid", experiment, optim.lr) # 4. Update the learning rate trainer.epoch_step(valid_stats.ppl(), epoch) # 5. Drop a checkpoint if needed. if epoch >= opt.start_checkpoint_at: if valid_stats.accuracy() > max_accuracy: # 5.1 drop checkpoint when bigger accuracy is achieved. min_ppl = min(valid_stats.ppl(), min_ppl) max_accuracy = max(valid_stats.accuracy(), max_accuracy) trainer.drop_checkpoint(opt, epoch, fields, valid_stats) print('Save model according to biggest-ever accuracy: acc: {0}, ppl: {1}'.format(max_accuracy, min_ppl)) elif valid_stats.ppl() < min_ppl: # 5.2 drop checkpoint when smaller ppl is achieved. min_ppl = min(valid_stats.ppl(), min_ppl) max_accuracy = max(valid_stats.accuracy(), max_accuracy) trainer.drop_checkpoint(opt, epoch, fields, valid_stats) print('Save model according to lowest-ever ppl: acc: {0}, ppl: {1}'.format(max_accuracy, min_ppl))
def train_model(model1, model2, train_data, valid_data, fields1, fields2, optim1, optim2): train_iter = make_train_data_iter(train_data, opt) valid_iter = make_valid_data_iter(valid_data, opt) train_loss1, train_loss2 = make_loss_compute(model1, model2, fields1["tgt"].vocab, train_data, opt) valid_loss1, valid_loss2 = make_loss_compute(model1, model2, fields1["tgt"].vocab, valid_data, opt) trunc_size = opt.truncated_decoder # Badly named... shard_size = opt.max_generator_batches trainer = onmt.Trainer(model1, model2, train_iter, valid_iter, train_loss1, valid_loss1, train_loss2, valid_loss2, optim1, optim2, trunc_size, shard_size) for epoch in range(opt.start_epoch, opt.epochs + 1): print('') # 1. Train for one epoch on the training set. train_stats1, train_stats2 = trainer.train(epoch, report_func) print('Train perplexity 1: %g' % train_stats1.ppl()) print('Train accuracy 1: %g' % train_stats1.accuracy()) print('Train perplexity 2: %g' % train_stats2.ppl()) print('Train accuracy 2: %g' % train_stats2.accuracy()) # 2. Validate on the validation set. valid_stats1, valid_stats2 = trainer.validate() print('Validation perplexity 1: %g' % valid_stats1.ppl()) print('Validation accuracy 1: %g' % valid_stats1.accuracy()) print('Validation perplexity 2: %g' % valid_stats2.ppl()) print('Validation accuracy 2: %g' % valid_stats2.accuracy()) # 3. Log to remote server. if opt.exp_host: train_stats1.log("train", experiment, optim1.lr) valid_stats1.log("valid", experiment, optim1.lr) train_stats2.log("train", experiment, optim2.lr) valid_stats2.log("valid", experiment, optim2.lr) # 4. Update the learning rate trainer.epoch_step(valid_stats1.ppl(), epoch) trainer.epoch_step(valid_stats2.ppl(), epoch) # 5. Drop a checkpoint if needed. if epoch >= opt.start_checkpoint_at: trainer.drop_checkpoint(opt, epoch, fields1, fields2, valid_stats1, valid_stats2)
def train_model(model, train_dataset, valid_dataset, fields, optim, model_opt): print("making iters") train_iter = make_train_data_iter(train_dataset, opt) valid_iter = make_valid_data_iter(valid_dataset, opt) print("making losses") train_loss = make_loss_compute(model, fields["tgt"].vocab, train_dataset, opt) valid_loss = make_loss_compute(model, fields["tgt"].vocab, valid_dataset, opt) trunc_size = opt.truncated_decoder # Badly named... shard_size = opt.max_generator_batches data_type = train_dataset.data_type print("making trainer") trainer = onmt.Trainer(model, train_iter, valid_iter, train_loss, valid_loss, optim, trunc_size, shard_size, data_type) print("made trainer") for epoch in range(opt.start_epoch, opt.epochs + 1): print('') # 1. Train for one epoch on the training set. train_stats = trainer.train(epoch, report_func) print('Train perplexity: %g' % train_stats.ppl()) print('Train accuracy: %g' % train_stats.accuracy()) # 2. Validate on the validation set. valid_stats = trainer.validate() print('Validation perplexity: %g' % valid_stats.ppl()) print('Validation accuracy: %g' % valid_stats.accuracy()) # 3. Log to remote server. if opt.exp_host: train_stats.log("train", experiment, optim.lr) valid_stats.log("valid", experiment, optim.lr) # 4. Update the learning rate trainer.epoch_step(valid_stats.ppl(), epoch) # 5. Drop a checkpoint if needed. if epoch >= opt.start_checkpoint_at: trainer.drop_checkpoint(model_opt, epoch, fields, valid_stats)
def create_trainer( encoder_vocab, decoder_vocab, device, decoder_hidden_size=256, embedding_padding_idx=0, report_every=50, ): # Get the model model = create_model( encoder_vocab=encoder_vocab, decoder_vocab=decoder_vocab, batch_size=64, ) model.to(device) # Set the optimizer optimizer = onmt.utils.Optimizer( optimizer=torch.optim.SGD(model.parameters(), lr=1), learning_rate=1, learning_rate_decay_fn=lambda n: 1, ) # Set the loss function model.generator = torch.nn.Sequential( torch.nn.Linear(decoder_hidden_size, len(decoder_vocab)), torch.nn.LogSoftmax(dim=-1)).to(device) loss = onmt.utils.loss.NMTLossCompute(criterion=torch.nn.NLLLoss( ignore_index=embedding_padding_idx, reduction='sum'), generator=model.generator) # Reports report_manager = onmt.utils.ReportMgr(report_every=report_every, start_time=None, tensorboard_writer=None) # Finally get the trainer return model, onmt.Trainer( model=model, optim=optimizer, train_loss=loss, valid_loss=loss, report_manager=report_manager, )
def train_model(auto_models, valid_model, train_data, valid_data, fields_list, valid_fields, optims, discrim_models, discrim_optims, labels): # train_model(models, valid_model, train, valid, fields, fields_valid, optims, # discrim_models, discrim_optims, advers_optims, labels) trunc_size = opt.truncated_decoder # Badly named... shard_size = opt.max_generator_batches valid_iter = make_valid_data_iter(valid_data, opt) valid_loss = make_loss_compute(valid_model, valid_fields["tgt"].vocab, valid_data, opt) trainers = [] trainers.append( onmt.Trainer(valid_model, valid_iter, valid_iter, valid_loss, valid_loss, optims[0], trunc_size, shard_size)) for model, discrim_model, optim, label, train, fields in zip( auto_models, discrim_models, optims, labels, train_data, fields_list): train_iter = make_train_data_iter(train, opt) train_loss = make_loss_compute(model, fields["tgt"].vocab, train, opt) trainers.append( onmt.AdvTrainer(model, discrim_model, train_iter, valid_iter, train_loss, valid_loss, optim, label, trunc_size, shard_size)) discrim_trainers = [] for model, discrim_optim, data in zip(discrim_models, discrim_optims, train_data): train_iter = make_train_data_iter(data, opt) discrim_trainers.append( onmt.DiscrimTrainer(model, train_iter, discrim_optim, shard_size)) #for model, optim, data in zip(discrim_models, advers_optims, advers_data): # train_iter = make_train_data_iter(data, opt) # discrim_trainers.append(onmt.DiscrimTrainer(model, train_iter, optim, shard_size)) for epoch in range(opt.start_epoch, opt.epochs + 1): print('') for label, trainer in zip(labels, discrim_trainers): # 1. Train for one epoch on the training set. train_stats = trainer.train(epoch, label, discrim_report_func) print('Train loss: %g' % train_stats.loss) #print('Train accuracy: %g' % train_stats.accuracy()) if opt.exp_host: train_stats.log("train", experiment, optim.lr) for trainer in trainers[1:]: # 1. Train for one epoch on the training set. train_stats = trainer.train(epoch, report_func) print('Train perplexity: %g' % train_stats.ppl()) print('Train accuracy: %g' % train_stats.accuracy()) if opt.exp_host: train_stats.log("train", experiment, optim.lr) # 2. Validate on the validation set. valid_stats = trainers[0].validate() print('Validation perplexity: %g' % valid_stats.ppl()) print('Validation accuracy: %g' % valid_stats.accuracy()) # 3. Log to remote server. if opt.exp_host: valid_stats.log("valid", experiment, optim.lr) ''' for trainer in trainers[1:]: # 4. Update the learning rate trainer.epoch_step(valid_stats.ppl(), epoch) for trainer in discrim_trainers: # 4. Update the learning rate trainer.epoch_step(valid_stats.ppl(), epoch) ''' # 5. Drop a checkpoint if needed. if epoch >= opt.start_checkpoint_at: trainers[0].drop_checkpoint(opt, epoch, fields, valid_stats)
args, extra_args = argparser.parse_known_args() opt = Configurable(args.config_file, extra_args) model = ADVModel(opt) optim = onmt.Optim( opt.optim, opt.learning_rate, opt.max_grad_norm, lr_decay=opt.learning_rate_decay, start_decay_at=opt.start_decay_at, beta1=opt.adam_beta1, beta2=opt.adam_beta2, adagrad_accum=opt.adagrad_accumulator_init, decay_method=opt.decay_method, warmup_steps=opt.warmup_steps, model_size=opt.rnn_size) optim.set_parameters(model.named_parameters()) tgt_vocab = Vocab(opt.tgt_vocab) loss_compute = onmt.Loss.NMTLossCompute(model.generator, tgt_vocab).cuda() trainer = onmt.Trainer(model, loss_compute, loss_compute, optim) train_set = Data_Loader(opt.train_file, opt.batch_size) valid_set = Data_Loader(opt.dev_file, opt.batch_size) for epoch in xrange(opt.max_epoch): train_stats = trainer.train(train_set, epoch, report_func) print('Train perplexity: %g' % train_stats.ppl()) print('Train accuracy: %g' % train_stats.accuracy()) valid_stats = trainer.validate(valid_set) print('Validation perplexity: %g' % valid_stats.ppl()) print('Validation accuracy: %g' % valid_stats.accuracy()) trainer.epoch_step(valid_stats.ppl(), epoch) model.save_checkpoint(epoch, opt)
def train_model(model, fields, optim, data_type, model_opt): if opt.emotional_words: print("Emotional words only") embedding_copy = torch.load(opt.emotional_words) else: embedding_copy = model.decoder.embeddings.embedding_copy train_loss = make_loss_compute(model, fields["tgt"].vocab, opt, embedding_copy) valid_loss = make_loss_compute(model, fields["tgt"].vocab, opt, embedding_copy) trunc_size = opt.truncated_decoder # Badly named... shard_size = opt.max_generator_batches norm_method = opt.normalization grad_accum_count = opt.accum_count trainer = onmt.Trainer(model, train_loss, valid_loss, optim, trunc_size, shard_size, data_type, norm_method, grad_accum_count, opt.affect_bias, fields["tgt"].vocab) print('\nStart training...') print(' * number of epochs: %d, starting from Epoch %d' % (opt.epochs + 1 - opt.start_epoch, opt.start_epoch)) print(' * batch size: %d' % opt.batch_size) # Create a file to save validation metrics val_metrics_file = "./data/log/" + opt.exp + ".csv" with open(val_metrics_file, "w") as f: f.write("loss,perplexity,distinct-1,distinct-2,embed_greedy,embed_avg,embed_extrema,affect_distance,affect_strength\n") for epoch in range(opt.start_epoch, opt.epochs + 1): print('') # 1. Train for one epoch on the training set. train_iter = make_dataset_iter(lazily_load_dataset("train"), fields, opt, repeat=False) valid_iter = make_dataset_iter(lazily_load_dataset("valid"), fields, opt, is_train=False, repeat=True) # Run validation every a few iterations px train_stats = trainer.train(train_iter, epoch, report_func, valid_iter, opt.evaluate_every, report_evaluation) print('Train perplexity: %g' % train_stats.ppl()) print('Train accuracy: %g' % train_stats.accuracy()) # 2. Validate on the validation set. valid_stats, num_val_batches = trainer.validate(valid_iter) report_evaluation(valid_stats, num_val_batches) # Save validation metrics file with open(val_metrics_file, "a") as f: f.write("{0},{1},{2},{3},{4},{5},{6},{7},{8}\n".format( valid_stats.loss/valid_stats.n_words, valid_stats.ppl(), valid_stats.distinct_1(), valid_stats.distinct_2(), valid_stats.embed_greedy/num_val_batches, valid_stats.embed_avg/num_val_batches, valid_stats.embed_extrema/num_val_batches, valid_stats.affect_dist/num_val_batches, valid_stats.affect_strength/num_val_batches)) # 3. Log to remote server. if opt.exp_host: train_stats.log("train", experiment, optim.lr) valid_stats.log("valid", experiment, optim.lr) # 4. Update the learning rate trainer.epoch_step(valid_stats.ppl(), epoch) # 5. Drop a checkpoint if needed. if epoch >= opt.start_checkpoint_at: trainer.drop_checkpoint(model_opt, epoch, fields, valid_stats)
def train_model(auto_models, valid_model, train_data, valid_data, fields_list, valid_fields, optims, discrim_models, discrim_optims, labels, advers_optims): # train_model(models, valid_model, train, valid, fields, fields_valid, optims, # discrim_models, discrim_optims, advers_optims, labels) trunc_size = opt.truncated_decoder # Badly named... shard_size = opt.max_generator_batches valid_iter = make_valid_data_iter(valid_data, opt) valid_loss = make_loss_compute(valid_model, valid_fields["tgt"].vocab, valid_data, opt) valid_trainer = onmt.Trainer(valid_model, valid_iter, valid_iter, valid_loss, valid_loss, optims[0], trunc_size, shard_size) src_train_iter = make_train_data_iter(train_data[0], opt) src_train_loss = make_loss_compute(auto_models[0], fields_list[0]["tgt"].vocab, train_data[0], opt) src_trainer = onmt.Trainer(auto_models[0], src_train_iter, valid_iter, src_train_loss, valid_loss, optims[0], trunc_size, shard_size) tgt_train_iter = make_train_data_iter(train_data[1], opt) tgt_train_loss = make_loss_compute(auto_models[1], fields_list[1]["tgt"].vocab, train_data[1], opt) tgt_trainer = onmt.Trainer(auto_models[1], tgt_train_iter, valid_iter, tgt_train_loss, valid_loss, optims[1], trunc_size, shard_size) src_train_iter = make_train_data_iter(train_data[0], opt) tgt_train_iter = make_train_data_iter(train_data[1], opt) src_train_loss = make_loss_compute(auto_models[0], fields_list[0]["tgt"].vocab, train_data[0], opt) tgt_train_loss = make_loss_compute(auto_models[1], fields_list[1]["tgt"].vocab, train_data[1], opt) unsup_trainer = onmt.UnsupTrainer( auto_models, [None, None], discrim_models, [src_train_iter, tgt_train_iter], valid_iter, [src_train_loss, tgt_train_loss], [None, None], valid_loss, [optims[0], None], [None, None], [0.9, 0.9], trunc_size, shard_size) src_train_iter = make_train_data_iter(train_data[0], opt) tgt_train_iter = make_train_data_iter(train_data[1], opt) discrim_trainer = onmt.DiscrimTrainer(discrim_models, [src_train_iter, tgt_train_iter], discrim_optims, labels, shard_size) src_train_iter = make_train_data_iter(train_data[0], opt) tgt_train_iter = make_train_data_iter(train_data[1], opt) advers_trainer = onmt.DiscrimTrainer(discrim_models, [src_train_iter, tgt_train_iter], advers_optims, [0.9, 0.9], shard_size) ''' for epoch in range(10): train_stats = discrim_trainer.train(epoch, discrim_report_func) print('Discrim Train loss: %g' % train_stats.loss) for epoch in range(10): train_stats = discrim_trainer.train(epoch, discrim_report_func) print('Discrim Train loss: %g' % train_stats.loss) train_stats = advers_trainer.train(epoch, discrim_report_func) print('Advers Train loss: %g' % train_stats.loss) ''' for epoch in range(opt.start_epoch, opt.epochs + 1): print('') # 1. Train for one epoch on the training set. train_stats = src_trainer.train(epoch, discrim_report_func) print('SRC Train perplexity: %g' % train_stats.ppl()) print('SRC Train accuracy: %g' % train_stats.accuracy()) # 1. Train for one epoch on the training set. train_stats = tgt_trainer.train(epoch, discrim_report_func) print('TGT Train perplexity: %g' % train_stats.ppl()) print('TGT Train accuracy: %g' % train_stats.accuracy()) if opt.exp_host: train_stats.log("train", experiment, optim.lr) for epoch in range(opt.start_epoch, opt.epochs + 1): print('') # 1. Train for one epoch on the training set. train_stats = src_trainer.train(epoch, discrim_report_func) print('SRC Train perplexity: %g' % train_stats.ppl()) print('SRC Train accuracy: %g' % train_stats.accuracy()) train_stats = discrim_trainer.train(epoch, discrim_report_func) print('Discrim Train loss: %g' % train_stats.loss) train_stats = advers_trainer.train(epoch, discrim_report_func) print('Advers Train loss: %g' % train_stats.loss) ''' train_stats = unsup_trainer.train(epoch, report_func) print('Train perplexity: %g' % train_stats.ppl()) print('Train accuracy: %g' % train_stats.accuracy()) ''' # 2. Validate on the validation set. valid_stats = valid_trainer.validate() print('Validation perplexity: %g' % valid_stats.ppl()) print('Validation accuracy: %g' % valid_stats.accuracy()) # 3. Log to remote server. if opt.exp_host: valid_stats.log("valid", experiment, optim.lr) # 5. Drop a checkpoint if needed. if epoch >= opt.start_checkpoint_at: valid_trainer.drop_checkpoint(opt, epoch, valid_fields, valid_stats)
def train_model(model, fields, optim, data_type, model_opt): train_loss = make_loss_compute(model, fields["tgt"].vocab, opt) valid_loss = make_loss_compute(model, fields["tgt"].vocab, opt) trunc_size = opt.truncated_decoder # Badly named... shard_size = opt.max_generator_batches norm_method = opt.normalization grad_accum_count = opt.accum_count print("model_opt: %s" % str(model_opt)) trainer = onmt.Trainer(model, train_loss, valid_loss, optim, trunc_size, shard_size, data_type, norm_method, grad_accum_count, model_opt, copy.copy(fields)) print('\nStart training...') print(' * number of epochs: %d, starting from Epoch %d' % (opt.epochs + 1 - opt.start_epoch, opt.start_epoch)) print(' * batch size: %d' % opt.batch_size) for epoch in range(opt.start_epoch, opt.epochs + 1): print('') # 0. Validate on the validation set. # work-around to make model work :O :O :O if epoch == 1: valid_iter = make_dataset_iter(lazily_load_dataset("valid"), fields, opt, is_train=False) valid_stats = trainer.validate(valid_iter) print('Validation perplexity: %g' % valid_stats.ppl()) print('Validation accuracy: %g' % valid_stats.accuracy()) # 1. Train for one epoch on the training set. train_iter = make_dataset_iter(lazily_load_dataset("train"), fields, opt) train_stats = trainer.train(train_iter, epoch, report_func) print('Train perplexity: %g' % train_stats.ppl()) print('Train accuracy: %g' % train_stats.accuracy()) # 2. Validate on the validation set. valid_iter = make_dataset_iter(lazily_load_dataset("valid"), fields, opt, is_train=False) valid_stats = trainer.validate(valid_iter) print('Validation perplexity: %g' % valid_stats.ppl()) print('Validation accuracy: %g' % valid_stats.accuracy()) # 3. Log to remote server. if opt.exp_host: train_stats.log("train", experiment, optim.lr) valid_stats.log("valid", experiment, optim.lr) # 4. Update the learning rate trainer.epoch_step(valid_stats.ppl(), epoch) # 5. Drop a checkpoint if needed. if trainer.early_stop.early_stop_criteria == 'perplexity' or trainer.early_stop.early_stop_criteria is None: # no early-stopping if epoch >= opt.start_checkpoint_at: trainer.drop_checkpoint(model_opt, epoch, fields, valid_stats) else: # if we are using a non-default early-stopping criteria # save model to use for continuing training later on if needed be model_name = trainer.drop_checkpoint( model_opt, epoch, fields, valid_stats, overwrite=opt.overwrite_model_file, checkpoint_type='last') trainer.drop_metric_scores(model_opt, epoch, fields, valid_stats, overwrite=True, checkpoint_type='last') print("") if trainer.early_stop.signal_early_stopping: print("WARNING: Early stopping!") break
def train_model(model, fields, optim, data_type, model_opt): translate_parser = argparse.ArgumentParser( description='translate', formatter_class=argparse.ArgumentDefaultsHelpFormatter) onmt.opts.add_md_help_argument(translate_parser) onmt.opts.translate_opts(translate_parser) opt_translate = translate_parser.parse_args(args=[]) opt_translate.replace_unk = False opt_translate.verbose = True train_loss = make_loss_compute(model, fields["tgt"].vocab, opt) valid_loss = make_loss_compute(model, fields["tgt"].vocab, opt, train=False) trunc_size = opt.truncated_decoder # Badly named... shard_size = opt.max_generator_batches norm_method = opt.normalization grad_accum_count = opt.accum_count trainer = onmt.Trainer(model, train_loss, valid_loss, optim, trunc_size, shard_size, data_type, norm_method, grad_accum_count) logger.info('') logger.info('Start training...') logger.info(' * number of epochs: %d, starting from Epoch %d' % (opt.epochs + 1 - opt.start_epoch, opt.start_epoch)) logger.info(' * batch size: %d' % opt.batch_size) for epoch in range(opt.start_epoch, opt.epochs + 1): logger.info('') # 1. Train for one epoch on the training set. train_iter = make_dataset_iter(lazily_load_dataset("train"), fields, opt) train_stats = trainer.train(train_iter, epoch, report_func) logger.info('Train perplexity: %g' % train_stats.ppl()) logger.info('Train accuracy: %g' % train_stats.accuracy()) # 2. Validate on the validation set. valid_iter = make_dataset_iter(lazily_load_dataset("valid"), fields, opt, is_train=False) valid_stats = trainer.validate(valid_iter) logger.info('Validation perplexity: %g' % valid_stats.ppl()) logger.info('Validation accuracy: %g' % valid_stats.accuracy()) # 3. Log to remote server. if opt.exp_host: train_stats.log("train", experiment, optim.lr) valid_stats.log("valid", experiment, optim.lr) if opt.tensorboard: train_stats.log_tensorboard("train", writer, optim.lr, epoch) train_stats.log_tensorboard("valid", writer, optim.lr, epoch) # 4. Update the learning rate decay = trainer.epoch_step(valid_stats.ppl(), epoch) if decay: logger.info("Decaying learning rate to %g" % trainer.optim.lr) # 5. Drop a checkpoint if needed. if epoch % 10 == 0: #epoch >= opt.start_checkpoint_at: trainer.drop_checkpoint(model_opt, epoch, fields, valid_stats) opt_translate.src = 'cache/valid_src_{:s}.txt'.format( opt.file_templ) opt_translate.tgt = 'cache/valid_eval_refs_{:s}.txt'.format( opt.file_templ) opt_translate.output = 'result/{:s}/valid_res_{:s}.txt'.format( opt.dataset, opt.file_templ) opt_translate.model = '%s_acc_%.2f_ppl_%.2f_e%d.pt' % ( opt.save_model, valid_stats.accuracy(), valid_stats.ppl(), epoch) check_save_result_path(opt_translate.output) translator = make_translator(opt_translate, report_score=False, logger=logger) translator.calc_sacre_bleu = False translator.translate(opt_translate.src_dir, opt_translate.src, opt_translate.tgt, opt_translate.batch_size, opt_translate.attn_debug)
def train_model(model, fields, optim, data_type, model_opt): train_loss = make_loss_compute(model, fields["tgt"].vocab, opt) valid_loss = make_loss_compute(model, fields["tgt"].vocab, opt, train=False) trunc_size = opt.truncated_decoder # Badly named... shard_size = opt.max_generator_batches norm_method = opt.normalization grad_accum_count = opt.accum_count if opt.sense_loss_lbd > 1e-5: vcount = fields['src'].vocab.freqs prior = init_unigram_table(vcount) print("Done generating negative sampling prior") else: prior = None trainer = onmt.Trainer(model, train_loss, valid_loss, optim, trunc_size, shard_size, data_type, norm_method, grad_accum_count, use_sense=(opt.num_senses > 1), window_size=opt.sense_window_size, tau=opt.tau, scale=opt.scale, num_neg=opt.num_neg, sense_loss_lbd=opt.sense_loss_lbd) print('\nStart training...') print(' * number of epochs: %d, starting from Epoch %d' % (opt.epochs + 1 - opt.start_epoch, opt.start_epoch)) print(' * batch size: %d' % opt.batch_size) for epoch in range(opt.start_epoch, opt.epochs + 1): print('') # 1. Train for one epoch on the training set. train_iter = make_dataset_iter(lazily_load_dataset("train"), fields, opt) train_stats = trainer.train(train_iter, epoch, report_func, neg_prior=prior) print('Train perplexity: %g' % train_stats.ppl()) print('Train accuracy: %g' % train_stats.accuracy()) # 2. Validate on the validation set. valid_iter = make_dataset_iter(lazily_load_dataset("valid"), fields, opt, is_train=False) valid_stats = trainer.validate(valid_iter) print('Validation perplexity: %g' % valid_stats.ppl()) print('Validation accuracy: %g' % valid_stats.accuracy()) # 3. Log to remote server. if opt.exp_host: train_stats.log("train", experiment, optim.lr) valid_stats.log("valid", experiment, optim.lr) if opt.tensorboard: train_stats.log_tensorboard("train", writer, optim.lr, epoch) train_stats.log_tensorboard("valid", writer, optim.lr, epoch) # 4. Update the learning rate trainer.epoch_step(valid_stats.ppl(), epoch) # 5. Drop a checkpoint if needed. if epoch >= opt.start_checkpoint_at: trainer.drop_checkpoint(model_opt, epoch, fields, valid_stats)
def train_model(model, fields, optim, data_type, model_opt, time_str=''): train_loss_dict = make_mirror_loss(model, fields, opt) valid_loss_dict = make_mirror_loss(model, fields, opt) trunc_size = opt.truncated_decoder # Badly named... shard_size = opt.max_generator_batches trainer = onmt.Trainer(model, train_loss_dict, valid_loss_dict, optim, trunc_size, shard_size, data_type, opt.normalization, opt.accum_count, opt.back_factor, opt.kl_balance, opt.kl_fix) # train_datasets = lazily_load_dataset("train") # valid_datasets = lazily_load_dataset("valid") # train_iter = make_dataset_iter(train_datasets, fields, opt) # valid_iter = make_dataset_iter(valid_datasets, fields, opt, is_train=False) val_value = [] for epoch in range(opt.start_epoch, opt.epochs + 1): # print('') # train_iter = make_dataset_iter(train_datasets, fields, opt) # valid_iter = make_dataset_iter(valid_datasets, fields, opt, is_train=False) # 1. Train for one epoch on the training set. train_datasets = lazily_load_dataset("train") train_iter = make_dataset_iter(train_datasets, fields, opt) print("train iter size: {}".format(len(train_iter))) train_stats = trainer.train(train_iter, epoch, report_func) print('Train perplexity: %g' % train_stats.ppl()) print('Train accuracy: %g' % train_stats.accuracy()) del train_datasets, train_iter # 2. Validate on the validation set. valid_iter = make_dataset_iter(lazily_load_dataset("valid"), fields, opt, is_train=False) valid_size = valid_iter.__len__() with torch.no_grad(): valid_stats = trainer.validate(valid_iter) print('Validation perplexity: %g' % valid_stats.ppl()) print('Validation loss_total: %g' % (valid_stats.loss * 0.5)) print('Validation accuracy: %g' % valid_stats.accuracy()) # print('Validation kl: %g' % (valid_stats.loss_kl/valid_stats.counter)) print('Validation kl: %g' % (valid_stats.loss_kl / valid_size)) print("current kl_factor: {}".format(str(trainer.kl_knealling))) val_value.append( (valid_stats.loss * 0.5 + valid_stats.loss_kl) / valid_size) del valid_iter # 3. Log to remote server. if opt.exp_host: train_stats.log("train", experiment, optim.lr) valid_stats.log("valid", experiment, optim.lr) # 4. Update the learning rate trainer.epoch_step(valid_stats.ppl(), epoch) # 5. Drop a checkpoint if needed. # if epoch >= opt.start_checkpoint_at: if epoch >= opt.kl_balance: trainer.drop_checkpoint(model_opt, epoch, fields, valid_stats, time_str, val_value[-1]) if len(val_value) > (opt.kl_balance + 3) and val_value[-1] > val_value[ -2] and val_value[-2] > val_value[-3]: print("early stop due to the val loss") break print("**** Training Finished *****") assert trainer.best_model['model_name'] in trainer.checkpoint_list for cp in trainer.checkpoint_list: if cp != trainer.best_model['model_name']: print(cp) os.remove(cp) print("Cleaning redundant checkpoints") print("the best model path: {}".format(trainer.best_model['model_name'])) print("the best model ppl: {}".format(trainer.best_model['model_ppl']))
def train_model(model, fields, optim, data_type, model_opt): train_loss = make_loss_compute(model, fields["tgt"].vocab, opt) valid_loss = make_loss_compute(model, fields["tgt"].vocab, opt, train=False) trunc_size = opt.truncated_decoder # Badly named... shard_size = opt.max_generator_batches norm_method = opt.normalization grad_accum_count = opt.accum_count trainer = onmt.Trainer(model, train_loss, valid_loss, optim, trunc_size, shard_size, data_type, norm_method, grad_accum_count) logger.info('') logger.info('Start training...') logger.info(' * number of epochs: %d, starting from Epoch %d' % (opt.epochs + 1 - opt.start_epoch, opt.start_epoch)) logger.info(' * batch size: %d' % opt.batch_size) for epoch in range(opt.start_epoch, opt.epochs + 1): logger.info('The current epoch: %d' % epoch) # 1. Train for one epoch on the training set. train_iter = make_dataset_iter(lazily_load_dataset("train"), fields, opt) train_stats = trainer.train(train_iter, epoch, opt.max_src_len, opt.max_conv_len, report_func) logger.info('Train perplexity: %g' % train_stats.ppl()) logger.info('Train accuracy: %g' % train_stats.accuracy()) # 2. Validate on the validation set. valid_iter = make_dataset_iter(lazily_load_dataset("valid"), fields, opt, is_train=False) valid_stats = trainer.validate(valid_iter, opt.max_src_len, opt.max_conv_len) logger.info('Validation perplexity: %g' % valid_stats.ppl()) logger.info('Validation accuracy: %g' % valid_stats.accuracy()) # 3. Log to remote server. if opt.exp_host: train_stats.log("train", experiment, optim.lr) valid_stats.log("valid", experiment, optim.lr) if opt.tensorboard: train_stats.log_tensorboard("train", writer, optim.lr, epoch) train_stats.log_tensorboard("valid", writer, optim.lr, epoch) # 4. Update the learning rate decay = trainer.epoch_step(valid_stats.ppl(), epoch) if decay: logger.info("Decaying learning rate to %g" % trainer.optim.lr) # 5. Drop a checkpoint if needed. if epoch >= opt.start_checkpoint_at: trainer.drop_checkpoint(model_opt, epoch, fields, valid_stats) if trainer.optim.lr <= 1e-6: logger.info( "Learning rate %g is below 1e-6 at epoch %d. Stop training!!" % (trainer.optim.lr, epoch)) break
def train(enc_rnn_size: int, dec_rnn_size: int, src_word_vec_size: int, tgt_word_vec_size: int, dropout: float, learning_rate: float, train_steps: int, valid_steps: int, early_stopping_tolerance: int, preprocessed_data_path: str, save_model_path: str): vocab_fields = torch.load("{}.vocab.pt".format(preprocessed_data_path)) src_text_field = vocab_fields["src"].base_field src_vocab = src_text_field.vocab src_padding = src_vocab.stoi[src_text_field.pad_token] tgt_text_field = vocab_fields['tgt'].base_field tgt_vocab = tgt_text_field.vocab tgt_padding = tgt_vocab.stoi[tgt_text_field.pad_token] # Specify the core model. encoder_embeddings = onmt.modules.Embeddings(src_word_vec_size, len(src_vocab), word_padding_idx=src_padding, dropout=dropout) encoder = onmt.encoders.RNNEncoder(hidden_size=enc_rnn_size, num_layers=2, rnn_type="LSTM", bidirectional=True, embeddings=encoder_embeddings, dropout=dropout) decoder_embeddings = onmt.modules.Embeddings(tgt_word_vec_size, len(tgt_vocab), word_padding_idx=tgt_padding, dropout=dropout) decoder = onmt.decoders.decoder.InputFeedRNNDecoder( hidden_size=dec_rnn_size, num_layers=1, bidirectional_encoder=True, rnn_type="LSTM", embeddings=decoder_embeddings, dropout=dropout, attn_type=None) device = "cuda" if torch.cuda.is_available() else "cpu" model = onmt.models.model.NMTModel(encoder, decoder) model.to(device) # Specify the tgt word generator and loss computation module model.generator = nn.Sequential(nn.Linear(dec_rnn_size, len(tgt_vocab)), nn.LogSoftmax(dim=-1)).to(device) loss = onmt.utils.loss.NMTLossCompute(criterion=nn.NLLLoss( ignore_index=tgt_padding, reduction="sum"), generator=model.generator) torch_optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) optim = onmt.utils.optimizers.Optimizer(torch_optimizer, learning_rate=learning_rate, max_grad_norm=2) # Load some data from itertools import chain train_data_file = "{}.train.0.pt".format(preprocessed_data_path) valid_data_file = "{}.valid.0.pt".format(preprocessed_data_path) train_iter = onmt.inputters.inputter.DatasetLazyIter( dataset_paths=[train_data_file], fields=vocab_fields, batch_size=64, batch_size_multiple=1, batch_size_fn=None, device=device, is_train=True, repeat=True, pool_factor=1) valid_iter = onmt.inputters.inputter.DatasetLazyIter( dataset_paths=[valid_data_file], fields=vocab_fields, batch_size=128, batch_size_multiple=1, batch_size_fn=None, device=device, is_train=False, repeat=False, pool_factor=1) tensorboard = SummaryWriter(flush_secs=5) import logging import sys logging.basicConfig(stream=sys.stdout, level=logging.INFO) report_manager = onmt.utils.ReportMgr(report_every=64, tensorboard_writer=tensorboard) report_manager.start() early_stopper = EarlyStopping(early_stopping_tolerance) model_saver = ModelSaver(save_model_path, model, None, vocab_fields, optim) trainer = onmt.Trainer(model=model, train_loss=loss, valid_loss=loss, optim=optim, report_manager=report_manager, dropout=dropout, model_saver=model_saver, earlystopper=early_stopper) print("Starting training...") trainer.train(train_iter=train_iter, train_steps=train_steps, valid_iter=valid_iter, valid_steps=valid_steps) src_reader = onmt.inputters.str2reader["text"] tgt_reader = onmt.inputters.str2reader["text"] scorer = onmt.translate.GNMTGlobalScorer(alpha=0.7, beta=0., length_penalty="avg", coverage_penalty="none") gpu = 0 if torch.cuda.is_available() else -1 translator = onmt.translate.Translator(model=model, fields=vocab_fields, src_reader=src_reader, tgt_reader=tgt_reader, global_scorer=scorer, gpu=gpu) builder = onmt.translate.TranslationBuilder( data=torch.load(valid_data_file), fields=vocab_fields, has_tgt=True) pos_matches = count = 0 for batch in valid_iter: trans_batch = translator.translate_batch(batch=batch, src_vocabs=[src_vocab], attn_debug=False) translations = builder.from_batch(trans_batch) for trans in translations: pred = ' '.join(trans.pred_sents[0]) gold = ' '.join(trans.gold_sent) pos_matches += 1 if pred == gold else 0 count += 1 print("Acc: ", pos_matches / count)
def train_model(model, fields, optim, data_type, opt_per_pred): translate_parser = argparse.ArgumentParser( description='translate', formatter_class=argparse.ArgumentDefaultsHelpFormatter) onmt.opts.add_md_help_argument(translate_parser) onmt.opts.translate_opts(translate_parser) opt_translate = translate_parser.parse_args(args=[]) opt_translate.replace_unk = False opt_translate.verbose = False opt_translate.block_ngram_repeat = False if opt.gpuid: opt_translate.gpu = opt.gpuid[0] trunc_size = opt.truncated_decoder # Badly named... shard_size = opt.max_generator_batches norm_method = opt.normalization grad_accum_count = opt.accum_count trainer = {} for predicate in opt.parser.predicates: train_loss = make_loss_compute(model[predicate], fields[predicate]["tgt"].vocab, opt_per_pred[predicate]) valid_loss = make_loss_compute(model[predicate], fields[predicate]["tgt"].vocab, opt_per_pred[predicate], train=False) trainer[predicate] = onmt.Trainer(model[predicate], train_loss, valid_loss, optim[predicate], trunc_size, shard_size, data_type[predicate], norm_method, grad_accum_count) logger.info('') logger.info('Start training...') logger.info(' * number of epochs: %d, starting from Epoch %d' % (opt.epochs + 1 - opt.start_epoch, opt.start_epoch)) logger.info(' * batch size: %d' % opt.batch_size) for epoch in range(opt.start_epoch, opt.epochs + 1): logger.info('') train_stats = {} valid_stats = {} for predicate in opt.parser.predicates: logger.info('Train predicate: %s' % predicate) # 1. Train for one epoch on the training set. train_iter = make_dataset_iter( lazily_load_dataset("train", opt_per_pred[predicate]), fields[predicate], opt_per_pred[predicate]) train_stats[predicate] = trainer[predicate].train( train_iter, epoch, fields[predicate], report_func) logger.info('Train perplexity: %g' % train_stats[predicate].ppl()) logger.info('Train accuracy: %g' % train_stats[predicate].accuracy()) # 2. Validate on the validation set. valid_iter = make_dataset_iter(lazily_load_dataset( "valid", opt_per_pred[predicate]), fields[predicate], opt_per_pred[predicate], is_train=False) valid_stats[predicate] = trainer[predicate].validate(valid_iter) logger.info('Validation perplexity: %g' % valid_stats[predicate].ppl()) logger.info('Validation accuracy: %g' % valid_stats[predicate].accuracy()) # 3. Log to remote server. if opt_per_pred[predicate].exp_host: train_stats[predicate].log("train", experiment, optim[predicate].lr) valid_stats[predicate].log("valid", experiment, optim[predicate].lr) if opt_per_pred[predicate].tensorboard: train_stats[predicate].log_tensorboard("train", writer, optim[predicate].lr, epoch) train_stats[predicate].log_tensorboard("valid", writer, optim[predicate].lr, epoch) # 4. Update the learning rate decay = trainer[predicate].epoch_step(valid_stats[predicate].ppl(), epoch) if decay: logger.info("Decaying learning rate to %g" % trainer[predicate].optim.lr) # 5. Drop a checkpoint if needed. if epoch % 10 == 0: #epoch >= opt.start_checkpoint_at: opt_translates = [] for predicate in opt.parser.predicates: opt_translate.predicate = predicate opt_translate.batch_size = opt_per_pred[predicate].batch_size opt_translate.src = 'cache/valid_src_{:s}.txt'.format( opt_per_pred[predicate].file_templ) opt_translate.tgt = 'cache/valid_eval_refs_{:s}.txt'.format( opt_per_pred[predicate].file_templ) opt_translate.output = 'result/{:s}/valid_res_{:s}.txt'.format( opt_per_pred[predicate].dataset, opt_per_pred[predicate].file_templ) #opt_translate.model = '%s_acc_%.2f_ppl_%.2f_e%d.pt' % ( #opt_per_pred[predicate].save_model, valid_stats[predicate].accuracy(), valid_stats[predicate].ppl(), epoch) check_save_result_path(opt_translate.output) translator = make_translator(opt_translate, report_score=False, logger=logger, fields=fields[predicate], model=trainer[predicate].model, model_opt=opt_per_pred[predicate]) translator.output_beam = 'result/{:s}/valid_res_beam_{:s}.txt'.format( opt_per_pred[predicate].dataset, opt_per_pred[predicate].file_templ) #translator.beam_size = 5 #translator.n_best = 5 translator.translate(opt_translate.src_dir, opt_translate.src, opt_translate.tgt, opt_translate.batch_size, opt_translate.attn_debug) opt_translates.append(copy(opt_translate)) corpusBLEU, bleu, rouge, coverage, bleu_per_predicate = evaluate( opt_translates) for predicate in opt.parser.predicates: trainer[predicate].drop_checkpoint( opt_per_pred[predicate], epoch, corpusBLEU, bleu, rouge, coverage, bleu_per_predicate[predicate], fields[predicate], valid_stats[predicate])
def train_model(model, fields, optim, model_opt, swap_dict): train_loss = make_loss_compute(model, fields["tgt"].vocab, opt) valid_loss = make_loss_compute(model, fields["tgt"].vocab, opt, train=False) trunc_size = opt.truncated_decoder # Badly named... shard_size = opt.max_generator_batches norm_method = opt.normalization grad_accum_count = opt.accum_count trainer = onmt.Trainer(model, train_loss, valid_loss, optim, trunc_size, shard_size, norm_method, grad_accum_count, opt.select_model) print('\nStart training...') print(' * number of epochs: %d, starting from Epoch %d' % (opt.epochs + 1 - opt.start_epoch, opt.start_epoch)) print(' * batch size: %d' % opt.batch_size) for epoch in range(opt.start_epoch, opt.epochs + 1): print('') # 1. Train for one epoch on the training set. train_iter = make_dataset_iter(lazily_load_dataset("train"), fields, opt) train_stats = trainer.train(train_iter, epoch, opt, fields, validate_while_training, writer, report_func, opt.valid_pt, swap_dict) print('Train perplexity: %g' % train_stats.ppl()) print('Train accuracy: %g' % train_stats.accuracy()) # 2. Validate on the validation set. valid_iter = make_dataset_iter(lazily_load_dataset( "valid", valid_pt=opt.valid_pt), fields, opt, is_train=False) valid_stats = trainer.validate(valid_iter) print('Validation perplexity: %g' % valid_stats.ppl()) print('Validation accuracy: %g' % valid_stats.accuracy()) # Additional Step. Validate on BLEU. # translate.main(True, fields, model, model_opt) # 3. Log to remote server. # if opt.exp_host: # train_stats.log("train", experiment, optim.lr) # valid_stats.log("valid", experiment, optim.lr) # if opt.tensorboard: # train_stats.log_tensorboard("train", writer, optim.lr, epoch) # train_stats.log_tensorboard("valid", writer, optim.lr, epoch) # 4. Update the learning rate trainer.epoch_step(valid_stats.ppl(), epoch) # 5. Drop a checkpoint if needed. if epoch >= opt.start_checkpoint_at and epoch % opt.save_interval == 0: trainer.drop_checkpoint(model_opt, epoch, deepcopy(fields), valid_stats, 0)