Exemple #1
0
def test(est_model, est_args, args, test_samples, vocab, test_scores=None):
    est_model.eval()
    sample_idx = 0
    loss = 0
    total_loss_value = 0
    est_criterion = trainer_utils.set_criterion(est_args.loss)
    out_scores = []
    for sample in test_samples:
        sample_as_batch = [sample]
        pred_input = io_utils.create_predictor_input(sample_as_batch, vocab)

        #extract source sentence tokens and target sentence tokens from input
        source = pred_input[0]
        target = pred_input[1]
        source_mask = pred_input[2]
        target_mask = pred_input[3]

        # convert to autograd Variables
        source_input = Variable(torch.LongTensor(source), volatile=True).cuda()
        source_mask_input = Variable(torch.LongTensor(source_mask),
                                     volatile=True).cuda()
        target_ref = Variable(torch.LongTensor(target), volatile=True).cuda()
        target_ref_mask = Variable(torch.LongTensor(target_mask),
                                   volatile=True).cuda()
        target_length = target_ref.size()[0]

        model_input = (source_input, source_mask_input, target_ref,
                       target_ref_mask)
        est_score, log_probs = est_model(model_input)

        out_scores.append(est_score.data[0][0])  # only one element in output
        if test_scores:
            scores_ref = Variable(torch.FloatTensor([test_scores[sample_idx]
                                                     ])).cuda()
            est_loss = est_criterion(est_score, scores_ref)
            total_loss_value += (est_loss.data[0])

        sample_idx += 1
        if (args.debug == True):
            return 0.0, 0.0
    assert sample_idx == len(
        test_samples), "error in dimension of samples and testset"
    if test_scores:
        avg_loss = total_loss_value / len(test_samples)
    else:
        avg_loss = None
    return out_scores, avg_loss
def train(model, args, trainset_reader, vocab, validset_reader=None):

    debug = args.debug

    # for logging
    total_loss_value = 0

    #setting optimizers
    optimizer = trainer_utils.set_optimizer(args.optimizer)(
        model.parameters(), lr=args.learning_rate)

    #setting loss function
    criterion = nn.CrossEntropyLoss(ignore_index=0)

    trainset_reader.reset()
    num_batches = None
    best_valid_loss = None
    best_model = None
    is_best = False
    for epoch_idx in range(1, args.num_epochs + 1):

        # shuffling trainset
        logger.info("shuffling batches...")
        random.seed(args.seed + (epoch_idx - 1))
        if trainset_reader.shuffle_batches:
            trainset_reader.shuffle()

        # initializing minibatch
        minibatch_idx = 0
        minibatch = trainset_reader.next()

        while (minibatch):
            minibatch_idx += 1
            train_input = io_utils.create_predictor_input(minibatch, vocab)
            loss_value = trainer.train_step(train_input,
                                            model,
                                            optimizer,
                                            criterion,
                                            clip_norm=args.clip_norm,
                                            debug=debug)

            # calculating total loss for logging (per epoch)
            total_loss_value += loss_value

            # logging after set interval
            if minibatch_idx % args.log_interval == 0:
                trainer_utils.log_train_info(epoch_idx, minibatch_idx,
                                             total_loss_value, num_batches)

            if (debug == True):
                return

            # read next batch
            minibatch = trainset_reader.next()

        num_batches = minibatch_idx
        trainer_utils.log_train_info(epoch_idx, minibatch_idx,
                                     total_loss_value, num_batches)

        logger.info("epoch {} completed.".format(epoch_idx))
        total_loss_value = 0

        # validation
        if validset_reader:
            valid_loss = trainer.run_validation(model,
                                                validset_reader,
                                                vocab,
                                                debug=debug)
            is_best = False
            if best_valid_loss is None or best_valid_loss > valid_loss:
                best_epoch_idx = epoch_idx
                best_valid_loss = valid_loss
                is_best = True

            logger.info(
                'validation: average loss per batch = %.4f (best %.4f @ epoch %d)'
                % (valid_loss, best_valid_loss, best_epoch_idx))

            state = {
                'epoch': epoch_idx,
                'vocab': vocab,
                'args': args,
                'state_dict': model.state_dict(),
                'best_valid_loss': best_valid_loss,
                'best_epoch_idx': best_epoch_idx,
                'optimizer': optimizer.state_dict(),
            }
            model_path = args.output_dir + '/model.epoch' + str(
                epoch_idx) + '.pt'
            best_model_path = args.output_dir + '/model.best.pt'
            trainer_utils.save_checkpoint(state,
                                          args.save_after_epochs,
                                          is_best,
                                          model_path=model_path,
                                          best_model_path=best_model_path)
