예제 #1
0
파일: common.py 프로젝트: xueeinstein/jax
def _initial_style_jaxprs_with_common_consts(funs: Sequence[Callable], in_tree,
                                             in_avals, primitive_name: str):
    # When staging the branches of a conditional into jaxprs, constants are
    # extracted from each branch and converted to jaxpr arguments. To use the
    # staged jaxprs as the branches to a conditional *primitive*, we need for
    # their (input) signatures to match. This function "joins" the staged jaxprs:
    # for each one, it makes another that accepts *all* constants, but only uses
    # those that it needs (dropping the rest).

    jaxprs, all_consts, all_out_trees = \
        unzip3(_initial_style_open_jaxpr(fun, in_tree, in_avals, primitive_name)
               for fun in funs)

    newvar = core.gensym(jaxprs, suffix='_')
    all_const_avals = [map(_abstractify, consts) for consts in all_consts]
    unused_const_vars = [
        map(newvar, const_avals) for const_avals in all_const_avals
    ]

    def pad_jaxpr_constvars(i, jaxpr):
        prefix = util.concatenate(unused_const_vars[:i])
        suffix = util.concatenate(unused_const_vars[i + 1:])
        constvars = [*prefix, *jaxpr.constvars, *suffix]
        return jaxpr.replace(constvars=constvars)

    consts = util.concatenate(all_consts)
    jaxprs = [pad_jaxpr_constvars(i, jaxpr) for i, jaxpr in enumerate(jaxprs)]
    closed_jaxprs = [
        core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
        for jaxpr in jaxprs
    ]
    return closed_jaxprs, consts, all_out_trees
예제 #2
0
파일: callback.py 프로젝트: xueeinstein/jax
def callback_jaxpr(closed_jaxpr, callback, strip_calls):
  fun = lu.wrap_init(jaxpr_as_fun(closed_jaxpr))
  fun = callback_subtrace(fun)
  fun = _callback_fun(fun, callback, strip_calls)
  avals_in = closed_jaxpr.in_avals
  jaxpr_out, consts = cd._initial_style_jaxpr(fun, avals_in)
  return core.ClosedJaxpr(jaxpr_out, consts)
예제 #3
0
def remat_transpose(params, call_jaxpr, primals_in, cotangents_in,
                    cotangent_in_avals, reduce_axes):
    # backward_pass can only transpose linear computations, but the call_jaxpr embedded in
    # remat contains primal (non-linear) equations too. Hence, we have to eliminate those
    # (in this case via partial_eval) before we call into backward_pass again.
    typed_call_jaxpr = core.ClosedJaxpr(call_jaxpr, [])
    unknowns = map(is_undefined_primal, primals_in)
    primal_jaxpr, tangent_jaxpr, out_unknowns = \
      pe.partial_eval_jaxpr(typed_call_jaxpr, unknowns=unknowns, instantiate=True)  # type: ignore

    def do_transpose(primals_in, cotangents_in):
        # NOTE: This is passing in undefined primals in place of tangent arguments, but it
        #       should all work out, because we're only computing the primal part here.
        residuals = core.jaxpr_as_fun(primal_jaxpr)(
            *primals_in)[len(cotangents_in):]
        # Now that we have a purely linear jaxpr, we can transpose it
        cotangents_out = backward_pass(tangent_jaxpr.jaxpr, reduce_axes, (),
                                       primals_in + residuals, cotangents_in)
        # backward_pass will return cotangents computed for all invars, but some of them
        # are residuals appended by partial eval, so we need to skip those before we return.
        return cotangents_out[:len(primals_in)]

    flat_args, in_tree_def = tree_flatten((primals_in, cotangents_in))
    flat_do_transpose, out_tree = flatten_fun_nokwargs(
        lu.wrap_init(do_transpose), in_tree_def)
    flat_cotangents_out = pe.remat_call_p.bind(flat_do_transpose, *flat_args,
                                               **params)
    return tree_unflatten(out_tree(), flat_cotangents_out)
