def while_loop(cond_fun, body_fun, init_val): """Call `body_fun` repeatedly in a loop while `cond_fun` is True. Arguments: cond_fun: pure function of type `T -> Bool`. body_fun: pure function of type `T -> T`. init_val: value of type `T`, a type that can be a scalar, array, or any (nested) Python tuple/list/dict thereof. Returns: The output from the final iteration of body_fun, of type `T`. The semantics of `while_loop` are given by this Python implementation:: def while_loop(cond_fun, body_fun, init_val): val = init_val while cond_fun(val): val = body_fun(val) return val Unlike that pure Python version, `while_loop` is a JAX primitive and is lowered to a single XLA While HLO. That makes it useful for reducing compilation times for jit-compiled functions, since native Python loop constructs in an `@jit` function are unrolled, leading to large XLA computations. Another difference from using Python-native loop constructs is that `while_loop` is not (yet) reverse-mode differentiable because XLA computations require static bounds on memory requirements. """ init_val_flat, in_tree = pytree_to_jaxtupletree(init_val) flat_body_fun, out_tree = pytree_fun_to_jaxtupletree_fun( lu.wrap_init(body_fun), (in_tree, )) flat_cond_fun, _ = pytree_fun_to_jaxtupletree_fun(lu.wrap_init(cond_fun), (in_tree, )) pval_flat = lax._abstractify(init_val_flat) cond_jaxpr, _, cond_consts = pe.trace_to_jaxpr(flat_cond_fun, (pval_flat, )) body_jaxpr, pval_out, body_consts = pe.trace_to_jaxpr( flat_body_fun, (pval_flat, )) aval_out, _ = pval_out # We don't want to promote literal constants as loop arguments; there are # sometimes many of them. We pass tracers as loop arguments, but leave # nontracers as constants. We also sort the constants so the nontracers are # first. def split_tracers_and_nontracers(jaxpr, consts): tracer = [] nontracer = [] for x in zip(jaxpr.constvars, consts): # TODO(phawkins): We avoid treating DeviceArrays as constant literals so # we don't copy large arrays back to the host. We probably should relax # this and either always copy small constants, or opportunistically use # DeviceArray values for which we already know npy_value. not_literal_const = isinstance(x[1], (core.Tracer, xla.DeviceArray)) (tracer if not_literal_const else nontracer).append(x) tracer_vars, tracer_consts = unzip2(tracer) nontracer_vars, nontracer_consts = unzip2(nontracer) return nontracer_vars + tracer_vars, nontracer_consts, tracer_consts cond_split = split_tracers_and_nontracers(cond_jaxpr, cond_consts) cond_jaxpr.constvars, cond_nontracer_consts, cond_tracer_consts = cond_split body_split = split_tracers_and_nontracers(body_jaxpr, body_consts) body_jaxpr.constvars, body_nontracer_consts, body_tracer_consts = body_split if out_tree() != in_tree: raise TypeError( "body_fun input and output must have identical structure") out_flat = while_p.bind( init_val_flat, core.pack(cond_tracer_consts), core.pack(body_tracer_consts), cond_consts=lax._OpaqueParam(cond_nontracer_consts), body_consts=lax._OpaqueParam(body_nontracer_consts), aval_out=aval_out, cond_jaxpr=cond_jaxpr, body_jaxpr=body_jaxpr) return build_tree(out_tree(), out_flat)
def testRoundtripViaBuild(self, inputs): xs, tree = _process_pytree(tuple, inputs) actual = tree_util.build_tree(tree, xs) self.assertEqual(actual, inputs)
def scan(f, init, xs): """Scan a function over leading array axes while carrying along state. The type signature in brief is .. code-block:: haskell scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b]) where we use [t] here to denote the type t with an additional leading axis. That is, if t is an array type then [t] represents the type with an additional leading axis, and if t is a pytree (container) type with array leaves then [t] represents the type with the same pytree structure and corresponding leaves each with an additional leading axis. When both ``a`` and ``b`` are array types, the semantics of ``scan`` are given by this Python implementation:: def scan(f, init, xs): carry = init ys = [] for x in xs: carry, y = f(carry, x) ys.append(y) return carry, np.stack(ys) Unlike that Python version, both ``a`` and ``b`` may be arbitrary pytree types, and so multiple arrays can be scanned over at once and produce multiple output arrays. Also unlike that Python version, ``scan`` is a JAX primitive and is lowered to a single XLA While HLO. That makes it useful for reducing compilation times for jit-compiled functions, since native Python loop constructs in an ``@jit`` function are unrolled, leading to large XLA computations. Args: f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning that ``f`` accepts two arguments where the first is a value of the loop carry and the second is a slice of ``xs`` along its leading axis, and that ``f`` returns a pair where the first element represents a new value for the loop carry and the second represents a slice of the output. init: an initial loop carry value of type ``c``, which can be a scalar, array, or any pytree (nested Python tuple/list/dict) thereof, representing the initial loop carry value. xs: the value of type ``[a]`` over which to scan along the leading axis, where ``[a]`` can be an array or any pytree (nested Python tuple/list/dict) thereof with consistent leading axis sizes. Returns: A pair of type ``(c, [b])`` where the first element represents the final loop carry value and the second element represents the stacked outputs of the second output of ``f`` when scanned over the leading axis of the inputs. """ (init, xs), in_trees = unzip2(map(pytree_to_jaxtupletree, (init, xs))) f, out_tree = pytree_fun_to_jaxtupletree_fun(lu.wrap_init(f), in_trees) carry_pval = carry_aval, _ = _abstractify(init) xs_aval, _ = _abstractify(xs) x_aval = _demote_aval_rank(xs_aval) x_pval = pe.PartialVal((x_aval, core.unit)) jaxpr, pval_out, consts = pe.trace_to_jaxpr( f, (carry_pval, x_pval), instantiate=True) pv_out, const_out = pval_out assert isinstance(pv_out, core.AbstractValue) and const_out == core.unit if not isinstance(pv_out, core.AbstractTuple) or len(pv_out) != 2: msg = ("scanned function must have signature `c -> a -> (c, b)`, but the " "output was not a pair: got type {}.") raise TypeError(msg.format(pv_out)) carry_aval_out, y_aval = pv_out if carry_aval != carry_aval_out: msg = ("scanned function carry output does not match carry input: " "input carry is {} and output carry is {}.") raise TypeError(msg.format(carry_aval, carry_aval_out)) lifted_jaxpr = pe._closure_convert_jaxpr(jaxpr) consts_aval, _ = _abstractify(core.pack(consts)) in_avals = (consts_aval, carry_aval, x_aval) out_aval = core.AbstractTuple((carry_aval, y_aval)) jaxpr = core.TypedJaxpr(lifted_jaxpr, (), in_avals, out_aval) length = _leading_dim_size(xs) out = scan_p.bind(core.pack(consts), init, xs, forward=True, length=length, jaxpr=jaxpr) return build_tree(out_tree(), out)
def while_loop(cond_fun, body_fun, init_val): """Call ``body_fun`` repeatedly in a loop while ``cond_fun`` is True. The type signature in brief is .. code-block:: haskell while_loop :: (a -> Bool) -> (a -> a) -> a -> a The semantics of ``while_loop`` are given by this Python implementation:: def while_loop(cond_fun, body_fun, init_val): val = init_val while cond_fun(val): val = body_fun(val) return val Unlike that Python version, ``while_loop`` is a JAX primitive and is lowered to a single XLA While HLO. That makes it useful for reducing compilation times for jit-compiled functions, since native Python loop constructs in an ``@jit`` function are unrolled, leading to large XLA computations. Another difference from using Python-native loop constructs is that ``while_loop`` is not reverse-mode differentiable because XLA computations require static bounds on memory requirements. Args: cond_fun: function of type ``a -> Bool``. body_fun: function of type ``a -> a``. init_val: value of type ``a``, a type that can be a scalar, array, or any pytree (nested Python tuple/list/dict) thereof, representing the initial loop carry value. Returns: The output from the final iteration of body_fun, of type ``a``. """ init_val_flat, in_tree = pytree_to_jaxtupletree(init_val) flat_body_fun, out_tree = pytree_fun_to_jaxtupletree_fun(lu.wrap_init(body_fun), (in_tree,)) flat_cond_fun, _ = pytree_fun_to_jaxtupletree_fun(lu.wrap_init(cond_fun), (in_tree,)) carry_pval_flat = carry_aval, _ = _abstractify(init_val_flat) cond_jaxpr, cond_pval_out, cond_consts = pe.trace_to_jaxpr(flat_cond_fun, (carry_pval_flat,)) body_jaxpr, body_pval_out, body_consts = pe.trace_to_jaxpr(flat_body_fun, (carry_pval_flat,), instantiate=True) carry_aval_out, _ = body_pval_out assert isinstance(carry_aval_out, core.AbstractValue) assert carry_aval == core.lattice_join(carry_aval, carry_aval_out) cond_pv, cond_const = cond_pval_out if cond_pv is None: # cond_fun evaluates to a constant, so don't need to generate a while_loop if cond_const: raise ValueError("infinite loop with no effects") else: return init_val else: assert isinstance(cond_pv, core.AbstractValue) if (not isinstance(cond_pv, ShapedArray) or cond_pv.shape or cond_pv.dtype != onp.bool_): msg = "while_loop cond_fun must return a scalar boolean, got {}." raise TypeError(msg.format(cond_pv)) if out_tree() != in_tree: raise TypeError("body_fun input and output must have identical structure") out_flat = while_p.bind( init_val_flat, core.pack(cond_consts), core.pack(body_consts), aval_out=carry_aval_out, cond_jaxpr=cond_jaxpr, body_jaxpr=body_jaxpr) return build_tree(out_tree(), out_flat)
def while_loop(cond_fun, body_fun, init_val): """Call ``body_fun`` repeatedly in a loop while ``cond_fun`` is True. The type signature in brief is .. code-block:: haskell while_loop :: (a -> Bool) -> (a -> a) -> a -> a The semantics of ``while_loop`` are given by this Python implementation:: def while_loop(cond_fun, body_fun, init_val): val = init_val while cond_fun(val): val = body_fun(val) return val Unlike that Python version, ``while_loop`` is a JAX primitive and is lowered to a single XLA While HLO. That makes it useful for reducing compilation times for jit-compiled functions, since native Python loop constructs in an ``@jit`` function are unrolled, leading to large XLA computations. Another difference from using Python-native loop constructs is that ``while_loop`` is not reverse-mode differentiable because XLA computations require static bounds on memory requirements. Args: cond_fun: function of type ``a -> Bool``. body_fun: function of type ``a -> a``. init_val: value of type ``a``, a type that can be a scalar, array, or any pytree (nested Python tuple/list/dict) thereof, representing the initial loop carry value. Returns: The output from the final iteration of body_fun, of type ``a``. """ init_val_flat, in_tree = pytree_to_jaxtupletree(init_val) flat_body_fun, out_tree = pytree_fun_to_jaxtupletree_fun( lu.wrap_init(body_fun), (in_tree, )) flat_cond_fun, _ = pytree_fun_to_jaxtupletree_fun(lu.wrap_init(cond_fun), (in_tree, )) carry_pval_flat = carry_aval, _ = _abstractify(init_val_flat) cond_jaxpr, cond_pval_out, cond_consts = pe.trace_to_jaxpr( flat_cond_fun, (carry_pval_flat, )) body_jaxpr, body_pval_out, body_consts = pe.trace_to_jaxpr( flat_body_fun, (carry_pval_flat, ), instantiate=True) carry_aval_out, _ = body_pval_out assert isinstance(carry_aval_out, core.AbstractValue) assert carry_aval == core.lattice_join(carry_aval, carry_aval_out) cond_pv, cond_const = cond_pval_out if cond_pv is None: # cond_fun evaluates to a constant, so don't need to generate a while_loop if cond_const: raise ValueError("infinite loop with no effects") else: return init_val else: assert isinstance(cond_pv, core.AbstractValue) if (not isinstance(cond_pv, ShapedArray) or cond_pv.shape or cond_pv.dtype != onp.bool_): msg = "while_loop cond_fun must return a scalar boolean, got {}." raise TypeError(msg.format(cond_pv)) # We don't want to promote literal constants as loop arguments; there are # sometimes many of them. We pass tracers as loop arguments, but leave # nontracers as constants. We also sort the constants so the nontracers are # first. def split_tracers_and_nontracers(jaxpr, consts): tracer = [] nontracer = [] for x in zip(jaxpr.constvars, consts): # TODO(phawkins): We avoid treating DeviceArrays as constant literals so # we don't copy large arrays back to the host. We probably should relax # this and either always copy small constants, or opportunistically use # DeviceArray values for which we already know npy_value. not_literal_const = isinstance(x[1], (core.Tracer, xla.DeviceArray)) (tracer if not_literal_const else nontracer).append(x) tracer_vars, tracer_consts = unzip2(tracer) nontracer_vars, nontracer_consts = unzip2(nontracer) return nontracer_vars + tracer_vars, nontracer_consts, tracer_consts cond_split = split_tracers_and_nontracers(cond_jaxpr, cond_consts) cond_jaxpr.constvars, cond_nontracer_consts, cond_tracer_consts = cond_split body_split = split_tracers_and_nontracers(body_jaxpr, body_consts) body_jaxpr.constvars, body_nontracer_consts, body_tracer_consts = body_split if out_tree() != in_tree: raise TypeError( "body_fun input and output must have identical structure") out_flat = while_p.bind( init_val_flat, core.pack(cond_tracer_consts), core.pack(body_tracer_consts), cond_consts=lax._OpaqueParam(cond_nontracer_consts), body_consts=lax._OpaqueParam(body_nontracer_consts), aval_out=carry_aval_out, cond_jaxpr=cond_jaxpr, body_jaxpr=body_jaxpr) return build_tree(out_tree(), out_flat)