Ejemplo n.º 1
0
def scan(f, init, xs):
    """Scan a function over leading array axes while carrying along state.

  The type signature in brief is

  .. code-block:: haskell

    scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b])

  where we use [t] here to denote the type t with an additional leading axis.
  That is, if t is an array type then [t] represents the type with an additional
  leading axis, and if t is a pytree (container) type with array leaves then [t]
  represents the type with the same pytree structure and corresponding leaves
  each with an additional leading axis.

  When both ``a`` and ``b`` are array types, the semantics of ``scan`` are given
  by this Python implementation::

    def scan(f, init, xs):
      carry = init
      ys = []
      for x in xs:
        carry, y = f(carry, x)
        ys.append(y)
      return carry, np.stack(ys)

  Unlike that Python version, both ``a`` and ``b`` may be arbitrary pytree
  types, and so multiple arrays can be scanned over at once and produce multiple
  output arrays.

  Also unlike that Python version, ``scan`` is a JAX primitive and is lowered to
  a single XLA While HLO. That makes it useful for reducing compilation times
  for jit-compiled functions, since native Python loop constructs in an ``@jit``
  function are unrolled, leading to large XLA computations.

  Args:
    f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning
      that ``f`` accepts two arguments where the first is a value of the loop
      carry and the second is a slice of ``xs`` along its leading axis, and that
      ``f`` returns a pair where the first element represents a new value for
      the loop carry and the second represents a slice of the output.
    init: an initial loop carry value of type ``c``, which can be a scalar,
      array, or any pytree (nested Python tuple/list/dict) thereof, representing
      the initial loop carry value.
    xs: the value of type ``[a]`` over which to scan along the leading axis,
      where ``[a]`` can be an array or any pytree (nested Python
      tuple/list/dict) thereof with consistent leading axis sizes.

  Returns:
    A pair of type ``(c, [b])`` where the first element represents the final
    loop carry value and the second element represents the stacked outputs of
    the second output of ``f`` when scanned over the leading axis of the inputs.
  """
    (init, xs), in_trees = unzip2(map(pytree_to_jaxtupletree, (init, xs)))
    f, out_tree = pytree_fun_to_jaxtupletree_fun(lu.wrap_init(f), in_trees)
    carry_pval = carry_aval, _ = _abstractify(init)
    xs_aval, _ = _abstractify(xs)
    x_aval = _demote_aval_rank(xs_aval)
    x_pval = pe.PartialVal((x_aval, core.unit))
    jaxpr, pval_out, consts = pe.trace_to_jaxpr(f, (carry_pval, x_pval),
                                                instantiate=True)
    pv_out, const_out = pval_out
    assert isinstance(pv_out, core.AbstractValue) and const_out == core.unit
    if not isinstance(pv_out, core.AbstractTuple) or len(pv_out) != 2:
        msg = (
            "scanned function must have signature `c -> a -> (c, b)`, but the "
            "output was not a pair: got type {}.")
        raise TypeError(msg.format(pv_out))
    carry_aval_out, y_aval = pv_out
    if carry_aval != carry_aval_out:
        msg = ("scanned function carry output does not match carry input: "
               "input carry is {} and output carry is {}.")
        raise TypeError(msg.format(carry_aval, carry_aval_out))
    lifted_jaxpr = pe._closure_convert_jaxpr(jaxpr)
    consts_aval, _ = _abstractify(core.pack(consts))
    in_avals = (consts_aval, carry_aval, x_aval)
    out_aval = core.AbstractTuple((carry_aval, y_aval))
    jaxpr = core.TypedJaxpr(lifted_jaxpr, (), in_avals, out_aval)
    length = _leading_dim_size(xs)
    out = scan_p.bind(core.pack(consts),
                      init,
                      xs,
                      forward=True,
                      length=length,
                      jaxpr=jaxpr)
    return build_tree(out_tree(), out)
Ejemplo n.º 2
0
def propagate(cell_type: Type[Cell], rules: Dict[jax_core.Primitive,
                                                 PropagationRule],
              jaxpr: pe.Jaxpr, constcells: List[Cell], incells: List[Cell],
              outcells: List[Cell]) -> Environment:
    """Propagates cells in a Jaxpr using a set of rules.

  Args:
    cell_type: used to instantiate literals into cells
    rules: maps JAX primitives to propagation rule functions
    jaxpr: used to construct the propagation graph
    constcells: used to populate the Jaxpr's constvars
    incells: used to populate the Jaxpr's invars
    outcells: used to populate the Jaxpr's outcells
  Returns:
    The Jaxpr environment after propagation has terminated
  """
    env = Environment(cell_type, jaxpr)

    safe_map(env.write, jaxpr.constvars, constcells)
    safe_map(env.write, jaxpr.invars, incells)
    safe_map(env.write, jaxpr.outvars, outcells)

    eqns = safe_map(Equation.from_jaxpr_eqn, jaxpr.eqns)

    get_neighbor_eqns = construct_graph_representation(eqns)
    # Initialize propagation queue with equations neighboring constvars, invars,
    # and outvars.
    out_eqns = set()
    for var in it.chain(jaxpr.outvars, jaxpr.invars, jaxpr.constvars):
        out_eqns.update(get_neighbor_eqns(var))
    queue = collections.deque(out_eqns)
    done_eqns = set()
    # checked_eqns is used to stop propagation if all equations in queue have
    # been checked without the propagation progressing
    checked_eqns = set()
    while queue:
        eqn = queue.popleft()
        assert eqn not in done_eqns

        incells = safe_map(env.read, eqn.invars)
        outcells = safe_map(env.read, eqn.outvars)

        rule = rules[eqn.primitive]
        call_jaxpr, params = jax_core.extract_call_jaxpr(
            eqn.primitive, eqn.params)
        if call_jaxpr:
            subfuns = [
                lu.wrap_init(
                    functools.partial(propagate, cell_type, rules, call_jaxpr,
                                      ()))
            ]
        else:
            subfuns = []
        new_incells, new_outcells, done, subenv = rule(subfuns + incells,
                                                       outcells, **params)
        if subenv:
            env.write_subenv(eqn, subenv)

        safe_map(env.write, eqn.invars, new_incells)
        safe_map(env.write, eqn.outvars, new_outcells)

        update_queue_state(queue, eqn, get_neighbor_eqns, done, done_eqns,
                           checked_eqns, incells, outcells, new_incells,
                           new_outcells)
        if checked_eqns == set(queue):
            break
    return env
