Ejemplo n.º 1
0
def test_logistic_regression(jit, num_chains):
    dim = 3
    data = torch.randn(2000, dim)
    true_coefs = torch.arange(1.0, dim + 1.0)
    labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample()

    def model(data):
        coefs_mean = torch.zeros(dim)
        coefs = pyro.sample("beta", dist.Normal(coefs_mean, torch.ones(dim)))
        y = pyro.sample("y",
                        dist.Bernoulli(logits=(coefs * data).sum(-1)),
                        obs=labels)
        return y

    slice_kernel = Slice(model, jit_compile=jit, ignore_jit_warnings=True)
    mcmc = MCMC(
        slice_kernel,
        num_samples=500,
        warmup_steps=100,
        num_chains=num_chains,
        mp_context="fork",
        available_cpu=1,
    )
    mcmc.run(data)
    samples = mcmc.get_samples()
    assert_equal(rmse(true_coefs, samples["beta"].mean(0)).item(),
                 0.0,
                 prec=0.1)
Ejemplo n.º 2
0
def test_beta_binomial(hyperpriors):
    def model(data):
        with pyro.plate("plate_0", data.shape[-1]):
            alpha = (pyro.sample("alpha", dist.HalfCauchy(1.0))
                     if hyperpriors else torch.tensor([1.0, 1.0]))
            beta = (pyro.sample("beta", dist.HalfCauchy(1.0))
                    if hyperpriors else torch.tensor([1.0, 1.0]))
            beta_binom = BetaBinomialPair()
            with pyro.plate("plate_1", data.shape[-2]):
                probs = pyro.sample("probs", beta_binom.latent(alpha, beta))
                with pyro.plate("data", data.shape[0]):
                    pyro.sample(
                        "binomial",
                        beta_binom.conditional(probs=probs,
                                               total_count=total_count),
                        obs=data,
                    )

    true_probs = torch.tensor([[0.7, 0.4], [0.6, 0.4]])
    total_count = torch.tensor([[1000, 600], [400, 800]])
    num_samples = 80
    data = dist.Binomial(
        total_count=total_count,
        probs=true_probs).sample(sample_shape=(torch.Size((10, ))))
    hmc_kernel = Slice(collapse_conjugate(model),
                       jit_compile=True,
                       ignore_jit_warnings=True)
    mcmc = MCMC(hmc_kernel, num_samples=num_samples, warmup_steps=50)
    mcmc.run(data)
    samples = mcmc.get_samples()
    posterior = posterior_replay(model, samples, data, num_samples=num_samples)
    assert_equal(posterior["probs"].mean(0), true_probs, prec=0.05)
Ejemplo n.º 3
0
def test_gaussian_mixture_model(jit):
    K, N = 3, 1000

    def gmm(data):
        mix_proportions = pyro.sample("phi", dist.Dirichlet(torch.ones(K)))
        with pyro.plate("num_clusters", K):
            cluster_means = pyro.sample(
                "cluster_means", dist.Normal(torch.arange(float(K)), 1.0))
        with pyro.plate("data", data.shape[0]):
            assignments = pyro.sample("assignments",
                                      dist.Categorical(mix_proportions))
            pyro.sample("obs",
                        dist.Normal(cluster_means[assignments], 1.0),
                        obs=data)
        return cluster_means

    true_cluster_means = torch.tensor([1.0, 5.0, 10.0])
    true_mix_proportions = torch.tensor([0.1, 0.3, 0.6])
    cluster_assignments = dist.Categorical(true_mix_proportions).sample(
        torch.Size((N, )))
    data = dist.Normal(true_cluster_means[cluster_assignments], 1.0).sample()
    slice_kernel = Slice(gmm,
                         max_plate_nesting=1,
                         jit_compile=jit,
                         ignore_jit_warnings=True)
    mcmc = MCMC(slice_kernel, num_samples=300, warmup_steps=100)
    mcmc.run(data)
    samples = mcmc.get_samples()
    assert_equal(samples["phi"].mean(0).sort()[0],
                 true_mix_proportions,
                 prec=0.05)
    assert_equal(samples["cluster_means"].mean(0).sort()[0],
                 true_cluster_means,
                 prec=0.2)
