示例#1
0
def logistic_random_effects(positions, annotations):
    """
    This model corresponds to the plate diagram in Figure 5 of reference [1].
    """
    num_annotators = int(np.max(positions)) + 1
    num_classes = int(np.max(annotations)) + 1
    num_items, num_positions = annotations.shape

    with numpyro.plate("class", num_classes):
        zeta = numpyro.sample("zeta", dist.Normal(0, 1).expand([num_classes - 1]).to_event(1))
        omega = numpyro.sample("Omega", dist.HalfNormal(1).expand([num_classes - 1]).to_event(1))
        chi = numpyro.sample("Chi", dist.HalfNormal(1).expand([num_classes - 1]).to_event(1))

    with numpyro.plate("annotator", num_annotators, dim=-2):
        with numpyro.plate("class", num_classes):
            with handlers.reparam(config={"beta": LocScaleReparam(0)}):
                beta = numpyro.sample("beta", dist.Normal(zeta, omega).to_event(1))
                beta = jnp.pad(beta, [(0, 0)] * (jnp.ndim(beta) - 1) + [(0, 1)])

    pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))

    with numpyro.plate("item", num_items, dim=-2):
        c = numpyro.sample("c", dist.Categorical(pi))

        with handlers.reparam(config={"theta": LocScaleReparam(0)}):
            theta = numpyro.sample("theta", dist.Normal(0, chi[c]).to_event(1))
            theta = jnp.pad(theta, [(0, 0)] * (jnp.ndim(theta) - 1) + [(0, 1)])

        with numpyro.plate("position", num_positions):
            logits = Vindex(beta)[positions, c, :] - theta
            numpyro.sample("y", dist.Categorical(logits=logits), obs=annotations)
示例#2
0
def item_difficulty(annotations):
    """
    This model corresponds to the plate diagram in Figure 5 of reference [1].
    """
    num_classes = int(np.max(annotations)) + 1
    num_items, num_positions = annotations.shape

    with numpyro.plate("class", num_classes):
        eta = numpyro.sample(
            "eta",
            dist.Normal(0, 1).expand([num_classes - 1]).to_event(1))
        chi = numpyro.sample(
            "Chi",
            dist.HalfNormal(1).expand([num_classes - 1]).to_event(1))

    pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))

    with numpyro.plate("item", num_items, dim=-2):
        c = numpyro.sample("c",
                           dist.Categorical(pi),
                           infer={"enumerate": "parallel"})

        with handlers.reparam(config={"theta": LocScaleReparam(0)}):
            theta = numpyro.sample("theta",
                                   dist.Normal(eta[c], chi[c]).to_event(1))
            theta = jnp.pad(theta, [(0, 0)] * (jnp.ndim(theta) - 1) + [(0, 1)])

        with numpyro.plate("position", annotations.shape[-1]):
            numpyro.sample("y",
                           dist.Categorical(logits=theta),
                           obs=annotations)
示例#3
0
def test_model_with_transformed_distribution():
    x_prior = dist.HalfNormal(2)
    y_prior = dist.LogNormal(scale=3.)  # transformed distribution

    def model():
        numpyro.sample('x', x_prior)
        numpyro.sample('y', y_prior)

    params = {'x': jnp.array(-5.), 'y': jnp.array(7.)}
    model = handlers.seed(model, random.PRNGKey(0))
    inv_transforms = {
        'x': biject_to(x_prior.support),
        'y': biject_to(y_prior.support)
    }
    expected_samples = partial(transform_fn, inv_transforms)(params)
    expected_potential_energy = (-x_prior.log_prob(expected_samples['x']) -
                                 y_prior.log_prob(expected_samples['y']) -
                                 inv_transforms['x'].log_abs_det_jacobian(
                                     params['x'], expected_samples['x']) -
                                 inv_transforms['y'].log_abs_det_jacobian(
                                     params['y'], expected_samples['y']))

    reparam_model = handlers.reparam(model, {'y': TransformReparam()})
    base_params = {'x': params['x'], 'y_base': params['y']}
    actual_samples = constrain_fn(handlers.seed(reparam_model,
                                                random.PRNGKey(0)), (), {},
                                  base_params,
                                  return_deterministic=True)
    actual_potential_energy = potential_energy(reparam_model, (), {},
                                               base_params)

    assert_allclose(expected_samples['x'], actual_samples['x'])
    assert_allclose(expected_samples['y'], actual_samples['y'])
    assert_allclose(actual_potential_energy, expected_potential_energy)
