示例#1
0
def while_loop(cond_fun, body_fun, init_val):
    """Call `body_fun` repeatedly in a loop while `cond_fun` is True.

  Arguments:
    cond_fun: pure function of type `T -> Bool`.
    body_fun: pure function of type `T -> T`.
    init_val: value of type `T`, a type that can be a scalar, array, or any
      (nested) Python tuple/list/dict thereof.

  Returns:
    The output from the final iteration of body_fun, of type `T`.

  The semantics of `while_loop` are given by this Python implementation::

    def while_loop(cond_fun, body_fun, init_val):
      val = init_val
      while cond_fun(val):
        val = body_fun(val)
      return val

  Unlike that pure Python version, `while_loop` is a JAX primitive and is
  lowered to a single XLA While HLO. That makes it useful for reducing
  compilation times for jit-compiled functions, since native Python loop
  constructs in an `@jit` function are unrolled, leading to large XLA
  computations.

  Another difference from using Python-native loop constructs is that
  `while_loop` is not (yet) reverse-mode differentiable because XLA computations
  require static bounds on memory requirements.
  """
    init_val_flat, in_tree = pytree_to_jaxtupletree(init_val)
    flat_body_fun, out_tree = pytree_fun_to_jaxtupletree_fun(
        lu.wrap_init(body_fun), (in_tree, ))
    flat_cond_fun, _ = pytree_fun_to_jaxtupletree_fun(lu.wrap_init(cond_fun),
                                                      (in_tree, ))

    pval_flat = lax._abstractify(init_val_flat)
    cond_jaxpr, _, cond_consts = pe.trace_to_jaxpr(flat_cond_fun,
                                                   (pval_flat, ))
    body_jaxpr, pval_out, body_consts = pe.trace_to_jaxpr(
        flat_body_fun, (pval_flat, ))
    aval_out, _ = pval_out

    # We don't want to promote literal constants as loop arguments; there are
    # sometimes many of them. We pass tracers as loop arguments, but leave
    # nontracers as constants. We also sort the constants so the nontracers are
    # first.
    def split_tracers_and_nontracers(jaxpr, consts):
        tracer = []
        nontracer = []
        for x in zip(jaxpr.constvars, consts):
            # TODO(phawkins): We avoid treating DeviceArrays as constant literals so
            # we don't copy large arrays back to the host. We probably should relax
            # this and either always copy small constants, or opportunistically use
            # DeviceArray values for which we already know npy_value.
            not_literal_const = isinstance(x[1],
                                           (core.Tracer, xla.DeviceArray))
            (tracer if not_literal_const else nontracer).append(x)
        tracer_vars, tracer_consts = unzip2(tracer)
        nontracer_vars, nontracer_consts = unzip2(nontracer)
        return nontracer_vars + tracer_vars, nontracer_consts, tracer_consts

    cond_split = split_tracers_and_nontracers(cond_jaxpr, cond_consts)
    cond_jaxpr.constvars, cond_nontracer_consts, cond_tracer_consts = cond_split
    body_split = split_tracers_and_nontracers(body_jaxpr, body_consts)
    body_jaxpr.constvars, body_nontracer_consts, body_tracer_consts = body_split

    if out_tree() != in_tree:
        raise TypeError(
            "body_fun input and output must have identical structure")
    out_flat = while_p.bind(
        init_val_flat,
        core.pack(cond_tracer_consts),
        core.pack(body_tracer_consts),
        cond_consts=lax._OpaqueParam(cond_nontracer_consts),
        body_consts=lax._OpaqueParam(body_nontracer_consts),
        aval_out=aval_out,
        cond_jaxpr=cond_jaxpr,
        body_jaxpr=body_jaxpr)
    return build_tree(out_tree(), out_flat)
示例#2
0
 def testRoundtripViaBuild(self, inputs):
   xs, tree = _process_pytree(tuple, inputs)
   actual = tree_util.build_tree(tree, xs)
   self.assertEqual(actual, inputs)