Ejemplo n.º 3
0
def jet(fun, primals, series):
    r"""Taylor-mode higher-order automatic differentiation.

  Args:
    fun: Function to be differentiated. Its arguments should be arrays, scalars,
      or standard Python containers of arrays or scalars. It should return an
      array, scalar, or standard Python container of arrays or scalars.
    primals: The primal values at which the Taylor approximation of ``fun`` should be
      evaluated. Should be either a tuple or a list of arguments,
      and its length should be equal to the number of positional parameters of
      ``fun``.
    series: Higher order Taylor-series-coefficients.
      Together, `primals` and `series` make up a truncated Taylor polynomial.
      Should be either a tuple or a list of tuples or lists,
      and its length dictates the degree of the truncated Taylor polynomial.

  Returns:
    A ``(primals_out, series_out)`` pair, where ``primals_out`` is ``fun(*primals)``,
    and together, ``primals_out`` and ``series_out`` are a
    truncated Taylor polynomial of :math:`f(h(\cdot))`.
    The ``primals_out`` value has the same Python tree structure as ``primals``,
    and the ``series_out`` value the same Python tree structure as ``series``.

  For example:

  >>> import jax
  >>> import jax.numpy as np

  Consider the function :math:`h(z) = z^3`, :math:`x = 0.5`,
  and the first few Taylor coefficients
  :math:`h_0=x^3`, :math:`h_1=3x^2`, and :math:`h_2=6x`.
  Let :math:`f(y) = \sin(y)`.

  >>> h0, h1, h2 = 0.5**3., 3.*0.5**2., 6.*0.5
  >>> f, df, ddf = np.sin, np.cos, lambda *args: -np.sin(*args)

  :func:`jet` returns the Taylor coefficients of :math:`f(h(z)) = \sin(z^3)`
  according to Faà di Bruno's formula:

  >>> f0, (f1, f2) =  jet(f, (h0,), ((h1, h2),))
  >>> print(f0,  f(h0))
  0.12467473 0.12467473

  >>> print(f1, df(h0) * h1)
  0.7441479 0.74414825

  >>> print(f2, ddf(h0) * h1 ** 2 + df(h0) * h2)
  2.9064622 2.9064634
  """
    try:
        order, = set(map(len, series))
    except ValueError:
        msg = "jet terms have inconsistent lengths for different arguments"
        raise ValueError(msg) from None

    # TODO(mattjj): consider supporting pytree inputs
    for i, (x, terms) in enumerate(zip(primals, series)):
        treedef = tree_structure(x)
        if not treedef_is_leaf(treedef):
            raise ValueError(
                "primal value at position {} is not an array".format(i))
        for j, t in enumerate(terms):
            treedef = tree_structure(t)
            if not treedef_is_leaf(treedef):
                raise ValueError(
                    "term {} for argument {} is not an array".format(j, i))

    @lu.transformation_with_aux
    def flatten_fun_output(*args):
        ans = yield args, {}
        yield tree_flatten(ans)

    f, out_tree = flatten_fun_output(lu.wrap_init(fun))
    out_primals, out_terms = jet_fun(jet_subtrace(f),
                                     order).call_wrapped(primals, series)
    return tree_unflatten(out_tree(),
                          out_primals), tree_unflatten(out_tree(), out_terms)
Ejemplo n.º 4
0
def propagate(cell_type: Type[Cell],
              rules: Dict[jax_core.Primitive, PropagationRule],
              jaxpr: pe.Jaxpr,
              constcells: List[Cell],
              incells: List[Cell],
              outcells: List[Cell],
              reducer: Callable[[Environment, Equation, State, State],
                                State] = identity_reducer,
              initial_state: State = None) -> Tuple[Environment, State]:
    """Propagates cells in a Jaxpr using a set of rules.

  Args:
    cell_type: used to instantiate literals into cells
    rules: maps JAX primitives to propagation rule functions
    jaxpr: used to construct the propagation graph
    constcells: used to populate the Jaxpr's constvars
    incells: used to populate the Jaxpr's invars
    outcells: used to populate the Jaxpr's outcells
    reducer: An optional callable used to reduce over the state at each
      equation in the Jaxpr. `reducer` takes in `(env, eqn, state, new_state)`
      as arguments and should return an updated state. The `new_state` value
      is provided by each equation.
    initial_state: The initial `state` value used in the reducer
  Returns:
    The Jaxpr environment after propagation has terminated
  """
    env = Environment(cell_type, jaxpr)
    safe_map(env.write, jaxpr.constvars, constcells)
    safe_map(env.write, jaxpr.invars, incells)
    safe_map(env.write, jaxpr.outvars, outcells)

    eqns = safe_map(Equation.from_jaxpr_eqn, jaxpr.eqns)

    get_neighbor_eqns = construct_graph_representation(eqns)
    # Initialize propagation queue with equations neighboring constvars, invars,
    # and outvars.
    out_eqns = set()
    for eqn in jaxpr.eqns:
        for var in it.chain(eqn.invars, eqn.outvars):
            env.write(var, cell_type.unknown(var.aval))

    for var in it.chain(jaxpr.outvars, jaxpr.invars, jaxpr.constvars):
        out_eqns.update(get_neighbor_eqns(var))
    queue = collections.deque(out_eqns)
    while queue:
        eqn = queue.popleft()

        incells = safe_map(env.read, eqn.invars)
        outcells = safe_map(env.read, eqn.outvars)

        call_jaxpr, params = jax_core.extract_call_jaxpr(
            eqn.primitive, eqn.params)
        if call_jaxpr:
            subfuns = [
                lu.wrap_init(
                    functools.partial(propagate,
                                      cell_type,
                                      rules,
                                      call_jaxpr, (),
                                      initial_state=initial_state,
                                      reducer=reducer))
            ]
            if eqn.primitive not in rules:
                rule = default_call_rules.get(eqn.primitive)
            else:
                rule = rules[eqn.primitive]
        else:
            subfuns = []
            rule = rules[eqn.primitive]
        new_incells, new_outcells, eqn_state = rule(subfuns + incells,
                                                    outcells, **params)
        env.write_state(eqn, eqn_state)

        new_incells = safe_map(env.write, eqn.invars, new_incells)
        new_outcells = safe_map(env.write, eqn.outvars, new_outcells)

        update_queue_state(queue, eqn, get_neighbor_eqns, incells, outcells,
                           new_incells, new_outcells)
    state = initial_state
    for eqn in eqns:
        state = reducer(env, eqn, state, env.read_state(eqn))
    return env, state
