import torch import numpy as np import matplotlib.pyplot as plt import model import data cinn = model.MNIST_cINN(0) cinn.cuda() state_dict = { k: v for k, v in torch.load('output/mnist_cinn.pt').items() if 'tmp_var' not in k } cinn.load_state_dict(state_dict) cinn.eval() def show_samples(label): '''produces and shows cINN samples for a given label (0-9)''' N_samples = 100 l = torch.cuda.LongTensor(N_samples) l[:] = label z = 1.0 * torch.randn(N_samples, model.ndim_total).cuda() with torch.no_grad(): samples = cinn.reverse_sample(z, l).cpu().numpy() samples = data.unnormalize(samples)
from time import time from tqdm import tqdm import torch import torch.nn import torch.optim import numpy as np import model import data cinn = model.MNIST_cINN(5e-4) cinn.cuda() scheduler = torch.optim.lr_scheduler.MultiStepLR(cinn.optimizer, milestones=[20, 40], gamma=0.1) N_epochs = 60 t_start = time() nll_mean = [] print('Epoch\tBatch/Total \tTime \tNLL train\tNLL val\tLR') for epoch in range(N_epochs): for i, (x, l) in enumerate(data.train_loader): x, l = x.cuda(), l.cuda() z, log_j = cinn(x, l) nll = torch.mean(z**2) / 2 - torch.mean(log_j) / model.ndim_total nll.backward() torch.nn.utils.clip_grad_norm_(cinn.trainable_parameters, 10.) nll_mean.append(nll.item())
from time import time from tqdm import tqdm import torch import torch.optim import numpy as np import model import data cinn = model.MNIST_cINN(1e-3) cinn.cuda() scheduler = torch.optim.lr_scheduler.MultiStepLR(cinn.optimizer, milestones=[40, 80], gamma=0.1) N_epochs = 120 t_start = time() nll_mean = [] print('Epoch\tBatch/Total \tTime \tNLL train\tNLL val\tLR') for epoch in range(N_epochs): for i, (x, l) in enumerate(data.train_loader): x, l = x.cuda(), l.cuda() z, log_j = cinn(x, l) nll = torch.mean(z**2) / 2 - torch.mean(log_j) / model.ndim_total nll.backward() nll_mean.append(nll.item()) cinn.optimizer.step() cinn.optimizer.zero_grad()