示例#3
0
def scan(f, init, xs):
  """Scan a function over leading array axes while carrying along state.

  The type signature in brief is

  .. code-block:: haskell

    scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b])

  where we use [t] here to denote the type t with an additional leading axis.
  That is, if t is an array type then [t] represents the type with an additional
  leading axis, and if t is a pytree (container) type with array leaves then [t]
  represents the type with the same pytree structure and corresponding leaves
  each with an additional leading axis.

  When both ``a`` and ``b`` are array types, the semantics of ``scan`` are given
  by this Python implementation::

    def scan(f, init, xs):
      carry = init
      ys = []
      for x in xs:
        carry, y = f(carry, x)
        ys.append(y)
      return carry, np.stack(ys)

  Unlike that Python version, both ``a`` and ``b`` may be arbitrary pytree
  types, and so multiple arrays can be scanned over at once and produce multiple
  output arrays.

  Also unlike that Python version, ``scan`` is a JAX primitive and is lowered to
  a single XLA While HLO. That makes it useful for reducing compilation times
  for jit-compiled functions, since native Python loop constructs in an ``@jit``
  function are unrolled, leading to large XLA computations.

  Args:
    f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning
      that ``f`` accepts two arguments where the first is a value of the loop
      carry and the second is a slice of ``xs`` along its leading axis, and that
      ``f`` returns a pair where the first element represents a new value for
      the loop carry and the second represents a slice of the output.
    init: an initial loop carry value of type ``c``, which can be a scalar,
      array, or any pytree (nested Python tuple/list/dict) thereof, representing
      the initial loop carry value.
    xs: the value of type ``[a]`` over which to scan along the leading axis,
      where ``[a]`` can be an array or any pytree (nested Python
      tuple/list/dict) thereof with consistent leading axis sizes.

  Returns:
    A pair of type ``(c, [b])`` where the first element represents the final
    loop carry value and the second element represents the stacked outputs of
    the second output of ``f`` when scanned over the leading axis of the inputs.
  """
  (init, xs), in_trees = unzip2(map(pytree_to_jaxtupletree, (init, xs)))
  f, out_tree = pytree_fun_to_jaxtupletree_fun(lu.wrap_init(f), in_trees)
  carry_pval = carry_aval, _ = _abstractify(init)
  xs_aval, _ = _abstractify(xs)
  x_aval = _demote_aval_rank(xs_aval)
  x_pval = pe.PartialVal((x_aval, core.unit))
  jaxpr, pval_out, consts = pe.trace_to_jaxpr(
      f, (carry_pval, x_pval), instantiate=True)
  pv_out, const_out = pval_out
  assert isinstance(pv_out, core.AbstractValue) and const_out == core.unit
  if not isinstance(pv_out, core.AbstractTuple) or len(pv_out) != 2:
    msg = ("scanned function must have signature `c -> a -> (c, b)`, but the "
           "output was not a pair: got type {}.")
    raise TypeError(msg.format(pv_out))
  carry_aval_out, y_aval = pv_out
  if carry_aval != carry_aval_out:
    msg = ("scanned function carry output does not match carry input: "
           "input carry is {} and output carry is {}.")
    raise TypeError(msg.format(carry_aval, carry_aval_out))
  lifted_jaxpr = pe._closure_convert_jaxpr(jaxpr)
  consts_aval, _ = _abstractify(core.pack(consts))
  in_avals = (consts_aval, carry_aval, x_aval)
  out_aval = core.AbstractTuple((carry_aval, y_aval))
  jaxpr = core.TypedJaxpr(lifted_jaxpr, (), in_avals, out_aval)
  length = _leading_dim_size(xs)
  out = scan_p.bind(core.pack(consts), init, xs,
                    forward=True, length=length, jaxpr=jaxpr)
  return build_tree(out_tree(), out)
