예제 #1
0
파일: common.py 프로젝트: xueeinstein/jax
def _initial_style_open_jaxpr(fun: Callable,
                              in_tree,
                              in_avals,
                              primitive_name: Optional[str] = None):
    wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
    debug = pe.debug_info(fun, in_tree, False, primitive_name or "<unknown>")
    jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug)
    return jaxpr, consts, out_tree()
예제 #2
0
 def fun_remat(*args, **kwargs):
   args_flat, in_tree = tree_flatten((args, kwargs))
   flat_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
   in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
   debug = pe.debug_info(fun, in_tree, False, "checkpoint")
   jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
   out_flat = remat_p.bind(
       *consts, *args_flat, jaxpr=pe.convert_constvars_jaxpr(jaxpr),
       prevent_cse=prevent_cse, differentiated=False, policy=policy)
   return tree_unflatten(out_tree(), out_flat)
예제 #3
0
def make_jaxpr(fun: Callable, in_tree: PyTreeDef,
               in_avals: typing.Tuple[core.AbstractValue],  # with DBIdx in them
               keep_inputs: typing.Tuple[bool]
               ) -> typing.Tuple[core.Jaxpr, List[Any], PyTreeDef]:
  flat_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
  debug = pe.debug_info(fun, in_tree, False, "dex_jit")
  jaxpr_, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug,
                                                keep_inputs=keep_inputs)
  jaxpr = pe.convert_constvars_jaxpr(jaxpr_)
  consts = [_canonicalize_arg(c) for c in consts]
  return jaxpr, consts, out_tree()
예제 #4
0
 def __call__(self, *args, **kwargs):
   assert not kwargs
   args_flat, in_tree = tree_flatten(args)
   flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(self.fun), in_tree)
   in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
   debug = pe.debug_info(self.fun, in_tree, False, "custom_vmap")
   jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
   assert not len(consts)
   closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
   out_flat = custom_vmap_p.bind(*consts, *args_flat,
                                 call=closed_call,
                                 rule=self.vmap_rule,
                                 in_tree=in_tree)
   return tree_unflatten(out_tree(), out_flat)
예제 #5
0
    def __call__(self, residual_arg, linear_arg):
        res_arg, lin_arg = residual_arg, linear_arg
        _, res_tree = tree_flatten(res_arg)
        _, lin_tree = tree_flatten(lin_arg)
        args_flat, in_tree = tree_flatten((res_arg, lin_arg))

        flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(self.fun),
                                                  in_tree)
        in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
        debug = pe.debug_info(self.fun, in_tree, False, "custom_transpose")
        jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
        assert not len(consts)
        closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
        out_flat = custom_transpose_p.bind(*consts,
                                           *args_flat,
                                           call=closed_call,
                                           rule=self.transpose,
                                           lin_tree=lin_tree,
                                           res_tree=res_tree,
                                           out_tree=out_tree())
        return tree_unflatten(out_tree(), out_flat)