Beispiel #1
0
def test_ravel_pytree(pytree):
    flat, unravel_fn = ravel_pytree(pytree)
    unravel = unravel_fn(flat)
    tree_flatten(tree_multimap(lambda x, y: assert_allclose(x, y), unravel, pytree))
    assert all(tree_flatten(tree_multimap(lambda x, y:
                                          canonicalize_dtype(lax.dtype(x)) == canonicalize_dtype(lax.dtype(y)),
                                          unravel, pytree))[0])
Beispiel #2
0
    def gibbs_fn(rng_key, gibbs_sites, hmc_sites, pe):
        # get support_sizes of gibbs_sites
        support_sizes_flat, _ = ravel_pytree({k: support_sizes[k] for k in gibbs_sites})
        num_discretes = support_sizes_flat.shape[0]

        rng_key, rng_permute = random.split(rng_key)
        idxs = random.permutation(rng_key, jnp.arange(num_discretes))

        def body_fn(i, val):
            idx = idxs[i]
            support_size = support_sizes_flat[idx]
            rng_key, z, pe = val
            rng_key, z_new, pe_new, log_accept_ratio = proposal_fn(
                rng_key, z, pe, potential_fn=partial(potential_fn, z_hmc=hmc_sites),
                idx=idx, support_size=support_size)
            rng_key, rng_accept = random.split(rng_key)
            # u ~ Uniform(0, 1), u < accept_ratio => -log(u) > -log_accept_ratio
            # and -log(u) ~ exponential(1)
            z, pe = cond(random.exponential(rng_accept) > -log_accept_ratio,
                         (z_new, pe_new), identity,
                         (z, pe), identity)
            return rng_key, z, pe

        init_val = (rng_key, gibbs_sites, pe)
        _, gibbs_sites, pe = fori_loop(0, num_discretes, body_fn, init_val)
        return gibbs_sites, pe
Beispiel #3
0
        def proxy_fn(params, subsample_lik_sites, gibbs_state):
            params_flat, _ = ravel_pytree(params)
            params_diff = params_flat - ref_params_flat

            ref_subsample_log_liks = gibbs_state.ref_subsample_log_liks
            ref_subsample_log_lik_grads = gibbs_state.ref_subsample_log_lik_grads
            ref_subsample_log_lik_hessians = gibbs_state.ref_subsample_log_lik_hessians

            proxy_sum = defaultdict(float)
            proxy_subsample = defaultdict(float)
            for name in subsample_lik_sites:
                proxy_subsample[name] = (
                    ref_subsample_log_liks[name] +
                    jnp.dot(ref_subsample_log_lik_grads[name], params_diff) +
                    0.5 * jnp.dot(
                        jnp.dot(ref_subsample_log_lik_hessians[name],
                                params_diff),
                        params_diff,
                    ))

                proxy_sum[name] = (
                    ref_log_likelihoods_sum[name] +
                    jnp.dot(ref_log_likelihood_grads_sum[name], params_diff) +
                    0.5 * jnp.dot(
                        jnp.dot(ref_log_likelihood_hessians_sum[name],
                                params_diff),
                        params_diff,
                    ))
            return proxy_sum, proxy_subsample
Beispiel #4
0
def _discrete_gibbs_proposal(rng_key, z_discrete, pe, potential_fn, idx,
                             support_size):
    # idx: current index of `z_discrete_flat` to update
    # support_size: support size of z_discrete at the index idx

    z_discrete_flat, unravel_fn = ravel_pytree(z_discrete)
    # Here we loop over the support of z_flat[idx] to get z_new
    # XXX: we can't vmap potential_fn over all proposals and sample from the conditional
    # categorical distribution because support_size is a traced value, i.e. its value
    # might change across different discrete variables;
    # so here we will loop over all proposals and use an online scheme to sample from
    # the conditional categorical distribution
    body_fn = partial(
        _discrete_gibbs_proposal_body_fn,
        z_discrete_flat,
        unravel_fn,
        pe,
        potential_fn,
        idx,
    )
    init_val = (rng_key, z_discrete, pe, jnp.array(0.0))
    rng_key, z_new, pe_new, _ = fori_loop(0, support_size - 1, body_fn,
                                          init_val)
    log_accept_ratio = jnp.array(0.0)
    return rng_key, z_new, pe_new, log_accept_ratio