示例#4
0
def while_loop(cond_fun, body_fun, init_val):
  """Call ``body_fun`` repeatedly in a loop while ``cond_fun`` is True.

  The type signature in brief is

  .. code-block:: haskell

    while_loop :: (a -> Bool) -> (a -> a) -> a -> a

  The semantics of ``while_loop`` are given by this Python implementation::

    def while_loop(cond_fun, body_fun, init_val):
      val = init_val
      while cond_fun(val):
        val = body_fun(val)
      return val

  Unlike that Python version, ``while_loop`` is a JAX primitive and is lowered
  to a single XLA While HLO. That makes it useful for reducing compilation times
  for jit-compiled functions, since native Python loop constructs in an ``@jit``
  function are unrolled, leading to large XLA computations.

  Another difference from using Python-native loop constructs is that
  ``while_loop`` is not reverse-mode differentiable because XLA computations
  require static bounds on memory requirements.

  Args:
    cond_fun: function of type ``a -> Bool``.
    body_fun: function of type ``a -> a``.
    init_val: value of type ``a``, a type that can be a scalar, array, or any
      pytree (nested Python tuple/list/dict) thereof, representing the initial
      loop carry value.

  Returns:
    The output from the final iteration of body_fun, of type ``a``.
  """
  init_val_flat, in_tree = pytree_to_jaxtupletree(init_val)
  flat_body_fun, out_tree = pytree_fun_to_jaxtupletree_fun(lu.wrap_init(body_fun), (in_tree,))
  flat_cond_fun, _ = pytree_fun_to_jaxtupletree_fun(lu.wrap_init(cond_fun), (in_tree,))

  carry_pval_flat = carry_aval, _ = _abstractify(init_val_flat)
  cond_jaxpr, cond_pval_out, cond_consts = pe.trace_to_jaxpr(flat_cond_fun, (carry_pval_flat,))
  body_jaxpr, body_pval_out, body_consts = pe.trace_to_jaxpr(flat_body_fun, (carry_pval_flat,), instantiate=True)
  carry_aval_out, _ = body_pval_out
  assert isinstance(carry_aval_out, core.AbstractValue)
  assert carry_aval == core.lattice_join(carry_aval, carry_aval_out)

  cond_pv, cond_const = cond_pval_out
  if cond_pv is None:
    # cond_fun evaluates to a constant, so don't need to generate a while_loop
    if cond_const:
      raise ValueError("infinite loop with no effects")
    else:
      return init_val
  else:
    assert isinstance(cond_pv, core.AbstractValue)
    if (not isinstance(cond_pv, ShapedArray) or cond_pv.shape
        or cond_pv.dtype != onp.bool_):
      msg = "while_loop cond_fun must return a scalar boolean, got {}."
      raise TypeError(msg.format(cond_pv))

  if out_tree() != in_tree:
    raise TypeError("body_fun input and output must have identical structure")
  out_flat = while_p.bind(
      init_val_flat, core.pack(cond_consts), core.pack(body_consts),
      aval_out=carry_aval_out, cond_jaxpr=cond_jaxpr, body_jaxpr=body_jaxpr)
  return build_tree(out_tree(), out_flat)
