trainer.logger.experiment.add_image('gan_images',
                                            grid,
                                            global_step=trainer.global_step)


if __name__ == '__main__':
    from pl_bolts.datamodules import ImagenetDataModule

    parser = ArgumentParser()
    parser = Trainer.add_argparse_args(parser)
    parser = GAN.add_model_specific_args(parser)
    parser = ImagenetDataModule.add_argparse_args(parser)
    args = parser.parse_args()

    # default is mnist
    datamodule = None
    if args.dataset == 'imagenet2012':
        datamodule = ImagenetDataModule.from_argparse_args(args)
    elif args.dataset == 'stl10':
        datamodule = STL10DataModule.from_argparse_args(args)

    gan = GAN(**vars(args), datamodule=datamodule)
    callbacks = [ImageGenerator(), LatentDimInterpolator()]

    # no val loop... thus we condition on loss and always save the last
    checkpoint_cb = ModelCheckpoint(monitor='loss', save_last=True)
    trainer = Trainer.from_argparse_args(args,
                                         callbacks=callbacks,
                                         checkpoint_callback=checkpoint_cb)
    trainer.fit(gan)