Exemplo n.º 1
0
def _while_callback_rule(trace, *tracers, cond_jaxpr, body_jaxpr,
                         cond_nconsts, body_nconsts):
  cond_const_tracers, body_const_tracers, init_tracers = split_list(
            tracers, [cond_nconsts, body_nconsts])
  init_avals = safe_map(lambda x: x.aval, init_tracers)
  cond_const_vals, body_const_vals, init_vals = tree_map(
      lambda x: x.val, (cond_const_tracers, body_const_tracers, init_tracers))

  body_fun = jaxpr_as_fun(body_jaxpr)
  cond_fun = jaxpr_as_fun(cond_jaxpr)

  def cond(*carry):
    return cond_fun(*it.chain(cond_const_vals, carry))

  def body(*carry):
    return body_fun(*it.chain(body_const_vals, carry))

  new_cond = callback_transform(cond, trace.callback, strip_calls=trace.strip_calls)  # type: ignore
  new_body = callback_transform(body, trace.callback, strip_calls=trace.strip_calls)  # type: ignore
  in_tree = tree_structure(init_avals)

  new_cond_jaxpr, new_cond_consts, _ = lcf._initial_style_jaxpr(new_cond, in_tree, tuple(init_avals))
  new_body_jaxpr, new_body_consts, _ = lcf._initial_style_jaxpr(new_body, in_tree, tuple(init_avals))
  out = lcf.while_p.bind(
      *it.chain(new_cond_consts, new_body_consts, init_vals),
      cond_nconsts=len(new_cond_consts),
      body_nconsts=len(new_body_consts),
      cond_jaxpr=new_cond_jaxpr,
      body_jaxpr=new_body_jaxpr)
  return safe_map(trace.pure, out)
Exemplo n.º 2
0
def _root_jvp(const_lengths, jaxprs, primals, tangents):
    params, _ = _split_root_args(primals, const_lengths)
    sol = _custom_root(const_lengths, jaxprs, *primals)

    f_out_vals = len(jaxprs.f.out_avals)
    solution, aux = split_list(sol, [f_out_vals])

    params_dot, _ = _split_root_args(tangents, const_lengths)

    # F(m, u) = 0      # system of equations in u, parameterized by m
    #                  # solution is u*(m) defined in a neighborhood
    # F(m, u*(m)) = 0  # satisfied in a neighborhood
    #
    # ∂_0 F(m, u*(m)) + ∂_1 F(m, u*(m)) ∂ u*(m) = 0       # implied by line above
    # ∂ u*(m) = - (∂_1 F(m, u*(m)))^{-1} ∂_0 F(m, u*(m))  # rearrange
    #
    # ∂ u*(m)[v] = - (∂_1 F(m, u*(m)))^{-1} [∂_0 F(m, u*(m))[v]]  # jvp

    f = core.jaxpr_as_fun(jaxprs.f)
    linearize_and_solve = partial(core.jaxpr_as_fun(jaxprs.l_and_s),
                                  *params.l_and_s)
    f_at_solution = lambda *params: f(*params, *solution)
    _, rhs = ad.jvp(lu.wrap_init(f_at_solution)).call_wrapped(
        params.f, params_dot.f)
    solution_dot = _map(operator.neg, linearize_and_solve(*solution, *rhs))
    # append aux, create symbolic zero tangents for the aux values
    solution += aux
    solution_dot += _map(lax.zeros_like_array, aux)

    return solution, solution_dot