示例#5
0
def while_loop(cond_fun, body_fun, init_val):
    """Call ``body_fun`` repeatedly in a loop while ``cond_fun`` is True.

  The type signature in brief is

  .. code-block:: haskell

    while_loop :: (a -> Bool) -> (a -> a) -> a -> a

  The semantics of ``while_loop`` are given by this Python implementation::

    def while_loop(cond_fun, body_fun, init_val):
      val = init_val
      while cond_fun(val):
        val = body_fun(val)
      return val

  Unlike that Python version, ``while_loop`` is a JAX primitive and is lowered
  to a single XLA While HLO. That makes it useful for reducing compilation times
  for jit-compiled functions, since native Python loop constructs in an ``@jit``
  function are unrolled, leading to large XLA computations.

  Another difference from using Python-native loop constructs is that
  ``while_loop`` is not reverse-mode differentiable because XLA computations
  require static bounds on memory requirements.

  Args:
    cond_fun: function of type ``a -> Bool``.
    body_fun: function of type ``a -> a``.
    init_val: value of type ``a``, a type that can be a scalar, array, or any
      pytree (nested Python tuple/list/dict) thereof, representing the initial
      loop carry value.

  Returns:
    The output from the final iteration of body_fun, of type ``a``.
  """
    init_val_flat, in_tree = pytree_to_jaxtupletree(init_val)
    flat_body_fun, out_tree = pytree_fun_to_jaxtupletree_fun(
        lu.wrap_init(body_fun), (in_tree, ))
    flat_cond_fun, _ = pytree_fun_to_jaxtupletree_fun(lu.wrap_init(cond_fun),
                                                      (in_tree, ))

    carry_pval_flat = carry_aval, _ = _abstractify(init_val_flat)
    cond_jaxpr, cond_pval_out, cond_consts = pe.trace_to_jaxpr(
        flat_cond_fun, (carry_pval_flat, ))
    body_jaxpr, body_pval_out, body_consts = pe.trace_to_jaxpr(
        flat_body_fun, (carry_pval_flat, ), instantiate=True)
    carry_aval_out, _ = body_pval_out
    assert isinstance(carry_aval_out, core.AbstractValue)
    assert carry_aval == core.lattice_join(carry_aval, carry_aval_out)

    cond_pv, cond_const = cond_pval_out
    if cond_pv is None:
        # cond_fun evaluates to a constant, so don't need to generate a while_loop
        if cond_const:
            raise ValueError("infinite loop with no effects")
        else:
            return init_val
    else:
        assert isinstance(cond_pv, core.AbstractValue)
        if (not isinstance(cond_pv, ShapedArray) or cond_pv.shape
                or cond_pv.dtype != onp.bool_):
            msg = "while_loop cond_fun must return a scalar boolean, got {}."
            raise TypeError(msg.format(cond_pv))

    # We don't want to promote literal constants as loop arguments; there are
    # sometimes many of them. We pass tracers as loop arguments, but leave
    # nontracers as constants. We also sort the constants so the nontracers are
    # first.
    def split_tracers_and_nontracers(jaxpr, consts):
        tracer = []
        nontracer = []
        for x in zip(jaxpr.constvars, consts):
            # TODO(phawkins): We avoid treating DeviceArrays as constant literals so
            # we don't copy large arrays back to the host. We probably should relax
            # this and either always copy small constants, or opportunistically use
            # DeviceArray values for which we already know npy_value.
            not_literal_const = isinstance(x[1],
                                           (core.Tracer, xla.DeviceArray))
            (tracer if not_literal_const else nontracer).append(x)
        tracer_vars, tracer_consts = unzip2(tracer)
        nontracer_vars, nontracer_consts = unzip2(nontracer)
        return nontracer_vars + tracer_vars, nontracer_consts, tracer_consts

    cond_split = split_tracers_and_nontracers(cond_jaxpr, cond_consts)
    cond_jaxpr.constvars, cond_nontracer_consts, cond_tracer_consts = cond_split
    body_split = split_tracers_and_nontracers(body_jaxpr, body_consts)
    body_jaxpr.constvars, body_nontracer_consts, body_tracer_consts = body_split

    if out_tree() != in_tree:
        raise TypeError(
            "body_fun input and output must have identical structure")
    out_flat = while_p.bind(
        init_val_flat,
        core.pack(cond_tracer_consts),
        core.pack(body_tracer_consts),
        cond_consts=lax._OpaqueParam(cond_nontracer_consts),
        body_consts=lax._OpaqueParam(body_nontracer_consts),
        aval_out=carry_aval_out,
        cond_jaxpr=cond_jaxpr,
        body_jaxpr=body_jaxpr)
    return build_tree(out_tree(), out_flat)