Ejemplo n.º 5
0
def linear_call(fun: Callable, fun_transpose: Callable, residual_args,
                linear_args):
    """Call a linear function, with a custom implementation for its transpose.

  The type signatures of ``fun`` and ``fun_transpose`` are:

  .. code-block:: haskell

    fun           :: r -> a -o b
    fun_transpose :: r -> b -o a

  where the ``-o`` arrow indicates a linear function, ``r`` is the
  residual input type and ``a`` is the linear input type.

  The functions ``fun`` and ``fun_transpose`` are coupled as
  transposes of one another. Specifically, the transpose of a
  ``linear_call`` primitive is another ``linear_call`` to
  ``fun_transpose``, with ``fun`` as its custom transposition.

  For example:

  >>> def f(r, x):
  ...   return x / r

  >>> def t(r, t):
  ...   return t / r

  >>> def div_add(x, denom):
  ...   return x + linear_call(f, t, denom, x)

  >>> def transpose(f, x_example):
  ...   def transposed(y):
  ...     x, = jax.linear_transpose(f, x_example)(y)
  ...     return x
  ...   return transposed

  >>> div_add(9., 3.)
  DeviceArray(12., dtype=float32)

  >>> transpose(partial(div_add, denom=3.), 1.)(18.)  # custom
  DeviceArray(24., dtype=float32)

  >>> transpose(lambda x: x + x / 3., 1.)(18.)  # reference
  DeviceArray(24., dtype=float32)

  The above definition of ``f`` illustrates the purpose of a residual
  argument: division is linear in one of its inputs (the dividend
  ``x``) but not the other (the divisor ``r``).

  As another example:

  >>> def custom_id(x):
  ...   def f(_, x): return x
  ...   def t(_, t): return 7.
  ...   return linear_call(f, t, (), x)
  >>> custom_id(1.)
  1.0
  >>> transpose(custom_id, 1.)(1.)
  7.0
  >>> transpose(transpose(custom_id, 1.), 1.)(1.)
  1.0
  >>> transpose(transpose(transpose(custom_id, 1.), 1.), 1.)(1.)
  7.0

  Args:
    fun: a Python callable specifying a linear function. It should
      take two arguments: one of "residual" inputs (type ``r``),
      i.e. inputs in which the function is not necessarly linear, and
      one of "linear" inputs (type ``a``).  It should return output
      whose components are linear in the linear input (type ``b``).
    fun_transpose: a Python callable specifying a structurally linear
      function that is the transpose of ``fun`` with respect to its
      linear inputs. Its first argument is the same residual inputs
      (``r``) as ``fun``. Its second argument is of type
      ``b``. Finally, its output is of type ``a`` and each of its
      component are linear in its second argument (the ``b`` inputs).
    residual_args: Argument in which ``fun`` and ``fun_transpose`` are
      not necessarily linear. Not involved in transposition.
    linear_args: Argument in which ``fun`` and ``fun_transpose`` are
      linear and with respect to which the two are transposes.

  Returns:
    The call result, i.e. ``fun(residual_args, linear_args)``.

  """
    operands_res, res_tree = tree_flatten(residual_args)
    operands_lin, lin_tree = tree_flatten(linear_args)

    f_in_tree = treedef_tuple((res_tree, lin_tree))
    f, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), f_in_tree)

    res_avals = map(abstractify, operands_res)
    lin_avals = map(abstractify, operands_lin)
    f_jaxpr, f_consts = _initial_style_jaxpr(f, (*res_avals, *lin_avals))
    f_jaxpr = _close_jaxpr(f_jaxpr)
    out_avals = map(core.raise_to_shaped, f_jaxpr.out_avals)

    t_in_tree = treedef_tuple((res_tree, out_tree()))
    t, t_out_tree = flatten_fun_nokwargs(lu.wrap_init(fun_transpose),
                                         t_in_tree)

    t_jaxpr, t_consts = _initial_style_jaxpr(t, (*res_avals, *out_avals))
    t_jaxpr = _close_jaxpr(t_jaxpr)

    if t_out_tree() != lin_tree:
        raise TypeError(
            'transpose output pytree structure must match that of linear inputs, '
            f'got output structure {t_out_tree()} '
            f'and input structure {lin_tree}.')

    out = linear_call_p.bind(*f_consts,
                             *t_consts,
                             *operands_res,
                             *operands_lin,
                             callee=f_jaxpr,
                             transpose=t_jaxpr,
                             num_callee_consts=len(f_consts),
                             num_transpose_consts=len(t_consts),
                             num_res=len(operands_res))

    return tree_unflatten(out_tree(), out)
