Exemplo n.º 1
0
def train(data_dir, num_workers, use_gpu, batch_size, embed_dim, max_epochs,
          max_steps):
    # Make sure data is downloaded on all nodes.
    def download_data():
        from filelock import FileLock
        with FileLock(os.path.join(data_dir, ".lock")):
            MNISTDataModule(data_dir=data_dir).prepare_data()

    plugin = RayShardedPlugin(num_workers=num_workers,
                              use_gpu=use_gpu,
                              init_hook=download_data)

    dm = MNISTDataModule(data_dir, batch_size=batch_size)

    model = ImageGPT(embed_dim=embed_dim,
                     layers=16,
                     heads=4,
                     vocab_size=32,
                     num_pixels=28)

    trainer = pl.Trainer(max_epochs=max_epochs,
                         precision=16 if use_gpu else 32,
                         callbacks=[CUDACallback()] if use_gpu else [],
                         plugins=plugin,
                         max_steps=max_steps)

    trainer.fit(model, dm)
Exemplo n.º 2
0
def test_igpt(tmpdir, datadir):
    seed_everything(0)
    dm = MNISTDataModule(data_dir=datadir, normalize=False)
    model = ImageGPT()

    trainer = Trainer(
        limit_train_batches=2,
        limit_val_batches=2,
        limit_test_batches=2,
        max_epochs=1,
    )
    trainer.fit(model, datamodule=dm)
    trainer.test(datamodule=dm)
    assert trainer.callback_metrics["test_loss"] < 1.7

    dm = FashionMNISTDataModule(data_dir=datadir, num_workers=1)
    model = ImageGPT(classify=True)
    trainer = Trainer(
        limit_train_batches=2,
        limit_val_batches=2,
        limit_test_batches=2,
        max_epochs=1,
        logger=False,
        checkpoint_callback=False,
    )
    trainer.fit(model, datamodule=dm)
def cli_main():
    from pl_bolts.callbacks import LatentDimInterpolator, TensorboardGenerativeModelImageSampler
    from pl_bolts.datamodules import ImagenetDataModule

    pl.seed_everything(1234)
    parser = ArgumentParser()
    parser.add_argument('--dataset',
                        default='mnist',
                        type=str,
                        help='mnist, stl10, imagenet')

    parser = pl.Trainer.add_argparse_args(parser)
    parser = VAE.add_model_specific_args(parser)
    parser = ImagenetDataModule.add_argparse_args(parser)
    parser = MNISTDataModule.add_argparse_args(parser)
    args = parser.parse_args()

    # default is mnist
    datamodule = None
    if args.dataset == 'imagenet2012':
        datamodule = ImagenetDataModule.from_argparse_args(args)
    elif args.dataset == 'stl10':
        datamodule = STL10DataModule.from_argparse_args(args)

    callbacks = [
        TensorboardGenerativeModelImageSampler(),
        LatentDimInterpolator(interpolate_epoch_interval=5)
    ]
    vae = VAE(**vars(args), datamodule=datamodule)
    trainer = pl.Trainer.from_argparse_args(args,
                                            callbacks=callbacks,
                                            progress_bar_refresh_rate=10)
    trainer.fit(vae)
Exemplo n.º 4
0
def test_logistic_regression_model(tmpdir, datadir):
    pl.seed_everything(0)

    # create dataset
    dm = MNISTDataModule(num_workers=0, data_dir=datadir)

    model = LogisticRegression(input_dim=28 * 28,
                               num_classes=10,
                               learning_rate=0.001)
    model.prepare_data = dm.prepare_data
    model.setup = dm.setup
    model.train_dataloader = dm.train_dataloader
    model.val_dataloader = dm.val_dataloader
    model.test_dataloader = dm.test_dataloader

    trainer = pl.Trainer(
        max_epochs=3,
        default_root_dir=tmpdir,
        progress_bar_refresh_rate=0,
        logger=False,
        checkpoint_callback=False,
    )
    trainer.fit(model)
    trainer.test(model)
    assert trainer.progress_bar_dict['test_acc'] >= 0.9