예제 #4
0
 def batched_fwd_jaxpr_thunk():
     fwd_jaxpr = core.ClosedJaxpr(
         *fwd_jaxpr_thunk())  # consts can be tracers
     batched_fwd_jaxpr, out_batched = batching.batch_jaxpr(
         fwd_jaxpr, axis_size, args_batched, False, axis_name, main_type)
     out_dims2.append([0 if b else not_mapped for b in out_batched])
     return batched_fwd_jaxpr.jaxpr, batched_fwd_jaxpr.consts
예제 #5
0
파일: callback.py 프로젝트: xueeinstein/jax
def _custom_derivative_call_jaxpr_callback_rule(primitive, trace, *tracers,
                                                fun_jaxpr, num_consts, **params):
  main = trace.main
  vals = [t.val for t in tracers]

  new_closed_jaxpr = callback_jaxpr(fun_jaxpr, trace.callback, strip_calls=trace.strip_calls)
  if primitive == cd.custom_jvp_call_jaxpr_p:
    thunk_name = 'jvp_jaxpr_thunk'
  elif primitive == cd.custom_vjp_call_jaxpr_p:
    thunk_name = 'fwd_jaxpr_thunk'
    params['bwd'] = callback_subtrace(params['bwd'], main)
  else:
    raise NotImplementedError(primitive)

  thunk = params.pop(thunk_name)
  @pe._memoize
  def new_thunk():
    thunk_jaxpr = core.ClosedJaxpr(*thunk())
    closed_jaxpr = callback_jaxpr(thunk_jaxpr, trace.callback, trace.strip_calls)
    return closed_jaxpr.jaxpr, closed_jaxpr.literals

  params[thunk_name] = new_thunk
  new_fun_jaxpr, new_consts = new_closed_jaxpr.jaxpr, new_closed_jaxpr.literals
  closed_fun_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(new_fun_jaxpr), ())
  new_num_consts = len(new_consts) + num_consts
  out = primitive.bind(*it.chain(new_consts, vals), fun_jaxpr=closed_fun_jaxpr,
                       num_consts=new_num_consts, **params)
  return safe_map(trace.pure, out)
예제 #6
0
def rearrange_binders(jaxpr: core.ClosedJaxpr, primals_in, tangents_in, primals_out, tangents_out):
  new_invars = _perm(primals_in, tangents_in, jaxpr.jaxpr.invars)
  new_outvars = _perm(primals_out, tangents_out, jaxpr.jaxpr.outvars)
  new_jaxpr = core.Jaxpr(jaxpr.jaxpr.constvars,
                         new_invars, new_outvars, jaxpr.jaxpr.eqns,
                         jaxpr.jaxpr.effects)
  return core.ClosedJaxpr(new_jaxpr, jaxpr.consts)
예제 #7
0
def _for_jvp(primals, tangents, *, jaxpr, nsteps, reverse, which_linear):
    nonzero_tangents = [not isinstance(t, ad_util.Zero) for t in tangents]
    # We need to find out which `Ref`s have nonzero tangents after running the
    # for loop. Ordinarily we do this with a fixed point on the body jaxpr but
    # a `for` body jaxpr is stateful and has no outputs. We therefore discharge
    # the state effect from the jaxpr and we will now have a "symmetric" jaxpr
    # where the inputs line up with the outputs. We use this discharged jaxpr
    # for the fixed point.
    discharged_jaxpr, body_consts = discharge_state(jaxpr, ())
    for _ in range(len(nonzero_tangents)):
        _, out_nonzero_tangents = ad.jvp_jaxpr(core.ClosedJaxpr(
            discharged_jaxpr, body_consts), [False] + nonzero_tangents,
                                               instantiate=nonzero_tangents)
        if out_nonzero_tangents == nonzero_tangents:
            break
        nonzero_tangents = map(operator.or_, nonzero_tangents,
                               out_nonzero_tangents)
    else:
        raise Exception("Invalid fixpoint")
    tangents = [
        ad.instantiate_zeros(t) if inst else t
        for t, inst in zip(tangents, nonzero_tangents)
    ]
    tangents = [t for t in tangents if type(t) is not ad_util.Zero]
    closed_jaxpr = core.ClosedJaxpr(jaxpr, ())
    jvp_jaxpr_, _ = ad.jvp_jaxpr(closed_jaxpr, [False] + nonzero_tangents, [])
    jvp_jaxpr, jvp_consts = jvp_jaxpr_.jaxpr, jvp_jaxpr_.consts
    jvp_which_linear = ((False, ) * len(jvp_consts) + which_linear +
                        (True, ) * len(tangents))
    out_flat = for_p.bind(*jvp_consts,
                          *primals,
                          *tangents,
                          jaxpr=jvp_jaxpr,
                          nsteps=nsteps,
                          reverse=reverse,
                          which_linear=jvp_which_linear)
    # `out_flat` includes constant inputs into the `for_loop` which are
    # converted into outputs as well. We don't care about these in AD so we
    # throw them out.
    _, out_primals, out_tangents = split_list(
        out_flat, [len(jvp_consts), len(primals)])
    out_tangents_iter = iter(out_tangents)
    out_tangents = [
        next(out_tangents_iter) if nz else ad_util.Zero.from_value(p)
        for p, nz in zip(out_primals, nonzero_tangents)
    ]
    return out_primals, out_tangents
