示例#1
0
def test_cnf_computation(input_shape, latent_dim):
    np.random.seed(123)
    torch.manual_seed(123)

    x = torch.rand(*input_shape)
    latent = torch.randn(*x.shape[:-1], latent_dim) if latent_dim else None

    dim = x.shape[-1]
    in_dim = dim if latent_dim is None else dim + latent_dim
    transforms = [
        nf.ContinuousFlow(dim,
                          net=nf.net.DiffeqMLP(in_dim + 1, [32], dim),
                          atol=1e-8,
                          rtol=1e-8,
                          divergence='compute',
                          solver='dopri5',
                          has_latent=latent is not None)
    ]
    model = nf.Flow(nf.Normal(torch.zeros(dim), torch.ones(dim)), transforms)

    y, log_jac_y = model.forward(x, latent=latent)
    x_, log_jac_x = model.inverse(y, latent=latent)

    check_inverse(x, x_)
    check_jacobian(log_jac_x, log_jac_y)
    check_one_training_step(input_shape[-1], model, x, latent)
示例#2
0
def test_normal():
    x = torch.randn(4, 3, 5, 2)
    p0 = torch.distributions.Normal(torch.zeros(2),
                                    torch.ones(2)).log_prob(x).sum(-1)
    p1 = nf.Normal(torch.zeros(2), torch.ones(2)).log_prob(x)
    p2 = nf.MultivariateNormal(torch.zeros(2), torch.eye(2)).log_prob(x)
    assert (p0 == p1).all() and torch.isclose(p1, p2).all()
示例#3
0
def build_coupling_flow(dim, hidden_dims, latent_dim, mask, flow_type,
                        num_layers):
    base_dist = nf.Normal(torch.zeros(dim), torch.ones(dim))
    transforms = []
    for _ in range(num_layers):
        lat_dim = hidden_dims[-1]
        if latent_dim is not None:
            lat_dim += latent_dim
        transforms.append(
            nf.Coupling(
                getattr(nf, flow_type)(dim, latent_dim=lat_dim, n_bins=5),
                nf.net.MLP(dim, hidden_dims, hidden_dims[-1]), mask))
    return nf.Flow(base_dist, transforms)
示例#4
0
def test_cnf_definition(input_shape, latent_dim, diffeq, hidden_dims, solver,
                        solver_options, rademacher, use_adjoint):
    np.random.seed(123)
    torch.manual_seed(123)

    x = torch.rand(*input_shape)
    latent = torch.randn(*x.shape[:-1], latent_dim) if latent_dim else None

    dim = x.shape[-1]
    in_dim = dim if latent_dim is None else dim + latent_dim
    cnf = nf.ContinuousFlow(dim,
                            net=getattr(nf.net, diffeq)(in_dim + 1,
                                                        hidden_dims, dim),
                            solver=solver,
                            solver_options=solver_options,
                            rademacher=rademacher,
                            use_adjoint=use_adjoint)
    model = nf.Flow(nf.Normal(torch.zeros(dim), torch.ones(dim)), [cnf])
示例#5
0
def test_diffeq_self_attention(input_shape):
    torch.manual_seed(123)

    dim = input_shape[-1]
    cnf = nf.ContinuousFlow(dim,
                            net=nf.net.DiffeqSelfAttention(dim + 1, [32], dim),
                            atol=1e-8,
                            rtol=1e-8,
                            divergence='compute',
                            solver='dopri5')
    model = nf.Flow(nf.Normal(torch.zeros(dim), torch.ones(dim)), [cnf])

    x = torch.rand(*input_shape)

    y, log_jac_y = model.forward(x)
    x_, log_jac_x = model.inverse(y)

    check_inverse(x, x_)
    check_jacobian(log_jac_x, log_jac_y)
    check_one_training_step(input_shape[-1], model, x, None)
示例#6
0
def build_affine(dim, num_layers, latent_dim):
    base_dist = nf.Normal(torch.zeros(dim), torch.ones(dim))
    transforms = []
    for _ in range(num_layers):
        transforms.append(nf.Affine(dim, latent_dim=latent_dim))
    return nf.Flow(base_dist, transforms)
示例#7
0
def build_spline(dim, n_bins, lower, upper, spline_type, latent_dim, num_layers):
    base_dist = nf.Normal(torch.zeros(dim), torch.ones(dim))
    transforms = []
    for _ in range(num_layers):
        transforms.append(nf.Spline(dim, n_bins=n_bins, lower=lower, upper=upper, spline_type=spline_type, latent_dim=latent_dim))
    return nf.Flow(base_dist, transforms)