def train(est_model, pred_model, args, trainset_reader, vocab, validset_reader, testset_readers=None):
    """ Training function """
    debug=args.debug

    # for logging
    total_loss_value = 0

    #setting optimizers
    est_optimizer = trainer_utils.set_optimizer(args.optimizer)(filter(lambda p: p.requires_grad, est_model.parameters()), lr=args.learning_rate, weight_decay=args.weight_decay)

    #setting loss function
    est_criterion = trainer_utils.set_criterion(args.loss)

    trainset_reader.reset()
    num_batches = None
    best_valid_loss = None
    best_model = None

    patience = 0
    for epoch_idx in range(1,args.num_epochs+1):
        # shuffling trainset
        logger.info("shuffling batches...")
        random.seed(args.seed + (epoch_idx-1))
        if trainset_reader.shuffle_batches:
            trainset_reader.shuffle()

        # initializing minibatch
        minibatch_idx = 0
        minibatch = trainset_reader.next()

        while(minibatch):
            minibatch_idx += 1

            # split into predictor input and estimator target scores
            pred_minibatch = [(src,hyp) for src,hyp,score in minibatch]
            scores = [score for src,hyp,score in minibatch]

            # create input as source, hypothesis pairs and their masks indexed with vocab
            train_input = io_utils.create_predictor_input(pred_minibatch,vocab)

            # perform a step of trainining
            loss_value = estimator_trainer.train_step(train_input, scores, est_model, est_optimizer, est_criterion, clip_norm=args.clip_norm, debug=args.debug)

            # calculating total loss for logging (per epoch)
            total_loss_value += loss_value

            # logging after set interval
            if minibatch_idx % args.log_interval == 0:
                trainer_utils.log_train_info(epoch_idx, minibatch_idx, total_loss_value, num_batches)
            if(debug==True):
                return

            # read next batch
            minibatch = trainset_reader.next()

        # find total number of batches
        num_batches = minibatch_idx

        # print the training log
        trainer_utils.log_train_info(epoch_idx, minibatch_idx, total_loss_value, num_batches)

        # completing one epoch
        logger.info("epoch {} completed.".format(epoch_idx))
        total_loss_value = 0

        #################
        # validation
        #################
        valid_loss, metric_scores = estimator_trainer.run_validation(est_model, validset_reader, vocab, est_criterion, metrics=args.metrics, debug=debug)

        is_best = False
        patience += 1
        if best_valid_loss is None or best_valid_loss > valid_loss:
            best_epoch_idx = epoch_idx
            best_valid_loss = valid_loss
            is_best = True
            patience = 0


        logger.info('epoch {0} validation \t\t| average {1} loss/batch = {2:.4f} (best {3:.4f} @ epoch {4})'.format(epoch_idx, args.loss, valid_loss, best_valid_loss, best_epoch_idx))
        if metric_scores:
            logger.info('epoch {0} validation \t\t| '.format(epoch_idx) + ', '.join(["{0}={1:.4f}".format(metric,score) for metric,score in metric_scores.items()]))

        state = {
            'epoch': epoch_idx,
            'args':args,
            'state_dict': est_model.state_dict(),
            'best_valid_loss': best_valid_loss,
            'best_epoch_idx': best_epoch_idx,
            'optimizer' : est_optimizer.state_dict(),
        }

        ##############
        # testing
        ##############
        if (testset_readers):
            for testset_reader in testset_readers:
                test_loss, metric_scores = estimator_trainer.run_validation(est_model, testset_reader, vocab, est_criterion, metrics=args.metrics, debug=debug)
                logger.info('epoch {0} testing on {1} \t\t| average {2} loss/batch = {3:.4f}'.format(epoch_idx,testset_reader.source_dataset_path,args.loss, test_loss))
                if metric_scores:
                    logger.info('epoch {0} testing on {1} \t\t| '.format(epoch_idx,testset_reader.source_dataset_path) + ', '.join(["{0}={1:.4f}".format(metric,score) for metric,score in metric_scores.items()]))

        ## saving the model
        est_model_path = args.output_dir + '/est_model.epoch' + str(epoch_idx) + '.pt'
        est_best_model_path = args.output_dir + '/est_model.best.pt'
        logger.info("saving model...")
        trainer_utils.save_checkpoint(state, args.save_after_epochs, is_best, args.no_save_best, est_model_path, est_best_model_path)

        if (patience >= args.patience):
            logger.info("early stopping at epoch {} (patience param: {})".format(epoch_idx, args.patience))
            logger.info("training complete.")
            break