Exemplo n.º 3
0
def _cond_batching_rule(args, dims, true_jaxpr, false_jaxpr, true_nconsts,
                        false_nconsts):
  # TODO: maybe avoid moving arg axes to front if we're promoting to select?
  args = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0
          else x for x, d in zip(args, dims)]
  true_nops = len(true_jaxpr.in_avals) - true_nconsts
  (pred,), true_consts, true_ops, false_consts, false_ops = split_list(
      args, [1, true_nconsts, true_nops, false_nconsts])
  size, = {x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped}
  orig_bat = [d is not batching.not_mapped for d in dims]
  (pred_bat,), tconst_bat, t_bat, fconst_bat, f_bat = split_list(
    orig_bat, [1, true_nconsts, true_nops, false_nconsts])

  _, true_out_bat = batching.batch_jaxpr(true_jaxpr, size, tconst_bat + t_bat, False)
  _, false_out_bat = batching.batch_jaxpr(false_jaxpr, size, fconst_bat + f_bat, False)
  out_bat = [a or b for a, b in zip(true_out_bat, false_out_bat)]

  true_jaxpr_batched, _ = batching.batch_jaxpr(true_jaxpr, size, tconst_bat + t_bat, out_bat)
  false_jaxpr_batched, _ = batching.batch_jaxpr(false_jaxpr, size, fconst_bat + f_bat, out_bat)

  if pred_bat:
    true_out = core.jaxpr_as_fun(true_jaxpr_batched)(*(true_consts + true_ops))
    false_out = core.jaxpr_as_fun(false_jaxpr_batched)(*(false_consts + false_ops))
    true_out = [batching.broadcast(x, size, 0) if not b else x
                for x, b in zip(true_out, out_bat)]
    false_out = [batching.broadcast(x, size, 0) if not b else x
                 for x, b in zip(false_out, out_bat)]
    return [_cond_pred_bcast_select(pred, t, f)
            for t, f in zip(true_out, false_out)], [0] * len(true_out)
  else:
    out_dims = [0 if b else batching.not_mapped for b in out_bat]
    return cond_p.bind(
      *itertools.chain([pred], true_consts, true_ops, false_consts, false_ops),
      true_jaxpr=true_jaxpr_batched, false_jaxpr=false_jaxpr_batched,
      true_nconsts=len(true_consts), false_nconsts=len(false_consts)), out_dims
Exemplo n.º 4
0
def checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr, error):
  cond_f = core.jaxpr_as_fun(cond_jaxpr)
  body_f = core.jaxpr_as_fun(body_jaxpr)
  def new_body_f(*vals):
    out = body_f(*vals)
    _ = cond_f(*out)  # this checks if the next cond application will error
    return out
  return checkify_fun_to_jaxpr(lu.wrap_init(new_body_f), error, body_jaxpr.in_avals)
Exemplo n.º 5
0
def checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr, error):
    cond_f = core.jaxpr_as_fun(cond_jaxpr)
    body_f = core.jaxpr_as_fun(body_jaxpr)

    def new_body_f(*vals):
        _ = cond_f(*vals)
        return body_f(*vals)

    return checkify_fun_to_jaxpr(lu.wrap_init(new_body_f), error,
                                 body_jaxpr.in_avals)
