Beispiel #1
0
 def upper_func(row, col, **kwargs):
     p_image = (
         eval_conditional_density(
             opts["density"],
             opts["condition"].to(device),
             limits.to(device),
             row,
             col,
             resolution=resolution,
             eps_margins1=eps_margins[row],
             eps_margins2=eps_margins[col],
         )
         .to("cpu")
         .numpy()
     )
     plt.imshow(
         p_image.T,
         origin="lower",
         extent=(
             limits[col, 0],
             limits[col, 1],
             limits[row, 0],
             limits[row, 1],
         ),
         aspect="auto",
     )
Beispiel #2
0
def test_conditional_density_2d():
    """
    Test whether the conditional density matches analytical results for MVN.

    This uses a 3D joint and conditions on the last value to generate a 2D conditional.
    """
    joint_mean = torch.zeros(3)
    joint_cov = torch.tensor([[1.0, 0.0, 0.7], [0.0, 1.0, 0.7],
                              [0.7, 0.7, 1.0]])
    joint_dist = MultivariateNormal(joint_mean, joint_cov)

    condition_dim2 = torch.ones(1)
    full_condition = torch.ones(3)

    resolution = 100
    vals_to_eval_at_dim1 = (torch.linspace(
        -3, 3, resolution).repeat(resolution).unsqueeze(1))
    vals_to_eval_at_dim2 = torch.repeat_interleave(
        torch.linspace(-3, 3, resolution), resolution).unsqueeze(1)
    vals_to_eval_at = torch.cat((vals_to_eval_at_dim1, vals_to_eval_at_dim2),
                                axis=1)

    # Solution with sbi.
    probs = eval_conditional_density(
        density=joint_dist,
        condition=full_condition,
        limits=torch.tensor([[-3, 3], [-3, 3], [-3, 3]]),
        dim1=0,
        dim2=1,
        resolution=resolution,
    )
    probs_sbi = probs / torch.sum(probs)

    # Analytical solution.
    conditional_mean, conditional_cov = conditional_of_mvn(
        joint_mean, joint_cov, condition_dim2)
    conditional_dist = torch.distributions.MultivariateNormal(
        conditional_mean, conditional_cov)

    probs = torch.exp(conditional_dist.log_prob(vals_to_eval_at))
    probs = torch.reshape(probs, (resolution, resolution))
    probs_analytical = probs / torch.sum(probs)

    assert torch.all(torch.abs(probs_analytical - probs_sbi) < 1e-5)
Beispiel #3
0
 def diag_func(row, **kwargs):
     p_vector = (eval_conditional_density(
         opts["density"],
         opts["condition"],
         limits,
         row,
         row,
         resolution=resolution,
         eps_margins1=eps_margins[row],
         eps_margins2=eps_margins[row],
     ).to("cpu").numpy())
     plt.plot(
         np.linspace(
             limits[row, 0],
             limits[row, 1],
             resolution,
         ),
         p_vector,
         c=opts["samples_colors"][0],
     )
Beispiel #4
0
def test_sample_conditional(snpe_method: type, set_seed):
    """
    Test whether sampling from the conditional gives the same results as evaluating.

    This compares samples that get smoothed with a Gaussian kde to evaluating the
    conditional log-probability with `eval_conditional_density`.

    `eval_conditional_density` is itself tested in `sbiutils_test.py`. Here, we use
    a bimodal posterior to test the conditional.
    """

    num_dim = 3
    dim_to_sample_1 = 0
    dim_to_sample_2 = 2

    x_o = zeros(1, num_dim)

    likelihood_shift = -1.0 * ones(num_dim)
    likelihood_cov = 0.1 * eye(num_dim)

    prior = utils.BoxUniform(-2.0 * ones(num_dim), 2.0 * ones(num_dim))

    def simulator(theta):
        if torch.rand(1) > 0.5:
            return linear_gaussian(theta, likelihood_shift, likelihood_cov)
        else:
            return linear_gaussian(theta, -likelihood_shift, likelihood_cov)

    if snpe_method == SNPE_A:
        net = utils.posterior_nn("mdn_snpe_a", hidden_features=20)
    else:
        net = utils.posterior_nn("maf", hidden_features=20)

    simulator, prior = prepare_for_sbi(simulator, prior)
    inference = snpe_method(prior,
                            density_estimator=net,
                            show_progress_bars=False)

    # We need a pretty big dataset to properly model the bimodality.
    theta, x = simulate_for_sbi(simulator, prior, 10000)
    _ = inference.append_simulations(theta, x).train(max_num_epochs=50)

    posterior = inference.build_posterior().set_default_x(x_o)
    samples = posterior.sample((50, ))

    # Evaluate the conditional density be drawing samples and smoothing with a Gaussian
    # kde.
    cond_samples = posterior.sample_conditional(
        (500, ),
        condition=samples[0],
        dims_to_sample=[dim_to_sample_1, dim_to_sample_2])
    _ = analysis.pairplot(
        cond_samples,
        limits=[[-2, 2], [-2, 2], [-2, 2]],
        fig_size=(2, 2),
        diag="kde",
        upper="kde",
    )

    limits = [[-2, 2], [-2, 2], [-2, 2]]

    density = gaussian_kde(cond_samples.numpy().T, bw_method="scott")

    X, Y = np.meshgrid(
        np.linspace(limits[0][0], limits[0][1], 50),
        np.linspace(limits[1][0], limits[1][1], 50),
    )
    positions = np.vstack([X.ravel(), Y.ravel()])
    sample_kde_grid = np.reshape(density(positions).T, X.shape)

    # Evaluate the conditional with eval_conditional_density.
    eval_grid = analysis.eval_conditional_density(
        posterior,
        condition=samples[0],
        dim1=dim_to_sample_1,
        dim2=dim_to_sample_2,
        limits=torch.tensor([[-2, 2], [-2, 2], [-2, 2]]),
    )

    # Compare the two densities.
    sample_kde_grid = sample_kde_grid / np.sum(sample_kde_grid)
    eval_grid = eval_grid / torch.sum(eval_grid)

    error = np.abs(sample_kde_grid - eval_grid.numpy())

    max_err = np.max(error)
    assert max_err < 0.0025