Beispiel #5
0
def _discrete_modified_gibbs_proposal(rng_key,
                                      z_discrete,
                                      pe,
                                      potential_fn,
                                      idx,
                                      support_size,
                                      stay_prob=0.0):
    assert isinstance(stay_prob, float) and stay_prob >= 0.0 and stay_prob < 1
    z_discrete_flat, unravel_fn = ravel_pytree(z_discrete)
    body_fn = partial(
        _discrete_gibbs_proposal_body_fn,
        z_discrete_flat,
        unravel_fn,
        pe,
        potential_fn,
        idx,
    )
    # like gibbs_step but here, weight of the current value is 0
    init_val = (rng_key, z_discrete, pe, jnp.array(-jnp.inf))
    rng_key, z_new, pe_new, log_weight_sum = fori_loop(0, support_size - 1,
                                                       body_fn, init_val)
    rng_key, rng_stay = random.split(rng_key)
    z_new, pe_new = cond(
        random.bernoulli(rng_stay, stay_prob),
        (z_discrete, pe),
        identity,
        (z_new, pe_new),
        identity,
    )
    # here we calculate the MH correction: (1 - P(z)) / (1 - P(z_new))
    # where 1 - P(z) ~ weight_sum
    # and 1 - P(z_new) ~ 1 + weight_sum - z_new_weight
    log_accept_ratio = log_weight_sum - jnp.log(
        jnp.exp(log_weight_sum) - jnp.expm1(pe - pe_new))
    return rng_key, z_new, pe_new, log_accept_ratio
Beispiel #6
0
    def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
        rng_key, rng_r = random.split(rng_key)
        state = super().init(rng_key, num_warmup, init_params, model_args,
                             model_kwargs)
        self._support_sizes_flat, _ = ravel_pytree(
            {k: self._support_sizes[k]
             for k in self._gibbs_sites})
        if self._num_discrete_updates is None:
            self._num_discrete_updates = self._support_sizes_flat.shape[0]
        self._num_warmup = num_warmup

        # NB: the warmup adaptation can not be performed in sub-trajectories (i.e. the hmc trajectory
        # between two discrete updates), so we will do it here, at the end of each MixedHMC step.
        _, self._wa_update = warmup_adapter(
            num_warmup,
            adapt_step_size=self.inner_kernel._adapt_step_size,
            adapt_mass_matrix=self.inner_kernel._adapt_mass_matrix,
            dense_mass=self.inner_kernel._dense_mass,
            target_accept_prob=self.inner_kernel._target_accept_prob,
            find_reasonable_step_size=None)

        # In HMC, when `hmc_state.r` is not None, we will skip drawing a random momemtum at the
        # beginning of an HMC step. The reason is we need to maintain `r` between each sub-trajectories.
        r = momentum_generator(state.hmc_state.z,
                               state.hmc_state.adapt_state.mass_matrix_sqrt,
                               rng_r)
        return MixedHMCState(state.z, state.hmc_state._replace(r=r),
                             state.rng_key, jnp.zeros(()))
Beispiel #7
0
def _discrete_rw_proposal(rng_key, z_discrete, pe, potential_fn, idx, support_size):
    rng_key, rng_proposal = random.split(rng_key, 2)
    z_discrete_flat, unravel_fn = ravel_pytree(z_discrete)

    proposal = random.randint(rng_proposal, (), minval=0, maxval=support_size)
    z_new_flat = ops.index_update(z_discrete_flat, idx, proposal)
    z_new = unravel_fn(z_new_flat)
    pe_new = potential_fn(z_new)
    log_accept_ratio = pe - pe_new
    return rng_key, z_new, pe_new, log_accept_ratio