예제 #8
0
def checkify_fun_to_jaxpr(f, error, enabled_errors, in_avals):
    f, msgs = checkify_subtrace(f)
    f = checkify_traceable(f, tuple(error.msgs.items()), enabled_errors)
    err_aval = core.raise_to_shaped(core.get_aval(error.err))
    code_aval = core.raise_to_shaped(core.get_aval(error.code))
    avals_in = [err_aval, code_aval, *in_avals]
    jaxpr_out, _, literals_out = pe.trace_to_jaxpr_dynamic(f, avals_in)
    return core.ClosedJaxpr(jaxpr_out, literals_out), msgs()
예제 #9
0
파일: common.py 프로젝트: xueeinstein/jax
def _initial_style_jaxpr(fun: Callable,
                         in_tree,
                         in_avals,
                         primitive_name: Optional[str] = None):
    jaxpr, consts, out_tree = _initial_style_open_jaxpr(
        fun, in_tree, in_avals, primitive_name)
    closed_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
    return closed_jaxpr, consts, out_tree
예제 #10
0
 def _initial_style_jaxpr(fun, in_avals):
     in_pvals = [pe.PartialVal.unknown(aval) for aval in in_avals]
     jaxpr, _, consts = pe.trace_to_jaxpr(fun,
                                          in_pvals,
                                          instantiate=True,
                                          bottom=True,
                                          stage_out=False)  # type: ignore
     assert not any(isinstance(c, core.Tracer) for c in consts)
     return core.ClosedJaxpr(jaxpr, consts)
예제 #11
0
파일: batching.py 프로젝트: xueeinstein/jax
def _batch_jaxpr_axes(closed_jaxpr, axis_size, in_axes, out_axes_dest,
                      axis_name, main_type):
  f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr))
  f, out_batched = _batch_jaxpr_inner(f, axis_size, out_axes_dest)
  f = _batch_jaxpr_outer(f, axis_name, axis_size, in_axes, main_type)
  avals_in = [core.unmapped_aval(axis_size, axis_name, b, aval) if b is not not_mapped
              else aval for aval, b in zip(closed_jaxpr.in_avals, in_axes)]
  jaxpr_out, _, consts = pe.trace_to_jaxpr_dynamic(f, avals_in)
  return core.ClosedJaxpr(jaxpr_out, consts), out_batched()
예제 #12
0
def make_transpose_from_thunk(thunk, lin_tree):
  transpose_jaxpr, transpose_consts = thunk()
  transpose_jaxpr = core.ClosedJaxpr(
      pe.convert_constvars_jaxpr(transpose_jaxpr), ())
  def transpose(res_arg, ct_out):
    args_flat = tree_leaves((res_arg, ct_out))
    ct_ins = core.jaxpr_as_fun(transpose_jaxpr)(*transpose_consts, *args_flat)
    return tree_unflatten(lin_tree, ct_ins)
  return transpose