Exemplo n.º 5
0
def tune_mnist(num_samples=10, num_epochs=10, gpus_per_trial=0):
    data_dir = os.path.join(tempfile.gettempdir(), "mnist_data_")
    # Download data
    MNISTDataModule(data_dir=data_dir).prepare_data()

    config = {
        "layer_1": tune.choice([32, 64, 128]),
        "layer_2": tune.choice([64, 128, 256]),
        "lr": tune.loguniform(1e-4, 1e-1),
        "batch_size": tune.choice([32, 64, 128]),
    }

    trainable = tune.with_parameters(train_mnist_tune,
                                     data_dir=data_dir,
                                     num_epochs=num_epochs,
                                     num_gpus=gpus_per_trial)
    analysis = tune.run(trainable,
                        resources_per_trial={
                            "cpu": 1,
                            "gpu": gpus_per_trial
                        },
                        metric="loss",
                        mode="min",
                        config=config,
                        num_samples=num_samples,
                        name="tune_mnist")

    print("Best hyperparameters found were: ", analysis.best_config)
Exemplo n.º 6
0
 def _set_default_datamodule(self, datamodule):
     # link default data
     if datamodule is None:
         datamodule = MNISTDataModule(data_dir=self.hparams.data_dir,
                                      num_workers=self.hparams.num_workers,
                                      normalize=True)
     self.datamodule = datamodule
     self.img_dim = self.datamodule.size()
Exemplo n.º 7
0
    def __init__(self,
                 datamodule: LightningDataModule = None,
                 latent_dim: int = 32,
                 batch_size: int = 100,
                 learning_rate: float = 0.0002,
                 data_dir: str = '',
                 num_workers: int = 8,
                 **kwargs):
        """
        Vanilla GAN implementation.

        Example::

            from pl_bolts.models.gan import GAN

            m = GAN()
            Trainer(gpus=2).fit(m)

        Example CLI::

            # mnist
            python  basic_gan_module.py --gpus 1

            # imagenet
            python  basic_gan_module.py --gpus 1 --dataset 'imagenet2012'
            --data_dir /path/to/imagenet/folder/ --meta_dir ~/path/to/meta/bin/folder
            --batch_size 256 --learning_rate 0.0001

        Args:

            datamodule: the datamodule (train, val, test splits)
            latent_dim: emb dim for encoder
            batch_size: the batch size
            learning_rate: the learning rate
            data_dir: where to store data
            num_workers: data workers

        """
        super().__init__()

        # makes self.hparams under the hood and saves to ckpt
        self.save_hyperparameters()

        # link default data
        if datamodule is None:
            datamodule = MNISTDataModule(
                data_dir=self.hparams.data_dir,
                num_workers=self.hparams.num_workers,
                normalize=True
            )
        self.datamodule = datamodule
        self.img_dim = self.datamodule.size()

        # networks
        self.generator = self.init_generator(self.img_dim)
        self.discriminator = self.init_discriminator(self.img_dim)
    def _set_default_datamodule(self, datamodule):
        # link default data
        if datamodule is None:
            datamodule = MNISTDataModule(data_dir=self.hparams.data_dir,
                                         num_workers=self.hparams.num_workers,
                                         normalize=False)
        self.datamodule = datamodule
        self.img_dim = self.datamodule.size()

        (self.hparams.input_channels, self.hparams.input_height,
         self.hparams.input_width) = self.img_dim
Exemplo n.º 9
0
def train_mnist_tune(config, data_dir=None, num_epochs=10, num_gpus=0):
    model = LightningMNISTClassifier(config, data_dir)
    dm = MNISTDataModule(data_dir=data_dir,
                         num_workers=1,
                         batch_size=config["batch_size"])
    metrics = {"loss": "ptl/val_loss", "acc": "ptl/val_accuracy"}
    trainer = pl.Trainer(
        max_epochs=num_epochs,
        gpus=num_gpus,
        progress_bar_refresh_rate=0,
        callbacks=[TuneReportCallback(metrics, on="validation_end")])
    trainer.fit(model, dm)