Exemplo n.º 6
0
def _reap_while_rule(trace: HarvestTrace, *tracers, cond_jaxpr, body_jaxpr,
                     cond_nconsts, body_nconsts):
    """Reaps the body of a while loop to get the reaps of the final iteration."""
    cond_const_tracers, body_const_tracers, init_tracers = jax_util.split_list(
        tracers, [cond_nconsts, body_nconsts])
    _, init_avals = tree_util.tree_map(lambda x: x.aval,
                                       (body_const_tracers, init_tracers))
    cond_const_vals, body_const_vals, init_vals = tree_util.tree_map(
        lambda x: x.val,
        (cond_const_tracers, body_const_tracers, init_tracers))
    context = trace_util.get_dynamic_context(trace)
    settings = context.settings
    body_metadata = _get_harvest_metadata(body_jaxpr, settings,
                                          *(body_const_tracers + init_tracers))
    for k, meta in body_metadata.items():
        mode = meta['mode']
        if mode != 'clobber':
            raise ValueError(
                f'Must use clobber mode for \'{k}\' inside of a `while_loop`.')
    reap_avals = {k: v['aval'] for k, v in body_metadata.items()}

    cond_fun = jax_core.jaxpr_as_fun(cond_jaxpr)
    body_fun = jax_core.jaxpr_as_fun(body_jaxpr)
    reap_settings = dict(tag=settings.tag,
                         allowlist=settings.allowlist,
                         blocklist=settings.blocklist,
                         exclusive=settings.exclusive)

    def new_cond(carry, _):
        return cond_fun(*(cond_const_vals + carry))

    def new_body(carry, _):
        carry, reaps = call_and_reap(
            body_fun, **reap_settings)(*(body_const_vals + carry))
        return (carry, reaps)

    new_in_avals, new_in_tree = tree_util.tree_flatten(
        (init_avals, reap_avals))
    new_cond_jaxpr, cond_consts, _ = lcf._initial_style_jaxpr(  # pylint: disable=protected-access
        new_cond, new_in_tree, tuple(new_in_avals))
    new_body_jaxpr, body_consts, out_tree = lcf._initial_style_jaxpr(  # pylint: disable=protected-access
        new_body, new_in_tree, tuple(new_in_avals))
    dummy_reap_vals = tree_util.tree_map(lambda x: jnp.zeros(x.shape, x.dtype),
                                         reap_avals)
    new_in_vals = tree_util.tree_leaves((init_vals, dummy_reap_vals))
    out = lax.while_p.bind(*(cond_consts + body_consts + new_in_vals),
                           cond_nconsts=len(cond_consts),
                           body_nconsts=len(body_consts),
                           cond_jaxpr=new_cond_jaxpr,
                           body_jaxpr=new_body_jaxpr)
    out = jax_util.safe_map(trace.pure, out)
    out, reaps = tree_util.tree_unflatten(out_tree, out)
    for k, v in reaps.items():
        sow(v, name=k, tag=settings.tag, mode=body_metadata[k]['mode'])
    return out
Exemplo n.º 7
0
def _cond_impl(pred, *args, **kwargs):
    true_jaxpr, false_jaxpr, true_nconsts, false_nconsts = split_dict(
        kwargs, ["true_jaxpr", "false_jaxpr", "true_nconsts", "false_nconsts"])
    true_consts, true_ops, false_consts, false_ops = split_list(
        args,
        [true_nconsts, len(true_jaxpr.in_avals), false_nconsts])

    if pred:
        return core.jaxpr_as_fun(true_jaxpr)(*(true_consts + true_ops))
    else:
        return core.jaxpr_as_fun(false_jaxpr)(*(false_consts + false_ops))
Exemplo n.º 8
0
def checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr, error, enabled_errors,
                              c_consts):
    cond_f = core.jaxpr_as_fun(cond_jaxpr)
    body_f = core.jaxpr_as_fun(body_jaxpr)

    def new_body_f(*vals):
        out = body_f(*vals)
        # This checks if the next cond application will error
        _ = cond_f(*c_consts, *out)
        return out

    return checkify_fun_to_jaxpr(lu.wrap_init(new_body_f), error,
                                 enabled_errors, body_jaxpr.in_avals)
Exemplo n.º 9
0
def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts,
                           cond_jaxpr, body_nconsts, body_jaxpr):
    c_consts, b_consts, carry = split_list(in_flat,
                                           [cond_nconsts, body_nconsts])

    # Check if the first cond application will error.
    cond_jaxpr_, msgs_cond = checkify_jaxpr(cond_jaxpr, error, enabled_errors)
    cond_err, cond_code, cond_payload, _ = core.jaxpr_as_fun(cond_jaxpr_)(
        error.err, error.code, error.payload, *c_consts, *carry)
    del cond_jaxpr_

    checked_body_jaxpr_, msgs_body = checkify_while_body_jaxpr(
        cond_jaxpr, body_jaxpr, error, enabled_errors, c_consts)
    to_move = [False] * 3 + [True] * body_nconsts + [False] * len(carry)
    checked_body_jaxpr = pe.move_binders_to_front(checked_body_jaxpr_, to_move)

    compat_cond_jaxpr_ = ignore_errors_jaxpr(cond_jaxpr, error)
    to_move = [False] * 3 + [True] * cond_nconsts + [False] * len(carry)
    compat_cond_jaxpr = pe.move_binders_to_front(compat_cond_jaxpr_, to_move)
    new_in_flat = [
        *c_consts, *b_consts, cond_err, cond_code, cond_payload, *carry
    ]

    err, code, payload, *out = lax.while_p.bind(*new_in_flat,
                                                cond_nconsts=cond_nconsts,
                                                cond_jaxpr=compat_cond_jaxpr,
                                                body_nconsts=body_nconsts,
                                                body_jaxpr=checked_body_jaxpr)
    new_msgs = {**error.msgs, **msgs_body, **msgs_cond}
    return out, Error(err, code, new_msgs, payload)
