def wrapped(init_key: Key, *args, **kwargs) -> FunctionModule: has_init_key = kwargs_util.check_in_kwargs(f, init_keyword) if not has_init_key: def init_f(init_key, *args, **kwargs): del init_key, args, kwargs return {} def cau_f(variables, *args, **kwargs): return f(*args, **kwargs), variables else: with unzip.new_custom_rules(custom_unzip_rules): def fun(init_key, *args, **kwargs): kwargs = {**kwargs, init_keyword: init_key} return f(*args, **kwargs) init_f, apply_f = unzip.unzip(fun, tag=module.VARIABLE)(init_key, *args, **kwargs) cau_f = functools.partial( harvest.harvest(apply_f, tag=module.ASSIGN), {}) if name is not None: init_f = harvest.nest(init_f, scope=name) cau_f = harvest.nest(cau_f, scope=name) variables = init_f(init_key) cau_jaxpr, (in_tree, out_tree) = trace_util.stage(cau_f, dynamic=True)(variables, *args, **kwargs) if name is None: variables = { k: module.variable(val, name=k, key=init_key) for k, val in variables.items() } return FunctionModule(variables, cau_jaxpr, in_tree, out_tree, name=name) else: return module.variable(FunctionModule(variables, cau_jaxpr, in_tree, out_tree, name=name), name=name, key=init_key)
def wrapped(key): def sample(key): result = primitive.initial_style_bind( random_variable_p, distribution_name=dist.__class__.__name__)( _sample_distribution)(key, dist) return result if name is not None: return ppl.random_variable(harvest.nest(sample, scope=name)(key), name=name) return sample(key)
def test_unzip_of_nest_should_nest_variables(self): def f(x): x = variable(x, name='x') return x init, apply = unzip_variable(harvest.nest(f, scope='f'))(1.) self.assertDictEqual(init(1.), {'f': {'x': 1}}) self.assertEqual(apply({'f': {'x': 2.}}), 2.) def g(x): y = harvest.nest(f, scope='f1')(x + 1.) z = harvest.nest(f, scope='f2')(y + 1.) return z init, apply = unzip_variable(g)(1.) self.assertDictEqual(init(1.), {'f1': {'x': 2}, 'f2': {'x': 3.}}) self.assertEqual(apply({'f1': {'x': 4.}, 'f2': {'x': 100.}}), 100.)
def f(x): x = harvest.nest(lambda x: x, scope='foo')(x) return x / 2.
def g(x): y = harvest.nest(f, scope='f1')(x + 1.) z = harvest.nest(f, scope='f2')(y + 1.) return z