예제 #1
0
def train_model(model, fields, optim, data_type, model_opt, train_part):
    train_loss = make_loss_compute(model, fields["tgt"].vocab, opt)
    valid_loss = make_loss_compute(model,
                                   fields["tgt"].vocab,
                                   opt,
                                   train=False)

    trunc_size = opt.truncated_decoder  # Badly named...
    shard_size = opt.max_generator_batches
    norm_method = opt.normalization
    grad_accum_count = opt.accum_count

    trainer = onmt.Trainer(model, train_loss, valid_loss, optim, trunc_size,
                           shard_size, data_type, norm_method,
                           grad_accum_count)

    print('\nStart training...')
    print(' * number of epochs: %d, starting from Epoch %d' %
          (opt.epochs + 1 - opt.start_epoch, opt.start_epoch))
    print(' * batch size: %d' % opt.batch_size)

    #for epoch in range(opt.start_epoch, opt.epochs + 1):
    for epoch in range(1, opt.epochs + 1):
        print(f"Start to train on {epoch}/{opt.epochs}")

        # 1. Train for one epoch on the training set.
        train_iter = None
        train_iter = make_dataset_iter(lazily_load_dataset("train"), fields,
                                       opt)
        train_stats = trainer.train(train_iter, epoch, report_func, train_part,
                                    model_opt, fields)
        print('Train perplexity: %g' % train_stats.ppl())
        print('Train accuracy: %g' % train_stats.accuracy())

        # 2. Validate on the validation set.
        train_iter = None
        valid_iter = None
        valid_iter = make_dataset_iter(lazily_load_dataset("valid"),
                                       fields,
                                       opt,
                                       is_train=False)
        valid_stats = trainer.validate(valid_iter, train_part)
        print(f"Epoch: {epoch}")
        print('Validation perplexity: %g' % valid_stats.ppl())
        print('Validation accuracy: %g' % valid_stats.accuracy())

        # 3. Log to remote server.
        if opt.exp_host:
            train_stats.log("train", experiment, optim.lr)
            valid_stats.log("valid", experiment, optim.lr)
        if opt.tensorboard:
            train_stats.log_tensorboard("train", writer, optim.lr, epoch)
            train_stats.log_tensorboard("valid", writer, optim.lr, epoch)

        # 4. Update the learning rate
        trainer.epoch_step(valid_stats.ppl(), epoch)

        # 5. Drop a checkpoint if needed.
        if epoch >= opt.start_checkpoint_at:
            trainer.drop_checkpoint(model_opt, epoch, fields, valid_stats)
예제 #2
0
def train_model(model, train_data, valid_data, fields, optim, opt, num_runs):

    train_iter = make_train_data_iter(train_data, opt)
    valid_iter = make_valid_data_iter(valid_data, opt)

    train_loss = make_loss_compute(model, fields["tgt"].vocab,
                                   train_data, opt)
    valid_loss = make_loss_compute(model, fields["tgt"].vocab,
                                   valid_data, opt)

    trunc_size = opt.truncated_decoder  # Badly named...
    shard_size = opt.max_generator_batches

    trainer = onmt.Trainer(model, train_iter, valid_iter,
                           train_loss, valid_loss, optim,
                           trunc_size, shard_size)
    # pdb.set_trace()

    for epoch in range(1):
        print('')

        # 1. Train for one epoch on the training set.
        train_stats = trainer.train(epoch, report_func)
        print('Train perplexity: %g' % train_stats.ppl())
        print('Train accuracy: %g' % train_stats.accuracy())
        logger.scalar_summary('train_%s_ppl' % num_runs, train_stats.ppl(), num_runs*(opt.epochs+1)+epoch)
        logger.scalar_summary('train_%s_acc' % num_runs, train_stats.accuracy(), num_runs*(opt.epochs+1)+epoch)
예제 #3
0
def train_model(model, train_data, valid_data, fields, optim):

    min_ppl, max_accuracy = float('inf'), -1

    train_iter = make_train_data_iter(train_data, opt)
    valid_iter = make_valid_data_iter(valid_data, opt)

    train_loss = make_loss_compute(model, fields["tgt"].vocab,
                                   train_data, opt)
    valid_loss = make_loss_compute(model, fields["tgt"].vocab,
                                   valid_data, opt)

    trunc_size = opt.truncated_decoder  # Badly named...
    shard_size = opt.max_generator_batches

    trainer = onmt.Trainer(model, train_iter, valid_iter,
                           train_loss, valid_loss, optim,
                           trunc_size, shard_size)

    for epoch in range(opt.start_epoch, opt.epochs + 1):
        print('')

        # 1. Train for one epoch on the training set.
        train_stats = trainer.train(epoch, report_func)
        print('Train perplexity: %g' % train_stats.ppl())
        print('Train accuracy: %g' % train_stats.accuracy())

        # 2. Validate on the validation set.
        valid_stats = trainer.validate()
        print('Validation perplexity: %g' % valid_stats.ppl())
        print('Validation accuracy: %g' % valid_stats.accuracy())

        # 3. Log to remote server.
        if opt.exp_host:
            train_stats.log("train", experiment, optim.lr)
            valid_stats.log("valid", experiment, optim.lr)

        # 4. Update the learning rate
        trainer.epoch_step(valid_stats.ppl(), epoch)

        # 5. Drop a checkpoint if needed.
        if epoch >= opt.start_checkpoint_at:
            if valid_stats.accuracy() > max_accuracy:
                # 5.1 drop checkpoint when bigger accuracy is achieved.
                min_ppl = min(valid_stats.ppl(), min_ppl)
                max_accuracy = max(valid_stats.accuracy(), max_accuracy)
                trainer.drop_checkpoint(opt, epoch, fields, valid_stats)
                print('Save model according to biggest-ever accuracy: acc: {0}, ppl: {1}'.format(max_accuracy, min_ppl))

            elif valid_stats.ppl() < min_ppl:
                # 5.2 drop checkpoint when smaller ppl is achieved.
                min_ppl = min(valid_stats.ppl(), min_ppl)
                max_accuracy = max(valid_stats.accuracy(), max_accuracy)
                trainer.drop_checkpoint(opt, epoch, fields, valid_stats)
                print('Save model according to lowest-ever ppl: acc: {0}, ppl: {1}'.format(max_accuracy, min_ppl))
