Exemplo n.º 1
0
def main(_args):
    data = generate_data()
    init_rng_key = PRNGKey(1273)
    # nuts = NUTS(gmm)
    # mcmc = MCMC(nuts, 100, 1000)
    # mcmc.print_summary()
    seeded_gmm = seed(gmm, init_rng_key)
    model_trace = trace(seeded_gmm).get_trace(data)
    max_plate_nesting = _guess_max_plate_nesting(model_trace)
    enum_gmm = enum(config_enumerate(gmm), - max_plate_nesting - 1)
    svi = SVI(enum_gmm, gmm_guide, Adam(0.1), RenyiELBO(-10.))
    svi_state = svi.init(init_rng_key, data)
    upd_fun = jax.jit(svi.update)
    with tqdm.trange(100_000) as pbar:
        for i in pbar:
            svi_state, loss = upd_fun(svi_state, data)
            pbar.set_description(f"SVI {loss}", True)
Exemplo n.º 2
0
    def run(self, rng_key, *args, **kwargs):
        """
        Run the nested samplers and collect weighted samples.

        :param random.PRNGKey rng_key: Random number generator key to be used for the sampling.
        :param args: The arguments needed by the `model`.
        :param kwargs: The keyword arguments needed by the `model`.
        """
        rng_sampling, rng_predictive = random.split(rng_key)
        # reparam the model so that latent sites have Uniform(0, 1) priors
        prototype_trace = trace(seed(self.model,
                                     rng_key)).get_trace(*args, **kwargs)
        param_names = [
            site["name"] for site in prototype_trace.values()
            if site["type"] == "sample" and not site["is_observed"]
            and site["infer"].get("enumerate", "") != "parallel"
        ]
        deterministics = [
            site["name"] for site in prototype_trace.values()
            if site["type"] == "deterministic"
        ]
        reparam_model = reparam(
            self.model, config={k: UniformReparam()
                                for k in param_names})

        # enable enumerate if needed
        has_enum = any(site["type"] == "sample"
                       and site["infer"].get("enumerate", "") == "parallel"
                       for site in prototype_trace.values())
        if has_enum:
            from numpyro.contrib.funsor import enum, log_density as log_density_

            max_plate_nesting = _guess_max_plate_nesting(prototype_trace)
            _validate_model(prototype_trace)
            reparam_model = enum(reparam_model, -max_plate_nesting - 1)
        else:
            log_density_ = log_density

        def loglik_fn(**params):
            return log_density_(reparam_model, args, kwargs, params)[0]

        # use NestedSampler with identity prior chain
        prior_chain = PriorChain()
        for name in param_names:
            prior = UniformPrior(name + "_base",
                                 prototype_trace[name]["fn"].shape())
            prior_chain.push(prior)
        # XXX: the `marginalised` keyword in jaxns can be used to get expectation of some
        # quantity over posterior samples; it can be helpful to expose it in this wrapper
        ns = OrigNestedSampler(
            loglik_fn,
            prior_chain,
            sampler_name=self.sampler_name,
            sampler_kwargs={
                "depth": self.depth,
                "num_slices": self.num_slices
            },
            max_samples=self.max_samples,
            num_live_points=self.num_live_points,
            collect_samples=True,
        )
        # some places of jaxns uses float64 and raises some warnings if the default dtype is
        # float32, so we suppress them here to avoid confusion
        with warnings.catch_warnings():
            warnings.filterwarnings(
                "ignore", message=".*will be truncated to dtype float32.*")
            results = ns(rng_sampling, termination_frac=self.termination_frac)
        # transform base samples back to original domains
        # Here we only transform the first valid num_samples samples
        # NB: the number of weighted samples obtained from jaxns is results.num_samples
        # and only the first num_samples values of results.samples are valid.
        num_samples = results.num_samples
        samples = tree_util.tree_map(lambda x: x[:num_samples],
                                     results.samples)
        predictive = Predictive(reparam_model,
                                samples,
                                return_sites=param_names + deterministics)
        samples = predictive(rng_predictive, *args, **kwargs)
        # replace base samples in jaxns results by transformed samples
        self._results = results._replace(samples=samples)