예제 #13
0
def remat_vmap(axis_size, axis_name, main_type, args, dims, *, jaxpr, **params):
  assert not jaxpr.constvars
  jaxpr_ = core.ClosedJaxpr(jaxpr, ())
  jaxpr_batched_, out_batched = batching.batch_jaxpr_axes(
      jaxpr_, axis_size, dims, [batching.zero_if_mapped] * len(jaxpr.outvars),
      axis_name=axis_name, main_type=main_type)
  jaxpr_batched, consts = jaxpr_batched_.jaxpr, jaxpr_batched_.consts
  out_dims = [0 if b else None for b in out_batched]
  return remat_p.bind(*consts, *args, jaxpr=jaxpr_batched, **params), out_dims
예제 #14
0
파일: checkify.py 프로젝트: rsepassi/jax
def checkify_jaxpr(jaxpr, error):
  f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
  f, msgs = check_errors_subtrace(f)
  f = check_errors_traceable(f, tuple(error.msgs.items()))
  err_aval = core.raise_to_shaped(core.get_aval(error.err))
  code_aval = core.raise_to_shaped(core.get_aval(error.code))
  avals_in = [err_aval, code_aval, *jaxpr.in_avals]
  jaxpr_out, _, literals_out = pe.trace_to_jaxpr_dynamic(f, avals_in)
  return core.ClosedJaxpr(jaxpr_out, literals_out), msgs()
예제 #15
0
파일: ad.py 프로젝트: jbampton/jax
def _jvp_jaxpr(jaxpr, nonzeros, instantiate):
  assert len(jaxpr.in_avals) == len(nonzeros)
  f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
  f_jvp, out_nonzeros = f_jvp_traceable(jvp(f, instantiate=instantiate, transform_stack=False),
                                        nonzeros)
  tangent_avals = [aval for aval, nz in zip(jaxpr.in_avals, nonzeros) if nz]
  avals_in = list(it.chain(jaxpr.in_avals, tangent_avals))
  jaxpr_out, avals_out, literals_out = pe.trace_to_jaxpr_dynamic(f_jvp, avals_in)
  return core.ClosedJaxpr(jaxpr_out, literals_out), out_nonzeros()
예제 #16
0
 def trace_to_jaxpr_finalize(in_tracers, out_tracers, trace, instantiate=True):
   # TODO: This is the final part of the partial_eval.trace_to_subjaxpr. Share.
   instantiate = [instantiate] * len(out_tracers)
   out_tracers = safe_map(trace.full_raise, safe_map(core.full_lower, out_tracers))
   out_tracers = safe_map(partial(pe.instantiate_const_at, trace),
                          instantiate, out_tracers)
   jaxpr, consts, env = pe.tracers_to_jaxpr(in_tracers, out_tracers)
   assert not env  # TODO: this is from partial_eval.trace_to_jaxpr. Share.
   closed_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
   return closed_jaxpr, consts
예제 #17
0
def remat_vmap(axis_size, axis_name, main_type, args, dims, *, jaxpr, **params):
  assert not jaxpr.constvars
  in_batched = [d is not batching.not_mapped for d in dims]
  jaxpr_ = core.ClosedJaxpr(jaxpr, ())
  jaxpr_batched_, out_batched = batching.batch_jaxpr(
      jaxpr_, axis_size, in_batched, instantiate=False, axis_name=axis_name,
      main_type=main_type)
  jaxpr_batched, consts = jaxpr_batched_.jaxpr, jaxpr_batched_.consts
  out_dims = [0 if b else None for b in out_batched]
  return remat_p.bind(*consts, *args, jaxpr=jaxpr_batched, **params), out_dims