Exemplo n.º 10
0
    def __init__(
            self,
            hidden_dim=128,
            latent_dim=32,
            input_channels=1,
            input_width=28,
            input_height=28,
            batch_size=32,
            learning_rate=0.001,
            data_dir=os.getcwd(),
            **kwargs
    ):
        super().__init__()
        self.save_hyperparameters()

        self.dataloaders = MNISTDataModule(data_dir=data_dir)
        self.img_dim = self.dataloaders.size()

        self.encoder = self.init_encoder(self.hparams.hidden_dim, self.hparams.latent_dim,
                                         self.hparams.input_width, self.hparams.input_height)
        self.decoder = self.init_decoder(self.hparams.hidden_dim, self.hparams.latent_dim)
Exemplo n.º 11
0
def tune_mnist(
    num_samples=10,
    num_epochs=10,
    gpus_per_trial=0,
    tracking_uri=None,
    experiment_name="ptl_autologging_example",
):
    data_dir = os.path.join(tempfile.gettempdir(), "mnist_data_")
    # Download data
    MNISTDataModule(data_dir=data_dir).prepare_data()

    # Set the MLflow experiment, or create it if it does not exist.
    mlflow.set_tracking_uri(tracking_uri)
    mlflow.set_experiment(experiment_name)

    config = {
        "layer_1": tune.choice([32, 64, 128]),
        "layer_2": tune.choice([64, 128, 256]),
        "lr": tune.loguniform(1e-4, 1e-1),
        "batch_size": tune.choice([32, 64, 128]),
        "mlflow": {
            "experiment_name": experiment_name,
            "tracking_uri": mlflow.get_tracking_uri(),
        },
        "data_dir": os.path.join(tempfile.gettempdir(), "mnist_data_"),
        "num_epochs": num_epochs,
    }

    trainable = tune.with_parameters(
        train_mnist_tune,
        data_dir=data_dir,
        num_epochs=num_epochs,
        num_gpus=gpus_per_trial,
    )

    analysis = tune.run(
        trainable,
        resources_per_trial={
            "cpu": 1,
            "gpu": gpus_per_trial
        },
        metric="loss",
        mode="min",
        config=config,
        num_samples=num_samples,
        name="tune_mnist",
    )

    print("Best hyperparameters found were: ", analysis.best_config)
Exemplo n.º 12
0
def test_predict(tmpdir, ray_start_2_cpus, seed, num_workers):
    """Tests if trained model has high accuracy on test set."""
    config = {
        "layer_1": 32,
        "layer_2": 32,
        "lr": 1e-2,
        "batch_size": 32,
    }

    model = LightningMNISTClassifier(config, tmpdir)
    dm = MNISTDataModule(
        data_dir=tmpdir, num_workers=1, batch_size=config["batch_size"])
    plugin = RayPlugin(num_workers=num_workers, use_gpu=False)
    trainer = get_trainer(
        tmpdir, limit_train_batches=20, max_epochs=1, plugins=[plugin])
    predict_test(trainer, model, dm)
Exemplo n.º 13
0
    def __init__(self,
                 datamodule: LightningDataModule = None,
                 input_channels=1,
                 input_height=28,
                 input_width=28,
                 latent_dim=32,
                 batch_size=32,
                 hidden_dim=128,
                 learning_rate=0.001,
                 num_workers=8,
                 data_dir='.',
                 **kwargs):
        """
        Arg:

            datamodule: the datamodule (train, val, test splits)
            input_channels: num of image channels
            input_height: image height
            input_width: image width
            latent_dim: emb dim for encoder
            batch_size: the batch size
            hidden_dim: the encoder dim
            learning_rate: the learning rate
            num_workers: num dataloader workers
            data_dir: where to store data
        """
        super().__init__()
        self.save_hyperparameters()

        # link default data
        if datamodule is None:
            datamodule = MNISTDataModule(data_dir=self.hparams.data_dir,
                                         num_workers=self.hparams.num_workers)

        self.datamodule = datamodule

        self.img_dim = self.datamodule.size()

        self.encoder = self.init_encoder(self.hparams.hidden_dim,
                                         self.hparams.latent_dim,
                                         self.hparams.input_width,
                                         self.hparams.input_height)
        self.decoder = self.init_decoder(self.hparams.hidden_dim,
                                         self.hparams.latent_dim)
