def get_jaxified_logp(model: Model) -> Callable: """Compile model.logpt into an optimized jax function""" logpt = replace_shared_variables([model.logpt])[0] logpt_fgraph = FunctionGraph(outputs=[logpt], clone=False) optimize_graph(logpt_fgraph, include=["fast_run"], exclude=["cxx_only", "BlasOpt"]) # We now jaxify the optimized fgraph logp_fn = jax_funcify(logpt_fgraph) if isinstance(logp_fn, (list, tuple)): # This handles the new JAX backend, which always returns a tuple logp_fn = logp_fn[0] def logp_fn_wrap(x): res = logp_fn(*x) if isinstance(res, (list, tuple)): # This handles the new JAX backend, which always returns a tuple res = res[0] # Jax expects a potential with the opposite sign of model.logpt return -res return logp_fn_wrap
def _get_log_likelihood(model, samples): "Compute log-likelihood for all observations" data = {} for v in model.observed_RVs: logp_v = replace_shared_variables([logpt(v)]) fgraph = FunctionGraph(model.value_vars, logp_v, clone=False) optimize_graph(fgraph, include=["fast_run"], exclude=["cxx_only", "BlasOpt"]) jax_fn = jax_funcify(fgraph) result = jax.jit(jax.vmap(jax.vmap(jax_fn)))(*samples)[0] data[v.name] = result return data
def _get_ar_order(cls, rhos: TensorVariable, ar_order: Optional[int], constant: bool) -> int: """Compute ar_order given inputs If ar_order is not specified we do constant folding on the shape of rhos to retrieve it. For example, this will detect that Normal(size=(5, 3)).shape[-1] == 3, which is not known by Aesara before. Raises ------ ValueError If inferred ar_order cannot be inferred from rhos or if it is less than 1 """ if ar_order is None: shape_fg = FunctionGraph( outputs=[rhos.shape[-1]], features=[ShapeFeature()], clone=True, ) (folded_shape,) = optimize_graph(shape_fg, custom_opt=topo_constant_folding).outputs folded_shape = getattr(folded_shape, "data", None) if folded_shape is None: raise ValueError( "Could not infer ar_order from last dimension of rho. Pass it " "explictily or make sure rho have a static shape" ) ar_order = int(folded_shape) - int(constant) if ar_order < 1: raise ValueError( "Inferred ar_order is smaller than 1. Increase the last dimension " "of rho or remove constant_term" ) return ar_order
def sample_numpyro_nuts( draws=1000, tune=1000, chains=4, target_accept=0.8, random_seed=10, model=None, var_names=None, progress_bar=True, keep_untransformed=False, chain_method="parallel", ): from numpyro.infer import MCMC, NUTS model = modelcontext(model) if var_names is None: var_names = model.unobserved_value_vars vars_to_sample = list( get_default_varnames(var_names, include_transformed=keep_untransformed)) coords = { cname: np.array(cvals) if isinstance(cvals, tuple) else cvals for cname, cvals in model.coords.items() if cvals is not None } if hasattr(model, "RV_dims"): dims = { var_name: [dim for dim in dims if dim is not None] for var_name, dims in model.RV_dims.items() } else: dims = {} tic1 = pd.Timestamp.now() print("Compiling...", file=sys.stdout) rv_names = [rv.name for rv in model.value_vars] init_state = [model.initial_point[rv_name] for rv_name in rv_names] init_state_batched = jax.tree_map( lambda x: np.repeat(x[None, ...], chains, axis=0), init_state) logp_fn = get_jaxified_logp(model) nuts_kernel = NUTS( potential_fn=logp_fn, target_accept_prob=target_accept, adapt_step_size=True, adapt_mass_matrix=True, dense_mass=False, ) pmap_numpyro = MCMC( nuts_kernel, num_warmup=tune, num_samples=draws, num_chains=chains, postprocess_fn=None, chain_method=chain_method, progress_bar=progress_bar, ) tic2 = pd.Timestamp.now() print("Compilation time = ", tic2 - tic1, file=sys.stdout) print("Sampling...", file=sys.stdout) seed = jax.random.PRNGKey(random_seed) map_seed = jax.random.split(seed, chains) if chains == 1: init_params = init_state map_seed = seed else: init_params = init_state_batched pmap_numpyro.run( map_seed, init_params=init_params, extra_fields=( "num_steps", "potential_energy", "energy", "adapt_state.step_size", "accept_prob", "diverging", ), ) raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True) tic3 = pd.Timestamp.now() print("Sampling time = ", tic3 - tic2, file=sys.stdout) print("Transforming variables...", file=sys.stdout) mcmc_samples = {} for v in vars_to_sample: fgraph = FunctionGraph(model.value_vars, [v], clone=False) optimize_graph(fgraph, include=["fast_run"], exclude=["cxx_only", "BlasOpt"]) jax_fn = jax_funcify(fgraph) result = jax.vmap(jax.vmap(jax_fn))(*raw_mcmc_samples)[0] mcmc_samples[v.name] = result tic4 = pd.Timestamp.now() print("Transformation time = ", tic4 - tic3, file=sys.stdout) posterior = mcmc_samples az_trace = az.from_dict( posterior=posterior, log_likelihood=_get_log_likelihood(model, raw_mcmc_samples), observed_data=find_observations(model), sample_stats=_sample_stats_to_xarray(pmap_numpyro), coords=coords, dims=dims, ) return az_trace