예제 #18
0
def ignore_errors_jaxpr(jaxpr, error):
    """Constructs a jaxpr which takes two extra args but ignores them."""
    err_aval = core.raise_to_shaped(core.get_aval(error.err))
    code_aval = core.raise_to_shaped(core.get_aval(error.code))
    consts = jaxpr.consts
    jaxpr = jaxpr.jaxpr
    new_vars = core.gensym([jaxpr])
    new_invars = (new_vars(err_aval), new_vars(code_aval), *jaxpr.invars)
    new_jaxpr = core.Jaxpr(jaxpr.constvars, new_invars, jaxpr.outvars,
                           jaxpr.eqns)
    return core.ClosedJaxpr(new_jaxpr, consts)
예제 #19
0
파일: mlir.py 프로젝트: rsepassi/jax
def _call_lowering(fn_name, stack_name, call_jaxpr, backend, ctx, avals_in,
                   avals_out, *args):
    xla.check_backend_matches(backend, ctx.platform)
    output_types = map(aval_to_ir_types, avals_out)
    flat_output_types = util.flatten(output_types)
    sub_ctx = ctx.replace(
        name_stack=xla.extend_name_stack(ctx.name_stack, stack_name))
    symbol_name = lower_jaxpr_to_fun(sub_ctx, fn_name,
                                     core.ClosedJaxpr(call_jaxpr, ()))
    call = std.CallOp(flat_output_types, ir.FlatSymbolRefAttr.get(symbol_name),
                      flatten_lowering_ir_args(args))
    return util.unflatten(call.results, map(len, output_types))
예제 #20
0
 def batched_jvp_jaxpr_thunk():
   jvp_jaxpr = core.ClosedJaxpr(*jvp_jaxpr_thunk())  # consts can be tracers
   _, args_batched = split_list(in_batched, [num_consts])
   _, all_batched = batching.batch_jaxpr(jvp_jaxpr, size, args_batched * 2, False,
                                         axis_name, main_type)
   primals_batched, tangents_batched = split_list(all_batched, [num_out])
   out_batched = map(op.or_, primals_batched, tangents_batched)
   out_dims2.append([0 if b else not_mapped for b in out_batched])
   batched_jvp_jaxpr, _ = batching.batch_jaxpr(
       jvp_jaxpr, size, args_batched * 2, out_batched * 2,
       axis_name, main_type)
   return batched_jvp_jaxpr.jaxpr, batched_jvp_jaxpr.consts
예제 #21
0
    def augment_jaxpr(jaxpr, res_indices):
        num_res = len(res_indices)
        res_vars = jaxpr.jaxpr.invars[:num_res]
        non_res_vars = jaxpr.jaxpr.invars[num_res:]

        aug_res_vars = list(
            util.subvals(all_res_vars, zip(res_indices, res_vars)))
        aug_invars = aug_res_vars + non_res_vars
        jaxpr_aug = core.Jaxpr(jaxpr.jaxpr.constvars, aug_invars,
                               jaxpr.jaxpr.outvars, jaxpr.jaxpr.eqns,
                               jaxpr.jaxpr.effects)
        jaxpr_aug = core.ClosedJaxpr(jaxpr_aug, jaxpr.consts)
        return jaxpr_aug
예제 #22
0
 def __call__(self, *args, **kwargs):
   assert not kwargs
   args_flat, in_tree = tree_flatten(args)
   flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(self.fun), in_tree)
   in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
   debug = pe.debug_info(self.fun, in_tree, False, "custom_vmap")
   jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
   assert not len(consts)
   closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
   out_flat = custom_vmap_p.bind(*consts, *args_flat,
                                 call=closed_call,
                                 rule=self.vmap_rule,
                                 in_tree=in_tree)
   return tree_unflatten(out_tree(), out_flat)
예제 #23
0
def remat_jvp(primals, tangents, jaxpr, prevent_cse, differentiated, policy):
  assert not jaxpr.constvars
  in_nonzeros = [type(t) is not ad_util.Zero for t in tangents]
  jaxpr_ = core.ClosedJaxpr(jaxpr, ())
  jaxpr_jvp_, out_nonzeros = ad.jvp_jaxpr(jaxpr_, in_nonzeros, False)
  nonzero_tangents = [t for t in tangents if type(t) is not ad_util.Zero]
  jaxpr_jvp = pe.convert_constvars_jaxpr(jaxpr_jvp_.jaxpr)
  outs = remat_p.bind(
      *jaxpr_jvp_.consts, *primals, *nonzero_tangents, jaxpr=jaxpr_jvp,
      prevent_cse=prevent_cse, differentiated=differentiated, policy=policy)
  out_primals, out_tangents_ = split_list(outs, [len(jaxpr.outvars)])
  out_tangents_ = iter(out_tangents_)
  out_tangents = [next(out_tangents_) if nz else ad_util.Zero.from_value(p)
                  for p, nz in zip(out_primals, out_nonzeros)]
  return out_primals, out_tangents