Exemplo n.º 3
0
def _sample_posterior(model, first_available_dim, temperature, rng_key, *args,
                      **kwargs):

    if temperature == 0:
        sum_op, prod_op = funsor.ops.max, funsor.ops.add
        approx = funsor.approximations.argmax_approximate
    elif temperature == 1:
        sum_op, prod_op = funsor.ops.logaddexp, funsor.ops.add
        rng_key, sub_key = random.split(rng_key)
        approx = funsor.montecarlo.MonteCarlo(rng_key=sub_key)
    else:
        raise ValueError("temperature must be 0 (map) or 1 (sample) for now")

    if first_available_dim is None:
        with block():
            model_trace = trace(seed(model,
                                     rng_key)).get_trace(*args, **kwargs)
        first_available_dim = -_guess_max_plate_nesting(model_trace) - 1

    with block(), enum(first_available_dim=first_available_dim):
        with plate_to_enum_plate():
            model_tr = packed_trace(model).get_trace(*args, **kwargs)

    terms = terms_from_trace(model_tr)
    # terms["log_factors"] = [log p(x) for each observed or latent sample site x]
    # terms["log_measures"] = [log p(z) or other Dice factor
    #                          for each latent sample site z]

    with funsor.interpretations.lazy:
        log_prob = funsor.sum_product.sum_product(
            sum_op,
            prod_op,
            list(terms["log_factors"].values()) +
            list(terms["log_measures"].values()),
            eliminate=terms["measure_vars"] | terms["plate_vars"],
            plates=terms["plate_vars"],
        )
        log_prob = funsor.optimizer.apply_optimizer(log_prob)

    with approx:
        approx_factors = funsor.adjoint.adjoint(sum_op, prod_op, log_prob)

    # construct a result trace to replay against the model
    sample_tr = model_tr.copy()
    sample_subs = {}
    for name, node in sample_tr.items():
        if node["type"] != "sample":
            continue
        if node["is_observed"]:
            # "observed" values may be collapsed samples that depend on enumerated
            # values, so we have to slice them down
            # TODO this should really be handled entirely under the hood by adjoint
            output = funsor.Reals[node["fn"].event_shape]
            value = funsor.to_funsor(node["value"],
                                     output,
                                     dim_to_name=node["infer"]["dim_to_name"])
            value = value(**sample_subs)
            node["value"] = funsor.to_data(
                value, name_to_dim=node["infer"]["name_to_dim"])
        else:
            log_measure = approx_factors[terms["log_measures"][name]]
            sample_subs[name] = _get_support_value(log_measure, name)
            node["value"] = funsor.to_data(
                sample_subs[name], name_to_dim=node["infer"]["name_to_dim"])

    with replay(guide_trace=sample_tr):
        return model(*args, **kwargs)
