Exemplo n.º 1
0
def cli_main():  # pragma: no cover
    from pl_bolts.datamodules import ImagenetDataModule, STL10DataModule

    seed_everything(1234)

    parser = ArgumentParser()
    parser.add_argument("--dataset",
                        type=str,
                        help="stl10, imagenet",
                        default="stl10")
    parser.add_argument("--ckpt_path", type=str, help="path to ckpt")
    parser.add_argument("--data_dir",
                        type=str,
                        help="path to dataset",
                        default=os.getcwd())

    parser.add_argument("--batch_size",
                        default=64,
                        type=int,
                        help="batch size per gpu")
    parser.add_argument("--num_workers",
                        default=8,
                        type=int,
                        help="num of workers per GPU")
    parser.add_argument("--gpus", default=4, type=int, help="number of GPUs")
    parser.add_argument("--num_epochs",
                        default=100,
                        type=int,
                        help="number of epochs")

    # fine-tuner params
    parser.add_argument("--in_features", type=int, default=2048)
    parser.add_argument("--dropout", type=float, default=0.0)
    parser.add_argument("--learning_rate", type=float, default=0.3)
    parser.add_argument("--weight_decay", type=float, default=1e-6)
    parser.add_argument("--nesterov", type=bool, default=False)
    parser.add_argument("--scheduler_type", type=str, default="cosine")
    parser.add_argument("--gamma", type=float, default=0.1)
    parser.add_argument("--final_lr", type=float, default=0.0)

    args = parser.parse_args()

    if args.dataset == "stl10":
        dm = STL10DataModule(data_dir=args.data_dir,
                             batch_size=args.batch_size,
                             num_workers=args.num_workers)

        dm.train_dataloader = dm.train_dataloader_labeled
        dm.val_dataloader = dm.val_dataloader_labeled
        args.num_samples = 0

        dm.train_transforms = SwAVFinetuneTransform(
            normalize=stl10_normalization(),
            input_height=dm.size()[-1],
            eval_transform=False)
        dm.val_transforms = SwAVFinetuneTransform(
            normalize=stl10_normalization(),
            input_height=dm.size()[-1],
            eval_transform=True)
        dm.test_transforms = SwAVFinetuneTransform(
            normalize=stl10_normalization(),
            input_height=dm.size()[-1],
            eval_transform=True)

        args.maxpool1 = False
        args.first_conv = True
    elif args.dataset == "imagenet":
        dm = ImagenetDataModule(data_dir=args.data_dir,
                                batch_size=args.batch_size,
                                num_workers=args.num_workers)

        dm.train_transforms = SwAVFinetuneTransform(
            normalize=imagenet_normalization(),
            input_height=dm.size()[-1],
            eval_transform=False)
        dm.val_transforms = SwAVFinetuneTransform(
            normalize=imagenet_normalization(),
            input_height=dm.size()[-1],
            eval_transform=True)

        dm.test_transforms = SwAVFinetuneTransform(
            normalize=imagenet_normalization(),
            input_height=dm.size()[-1],
            eval_transform=True)

        args.num_samples = 1
        args.maxpool1 = True
        args.first_conv = True
    else:
        raise NotImplementedError(
            "other datasets have not been implemented till now")

    backbone = SwAV(
        gpus=args.gpus,
        nodes=1,
        num_samples=args.num_samples,
        batch_size=args.batch_size,
        maxpool1=args.maxpool1,
        first_conv=args.first_conv,
        dataset=args.dataset,
    ).load_from_checkpoint(args.ckpt_path, strict=False)

    tuner = SSLFineTuner(
        backbone,
        in_features=args.in_features,
        num_classes=dm.num_classes,
        epochs=args.num_epochs,
        hidden_dim=None,
        dropout=args.dropout,
        learning_rate=args.learning_rate,
        weight_decay=args.weight_decay,
        nesterov=args.nesterov,
        scheduler_type=args.scheduler_type,
        gamma=args.gamma,
        final_lr=args.final_lr,
    )

    trainer = Trainer(
        gpus=args.gpus,
        num_nodes=1,
        precision=16,
        max_epochs=args.num_epochs,
        distributed_backend="ddp",
        sync_batchnorm=True if args.gpus > 1 else False,
    )

    trainer.fit(tuner, dm)
    trainer.test(datamodule=dm)
