def test_single_chain(self, make_kernel, target_accept_rate): num_samples = 20000 sample_key, chain_key, init_key = random.split(self._seed, 3) unconstrained_log_prob = self._make_unconstrained_log_prob() initial_state = self._initialize_state(init_key) kernel = make_kernel(unconstrained_log_prob) sample_chain = jax.jit( harvest.harvest(kernels.sample_chain(kernel, num_samples), tag=kernels.MCMC_METRICS)) true_samples = self.model.sample(sample_shape=4096, seed=sample_key) samples, metrics = sample_chain({}, chain_key, initial_state) onp.testing.assert_allclose(true_samples.mean(axis=0), samples.mean(axis=0), rtol=0.5, atol=0.1) onp.testing.assert_allclose(np.cov(true_samples.T), np.cov(samples.T), rtol=0.5, atol=0.1) onp.testing.assert_allclose(target_accept_rate, metrics['kernel']['accept_prob'].mean(), atol=1e-2, rtol=1e-2)
def test_multiple_chains(self, make_kernel, target_accept_rate): num_chains = 16 num_samples = 4000 sample_key, chain_key, init_key = random.split(self._seed, 3) unconstrained_log_prob = self._make_unconstrained_log_prob() initial_states = jax.vmap(self._initialize_state)(random.split( init_key, num_chains)) kernel = make_kernel(unconstrained_log_prob) sample_chain = jax.jit( jax.vmap( harvest.harvest(kernels.sample_chain(kernel, num_samples), tag=kernels.MCMC_METRICS))) true_samples = self.model.sample(sample_shape=4096, seed=sample_key) samples, metrics = sample_chain({}, random.split(chain_key, num_chains), initial_states) samples = tf.nest.map_structure( lambda s, shape: s.reshape([num_chains * num_samples] + list(shape) ), samples, self.model.event_shape) onp.testing.assert_allclose(true_samples.mean(axis=0), samples.mean(axis=0), rtol=0.1, atol=0.1) onp.testing.assert_allclose(np.cov(true_samples.T), np.cov(samples.T), rtol=0.1, atol=0.1) onp.testing.assert_allclose(target_accept_rate, metrics['kernel']['accept_prob'].mean(), atol=1e-2, rtol=1e-2)
def get_summaries(f): """Transforms a function into one that additionally output summaries. Args: f: a callable. Returns: A function that when called returns the original output of `f` and a dictionary mapping summary names to their values during execution. """ return functools.partial(harvest.harvest(f, tag=SUMMARY), {})
def wrapped(*args, **kwargs): result, latents = harvest.harvest(f, tag=RANDOM_VARIABLE)(observations, *args, **kwargs) latents = { name: harvest.sow(value, tag=RANDOM_VARIABLE, name=name, mode='strict') for name, value in latents.items() } return primitive.tie_in(latents, result)
def wrapped(init_key: Key, *args, **kwargs) -> FunctionModule: has_init_key = kwargs_util.check_in_kwargs(f, init_keyword) if not has_init_key: def init_f(init_key, *args, **kwargs): del init_key, args, kwargs return {} def cau_f(variables, *args, **kwargs): return f(*args, **kwargs), variables else: with unzip.new_custom_rules(custom_unzip_rules): def fun(init_key, *args, **kwargs): kwargs = {**kwargs, init_keyword: init_key} return f(*args, **kwargs) init_f, apply_f = unzip.unzip(fun, tag=module.VARIABLE)(init_key, *args, **kwargs) cau_f = functools.partial( harvest.harvest(apply_f, tag=module.ASSIGN), {}) if name is not None: init_f = harvest.nest(init_f, scope=name) cau_f = harvest.nest(cau_f, scope=name) variables = init_f(init_key) cau_jaxpr, (in_tree, out_tree) = trace_util.stage(cau_f, dynamic=True)(variables, *args, **kwargs) if name is None: variables = { k: module.variable(val, name=k, key=init_key) for k, val in variables.items() } return FunctionModule(variables, cau_jaxpr, in_tree, out_tree, name=name) else: return module.variable(FunctionModule(variables, cau_jaxpr, in_tree, out_tree, name=name), name=name, key=init_key)
def wrapped(init_key: Key, *args, **kwargs) -> FunctionModule: has_init_key = kwargs_util.check_in_kwargs(f, init_keyword) if not has_init_key: def init_f(init_key, *args, **kwargs): del init_key, args, kwargs return {} def cau_f(variables, *args, **kwargs): return f(*args, **kwargs), variables else: def f_(init_key, *args, **kwargs): return f(*args, **kwargs, init_key=init_key) def init_f(init_key, *args, **kwargs): return harvest.reap( f_, tag=module.VARIABLE, exclusive=True)(init_key, *args, **kwargs) def apply_f(variables, *args, **kwargs): return harvest.plant( f_, tag=module.VARIABLE)(variables, random.PRNGKey(0), *args, **kwargs) cau_f = functools.partial(harvest.harvest(apply_f, tag=module.ASSIGN), {}) variables = init_f(init_key, *args, **kwargs) cau_jaxpr, (in_tree, out_tree) = trace_util.stage( cau_f, dynamic=True)(variables, *args, **kwargs) if name is None: variables = { k: module.variable(val, name=k, key=init_key) for k, val in variables.items() } return FunctionModule(variables, cau_jaxpr, in_tree, out_tree, name=name) else: mod = FunctionModule(variables, cau_jaxpr, in_tree, out_tree, name=name) if variables: return module.variable(mod, name=name, key=init_key) return mod
def wrapped(*args, **kwargs): return harvest.harvest(f, tag=RANDOM_VARIABLE, mode='plant_only')(observations, *args, **kwargs)[0]
def program(key, *args, **kwargs): return harvest.harvest( f, tag=RANDOM_VARIABLE, blocklist=names, mode=harvest.HarvestMode.PLANT_ONLY)({}, key, *args, **kwargs)[0]