def get_jaxified_graph( inputs: Optional[List[TensorVariable]] = None, outputs: Optional[List[TensorVariable]] = None, ) -> List[TensorVariable]: """Compile an Aesara graph into an optimized JAX function""" graph = _replace_shared_variables(outputs) fgraph = FunctionGraph(inputs=inputs, outputs=graph, clone=True) # We need to add a Supervisor to the fgraph to be able to run the # JAX sequential optimizer without warnings. We made sure there # are no mutable input variables, so we only need to check for # "destroyers". This should be automatically handled by Aesara # once https://github.com/aesara-devs/aesara/issues/637 is fixed. fgraph.attach_feature( Supervisor( input for input in fgraph.inputs if not (hasattr(fgraph, "destroyers") and fgraph.has_destroyers([input])) ) ) mode.JAX.optimizer.optimize(fgraph) # We now jaxify the optimized fgraph return jax_funcify(fgraph)
def get_jaxified_logp(model: Model) -> Callable: """Compile model.logpt into an optimized jax function""" logpt = replace_shared_variables([model.logpt])[0] logpt_fgraph = FunctionGraph(outputs=[logpt], clone=False) optimize_graph(logpt_fgraph, include=["fast_run"], exclude=["cxx_only", "BlasOpt"]) # We now jaxify the optimized fgraph logp_fn = jax_funcify(logpt_fgraph) if isinstance(logp_fn, (list, tuple)): # This handles the new JAX backend, which always returns a tuple logp_fn = logp_fn[0] def logp_fn_wrap(x): res = logp_fn(*x) if isinstance(res, (list, tuple)): # This handles the new JAX backend, which always returns a tuple res = res[0] # Jax expects a potential with the opposite sign of model.logpt return -res return logp_fn_wrap
def _get_log_likelihood(model, samples): "Compute log-likelihood for all observations" data = {} for v in model.observed_RVs: logp_v = replace_shared_variables([logpt(v)]) fgraph = FunctionGraph(model.value_vars, logp_v, clone=False) optimize_graph(fgraph, include=["fast_run"], exclude=["cxx_only", "BlasOpt"]) jax_fn = jax_funcify(fgraph) result = jax.jit(jax.vmap(jax.vmap(jax_fn)))(*samples)[0] data[v.name] = result return data
def test_jax_FunctionGraph_once(): """Make sure that an output is only computed once when it's referenced multiple times.""" from aesara.link.jax.dispatch import jax_funcify x = vector("x") y = vector("y") class TestOp(Op): def __init__(self): self.called = 0 def make_node(self, *args): return Apply(self, list(args), [x.type() for x in args]) def perform(self, inputs, outputs): for i, inp in enumerate(inputs): outputs[i][0] = inp[0] @jax_funcify.register(TestOp) def jax_funcify_TestOp(op, **kwargs): def func(*args, op=op): op.called += 1 return list(args) return func op1 = TestOp() op2 = TestOp() q, r = op1(x, y) outs = op2(q + r, q + r) out_fg = FunctionGraph([x, y], outs, clone=False) assert len(out_fg.outputs) == 2 out_jx = jax_funcify(out_fg) x_val = np.r_[1, 2].astype(config.floatX) y_val = np.r_[2, 3].astype(config.floatX) res = out_jx(x_val, y_val) assert len(res) == 2 assert op1.called == 1 assert op2.called == 1 res = out_jx(x_val, y_val) assert len(res) == 2 assert op1.called == 2 assert op2.called == 2
def test_jax_FunctionGraph_names(): import inspect from aesara.link.jax.dispatch import jax_funcify x = scalar("1x") y = scalar("_") z = scalar() q = scalar("def") out_fg = FunctionGraph([x, y, z, q], [x, y, z, q], clone=False) out_jx = jax_funcify(out_fg) sig = inspect.signature(out_jx) assert (x.auto_name, "_", z.auto_name, q.auto_name) == tuple(sig.parameters.keys()) assert (1, 2, 3, 4) == out_jx(1, 2, 3, 4)
def fgraph_convert(self, fgraph, **kwargs): from aesara.link.jax.dispatch import jax_funcify return jax_funcify(fgraph, **kwargs)
def sample_numpyro_nuts( draws=1000, tune=1000, chains=4, target_accept=0.8, random_seed=10, model=None, var_names=None, progress_bar=True, keep_untransformed=False, chain_method="parallel", ): from numpyro.infer import MCMC, NUTS model = modelcontext(model) if var_names is None: var_names = model.unobserved_value_vars vars_to_sample = list( get_default_varnames(var_names, include_transformed=keep_untransformed)) coords = { cname: np.array(cvals) if isinstance(cvals, tuple) else cvals for cname, cvals in model.coords.items() if cvals is not None } if hasattr(model, "RV_dims"): dims = { var_name: [dim for dim in dims if dim is not None] for var_name, dims in model.RV_dims.items() } else: dims = {} tic1 = pd.Timestamp.now() print("Compiling...", file=sys.stdout) rv_names = [rv.name for rv in model.value_vars] init_state = [model.initial_point[rv_name] for rv_name in rv_names] init_state_batched = jax.tree_map( lambda x: np.repeat(x[None, ...], chains, axis=0), init_state) logp_fn = get_jaxified_logp(model) nuts_kernel = NUTS( potential_fn=logp_fn, target_accept_prob=target_accept, adapt_step_size=True, adapt_mass_matrix=True, dense_mass=False, ) pmap_numpyro = MCMC( nuts_kernel, num_warmup=tune, num_samples=draws, num_chains=chains, postprocess_fn=None, chain_method=chain_method, progress_bar=progress_bar, ) tic2 = pd.Timestamp.now() print("Compilation time = ", tic2 - tic1, file=sys.stdout) print("Sampling...", file=sys.stdout) seed = jax.random.PRNGKey(random_seed) map_seed = jax.random.split(seed, chains) if chains == 1: init_params = init_state map_seed = seed else: init_params = init_state_batched pmap_numpyro.run( map_seed, init_params=init_params, extra_fields=( "num_steps", "potential_energy", "energy", "adapt_state.step_size", "accept_prob", "diverging", ), ) raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True) tic3 = pd.Timestamp.now() print("Sampling time = ", tic3 - tic2, file=sys.stdout) print("Transforming variables...", file=sys.stdout) mcmc_samples = {} for v in vars_to_sample: fgraph = FunctionGraph(model.value_vars, [v], clone=False) optimize_graph(fgraph, include=["fast_run"], exclude=["cxx_only", "BlasOpt"]) jax_fn = jax_funcify(fgraph) result = jax.vmap(jax.vmap(jax_fn))(*raw_mcmc_samples)[0] mcmc_samples[v.name] = result tic4 = pd.Timestamp.now() print("Transformation time = ", tic4 - tic3, file=sys.stdout) posterior = mcmc_samples az_trace = az.from_dict( posterior=posterior, log_likelihood=_get_log_likelihood(model, raw_mcmc_samples), observed_data=find_observations(model), sample_stats=_sample_stats_to_xarray(pmap_numpyro), coords=coords, dims=dims, ) return az_trace
def jax_funcify_NumPyroNUTS(op, node, **kwargs): from numpyro.infer import MCMC, NUTS draws = op.draws tune = op.tune chains = op.chains target_accept = op.target_accept progress_bar = op.progress_bar seed = op.seed # Compile the "inner" log-likelihood function. This will have extra shared # variable inputs as the last arguments logp_fn = jax_funcify(op.fgraph, **kwargs) if isinstance(logp_fn, (list, tuple)): # This handles the new JAX backend, which always returns a tuple logp_fn = logp_fn[0] def _sample(*inputs): if op.nshared > 0: current_state = inputs[:-op.nshared] shared_inputs = tuple(op.fgraph.inputs[-op.nshared:]) else: current_state = inputs shared_inputs = () def log_fn_wrap(x): res = logp_fn(*( x # We manually obtain the shared values and added them # as arguments to our compiled "inner" function + tuple( v.get_value(borrow=True, return_internal_type=True) for v in shared_inputs))) if isinstance(res, (list, tuple)): # This handles the new JAX backend, which always returns a tuple res = res[0] return -res nuts_kernel = NUTS( potential_fn=log_fn_wrap, target_accept_prob=target_accept, adapt_step_size=True, adapt_mass_matrix=True, dense_mass=False, ) pmap_numpyro = MCMC( nuts_kernel, num_warmup=tune, num_samples=draws, num_chains=chains, postprocess_fn=None, chain_method="parallel", progress_bar=progress_bar, ) pmap_numpyro.run(seed, init_params=current_state, extra_fields=("num_steps", )) samples = pmap_numpyro.get_samples(group_by_chain=True) leapfrogs_taken = pmap_numpyro.get_extra_fields( group_by_chain=True)["num_steps"] return tuple(samples) + (leapfrogs_taken, ) return _sample