예제 #24
0
파일: unzip.py 프로젝트: yli96/probability
def custom_jvp_call_jaxpr(fun, jvp, *args):
    """A convenience wrapper to apply the custom_jvp_call_jaxpr primitive."""
    in_avals = [
        abstract_arrays.raise_to_shaped(jax_core.get_aval(x)) for x in args
    ]
    fun_jaxpr, consts = cd._initial_style_jaxpr(  # pylint: disable=protected-access
        fun, in_avals)  # consts can be tracers!
    closed_fun_jaxpr = jax_core.ClosedJaxpr(
        pe.convert_constvars_jaxpr(fun_jaxpr), ())
    jvp_jaxpr_thunk = pe._memoize(  # pylint: disable=protected-access
        lambda: cd._initial_style_jaxpr(jvp, in_avals * 2))  # pylint: disable=protected-access
    return cd.custom_jvp_call_jaxpr_p.bind(*consts,
                                           *args,
                                           fun_jaxpr=closed_fun_jaxpr,
                                           jvp_jaxpr_thunk=jvp_jaxpr_thunk,
                                           num_consts=len(consts))
예제 #25
0
 def wrapped(*args, **kwargs):
   fun = lu.wrap_init(f, kwargs)
   flat_args, in_tree = tree_util.tree_flatten(args)
   flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree)
   flat_avals = safe_map(get_shaped_aval, flat_args)
   if dynamic:
     jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(
         flat_fun,
         flat_avals)
   else:
     pvals = [pe.PartialVal.unknown(aval) for aval in flat_avals]
     jaxpr, _, consts = pe.trace_to_jaxpr(
         flat_fun,
         pvals,
         instantiate=True)
   typed_jaxpr = jax_core.ClosedJaxpr(jaxpr, consts)
   return typed_jaxpr, (in_tree, out_tree())
예제 #26
0
파일: unzip.py 프로젝트: yli96/probability
def custom_vjp_call_jaxpr(fun, fwd, bwd, *args, out_trees):
    in_avals = [
        abstract_arrays.raise_to_shaped(jax_core.get_aval(x)) for x in args
    ]
    fun_jaxpr, consts = cd._initial_style_jaxpr(  # pylint: disable=protected-access
        fun, in_avals)  # consts can be tracers!
    closed_fun_jaxpr = jax_core.ClosedJaxpr(
        pe.convert_constvars_jaxpr(fun_jaxpr), ())
    fwd_jaxpr_thunk = pe._memoize(
        lambda: cd._initial_style_jaxpr(fwd, in_avals))  # pylint: disable=protected-access
    return cd.custom_vjp_call_jaxpr_p.bind(*consts,
                                           *args,
                                           fun_jaxpr=closed_fun_jaxpr,
                                           fwd_jaxpr_thunk=fwd_jaxpr_thunk,
                                           bwd=bwd,
                                           out_trees=out_trees,
                                           num_consts=len(consts))
예제 #27
0
 def wrapped(*args, **kwargs):
     fun = lu.wrap_init(f, kwargs)
     flat_args, in_tree = tree_util.tree_flatten(args)
     flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree)
     flat_avals = safe_map(get_shaped_aval, flat_args)
     if not jax.config.omnistaging_enabled:
         raise ValueError('Oryx must be used with JAX omnistaging enabled.')
     if dynamic:
         jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals)
     else:
         pvals = [
             pe.PartialVal((aval, jax_core.unit)) for aval in flat_avals
         ]
         jaxpr, _, consts = pe.trace_to_jaxpr(flat_fun,
                                              pvals,
                                              instantiate=True)
     typed_jaxpr = jax_core.ClosedJaxpr(jaxpr, consts)
     return typed_jaxpr, (in_tree, out_tree())
