def partial_eval_by_shape(fn, input_spec, *args, **kwargs): """Lazily evaluate a function by using the shapes of the inputs. This function is similar to `jax.eval_shape` with the key difference that function outputs that can be computed without a concrete value of the inputs are returned as is instead of only the shape. See for example `module.init_by_shape` where this functionality is used to initialize a model without using input data lr computation. Args: fn: the function to be lazily evaluated. input_spec: an iterable of shapes or (shape, dtype) tuples specifying the shape and type of the inputs. If unspecified the dtype is float32. *args: other arguments passed to the module's apply function **kwargs: keyword arguments passed to the module's apply function Returns: A pair consisting of the model output and an instance of Model """ # output cannot be returned in lazy_create because jax.eval_shape will only # return the shape and dtype. # TODO(mattjj,jheek): use a public JAX API f = lambda *inputs: fn(*inputs, *args, **kwargs) input_structs = [_parse_spec(spec) for spec in input_spec] inputs_flat, in_tree = jax.tree_flatten(input_structs) f_flat, out_tree = jax.api_util.flatten_fun_nokwargs(lu.wrap_init(f), in_tree) in_pvals = [pe.PartialVal.unknown(jax.ShapedArray(x.shape, x.dtype)) for x in inputs_flat] if config.omnistaging_enabled: _, out_pvals, _ = pe.trace_to_jaxpr(f_flat, in_pvals) else: with jax.core.initial_style_staging(): _, out_pvals, _ = pe.trace_to_jaxpr(f_flat, in_pvals, stage_out=True) out_flat = [const if pv is None else jax.ShapeDtypeStruct(pv.shape, pv.dtype) for pv, const in out_pvals] return jax.tree_unflatten(out_tree(), out_flat)
def wrapped(*args, **kwargs): fun = lu.wrap_init(f, kwargs) flat_args, in_tree = tree_util.tree_flatten(args) flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree) flat_avals = safe_map(get_shaped_aval, flat_args) pvals = [pe.PartialVal((aval, jax_core.unit)) for aval in flat_avals] jaxpr, out_pvals, consts = pe.trace_to_jaxpr( flat_fun, pvals, instantiate=True, stage_out=True, trace_type=pe.StagingJaxprTrace) out_avals = [pval.get_aval() for pval in out_pvals] typed_jaxpr = jax_core.TypedJaxpr(jaxpr, consts, flat_avals, out_avals) return typed_jaxpr, (in_tree, out_tree())
def transposed(*args): in_primals, out_cts = tree_unflatten(treedef, args) in_pvals = [ pe.PartialVal.unknown(x.aval) if ad.is_undefined_primal(x) else pe.PartialVal.known(x) for x in in_primals ] primal_fun = lu.wrap_init(partial(core.eval_jaxpr, jaxpr, ())) tangent_jaxpr, _, consts = pe.trace_to_jaxpr(primal_fun, in_pvals, False) dummy_args = [ad.UndefinedPrimal(v.aval) for v in tangent_jaxpr.invars] in_cts_ = ad.backward_pass(tangent_jaxpr, reduce_axes, False, consts, dummy_args, out_cts) in_cts, cell.treedef = tree_flatten(in_cts_) return in_cts
def scan_fn(broadcast_in, init, *args): xs = jax.tree_multimap(transpose_to_front, in_axes, args) def body_fn(c, xs, init_mode=False): # inject constants xs = jax.tree_multimap( lambda ax, arg, x: (arg if ax is broadcast else x), in_axes, args, xs) broadcast_out, c, ys = fn(broadcast_in, c, *xs) if init_mode: ys = jax.tree_multimap( lambda ax, y: (y if ax is broadcast else ()), out_axes, ys) return broadcast_out, ys else: ys = jax.tree_multimap( lambda ax, y: (() if ax is broadcast else y), out_axes, ys) return c, ys broadcast_body = functools.partial(body_fn, init_mode=True) carry_pvals = jax.tree_map( lambda x: pe.PartialVal.unknown( jax.ShapedArray(jnp.shape(x), jnp.result_type(x))), init) scan_pvals = jax.tree_map( lambda x: pe.PartialVal.unknown( jax.ShapedArray(jnp.shape(x)[1:], jnp.result_type(x))), xs) input_pvals = (carry_pvals, scan_pvals) in_pvals, in_tree = jax.tree_flatten(input_pvals) f_flat, out_tree = jax.api_util.flatten_fun_nokwargs( lu.wrap_init(broadcast_body), in_tree) _, out_pvals, _ = pe.trace_to_jaxpr(f_flat, in_pvals) out_flat = [] for pv, const in out_pvals: if pv is not None: raise ValueError( 'broadcasted variable has a data dependency on the scan body.' ) out_flat.append(const) broadcast_in, constants_out = jax.tree_unflatten(out_tree(), out_flat) c, ys = lax.scan(body_fn, init, xs, length=length, reverse=reverse) ys = jax.tree_multimap(transpose_from_front, out_axes, ys) ys = jax.tree_multimap( lambda ax, const, y: (const if ax is broadcast else y), out_axes, constants_out, ys) return broadcast_in, c, ys
def _get_jax_objects(function, args, kwargs): # Set up function for transformation wrapped_function = j_linear_util.wrap_init(function) # Flatten input arguments jax_arguments, in_tree = j_tree_util.tree_flatten((args, kwargs)) # Transform function to accept flat arguments # and return a flat list result jaxtree_function, _ = j_api_util.flatten_fun(wrapped_function, in_tree) # Abstract and partial-value's flat arguments partial_values = j_util.safe_map(_get_partial_value, jax_arguments) # Trace function into Jaxpr jaxpr, _, constants = ji_partial_eval.trace_to_jaxpr( jaxtree_function, partial_values ) result = (jaxpr, constants) return result
def wrapped(*args, **kwargs): fun = lu.wrap_init(f, kwargs) flat_args, in_tree = tree_util.tree_flatten(args) flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree) flat_avals = safe_map(get_shaped_aval, flat_args) if dynamic: jaxpr, _, consts = pe.trace_to_jaxpr_dynamic( flat_fun, flat_avals) else: pvals = [pe.PartialVal.unknown(aval) for aval in flat_avals] jaxpr, _, consts = pe.trace_to_jaxpr( flat_fun, pvals, instantiate=True) typed_jaxpr = jax_core.ClosedJaxpr(jaxpr, consts) return typed_jaxpr, (in_tree, out_tree())
def wrapped(*args, **kwargs): fun = lu.wrap_init(f, kwargs) flat_args, in_tree = tree_util.tree_flatten(args) flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree) flat_avals = safe_map(get_shaped_aval, flat_args) if not jax.config.omnistaging_enabled: raise ValueError('Oryx must be used with JAX omnistaging enabled.') if dynamic: jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals) else: pvals = [ pe.PartialVal((aval, jax_core.unit)) for aval in flat_avals ] jaxpr, _, consts = pe.trace_to_jaxpr(flat_fun, pvals, instantiate=True) typed_jaxpr = jax_core.ClosedJaxpr(jaxpr, consts) return typed_jaxpr, (in_tree, out_tree())
def closure_convert(fun, in_tree, in_avals): in_pvals = [pe.PartialVal.unknown(aval) for aval in in_avals] wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree) with core.initial_style_staging(): jaxpr, out_pvals, consts = pe.trace_to_jaxpr(wrapped_fun, in_pvals, instantiate=True, stage_out=False) out_tree = out_tree() num_consts = len(consts) def converted_fun(y, t, *consts_args): consts, args = split_list(consts_args, [num_consts]) all_args, in_tree2 = tree_flatten((y, t, *args)) assert in_tree == in_tree2 out_flat = core.eval_jaxpr(jaxpr, consts, *all_args) return tree_unflatten(out_tree, out_flat) return converted_fun, consts
def linearize(traceable, *primals, **kwargs): has_aux = kwargs.pop('has_aux', False) if not has_aux: jvpfun = jvp(traceable) else: jvpfun, aux = jvp(traceable, has_aux=True) in_pvals = (tuple(pe.PartialVal.known(p) for p in primals) + tuple(pe.PartialVal.unknown(get_aval(p).at_least_vspace()) for p in primals)) _, in_tree = tree_flatten(((primals, primals), {})) jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree) jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals) out_primals_pvals, out_tangents_pvals = tree_unflatten(out_tree(), out_pvals) assert all(out_primal_pval.is_known() for out_primal_pval in out_primals_pvals) _, out_primals_consts = unzip2(out_primals_pvals) jaxpr.invars = jaxpr.invars[len(primals):] jaxpr.outvars = jaxpr.outvars[len(out_primals_pvals):] if not has_aux: return out_primals_consts, out_tangents_pvals, jaxpr, consts else: return out_primals_consts, out_tangents_pvals, jaxpr, consts, aux()
def _tscan(f, a, bs, fields=(0, )): """ Works as jax.lax.scan but has additional `fields` argument to select only necessary fields from `a`'s structure. Defaults to selecting only the first field. Other fields will be filled by None. """ # Note: code is copied and modified from lax.scan implementation in # [JAX](https://github.com/google/jax) to support the additional `fields` # arg. Original code has the following copyright: # # Copyright 2018 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License") # convert pytree to flat jaxtuple a, a_tree = pytree_to_flatjaxtuple(a) bs, b_tree = pytree_to_flatjaxtuple(bs) fields, _ = pytree_to_flatjaxtuple(fields) f, out_tree = pytree_fun_to_flatjaxtuple_fun(wrap_init(f), (a_tree, b_tree)) # convert arrays to abstract values a_aval, _ = lax._abstractify(a) bs_aval, _ = lax._abstractify(bs) # convert bs to b b_aval = core.AbstractTuple( [ShapedArray(b.shape[1:], b.dtype) for b in bs_aval]) # convert abstract values to partial values (?) then evaluate to get jaxpr a_pval = partial_eval.PartialVal((a_aval, core.unit)) b_pval = partial_eval.PartialVal((b_aval, core.unit)) jaxpr, pval_out, consts = partial_eval.trace_to_jaxpr(f, (a_pval, b_pval)) aval_out, _ = pval_out consts = core.pack(consts) out = tscan_p.bind(a, bs, fields, consts, aval_out=aval_out, jaxpr=jaxpr) return tree_unflatten(out_tree(), out)
def closure_convert(fun, in_tree, in_avals): in_pvals = [pe.PartialVal.unknown(aval) for aval in in_avals] wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree) with core.initial_style_staging(): jaxpr, out_pvals, consts = pe.trace_to_jaxpr( wrapped_fun, in_pvals, instantiate=True, stage_out=False) out_tree = out_tree() # We only want to closure convert for constants with respect to which we're # differentiating. As a proxy for that, we hoist consts with float dtype. # TODO(mattjj): revise this approach is_float = lambda c: dtypes.issubdtype(dtypes.dtype(c), jnp.inexact) (closure_consts, hoisted_consts), merge = partition_list(is_float, consts) num_consts = len(hoisted_consts) def converted_fun(y, t, *hconsts_args): hoisted_consts, args = split_list(hconsts_args, [num_consts]) consts = merge(closure_consts, hoisted_consts) all_args, in_tree2 = tree_flatten((y, t, *args)) assert in_tree == in_tree2 out_flat = core.eval_jaxpr(jaxpr, consts, *all_args) return tree_unflatten(out_tree, out_flat) return converted_fun, hoisted_consts
def _make_typed_jaxpr(traceable, in_avals): pvals = [pe.PartialVal((aval, core.unit)) for aval in in_avals] jaxpr, pvals_out, consts = pe.trace_to_jaxpr(traceable, pvals, instantiate=True) out_avals, _ = unzip2(pvals_out) return core.TypedJaxpr(jaxpr, consts, in_avals, out_avals)
def while_loop(cond_fun, body_fun, init_val): """Call ``body_fun`` repeatedly in a loop while ``cond_fun`` is True. The type signature in brief is .. code-block:: haskell while_loop :: (a -> Bool) -> (a -> a) -> a -> a The semantics of ``while_loop`` are given by this Python implementation:: def while_loop(cond_fun, body_fun, init_val): val = init_val while cond_fun(val): val = body_fun(val) return val Unlike that Python version, ``while_loop`` is a JAX primitive and is lowered to a single XLA While HLO. That makes it useful for reducing compilation times for jit-compiled functions, since native Python loop constructs in an ``@jit`` function are unrolled, leading to large XLA computations. Another difference from using Python-native loop constructs is that ``while_loop`` is not reverse-mode differentiable because XLA computations require static bounds on memory requirements. Args: cond_fun: function of type ``a -> Bool``. body_fun: function of type ``a -> a``. init_val: value of type ``a``, a type that can be a scalar, array, or any pytree (nested Python tuple/list/dict) thereof, representing the initial loop carry value. Returns: The output from the final iteration of body_fun, of type ``a``. """ init_val_flat, in_tree = pytree_to_jaxtupletree(init_val) flat_body_fun, out_tree = pytree_fun_to_jaxtupletree_fun(lu.wrap_init(body_fun), (in_tree,)) flat_cond_fun, _ = pytree_fun_to_jaxtupletree_fun(lu.wrap_init(cond_fun), (in_tree,)) carry_pval_flat = carry_aval, _ = _abstractify(init_val_flat) cond_jaxpr, cond_pval_out, cond_consts = pe.trace_to_jaxpr(flat_cond_fun, (carry_pval_flat,)) body_jaxpr, body_pval_out, body_consts = pe.trace_to_jaxpr(flat_body_fun, (carry_pval_flat,), instantiate=True) carry_aval_out, _ = body_pval_out assert isinstance(carry_aval_out, core.AbstractValue) assert carry_aval == core.lattice_join(carry_aval, carry_aval_out) cond_pv, cond_const = cond_pval_out if cond_pv is None: # cond_fun evaluates to a constant, so don't need to generate a while_loop if cond_const: raise ValueError("infinite loop with no effects") else: return init_val else: assert isinstance(cond_pv, core.AbstractValue) if (not isinstance(cond_pv, ShapedArray) or cond_pv.shape or cond_pv.dtype != onp.bool_): msg = "while_loop cond_fun must return a scalar boolean, got {}." raise TypeError(msg.format(cond_pv)) if out_tree() != in_tree: raise TypeError("body_fun input and output must have identical structure") out_flat = while_p.bind( init_val_flat, core.pack(cond_consts), core.pack(body_consts), aval_out=carry_aval_out, cond_jaxpr=cond_jaxpr, body_jaxpr=body_jaxpr) return build_tree(out_tree(), out_flat)
def _make_typed_jaxpr(traceable, in_avals): pvals = [pe.PartialVal((aval, core.unit)) for aval in in_avals] jaxpr, pval_out, consts = pe.trace_to_jaxpr(traceable, pvals, instantiate=True) out_aval, _ = pval_out assert isinstance(out_aval, core.AbstractValue) return core.TypedJaxpr(jaxpr, consts, in_avals, out_aval)
def scan(f, init, xs): """Scan a function over leading array axes while carrying along state. The type signature in brief is .. code-block:: haskell scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b]) where we use [t] here to denote the type t with an additional leading axis. That is, if t is an array type then [t] represents the type with an additional leading axis, and if t is a pytree (container) type with array leaves then [t] represents the type with the same pytree structure and corresponding leaves each with an additional leading axis. When both ``a`` and ``b`` are array types, the semantics of ``scan`` are given by this Python implementation:: def scan(f, init, xs): carry = init ys = [] for x in xs: carry, y = f(carry, x) ys.append(y) return carry, np.stack(ys) Unlike that Python version, both ``a`` and ``b`` may be arbitrary pytree types, and so multiple arrays can be scanned over at once and produce multiple output arrays. Also unlike that Python version, ``scan`` is a JAX primitive and is lowered to a single XLA While HLO. That makes it useful for reducing compilation times for jit-compiled functions, since native Python loop constructs in an ``@jit`` function are unrolled, leading to large XLA computations. Args: f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning that ``f`` accepts two arguments where the first is a value of the loop carry and the second is a slice of ``xs`` along its leading axis, and that ``f`` returns a pair where the first element represents a new value for the loop carry and the second represents a slice of the output. init: an initial loop carry value of type ``c``, which can be a scalar, array, or any pytree (nested Python tuple/list/dict) thereof, representing the initial loop carry value. xs: the value of type ``[a]`` over which to scan along the leading axis, where ``[a]`` can be an array or any pytree (nested Python tuple/list/dict) thereof with consistent leading axis sizes. Returns: A pair of type ``(c, [b])`` where the first element represents the final loop carry value and the second element represents the stacked outputs of the second output of ``f`` when scanned over the leading axis of the inputs. """ (init, xs), in_trees = unzip2(map(pytree_to_jaxtupletree, (init, xs))) f, out_tree = pytree_fun_to_jaxtupletree_fun(lu.wrap_init(f), in_trees) carry_pval = carry_aval, _ = _abstractify(init) xs_aval, _ = _abstractify(xs) x_aval = _demote_aval_rank(xs_aval) x_pval = pe.PartialVal((x_aval, core.unit)) jaxpr, pval_out, consts = pe.trace_to_jaxpr( f, (carry_pval, x_pval), instantiate=True) pv_out, const_out = pval_out assert isinstance(pv_out, core.AbstractValue) and const_out == core.unit if not isinstance(pv_out, core.AbstractTuple) or len(pv_out) != 2: msg = ("scanned function must have signature `c -> a -> (c, b)`, but the " "output was not a pair: got type {}.") raise TypeError(msg.format(pv_out)) carry_aval_out, y_aval = pv_out if carry_aval != carry_aval_out: msg = ("scanned function carry output does not match carry input: " "input carry is {} and output carry is {}.") raise TypeError(msg.format(carry_aval, carry_aval_out)) lifted_jaxpr = pe._closure_convert_jaxpr(jaxpr) consts_aval, _ = _abstractify(core.pack(consts)) in_avals = (consts_aval, carry_aval, x_aval) out_aval = core.AbstractTuple((carry_aval, y_aval)) jaxpr = core.TypedJaxpr(lifted_jaxpr, (), in_avals, out_aval) length = _leading_dim_size(xs) out = scan_p.bind(core.pack(consts), init, xs, forward=True, length=length, jaxpr=jaxpr) return build_tree(out_tree(), out)
def inv_backward_pass(jaxpr: core.Jaxpr, consts, primals_in, primals_out, cotangents_in): if all(type(ct) is ad.Zero for ct in cotangents_in): return map(lambda v: ad.Zero(v.aval), jaxpr.invars) def write_cotangent(v, ct): # assert v not in primal_env if ct is not None and type(v) is not Literal: ct_env[v] = ad.add_tangents(ct_env[v], ct) if v in ct_env else ct def read_cotangent(v): return ct_env.get(v, ad.Zero(v.aval)) def read_primal(v): if type(v) is Literal: return v.val else: return primal_env.get(v, ad.UndefinedPrimal(v.aval)) def write_primal(v, val): if type(v) is Literal: return primal_env.setdefault(v, val) # Invert while computing cotangents ct_env: Dict[Any, Any] = {} primal_env: Dict[Any, Any] = {} write_primal(core.unitvar, core.unit) map(write_primal, jaxpr.invars, primals_in) map(write_primal, jaxpr.outvars, primals_out) map(write_primal, jaxpr.constvars, consts) map(write_cotangent, jaxpr.outvars, cotangents_in) for eqn in jaxpr.eqns[::-1]: primals_in = map(read_primal, eqn.invars) primals_out = map(read_primal, eqn.outvars) cts_in = map(read_cotangent, eqn.outvars) should_invert = any(type(primal) is not ad.UndefinedPrimal for primal in primals_out) should_vjp = any(type(ct) is not ad.Zero for ct in cts_in) assert not eqn.primitive.call_primitive # Skip primals equations that are only jvp coefficients and don't affect # primal outputs. if not should_invert and not should_vjp: continue def abstract(value): return raise_to_shaped(value.aval if ad.is_undefined_primal(value) else get_aval(value)) # Get the ivjp_jaxpr if eqn.primitive is custom_ivjp_p: ivjp_jaxpr = eqn.params['ivjp_jaxpr'] else: if eqn.primitive in primitive_ivjps: complete_ivjp = lu.wrap_init(primitive_ivjps[eqn.primitive]) else: complete_ivjp = lu.wrap_init(partial(synthesize_ivjp, eqn, map(ad.is_undefined_primal, primals_in))) _, in_tree = tree_flatten( tuple(map(abstract, x) for x in (primals_in, primals_out, primals_out))) complete_ivjp_flat, _ = flatten_fun_nokwargs(complete_ivjp, in_tree) in_avals = map(abstract, primals_in + primals_out + primals_out) # TODO: Actually we do know some of the inputs, because they might be literals! ivjp_jaxpr, out_pvals, _ = pe.trace_to_jaxpr( complete_ivjp_flat, map(pe.PartialVal.unknown, in_avals), instantiate=True) assert not ivjp_jaxpr.constvars # That might happen some time, but don't bother until then ivjp_jaxpr = core.ClosedJaxpr(ivjp_jaxpr, []) # Once we know what the ivjp can do exactly, we have to isolate the part we are # actually able to compute with the values we have at hand. num_inputs = len(eqn.invars) unknowns = (map(ad.is_undefined_primal, primals_in) + map(ad.is_undefined_primal, primals_out) + [False] * len(cts_in)) jaxpr_known, jaxpr_unknown, out_unknowns = pe.partial_eval_jaxpr( # type: ignore ivjp_jaxpr, unknowns, instantiate=False) # type:ignore unknown_rec_primals_in, unknown_cotangents = split_list(out_unknowns, [num_inputs]) # Make sure we're able to compute all cotangents. We don't really care if we # can reconstruct or primals or not, although failure to do so might result in # failing to compute cotangents later. assert not any(unknown_cotangents) # Remove residual outputs -- we won't be computing the unknown jaxpr anyway. num_outputs = len(jaxpr_unknown.jaxpr.outvars) jaxpr_known.jaxpr.outvars = jaxpr_known.jaxpr.outvars[:num_outputs] # TODO: We could drop the outputs that correspond to primals that we already know. # This only matters in eager mode, so leaving it out for now... ivjp = core.jaxpr_as_fun(jaxpr_known) rec_primals_in, cts_out = split_list(ivjp(*primals_in, *primals_out, *cts_in), [num_inputs]) # Unknown rec_primals_in are core.units, so we have to replace them # with UnknownPrimals because that's what write_primal will ignore. rec_primals_in = [prev if unknown else rec for prev, rec, unknown in zip(primals_in, rec_primals_in, unknown_rec_primals_in)] map(write_primal, eqn.invars, rec_primals_in) map(write_cotangent, [v for v in eqn.invars if type(v) is not Literal], cts_out) # NOTE: We keep the cotangents associated with primal variables, while the contract of a # transpose is to return them in positions associated with tangent variables, which # is what causes this whole confusion. return map(read_cotangent, jaxpr.invars)
def DIABLED_test_print_jaxpr_compound(self): # TODO(dougalm): figure out what jaxpr-tracing api to expose and re-enable pv = pe.PartialVal((ShapedArray((2, 3), onp.float32), core.unit)) print(pe.trace_to_jaxpr(fun_with_call_closure, (pv, ))[0])
def while_loop(cond_fun, body_fun, init_val): """Call ``body_fun`` repeatedly in a loop while ``cond_fun`` is True. The type signature in brief is .. code-block:: haskell while_loop :: (a -> Bool) -> (a -> a) -> a -> a The semantics of ``while_loop`` are given by this Python implementation:: def while_loop(cond_fun, body_fun, init_val): val = init_val while cond_fun(val): val = body_fun(val) return val Unlike that Python version, ``while_loop`` is a JAX primitive and is lowered to a single XLA While HLO. That makes it useful for reducing compilation times for jit-compiled functions, since native Python loop constructs in an ``@jit`` function are unrolled, leading to large XLA computations. Another difference from using Python-native loop constructs is that ``while_loop`` is not reverse-mode differentiable because XLA computations require static bounds on memory requirements. Args: cond_fun: function of type ``a -> Bool``. body_fun: function of type ``a -> a``. init_val: value of type ``a``, a type that can be a scalar, array, or any pytree (nested Python tuple/list/dict) thereof, representing the initial loop carry value. Returns: The output from the final iteration of body_fun, of type ``a``. """ init_val_flat, in_tree = pytree_to_jaxtupletree(init_val) flat_body_fun, out_tree = pytree_fun_to_jaxtupletree_fun( lu.wrap_init(body_fun), (in_tree, )) flat_cond_fun, _ = pytree_fun_to_jaxtupletree_fun(lu.wrap_init(cond_fun), (in_tree, )) carry_pval_flat = carry_aval, _ = _abstractify(init_val_flat) cond_jaxpr, cond_pval_out, cond_consts = pe.trace_to_jaxpr( flat_cond_fun, (carry_pval_flat, )) body_jaxpr, body_pval_out, body_consts = pe.trace_to_jaxpr( flat_body_fun, (carry_pval_flat, ), instantiate=True) carry_aval_out, _ = body_pval_out assert isinstance(carry_aval_out, core.AbstractValue) assert carry_aval == core.lattice_join(carry_aval, carry_aval_out) cond_pv, cond_const = cond_pval_out if cond_pv is None: # cond_fun evaluates to a constant, so don't need to generate a while_loop if cond_const: raise ValueError("infinite loop with no effects") else: return init_val else: assert isinstance(cond_pv, core.AbstractValue) if (not isinstance(cond_pv, ShapedArray) or cond_pv.shape or cond_pv.dtype != onp.bool_): msg = "while_loop cond_fun must return a scalar boolean, got {}." raise TypeError(msg.format(cond_pv)) # We don't want to promote literal constants as loop arguments; there are # sometimes many of them. We pass tracers as loop arguments, but leave # nontracers as constants. We also sort the constants so the nontracers are # first. def split_tracers_and_nontracers(jaxpr, consts): tracer = [] nontracer = [] for x in zip(jaxpr.constvars, consts): # TODO(phawkins): We avoid treating DeviceArrays as constant literals so # we don't copy large arrays back to the host. We probably should relax # this and either always copy small constants, or opportunistically use # DeviceArray values for which we already know npy_value. not_literal_const = isinstance(x[1], (core.Tracer, xla.DeviceArray)) (tracer if not_literal_const else nontracer).append(x) tracer_vars, tracer_consts = unzip2(tracer) nontracer_vars, nontracer_consts = unzip2(nontracer) return nontracer_vars + tracer_vars, nontracer_consts, tracer_consts cond_split = split_tracers_and_nontracers(cond_jaxpr, cond_consts) cond_jaxpr.constvars, cond_nontracer_consts, cond_tracer_consts = cond_split body_split = split_tracers_and_nontracers(body_jaxpr, body_consts) body_jaxpr.constvars, body_nontracer_consts, body_tracer_consts = body_split if out_tree() != in_tree: raise TypeError( "body_fun input and output must have identical structure") out_flat = while_p.bind( init_val_flat, core.pack(cond_tracer_consts), core.pack(body_tracer_consts), cond_consts=lax._OpaqueParam(cond_nontracer_consts), body_consts=lax._OpaqueParam(body_nontracer_consts), aval_out=carry_aval_out, cond_jaxpr=cond_jaxpr, body_jaxpr=body_jaxpr) return build_tree(out_tree(), out_flat)
def inner(scope_fn, repack_fn, variable_groups_xs, rng_groups_xs): # split rngs split_fn = lambda rng: random.split(rng, length) broadcast_rngs_xs = [] scan_rngs_xs = [] for rng_groups in rng_groups_xs: broadcast_rngs_xs.append( tuple(rng_group for rng_group, split in zip(rng_groups, rng_splits) if not split)) scan_rngs_xs.append( tuple( jax.tree_map(split_fn, rng_group) for rng_group, split in zip(rng_groups, rng_splits) if split)) def body(carry, xs, init_mode=False): carry_vars_xs, c = carry scan_vars_xs, scan_rngs_xs, x = xs variable_groups_xs = combine(scan_vars_xs, carry_vars_xs, broadcast_vars_xs) rng_groups_xs = [] for broadcast_rngs, scan_rngs in zip(broadcast_rngs_xs, scan_rngs_xs): rng_groups_xs.append(broadcast_rngs + scan_rngs) scope = scope_fn(variable_groups_xs, rng_groups_xs) carry, y = fn(scope, c, x) out_vars = repack_fn(scope) scan_vars_xs, carry_vars_out_xs, broadcast_vars_out_xs = split( out_vars, 1) # TODO(jheek) more informative error check def check_shapes(c_in, c_out): if not isinstance(c_in, jnp.ndarray) or not isinstance( c_out, jnp.ndarray): return if jnp.shape(c_in) != jnp.shape(c_out) or jnp.dtype( c_in) != jnp.dtype(c_out): raise ValueError() try: jax.tree_multimap(check_shapes, carry_vars_xs, carry_vars_out_xs) except ValueError: raise ValueError( 'carry variables must have the same shape and dtype before and after scan.' ) if init_mode: return broadcast_vars_out_xs else: return (carry_vars_out_xs, carry), (scan_vars_xs, y) broadcast_body = functools.partial(body, init_mode=True) scan_vars_xs, carry_vars_xs, broadcast_vars_xs = split( variable_groups_xs, 0) carry0 = (carry_vars_xs, init_carry) xxs = (scan_vars_xs, scan_rngs_xs, xs) # use partial evaluation to find the variables that are broadcasted out # an error is thrown if a broadcasted output has a dependency on any scan variables carry_pvals = jax.tree_map( lambda x: pe.PartialVal.unknown(jax.ShapedArray(x.shape, x.dtype)), carry0) scan_pvals = jax.tree_map( lambda x: pe.PartialVal.unknown( jax.ShapedArray(x.shape[1:], x.dtype)), xxs) input_pvals = (carry_pvals, scan_pvals) in_pvals, in_tree = jax.tree_flatten(input_pvals) f_flat, out_tree = jax.api_util.flatten_fun_nokwargs( lu.wrap_init(broadcast_body), in_tree) _, out_pvals, _ = pe.trace_to_jaxpr(f_flat, in_pvals) # _, out_pvals, _ = pe.trace_to_jaxpr(f_flat, in_pvals, stage_out=True) out_flat = [] for pv, const in out_pvals: if pv is not None: raise ValueError( 'broadcasted variable has a data dependency on the scan body.' ) out_flat.append(const) (carry_vars_xs, carry), (scan_vars_xs, ys) = lax.scan(body, carry0, xxs, length=length, reverse=reverse) broadcast_vars_xs = jax.tree_unflatten(out_tree(), out_flat) out_vars_xs = combine(carry_vars_xs, scan_vars_xs, broadcast_vars_xs) return (carry, ys), out_vars_xs
def _instantiated_trace_to_jaxpr(fun, avals): pvals = map(lambda aval: pe.PartialVal((aval, unit)), avals) jaxpr, out_pvals, consts = pe.trace_to_jaxpr(fun, pvals, instantiate=True) out_avals, _ = unzip2(out_pvals) return jaxpr, out_avals, consts
def while_loop(cond_fun, body_fun, init_val): """Call `body_fun` repeatedly in a loop while `cond_fun` is True. Arguments: cond_fun: pure function of type `T -> Bool`. body_fun: pure function of type `T -> T`. init_val: value of type `T`, a type that can be a scalar, array, or any (nested) Python tuple/list/dict thereof. Returns: The output from the final iteration of body_fun, of type `T`. The semantics of `while_loop` are given by this Python implementation:: def while_loop(cond_fun, body_fun, init_val): val = init_val while cond_fun(val): val = body_fun(val) return val Unlike that pure Python version, `while_loop` is a JAX primitive and is lowered to a single XLA While HLO. That makes it useful for reducing compilation times for jit-compiled functions, since native Python loop constructs in an `@jit` function are unrolled, leading to large XLA computations. Another difference from using Python-native loop constructs is that `while_loop` is not (yet) reverse-mode differentiable because XLA computations require static bounds on memory requirements. """ init_val_flat, in_tree = pytree_to_jaxtupletree(init_val) flat_body_fun, out_tree = pytree_fun_to_jaxtupletree_fun( lu.wrap_init(body_fun), (in_tree, )) flat_cond_fun, _ = pytree_fun_to_jaxtupletree_fun(lu.wrap_init(cond_fun), (in_tree, )) pval_flat = lax._abstractify(init_val_flat) cond_jaxpr, _, cond_consts = pe.trace_to_jaxpr(flat_cond_fun, (pval_flat, )) body_jaxpr, pval_out, body_consts = pe.trace_to_jaxpr( flat_body_fun, (pval_flat, )) aval_out, _ = pval_out # We don't want to promote literal constants as loop arguments; there are # sometimes many of them. We pass tracers as loop arguments, but leave # nontracers as constants. We also sort the constants so the nontracers are # first. def split_tracers_and_nontracers(jaxpr, consts): tracer = [] nontracer = [] for x in zip(jaxpr.constvars, consts): # TODO(phawkins): We avoid treating DeviceArrays as constant literals so # we don't copy large arrays back to the host. We probably should relax # this and either always copy small constants, or opportunistically use # DeviceArray values for which we already know npy_value. not_literal_const = isinstance(x[1], (core.Tracer, xla.DeviceArray)) (tracer if not_literal_const else nontracer).append(x) tracer_vars, tracer_consts = unzip2(tracer) nontracer_vars, nontracer_consts = unzip2(nontracer) return nontracer_vars + tracer_vars, nontracer_consts, tracer_consts cond_split = split_tracers_and_nontracers(cond_jaxpr, cond_consts) cond_jaxpr.constvars, cond_nontracer_consts, cond_tracer_consts = cond_split body_split = split_tracers_and_nontracers(body_jaxpr, body_consts) body_jaxpr.constvars, body_nontracer_consts, body_tracer_consts = body_split if out_tree() != in_tree: raise TypeError( "body_fun input and output must have identical structure") out_flat = while_p.bind( init_val_flat, core.pack(cond_tracer_consts), core.pack(body_tracer_consts), cond_consts=lax._OpaqueParam(cond_nontracer_consts), body_consts=lax._OpaqueParam(body_nontracer_consts), aval_out=aval_out, cond_jaxpr=cond_jaxpr, body_jaxpr=body_jaxpr) return build_tree(out_tree(), out_flat)
def trace_jaxpr(fun, operand): op_flat, in_tree = pytree_to_flatjaxtuple(operand) fun_flat, out_tree = pytree_fun_to_flatjaxtuple_fun(lu.wrap_init(fun), (in_tree,)) jaxpr, pvout, consts = pe.trace_to_jaxpr(fun_flat, (_abstractify(op_flat),)) return op_flat, jaxpr, consts, pvout, out_tree