Ejemplo n.º 6
0
 def doit():
     f = lu.wrap_init(fun)
     args_flat, in_tree = tree_util.tree_flatten((args, {}))
     flat_fun, out_tree = flatten_fun(f, in_tree)
     out_flat = _interpret_fun(flat_fun, args_flat)
     return tree_util.tree_unflatten(out_tree(), out_flat)
Ejemplo n.º 7
0
def inv_backward_pass(jaxpr: core.Jaxpr, consts, primals_in, primals_out, cotangents_in):
  if all(type(ct) is ad.Zero for ct in cotangents_in):
    return map(lambda v: ad.Zero(v.aval), jaxpr.invars)

  def write_cotangent(v, ct):
    # assert v not in primal_env
    if ct is not None and type(v) is not Literal:
      ct_env[v] = ad.add_tangents(ct_env[v], ct) if v in ct_env else ct

  def read_cotangent(v):
    return ct_env.get(v, ad.Zero(v.aval))

  def read_primal(v):
    if type(v) is Literal:
      return v.val
    else:
      return primal_env.get(v, ad.UndefinedPrimal(v.aval))

  def write_primal(v, val):
    if type(v) is Literal:
      return
    primal_env.setdefault(v, val)

  # Invert while computing cotangents
  ct_env: Dict[Any, Any] = {}
  primal_env: Dict[Any, Any] = {}
  write_primal(core.unitvar, core.unit)
  map(write_primal, jaxpr.invars, primals_in)
  map(write_primal, jaxpr.outvars, primals_out)
  map(write_primal, jaxpr.constvars, consts)
  map(write_cotangent, jaxpr.outvars, cotangents_in)
  for eqn in jaxpr.eqns[::-1]:
    primals_in = map(read_primal, eqn.invars)
    primals_out = map(read_primal, eqn.outvars)
    cts_in = map(read_cotangent, eqn.outvars)
    should_invert = any(type(primal) is not ad.UndefinedPrimal
                        for primal in primals_out)
    should_vjp = any(type(ct) is not ad.Zero for ct in cts_in)
    assert not eqn.primitive.call_primitive

    # Skip primals equations that are only jvp coefficients and don't affect
    # primal outputs.
    if not should_invert and not should_vjp:
      continue

    def abstract(value):
      return raise_to_shaped(value.aval if ad.is_undefined_primal(value) else get_aval(value))

    # Get the ivjp_jaxpr
    if eqn.primitive is custom_ivjp_p:
      ivjp_jaxpr = eqn.params['ivjp_jaxpr']
    else:
      if eqn.primitive in primitive_ivjps:
        complete_ivjp = lu.wrap_init(primitive_ivjps[eqn.primitive])
      else:
        complete_ivjp = lu.wrap_init(partial(synthesize_ivjp, eqn, map(ad.is_undefined_primal, primals_in)))
      _, in_tree = tree_flatten(
          tuple(map(abstract, x) for x in (primals_in, primals_out, primals_out)))
      complete_ivjp_flat, _ = flatten_fun_nokwargs(complete_ivjp, in_tree)

      in_avals = map(abstract, primals_in + primals_out + primals_out)
      # TODO: Actually we do know some of the inputs, because they might be literals!
      ivjp_jaxpr, out_pvals, _ = pe.trace_to_jaxpr(
          complete_ivjp_flat, map(pe.PartialVal.unknown, in_avals), instantiate=True)
      assert not ivjp_jaxpr.constvars  # That might happen some time, but don't bother until then
      ivjp_jaxpr = core.ClosedJaxpr(ivjp_jaxpr, [])

    # Once we know what the ivjp can do exactly, we have to isolate the part we are
    # actually able to compute with the values we have at hand.
    num_inputs = len(eqn.invars)
    unknowns = (map(ad.is_undefined_primal, primals_in) +
                map(ad.is_undefined_primal, primals_out) +
                [False] * len(cts_in))
    jaxpr_known, jaxpr_unknown, out_unknowns = pe.partial_eval_jaxpr(  # type: ignore
        ivjp_jaxpr, unknowns, instantiate=False)  # type:ignore
    unknown_rec_primals_in, unknown_cotangents = split_list(out_unknowns, [num_inputs])
    # Make sure we're able to compute all cotangents. We don't really care if we
    # can reconstruct or primals or not, although failure to do so might result in
    # failing to compute cotangents later.
    assert not any(unknown_cotangents)
    # Remove residual outputs -- we won't be computing the unknown jaxpr anyway.
    num_outputs = len(jaxpr_unknown.jaxpr.outvars)
    jaxpr_known.jaxpr.outvars = jaxpr_known.jaxpr.outvars[:num_outputs]
    # TODO: We could drop the outputs that correspond to primals that we already know.
    #       This only matters in eager mode, so leaving it out for now...
    ivjp = core.jaxpr_as_fun(jaxpr_known)
    rec_primals_in, cts_out = split_list(ivjp(*primals_in, *primals_out, *cts_in),
                                         [num_inputs])
    # Unknown rec_primals_in are core.units, so we have to replace them
    # with UnknownPrimals because that's what write_primal will ignore.
    rec_primals_in = [prev if unknown else rec
                      for prev, rec, unknown
                      in zip(primals_in, rec_primals_in, unknown_rec_primals_in)]
    map(write_primal, eqn.invars, rec_primals_in)
    map(write_cotangent, [v for v in eqn.invars if type(v) is not Literal], cts_out)

  # NOTE: We keep the cotangents associated with primal variables, while the contract of a
  #       transpose is to return them in positions associated with tangent variables, which
  #       is what causes this whole confusion.
  return map(read_cotangent, jaxpr.invars)
Ejemplo n.º 8
0
 def _batch(args, dims, **params):
     batched, out_dims = batching.batch_fun2(
         lu.wrap_init(self.impl, params), dims)
     return batched.call_wrapped(*args), out_dims()