示例#4
0
 def model(data):
     alpha = numpyro.sample('alpha', dist.Uniform(0, 1))
     with handlers.reparam(config={'loc': TransformReparam()}):
         loc = numpyro.sample('loc', dist.TransformedDistribution(
             dist.Uniform(0, 1).mask(False),
             AffineTransform(0, alpha)))
     numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data)
示例#5
0
def hierarchical_dawid_skene(positions, annotations):
    """
    This model corresponds to the plate diagram in Figure 4 of reference [1].
    """
    num_annotators = int(np.max(positions)) + 1
    num_classes = int(np.max(annotations)) + 1
    num_items, num_positions = annotations.shape

    with numpyro.plate("class", num_classes):
        # NB: we define `beta` as the `logits` of `y` likelihood; but `logits` is
        # invariant up to a constant, so we'll follow [1]: fix the last term of `beta`
        # to 0 and only define hyperpriors for the first `num_classes - 1` terms.
        zeta = numpyro.sample("zeta", dist.Normal(0, 1).expand([num_classes - 1]).to_event(1))
        omega = numpyro.sample("Omega", dist.HalfNormal(1).expand([num_classes - 1]).to_event(1))

    with numpyro.plate("annotator", num_annotators, dim=-2):
        with numpyro.plate("class", num_classes):
            # non-centered parameterization
            with handlers.reparam(config={"beta": LocScaleReparam(0)}):
                beta = numpyro.sample("beta", dist.Normal(zeta, omega).to_event(1))
            # pad 0 to the last item
            beta = jnp.pad(beta, [(0, 0)] * (jnp.ndim(beta) - 1) + [(0, 1)])

    pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))

    with numpyro.plate("item", num_items, dim=-2):
        c = numpyro.sample("c", dist.Categorical(pi))

        with numpyro.plate("position", num_positions):
            logits = Vindex(beta)[positions, c, :]
            numpyro.sample("y", dist.Categorical(logits=logits), obs=annotations)
示例#6
0
 def model(data):
     alpha = numpyro.sample("alpha", dist.Uniform(0, 1))
     with handlers.reparam(config={"loc": TransformReparam()}):
         loc = numpyro.sample(
             "loc",
             dist.TransformedDistribution(
                 dist.Uniform(0, 1).mask(False), AffineTransform(0, alpha)),
         )
     numpyro.sample("obs", dist.Normal(loc, 0.1), obs=data)
示例#7
0
    def transition_fn(carry, y):
        first_capture_mask, z = carry
        with handlers.reparam(config={"phi_logit": LocScaleReparam(0)}):
            phi_logit_t = numpyro.sample("phi_logit", dist.Normal(phi_logit_mean, phi_sigma))
        phi_t = expit(phi_logit_t)
        with numpyro.plate("animals", N, dim=-1):
            with handlers.mask(mask=first_capture_mask):
                mu_z_t = first_capture_mask * phi_t * z + (1 - first_capture_mask)
                # NumPyro exactly sums out the discrete states z_t.
                z = numpyro.sample("z", dist.Bernoulli(dist.util.clamp_probs(mu_z_t)))
                mu_y_t = rho * z
                numpyro.sample("y", dist.Bernoulli(dist.util.clamp_probs(mu_y_t)), obs=y)

        first_capture_mask = first_capture_mask | y.astype(bool)
        return (first_capture_mask, z), None
示例#8
0
from jax import random
import jax.numpy as jnp

import numpyro
import numpyro.distributions as dist
from numpyro.handlers import reparam
from numpyro.infer import MCMC, NUTS, Predictive
from numpyro.infer.reparam import LocScaleReparam


def model(dim=10):
    y = numpyro.sample("y", dist.Normal(0, 3))
    numpyro.sample("x", dist.Normal(jnp.zeros(dim - 1), jnp.exp(y / 2)))


reparam_model = reparam(model, config={"x": LocScaleReparam(0)})


def run_inference(model, args, rng_key):
    kernel = NUTS(model)
    mcmc = MCMC(
        kernel,
        num_warmup=args.num_warmup,
        num_samples=args.num_samples,
        num_chains=args.num_chains,
        progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
    )
    mcmc.run(rng_key)
    mcmc.print_summary(exclude_deterministic=False)
    return mcmc.get_samples()
