Exemple #1
0
def _initial_style_jaxprs_with_common_consts(funs: Sequence[Callable], in_tree,
                                             in_avals, primitive_name: str):
    # When staging the branches of a conditional into jaxprs, constants are
    # extracted from each branch and converted to jaxpr arguments. To use the
    # staged jaxprs as the branches to a conditional *primitive*, we need for
    # their (input) signatures to match. This function "joins" the staged jaxprs:
    # for each one, it makes another that accepts *all* constants, but only uses
    # those that it needs (dropping the rest).

    jaxprs, all_consts, all_out_trees = \
        unzip3(_initial_style_open_jaxpr(fun, in_tree, in_avals, primitive_name)
               for fun in funs)

    newvar = core.gensym(jaxprs, suffix='_')
    all_const_avals = [map(_abstractify, consts) for consts in all_consts]
    unused_const_vars = [
        map(newvar, const_avals) for const_avals in all_const_avals
    ]

    def pad_jaxpr_constvars(i, jaxpr):
        prefix = util.concatenate(unused_const_vars[:i])
        suffix = util.concatenate(unused_const_vars[i + 1:])
        constvars = [*prefix, *jaxpr.constvars, *suffix]
        return jaxpr.replace(constvars=constvars)

    consts = util.concatenate(all_consts)
    jaxprs = [pad_jaxpr_constvars(i, jaxpr) for i, jaxpr in enumerate(jaxprs)]
    closed_jaxprs = [
        core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
        for jaxpr in jaxprs
    ]
    return closed_jaxprs, consts, all_out_trees
Exemple #2
0
def _dynamic_jaxpr_process_xmap(self, primitive, f, tracers, params):
  from jax.interpreters.partial_eval import (
    trace_to_subjaxpr_dynamic, DynamicJaxprTracer, source_info_util,
    convert_constvars_jaxpr, new_jaxpr_eqn)
  assert primitive is xmap_p
  in_avals = [t.aval for t in tracers]
  global_axis_sizes = params['global_axis_sizes']
  mapped_in_avals = [_delete_aval_axes(a, a_in_axes)
                     for a, a_in_axes in zip(in_avals, params['in_axes'])]
  with core.extend_axis_env_nd(global_axis_sizes.items()):
    jaxpr, mapped_out_avals, consts = trace_to_subjaxpr_dynamic(
        f, self.main, mapped_in_avals)
  out_axes = params['out_axes_thunk']()
  axis_resource_count = _get_axis_resource_count(params['axis_resources'],
                                                 params['resource_env'])
  local_axis_sizes = {axis: axis_resource_count[axis].to_local(global_size)
                      for axis, global_size in global_axis_sizes.items()}
  out_avals = [_insert_aval_axes(a, a_out_axes, local_axis_sizes)
               for a, a_out_axes in zip(mapped_out_avals, out_axes)]
  _check_out_avals_vs_out_axes(out_avals, out_axes, params['global_axis_sizes'])
  source_info = source_info_util.current()
  out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals]
  invars = map(self.getvar, tracers)
  constvars = map(self.getvar, map(self.instantiate_const, consts))
  outvars = map(self.makevar, out_tracers)
  new_in_axes = (AxisNamePos(user_repr='{}'),) * len(consts) + params['in_axes']
  new_donated_invars = (False,) * len(consts) + params['donated_invars']
  new_params = dict(params, in_axes=new_in_axes, out_axes=out_axes,
                    donated_invars=new_donated_invars,
                    call_jaxpr=convert_constvars_jaxpr(jaxpr))
  del new_params['out_axes_thunk']
  eqn = new_jaxpr_eqn([*constvars, *invars], outvars, primitive,
                      new_params, source_info)
  self.frame.eqns.append(eqn)
  return out_tracers