Ejemplo n.º 4
0
def test_gamma_normal(jit):
    def model(data):
        rate = torch.tensor([1.0, 1.0])
        concentration = torch.tensor([1.0, 1.0])
        p_latent = pyro.sample("p_latent", dist.Gamma(rate, concentration))
        pyro.sample("obs", dist.Normal(3, p_latent), obs=data)
        return p_latent

    true_std = torch.tensor([0.5, 2])
    data = dist.Normal(3, true_std).sample(sample_shape=(torch.Size((2000, ))))
    slice_kernel = Slice(model, jit_compile=jit, ignore_jit_warnings=True)
    mcmc = MCMC(slice_kernel, num_samples=200, warmup_steps=100)
    mcmc.run(data)
    samples = mcmc.get_samples()
    assert_equal(samples["p_latent"].mean(0), true_std, prec=0.05)
Ejemplo n.º 5
0
def test_dirichlet_categorical(jit):
    def model(data):
        concentration = torch.tensor([1.0, 1.0, 1.0])
        p_latent = pyro.sample("p_latent", dist.Dirichlet(concentration))
        pyro.sample("obs", dist.Categorical(p_latent), obs=data)
        return p_latent

    true_probs = torch.tensor([0.1, 0.6, 0.3])
    data = dist.Categorical(true_probs).sample(
        sample_shape=(torch.Size((2000, ))))
    slice_kernel = Slice(model, jit_compile=jit, ignore_jit_warnings=True)
    mcmc = MCMC(slice_kernel, num_samples=200, warmup_steps=100)
    mcmc.run(data)
    samples = mcmc.get_samples()
    posterior = samples["p_latent"]
    assert_equal(posterior.mean(0), true_probs, prec=0.02)
Ejemplo n.º 6
0
def test_beta_bernoulli():
    def model(data):
        alpha = torch.tensor([1.1, 1.1])
        beta = torch.tensor([1.1, 1.1])
        p_latent = pyro.sample("p_latent", dist.Beta(alpha, beta))
        pyro.sample("obs", dist.Bernoulli(p_latent), obs=data)
        return p_latent

    true_probs = torch.tensor([0.9, 0.1])
    data = dist.Bernoulli(true_probs).sample(
        sample_shape=(torch.Size((1000, ))))
    slice_kernel = Slice(model)
    mcmc = MCMC(slice_kernel, num_samples=400, warmup_steps=200)
    mcmc.run(data)
    samples = mcmc.get_samples()
    assert_equal(samples["p_latent"].mean(0), true_probs, prec=0.02)
Ejemplo n.º 7
0
def test_bernoulli_latent_model(jit):
    @poutine.broadcast
    def model(data):
        y_prob = pyro.sample("y_prob", dist.Beta(1.0, 1.0))
        with pyro.plate("data", data.shape[0]):
            y = pyro.sample("y", dist.Bernoulli(y_prob))
            z = pyro.sample("z", dist.Bernoulli(0.65 * y + 0.1))
            pyro.sample("obs", dist.Normal(2.0 * z, 1.0), obs=data)

    N = 2000
    y_prob = torch.tensor(0.3)
    y = dist.Bernoulli(y_prob).sample(torch.Size((N, )))
    z = dist.Bernoulli(0.65 * y + 0.1).sample()
    data = dist.Normal(2.0 * z, 1.0).sample()
    slice_kernel = Slice(model,
                         max_plate_nesting=1,
                         jit_compile=jit,
                         ignore_jit_warnings=True)
    mcmc = MCMC(slice_kernel, num_samples=600, warmup_steps=200)
    mcmc.run(data)
    samples = mcmc.get_samples()
    assert_equal(samples["y_prob"].mean(0), y_prob, prec=0.05)
Ejemplo n.º 8
0
def test_gamma_beta(jit):
    def model(data):
        alpha_prior = pyro.sample("alpha",
                                  dist.Gamma(concentration=1.0, rate=1.0))
        beta_prior = pyro.sample("beta", dist.Gamma(concentration=1.0,
                                                    rate=1.0))
        pyro.sample(
            "x",
            dist.Beta(concentration1=alpha_prior, concentration0=beta_prior),
            obs=data,
        )

    true_alpha = torch.tensor(5.0)
    true_beta = torch.tensor(1.0)
    data = dist.Beta(concentration1=true_alpha,
                     concentration0=true_beta).sample(torch.Size((5000, )))
    slice_kernel = Slice(model, jit_compile=jit, ignore_jit_warnings=True)
    mcmc = MCMC(slice_kernel, num_samples=500, warmup_steps=200)
    mcmc.run(data)
    samples = mcmc.get_samples()
    assert_equal(samples["alpha"].mean(0), true_alpha, prec=0.08)
    assert_equal(samples["beta"].mean(0), true_beta, prec=0.05)