Beispiel #8
0
    def gibbs_fn(rng_key, gibbs_sites, hmc_sites):
        # convert to unconstrained values
        z_hmc = {
            k: biject_to(prototype_trace[k]["fn"].support).inv(v)
            for k, v in hmc_sites.items()
            if k in prototype_trace and prototype_trace[k]["type"] == "sample"
        }
        use_enum = len(set(support_sizes) - set(gibbs_sites)) > 0
        wrapped_model = _wrap_model(model)
        if use_enum:
            from numpyro.contrib.funsor import config_enumerate, enum

            wrapped_model = enum(config_enumerate(wrapped_model),
                                 -max_plate_nesting - 1)

        def potential_fn(z_discrete):
            model_kwargs_ = model_kwargs.copy()
            model_kwargs_["_gibbs_sites"] = z_discrete
            return potential_energy(wrapped_model,
                                    model_args,
                                    model_kwargs_,
                                    z_hmc,
                                    enum=use_enum)

        # get support_sizes of gibbs_sites
        support_sizes_flat, _ = ravel_pytree(
            {k: support_sizes[k]
             for k in gibbs_sites})
        num_discretes = support_sizes_flat.shape[0]

        rng_key, rng_permute = random.split(rng_key)
        idxs = random.permutation(rng_key, jnp.arange(num_discretes))

        def body_fn(i, val):
            idx = idxs[i]
            support_size = support_sizes_flat[idx]
            rng_key, z, pe = val
            rng_key, z_new, pe_new, log_accept_ratio = proposal_fn(
                rng_key,
                z,
                pe,
                potential_fn=potential_fn,
                idx=idx,
                support_size=support_size)
            rng_key, rng_accept = random.split(rng_key)
            # u ~ Uniform(0, 1), u < accept_ratio => -log(u) > -log_accept_ratio
            # and -log(u) ~ exponential(1)
            z, pe = cond(
                random.exponential(rng_accept) > -log_accept_ratio,
                (z_new, pe_new), identity, (z, pe), identity)
            return rng_key, z, pe

        init_val = (rng_key, gibbs_sites, potential_fn(gibbs_sites))
        _, gibbs_sites, _ = fori_loop(0, num_discretes, body_fn, init_val)
        return gibbs_sites
Beispiel #9
0
    def gibbs_fn(rng_key, gibbs_sites, hmc_sites):
        z_hmc = hmc_sites
        use_enum = len(set(support_sizes) - set(gibbs_sites)) > 0
        if use_enum:
            from numpyro.contrib.funsor import config_enumerate, enum

            wrapped_model_ = enum(config_enumerate(wrapped_model),
                                  -max_plate_nesting - 1)
        else:
            wrapped_model_ = wrapped_model

        def potential_fn(z_discrete):
            model_kwargs_ = model_kwargs.copy()
            model_kwargs_["_gibbs_sites"] = z_discrete
            return potential_energy(wrapped_model_,
                                    model_args,
                                    model_kwargs_,
                                    z_hmc,
                                    enum=use_enum)

        # get support_sizes of gibbs_sites
        support_sizes_flat, _ = ravel_pytree(
            {k: support_sizes[k]
             for k in gibbs_sites})
        num_discretes = support_sizes_flat.shape[0]

        rng_key, rng_permute = random.split(rng_key)
        idxs = random.permutation(rng_key, jnp.arange(num_discretes))

        def body_fn(i, val):
            idx = idxs[i]
            support_size = support_sizes_flat[idx]
            rng_key, z, pe = val
            rng_key, z_new, pe_new, log_accept_ratio = proposal_fn(
                rng_key,
                z,
                pe,
                potential_fn=potential_fn,
                idx=idx,
                support_size=support_size)
            rng_key, rng_accept = random.split(rng_key)
            # u ~ Uniform(0, 1), u < accept_ratio => -log(u) > -log_accept_ratio
            # and -log(u) ~ exponential(1)
            z, pe = cond(
                random.exponential(rng_accept) > -log_accept_ratio,
                (z_new, pe_new), identity, (z, pe), identity)
            return rng_key, z, pe

        init_val = (rng_key, gibbs_sites, potential_fn(gibbs_sites))
        _, gibbs_sites, _ = fori_loop(0, num_discretes, body_fn, init_val)
        return gibbs_sites
