예제 #1
0
def partial_eval_by_shape(fn, input_spec, *args, **kwargs):
  """Lazily evaluate a function by using the shapes of the inputs.

  This function is similar to `jax.eval_shape` with the key difference that
  function outputs that can be computed without a concrete value of the
  inputs are returned as is instead of only the shape. See for example
  `module.init_by_shape` where this functionality is used to initialize a
  model without using input data lr computation.

  Args:
    fn: the function to be lazily evaluated.
    input_spec: an iterable of shapes or (shape, dtype) tuples specifying the
      shape and type of the inputs. If unspecified the dtype is float32.
    *args: other arguments passed to the module's apply function
    **kwargs: keyword arguments passed to the module's apply function
  Returns:
    A pair consisting of the model output and an instance of Model
  """
  # output cannot be returned in lazy_create because jax.eval_shape will only
  # return the shape and dtype.
  # TODO(mattjj,jheek): use a public JAX API
  f = lambda *inputs: fn(*inputs, *args, **kwargs)
  input_structs = [_parse_spec(spec) for spec in input_spec]
  inputs_flat, in_tree = jax.tree_flatten(input_structs)
  f_flat, out_tree = jax.api_util.flatten_fun_nokwargs(lu.wrap_init(f), in_tree)
  in_pvals = [pe.PartialVal.unknown(jax.ShapedArray(x.shape, x.dtype))
              for x in inputs_flat]

  if config.omnistaging_enabled:
    _, out_pvals, _ = pe.trace_to_jaxpr(f_flat, in_pvals)
  else:
    with jax.core.initial_style_staging():
      _, out_pvals, _ = pe.trace_to_jaxpr(f_flat, in_pvals, stage_out=True)
  out_flat = [const if pv is None else jax.ShapeDtypeStruct(pv.shape, pv.dtype)
              for pv, const in out_pvals]
  return jax.tree_unflatten(out_tree(), out_flat)
예제 #2
0
 def wrapped(*args, **kwargs):
     fun = lu.wrap_init(f, kwargs)
     flat_args, in_tree = tree_util.tree_flatten(args)
     flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree)
     flat_avals = safe_map(get_shaped_aval, flat_args)
     pvals = [pe.PartialVal((aval, jax_core.unit)) for aval in flat_avals]
     jaxpr, out_pvals, consts = pe.trace_to_jaxpr(
         flat_fun,
         pvals,
         instantiate=True,
         stage_out=True,
         trace_type=pe.StagingJaxprTrace)
     out_avals = [pval.get_aval() for pval in out_pvals]
     typed_jaxpr = jax_core.TypedJaxpr(jaxpr, consts, flat_avals, out_avals)
     return typed_jaxpr, (in_tree, out_tree())
예제 #3
0
 def transposed(*args):
     in_primals, out_cts = tree_unflatten(treedef, args)
     in_pvals = [
         pe.PartialVal.unknown(x.aval)
         if ad.is_undefined_primal(x) else pe.PartialVal.known(x)
         for x in in_primals
     ]
     primal_fun = lu.wrap_init(partial(core.eval_jaxpr, jaxpr, ()))
     tangent_jaxpr, _, consts = pe.trace_to_jaxpr(primal_fun, in_pvals,
                                                  False)
     dummy_args = [ad.UndefinedPrimal(v.aval) for v in tangent_jaxpr.invars]
     in_cts_ = ad.backward_pass(tangent_jaxpr, reduce_axes, False, consts,
                                dummy_args, out_cts)
     in_cts, cell.treedef = tree_flatten(in_cts_)
     return in_cts
예제 #4
0
    def scan_fn(broadcast_in, init, *args):
        xs = jax.tree_multimap(transpose_to_front, in_axes, args)

        def body_fn(c, xs, init_mode=False):
            # inject constants
            xs = jax.tree_multimap(
                lambda ax, arg, x: (arg if ax is broadcast else x), in_axes,
                args, xs)
            broadcast_out, c, ys = fn(broadcast_in, c, *xs)

            if init_mode:
                ys = jax.tree_multimap(
                    lambda ax, y: (y if ax is broadcast else ()), out_axes, ys)
                return broadcast_out, ys
            else:
                ys = jax.tree_multimap(
                    lambda ax, y: (() if ax is broadcast else y), out_axes, ys)
                return c, ys

        broadcast_body = functools.partial(body_fn, init_mode=True)

        carry_pvals = jax.tree_map(
            lambda x: pe.PartialVal.unknown(
                jax.ShapedArray(jnp.shape(x), jnp.result_type(x))), init)
        scan_pvals = jax.tree_map(
            lambda x: pe.PartialVal.unknown(
                jax.ShapedArray(jnp.shape(x)[1:], jnp.result_type(x))), xs)
        input_pvals = (carry_pvals, scan_pvals)
        in_pvals, in_tree = jax.tree_flatten(input_pvals)
        f_flat, out_tree = jax.api_util.flatten_fun_nokwargs(
            lu.wrap_init(broadcast_body), in_tree)
        _, out_pvals, _ = pe.trace_to_jaxpr(f_flat, in_pvals)

        out_flat = []
        for pv, const in out_pvals:
            if pv is not None:
                raise ValueError(
                    'broadcasted variable has a data dependency on the scan body.'
                )
            out_flat.append(const)
        broadcast_in, constants_out = jax.tree_unflatten(out_tree(), out_flat)

        c, ys = lax.scan(body_fn, init, xs, length=length, reverse=reverse)
        ys = jax.tree_multimap(transpose_from_front, out_axes, ys)
        ys = jax.tree_multimap(
            lambda ax, const, y: (const if ax is broadcast else y), out_axes,
            constants_out, ys)
        return broadcast_in, c, ys