예제 #4
0
def train_model(model1, model2, train_data, valid_data, fields1, fields2,  optim1, optim2):

    train_iter = make_train_data_iter(train_data, opt)
    valid_iter = make_valid_data_iter(valid_data, opt)

    train_loss1, train_loss2 = make_loss_compute(model1, model2, fields1["tgt"].vocab,
                                   train_data, opt)
    valid_loss1, valid_loss2 = make_loss_compute(model1, model2, fields1["tgt"].vocab,
                                   valid_data, opt)

    trunc_size = opt.truncated_decoder  # Badly named...
    shard_size = opt.max_generator_batches

    trainer = onmt.Trainer(model1, model2,  train_iter, valid_iter,
                           train_loss1, valid_loss1, train_loss2, valid_loss2, optim1, optim2,
                           trunc_size, shard_size)

    for epoch in range(opt.start_epoch, opt.epochs + 1):
        print('')

        # 1. Train for one epoch on the training set.
        train_stats1, train_stats2 = trainer.train(epoch, report_func)
        print('Train perplexity 1: %g' % train_stats1.ppl())
        print('Train accuracy 1: %g' % train_stats1.accuracy())

        print('Train perplexity 2: %g' % train_stats2.ppl())
        print('Train accuracy 2: %g' % train_stats2.accuracy())

        # 2. Validate on the validation set.
        valid_stats1, valid_stats2 = trainer.validate()
        print('Validation perplexity 1: %g' % valid_stats1.ppl())
        print('Validation accuracy 1: %g' % valid_stats1.accuracy())

        print('Validation perplexity 2: %g' % valid_stats2.ppl())
        print('Validation accuracy 2: %g' % valid_stats2.accuracy())

        # 3. Log to remote server.
        if opt.exp_host:
            train_stats1.log("train", experiment, optim1.lr)
            valid_stats1.log("valid", experiment, optim1.lr)
            train_stats2.log("train", experiment, optim2.lr)
            valid_stats2.log("valid", experiment, optim2.lr)

        # 4. Update the learning rate
        trainer.epoch_step(valid_stats1.ppl(), epoch)
        trainer.epoch_step(valid_stats2.ppl(), epoch)

        # 5. Drop a checkpoint if needed.
        if epoch >= opt.start_checkpoint_at:
            trainer.drop_checkpoint(opt, epoch, fields1, fields2, valid_stats1, valid_stats2)
예제 #5
0
def train_model(model, train_dataset, valid_dataset, fields, optim, model_opt):

    print("making iters")
    train_iter = make_train_data_iter(train_dataset, opt)
    valid_iter = make_valid_data_iter(valid_dataset, opt)

    print("making losses")
    train_loss = make_loss_compute(model, fields["tgt"].vocab, train_dataset,
                                   opt)
    valid_loss = make_loss_compute(model, fields["tgt"].vocab, valid_dataset,
                                   opt)

    trunc_size = opt.truncated_decoder  # Badly named...
    shard_size = opt.max_generator_batches
    data_type = train_dataset.data_type

    print("making trainer")
    trainer = onmt.Trainer(model, train_iter, valid_iter, train_loss,
                           valid_loss, optim, trunc_size, shard_size,
                           data_type)

    print("made trainer")
    for epoch in range(opt.start_epoch, opt.epochs + 1):
        print('')

        # 1. Train for one epoch on the training set.
        train_stats = trainer.train(epoch, report_func)
        print('Train perplexity: %g' % train_stats.ppl())
        print('Train accuracy: %g' % train_stats.accuracy())

        # 2. Validate on the validation set.
        valid_stats = trainer.validate()
        print('Validation perplexity: %g' % valid_stats.ppl())
        print('Validation accuracy: %g' % valid_stats.accuracy())

        # 3. Log to remote server.
        if opt.exp_host:
            train_stats.log("train", experiment, optim.lr)
            valid_stats.log("valid", experiment, optim.lr)

        # 4. Update the learning rate
        trainer.epoch_step(valid_stats.ppl(), epoch)

        # 5. Drop a checkpoint if needed.
        if epoch >= opt.start_checkpoint_at:
            trainer.drop_checkpoint(model_opt, epoch, fields, valid_stats)
예제 #6
0
def create_trainer(
    encoder_vocab,
    decoder_vocab,
    device,
    decoder_hidden_size=256,
    embedding_padding_idx=0,
    report_every=50,
):
    # Get the model
    model = create_model(
        encoder_vocab=encoder_vocab,
        decoder_vocab=decoder_vocab,
        batch_size=64,
    )
    model.to(device)

    # Set the optimizer
    optimizer = onmt.utils.Optimizer(
        optimizer=torch.optim.SGD(model.parameters(), lr=1),
        learning_rate=1,
        learning_rate_decay_fn=lambda n: 1,
    )

    # Set the loss function
    model.generator = torch.nn.Sequential(
        torch.nn.Linear(decoder_hidden_size, len(decoder_vocab)),
        torch.nn.LogSoftmax(dim=-1)).to(device)
    loss = onmt.utils.loss.NMTLossCompute(criterion=torch.nn.NLLLoss(
        ignore_index=embedding_padding_idx, reduction='sum'),
                                          generator=model.generator)

    # Reports
    report_manager = onmt.utils.ReportMgr(report_every=report_every,
                                          start_time=None,
                                          tensorboard_writer=None)

    # Finally get the trainer
    return model, onmt.Trainer(
        model=model,
        optim=optimizer,
        train_loss=loss,
        valid_loss=loss,
        report_manager=report_manager,
    )
예제 #7
0
def train_model(auto_models, valid_model, train_data, valid_data, fields_list,
                valid_fields, optims, discrim_models, discrim_optims, labels):

    #     train_model(models, valid_model, train, valid, fields, fields_valid, optims,
    #                 discrim_models, discrim_optims, advers_optims, labels)

    trunc_size = opt.truncated_decoder  # Badly named...
    shard_size = opt.max_generator_batches

    valid_iter = make_valid_data_iter(valid_data, opt)
    valid_loss = make_loss_compute(valid_model, valid_fields["tgt"].vocab,
                                   valid_data, opt)
    trainers = []
    trainers.append(
        onmt.Trainer(valid_model, valid_iter, valid_iter, valid_loss,
                     valid_loss, optims[0], trunc_size, shard_size))

    for model, discrim_model, optim, label, train, fields in zip(
            auto_models, discrim_models, optims, labels, train_data,
            fields_list):
        train_iter = make_train_data_iter(train, opt)
        train_loss = make_loss_compute(model, fields["tgt"].vocab, train, opt)
        trainers.append(
            onmt.AdvTrainer(model, discrim_model, train_iter, valid_iter,
                            train_loss, valid_loss, optim, label, trunc_size,
                            shard_size))

    discrim_trainers = []
    for model, discrim_optim, data in zip(discrim_models, discrim_optims,
                                          train_data):
        train_iter = make_train_data_iter(data, opt)
        discrim_trainers.append(
            onmt.DiscrimTrainer(model, train_iter, discrim_optim, shard_size))
    #for model, optim, data in zip(discrim_models, advers_optims, advers_data):
    #    train_iter = make_train_data_iter(data, opt)
    #    discrim_trainers.append(onmt.DiscrimTrainer(model, train_iter, optim, shard_size))

    for epoch in range(opt.start_epoch, opt.epochs + 1):
        print('')

        for label, trainer in zip(labels, discrim_trainers):
            # 1. Train for one epoch on the training set.
            train_stats = trainer.train(epoch, label, discrim_report_func)
            print('Train loss: %g' % train_stats.loss)
            #print('Train accuracy: %g' % train_stats.accuracy())

            if opt.exp_host:
                train_stats.log("train", experiment, optim.lr)

        for trainer in trainers[1:]:
            # 1. Train for one epoch on the training set.
            train_stats = trainer.train(epoch, report_func)
            print('Train perplexity: %g' % train_stats.ppl())
            print('Train accuracy: %g' % train_stats.accuracy())

            if opt.exp_host:
                train_stats.log("train", experiment, optim.lr)

        # 2. Validate on the validation set.
        valid_stats = trainers[0].validate()
        print('Validation perplexity: %g' % valid_stats.ppl())
        print('Validation accuracy: %g' % valid_stats.accuracy())

        # 3. Log to remote server.
        if opt.exp_host:
            valid_stats.log("valid", experiment, optim.lr)
        '''
        for trainer in trainers[1:]:
            # 4. Update the learning rate
            trainer.epoch_step(valid_stats.ppl(), epoch)
            
        for trainer in discrim_trainers:
            # 4. Update the learning rate
            trainer.epoch_step(valid_stats.ppl(), epoch)
        '''

        # 5. Drop a checkpoint if needed.
        if epoch >= opt.start_checkpoint_at:
            trainers[0].drop_checkpoint(opt, epoch, fields, valid_stats)
