def train_vae(): config = ConfigProvider.get_config() seed_everything(config.random_seed) if config.dataset == "toy": datamodule = MyDataModule(config) latent_dim = config.latent_dim_toy enc_layer_sizes = config.enc_layer_sizes_toy + [latent_dim] dec_layer_sizes = [latent_dim] + config.dec_layer_sizes_toy elif config.dataset == "mnist": datamodule = MNISTDataModule(config) latent_dim = config.latent_dim_mnist enc_layer_sizes = config.enc_layer_sizes_mnist + [latent_dim] dec_layer_sizes = [latent_dim] + config.dec_layer_sizes_mnist else: raise ValueError( "undefined config.dataset. Allowed are either 'toy' or 'mnist'") model = VAEFC(config=config, encoder_layer_sizes=enc_layer_sizes, decoder_layer_sizes=dec_layer_sizes) logger = TensorBoardLogger(save_dir=tb_logs_folder, name='VAEFC', default_hp_metric=False) logger.hparams = config # TODO only put here relevant stuff checkpoint_callback = ModelCheckpoint(dirpath=vae_checkpoints_path) trainer = Trainer( deterministic=config.is_deterministic, # auto_lr_find=config.auto_lr_find, # log_gpu_memory='all', # min_epochs=99999, max_epochs=config.num_epochs, default_root_dir=vae_checkpoints_path, logger=logger, callbacks=[checkpoint_callback], gpus=1) # trainer.tune(model) trainer.fit(model, datamodule=datamodule) best_model_path = checkpoint_callback.best_model_path print("done training vae with lightning") print(f"best model path = {best_model_path}") return trainer
def train_latent_classifier(): config = ConfigProvider.get_config() seed_everything(config.random_seed) if config.dataset == "toy": datamodule = MyDataModule(config) latent_dim = config.latent_dim_toy enc_layer_sizes = config.enc_layer_sizes_toy + [latent_dim] dec_layer_sizes = [latent_dim] + config.dec_layer_sizes_toy elif config.dataset == "mnist": datamodule = MNISTDataModule(config) latent_dim = config.latent_dim_mnist enc_layer_sizes = config.enc_layer_sizes_mnist + [latent_dim] dec_layer_sizes = [latent_dim] + config.dec_layer_sizes_mnist else: raise ValueError("undefined config.dataset. Allowed are either 'toy' or 'mnist'") # model = VAEFC(config=config, encoder_layer_sizes=enc_layer_sizes, decoder_layer_sizes=dec_layer_sizes) last_vae = max(glob.glob(os.path.join(os.path.abspath(vae_checkpoints_path), r"**/*.ckpt"), recursive=True), key=os.path.getctime) trained_vae = VAEFC.load_from_checkpoint(last_vae, config=config, encoder_layer_sizes=enc_layer_sizes, decoder_layer_sizes=dec_layer_sizes) logger = TensorBoardLogger(save_dir=tb_logs_folder, name='Classifier', default_hp_metric=False) logger.hparams = config # TODO only put here relevant stuff checkpoint_callback = ModelCheckpoint(dirpath=classifier_checkpoints_path) trainer = Trainer(deterministic=config.is_deterministic, # auto_lr_find=config.auto_lr_find, # log_gpu_memory='all', # min_epochs=99999, max_epochs=config.num_epochs, default_root_dir=classifier_checkpoints_path, logger=logger, callbacks=[checkpoint_callback], gpus=1 ) # trainer.tune(model) classifier = LatentSpaceClassifierLightning(config, trained_vae, latent_dim=latent_dim) trainer.fit(classifier, datamodule=datamodule) best_model_path = checkpoint_callback.best_model_path print("done training classifier with lightning") print(f"best model path = {best_model_path}") return trainer