コード例 #1
0
def train_vae():
    config = ConfigProvider.get_config()
    seed_everything(config.random_seed)

    if config.dataset == "toy":
        datamodule = MyDataModule(config)
        latent_dim = config.latent_dim_toy
        enc_layer_sizes = config.enc_layer_sizes_toy + [latent_dim]
        dec_layer_sizes = [latent_dim] + config.dec_layer_sizes_toy
    elif config.dataset == "mnist":
        datamodule = MNISTDataModule(config)
        latent_dim = config.latent_dim_mnist
        enc_layer_sizes = config.enc_layer_sizes_mnist + [latent_dim]
        dec_layer_sizes = [latent_dim] + config.dec_layer_sizes_mnist
    else:
        raise ValueError(
            "undefined config.dataset. Allowed are either 'toy' or 'mnist'")

    model = VAEFC(config=config,
                  encoder_layer_sizes=enc_layer_sizes,
                  decoder_layer_sizes=dec_layer_sizes)

    logger = TensorBoardLogger(save_dir=tb_logs_folder,
                               name='VAEFC',
                               default_hp_metric=False)
    logger.hparams = config  # TODO only put here relevant stuff

    checkpoint_callback = ModelCheckpoint(dirpath=vae_checkpoints_path)
    trainer = Trainer(
        deterministic=config.is_deterministic,
        # auto_lr_find=config.auto_lr_find,
        # log_gpu_memory='all',
        # min_epochs=99999,
        max_epochs=config.num_epochs,
        default_root_dir=vae_checkpoints_path,
        logger=logger,
        callbacks=[checkpoint_callback],
        gpus=1)
    # trainer.tune(model)
    trainer.fit(model, datamodule=datamodule)
    best_model_path = checkpoint_callback.best_model_path
    print("done training vae with lightning")
    print(f"best model path = {best_model_path}")
    return trainer
コード例 #2
0
def train_latent_classifier():
    config = ConfigProvider.get_config()
    seed_everything(config.random_seed)

    if config.dataset == "toy":
        datamodule = MyDataModule(config)
        latent_dim = config.latent_dim_toy
        enc_layer_sizes = config.enc_layer_sizes_toy + [latent_dim]
        dec_layer_sizes = [latent_dim] + config.dec_layer_sizes_toy
    elif config.dataset == "mnist":
        datamodule = MNISTDataModule(config)
        latent_dim = config.latent_dim_mnist
        enc_layer_sizes = config.enc_layer_sizes_mnist + [latent_dim]
        dec_layer_sizes = [latent_dim] + config.dec_layer_sizes_mnist
    else:
        raise ValueError("undefined config.dataset. Allowed are either 'toy' or 'mnist'")

    # model = VAEFC(config=config, encoder_layer_sizes=enc_layer_sizes, decoder_layer_sizes=dec_layer_sizes)

    last_vae = max(glob.glob(os.path.join(os.path.abspath(vae_checkpoints_path), r"**/*.ckpt"), recursive=True), key=os.path.getctime)
    trained_vae = VAEFC.load_from_checkpoint(last_vae, config=config, encoder_layer_sizes=enc_layer_sizes, decoder_layer_sizes=dec_layer_sizes)

    logger = TensorBoardLogger(save_dir=tb_logs_folder, name='Classifier', default_hp_metric=False)
    logger.hparams = config  # TODO only put here relevant stuff

    checkpoint_callback = ModelCheckpoint(dirpath=classifier_checkpoints_path)
    trainer = Trainer(deterministic=config.is_deterministic,
                      # auto_lr_find=config.auto_lr_find,
                      # log_gpu_memory='all',
                      # min_epochs=99999,
                      max_epochs=config.num_epochs,
                      default_root_dir=classifier_checkpoints_path,
                      logger=logger,
                      callbacks=[checkpoint_callback],
                      gpus=1
                      )
    # trainer.tune(model)

    classifier = LatentSpaceClassifierLightning(config, trained_vae, latent_dim=latent_dim)
    trainer.fit(classifier, datamodule=datamodule)
    best_model_path = checkpoint_callback.best_model_path
    print("done training classifier with lightning")
    print(f"best model path = {best_model_path}")
    return trainer