def log_likelihood(model, posterior_samples, *args, parallel=False, batch_ndims=1, **kwargs): """ (EXPERIMENTAL INTERFACE) Returns log likelihood at observation nodes of model, given samples of all latent variables. :param model: Python callable containing Pyro primitives. :param dict posterior_samples: dictionary of samples from the posterior. :param args: model arguments. :param batch_ndims: the number of batch dimensions in posterior samples. Some usages: + set `batch_ndims=0` to get log likelihoods for 1 single sample + set `batch_ndims=1` to get log likelihoods for `posterior_samples` with shapes `(num_samples x ...)` + set `batch_ndims=2` to get log likelihoods for `posterior_samples` with shapes `(num_chains x num_samples x ...)` :param kwargs: model kwargs. :return: dict of log likelihoods at observation sites. """ def single_loglik(samples): substituted_model = (substitute(model, samples) if isinstance( samples, dict) else model) model_trace = trace(substituted_model).get_trace(*args, **kwargs) return { name: site["fn"].log_prob(site["value"]) for name, site in model_trace.items() if site["type"] == "sample" and site["is_observed"] } prototype_site = batch_shape = None for name, sample in posterior_samples.items(): if batch_shape is not None and jnp.shape( sample)[:batch_ndims] != batch_shape: raise ValueError( f"Batch shapes at site {name} and {prototype_site} " f"should be the same, but got " f"{sample.shape[:batch_ndims]} and {batch_shape}") else: prototype_site = name batch_shape = jnp.shape(sample)[:batch_ndims] if batch_shape is None: # posterior_samples is an empty dict batch_shape = (1, ) * batch_ndims posterior_samples = np.zeros(batch_shape) batch_size = int(np.prod(batch_shape)) chunk_size = batch_size if parallel else 1 return soft_vmap(single_loglik, posterior_samples, len(batch_shape), chunk_size)
def test_soft_vmap(batch_shape, chunk_size): def f(x): return { k: ((v[..., None] * jnp.ones(4)) if k == "a" else ~v) for k, v in x.items() } xs = {"a": jnp.ones(batch_shape + (4,)), "b": jnp.zeros(batch_shape).astype(bool)} ys = soft_vmap(f, xs, len(batch_shape), chunk_size) assert set(ys.keys()) == {"a", "b"} assert_allclose(ys["a"], xs["a"][..., None] * jnp.ones(4)) assert_allclose(ys["b"], ~xs["b"])
def _predictive( rng_key, model, posterior_samples, batch_shape, return_sites=None, parallel=True, model_args=(), model_kwargs={}, ): model = numpyro.handlers.mask(model, mask=False) def single_prediction(val): rng_key, samples = val model_trace = trace(seed(substitute(model, samples), rng_key)).get_trace(*model_args, **model_kwargs) if return_sites is not None: if return_sites == "": sites = { k for k, site in model_trace.items() if site["type"] != "plate" } else: sites = return_sites else: sites = { k for k, site in model_trace.items() if (site["type"] == "sample" and k not in samples) or ( site["type"] == "deterministic") } return { name: site["value"] for name, site in model_trace.items() if name in sites } num_samples = int(np.prod(batch_shape)) if num_samples > 1: rng_key = random.split(rng_key, num_samples) rng_key = rng_key.reshape(batch_shape + (2, )) chunk_size = num_samples if parallel else 1 return soft_vmap(single_prediction, (rng_key, posterior_samples), len(batch_shape), chunk_size)
def _predictive(rng_key, model, posterior_samples, batch_shape, return_sites=None, parallel=True, model_args=(), model_kwargs={}): def single_prediction(val): rng_key, samples = val model_trace = trace(seed(substitute(model, samples), rng_key)).get_trace(*model_args, **model_kwargs) if return_sites is not None: if return_sites == '': sites = { k for k, site in model_trace.items() if site['type'] != 'plate' } else: sites = return_sites else: sites = { k for k, site in model_trace.items() if (site['type'] == 'sample' and k not in samples) or ( site['type'] == 'deterministic') } return { name: site['value'] for name, site in model_trace.items() if name in sites } num_samples = int(np.prod(batch_shape)) if num_samples > 1: rng_key = random.split(rng_key, num_samples) rng_key = rng_key.reshape(batch_shape + (2, )) chunk_size = num_samples if parallel else 1 return soft_vmap(single_prediction, (rng_key, posterior_samples), len(batch_shape), chunk_size)
def _predictive( rng_key, model, posterior_samples, batch_shape, return_sites=None, infer_discrete=False, parallel=True, model_args=(), model_kwargs={}, ): masked_model = numpyro.handlers.mask(model, mask=False) if infer_discrete: # inspect the model to get some structure rng_key, subkey = random.split(rng_key) batch_ndim = len(batch_shape) prototype_sample = tree_map( lambda x: jnp.reshape(x, (-1, ) + jnp.shape(x)[batch_ndim:])[0], posterior_samples, ) prototype_trace = trace( seed(substitute(masked_model, prototype_sample), subkey)).get_trace(*model_args, **model_kwargs) first_available_dim = -_guess_max_plate_nesting(prototype_trace) - 1 def single_prediction(val): rng_key, samples = val if infer_discrete: from numpyro.contrib.funsor import config_enumerate from numpyro.contrib.funsor.discrete import _sample_posterior model_trace = prototype_trace temperature = 1 pred_samples = _sample_posterior( config_enumerate(condition(model, samples)), first_available_dim, temperature, rng_key, *model_args, **model_kwargs, ) else: model_trace = trace( seed(substitute(masked_model, samples), rng_key)).get_trace(*model_args, **model_kwargs) pred_samples = { name: site["value"] for name, site in model_trace.items() } if return_sites is not None: if return_sites == "": sites = { k for k, site in model_trace.items() if site["type"] != "plate" } else: sites = return_sites else: sites = { k for k, site in model_trace.items() if (site["type"] == "sample" and k not in samples) or ( site["type"] == "deterministic") } return { name: value for name, value in pred_samples.items() if name in sites } num_samples = int(np.prod(batch_shape)) if num_samples > 1: rng_key = random.split(rng_key, num_samples) rng_key = rng_key.reshape((*batch_shape, 2)) chunk_size = num_samples if parallel else 1 return soft_vmap(single_prediction, (rng_key, posterior_samples), len(batch_shape), chunk_size)