Exemplo n.º 4
0
def discrete_gibbs_fn(model,
                      model_args=(),
                      model_kwargs={},
                      *,
                      random_walk=False,
                      modified=False):
    """
    [EXPERIMENTAL INTERFACE]

    Returns a gibbs_fn to be used in :class:`HMCGibbs`, which works for discrete latent sites
    with enumerate support. The site update order is randomly permuted at each step.

    Note that those discrete latent sites that are not specified in the constructor of
    :class:`HMCGibbs` will be marginalized out by default (if they have enumerate supports).

    :param callable model: a callable with NumPyro primitives. This should be the same model
        as the one used in the `inner_kernel` of :class:`HMCGibbs`.
    :param tuple model_args: Arguments provided to the model.
    :param dict model_kwargs: Keyword arguments provided to the model.
    :param bool random_walk: If False, Gibbs sampling will be used to draw a sample from the
        conditional `p(gibbs_site | remaining sites)`. Otherwise, a sample will be drawn uniformly
        from the domain of `gibbs_site`.
    :param bool modified: whether to use a modified proposal, as suggested in reference [1], which
        always proposes a new state for the current Gibbs site.
        The modified scheme appears in the literature under the name "modified Gibbs sampler" or
        "Metropolised Gibbs sampler".
    :return: a callable `gibbs_fn` to be used in :class:`HMCGibbs`

    **References:**

    1. *Peskun's theorem and a modified discrete-state Gibbs sampler*,
       Liu, J. S. (1996)

    **Example**

    .. doctest::

        >>> from jax import random
        >>> import jax.numpy as jnp
        >>> import numpyro
        >>> import numpyro.distributions as dist
        >>> from numpyro.infer import MCMC, NUTS, HMCGibbs, discrete_gibbs_fn
        ...
        >>> def model(probs, locs):
        ...     c = numpyro.sample("c", dist.Categorical(probs))
        ...     numpyro.sample("x", dist.Normal(locs[c], 0.5))
        ...
        >>> probs = jnp.array([0.15, 0.3, 0.3, 0.25])
        >>> locs = jnp.array([-2, 0, 2, 4])
        >>> gibbs_fn = discrete_gibbs_fn(model, (probs, locs))
        >>> kernel = HMCGibbs(NUTS(model), gibbs_fn, gibbs_sites=["c"])
        >>> mcmc = MCMC(kernel, 1000, 100000, progress_bar=False)
        >>> mcmc.run(random.PRNGKey(0), probs, locs)
        >>> mcmc.print_summary()  # doctest: +SKIP

    """
    # NB: all of the information such as `model`, `model_args`, `model_kwargs`
    # can be accessed from HMCGibbs.sample but we require them here to
    # simplify the api of `gibbs_fn`
    prototype_trace = trace(seed(model, rng_seed=0)).get_trace(
        *model_args, **model_kwargs)
    support_sizes = {
        name: jnp.broadcast_to(site["fn"].enumerate_support(False).shape[0],
                               jnp.shape(site["value"]))
        for name, site in prototype_trace.items() if site["type"] == "sample"
        and site["fn"].has_enumerate_support and not site["is_observed"]
    }
    max_plate_nesting = _guess_max_plate_nesting(prototype_trace)
    if random_walk:
        if modified:
            proposal_fn = partial(_discrete_modified_rw_proposal, stay_prob=0.)
        else:
            proposal_fn = _discrete_rw_proposal
    else:
        if modified:
            proposal_fn = partial(_discrete_modified_gibbs_proposal,
                                  stay_prob=0.)
        else:
            proposal_fn = _discrete_gibbs_proposal

    def gibbs_fn(rng_key, gibbs_sites, hmc_sites):
        # convert to unconstrained values
        z_hmc = {
            k: biject_to(prototype_trace[k]["fn"].support).inv(v)
            for k, v in hmc_sites.items()
            if k in prototype_trace and prototype_trace[k]["type"] == "sample"
        }
        use_enum = len(set(support_sizes) - set(gibbs_sites)) > 0
        wrapped_model = _wrap_model(model)
        if use_enum:
            from numpyro.contrib.funsor import config_enumerate, enum

            wrapped_model = enum(config_enumerate(wrapped_model),
                                 -max_plate_nesting - 1)

        def potential_fn(z_discrete):
            model_kwargs_ = model_kwargs.copy()
            model_kwargs_["_gibbs_sites"] = z_discrete
            return potential_energy(wrapped_model,
                                    model_args,
                                    model_kwargs_,
                                    z_hmc,
                                    enum=use_enum)

        # get support_sizes of gibbs_sites
        support_sizes_flat, _ = ravel_pytree(
            {k: support_sizes[k]
             for k in gibbs_sites})
        num_discretes = support_sizes_flat.shape[0]

        rng_key, rng_permute = random.split(rng_key)
        idxs = random.permutation(rng_key, jnp.arange(num_discretes))

        def body_fn(i, val):
            idx = idxs[i]
            support_size = support_sizes_flat[idx]
            rng_key, z, pe = val
            rng_key, z_new, pe_new, log_accept_ratio = proposal_fn(
                rng_key,
                z,
                pe,
                potential_fn=potential_fn,
                idx=idx,
                support_size=support_size)
            rng_key, rng_accept = random.split(rng_key)
            # u ~ Uniform(0, 1), u < accept_ratio => -log(u) > -log_accept_ratio
            # and -log(u) ~ exponential(1)
            z, pe = cond(
                random.exponential(rng_accept) > -log_accept_ratio,
                (z_new, pe_new), identity, (z, pe), identity)
            return rng_key, z, pe

        init_val = (rng_key, gibbs_sites, potential_fn(gibbs_sites))
        _, gibbs_sites, _ = fori_loop(0, num_discretes, body_fn, init_val)
        return gibbs_sites

    return gibbs_fn