Ejemplo n.º 9
0
def eval_sparse(
    jaxpr: core.Jaxpr,
    consts: Sequence[Array],  # all consts are dense
    spvalues: Sequence[
        SparsifyValue],  # mix of sparse and dense pointers into spenv
    spenv: SparsifyEnv,
) -> Sequence[SparsifyValue]:
    env: Dict[core.Var, SparsifyValue] = {}

    def read(var: core.Var) -> Union[Array, SparsifyValue]:
        # all literals are dense
        if isinstance(var, core.Literal):
            return spenv.dense(var.val)
        else:
            return env[var]

    def write_buffer(var: core.Var, a: Array) -> None:
        if isinstance(var, core.DropVar):
            return
        env[var] = spenv.dense(a)

    def write(var: core.Var, a: SparsifyValue) -> None:
        if isinstance(var, core.DropVar):
            return
        assert a is not None
        env[var] = a

    # TODO: handle unitvar at all?
    #write_buffer(core.unitvar, core.unit)
    safe_map(write_buffer, jaxpr.constvars, consts)
    safe_map(write, jaxpr.invars, spvalues)

    for eqn in jaxpr.eqns:
        prim = eqn.primitive
        invals = safe_map(read, eqn.invars)

        if any(val.is_sparse() for val in invals):
            if prim not in sparse_rules:
                raise NotImplementedError(f"sparse rule for {prim}")
            out = sparse_rules[prim](spenv, *invals, **eqn.params)
        else:
            if prim is xla.xla_call_p:
                # TODO(vanderplas,frostig): workaround for binding call primitives
                # within a jaxpr interpreter
                params = eqn.params.copy()
                fun = lu.wrap_init(
                    core.jaxpr_as_fun(
                        pe.ClosedJaxpr(params.pop('call_jaxpr'), ())))
                out_bufs = prim.bind(fun, *(spenv.data(val) for val in invals),
                                     **params)
            else:
                out_bufs = prim.bind(*(spenv.data(val) for val in invals),
                                     **eqn.params)
            out_bufs = out_bufs if prim.multiple_results else [out_bufs]
            out = []
            for buf, outvar in safe_zip(out_bufs, eqn.outvars):
                if isinstance(outvar, core.DropVar):
                    out.append(None)
                else:
                    out.append(spenv.dense(buf))
        safe_map(write, eqn.outvars, out)

    return safe_map(read, jaxpr.outvars)
Ejemplo n.º 10
0
 def f_(*args, **kwargs):
     f = lu.wrap_init(fun, kwargs)
     f_partial, dyn_args = argnums_partial(
         f, argnums, args, require_static_args_hashable=False
     )
     return scan_fun(lambda x: f_partial.call_wrapped(*x), dyn_args)
Ejemplo n.º 11
0
def core_closed_call(f, *args):
    args, in_tree = tree_flatten(args)
    f, out_tree = flatten_fun_nokwargs(lu.wrap_init(f), in_tree)
    out = core.closed_call_p.bind(f, *args)
    return tree_unflatten(out_tree(), out)
Ejemplo n.º 12
0
def submerge_consts(jaxpr, consts, invals=None):
    """
  Replace constvars with literals in jaxpr and its sub-jaxprs.
  """
    # TODO(j-towns): check that consts are in jax.core.literalable_types
    consts = dict(zip(jaxpr.constvars, consts))
    if invals is not None:
        # We're in a call_jaxpr
        new_jaxpr_invars = []
        for var, val in zip(jaxpr.invars, invals):
            if isinstance(val, Var):
                new_jaxpr_invars.append(var)
            else:
                consts[var] = val
    else:
        new_jaxpr_invars = jaxpr.invars
    new_eqns = []
    for eqn in jaxpr.eqns:
        if all(
                isinstance(var, Literal) or var in consts
                for var in eqn.invars):
            # Perform constant folding if all inputs to an eqn are known
            in_vals = [
                var.val if isinstance(var, Literal) else consts[var]
                for var in eqn.invars
            ]
            call_jaxpr, params = jc.extract_call_jaxpr(eqn.primitive,
                                                       eqn.params)
            if call_jaxpr:
                subfuns = [
                    lu.wrap_init(partial(jc.eval_jaxpr, call_jaxpr, ()))
                ]
            else:
                subfuns = []
            ans = eqn.primitive.bind(*(subfuns + in_vals), **params)
            if eqn.primitive.multiple_results:
                for outvar, out in zip(eqn.outvars, ans):
                    consts[outvar] = out
            else:
                outvar, = eqn.outvars
                consts[outvar] = ans
        else:
            new_invars = [
                consts[var] if
                (isinstance(var, Var) and var in consts) else var
                for var in eqn.invars
            ]
            new_params = dict(eqn.params)
            if eqn.primitive.call_primitive or eqn.primitive.map_primitive:
                new_params['call_jaxpr'] = submerge_consts(
                    eqn.params['call_jaxpr'], [], new_invars)
                new_invars = [
                    var for var in new_invars if isinstance(var, Var)
                ]
            else:
                new_invars = [
                    var if isinstance(var, (Var, Literal)) else Literal_(var)
                    for var in new_invars
                ]
            new_eqns.append(
                JaxprEqn(invars=new_invars,
                         outvars=eqn.outvars,
                         primitive=eqn.primitive,
                         params=new_params,
                         source_info=eqn.source_info))
    return Jaxpr([], new_jaxpr_invars, jaxpr.outvars, new_eqns)
Ejemplo n.º 13
0
def jvp_jaxpr(jaxpr):
  f = lu.wrap_init(jaxpr_as_fun(jaxpr))
  dimvars = dict((v, v.aval) for v in jaxpr.in_dim_binders)
  in_avals = [_replace_vars_with_avals(dimvars, v.aval) for v in jaxpr.in_binders]
  jaxpr, consts, _ = trace_to_jaxpr_dynamic(jvp_traceable(ad.jvp(f)), in_avals * 2)
  return jaxpr, consts