예제 #5
0
def _get_jax_objects(function, args, kwargs):
    # Set up function for transformation
    wrapped_function = j_linear_util.wrap_init(function)
    # Flatten input arguments
    jax_arguments, in_tree = j_tree_util.tree_flatten((args, kwargs))
    # Transform function to accept flat arguments
    # and return a flat list result
    jaxtree_function, _ = j_api_util.flatten_fun(wrapped_function, in_tree)
    # Abstract and partial-value's flat arguments
    partial_values = j_util.safe_map(_get_partial_value, jax_arguments)
    # Trace function into Jaxpr
    jaxpr, _, constants = ji_partial_eval.trace_to_jaxpr(
        jaxtree_function, partial_values
    )

    result = (jaxpr, constants)
    return result
예제 #6
0
 def wrapped(*args, **kwargs):
   fun = lu.wrap_init(f, kwargs)
   flat_args, in_tree = tree_util.tree_flatten(args)
   flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree)
   flat_avals = safe_map(get_shaped_aval, flat_args)
   if dynamic:
     jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(
         flat_fun,
         flat_avals)
   else:
     pvals = [pe.PartialVal.unknown(aval) for aval in flat_avals]
     jaxpr, _, consts = pe.trace_to_jaxpr(
         flat_fun,
         pvals,
         instantiate=True)
   typed_jaxpr = jax_core.ClosedJaxpr(jaxpr, consts)
   return typed_jaxpr, (in_tree, out_tree())
예제 #7
0
 def wrapped(*args, **kwargs):
     fun = lu.wrap_init(f, kwargs)
     flat_args, in_tree = tree_util.tree_flatten(args)
     flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree)
     flat_avals = safe_map(get_shaped_aval, flat_args)
     if not jax.config.omnistaging_enabled:
         raise ValueError('Oryx must be used with JAX omnistaging enabled.')
     if dynamic:
         jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals)
     else:
         pvals = [
             pe.PartialVal((aval, jax_core.unit)) for aval in flat_avals
         ]
         jaxpr, _, consts = pe.trace_to_jaxpr(flat_fun,
                                              pvals,
                                              instantiate=True)
     typed_jaxpr = jax_core.ClosedJaxpr(jaxpr, consts)
     return typed_jaxpr, (in_tree, out_tree())
예제 #8
0
def closure_convert(fun, in_tree, in_avals):
    in_pvals = [pe.PartialVal.unknown(aval) for aval in in_avals]
    wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
    with core.initial_style_staging():
        jaxpr, out_pvals, consts = pe.trace_to_jaxpr(wrapped_fun,
                                                     in_pvals,
                                                     instantiate=True,
                                                     stage_out=False)
    out_tree = out_tree()
    num_consts = len(consts)

    def converted_fun(y, t, *consts_args):
        consts, args = split_list(consts_args, [num_consts])
        all_args, in_tree2 = tree_flatten((y, t, *args))
        assert in_tree == in_tree2
        out_flat = core.eval_jaxpr(jaxpr, consts, *all_args)
        return tree_unflatten(out_tree, out_flat)

    return converted_fun, consts
예제 #9
0
파일: ad.py 프로젝트: jbampton/jax
def linearize(traceable, *primals, **kwargs):
  has_aux = kwargs.pop('has_aux', False)
  if not has_aux:
    jvpfun = jvp(traceable)
  else:
    jvpfun, aux = jvp(traceable, has_aux=True)

  in_pvals = (tuple(pe.PartialVal.known(p) for p in primals)
              + tuple(pe.PartialVal.unknown(get_aval(p).at_least_vspace())
                    for p in primals))
  _, in_tree = tree_flatten(((primals, primals), {}))
  jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree)
  jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
  out_primals_pvals, out_tangents_pvals = tree_unflatten(out_tree(), out_pvals)
  assert all(out_primal_pval.is_known() for out_primal_pval in out_primals_pvals)
  _, out_primals_consts = unzip2(out_primals_pvals)
  jaxpr.invars = jaxpr.invars[len(primals):]
  jaxpr.outvars = jaxpr.outvars[len(out_primals_pvals):]
  if not has_aux:
    return out_primals_consts, out_tangents_pvals, jaxpr, consts
  else:
    return out_primals_consts, out_tangents_pvals, jaxpr, consts, aux()