Ejemplo n.º 9
0
def test_slice_conjugate_gaussian(
    fixture,
    num_samples,
    warmup_steps,
    expected_means,
    expected_precs,
    mean_tol,
    std_tol,
):
    pyro.get_param_store().clear()
    slice_kernel = Slice(fixture.model)
    mcmc = MCMC(slice_kernel, num_samples, warmup_steps, num_chains=3)
    mcmc.run(fixture.data)
    samples = mcmc.get_samples()
    for i in range(1, fixture.chain_len + 1):
        param_name = "loc_" + str(i)
        latent = samples[param_name]
        latent_loc = latent.mean(0)
        latent_std = latent.std(0)
        expected_mean = torch.ones(fixture.dim) * expected_means[i - 1]
        expected_std = 1 / torch.sqrt(
            torch.ones(fixture.dim) * expected_precs[i - 1])

        # Actual vs expected posterior means for the latents
        logger.debug("Posterior mean (actual) - {}".format(param_name))
        logger.debug(latent_loc)
        logger.debug("Posterior mean (expected) - {}".format(param_name))
        logger.debug(expected_mean)
        assert_equal(rmse(latent_loc, expected_mean).item(),
                     0.0,
                     prec=mean_tol)

        # Actual vs expected posterior precisions for the latents
        logger.debug("Posterior std (actual) - {}".format(param_name))
        logger.debug(latent_std)
        logger.debug("Posterior std (expected) - {}".format(param_name))
        logger.debug(expected_std)
        assert_equal(rmse(latent_std, expected_std).item(), 0.0, prec=std_tol)
Ejemplo n.º 10
0
def test_gamma_poisson(hyperpriors):
    def model(data):
        with pyro.plate("latent_dim", data.shape[1]):
            alpha = (pyro.sample("alpha", dist.HalfCauchy(1.0))
                     if hyperpriors else torch.tensor([1.0, 1.0]))
            beta = (pyro.sample("beta", dist.HalfCauchy(1.0))
                    if hyperpriors else torch.tensor([1.0, 1.0]))
            gamma_poisson = GammaPoissonPair()
            rate = pyro.sample("rate", gamma_poisson.latent(alpha, beta))
            with pyro.plate("data", data.shape[0]):
                pyro.sample("obs", gamma_poisson.conditional(rate), obs=data)

    true_rate = torch.tensor([3.0, 10.0])
    num_samples = 100
    data = dist.Poisson(rate=true_rate).sample(
        sample_shape=(torch.Size((100, ))))
    slice_kernel = Slice(collapse_conjugate(model),
                         jit_compile=True,
                         ignore_jit_warnings=True)
    mcmc = MCMC(slice_kernel, num_samples=num_samples, warmup_steps=50)
    mcmc.run(data)
    samples = mcmc.get_samples()
    posterior = posterior_replay(model, samples, data, num_samples=num_samples)
    assert_equal(posterior["rate"].mean(0), true_rate, prec=0.3)