Exemple #3
0
def _sparsify_jaxpr(spenv, jaxpr, *argspecs):
    # TODO(jakevdp): currently this approach discards all information about
    #   shared data & indices when generating the sparsified jaxpr. The
    #   current approach produces valid sparsified while loops, but they
    #   don't work in corner cases (see associated TODO in sparsify_test.py)
    out_tree = None

    @lu.wrap_init
    def wrapped(*args_flat):
        nonlocal out_tree
        args = tree_unflatten(in_tree, args_flat)
        argspecs = arrays_to_argspecs(spenv, args)
        result = eval_sparse(jaxpr.jaxpr, jaxpr.consts, argspecs, spenv)
        out = argspecs_to_arrays(spenv, result)
        out_flat, out_tree = tree_flatten(out)
        return out_flat

    args = argspecs_to_arrays(spenv, argspecs)
    args_flat, in_tree = tree_flatten(args)
    avals_flat = [
        core.raise_to_shaped(core.get_aval(arg)) for arg in args_flat
    ]
    sp_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped, avals_flat)
    sp_jaxpr = pe.ClosedJaxpr(pe.convert_constvars_jaxpr(sp_jaxpr), consts)
    return sp_jaxpr, out_tree
Exemple #4
0
def _custom_derivative_call_jaxpr_callback_rule(primitive, trace, *tracers,
                                                fun_jaxpr, num_consts, **params):
  main = trace.main
  vals = [t.val for t in tracers]

  new_closed_jaxpr = callback_jaxpr(fun_jaxpr, trace.callback, strip_calls=trace.strip_calls)
  if primitive == cd.custom_jvp_call_jaxpr_p:
    thunk_name = 'jvp_jaxpr_thunk'
  elif primitive == cd.custom_vjp_call_jaxpr_p:
    thunk_name = 'fwd_jaxpr_thunk'
    params['bwd'] = callback_subtrace(params['bwd'], main)
  else:
    raise NotImplementedError(primitive)

  thunk = params.pop(thunk_name)
  @pe._memoize
  def new_thunk():
    thunk_jaxpr = core.ClosedJaxpr(*thunk())
    closed_jaxpr = callback_jaxpr(thunk_jaxpr, trace.callback, trace.strip_calls)
    return closed_jaxpr.jaxpr, closed_jaxpr.literals

  params[thunk_name] = new_thunk
  new_fun_jaxpr, new_consts = new_closed_jaxpr.jaxpr, new_closed_jaxpr.literals
  closed_fun_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(new_fun_jaxpr), ())
  new_num_consts = len(new_consts) + num_consts
  out = primitive.bind(*it.chain(new_consts, vals), fun_jaxpr=closed_fun_jaxpr,
                       num_consts=new_num_consts, **params)
  return safe_map(trace.pure, out)
Exemple #5
0
    def trace_to_jaxpr_finalize(in_tracers,
                                out_tracers,
                                trace,
                                instantiate=True):
        # TODO: This is the final part of the partial_eval.trace_to_subjaxpr. Share.
        instantiate = [instantiate] * len(out_tracers)
        out_tracers = safe_map(trace.full_raise,
                               safe_map(core.full_lower, out_tracers))
        out_tracers = safe_map(partial(pe.instantiate_const_at, trace),
                               instantiate, out_tracers)
        jaxpr, consts, env = pe.tracers_to_jaxpr(in_tracers, out_tracers)
        out_pvals = [t.pval for t in out_tracers]
        # TODO: this is from partial_eval.trace_to_jaxpr. Share.
        assert not env

        # TODO: this is from the final part of lax_control_flow._initial_style_jaxpr
        out_avals = safe_map(abstract_arrays.raise_to_shaped,
                             unzip2(out_pvals)[0])
        const_avals = tuple(
            abstract_arrays.raise_to_shaped(core.get_aval(c)) for c in consts)

        in_pvals = [t.pval for t in in_tracers]
        in_avals = tuple(
            safe_map(abstract_arrays.raise_to_shaped,
                     unzip2(in_pvals)[0]))

        typed_jaxpr = core.TypedJaxpr(pe.convert_constvars_jaxpr(jaxpr), (),
                                      const_avals + in_avals, out_avals)
        return typed_jaxpr, consts
