def cli_main():
    from pl_bolts.callbacks import LatentDimInterpolator, TensorboardGenerativeModelImageSampler
    from pl_bolts.datamodules import ImagenetDataModule

    pl.seed_everything(1234)
    parser = ArgumentParser()
    parser.add_argument('--dataset',
                        default='mnist',
                        type=str,
                        help='mnist, stl10, imagenet')

    parser = pl.Trainer.add_argparse_args(parser)
    parser = VAE.add_model_specific_args(parser)
    parser = ImagenetDataModule.add_argparse_args(parser)
    parser = MNISTDataModule.add_argparse_args(parser)
    args = parser.parse_args()

    # default is mnist
    datamodule = None
    if args.dataset == 'imagenet2012':
        datamodule = ImagenetDataModule.from_argparse_args(args)
    elif args.dataset == 'stl10':
        datamodule = STL10DataModule.from_argparse_args(args)

    callbacks = [
        TensorboardGenerativeModelImageSampler(),
        LatentDimInterpolator(interpolate_epoch_interval=5)
    ]
    vae = VAE(**vars(args), datamodule=datamodule)
    trainer = pl.Trainer.from_argparse_args(args,
                                            callbacks=callbacks,
                                            progress_bar_refresh_rate=10)
    trainer.fit(vae)
Beispiel #2
0
        parser.add_argument('--pretrained', type=str, default=None)
        parser.add_argument('--data_dir', type=str, default=os.getcwd())

        parser.add_argument('--learning_rate', type=float, default=1e-3)
        return parser


if __name__ == '__main__':
    from pl_bolts.datamodules import ImagenetDataModule
    parser = ArgumentParser()
    parser.add_argument('--dataset', default='mnist', type=str)

    parser = Trainer.add_argparse_args(parser)
    parser = VAE.add_model_specific_args(parser)
    parser = ImagenetDataModule.add_argparse_args(parser)
    parser = MNISTDataModule.add_argparse_args(parser)
    args = parser.parse_args()
    #
    # if args.dataset == 'imagenet' or args.pretrained:
    #     datamodule = ImagenetDataModule.from_argparse_args(args)
    #     args.image_width = datamodule.size()[1]
    #     args.image_height = datamodule.size()[2]
    #     args.input_channels = datamodule.size()[0]
    #
    # elif args.dataset == 'mnist':
    #     datamodule = MNISTDataModule.from_argparse_args(args)
    #     args.image_width = datamodule.size()[1]
    #     args.image_height = datamodule.size()[2]
    #     args.input_channels = datamodule.size()[0]

    vae = VAE(**vars(args))