def cli_main():
    from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator
    from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule
    from pl_bolts.models.self_supervised.swav.transforms import SwAVEvalDataTransform, SwAVTrainDataTransform

    parser = ArgumentParser()

    # model args
    parser = SwAV.add_model_specific_args(parser)
    args = parser.parse_args()

    if args.dataset == 'stl10':
        dm = STL10DataModule(data_dir=args.data_dir,
                             batch_size=args.batch_size,
                             num_workers=args.num_workers)

        dm.train_dataloader = dm.train_dataloader_mixed
        dm.val_dataloader = dm.val_dataloader_mixed
        args.num_samples = dm.num_unlabeled_samples

        args.maxpool1 = False

        normalization = stl10_normalization()
    elif args.dataset == 'cifar10':
        args.batch_size = 2
        args.num_workers = 0

        dm = CIFAR10DataModule(data_dir=args.data_dir,
                               batch_size=args.batch_size,
                               num_workers=args.num_workers)

        args.num_samples = dm.num_samples

        args.maxpool1 = False
        args.first_conv = False

        normalization = cifar10_normalization()

        # cifar10 specific params
        args.size_crops = [32, 16]
        args.nmb_crops = [2, 1]
        args.gaussian_blur = False
    elif args.dataset == 'imagenet':
        args.maxpool1 = True
        args.first_conv = True
        normalization = imagenet_normalization()

        args.size_crops = [224, 96]
        args.nmb_crops = [2, 6]
        args.min_scale_crops = [0.14, 0.05]
        args.max_scale_crops = [1., 0.14]
        args.gaussian_blur = True
        args.jitter_strength = 1.

        args.batch_size = 64
        args.num_nodes = 8
        args.gpus = 8  # per-node
        args.max_epochs = 800

        args.optimizer = 'lars'
        args.learning_rate = 4.8
        args.final_lr = 0.0048
        args.start_lr = 0.3

        args.nmb_prototypes = 3000
        args.online_ft = True

        dm = ImagenetDataModule(data_dir=args.data_dir,
                                batch_size=args.batch_size,
                                num_workers=args.num_workers)

        args.num_samples = dm.num_samples
        args.input_height = dm.size()[-1]
    else:
        raise NotImplementedError(
            "other datasets have not been implemented till now")

    dm.train_transforms = SwAVTrainDataTransform(
        normalize=normalization,
        size_crops=args.size_crops,
        nmb_crops=args.nmb_crops,
        min_scale_crops=args.min_scale_crops,
        max_scale_crops=args.max_scale_crops,
        gaussian_blur=args.gaussian_blur,
        jitter_strength=args.jitter_strength)

    dm.val_transforms = SwAVEvalDataTransform(
        normalize=normalization,
        size_crops=args.size_crops,
        nmb_crops=args.nmb_crops,
        min_scale_crops=args.min_scale_crops,
        max_scale_crops=args.max_scale_crops,
        gaussian_blur=args.gaussian_blur,
        jitter_strength=args.jitter_strength)

    # swav model init
    model = SwAV(**args.__dict__)

    online_evaluator = None
    if args.online_ft:
        # online eval
        online_evaluator = SSLOnlineEvaluator(
            drop_p=0.,
            hidden_dim=None,
            z_dim=args.hidden_mlp,
            num_classes=dm.num_classes,
            dataset=args.dataset,
        )

    lr_monitor = LearningRateMonitor(logging_interval="step")
    model_checkpoint = ModelCheckpoint(save_last=True,
                                       save_top_k=1,
                                       monitor='val_loss')
    callbacks = [model_checkpoint, online_evaluator
                 ] if args.online_ft else [model_checkpoint]
    callbacks.append(lr_monitor)

    trainer = Trainer(
        max_epochs=args.max_epochs,
        max_steps=None if args.max_steps == -1 else args.max_steps,
        gpus=args.gpus,
        num_nodes=args.num_nodes,
        distributed_backend='ddp' if args.gpus > 1 else None,
        sync_batchnorm=True if args.gpus > 1 else False,
        precision=32 if args.fp32 else 16,
        callbacks=callbacks,
        fast_dev_run=args.fast_dev_run)

    trainer.fit(model, datamodule=dm)