예제 #28
0
def _sharded_jit_lowering(ctx, *in_nodes, in_parts, out_parts_thunk, nparts,
                          name, call_jaxpr, local_in_parts,
                          local_out_parts_thunk, local_nparts):
    # We assume any extra leading in_nodes are constants and replicate them.
    num_extra_nodes = len(in_nodes) - len(in_parts)
    assert num_extra_nodes >= 0
    in_parts = (None, ) * num_extra_nodes + in_parts

    args = []
    for ns, sharding in safe_zip(
            safe_map(mlir.wrap_singleton_ir_values, in_nodes), in_parts):
        if sharding is not None:
            args.append([
                mlir.wrap_with_sharding_op(n, xla.sharding_to_proto(sharding))
                for n in ns
            ])
        else:
            args.append(ns)

    sub_ctx = ctx.module_context.replace(
        name_stack=extend_name_stack(wrap_name(name, "sharded_jit")))
    fn = mlir.lower_jaxpr_to_fun(sub_ctx, f"sharded_jit_{name}",
                                 core.ClosedJaxpr(call_jaxpr, ()))

    output_types = safe_map(mlir.aval_to_ir_types, ctx.avals_out)
    flat_output_types = util.flatten(output_types)
    call = std.CallOp(flat_output_types,
                      ir.FlatSymbolRefAttr.get(fn.name.value),
                      mlir.flatten_lowering_ir_args(args))
    out_nodes = util.unflatten(call.results, safe_map(len, output_types))

    out_parts = out_parts_thunk()
    outputs = []
    for ns, sharding in safe_zip(out_nodes, out_parts):
        if sharding is not None:
            outputs.append([
                mlir.wrap_with_sharding_op(n, xla.sharding_to_proto(sharding))
                for n in ns
            ])
        else:
            outputs.append(ns)
    return outputs
예제 #29
0
    def __call__(self, residual_arg, linear_arg):
        res_arg, lin_arg = residual_arg, linear_arg
        _, res_tree = tree_flatten(res_arg)
        _, lin_tree = tree_flatten(lin_arg)
        args_flat, in_tree = tree_flatten((res_arg, lin_arg))

        flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(self.fun),
                                                  in_tree)
        in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
        debug = pe.debug_info(self.fun, in_tree, False, "custom_transpose")
        jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
        assert not len(consts)
        closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
        out_flat = custom_transpose_p.bind(*consts,
                                           *args_flat,
                                           call=closed_call,
                                           rule=self.transpose,
                                           lin_tree=lin_tree,
                                           res_tree=res_tree,
                                           out_tree=out_tree())
        return tree_unflatten(out_tree(), out_flat)
