def main(): parser = ArgumentParser() parser = add_data_specific_args(parser) parser = LitDCGAN.add_model_specific_args(parser) parser = pl.Trainer.add_argparse_args(parser) args = parser.parse_args() pl.seed_everything(42) model = LitDCGAN() callbacks = [ TensorboardGenerativeModelImageSampler(), LatentDimInterpolator(interpolate_epoch_interval=5), ] trainer = pl.Trainer.from_argparse_args( args, callbacks=callbacks, ) transform = transforms.Compose([ transforms.Resize(model.input_size), transforms.CenterCrop(model.input_size), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) dataset = ImagesFolder( root=args.data_dir, transform=transform, ) dataloader = DataLoader( dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, ) trainer.fit(model, train_dataloader=dataloader)
def cli_main(args=None): from pl_bolts.callbacks import LatentDimInterpolator, TensorboardGenerativeModelImageSampler from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, MNISTDataModule, STL10DataModule pl.seed_everything(1234) parser = ArgumentParser() parser.add_argument("--dataset", default="mnist", type=str, help="mnist, cifar10, stl10, imagenet") script_args, _ = parser.parse_known_args(args) if script_args.dataset == "mnist": dm_cls = MNISTDataModule elif script_args.dataset == "cifar10": dm_cls = CIFAR10DataModule elif script_args.dataset == "stl10": dm_cls = STL10DataModule elif script_args.dataset == "imagenet": dm_cls = ImagenetDataModule parser = dm_cls.add_argparse_args(parser) parser = pl.Trainer.add_argparse_args(parser) parser = GAN.add_model_specific_args(parser) args = parser.parse_args(args) dm = dm_cls.from_argparse_args(args) model = GAN(*dm.size(), **vars(args)) callbacks = [ TensorboardGenerativeModelImageSampler(), LatentDimInterpolator(interpolate_epoch_interval=5) ] trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks, progress_bar_refresh_rate=20) trainer.fit(model, dm) return dm, model, trainer
def cli_main(args=None): from pl_bolts.callbacks import TensorboardGenerativeModelImageSampler from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule parser = ArgumentParser() parser.add_argument("--dataset", default="cifar10", type=str, choices=["cifar10", "stl10", "imagenet"]) script_args, _ = parser.parse_known_args(args) if script_args.dataset == "cifar10": dm_cls = CIFAR10DataModule elif script_args.dataset == "stl10": dm_cls = STL10DataModule elif script_args.dataset == "imagenet": dm_cls = ImagenetDataModule else: raise ValueError(f"undefined dataset {script_args.dataset}") parser = AE.add_model_specific_args(parser) parser = pl.Trainer.add_argparse_args(parser) args = parser.parse_args(args) dm = dm_cls.from_argparse_args(args) args.input_height = dm.size()[-1] if args.max_steps == -1: args.max_steps = None model = AE(**vars(args)) callbacks = [TensorboardGenerativeModelImageSampler()] trainer = pl.Trainer.from_argparse_args(args) trainer.callbacks += callbacks trainer.fit(model, dm) return dm, model, trainer
def cli_main(args=None): from pl_bolts.callbacks import LatentDimInterpolator, TensorboardGenerativeModelImageSampler pl.seed_everything() parser = ArgumentParser() parser.add_argument("--dataset", default="cifar10", type=str, help="cifar10, stl10, imagenet") script_args, _ = parser.parse_known_args(args) if script_args.dataset == "cifar10": dm_cls = CIFAR10DataModule elif script_args.dataset == "stl10": dm_cls = STL10DataModule elif script_args.dataset == "imagenet": dm_cls = ImagenetDataModule parser = VAE.add_model_specific_args(parser) args = parser.parse_args(args) dm = dm_cls.from_argparse_args(args) args.input_height = dm.size()[-1] if args.max_steps == -1: args.max_steps = None model = VAE(**vars(args)) callbacks = [ TensorboardGenerativeModelImageSampler(), LatentDimInterpolator(interpolate_epoch_interval=5) ] trainer = pl.Trainer.from_argparse_args(args) trainer.fit(model, dm) return dm, model, trainer