class VAE(pl.LightningModule):
    def __init__(self,
                 hidden_dim: int = 128,
                 latent_dim: int = 32,
                 input_channels: int = 3,
                 input_width: int = 224,
                 input_height: int = 224,
                 batch_size: int = 32,
                 learning_rate: float = 0.001,
                 data_dir: str = '.',
                 datamodule: pl.LightningDataModule = None,
                 num_workers: int = 8,
                 pretrained: str = None,
                 **kwargs):
        """
        Standard VAE with Gaussian Prior and approx posterior.

        Model is available pretrained on different datasets:

        Example::

            # not pretrained
            vae = VAE()

            # pretrained on imagenet
            vae = VAE(pretrained='imagenet')

            # pretrained on cifar10
            vae = VAE(pretrained='cifar10')

        Args:

            hidden_dim: encoder and decoder hidden dims
            latent_dim: latenet code dim
            input_channels: num of channels of the input image.
            input_width: image input width
            input_height: image input height
            batch_size: the batch size
            learning_rate" the learning rate
            data_dir: the directory to store data
            datamodule: The Lightning DataModule
            pretrained: Load weights pretrained on a dataset
        """
        super().__init__()
        self.save_hyperparameters()

        self.datamodule = datamodule
        self.__set_pretrained_dims(pretrained)

        # use mnist as the default module
        self._set_default_datamodule(datamodule)

        # init actual model
        self.__init_system()

        if pretrained:
            self.load_pretrained(pretrained)

    def __init_system(self):
        self.encoder = self.init_encoder()
        self.decoder = self.init_decoder()

    def __set_pretrained_dims(self, pretrained):
        if pretrained == 'imagenet2012':
            self.datamodule = ImagenetDataModule(
                data_dir=self.hparams.data_dir)
            (self.hparams.input_channels, self.hparams.input_height,
             self.hparams.input_width) = self.datamodule.size()

    def _set_default_datamodule(self, datamodule):
        # link default data
        if datamodule is None:
            datamodule = MNISTDataModule(data_dir=self.hparams.data_dir,
                                         num_workers=self.hparams.num_workers,
                                         normalize=False)
        self.datamodule = datamodule
        self.img_dim = self.datamodule.size()

        (self.hparams.input_channels, self.hparams.input_height,
         self.hparams.input_width) = self.img_dim

    def load_pretrained(self, pretrained):
        available_weights = {'imagenet2012'}

        if pretrained in available_weights:
            weights_name = f'vae-{pretrained}'
            load_pretrained(self, weights_name)

    def init_encoder(self):
        encoder = Encoder(self.hparams.hidden_dim, self.hparams.latent_dim,
                          self.hparams.input_channels,
                          self.hparams.input_width, self.hparams.input_height)
        return encoder

    def init_decoder(self):
        decoder = Decoder(self.hparams.hidden_dim, self.hparams.latent_dim,
                          self.hparams.input_width, self.hparams.input_height,
                          self.hparams.input_channels)
        return decoder

    def get_prior(self, z_mu, z_std):
        # Prior ~ Normal(0,1)
        P = distributions.normal.Normal(loc=torch.zeros_like(z_mu),
                                        scale=torch.ones_like(z_std))
        return P

    def get_approx_posterior(self, z_mu, z_std):
        # Approx Posterior ~ Normal(mu, sigma)
        Q = distributions.normal.Normal(loc=z_mu, scale=z_std)
        return Q

    def elbo_loss(self, x, P, Q, num_samples):
        z = Q.rsample()

        # ----------------------
        # KL divergence loss (using monte carlo sampling)
        # ----------------------
        log_qz = Q.log_prob(z)
        log_pz = P.log_prob(z)

        # (batch, num_samples, z_dim) -> (batch, num_samples)
        kl_div = (log_qz - log_pz).sum(dim=2)

        # we used monte carlo sampling to estimate. average across samples
        # kl_div = kl_div.mean(-1)

        # ----------------------
        # Reconstruction loss
        # ----------------------
        z = z.view(-1, z.size(-1)).contiguous()
        pxz = self.decoder(z)

        pxz = pxz.view(-1, num_samples, pxz.size(-1))
        x = shaping.tile(x.unsqueeze(1), 1, num_samples)

        pxz = torch.sigmoid(pxz)
        recon_loss = F.binary_cross_entropy(pxz, x, reduction='none')

        # sum across dimensions because sum of log probabilities of iid univariate gaussians is the same as
        # multivariate gaussian
        recon_loss = recon_loss.sum(dim=-1)

        # we used monte carlo sampling to estimate. average across samples
        # recon_loss = recon_loss.mean(-1)

        # ELBO = reconstruction + KL
        loss = recon_loss + kl_div

        # average over batch
        loss = loss.mean()
        recon_loss = recon_loss.mean()
        kl_div = kl_div.mean()

        return loss, recon_loss, kl_div, pxz

    def forward(self, z):
        return self.decoder(z)

    def _run_step(self, batch):
        x, _ = batch
        z_mu, z_log_var = self.encoder(x)

        # we're estimating the KL divergence using sampling
        num_samples = 32

        # expand dims to sample all at once
        # (batch, z_dim) -> (batch, num_samples, z_dim)
        z_mu = z_mu.unsqueeze(1)
        z_mu = shaping.tile(z_mu, 1, num_samples)

        # (batch, z_dim) -> (batch, num_samples, z_dim)
        z_log_var = z_log_var.unsqueeze(1)
        z_log_var = shaping.tile(z_log_var, 1, num_samples)

        # convert to std
        z_std = torch.exp(z_log_var / 2)

        P = self.get_prior(z_mu, z_std)
        Q = self.get_approx_posterior(z_mu, z_std)

        x = x.view(x.size(0), -1)

        loss, recon_loss, kl_div, pxz = self.elbo_loss(x, P, Q, num_samples)

        return loss, recon_loss, kl_div, pxz

    def training_step(self, batch, batch_idx):
        loss, recon_loss, kl_div, pxz = self._run_step(batch)
        result = pl.TrainResult(loss)
        result.log_dict({
            'train_elbo_loss': loss,
            'train_recon_loss': recon_loss,
            'train_kl_loss': kl_div
        })
        return result

    def validation_step(self, batch, batch_idx):
        loss, recon_loss, kl_div, pxz = self._run_step(batch)
        result = pl.EvalResult(loss, checkpoint_on=loss)
        result.log_dict({
            'val_loss': loss,
            'val_recon_loss': recon_loss,
            'val_kl_div': kl_div,
        })
        return result

    def test_step(self, batch, batch_idx):
        loss, recon_loss, kl_div, pxz = self._run_step(batch)
        result = pl.EvalResult(loss)
        result.log_dict({
            'test_loss': loss,
            'test_recon_loss': recon_loss,
            'test_kl_div': kl_div,
        })
        return result

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(),
                                lr=self.hparams.learning_rate)

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument(
            '--hidden_dim',
            type=int,
            default=128,
            help=
            'itermediate layers dimension before embedding for default encoder/decoder'
        )
        parser.add_argument('--latent_dim',
                            type=int,
                            default=4,
                            help='dimension of latent variables z')
        parser.add_argument(
            '--input_width',
            type=int,
            default=224,
            help='input width (used Imagenet downsampled size)')
        parser.add_argument(
            '--input_height',
            type=int,
            default=224,
            help='input width (used Imagenet downsampled size)')
        parser.add_argument('--input_channels',
                            type=int,
                            default=3,
                            help='number of input channels')
        parser.add_argument('--batch_size', type=int, default=64)
        parser.add_argument('--pretrained', type=str, default=None)
        parser.add_argument('--data_dir', type=str, default=os.getcwd())
        parser.add_argument('--num_workers',
                            type=int,
                            default=8,
                            help="num dataloader workers")

        parser.add_argument('--learning_rate', type=float, default=1e-3)
        return parser