예제 #30
0
def _sharded_callable(
        fun: lu.WrappedFun, nparts: Optional[int],
        in_parts: Tuple[pxla.PartitionsOrReplicated, ...],
        out_parts_thunk: Callable[[], Tuple[pxla.PartitionsOrReplicated, ...]],
        local_in_parts: Optional[Tuple[pxla.PartitionsOrReplicated, ...]],
        local_out_parts_thunk: Callable[[], Optional[Tuple[
            pxla.PartitionsOrReplicated,
            ...]]], local_nparts: Optional[int], name: str, *abstract_args):
    nrep = 1

    if local_in_parts is None:
        local_in_parts = in_parts

    global_abstract_args = [
        pxla.get_global_aval(arg, parts,
                             lparts) for arg, parts, lparts in safe_zip(
                                 abstract_args, in_parts, local_in_parts)
    ]

    if logging.vlog_is_on(2):
        logging.vlog(2, "abstract_args: %s", abstract_args)
        logging.vlog(2, "global_abstract_args: %s", global_abstract_args)
        logging.vlog(2, "in_parts: %s", in_parts)
        logging.vlog(2, "local_in_parts: %s", local_in_parts)

    jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final(
        fun, global_abstract_args)

    platform = xb.get_backend().platform

    nparts = pxla.reconcile_num_partitions(jaxpr, nparts)
    assert nparts is not None
    if nparts > xb.device_count():
        raise ValueError(
            f"sharded_jit computation requires {nparts} devices, "
            f"but only {xb.device_count()} devices are available.")
    if xb.local_device_count() < nparts < xb.device_count():
        raise NotImplementedError(
            f"sharded_jit across multiple hosts must use all available devices. "
            f"Got {nparts} out of {xb.device_count()} requested devices "
            f"(local device count: {xb.local_device_count()})")

    if local_nparts is None:
        if nparts > xb.local_device_count():
            raise ValueError(
                "Specify 'local_nparts' when using cross-process sharded_jit "
                "and all inputs and outputs are replicated.")
        else:
            local_nparts = nparts
    if local_nparts > xb.local_device_count():
        raise ValueError(
            f"sharded_jit computation requires {local_nparts} local devices, "
            f"but only {xb.local_device_count()} local devices are available.")

    if logging.vlog_is_on(2):
        logging.vlog(2, "nparts: %d  local_nparts: %d", nparts, local_nparts)

    out_parts = out_parts_thunk()

    local_out_parts = local_out_parts_thunk()
    if local_out_parts is None:
        local_out_parts = out_parts

    if logging.vlog_is_on(2):
        logging.vlog(2, "out_parts: %s", out_parts)
        logging.vlog(2, "local_out_parts: %s", local_out_parts)

    local_out_avals = [
        pxla.get_local_aval(out, parts,
                            lparts) for out, parts, lparts in safe_zip(
                                global_out_avals, out_parts, local_out_parts)
    ]

    log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
    logging.log(log_priority, "Compiling %s for %d devices with args %s.",
                fun.__name__, nparts, global_abstract_args)

    axis_env = xla.AxisEnv(nrep, (), ())
    unordered_effects = [
        eff for eff in jaxpr.effects if eff not in core.ordered_effects
    ]
    ordered_effects = [
        eff for eff in jaxpr.effects if eff in core.ordered_effects
    ]
    module, _ = mlir.lower_jaxpr_to_module(
        f"spjit_{fun.__name__}",
        core.ClosedJaxpr(jaxpr, consts),
        unordered_effects,
        ordered_effects,
        platform=platform,
        axis_context=mlir.ReplicaAxisContext(axis_env),
        name_stack=new_name_stack(wrap_name(name, "sharded_jit")),
        donated_args=[False] * len(in_parts),
        arg_shardings=safe_map(xla.sharding_to_proto, in_parts),
        result_shardings=safe_map(xla.sharding_to_proto, out_parts))
    built = xc._xla.mlir.mlir_module_to_xla_computation(
        mlir.module_to_string(module), use_tuple_args=False, return_tuple=True)

    if nparts <= xb.local_device_count():
        devices = xb.local_devices()[:nparts]
    else:
        assert nparts == xb.device_count()
        devices = xb.devices()
    device_assignment = np.array([[d for d in devices]])
    device_assignment = np.reshape(device_assignment, (-1, nparts))
    # device_assignment = None  # TODO(skye): replace with default device assignment?

    compiled = dispatch.backend_compile(
        xb.get_backend(), built,
        xb.get_compile_options(nrep, nparts, device_assignment))

    input_specs = [
        pxla.partitioned_sharding_spec(local_nparts, parts, aval)
        for parts, aval in zip(local_in_parts, abstract_args)
    ]
    input_indices = [
        pxla.spec_to_indices(aval.shape, spec) if spec is not None else None
        for aval, spec in zip(abstract_args, input_specs)
    ]

    handle_args = partial(pxla.shard_args, compiled.local_devices(),
                          input_indices)
    handle_outs = _avals_to_results_handler(
        nrep,
        local_nparts,  # type: ignore
        local_out_parts,
        local_out_avals)
    return partial(_execute_spatially_partitioned, compiled, handle_args,
                   handle_outs)