예제 #10
0
파일: util.py 프로젝트: ColCarroll/numpyro
def _tscan(f, a, bs, fields=(0, )):
    """
    Works as jax.lax.scan but has additional `fields` argument to select only
    necessary fields from `a`'s structure. Defaults to selecting only the first
    field. Other fields will be filled by None.
    """
    # Note: code is copied and modified from lax.scan implementation in
    # [JAX](https://github.com/google/jax) to support the additional `fields`
    # arg. Original code has the following copyright:
    #
    # Copyright 2018 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License")

    # convert pytree to flat jaxtuple
    a, a_tree = pytree_to_flatjaxtuple(a)
    bs, b_tree = pytree_to_flatjaxtuple(bs)
    fields, _ = pytree_to_flatjaxtuple(fields)
    f, out_tree = pytree_fun_to_flatjaxtuple_fun(wrap_init(f),
                                                 (a_tree, b_tree))

    # convert arrays to abstract values
    a_aval, _ = lax._abstractify(a)
    bs_aval, _ = lax._abstractify(bs)
    # convert bs to b
    b_aval = core.AbstractTuple(
        [ShapedArray(b.shape[1:], b.dtype) for b in bs_aval])

    # convert abstract values to partial values (?) then evaluate to get jaxpr
    a_pval = partial_eval.PartialVal((a_aval, core.unit))
    b_pval = partial_eval.PartialVal((b_aval, core.unit))
    jaxpr, pval_out, consts = partial_eval.trace_to_jaxpr(f, (a_pval, b_pval))
    aval_out, _ = pval_out
    consts = core.pack(consts)

    out = tscan_p.bind(a, bs, fields, consts, aval_out=aval_out, jaxpr=jaxpr)
    return tree_unflatten(out_tree(), out)
예제 #11
0
파일: ode.py 프로젝트: yangliuy/jax
def closure_convert(fun, in_tree, in_avals):
  in_pvals = [pe.PartialVal.unknown(aval) for aval in in_avals]
  wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
  with core.initial_style_staging():
    jaxpr, out_pvals, consts = pe.trace_to_jaxpr(
      wrapped_fun, in_pvals, instantiate=True, stage_out=False)
  out_tree = out_tree()

  # We only want to closure convert for constants with respect to which we're
  # differentiating. As a proxy for that, we hoist consts with float dtype.
  # TODO(mattjj): revise this approach
  is_float = lambda c: dtypes.issubdtype(dtypes.dtype(c), jnp.inexact)
  (closure_consts, hoisted_consts), merge = partition_list(is_float, consts)
  num_consts = len(hoisted_consts)

  def converted_fun(y, t, *hconsts_args):
    hoisted_consts, args = split_list(hconsts_args, [num_consts])
    consts = merge(closure_consts, hoisted_consts)
    all_args, in_tree2 = tree_flatten((y, t, *args))
    assert in_tree == in_tree2
    out_flat = core.eval_jaxpr(jaxpr, consts, *all_args)
    return tree_unflatten(out_tree, out_flat)

  return converted_fun, hoisted_consts
예제 #12
0
def _make_typed_jaxpr(traceable, in_avals):
  pvals = [pe.PartialVal((aval, core.unit)) for aval in in_avals]
  jaxpr, pvals_out, consts = pe.trace_to_jaxpr(traceable, pvals, instantiate=True)
  out_avals, _ = unzip2(pvals_out)
  return core.TypedJaxpr(jaxpr, consts, in_avals, out_avals)
