コード例 #1
0
    def trace_to_jaxpr_finalize(in_tracers,
                                out_tracers,
                                trace,
                                instantiate=True):
        # TODO: This is the final part of the partial_eval.trace_to_subjaxpr. Share.
        instantiate = [instantiate] * len(out_tracers)
        out_tracers = safe_map(trace.full_raise,
                               safe_map(core.full_lower, out_tracers))
        out_tracers = safe_map(partial(pe.instantiate_const_at, trace),
                               instantiate, out_tracers)
        jaxpr, consts, env = pe.tracers_to_jaxpr(in_tracers, out_tracers)
        out_pvals = [t.pval for t in out_tracers]
        # TODO: this is from partial_eval.trace_to_jaxpr. Share.
        assert not env

        # TODO: this is from the final part of lax_control_flow._initial_style_jaxpr
        out_avals = safe_map(abstract_arrays.raise_to_shaped,
                             unzip2(out_pvals)[0])
        const_avals = tuple(
            abstract_arrays.raise_to_shaped(core.get_aval(c)) for c in consts)

        in_pvals = [t.pval for t in in_tracers]
        in_avals = tuple(
            safe_map(abstract_arrays.raise_to_shaped,
                     unzip2(in_pvals)[0]))

        typed_jaxpr = core.TypedJaxpr(pe.convert_constvars_jaxpr(jaxpr), (),
                                      const_avals + in_avals, out_avals)
        return typed_jaxpr, consts
コード例 #2
0
ファイル: lax_control_flow.py プロジェクト: superbobry/jax
def rearrange_binders(f, typed_jaxpr):
  jaxpr = typed_jaxpr.jaxpr.copy()
  jaxpr.invars = f(*jaxpr.invars)
  in_avals = f(*typed_jaxpr.in_avals)
  core.skip_checks or core.check_jaxpr(jaxpr)
  return core.TypedJaxpr(jaxpr, typed_jaxpr.literals, in_avals,
                         typed_jaxpr.out_aval)
コード例 #3
0
def _make_typed_jaxpr(traceable, in_avals):
    pvals = [pe.PartialVal((aval, core.unit)) for aval in in_avals]
    jaxpr, pvals_out, consts = pe.trace_to_jaxpr(traceable,
                                                 pvals,
                                                 instantiate=True)
    out_avals, _ = unzip2(pvals_out)
    return core.TypedJaxpr(jaxpr, consts, in_avals, out_avals)
コード例 #4
0
ファイル: lax_control_flow.py プロジェクト: proteneer/jax
def _make_typed_jaxpr(traceable, in_avals):
    pvals = [pe.PartialVal((aval, core.unit)) for aval in in_avals]
    jaxpr, pval_out, consts = pe.trace_to_jaxpr(traceable,
                                                pvals,
                                                instantiate=True)
    out_aval, _ = pval_out
    assert isinstance(out_aval, core.AbstractValue)
    return core.TypedJaxpr(jaxpr, consts, in_avals, out_aval)
コード例 #5
0
ファイル: lax_control_flow.py プロジェクト: jonasrauber/jax
def _initial_style_jaxpr(fun, in_tree, in_avals):
  in_pvals = [pe.PartialVal((aval, core.unit)) for aval in in_avals]
  fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
  jaxpr, out_pvals, consts = pe.trace_to_jaxpr(fun, in_pvals, instantiate=True)
  out_avals = _map(raise_to_shaped, unzip2(out_pvals)[0])
  const_avals = tuple(raise_to_shaped(core.get_aval(c)) for c in consts)
  typed_jaxpr = core.TypedJaxpr(pe.closure_convert_jaxpr(jaxpr),
                                (), const_avals + in_avals, out_avals)
  return typed_jaxpr, consts, out_tree()