def cli_main():
    from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator
    from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule
    from pl_bolts.models.self_supervised.simclr.transforms import SimCLREvalDataTransform, SimCLRTrainDataTransform

    parser = ArgumentParser()

    # model args
    parser = SimCLR.add_model_specific_args(parser)
    parser = pl.Trainer.add_argparse_args(parser)
    args = parser.parse_args()

    if args.dataset == 'stl10':
        dm = STL10DataModule(data_dir=args.data_dir,
                             batch_size=args.batch_size,
                             num_workers=args.num_workers)

        dm.train_dataloader = dm.train_dataloader_mixed
        dm.val_dataloader = dm.val_dataloader_mixed
        args.num_samples = dm.num_unlabeled_samples

        args.maxpool1 = False
        args.first_conv = True
        args.input_height = dm.size()[-1]

        normalization = stl10_normalization()

        args.gaussian_blur = True
        args.jitter_strength = 1.
    elif args.dataset == 'cifar10':
        val_split = 5000
        if args.num_nodes * args.gpus * args.batch_size > val_split:
            val_split = args.num_nodes * args.gpus * args.batch_size

        dm = CIFAR10DataModule(data_dir=args.data_dir,
                               batch_size=args.batch_size,
                               num_workers=args.num_workers,
                               val_split=val_split)

        args.num_samples = dm.num_samples

        args.maxpool1 = False
        args.first_conv = False
        args.input_height = dm.size()[-1]
        args.temperature = 0.5

        normalization = cifar10_normalization()

        args.gaussian_blur = False
        args.jitter_strength = 0.5
    elif args.dataset == 'imagenet':
        args.maxpool1 = True
        args.first_conv = True
        normalization = imagenet_normalization()

        args.gaussian_blur = True
        args.jitter_strength = 1.

        args.batch_size = 64
        args.num_nodes = 8
        args.gpus = 8  # per-node
        args.max_epochs = 800

        args.optimizer = 'sgd'
        args.lars_wrapper = True
        args.learning_rate = 4.8
        args.final_lr = 0.0048
        args.start_lr = 0.3
        args.online_ft = True

        dm = ImagenetDataModule(data_dir=args.data_dir,
                                batch_size=args.batch_size,
                                num_workers=args.num_workers)

        args.num_samples = dm.num_samples
        args.input_height = dm.size()[-1]
    else:
        raise NotImplementedError(
            "other datasets have not been implemented till now")

    dm.train_transforms = SimCLRTrainDataTransform(
        input_height=args.input_height,
        gaussian_blur=args.gaussian_blur,
        jitter_strength=args.jitter_strength,
        normalize=normalization,
    )

    dm.val_transforms = SimCLREvalDataTransform(
        input_height=args.input_height,
        gaussian_blur=args.gaussian_blur,
        jitter_strength=args.jitter_strength,
        normalize=normalization,
    )

    model = SimCLR(**args.__dict__)

    online_evaluator = None
    if args.online_ft:
        # online eval
        online_evaluator = SSLOnlineEvaluator(drop_p=0.,
                                              hidden_dim=None,
                                              z_dim=args.hidden_mlp,
                                              num_classes=dm.num_classes,
                                              dataset=args.dataset)

    model_checkpoint = ModelCheckpoint(save_last=True,
                                       save_top_k=1,
                                       monitor='val_loss')
    callbacks = [model_checkpoint, online_evaluator
                 ] if args.online_ft else [model_checkpoint]

    trainer = pl.Trainer.from_argparse_args(
        args,
        sync_batchnorm=True if args.gpus > 1 else False,
        callbacks=callbacks,
    )

    trainer.fit(model, datamodule=dm)
