def test_PixelSNAIL(self): model = models.PixelSNAIL(in_channels=1, out_channels=1, model = models.PixelSNAIL(in_channels=3, out_channels=3*10, n_channels=1, n_pixel_snail_blocks=1, n_residual_blocks=1, attention_key_channels=1, attention_value_channels=1, head_channels=1) self._smoke_test(model) head_channels=1, sample_fn=_multi_channel_test_sample_fn) self._smoke_test(model, in_channels=3)
def test_PixelSNAIL(self): model = models.PixelSNAIL(in_channels=3, out_channels=3, n_channels=1, n_pixel_snail_blocks=1, n_residual_blocks=1, attention_key_channels=1, attention_value_channels=1, head_channels=1) self._smoke_test(model, in_channels=3)
def test_PixelSNAIL_multiple_channels(self): model = models.PixelSNAIL( in_channels=3, out_channels=3, n_channels=2, n_pixel_snail_blocks=1, n_residual_blocks=1, attention_key_channels=1, attention_value_channels=1, ) self._smoke_test(model)
def test_PixelSNAIL(self): model = models.PixelSNAIL( in_channels=3, out_channels=3, n_channels=2, n_pixel_snail_blocks=1, n_residual_blocks=1, attention_key_channels=1, attention_value_channels=1, ) self._test_multiple_channels(model, conditional_sample=True)
def reproduce(n_epochs=457, batch_size=128, log_dir="/tmp/run", device="cuda", debug_loader=None): """Training script with defaults to reproduce results. The code inside this function is self contained and can be used as a top level training script, e.g. by copy/pasting it into a Jupyter notebook. Args: n_epochs: Number of epochs to train for. batch_size: Batch size to use for training and evaluation. log_dir: Directory where to log trainer state and TensorBoard summaries. device: Device to train on (either 'cuda' or 'cpu'). debug_loader: Debug DataLoader which replaces the default training and evaluation loaders if not 'None'. Do not use unless you're writing unit tests. """ from torch import optim from torch.nn import functional as F from torch.optim import lr_scheduler from pytorch_generative import datasets from pytorch_generative import models from pytorch_generative import trainer train_loader, test_loader = debug_loader, debug_loader if train_loader is None: train_loader, test_loader = datasets.get_mnist_loaders( batch_size, dynamically_binarize=True) model = models.PixelSNAIL( in_channels=1, out_channels=1, n_channels=64, n_pixel_snail_blocks=8, n_residual_blocks=2, attention_value_channels=32, # n_channels / 2 attention_key_channels=4, # attention_value_channels / 8 ) optimizer = optim.Adam(model.parameters(), lr=1e-3) scheduler = lr_scheduler.MultiplicativeLR(optimizer, lr_lambda=lambda _: 0.999977) def loss_fn(x, _, preds): batch_size = x.shape[0] x, preds = x.view((batch_size, -1)), preds.view((batch_size, -1)) loss = F.binary_cross_entropy_with_logits(preds, x, reduction="none") return loss.sum(dim=1).mean() trainer = trainer.Trainer( model=model, loss_fn=loss_fn, optimizer=optimizer, train_loader=train_loader, eval_loader=test_loader, lr_scheduler=scheduler, log_dir=log_dir, device=device, ) trainer.interleaved_train_and_eval(n_epochs)
def reproduce(n_epochs=457, batch_size=128, log_dir="/tmp/run", device="cuda", n_channels=1, n_pixel_snail_blocks=1, n_residual_blocks=1, attention_value_channels=1, attention_key_channels=1, evalFlag=False, evaldir="/tmp/run", sampling_part=1): """Training script with defaults to reproduce results. The code inside this function is self contained and can be used as a top level training script, e.g. by copy/pasting it into a Jupyter notebook. Args: n_epochs: Number of epochs to train for. batch_size: Batch size to use for training and evaluation. log_dir: Directory where to log trainer state and TensorBoard summaries. device: Device to train on (either 'cuda' or 'cpu'). debug_loader: Debug DataLoader which replaces the default training and evaluation loaders if not 'None'. Do not use unless you're writing unit tests. """ from torch import optim from torch.nn import functional as F from torch.optim import lr_scheduler from torch.utils import data from torchvision import datasets from torchvision import transforms from pytorch_generative import trainer from pytorch_generative import models #################################################################################################################### #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~EB~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Load ImageGPT Data : import gmpm train = gmpm.train test = gmpm.test train_loader = data.DataLoader( data.TensorDataset(torch.Tensor(train), torch.rand(len(train))), batch_size=batch_size, shuffle=True, num_workers=8, ) test_loader = data.DataLoader( data.TensorDataset(torch.Tensor(test), torch.rand(len(test))), batch_size=batch_size, num_workers=8, ) attention_value_channels = attention_value_channels attention_key_channels = attention_key_channels model = models.PixelSNAIL( #################################################################################################################### # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~EB~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Change Input / Output size : # 3 channels - image after clusters mapping function as input to NN : in_channels=3, # 512 channels - each pixel get probability to get value from 0 to 511 out_channels=512, #################################################################################################################### n_channels=n_channels, n_pixel_snail_blocks=n_pixel_snail_blocks, n_residual_blocks=n_residual_blocks, attention_value_channels=attention_value_channels, attention_key_channels=attention_key_channels, ) optimizer = optim.Adam(model.parameters(), lr=1e-4) scheduler = lr_scheduler.MultiplicativeLR(optimizer, lr_lambda=lambda _: 0.999977) def loss_fn(x, _, preds): #################################################################################################################### # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~EB~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Update loss function to CrossEntropyLoss : x = x.long() criterion = nn.NLLLoss() B, C, D = preds.size() preds_2d = preds.view(B, C, D, -1) x_2d = x.view(B, D, -1) loss = criterion(preds_2d, x_2d.long()) #################################################################################################################### return loss _model = model.to(device) trainer = trainer.Trainer( model=model, loss_fn=loss_fn, optimizer=optimizer, train_loader=train_loader, eval_loader=test_loader, lr_scheduler=scheduler, log_dir=log_dir, device=device, sample_epochs=5, sample_fn=None, n_channels=n_channels, n_pixel_snail_blocks=n_pixel_snail_blocks, n_residual_blocks=n_residual_blocks, attention_value_channels=attention_value_channels, attention_key_channels=attention_key_channels, evalFlag=evalFlag, evaldir=evaldir, sampling_part=sampling_part) trainer.interleaved_train_and_eval(n_epochs)