def test_cnf_computation(input_shape, latent_dim): np.random.seed(123) torch.manual_seed(123) x = torch.rand(*input_shape) latent = torch.randn(*x.shape[:-1], latent_dim) if latent_dim else None dim = x.shape[-1] in_dim = dim if latent_dim is None else dim + latent_dim transforms = [ nf.ContinuousFlow(dim, net=nf.net.DiffeqMLP(in_dim + 1, [32], dim), atol=1e-8, rtol=1e-8, divergence='compute', solver='dopri5', has_latent=latent is not None) ] model = nf.Flow(nf.Normal(torch.zeros(dim), torch.ones(dim)), transforms) y, log_jac_y = model.forward(x, latent=latent) x_, log_jac_x = model.inverse(y, latent=latent) check_inverse(x, x_) check_jacobian(log_jac_x, log_jac_y) check_one_training_step(input_shape[-1], model, x, latent)
def test_normal(): x = torch.randn(4, 3, 5, 2) p0 = torch.distributions.Normal(torch.zeros(2), torch.ones(2)).log_prob(x).sum(-1) p1 = nf.Normal(torch.zeros(2), torch.ones(2)).log_prob(x) p2 = nf.MultivariateNormal(torch.zeros(2), torch.eye(2)).log_prob(x) assert (p0 == p1).all() and torch.isclose(p1, p2).all()
def build_coupling_flow(dim, hidden_dims, latent_dim, mask, flow_type, num_layers): base_dist = nf.Normal(torch.zeros(dim), torch.ones(dim)) transforms = [] for _ in range(num_layers): lat_dim = hidden_dims[-1] if latent_dim is not None: lat_dim += latent_dim transforms.append( nf.Coupling( getattr(nf, flow_type)(dim, latent_dim=lat_dim, n_bins=5), nf.net.MLP(dim, hidden_dims, hidden_dims[-1]), mask)) return nf.Flow(base_dist, transforms)
def test_cnf_definition(input_shape, latent_dim, diffeq, hidden_dims, solver, solver_options, rademacher, use_adjoint): np.random.seed(123) torch.manual_seed(123) x = torch.rand(*input_shape) latent = torch.randn(*x.shape[:-1], latent_dim) if latent_dim else None dim = x.shape[-1] in_dim = dim if latent_dim is None else dim + latent_dim cnf = nf.ContinuousFlow(dim, net=getattr(nf.net, diffeq)(in_dim + 1, hidden_dims, dim), solver=solver, solver_options=solver_options, rademacher=rademacher, use_adjoint=use_adjoint) model = nf.Flow(nf.Normal(torch.zeros(dim), torch.ones(dim)), [cnf])
def test_diffeq_self_attention(input_shape): torch.manual_seed(123) dim = input_shape[-1] cnf = nf.ContinuousFlow(dim, net=nf.net.DiffeqSelfAttention(dim + 1, [32], dim), atol=1e-8, rtol=1e-8, divergence='compute', solver='dopri5') model = nf.Flow(nf.Normal(torch.zeros(dim), torch.ones(dim)), [cnf]) x = torch.rand(*input_shape) y, log_jac_y = model.forward(x) x_, log_jac_x = model.inverse(y) check_inverse(x, x_) check_jacobian(log_jac_x, log_jac_y) check_one_training_step(input_shape[-1], model, x, None)
def build_affine(dim, num_layers, latent_dim): base_dist = nf.Normal(torch.zeros(dim), torch.ones(dim)) transforms = [] for _ in range(num_layers): transforms.append(nf.Affine(dim, latent_dim=latent_dim)) return nf.Flow(base_dist, transforms)
def build_spline(dim, n_bins, lower, upper, spline_type, latent_dim, num_layers): base_dist = nf.Normal(torch.zeros(dim), torch.ones(dim)) transforms = [] for _ in range(num_layers): transforms.append(nf.Spline(dim, n_bins=n_bins, lower=lower, upper=upper, spline_type=spline_type, latent_dim=latent_dim)) return nf.Flow(base_dist, transforms)