示例#1
0
文件: train_gan.py 项目: wnstlr/LNets
def main():
    # Parse config json.
    cfg = process_config()

    # Set the seed.
    set_experiment_seed(cfg.seed)

    # Create directories to be used in the experiments.
    cfg = create_dirs(cfg)

    if cfg.benchmark_mode:
        torch.backends.cudnn.benchmark = True

    # Declare instance for GAN.
    if cfg.gan_type == 'WGAN':
        gan = WGAN(cfg)
    elif cfg.gan_type == 'LWGAN':
        gan = LWGAN(cfg)
    elif cfg.gan_type == 'WGAN_GP':
        gan = WGAN_GP(cfg)
    else:
        raise Exception("[!] There is no option for " + cfg.gan_type)

    # Save the hyperparameter json.
    save_hparams(cfg)

    # Launch the graph in a session.
    gan.train()
    print(" [*] Training finished!")

    # Visualize learned generator.
    gan.visualize_results(cfg.epoch)
    print(" [*] Testing finished!")
示例#2
0
        'wait': 0
    })

    # Enter the training loop.
    trainer.train(model,
                  loaders['train'],
                  maxepoch=config.optim.epochs,
                  optimizer=optimizer)

    # Pick the best model according to validation score and test it.
    model.reset_meters()
    best_model_path = os.path.join(dirs.best_path, "best_model.pt")
    if os.path.exists(best_model_path):
        model.load_state_dict(torch.load(best_model_path))
    if loaders['test'] is not None:
        print("Testing the best model. ")
        logger.log_meters('test', trainer.test(model, loaders['test']))

    return model


if __name__ == '__main__':
    # Get the config, initialize the model and construct the data loader.
    cfg = process_config()
    model_initialization = get_model(cfg)
    print(model_initialization)
    data_loaders = load_data(cfg)

    # Train.
    trained_model = train(model_initialization, data_loaders, cfg)