Exemple #6
0
def remat_transpose(reduce_axes, out_cts, *in_primals, jaxpr, **params):
  assert not jaxpr.constvars
  cell = lambda: None

  @lu.wrap_init
  def transposed(*args):
    in_primals, out_cts = tree_unflatten(treedef, args)
    in_pvals = [pe.PartialVal.unknown(x.aval) if ad.is_undefined_primal(x) else
                pe.PartialVal.known(x) for x in in_primals]
    primal_fun = lu.wrap_init(partial(core.eval_jaxpr, jaxpr, ()))
    t_jaxpr, _, consts = pe.trace_to_jaxpr_nounits(primal_fun, in_pvals, False)
    dummy_args = [ad.UndefinedPrimal(v.aval) for v in t_jaxpr.invars]
    in_cts = ad.backward_pass(t_jaxpr, reduce_axes, False, consts, dummy_args,
                              out_cts)
    in_cts_ = iter(in_cts)
    in_cts = [next(in_cts_) if ad.is_undefined_primal(x)
              else ad_util.Zero(x.aval) for x in in_primals]
    assert next(in_cts_, None) is None
    in_cts, cell.treedef = tree_flatten(in_cts)
    return in_cts

  args, treedef = tree_flatten((in_primals, out_cts))
  in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args]
  transposed_jaxpr_, _, consts = pe.trace_to_jaxpr_dynamic(transposed, in_avals)
  transposed_jaxpr = pe.convert_constvars_jaxpr(transposed_jaxpr_)
  in_cts = remat_p.bind(*consts, *args, jaxpr=transposed_jaxpr, **params)
  return tree_unflatten(cell.treedef, in_cts)  # type: ignore
Exemple #7
0
def _dynamic_jaxpr_process_xmap(self, primitive, f, tracers, params):
  from jax.interpreters.partial_eval import (
    trace_to_subjaxpr_dynamic, DynamicJaxprTracer, source_info_util,
    convert_constvars_jaxpr, call_param_updaters, new_jaxpr_eqn)
  assert primitive is xmap_p
  in_avals = [t.aval for t in tracers]
  axis_sizes = params['axis_sizes']
  mapped_in_avals = [_delete_aval_axes(a, a_in_axes)
                     for a, a_in_axes in zip(in_avals, params['in_axes'])]
  with core.extend_axis_env_nd(params['axis_sizes'].items()):
    jaxpr, mapped_out_avals, consts = trace_to_subjaxpr_dynamic(
        f, self.main, mapped_in_avals)
  out_axes = params['out_axes_thunk']()
  out_avals = [_insert_aval_axes(a, a_out_axes, axis_sizes)
               for a, a_out_axes in zip(mapped_out_avals, out_axes)]
  source_info = source_info_util.current()
  out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals]
  invars = map(self.getvar, tracers)
  constvars = map(self.getvar, map(self.instantiate_const, consts))
  outvars = map(self.makevar, out_tracers)
  new_in_axes = (None,) * len(consts) + params['in_axes']
  new_params = dict(params, in_axes=new_in_axes, out_axes=out_axes,
                    call_jaxpr=convert_constvars_jaxpr(jaxpr))
  del new_params['out_axes_thunk']
  update_params = call_param_updaters.get(primitive)
  if update_params:
    new_params = update_params(new_params, [True] * len(tracers))
  eqn = new_jaxpr_eqn([*constvars, *invars], outvars, primitive,
                      new_params, source_info)
  self.frame.eqns.append(eqn)
  return out_tracers
Exemple #8
0
 def _bound_output_tracers(self, primitive, params, jaxpr, consts, env,
                           in_tracers, out_pvs, out_consts, out_keys, name,
                           is_map):
     """Takes a traced function and binds the Jaxpr to output tracers."""
     lifted_jaxpr = pe.convert_constvars_jaxpr(jaxpr)
     const_tracers = safe_map(self.new_instantiated_const, consts)
     env_tracers = safe_map(self.instantiate_const, env)
     out_tracers = [
         UnzipTracer(self, pe.PartialVal((pv, const)), None, key)
         for pv, const, key in safe_zip(out_pvs, out_consts, out_keys)
     ]
     new_params = dict(params, name=name, call_jaxpr=lifted_jaxpr)
     if 'donated_invars' in params:
         new_donated_invars = (
             (False, ) * len(const_tracers) + (False, ) * len(env_tracers) +
             tuple(v for v, t in zip(params['donated_invars'], in_tracers)
                   if not t.pval.is_known()))
         new_params['donated_invars'] = new_donated_invars
     if is_map:
         out_axes = params['out_axes_thunk']()
         assert all(out_axis == 0 for out_axis in out_axes)
         new_params['out_axes'] = (0, ) * len(out_tracers)
         del new_params['out_axes_thunk']
     eqn = pe.new_eqn_recipe(tuple(const_tracers + env_tracers +
                                   in_tracers), out_tracers, primitive,
                             new_params, source_info_util.current())  # pytype: disable=wrong-arg-types
     for t in out_tracers:
         t.recipe = eqn
     return out_tracers
