def scan(f, init, xs): """Scan a function over leading array axes while carrying along state. The type signature in brief is .. code-block:: haskell scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b]) where we use [t] here to denote the type t with an additional leading axis. That is, if t is an array type then [t] represents the type with an additional leading axis, and if t is a pytree (container) type with array leaves then [t] represents the type with the same pytree structure and corresponding leaves each with an additional leading axis. When both ``a`` and ``b`` are array types, the semantics of ``scan`` are given by this Python implementation:: def scan(f, init, xs): carry = init ys = [] for x in xs: carry, y = f(carry, x) ys.append(y) return carry, np.stack(ys) Unlike that Python version, both ``a`` and ``b`` may be arbitrary pytree types, and so multiple arrays can be scanned over at once and produce multiple output arrays. Also unlike that Python version, ``scan`` is a JAX primitive and is lowered to a single XLA While HLO. That makes it useful for reducing compilation times for jit-compiled functions, since native Python loop constructs in an ``@jit`` function are unrolled, leading to large XLA computations. Args: f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning that ``f`` accepts two arguments where the first is a value of the loop carry and the second is a slice of ``xs`` along its leading axis, and that ``f`` returns a pair where the first element represents a new value for the loop carry and the second represents a slice of the output. init: an initial loop carry value of type ``c``, which can be a scalar, array, or any pytree (nested Python tuple/list/dict) thereof, representing the initial loop carry value. This value must have the same structure as the first element of the pair returned by ``f``. xs: the value of type ``[a]`` over which to scan along the leading axis, where ``[a]`` can be an array or any pytree (nested Python tuple/list/dict) thereof with consistent leading axis sizes. Returns: A pair of type ``(c, [b])`` where the first element represents the final loop carry value and the second element represents the stacked outputs of the second output of ``f`` when scanned over the leading axis of the inputs. """ init_flat, init_tree = tree_flatten(init) xs_flat, _ = tree_flatten(xs) in_flat, in_tree = tree_flatten((init, xs)) try: length, = {x.shape[0] for x in xs_flat} except AttributeError: msg = "scan got value with no leading axis to scan over: {}." raise ValueError(msg.format([x for x in xs_flat if not hasattr(x, 'shape')])) except ValueError: msg = "scan got values with different leading axis sizes: {}." raise ValueError(msg.format([x.shape[0] for x in xs_flat])) carry_avals = tuple(_map(_abstractify, init_flat)) x_shapes = [masking.padded_shape_as_value(x.shape[1:]) for x in xs_flat] x_dtypes = [x.dtype for x in xs_flat] x_avals = tuple(_map(ShapedArray, x_shapes, x_dtypes)) jaxpr, consts, out_tree = _initial_style_jaxpr(f, in_tree, carry_avals + x_avals) out_tree_children = out_tree.children() if len(out_tree_children) != 2: msg = "scan body output must be a pair, got {}." raise TypeError(msg.format(tree_unflatten(out_tree, jaxpr.out_avals))) _check_tree_and_avals("scan carry output and input", # Extract the subtree and avals for the first element of the return tuple out_tree_children[0], jaxpr.out_avals[:out_tree_children[0].num_leaves], init_tree, carry_avals) out = scan_p.bind(*itertools.chain(consts, in_flat), forward=True, length=length, jaxpr=jaxpr, num_consts=len(consts), num_carry=len(init_flat), linear=(False,) * (len(consts) + len(in_flat))) return tree_unflatten(out_tree, out)
def _scan_harvest_rule(trace: HarvestTrace, *tracers, length, reverse, jaxpr, num_consts, num_carry, linear, unroll): """Collects and injects values into/from the scan body.""" context = trace_util.get_dynamic_context(trace) settings = context.settings values = [t.val for t in tracers] consts, init, xs = jax_util.split_list(values, [num_consts, num_carry]) active_sows = _find_sows(jaxpr, settings.tag) active_modes = [params['mode'] for params in active_sows] if any(mode == 'strict' for mode in active_modes): raise ValueError('Cannot use strict mode in a scan.') active_names = [params['name'] for params in active_sows] sow_modes = {name: mode for name, mode in zip(active_names, active_modes)} carry_plants = { name: context.plants[name] for name in active_names if name in context.plants and sow_modes[name] == 'clobber' } xs_plants = { name: context.plants[name] for name in active_names if name in context.plants and sow_modes[name] == 'append' } def jaxpr_fun(carry, x): body_out = jax_core.eval_jaxpr(jaxpr.jaxpr, jaxpr.literals, *(consts + carry + x)) carry, y = jax_util.split_list(body_out, [num_carry]) return carry, y harvest_body = harvest(jaxpr_fun, tag=settings.tag, allowlist=settings.allowlist, blocklist=settings.blocklist) def body(carry, x): x_plants, x_vals = x (carry, y), reaps = harvest_body({ **carry_plants, **x_plants }, carry, x_vals) return carry, (y, reaps) xs_flat = tree_util.tree_leaves((xs_plants, xs)) x_avals = [] for x in xs_flat: x_aval = jax_core.get_aval(x) if x_aval is jax_core.abstract_unit: x_avals.append(x_aval) else: x_shape, x_dtype = masking.padded_shape_as_value( x.shape[1:]), x.dtype x_avals.append(abstract_arrays.ShapedArray(x_shape, x_dtype)) x_avals = tuple(x_avals) init_avals = tuple( abstract_arrays.raise_to_shaped(jax_core.get_aval(a)) for a in init) in_flat, in_tree = tree_util.tree_flatten((init, (xs_plants, xs))) body_jaxpr, new_consts, out_tree = ( jax.lax.lax_control_flow._initial_style_jaxpr( # pylint: disable=protected-access body, in_tree, init_avals + x_avals)) new_values = list(new_consts) + in_flat num_xs_plants = len(new_values) - len(init) - len(xs) - len(new_consts) remaining_linear = linear[num_consts:] new_linear = ((False, ) * len(new_consts) + remaining_linear[:len(init)] + (False, ) * num_xs_plants + remaining_linear[len(init):]) assert len(new_linear) == len(new_values) outs = lax.scan_p.bind(*new_values, length=length, reverse=reverse, jaxpr=body_jaxpr, num_consts=len(new_consts), num_carry=num_carry, linear=new_linear, unroll=unroll) outs = safe_map(trace.pure, outs) carry, (ys, reaps) = tree_util.tree_unflatten(out_tree, outs) out_reaps = {} for k, val in reaps.items(): mode = sow_modes.get(k, 'strict') if mode == 'append': val = tree_util.tree_map(np.concatenate, val) elif mode == 'clobber': val = tree_util.tree_map(lambda x: x[-1], val) out_reaps[k] = sow(val, tag=settings.tag, name=k, mode='strict') (carry, ys) = prim.tie_in(out_reaps, (carry, ys)) return carry + ys