예제 #1
0
def test_mdn_with_1D_uniform_prior():
    """
    Note, we have this test because for 1D uniform priors, mdn log prob evaluation
    results in batch_size x batch_size return. This is probably because Uniform does
    not allow for event dimension > 1 and somewhere in pyknos it is used as if this was
    possible.
    Casting to BoxUniform solves it.
    """
    num_dim = 1
    x_o = torch.tensor([[1.0]])
    num_samples = 100

    # likelihood_mean will be likelihood_shift+theta
    likelihood_shift = -1.0 * ones(num_dim)
    likelihood_cov = 0.3 * eye(num_dim)

    prior = Uniform(low=torch.zeros(num_dim), high=torch.ones(num_dim))

    def simulator(theta: Tensor) -> Tensor:
        return linear_gaussian(theta, likelihood_shift, likelihood_cov)

    simulator, prior = prepare_for_sbi(simulator, prior)
    inference = SNPE(prior, density_estimator="mdn")

    theta, x = simulate_for_sbi(simulator, prior, 100)
    _ = inference.append_simulations(theta, x).train(training_batch_size=50)
    posterior = inference.build_posterior().set_default_x(x_o)
    samples = posterior.sample((num_samples,))
    log_probs = posterior.log_prob(samples)

    assert log_probs.shape == torch.Size([num_samples])
예제 #2
0
def test_get_dataloaders(training_batch_size):

    N = 1000
    validation_fraction = 0.1

    dataset = TensorDataset(torch.ones(N), torch.zeros(N))

    inferer = SNPE()

    _, val_loader = inferer.get_dataloaders(
        dataset,
        training_batch_size=training_batch_size,
        validation_fraction=validation_fraction,
    )

    assert len(val_loader) * val_loader.batch_size == int(validation_fraction * N)
예제 #3
0
def flexible():
    num_dim = 3
    x_o = torch.ones(1, num_dim)
    prior_mean = torch.zeros(num_dim)
    prior_cov = torch.eye(num_dim)
    simulator = diagonal_linear_gaussian

    # flexible interface
    prior = torch.distributions.MultivariateNormal(
        loc=prior_mean, covariance_matrix=prior_cov
    )
    simulator, prior = prepare_for_sbi(simulator, prior)
    inference = SNPE(prior)

    theta, x = simulate_for_sbi(simulator, proposal=prior, num_simulations=500)
    density_estimator = inference.append_simulations(theta, x).train()
    posterior = inference.build_posterior(density_estimator)
    posterior.sample((100,), x=x_o)

    return posterior
예제 #4
0
def flexible():
    num_dim = 3
    x_o = torch.ones(1, num_dim)
    prior_mean = torch.zeros(num_dim)
    prior_cov = torch.eye(num_dim)

    # flexible interface
    prior = torch.distributions.MultivariateNormal(loc=prior_mean,
                                                   covariance_matrix=prior_cov)
    simulator, prior = prepare_for_sbi(diagonal_linear_gaussian, prior)
    inference = SNPE(simulator, prior)
    posterior = inference(num_simulations=500)
    posterior.sample((100, ), x=x_o)

    return posterior
예제 #5
0
def test_embedding_net_api(method, num_dim: int, embedding_net: str):
    """Tests the API when using a preconfigured embedding net."""

    x_o = zeros(1, num_dim)

    # likelihood_mean will be likelihood_shift+theta
    likelihood_shift = -1.0 * ones(num_dim)
    likelihood_cov = 0.3 * eye(num_dim)

    prior = utils.BoxUniform(-2.0 * ones(num_dim), 2.0 * ones(num_dim))

    theta = prior.sample((1000, ))
    x = linear_gaussian(theta, likelihood_shift, likelihood_cov)

    if embedding_net == "mlp":
        embedding = FCEmbedding(input_dim=num_dim)
    else:
        raise NameError

    if method == "SNPE":
        density_estimator = posterior_nn("maf", embedding_net=embedding)
        inference = SNPE(prior,
                         density_estimator=density_estimator,
                         show_progress_bars=False)
    elif method == "SNLE":
        density_estimator = likelihood_nn("maf", embedding_net=embedding)
        inference = SNLE(prior,
                         density_estimator=density_estimator,
                         show_progress_bars=False)
    elif method == "SNRE":
        classifier = classifier_nn("resnet", embedding_net_x=embedding)
        inference = SNRE(prior,
                         classifier=classifier,
                         show_progress_bars=False)
    else:
        raise NameError

    _ = inference.append_simulations(theta, x).train(max_num_epochs=5)
    posterior = inference.build_posterior().set_default_x(x_o)

    s = posterior.sample((1, ))
    _ = posterior.potential(s)