Exemplo n.º 10
0
def _custom_linear_solve_impl(*args, **kwargs):
  const_lengths, jaxprs, tree = split_dict(
      kwargs, ['const_lengths', 'jaxprs', 'tree'])
  params, b = _split_linear_solve_args(args, const_lengths)
  x = core.jaxpr_as_fun(jaxprs.solve)(*(params.solve + b))
  _check_shapes('solve', 'b', x, b, tree)
  return x
Exemplo n.º 11
0
def _scan_callback_rule(trace, *tracers, reverse, length, num_consts, num_carry,
                        jaxpr, linear, unroll):
  const_tracers, carry_tracers, xs_tracers = split_list(tracers, [num_consts, num_carry])
  carry_avals, xs_avals = tree_map(lambda x: x.aval, (carry_tracers, xs_tracers))
  const_vals, carry_vals, xs_vals = tree_map(lambda x: x.val, (const_tracers, carry_tracers, xs_tracers))

  x_tracers = [t[0] for t in xs_tracers]
  x_avals = [t.aval for t in x_tracers]

  body_fun = jaxpr_as_fun(jaxpr)

  def new_body(*vals):
    out = body_fun(*vals)
    out_carry, y = split_list(out, [num_carry])
    return out_carry, y
  new_body = callback_transform(new_body, trace.callback,
                                strip_calls=trace.strip_calls)  # type: ignore
  in_tree = tree_structure(carry_avals + xs_avals)
  new_jaxpr, new_consts, _ = lcf._initial_style_jaxpr(
      new_body, in_tree, tuple(carry_avals + x_avals))
  vals = tuple(it.chain(new_consts, carry_vals, xs_vals))
  out_vals = lax.scan_p.bind(*vals, reverse=reverse, length=length,
                             num_consts=len(new_consts), num_carry=num_carry,
                             jaxpr=new_jaxpr, linear=linear, unroll=unroll)
  return safe_map(trace.pure, out_vals)
Exemplo n.º 12
0
def callback_jaxpr(closed_jaxpr, callback, strip_calls):
  fun = lu.wrap_init(jaxpr_as_fun(closed_jaxpr))
  fun = callback_subtrace(fun)
  fun = _callback_fun(fun, callback, strip_calls)
  avals_in = closed_jaxpr.in_avals
  jaxpr_out, consts = cd._initial_style_jaxpr(fun, avals_in)
  return core.ClosedJaxpr(jaxpr_out, consts)
Exemplo n.º 13
0
 def body_fun(i, vals):
   idx = i if forward else length - i - 1
   carry, ys = vals
   x = _index_arrays(idx, x_aval, xs)
   carry_out, y = core.jaxpr_as_fun(jaxpr)(consts, carry, x)
   ys_out = _update_arrays(idx, y_aval, ys, y)
   return (carry_out, ys_out)
Exemplo n.º 14
0
 def get_bind_params(self, params):
     assert 'call_jaxpr' in params
     assert 'transpose_jaxpr_thunk' in params
     new_params = dict(params)
     new_params['transpose'] = make_transpose_from_thunk(
         new_params.pop('transpose_jaxpr_thunk'), new_params['lin_tree'])
     call = lu.wrap_init(core.jaxpr_as_fun(new_params.pop('call_jaxpr')))
     return [call], new_params