示例#9
0
    def run(self, rng_key, *args, **kwargs):
        """
        Run the nested samplers and collect weighted samples.

        :param random.PRNGKey rng_key: Random number generator key to be used for the sampling.
        :param args: The arguments needed by the `model`.
        :param kwargs: The keyword arguments needed by the `model`.
        """
        rng_sampling, rng_predictive = random.split(rng_key)
        # reparam the model so that latent sites have Uniform(0, 1) priors
        prototype_trace = trace(seed(self.model,
                                     rng_key)).get_trace(*args, **kwargs)
        param_names = [
            site["name"] for site in prototype_trace.values()
            if site["type"] == "sample" and not site["is_observed"]
            and site["infer"].get("enumerate", "") != "parallel"
        ]
        deterministics = [
            site["name"] for site in prototype_trace.values()
            if site["type"] == "deterministic"
        ]
        reparam_model = reparam(
            self.model, config={k: UniformReparam()
                                for k in param_names})

        # enable enumerate if needed
        has_enum = any(site["type"] == "sample"
                       and site["infer"].get("enumerate", "") == "parallel"
                       for site in prototype_trace.values())
        if has_enum:
            from numpyro.contrib.funsor import enum, log_density as log_density_

            max_plate_nesting = _guess_max_plate_nesting(prototype_trace)
            _validate_model(prototype_trace)
            reparam_model = enum(reparam_model, -max_plate_nesting - 1)
        else:
            log_density_ = log_density

        def loglik_fn(**params):
            return log_density_(reparam_model, args, kwargs, params)[0]

        # use NestedSampler with identity prior chain
        prior_chain = PriorChain()
        for name in param_names:
            prior = UniformPrior(name + "_base",
                                 prototype_trace[name]["fn"].shape())
            prior_chain.push(prior)
        # XXX: the `marginalised` keyword in jaxns can be used to get expectation of some
        # quantity over posterior samples; it can be helpful to expose it in this wrapper
        ns = OrigNestedSampler(
            loglik_fn,
            prior_chain,
            sampler_name=self.sampler_name,
            sampler_kwargs={
                "depth": self.depth,
                "num_slices": self.num_slices
            },
            max_samples=self.max_samples,
            num_live_points=self.num_live_points,
            collect_samples=True,
        )
        # some places of jaxns uses float64 and raises some warnings if the default dtype is
        # float32, so we suppress them here to avoid confusion
        with warnings.catch_warnings():
            warnings.filterwarnings(
                "ignore", message=".*will be truncated to dtype float32.*")
            results = ns(rng_sampling, termination_frac=self.termination_frac)
        # transform base samples back to original domains
        # Here we only transform the first valid num_samples samples
        # NB: the number of weighted samples obtained from jaxns is results.num_samples
        # and only the first num_samples values of results.samples are valid.
        num_samples = results.num_samples
        samples = tree_util.tree_map(lambda x: x[:num_samples],
                                     results.samples)
        predictive = Predictive(reparam_model,
                                samples,
                                return_sites=param_names + deterministics)
        samples = predictive(rng_predictive, *args, **kwargs)
        # replace base samples in jaxns results by transformed samples
        self._results = results._replace(samples=samples)
示例#10
0
import matplotlib.pyplot as plt
import numpy as np
import numpyro
import numpyro.distributions as dist
from jax import random
from numpyro import handlers, infer
from numpyro.infer import reparam


def model(dim: int = 10) -> None:

    y = numpyro.sample("y", dist.Normal(0, 3))
    numpyro.sample("x", dist.Normal(jnp.zeros(dim - 1), jnp.exp(y / 2)))


reparam_model = handlers.reparam(model,
                                 config={"x": reparam.LocScaleReparam(0)})


def run_inference(model: Callable,
                  rng_key: np.ndarray) -> Dict[str, jnp.ndarray]:

    kernel = infer.NUTS(model)
    mcmc = infer.MCMC(kernel, 1000, 1000, 1)
    mcmc.run(rng_key)
    return mcmc.get_samples()


def _plot_results(samples: Dict[str, jnp.ndarray],
                  reparam_samples: Dict[str, jnp.ndarray]) -> None:

    root = pathlib.Path("./data/neals")
示例#11
0
def reparam_model(model):
    return reparam(model, config={'x': LocScaleReparam(0)})
 def _auto_reparam(self) -> numpyro.handlers.reparam:
     """Automatically reparameterise circular parameters."""
     return hdl.reparam(
         config={k: CircularReparam()
                 for k in self.get_circ_var_names()})