Exemplo n.º 14
0
def test_igpt(tmpdir):
    pl.seed_everything(0)
    dm = MNISTDataModule(tmpdir, normalize=False)
    model = ImageGPT(datamodule=dm)

    trainer = pl.Trainer(limit_train_batches=2,
                         limit_val_batches=2,
                         limit_test_batches=2,
                         max_epochs=1)
    trainer.fit(model)
    trainer.test()
    assert trainer.callback_metrics["test_loss"] < 1.7

    model = ImageGPT(classify=True)
    trainer = pl.Trainer(limit_train_batches=2,
                         limit_val_batches=2,
                         limit_test_batches=2,
                         max_epochs=1)
    trainer.fit(model)
Exemplo n.º 15
0
def test_predict_client(tmpdir, start_ray_client_server_2_cpus, seed,
                        num_workers):
    assert ray.util.client.ray.is_connected()
    config = {
        "layer_1": 32,
        "layer_2": 32,
        "lr": 1e-2,
        "batch_size": 32,
    }

    model = LightningMNISTClassifier(config, tmpdir)
    dm = MNISTDataModule(data_dir=tmpdir,
                         num_workers=1,
                         batch_size=config["batch_size"])
    plugin = RayPlugin(num_workers=num_workers, use_gpu=False)
    trainer = get_trainer(tmpdir,
                          limit_train_batches=20,
                          max_epochs=1,
                          plugins=[plugin])
    predict_test(trainer, model, dm)
Exemplo n.º 16
0
        parser.add_argument('--pretrained', type=str, default=None)
        parser.add_argument('--data_dir', type=str, default=os.getcwd())

        parser.add_argument('--learning_rate', type=float, default=1e-3)
        return parser


if __name__ == '__main__':
    from pl_bolts.datamodules import ImagenetDataModule
    parser = ArgumentParser()
    parser.add_argument('--dataset', default='mnist', type=str)

    parser = Trainer.add_argparse_args(parser)
    parser = VAE.add_model_specific_args(parser)
    parser = ImagenetDataModule.add_argparse_args(parser)
    parser = MNISTDataModule.add_argparse_args(parser)
    args = parser.parse_args()
    #
    # if args.dataset == 'imagenet' or args.pretrained:
    #     datamodule = ImagenetDataModule.from_argparse_args(args)
    #     args.image_width = datamodule.size()[1]
    #     args.image_height = datamodule.size()[2]
    #     args.input_channels = datamodule.size()[0]
    #
    # elif args.dataset == 'mnist':
    #     datamodule = MNISTDataModule.from_argparse_args(args)
    #     args.image_width = datamodule.size()[1]
    #     args.image_height = datamodule.size()[2]
    #     args.input_channels = datamodule.size()[0]

    vae = VAE(**vars(args))
Exemplo n.º 17
0
 def __set_default_datamodule(self, data_dir):
     if self.datamodule is None:
         self.datamodule = MNISTDataModule(data_dir=data_dir)
         (self.hparams.input_channels, self.hparams.input_height,
          self.hparams.input_width) = self.datamodule.size()