Exemplo n.º 15
0
 def f_aug(*args):
     outs_and_residuals = core.jaxpr_as_fun(jaxpr)(*args)
     outs, residuals = split_list(outs_and_residuals,
                                  [num_non_res_outputs])
     aug_residuals = _map(ad_util.zeros_like_aval, all_res_avals)
     aug_residuals = util.subvals(aug_residuals,
                                  zip(res_indices, residuals))
     return outs + list(aug_residuals)
Exemplo n.º 16
0
 def body_fun(i, vals):
     idx = i if forward else length - i - 1
     carry, ys = vals
     x = _index_arrays(idx, x_aval, xs)
     cell = parametrized(jc.jaxpr_as_fun(jaxpr))
     carry_out, y = cell.apply(cell_params, consts, carry, x)
     ys_out = _update_arrays(idx, y_aval, ys, y)
     return carry_out, ys_out
Exemplo n.º 17
0
 def body_fun(i, vals):
   i = i if forward else length - i - 1
   carry, ys = split_list(vals, [num_carry])
   x = _map(partial(_index_array, i), x_avals, xs)
   out_flat = core.jaxpr_as_fun(jaxpr)(*(consts + carry + x))
   carry_out, y_updates = split_list(out_flat, [num_carry])
   ys_out = _map(partial(_update_array, i), y_avals, ys, y_updates)
   return carry_out + ys_out
Exemplo n.º 18
0
def eval_sparse(
    jaxpr: core.Jaxpr,
    consts: Sequence[Array],  # all consts are dense
    argspecs: Sequence[ArgSpec],  # mix of sparse and dense pointers into spenv
    spenv: SparseEnv,
) -> Sequence[ArgSpec]:
    env: Dict[core.Var, ArgSpec] = {}

    def read(var: core.Var) -> Union[Array, ArgSpec]:
        # all literals are dense
        if isinstance(var, core.Literal):
            return ArgSpec(np.shape(var.val), spenv.push(var.val), None)
        else:
            return env[var]

    def write_buffer(var: core.Var, a: Array) -> None:
        if var is core.dropvar:
            return
        env[var] = ArgSpec(a.shape, spenv.push(a), None)

    def write(var: core.Var, a: ArgSpec) -> None:
        if var is core.dropvar:
            return
        env[var] = a

    # TODO: handle unitvar at all?
    #write_buffer(core.unitvar, core.unit)
    safe_map(write_buffer, jaxpr.constvars, consts)
    safe_map(write, jaxpr.invars, argspecs)

    for eqn in jaxpr.eqns:
        prim = eqn.primitive
        invals = safe_map(read, eqn.invars)

        if any(val.is_sparse() for val in invals):
            if prim not in sparse_rules:
                raise NotImplementedError(f"sparse rule for {prim}")
            out = sparse_rules[prim](spenv, *invals, **eqn.params)
        else:
            if prim is xla.xla_call_p:
                # TODO(vanderplas,frostig): workaround for binding call primitives
                # within a jaxpr interpreter
                params = eqn.params.copy()
                fun = lu.wrap_init(
                    core.jaxpr_as_fun(
                        pe.ClosedJaxpr(params.pop('call_jaxpr'), ())))
                out_bufs = prim.bind(fun, *(val.data(spenv) for val in invals),
                                     **params)
            else:
                out_bufs = prim.bind(*(val.data(spenv) for val in invals),
                                     **eqn.params)
            out_bufs = out_bufs if prim.multiple_results else [out_bufs]
            out = []
            for buf in out_bufs:
                out.append(ArgSpec(buf.shape, spenv.push(buf), None))
        safe_map(write, eqn.outvars, out)

    return safe_map(read, jaxpr.outvars)
Exemplo n.º 19
0
def _custom_ivjp_jvp(primals, tangents, *, fun_jaxpr, ivjp_jaxpr):
  primals_out = custom_ivjp_p.bind(*primals, fun_jaxpr=fun_jaxpr,
                                             ivjp_jaxpr=ivjp_jaxpr)
  fun = core.jaxpr_as_fun(fun_jaxpr)
  # FIXME: This might compute the primals multiple times, but we only need to do
  #        this trick while linearizing. It should be possible to do it through
  #        a custom partial eval rule.
  _, tangents_out = ad.jvp(lu.wrap_init(fun)).call_wrapped(primals, tangents)
  return primals_out, tangents_out
