def cli_main():
    parser = ArgumentParser()

    # trainer args
    parser = pl.Trainer.add_argparse_args(parser)

    # model args
    parser = ImageGPT.add_model_specific_args(parser)
    args = parser.parse_args()

    if args.dataset == "fashion_mnist":
        datamodule = FashionMNISTDataModule.from_argparse_args(args)

    elif args.dataset == "imagenet128":
        datamodule = ImagenetDataModule.from_argparse_args(args)

    model = ImageGPT(**args.__dict__, datamodule=datamodule)

    trainer = pl.Trainer.from_argparse_args(args)
    trainer.fit(model)
        parser.add_argument("--classify", action="store_true", default=False)
        parser.add_argument("--batch_size", type=int, default=64)
        parser.add_argument("--learning_rate", type=float, default=1e-2)
        parser.add_argument("--steps", type=int, default=25_000)
        return parser


# todo: covert to CLI func and add test
if __name__ == "__main__":
    from argparse import ArgumentParser

    parser = ArgumentParser()

    # trainer args
    parser = pl.Trainer.add_argparse_args(parser)

    # model args
    parser = ImageGPT.add_model_specific_args(parser)
    args = parser.parse_args()

    if args.dataset == "fashion_mnist":
        datamodule = FashionMNISTDataModule.from_argparse_args(args)

    elif args.dataset == "imagenet128":
        datamodule = ImagenetDataModule.from_argparse_args(args)

    model = ImageGPT(**args.__dict__, datamodule=datamodule)

    trainer = pl.Trainer.from_argparse_args(args)
    trainer.fit(model)