Exemplo n.º 18
0
class VAE(LightningModule):
    def __init__(self,
                 hidden_dim: int = 128,
                 latent_dim: int = 32,
                 input_channels: int = 3,
                 input_width: int = 224,
                 input_height: int = 224,
                 batch_size: int = 32,
                 learning_rate: float = 0.001,
                 data_dir: str = '.',
                 datamodule: pl_bolts.datamodules.LightningDataModule = None,
                 pretrained: str = None,
                 **kwargs):
        """
        Standard VAE with Gaussian Prior and approx posterior.

        Model is available pretrained on different datasets:

        Example::

            # not pretrained
            vae = VAE()

            # pretrained on imagenet
            vae = VAE(pretrained='imagenet')

            # pretrained on cifar10
            vae = VAE(pretrained='cifar10')

        Args:

            hidden_dim: encoder and decoder hidden dims
            latent_dim: latenet code dim
            input_channels: num of channels of the input image.
            input_width: image input width
            input_height: image input height
            batch_size: the batch size
            learning_rate" the learning rate
            data_dir: the directory to store data
            datamodule: The Lightning DataModule
            pretrained: Load weights pretrained on a dataset
        """
        super().__init__()
        self.save_hyperparameters()

        self.datamodule = datamodule
        self.__set_pretrained_dims(pretrained)

        # use mnist as the default module
        self.__set_default_datamodule(data_dir)

        # init actual model
        self.__init_system()

        if pretrained:
            self.load_pretrained(pretrained)

    def __init_system(self):
        self.encoder = self.init_encoder()
        self.decoder = self.init_decoder()

    def __set_pretrained_dims(self, pretrained):
        if pretrained == 'imagenet2012':
            self.datamodule = ImagenetDataModule(
                data_dir=self.hparams.data_dir)
            (self.hparams.input_channels, self.hparams.input_height,
             self.hparams.input_width) = self.datamodule.size()

    def __set_default_datamodule(self, data_dir):
        if self.datamodule is None:
            self.datamodule = MNISTDataModule(data_dir=data_dir)
            (self.hparams.input_channels, self.hparams.input_height,
             self.hparams.input_width) = self.datamodule.size()

    def load_pretrained(self, pretrained):
        available_weights = {'imagenet2012'}

        if pretrained in available_weights:
            weights_name = f'vae-{pretrained}'
            load_pretrained(self, weights_name)

    def init_encoder(self):
        encoder = Encoder(self.hparams.hidden_dim, self.hparams.latent_dim,
                          self.hparams.input_channels,
                          self.hparams.input_width, self.hparams.input_height)
        return encoder

    def init_decoder(self):
        decoder = Decoder(self.hparams.hidden_dim, self.hparams.latent_dim,
                          self.hparams.input_width, self.hparams.input_height,
                          self.hparams.input_channels)
        return decoder

    def get_prior(self, z_mu, z_std):
        # Prior ~ Normal(0,1)
        P = distributions.normal.Normal(loc=torch.zeros_like(z_mu),
                                        scale=torch.ones_like(z_std))
        return P

    def get_approx_posterior(self, z_mu, z_std):
        # Approx Posterior ~ Normal(mu, sigma)
        Q = distributions.normal.Normal(loc=z_mu, scale=z_std)
        return Q

    def elbo_loss(self, x, P, Q):
        # Reconstruction loss
        z = Q.rsample()
        pxz = self(z)
        pxz = torch.tanh(pxz)
        recon_loss = F.mse_loss(pxz, x, reduction='none')

        # sum across dimensions because sum of log probabilities of iid univariate gaussians is the same as
        # multivariate gaussian
        recon_loss = recon_loss.sum(dim=-1)

        # KL divergence loss
        log_qz = Q.log_prob(z)
        log_pz = P.log_prob(z)
        kl_div = (log_qz - log_pz).sum(dim=1)

        # ELBO = reconstruction + KL
        loss = recon_loss + kl_div

        # average over batch
        loss = loss.mean()
        recon_loss = recon_loss.mean()
        kl_div = kl_div.mean()

        return loss, recon_loss, kl_div, pxz

    def forward(self, z):
        return self.decoder(z)

    def _run_step(self, batch):
        x, _ = batch
        z_mu, z_log_var = self.encoder(x)
        z_std = torch.exp(z_log_var / 2)

        P = self.get_prior(z_mu, z_std)
        Q = self.get_approx_posterior(z_mu, z_std)

        x = x.view(x.size(0), -1)

        loss, recon_loss, kl_div, pxz = self.elbo_loss(x, P, Q)

        return loss, recon_loss, kl_div, pxz

    def training_step(self, batch, batch_idx):
        loss, recon_loss, kl_div, pxz = self._run_step(batch)

        tensorboard_logs = {
            'train_elbo_loss': loss,
            'train_recon_loss': recon_loss,
            'train_kl_loss': kl_div
        }

        return {'loss': loss, 'log': tensorboard_logs}

    def validation_step(self, batch, batch_idx):
        loss, recon_loss, kl_div, pxz = self._run_step(batch)

        return {
            'val_loss': loss,
            'val_recon_loss': recon_loss,
            'val_kl_div': kl_div,
            'pxz': pxz
        }

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        recon_loss = torch.stack([x['val_recon_loss'] for x in outputs]).mean()
        kl_loss = torch.stack([x['val_kl_div'] for x in outputs]).mean()

        tensorboard_logs = {
            'val_elbo_loss': avg_loss,
            'val_recon_loss': recon_loss,
            'val_kl_loss': kl_loss
        }

        return {'val_loss': avg_loss, 'log': tensorboard_logs}

    def test_step(self, batch, batch_idx):
        loss, recon_loss, kl_div, pxz = self._run_step(batch)

        return {
            'test_loss': loss,
            'test_recon_loss': recon_loss,
            'test_kl_div': kl_div,
            'pxz': pxz
        }

    def test_epoch_end(self, outputs):
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
        recon_loss = torch.stack([x['test_recon_loss']
                                  for x in outputs]).mean()
        kl_loss = torch.stack([x['test_kl_div'] for x in outputs]).mean()

        tensorboard_logs = {
            'test_elbo_loss': avg_loss,
            'test_recon_loss': recon_loss,
            'test_kl_loss': kl_loss
        }

        return {'test_loss': avg_loss, 'log': tensorboard_logs}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(),
                                lr=self.hparams.learning_rate)

    def prepare_data(self):
        self.datamodule.prepare_data()

    def train_dataloader(self):
        return self.datamodule.train_dataloader(self.hparams.batch_size)

    def val_dataloader(self):
        return self.datamodule.val_dataloader(self.hparams.batch_size)

    def test_dataloader(self):
        return self.datamodule.test_dataloader(self.hparams.batch_size)

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument(
            '--hidden_dim',
            type=int,
            default=128,
            help=
            'itermediate layers dimension before embedding for default encoder/decoder'
        )
        parser.add_argument('--latent_dim',
                            type=int,
                            default=32,
                            help='dimension of latent variables z')
        parser.add_argument(
            '--input_width',
            type=int,
            default=224,
            help='input width (used Imagenet downsampled size)')
        parser.add_argument(
            '--input_height',
            type=int,
            default=224,
            help='input width (used Imagenet downsampled size)')
        parser.add_argument('--input_channels',
                            type=int,
                            default=3,
                            help='number of input channels')
        parser.add_argument('--batch_size', type=int, default=32)
        parser.add_argument('--pretrained', type=str, default=None)
        parser.add_argument('--data_dir', type=str, default=os.getcwd())

        parser.add_argument('--learning_rate', type=float, default=1e-3)
        return parser