Exemplo n.º 20
0
def _batch_jaxpr_axes(closed_jaxpr, axis_size, in_axes, out_axes_dest,
                      axis_name, main_type):
  f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr))
  f, out_batched = _batch_jaxpr_inner(f, axis_size, out_axes_dest)
  f = _batch_jaxpr_outer(f, axis_name, axis_size, in_axes, main_type)
  avals_in = [core.unmapped_aval(axis_size, axis_name, b, aval) if b is not not_mapped
              else aval for aval, b in zip(closed_jaxpr.in_avals, in_axes)]
  jaxpr_out, _, consts = pe.trace_to_jaxpr_dynamic(f, avals_in)
  return core.ClosedJaxpr(jaxpr_out, consts), out_batched()
Exemplo n.º 21
0
Arquivo: ad.py Projeto: jbampton/jax
def _jvp_jaxpr(jaxpr, nonzeros, instantiate):
  assert len(jaxpr.in_avals) == len(nonzeros)
  f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
  f_jvp, out_nonzeros = f_jvp_traceable(jvp(f, instantiate=instantiate, transform_stack=False),
                                        nonzeros)
  tangent_avals = [aval for aval, nz in zip(jaxpr.in_avals, nonzeros) if nz]
  avals_in = list(it.chain(jaxpr.in_avals, tangent_avals))
  jaxpr_out, avals_out, literals_out = pe.trace_to_jaxpr_dynamic(f_jvp, avals_in)
  return core.ClosedJaxpr(jaxpr_out, literals_out), out_nonzeros()
Exemplo n.º 22
0
def _xla_call_sparse(spenv, *spvalues, call_jaxpr, donated_invars, **params):
  if any(donated_invars):
    raise NotImplementedError("sparse xla_call with donated_invars")
  sp_call_jaxpr, out_tree = _sparsify_jaxpr(spenv, pe.ClosedJaxpr(call_jaxpr, ()), *spvalues)
  fun = lu.wrap_init(core.jaxpr_as_fun(sp_call_jaxpr))
  args_flat, _ = tree_flatten(spvalues_to_arrays(spenv, spvalues))
  donated_invars = tuple(False for arg in args_flat)
  out_flat = xla.xla_call_p.bind(fun, *args_flat, donated_invars=donated_invars, **params)
  return arrays_to_spvalues(spenv, tree_unflatten(out_tree, out_flat))
Exemplo n.º 23
0
def checkify_jaxpr(jaxpr, error):
  f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
  f, msgs = check_errors_subtrace(f)
  f = check_errors_traceable(f, tuple(error.msgs.items()))
  err_aval = core.raise_to_shaped(core.get_aval(error.err))
  code_aval = core.raise_to_shaped(core.get_aval(error.code))
  avals_in = [err_aval, code_aval, *jaxpr.in_avals]
  jaxpr_out, _, literals_out = pe.trace_to_jaxpr_dynamic(f, avals_in)
  return core.ClosedJaxpr(jaxpr_out, literals_out), msgs()
