コード例 #1
0
    def wrapped(init_key: Key, *args, **kwargs) -> FunctionModule:
        has_init_key = kwargs_util.check_in_kwargs(f, init_keyword)
        if not has_init_key:

            def init_f(init_key, *args, **kwargs):
                del init_key, args, kwargs
                return {}

            def cau_f(variables, *args, **kwargs):
                return f(*args, **kwargs), variables
        else:
            with unzip.new_custom_rules(custom_unzip_rules):

                def fun(init_key, *args, **kwargs):
                    kwargs = {**kwargs, init_keyword: init_key}
                    return f(*args, **kwargs)

                init_f, apply_f = unzip.unzip(fun,
                                              tag=module.VARIABLE)(init_key,
                                                                   *args,
                                                                   **kwargs)
            cau_f = functools.partial(
                harvest.harvest(apply_f, tag=module.ASSIGN), {})
        if name is not None:
            init_f = harvest.nest(init_f, scope=name)
            cau_f = harvest.nest(cau_f, scope=name)
        variables = init_f(init_key)
        cau_jaxpr, (in_tree,
                    out_tree) = trace_util.stage(cau_f,
                                                 dynamic=True)(variables,
                                                               *args, **kwargs)
        if name is None:
            variables = {
                k: module.variable(val, name=k, key=init_key)
                for k, val in variables.items()
            }
            return FunctionModule(variables,
                                  cau_jaxpr,
                                  in_tree,
                                  out_tree,
                                  name=name)
        else:
            return module.variable(FunctionModule(variables,
                                                  cau_jaxpr,
                                                  in_tree,
                                                  out_tree,
                                                  name=name),
                                   name=name,
                                   key=init_key)
コード例 #2
0
    def wrapped(key):
        def sample(key):
            result = primitive.initial_style_bind(
                random_variable_p, distribution_name=dist.__class__.__name__)(
                    _sample_distribution)(key, dist)
            return result

        if name is not None:
            return ppl.random_variable(harvest.nest(sample, scope=name)(key),
                                       name=name)
        return sample(key)
コード例 #3
0
    def test_unzip_of_nest_should_nest_variables(self):
        def f(x):
            x = variable(x, name='x')
            return x

        init, apply = unzip_variable(harvest.nest(f, scope='f'))(1.)
        self.assertDictEqual(init(1.), {'f': {'x': 1}})
        self.assertEqual(apply({'f': {'x': 2.}}), 2.)

        def g(x):
            y = harvest.nest(f, scope='f1')(x + 1.)
            z = harvest.nest(f, scope='f2')(y + 1.)
            return z

        init, apply = unzip_variable(g)(1.)
        self.assertDictEqual(init(1.), {'f1': {'x': 2}, 'f2': {'x': 3.}})
        self.assertEqual(apply({'f1': {'x': 4.}, 'f2': {'x': 100.}}), 100.)
コード例 #4
0
 def f(x):
     x = harvest.nest(lambda x: x, scope='foo')(x)
     return x / 2.
コード例 #5
0
 def g(x):
     y = harvest.nest(f, scope='f1')(x + 1.)
     z = harvest.nest(f, scope='f2')(y + 1.)
     return z