Exemplo n.º 19
0
            hook.record = False

    def on_train_epoch_start(self, trainer, pl_module):
        for hook in self.all_hooks:
            hook.record = True

    def on_train_batch_start(self, trainer, pl_module, batch, batch_idx,
                             dataloader_idx):
        opt = pl_module.optimizers(use_pl_optimizer=False)
        self.lrs.append(opt.param_groups[0]['lr'])
        self.momentums.append(opt.param_groups[0]['betas'][0])


if __name__ == '__main__':

    dm = MNISTDataModule()
    dm.batch_size = 512
    print(dm.batch_size)
    dm.setup()
    dm.prepare_data()
    dl = dm.train_dataloader()
    steps_per_epoch = len(dl)

    model = CustomModel(steps_per_epoch, 0.06)
    all_hooks = []
    for i, k in enumerate(model.model):
        if isinstance(k, ConvBlock):
            for l in k:
                hook = ActivationHistHook(i, type(l))
                hook.add_hook(l)
                all_hooks.append(hook)
Exemplo n.º 20
0
 def __set_pretrained_dims(self, pretrained):
     if pretrained == 'imagenet2012':
         self.datamodule = ImagenetDataModule(
             data_dir=self.hparams.data_dir)
         (self.hparams.input_channels, self.hparams.input_height,
          self.hparams.input_width) = self.datamodule.size()
