def _while_callback_rule(trace, *tracers, cond_jaxpr, body_jaxpr, cond_nconsts, body_nconsts): cond_const_tracers, body_const_tracers, init_tracers = split_list( tracers, [cond_nconsts, body_nconsts]) init_avals = safe_map(lambda x: x.aval, init_tracers) cond_const_vals, body_const_vals, init_vals = tree_map( lambda x: x.val, (cond_const_tracers, body_const_tracers, init_tracers)) body_fun = jaxpr_as_fun(body_jaxpr) cond_fun = jaxpr_as_fun(cond_jaxpr) def cond(*carry): return cond_fun(*it.chain(cond_const_vals, carry)) def body(*carry): return body_fun(*it.chain(body_const_vals, carry)) new_cond = callback_transform(cond, trace.callback, strip_calls=trace.strip_calls) # type: ignore new_body = callback_transform(body, trace.callback, strip_calls=trace.strip_calls) # type: ignore in_tree = tree_structure(init_avals) new_cond_jaxpr, new_cond_consts, _ = lcf._initial_style_jaxpr(new_cond, in_tree, tuple(init_avals)) new_body_jaxpr, new_body_consts, _ = lcf._initial_style_jaxpr(new_body, in_tree, tuple(init_avals)) out = lcf.while_p.bind( *it.chain(new_cond_consts, new_body_consts, init_vals), cond_nconsts=len(new_cond_consts), body_nconsts=len(new_body_consts), cond_jaxpr=new_cond_jaxpr, body_jaxpr=new_body_jaxpr) return safe_map(trace.pure, out)
def _reap_while_rule(trace: HarvestTrace, *tracers, cond_jaxpr, body_jaxpr, cond_nconsts, body_nconsts): """Reaps the body of a while loop to get the reaps of the final iteration.""" cond_const_tracers, body_const_tracers, init_tracers = jax_util.split_list( tracers, [cond_nconsts, body_nconsts]) _, init_avals = tree_util.tree_map(lambda x: x.aval, (body_const_tracers, init_tracers)) cond_const_vals, body_const_vals, init_vals = tree_util.tree_map( lambda x: x.val, (cond_const_tracers, body_const_tracers, init_tracers)) context = trace_util.get_dynamic_context(trace) settings = context.settings body_metadata = _get_harvest_metadata(body_jaxpr, settings, *(body_const_tracers + init_tracers)) for k, meta in body_metadata.items(): mode = meta['mode'] if mode != 'clobber': raise ValueError( f'Must use clobber mode for \'{k}\' inside of a `while_loop`.') reap_avals = {k: v['aval'] for k, v in body_metadata.items()} cond_fun = jax_core.jaxpr_as_fun(cond_jaxpr) body_fun = jax_core.jaxpr_as_fun(body_jaxpr) reap_settings = dict(tag=settings.tag, allowlist=settings.allowlist, blocklist=settings.blocklist, exclusive=settings.exclusive) def new_cond(carry, _): return cond_fun(*(cond_const_vals + carry)) def new_body(carry, _): carry, reaps = call_and_reap( body_fun, **reap_settings)(*(body_const_vals + carry)) return (carry, reaps) new_in_avals, new_in_tree = tree_util.tree_flatten( (init_avals, reap_avals)) new_cond_jaxpr, cond_consts, _ = lcf._initial_style_jaxpr( # pylint: disable=protected-access new_cond, new_in_tree, tuple(new_in_avals)) new_body_jaxpr, body_consts, out_tree = lcf._initial_style_jaxpr( # pylint: disable=protected-access new_body, new_in_tree, tuple(new_in_avals)) dummy_reap_vals = tree_util.tree_map(lambda x: jnp.zeros(x.shape, x.dtype), reap_avals) new_in_vals = tree_util.tree_leaves((init_vals, dummy_reap_vals)) out = lax.while_p.bind(*(cond_consts + body_consts + new_in_vals), cond_nconsts=len(cond_consts), body_nconsts=len(body_consts), cond_jaxpr=new_cond_jaxpr, body_jaxpr=new_body_jaxpr) out = jax_util.safe_map(trace.pure, out) out, reaps = tree_util.tree_unflatten(out_tree, out) for k, v in reaps.items(): sow(v, name=k, tag=settings.tag, mode=body_metadata[k]['mode']) return out
def _scan_callback_rule(trace, *tracers, reverse, length, num_consts, num_carry, jaxpr, linear, unroll): const_tracers, carry_tracers, xs_tracers = split_list(tracers, [num_consts, num_carry]) carry_avals, xs_avals = tree_map(lambda x: x.aval, (carry_tracers, xs_tracers)) const_vals, carry_vals, xs_vals = tree_map(lambda x: x.val, (const_tracers, carry_tracers, xs_tracers)) x_tracers = [t[0] for t in xs_tracers] x_avals = [t.aval for t in x_tracers] body_fun = jaxpr_as_fun(jaxpr) def new_body(*vals): out = body_fun(*vals) out_carry, y = split_list(out, [num_carry]) return out_carry, y new_body = callback_transform(new_body, trace.callback, strip_calls=trace.strip_calls) # type: ignore in_tree = tree_structure(carry_avals + xs_avals) new_jaxpr, new_consts, _ = lcf._initial_style_jaxpr( new_body, in_tree, tuple(carry_avals + x_avals)) vals = tuple(it.chain(new_consts, carry_vals, xs_vals)) out_vals = lax.scan_p.bind(*vals, reverse=reverse, length=length, num_consts=len(new_consts), num_carry=num_carry, jaxpr=new_jaxpr, linear=linear, unroll=unroll) return safe_map(trace.pure, out_vals)
def build_output_vals(self, scope, carried_state_names, carried_tree, init_vals, body_closed_jaxpr, body_const_vals): # Simulate a pass-through false branch in_vals, in_tree = tree_util.tree_flatten( (body_const_vals, tree_util.tree_unflatten(carried_tree, init_vals))) in_avals = safe_map(_BodyTracer.abstractify, in_vals) pass_through_closed_jaxpr, pass_through_const_vals, _ = ( lax_control_flow._initial_style_jaxpr(lambda *args: args[1], in_tree, tuple(in_avals))) assert len(pass_through_const_vals) == 0 args = list(itertools.chain(body_const_vals, init_vals)) return lax_control_flow.cond_p.bind( self.index, *args, branches=(pass_through_closed_jaxpr, body_closed_jaxpr), linear=(False, ) * len(args))
def build_output_vals(self, scope, carried_state_names, carried_tree, init_vals, body_closed_jaxpr, body_const_vals): # Trace the conditional function. cond_func takes 0 arguments, but # for lax.while we need a conditional function that takes the # carried_state_names. _initial_style_jaxpr will start its own trace and # will create tracers for all the carried state. We must put these values # in the scope._mutable_state before we trace the conditional # function. def cond_func_wrapped(*args): assert len(args) == len(carried_state_names) for ms, init_ms in zip(carried_state_names, args): scope._mutable_state[ms] = init_ms res = self.cond_func() # Conditional function is not allowed to modify the scope state for ms, init_ms in zip(carried_state_names, args): if not (scope._mutable_state[ms] is init_ms): raise ValueError( f"Conditional function modifies scope.{ms} field.") return res init_avals = safe_map(_BodyTracer.abstractify, init_vals) cond_jaxpr, cond_consts, cond_tree = ( lax_control_flow._initial_style_jaxpr(cond_func_wrapped, carried_tree, tuple(init_avals))) # TODO: share these checks with lax_control_flow.while if not tree_util.treedef_is_leaf(cond_tree): raise TypeError( f"cond_fun must return a boolean scalar, but got pytree {cond_tree}." ) if not safe_map(core.typecompat, cond_jaxpr.out_avals, [core.ShapedArray((), np.bool_)]): raise TypeError( f"cond_fun must return a boolean scalar, but got output type(s) " f"{cond_jaxpr.out_avals}.") return lax_control_flow.while_p.bind(*cond_consts, *body_const_vals, *init_vals, cond_nconsts=len(cond_consts), cond_jaxpr=cond_jaxpr, body_nconsts=len(body_const_vals), body_jaxpr=body_closed_jaxpr)
def _plant_while_rule(trace: HarvestTrace, *tracers, cond_jaxpr, body_jaxpr, cond_nconsts, body_nconsts): """Injects values into a while loop, overriding values for all iterations.""" cond_const_tracers, body_const_tracers, init_tracers = jax_util.split_list( tracers, [cond_nconsts, body_nconsts]) init_avals = tree_util.tree_map(lambda x: x.aval, init_tracers) cond_const_vals, body_const_vals, init_vals = tree_util.tree_map( lambda x: x.val, (cond_const_tracers, body_const_tracers, init_tracers)) context = trace_util.get_dynamic_context(trace) settings = context.settings body_metadata = _get_harvest_metadata(body_jaxpr, settings, *(body_const_tracers + init_tracers)) for k, meta in body_metadata.items(): mode = meta['mode'] if mode != 'clobber': raise ValueError( f'Must use clobber mode for \'{k}\' inside of a `while_loop`.') body_fun = jax_core.jaxpr_as_fun(body_jaxpr) plant_settings = dict(tag=settings.tag, allowlist=settings.allowlist, blocklist=settings.blocklist, exclusive=settings.exclusive) plants = context.plants def new_body(*carry): carry = plant(body_fun, **plant_settings)(plants, *(tuple(body_const_vals) + carry)) return carry in_tree = tree_util.tree_structure(init_avals) new_body_jaxpr, new_body_consts, _ = lcf._initial_style_jaxpr( # pylint: disable=protected-access new_body, in_tree, tuple(init_avals)) out = lcf.while_p.bind(*(cond_const_vals + new_body_consts + init_vals), cond_nconsts=len(cond_const_vals), body_nconsts=len(new_body_consts), cond_jaxpr=cond_jaxpr, body_jaxpr=new_body_jaxpr) return jax_util.safe_map(trace.pure, out)
def _plant_scan_rule(trace: HarvestTrace, *tracers, length, reverse, jaxpr, num_consts, num_carry, linear, unroll): """Injects values into a scan according to their sow mode.""" const_tracers, carry_tracers, xs_tracers = jax_util.split_list( tracers, [num_consts, num_carry]) carry_avals, xs_avals = tree_util.tree_map(lambda x: x.aval, (carry_tracers, xs_tracers)) const_vals, carry_vals, xs_vals = tree_util.tree_map( lambda x: x.val, (const_tracers, carry_tracers, xs_tracers)) context = trace_util.get_dynamic_context(trace) settings = context.settings x_tracers = [t[0] if hasattr(t, '_getitem') else t for t in xs_tracers] x_avals = [t.aval for t in x_tracers] metadata = _get_harvest_metadata( jaxpr, settings, *(const_tracers + carry_tracers + x_tracers)) plants = context.plants plant_modes = collections.defaultdict(set) plant_xs_avals = {} for name, meta in metadata.items(): mode = meta['mode'] aval = meta['aval'] if mode == 'strict': raise ValueError( f'Cannot use strict mode for \'{name}\' inside `scan`.') plant_modes[mode].add(name) if mode == 'append' and name in plants: plant_xs_avals[name] = aval body_fun = jax_core.jaxpr_as_fun(jaxpr) clobber_plants = { name: value for name, value in plants.items() if name in plant_modes['clobber'] } append_plants = { name: value for name, value in plants.items() if name in plant_modes['append'] } plant_xs_flat_avals, _ = tree_util.tree_flatten(plant_xs_avals) plant_xs_in_tree = tree_util.tree_structure( (carry_avals, (xs_avals, plant_xs_avals))) def new_body(carry, x): x, plants = x all_plants = {**plants, **clobber_plants} all_values = const_vals + tree_util.tree_leaves((carry, x)) out = plant(body_fun, tag=settings.tag, allowlist=settings.allowlist, blocklist=settings.blocklist, exclusive=settings.exclusive)(all_plants, *all_values) carry_out, y = jax_util.split_list(out, [num_carry]) return carry_out, y new_body_jaxpr, consts, _ = lcf._initial_style_jaxpr( # pylint: disable=protected-access new_body, plant_xs_in_tree, tuple(carry_avals + x_avals + plant_xs_flat_avals)) plant_vals = tree_util.tree_leaves(append_plants) out = lcf.scan_p.bind(*(consts + carry_vals + xs_vals + plant_vals), reverse=reverse, length=length, jaxpr=new_body_jaxpr, num_consts=len(consts), num_carry=num_carry, linear=linear + (False, ) * len(plant_vals), unroll=unroll) return out
def _reap_scan_rule(trace: HarvestTrace, *tracers, length, reverse, jaxpr, num_consts, num_carry, linear, unroll): """Reaps the body of a scan to pull out `clobber` and `append` sows.""" const_tracers, carry_tracers, xs_tracers = jax_util.split_list( tracers, [num_consts, num_carry]) _, carry_avals, xs_avals = tree_util.tree_map( lambda x: x.aval, (const_tracers, carry_tracers, xs_tracers)) const_vals, carry_vals, xs_vals = tree_util.tree_map( lambda x: x.val, (const_tracers, carry_tracers, xs_tracers)) context = trace_util.get_dynamic_context(trace) settings = context.settings x_tracers = [t[0] if hasattr(t, '_getitem') else t for t in xs_tracers] x_avals = [t.aval for t in x_tracers] x_vals = [t.val for t in x_tracers] metadata = _get_harvest_metadata(jaxpr, settings, *(const_vals + carry_vals + x_vals)) reap_modes = collections.defaultdict(set) reap_carry_avals = {} for name, meta in metadata.items(): mode = meta['mode'] aval = meta['aval'] if mode == 'strict': raise ValueError( f'Cannot use strict mode for \'{name}\' inside `scan`.') reap_modes[mode].add(name) if mode == 'clobber': reap_carry_avals[name] = aval body_fun = jax_core.jaxpr_as_fun(jaxpr) reap_carry_flat_avals, _ = tree_util.tree_flatten(reap_carry_avals) reap_carry_in_tree = tree_util.tree_structure( ((carry_avals, reap_carry_avals), xs_avals)) def new_body(carry, x): carry, _ = carry all_values = const_vals + tree_util.tree_leaves((carry, x)) out, reaps = call_and_reap(body_fun, tag=settings.tag, allowlist=settings.allowlist, blocklist=settings.blocklist, exclusive=settings.exclusive)(*all_values) carry_out, y = jax_util.split_list(out, [num_carry]) carry_reaps = { name: val for name, val in reaps.items() if name in reap_modes['clobber'] } xs_reaps = { name: val for name, val in reaps.items() if name in reap_modes['append'] } return (carry_out, carry_reaps), (y, xs_reaps) new_body_jaxpr, consts, out_tree = lcf._initial_style_jaxpr( # pylint: disable=protected-access new_body, reap_carry_in_tree, tuple(carry_avals + reap_carry_flat_avals + x_avals)) dummy_reap_carry_vals = tree_util.tree_map( lambda x: jnp.zeros(x.shape, x.dtype), reap_carry_flat_avals) out = lax.scan_p.bind( *(consts + carry_vals + dummy_reap_carry_vals + xs_vals), reverse=reverse, length=length, jaxpr=new_body_jaxpr, num_consts=len(consts), num_carry=len(carry_vals + dummy_reap_carry_vals), linear=(linear[:len(consts)] + (False, ) * len(dummy_reap_carry_vals) + linear[len(consts):]), unroll=unroll) (carry_out, carry_reaps), (ys, ys_reaps) = tree_util.tree_unflatten(out_tree, out) (carry_out, carry_reaps), (ys, ys_reaps) = tree_util.tree_map( trace.pure, ((carry_out, carry_reaps), (ys, ys_reaps))) for k, v in {**carry_reaps, **ys_reaps}.items(): sow(v, tag=settings.tag, mode=metadata[k]['mode'], name=k) return carry_out + ys
def _scan_harvest_rule(trace: HarvestTrace, *tracers, length, reverse, jaxpr, num_consts, num_carry, linear, unroll): """Collects and injects values into/from the scan body.""" context = trace_util.get_dynamic_context(trace) settings = context.settings values = [t.val for t in tracers] consts, init, xs = jax_util.split_list(values, [num_consts, num_carry]) active_sows = _find_sows(jaxpr, settings.tag) active_modes = [params['mode'] for params in active_sows] if any(mode == 'strict' for mode in active_modes): raise ValueError('Cannot use strict mode in a scan.') active_names = [params['name'] for params in active_sows] sow_modes = {name: mode for name, mode in zip(active_names, active_modes)} carry_plants = { name: context.plants[name] for name in active_names if name in context.plants and sow_modes[name] == 'clobber' } xs_plants = { name: context.plants[name] for name in active_names if name in context.plants and sow_modes[name] == 'append' } def jaxpr_fun(carry, x): body_out = jax_core.eval_jaxpr(jaxpr.jaxpr, jaxpr.literals, *(consts + carry + x)) carry, y = jax_util.split_list(body_out, [num_carry]) return carry, y harvest_body = harvest(jaxpr_fun, tag=settings.tag, allowlist=settings.allowlist, blocklist=settings.blocklist, mode=settings.mode) def body(carry, x): x_plants, x_vals = x (carry, y), reaps = harvest_body({ **carry_plants, **x_plants }, carry, x_vals) return carry, (y, reaps) xs_flat = tree_util.tree_leaves((xs_plants, xs)) x_avals = [] for x in xs_flat: x_aval = jax_core.get_aval(x) if x_aval is jax_core.abstract_unit: x_avals.append(x_aval) else: x_shape, x_dtype = masking.padded_shape_as_value( x.shape[1:]), x.dtype x_avals.append(abstract_arrays.ShapedArray(x_shape, x_dtype)) x_avals = tuple(x_avals) init_avals = tuple( abstract_arrays.raise_to_shaped(jax_core.get_aval(a)) for a in init) in_flat, in_tree = tree_util.tree_flatten((init, (xs_plants, xs))) body_jaxpr, new_consts, out_tree = ( lax_control_flow._initial_style_jaxpr( # pylint: disable=protected-access body, in_tree, init_avals + x_avals)) new_values = list(new_consts) + in_flat num_xs_plants = len(new_values) - len(init) - len(xs) - len(new_consts) remaining_linear = linear[num_consts:] new_linear = ((False, ) * len(new_consts) + remaining_linear[:len(init)] + (False, ) * num_xs_plants + remaining_linear[len(init):]) assert len(new_linear) == len(new_values) outs = lax.scan_p.bind(*new_values, length=length, reverse=reverse, jaxpr=body_jaxpr, num_consts=len(new_consts), num_carry=num_carry, linear=new_linear, unroll=unroll) outs = safe_map(trace.pure, outs) carry, (ys, reaps) = tree_util.tree_unflatten(out_tree, outs) out_reaps = {} for k, val in reaps.items(): mode = sow_modes.get(k, 'strict') if mode == 'append': val = tree_util.tree_map(np.concatenate, val) elif mode == 'clobber': val = tree_util.tree_map(lambda x: x[-1], val) out_reaps[k] = sow(val, tag=settings.tag, name=k, mode='strict') (carry, ys) = prim.tie_in(out_reaps, (carry, ys)) return carry + ys