Exemplo n.º 24
0
def eval_sparse(
    jaxpr: core.Jaxpr,
    consts: Sequence[Array],  # all consts are dense
    spvalues: Sequence[SparsifyValue],  # mix of sparse and dense pointers into spenv
    spenv: SparsifyEnv,
) -> Sequence[SparsifyValue]:
  env : Dict[core.Var, SparsifyValue] = {}

  def read(var: core.Var) -> Union[Array, SparsifyValue]:
    # all literals are dense
    if isinstance(var, core.Literal):
      return spenv.dense(var.val)
    else:
      return env[var]

  def write_buffer(var: core.Var, a: Array) -> None:
    if isinstance(var, core.DropVar):
      return
    env[var] = spenv.dense(a)

  def write(var: core.Var, a: SparsifyValue) -> None:
    if isinstance(var, core.DropVar):
      return
    assert a is not None
    env[var] = a

  safe_map(write_buffer, jaxpr.constvars, consts)
  safe_map(write, jaxpr.invars, spvalues)

  for eqn in jaxpr.eqns:
    prim = eqn.primitive
    invals = safe_map(read, eqn.invars)

    if any(val.is_sparse() for val in invals):
      if prim not in sparse_rules:
        _raise_unimplemented_primitive(prim)
      out = sparse_rules[prim](spenv, *invals, **eqn.params)
    else:
      if prim is xla.xla_call_p:
        # TODO(vanderplas,frostig): workaround for binding call primitives
        # within a jaxpr interpreter
        params = eqn.params.copy()
        fun = lu.wrap_init(core.jaxpr_as_fun(pe.ClosedJaxpr(params.pop('call_jaxpr'), ())))
        out_bufs = prim.bind(fun, *(spenv.data(val) for val in invals), **params)
      else:
        out_bufs = prim.bind(*(spenv.data(val) for val in invals), **eqn.params)
      out_bufs = out_bufs if prim.multiple_results else [out_bufs]
      out = []
      for buf, outvar in safe_zip(out_bufs, eqn.outvars):
        if isinstance(outvar, core.DropVar):
          out.append(None)
        else:
          out.append(spenv.dense(buf))
    safe_map(write, eqn.outvars, out)

  return safe_map(read, jaxpr.outvars)
Exemplo n.º 25
0
def _cond_batching_rule(axis_size, axis_name, main_type, args, dims, branches,
                        linear):
    index, *ops = args
    index_dim, *op_dims = dims

    if index_dim is not batching.not_mapped:
        # Convert to a lax.select. While we could get away with not broadcasting
        # some operands yet, because all outputs must be broadcast together anyway
        # for the select we broadcast the input operands for simplicity and leave
        # optimizations to XLA.
        # TODO(mattjj,frostig): assumes branches are side-effect-free, revise!
        index, *ops = (batching.bdim_at_front(x, d, axis_size)
                       for x, d in zip(args, dims))

        in_batched = [True] * len(branches[0].in_avals)
        out_batched = [True] * len(branches[0].out_avals)

        branches_batched = [
            batching.batch_jaxpr(jaxpr, axis_size, in_batched, out_batched,
                                 axis_name, main_type)[0] for jaxpr in branches
        ]

        branch_outs = []
        for i, jaxpr in enumerate(branches_batched):
            # Perform a select on the inputs for safety of reverse-mode autodiff; see
            # https://github.com/google/jax/issues/1052
            predicate = lax.eq(index, lax._const(index, i))
            ops_ = [
                _bcast_select(predicate, x, lax.stop_gradient(x)) for x in ops
            ]
            branch_outs.append(core.jaxpr_as_fun(jaxpr)(*ops_))
        out = [_bcast_select_n(index, *outs) for outs in zip(*branch_outs)]
        return out, [0 if b else None for b in out_batched]
    else:
        ops_bat = [d is not batching.not_mapped for d in op_dims]
        ops = [
            batching.moveaxis(x, d, 0) if b else x
            for b, x, d in zip(ops_bat, ops, op_dims)
        ]

        branches_out_bat = [
            batching.batch_jaxpr(jaxpr, axis_size, ops_bat, False, axis_name,
                                 main_type)[1] for jaxpr in branches
        ]
        out_bat = [any(bat) for bat in zip(*branches_out_bat)]
        branches_batched = tuple(
            batching.batch_jaxpr(jaxpr, axis_size, ops_bat, out_bat, axis_name,
                                 main_type)[0] for jaxpr in branches)

        out_dims = [0 if b else batching.not_mapped for b in out_bat]
        out = cond_p.bind(index,
                          *ops,
                          branches=branches_batched,
                          linear=linear)
        return out, out_dims