예제 #8
0
    args, extra_args = argparser.parse_known_args()
    opt = Configurable(args.config_file, extra_args)

    model = ADVModel(opt)
    optim = onmt.Optim(
            opt.optim, opt.learning_rate, opt.max_grad_norm,
            lr_decay=opt.learning_rate_decay,
            start_decay_at=opt.start_decay_at,
            beta1=opt.adam_beta1,
            beta2=opt.adam_beta2,
            adagrad_accum=opt.adagrad_accumulator_init,
            decay_method=opt.decay_method,
            warmup_steps=opt.warmup_steps,
            model_size=opt.rnn_size)
    optim.set_parameters(model.named_parameters())
    tgt_vocab = Vocab(opt.tgt_vocab)
    loss_compute = onmt.Loss.NMTLossCompute(model.generator, tgt_vocab).cuda()
    trainer = onmt.Trainer(model, loss_compute, loss_compute, optim)
    train_set = Data_Loader(opt.train_file, opt.batch_size)
    valid_set = Data_Loader(opt.dev_file, opt.batch_size)
    for epoch in xrange(opt.max_epoch):
        train_stats = trainer.train(train_set, epoch, report_func)
        print('Train perplexity: %g' % train_stats.ppl())
        print('Train accuracy: %g' % train_stats.accuracy())

        valid_stats = trainer.validate(valid_set)
        print('Validation perplexity: %g' % valid_stats.ppl())
        print('Validation accuracy: %g' % valid_stats.accuracy())

        trainer.epoch_step(valid_stats.ppl(), epoch)
        model.save_checkpoint(epoch, opt)
def train_model(model, fields, optim, data_type, model_opt):
    if opt.emotional_words:
        print("Emotional words only")
        embedding_copy = torch.load(opt.emotional_words)
    else:
        embedding_copy = model.decoder.embeddings.embedding_copy
    train_loss = make_loss_compute(model, fields["tgt"].vocab, opt, embedding_copy)
    valid_loss = make_loss_compute(model, fields["tgt"].vocab, opt, embedding_copy)

    trunc_size = opt.truncated_decoder  # Badly named...
    shard_size = opt.max_generator_batches
    norm_method = opt.normalization
    grad_accum_count = opt.accum_count

    trainer = onmt.Trainer(model, train_loss, valid_loss, optim,
                           trunc_size, shard_size, data_type,
                           norm_method, grad_accum_count, opt.affect_bias, fields["tgt"].vocab)

    print('\nStart training...')
    print(' * number of epochs: %d, starting from Epoch %d' %
          (opt.epochs + 1 - opt.start_epoch, opt.start_epoch))
    print(' * batch size: %d' % opt.batch_size)

    # Create a file to save validation metrics
    val_metrics_file = "./data/log/" + opt.exp + ".csv"
    with open(val_metrics_file, "w") as f:
        f.write("loss,perplexity,distinct-1,distinct-2,embed_greedy,embed_avg,embed_extrema,affect_distance,affect_strength\n")

    for epoch in range(opt.start_epoch, opt.epochs + 1):
        print('')

        # 1. Train for one epoch on the training set.
        train_iter = make_dataset_iter(lazily_load_dataset("train"),
                                       fields, opt, repeat=False)
        valid_iter = make_dataset_iter(lazily_load_dataset("valid"),
                                fields, opt,
                                is_train=False, repeat=True)
        # Run validation every a few iterations px
        train_stats = trainer.train(train_iter, epoch, report_func, valid_iter, opt.evaluate_every, report_evaluation) 
        print('Train perplexity: %g' % train_stats.ppl())
        print('Train accuracy: %g' % train_stats.accuracy())

        # 2. Validate on the validation set.
        valid_stats, num_val_batches = trainer.validate(valid_iter)
        report_evaluation(valid_stats, num_val_batches)

        # Save validation metrics file
        with open(val_metrics_file, "a") as f:
            f.write("{0},{1},{2},{3},{4},{5},{6},{7},{8}\n".format(
                valid_stats.loss/valid_stats.n_words, valid_stats.ppl(), valid_stats.distinct_1(), valid_stats.distinct_2(),
                valid_stats.embed_greedy/num_val_batches, valid_stats.embed_avg/num_val_batches, valid_stats.embed_extrema/num_val_batches, 
                valid_stats.affect_dist/num_val_batches, valid_stats.affect_strength/num_val_batches))


        
        # 3. Log to remote server.
        if opt.exp_host:
            train_stats.log("train", experiment, optim.lr)
            valid_stats.log("valid", experiment, optim.lr)

        # 4. Update the learning rate
        trainer.epoch_step(valid_stats.ppl(), epoch)

        # 5. Drop a checkpoint if needed.
        if epoch >= opt.start_checkpoint_at:
            trainer.drop_checkpoint(model_opt, epoch, fields, valid_stats)
