def test_scan_enum_plate(): N, D = 10, 3 data = random.normal(random.PRNGKey(0), (N, D)) init_probs = jnp.array([0.6, 0.4]) transition_probs = jnp.array([[0.8, 0.2], [0.1, 0.9]]) locs = jnp.array([-1.0, 1.0]) def model(data): x = None D_plate = numpyro.plate("D", D, dim=-1) for i, y in markov(enumerate(data)): with D_plate: probs = init_probs if x is None else transition_probs[x] x = numpyro.sample(f"x_{i}", dist.Categorical(probs)) numpyro.sample(f"y_{i}", dist.Normal(locs[x], 1), obs=y) def fun_model(data): def transition_fn(x, y): probs = init_probs if x is None else transition_probs[x] with numpyro.plate("D", D, dim=-1): x = numpyro.sample("x", dist.Categorical(probs)) numpyro.sample("y", dist.Normal(locs[x], 1), obs=y) return x, None scan(transition_fn, None, data) actual_log_joint = log_density(enum(config_enumerate(fun_model), -2), (data,), {}, {})[0] expected_log_joint = log_density(enum(config_enumerate(model), -2), (data,), {}, {})[0] assert_allclose(actual_log_joint, expected_log_joint)
def test_scan_enum_separated_plate_discrete(): N, D = 10, 3 data = random.normal(random.PRNGKey(0), (N, D)) transition_probs = jnp.array([[0.8, 0.2], [0.1, 0.9]]) locs = jnp.array([[-1.0, 1.0], [2.0, 3.0]]) def model(data): x = 0 D_plate = numpyro.plate("D", D, dim=-1) for i, y in markov(enumerate(data)): probs = transition_probs[x] x = numpyro.sample(f"x_{i}", dist.Categorical(probs)) with D_plate: w = numpyro.sample(f"w_{i}", dist.Bernoulli(0.6)) numpyro.sample(f"y_{i}", dist.Normal(Vindex(locs)[x, w], 1), obs=y) def fun_model(data): def transition_fn(x, y): probs = transition_probs[x] x = numpyro.sample("x", dist.Categorical(probs)) with numpyro.plate("D", D, dim=-1): w = numpyro.sample("w", dist.Bernoulli(0.6)) numpyro.sample("y", dist.Normal(Vindex(locs)[x, w], 1), obs=y) return x, None scan(transition_fn, 0, data) actual_log_joint = log_density(enum(config_enumerate(fun_model), -2), (data,), {}, {})[0] expected_log_joint = log_density(enum(config_enumerate(model), -2), (data,), {}, {})[0] assert_allclose(actual_log_joint, expected_log_joint)
def test_scan_enum_discrete_outside(): data = random.normal(random.PRNGKey(0), (10,)) probs = jnp.array([[[0.8, 0.2], [0.1, 0.9]], [[0.7, 0.3], [0.6, 0.4]]]) locs = jnp.array([-1.0, 1.0]) def model(data): w = numpyro.sample("w", dist.Bernoulli(0.6)) x = 0 for i, y in markov(enumerate(data)): x = numpyro.sample(f"x_{i}", dist.Categorical(probs[w, x])) numpyro.sample(f"y_{i}", dist.Normal(locs[x], 1), obs=y) def fun_model(data): w = numpyro.sample("w", dist.Bernoulli(0.6)) def transition_fn(x, y): x = numpyro.sample("x", dist.Categorical(probs[w, x])) numpyro.sample("y", dist.Normal(locs[x], 1), obs=y) return x, None scan(transition_fn, 0, data) actual_log_joint = log_density(enum(config_enumerate(fun_model)), (data,), {}, {})[0] expected_log_joint = log_density(enum(config_enumerate(model)), (data,), {}, {})[0] assert_allclose(actual_log_joint, expected_log_joint)
def test_scan_enum_two_latents(): num_steps = 11 data = random.normal(random.PRNGKey(0), (num_steps,)) probs_x = jnp.array([[0.8, 0.2], [0.1, 0.9]]) probs_w = jnp.array([[0.7, 0.3], [0.6, 0.4]]) locs = jnp.array([[-1.0, 1.0], [2.0, 3.0]]) def model(data): x = w = 0 for i, y in markov(enumerate(data)): x = numpyro.sample(f"x_{i}", dist.Categorical(probs_x[x])) w = numpyro.sample(f"w_{i}", dist.Categorical(probs_w[w])) numpyro.sample(f"y_{i}", dist.Normal(locs[w, x], 1), obs=y) def fun_model(data): def transition_fn(carry, y): x, w = carry x = numpyro.sample("x", dist.Categorical(probs_x[x])) w = numpyro.sample("w", dist.Categorical(probs_w[w])) numpyro.sample("y", dist.Normal(locs[w, x], 1), obs=y) # also test if scan's `ys` are recorded corrected return (x, w), x scan(transition_fn, (0, 0), data) actual_log_joint = log_density(enum(config_enumerate(fun_model)), (data,), {}, {})[0] expected_log_joint = log_density(enum(config_enumerate(model)), (data,), {}, {})[0] assert_allclose(actual_log_joint, expected_log_joint)
def test_scan_enum_scan_enum(): num_steps = 11 data_x = random.normal(random.PRNGKey(0), (num_steps,)) data_w = data_x[:-1] + 1 probs_x = jnp.array([[0.8, 0.2], [0.1, 0.9]]) probs_w = jnp.array([[0.7, 0.3], [0.6, 0.4]]) locs_x = jnp.array([-1.0, 1.0]) locs_w = jnp.array([2.0, 3.0]) def model(data_x, data_w): x = w = 0 for i, y in markov(enumerate(data_x)): x = numpyro.sample(f"x_{i}", dist.Categorical(probs_x[x])) numpyro.sample(f"y_x_{i}", dist.Normal(locs_x[x], 1), obs=y) for i, y in markov(enumerate(data_w)): w = numpyro.sample(f"w{i}", dist.Categorical(probs_w[w])) numpyro.sample(f"y_w_{i}", dist.Normal(locs_w[w], 1), obs=y) def fun_model(data_x, data_w): def transition_fn(name, probs, locs, x, y): x = numpyro.sample(name, dist.Categorical(probs[x])) numpyro.sample("y_" + name, dist.Normal(locs[x], 1), obs=y) return x, None scan(partial(transition_fn, "x", probs_x, locs_x), 0, data_x) scan(partial(transition_fn, "w", probs_w, locs_w), 0, data_w) actual_log_joint = log_density(enum(config_enumerate(fun_model)), (data_x, data_w), {}, {})[0] expected_log_joint = log_density(enum(config_enumerate(model)), (data_x, data_w), {}, {})[0] assert_allclose(actual_log_joint, expected_log_joint)
def test_scan_enum_one_latent(num_steps): data = random.normal(random.PRNGKey(0), (num_steps,)) init_probs = jnp.array([0.6, 0.4]) transition_probs = jnp.array([[0.8, 0.2], [0.1, 0.9]]) locs = jnp.array([-1.0, 1.0]) def model(data): x = None for i, y in markov(enumerate(data)): probs = init_probs if x is None else transition_probs[x] x = numpyro.sample(f"x_{i}", dist.Categorical(probs)) numpyro.sample(f"y_{i}", dist.Normal(locs[x], 1), obs=y) return x def fun_model(data): def transition_fn(x, y): probs = init_probs if x is None else transition_probs[x] x = numpyro.sample("x", dist.Categorical(probs)) numpyro.sample("y", dist.Normal(locs[x], 1), obs=y) return x, None x, collections = scan(transition_fn, None, data) assert collections is None return x actual_log_joint = log_density(enum(config_enumerate(fun_model)), (data,), {}, {})[0] expected_log_joint = log_density(enum(config_enumerate(model)), (data,), {}, {})[0] assert_allclose(actual_log_joint, expected_log_joint) actual_last_x = enum(config_enumerate(fun_model))(data) expected_last_x = enum(config_enumerate(model))(data) assert_allclose(actual_last_x, expected_last_x)
def body_fn(wrapped_carry, x, prefix=None): i, rng_key, carry = wrapped_carry init = True if (not_jax_tracer(i) and i in range(unroll_steps)) else False rng_key, subkey = random.split(rng_key) if rng_key is not None else ( None, None) # we need to tell unconstrained messenger in potential energy computation # that only the item at time `i` is needed when transforming fn = handlers.infer_config( f, config_fn=lambda msg: {"_scan_current_index": i}) seeded_fn = handlers.seed(fn, subkey) if subkey is not None else fn for subs_type, subs_map in substitute_stack: subs_fn = partial(_subs_wrapper, subs_map, i, length) if subs_type == "condition": seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn) elif subs_type == "substitute": seeded_fn = handlers.substitute(seeded_fn, substitute_fn=subs_fn) if init: # handler the name to match the pattern of sakkar_bilmes product with handlers.scope(prefix="_PREV_" * (unroll_steps - i), divider=""): new_carry, y = config_enumerate(seeded_fn)(carry, x) trace = {} else: # Like scan_wrapper, we collect the trace of scan's transition function # `seeded_fn` here. To put time dimension to the correct position, we need to # promote shapes to make `fn` and `value` # at each site have the same batch dims (e.g. if `fn.batch_shape = (2, 3)`, # and value's batch_shape is (3,), then we promote shape of # value so that its batch shape is (1, 3)). # Here we will promote `fn` shape first. `value` shape will be promoted after scanned. # We don't promote `value` shape here because we need to store carry shape # at this step. If we reshape the `value` here, output carry might get wrong shape. with _promote_fn_shapes(), packed_trace() as trace: new_carry, y = config_enumerate(seeded_fn)(carry, x) # store shape of new_carry at a global variable if len(carry_shapes) < (history + 1): carry_shapes.append( [jnp.shape(x) for x in tree_flatten(new_carry)[0]]) # make new_carry have the same shape as carry # FIXME: is this rigorous? new_carry = tree_multimap( lambda a, b: jnp.reshape(a, jnp.shape(b)), new_carry, carry) return (i + 1, rng_key, new_carry), (PytreeTrace(trace), y)
def test_scan_hmm_smoke(length, temperature): # This should match the example in the infer_discrete docstring. def hmm(data, hidden_dim=10): transition = 0.3 / hidden_dim + 0.7 * jnp.eye(hidden_dim) means = jnp.arange(float(hidden_dim)) def transition_fn(state, y): state = numpyro.sample("states", dist.Categorical(transition[state])) y = numpyro.sample("obs", dist.Normal(means[state], 1.0), obs=y) return state, (state, y) _, (states, data) = scan(transition_fn, 0, data, length=length) return [0] + [s for s in states], data true_states, data = handlers.seed(hmm, 0)(None) assert len(data) == length assert len(true_states) == 1 + len(data) decoder = infer_discrete(config_enumerate(hmm), temperature=temperature, rng_key=random.PRNGKey(1)) inferred_states, _ = decoder(data) assert len(inferred_states) == len(true_states) logger.info("true states: {}".format(list(map(int, true_states)))) logger.info("inferred states: {}".format(list(map(int, inferred_states))))
def test_hmm_smoke(length, temperature): # This should match the example in the infer_discrete docstring. def hmm(data, hidden_dim=10): transition = 0.3 / hidden_dim + 0.7 * jnp.eye(hidden_dim) means = jnp.arange(float(hidden_dim)) states = [0] for t in markov(range(len(data))): states.append( numpyro.sample( "states_{}".format(t), dist.Categorical(transition[states[-1]]) ) ) data[t] = numpyro.sample( "obs_{}".format(t), dist.Normal(means[states[-1]], 1.0), obs=data[t] ) return states, data true_states, data = handlers.seed(hmm, 0)([None] * length) assert len(data) == length assert len(true_states) == 1 + len(data) decoder = infer_discrete( config_enumerate(hmm), temperature=temperature, rng_key=random.PRNGKey(1) ) inferred_states, _ = decoder(data) assert len(inferred_states) == len(true_states) logger.info("true states: {}".format(list(map(int, true_states)))) logger.info("inferred states: {}".format(list(map(int, inferred_states))))
def body_fn(wrapped_carry, x, prefix=None): i, rng_key, carry = wrapped_carry init = True if (not_jax_tracer(i) and i == 0) else False rng_key, subkey = random.split(rng_key) if rng_key is not None else (None, None) seeded_fn = handlers.seed(f, subkey) if subkey is not None else f for subs_type, subs_map in substitute_stack: subs_fn = partial(_subs_wrapper, subs_map, i, length) if subs_type == 'condition': seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn) elif subs_type == 'substitute': seeded_fn = handlers.substitute(seeded_fn, substitute_fn=subs_fn) if init: with handlers.scope(prefix="_init"): new_carry, y = seeded_fn(carry, x) trace = {} else: with handlers.block(), packed_trace() as trace, promote_shapes(), enum(), markov(): # Like scan_wrapper, we collect the trace of scan's transition function # `seeded_fn` here. To put time dimension to the correct position, we need to # promote shapes to make `fn` and `value` # at each site have the same batch dims (e.g. if `fn.batch_shape = (2, 3)`, # and value's batch_shape is (3,), then we promote shape of # value so that its batch shape is (1, 3)). new_carry, y = config_enumerate(seeded_fn)(carry, x) # store shape of new_carry at a global variable nonlocal carry_shape_at_t1 carry_shape_at_t1 = [jnp.shape(x) for x in tree_flatten(new_carry)[0]] # make new_carry have the same shape as carry # FIXME: is this rigorous? new_carry = tree_multimap(lambda a, b: jnp.reshape(a, jnp.shape(b)), new_carry, carry) return (i + jnp.array(1), rng_key, new_carry), (PytreeTrace(trace), y)
def gibbs_fn(rng_key, gibbs_sites, hmc_sites): # convert to unconstrained values z_hmc = { k: biject_to(prototype_trace[k]["fn"].support).inv(v) for k, v in hmc_sites.items() if k in prototype_trace and prototype_trace[k]["type"] == "sample" } use_enum = len(set(support_sizes) - set(gibbs_sites)) > 0 wrapped_model = _wrap_model(model) if use_enum: from numpyro.contrib.funsor import config_enumerate, enum wrapped_model = enum(config_enumerate(wrapped_model), -max_plate_nesting - 1) def potential_fn(z_discrete): model_kwargs_ = model_kwargs.copy() model_kwargs_["_gibbs_sites"] = z_discrete return potential_energy(wrapped_model, model_args, model_kwargs_, z_hmc, enum=use_enum) # get support_sizes of gibbs_sites support_sizes_flat, _ = ravel_pytree( {k: support_sizes[k] for k in gibbs_sites}) num_discretes = support_sizes_flat.shape[0] rng_key, rng_permute = random.split(rng_key) idxs = random.permutation(rng_key, jnp.arange(num_discretes)) def body_fn(i, val): idx = idxs[i] support_size = support_sizes_flat[idx] rng_key, z, pe = val rng_key, z_new, pe_new, log_accept_ratio = proposal_fn( rng_key, z, pe, potential_fn=potential_fn, idx=idx, support_size=support_size) rng_key, rng_accept = random.split(rng_key) # u ~ Uniform(0, 1), u < accept_ratio => -log(u) > -log_accept_ratio # and -log(u) ~ exponential(1) z, pe = cond( random.exponential(rng_accept) > -log_accept_ratio, (z_new, pe_new), identity, (z, pe), identity) return rng_key, z, pe init_val = (rng_key, gibbs_sites, potential_fn(gibbs_sites)) _, gibbs_sites, _ = fori_loop(0, num_discretes, body_fn, init_val) return gibbs_sites
def test_scan_history(history, T): def model(): p = numpyro.param("p", 0.25 * jnp.ones((2, 2, 2))) q = numpyro.param("q", 0.25 * jnp.ones(2)) z = numpyro.sample("z", dist.Bernoulli(0.5)) x_prev = 0 x_curr = 0 for t in markov(range(T), history=history): probs = p[x_prev, x_curr, z] x_prev, x_curr = x_curr, numpyro.sample("x_{}".format(t), dist.Bernoulli(probs)) numpyro.sample("y_{}".format(t), dist.Bernoulli(q[x_curr]), obs=0) return x_prev, x_curr def fun_model(): p = numpyro.param("p", 0.25 * jnp.ones((2, 2, 2))) q = numpyro.param("q", 0.25 * jnp.ones(2)) z = numpyro.sample("z", dist.Bernoulli(0.5)) def transition_fn(carry, y): x_prev, x_curr = carry probs = p[x_prev, x_curr, z] x_prev, x_curr = x_curr, numpyro.sample("x", dist.Bernoulli(probs)) numpyro.sample("y", dist.Bernoulli(q[x_curr]), obs=y) return (x_prev, x_curr), None (x_prev, x_curr), _ = scan(transition_fn, (0, 0), jnp.zeros(T), history=history) return x_prev, x_curr expected_log_joint = log_density(enum(config_enumerate(model)), (), {}, {})[0] actual_log_joint = log_density(enum(config_enumerate(fun_model)), (), {}, {})[0] assert_allclose(actual_log_joint, expected_log_joint) expected_x_prev, expected_x_curr = enum(config_enumerate(model))() actual_x_prev, actual_x_curr = enum(config_enumerate(fun_model))() assert_allclose(actual_x_prev, expected_x_prev) assert_allclose(actual_x_curr, expected_x_curr)
def gibbs_fn(rng_key, gibbs_sites, hmc_sites): z_hmc = hmc_sites use_enum = len(set(support_sizes) - set(gibbs_sites)) > 0 if use_enum: from numpyro.contrib.funsor import config_enumerate, enum wrapped_model_ = enum(config_enumerate(wrapped_model), -max_plate_nesting - 1) else: wrapped_model_ = wrapped_model def potential_fn(z_discrete): model_kwargs_ = model_kwargs.copy() model_kwargs_["_gibbs_sites"] = z_discrete return potential_energy(wrapped_model_, model_args, model_kwargs_, z_hmc, enum=use_enum) # get support_sizes of gibbs_sites support_sizes_flat, _ = ravel_pytree( {k: support_sizes[k] for k in gibbs_sites}) num_discretes = support_sizes_flat.shape[0] rng_key, rng_permute = random.split(rng_key) idxs = random.permutation(rng_key, jnp.arange(num_discretes)) def body_fn(i, val): idx = idxs[i] support_size = support_sizes_flat[idx] rng_key, z, pe = val rng_key, z_new, pe_new, log_accept_ratio = proposal_fn( rng_key, z, pe, potential_fn=potential_fn, idx=idx, support_size=support_size) rng_key, rng_accept = random.split(rng_key) # u ~ Uniform(0, 1), u < accept_ratio => -log(u) > -log_accept_ratio # and -log(u) ~ exponential(1) z, pe = cond( random.exponential(rng_accept) > -log_accept_ratio, (z_new, pe_new), identity, (z, pe), identity) return rng_key, z, pe init_val = (rng_key, gibbs_sites, potential_fn(gibbs_sites)) _, gibbs_sites, _ = fori_loop(0, num_discretes, body_fn, init_val) return gibbs_sites
def test_scan_enum_separated_plates_same_dim(): N, D1, D2 = 10, 3, 4 data = random.normal(random.PRNGKey(0), (N, D1 + D2)) data1, data2 = data[:, :D1], data[:, D1:] init_probs = jnp.array([0.6, 0.4]) transition_probs = jnp.array([[0.8, 0.2], [0.1, 0.9]]) locs = jnp.array([-1.0, 1.0]) def model(data1, data2): x = None D1_plate = numpyro.plate("D1", D1, dim=-1) D2_plate = numpyro.plate("D2", D2, dim=-1) for i, (y1, y2) in markov(enumerate(zip(data1, data2))): probs = init_probs if x is None else transition_probs[x] x = numpyro.sample(f"x_{i}", dist.Categorical(probs)) with D1_plate: numpyro.sample(f"y1_{i}", dist.Normal(locs[x], 1), obs=y1) with D2_plate: numpyro.sample(f"y2_{i}", dist.Normal(locs[x], 1), obs=y2) def fun_model(data1, data2): def transition_fn(x, y): y1, y2 = y probs = init_probs if x is None else transition_probs[x] x = numpyro.sample("x", dist.Categorical(probs)) with numpyro.plate("D1", D1, dim=-1): numpyro.sample("y1", dist.Normal(locs[x], 1), obs=y1) with numpyro.plate("D2", D2, dim=-1): numpyro.sample("y2", dist.Normal(locs[x], 1), obs=y2) return x, None scan(transition_fn, None, (data1, data2)) actual_log_joint = log_density(enum(config_enumerate(fun_model), -2), (data1, data2), {}, {})[0] expected_log_joint = log_density(enum(config_enumerate(model), -2), (data1, data2), {}, {})[0] assert_allclose(actual_log_joint, expected_log_joint)
def test_mcmc_model_side_enumeration(model, temperature): mcmc = infer.MCMC(infer.NUTS(config_enumerate(model)), num_warmup=0, num_samples=1) mcmc.run(random.PRNGKey(0)) mcmc_data = { k: v[0] for k, v in mcmc.get_samples().items() if k in ["loc", "scale"] } # MAP estimate discretes, conditioned on posterior sampled continous latents. model = handlers.seed(model, rng_seed=1) actual_trace = handlers.trace( infer_discrete( # TODO support replayed sites in infer_discrete. # handlers.replay(config_enumerate(model), mcmc_trace), handlers.condition(config_enumerate(model), mcmc_data), temperature=temperature, rng_key=random.PRNGKey(1), )).get_trace() # Check site names and shapes. expected_trace = handlers.trace(model).get_trace() assert set(actual_trace) == set(expected_trace)
def main(_args): data = generate_data() init_rng_key = PRNGKey(1273) # nuts = NUTS(gmm) # mcmc = MCMC(nuts, 100, 1000) # mcmc.print_summary() seeded_gmm = seed(gmm, init_rng_key) model_trace = trace(seeded_gmm).get_trace(data) max_plate_nesting = _guess_max_plate_nesting(model_trace) enum_gmm = enum(config_enumerate(gmm), - max_plate_nesting - 1) svi = SVI(enum_gmm, gmm_guide, Adam(0.1), RenyiELBO(-10.)) svi_state = svi.init(init_rng_key, data) upd_fun = jax.jit(svi.update) with tqdm.trange(100_000) as pbar: for i in pbar: svi_state, loss = upd_fun(svi_state, data) pbar.set_description(f"SVI {loss}", True)
def single_prediction(val): rng_key, samples = val if infer_discrete: from numpyro.contrib.funsor import config_enumerate from numpyro.contrib.funsor.discrete import _sample_posterior model_trace = prototype_trace temperature = 1 pred_samples = _sample_posterior( config_enumerate(condition(model, samples)), first_available_dim, temperature, rng_key, *model_args, **model_kwargs, ) else: model_trace = trace( seed(substitute(masked_model, samples), rng_key)).get_trace(*model_args, **model_kwargs) pred_samples = { name: site["value"] for name, site in model_trace.items() } if return_sites is not None: if return_sites == "": sites = { k for k, site in model_trace.items() if site["type"] != "plate" } else: sites = return_sites else: sites = { k for k, site in model_trace.items() if (site["type"] == "sample" and k not in samples) or ( site["type"] == "deterministic") } return { name: value for name, value in pred_samples.items() if name in sites }
def init(self, rng_key, *args, **kwargs): """ :param jax.random.PRNGKey rng_key: random number generator seed. :param args: arguments to the model / guide (these can possibly vary during the course of fitting). :param kwargs: keyword arguments to the model / guide (these can possibly vary during the course of fitting). :return: initial :data:`SteinVIState` """ rng_key, kernel_seed, model_seed, guide_seed = jax.random.split( rng_key, 4) model_init = handlers.seed(self.model, model_seed) guide_init = handlers.seed(self.guide, guide_seed) guide_trace = handlers.trace(guide_init).get_trace( *args, **kwargs, **self.static_kwargs) model_trace = handlers.trace(model_init).get_trace( *args, **kwargs, **self.static_kwargs) rng_key, particle_seed = jax.random.split(rng_key) guide_init_params = self._find_init_params(particle_seed, self.guide, guide_trace) params = {} transforms = {} inv_transforms = {} particle_transforms = {} guide_param_names = set() should_enum = False for site in model_trace.values(): if ("fn" in site and site["type"] == "sample" and not site["is_observed"] and isinstance(site["fn"], Distribution) and site["fn"].is_discrete): if site["fn"].has_enumerate_support and self.enum: should_enum = True else: raise Exception( "Cannot enumerate model with discrete variables without enumerate support" ) # NB: params in model_trace will be overwritten by params in guide_trace for site in chain(model_trace.values(), guide_trace.values()): if site["type"] == "param": transform = get_parameter_transform(site) inv_transforms[site["name"]] = transform transforms[site["name"]] = transform.inv particle_transforms[site["name"]] = site.get( "particle_transform", IdentityTransform()) if site["name"] in guide_init_params: pval, _ = guide_init_params[site["name"]] if self.classic_guide_params_fn(site["name"]): pval = tree_map(lambda x: x[0], pval) else: pval = site["value"] params[site["name"]] = transform.inv(pval) if site["name"] in guide_trace: guide_param_names.add(site["name"]) if should_enum: mpn = _guess_max_plate_nesting(model_trace) self._inference_model = enum(config_enumerate(self.model), -mpn - 1) self.guide_param_names = guide_param_names self.constrain_fn = partial(transform_fn, inv_transforms) self.uconstrain_fn = partial(transform_fn, transforms) self.particle_transforms = particle_transforms self.particle_transform_fn = partial(transform_fn, particle_transforms) stein_particles, _, _ = batch_ravel_pytree( { k: params[k] for k, site in guide_trace.items() if site["type"] == "param" and site["name"] in guide_init_params }, nbatch_dims=1, ) self.kernel_fn.init(kernel_seed, stein_particles.shape) return SteinVIState(self.optim.init(params), rng_key)
def initialize_model(rng_key, model, init_strategy=init_to_uniform, dynamic_args=False, model_args=(), model_kwargs=None): """ (EXPERIMENTAL INTERFACE) Helper function that calls :func:`~numpyro.infer.util.get_potential_fn` and :func:`~numpyro.infer.util.find_valid_initial_params` under the hood to return a tuple of (`init_params_info`, `potential_fn`, `postprocess_fn`, `model_trace`). :param jax.random.PRNGKey rng_key: random number generator seed to sample from the prior. The returned `init_params` will have the batch shape ``rng_key.shape[:-1]``. :param model: Python callable containing Pyro primitives. :param callable init_strategy: a per-site initialization function. See :ref:`init_strategy` section for available functions. :param bool dynamic_args: if `True`, the `potential_fn` and `constraints_fn` are themselves dependent on model arguments. When provided a `*model_args, **model_kwargs`, they return `potential_fn` and `constraints_fn` callables, respectively. :param tuple model_args: args provided to the model. :param dict model_kwargs: kwargs provided to the model. :return: a namedtupe `ModelInfo` which contains the fields (`param_info`, `potential_fn`, `postprocess_fn`, `model_trace`), where `param_info` is a namedtuple `ParamInfo` containing values from the prior used to initiate MCMC, their corresponding potential energy, and their gradients; `postprocess_fn` is a callable that uses inverse transforms to convert unconstrained HMC samples to constrained values that lie within the site's support, in addition to returning values at `deterministic` sites in the model. """ model_kwargs = {} if model_kwargs is None else model_kwargs substituted_model = substitute(seed( model, rng_key if jnp.ndim(rng_key) == 1 else rng_key[0]), substitute_fn=init_strategy) inv_transforms, replay_model, has_enumerate_support, model_trace = _get_model_transforms( substituted_model, model_args, model_kwargs) constrained_values = { k: v['value'] for k, v in model_trace.items() if v['type'] == 'sample' and not v['is_observed'] and not v['fn'].is_discrete } if has_enumerate_support: from numpyro.contrib.funsor import config_enumerate, enum if not isinstance(model, enum): max_plate_nesting = _guess_max_plate_nesting(model_trace) model = enum(config_enumerate(model), -max_plate_nesting - 1) potential_fn, postprocess_fn = get_potential_fn(model, inv_transforms, replay_model=replay_model, enum=has_enumerate_support, dynamic_args=dynamic_args, model_args=model_args, model_kwargs=model_kwargs) init_strategy = init_strategy if isinstance(init_strategy, partial) else init_strategy() if (init_strategy.func is init_to_value) and not replay_model: init_values = init_strategy.keywords.get("values") unconstrained_values = transform_fn(inv_transforms, init_values, invert=True) init_strategy = _init_to_unconstrained_value( values=unconstrained_values) prototype_params = transform_fn(inv_transforms, constrained_values, invert=True) (init_params, pe, grad), is_valid = find_valid_initial_params( rng_key, model, init_strategy=init_strategy, enum=has_enumerate_support, model_args=model_args, model_kwargs=model_kwargs, prototype_params=prototype_params) if not_jax_tracer(is_valid): if device_get(~jnp.all(is_valid)): raise RuntimeError( "Cannot find valid initial parameters. Please check your model again." ) return ModelInfo(ParamInfo(init_params, pe, grad), potential_fn, postprocess_fn, model_trace)
def initialize_model( rng_key, model, *, init_strategy=init_to_uniform, dynamic_args=False, model_args=(), model_kwargs=None, forward_mode_differentiation=False, validate_grad=True, ): """ (EXPERIMENTAL INTERFACE) Helper function that calls :func:`~numpyro.infer.util.get_potential_fn` and :func:`~numpyro.infer.util.find_valid_initial_params` under the hood to return a tuple of (`init_params_info`, `potential_fn`, `postprocess_fn`, `model_trace`). :param jax.random.PRNGKey rng_key: random number generator seed to sample from the prior. The returned `init_params` will have the batch shape ``rng_key.shape[:-1]``. :param model: Python callable containing Pyro primitives. :param callable init_strategy: a per-site initialization function. See :ref:`init_strategy` section for available functions. :param bool dynamic_args: if `True`, the `potential_fn` and `constraints_fn` are themselves dependent on model arguments. When provided a `*model_args, **model_kwargs`, they return `potential_fn` and `constraints_fn` callables, respectively. :param tuple model_args: args provided to the model. :param dict model_kwargs: kwargs provided to the model. :param bool forward_mode_differentiation: whether to use forward-mode differentiation or reverse-mode differentiation. By default, we use reverse mode but the forward mode can be useful in some cases to improve the performance. In addition, some control flow utility on JAX such as `jax.lax.while_loop` or `jax.lax.fori_loop` only supports forward-mode differentiation. See `JAX's The Autodiff Cookbook <https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html>`_ for more information. :param bool validate_grad: whether to validate gradient of the initial params. Defaults to True. :return: a namedtupe `ModelInfo` which contains the fields (`param_info`, `potential_fn`, `postprocess_fn`, `model_trace`), where `param_info` is a namedtuple `ParamInfo` containing values from the prior used to initiate MCMC, their corresponding potential energy, and their gradients; `postprocess_fn` is a callable that uses inverse transforms to convert unconstrained HMC samples to constrained values that lie within the site's support, in addition to returning values at `deterministic` sites in the model. """ model_kwargs = {} if model_kwargs is None else model_kwargs substituted_model = substitute( seed(model, rng_key if jnp.ndim(rng_key) == 1 else rng_key[0]), substitute_fn=init_strategy, ) ( inv_transforms, replay_model, has_enumerate_support, model_trace, ) = _get_model_transforms(substituted_model, model_args, model_kwargs) # substitute param sites from model_trace to model so # we don't need to generate again parameters of `numpyro.module` model = substitute( model, data={ k: site["value"] for k, site in model_trace.items() if site["type"] in ["param"] }, ) constrained_values = { k: v["value"] for k, v in model_trace.items() if v["type"] == "sample" and not v["is_observed"] and not v["fn"].is_discrete } if has_enumerate_support: from numpyro.contrib.funsor import config_enumerate, enum if not isinstance(model, enum): max_plate_nesting = _guess_max_plate_nesting(model_trace) _validate_model(model_trace) model = enum(config_enumerate(model), -max_plate_nesting - 1) potential_fn, postprocess_fn = get_potential_fn( model, inv_transforms, replay_model=replay_model, enum=has_enumerate_support, dynamic_args=dynamic_args, model_args=model_args, model_kwargs=model_kwargs, ) init_strategy = (init_strategy if isinstance(init_strategy, partial) else init_strategy()) if (init_strategy.func is init_to_value) and not replay_model: init_values = init_strategy.keywords.get("values") unconstrained_values = transform_fn(inv_transforms, init_values, invert=True) init_strategy = _init_to_unconstrained_value( values=unconstrained_values) prototype_params = transform_fn(inv_transforms, constrained_values, invert=True) (init_params, pe, grad), is_valid = find_valid_initial_params( rng_key, substitute( model, data={ k: site["value"] for k, site in model_trace.items() if site["type"] in ["plate"] }, ), init_strategy=init_strategy, enum=has_enumerate_support, model_args=model_args, model_kwargs=model_kwargs, prototype_params=prototype_params, forward_mode_differentiation=forward_mode_differentiation, validate_grad=validate_grad, ) if not_jax_tracer(is_valid): if device_get(~jnp.all(is_valid)): with numpyro.validation_enabled(), trace() as tr: # validate parameters substituted_model(*model_args, **model_kwargs) # validate values for site in tr.values(): if site["type"] == "sample": with warnings.catch_warnings(record=True) as ws: site["fn"]._validate_sample(site["value"]) if len(ws) > 0: for w in ws: # at site information to the warning message w.message.args = ("Site {}: {}".format( site["name"], w.message.args[0]), ) + w.message.args[1:] warnings.showwarning( w.message, w.category, w.filename, w.lineno, file=w.file, line=w.line, ) raise RuntimeError( "Cannot find valid initial parameters. Please check your model again." ) return ModelInfo(ParamInfo(init_params, pe, grad), potential_fn, postprocess_fn, model_trace)