def test_VeryDeepVAE(self): from pytorch_generative.models.vae.vd_vae import StackConfig model = models.VeryDeepVAE( in_channels=3, out_channels=3, input_resolution=8, stack_configs=[ StackConfig(n_encoder_blocks=1, n_decoder_blocks=1), StackConfig(n_encoder_blocks=1, n_decoder_blocks=1), ], latent_channels=1, bottleneck_channels=1, ) self._test_multiple_channels(model)
def reproduce( n_epochs=500, batch_size=128, log_dir="/tmp/run", n_gpus=1, device_id=0, debug_loader=None, ): """Training script with defaults to reproduce results. The code inside this function is self contained and can be used as a top level training script, e.g. by copy/pasting it into a Jupyter notebook. Args: n_epochs: Number of epochs to train for. batch_size: Batch size to use for training and evaluation. log_dir: Directory where to log trainer state and TensorBoard summaries. n_gpus: Number of GPUs to use for training the model. If 0, uses CPU. device_id: The device_id of the current GPU when training on multiple GPUs. debug_loader: Debug DataLoader which replaces the default training and evaluation loaders if not 'None'. Do not use unless you're writing unit tests. """ import torch from torch import optim from torch.nn import functional as F from pytorch_generative import datasets from pytorch_generative import models from pytorch_generative import trainer train_loader, test_loader = debug_loader, debug_loader if train_loader is None: train_loader, test_loader = datasets.get_mnist_loaders( batch_size, dynamically_binarize=True, resize_to_32=True ) stack_configs = [ StackConfig(n_encoder_blocks=3, n_decoder_blocks=5), StackConfig(n_encoder_blocks=3, n_decoder_blocks=5), StackConfig(n_encoder_blocks=2, n_decoder_blocks=4), StackConfig(n_encoder_blocks=2, n_decoder_blocks=3), StackConfig(n_encoder_blocks=2, n_decoder_blocks=2), StackConfig(n_encoder_blocks=1, n_decoder_blocks=1), ] model = models.VeryDeepVAE( in_channels=1, out_channels=1, input_resolution=32, stack_configs=stack_configs, latent_channels=16, hidden_channels=64, bottleneck_channels=32, ) optimizer = optim.Adam(model.parameters(), lr=5e-4) def loss_fn(x, _, preds): preds, kl_div = preds recon_loss = F.binary_cross_entropy_with_logits(preds, x, reduction="none") recon_loss = recon_loss.sum(dim=(1, 2, 3)) elbo = recon_loss + kl_div return { "recon_loss": recon_loss.mean(), "kl_div": kl_div.mean(), "loss": elbo.mean(), } def sample_fn(model): sample = torch.sigmoid(model.sample(n_samples=16)) return torch.where( sample < 0.5, torch.zeros_like(sample), torch.ones_like(sample) ) model_trainer = trainer.Trainer( model=model, loss_fn=loss_fn, optimizer=optimizer, train_loader=train_loader, eval_loader=test_loader, sample_epochs=1, sample_fn=sample_fn, log_dir=log_dir, n_gpus=n_gpus, device_id=device_id, ) model_trainer.interleaved_train_and_eval(n_epochs)
def reproduce( n_epochs=500, batch_size=128, log_dir="/tmp/run", device="cuda", debug_loader=None ): """Training script with defaults to reproduce results. The code inside this function is self contained and can be used as a top level training script, e.g. by copy/pasting it into a Jupyter notebook. Args: n_epochs: Number of epochs to train for. batch_size: Batch size to use for training and evaluation. log_dir: Directory where to log trainer state and TensorBoard summaries. device: Device to train on (either 'cuda' or 'cpu'). debug_loader: Debug DataLoader which replaces the default training and evaluation loaders if not 'None'. Do not use unless you're writing unit tests. """ from torch import optim from torch.nn import functional as F from pytorch_generative import datasets from pytorch_generative import models from pytorch_generative import trainer train_loader, test_loader = debug_loader, debug_loader if train_loader is None: train_loader, test_loader = datasets.get_mnist_loaders( batch_size, resize_to_32=True, ) model = models.VeryDeepVAE( in_channels=1, out_channels=1, input_resolution=32, latent_channels=16, hidden_channels=32, bottleneck_channels=8, ) optimizer = optim.Adam(model.parameters(), lr=1e-3) def loss_fn(x, _, preds): preds, kl_div = preds recon_loss = F.binary_cross_entropy_with_logits(preds, x, reduction="none") recon_loss = recon_loss.mean(dim=(1, 2, 3)) loss = recon_loss + kl_div return { "recon_loss": recon_loss.mean(), "kl_div": kl_div.mean(), "loss": loss.mean(), } def sample_fn(model): return torch.sigmoid(model.sample(n_samples=16)) model_trainer = trainer.Trainer( model=model, loss_fn=loss_fn, optimizer=optimizer, train_loader=train_loader, eval_loader=test_loader, sample_epochs=1, sample_fn=sample_fn, log_dir=log_dir, device=device, ) model_trainer.interleaved_train_and_eval(n_epochs)