예제 #1
0
def train_model(model, fields, optim, data_type, train_attr, valid_attr,
                train_img_feats, valid_img_feats, train_img_mask, valid_img_mask, train_feat_indices,
                model_opt):
    train_loss = make_loss_compute(model, fields["tgt"].vocab, opt)
    valid_loss = make_loss_compute(model, fields["tgt"].vocab, opt,
                                   train=False)

    trunc_size = opt.truncated_decoder  # Badly named...
    shard_size = opt.max_generator_batches
    norm_method = opt.normalization
    grad_accum_count = opt.accum_count
    multimodal_model_type = opt.multimodal_model_type

    trainer = onmt.TrainerMultimodal(model,
                           train_loss, valid_loss,
                           optim, trunc_size, shard_size, data_type,
                           norm_method, grad_accum_count,
                           train_attr, valid_attr,
                           train_img_feats, valid_img_feats,
                           train_img_mask, valid_img_mask,
                           train_feat_indices,
                           multimodal_model_type)

    print('\nStart training...')
    print(' * number of epochs: %d, starting from Epoch %d' %
          (opt.epochs + 1 - opt.start_epoch, opt.start_epoch))
    print(' * batch size: %d' % opt.batch_size)

    for epoch in range(opt.start_epoch, opt.epochs + 1):
        print('')

        # 1. Train for one epoch on the training set.
        train_iter = make_dataset_iter(lazily_load_dataset("train"),
                                       fields, opt)
        train_stats = trainer.train(train_iter, epoch, report_func)
        print('Train perplexity: %g' % train_stats.ppl())
        print('Train accuracy: %g' % train_stats.accuracy())

        # 2. Validate on the validation set.
        valid_iter = make_dataset_iter(lazily_load_dataset("valid"),
                                       fields, opt,
                                       is_train=False)
        valid_stats = trainer.validate(valid_iter)
        print('Validation perplexity: %g' % valid_stats.ppl())
        print('Validation accuracy: %g' % valid_stats.accuracy())

        # 3. Log to remote server.
        if opt.exp_host:
            train_stats.log("train", experiment, optim.lr)
            valid_stats.log("valid", experiment, optim.lr)
        if opt.tensorboard:
            train_stats.log_tensorboard("train", writer, optim.lr, epoch)
            train_stats.log_tensorboard("valid", writer, optim.lr, epoch)

        # 4. Update the learning rate
        trainer.epoch_step(valid_stats.ppl(), epoch)

        # 5. Drop a checkpoint if needed.
        if epoch >= opt.start_checkpoint_at:
            trainer.drop_checkpoint(model_opt, epoch, fields, valid_stats)
예제 #2
0
def train_model(model, fields, optim, data_type, train_img_feats,
                valid_img_feats, train_img_vecs, valid_img_vecs, model_opt):
    train_loss = make_loss_compute(model,
                                   fields["tgt"].vocab,
                                   opt,
                                   training=True)
    valid_loss = make_loss_compute(model,
                                   fields["tgt"].vocab,
                                   opt,
                                   training=False)

    trunc_size = opt.truncated_decoder  # Badly named...
    shard_size = opt.max_generator_batches
    norm_method = opt.normalization
    grad_accum_count = opt.accum_count
    multimodal_model_type = opt.multimodal_model_type

    trainer = onmt.TrainerMultimodal(
        model,
        train_loss,
        valid_loss,
        optim,
        trunc_size,
        shard_size,
        data_type,
        norm_method,
        grad_accum_count,
        train_img_feats,
        valid_img_feats,
        multimodal_model_type=multimodal_model_type,
        train_img_vecs=train_img_vecs,
        valid_img_vecs=valid_img_vecs,
        model_opt=model_opt,
        fields=fields)

    print('\nStart training...')
    print(' * number of epochs: %d, starting from Epoch %d' %
          (opt.epochs + 1 - opt.start_epoch, opt.start_epoch))
    print(' * batch size: %d' % opt.batch_size)

    for epoch in range(opt.start_epoch, opt.epochs + 1):
        print('')

        # 0. Validate on the validation set.
        if epoch == 1:
            valid_iter = make_dataset_iter(lazily_load_dataset("valid"),
                                           fields,
                                           opt,
                                           is_train=False)
            valid_stats = trainer.validate(valid_iter)
            print('Validation perplexity: %g' % valid_stats.ppl())
            print('Validation accuracy: %g' % valid_stats.accuracy())

        # 1. Train for one epoch on the training set.
        train_iter = make_dataset_iter(lazily_load_dataset("train"),
                                       fields,
                                       opt,
                                       is_train=True)
        train_stats = trainer.train(train_iter, epoch, report_func)
        print('Train perplexity: %g' % train_stats.ppl())
        print('Train accuracy: %g' % train_stats.accuracy())

        # 2. Validate on the validation set.
        valid_iter = make_dataset_iter(lazily_load_dataset("valid"),
                                       fields,
                                       opt,
                                       is_train=False)
        valid_stats = trainer.validate(valid_iter)
        print('Validation perplexity: %g' % valid_stats.ppl())
        print('Validation accuracy: %g' % valid_stats.accuracy())
        image_feats_loss = valid_stats.image_feats_loss
        image_feats_cos = valid_stats.image_feats_cos
        image_pixels_loss = valid_stats.image_pixels_loss
        image_pixels_acc = valid_stats.image_pixels_acc
        print('Validation image feats nll (avg.): %g' %
              (image_feats_loss / valid_stats.n_updates))
        print('Validation image fests cosine (avg.): %g' %
              (image_feats_cos / valid_stats.n_updates))
        #print('Validation image pixels nll (avg.): %g' % (image_pixels_loss / valid_stats.n_updates))
        #print('Validation image pixels acc (avg.): %g' % (image_pixels_acc / valid_stats.n_updates))

        # 3. Log to remote server.
        if opt.exp_host:
            train_stats.log("train", experiment, optim.lr)
            valid_stats.log("valid", experiment, optim.lr)

        # 4. Update the learning rate
        trainer.epoch_step(valid_stats.ppl(), epoch)

        # 5. Drop a checkpoint if needed.
        if trainer.early_stop.early_stop_criteria in ['perplexity', None]:
            # not early-stopping
            if epoch >= opt.start_checkpoint_at:
                trainer.drop_checkpoint(model_opt,
                                        epoch,
                                        fields,
                                        valid_stats,
                                        overwrite=opt.overwrite_model_file)
        else:
            # if we are using a non-default early-stopping criteria
            # save model to use for continuing training later on if needed be
            model_name = trainer.drop_checkpoint(
                model_opt,
                epoch,
                fields,
                valid_stats,
                overwrite=opt.overwrite_model_file,
                checkpoint_type='last')
            trainer.drop_metric_scores(model_opt,
                                       epoch,
                                       fields,
                                       valid_stats,
                                       overwrite=True,
                                       checkpoint_type='last')
            print("")

        if trainer.early_stop.signal_early_stopping:
            print("WARNING: Early stopping!")
            break