예제 #13
0
def while_loop(cond_fun, body_fun, init_val):
  """Call ``body_fun`` repeatedly in a loop while ``cond_fun`` is True.

  The type signature in brief is

  .. code-block:: haskell

    while_loop :: (a -> Bool) -> (a -> a) -> a -> a

  The semantics of ``while_loop`` are given by this Python implementation::

    def while_loop(cond_fun, body_fun, init_val):
      val = init_val
      while cond_fun(val):
        val = body_fun(val)
      return val

  Unlike that Python version, ``while_loop`` is a JAX primitive and is lowered
  to a single XLA While HLO. That makes it useful for reducing compilation times
  for jit-compiled functions, since native Python loop constructs in an ``@jit``
  function are unrolled, leading to large XLA computations.

  Another difference from using Python-native loop constructs is that
  ``while_loop`` is not reverse-mode differentiable because XLA computations
  require static bounds on memory requirements.

  Args:
    cond_fun: function of type ``a -> Bool``.
    body_fun: function of type ``a -> a``.
    init_val: value of type ``a``, a type that can be a scalar, array, or any
      pytree (nested Python tuple/list/dict) thereof, representing the initial
      loop carry value.

  Returns:
    The output from the final iteration of body_fun, of type ``a``.
  """
  init_val_flat, in_tree = pytree_to_jaxtupletree(init_val)
  flat_body_fun, out_tree = pytree_fun_to_jaxtupletree_fun(lu.wrap_init(body_fun), (in_tree,))
  flat_cond_fun, _ = pytree_fun_to_jaxtupletree_fun(lu.wrap_init(cond_fun), (in_tree,))

  carry_pval_flat = carry_aval, _ = _abstractify(init_val_flat)
  cond_jaxpr, cond_pval_out, cond_consts = pe.trace_to_jaxpr(flat_cond_fun, (carry_pval_flat,))
  body_jaxpr, body_pval_out, body_consts = pe.trace_to_jaxpr(flat_body_fun, (carry_pval_flat,), instantiate=True)
  carry_aval_out, _ = body_pval_out
  assert isinstance(carry_aval_out, core.AbstractValue)
  assert carry_aval == core.lattice_join(carry_aval, carry_aval_out)

  cond_pv, cond_const = cond_pval_out
  if cond_pv is None:
    # cond_fun evaluates to a constant, so don't need to generate a while_loop
    if cond_const:
      raise ValueError("infinite loop with no effects")
    else:
      return init_val
  else:
    assert isinstance(cond_pv, core.AbstractValue)
    if (not isinstance(cond_pv, ShapedArray) or cond_pv.shape
        or cond_pv.dtype != onp.bool_):
      msg = "while_loop cond_fun must return a scalar boolean, got {}."
      raise TypeError(msg.format(cond_pv))

  if out_tree() != in_tree:
    raise TypeError("body_fun input and output must have identical structure")
  out_flat = while_p.bind(
      init_val_flat, core.pack(cond_consts), core.pack(body_consts),
      aval_out=carry_aval_out, cond_jaxpr=cond_jaxpr, body_jaxpr=body_jaxpr)
  return build_tree(out_tree(), out_flat)
예제 #14
0
def _make_typed_jaxpr(traceable, in_avals):
  pvals = [pe.PartialVal((aval, core.unit)) for aval in in_avals]
  jaxpr, pval_out, consts = pe.trace_to_jaxpr(traceable, pvals, instantiate=True)
  out_aval, _ = pval_out
  assert isinstance(out_aval, core.AbstractValue)
  return core.TypedJaxpr(jaxpr, consts, in_avals, out_aval)
예제 #15
0
def scan(f, init, xs):
  """Scan a function over leading array axes while carrying along state.

  The type signature in brief is

  .. code-block:: haskell

    scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b])

  where we use [t] here to denote the type t with an additional leading axis.
  That is, if t is an array type then [t] represents the type with an additional
  leading axis, and if t is a pytree (container) type with array leaves then [t]
  represents the type with the same pytree structure and corresponding leaves
  each with an additional leading axis.

  When both ``a`` and ``b`` are array types, the semantics of ``scan`` are given
  by this Python implementation::

    def scan(f, init, xs):
      carry = init
      ys = []
      for x in xs:
        carry, y = f(carry, x)
        ys.append(y)
      return carry, np.stack(ys)

  Unlike that Python version, both ``a`` and ``b`` may be arbitrary pytree
  types, and so multiple arrays can be scanned over at once and produce multiple
  output arrays.

  Also unlike that Python version, ``scan`` is a JAX primitive and is lowered to
  a single XLA While HLO. That makes it useful for reducing compilation times
  for jit-compiled functions, since native Python loop constructs in an ``@jit``
  function are unrolled, leading to large XLA computations.

  Args:
    f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning
      that ``f`` accepts two arguments where the first is a value of the loop
      carry and the second is a slice of ``xs`` along its leading axis, and that
      ``f`` returns a pair where the first element represents a new value for
      the loop carry and the second represents a slice of the output.
    init: an initial loop carry value of type ``c``, which can be a scalar,
      array, or any pytree (nested Python tuple/list/dict) thereof, representing
      the initial loop carry value.
    xs: the value of type ``[a]`` over which to scan along the leading axis,
      where ``[a]`` can be an array or any pytree (nested Python
      tuple/list/dict) thereof with consistent leading axis sizes.

  Returns:
    A pair of type ``(c, [b])`` where the first element represents the final
    loop carry value and the second element represents the stacked outputs of
    the second output of ``f`` when scanned over the leading axis of the inputs.
  """
  (init, xs), in_trees = unzip2(map(pytree_to_jaxtupletree, (init, xs)))
  f, out_tree = pytree_fun_to_jaxtupletree_fun(lu.wrap_init(f), in_trees)
  carry_pval = carry_aval, _ = _abstractify(init)
  xs_aval, _ = _abstractify(xs)
  x_aval = _demote_aval_rank(xs_aval)
  x_pval = pe.PartialVal((x_aval, core.unit))
  jaxpr, pval_out, consts = pe.trace_to_jaxpr(
      f, (carry_pval, x_pval), instantiate=True)
  pv_out, const_out = pval_out
  assert isinstance(pv_out, core.AbstractValue) and const_out == core.unit
  if not isinstance(pv_out, core.AbstractTuple) or len(pv_out) != 2:
    msg = ("scanned function must have signature `c -> a -> (c, b)`, but the "
           "output was not a pair: got type {}.")
    raise TypeError(msg.format(pv_out))
  carry_aval_out, y_aval = pv_out
  if carry_aval != carry_aval_out:
    msg = ("scanned function carry output does not match carry input: "
           "input carry is {} and output carry is {}.")
    raise TypeError(msg.format(carry_aval, carry_aval_out))
  lifted_jaxpr = pe._closure_convert_jaxpr(jaxpr)
  consts_aval, _ = _abstractify(core.pack(consts))
  in_avals = (consts_aval, carry_aval, x_aval)
  out_aval = core.AbstractTuple((carry_aval, y_aval))
  jaxpr = core.TypedJaxpr(lifted_jaxpr, (), in_avals, out_aval)
  length = _leading_dim_size(xs)
  out = scan_p.bind(core.pack(consts), init, xs,
                    forward=True, length=length, jaxpr=jaxpr)
  return build_tree(out_tree(), out)