예제 #6
0
    x = _W_eigs(U, V)
    return x


simulator, prior = prepare_for_sbi(simulator, prior)
density_estimator_build_fun = posterior_nn(
    model="maf",
    hidden_features=50,
    num_transforms=num_transforms,
    z_score_x=False,
    z_score_theta=False,
    support_map=True,
)
x_0 = torch.tensor([0.5, 1.5])

inference = SNPE(prior, density_estimator=density_estimator_build_fun)

best_round_ind = 0

# log initialized state
_theta, _x = simulate_for_sbi(simulator,
                              proposal=prior,
                              num_simulations=num_sims)
density_estimator = density_estimator_build_fun(_theta, _x)
posterior = inference.build_posterior(density_estimator)
z = posterior.sample((M, ), x=x_0)
x = simulator(z).numpy()
posteriors, zs, xs = [posterior], [z.numpy()], [x]
log_probs = [posterior.log_prob(z, x=x_0).numpy()]
mean_x = np.mean(x, axis=0)
distances = [np.linalg.norm(mean_x - x_0.numpy())]
예제 #7
0
파일: sbiutils_test.py 프로젝트: bkmi/sbi
def test_gaussian_transforms(snpe_method: str, plot_results: bool = False):
    """
    Tests whether the the product between proposal and posterior is computed correctly.

    For SNPE-C, this initializes two MoGs with two components each. It then evaluates
    their product by simply multiplying the probabilities of the two. The result is
    compared to the product of two MoGs as implemented in APT.

    For SNPE-A, it initializes a MoG with two compontents and one Gaussian (with one
    component). It then devices the MoG by the Gaussian and compares it to the
    transformation in SNPE-A.

    Args:
        snpe_method: String indicating whether to test snpe-a or snpe-c.
        plot_results: Whether to plot the products of the distributions.
    """
    class MoG:
        def __init__(self, means, preds, logits):
            self._means = means
            self._preds = preds
            self._logits = logits

        def log_prob(self, theta):
            probs = zeros(theta.shape[0])
            for m, p, l in zip(self._means, self._preds, self._logits):
                mvn = MultivariateNormal(m, p)
                weighted_prob = torch.exp(mvn.log_prob(theta)) * l
                probs += weighted_prob
            return probs

    # Build a grid on which to evaluate the densities.
    bound = 5.0
    theta_range = torch.linspace(-bound, bound, 100)
    theta1_grid, theta2_grid = torch.meshgrid(theta_range, theta_range)
    theta_grid = torch.stack([theta1_grid, theta2_grid])
    theta_grid_flat = torch.reshape(theta_grid, (2, 100**2))

    # Generate two MoGs.
    means1 = torch.tensor([[2.0, 2.0], [-2.0, -2.0]])
    covs1 = torch.stack([0.5 * torch.eye(2), torch.eye(2)])
    weights1 = torch.tensor([0.3, 0.7])

    if snpe_method == "snpe_c":
        means2 = torch.tensor([[2.0, -2.2], [-2.0, 1.9]])
        covs2 = torch.stack([0.6 * torch.eye(2), 0.9 * torch.eye(2)])
        weights2 = torch.tensor([0.6, 0.4])
    elif snpe_method == "snpe_a":
        means2 = torch.tensor([[-0.2, -0.4]])
        covs2 = torch.stack([3.5 * torch.eye(2)])
        weights2 = torch.tensor([1.0])

    mog1 = MoG(means1, covs1, weights1)
    mog2 = MoG(means2, covs2, weights2)

    # Evaluate the product of their pdfs by evaluating them separately and multiplying.
    probs1_raw = mog1.log_prob(theta_grid_flat.T)
    probs1 = torch.reshape(probs1_raw, (100, 100))

    probs2_raw = mog2.log_prob(theta_grid_flat.T)
    probs2 = torch.reshape(probs2_raw, (100, 100))

    if snpe_method == "snpe_c":
        probs_mult = probs1 * probs2

        # Set up a SNPE object in order to use the
        # `_automatic_posterior_transformation()`.
        prior = BoxUniform(-5 * ones(2), 5 * ones(2))
        # Testing new z-score arg options.
        density_estimator = posterior_nn("mdn",
                                         z_score_theta=None,
                                         z_score_x=None)
        inference = SNPE(prior=prior, density_estimator=density_estimator)
        theta_ = torch.rand(100, 2)
        x_ = torch.rand(100, 2)
        _ = inference.append_simulations(theta_, x_).train(max_num_epochs=1)
        inference._set_state_for_mog_proposal()

        precs1 = torch.inverse(covs1)
        precs2 = torch.inverse(covs2)

        # `.unsqueeze(0)` is needed because the method requires a batch dimension.
        logits_pp, means_pp, _, covs_pp = inference._automatic_posterior_transformation(
            torch.log(weights1.unsqueeze(0)),
            means1.unsqueeze(0),
            precs1.unsqueeze(0),
            torch.log(weights2.unsqueeze(0)),
            means2.unsqueeze(0),
            precs2.unsqueeze(0),
        )

    elif snpe_method == "snpe_a":
        probs_mult = probs1 / probs2

        prior = BoxUniform(-5 * ones(2), 5 * ones(2))

        inference = SNPE_A(prior=prior)
        theta_ = torch.rand(100, 2)
        x_ = torch.rand(100, 2)
        density_estimator = inference.append_simulations(
            theta_, x_).train(max_num_epochs=1)
        wrapped_density_estimator = SNPE_A_MDN(flow=density_estimator,
                                               proposal=prior,
                                               prior=prior,
                                               device="cpu")

        precs1 = torch.inverse(covs1)
        precs2 = torch.inverse(covs2)

        # `.unsqueeze(0)` is needed because the method requires a batch dimension.
        (
            logits_pp,
            means_pp,
            _,
            covs_pp,
        ) = wrapped_density_estimator._proposal_posterior_transformation(
            torch.log(weights2.unsqueeze(0)),
            means2.unsqueeze(0),
            precs2.unsqueeze(0),
            torch.log(weights1.unsqueeze(0)),
            means1.unsqueeze(0),
            precs1.unsqueeze(0),
        )

    # Normalize weights.
    logits_pp_norm = logits_pp - torch.logsumexp(
        logits_pp, dim=-1, keepdim=True)
    weights_pp = torch.exp(logits_pp_norm)

    # Evaluate the product of the two distributions.
    mog_apt = MoG(means_pp[0], covs_pp[0], weights_pp[0])

    probs_apt_raw = mog_apt.log_prob(theta_grid_flat.T)
    probs_apt = torch.reshape(probs_apt_raw, (100, 100))

    # Compute the error between the two methods.
    norm_probs_mult = probs_mult / torch.max(probs_mult)
    norm_probs3_ = probs_apt / torch.max(probs_apt)
    error = torch.abs(norm_probs_mult - norm_probs3_)

    assert torch.max(error) < 1e-5

    if plot_results:
        _, ax = plt.subplots(1, 4, figsize=(16, 4))

        ax[0].imshow(probs1, extent=[-bound, bound, -bound, bound])
        ax[0].set_title("p_1")
        ax[1].imshow(probs2, extent=[-bound, bound, -bound, bound])
        ax[1].set_title("p_2")
        ax[2].imshow(probs_mult, extent=[-bound, bound, -bound, bound])
        ax[3].imshow(probs_apt, extent=[-bound, bound, -bound, bound])
        if snpe_method == "snpe_c":
            ax[2].set_title("p_1 * p_2")
            ax[3].set_title("APT")
        elif snpe_method == "snpe_a":
            ax[2].set_title("p_1 / p_2")
            ax[3].set_title("SNPE-A")

        plt.show()
