コード例 #1
0
ファイル: main.py プロジェクト: senior-sigan/GAN_zoo
def main():
    parser = ArgumentParser()
    parser = add_data_specific_args(parser)
    parser = LitDCGAN.add_model_specific_args(parser)
    parser = pl.Trainer.add_argparse_args(parser)
    args = parser.parse_args()

    pl.seed_everything(42)

    model = LitDCGAN()
    callbacks = [
        TensorboardGenerativeModelImageSampler(),
        LatentDimInterpolator(interpolate_epoch_interval=5),
    ]

    trainer = pl.Trainer.from_argparse_args(
        args,
        callbacks=callbacks,
    )

    transform = transforms.Compose([
        transforms.Resize(model.input_size),
        transforms.CenterCrop(model.input_size),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    dataset = ImagesFolder(
        root=args.data_dir,
        transform=transform,
    )
    dataloader = DataLoader(
        dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.workers,
    )
    trainer.fit(model, train_dataloader=dataloader)
コード例 #2
0
def cli_main(args=None):
    from pl_bolts.callbacks import LatentDimInterpolator, TensorboardGenerativeModelImageSampler
    from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, MNISTDataModule, STL10DataModule

    pl.seed_everything(1234)

    parser = ArgumentParser()
    parser.add_argument("--dataset",
                        default="mnist",
                        type=str,
                        help="mnist, cifar10, stl10, imagenet")
    script_args, _ = parser.parse_known_args(args)

    if script_args.dataset == "mnist":
        dm_cls = MNISTDataModule
    elif script_args.dataset == "cifar10":
        dm_cls = CIFAR10DataModule
    elif script_args.dataset == "stl10":
        dm_cls = STL10DataModule
    elif script_args.dataset == "imagenet":
        dm_cls = ImagenetDataModule

    parser = dm_cls.add_argparse_args(parser)
    parser = pl.Trainer.add_argparse_args(parser)
    parser = GAN.add_model_specific_args(parser)
    args = parser.parse_args(args)

    dm = dm_cls.from_argparse_args(args)
    model = GAN(*dm.size(), **vars(args))
    callbacks = [
        TensorboardGenerativeModelImageSampler(),
        LatentDimInterpolator(interpolate_epoch_interval=5)
    ]
    trainer = pl.Trainer.from_argparse_args(args,
                                            callbacks=callbacks,
                                            progress_bar_refresh_rate=20)
    trainer.fit(model, dm)
    return dm, model, trainer
コード例 #3
0
def cli_main(args=None):
    from pl_bolts.callbacks import TensorboardGenerativeModelImageSampler
    from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule

    parser = ArgumentParser()
    parser.add_argument("--dataset",
                        default="cifar10",
                        type=str,
                        choices=["cifar10", "stl10", "imagenet"])
    script_args, _ = parser.parse_known_args(args)

    if script_args.dataset == "cifar10":
        dm_cls = CIFAR10DataModule
    elif script_args.dataset == "stl10":
        dm_cls = STL10DataModule
    elif script_args.dataset == "imagenet":
        dm_cls = ImagenetDataModule
    else:
        raise ValueError(f"undefined dataset {script_args.dataset}")

    parser = AE.add_model_specific_args(parser)
    parser = pl.Trainer.add_argparse_args(parser)
    args = parser.parse_args(args)

    dm = dm_cls.from_argparse_args(args)
    args.input_height = dm.size()[-1]

    if args.max_steps == -1:
        args.max_steps = None

    model = AE(**vars(args))
    callbacks = [TensorboardGenerativeModelImageSampler()]

    trainer = pl.Trainer.from_argparse_args(args)
    trainer.callbacks += callbacks
    trainer.fit(model, dm)
    return dm, model, trainer
コード例 #4
0
def cli_main(args=None):
    from pl_bolts.callbacks import LatentDimInterpolator, TensorboardGenerativeModelImageSampler

    pl.seed_everything()

    parser = ArgumentParser()
    parser.add_argument("--dataset",
                        default="cifar10",
                        type=str,
                        help="cifar10, stl10, imagenet")
    script_args, _ = parser.parse_known_args(args)

    if script_args.dataset == "cifar10":
        dm_cls = CIFAR10DataModule
    elif script_args.dataset == "stl10":
        dm_cls = STL10DataModule
    elif script_args.dataset == "imagenet":
        dm_cls = ImagenetDataModule

    parser = VAE.add_model_specific_args(parser)
    args = parser.parse_args(args)

    dm = dm_cls.from_argparse_args(args)
    args.input_height = dm.size()[-1]

    if args.max_steps == -1:
        args.max_steps = None

    model = VAE(**vars(args))
    callbacks = [
        TensorboardGenerativeModelImageSampler(),
        LatentDimInterpolator(interpolate_epoch_interval=5)
    ]
    trainer = pl.Trainer.from_argparse_args(args)
    trainer.fit(model, dm)
    return dm, model, trainer