Exemplo n.º 26
0
 def do_transpose(primals_in, cotangents_in):
     # NOTE: This is passing in undefined primals in place of tangent arguments, but it
     #       should all work out, because we're only computing the primal part here.
     residuals = core.jaxpr_as_fun(primal_jaxpr)(
         *primals_in)[len(cotangents_in):]
     # Now that we have a purely linear jaxpr, we can transpose it
     cotangents_out = backward_pass(tangent_jaxpr.jaxpr, reduce_axes, (),
                                    primals_in + residuals, cotangents_in)
     # backward_pass will return cotangents computed for all invars, but some of them
     # are residuals appended by partial eval, so we need to skip those before we return.
     return cotangents_out[:len(primals_in)]
Exemplo n.º 27
0
def _remat_transpose(primal_jaxpr, tangent_jaxpr, reduce_axes,
                     primals_tangents_in, cotangents_in):
  primals_in  = [x for x in primals_tangents_in if not is_undefined_primal(x)]
  tangents_in = [x for x in primals_tangents_in if     is_undefined_primal(x)]
  res = core.jaxpr_as_fun(primal_jaxpr)(*primals_in)
  cotangents_out_ = backward_pass(tangent_jaxpr.jaxpr, reduce_axes, False, (),
                                  (*res, *tangents_in), cotangents_in)
  cotangents_out = iter(cotangents_out_[len(res):])
  outs = [next(cotangents_out) if is_undefined_primal(x) else Zero.from_value(x)
          for x in primals_tangents_in]
  assert next(cotangents_out, None) is None
  return outs
Exemplo n.º 28
0
def _interpret_jaxpr(jaxpr: core.TypedJaxpr, *args: TfValOrUnit) -> Sequence[TfVal]:
  """Evaluates a Jaxpr with tf.Tensor arguments.

  It is safe to call this function with arguments TfVal or TfValOrUnit, they
  will be replaced with `core.unit` if the `jaxpr` expects units.

  The output is a sequence of TfVal (no `core.unit`), suitable for use with TF.
  """
  fun: lu.WrappedFun = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
  args_jax: Sequence[TfValOrUnit] = _tfval_add_unit(args, jaxpr.in_avals)
  out_vals_jax: Sequence[TfValOrUnit] = _interpret_fun(fun, args_jax)
  return _tfval_remove_unit(out_vals_jax)
Exemplo n.º 29
0
def _masked_scan_jaxpr(jaxpr, num_consts, num_carry):
  fun = core.jaxpr_as_fun(jaxpr)

  @lu.wrap_init
  def masked(*args):
    [dynamic_length], consts, [i], carry, xs = split_list(
        args, [1, num_consts, 1, num_carry])
    out = fun(*(consts + carry + xs))
    new_carry, ys = split_list(out, [num_carry])
    new_carry = [lax.select(i < dynamic_length, new_c, c)
                 for new_c, c in zip(new_carry, carry)]
    return [i + 1] + new_carry + ys

  aval = ShapedArray((), onp.int64)
  const_avals, carry_avals, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry])
  return _make_typed_jaxpr(masked, [aval] + const_avals + [aval] + carry_avals + x_avals)
Exemplo n.º 30
0
def _get_harvest_metadata(closed_jaxpr, settings, *args):
    """Probes a jaxpr for metadata like its sown values."""
    fun = lu.wrap_init(jax_core.jaxpr_as_fun(closed_jaxpr))
    with jax_core.new_main(HarvestTrace) as main:
        settings = HarvestSettings(settings.tag, settings.blocklist,
                                   settings.allowlist, True)
        fun = reap_function(fun, main, settings, True)
        fun, aux = _reap_metadata_wrapper(fun)
        flat_args, in_tree = tree_util.tree_flatten(args)
        flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree)
        in_avals = jax_util.safe_map(
            lambda a: abstract_arrays.raise_to_shaped(jax_core.get_aval(a)),
            flat_args)
        pe.trace_to_jaxpr_final(flat_fun, in_avals)
        metadata = aux()
        out_tree()
    return metadata