コード例 #6
0
def _rewrite_typed_jaxpr(
        tjaxpr: core.TypedJaxpr, has_input_token: bool,
        has_output_token: bool) -> Tuple[core.TypedJaxpr, bool]:
    """Rewrites a TypedJaxpr to thread the token, if needed.

  Returns the rewritten Jaxpr, and whether it uses outfeed."""
    new_jaxpr, uses_outfeed = _rewrite_jaxpr(tjaxpr.jaxpr, has_input_token,
                                             has_output_token)
    return (core.TypedJaxpr(new_jaxpr, tjaxpr.literals,
                            tuple(map(lambda v: v.aval, new_jaxpr.invars)),
                            tuple(map(lambda v: v.aval,
                                      new_jaxpr.outvars))), uses_outfeed)
コード例 #7
0
 def wrapped(*args, **kwargs):
     fun = lu.wrap_init(f, kwargs)
     flat_args, in_tree = tree_util.tree_flatten(args)
     flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree)
     flat_avals = safe_map(get_shaped_aval, flat_args)
     pvals = [pe.PartialVal((aval, jax_core.unit)) for aval in flat_avals]
     jaxpr, out_pvals, consts = pe.trace_to_jaxpr(
         flat_fun,
         pvals,
         instantiate=True,
         stage_out=True,
         trace_type=pe.StagingJaxprTrace)
     out_avals = [pval.get_aval() for pval in out_pvals]
     typed_jaxpr = jax_core.TypedJaxpr(jaxpr, consts, flat_avals, out_avals)
     return typed_jaxpr, (in_tree, out_tree())
コード例 #8
0
def tie_the_knot(typed_jaxpr):
    jaxpr, _, in_avals, out_avals = typed_jaxpr
    assert all(i == o for i, o in zip(in_avals, out_avals))
    in2out = dict(zip(jaxpr.invars, jaxpr.outvars))

    def replace(eqn):
        invars = [
            in2out[i] if (isinstance(i, jc.Var) and i in in2out) else i
            for i in eqn.invars
        ]
        return jc.JaxprEqn(invars, eqn.outvars, eqn.primitive, eqn.params,
                           eqn.source_info)

    eqns = [replace(eqn) for eqn in jaxpr.eqns]
    new_jaxpr = jc.Jaxpr(jaxpr.constvars, [], jaxpr.outvars, eqns)
    return jc.TypedJaxpr(new_jaxpr, typed_jaxpr.literals, [],
                         typed_jaxpr.out_avals)
コード例 #9
0
ファイル: lax_control_flow.py プロジェクト: superbobry/jax
def _move_stuff_and_add_add(typed_jaxpr):
  # jaxpr_lifted_trans :: res -> (CT c, CT b) -> (CT d, CT c, CT a)
  # jaxpr_trans :: * -> (CT c, CT d) -> (CT b, res) -> ((CT c, CT d), CT a)

  res_aval, (CTc_aval, CTb_aval) = typed_jaxpr.in_avals
  CTd_aval, CTc_aval2, CTa_aval = typed_jaxpr.out_aval
  assert CTc_aval == CTc_aval2
  in_avals = (core.AbstractTuple(()), core.AbstractTuple((CTc_aval, CTd_aval)),
              core.AbstractTuple((CTb_aval, res_aval)))
  out_aval = core.AbstractTuple((core.AbstractTuple((CTc_aval, CTd_aval)),
                                 CTa_aval))

  jaxpr = typed_jaxpr.jaxpr.copy()
  # assume the jaxpr isn't restructuring any inputs
  assert not any(type(invar) is tuple for invar in jaxpr.invars)

  # munge input side
  CTc_in = _scan_newvar()
  CTb_in = _scan_newvar()
  CTd_in = _scan_newvar()
  res_in, CTc_CTb_in = jaxpr.invars
  jaxpr.invars = ((), (CTc_in, CTd_in), (CTb_in, res_in))
  jaxpr.eqns = (
      [pe._pack_eqn([CTc_in, CTb_in], CTc_CTb_in)] +
      jaxpr.eqns)

  # munge output side
  CTd_new = _scan_newvar()
  CTd_sum = _scan_newvar()
  CTc = _scan_newvar()
  CTa = _scan_newvar()
  partial_out = _scan_newvar()
  outvar = _scan_newvar()
  jaxpr.eqns = (
      jaxpr.eqns +
      [pe._unpack_eqn(jaxpr.outvar, [CTd_new, CTc, CTa]),
       _add_any_eqn(CTd_sum, CTd_new, CTd_in),
       pe._pack_eqn([CTc, CTd_sum], partial_out),
       pe._pack_eqn([partial_out, CTa], outvar)])
  jaxpr.outvar = outvar

  # TODO(mattjj): add a check_typed_jaxpr and use it here
  core.skip_checks or core.check_jaxpr(jaxpr)
  return core.TypedJaxpr(jaxpr, typed_jaxpr.literals, in_avals, out_aval)
