Пример #1
0
    def wrapped(init_key: Key, *args, **kwargs) -> FunctionModule:
        has_init_key = kwargs_util.check_in_kwargs(f, init_keyword)
        if not has_init_key:

            def init_f(init_key, *args, **kwargs):
                del init_key, args, kwargs
                return {}

            def cau_f(variables, *args, **kwargs):
                return f(*args, **kwargs), variables
        else:
            with unzip.new_custom_rules(custom_unzip_rules):

                def fun(init_key, *args, **kwargs):
                    kwargs = {**kwargs, init_keyword: init_key}
                    return f(*args, **kwargs)

                init_f, apply_f = unzip.unzip(fun,
                                              tag=module.VARIABLE)(init_key,
                                                                   *args,
                                                                   **kwargs)
            cau_f = functools.partial(
                harvest.harvest(apply_f, tag=module.ASSIGN), {})
        if name is not None:
            init_f = harvest.nest(init_f, scope=name)
            cau_f = harvest.nest(cau_f, scope=name)
        variables = init_f(init_key)
        cau_jaxpr, (in_tree,
                    out_tree) = trace_util.stage(cau_f,
                                                 dynamic=True)(variables,
                                                               *args, **kwargs)
        if name is None:
            variables = {
                k: module.variable(val, name=k, key=init_key)
                for k, val in variables.items()
            }
            return FunctionModule(variables,
                                  cau_jaxpr,
                                  in_tree,
                                  out_tree,
                                  name=name)
        else:
            return module.variable(FunctionModule(variables,
                                                  cau_jaxpr,
                                                  in_tree,
                                                  out_tree,
                                                  name=name),
                                   name=name,
                                   key=init_key)
Пример #2
0
  def wrapped(init_key: Key, *args, **kwargs) -> FunctionModule:
    has_init_key = kwargs_util.check_in_kwargs(f, init_keyword)
    if not has_init_key:

      def init_f(init_key, *args, **kwargs):
        del init_key, args, kwargs
        return {}

      def cau_f(variables, *args, **kwargs):
        return f(*args, **kwargs), variables
    else:

      def f_(init_key, *args, **kwargs):
        return f(*args, **kwargs, init_key=init_key)

      def init_f(init_key, *args, **kwargs):
        return harvest.reap(
            f_, tag=module.VARIABLE, exclusive=True)(init_key, *args, **kwargs)

      def apply_f(variables, *args, **kwargs):
        return harvest.plant(
            f_, tag=module.VARIABLE)(variables, random.PRNGKey(0), *args,
                                     **kwargs)

      cau_f = functools.partial(harvest.harvest(apply_f, tag=module.ASSIGN), {})
    variables = init_f(init_key, *args, **kwargs)
    cau_jaxpr, (in_tree, out_tree) = trace_util.stage(
        cau_f, dynamic=True)(variables, *args, **kwargs)
    if name is None:
      variables = {
          k: module.variable(val, name=k, key=init_key)
          for k, val in variables.items()
      }
      return FunctionModule(variables, cau_jaxpr, in_tree, out_tree, name=name)
    else:
      mod = FunctionModule(variables, cau_jaxpr, in_tree, out_tree, name=name)
      if variables:
        return module.variable(mod, name=name, key=init_key)
      return mod
Пример #3
0
 def f(x, init_key=None):
     y = module.variable(np.zeros(x.shape), name='y', key=init_key)
     next_y = module.assign(y + 1., name='y')
     return primitive.tie_in(next_y, x) + y
Пример #4
0
 def f(x, init_key=None):
     y = module.variable(np.zeros(x.shape), name='y', key=init_key)
     next_y = module.assign(y + 1., name='z')
     return x + next_y
Пример #5
0
 def f(x, init_key=None):
     y = module.variable(np.ones(x.shape), name='y', key=init_key)
     return x + y
Пример #6
0
 def f(x, init_key=None):
     w = module.variable(random.normal(init_key, x.shape), name='w')
     return np.dot(w, x)