예제 #10
0
def train_model(auto_models, valid_model, train_data, valid_data, fields_list,
                valid_fields, optims, discrim_models, discrim_optims, labels,
                advers_optims):

    #     train_model(models, valid_model, train, valid, fields, fields_valid, optims,
    #                 discrim_models, discrim_optims, advers_optims, labels)

    trunc_size = opt.truncated_decoder  # Badly named...
    shard_size = opt.max_generator_batches

    valid_iter = make_valid_data_iter(valid_data, opt)
    valid_loss = make_loss_compute(valid_model, valid_fields["tgt"].vocab,
                                   valid_data, opt)
    valid_trainer = onmt.Trainer(valid_model, valid_iter, valid_iter,
                                 valid_loss, valid_loss, optims[0], trunc_size,
                                 shard_size)

    src_train_iter = make_train_data_iter(train_data[0], opt)
    src_train_loss = make_loss_compute(auto_models[0],
                                       fields_list[0]["tgt"].vocab,
                                       train_data[0], opt)
    src_trainer = onmt.Trainer(auto_models[0], src_train_iter, valid_iter,
                               src_train_loss, valid_loss, optims[0],
                               trunc_size, shard_size)

    tgt_train_iter = make_train_data_iter(train_data[1], opt)
    tgt_train_loss = make_loss_compute(auto_models[1],
                                       fields_list[1]["tgt"].vocab,
                                       train_data[1], opt)
    tgt_trainer = onmt.Trainer(auto_models[1], tgt_train_iter, valid_iter,
                               tgt_train_loss, valid_loss, optims[1],
                               trunc_size, shard_size)

    src_train_iter = make_train_data_iter(train_data[0], opt)
    tgt_train_iter = make_train_data_iter(train_data[1], opt)
    src_train_loss = make_loss_compute(auto_models[0],
                                       fields_list[0]["tgt"].vocab,
                                       train_data[0], opt)
    tgt_train_loss = make_loss_compute(auto_models[1],
                                       fields_list[1]["tgt"].vocab,
                                       train_data[1], opt)
    unsup_trainer = onmt.UnsupTrainer(
        auto_models, [None, None], discrim_models,
        [src_train_iter, tgt_train_iter], valid_iter,
        [src_train_loss, tgt_train_loss], [None, None], valid_loss,
        [optims[0], None], [None, None], [0.9, 0.9], trunc_size, shard_size)

    src_train_iter = make_train_data_iter(train_data[0], opt)
    tgt_train_iter = make_train_data_iter(train_data[1], opt)
    discrim_trainer = onmt.DiscrimTrainer(discrim_models,
                                          [src_train_iter, tgt_train_iter],
                                          discrim_optims, labels, shard_size)

    src_train_iter = make_train_data_iter(train_data[0], opt)
    tgt_train_iter = make_train_data_iter(train_data[1], opt)
    advers_trainer = onmt.DiscrimTrainer(discrim_models,
                                         [src_train_iter, tgt_train_iter],
                                         advers_optims, [0.9, 0.9], shard_size)
    '''
    for epoch in range(10):
        train_stats = discrim_trainer.train(epoch, discrim_report_func)
        print('Discrim Train loss: %g' % train_stats.loss)
    for epoch in range(10):
        train_stats = discrim_trainer.train(epoch, discrim_report_func)
        print('Discrim Train loss: %g' % train_stats.loss)
        train_stats = advers_trainer.train(epoch, discrim_report_func)
        print('Advers Train loss: %g' % train_stats.loss)
    '''

    for epoch in range(opt.start_epoch, opt.epochs + 1):
        print('')

        # 1. Train for one epoch on the training set.
        train_stats = src_trainer.train(epoch, discrim_report_func)
        print('SRC Train perplexity: %g' % train_stats.ppl())
        print('SRC Train accuracy: %g' % train_stats.accuracy())

        # 1. Train for one epoch on the training set.
        train_stats = tgt_trainer.train(epoch, discrim_report_func)
        print('TGT Train perplexity: %g' % train_stats.ppl())
        print('TGT Train accuracy: %g' % train_stats.accuracy())

        if opt.exp_host:
            train_stats.log("train", experiment, optim.lr)

    for epoch in range(opt.start_epoch, opt.epochs + 1):
        print('')

        # 1. Train for one epoch on the training set.
        train_stats = src_trainer.train(epoch, discrim_report_func)
        print('SRC Train perplexity: %g' % train_stats.ppl())
        print('SRC Train accuracy: %g' % train_stats.accuracy())

        train_stats = discrim_trainer.train(epoch, discrim_report_func)
        print('Discrim Train loss: %g' % train_stats.loss)

        train_stats = advers_trainer.train(epoch, discrim_report_func)
        print('Advers Train loss: %g' % train_stats.loss)
        '''
        train_stats = unsup_trainer.train(epoch, report_func)
        print('Train perplexity: %g' % train_stats.ppl())
        print('Train accuracy: %g' % train_stats.accuracy())
        '''

        # 2. Validate on the validation set.
        valid_stats = valid_trainer.validate()
        print('Validation perplexity: %g' % valid_stats.ppl())
        print('Validation accuracy: %g' % valid_stats.accuracy())

        # 3. Log to remote server.
        if opt.exp_host:
            valid_stats.log("valid", experiment, optim.lr)

        # 5. Drop a checkpoint if needed.
        if epoch >= opt.start_checkpoint_at:
            valid_trainer.drop_checkpoint(opt, epoch, valid_fields,
                                          valid_stats)
예제 #11
0
def train_model(model, fields, optim, data_type, model_opt):
    train_loss = make_loss_compute(model, fields["tgt"].vocab, opt)
    valid_loss = make_loss_compute(model, fields["tgt"].vocab, opt)

    trunc_size = opt.truncated_decoder  # Badly named...
    shard_size = opt.max_generator_batches
    norm_method = opt.normalization
    grad_accum_count = opt.accum_count

    print("model_opt: %s" % str(model_opt))
    trainer = onmt.Trainer(model, train_loss, valid_loss, optim, trunc_size,
                           shard_size, data_type, norm_method,
                           grad_accum_count, model_opt, copy.copy(fields))

    print('\nStart training...')
    print(' * number of epochs: %d, starting from Epoch %d' %
          (opt.epochs + 1 - opt.start_epoch, opt.start_epoch))
    print(' * batch size: %d' % opt.batch_size)

    for epoch in range(opt.start_epoch, opt.epochs + 1):
        print('')

        # 0. Validate on the validation set.
        # work-around to make model work :O :O :O
        if epoch == 1:
            valid_iter = make_dataset_iter(lazily_load_dataset("valid"),
                                           fields,
                                           opt,
                                           is_train=False)
            valid_stats = trainer.validate(valid_iter)
            print('Validation perplexity: %g' % valid_stats.ppl())
            print('Validation accuracy: %g' % valid_stats.accuracy())

        # 1. Train for one epoch on the training set.
        train_iter = make_dataset_iter(lazily_load_dataset("train"), fields,
                                       opt)
        train_stats = trainer.train(train_iter, epoch, report_func)
        print('Train perplexity: %g' % train_stats.ppl())
        print('Train accuracy: %g' % train_stats.accuracy())

        # 2. Validate on the validation set.
        valid_iter = make_dataset_iter(lazily_load_dataset("valid"),
                                       fields,
                                       opt,
                                       is_train=False)
        valid_stats = trainer.validate(valid_iter)
        print('Validation perplexity: %g' % valid_stats.ppl())
        print('Validation accuracy: %g' % valid_stats.accuracy())

        # 3. Log to remote server.
        if opt.exp_host:
            train_stats.log("train", experiment, optim.lr)
            valid_stats.log("valid", experiment, optim.lr)

        # 4. Update the learning rate
        trainer.epoch_step(valid_stats.ppl(), epoch)

        # 5. Drop a checkpoint if needed.
        if trainer.early_stop.early_stop_criteria == 'perplexity' or trainer.early_stop.early_stop_criteria is None:
            # no early-stopping
            if epoch >= opt.start_checkpoint_at:
                trainer.drop_checkpoint(model_opt, epoch, fields, valid_stats)
        else:
            # if we are using a non-default early-stopping criteria
            # save model to use for continuing training later on if needed be
            model_name = trainer.drop_checkpoint(
                model_opt,
                epoch,
                fields,
                valid_stats,
                overwrite=opt.overwrite_model_file,
                checkpoint_type='last')
            trainer.drop_metric_scores(model_opt,
                                       epoch,
                                       fields,
                                       valid_stats,
                                       overwrite=True,
                                       checkpoint_type='last')
            print("")

        if trainer.early_stop.signal_early_stopping:
            print("WARNING: Early stopping!")
            break