コード例 #10
0
 def wrapped(*args, **kwargs):
   fun = lu.wrap_init(f, kwargs)
   flat_args, in_tree = tree_util.tree_flatten(args)
   flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree)
   flat_avals = safe_map(get_shaped_aval, flat_args)
   if not jax.config.omnistaging_enabled:
     raise ValueError('Oryx must be used with JAX omnistaging enabled.')
   if dynamic:
     jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(
         flat_fun,
         flat_avals)
   else:
     pvals = [pe.PartialVal((aval, jax_core.unit)) for aval in flat_avals]
     jaxpr, out_pvals, consts = pe.trace_to_jaxpr(
         flat_fun,
         pvals,
         instantiate=True)
     out_avals = [pval.get_aval() for pval in out_pvals]
   typed_jaxpr = jax_core.TypedJaxpr(jaxpr, consts, flat_avals, out_avals)
   return typed_jaxpr, (in_tree, out_tree())
コード例 #11
0
ファイル: host_callback.py プロジェクト: zohbahk/jax
def _mk_typed_jaxpr(jaxpr: core.Jaxpr, literals: Sequence) -> core.TypedJaxpr:
  return core.TypedJaxpr(jaxpr, literals,
                         tuple(map(lambda v: v.aval, jaxpr.invars)),
                         tuple(map(lambda v: v.aval, jaxpr.outvars)))