Exemplo n.º 21
0
class AE(LightningModule):

    def __init__(
            self,
            hidden_dim=128,
            latent_dim=32,
            input_channels=1,
            input_width=28,
            input_height=28,
            batch_size=32,
            learning_rate=0.001,
            data_dir=os.getcwd(),
            **kwargs
    ):
        super().__init__()
        self.save_hyperparameters()

        self.dataloaders = MNISTDataModule(data_dir=data_dir)
        self.img_dim = self.dataloaders.size()

        self.encoder = self.init_encoder(self.hparams.hidden_dim, self.hparams.latent_dim,
                                         self.hparams.input_width, self.hparams.input_height)
        self.decoder = self.init_decoder(self.hparams.hidden_dim, self.hparams.latent_dim)

    def init_encoder(self, hidden_dim, latent_dim, input_width, input_height):
        encoder = AEEncoder(hidden_dim, latent_dim, input_width, input_height)
        return encoder

    def init_decoder(self, hidden_dim, latent_dim):
        c, h, w = self.img_dim
        decoder = Decoder(hidden_dim, latent_dim, w, h, c)
        return decoder

    def forward(self, z):
        return self.decoder(z)

    def _run_step(self, batch):
        x, _ = batch
        z = self.encoder(x)
        x_hat = self(z)

        loss = F.mse_loss(x.view(x.size(0), -1), x_hat)
        return loss

    def training_step(self, batch, batch_idx):
        loss = self._run_step(batch)

        tensorboard_logs = {
            'mse_loss': loss,
        }

        return {'loss': loss, 'log': tensorboard_logs}

    def validation_step(self, batch, batch_idx):
        loss = self._run_step(batch)

        return {
            'val_loss': loss,
        }

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()

        tensorboard_logs = {'mse_loss': avg_loss}

        return {
            'val_loss': avg_loss,
            'log': tensorboard_logs
        }

    def test_step(self, batch, batch_idx):
        loss = self._run_step(batch)

        return {
            'test_loss': loss,
        }

    def test_epoch_end(self, outputs):
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()

        tensorboard_logs = {'mse_loss': avg_loss}

        return {
            'test_loss': avg_loss,
            'log': tensorboard_logs
        }

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)

    def prepare_data(self):
        self.dataloaders.prepare_data()

    def train_dataloader(self):
        return self.dataloaders.train_dataloader(self.hparams.batch_size)

    def val_dataloader(self):
        return self.dataloaders.val_dataloader(self.hparams.batch_size)

    def test_dataloader(self):
        return self.dataloaders.test_dataloader(self.hparams.batch_size)

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument('--hidden_dim', type=int, default=128,
                            help='itermediate layers dimension before embedding for default encoder/decoder')
        parser.add_argument('--latent_dim', type=int, default=32,
                            help='dimension of latent variables z')
        parser.add_argument('--input_width', type=int, default=28,
                            help='input image width - 28 for MNIST (must be even)')
        parser.add_argument('--input_height', type=int, default=28,
                            help='input image height - 28 for MNIST (must be even)')
        parser.add_argument('--batch_size', type=int, default=32)
        parser.add_argument('--learning_rate', type=float, default=1e-3)
        parser.add_argument('--data_dir', type=str, default='')
        return parser
Exemplo n.º 22
0
    from pl_bolts.datamodules import MNISTDataModule, ImagenetDataModule
    '''
    datamodule = SprayDataModule(dataset_parms=config['exp_params'],
                                 trainer_parms=config['trainer_params'],
                                 model_params=config['model_params'])
    model = VAE(hidden_dim = 128,
                latent_dim = 32,
                input_channels = config['model_params']['in_channels'],
                input_width = config['exp_params']['img_size'],
                input_height = config['exp_params']['img_size'],
                batch_size = config['exp_params']['batch_size'],
                learning_rate = config['exp_params']['LR'],
                datamodule = datamodule)
    '''
    datamodule = MNISTDataModule(data_dir='./')
    #datamodule = ImagenetDataModule(data_dir='./')

    model = VAE(datamodule=datamodule)

    trainer = pl.Trainer(
        gpus=1,
        max_epochs=config['trainer_params']['max_epochs'],
        checkpoint_callback=checkpoint_callback)  #, profiler=True)

    trainer.fit(model)

    output_dir = 'trained_models/'
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)
    torch.save(model.state_dict(), os.path.join(output_dir, "model.pt"))
Exemplo n.º 23
0
 def download_data():
     from filelock import FileLock
     with FileLock(os.path.join(data_dir, ".lock")):
         MNISTDataModule(data_dir=data_dir).prepare_data()