예제 #12
0
def train_model(model, fields, optim, data_type, model_opt):
    translate_parser = argparse.ArgumentParser(
        description='translate',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    onmt.opts.add_md_help_argument(translate_parser)
    onmt.opts.translate_opts(translate_parser)
    opt_translate = translate_parser.parse_args(args=[])
    opt_translate.replace_unk = False
    opt_translate.verbose = True

    train_loss = make_loss_compute(model, fields["tgt"].vocab, opt)
    valid_loss = make_loss_compute(model,
                                   fields["tgt"].vocab,
                                   opt,
                                   train=False)

    trunc_size = opt.truncated_decoder  # Badly named...
    shard_size = opt.max_generator_batches
    norm_method = opt.normalization
    grad_accum_count = opt.accum_count

    trainer = onmt.Trainer(model, train_loss, valid_loss, optim, trunc_size,
                           shard_size, data_type, norm_method,
                           grad_accum_count)

    logger.info('')
    logger.info('Start training...')
    logger.info(' * number of epochs: %d, starting from Epoch %d' %
                (opt.epochs + 1 - opt.start_epoch, opt.start_epoch))
    logger.info(' * batch size: %d' % opt.batch_size)

    for epoch in range(opt.start_epoch, opt.epochs + 1):
        logger.info('')

        # 1. Train for one epoch on the training set.
        train_iter = make_dataset_iter(lazily_load_dataset("train"), fields,
                                       opt)
        train_stats = trainer.train(train_iter, epoch, report_func)
        logger.info('Train perplexity: %g' % train_stats.ppl())
        logger.info('Train accuracy: %g' % train_stats.accuracy())

        # 2. Validate on the validation set.
        valid_iter = make_dataset_iter(lazily_load_dataset("valid"),
                                       fields,
                                       opt,
                                       is_train=False)
        valid_stats = trainer.validate(valid_iter)
        logger.info('Validation perplexity: %g' % valid_stats.ppl())
        logger.info('Validation accuracy: %g' % valid_stats.accuracy())

        # 3. Log to remote server.
        if opt.exp_host:
            train_stats.log("train", experiment, optim.lr)
            valid_stats.log("valid", experiment, optim.lr)
        if opt.tensorboard:
            train_stats.log_tensorboard("train", writer, optim.lr, epoch)
            train_stats.log_tensorboard("valid", writer, optim.lr, epoch)

        # 4. Update the learning rate
        decay = trainer.epoch_step(valid_stats.ppl(), epoch)
        if decay:
            logger.info("Decaying learning rate to %g" % trainer.optim.lr)

        # 5. Drop a checkpoint if needed.
        if epoch % 10 == 0:  #epoch >= opt.start_checkpoint_at:
            trainer.drop_checkpoint(model_opt, epoch, fields, valid_stats)

            opt_translate.src = 'cache/valid_src_{:s}.txt'.format(
                opt.file_templ)
            opt_translate.tgt = 'cache/valid_eval_refs_{:s}.txt'.format(
                opt.file_templ)
            opt_translate.output = 'result/{:s}/valid_res_{:s}.txt'.format(
                opt.dataset, opt.file_templ)
            opt_translate.model = '%s_acc_%.2f_ppl_%.2f_e%d.pt' % (
                opt.save_model, valid_stats.accuracy(), valid_stats.ppl(),
                epoch)

            check_save_result_path(opt_translate.output)

            translator = make_translator(opt_translate,
                                         report_score=False,
                                         logger=logger)
            translator.calc_sacre_bleu = False
            translator.translate(opt_translate.src_dir, opt_translate.src,
                                 opt_translate.tgt, opt_translate.batch_size,
                                 opt_translate.attn_debug)
예제 #13
0
def train_model(model, fields, optim, data_type, model_opt):
    train_loss = make_loss_compute(model, fields["tgt"].vocab, opt)
    valid_loss = make_loss_compute(model,
                                   fields["tgt"].vocab,
                                   opt,
                                   train=False)

    trunc_size = opt.truncated_decoder  # Badly named...
    shard_size = opt.max_generator_batches
    norm_method = opt.normalization
    grad_accum_count = opt.accum_count
    if opt.sense_loss_lbd > 1e-5:
        vcount = fields['src'].vocab.freqs
        prior = init_unigram_table(vcount)
        print("Done generating negative sampling prior")
    else:
        prior = None

    trainer = onmt.Trainer(model,
                           train_loss,
                           valid_loss,
                           optim,
                           trunc_size,
                           shard_size,
                           data_type,
                           norm_method,
                           grad_accum_count,
                           use_sense=(opt.num_senses > 1),
                           window_size=opt.sense_window_size,
                           tau=opt.tau,
                           scale=opt.scale,
                           num_neg=opt.num_neg,
                           sense_loss_lbd=opt.sense_loss_lbd)

    print('\nStart training...')
    print(' * number of epochs: %d, starting from Epoch %d' %
          (opt.epochs + 1 - opt.start_epoch, opt.start_epoch))
    print(' * batch size: %d' % opt.batch_size)

    for epoch in range(opt.start_epoch, opt.epochs + 1):
        print('')

        # 1. Train for one epoch on the training set.
        train_iter = make_dataset_iter(lazily_load_dataset("train"), fields,
                                       opt)
        train_stats = trainer.train(train_iter,
                                    epoch,
                                    report_func,
                                    neg_prior=prior)
        print('Train perplexity: %g' % train_stats.ppl())
        print('Train accuracy: %g' % train_stats.accuracy())

        # 2. Validate on the validation set.
        valid_iter = make_dataset_iter(lazily_load_dataset("valid"),
                                       fields,
                                       opt,
                                       is_train=False)
        valid_stats = trainer.validate(valid_iter)
        print('Validation perplexity: %g' % valid_stats.ppl())
        print('Validation accuracy: %g' % valid_stats.accuracy())

        # 3. Log to remote server.
        if opt.exp_host:
            train_stats.log("train", experiment, optim.lr)
            valid_stats.log("valid", experiment, optim.lr)
        if opt.tensorboard:
            train_stats.log_tensorboard("train", writer, optim.lr, epoch)
            train_stats.log_tensorboard("valid", writer, optim.lr, epoch)

        # 4. Update the learning rate
        trainer.epoch_step(valid_stats.ppl(), epoch)

        # 5. Drop a checkpoint if needed.
        if epoch >= opt.start_checkpoint_at:
            trainer.drop_checkpoint(model_opt, epoch, fields, valid_stats)
예제 #14
0
def train_model(model, fields, optim, data_type, model_opt, time_str=''):

    train_loss_dict = make_mirror_loss(model, fields, opt)
    valid_loss_dict = make_mirror_loss(model, fields, opt)

    trunc_size = opt.truncated_decoder  # Badly named...
    shard_size = opt.max_generator_batches

    trainer = onmt.Trainer(model, train_loss_dict, valid_loss_dict, optim,
                           trunc_size, shard_size, data_type,
                           opt.normalization, opt.accum_count, opt.back_factor,
                           opt.kl_balance, opt.kl_fix)

    # train_datasets = lazily_load_dataset("train")
    # valid_datasets = lazily_load_dataset("valid")
    # train_iter = make_dataset_iter(train_datasets, fields, opt)
    # valid_iter = make_dataset_iter(valid_datasets, fields, opt, is_train=False)
    val_value = []
    for epoch in range(opt.start_epoch, opt.epochs + 1):
        # print('')
        # train_iter = make_dataset_iter(train_datasets, fields, opt)
        # valid_iter = make_dataset_iter(valid_datasets, fields, opt, is_train=False)
        # 1. Train for one epoch on the training set.
        train_datasets = lazily_load_dataset("train")
        train_iter = make_dataset_iter(train_datasets, fields, opt)
        print("train iter size: {}".format(len(train_iter)))
        train_stats = trainer.train(train_iter, epoch, report_func)
        print('Train perplexity: %g' % train_stats.ppl())
        print('Train accuracy: %g' % train_stats.accuracy())
        del train_datasets, train_iter
        # 2. Validate on the validation set.
        valid_iter = make_dataset_iter(lazily_load_dataset("valid"),
                                       fields,
                                       opt,
                                       is_train=False)
        valid_size = valid_iter.__len__()
        with torch.no_grad():
            valid_stats = trainer.validate(valid_iter)
        print('Validation perplexity: %g' % valid_stats.ppl())
        print('Validation loss_total: %g' % (valid_stats.loss * 0.5))
        print('Validation accuracy: %g' % valid_stats.accuracy())
        # print('Validation kl: %g' % (valid_stats.loss_kl/valid_stats.counter))
        print('Validation kl: %g' % (valid_stats.loss_kl / valid_size))
        print("current kl_factor: {}".format(str(trainer.kl_knealling)))
        val_value.append(
            (valid_stats.loss * 0.5 + valid_stats.loss_kl) / valid_size)
        del valid_iter
        # 3. Log to remote server.
        if opt.exp_host:
            train_stats.log("train", experiment, optim.lr)
            valid_stats.log("valid", experiment, optim.lr)

        # 4. Update the learning rate
        trainer.epoch_step(valid_stats.ppl(), epoch)
        # 5. Drop a checkpoint if needed.
        # if epoch >= opt.start_checkpoint_at:
        if epoch >= opt.kl_balance:
            trainer.drop_checkpoint(model_opt, epoch, fields, valid_stats,
                                    time_str, val_value[-1])

        if len(val_value) > (opt.kl_balance + 3) and val_value[-1] > val_value[
                -2] and val_value[-2] > val_value[-3]:
            print("early stop due to the val loss")
            break
    print("**** Training Finished *****")

    assert trainer.best_model['model_name'] in trainer.checkpoint_list
    for cp in trainer.checkpoint_list:
        if cp != trainer.best_model['model_name']:
            print(cp)
            os.remove(cp)
    print("Cleaning redundant checkpoints")

    print("the best model path: {}".format(trainer.best_model['model_name']))
    print("the best model ppl: {}".format(trainer.best_model['model_ppl']))
예제 #15
0
def train_model(model, fields, optim, data_type, model_opt):
    train_loss = make_loss_compute(model, fields["tgt"].vocab, opt)
    valid_loss = make_loss_compute(model,
                                   fields["tgt"].vocab,
                                   opt,
                                   train=False)

    trunc_size = opt.truncated_decoder  # Badly named...
    shard_size = opt.max_generator_batches
    norm_method = opt.normalization
    grad_accum_count = opt.accum_count

    trainer = onmt.Trainer(model, train_loss, valid_loss, optim, trunc_size,
                           shard_size, data_type, norm_method,
                           grad_accum_count)

    logger.info('')
    logger.info('Start training...')
    logger.info(' * number of epochs: %d, starting from Epoch %d' %
                (opt.epochs + 1 - opt.start_epoch, opt.start_epoch))
    logger.info(' * batch size: %d' % opt.batch_size)

    for epoch in range(opt.start_epoch, opt.epochs + 1):
        logger.info('The current epoch: %d' % epoch)

        # 1. Train for one epoch on the training set.
        train_iter = make_dataset_iter(lazily_load_dataset("train"), fields,
                                       opt)
        train_stats = trainer.train(train_iter, epoch, opt.max_src_len,
                                    opt.max_conv_len, report_func)
        logger.info('Train perplexity: %g' % train_stats.ppl())
        logger.info('Train accuracy: %g' % train_stats.accuracy())

        # 2. Validate on the validation set.
        valid_iter = make_dataset_iter(lazily_load_dataset("valid"),
                                       fields,
                                       opt,
                                       is_train=False)
        valid_stats = trainer.validate(valid_iter, opt.max_src_len,
                                       opt.max_conv_len)
        logger.info('Validation perplexity: %g' % valid_stats.ppl())
        logger.info('Validation accuracy: %g' % valid_stats.accuracy())

        # 3. Log to remote server.
        if opt.exp_host:
            train_stats.log("train", experiment, optim.lr)
            valid_stats.log("valid", experiment, optim.lr)
        if opt.tensorboard:
            train_stats.log_tensorboard("train", writer, optim.lr, epoch)
            train_stats.log_tensorboard("valid", writer, optim.lr, epoch)

        # 4. Update the learning rate
        decay = trainer.epoch_step(valid_stats.ppl(), epoch)
        if decay:
            logger.info("Decaying learning rate to %g" % trainer.optim.lr)

        # 5. Drop a checkpoint if needed.
        if epoch >= opt.start_checkpoint_at:
            trainer.drop_checkpoint(model_opt, epoch, fields, valid_stats)

        if trainer.optim.lr <= 1e-6:
            logger.info(
                "Learning rate %g is below 1e-6 at epoch %d. Stop training!!" %
                (trainer.optim.lr, epoch))
            break
예제 #16
0
def train(enc_rnn_size: int, dec_rnn_size: int, src_word_vec_size: int,
          tgt_word_vec_size: int, dropout: float, learning_rate: float,
          train_steps: int, valid_steps: int, early_stopping_tolerance: int,
          preprocessed_data_path: str, save_model_path: str):
    vocab_fields = torch.load("{}.vocab.pt".format(preprocessed_data_path))

    src_text_field = vocab_fields["src"].base_field
    src_vocab = src_text_field.vocab
    src_padding = src_vocab.stoi[src_text_field.pad_token]

    tgt_text_field = vocab_fields['tgt'].base_field
    tgt_vocab = tgt_text_field.vocab
    tgt_padding = tgt_vocab.stoi[tgt_text_field.pad_token]

    # Specify the core model.

    encoder_embeddings = onmt.modules.Embeddings(src_word_vec_size,
                                                 len(src_vocab),
                                                 word_padding_idx=src_padding,
                                                 dropout=dropout)

    encoder = onmt.encoders.RNNEncoder(hidden_size=enc_rnn_size,
                                       num_layers=2,
                                       rnn_type="LSTM",
                                       bidirectional=True,
                                       embeddings=encoder_embeddings,
                                       dropout=dropout)

    decoder_embeddings = onmt.modules.Embeddings(tgt_word_vec_size,
                                                 len(tgt_vocab),
                                                 word_padding_idx=tgt_padding,
                                                 dropout=dropout)
    decoder = onmt.decoders.decoder.InputFeedRNNDecoder(
        hidden_size=dec_rnn_size,
        num_layers=1,
        bidirectional_encoder=True,
        rnn_type="LSTM",
        embeddings=decoder_embeddings,
        dropout=dropout,
        attn_type=None)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = onmt.models.model.NMTModel(encoder, decoder)
    model.to(device)

    # Specify the tgt word generator and loss computation module
    model.generator = nn.Sequential(nn.Linear(dec_rnn_size, len(tgt_vocab)),
                                    nn.LogSoftmax(dim=-1)).to(device)

    loss = onmt.utils.loss.NMTLossCompute(criterion=nn.NLLLoss(
        ignore_index=tgt_padding, reduction="sum"),
                                          generator=model.generator)

    torch_optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    optim = onmt.utils.optimizers.Optimizer(torch_optimizer,
                                            learning_rate=learning_rate,
                                            max_grad_norm=2)

    # Load some data
    from itertools import chain
    train_data_file = "{}.train.0.pt".format(preprocessed_data_path)
    valid_data_file = "{}.valid.0.pt".format(preprocessed_data_path)
    train_iter = onmt.inputters.inputter.DatasetLazyIter(
        dataset_paths=[train_data_file],
        fields=vocab_fields,
        batch_size=64,
        batch_size_multiple=1,
        batch_size_fn=None,
        device=device,
        is_train=True,
        repeat=True,
        pool_factor=1)

    valid_iter = onmt.inputters.inputter.DatasetLazyIter(
        dataset_paths=[valid_data_file],
        fields=vocab_fields,
        batch_size=128,
        batch_size_multiple=1,
        batch_size_fn=None,
        device=device,
        is_train=False,
        repeat=False,
        pool_factor=1)

    tensorboard = SummaryWriter(flush_secs=5)

    import logging
    import sys
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    report_manager = onmt.utils.ReportMgr(report_every=64,
                                          tensorboard_writer=tensorboard)
    report_manager.start()

    early_stopper = EarlyStopping(early_stopping_tolerance)
    model_saver = ModelSaver(save_model_path, model, None, vocab_fields, optim)

    trainer = onmt.Trainer(model=model,
                           train_loss=loss,
                           valid_loss=loss,
                           optim=optim,
                           report_manager=report_manager,
                           dropout=dropout,
                           model_saver=model_saver,
                           earlystopper=early_stopper)

    print("Starting training...")
    trainer.train(train_iter=train_iter,
                  train_steps=train_steps,
                  valid_iter=valid_iter,
                  valid_steps=valid_steps)

    src_reader = onmt.inputters.str2reader["text"]
    tgt_reader = onmt.inputters.str2reader["text"]
    scorer = onmt.translate.GNMTGlobalScorer(alpha=0.7,
                                             beta=0.,
                                             length_penalty="avg",
                                             coverage_penalty="none")
    gpu = 0 if torch.cuda.is_available() else -1
    translator = onmt.translate.Translator(model=model,
                                           fields=vocab_fields,
                                           src_reader=src_reader,
                                           tgt_reader=tgt_reader,
                                           global_scorer=scorer,
                                           gpu=gpu)
    builder = onmt.translate.TranslationBuilder(
        data=torch.load(valid_data_file), fields=vocab_fields, has_tgt=True)
    pos_matches = count = 0

    for batch in valid_iter:
        trans_batch = translator.translate_batch(batch=batch,
                                                 src_vocabs=[src_vocab],
                                                 attn_debug=False)
        translations = builder.from_batch(trans_batch)
        for trans in translations:
            pred = ' '.join(trans.pred_sents[0])
            gold = ' '.join(trans.gold_sent)
            pos_matches += 1 if pred == gold else 0
            count += 1

    print("Acc: ", pos_matches / count)
예제 #17
0
def train_model(model, fields, optim, data_type, opt_per_pred):
    translate_parser = argparse.ArgumentParser(
        description='translate',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    onmt.opts.add_md_help_argument(translate_parser)
    onmt.opts.translate_opts(translate_parser)
    opt_translate = translate_parser.parse_args(args=[])
    opt_translate.replace_unk = False
    opt_translate.verbose = False
    opt_translate.block_ngram_repeat = False
    if opt.gpuid:
        opt_translate.gpu = opt.gpuid[0]

    trunc_size = opt.truncated_decoder  # Badly named...
    shard_size = opt.max_generator_batches
    norm_method = opt.normalization
    grad_accum_count = opt.accum_count

    trainer = {}
    for predicate in opt.parser.predicates:
        train_loss = make_loss_compute(model[predicate],
                                       fields[predicate]["tgt"].vocab,
                                       opt_per_pred[predicate])
        valid_loss = make_loss_compute(model[predicate],
                                       fields[predicate]["tgt"].vocab,
                                       opt_per_pred[predicate],
                                       train=False)
        trainer[predicate] = onmt.Trainer(model[predicate], train_loss,
                                          valid_loss, optim[predicate],
                                          trunc_size, shard_size,
                                          data_type[predicate], norm_method,
                                          grad_accum_count)

    logger.info('')
    logger.info('Start training...')
    logger.info(' * number of epochs: %d, starting from Epoch %d' %
                (opt.epochs + 1 - opt.start_epoch, opt.start_epoch))
    logger.info(' * batch size: %d' % opt.batch_size)

    for epoch in range(opt.start_epoch, opt.epochs + 1):
        logger.info('')

        train_stats = {}
        valid_stats = {}
        for predicate in opt.parser.predicates:
            logger.info('Train predicate: %s' % predicate)
            # 1. Train for one epoch on the training set.
            train_iter = make_dataset_iter(
                lazily_load_dataset("train", opt_per_pred[predicate]),
                fields[predicate], opt_per_pred[predicate])
            train_stats[predicate] = trainer[predicate].train(
                train_iter, epoch, fields[predicate], report_func)
            logger.info('Train perplexity: %g' % train_stats[predicate].ppl())
            logger.info('Train accuracy: %g' %
                        train_stats[predicate].accuracy())

            # 2. Validate on the validation set.
            valid_iter = make_dataset_iter(lazily_load_dataset(
                "valid", opt_per_pred[predicate]),
                                           fields[predicate],
                                           opt_per_pred[predicate],
                                           is_train=False)
            valid_stats[predicate] = trainer[predicate].validate(valid_iter)
            logger.info('Validation perplexity: %g' %
                        valid_stats[predicate].ppl())
            logger.info('Validation accuracy: %g' %
                        valid_stats[predicate].accuracy())

            # 3. Log to remote server.
            if opt_per_pred[predicate].exp_host:
                train_stats[predicate].log("train", experiment,
                                           optim[predicate].lr)
                valid_stats[predicate].log("valid", experiment,
                                           optim[predicate].lr)
            if opt_per_pred[predicate].tensorboard:
                train_stats[predicate].log_tensorboard("train", writer,
                                                       optim[predicate].lr,
                                                       epoch)
                train_stats[predicate].log_tensorboard("valid", writer,
                                                       optim[predicate].lr,
                                                       epoch)

            # 4. Update the learning rate
            decay = trainer[predicate].epoch_step(valid_stats[predicate].ppl(),
                                                  epoch)
            if decay:
                logger.info("Decaying learning rate to %g" %
                            trainer[predicate].optim.lr)

        # 5. Drop a checkpoint if needed.
        if epoch % 10 == 0:  #epoch >= opt.start_checkpoint_at:
            opt_translates = []
            for predicate in opt.parser.predicates:
                opt_translate.predicate = predicate
                opt_translate.batch_size = opt_per_pred[predicate].batch_size
                opt_translate.src = 'cache/valid_src_{:s}.txt'.format(
                    opt_per_pred[predicate].file_templ)
                opt_translate.tgt = 'cache/valid_eval_refs_{:s}.txt'.format(
                    opt_per_pred[predicate].file_templ)
                opt_translate.output = 'result/{:s}/valid_res_{:s}.txt'.format(
                    opt_per_pred[predicate].dataset,
                    opt_per_pred[predicate].file_templ)
                #opt_translate.model = '%s_acc_%.2f_ppl_%.2f_e%d.pt' % (
                #opt_per_pred[predicate].save_model, valid_stats[predicate].accuracy(), valid_stats[predicate].ppl(), epoch)

                check_save_result_path(opt_translate.output)

                translator = make_translator(opt_translate,
                                             report_score=False,
                                             logger=logger,
                                             fields=fields[predicate],
                                             model=trainer[predicate].model,
                                             model_opt=opt_per_pred[predicate])
                translator.output_beam = 'result/{:s}/valid_res_beam_{:s}.txt'.format(
                    opt_per_pred[predicate].dataset,
                    opt_per_pred[predicate].file_templ)
                #translator.beam_size = 5
                #translator.n_best = 5
                translator.translate(opt_translate.src_dir, opt_translate.src,
                                     opt_translate.tgt,
                                     opt_translate.batch_size,
                                     opt_translate.attn_debug)
                opt_translates.append(copy(opt_translate))
            corpusBLEU, bleu, rouge, coverage, bleu_per_predicate = evaluate(
                opt_translates)
            for predicate in opt.parser.predicates:
                trainer[predicate].drop_checkpoint(
                    opt_per_pred[predicate], epoch, corpusBLEU, bleu, rouge,
                    coverage, bleu_per_predicate[predicate], fields[predicate],
                    valid_stats[predicate])
예제 #18
0
def train_model(model, fields, optim, model_opt, swap_dict):
    train_loss = make_loss_compute(model, fields["tgt"].vocab, opt)
    valid_loss = make_loss_compute(model,
                                   fields["tgt"].vocab,
                                   opt,
                                   train=False)

    trunc_size = opt.truncated_decoder  # Badly named...
    shard_size = opt.max_generator_batches
    norm_method = opt.normalization
    grad_accum_count = opt.accum_count

    trainer = onmt.Trainer(model, train_loss, valid_loss, optim, trunc_size,
                           shard_size, norm_method, grad_accum_count,
                           opt.select_model)

    print('\nStart training...')
    print(' * number of epochs: %d, starting from Epoch %d' %
          (opt.epochs + 1 - opt.start_epoch, opt.start_epoch))
    print(' * batch size: %d' % opt.batch_size)

    for epoch in range(opt.start_epoch, opt.epochs + 1):
        print('')

        # 1. Train for one epoch on the training set.
        train_iter = make_dataset_iter(lazily_load_dataset("train"), fields,
                                       opt)
        train_stats = trainer.train(train_iter, epoch, opt, fields,
                                    validate_while_training, writer,
                                    report_func, opt.valid_pt, swap_dict)
        print('Train perplexity: %g' % train_stats.ppl())
        print('Train accuracy: %g' % train_stats.accuracy())

        # 2. Validate on the validation set.
        valid_iter = make_dataset_iter(lazily_load_dataset(
            "valid", valid_pt=opt.valid_pt),
                                       fields,
                                       opt,
                                       is_train=False)
        valid_stats = trainer.validate(valid_iter)
        print('Validation perplexity: %g' % valid_stats.ppl())
        print('Validation accuracy: %g' % valid_stats.accuracy())

        # Additional Step. Validate on BLEU.
        # translate.main(True, fields, model, model_opt)

        # 3. Log to remote server.
        # if opt.exp_host:
        #     train_stats.log("train", experiment, optim.lr)
        #     valid_stats.log("valid", experiment, optim.lr)
        # if opt.tensorboard:
        #     train_stats.log_tensorboard("train", writer, optim.lr, epoch)
        #     train_stats.log_tensorboard("valid", writer, optim.lr, epoch)

        # 4. Update the learning rate
        trainer.epoch_step(valid_stats.ppl(), epoch)

        # 5. Drop a checkpoint if needed.
        if epoch >= opt.start_checkpoint_at and epoch % opt.save_interval == 0:
            trainer.drop_checkpoint(model_opt, epoch, deepcopy(fields),
                                    valid_stats, 0)