コード例 #12
0
ファイル: invertible_ad.py プロジェクト: xiaoral2/jax
def inv_backward_pass(jaxpr: core.Jaxpr, consts, primals_in, primals_out, cotangents_in):
  if all(type(ct) is ad.Zero for ct in cotangents_in):
    return map(lambda v: ad.Zero(v.aval), jaxpr.invars)

  def write_cotangent(v, ct):
    # assert v not in primal_env
    if ct is not None and type(v) is not Literal:
      ct_env[v] = ad.add_tangents(ct_env[v], ct) if v in ct_env else ct

  def read_cotangent(v):
    return ct_env.get(v, ad.Zero(v.aval))

  def read_primal(v):
    if type(v) is Literal:
      return v.val
    else:
      return primal_env.get(v, ad.UndefinedPrimal(v.aval))

  def write_primal(v, val):
    if type(v) is Literal:
      return
    primal_env.setdefault(v, val)

  # Invert while computing cotangents
  ct_env: Dict[Any, Any] = {}
  primal_env: Dict[Any, Any] = {}
  write_primal(core.unitvar, core.unit)
  map(write_primal, jaxpr.invars, primals_in)
  map(write_primal, jaxpr.outvars, primals_out)
  map(write_primal, jaxpr.constvars, consts)
  map(write_cotangent, jaxpr.outvars, cotangents_in)
  for eqn in jaxpr.eqns[::-1]:
    primals_in = map(read_primal, eqn.invars)
    primals_out = map(read_primal, eqn.outvars)
    cts_in = map(read_cotangent, eqn.outvars)
    should_invert = any(type(primal) is not ad.UndefinedPrimal
                        for primal in primals_out)
    should_vjp = any(type(ct) is not ad.Zero for ct in cts_in)
    assert not eqn.primitive.call_primitive
    assert not (should_invert ^ should_vjp)  # Either both true or both false

    # Skip primals equations that are only jvp coefficients and don't affect
    # primal outputs.
    if not should_invert and not should_vjp:
      continue

    def abstract(value):
      return raise_to_shaped(value.aval if ad.is_undefined_primal(value) else get_aval(value))

    # Get the ivjp_jaxpr
    if eqn.primitive is custom_ivjp_p:
      ivjp_jaxpr = eqn.params['ivjp_jaxpr']
    else:
      if eqn.primitive in primitive_ivjps:
        complete_ivjp = lu.wrap_init(primitive_ivjps[eqn.primitive])
      else:
        complete_ivjp = lu.wrap_init(partial(synthesize_ivjp, eqn, map(ad.is_undefined_primal, primals_in)))
      _, in_tree = tree_flatten(
          tuple(map(abstract, x) for x in (primals_in, primals_out, primals_out)))
      complete_ivjp_flat, _ = flatten_fun_nokwargs(complete_ivjp, in_tree)

      in_avals = map(abstract, primals_in + primals_out + primals_out)
      if config.omnistaging_enabled:
        # TODO: Actually we do know some of the inputs, because they might be literals!
        ivjp_jaxpr, out_pvals, _ = pe.trace_to_jaxpr(
            complete_ivjp_flat, map(pe.PartialVal.unknown, in_avals), instantiate=True)
      else:
        ivjp_jaxpr, out_pvals, _ = pe.trace_to_jaxpr(
          complete_ivjp_flat, map(pe.PartialVal.unknown, in_avals),
          instantiate=True, stage_out=False)
      assert not ivjp_jaxpr.constvars  # That might happen some time, but don't bother until then
      out_avals = map(raise_to_shaped, unzip2(out_pvals)[0])
      ivjp_jaxpr = core.TypedJaxpr(ivjp_jaxpr, [], in_avals, out_avals)

    # Once we know what the ivjp can do exactly, we have to isolate the part we are
    # actually able to compute with the values we have at hand.
    num_inputs = len(eqn.invars)
    unknowns = (map(ad.is_undefined_primal, primals_in) +
                map(ad.is_undefined_primal, primals_out) +
                [False] * len(cts_in))
    if config.omnistaging_enabled:
      jaxpr_known, jaxpr_unknown, out_unknowns = pe.partial_eval_jaxpr(  # type: ignore
          ivjp_jaxpr, unknowns, instantiate=False)  # type:ignore
    else:
      jaxpr_known, jaxpr_unknown, out_unknowns = pe.partial_eval_jaxpr(
          ivjp_jaxpr, unknowns, instantiate=False, trace_type=None)
    unknown_rec_primals_in, unknown_cotangents = split_list(out_unknowns, [num_inputs])
    # Make sure we're able to compute all cotangents. We don't really care if we
    # can reconstruct or primals or not, although failure to do so might result in
    # failing to compute cotangents later.
    assert not any(unknown_cotangents)
    # Remove residual outputs -- we won't be computing the unknown jaxpr anyway.
    num_outputs = len(jaxpr_unknown.jaxpr.outvars)
    jaxpr_known.out_avals = jaxpr_known.out_avals[:num_outputs]
    jaxpr_known.jaxpr.outvars = jaxpr_known.jaxpr.outvars[:num_outputs]
    # TODO: We could drop the outputs that correspond to primals that we already know.
    #       This only matters in eager mode, so leaving it out for now...
    ivjp = core.jaxpr_as_fun(jaxpr_known)
    rec_primals_in, cts_out = split_list(ivjp(*primals_in, *primals_out, *cts_in),
                                         [num_inputs])
    # Unknown rec_primals_in are core.units, so we have to replace them
    # with UnknownPrimals because that's what write_primal will ignore.
    rec_primals_in = [prev if unknown else rec
                      for prev, rec, unknown
                      in zip(primals_in, rec_primals_in, unknown_rec_primals_in)]
    map(write_primal, eqn.invars, rec_primals_in)
    map(write_cotangent, [v for v in eqn.invars if type(v) is not Literal], cts_out)

  # NOTE: We keep the cotangents associated with primal variables, while the contract of a
  #       transpose is to return them in positions associated with tangent variables, which
  #       is what causes this whole confusion.
  return map(read_cotangent, jaxpr.invars)
