def cli_main(): from pl_bolts.callbacks import LatentDimInterpolator, TensorboardGenerativeModelImageSampler from pl_bolts.datamodules import ImagenetDataModule pl.seed_everything(1234) parser = ArgumentParser() parser.add_argument('--dataset', default='mnist', type=str, help='mnist, stl10, imagenet') parser = pl.Trainer.add_argparse_args(parser) parser = VAE.add_model_specific_args(parser) parser = ImagenetDataModule.add_argparse_args(parser) parser = MNISTDataModule.add_argparse_args(parser) args = parser.parse_args() # default is mnist datamodule = None if args.dataset == 'imagenet2012': datamodule = ImagenetDataModule.from_argparse_args(args) elif args.dataset == 'stl10': datamodule = STL10DataModule.from_argparse_args(args) callbacks = [ TensorboardGenerativeModelImageSampler(), LatentDimInterpolator(interpolate_epoch_interval=5) ] vae = VAE(**vars(args), datamodule=datamodule) trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks, progress_bar_refresh_rate=10) trainer.fit(vae)
parser.add_argument('--pretrained', type=str, default=None) parser.add_argument('--data_dir', type=str, default=os.getcwd()) parser.add_argument('--learning_rate', type=float, default=1e-3) return parser if __name__ == '__main__': from pl_bolts.datamodules import ImagenetDataModule parser = ArgumentParser() parser.add_argument('--dataset', default='mnist', type=str) parser = Trainer.add_argparse_args(parser) parser = VAE.add_model_specific_args(parser) parser = ImagenetDataModule.add_argparse_args(parser) parser = MNISTDataModule.add_argparse_args(parser) args = parser.parse_args() # # if args.dataset == 'imagenet' or args.pretrained: # datamodule = ImagenetDataModule.from_argparse_args(args) # args.image_width = datamodule.size()[1] # args.image_height = datamodule.size()[2] # args.input_channels = datamodule.size()[0] # # elif args.dataset == 'mnist': # datamodule = MNISTDataModule.from_argparse_args(args) # args.image_width = datamodule.size()[1] # args.image_height = datamodule.size()[2] # args.input_channels = datamodule.size()[0] vae = VAE(**vars(args))