Exemplo n.º 1
0
def log_likelihood(model,
                   posterior_samples,
                   *args,
                   parallel=False,
                   batch_ndims=1,
                   **kwargs):
    """
    (EXPERIMENTAL INTERFACE) Returns log likelihood at observation nodes of model,
    given samples of all latent variables.

    :param model: Python callable containing Pyro primitives.
    :param dict posterior_samples: dictionary of samples from the posterior.
    :param args: model arguments.
    :param batch_ndims: the number of batch dimensions in posterior samples. Some usages:

        + set `batch_ndims=0` to get log likelihoods for 1 single sample

        + set `batch_ndims=1` to get log likelihoods for `posterior_samples`
          with shapes `(num_samples x ...)`

        + set `batch_ndims=2` to get log likelihoods for `posterior_samples`
          with shapes `(num_chains x num_samples x ...)`

    :param kwargs: model kwargs.
    :return: dict of log likelihoods at observation sites.
    """
    def single_loglik(samples):
        substituted_model = (substitute(model, samples) if isinstance(
            samples, dict) else model)
        model_trace = trace(substituted_model).get_trace(*args, **kwargs)
        return {
            name: site["fn"].log_prob(site["value"])
            for name, site in model_trace.items()
            if site["type"] == "sample" and site["is_observed"]
        }

    prototype_site = batch_shape = None
    for name, sample in posterior_samples.items():
        if batch_shape is not None and jnp.shape(
                sample)[:batch_ndims] != batch_shape:
            raise ValueError(
                f"Batch shapes at site {name} and {prototype_site} "
                f"should be the same, but got "
                f"{sample.shape[:batch_ndims]} and {batch_shape}")
        else:
            prototype_site = name
            batch_shape = jnp.shape(sample)[:batch_ndims]

    if batch_shape is None:  # posterior_samples is an empty dict
        batch_shape = (1, ) * batch_ndims
        posterior_samples = np.zeros(batch_shape)

    batch_size = int(np.prod(batch_shape))
    chunk_size = batch_size if parallel else 1
    return soft_vmap(single_loglik, posterior_samples, len(batch_shape),
                     chunk_size)
Exemplo n.º 2
0
def test_soft_vmap(batch_shape, chunk_size):
    def f(x):
        return {
            k: ((v[..., None] * jnp.ones(4)) if k == "a" else ~v) for k, v in x.items()
        }

    xs = {"a": jnp.ones(batch_shape + (4,)), "b": jnp.zeros(batch_shape).astype(bool)}
    ys = soft_vmap(f, xs, len(batch_shape), chunk_size)
    assert set(ys.keys()) == {"a", "b"}
    assert_allclose(ys["a"], xs["a"][..., None] * jnp.ones(4))
    assert_allclose(ys["b"], ~xs["b"])
Exemplo n.º 3
0
def _predictive(
        rng_key,
        model,
        posterior_samples,
        batch_shape,
        return_sites=None,
        parallel=True,
        model_args=(),
        model_kwargs={},
):
    model = numpyro.handlers.mask(model, mask=False)

    def single_prediction(val):
        rng_key, samples = val
        model_trace = trace(seed(substitute(model, samples),
                                 rng_key)).get_trace(*model_args,
                                                     **model_kwargs)
        if return_sites is not None:
            if return_sites == "":
                sites = {
                    k
                    for k, site in model_trace.items()
                    if site["type"] != "plate"
                }
            else:
                sites = return_sites
        else:
            sites = {
                k
                for k, site in model_trace.items()
                if (site["type"] == "sample" and k not in samples) or (
                    site["type"] == "deterministic")
            }
        return {
            name: site["value"]
            for name, site in model_trace.items() if name in sites
        }

    num_samples = int(np.prod(batch_shape))
    if num_samples > 1:
        rng_key = random.split(rng_key, num_samples)
    rng_key = rng_key.reshape(batch_shape + (2, ))
    chunk_size = num_samples if parallel else 1
    return soft_vmap(single_prediction, (rng_key, posterior_samples),
                     len(batch_shape), chunk_size)
Exemplo n.º 4
0
def _predictive(rng_key,
                model,
                posterior_samples,
                batch_shape,
                return_sites=None,
                parallel=True,
                model_args=(),
                model_kwargs={}):
    def single_prediction(val):
        rng_key, samples = val
        model_trace = trace(seed(substitute(model, samples),
                                 rng_key)).get_trace(*model_args,
                                                     **model_kwargs)
        if return_sites is not None:
            if return_sites == '':
                sites = {
                    k
                    for k, site in model_trace.items()
                    if site['type'] != 'plate'
                }
            else:
                sites = return_sites
        else:
            sites = {
                k
                for k, site in model_trace.items()
                if (site['type'] == 'sample' and k not in samples) or (
                    site['type'] == 'deterministic')
            }
        return {
            name: site['value']
            for name, site in model_trace.items() if name in sites
        }

    num_samples = int(np.prod(batch_shape))
    if num_samples > 1:
        rng_key = random.split(rng_key, num_samples)
    rng_key = rng_key.reshape(batch_shape + (2, ))
    chunk_size = num_samples if parallel else 1
    return soft_vmap(single_prediction, (rng_key, posterior_samples),
                     len(batch_shape), chunk_size)
Exemplo n.º 5
0
def _predictive(
        rng_key,
        model,
        posterior_samples,
        batch_shape,
        return_sites=None,
        infer_discrete=False,
        parallel=True,
        model_args=(),
        model_kwargs={},
):
    masked_model = numpyro.handlers.mask(model, mask=False)
    if infer_discrete:
        # inspect the model to get some structure
        rng_key, subkey = random.split(rng_key)
        batch_ndim = len(batch_shape)
        prototype_sample = tree_map(
            lambda x: jnp.reshape(x, (-1, ) + jnp.shape(x)[batch_ndim:])[0],
            posterior_samples,
        )
        prototype_trace = trace(
            seed(substitute(masked_model, prototype_sample),
                 subkey)).get_trace(*model_args, **model_kwargs)
        first_available_dim = -_guess_max_plate_nesting(prototype_trace) - 1

    def single_prediction(val):
        rng_key, samples = val
        if infer_discrete:
            from numpyro.contrib.funsor import config_enumerate
            from numpyro.contrib.funsor.discrete import _sample_posterior

            model_trace = prototype_trace
            temperature = 1
            pred_samples = _sample_posterior(
                config_enumerate(condition(model, samples)),
                first_available_dim,
                temperature,
                rng_key,
                *model_args,
                **model_kwargs,
            )
        else:
            model_trace = trace(
                seed(substitute(masked_model, samples),
                     rng_key)).get_trace(*model_args, **model_kwargs)
            pred_samples = {
                name: site["value"]
                for name, site in model_trace.items()
            }

        if return_sites is not None:
            if return_sites == "":
                sites = {
                    k
                    for k, site in model_trace.items()
                    if site["type"] != "plate"
                }
            else:
                sites = return_sites
        else:
            sites = {
                k
                for k, site in model_trace.items()
                if (site["type"] == "sample" and k not in samples) or (
                    site["type"] == "deterministic")
            }
        return {
            name: value
            for name, value in pred_samples.items() if name in sites
        }

    num_samples = int(np.prod(batch_shape))
    if num_samples > 1:
        rng_key = random.split(rng_key, num_samples)
    rng_key = rng_key.reshape((*batch_shape, 2))
    chunk_size = num_samples if parallel else 1
    return soft_vmap(single_prediction, (rng_key, posterior_samples),
                     len(batch_shape), chunk_size)