def test_scan_enum_one_latent(num_steps): data = random.normal(random.PRNGKey(0), (num_steps, )) init_probs = jnp.array([0.6, 0.4]) transition_probs = jnp.array([[0.8, 0.2], [0.1, 0.9]]) locs = jnp.array([-1.0, 1.0]) def model(data): x = None for i, y in markov(enumerate(data)): probs = init_probs if x is None else transition_probs[x] x = numpyro.sample(f"x_{i}", dist.Categorical(probs)) numpyro.sample(f"y_{i}", dist.Normal(locs[x], 1), obs=y) return x def fun_model(data): def transition_fn(x, y): probs = init_probs if x is None else transition_probs[x] x = numpyro.sample("x", dist.Categorical(probs)) numpyro.sample("y", dist.Normal(locs[x], 1), obs=y) return x, 1 x, collections = scan(transition_fn, None, data) assert collections.shape == data.shape[:1] return x expected_log_joint = log_density(enum(config_enumerate(model)), (data, ), {}, {})[0] actual_log_joint = log_density(enum(config_enumerate(fun_model)), (data, ), {}, {})[0] assert_allclose(actual_log_joint, expected_log_joint) actual_last_x = enum(config_enumerate(fun_model))(data) expected_last_x = enum(config_enumerate(model))(data) assert_allclose(actual_last_x, expected_last_x)
def test_scan_enum_plate(): N, D = 10, 3 data = random.normal(random.PRNGKey(0), (N, D)) init_probs = jnp.array([0.6, 0.4]) transition_probs = jnp.array([[0.8, 0.2], [0.1, 0.9]]) locs = jnp.array([-1.0, 1.0]) def model(data): x = None D_plate = numpyro.plate("D", D, dim=-1) for i, y in markov(enumerate(data)): with D_plate: probs = init_probs if x is None else transition_probs[x] x = numpyro.sample(f"x_{i}", dist.Categorical(probs)) numpyro.sample(f"y_{i}", dist.Normal(locs[x], 1), obs=y) def fun_model(data): def transition_fn(x, y): probs = init_probs if x is None else transition_probs[x] with numpyro.plate("D", D, dim=-1): x = numpyro.sample("x", dist.Categorical(probs)) numpyro.sample("y", dist.Normal(locs[x], 1), obs=y) return x, None scan(transition_fn, None, data) expected_log_joint = log_density(enum(config_enumerate(model), -2), (data, ), {}, {})[0] actual_log_joint = log_density(enum(config_enumerate(fun_model), -2), (data, ), {}, {})[0] assert_allclose(actual_log_joint, expected_log_joint)
def test_scan_enum_two_latents(): num_steps = 11 data = random.normal(random.PRNGKey(0), (num_steps, )) probs_x = jnp.array([[0.8, 0.2], [0.1, 0.9]]) probs_w = jnp.array([[0.7, 0.3], [0.6, 0.4]]) locs = jnp.array([[-1.0, 1.0], [2.0, 3.0]]) def model(data): x = w = 0 for i, y in markov(enumerate(data)): x = numpyro.sample(f"x_{i}", dist.Categorical(probs_x[x])) w = numpyro.sample(f"w_{i}", dist.Categorical(probs_w[w])) numpyro.sample(f"y_{i}", dist.Normal(locs[w, x], 1), obs=y) def fun_model(data): def transition_fn(carry, y): x, w = carry x = numpyro.sample("x", dist.Categorical(probs_x[x])) w = numpyro.sample("w", dist.Categorical(probs_w[w])) numpyro.sample("y", dist.Normal(locs[w, x], 1), obs=y) # also test if scan's `ys` are recorded corrected return (x, w), x scan(transition_fn, (0, 0), data) actual_log_joint = log_density(enum(config_enumerate(fun_model)), (data, ), {}, {})[0] expected_log_joint = log_density(enum(config_enumerate(model)), (data, ), {}, {})[0] assert_allclose(actual_log_joint, expected_log_joint)
def test_scan_enum_scan_enum(): num_steps = 11 data_x = random.normal(random.PRNGKey(0), (num_steps, )) data_w = data_x[:-1] + 1 probs_x = jnp.array([[0.8, 0.2], [0.1, 0.9]]) probs_w = jnp.array([[0.7, 0.3], [0.6, 0.4]]) locs_x = jnp.array([-1.0, 1.0]) locs_w = jnp.array([2.0, 3.0]) def model(data_x, data_w): x = w = 0 for i, y in markov(enumerate(data_x)): x = numpyro.sample(f"x_{i}", dist.Categorical(probs_x[x])) numpyro.sample(f"y_x_{i}", dist.Normal(locs_x[x], 1), obs=y) for i, y in markov(enumerate(data_w)): w = numpyro.sample(f"w{i}", dist.Categorical(probs_w[w])) numpyro.sample(f"y_w_{i}", dist.Normal(locs_w[w], 1), obs=y) def fun_model(data_x, data_w): def transition_fn(name, probs, locs, x, y): x = numpyro.sample(name, dist.Categorical(probs[x])) numpyro.sample("y_" + name, dist.Normal(locs[x], 1), obs=y) return x, None scan(partial(transition_fn, "x", probs_x, locs_x), 0, data_x) scan(partial(transition_fn, "w", probs_w, locs_w), 0, data_w) actual_log_joint = log_density(enum(config_enumerate(fun_model)), (data_x, data_w), {}, {})[0] expected_log_joint = log_density(enum(config_enumerate(model)), (data_x, data_w), {}, {})[0] assert_allclose(actual_log_joint, expected_log_joint)
def test_scan_enum_discrete_outside(): data = random.normal(random.PRNGKey(0), (10, )) probs = jnp.array([[[0.8, 0.2], [0.1, 0.9]], [[0.7, 0.3], [0.6, 0.4]]]) locs = jnp.array([-1.0, 1.0]) def model(data): w = numpyro.sample("w", dist.Bernoulli(0.6)) x = 0 for i, y in markov(enumerate(data)): x = numpyro.sample(f"x_{i}", dist.Categorical(probs[w, x])) numpyro.sample(f"y_{i}", dist.Normal(locs[x], 1), obs=y) def fun_model(data): w = numpyro.sample("w", dist.Bernoulli(0.6)) def transition_fn(x, y): x = numpyro.sample("x", dist.Categorical(probs[w, x])) numpyro.sample("y", dist.Normal(locs[x], 1), obs=y) return x, None scan(transition_fn, 0, data) actual_log_joint = log_density(enum(config_enumerate(fun_model)), (data, ), {}, {})[0] expected_log_joint = log_density(enum(config_enumerate(model)), (data, ), {}, {})[0] assert_allclose(actual_log_joint, expected_log_joint)
def test_scan_enum_separated_plate_discrete(): N, D = 10, 3 data = random.normal(random.PRNGKey(0), (N, D)) transition_probs = jnp.array([[0.8, 0.2], [0.1, 0.9]]) locs = jnp.array([[-1.0, 1.0], [2.0, 3.0]]) def model(data): x = 0 D_plate = numpyro.plate("D", D, dim=-1) for i, y in markov(enumerate(data)): probs = transition_probs[x] x = numpyro.sample(f"x_{i}", dist.Categorical(probs)) with D_plate: w = numpyro.sample(f"w_{i}", dist.Bernoulli(0.6)) numpyro.sample(f"y_{i}", dist.Normal(Vindex(locs)[x, w], 1), obs=y) def fun_model(data): def transition_fn(x, y): probs = transition_probs[x] x = numpyro.sample("x", dist.Categorical(probs)) with numpyro.plate("D", D, dim=-1): w = numpyro.sample("w", dist.Bernoulli(0.6)) numpyro.sample("y", dist.Normal(Vindex(locs)[x, w], 1), obs=y) return x, None scan(transition_fn, 0, data) actual_log_joint = log_density(enum(config_enumerate(fun_model), -2), (data, ), {}, {})[0] expected_log_joint = log_density(enum(config_enumerate(model), -2), (data, ), {}, {})[0] assert_allclose(actual_log_joint, expected_log_joint)
def body_fn(wrapped_carry, x, prefix=None): i, rng_key, carry = wrapped_carry init = True if (not_jax_tracer(i) and i == 0) else False rng_key, subkey = random.split(rng_key) if rng_key is not None else (None, None) seeded_fn = handlers.seed(f, subkey) if subkey is not None else f for subs_type, subs_map in substitute_stack: subs_fn = partial(_subs_wrapper, subs_map, i, length) if subs_type == 'condition': seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn) elif subs_type == 'substitute': seeded_fn = handlers.substitute(seeded_fn, substitute_fn=subs_fn) if init: with handlers.scope(prefix="_init"): new_carry, y = seeded_fn(carry, x) trace = {} else: with handlers.block(), packed_trace() as trace, promote_shapes(), enum(), markov(): # Like scan_wrapper, we collect the trace of scan's transition function # `seeded_fn` here. To put time dimension to the correct position, we need to # promote shapes to make `fn` and `value` # at each site have the same batch dims (e.g. if `fn.batch_shape = (2, 3)`, # and value's batch_shape is (3,), then we promote shape of # value so that its batch shape is (1, 3)). new_carry, y = config_enumerate(seeded_fn)(carry, x) # store shape of new_carry at a global variable nonlocal carry_shape_at_t1 carry_shape_at_t1 = [jnp.shape(x) for x in tree_flatten(new_carry)[0]] # make new_carry have the same shape as carry # FIXME: is this rigorous? new_carry = tree_multimap(lambda a, b: jnp.reshape(a, jnp.shape(b)), new_carry, carry) return (i + jnp.array(1), rng_key, new_carry), (PytreeTrace(trace), y)
def test_nested_plate(): with enum(first_available_dim=-3): with enum_plate("a", 5): with enum_plate("b", 2): x = numpyro.sample("x", dist.Normal(0, 1), rng_key=random.PRNGKey(0)) assert x.shape == (2, 5)
def gibbs_fn(rng_key, gibbs_sites, hmc_sites): # convert to unconstrained values z_hmc = { k: biject_to(prototype_trace[k]["fn"].support).inv(v) for k, v in hmc_sites.items() if k in prototype_trace and prototype_trace[k]["type"] == "sample" } use_enum = len(set(support_sizes) - set(gibbs_sites)) > 0 wrapped_model = _wrap_model(model) if use_enum: from numpyro.contrib.funsor import config_enumerate, enum wrapped_model = enum(config_enumerate(wrapped_model), -max_plate_nesting - 1) def potential_fn(z_discrete): model_kwargs_ = model_kwargs.copy() model_kwargs_["_gibbs_sites"] = z_discrete return potential_energy(wrapped_model, model_args, model_kwargs_, z_hmc, enum=use_enum) # get support_sizes of gibbs_sites support_sizes_flat, _ = ravel_pytree( {k: support_sizes[k] for k in gibbs_sites}) num_discretes = support_sizes_flat.shape[0] rng_key, rng_permute = random.split(rng_key) idxs = random.permutation(rng_key, jnp.arange(num_discretes)) def body_fn(i, val): idx = idxs[i] support_size = support_sizes_flat[idx] rng_key, z, pe = val rng_key, z_new, pe_new, log_accept_ratio = proposal_fn( rng_key, z, pe, potential_fn=potential_fn, idx=idx, support_size=support_size) rng_key, rng_accept = random.split(rng_key) # u ~ Uniform(0, 1), u < accept_ratio => -log(u) > -log_accept_ratio # and -log(u) ~ exponential(1) z, pe = cond( random.exponential(rng_accept) > -log_accept_ratio, (z_new, pe_new), identity, (z, pe), identity) return rng_key, z, pe init_val = (rng_key, gibbs_sites, potential_fn(gibbs_sites)) _, gibbs_sites, _ = fori_loop(0, num_discretes, body_fn, init_val) return gibbs_sites
def test_scan_history(history, T): def model(): p = numpyro.param("p", 0.25 * jnp.ones((2, 2, 2))) q = numpyro.param("q", 0.25 * jnp.ones(2)) z = numpyro.sample("z", dist.Bernoulli(0.5)) x_prev = 0 x_curr = 0 for t in markov(range(T), history=history): probs = p[x_prev, x_curr, z] x_prev, x_curr = x_curr, numpyro.sample("x_{}".format(t), dist.Bernoulli(probs)) numpyro.sample("y_{}".format(t), dist.Bernoulli(q[x_curr]), obs=0) return x_prev, x_curr def fun_model(): p = numpyro.param("p", 0.25 * jnp.ones((2, 2, 2))) q = numpyro.param("q", 0.25 * jnp.ones(2)) z = numpyro.sample("z", dist.Bernoulli(0.5)) def transition_fn(carry, y): x_prev, x_curr = carry probs = p[x_prev, x_curr, z] x_prev, x_curr = x_curr, numpyro.sample("x", dist.Bernoulli(probs)) numpyro.sample("y", dist.Bernoulli(q[x_curr]), obs=y) return (x_prev, x_curr), None (x_prev, x_curr), _ = scan(transition_fn, (0, 0), jnp.zeros(T), history=history) return x_prev, x_curr expected_log_joint = log_density(enum(config_enumerate(model)), (), {}, {})[0] actual_log_joint = log_density(enum(config_enumerate(fun_model)), (), {}, {})[0] assert_allclose(actual_log_joint, expected_log_joint) expected_x_prev, expected_x_curr = enum(config_enumerate(model))() actual_x_prev, actual_x_curr = enum(config_enumerate(fun_model))() assert_allclose(actual_x_prev, expected_x_prev) assert_allclose(actual_x_curr, expected_x_curr)
def gibbs_fn(rng_key, gibbs_sites, hmc_sites): z_hmc = hmc_sites use_enum = len(set(support_sizes) - set(gibbs_sites)) > 0 if use_enum: from numpyro.contrib.funsor import config_enumerate, enum wrapped_model_ = enum(config_enumerate(wrapped_model), -max_plate_nesting - 1) else: wrapped_model_ = wrapped_model def potential_fn(z_discrete): model_kwargs_ = model_kwargs.copy() model_kwargs_["_gibbs_sites"] = z_discrete return potential_energy(wrapped_model_, model_args, model_kwargs_, z_hmc, enum=use_enum) # get support_sizes of gibbs_sites support_sizes_flat, _ = ravel_pytree( {k: support_sizes[k] for k in gibbs_sites}) num_discretes = support_sizes_flat.shape[0] rng_key, rng_permute = random.split(rng_key) idxs = random.permutation(rng_key, jnp.arange(num_discretes)) def body_fn(i, val): idx = idxs[i] support_size = support_sizes_flat[idx] rng_key, z, pe = val rng_key, z_new, pe_new, log_accept_ratio = proposal_fn( rng_key, z, pe, potential_fn=potential_fn, idx=idx, support_size=support_size) rng_key, rng_accept = random.split(rng_key) # u ~ Uniform(0, 1), u < accept_ratio => -log(u) > -log_accept_ratio # and -log(u) ~ exponential(1) z, pe = cond( random.exponential(rng_accept) > -log_accept_ratio, (z_new, pe_new), identity, (z, pe), identity) return rng_key, z, pe init_val = (rng_key, gibbs_sites, potential_fn(gibbs_sites)) _, gibbs_sites, _ = fori_loop(0, num_discretes, body_fn, init_val) return gibbs_sites
def test_scan_enum_separated_plates_same_dim(): N, D1, D2 = 10, 3, 4 data = random.normal(random.PRNGKey(0), (N, D1 + D2)) data1, data2 = data[:, :D1], data[:, D1:] init_probs = jnp.array([0.6, 0.4]) transition_probs = jnp.array([[0.8, 0.2], [0.1, 0.9]]) locs = jnp.array([-1.0, 1.0]) def model(data1, data2): x = None D1_plate = numpyro.plate("D1", D1, dim=-1) D2_plate = numpyro.plate("D2", D2, dim=-1) for i, (y1, y2) in markov(enumerate(zip(data1, data2))): probs = init_probs if x is None else transition_probs[x] x = numpyro.sample(f"x_{i}", dist.Categorical(probs)) with D1_plate: numpyro.sample(f"y1_{i}", dist.Normal(locs[x], 1), obs=y1) with D2_plate: numpyro.sample(f"y2_{i}", dist.Normal(locs[x], 1), obs=y2) def fun_model(data1, data2): def transition_fn(x, y): y1, y2 = y probs = init_probs if x is None else transition_probs[x] x = numpyro.sample("x", dist.Categorical(probs)) with numpyro.plate("D1", D1, dim=-1): numpyro.sample("y1", dist.Normal(locs[x], 1), obs=y1) with numpyro.plate("D2", D2, dim=-1): numpyro.sample("y2", dist.Normal(locs[x], 1), obs=y2) return x, None scan(transition_fn, None, (data1, data2)) actual_log_joint = log_density(enum(config_enumerate(fun_model), -2), (data1, data2), {}, {})[0] expected_log_joint = log_density(enum(config_enumerate(model), -2), (data1, data2), {}, {})[0] assert_allclose(actual_log_joint, expected_log_joint)
def main(_args): data = generate_data() init_rng_key = PRNGKey(1273) # nuts = NUTS(gmm) # mcmc = MCMC(nuts, 100, 1000) # mcmc.print_summary() seeded_gmm = seed(gmm, init_rng_key) model_trace = trace(seeded_gmm).get_trace(data) max_plate_nesting = _guess_max_plate_nesting(model_trace) enum_gmm = enum(config_enumerate(gmm), - max_plate_nesting - 1) svi = SVI(enum_gmm, gmm_guide, Adam(0.1), RenyiELBO(-10.)) svi_state = svi.init(init_rng_key, data) upd_fun = jax.jit(svi.update) with tqdm.trange(100_000) as pbar: for i in pbar: svi_state, loss = upd_fun(svi_state, data) pbar.set_description(f"SVI {loss}", True)
def run(self, rng_key, *args, **kwargs): """ Run the nested samplers and collect weighted samples. :param random.PRNGKey rng_key: Random number generator key to be used for the sampling. :param args: The arguments needed by the `model`. :param kwargs: The keyword arguments needed by the `model`. """ rng_sampling, rng_predictive = random.split(rng_key) # reparam the model so that latent sites have Uniform(0, 1) priors prototype_trace = trace(seed(self.model, rng_key)).get_trace(*args, **kwargs) param_names = [ site["name"] for site in prototype_trace.values() if site["type"] == "sample" and not site["is_observed"] and site["infer"].get("enumerate", "") != "parallel" ] deterministics = [ site["name"] for site in prototype_trace.values() if site["type"] == "deterministic" ] reparam_model = reparam( self.model, config={k: UniformReparam() for k in param_names}) # enable enumerate if needed has_enum = any(site["type"] == "sample" and site["infer"].get("enumerate", "") == "parallel" for site in prototype_trace.values()) if has_enum: from numpyro.contrib.funsor import enum, log_density as log_density_ max_plate_nesting = _guess_max_plate_nesting(prototype_trace) _validate_model(prototype_trace) reparam_model = enum(reparam_model, -max_plate_nesting - 1) else: log_density_ = log_density def loglik_fn(**params): return log_density_(reparam_model, args, kwargs, params)[0] # use NestedSampler with identity prior chain prior_chain = PriorChain() for name in param_names: prior = UniformPrior(name + "_base", prototype_trace[name]["fn"].shape()) prior_chain.push(prior) # XXX: the `marginalised` keyword in jaxns can be used to get expectation of some # quantity over posterior samples; it can be helpful to expose it in this wrapper ns = OrigNestedSampler( loglik_fn, prior_chain, sampler_name=self.sampler_name, sampler_kwargs={ "depth": self.depth, "num_slices": self.num_slices }, max_samples=self.max_samples, num_live_points=self.num_live_points, collect_samples=True, ) # some places of jaxns uses float64 and raises some warnings if the default dtype is # float32, so we suppress them here to avoid confusion with warnings.catch_warnings(): warnings.filterwarnings( "ignore", message=".*will be truncated to dtype float32.*") results = ns(rng_sampling, termination_frac=self.termination_frac) # transform base samples back to original domains # Here we only transform the first valid num_samples samples # NB: the number of weighted samples obtained from jaxns is results.num_samples # and only the first num_samples values of results.samples are valid. num_samples = results.num_samples samples = tree_util.tree_map(lambda x: x[:num_samples], results.samples) predictive = Predictive(reparam_model, samples, return_sites=param_names + deterministics) samples = predictive(rng_predictive, *args, **kwargs) # replace base samples in jaxns results by transformed samples self._results = results._replace(samples=samples)
def initialize_model(rng_key, model, init_strategy=init_to_uniform, dynamic_args=False, model_args=(), model_kwargs=None): """ (EXPERIMENTAL INTERFACE) Helper function that calls :func:`~numpyro.infer.util.get_potential_fn` and :func:`~numpyro.infer.util.find_valid_initial_params` under the hood to return a tuple of (`init_params_info`, `potential_fn`, `postprocess_fn`, `model_trace`). :param jax.random.PRNGKey rng_key: random number generator seed to sample from the prior. The returned `init_params` will have the batch shape ``rng_key.shape[:-1]``. :param model: Python callable containing Pyro primitives. :param callable init_strategy: a per-site initialization function. See :ref:`init_strategy` section for available functions. :param bool dynamic_args: if `True`, the `potential_fn` and `constraints_fn` are themselves dependent on model arguments. When provided a `*model_args, **model_kwargs`, they return `potential_fn` and `constraints_fn` callables, respectively. :param tuple model_args: args provided to the model. :param dict model_kwargs: kwargs provided to the model. :return: a namedtupe `ModelInfo` which contains the fields (`param_info`, `potential_fn`, `postprocess_fn`, `model_trace`), where `param_info` is a namedtuple `ParamInfo` containing values from the prior used to initiate MCMC, their corresponding potential energy, and their gradients; `postprocess_fn` is a callable that uses inverse transforms to convert unconstrained HMC samples to constrained values that lie within the site's support, in addition to returning values at `deterministic` sites in the model. """ model_kwargs = {} if model_kwargs is None else model_kwargs substituted_model = substitute(seed( model, rng_key if jnp.ndim(rng_key) == 1 else rng_key[0]), substitute_fn=init_strategy) inv_transforms, replay_model, has_enumerate_support, model_trace = _get_model_transforms( substituted_model, model_args, model_kwargs) constrained_values = { k: v['value'] for k, v in model_trace.items() if v['type'] == 'sample' and not v['is_observed'] and not v['fn'].is_discrete } if has_enumerate_support: from numpyro.contrib.funsor import config_enumerate, enum if not isinstance(model, enum): max_plate_nesting = _guess_max_plate_nesting(model_trace) model = enum(config_enumerate(model), -max_plate_nesting - 1) potential_fn, postprocess_fn = get_potential_fn(model, inv_transforms, replay_model=replay_model, enum=has_enumerate_support, dynamic_args=dynamic_args, model_args=model_args, model_kwargs=model_kwargs) init_strategy = init_strategy if isinstance(init_strategy, partial) else init_strategy() if (init_strategy.func is init_to_value) and not replay_model: init_values = init_strategy.keywords.get("values") unconstrained_values = transform_fn(inv_transforms, init_values, invert=True) init_strategy = _init_to_unconstrained_value( values=unconstrained_values) prototype_params = transform_fn(inv_transforms, constrained_values, invert=True) (init_params, pe, grad), is_valid = find_valid_initial_params( rng_key, model, init_strategy=init_strategy, enum=has_enumerate_support, model_args=model_args, model_kwargs=model_kwargs, prototype_params=prototype_params) if not_jax_tracer(is_valid): if device_get(~jnp.all(is_valid)): raise RuntimeError( "Cannot find valid initial parameters. Please check your model again." ) return ModelInfo(ParamInfo(init_params, pe, grad), potential_fn, postprocess_fn, model_trace)
def initialize_model( rng_key, model, *, init_strategy=init_to_uniform, dynamic_args=False, model_args=(), model_kwargs=None, forward_mode_differentiation=False, validate_grad=True, ): """ (EXPERIMENTAL INTERFACE) Helper function that calls :func:`~numpyro.infer.util.get_potential_fn` and :func:`~numpyro.infer.util.find_valid_initial_params` under the hood to return a tuple of (`init_params_info`, `potential_fn`, `postprocess_fn`, `model_trace`). :param jax.random.PRNGKey rng_key: random number generator seed to sample from the prior. The returned `init_params` will have the batch shape ``rng_key.shape[:-1]``. :param model: Python callable containing Pyro primitives. :param callable init_strategy: a per-site initialization function. See :ref:`init_strategy` section for available functions. :param bool dynamic_args: if `True`, the `potential_fn` and `constraints_fn` are themselves dependent on model arguments. When provided a `*model_args, **model_kwargs`, they return `potential_fn` and `constraints_fn` callables, respectively. :param tuple model_args: args provided to the model. :param dict model_kwargs: kwargs provided to the model. :param bool forward_mode_differentiation: whether to use forward-mode differentiation or reverse-mode differentiation. By default, we use reverse mode but the forward mode can be useful in some cases to improve the performance. In addition, some control flow utility on JAX such as `jax.lax.while_loop` or `jax.lax.fori_loop` only supports forward-mode differentiation. See `JAX's The Autodiff Cookbook <https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html>`_ for more information. :param bool validate_grad: whether to validate gradient of the initial params. Defaults to True. :return: a namedtupe `ModelInfo` which contains the fields (`param_info`, `potential_fn`, `postprocess_fn`, `model_trace`), where `param_info` is a namedtuple `ParamInfo` containing values from the prior used to initiate MCMC, their corresponding potential energy, and their gradients; `postprocess_fn` is a callable that uses inverse transforms to convert unconstrained HMC samples to constrained values that lie within the site's support, in addition to returning values at `deterministic` sites in the model. """ model_kwargs = {} if model_kwargs is None else model_kwargs substituted_model = substitute( seed(model, rng_key if jnp.ndim(rng_key) == 1 else rng_key[0]), substitute_fn=init_strategy, ) ( inv_transforms, replay_model, has_enumerate_support, model_trace, ) = _get_model_transforms(substituted_model, model_args, model_kwargs) # substitute param sites from model_trace to model so # we don't need to generate again parameters of `numpyro.module` model = substitute( model, data={ k: site["value"] for k, site in model_trace.items() if site["type"] in ["param"] }, ) constrained_values = { k: v["value"] for k, v in model_trace.items() if v["type"] == "sample" and not v["is_observed"] and not v["fn"].is_discrete } if has_enumerate_support: from numpyro.contrib.funsor import config_enumerate, enum if not isinstance(model, enum): max_plate_nesting = _guess_max_plate_nesting(model_trace) _validate_model(model_trace) model = enum(config_enumerate(model), -max_plate_nesting - 1) potential_fn, postprocess_fn = get_potential_fn( model, inv_transforms, replay_model=replay_model, enum=has_enumerate_support, dynamic_args=dynamic_args, model_args=model_args, model_kwargs=model_kwargs, ) init_strategy = (init_strategy if isinstance(init_strategy, partial) else init_strategy()) if (init_strategy.func is init_to_value) and not replay_model: init_values = init_strategy.keywords.get("values") unconstrained_values = transform_fn(inv_transforms, init_values, invert=True) init_strategy = _init_to_unconstrained_value( values=unconstrained_values) prototype_params = transform_fn(inv_transforms, constrained_values, invert=True) (init_params, pe, grad), is_valid = find_valid_initial_params( rng_key, substitute( model, data={ k: site["value"] for k, site in model_trace.items() if site["type"] in ["plate"] }, ), init_strategy=init_strategy, enum=has_enumerate_support, model_args=model_args, model_kwargs=model_kwargs, prototype_params=prototype_params, forward_mode_differentiation=forward_mode_differentiation, validate_grad=validate_grad, ) if not_jax_tracer(is_valid): if device_get(~jnp.all(is_valid)): with numpyro.validation_enabled(), trace() as tr: # validate parameters substituted_model(*model_args, **model_kwargs) # validate values for site in tr.values(): if site["type"] == "sample": with warnings.catch_warnings(record=True) as ws: site["fn"]._validate_sample(site["value"]) if len(ws) > 0: for w in ws: # at site information to the warning message w.message.args = ("Site {}: {}".format( site["name"], w.message.args[0]), ) + w.message.args[1:] warnings.showwarning( w.message, w.category, w.filename, w.lineno, file=w.file, line=w.line, ) raise RuntimeError( "Cannot find valid initial parameters. Please check your model again." ) return ModelInfo(ParamInfo(init_params, pe, grad), potential_fn, postprocess_fn, model_trace)
def scan_enum( f, init, xs, length, reverse, rng_key=None, substitute_stack=None, history=1, first_available_dim=None, ): from numpyro.contrib.funsor import ( config_enumerate, enum, markov, trace as packed_trace, ) # amount number of steps to unroll history = min(history, length) unroll_steps = min(2 * history - 1, length) if reverse: x0 = tree_map(lambda x: x[-unroll_steps:][::-1], xs) xs_ = tree_map(lambda x: x[:-unroll_steps], xs) else: x0 = tree_map(lambda x: x[:unroll_steps], xs) xs_ = tree_map(lambda x: x[unroll_steps:], xs) carry_shapes = [] def body_fn(wrapped_carry, x, prefix=None): i, rng_key, carry = wrapped_carry init = True if (not_jax_tracer(i) and i in range(unroll_steps)) else False rng_key, subkey = random.split(rng_key) if rng_key is not None else ( None, None) # we need to tell unconstrained messenger in potential energy computation # that only the item at time `i` is needed when transforming fn = handlers.infer_config( f, config_fn=lambda msg: {"_scan_current_index": i}) seeded_fn = handlers.seed(fn, subkey) if subkey is not None else fn for subs_type, subs_map in substitute_stack: subs_fn = partial(_subs_wrapper, subs_map, i, length) if subs_type == "condition": seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn) elif subs_type == "substitute": seeded_fn = handlers.substitute(seeded_fn, substitute_fn=subs_fn) if init: # handler the name to match the pattern of sakkar_bilmes product with handlers.scope(prefix="_PREV_" * (unroll_steps - i), divider=""): new_carry, y = config_enumerate(seeded_fn)(carry, x) trace = {} else: # Like scan_wrapper, we collect the trace of scan's transition function # `seeded_fn` here. To put time dimension to the correct position, we need to # promote shapes to make `fn` and `value` # at each site have the same batch dims (e.g. if `fn.batch_shape = (2, 3)`, # and value's batch_shape is (3,), then we promote shape of # value so that its batch shape is (1, 3)). # Here we will promote `fn` shape first. `value` shape will be promoted after scanned. # We don't promote `value` shape here because we need to store carry shape # at this step. If we reshape the `value` here, output carry might get wrong shape. with _promote_fn_shapes(), packed_trace() as trace: new_carry, y = config_enumerate(seeded_fn)(carry, x) # store shape of new_carry at a global variable if len(carry_shapes) < (history + 1): carry_shapes.append( [jnp.shape(x) for x in tree_flatten(new_carry)[0]]) # make new_carry have the same shape as carry # FIXME: is this rigorous? new_carry = tree_multimap( lambda a, b: jnp.reshape(a, jnp.shape(b)), new_carry, carry) return (i + 1, rng_key, new_carry), (PytreeTrace(trace), y) with handlers.block( hide_fn=lambda site: not site["name"].startswith("_PREV_")), enum( first_available_dim=first_available_dim): wrapped_carry = (0, rng_key, init) y0s = [] # We run unroll_steps + 1 where the last step is used for rolling with `lax.scan` for i in markov(range(unroll_steps + 1), history=history): if i < unroll_steps: wrapped_carry, (_, y0) = body_fn(wrapped_carry, tree_map(lambda z: z[i], x0)) if i > 0: # reshape y1, y2,... to have the same shape as y0 y0 = tree_multimap( lambda z0, z: jnp.reshape(z, jnp.shape(z0)), y0s[0], y0) y0s.append(y0) # shapes of the first `history - 1` steps are not useful to interpret the last carry # shape so we don't need to record them here if (i >= history - 1) and (len(carry_shapes) < history + 1): carry_shapes.append( jnp.shape(x) for x in tree_flatten(wrapped_carry[-1])[0]) else: # this is the last rolling step y0s = tree_multimap(lambda *z: jnp.stack(z, axis=0), *y0s) # return early if length = unroll_steps if length == unroll_steps: return wrapped_carry, (PytreeTrace({}), y0s) wrapped_carry = device_put(wrapped_carry) wrapped_carry, (pytree_trace, ys) = lax.scan(body_fn, wrapped_carry, xs_, length - unroll_steps, reverse) first_var = None for name, site in pytree_trace.trace.items(): # currently, we only record sample or deterministic in the trace # we don't need to adjust `dim_to_name` for deterministic site if site["type"] not in ("sample", ): continue # add `time` dimension, the name will be '_time_{first variable in the trace}' if first_var is None: first_var = name # we haven't promote shapes of values yet during `lax.scan`, so we do it here site["value"] = _promote_scanned_value_shapes(site["value"], site["fn"]) # XXX: site['infer']['dim_to_name'] is not enough to determine leftmost dimension because # we don't record 1-size dimensions in this field time_dim = -min(len(site["fn"].batch_shape), jnp.ndim(site["value"]) - site["fn"].event_dim) site["infer"]["dim_to_name"][time_dim] = "_time_{}".format(first_var) # similar to carry, we need to reshape due to shape alternating in markov ys = tree_multimap( lambda z0, z: jnp.reshape(z, z.shape[:1] + jnp.shape(z0)[1:]), y0s, ys) # then join with y0s ys = tree_multimap(lambda z0, z: jnp.concatenate([z0, z], axis=0), y0s, ys) # we also need to reshape `carry` to match sequential behavior i = (length + 1) % (history + 1) t, rng_key, carry = wrapped_carry carry_shape = carry_shapes[i] flatten_carry, treedef = tree_flatten(carry) flatten_carry = [ jnp.reshape(x, t1_shape) for x, t1_shape in zip(flatten_carry, carry_shape) ] carry = tree_unflatten(treedef, flatten_carry) wrapped_carry = (t, rng_key, carry) return wrapped_carry, (pytree_trace, ys)
def init(self, rng_key, *args, **kwargs): """ :param jax.random.PRNGKey rng_key: random number generator seed. :param args: arguments to the model / guide (these can possibly vary during the course of fitting). :param kwargs: keyword arguments to the model / guide (these can possibly vary during the course of fitting). :return: initial :data:`SteinVIState` """ rng_key, kernel_seed, model_seed, guide_seed = jax.random.split( rng_key, 4) model_init = handlers.seed(self.model, model_seed) guide_init = handlers.seed(self.guide, guide_seed) guide_trace = handlers.trace(guide_init).get_trace( *args, **kwargs, **self.static_kwargs) model_trace = handlers.trace(model_init).get_trace( *args, **kwargs, **self.static_kwargs) rng_key, particle_seed = jax.random.split(rng_key) guide_init_params = self._find_init_params(particle_seed, self.guide, guide_trace) params = {} transforms = {} inv_transforms = {} particle_transforms = {} guide_param_names = set() should_enum = False for site in model_trace.values(): if ("fn" in site and site["type"] == "sample" and not site["is_observed"] and isinstance(site["fn"], Distribution) and site["fn"].is_discrete): if site["fn"].has_enumerate_support and self.enum: should_enum = True else: raise Exception( "Cannot enumerate model with discrete variables without enumerate support" ) # NB: params in model_trace will be overwritten by params in guide_trace for site in chain(model_trace.values(), guide_trace.values()): if site["type"] == "param": transform = get_parameter_transform(site) inv_transforms[site["name"]] = transform transforms[site["name"]] = transform.inv particle_transforms[site["name"]] = site.get( "particle_transform", IdentityTransform()) if site["name"] in guide_init_params: pval, _ = guide_init_params[site["name"]] if self.classic_guide_params_fn(site["name"]): pval = tree_map(lambda x: x[0], pval) else: pval = site["value"] params[site["name"]] = transform.inv(pval) if site["name"] in guide_trace: guide_param_names.add(site["name"]) if should_enum: mpn = _guess_max_plate_nesting(model_trace) self._inference_model = enum(config_enumerate(self.model), -mpn - 1) self.guide_param_names = guide_param_names self.constrain_fn = partial(transform_fn, inv_transforms) self.uconstrain_fn = partial(transform_fn, transforms) self.particle_transforms = particle_transforms self.particle_transform_fn = partial(transform_fn, particle_transforms) stein_particles, _, _ = batch_ravel_pytree( { k: params[k] for k, site in guide_trace.items() if site["type"] == "param" and site["name"] in guide_init_params }, nbatch_dims=1, ) self.kernel_fn.init(kernel_seed, stein_particles.shape) return SteinVIState(self.optim.init(params), rng_key)