def test_splitprior_logdet(input_size, distribution, n_times=10):
    module = SplitPrior(input_size[1:], NegativeGaussianLoss).to('cuda')
    half_c = input_size[1] // 2
    nfeats = np.prod((half_c, *input_size[2:]))

    x = torch.randn(input_size).to('cuda')
    ldj_ours = module.logdet(x)

    def func(*inputs):
        inp = torch.stack(inputs, dim=0)
        out, _ = module(inp)
        out = out.sum(dim=0)
        return out

    J = torch.autograd.functional.jacobian(func,
                                           tuple(x),
                                           create_graph=False,
                                           strict=False)
    J = torch.stack(J, dim=0)
    J = J[:, :, :, :, :half_c, :, :]
    J = J.view(x.size(0), nfeats, nfeats)
    logdet_pytorch = torch.slogdet(J)[1]
    logdet_pytorch += module.base.log_prob(module.transform(x)[0][:, half_c:])

    ldj_ours = ldj_ours.cpu().detach().numpy()
    ldj_pytorch = logdet_pytorch.cpu().detach().numpy()

    np.testing.assert_allclose(ldj_ours, ldj_pytorch, atol=1e-4)
Exemplo n.º 2
0
def create_model(num_blocks=3,
                 block_size=48,
                 sym_recon_grad=False,
                 actnorm=False,
                 split_prior=False,
                 recon_loss_weight=1.0):
    current_size = (3, 32, 32)

    alpha = 1e-6
    layers = [
        Dequantization(UniformDistribution(size=current_size)),
        Normalization(translation=0, scale=256),
        Normalization(translation=-alpha, scale=1 / (1 - 2 * alpha)),
        LogitTransform(),
    ]

    for l in range(num_blocks):
        layers.append(Squeeze())
        current_size = (current_size[0] * 4, current_size[1] // 2,
                        current_size[2] // 2)

        for k in range(block_size):
            if actnorm:
                layers.append(ActNorm(current_size[0]))
            layers.append(Conv1x1(current_size[0]))
            layers.append(Coupling(current_size))

        if split_prior and l < num_blocks - 1:
            layers.append(SplitPrior(current_size, NegativeGaussianLoss))
            current_size = (current_size[0] // 2, current_size[1],
                            current_size[2])

    return FlowSequential(NegativeGaussianLoss(size=current_size), *layers)
def create_model(num_blocks=2, block_size=16, sym_recon_grad=False, 
                 actnorm=False, split_prior=False, recon_loss_weight=100.0):
    alpha = 1e-6
    layers = [
        Dequantization(UniformDistribution(size=(1, 28, 28))),
        Normalization(translation=0, scale=256),
        Normalization(translation=-alpha, scale=1 / (1 - 2 * alpha)),
        LogitTransform(),
    ]

    current_size = (1, 28, 28)

    for l in range(num_blocks):
        layers.append(Squeeze())
        current_size = (current_size[0]*4, current_size[1]//2, current_size[2]//2)

        for k in range(block_size):
            if actnorm:
                layers.append(ActNorm(current_size[0]))
            
            layers.append(SelfNormConv(current_size[0], current_size[0], (1, 1), 
                                       bias=True, stride=1, padding=0,
                                       sym_recon_grad=sym_recon_grad, 
                                       recon_loss_weight=recon_loss_weight))
            layers.append(Coupling(current_size))

        if split_prior and l < num_blocks - 1:
            layers.append(SplitPrior(current_size, NegativeGaussianLoss))
            current_size = (current_size[0] // 2, current_size[1], current_size[2])

    return FlowSequential(NegativeGaussianLoss(size=current_size), 
                         *layers)
def test_splitprior_inverse(input_size, distribution, n_times=10):
    module = SplitPrior(input_size[1:], NegativeGaussianLoss).to('cuda')
    half_c = input_size[1] // 2

    for _ in range(n_times):
        input = torch.randn(input_size).to('cuda')

        forward, logdet = module(input)
        reverse = module.reverse(forward)

        full_reverse = torch.cat([reverse[:, :half_c], input[:, half_c:]],
                                 dim=1)

        inp = input.cpu().detach().numpy()
        outp = full_reverse.cpu().detach().view(input_size).numpy()

        np.testing.assert_allclose(inp, outp, atol=1e-3)