Exemple #9
0
def _initial_style_jaxpr(fun: Callable,
                         in_tree,
                         in_avals,
                         primitive_name: Optional[str] = None):
    jaxpr, consts, out_tree = _initial_style_open_jaxpr(
        fun, in_tree, in_avals, primitive_name)
    closed_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
    return closed_jaxpr, consts, out_tree
Exemple #10
0
def make_transpose_from_thunk(thunk, lin_tree):
  transpose_jaxpr, transpose_consts = thunk()
  transpose_jaxpr = core.ClosedJaxpr(
      pe.convert_constvars_jaxpr(transpose_jaxpr), ())
  def transpose(res_arg, ct_out):
    args_flat = tree_leaves((res_arg, ct_out))
    ct_ins = core.jaxpr_as_fun(transpose_jaxpr)(*transpose_consts, *args_flat)
    return tree_unflatten(lin_tree, ct_ins)
  return transpose
Exemple #11
0
 def trace_to_jaxpr_finalize(in_tracers, out_tracers, trace, instantiate=True):
   # TODO: This is the final part of the partial_eval.trace_to_subjaxpr. Share.
   instantiate = [instantiate] * len(out_tracers)
   out_tracers = safe_map(trace.full_raise, safe_map(core.full_lower, out_tracers))
   out_tracers = safe_map(partial(pe.instantiate_const_at, trace),
                          instantiate, out_tracers)
   jaxpr, consts, env = pe.tracers_to_jaxpr(in_tracers, out_tracers)
   assert not env  # TODO: this is from partial_eval.trace_to_jaxpr. Share.
   closed_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
   return closed_jaxpr, consts
Exemple #12
0
 def fun_remat(*args, **kwargs):
   args_flat, in_tree = tree_flatten((args, kwargs))
   flat_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
   in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
   debug = pe.debug_info(fun, in_tree, False, "checkpoint")
   jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
   out_flat = remat_p.bind(
       *consts, *args_flat, jaxpr=pe.convert_constvars_jaxpr(jaxpr),
       prevent_cse=prevent_cse, differentiated=False, policy=policy)
   return tree_unflatten(out_tree(), out_flat)
Exemple #13
0
def make_jaxpr(fun: Callable, in_tree: PyTreeDef,
               in_avals: typing.Tuple[core.AbstractValue],  # with DBIdx in them
               keep_inputs: typing.Tuple[bool]
               ) -> typing.Tuple[core.Jaxpr, List[Any], PyTreeDef]:
  flat_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
  debug = pe.debug_info(fun, in_tree, False, "dex_jit")
  jaxpr_, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug,
                                                keep_inputs=keep_inputs)
  jaxpr = pe.convert_constvars_jaxpr(jaxpr_)
  consts = [_canonicalize_arg(c) for c in consts]
  return jaxpr, consts, out_tree()
Exemple #14
0
 def __call__(self, *args, **kwargs):
   assert not kwargs
   args_flat, in_tree = tree_flatten(args)
   flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(self.fun), in_tree)
   in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
   debug = pe.debug_info(self.fun, in_tree, False, "custom_vmap")
   jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
   assert not len(consts)
   closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
   out_flat = custom_vmap_p.bind(*consts, *args_flat,
                                 call=closed_call,
                                 rule=self.vmap_rule,
                                 in_tree=in_tree)
   return tree_unflatten(out_tree(), out_flat)
Exemple #15
0
def remat_jvp(primals, tangents, jaxpr, prevent_cse, differentiated, policy):
  assert not jaxpr.constvars
  in_nonzeros = [type(t) is not ad_util.Zero for t in tangents]
  jaxpr_ = core.ClosedJaxpr(jaxpr, ())
  jaxpr_jvp_, out_nonzeros = ad.jvp_jaxpr(jaxpr_, in_nonzeros, False)
  nonzero_tangents = [t for t in tangents if type(t) is not ad_util.Zero]
  jaxpr_jvp = pe.convert_constvars_jaxpr(jaxpr_jvp_.jaxpr)
  outs = remat_p.bind(
      *jaxpr_jvp_.consts, *primals, *nonzero_tangents, jaxpr=jaxpr_jvp,
      prevent_cse=prevent_cse, differentiated=differentiated, policy=policy)
  out_primals, out_tangents_ = split_list(outs, [len(jaxpr.outvars)])
  out_tangents_ = iter(out_tangents_)
  out_tangents = [next(out_tangents_) if nz else ad_util.Zero.from_value(p)
                  for p, nz in zip(out_primals, out_nonzeros)]
  return out_primals, out_tangents