예제 #8
0
        new_params = dict(
            zip(self.param_names,
                new_param_values.detach().cpu().numpy()))
        self.params.update(new_params)

        net = Network(self.params)
        with JoblibBackend(n_jobs=1):
            dpl = simulate_dipole(net, n_trials=1)

        summstats = torch.as_tensor(dpl[0].data['agg'])
        return summstats


hnn_simulator = HNNSimulator(params_fname, prior_dict)
simulator, prior = prepare_for_sbi(hnn_simulator, prior)
inference = SNPE(prior)

dill_save(simulator, 'simulator', save_suffix, save_path)
dill_save(prior, 'prior', save_suffix, save_path)
dill_save(inference, 'inference', save_suffix, save_path)

theta, x = simulate_for_sbi(simulator,
                            proposal=prior,
                            num_simulations=10000,
                            num_workers=48)
dill_save(theta, 'theta', save_suffix, save_path)
dill_save(x, 'x', save_suffix, save_path)

density_estimator = inference.append_simulations(theta, x).train()
dill_save(density_estimator, 'density_estimator', save_suffix, save_path)
예제 #9
0
def test_pair_plot_scatter(
    env: SimEnv,
    policy: Policy,
    layout: str,
    labels: Optional[str],
    legend_labels: Optional[str],
    axis_limits: Optional[str],
    use_kde: bool,
    use_trafo: bool,
):
    def _simulator(dp: to.Tensor) -> to.Tensor:
        """The most simple interface of a simulation to sbi, using `env` and `policy` from outer scope"""
        ro = rollout(
            env,
            policy,
            eval=True,
            reset_kwargs=dict(domain_param=dict(m=dp[0], k=dp[1], d=dp[2])))
        observation_sim = to.from_numpy(
            ro.observations[-1]).to(dtype=to.float32)
        return to.atleast_2d(observation_sim)

    # Fix the init state
    env.init_space = SingularStateSpace(env.init_space.sample_uniform())
    env_real = deepcopy(env)
    env_real.domain_param = {"mass": 0.8, "stiffness": 15, "d": 0.7}

    # Optionally transformed domain parameters for inference
    if use_trafo:
        env = LogDomainParamTransform(env, mask=["stiffness"])

    # Domain parameter mapping and prior
    dp_mapping = {0: "mass", 1: "stiffness", 2: "d"}
    k_low = np.log(10) if use_trafo else 10
    k_up = np.log(20) if use_trafo else 20
    prior = sbiutils.BoxUniform(low=to.tensor([0.5, k_low, 0.2]),
                                high=to.tensor([1.5, k_up, 0.8]))

    # Learn a likelihood from the simulator
    density_estimator = sbiutils.posterior_nn(model="maf",
                                              hidden_features=10,
                                              num_transforms=3)
    snpe = SNPE(prior, density_estimator)
    simulator, prior = prepare_for_sbi(_simulator, prior)
    domain_param, data_sim = simulate_for_sbi(simulator=simulator,
                                              proposal=prior,
                                              num_simulations=50,
                                              num_workers=1)
    snpe.append_simulations(domain_param, data_sim)
    density_estimator = snpe.train(max_num_epochs=5)
    posterior = snpe.build_posterior(density_estimator)

    # Create a fake (random) true domain parameter
    domain_param_gt = to.tensor([
        env_real.domain_param[dp_mapping[key]]
        for key in sorted(dp_mapping.keys())
    ])
    domain_param_gt += domain_param_gt * to.randn(len(dp_mapping)) / 10
    domain_param_gt = domain_param_gt.unsqueeze(0)
    data_real = simulator(domain_param_gt)

    domain_params, log_probs = SBIBase.eval_posterior(
        posterior,
        data_real,
        num_samples=6,
        normalize_posterior=False,
        subrtn_sbi_sampling_hparam=dict(sample_with_mcmc=False),
    )
    dp_samples = [
        domain_params.reshape(1, -1, domain_params.shape[-1]).squeeze()
    ]

    if layout == "inside":
        num_rows, num_cols = len(dp_mapping), len(dp_mapping)
    else:
        num_rows, num_cols = len(dp_mapping) + 1, len(dp_mapping) + 1

    _, axs = plt.subplots(num_rows,
                          num_cols,
                          figsize=(8, 8),
                          tight_layout=True)
    fig = draw_posterior_pairwise_scatter(
        axs=axs,
        dp_samples=dp_samples,
        dp_mapping=dp_mapping,
        prior=prior if axis_limits == "use_prior" else None,
        env_sim=env,
        env_real=env_real,
        axis_limits=axis_limits,
        marginal_layout=layout,
        labels=labels,
        legend_labels=legend_labels,
        use_kde=use_kde,
    )
    assert fig is not None
