def test_layers(self): with pm.Model(rng_seeder=232093) as model: a = pm.Uniform("a", lower=0, upper=1, size=10) b = pm.Binomial("b", n=1, p=a, size=10) b_sampler = compile_rv_inplace([], b, mode="FAST_RUN") avg = np.stack([b_sampler() for i in range(10000)]).mean(0) npt.assert_array_almost_equal(avg, 0.5 * np.ones((10,)), decimal=2)
def delta_logp(point, logp, vars, shared): [logp0], inarray0 = pm.join_nonshared_inputs(point, [logp], vars, shared) tensor_type = inarray0.type inarray1 = tensor_type("inarray1") logp1 = pm.CallableTensor(logp0)(inarray1) f = compile_rv_inplace([inarray1, inarray0], logp1 - logp0) f.trust_input = True return f
def _logp_forw(point, out_vars, vars, shared): """Compile Aesara function of the model and the input and output variables. Parameters ---------- out_vars: List containing :class:`pymc.Distribution` for the output variables vars: List containing :class:`pymc.Distribution` for the input variables shared: List containing :class:`aesara.tensor.Tensor` for depended shared data """ # Convert expected input of discrete variables to (rounded) floats if any(var.dtype in discrete_types for var in vars): replace_int_to_float = {} replace_float_to_round = {} new_vars = [] for var in vars: if var.dtype in discrete_types: float_var = at.TensorType("floatX", var.broadcastable)(var.name) replace_int_to_float[var] = float_var new_vars.append(float_var) round_float_var = at.round(float_var) round_float_var.name = var.name replace_float_to_round[float_var] = round_float_var else: new_vars.append(var) replace_int_to_float.update(shared) replace_float_to_round.update(shared) out_vars = clone_replace(out_vars, replace_int_to_float, strict=False) out_vars = clone_replace(out_vars, replace_float_to_round) vars = new_vars out_list, inarray0 = join_nonshared_inputs(point, out_vars, vars, shared) f = compile_rv_inplace([inarray0], out_list[0]) f.trust_input = True return f
def make_initial_point_fn( *, model, overrides: Optional[StartDict] = None, jitter_rvs: Optional[Set[TensorVariable]] = None, default_strategy: str = "prior", return_transformed: bool = True, ) -> Callable: """Create seeded function that computes initial values for all free model variables. Parameters ---------- jitter_rvs : set The set (or list or tuple) of random variables for which a U(-1, +1) jitter should be added to the initial value. Only available for variables that have a transform or real-valued support. default_strategy : str Which of { "moment", "prior" } to prefer if the initval setting for an RV is None. overrides : dict Initial value (strategies) to use instead of what's specified in `Model.initial_values`. return_transformed : bool If `True` the returned variables will correspond to transformed initial values. """ def find_rng_nodes(variables): return [ node for node in graph_inputs(variables) if isinstance( node, ( at.random.var.RandomStateSharedVariable, at.random.var.RandomGeneratorSharedVariable, ), ) ] overrides = convert_str_to_rv_dict(model, overrides or {}) initial_values = make_initial_point_expression( free_rvs=model.free_RVs, rvs_to_values=model.rvs_to_values, initval_strategies={ **model.initial_values, **(overrides or {}) }, jitter_rvs=jitter_rvs, default_strategy=default_strategy, return_transformed=return_transformed, ) # Replace original rng shared variables so that we don't mess with them # when calling the final seeded function graph = FunctionGraph(outputs=initial_values, clone=False) rng_nodes = find_rng_nodes(graph.outputs) new_rng_nodes = [] for rng_node in rng_nodes: if isinstance(rng_node, at.random.var.RandomStateSharedVariable): new_rng = np.random.RandomState(np.random.PCG64()) else: new_rng = np.random.Generator(np.random.PCG64()) new_rng_nodes.append(aesara.shared(new_rng)) graph.replace_all(zip(rng_nodes, new_rng_nodes), import_missing=True) func = compile_rv_inplace(inputs=[], outputs=graph.outputs, mode=aesara.compile.mode.FAST_COMPILE) varnames = [] for var in model.free_RVs: transform = getattr(model.rvs_to_values[var].tag, "transform", None) if transform is not None and return_transformed: name = get_transformed_name(var.name, transform) else: name = var.name varnames.append(name) def make_seeded_function(func): rngs = find_rng_nodes(func.maker.fgraph.outputs) @functools.wraps(func) def inner(seed, *args, **kwargs): seeds = [ np.random.PCG64(sub_seed) for sub_seed in np.random.SeedSequence(seed).spawn(len(rngs)) ] for rng, seed in zip(rngs, seeds): if isinstance(rng, at.random.var.RandomStateSharedVariable): new_rng = np.random.RandomState(seed) else: new_rng = np.random.Generator(seed) rng.set_value(new_rng, True) values = func(*args, **kwargs) return dict(zip(varnames, values)) return inner return make_seeded_function(func)
def sample_numpyro_nuts( draws=1000, tune=1000, chains=4, target_accept=0.8, random_seed=10, model=None, progress_bar=True, keep_untransformed=False, ): from numpyro.infer import MCMC, NUTS model = modelcontext(model) tic1 = pd.Timestamp.now() print("Compiling...", file=sys.stdout) rv_names = [rv.name for rv in model.value_vars] init_state = [model.initial_point[rv_name] for rv_name in rv_names] init_state_batched = jax.tree_map( lambda x: np.repeat(x[None, ...], chains, axis=0), init_state) logp_fn = get_jaxified_logp(model) nuts_kernel = NUTS( potential_fn=logp_fn, target_accept_prob=target_accept, adapt_step_size=True, adapt_mass_matrix=True, dense_mass=False, ) pmap_numpyro = MCMC( nuts_kernel, num_warmup=tune, num_samples=draws, num_chains=chains, postprocess_fn=None, chain_method="parallel", progress_bar=progress_bar, ) tic2 = pd.Timestamp.now() print("Compilation time = ", tic2 - tic1, file=sys.stdout) print("Sampling...", file=sys.stdout) seed = jax.random.PRNGKey(random_seed) map_seed = jax.random.split(seed, chains) pmap_numpyro.run(map_seed, init_params=init_state_batched, extra_fields=("num_steps", )) raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True) tic3 = pd.Timestamp.now() print("Sampling time = ", tic3 - tic2, file=sys.stdout) print("Transforming variables...", file=sys.stdout) mcmc_samples = [] for i, (value_var, raw_samples) in enumerate(zip(model.value_vars, raw_mcmc_samples)): raw_samples = at.constant(np.asarray(raw_samples)) rv = model.values_to_rvs[value_var] transform = getattr(value_var.tag, "transform", None) if transform is not None: # TODO: This will fail when the transformation depends on another variable # such as in interval transform with RVs as edges trans_samples = transform.backward(raw_samples, *rv.owner.inputs) trans_samples.name = rv.name mcmc_samples.append(trans_samples) if keep_untransformed: raw_samples.name = value_var.name mcmc_samples.append(raw_samples) else: raw_samples.name = rv.name mcmc_samples.append(raw_samples) mcmc_varnames = [var.name for var in mcmc_samples] mcmc_samples = compile_rv_inplace( [], mcmc_samples, mode="JAX", )() tic4 = pd.Timestamp.now() print("Transformation time = ", tic4 - tic3, file=sys.stdout) posterior = {k: v for k, v in zip(mcmc_varnames, mcmc_samples)} az_trace = az.from_dict(posterior=posterior) return az_trace