Exemple #16
0
def custom_jvp_call_jaxpr(fun, jvp, *args):
    """A convenience wrapper to apply the custom_jvp_call_jaxpr primitive."""
    in_avals = [
        abstract_arrays.raise_to_shaped(jax_core.get_aval(x)) for x in args
    ]
    fun_jaxpr, consts = cd._initial_style_jaxpr(  # pylint: disable=protected-access
        fun, in_avals)  # consts can be tracers!
    closed_fun_jaxpr = jax_core.ClosedJaxpr(
        pe.convert_constvars_jaxpr(fun_jaxpr), ())
    jvp_jaxpr_thunk = pe._memoize(  # pylint: disable=protected-access
        lambda: cd._initial_style_jaxpr(jvp, in_avals * 2))  # pylint: disable=protected-access
    return cd.custom_jvp_call_jaxpr_p.bind(*consts,
                                           *args,
                                           fun_jaxpr=closed_fun_jaxpr,
                                           jvp_jaxpr_thunk=jvp_jaxpr_thunk,
                                           num_consts=len(consts))
Exemple #17
0
def custom_vjp_call_jaxpr(fun, fwd, bwd, *args, out_trees):
    in_avals = [
        abstract_arrays.raise_to_shaped(jax_core.get_aval(x)) for x in args
    ]
    fun_jaxpr, consts = cd._initial_style_jaxpr(  # pylint: disable=protected-access
        fun, in_avals)  # consts can be tracers!
    closed_fun_jaxpr = jax_core.ClosedJaxpr(
        pe.convert_constvars_jaxpr(fun_jaxpr), ())
    fwd_jaxpr_thunk = pe._memoize(
        lambda: cd._initial_style_jaxpr(fwd, in_avals))  # pylint: disable=protected-access
    return cd.custom_vjp_call_jaxpr_p.bind(*consts,
                                           *args,
                                           fun_jaxpr=closed_fun_jaxpr,
                                           fwd_jaxpr_thunk=fwd_jaxpr_thunk,
                                           bwd=bwd,
                                           out_trees=out_trees,
                                           num_consts=len(consts))
Exemple #18
0
    def __call__(self, residual_arg, linear_arg):
        res_arg, lin_arg = residual_arg, linear_arg
        _, res_tree = tree_flatten(res_arg)
        _, lin_tree = tree_flatten(lin_arg)
        args_flat, in_tree = tree_flatten((res_arg, lin_arg))

        flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(self.fun),
                                                  in_tree)
        in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
        debug = pe.debug_info(self.fun, in_tree, False, "custom_transpose")
        jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
        assert not len(consts)
        closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
        out_flat = custom_transpose_p.bind(*consts,
                                           *args_flat,
                                           call=closed_call,
                                           rule=self.transpose,
                                           lin_tree=lin_tree,
                                           res_tree=res_tree,
                                           out_tree=out_tree())
        return tree_unflatten(out_tree(), out_flat)
def _close_jaxpr(jaxpr):
    return core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
Exemple #20
0
def _closed_call_transpose(params, jaxpr, args, ct, cts_in_avals, reduce_axes):
  jaxpr_, consts = jaxpr.jaxpr, jaxpr.consts
  jaxpr_ = pe.convert_constvars_jaxpr(jaxpr_)
  return call_transpose(core.closed_call_p, params, jaxpr_, (*consts, *args),
                        ct, cts_in_avals, reduce_axes)
Exemple #21
0
def _flat_initial_style_jaxpr(fun: Callable, in_avals):
    """lax_control_flow._initial_style_jaxpr, but for flat arguments and results."""
    jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(fun, in_avals)
    closed_jaxpr = ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
    return closed_jaxpr, consts
Exemple #22
0
def _flat_initial_style_jaxpr(fun, in_avals):
    """lax_control_flow._initial_style_jaxpr, but for flat arguments and results."""
    jaxpr, out_avals, consts = _instantiated_trace_to_jaxpr(fun, in_avals)
    return TypedJaxpr(convert_constvars_jaxpr(jaxpr), (),
                      in_avals=_abstractified(consts) + in_avals,
                      out_avals=map(raise_to_shaped, out_avals)), consts