コード例 #13
0
ファイル: lax_control_flow.py プロジェクト: superbobry/jax
def scan(f, init, xs):
  """Scan a function over leading array axes while carrying along state.

  The type signature in brief is

  .. code-block:: haskell

    scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b])

  where we use [t] here to denote the type t with an additional leading axis.
  That is, if t is an array type then [t] represents the type with an additional
  leading axis, and if t is a pytree (container) type with array leaves then [t]
  represents the type with the same pytree structure and corresponding leaves
  each with an additional leading axis.

  When both ``a`` and ``b`` are array types, the semantics of ``scan`` are given
  by this Python implementation::

    def scan(f, init, xs):
      carry = init
      ys = []
      for x in xs:
        carry, y = f(carry, x)
        ys.append(y)
      return carry, np.stack(ys)

  Unlike that Python version, both ``a`` and ``b`` may be arbitrary pytree
  types, and so multiple arrays can be scanned over at once and produce multiple
  output arrays.

  Also unlike that Python version, ``scan`` is a JAX primitive and is lowered to
  a single XLA While HLO. That makes it useful for reducing compilation times
  for jit-compiled functions, since native Python loop constructs in an ``@jit``
  function are unrolled, leading to large XLA computations.

  Args:
    f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning
      that ``f`` accepts two arguments where the first is a value of the loop
      carry and the second is a slice of ``xs`` along its leading axis, and that
      ``f`` returns a pair where the first element represents a new value for
      the loop carry and the second represents a slice of the output.
    init: an initial loop carry value of type ``c``, which can be a scalar,
      array, or any pytree (nested Python tuple/list/dict) thereof, representing
      the initial loop carry value.
    xs: the value of type ``[a]`` over which to scan along the leading axis,
      where ``[a]`` can be an array or any pytree (nested Python
      tuple/list/dict) thereof with consistent leading axis sizes.

  Returns:
    A pair of type ``(c, [b])`` where the first element represents the final
    loop carry value and the second element represents the stacked outputs of
    the second output of ``f`` when scanned over the leading axis of the inputs.
  """
  (init, xs), in_trees = unzip2(map(pytree_to_jaxtupletree, (init, xs)))
  f, out_tree = pytree_fun_to_jaxtupletree_fun(lu.wrap_init(f), in_trees)
  carry_pval = carry_aval, _ = _abstractify(init)
  xs_aval, _ = _abstractify(xs)
  x_aval = _demote_aval_rank(xs_aval)
  x_pval = pe.PartialVal((x_aval, core.unit))
  jaxpr, pval_out, consts = pe.trace_to_jaxpr(
      f, (carry_pval, x_pval), instantiate=True)
  pv_out, const_out = pval_out
  assert isinstance(pv_out, core.AbstractValue) and const_out == core.unit
  if not isinstance(pv_out, core.AbstractTuple) or len(pv_out) != 2:
    msg = ("scanned function must have signature `c -> a -> (c, b)`, but the "
           "output was not a pair: got type {}.")
    raise TypeError(msg.format(pv_out))
  carry_aval_out, y_aval = pv_out
  if carry_aval != carry_aval_out:
    msg = ("scanned function carry output does not match carry input: "
           "input carry is {} and output carry is {}.")
    raise TypeError(msg.format(carry_aval, carry_aval_out))
  lifted_jaxpr = pe._closure_convert_jaxpr(jaxpr)
  consts_aval, _ = _abstractify(core.pack(consts))
  in_avals = (consts_aval, carry_aval, x_aval)
  out_aval = core.AbstractTuple((carry_aval, y_aval))
  jaxpr = core.TypedJaxpr(lifted_jaxpr, (), in_avals, out_aval)
  length = _leading_dim_size(xs)
  out = scan_p.bind(core.pack(consts), init, xs,
                    forward=True, length=length, jaxpr=jaxpr)
  return build_tree(out_tree(), out)