Ejemplo n.º 14
0
def make_djaxpr(fun, *args, **kwargs):
  args, in_tree = tree_flatten((args, kwargs))
  f, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
  in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args]
  return trace_to_jaxpr_dynamic(f, in_avals)
Ejemplo n.º 15
0
def while_loop(cond_fun, body_fun, init_val):
    """Call ``body_fun`` repeatedly in a loop while ``cond_fun`` is True.

  The type signature in brief is

  .. code-block:: haskell

    while_loop :: (a -> Bool) -> (a -> a) -> a -> a

  The semantics of ``while_loop`` are given by this Python implementation::

    def while_loop(cond_fun, body_fun, init_val):
      val = init_val
      while cond_fun(val):
        val = body_fun(val)
      return val

  Unlike that Python version, ``while_loop`` is a JAX primitive and is lowered
  to a single XLA While HLO. That makes it useful for reducing compilation times
  for jit-compiled functions, since native Python loop constructs in an ``@jit``
  function are unrolled, leading to large XLA computations.

  Another difference from using Python-native loop constructs is that
  ``while_loop`` is not reverse-mode differentiable because XLA computations
  require static bounds on memory requirements.

  Args:
    cond_fun: function of type ``a -> Bool``.
    body_fun: function of type ``a -> a``.
    init_val: value of type ``a``, a type that can be a scalar, array, or any
      pytree (nested Python tuple/list/dict) thereof, representing the initial
      loop carry value.

  Returns:
    The output from the final iteration of body_fun, of type ``a``.
  """
    init_val_flat, in_tree = pytree_to_jaxtupletree(init_val)
    flat_body_fun, out_tree = pytree_fun_to_jaxtupletree_fun(
        lu.wrap_init(body_fun), (in_tree, ))
    flat_cond_fun, _ = pytree_fun_to_jaxtupletree_fun(lu.wrap_init(cond_fun),
                                                      (in_tree, ))

    carry_pval_flat = carry_aval, _ = _abstractify(init_val_flat)
    cond_jaxpr, cond_pval_out, cond_consts = pe.trace_to_jaxpr(
        flat_cond_fun, (carry_pval_flat, ))
    body_jaxpr, body_pval_out, body_consts = pe.trace_to_jaxpr(
        flat_body_fun, (carry_pval_flat, ), instantiate=True)
    carry_aval_out, _ = body_pval_out
    assert isinstance(carry_aval_out, core.AbstractValue)
    assert carry_aval == core.lattice_join(carry_aval, carry_aval_out)

    cond_pv, cond_const = cond_pval_out
    if cond_pv is None:
        # cond_fun evaluates to a constant, so don't need to generate a while_loop
        if cond_const:
            raise ValueError("infinite loop with no effects")
        else:
            return init_val
    else:
        assert isinstance(cond_pv, core.AbstractValue)
        if (not isinstance(cond_pv, ShapedArray) or cond_pv.shape
                or cond_pv.dtype != onp.bool_):
            msg = "while_loop cond_fun must return a scalar boolean, got {}."
            raise TypeError(msg.format(cond_pv))

    # We don't want to promote literal constants as loop arguments; there are
    # sometimes many of them. We pass tracers as loop arguments, but leave
    # nontracers as constants. We also sort the constants so the nontracers are
    # first.
    def split_tracers_and_nontracers(jaxpr, consts):
        tracer = []
        nontracer = []
        for x in zip(jaxpr.constvars, consts):
            # TODO(phawkins): We avoid treating DeviceArrays as constant literals so
            # we don't copy large arrays back to the host. We probably should relax
            # this and either always copy small constants, or opportunistically use
            # DeviceArray values for which we already know npy_value.
            not_literal_const = isinstance(x[1],
                                           (core.Tracer, xla.DeviceArray))
            (tracer if not_literal_const else nontracer).append(x)
        tracer_vars, tracer_consts = unzip2(tracer)
        nontracer_vars, nontracer_consts = unzip2(nontracer)
        return nontracer_vars + tracer_vars, nontracer_consts, tracer_consts

    cond_split = split_tracers_and_nontracers(cond_jaxpr, cond_consts)
    cond_jaxpr.constvars, cond_nontracer_consts, cond_tracer_consts = cond_split
    body_split = split_tracers_and_nontracers(body_jaxpr, body_consts)
    body_jaxpr.constvars, body_nontracer_consts, body_tracer_consts = body_split

    if out_tree() != in_tree:
        raise TypeError(
            "body_fun input and output must have identical structure")
    out_flat = while_p.bind(
        init_val_flat,
        core.pack(cond_tracer_consts),
        core.pack(body_tracer_consts),
        cond_consts=lax._OpaqueParam(cond_nontracer_consts),
        body_consts=lax._OpaqueParam(body_nontracer_consts),
        aval_out=carry_aval_out,
        cond_jaxpr=cond_jaxpr,
        body_jaxpr=body_jaxpr)
    return build_tree(out_tree(), out_flat)
Ejemplo n.º 16
0
 def _jvp(primals, tangents, **params):
     return ad.jvp(lu.wrap_init(self.impl, params)).call_wrapped(
         primals, tangents)