Ejemplo n.º 11
0
def run(
    task: Task,
    num_samples: int,
    num_observation: Optional[int] = None,
    observation: Optional[torch.Tensor] = None,
    num_chains: int = 10,
    num_warmup: int = 10000,
    kernel: str = "slice",
    kernel_parameters: Optional[Dict[str, Any]] = None,
    thinning: int = 1,
    diagnostics: bool = True,
    available_cpu: int = 1,
    mp_context: str = "fork",
    jit_compile: bool = False,
    automatic_transforms_enabled: bool = True,
    initial_params: Optional[torch.Tensor] = None,
    **kwargs: Any,
) -> torch.Tensor:
    """Runs MCMC using Pyro on potential function

    Produces `num_samples` while accounting for warmup (burn-in) and thinning.

    Note that the actual number of simulations is not controlled for with MCMC since
    algorithms are only used as a reference method in the benchmark. 

    MCMC is run on the potential function, which returns the unnormalized
    negative log posterior probability. Note that this requires a tractable likelihood.
    Pyro is used to automatically construct the potential function.

    Args:
        task: Task instance
        num_samples: Number of samples to generate from posterior
        num_observation: Observation number to load, alternative to `observation`
        observation: Observation, alternative to `num_observation`
        num_chains: Number of chains
        num_warmup: Warmup steps, during which parameters of the sampler are adapted.
            Warmup samples are not returned by the algorithm.
        kernel: HMC, NUTS, or Slice
        kernel_parameters: Parameters passed to kernel
        thinning: Amount of thinning to apply, in order to avoid drawing
            correlated samples from the chain
        diagnostics: Flag for diagnostics
        available_cpu: Number of CPUs used to parallelize chains
        mp_context: multiprocessing context, only fork might work
        jit_compile: Just-in-time (JIT) compilation, can yield significant speed ups
        automatic_transforms_enabled: Whether or not to use automatic transforms
        initial_params: Parameters to initialize at

    Returns:
        Samples from posterior
    """
    assert not (num_observation is None and observation is None)
    assert not (num_observation is not None and observation is not None)

    tic = time.time()
    log = sbibm.get_logger(__name__)

    hook_fn = None
    if diagnostics:
        log.info(f"MCMC sampling for observation {num_observation}")
        tb_writer, tb_close = tb_make_writer(
            logger=log,
            basepath=
            f"tensorboard/pyro_{kernel.lower()}/observation_{num_observation}",
        )
        hook_fn = tb_make_hook_fn(tb_writer)

    if "num_simulations" in kwargs:
        warnings.warn(
            "`num_simulations` was passed as a keyword but will be ignored, see docstring for more info."
        )

    # Prepare model and transforms
    conditioned_model = task._get_pyro_model(num_observation=num_observation,
                                             observation=observation)
    transforms = task._get_transforms(
        num_observation=num_observation,
        observation=observation,
        automatic_transforms_enabled=automatic_transforms_enabled,
    )

    kernel_parameters = kernel_parameters if kernel_parameters is not None else {}
    kernel_parameters["jit_compile"] = jit_compile
    kernel_parameters["transforms"] = transforms
    log.info("Using kernel: {name}({parameters})".format(
        name=kernel,
        parameters=",".join([f"{k}={v}"
                             for k, v in kernel_parameters.items()]),
    ))
    if kernel.lower() == "nuts":
        mcmc_kernel = NUTS(model=conditioned_model, **kernel_parameters)

    elif kernel.lower() == "hmc":
        mcmc_kernel = HMC(model=conditioned_model, **kernel_parameters)

    elif kernel.lower() == "slice":
        mcmc_kernel = Slice(model=conditioned_model, **kernel_parameters)

    else:
        raise NotImplementedError

    if initial_params is not None:
        site_name = "parameters"
        initial_params = {site_name: transforms[site_name](initial_params)}
    else:
        initial_params = None

    mcmc_parameters = {
        "num_chains": num_chains,
        "num_samples": thinning * num_samples,
        "warmup_steps": num_warmup,
        "available_cpu": available_cpu,
        "initial_params": initial_params,
    }
    log.info("Calling MCMC with: MCMC({name}_kernel, {parameters})".format(
        name=kernel,
        parameters=",".join([f"{k}={v}" for k, v in mcmc_parameters.items()]),
    ))

    mcmc = MCMC(mcmc_kernel, hook_fn=hook_fn, **mcmc_parameters)
    mcmc.run()

    toc = time.time()
    log.info(f"Finished MCMC after {toc-tic:.3f} seconds")
    log.info(f"Automatic transforms {mcmc.transforms}")

    log.info(f"Apply thinning of {thinning}")
    mcmc._samples = {
        "parameters": mcmc._samples["parameters"][:, ::thinning, :]
    }

    num_samples_available = (mcmc._samples["parameters"].shape[0] *
                             mcmc._samples["parameters"].shape[1])
    if num_samples_available < num_samples:
        warnings.warn("Some samples will be included multiple times")
        samples = mcmc.get_samples(
            num_samples=num_samples,
            group_by_chain=False)["parameters"].squeeze()
    else:
        samples = mcmc.get_samples(
            group_by_chain=False)["parameters"].squeeze()
        idx = torch.randperm(samples.shape[0])[:num_samples]
        samples = samples[idx, :]

    assert samples.shape[0] == num_samples

    if diagnostics:
        mcmc.summary()
        tb_ess(tb_writer, mcmc)
        tb_r_hat(tb_writer, mcmc)
        tb_marginals(tb_writer, mcmc)
        tb_acf(tb_writer, mcmc)
        tb_posteriors(tb_writer, mcmc)
        tb_plot_posterior(tb_writer, samples, tag="posterior/final")
        tb_close()

    return samples