예제 #16
0
def inv_backward_pass(jaxpr: core.Jaxpr, consts, primals_in, primals_out, cotangents_in):
  if all(type(ct) is ad.Zero for ct in cotangents_in):
    return map(lambda v: ad.Zero(v.aval), jaxpr.invars)

  def write_cotangent(v, ct):
    # assert v not in primal_env
    if ct is not None and type(v) is not Literal:
      ct_env[v] = ad.add_tangents(ct_env[v], ct) if v in ct_env else ct

  def read_cotangent(v):
    return ct_env.get(v, ad.Zero(v.aval))

  def read_primal(v):
    if type(v) is Literal:
      return v.val
    else:
      return primal_env.get(v, ad.UndefinedPrimal(v.aval))

  def write_primal(v, val):
    if type(v) is Literal:
      return
    primal_env.setdefault(v, val)

  # Invert while computing cotangents
  ct_env: Dict[Any, Any] = {}
  primal_env: Dict[Any, Any] = {}
  write_primal(core.unitvar, core.unit)
  map(write_primal, jaxpr.invars, primals_in)
  map(write_primal, jaxpr.outvars, primals_out)
  map(write_primal, jaxpr.constvars, consts)
  map(write_cotangent, jaxpr.outvars, cotangents_in)
  for eqn in jaxpr.eqns[::-1]:
    primals_in = map(read_primal, eqn.invars)
    primals_out = map(read_primal, eqn.outvars)
    cts_in = map(read_cotangent, eqn.outvars)
    should_invert = any(type(primal) is not ad.UndefinedPrimal
                        for primal in primals_out)
    should_vjp = any(type(ct) is not ad.Zero for ct in cts_in)
    assert not eqn.primitive.call_primitive

    # Skip primals equations that are only jvp coefficients and don't affect
    # primal outputs.
    if not should_invert and not should_vjp:
      continue

    def abstract(value):
      return raise_to_shaped(value.aval if ad.is_undefined_primal(value) else get_aval(value))

    # Get the ivjp_jaxpr
    if eqn.primitive is custom_ivjp_p:
      ivjp_jaxpr = eqn.params['ivjp_jaxpr']
    else:
      if eqn.primitive in primitive_ivjps:
        complete_ivjp = lu.wrap_init(primitive_ivjps[eqn.primitive])
      else:
        complete_ivjp = lu.wrap_init(partial(synthesize_ivjp, eqn, map(ad.is_undefined_primal, primals_in)))
      _, in_tree = tree_flatten(
          tuple(map(abstract, x) for x in (primals_in, primals_out, primals_out)))
      complete_ivjp_flat, _ = flatten_fun_nokwargs(complete_ivjp, in_tree)

      in_avals = map(abstract, primals_in + primals_out + primals_out)
      # TODO: Actually we do know some of the inputs, because they might be literals!
      ivjp_jaxpr, out_pvals, _ = pe.trace_to_jaxpr(
          complete_ivjp_flat, map(pe.PartialVal.unknown, in_avals), instantiate=True)
      assert not ivjp_jaxpr.constvars  # That might happen some time, but don't bother until then
      ivjp_jaxpr = core.ClosedJaxpr(ivjp_jaxpr, [])

    # Once we know what the ivjp can do exactly, we have to isolate the part we are
    # actually able to compute with the values we have at hand.
    num_inputs = len(eqn.invars)
    unknowns = (map(ad.is_undefined_primal, primals_in) +
                map(ad.is_undefined_primal, primals_out) +
                [False] * len(cts_in))
    jaxpr_known, jaxpr_unknown, out_unknowns = pe.partial_eval_jaxpr(  # type: ignore
        ivjp_jaxpr, unknowns, instantiate=False)  # type:ignore
    unknown_rec_primals_in, unknown_cotangents = split_list(out_unknowns, [num_inputs])
    # Make sure we're able to compute all cotangents. We don't really care if we
    # can reconstruct or primals or not, although failure to do so might result in
    # failing to compute cotangents later.
    assert not any(unknown_cotangents)
    # Remove residual outputs -- we won't be computing the unknown jaxpr anyway.
    num_outputs = len(jaxpr_unknown.jaxpr.outvars)
    jaxpr_known.jaxpr.outvars = jaxpr_known.jaxpr.outvars[:num_outputs]
    # TODO: We could drop the outputs that correspond to primals that we already know.
    #       This only matters in eager mode, so leaving it out for now...
    ivjp = core.jaxpr_as_fun(jaxpr_known)
    rec_primals_in, cts_out = split_list(ivjp(*primals_in, *primals_out, *cts_in),
                                         [num_inputs])
    # Unknown rec_primals_in are core.units, so we have to replace them
    # with UnknownPrimals because that's what write_primal will ignore.
    rec_primals_in = [prev if unknown else rec
                      for prev, rec, unknown
                      in zip(primals_in, rec_primals_in, unknown_rec_primals_in)]
    map(write_primal, eqn.invars, rec_primals_in)
    map(write_cotangent, [v for v in eqn.invars if type(v) is not Literal], cts_out)

  # NOTE: We keep the cotangents associated with primal variables, while the contract of a
  #       transpose is to return them in positions associated with tangent variables, which
  #       is what causes this whole confusion.
  return map(read_cotangent, jaxpr.invars)