Ejemplo n.º 17
0
    def jvp_of_rule_rule(axis_size, in_batched, primals, tangents):
        in_batched_ps, in_batched_ts = in_batched

        mutually_batched = tree_map(operator.and_, in_batched_ps,
                                    in_batched_ts)
        extra_batched_ps = tree_map(
            lambda pb, tb: 0
            if pb and not tb else None, in_batched_ps, in_batched_ts)
        extra_batched_ts = tree_map(
            lambda pb, tb: 0
            if tb and not pb else None, in_batched_ps, in_batched_ts)

        out_mutually_batched = lu.Store()
        flat_ps_ts, tree_ps_ts = tree_flatten((primals, tangents))
        flat_extra_batched_ps_ts, tree_ps_ts2 = tree_flatten(
            (extra_batched_ps, extra_batched_ts), is_leaf=lambda x: x is None)

        # TODO(frostig): assert these also equal:
        #   treedef_tuple((in_tree, in_tree))
        # once https://github.com/google/jax/issues/9066 is fixed
        assert tree_ps_ts == tree_ps_ts2
        del tree_ps_ts2

        def to_jvp(*primals):
            outs, out_batched = call_rule(rule, axis_size, mutually_batched,
                                          primals)
            check_vmap_rule_trees(rule, tree_structure(outs),
                                  tree_structure(out_batched))
            out_mutually_batched.store(out_batched)
            return outs

        def to_vmap_over_extra_batched_dims(primals, tangents):
            return jax.jvp(to_jvp, primals, tangents)

        to_vmap_over_extra_batched_dims_flat, out_tree = flatten_fun_nokwargs(
            lu.wrap_init(to_vmap_over_extra_batched_dims), tree_ps_ts)

        flat_out_ps_ts, flat_out_axes = vmap_unrestricted(
            to_vmap_over_extra_batched_dims_flat,
            *flat_ps_ts,
            in_axes=flat_extra_batched_ps_ts,
            axis_name=core.no_axis_name,
            axis_size=axis_size)

        n, ragged = divmod(len(flat_out_ps_ts), 2)
        assert not ragged
        flat_out_ps, flat_out_ts = flat_out_ps_ts[:n], flat_out_ps_ts[n:]
        flat_out_axes_p, flat_out_axes_t = flat_out_axes[:n], flat_out_axes[n:]
        flat_out_ps = map(maybe_bdim_at_front, flat_out_ps, flat_out_axes_p)
        flat_out_extra_batched_ps = [
            d is not not_mapped for d in flat_out_axes_p
        ]
        flat_out_ts = map(maybe_bdim_at_front, flat_out_ts, flat_out_axes_t)
        flat_out_extra_batched_ts = [
            d is not not_mapped for d in flat_out_axes_t
        ]

        out_ps, out_ts = tree_unflatten(out_tree(),
                                        [*flat_out_ps, *flat_out_ts])
        out_extra_batched_ps, out_extra_batched_ts = tree_unflatten(
            out_tree(),
            [*flat_out_extra_batched_ps, *flat_out_extra_batched_ts])

        out_batched_ps = tree_map(operator.or_, out_mutually_batched.val,
                                  out_extra_batched_ps)
        out_batched_ts = tree_map(operator.or_, out_mutually_batched.val,
                                  out_extra_batched_ts)

        return (out_ps, out_ts), (out_batched_ps, out_batched_ts)
Ejemplo n.º 18
0
 def _batch(args, dims, **params):
     return batching.batch_fun(lu.wrap_init(self.impl, params), args,
                               dims)
Ejemplo n.º 19
0
def _interpret_jaxpr(jaxpr: core.TypedJaxpr, *args: TfVal) -> Sequence[TfVal]:
    """Evaluate a Jaxpr with tf.Tensor arguments."""
    fun: lu.WrappedFun = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
    out_vals = _interpret_fun(fun, args)
    return out_vals
Ejemplo n.º 20
0
 def wrapped(*args, **kwargs):
   fun = lu.wrap_init(f, kwargs)
   flat_args, in_tree = jax.tree_flatten(args)
   flat_fun, out_tree = jax.flatten_fun_nokwargs(fun, in_tree)
   ans = jax_core.call_p.bind(flat_fun, *flat_args)
   return jax.tree_unflatten(out_tree(), ans)
Ejemplo n.º 21
0
def ravel_first_arg(f, unravel):
  return ravel_first_arg_(lu.wrap_init(f), unravel).call_wrapped
Ejemplo n.º 22
0
def while_loop(cond_fun, body_fun, init_val):
    """Call ``body_fun`` repeatedly in a loop while ``cond_fun`` is True.

  The type signature in brief is

  .. code-block:: haskell

    while_loop :: (a -> Bool) -> (a -> a) -> a -> a

  The semantics of ``while_loop`` are given by this Python implementation::

    def while_loop(cond_fun, body_fun, init_val):
      val = init_val
      while cond_fun(val):
        val = body_fun(val)
      return val

  Unlike that Python version, ``while_loop`` is a JAX primitive and is lowered
  to a single XLA While HLO. That makes it useful for reducing compilation times
  for jit-compiled functions, since native Python loop constructs in an ``@jit``
  function are unrolled, leading to large XLA computations.

  Another difference from using Python-native loop constructs is that
  ``while_loop`` is not reverse-mode differentiable because XLA computations
  require static bounds on memory requirements.

  Args:
    cond_fun: function of type ``a -> Bool``.
    body_fun: function of type ``a -> a``.
    init_val: value of type ``a``, a type that can be a scalar, array, or any
      pytree (nested Python tuple/list/dict) thereof, representing the initial
      loop carry value.

  Returns:
    The output from the final iteration of body_fun, of type ``a``.
  """
    init_val_flat, in_tree = pytree_to_jaxtupletree(init_val)
    flat_body_fun, out_tree = pytree_fun_to_jaxtupletree_fun(
        lu.wrap_init(body_fun), (in_tree, ))
    flat_cond_fun, _ = pytree_fun_to_jaxtupletree_fun(lu.wrap_init(cond_fun),
                                                      (in_tree, ))

    carry_pval_flat = carry_aval, _ = _abstractify(init_val_flat)
    cond_jaxpr, cond_pval_out, cond_consts = pe.trace_to_jaxpr(
        flat_cond_fun, (carry_pval_flat, ))
    body_jaxpr, body_pval_out, body_consts = pe.trace_to_jaxpr(
        flat_body_fun, (carry_pval_flat, ), instantiate=True)
    carry_aval_out, _ = body_pval_out
    assert isinstance(carry_aval_out, core.AbstractValue)
    assert carry_aval == core.lattice_join(carry_aval, carry_aval_out)

    cond_pv, cond_const = cond_pval_out
    if cond_pv is None:
        # cond_fun evaluates to a constant, so don't need to generate a while_loop
        if cond_const:
            raise ValueError("infinite loop with no effects")
        else:
            return init_val
    else:
        assert isinstance(cond_pv, core.AbstractValue)
        if (not isinstance(cond_pv, ShapedArray) or cond_pv.shape
                or cond_pv.dtype != onp.bool_):
            msg = "while_loop cond_fun must return a scalar boolean, got {}."
            raise TypeError(msg.format(cond_pv))

    if out_tree() != in_tree:
        raise TypeError(
            "body_fun input and output must have identical structure")
    out_flat = while_p.bind(init_val_flat,
                            core.pack(cond_consts),
                            core.pack(body_consts),
                            aval_out=carry_aval_out,
                            cond_jaxpr=cond_jaxpr,
                            body_jaxpr=body_jaxpr)
    return build_tree(out_tree(), out_flat)