Exemplo n.º 5
0
def _discrete_gibbs_fn(wrapped_model,
                       model_args,
                       model_kwargs,
                       prototype_trace,
                       random_walk=False,
                       modified=False):
    support_sizes = {
        name: jnp.broadcast_to(site["fn"].enumerate_support(False).shape[0],
                               jnp.shape(site["value"]))
        for name, site in prototype_trace.items() if site["type"] == "sample"
        and site["fn"].has_enumerate_support and not site["is_observed"]
    }
    max_plate_nesting = _guess_max_plate_nesting(prototype_trace)
    if random_walk:
        if modified:
            proposal_fn = partial(_discrete_modified_rw_proposal, stay_prob=0.)
        else:
            proposal_fn = _discrete_rw_proposal
    else:
        if modified:
            proposal_fn = partial(_discrete_modified_gibbs_proposal,
                                  stay_prob=0.)
        else:
            proposal_fn = _discrete_gibbs_proposal

    def gibbs_fn(rng_key, gibbs_sites, hmc_sites):
        z_hmc = hmc_sites
        use_enum = len(set(support_sizes) - set(gibbs_sites)) > 0
        if use_enum:
            from numpyro.contrib.funsor import config_enumerate, enum

            wrapped_model_ = enum(config_enumerate(wrapped_model),
                                  -max_plate_nesting - 1)
        else:
            wrapped_model_ = wrapped_model

        def potential_fn(z_discrete):
            model_kwargs_ = model_kwargs.copy()
            model_kwargs_["_gibbs_sites"] = z_discrete
            return potential_energy(wrapped_model_,
                                    model_args,
                                    model_kwargs_,
                                    z_hmc,
                                    enum=use_enum)

        # get support_sizes of gibbs_sites
        support_sizes_flat, _ = ravel_pytree(
            {k: support_sizes[k]
             for k in gibbs_sites})
        num_discretes = support_sizes_flat.shape[0]

        rng_key, rng_permute = random.split(rng_key)
        idxs = random.permutation(rng_key, jnp.arange(num_discretes))

        def body_fn(i, val):
            idx = idxs[i]
            support_size = support_sizes_flat[idx]
            rng_key, z, pe = val
            rng_key, z_new, pe_new, log_accept_ratio = proposal_fn(
                rng_key,
                z,
                pe,
                potential_fn=potential_fn,
                idx=idx,
                support_size=support_size)
            rng_key, rng_accept = random.split(rng_key)
            # u ~ Uniform(0, 1), u < accept_ratio => -log(u) > -log_accept_ratio
            # and -log(u) ~ exponential(1)
            z, pe = cond(
                random.exponential(rng_accept) > -log_accept_ratio,
                (z_new, pe_new), identity, (z, pe), identity)
            return rng_key, z, pe

        init_val = (rng_key, gibbs_sites, potential_fn(gibbs_sites))
        _, gibbs_sites, _ = fori_loop(0, num_discretes, body_fn, init_val)
        return gibbs_sites

    return gibbs_fn
Exemplo n.º 6
0
    def init(self, rng_key, *args, **kwargs):
        """
        :param jax.random.PRNGKey rng_key: random number generator seed.
        :param args: arguments to the model / guide (these can possibly vary during
            the course of fitting).
        :param kwargs: keyword arguments to the model / guide (these can possibly vary
            during the course of fitting).
        :return: initial :data:`SteinVIState`
        """
        rng_key, kernel_seed, model_seed, guide_seed = jax.random.split(
            rng_key, 4)
        model_init = handlers.seed(self.model, model_seed)
        guide_init = handlers.seed(self.guide, guide_seed)
        guide_trace = handlers.trace(guide_init).get_trace(
            *args, **kwargs, **self.static_kwargs)
        model_trace = handlers.trace(model_init).get_trace(
            *args, **kwargs, **self.static_kwargs)
        rng_key, particle_seed = jax.random.split(rng_key)
        guide_init_params = self._find_init_params(particle_seed, self.guide,
                                                   guide_trace)
        params = {}
        transforms = {}
        inv_transforms = {}
        particle_transforms = {}
        guide_param_names = set()
        should_enum = False
        for site in model_trace.values():
            if ("fn" in site and site["type"] == "sample"
                    and not site["is_observed"]
                    and isinstance(site["fn"], Distribution)
                    and site["fn"].is_discrete):
                if site["fn"].has_enumerate_support and self.enum:
                    should_enum = True
                else:
                    raise Exception(
                        "Cannot enumerate model with discrete variables without enumerate support"
                    )
        # NB: params in model_trace will be overwritten by params in guide_trace
        for site in chain(model_trace.values(), guide_trace.values()):
            if site["type"] == "param":
                transform = get_parameter_transform(site)
                inv_transforms[site["name"]] = transform
                transforms[site["name"]] = transform.inv
                particle_transforms[site["name"]] = site.get(
                    "particle_transform", IdentityTransform())
                if site["name"] in guide_init_params:
                    pval, _ = guide_init_params[site["name"]]
                    if self.classic_guide_params_fn(site["name"]):
                        pval = tree_map(lambda x: x[0], pval)
                else:
                    pval = site["value"]
                params[site["name"]] = transform.inv(pval)
                if site["name"] in guide_trace:
                    guide_param_names.add(site["name"])

        if should_enum:
            mpn = _guess_max_plate_nesting(model_trace)
            self._inference_model = enum(config_enumerate(self.model),
                                         -mpn - 1)
        self.guide_param_names = guide_param_names
        self.constrain_fn = partial(transform_fn, inv_transforms)
        self.uconstrain_fn = partial(transform_fn, transforms)
        self.particle_transforms = particle_transforms
        self.particle_transform_fn = partial(transform_fn, particle_transforms)
        stein_particles, _, _ = batch_ravel_pytree(
            {
                k: params[k]
                for k, site in guide_trace.items() if site["type"] == "param"
                and site["name"] in guide_init_params
            },
            nbatch_dims=1,
        )

        self.kernel_fn.init(kernel_seed, stein_particles.shape)
        return SteinVIState(self.optim.init(params), rng_key)