예제 #17
0
 def DIABLED_test_print_jaxpr_compound(self):
     # TODO(dougalm): figure out what jaxpr-tracing api to expose and re-enable
     pv = pe.PartialVal((ShapedArray((2, 3), onp.float32), core.unit))
     print(pe.trace_to_jaxpr(fun_with_call_closure, (pv, ))[0])
예제 #18
0
def while_loop(cond_fun, body_fun, init_val):
    """Call ``body_fun`` repeatedly in a loop while ``cond_fun`` is True.

  The type signature in brief is

  .. code-block:: haskell

    while_loop :: (a -> Bool) -> (a -> a) -> a -> a

  The semantics of ``while_loop`` are given by this Python implementation::

    def while_loop(cond_fun, body_fun, init_val):
      val = init_val
      while cond_fun(val):
        val = body_fun(val)
      return val

  Unlike that Python version, ``while_loop`` is a JAX primitive and is lowered
  to a single XLA While HLO. That makes it useful for reducing compilation times
  for jit-compiled functions, since native Python loop constructs in an ``@jit``
  function are unrolled, leading to large XLA computations.

  Another difference from using Python-native loop constructs is that
  ``while_loop`` is not reverse-mode differentiable because XLA computations
  require static bounds on memory requirements.

  Args:
    cond_fun: function of type ``a -> Bool``.
    body_fun: function of type ``a -> a``.
    init_val: value of type ``a``, a type that can be a scalar, array, or any
      pytree (nested Python tuple/list/dict) thereof, representing the initial
      loop carry value.

  Returns:
    The output from the final iteration of body_fun, of type ``a``.
  """
    init_val_flat, in_tree = pytree_to_jaxtupletree(init_val)
    flat_body_fun, out_tree = pytree_fun_to_jaxtupletree_fun(
        lu.wrap_init(body_fun), (in_tree, ))
    flat_cond_fun, _ = pytree_fun_to_jaxtupletree_fun(lu.wrap_init(cond_fun),
                                                      (in_tree, ))

    carry_pval_flat = carry_aval, _ = _abstractify(init_val_flat)
    cond_jaxpr, cond_pval_out, cond_consts = pe.trace_to_jaxpr(
        flat_cond_fun, (carry_pval_flat, ))
    body_jaxpr, body_pval_out, body_consts = pe.trace_to_jaxpr(
        flat_body_fun, (carry_pval_flat, ), instantiate=True)
    carry_aval_out, _ = body_pval_out
    assert isinstance(carry_aval_out, core.AbstractValue)
    assert carry_aval == core.lattice_join(carry_aval, carry_aval_out)

    cond_pv, cond_const = cond_pval_out
    if cond_pv is None:
        # cond_fun evaluates to a constant, so don't need to generate a while_loop
        if cond_const:
            raise ValueError("infinite loop with no effects")
        else:
            return init_val
    else:
        assert isinstance(cond_pv, core.AbstractValue)
        if (not isinstance(cond_pv, ShapedArray) or cond_pv.shape
                or cond_pv.dtype != onp.bool_):
            msg = "while_loop cond_fun must return a scalar boolean, got {}."
            raise TypeError(msg.format(cond_pv))

    # We don't want to promote literal constants as loop arguments; there are
    # sometimes many of them. We pass tracers as loop arguments, but leave
    # nontracers as constants. We also sort the constants so the nontracers are
    # first.
    def split_tracers_and_nontracers(jaxpr, consts):
        tracer = []
        nontracer = []
        for x in zip(jaxpr.constvars, consts):
            # TODO(phawkins): We avoid treating DeviceArrays as constant literals so
            # we don't copy large arrays back to the host. We probably should relax
            # this and either always copy small constants, or opportunistically use
            # DeviceArray values for which we already know npy_value.
            not_literal_const = isinstance(x[1],
                                           (core.Tracer, xla.DeviceArray))
            (tracer if not_literal_const else nontracer).append(x)
        tracer_vars, tracer_consts = unzip2(tracer)
        nontracer_vars, nontracer_consts = unzip2(nontracer)
        return nontracer_vars + tracer_vars, nontracer_consts, tracer_consts

    cond_split = split_tracers_and_nontracers(cond_jaxpr, cond_consts)
    cond_jaxpr.constvars, cond_nontracer_consts, cond_tracer_consts = cond_split
    body_split = split_tracers_and_nontracers(body_jaxpr, body_consts)
    body_jaxpr.constvars, body_nontracer_consts, body_tracer_consts = body_split

    if out_tree() != in_tree:
        raise TypeError(
            "body_fun input and output must have identical structure")
    out_flat = while_p.bind(
        init_val_flat,
        core.pack(cond_tracer_consts),
        core.pack(body_tracer_consts),
        cond_consts=lax._OpaqueParam(cond_nontracer_consts),
        body_consts=lax._OpaqueParam(body_nontracer_consts),
        aval_out=carry_aval_out,
        cond_jaxpr=cond_jaxpr,
        body_jaxpr=body_jaxpr)
    return build_tree(out_tree(), out_flat)
