def cli_main(): from pl_bolts.callbacks import LatentDimInterpolator, TensorboardGenerativeModelImageSampler from pl_bolts.datamodules import ImagenetDataModule pl.seed_everything(1234) parser = ArgumentParser() parser.add_argument('--dataset', default='mnist', type=str, help='mnist, stl10, imagenet') parser = pl.Trainer.add_argparse_args(parser) parser = VAE.add_model_specific_args(parser) parser = ImagenetDataModule.add_argparse_args(parser) parser = MNISTDataModule.add_argparse_args(parser) args = parser.parse_args() # default is mnist datamodule = None if args.dataset == 'imagenet2012': datamodule = ImagenetDataModule.from_argparse_args(args) elif args.dataset == 'stl10': datamodule = STL10DataModule.from_argparse_args(args) callbacks = [ TensorboardGenerativeModelImageSampler(), LatentDimInterpolator(interpolate_epoch_interval=5) ] vae = VAE(**vars(args), datamodule=datamodule) trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks, progress_bar_refresh_rate=10) trainer.fit(vae)
images = pl_module(z) grid = torchvision.utils.make_grid(images) trainer.logger.experiment.add_image('gan_images', grid, global_step=trainer.global_step) # todo: covert to CLI func and add test if __name__ == '__main__': from pl_bolts.datamodules import ImagenetDataModule parser = ArgumentParser() parser = Trainer.add_argparse_args(parser) parser = GAN.add_model_specific_args(parser) parser = ImagenetDataModule.add_argparse_args(parser) args = parser.parse_args() # default is mnist datamodule = None if args.dataset == 'imagenet2012': datamodule = ImagenetDataModule.from_argparse_args(args) elif args.dataset == 'stl10': datamodule = STL10DataModule.from_argparse_args(args) gan = GAN(**vars(args), datamodule=datamodule) callbacks = [ImageGenerator(), LatentDimInterpolator()] # no val loop... thus we condition on loss and always save the last checkpoint_cb = ModelCheckpoint(monitor='loss', save_last=True) trainer = Trainer.from_argparse_args(args,