def test_splitprior_logdet(input_size, distribution, n_times=10): module = SplitPrior(input_size[1:], NegativeGaussianLoss).to('cuda') half_c = input_size[1] // 2 nfeats = np.prod((half_c, *input_size[2:])) x = torch.randn(input_size).to('cuda') ldj_ours = module.logdet(x) def func(*inputs): inp = torch.stack(inputs, dim=0) out, _ = module(inp) out = out.sum(dim=0) return out J = torch.autograd.functional.jacobian(func, tuple(x), create_graph=False, strict=False) J = torch.stack(J, dim=0) J = J[:, :, :, :, :half_c, :, :] J = J.view(x.size(0), nfeats, nfeats) logdet_pytorch = torch.slogdet(J)[1] logdet_pytorch += module.base.log_prob(module.transform(x)[0][:, half_c:]) ldj_ours = ldj_ours.cpu().detach().numpy() ldj_pytorch = logdet_pytorch.cpu().detach().numpy() np.testing.assert_allclose(ldj_ours, ldj_pytorch, atol=1e-4)
def create_model(num_blocks=3, block_size=48, sym_recon_grad=False, actnorm=False, split_prior=False, recon_loss_weight=1.0): current_size = (3, 32, 32) alpha = 1e-6 layers = [ Dequantization(UniformDistribution(size=current_size)), Normalization(translation=0, scale=256), Normalization(translation=-alpha, scale=1 / (1 - 2 * alpha)), LogitTransform(), ] for l in range(num_blocks): layers.append(Squeeze()) current_size = (current_size[0] * 4, current_size[1] // 2, current_size[2] // 2) for k in range(block_size): if actnorm: layers.append(ActNorm(current_size[0])) layers.append(Conv1x1(current_size[0])) layers.append(Coupling(current_size)) if split_prior and l < num_blocks - 1: layers.append(SplitPrior(current_size, NegativeGaussianLoss)) current_size = (current_size[0] // 2, current_size[1], current_size[2]) return FlowSequential(NegativeGaussianLoss(size=current_size), *layers)
def create_model(num_blocks=2, block_size=16, sym_recon_grad=False, actnorm=False, split_prior=False, recon_loss_weight=100.0): alpha = 1e-6 layers = [ Dequantization(UniformDistribution(size=(1, 28, 28))), Normalization(translation=0, scale=256), Normalization(translation=-alpha, scale=1 / (1 - 2 * alpha)), LogitTransform(), ] current_size = (1, 28, 28) for l in range(num_blocks): layers.append(Squeeze()) current_size = (current_size[0]*4, current_size[1]//2, current_size[2]//2) for k in range(block_size): if actnorm: layers.append(ActNorm(current_size[0])) layers.append(SelfNormConv(current_size[0], current_size[0], (1, 1), bias=True, stride=1, padding=0, sym_recon_grad=sym_recon_grad, recon_loss_weight=recon_loss_weight)) layers.append(Coupling(current_size)) if split_prior and l < num_blocks - 1: layers.append(SplitPrior(current_size, NegativeGaussianLoss)) current_size = (current_size[0] // 2, current_size[1], current_size[2]) return FlowSequential(NegativeGaussianLoss(size=current_size), *layers)
def test_splitprior_inverse(input_size, distribution, n_times=10): module = SplitPrior(input_size[1:], NegativeGaussianLoss).to('cuda') half_c = input_size[1] // 2 for _ in range(n_times): input = torch.randn(input_size).to('cuda') forward, logdet = module(input) reverse = module.reverse(forward) full_reverse = torch.cat([reverse[:, :half_c], input[:, half_c:]], dim=1) inp = input.cpu().detach().numpy() outp = full_reverse.cpu().detach().view(input_size).numpy() np.testing.assert_allclose(inp, outp, atol=1e-3)