예제 #19
0
    def inner(scope_fn, repack_fn, variable_groups_xs, rng_groups_xs):
        # split rngs
        split_fn = lambda rng: random.split(rng, length)
        broadcast_rngs_xs = []
        scan_rngs_xs = []
        for rng_groups in rng_groups_xs:
            broadcast_rngs_xs.append(
                tuple(rng_group
                      for rng_group, split in zip(rng_groups, rng_splits)
                      if not split))
            scan_rngs_xs.append(
                tuple(
                    jax.tree_map(split_fn, rng_group)
                    for rng_group, split in zip(rng_groups, rng_splits)
                    if split))

        def body(carry, xs, init_mode=False):
            carry_vars_xs, c = carry
            scan_vars_xs, scan_rngs_xs, x = xs
            variable_groups_xs = combine(scan_vars_xs, carry_vars_xs,
                                         broadcast_vars_xs)
            rng_groups_xs = []
            for broadcast_rngs, scan_rngs in zip(broadcast_rngs_xs,
                                                 scan_rngs_xs):
                rng_groups_xs.append(broadcast_rngs + scan_rngs)
            scope = scope_fn(variable_groups_xs, rng_groups_xs)
            carry, y = fn(scope, c, x)
            out_vars = repack_fn(scope)
            scan_vars_xs, carry_vars_out_xs, broadcast_vars_out_xs = split(
                out_vars, 1)

            # TODO(jheek) more informative error check
            def check_shapes(c_in, c_out):
                if not isinstance(c_in, jnp.ndarray) or not isinstance(
                        c_out, jnp.ndarray):
                    return
                if jnp.shape(c_in) != jnp.shape(c_out) or jnp.dtype(
                        c_in) != jnp.dtype(c_out):
                    raise ValueError()

            try:
                jax.tree_multimap(check_shapes, carry_vars_xs,
                                  carry_vars_out_xs)
            except ValueError:
                raise ValueError(
                    'carry variables must have the same shape and dtype before and after scan.'
                )

            if init_mode:
                return broadcast_vars_out_xs
            else:
                return (carry_vars_out_xs, carry), (scan_vars_xs, y)

        broadcast_body = functools.partial(body, init_mode=True)

        scan_vars_xs, carry_vars_xs, broadcast_vars_xs = split(
            variable_groups_xs, 0)
        carry0 = (carry_vars_xs, init_carry)
        xxs = (scan_vars_xs, scan_rngs_xs, xs)

        # use partial evaluation to find the variables that are broadcasted out
        # an error is thrown if a broadcasted output has a dependency on any scan variables
        carry_pvals = jax.tree_map(
            lambda x: pe.PartialVal.unknown(jax.ShapedArray(x.shape, x.dtype)),
            carry0)
        scan_pvals = jax.tree_map(
            lambda x: pe.PartialVal.unknown(
                jax.ShapedArray(x.shape[1:], x.dtype)), xxs)
        input_pvals = (carry_pvals, scan_pvals)
        in_pvals, in_tree = jax.tree_flatten(input_pvals)
        f_flat, out_tree = jax.api_util.flatten_fun_nokwargs(
            lu.wrap_init(broadcast_body), in_tree)

        _, out_pvals, _ = pe.trace_to_jaxpr(f_flat, in_pvals)
        # _, out_pvals, _ = pe.trace_to_jaxpr(f_flat, in_pvals, stage_out=True)

        out_flat = []
        for pv, const in out_pvals:
            if pv is not None:
                raise ValueError(
                    'broadcasted variable has a data dependency on the scan body.'
                )
            out_flat.append(const)

        (carry_vars_xs, carry), (scan_vars_xs, ys) = lax.scan(body,
                                                              carry0,
                                                              xxs,
                                                              length=length,
                                                              reverse=reverse)

        broadcast_vars_xs = jax.tree_unflatten(out_tree(), out_flat)

        out_vars_xs = combine(carry_vars_xs, scan_vars_xs, broadcast_vars_xs)
        return (carry, ys), out_vars_xs