예제 #10
0
def test_pair_plot(
    env: SimEnv,
    policy: Policy,
    layout: str,
    labels: Optional[str],
    prob_labels: Optional[str],
    use_prior: bool,
    use_trafo: bool,
):
    def _simulator(dp: to.Tensor) -> to.Tensor:
        """The most simple interface of a simulation to sbi, using `env` and `policy` from outer scope"""
        ro = rollout(
            env,
            policy,
            eval=True,
            reset_kwargs=dict(domain_param=dict(m=dp[0], k=dp[1], d=dp[2])))
        observation_sim = to.from_numpy(
            ro.observations[-1]).to(dtype=to.float32)
        return to.atleast_2d(observation_sim)

    # Fix the init state
    env.init_space = SingularStateSpace(env.init_space.sample_uniform())
    env_real = deepcopy(env)
    env_real.domain_param = {"mass": 0.8, "stiffness": 35, "d": 0.7}

    # Optionally transformed domain parameters for inference
    if use_trafo:
        env = SqrtDomainParamTransform(env, mask=["stiffness"])

    # Domain parameter mapping and prior
    dp_mapping = {0: "mass", 1: "stiffness", 2: "d"}
    prior = sbiutils.BoxUniform(low=to.tensor([0.5, 20, 0.2]),
                                high=to.tensor([1.5, 40, 0.8]))

    # Learn a likelihood from the simulator
    density_estimator = sbiutils.posterior_nn(model="maf",
                                              hidden_features=10,
                                              num_transforms=3)
    snpe = SNPE(prior, density_estimator)
    simulator, prior = prepare_for_sbi(_simulator, prior)
    domain_param, data_sim = simulate_for_sbi(simulator=simulator,
                                              proposal=prior,
                                              num_simulations=50,
                                              num_workers=1)
    snpe.append_simulations(domain_param, data_sim)
    density_estimator = snpe.train(max_num_epochs=5)
    posterior = snpe.build_posterior(density_estimator)

    # Create a fake (random) true domain parameter
    domain_param_gt = to.tensor(
        [env_real.domain_param[key] for _, key in dp_mapping.items()])
    domain_param_gt += domain_param_gt * to.randn(len(dp_mapping)) / 5
    domain_param_gt = domain_param_gt.unsqueeze(0)
    data_real = simulator(domain_param_gt)

    # Get a (random) condition
    condition = Embedding.pack(domain_param_gt.clone())

    if layout == "inside":
        num_rows, num_cols = len(dp_mapping), len(dp_mapping)
    else:
        num_rows, num_cols = len(dp_mapping) + 1, len(dp_mapping) + 1

    if use_prior:
        grid_bounds = None
    else:
        prior = None
        grid_bounds = to.cat(
            [to.zeros((len(dp_mapping), 1)),
             to.ones((len(dp_mapping), 1))],
            dim=1)

    _, axs = plt.subplots(num_rows,
                          num_cols,
                          figsize=(14, 14),
                          tight_layout=True)
    fig = draw_posterior_pairwise_heatmap(
        axs,
        posterior,
        data_real,
        dp_mapping,
        condition,
        prior=prior,
        env_real=env_real,
        marginal_layout=layout,
        grid_bounds=grid_bounds,
        grid_res=100,
        normalize_posterior=False,
        rescale_posterior=True,
        labels=None if labels is None else [""] * len(dp_mapping),
        prob_labels=prob_labels,
    )

    assert fig is not None