Exemplo n.º 5
0
def cli_main():
    from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator
    from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule
    from pl_bolts.models.self_supervised.simclr import SimCLREvalDataTransform, SimCLRTrainDataTransform

    seed_everything(1234)

    parser = ArgumentParser()

    # trainer args
    parser = pl.Trainer.add_argparse_args(parser)

    # model args
    parser = SimSiam.add_model_specific_args(parser)
    args = parser.parse_args()

    # pick data
    dm = None

    # init datamodule
    if args.dataset == "stl10":
        dm = STL10DataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers)

        dm.train_dataloader = dm.train_dataloader_mixed
        dm.val_dataloader = dm.val_dataloader_mixed
        args.num_samples = dm.num_unlabeled_samples

        args.maxpool1 = False
        args.first_conv = True
        args.input_height = dm.size()[-1]

        normalization = stl10_normalization()

        args.gaussian_blur = True
        args.jitter_strength = 1.0
    elif args.dataset == "cifar10":
        val_split = 5000
        if args.nodes * args.gpus * args.batch_size > val_split:
            val_split = args.nodes * args.gpus * args.batch_size

        dm = CIFAR10DataModule(
            data_dir=args.data_dir,
            batch_size=args.batch_size,
            num_workers=args.num_workers,
            val_split=val_split,
        )

        args.num_samples = dm.num_samples

        args.maxpool1 = False
        args.first_conv = False
        args.input_height = dm.size()[-1]
        args.temperature = 0.5

        normalization = cifar10_normalization()

        args.gaussian_blur = False
        args.jitter_strength = 0.5
    elif args.dataset == "cifar100":
        val_split = 5000
        if args.nodes * args.gpus * args.batch_size > val_split:
            val_split = args.nodes * args.gpus * args.batch_size

        dm = CIFAR100DataModule(
            data_dir=args.data_dir,
            batch_size=args.batch_size,
            num_workers=args.num_workers,
            val_split=val_split,
        )

        args.num_samples = dm.num_samples

        args.maxpool1 = False
        args.first_conv = False
        args.input_height = dm.size()[-1]
        args.temperature = 0.5

        # ((0.5071, 0.4866, 0.4409), (0.2009, 0.1984, 0.2023))
        normalization = transforms.Normalize(
            mean=(0.5071, 0.4866, 0.4409),
            std=(0.2009, 0.1984, 0.2023),
        )

        args.gaussian_blur = False
        args.jitter_strength = 0.5
    elif args.dataset == "imagenet":
        args.maxpool1 = True
        args.first_conv = True
        normalization = imagenet_normalization()

        args.gaussian_blur = True
        args.jitter_strength = 1.0

        args.batch_size = 64
        args.nodes = 8
        args.gpus = 8  # per-node
        args.max_epochs = 800

        args.optimizer = "sgd"
        args.lars_wrapper = True
        args.learning_rate = 4.8
        args.final_lr = 0.0048
        args.start_lr = 0.3
        args.online_ft = True

        dm = ImagenetDataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers)

        args.num_samples = dm.num_samples
        args.input_height = dm.size()[-1]
    else:
        raise NotImplementedError("other datasets have not been implemented till now")

    dm.train_transforms = SimCLRTrainDataTransform(
        input_height=args.input_height,
        gaussian_blur=args.gaussian_blur,
        jitter_strength=args.jitter_strength,
        normalize=normalization,
    )

    dm.val_transforms = SimCLREvalDataTransform(
        input_height=args.input_height,
        gaussian_blur=args.gaussian_blur,
        jitter_strength=args.jitter_strength,
        normalize=normalization,
    )

    model = SimSiam(**args.__dict__)

    # finetune in real-time
    online_evaluator = None
    if args.online_ft:
        # online eval
        online_evaluator = SSLOnlineEvaluator(
            drop_p=0.0,
            hidden_dim=None,
            z_dim=args.hidden_mlp,
            num_classes=dm.num_classes,
            dataset=args.dataset,
        )

    trainer = pl.Trainer(
        max_epochs=args.max_epochs,
        max_steps=None if args.max_steps == -1 else args.max_steps,
        gpus=args.gpus,
        num_nodes=args.nodes,
        distributed_backend="ddp" if args.gpus > 1 else None,
        sync_batchnorm=True if args.gpus > 1 else False,
        precision=32 if args.fp32 else 16,
        callbacks=[online_evaluator] if args.online_ft else None,
        fast_dev_run=args.fast_dev_run,
    )

    trainer.fit(model, dm)