Exemplo n.º 7
0
def _sample_posterior(model, first_available_dim, temperature, rng_key, *args,
                      **kwargs):

    if temperature == 0:
        sum_op, prod_op = funsor.ops.max, funsor.ops.add
        approx = funsor.approximations.argmax_approximate
    elif temperature == 1:
        sum_op, prod_op = funsor.ops.logaddexp, funsor.ops.add
        rng_key, sub_key = random.split(rng_key)
        approx = funsor.montecarlo.MonteCarlo(rng_key=sub_key)
    else:
        raise ValueError("temperature must be 0 (map) or 1 (sample) for now")

    if first_available_dim is None:
        with block():
            model_trace = trace(seed(model,
                                     rng_key)).get_trace(*args, **kwargs)
        first_available_dim = -_guess_max_plate_nesting(model_trace) - 1

    with funsor.adjoint.AdjointTape() as tape:
        with block(), enum(first_available_dim=first_available_dim):
            log_prob, model_tr, log_measures = _enum_log_density(
                model, args, kwargs, {}, sum_op, prod_op)

    with approx:
        approx_factors = tape.adjoint(sum_op, prod_op, log_prob)

    # construct a result trace to replay against the model
    sample_tr = model_tr.copy()
    sample_subs = {}
    for name, node in sample_tr.items():
        if node["type"] != "sample":
            continue
        if node["is_observed"]:
            # "observed" values may be collapsed samples that depend on enumerated
            # values, so we have to slice them down
            # TODO this should really be handled entirely under the hood by adjoint
            output = funsor.Reals[node["fn"].event_shape]
            value = funsor.to_funsor(node["value"],
                                     output,
                                     dim_to_name=node["infer"]["dim_to_name"])
            value = value(**sample_subs)
            node["value"] = funsor.to_data(
                value, name_to_dim=node["infer"]["name_to_dim"])
        else:
            log_measure = approx_factors[log_measures[name]]
            sample_subs[name] = _get_support_value(log_measure, name)
            node["value"] = funsor.to_data(
                sample_subs[name], name_to_dim=node["infer"]["name_to_dim"])

    data = {
        name: site["value"]
        for name, site in sample_tr.items() if site["type"] == "sample"
    }

    # concatenate _PREV_foo to foo
    time_vars = defaultdict(list)
    for name in data:
        if name.startswith("_PREV_"):
            root_name = _shift_name(name, -_get_shift(name))
            time_vars[root_name].append(name)
    for name in time_vars:
        if name in data:
            time_vars[name].append(name)
        time_vars[name] = sorted(time_vars[name], key=len, reverse=True)

    for root_name, vars in time_vars.items():
        prototype_shape = model_trace[root_name]["value"].shape
        values = [data.pop(name) for name in vars]
        if len(values) == 1:
            data[root_name] = values[0].reshape(prototype_shape)
        else:
            assert len(prototype_shape) >= 1
            values = [v.reshape((-1, ) + prototype_shape[1:]) for v in values]
            data[root_name] = jnp.concatenate(values)

    return data