def _transform_samples(samples, model, keep_untransformed=False): # Find out which RVs we need to compute: free_rv_names = {x.name for x in model.free_RVs} unobserved_names = {x.name for x in model.unobserved_RVs} names_to_compute = unobserved_names - free_rv_names ops_to_compute = [ x for x in model.unobserved_RVs if x.name in names_to_compute ] # Create function graph for these: fgraph = theano.graph.fg.FunctionGraph(model.free_RVs, ops_to_compute) # Jaxify, which returns a list of functions, one for each op jax_fns = jax_funcify(fgraph) # Put together the inputs inputs = [samples[x.name] for x in model.free_RVs] for cur_op, cur_jax_fn in zip(ops_to_compute, jax_fns): # We need a function taking a single argument to run vmap, while the # jax_fn takes a list, so: result = jax.vmap(jax.vmap(cur_jax_fn))(*inputs) # Add to sample dict samples[cur_op.name] = result # Discard unwanted transformed variables, if desired: vars_to_keep = set( pm.util.get_default_varnames(list(samples.keys()), include_transformed=keep_untransformed)) samples = {x: y for x, y in samples.items() if x in vars_to_keep} return samples
def create_jax_thunks(self, compute_map, storage_map): """Create a thunk for each output of the `Linker`s `FunctionGraph`. This is differs from the other thunk-making function in that it only produces thunks for the `FunctionGraph` output nodes. Parameters ---------- compute_map: dict The compute map dictionary. storage_map: dict The storage map dictionary. Returns ------- thunks: list A tuple containing the thunks. output_nodes: list and their A tuple containing the output nodes. """ import jax from theano.link.jax.jax_dispatch import jax_funcify output_nodes = [o.owner for o in self.fgraph.outputs] # Create a JAX-compilable function from our `FunctionGraph` jaxed_fgraph_outputs = jax_funcify(self.fgraph) assert len(jaxed_fgraph_outputs) == len(output_nodes) # I suppose we can consider `Constant`s to be "static" according to # JAX. static_argnums = [ n for n, i in enumerate(self.fgraph.inputs) if isinstance(i, Constant) ] thunk_inputs = [storage_map[n] for n in self.fgraph.inputs] thunks = [] for node, jax_funcs in zip(output_nodes, jaxed_fgraph_outputs): thunk_outputs = [storage_map[n] for n in node.outputs] if not isinstance(jax_funcs, Sequence): jax_funcs = [jax_funcs] jax_impl_jits = [ jax.jit(jax_func, static_argnums) for jax_func in jax_funcs ] def thunk(node=node, jax_impl_jits=jax_impl_jits, thunk_outputs=thunk_outputs): outputs = [ jax_impl_jit(*[x[0] for x in thunk_inputs]) for jax_impl_jit in jax_impl_jits ] if len(jax_impl_jits) < len(node.outputs): # In this case, the JAX function will output a single # output that contains the other outputs. # This happens for multi-output `Op`s that directly # correspond to multi-output JAX functions (e.g. `SVD` and # `jax.numpy.linalg.svd`). outputs = outputs[0] for o_node, o_storage, o_val in zip(node.outputs, thunk_outputs, outputs): compute_map[o_node][0] = True if len(o_storage) > 1: assert len(o_storage) == len(o_val) for i, o_sub_val in enumerate(o_val): o_storage[i] = o_sub_val else: o_storage[0] = o_val return outputs thunk.inputs = thunk_inputs thunk.outputs = thunk_outputs thunk.lazy = False thunks.append(thunk) return thunks, output_nodes
def sample_tfp_mhrw( draws=1000, tune=5000, burnin=1000, thin=0, chains=4, random_seed=10, model=None, num_tuning_epoch=5, step_size=.1, ): import jax from tensorflow_probability.substrates import jax as tfp model = modelcontext(model) seed = jax.random.PRNGKey(random_seed) fgraph = theano.graph.fg.FunctionGraph(model.free_RVs, [model.logpt]) fns = jax_funcify(fgraph) logp_fn_jax = fns[0] rv_names = [rv.name for rv in model.free_RVs] init_state = [model.test_point[rv_name] for rv_name in rv_names] init_state_batched = jax.tree_map( lambda x: np.repeat(x[None, ...], chains, axis=0), init_state) @jax.vmap def _sample(init_state, seed, step_size): def trace_is_accepted(states, previous_kernel_results): return previous_kernel_results.is_accepted def gen_kernel(step_size): kernel_ = tfp.mcmc.RandomWalkMetropolis( target_log_prob_fn=logp_fn_jax, new_state_fn=tfp.mcmc.random_walk_normal_fn(scale=step_size)) return kernel_ accept_rate = 0. for i in range(num_tuning_epoch - 1): print(jax.numpy.mean(accept_rate)) print(jax.numpy.std(step_size)) #print(f"Tuning step {i+1:2.0f} of {num_tuning_epoch:2.0f}. Accept rate: {jax.numpy.mean(accept_rate):1.4f}") tuning_mhrw = gen_kernel(step_size) samples, stats = tfp.mcmc.sample_chain( num_results=burnin // num_tuning_epoch, current_state=init_state, kernel=tuning_mhrw, num_burnin_steps=burnin, num_steps_between_results=thin, trace_fn=trace_is_accepted, seed=seed) #pdb.set_trace() accept_rate = jax.numpy.ravel(stats).mean() step_size = vtune(step_size, accept_rate) # Run inference sample_kernel = gen_kernel(step_size) mcmc_samples, mcmc_stats = tfp.mcmc.sample_chain( num_results=draws, num_burnin_steps=tune // num_tuning_epoch, current_state=init_state, kernel=sample_kernel, trace_fn=trace_is_accepted, seed=seed, ) return mcmc_samples, mcmc_stats print("Compiling and sampling...") tic2 = pd.Timestamp.now() #pdb.set_trace() map_seed = jax.random.split(seed, chains) map_stepsize = jax.tree_map(jax.numpy.ones, chains) mcmc_samples, accept_rate = _sample(init_state_batched, map_seed, map_stepsize) # map_seed = jax.random.split(seed, chains) # mcmc_samples = _sample(init_state_batched, map_seed) # tic4 = pd.Timestamp.now() # print("Sampling time = ", tic4 - tic3) posterior = {k: v for k, v in zip(rv_names, mcmc_samples)} az_trace = az.from_dict(posterior=posterior) tic3 = pd.Timestamp.now() print("Compilation + sampling time = ", tic3 - tic2) return az_trace, accept_rate
def sample_tfp_nuts( draws=1000, tune=1000, chains=4, target_accept=0.8, random_seed=10, model=None, num_tuning_epoch=2, num_compute_step_size=500, ): import jax from tensorflow_probability.substrates import jax as tfp model = modelcontext(model) seed = jax.random.PRNGKey(random_seed) fgraph = theano.graph.fg.FunctionGraph(model.free_RVs, [model.logpt]) fns = jax_funcify(fgraph) logp_fn_jax = fns[0] rv_names = [rv.name for rv in model.free_RVs] init_state = [model.test_point[rv_name] for rv_name in rv_names] init_state_batched = jax.tree_map( lambda x: np.repeat(x[None, ...], chains, axis=0), init_state) @jax.pmap def _sample(init_state, seed): def gen_kernel(step_size): hmc = tfp.mcmc.NoUTurnSampler(target_log_prob_fn=logp_fn_jax, step_size=step_size) return tfp.mcmc.DualAveragingStepSizeAdaptation( hmc, tune // num_tuning_epoch, target_accept_prob=target_accept) def trace_fn(_, pkr): return pkr.new_step_size def get_tuned_stepsize(samples, step_size): return step_size[-1] * jax.numpy.std( samples[-num_compute_step_size:]) step_size = jax.tree_map(jax.numpy.ones_like, init_state) for i in range(num_tuning_epoch - 1): tuning_hmc = gen_kernel(step_size) init_samples, tuning_result, kernel_results = tfp.mcmc.sample_chain( num_results=tune // num_tuning_epoch, current_state=init_state, kernel=tuning_hmc, trace_fn=trace_fn, return_final_kernel_results=True, seed=seed, ) step_size = jax.tree_multimap(get_tuned_stepsize, list(init_samples), tuning_result) init_state = [x[-1] for x in init_samples] # Run inference sample_kernel = gen_kernel(step_size) mcmc_samples, leapfrog_num = tfp.mcmc.sample_chain( num_results=draws, num_burnin_steps=tune // num_tuning_epoch, current_state=init_state, kernel=sample_kernel, trace_fn=lambda _, pkr: pkr.inner_results.leapfrogs_taken, seed=seed, ) return mcmc_samples, leapfrog_num print("Compiling and sampling...") tic2 = pd.Timestamp.now() map_seed = jax.random.split(seed, chains) mcmc_samples, leapfrog_num = _sample(init_state_batched, map_seed) # map_seed = jax.random.split(seed, chains) # mcmc_samples = _sample(init_state_batched, map_seed) # tic4 = pd.Timestamp.now() # print("Sampling time = ", tic4 - tic3) posterior = {k: v for k, v in zip(rv_names, mcmc_samples)} az_trace = az.from_dict(posterior=posterior) tic3 = pd.Timestamp.now() print("Compilation + sampling time = ", tic3 - tic2) return az_trace # , leapfrog_num, tic3 - tic2
def sample_numpyro_nuts_vmap(draws=1000, tune=1000, chains=4, target_accept=0.8, random_seed=10, model=None, progress_bar=True, chain_method="parallel"): from numpyro.infer import MCMC, NUTS from pymc3 import modelcontext model = modelcontext(model) seed = jax.random.PRNGKey(random_seed) fgraph = theano.graph.fg.FunctionGraph(model.free_RVs, [model.logpt]) fns = jax_funcify(fgraph) logp_fn_jax = fns[0] rv_names = [rv.name for rv in model.free_RVs] init_state = [model.test_point[rv_name] for rv_name in rv_names] init_state_batched = jax.tree_map( lambda x: np.repeat(x[None, ...], chains, axis=0), init_state) @jax.jit def _sample(current_state, seed): step_size = jax.tree_map(jax.numpy.ones_like, init_state) nuts_kernel = NUTS( potential_fn=lambda x: -logp_fn_jax(*x), # model=model, target_accept_prob=target_accept, adapt_step_size=True, adapt_mass_matrix=True, dense_mass=False, ) pmap_numpyro = MCMC( nuts_kernel, num_warmup=tune, num_samples=draws, num_chains=chains, postprocess_fn=None, chain_method=chain_method, progress_bar=progress_bar, ) pmap_numpyro.run(seed, init_params=current_state, extra_fields=("num_steps", )) samples = pmap_numpyro.get_samples(group_by_chain=True) leapfrogs_taken = pmap_numpyro.get_extra_fields( group_by_chain=True)["num_steps"] return samples, leapfrogs_taken print("Compiling and sampling...") tic2 = pd.Timestamp.now() map_seed = jax.random.split(seed, chains) mcmc_samples, leapfrogs_taken = _sample(init_state_batched, map_seed) # map_seed = jax.random.split(seed, chains) # mcmc_samples = _sample(init_state_batched, map_seed) # tic4 = pd.Timestamp.now() # print("Sampling time = ", tic4 - tic3) posterior = {k: v for k, v in zip(rv_names, mcmc_samples)} az_trace = az.from_dict(posterior=posterior) tic3 = pd.Timestamp.now() print("Compilation + sampling time = ", tic3 - tic2) return az_trace # , leapfrogs_taken, tic3 - tic2