def test(est_model, est_args, args, test_samples, vocab, test_scores=None): est_model.eval() sample_idx = 0 loss = 0 total_loss_value = 0 est_criterion = trainer_utils.set_criterion(est_args.loss) out_scores = [] for sample in test_samples: sample_as_batch = [sample] pred_input = io_utils.create_predictor_input(sample_as_batch, vocab) #extract source sentence tokens and target sentence tokens from input source = pred_input[0] target = pred_input[1] source_mask = pred_input[2] target_mask = pred_input[3] # convert to autograd Variables source_input = Variable(torch.LongTensor(source), volatile=True).cuda() source_mask_input = Variable(torch.LongTensor(source_mask), volatile=True).cuda() target_ref = Variable(torch.LongTensor(target), volatile=True).cuda() target_ref_mask = Variable(torch.LongTensor(target_mask), volatile=True).cuda() target_length = target_ref.size()[0] model_input = (source_input, source_mask_input, target_ref, target_ref_mask) est_score, log_probs = est_model(model_input) out_scores.append(est_score.data[0][0]) # only one element in output if test_scores: scores_ref = Variable(torch.FloatTensor([test_scores[sample_idx] ])).cuda() est_loss = est_criterion(est_score, scores_ref) total_loss_value += (est_loss.data[0]) sample_idx += 1 if (args.debug == True): return 0.0, 0.0 assert sample_idx == len( test_samples), "error in dimension of samples and testset" if test_scores: avg_loss = total_loss_value / len(test_samples) else: avg_loss = None return out_scores, avg_loss
def train(model, args, trainset_reader, vocab, validset_reader=None): debug = args.debug # for logging total_loss_value = 0 #setting optimizers optimizer = trainer_utils.set_optimizer(args.optimizer)( model.parameters(), lr=args.learning_rate) #setting loss function criterion = nn.CrossEntropyLoss(ignore_index=0) trainset_reader.reset() num_batches = None best_valid_loss = None best_model = None is_best = False for epoch_idx in range(1, args.num_epochs + 1): # shuffling trainset logger.info("shuffling batches...") random.seed(args.seed + (epoch_idx - 1)) if trainset_reader.shuffle_batches: trainset_reader.shuffle() # initializing minibatch minibatch_idx = 0 minibatch = trainset_reader.next() while (minibatch): minibatch_idx += 1 train_input = io_utils.create_predictor_input(minibatch, vocab) loss_value = trainer.train_step(train_input, model, optimizer, criterion, clip_norm=args.clip_norm, debug=debug) # calculating total loss for logging (per epoch) total_loss_value += loss_value # logging after set interval if minibatch_idx % args.log_interval == 0: trainer_utils.log_train_info(epoch_idx, minibatch_idx, total_loss_value, num_batches) if (debug == True): return # read next batch minibatch = trainset_reader.next() num_batches = minibatch_idx trainer_utils.log_train_info(epoch_idx, minibatch_idx, total_loss_value, num_batches) logger.info("epoch {} completed.".format(epoch_idx)) total_loss_value = 0 # validation if validset_reader: valid_loss = trainer.run_validation(model, validset_reader, vocab, debug=debug) is_best = False if best_valid_loss is None or best_valid_loss > valid_loss: best_epoch_idx = epoch_idx best_valid_loss = valid_loss is_best = True logger.info( 'validation: average loss per batch = %.4f (best %.4f @ epoch %d)' % (valid_loss, best_valid_loss, best_epoch_idx)) state = { 'epoch': epoch_idx, 'vocab': vocab, 'args': args, 'state_dict': model.state_dict(), 'best_valid_loss': best_valid_loss, 'best_epoch_idx': best_epoch_idx, 'optimizer': optimizer.state_dict(), } model_path = args.output_dir + '/model.epoch' + str( epoch_idx) + '.pt' best_model_path = args.output_dir + '/model.best.pt' trainer_utils.save_checkpoint(state, args.save_after_epochs, is_best, model_path=model_path, best_model_path=best_model_path)
def train(est_model, pred_model, args, trainset_reader, vocab, validset_reader, testset_readers=None): """ Training function """ debug=args.debug # for logging total_loss_value = 0 #setting optimizers est_optimizer = trainer_utils.set_optimizer(args.optimizer)(filter(lambda p: p.requires_grad, est_model.parameters()), lr=args.learning_rate, weight_decay=args.weight_decay) #setting loss function est_criterion = trainer_utils.set_criterion(args.loss) trainset_reader.reset() num_batches = None best_valid_loss = None best_model = None patience = 0 for epoch_idx in range(1,args.num_epochs+1): # shuffling trainset logger.info("shuffling batches...") random.seed(args.seed + (epoch_idx-1)) if trainset_reader.shuffle_batches: trainset_reader.shuffle() # initializing minibatch minibatch_idx = 0 minibatch = trainset_reader.next() while(minibatch): minibatch_idx += 1 # split into predictor input and estimator target scores pred_minibatch = [(src,hyp) for src,hyp,score in minibatch] scores = [score for src,hyp,score in minibatch] # create input as source, hypothesis pairs and their masks indexed with vocab train_input = io_utils.create_predictor_input(pred_minibatch,vocab) # perform a step of trainining loss_value = estimator_trainer.train_step(train_input, scores, est_model, est_optimizer, est_criterion, clip_norm=args.clip_norm, debug=args.debug) # calculating total loss for logging (per epoch) total_loss_value += loss_value # logging after set interval if minibatch_idx % args.log_interval == 0: trainer_utils.log_train_info(epoch_idx, minibatch_idx, total_loss_value, num_batches) if(debug==True): return # read next batch minibatch = trainset_reader.next() # find total number of batches num_batches = minibatch_idx # print the training log trainer_utils.log_train_info(epoch_idx, minibatch_idx, total_loss_value, num_batches) # completing one epoch logger.info("epoch {} completed.".format(epoch_idx)) total_loss_value = 0 ################# # validation ################# valid_loss, metric_scores = estimator_trainer.run_validation(est_model, validset_reader, vocab, est_criterion, metrics=args.metrics, debug=debug) is_best = False patience += 1 if best_valid_loss is None or best_valid_loss > valid_loss: best_epoch_idx = epoch_idx best_valid_loss = valid_loss is_best = True patience = 0 logger.info('epoch {0} validation \t\t| average {1} loss/batch = {2:.4f} (best {3:.4f} @ epoch {4})'.format(epoch_idx, args.loss, valid_loss, best_valid_loss, best_epoch_idx)) if metric_scores: logger.info('epoch {0} validation \t\t| '.format(epoch_idx) + ', '.join(["{0}={1:.4f}".format(metric,score) for metric,score in metric_scores.items()])) state = { 'epoch': epoch_idx, 'args':args, 'state_dict': est_model.state_dict(), 'best_valid_loss': best_valid_loss, 'best_epoch_idx': best_epoch_idx, 'optimizer' : est_optimizer.state_dict(), } ############## # testing ############## if (testset_readers): for testset_reader in testset_readers: test_loss, metric_scores = estimator_trainer.run_validation(est_model, testset_reader, vocab, est_criterion, metrics=args.metrics, debug=debug) logger.info('epoch {0} testing on {1} \t\t| average {2} loss/batch = {3:.4f}'.format(epoch_idx,testset_reader.source_dataset_path,args.loss, test_loss)) if metric_scores: logger.info('epoch {0} testing on {1} \t\t| '.format(epoch_idx,testset_reader.source_dataset_path) + ', '.join(["{0}={1:.4f}".format(metric,score) for metric,score in metric_scores.items()])) ## saving the model est_model_path = args.output_dir + '/est_model.epoch' + str(epoch_idx) + '.pt' est_best_model_path = args.output_dir + '/est_model.best.pt' logger.info("saving model...") trainer_utils.save_checkpoint(state, args.save_after_epochs, is_best, args.no_save_best, est_model_path, est_best_model_path) if (patience >= args.patience): logger.info("early stopping at epoch {} (patience param: {})".format(epoch_idx, args.patience)) logger.info("training complete.") break