Ejemplo n.º 23
0
 def checked_fun(*args, **kwargs):
     args_flat, in_tree = tree_flatten((args, kwargs))
     f, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
     (err, code, out_flat), msgs = checkify_flat(f, errors, *args_flat)
     out = tree_unflatten(out_tree(), out_flat)
     return Error(err, code, msgs), out
Ejemplo n.º 24
0
def custom_layer_cau_batch(vals, dims, *, num_consts, in_tree, out_tree,
                           kwargs, **params):
    """Batching rule for layer_cau primitive to handle custom layers."""
    if all(dim is batching.not_mapped for dim in dims):
        return layer_cau_p.bind(*vals,
                                num_consts=num_consts,
                                in_tree=in_tree,
                                out_tree=out_tree,
                                kwargs=kwargs,
                                **params)
    orig_vals, orig_dims = vals, dims
    vals, dims = vals[num_consts:], dims[num_consts:]
    args = tree_util.tree_unflatten(in_tree, vals)
    dims_ = [not_mapped if dim is None else dim for dim in dims]
    layer, args = args[0], args[1:]
    if hasattr(layer, '_call_and_update_batched'):
        num_params = len(tree_util.tree_leaves(layer))
        layer_dims, arg_dims = dims_[:num_params], dims_[num_params:]
        if kwargs['has_rng']:
            rng, args = args[0], args[1:]
            rng_dim, arg_dims = arg_dims[0], arg_dims[1:]
        mapping_over_layer = all(layer_dim is not not_mapped
                                 for layer_dim in layer_dims)
        mapping_over_args = all(arg_dim is not not_mapped
                                for arg_dim in arg_dims)
        assert mapping_over_layer or mapping_over_args, (layer_dims, arg_dims)
        if not mapping_over_layer and mapping_over_args:
            if kwargs['has_rng']:
                if rng_dim is not not_mapped:
                    arg_dims = tuple(None if dim is not_mapped else dim
                                     for dim in arg_dims)
                    map_fun = jax.vmap(
                        lambda layer, rng, *args: _layer_cau_batched(
                            layer,
                            rng,
                            *args,  # pylint: disable=unnecessary-lambda, g-long-lambda
                            **kwargs),
                        in_axes=(None, rng_dim) + (None, ) * len(arg_dims))
                else:
                    map_fun = lambda layer, *args: _layer_cau_batched(
                        layer,
                        *args,  # pylint: disable=unnecessary-lambda, g-long-lambda
                        **kwargs)
                vals_out, update_out = map_fun(layer, rng, *args)
            else:
                vals_out, update_out = _layer_cau_batched(
                    layer, *args, **kwargs)
            vals_out = tree_util.tree_leaves(vals_out)
            update_out = tree_util.tree_leaves(update_out)
            assert all(dim == 0 for dim in arg_dims)
            # Assume dimensions out are consistent
            dims_out = (0, ) * len(vals_out)
            dims_update = (None, ) * len(update_out)
            assert len(vals_out) == len(dims_out)
            assert len(update_out) == len(dims_update)
            return vals_out + update_out, dims_out + dims_update
    batched, out_dims = primitive.batch_fun(
        lu.wrap_init(
            layer_cau_p.impl,
            dict(params,
                 num_consts=num_consts,
                 in_tree=in_tree,
                 out_tree=out_tree,
                 kwargs=kwargs)), orig_dims)
    return batched.call_wrapped(*orig_vals), out_dims()
Ejemplo n.º 25
0
 def wrapped_fun(*args):
   args_flat, in_tree  = tree_flatten(args)
   f = lu.wrap_init(fun)
   flat_fun, out_tree = flatten_fun_nokwargs(f, in_tree)
   out_flat = callback_fun(flat_fun, args_flat, callback, strip_calls)
   return tree_unflatten(out_tree(), out_flat)
Ejemplo n.º 26
0
 def _wrapped(*args):
   args_flat, in_tree = tree_flatten(args, is_leaf=_is_bcoo)
   wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
   out = sparsify_fun(wrapped_fun, args_flat)
   return tree_unflatten(out_tree(), out)
Ejemplo n.º 27
0
def checkify_jaxpr(jaxpr, error):
  f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
  return checkify_fun_to_jaxpr(f, error, jaxpr.in_avals)
Ejemplo n.º 28
0
 def trace_jaxpr(fun, operand):
   op_flat, in_tree = pytree_to_flatjaxtuple(operand)
   fun_flat, out_tree = pytree_fun_to_flatjaxtuple_fun(lu.wrap_init(fun), (in_tree,))
   jaxpr, pvout, consts = pe.trace_to_jaxpr(fun_flat, (_abstractify(op_flat),))
   return op_flat, jaxpr, consts, pvout, out_tree