class VAE(LightningModule): def __init__(self, hidden_dim: int = 128, latent_dim: int = 32, input_channels: int = 3, input_width: int = 224, input_height: int = 224, batch_size: int = 32, learning_rate: float = 0.001, data_dir: str = '.', datamodule: pl_bolts.datamodules.LightningDataModule = None, pretrained: str = None, **kwargs): """ Standard VAE with Gaussian Prior and approx posterior. Model is available pretrained on different datasets: Example:: # not pretrained vae = VAE() # pretrained on imagenet vae = VAE(pretrained='imagenet') # pretrained on cifar10 vae = VAE(pretrained='cifar10') Args: hidden_dim: encoder and decoder hidden dims latent_dim: latenet code dim input_channels: num of channels of the input image. input_width: image input width input_height: image input height batch_size: the batch size learning_rate" the learning rate data_dir: the directory to store data datamodule: The Lightning DataModule pretrained: Load weights pretrained on a dataset """ super().__init__() self.save_hyperparameters() self.datamodule = datamodule self.__set_pretrained_dims(pretrained) # use mnist as the default module self.__set_default_datamodule(data_dir) # init actual model self.__init_system() if pretrained: self.load_pretrained(pretrained) def __init_system(self): self.encoder = self.init_encoder() self.decoder = self.init_decoder() def __set_pretrained_dims(self, pretrained): if pretrained == 'imagenet2012': self.datamodule = ImagenetDataModule( data_dir=self.hparams.data_dir) (self.hparams.input_channels, self.hparams.input_height, self.hparams.input_width) = self.datamodule.size() def __set_default_datamodule(self, data_dir): if self.datamodule is None: self.datamodule = MNISTDataModule(data_dir=data_dir) (self.hparams.input_channels, self.hparams.input_height, self.hparams.input_width) = self.datamodule.size() def load_pretrained(self, pretrained): available_weights = {'imagenet2012'} if pretrained in available_weights: weights_name = f'vae-{pretrained}' load_pretrained(self, weights_name) def init_encoder(self): encoder = Encoder(self.hparams.hidden_dim, self.hparams.latent_dim, self.hparams.input_channels, self.hparams.input_width, self.hparams.input_height) return encoder def init_decoder(self): decoder = Decoder(self.hparams.hidden_dim, self.hparams.latent_dim, self.hparams.input_width, self.hparams.input_height, self.hparams.input_channels) return decoder def get_prior(self, z_mu, z_std): # Prior ~ Normal(0,1) P = distributions.normal.Normal(loc=torch.zeros_like(z_mu), scale=torch.ones_like(z_std)) return P def get_approx_posterior(self, z_mu, z_std): # Approx Posterior ~ Normal(mu, sigma) Q = distributions.normal.Normal(loc=z_mu, scale=z_std) return Q def elbo_loss(self, x, P, Q): # Reconstruction loss z = Q.rsample() pxz = self(z) pxz = torch.tanh(pxz) recon_loss = F.mse_loss(pxz, x, reduction='none') # sum across dimensions because sum of log probabilities of iid univariate gaussians is the same as # multivariate gaussian recon_loss = recon_loss.sum(dim=-1) # KL divergence loss log_qz = Q.log_prob(z) log_pz = P.log_prob(z) kl_div = (log_qz - log_pz).sum(dim=1) # ELBO = reconstruction + KL loss = recon_loss + kl_div # average over batch loss = loss.mean() recon_loss = recon_loss.mean() kl_div = kl_div.mean() return loss, recon_loss, kl_div, pxz def forward(self, z): return self.decoder(z) def _run_step(self, batch): x, _ = batch z_mu, z_log_var = self.encoder(x) z_std = torch.exp(z_log_var / 2) P = self.get_prior(z_mu, z_std) Q = self.get_approx_posterior(z_mu, z_std) x = x.view(x.size(0), -1) loss, recon_loss, kl_div, pxz = self.elbo_loss(x, P, Q) return loss, recon_loss, kl_div, pxz def training_step(self, batch, batch_idx): loss, recon_loss, kl_div, pxz = self._run_step(batch) tensorboard_logs = { 'train_elbo_loss': loss, 'train_recon_loss': recon_loss, 'train_kl_loss': kl_div } return {'loss': loss, 'log': tensorboard_logs} def validation_step(self, batch, batch_idx): loss, recon_loss, kl_div, pxz = self._run_step(batch) return { 'val_loss': loss, 'val_recon_loss': recon_loss, 'val_kl_div': kl_div, 'pxz': pxz } def validation_epoch_end(self, outputs): avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() recon_loss = torch.stack([x['val_recon_loss'] for x in outputs]).mean() kl_loss = torch.stack([x['val_kl_div'] for x in outputs]).mean() tensorboard_logs = { 'val_elbo_loss': avg_loss, 'val_recon_loss': recon_loss, 'val_kl_loss': kl_loss } return {'val_loss': avg_loss, 'log': tensorboard_logs} def test_step(self, batch, batch_idx): loss, recon_loss, kl_div, pxz = self._run_step(batch) return { 'test_loss': loss, 'test_recon_loss': recon_loss, 'test_kl_div': kl_div, 'pxz': pxz } def test_epoch_end(self, outputs): avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean() recon_loss = torch.stack([x['test_recon_loss'] for x in outputs]).mean() kl_loss = torch.stack([x['test_kl_div'] for x in outputs]).mean() tensorboard_logs = { 'test_elbo_loss': avg_loss, 'test_recon_loss': recon_loss, 'test_kl_loss': kl_loss } return {'test_loss': avg_loss, 'log': tensorboard_logs} def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) def prepare_data(self): self.datamodule.prepare_data() def train_dataloader(self): return self.datamodule.train_dataloader(self.hparams.batch_size) def val_dataloader(self): return self.datamodule.val_dataloader(self.hparams.batch_size) def test_dataloader(self): return self.datamodule.test_dataloader(self.hparams.batch_size) @staticmethod def add_model_specific_args(parent_parser): parser = ArgumentParser(parents=[parent_parser], add_help=False) parser.add_argument( '--hidden_dim', type=int, default=128, help= 'itermediate layers dimension before embedding for default encoder/decoder' ) parser.add_argument('--latent_dim', type=int, default=32, help='dimension of latent variables z') parser.add_argument( '--input_width', type=int, default=224, help='input width (used Imagenet downsampled size)') parser.add_argument( '--input_height', type=int, default=224, help='input width (used Imagenet downsampled size)') parser.add_argument('--input_channels', type=int, default=3, help='number of input channels') parser.add_argument('--batch_size', type=int, default=32) parser.add_argument('--pretrained', type=str, default=None) parser.add_argument('--data_dir', type=str, default=os.getcwd()) parser.add_argument('--learning_rate', type=float, default=1e-3) return parser
class AE(LightningModule): def __init__( self, hidden_dim=128, latent_dim=32, input_channels=1, input_width=28, input_height=28, batch_size=32, learning_rate=0.001, data_dir=os.getcwd(), **kwargs ): super().__init__() self.save_hyperparameters() self.dataloaders = MNISTDataModule(data_dir=data_dir) self.img_dim = self.dataloaders.size() self.encoder = self.init_encoder(self.hparams.hidden_dim, self.hparams.latent_dim, self.hparams.input_width, self.hparams.input_height) self.decoder = self.init_decoder(self.hparams.hidden_dim, self.hparams.latent_dim) def init_encoder(self, hidden_dim, latent_dim, input_width, input_height): encoder = AEEncoder(hidden_dim, latent_dim, input_width, input_height) return encoder def init_decoder(self, hidden_dim, latent_dim): c, h, w = self.img_dim decoder = Decoder(hidden_dim, latent_dim, w, h, c) return decoder def forward(self, z): return self.decoder(z) def _run_step(self, batch): x, _ = batch z = self.encoder(x) x_hat = self(z) loss = F.mse_loss(x.view(x.size(0), -1), x_hat) return loss def training_step(self, batch, batch_idx): loss = self._run_step(batch) tensorboard_logs = { 'mse_loss': loss, } return {'loss': loss, 'log': tensorboard_logs} def validation_step(self, batch, batch_idx): loss = self._run_step(batch) return { 'val_loss': loss, } def validation_epoch_end(self, outputs): avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() tensorboard_logs = {'mse_loss': avg_loss} return { 'val_loss': avg_loss, 'log': tensorboard_logs } def test_step(self, batch, batch_idx): loss = self._run_step(batch) return { 'test_loss': loss, } def test_epoch_end(self, outputs): avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean() tensorboard_logs = {'mse_loss': avg_loss} return { 'test_loss': avg_loss, 'log': tensorboard_logs } def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) def prepare_data(self): self.dataloaders.prepare_data() def train_dataloader(self): return self.dataloaders.train_dataloader(self.hparams.batch_size) def val_dataloader(self): return self.dataloaders.val_dataloader(self.hparams.batch_size) def test_dataloader(self): return self.dataloaders.test_dataloader(self.hparams.batch_size) @staticmethod def add_model_specific_args(parent_parser): parser = ArgumentParser(parents=[parent_parser], add_help=False) parser.add_argument('--hidden_dim', type=int, default=128, help='itermediate layers dimension before embedding for default encoder/decoder') parser.add_argument('--latent_dim', type=int, default=32, help='dimension of latent variables z') parser.add_argument('--input_width', type=int, default=28, help='input image width - 28 for MNIST (must be even)') parser.add_argument('--input_height', type=int, default=28, help='input image height - 28 for MNIST (must be even)') parser.add_argument('--batch_size', type=int, default=32) parser.add_argument('--learning_rate', type=float, default=1e-3) parser.add_argument('--data_dir', type=str, default='') return parser