def cli_main(): from pl_bolts.callbacks import LatentDimInterpolator, TensorboardGenerativeModelImageSampler from pl_bolts.datamodules import ImagenetDataModule pl.seed_everything(1234) parser = ArgumentParser() parser.add_argument('--dataset', default='mnist', type=str, help='mnist, stl10, imagenet') parser = pl.Trainer.add_argparse_args(parser) parser = VAE.add_model_specific_args(parser) parser = ImagenetDataModule.add_argparse_args(parser) parser = MNISTDataModule.add_argparse_args(parser) args = parser.parse_args() # default is mnist datamodule = None if args.dataset == 'imagenet2012': datamodule = ImagenetDataModule.from_argparse_args(args) elif args.dataset == 'stl10': datamodule = STL10DataModule.from_argparse_args(args) callbacks = [ TensorboardGenerativeModelImageSampler(), LatentDimInterpolator(interpolate_epoch_interval=5) ] vae = VAE(**vars(args), datamodule=datamodule) trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks, progress_bar_refresh_rate=10) trainer.fit(vae)
def cli_main(): from pl_bolts.callbacks.self_supervised import SSLOnlineEvaluator from pl_bolts.datamodules import CIFAR10DataModule, STL10DataModule, ImagenetDataModule from pl_bolts.models.self_supervised.simclr import SimCLRTrainDataTransform, SimCLREvalDataTransform seed_everything(1234) parser = ArgumentParser() # trainer args parser = pl.Trainer.add_argparse_args(parser) # model args parser = BYOL.add_model_specific_args(parser) args = parser.parse_args() # pick data dm = None # init default datamodule if args.dataset == 'cifar10': dm = CIFAR10DataModule.from_argparse_args(args) dm.train_transforms = SimCLRTrainDataTransform(32) dm.val_transforms = SimCLREvalDataTransform(32) args.num_classes = dm.num_classes elif args.dataset == 'stl10': dm = STL10DataModule.from_argparse_args(args) dm.train_dataloader = dm.train_dataloader_mixed dm.val_dataloader = dm.val_dataloader_mixed (c, h, w) = dm.size() dm.train_transforms = SimCLRTrainDataTransform(h) dm.val_transforms = SimCLREvalDataTransform(h) args.num_classes = dm.num_classes elif args.dataset == 'imagenet2012': dm = ImagenetDataModule.from_argparse_args(args, image_size=196) (c, h, w) = dm.size() dm.train_transforms = SimCLRTrainDataTransform(h) dm.val_transforms = SimCLREvalDataTransform(h) args.num_classes = dm.num_classes model = BYOL(**args.__dict__) def to_device(batch, device): (x1, x2), y = batch x1 = x1.to(device) y = y.to(device) return x1, y # finetune in real-time online_eval = SSLOnlineEvaluator(z_dim=2048, num_classes=dm.num_classes) online_eval.to_device = to_device trainer = pl.Trainer.from_argparse_args(args, max_steps=300000, callbacks=[online_eval]) trainer.fit(model, dm)
def get_datamodule(cfg: DictConfig) -> pl.LightningDataModule: """Create DataModule according to the user configuration. Currently only ImageNet and CIFAR10 dataset is supported. Args: cfg: The top-level user configuration object. Returns: A LightningDataModule, which will be consumed by the neural network model. Raises: ValueError: An unsupported dataset is specified. """ dataset_cfg = cfg.dataset if dataset_cfg.name == 'cifar10': dm = CIFAR10DataModule(dataset_cfg.data_dir, num_workers=dataset_cfg.workers, batch_size=dataset_cfg.batch_size, seed=cfg.seed) elif dataset_cfg.name == 'imagenet': dm = ImagenetDataModule(dataset_cfg.data_dir, num_workers=dataset_cfg.workers, batch_size=dataset_cfg.batch_size) else: raise ValueError( f'get_datamodule does not support dataset {dataset_cfg.name}') return dm
def cli_main(): # pragma: no-cover pl.seed_everything(1234) parser = ArgumentParser() parser = pl.Trainer.add_argparse_args(parser) parser.add_argument('--dataset', type=str, help='stl10, cifar10', default='cifar10') parser.add_argument('--ckpt_path', type=str, help='path to ckpt') parser.add_argument('--data_dir', type=str, help='path to ckpt', default=os.getcwd()) args = parser.parse_args() # load the backbone backbone = SimCLR.load_from_checkpoint(args.ckpt_path, strict=False) # init default datamodule if args.dataset == 'cifar10': dm = CIFAR10DataModule.from_argparse_args(args) dm.train_transforms = SimCLRTrainDataTransform(32) dm.val_transforms = SimCLREvalDataTransform(32) dm.test_transforms = SimCLREvalDataTransform(32) args.num_samples = dm.num_samples elif args.dataset == 'stl10': dm = STL10DataModule.from_argparse_args(args) dm.train_dataloader = dm.train_dataloader_labeled dm.val_dataloader = dm.val_dataloader_labeled args.num_samples = dm.num_labeled_samples (c, h, w) = dm.size() dm.train_transforms = SimCLRTrainDataTransform(h) dm.val_transforms = SimCLREvalDataTransform(h) elif args.dataset == 'imagenet2012': dm = ImagenetDataModule.from_argparse_args(args, image_size=196) (c, h, w) = dm.size() dm.train_transforms = SimCLRTrainDataTransform(h) dm.val_transforms = SimCLREvalDataTransform(h) # finetune tuner = SSLFineTuner(backbone, in_features=2048 * 2 * 2, num_classes=dm.num_classes, hidden_dim=None) trainer = pl.Trainer.from_argparse_args(args, early_stop_callback=True) trainer.fit(tuner, dm) trainer.test(datamodule=dm)
def cli_main(): from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator from pl_bolts.datamodules import CIFAR10DataModule, STL10DataModule, ImagenetDataModule parser = ArgumentParser() # trainer args parser = pl.Trainer.add_argparse_args(parser) # model args parser = SimCLR.add_model_specific_args(parser) args = parser.parse_args() # init default datamodule if args.dataset == 'cifar10': dm = CIFAR10DataModule.from_argparse_args(args) dm.train_transforms = SimCLRTrainDataTransform(32) dm.val_transforms = SimCLREvalDataTransform(32) args.num_samples = dm.num_samples elif args.dataset == 'stl10': dm = STL10DataModule.from_argparse_args(args) dm.train_dataloader = dm.train_dataloader_mixed dm.val_dataloader = dm.val_dataloader_mixed args.num_samples = dm.num_unlabeled_samples (c, h, w) = dm.size() dm.train_transforms = SimCLRTrainDataTransform(h) dm.val_transforms = SimCLREvalDataTransform(h) elif args.dataset == 'imagenet2012': dm = ImagenetDataModule.from_argparse_args(args, image_size=196) (c, h, w) = dm.size() dm.train_transforms = SimCLRTrainDataTransform(h) dm.val_transforms = SimCLREvalDataTransform(h) model = SimCLR(**args.__dict__) # finetune in real-time def to_device(batch, device): (x1, x2), y = batch x1 = x1.to(device) y = y.to(device) return x1, y online_eval = SSLOnlineEvaluator(z_dim=2048 * 2 * 2, num_classes=dm.num_classes) online_eval.to_device = to_device trainer = pl.Trainer.from_argparse_args(args, callbacks=[online_eval]) trainer.fit(model, dm)
def train(): parser = ArgumentParser() parser.add_argument('--gpus', type=int, default=None) parser.add_argument('--dataset', type=str, default='cifar10') args = parser.parse_args() if args.dataset == 'cifar10': dataset = CIFAR10DataModule('.') if args.dataset == 'imagenet': dataset = ImagenetDataModule('.') sampler = ImageSampler() vae = VAE() trainer = pl.Trainer(gpus=args.gpus, max_epochs=20, callbacks=[sampler]) trainer.fit(vae, dataset)
def cli_main(): parser = ArgumentParser() # trainer args parser = pl.Trainer.add_argparse_args(parser) # model args parser = ImageGPT.add_model_specific_args(parser) args = parser.parse_args() if args.dataset == "fashion_mnist": datamodule = FashionMNISTDataModule.from_argparse_args(args) elif args.dataset == "imagenet128": datamodule = ImagenetDataModule.from_argparse_args(args) model = ImageGPT(**args.__dict__, datamodule=datamodule) trainer = pl.Trainer.from_argparse_args(args) trainer.fit(model)
images = pl_module(z) grid = torchvision.utils.make_grid(images) trainer.logger.experiment.add_image('gan_images', grid, global_step=trainer.global_step) # todo: covert to CLI func and add test if __name__ == '__main__': from pl_bolts.datamodules import ImagenetDataModule parser = ArgumentParser() parser = Trainer.add_argparse_args(parser) parser = GAN.add_model_specific_args(parser) parser = ImagenetDataModule.add_argparse_args(parser) args = parser.parse_args() # default is mnist datamodule = None if args.dataset == 'imagenet2012': datamodule = ImagenetDataModule.from_argparse_args(args) elif args.dataset == 'stl10': datamodule = STL10DataModule.from_argparse_args(args) gan = GAN(**vars(args), datamodule=datamodule) callbacks = [ImageGenerator(), LatentDimInterpolator()] # no val loop... thus we condition on loss and always save the last checkpoint_cb = ModelCheckpoint(monitor='loss', save_last=True) trainer = Trainer.from_argparse_args(args,
def cli_main(): # pragma: no cover from pl_bolts.datamodules import ImagenetDataModule, STL10DataModule seed_everything(1234) parser = ArgumentParser() parser.add_argument("--dataset", type=str, help="stl10, imagenet", default="stl10") parser.add_argument("--ckpt_path", type=str, help="path to ckpt") parser.add_argument("--data_dir", type=str, help="path to dataset", default=os.getcwd()) parser.add_argument("--batch_size", default=64, type=int, help="batch size per gpu") parser.add_argument("--num_workers", default=8, type=int, help="num of workers per GPU") parser.add_argument("--gpus", default=4, type=int, help="number of GPUs") parser.add_argument("--num_epochs", default=100, type=int, help="number of epochs") # fine-tuner params parser.add_argument("--in_features", type=int, default=2048) parser.add_argument("--dropout", type=float, default=0.0) parser.add_argument("--learning_rate", type=float, default=0.3) parser.add_argument("--weight_decay", type=float, default=1e-6) parser.add_argument("--nesterov", type=bool, default=False) parser.add_argument("--scheduler_type", type=str, default="cosine") parser.add_argument("--gamma", type=float, default=0.1) parser.add_argument("--final_lr", type=float, default=0.0) args = parser.parse_args() if args.dataset == "stl10": dm = STL10DataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers) dm.train_dataloader = dm.train_dataloader_labeled dm.val_dataloader = dm.val_dataloader_labeled args.num_samples = 0 dm.train_transforms = SwAVFinetuneTransform( normalize=stl10_normalization(), input_height=dm.size()[-1], eval_transform=False) dm.val_transforms = SwAVFinetuneTransform( normalize=stl10_normalization(), input_height=dm.size()[-1], eval_transform=True) dm.test_transforms = SwAVFinetuneTransform( normalize=stl10_normalization(), input_height=dm.size()[-1], eval_transform=True) args.maxpool1 = False args.first_conv = True elif args.dataset == "imagenet": dm = ImagenetDataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers) dm.train_transforms = SwAVFinetuneTransform( normalize=imagenet_normalization(), input_height=dm.size()[-1], eval_transform=False) dm.val_transforms = SwAVFinetuneTransform( normalize=imagenet_normalization(), input_height=dm.size()[-1], eval_transform=True) dm.test_transforms = SwAVFinetuneTransform( normalize=imagenet_normalization(), input_height=dm.size()[-1], eval_transform=True) args.num_samples = 1 args.maxpool1 = True args.first_conv = True else: raise NotImplementedError( "other datasets have not been implemented till now") backbone = SwAV( gpus=args.gpus, nodes=1, num_samples=args.num_samples, batch_size=args.batch_size, maxpool1=args.maxpool1, first_conv=args.first_conv, dataset=args.dataset, ).load_from_checkpoint(args.ckpt_path, strict=False) tuner = SSLFineTuner( backbone, in_features=args.in_features, num_classes=dm.num_classes, epochs=args.num_epochs, hidden_dim=None, dropout=args.dropout, learning_rate=args.learning_rate, weight_decay=args.weight_decay, nesterov=args.nesterov, scheduler_type=args.scheduler_type, gamma=args.gamma, final_lr=args.final_lr, ) trainer = Trainer( gpus=args.gpus, num_nodes=1, precision=16, max_epochs=args.num_epochs, distributed_backend="ddp", sync_batchnorm=True if args.gpus > 1 else False, ) trainer.fit(tuner, dm) trainer.test(datamodule=dm)
def __set_pretrained_dims(self, pretrained): if pretrained == 'imagenet2012': self.datamodule = ImagenetDataModule( data_dir=self.hparams.data_dir) (self.hparams.input_channels, self.hparams.input_height, self.hparams.input_width) = self.datamodule.size()
parser.add_argument('--batch_size', type=int, default=32) parser.add_argument('--pretrained', type=str, default=None) parser.add_argument('--data_dir', type=str, default=os.getcwd()) parser.add_argument('--learning_rate', type=float, default=1e-3) return parser if __name__ == '__main__': from pl_bolts.datamodules import ImagenetDataModule parser = ArgumentParser() parser.add_argument('--dataset', default='mnist', type=str) parser = Trainer.add_argparse_args(parser) parser = VAE.add_model_specific_args(parser) parser = ImagenetDataModule.add_argparse_args(parser) parser = MNISTDataModule.add_argparse_args(parser) args = parser.parse_args() # # if args.dataset == 'imagenet' or args.pretrained: # datamodule = ImagenetDataModule.from_argparse_args(args) # args.image_width = datamodule.size()[1] # args.image_height = datamodule.size()[2] # args.input_channels = datamodule.size()[0] # # elif args.dataset == 'mnist': # datamodule = MNISTDataModule.from_argparse_args(args) # args.image_width = datamodule.size()[1] # args.image_height = datamodule.size()[2] # args.input_channels = datamodule.size()[0]
parser.add_argument('--batch_size', default=256, type=int, help='Per device batch size.') parser.add_argument( '--data_dir', default='./', type=str, help='Directory for pre-downloaded ImageNet or cache for CIFAR10.') parser = pl.Trainer.add_argparse_args(parser) args = parser.parse_args() # Can be swapped to CIFAR10DataModule dm = ImagenetDataModule(batch_size=args.batch_size, data_dir=args.data_dir, train_transforms=Transform(), test_transforms=Transform(), val_transforms=Transform()) model = BarlowTwins(lr=0.2, weight_decay=1e-6, lambd=0.0051, projector=[8192, 8192, 8192], scale_loss=0.024, per_device_batch_size=args.batch_size) trainer = pl.Trainer.from_argparse_args( args, max_epochs=1000, precision=16, accelerator='ddp',
def cli_main(): from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule from pl_bolts.models.self_supervised.simclr.transforms import SimCLREvalDataTransform, SimCLRTrainDataTransform parser = ArgumentParser() # model args parser = SimCLR.add_model_specific_args(parser) parser = pl.Trainer.add_argparse_args(parser) args = parser.parse_args() if args.dataset == 'stl10': dm = STL10DataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers) dm.train_dataloader = dm.train_dataloader_mixed dm.val_dataloader = dm.val_dataloader_mixed args.num_samples = dm.num_unlabeled_samples args.maxpool1 = False args.first_conv = True args.input_height = dm.size()[-1] normalization = stl10_normalization() args.gaussian_blur = True args.jitter_strength = 1. elif args.dataset == 'cifar10': val_split = 5000 if args.num_nodes * args.gpus * args.batch_size > val_split: val_split = args.num_nodes * args.gpus * args.batch_size dm = CIFAR10DataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers, val_split=val_split) args.num_samples = dm.num_samples args.maxpool1 = False args.first_conv = False args.input_height = dm.size()[-1] args.temperature = 0.5 normalization = cifar10_normalization() args.gaussian_blur = False args.jitter_strength = 0.5 elif args.dataset == 'imagenet': args.maxpool1 = True args.first_conv = True normalization = imagenet_normalization() args.gaussian_blur = True args.jitter_strength = 1. args.batch_size = 64 args.num_nodes = 8 args.gpus = 8 # per-node args.max_epochs = 800 args.optimizer = 'sgd' args.lars_wrapper = True args.learning_rate = 4.8 args.final_lr = 0.0048 args.start_lr = 0.3 args.online_ft = True dm = ImagenetDataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers) args.num_samples = dm.num_samples args.input_height = dm.size()[-1] else: raise NotImplementedError( "other datasets have not been implemented till now") dm.train_transforms = SimCLRTrainDataTransform( input_height=args.input_height, gaussian_blur=args.gaussian_blur, jitter_strength=args.jitter_strength, normalize=normalization, ) dm.val_transforms = SimCLREvalDataTransform( input_height=args.input_height, gaussian_blur=args.gaussian_blur, jitter_strength=args.jitter_strength, normalize=normalization, ) model = SimCLR(**args.__dict__) online_evaluator = None if args.online_ft: # online eval online_evaluator = SSLOnlineEvaluator(drop_p=0., hidden_dim=None, z_dim=args.hidden_mlp, num_classes=dm.num_classes, dataset=args.dataset) model_checkpoint = ModelCheckpoint(save_last=True, save_top_k=1, monitor='val_loss') callbacks = [model_checkpoint, online_evaluator ] if args.online_ft else [model_checkpoint] trainer = pl.Trainer.from_argparse_args( args, sync_batchnorm=True if args.gpus > 1 else False, callbacks=callbacks, ) trainer.fit(model, datamodule=dm)
dm.train_transforms = SimCLRTrainDataTransform(32) dm.val_transforms = SimCLREvalDataTransform(32) args.num_samples = dm.num_samples elif args.dataset == 'stl10': dm = STL10DataModule.from_argparse_args(args) dm.train_dataloader = dm.train_dataloader_mixed dm.val_dataloader = dm.val_dataloader_mixed args.num_samples = dm.num_unlabeled_samples (c, h, w) = dm.size() dm.train_transforms = SimCLRTrainDataTransform(h) dm.val_transforms = SimCLREvalDataTransform(h) elif args.dataset == 'imagenet2012': dm = ImagenetDataModule.from_argparse_args(args, image_size=196) (c, h, w) = dm.size() dm.train_transforms = SimCLRTrainDataTransform(h) dm.val_transforms = SimCLREvalDataTransform(h) model = SimCLR(**args.__dict__) # finetune in real-time def to_device(batch, device): (x1, x2), y = batch x1 = x1.to(device) y = y.to(device) return x1, y online_eval = SSLOnlineEvaluator(z_dim=2048 * 2 * 2, num_classes=dm.num_classes)
class VAE(pl.LightningModule): def __init__(self, hidden_dim: int = 128, latent_dim: int = 32, input_channels: int = 3, input_width: int = 224, input_height: int = 224, batch_size: int = 32, learning_rate: float = 0.001, data_dir: str = '.', datamodule: pl.LightningDataModule = None, num_workers: int = 8, pretrained: str = None, **kwargs): """ Standard VAE with Gaussian Prior and approx posterior. Model is available pretrained on different datasets: Example:: # not pretrained vae = VAE() # pretrained on imagenet vae = VAE(pretrained='imagenet') # pretrained on cifar10 vae = VAE(pretrained='cifar10') Args: hidden_dim: encoder and decoder hidden dims latent_dim: latenet code dim input_channels: num of channels of the input image. input_width: image input width input_height: image input height batch_size: the batch size learning_rate" the learning rate data_dir: the directory to store data datamodule: The Lightning DataModule pretrained: Load weights pretrained on a dataset """ super().__init__() self.save_hyperparameters() self.datamodule = datamodule self.__set_pretrained_dims(pretrained) # use mnist as the default module self._set_default_datamodule(datamodule) # init actual model self.__init_system() if pretrained: self.load_pretrained(pretrained) def __init_system(self): self.encoder = self.init_encoder() self.decoder = self.init_decoder() def __set_pretrained_dims(self, pretrained): if pretrained == 'imagenet2012': self.datamodule = ImagenetDataModule( data_dir=self.hparams.data_dir) (self.hparams.input_channels, self.hparams.input_height, self.hparams.input_width) = self.datamodule.size() def _set_default_datamodule(self, datamodule): # link default data if datamodule is None: datamodule = MNISTDataModule(data_dir=self.hparams.data_dir, num_workers=self.hparams.num_workers, normalize=False) self.datamodule = datamodule self.img_dim = self.datamodule.size() (self.hparams.input_channels, self.hparams.input_height, self.hparams.input_width) = self.img_dim def load_pretrained(self, pretrained): available_weights = {'imagenet2012'} if pretrained in available_weights: weights_name = f'vae-{pretrained}' load_pretrained(self, weights_name) def init_encoder(self): encoder = Encoder(self.hparams.hidden_dim, self.hparams.latent_dim, self.hparams.input_channels, self.hparams.input_width, self.hparams.input_height) return encoder def init_decoder(self): decoder = Decoder(self.hparams.hidden_dim, self.hparams.latent_dim, self.hparams.input_width, self.hparams.input_height, self.hparams.input_channels) return decoder def get_prior(self, z_mu, z_std): # Prior ~ Normal(0,1) P = distributions.normal.Normal(loc=torch.zeros_like(z_mu), scale=torch.ones_like(z_std)) return P def get_approx_posterior(self, z_mu, z_std): # Approx Posterior ~ Normal(mu, sigma) Q = distributions.normal.Normal(loc=z_mu, scale=z_std) return Q def elbo_loss(self, x, P, Q, num_samples): z = Q.rsample() # ---------------------- # KL divergence loss (using monte carlo sampling) # ---------------------- log_qz = Q.log_prob(z) log_pz = P.log_prob(z) # (batch, num_samples, z_dim) -> (batch, num_samples) kl_div = (log_qz - log_pz).sum(dim=2) # we used monte carlo sampling to estimate. average across samples # kl_div = kl_div.mean(-1) # ---------------------- # Reconstruction loss # ---------------------- z = z.view(-1, z.size(-1)).contiguous() pxz = self.decoder(z) pxz = pxz.view(-1, num_samples, pxz.size(-1)) x = shaping.tile(x.unsqueeze(1), 1, num_samples) pxz = torch.sigmoid(pxz) recon_loss = F.binary_cross_entropy(pxz, x, reduction='none') # sum across dimensions because sum of log probabilities of iid univariate gaussians is the same as # multivariate gaussian recon_loss = recon_loss.sum(dim=-1) # we used monte carlo sampling to estimate. average across samples # recon_loss = recon_loss.mean(-1) # ELBO = reconstruction + KL loss = recon_loss + kl_div # average over batch loss = loss.mean() recon_loss = recon_loss.mean() kl_div = kl_div.mean() return loss, recon_loss, kl_div, pxz def forward(self, z): return self.decoder(z) def _run_step(self, batch): x, _ = batch z_mu, z_log_var = self.encoder(x) # we're estimating the KL divergence using sampling num_samples = 32 # expand dims to sample all at once # (batch, z_dim) -> (batch, num_samples, z_dim) z_mu = z_mu.unsqueeze(1) z_mu = shaping.tile(z_mu, 1, num_samples) # (batch, z_dim) -> (batch, num_samples, z_dim) z_log_var = z_log_var.unsqueeze(1) z_log_var = shaping.tile(z_log_var, 1, num_samples) # convert to std z_std = torch.exp(z_log_var / 2) P = self.get_prior(z_mu, z_std) Q = self.get_approx_posterior(z_mu, z_std) x = x.view(x.size(0), -1) loss, recon_loss, kl_div, pxz = self.elbo_loss(x, P, Q, num_samples) return loss, recon_loss, kl_div, pxz def training_step(self, batch, batch_idx): loss, recon_loss, kl_div, pxz = self._run_step(batch) result = pl.TrainResult(loss) result.log_dict({ 'train_elbo_loss': loss, 'train_recon_loss': recon_loss, 'train_kl_loss': kl_div }) return result def validation_step(self, batch, batch_idx): loss, recon_loss, kl_div, pxz = self._run_step(batch) result = pl.EvalResult(loss, checkpoint_on=loss) result.log_dict({ 'val_loss': loss, 'val_recon_loss': recon_loss, 'val_kl_div': kl_div, }) return result def test_step(self, batch, batch_idx): loss, recon_loss, kl_div, pxz = self._run_step(batch) result = pl.EvalResult(loss) result.log_dict({ 'test_loss': loss, 'test_recon_loss': recon_loss, 'test_kl_div': kl_div, }) return result def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) @staticmethod def add_model_specific_args(parent_parser): parser = ArgumentParser(parents=[parent_parser], add_help=False) parser.add_argument( '--hidden_dim', type=int, default=128, help= 'itermediate layers dimension before embedding for default encoder/decoder' ) parser.add_argument('--latent_dim', type=int, default=4, help='dimension of latent variables z') parser.add_argument( '--input_width', type=int, default=224, help='input width (used Imagenet downsampled size)') parser.add_argument( '--input_height', type=int, default=224, help='input width (used Imagenet downsampled size)') parser.add_argument('--input_channels', type=int, default=3, help='number of input channels') parser.add_argument('--batch_size', type=int, default=64) parser.add_argument('--pretrained', type=str, default=None) parser.add_argument('--data_dir', type=str, default=os.getcwd()) parser.add_argument('--num_workers', type=int, default=8, help="num dataloader workers") parser.add_argument('--learning_rate', type=float, default=1e-3) return parser
parser.add_argument("--classify", action="store_true", default=False) parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--learning_rate", type=float, default=1e-2) parser.add_argument("--steps", type=int, default=25_000) return parser # todo: covert to CLI func and add test if __name__ == "__main__": from argparse import ArgumentParser parser = ArgumentParser() # trainer args parser = pl.Trainer.add_argparse_args(parser) # model args parser = ImageGPT.add_model_specific_args(parser) args = parser.parse_args() if args.dataset == "fashion_mnist": datamodule = FashionMNISTDataModule.from_argparse_args(args) elif args.dataset == "imagenet128": datamodule = ImagenetDataModule.from_argparse_args(args) model = ImageGPT(**args.__dict__, datamodule=datamodule) trainer = pl.Trainer.from_argparse_args(args) trainer.fit(model)
def cli_main(): from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule from pl_bolts.models.self_supervised.swav.transforms import SwAVEvalDataTransform, SwAVTrainDataTransform parser = ArgumentParser() # model args parser = SwAV.add_model_specific_args(parser) args = parser.parse_args() if args.dataset == 'stl10': dm = STL10DataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers) dm.train_dataloader = dm.train_dataloader_mixed dm.val_dataloader = dm.val_dataloader_mixed args.num_samples = dm.num_unlabeled_samples args.maxpool1 = False normalization = stl10_normalization() elif args.dataset == 'cifar10': args.batch_size = 2 args.num_workers = 0 dm = CIFAR10DataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers) args.num_samples = dm.num_samples args.maxpool1 = False args.first_conv = False normalization = cifar10_normalization() # cifar10 specific params args.size_crops = [32, 16] args.nmb_crops = [2, 1] args.gaussian_blur = False elif args.dataset == 'imagenet': args.maxpool1 = True args.first_conv = True normalization = imagenet_normalization() args.size_crops = [224, 96] args.nmb_crops = [2, 6] args.min_scale_crops = [0.14, 0.05] args.max_scale_crops = [1., 0.14] args.gaussian_blur = True args.jitter_strength = 1. args.batch_size = 64 args.num_nodes = 8 args.gpus = 8 # per-node args.max_epochs = 800 args.optimizer = 'lars' args.learning_rate = 4.8 args.final_lr = 0.0048 args.start_lr = 0.3 args.nmb_prototypes = 3000 args.online_ft = True dm = ImagenetDataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers) args.num_samples = dm.num_samples args.input_height = dm.size()[-1] else: raise NotImplementedError( "other datasets have not been implemented till now") dm.train_transforms = SwAVTrainDataTransform( normalize=normalization, size_crops=args.size_crops, nmb_crops=args.nmb_crops, min_scale_crops=args.min_scale_crops, max_scale_crops=args.max_scale_crops, gaussian_blur=args.gaussian_blur, jitter_strength=args.jitter_strength) dm.val_transforms = SwAVEvalDataTransform( normalize=normalization, size_crops=args.size_crops, nmb_crops=args.nmb_crops, min_scale_crops=args.min_scale_crops, max_scale_crops=args.max_scale_crops, gaussian_blur=args.gaussian_blur, jitter_strength=args.jitter_strength) # swav model init model = SwAV(**args.__dict__) online_evaluator = None if args.online_ft: # online eval online_evaluator = SSLOnlineEvaluator( drop_p=0., hidden_dim=None, z_dim=args.hidden_mlp, num_classes=dm.num_classes, dataset=args.dataset, ) lr_monitor = LearningRateMonitor(logging_interval="step") model_checkpoint = ModelCheckpoint(save_last=True, save_top_k=1, monitor='val_loss') callbacks = [model_checkpoint, online_evaluator ] if args.online_ft else [model_checkpoint] callbacks.append(lr_monitor) trainer = Trainer( max_epochs=args.max_epochs, max_steps=None if args.max_steps == -1 else args.max_steps, gpus=args.gpus, num_nodes=args.num_nodes, distributed_backend='ddp' if args.gpus > 1 else None, sync_batchnorm=True if args.gpus > 1 else False, precision=32 if args.fp32 else 16, callbacks=callbacks, fast_dev_run=args.fast_dev_run) trainer.fit(model, datamodule=dm)
def cli_main(): from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule from pl_bolts.models.self_supervised.simclr import SimCLREvalDataTransform, SimCLRTrainDataTransform seed_everything(1234) parser = ArgumentParser() # trainer args parser = pl.Trainer.add_argparse_args(parser) # model args parser = SimSiam.add_model_specific_args(parser) args = parser.parse_args() # pick data dm = None # init datamodule if args.dataset == "stl10": dm = STL10DataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers) dm.train_dataloader = dm.train_dataloader_mixed dm.val_dataloader = dm.val_dataloader_mixed args.num_samples = dm.num_unlabeled_samples args.maxpool1 = False args.first_conv = True args.input_height = dm.size()[-1] normalization = stl10_normalization() args.gaussian_blur = True args.jitter_strength = 1.0 elif args.dataset == "cifar10": val_split = 5000 if args.nodes * args.gpus * args.batch_size > val_split: val_split = args.nodes * args.gpus * args.batch_size dm = CIFAR10DataModule( data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers, val_split=val_split, ) args.num_samples = dm.num_samples args.maxpool1 = False args.first_conv = False args.input_height = dm.size()[-1] args.temperature = 0.5 normalization = cifar10_normalization() args.gaussian_blur = False args.jitter_strength = 0.5 elif args.dataset == "cifar100": val_split = 5000 if args.nodes * args.gpus * args.batch_size > val_split: val_split = args.nodes * args.gpus * args.batch_size dm = CIFAR100DataModule( data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers, val_split=val_split, ) args.num_samples = dm.num_samples args.maxpool1 = False args.first_conv = False args.input_height = dm.size()[-1] args.temperature = 0.5 # ((0.5071, 0.4866, 0.4409), (0.2009, 0.1984, 0.2023)) normalization = transforms.Normalize( mean=(0.5071, 0.4866, 0.4409), std=(0.2009, 0.1984, 0.2023), ) args.gaussian_blur = False args.jitter_strength = 0.5 elif args.dataset == "imagenet": args.maxpool1 = True args.first_conv = True normalization = imagenet_normalization() args.gaussian_blur = True args.jitter_strength = 1.0 args.batch_size = 64 args.nodes = 8 args.gpus = 8 # per-node args.max_epochs = 800 args.optimizer = "sgd" args.lars_wrapper = True args.learning_rate = 4.8 args.final_lr = 0.0048 args.start_lr = 0.3 args.online_ft = True dm = ImagenetDataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers) args.num_samples = dm.num_samples args.input_height = dm.size()[-1] else: raise NotImplementedError("other datasets have not been implemented till now") dm.train_transforms = SimCLRTrainDataTransform( input_height=args.input_height, gaussian_blur=args.gaussian_blur, jitter_strength=args.jitter_strength, normalize=normalization, ) dm.val_transforms = SimCLREvalDataTransform( input_height=args.input_height, gaussian_blur=args.gaussian_blur, jitter_strength=args.jitter_strength, normalize=normalization, ) model = SimSiam(**args.__dict__) # finetune in real-time online_evaluator = None if args.online_ft: # online eval online_evaluator = SSLOnlineEvaluator( drop_p=0.0, hidden_dim=None, z_dim=args.hidden_mlp, num_classes=dm.num_classes, dataset=args.dataset, ) trainer = pl.Trainer( max_epochs=args.max_epochs, max_steps=None if args.max_steps == -1 else args.max_steps, gpus=args.gpus, num_nodes=args.nodes, distributed_backend="ddp" if args.gpus > 1 else None, sync_batchnorm=True if args.gpus > 1 else False, precision=32 if args.fp32 else 16, callbacks=[online_evaluator] if args.online_ft else None, fast_dev_run=args.fast_dev_run, ) trainer.fit(model, dm)