Beispiel #10
0
def _discrete_modified_rw_proposal(rng_key, z_discrete, pe, potential_fn, idx, support_size,
                                   stay_prob=0.):
    assert isinstance(stay_prob, float) and stay_prob >= 0. and stay_prob < 1
    rng_key, rng_proposal, rng_stay = random.split(rng_key, 3)
    z_discrete_flat, unravel_fn = ravel_pytree(z_discrete)

    i = random.randint(rng_proposal, (), minval=0, maxval=support_size - 1)
    proposal = jnp.where(i >= z_discrete_flat[idx], i + 1, i)
    proposal = jnp.where(random.bernoulli(rng_stay, stay_prob), idx, proposal)
    z_new_flat = ops.index_update(z_discrete_flat, idx, proposal)
    z_new = unravel_fn(z_new_flat)
    pe_new = potential_fn(z_new)
    log_accept_ratio = pe - pe_new
    return rng_key, z_new, pe_new, log_accept_ratio
Beispiel #11
0
 def single_particle_grad(particle, att_forces, rep_forces):
     reparam_jac = {
         k: jax.tree_map(
             lambda variable: jax.jacfwd(self.particle_transforms[k].inv
                                         )(variable),
             variables,
         )
         for k, variables in unravel_pytree(particle).items()
     }
     jac_params = jax.tree_multimap(
         lambda af, rf, rjac:
         ((af.reshape(-1) + rf.reshape(-1)) @ rjac.reshape(
             (_numel(rjac.shape[:len(rjac.shape) // 2]), -1))).reshape(
                 rf.shape),
         unravel_pytree(att_forces),
         unravel_pytree(rep_forces),
         reparam_jac,
     )
     jac_particle, _ = ravel_pytree(jac_params)
     return jac_particle
Beispiel #12
0
        def log_likelihood(params, subsample_indices=None):
            params_flat, unravel_fn = ravel_pytree(params)
            if subsample_indices is None:
                subsample_indices = {
                    k: jnp.arange(v[0])
                    for k, v in subsample_plate_sizes.items()
                }
            params = unravel_fn(params_flat)
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                with block(), trace(
                ) as tr, substitute(data=subsample_indices), substitute(
                        substitute_fn=partial(_unconstrain_reparam, params)):
                    model(*model_args, **model_kwargs)

            log_lik = defaultdict(float)
            for site in tr.values():
                if site["type"] == "sample" and site["is_observed"]:
                    for frame in site["cond_indep_stack"]:
                        if frame.name in subsample_plate_sizes:
                            log_lik[frame.name] += _sum_all_except_at_dim(
                                site["fn"].log_prob(site["value"]), frame.dim)
            return log_lik
Beispiel #13
0
    def construct_proxy_fn(prototype_trace, subsample_plate_sizes, model, model_args, model_kwargs, num_blocks=1):
        ref_params = {name: biject_to(prototype_trace[name]["fn"].support).inv(value)
                      for name, value in reference_params.items()}

        ref_params_flat, unravel_fn = ravel_pytree(ref_params)

        def log_likelihood(params_flat, subsample_indices=None):
            if subsample_indices is None:
                subsample_indices = {k: jnp.arange(v[0]) for k, v in subsample_plate_sizes.items()}
            params = unravel_fn(params_flat)
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                params = {name: biject_to(prototype_trace[name]["fn"].support)(value) for name, value in params.items()}
                with block(), trace() as tr, substitute(data=subsample_indices), substitute(data=params):
                    model(*model_args, **model_kwargs)

            log_lik = {}
            for site in tr.values():
                if site["type"] == "sample" and site["is_observed"]:
                    for frame in site["cond_indep_stack"]:
                        if frame.name in log_lik:
                            log_lik[frame.name] += _sum_all_except_at_dim(
                                site["fn"].log_prob(site["value"]), frame.dim)
                        else:
                            log_lik[frame.name] = _sum_all_except_at_dim(
                                site["fn"].log_prob(site["value"]), frame.dim)
            return log_lik

        def log_likelihood_sum(params_flat, subsample_indices=None):
            return {k: v.sum() for k, v in log_likelihood(params_flat, subsample_indices).items()}

        # those stats are dict keyed by subsample names
        ref_log_likelihoods_sum = log_likelihood_sum(ref_params_flat)
        ref_log_likelihood_grads_sum = jacobian(log_likelihood_sum)(ref_params_flat)
        ref_log_likelihood_hessians_sum = hessian(log_likelihood_sum)(ref_params_flat)

        def gibbs_init(rng_key, gibbs_sites):
            ref_subsample_log_liks = log_likelihood(ref_params_flat, gibbs_sites)
            ref_subsample_log_lik_grads = jacfwd(log_likelihood)(ref_params_flat, gibbs_sites)
            ref_subsample_log_lik_hessians = jacfwd(jacfwd(log_likelihood))(ref_params_flat, gibbs_sites)
            return TaylorProxyState(ref_subsample_log_liks, ref_subsample_log_lik_grads, ref_subsample_log_lik_hessians)

        def gibbs_update(rng_key, gibbs_sites, gibbs_state):
            u_new, pads, new_idxs, starts = _block_update_proxy(num_blocks, rng_key, gibbs_sites, subsample_plate_sizes)

            new_states = defaultdict(dict)
            ref_subsample_log_liks = log_likelihood(ref_params_flat, new_idxs)
            ref_subsample_log_lik_grads = jacfwd(log_likelihood)(ref_params_flat, new_idxs)
            ref_subsample_log_lik_hessians = jacfwd(jacfwd(log_likelihood))(ref_params_flat, new_idxs)
            for stat, new_block_values, last_values in zip(
                    ["log_liks", "grads", "hessians"],
                    [ref_subsample_log_liks,
                     ref_subsample_log_lik_grads,
                     ref_subsample_log_lik_hessians],
                    [gibbs_state.ref_subsample_log_liks,
                     gibbs_state.ref_subsample_log_lik_grads,
                     gibbs_state.ref_subsample_log_lik_hessians]):
                for name, subsample_idx in gibbs_sites.items():
                    size, subsample_size = subsample_plate_sizes[name]
                    pad, start = pads[name], starts[name]
                    new_value = jnp.pad(last_values[name], [(0, pad)] + [(0, 0)] * (jnp.ndim(last_values[name]) - 1))
                    new_value = lax.dynamic_update_slice_in_dim(
                        new_value, new_block_values[name], start, 0)
                    new_states[stat][name] = new_value[:subsample_size]
            gibbs_state = TaylorProxyState(new_states["log_liks"], new_states["grads"], new_states["hessians"])
            return u_new, gibbs_state

        def proxy_fn(params, subsample_lik_sites, gibbs_state):
            params_flat, _ = ravel_pytree(params)
            params_diff = params_flat - ref_params_flat

            ref_subsample_log_liks = gibbs_state.ref_subsample_log_liks
            ref_subsample_log_lik_grads = gibbs_state.ref_subsample_log_lik_grads
            ref_subsample_log_lik_hessians = gibbs_state.ref_subsample_log_lik_hessians

            proxy_sum = defaultdict(float)
            proxy_subsample = defaultdict(float)
            for name in subsample_lik_sites:
                proxy_subsample[name] = (ref_subsample_log_liks[name] +
                                         jnp.dot(ref_subsample_log_lik_grads[name], params_diff) +
                                         0.5 * jnp.dot(jnp.dot(ref_subsample_log_lik_hessians[name], params_diff),
                                                       params_diff))

                proxy_sum[name] = (ref_log_likelihoods_sum[name] +
                                   jnp.dot(ref_log_likelihood_grads_sum[name], params_diff) +
                                   0.5 * jnp.dot(jnp.dot(ref_log_likelihood_hessians_sum[name], params_diff),
                                                 params_diff))
            return proxy_sum, proxy_subsample

        return proxy_fn, gibbs_init, gibbs_update
Beispiel #14
0
        def particle_transform_fn(particle):
            params = unravel_pytree(particle)

            tparams = self.particle_transform_fn(params)
            tparticle, _ = ravel_pytree(tparams)
            return tparticle