예제 #20
0
파일: core.py 프로젝트: juliuskunze/jaxnet
def _instantiated_trace_to_jaxpr(fun, avals):
    pvals = map(lambda aval: pe.PartialVal((aval, unit)), avals)
    jaxpr, out_pvals, consts = pe.trace_to_jaxpr(fun, pvals, instantiate=True)
    out_avals, _ = unzip2(out_pvals)
    return jaxpr, out_avals, consts
예제 #21
0
def while_loop(cond_fun, body_fun, init_val):
    """Call `body_fun` repeatedly in a loop while `cond_fun` is True.

  Arguments:
    cond_fun: pure function of type `T -> Bool`.
    body_fun: pure function of type `T -> T`.
    init_val: value of type `T`, a type that can be a scalar, array, or any
      (nested) Python tuple/list/dict thereof.

  Returns:
    The output from the final iteration of body_fun, of type `T`.

  The semantics of `while_loop` are given by this Python implementation::

    def while_loop(cond_fun, body_fun, init_val):
      val = init_val
      while cond_fun(val):
        val = body_fun(val)
      return val

  Unlike that pure Python version, `while_loop` is a JAX primitive and is
  lowered to a single XLA While HLO. That makes it useful for reducing
  compilation times for jit-compiled functions, since native Python loop
  constructs in an `@jit` function are unrolled, leading to large XLA
  computations.

  Another difference from using Python-native loop constructs is that
  `while_loop` is not (yet) reverse-mode differentiable because XLA computations
  require static bounds on memory requirements.
  """
    init_val_flat, in_tree = pytree_to_jaxtupletree(init_val)
    flat_body_fun, out_tree = pytree_fun_to_jaxtupletree_fun(
        lu.wrap_init(body_fun), (in_tree, ))
    flat_cond_fun, _ = pytree_fun_to_jaxtupletree_fun(lu.wrap_init(cond_fun),
                                                      (in_tree, ))

    pval_flat = lax._abstractify(init_val_flat)
    cond_jaxpr, _, cond_consts = pe.trace_to_jaxpr(flat_cond_fun,
                                                   (pval_flat, ))
    body_jaxpr, pval_out, body_consts = pe.trace_to_jaxpr(
        flat_body_fun, (pval_flat, ))
    aval_out, _ = pval_out

    # We don't want to promote literal constants as loop arguments; there are
    # sometimes many of them. We pass tracers as loop arguments, but leave
    # nontracers as constants. We also sort the constants so the nontracers are
    # first.
    def split_tracers_and_nontracers(jaxpr, consts):
        tracer = []
        nontracer = []
        for x in zip(jaxpr.constvars, consts):
            # TODO(phawkins): We avoid treating DeviceArrays as constant literals so
            # we don't copy large arrays back to the host. We probably should relax
            # this and either always copy small constants, or opportunistically use
            # DeviceArray values for which we already know npy_value.
            not_literal_const = isinstance(x[1],
                                           (core.Tracer, xla.DeviceArray))
            (tracer if not_literal_const else nontracer).append(x)
        tracer_vars, tracer_consts = unzip2(tracer)
        nontracer_vars, nontracer_consts = unzip2(nontracer)
        return nontracer_vars + tracer_vars, nontracer_consts, tracer_consts

    cond_split = split_tracers_and_nontracers(cond_jaxpr, cond_consts)
    cond_jaxpr.constvars, cond_nontracer_consts, cond_tracer_consts = cond_split
    body_split = split_tracers_and_nontracers(body_jaxpr, body_consts)
    body_jaxpr.constvars, body_nontracer_consts, body_tracer_consts = body_split

    if out_tree() != in_tree:
        raise TypeError(
            "body_fun input and output must have identical structure")
    out_flat = while_p.bind(
        init_val_flat,
        core.pack(cond_tracer_consts),
        core.pack(body_tracer_consts),
        cond_consts=lax._OpaqueParam(cond_nontracer_consts),
        body_consts=lax._OpaqueParam(body_nontracer_consts),
        aval_out=aval_out,
        cond_jaxpr=cond_jaxpr,
        body_jaxpr=body_jaxpr)
    return build_tree(out_tree(), out_flat)
예제 #22
0
 def trace_jaxpr(fun, operand):
   op_flat, in_tree = pytree_to_flatjaxtuple(operand)
   fun_flat, out_tree = pytree_fun_to_flatjaxtuple_fun(lu.wrap_init(fun), (in_tree,))
   jaxpr, pvout, consts = pe.trace_to_jaxpr(fun_flat, (_abstractify(op_flat),))
   return op_flat, jaxpr, consts, pvout, out_tree