def _reap_cond_rule(trace, *tracers, branches, linear): """Reaps each path of the `cond`.""" index_tracer, ops_tracers = tracers[0], tracers[1:] index_val, ops_vals = tree_util.tree_map(lambda x: x.val, (index_tracer, ops_tracers)) _, ops_avals = tree_util.tree_map(lambda x: x.aval, (index_tracer, ops_tracers)) context = trace_util.get_dynamic_context(trace) settings = context.settings reap_settings = dict(tag=settings.tag, allowlist=settings.allowlist, blocklist=settings.blocklist, exclusive=settings.exclusive) branch_metadatas = tuple( _get_harvest_metadata(branch, settings, *ops_tracers) for branch in branches) _check_branch_metadata(branch_metadatas) branch_funs = tuple(map(jax_core.jaxpr_as_fun, branches)) reaped_branches = tuple( call_and_reap(f, **reap_settings) for f in branch_funs) in_tree = tree_util.tree_structure(ops_avals) new_branch_jaxprs, consts, out_trees = ( lcf._initial_style_jaxprs_with_common_consts( # pylint: disable=protected-access reaped_branches, in_tree, ops_avals, lax.cond_p.name)) out = lax.cond_p.bind(index_val, *(tuple(consts) + ops_vals), branches=tuple(new_branch_jaxprs), linear=(False, ) * len(tuple(consts) + linear)) out = jax_util.safe_map(trace.pure, out) out, reaps = tree_util.tree_unflatten(out_trees[0], out) for k, v in reaps.items(): sow(v, name=k, tag=settings.tag, mode=branch_metadatas[0][k]['mode']) return out
def process_call(self, call_primitive: jax_core.Primitive, f: Any, tracers: List['HarvestTracer'], params: Dict[str, Any]): context = trace_util.get_dynamic_context(self) if call_primitive is nest_p: return context.process_nest(self, f, *tracers, **params) return context.process_higher_order_primitive(self, call_primitive, f, tracers, params, False)
def _plant_cond_rule(trace, *tracers, branches, linear): """Injects the same values into both branches of a conditional.""" index_tracer, ops_tracers = tracers[0], tracers[1:] index_val, ops_vals = tree_util.tree_map(lambda x: x.val, (index_tracer, ops_tracers)) ops_avals = tree_util.tree_map(lambda x: x.aval, ops_tracers) context = trace_util.get_dynamic_context(trace) settings = context.settings plant_settings = dict(tag=settings.tag, allowlist=settings.allowlist, blocklist=settings.blocklist, exclusive=settings.exclusive) branch_metadatas = tuple( _get_harvest_metadata(branch, settings, *ops_tracers) for branch in branches) _check_branch_metadata(branch_metadatas) plants = context.plants branch_funs = tuple(map(jax_core.jaxpr_as_fun, branches)) planted_branches = tuple( functools.partial(plant(f, **plant_settings), plants) for f in branch_funs) in_tree = tree_util.tree_structure(ops_avals) new_branch_jaxprs, consts, _ = ( lcf._initial_style_jaxprs_with_common_consts( # pylint: disable=protected-access planted_branches, in_tree, ops_avals, lax.cond_p.name)) out = lax.cond_p.bind(index_val, *(tuple(consts) + ops_vals), branches=tuple(new_branch_jaxprs), linear=(False, ) * len(tuple(consts) + linear)) return jax_util.safe_map(trace.pure, out)
def process_primitive( self, primitive: jax_core.Primitive, tracers: List['HarvestTracer'], params: Dict[str, Any]) -> Union['HarvestTracer', List['HarvestTracer']]: context = trace_util.get_dynamic_context(self) custom_rule = context.get_custom_rule(primitive) if custom_rule: return custom_rule(self, *tracers, **params) return self.default_process_primitive(primitive, tracers, params)
def process_higher_order_primitive(self, primitive, f, tracers, params, is_map): name = params.pop('name', f.__name__) tracers = safe_map(self.instantiate_const, tracers) vals = [t.val for t in tracers] context = trace_util.get_dynamic_context(self) active_tag = context.settings.tag plants = context.plants if primitive is nest_p: plants = plants.get(params['scope'], {}) if is_map: # TODO(sharadmv): figure out if invars are mapped or unmapped params = params.copy() out_axes_thunk = params['out_axes_thunk'] @jax_util.as_hashable_function(key=('harvest', out_axes_thunk)) def new_out_axes_thunk(): out_axes = out_axes_thunk() assert all(out_axis == 0 for out_axis in out_axes) return (0, ) * out_tree().num_leaves new_params = dict( params, in_axes=(0, ) * len(tree_util.tree_leaves(plants)) + params['in_axes'], out_axes_thunk=new_out_axes_thunk) else: new_params = dict(params) all_args, all_tree = tree_util.tree_flatten((plants, vals)) num_plants = len(all_args) - len(vals) if 'donated_invars' in params: new_params['donated_invars'] = ((False, ) * num_plants + params['donated_invars']) f, out_tree = harvest_eval(f, self, context.settings, all_tree) out_flat = primitive.bind(f, *all_args, **new_params, name=jax_util.wrap_name(name, 'harvest')) out, reaps = tree_util.tree_unflatten(out_tree(), out_flat) out_tracers = safe_map(self.pure, out) reap_tracers = tree_util.tree_map(self.pure, reaps) if primitive is nest_p and reap_tracers: flat_tracers, tree = tree_util.tree_flatten(reap_tracers) self.handle_sow(*flat_tracers, name=params['scope'], tag=active_tag, mode='strict', tree=tree) else: for name, reap_tracer in reap_tracers.items(): flat_tracers, tree = tree_util.tree_flatten(reap_tracer) self.handle_sow(*flat_tracers, name=name, tag=active_tag, mode='strict', tree=tree) return out_tracers
def _reap_while_rule(trace: HarvestTrace, *tracers, cond_jaxpr, body_jaxpr, cond_nconsts, body_nconsts): """Reaps the body of a while loop to get the reaps of the final iteration.""" cond_const_tracers, body_const_tracers, init_tracers = jax_util.split_list( tracers, [cond_nconsts, body_nconsts]) _, init_avals = tree_util.tree_map(lambda x: x.aval, (body_const_tracers, init_tracers)) cond_const_vals, body_const_vals, init_vals = tree_util.tree_map( lambda x: x.val, (cond_const_tracers, body_const_tracers, init_tracers)) context = trace_util.get_dynamic_context(trace) settings = context.settings body_metadata = _get_harvest_metadata(body_jaxpr, settings, *(body_const_tracers + init_tracers)) for k, meta in body_metadata.items(): mode = meta['mode'] if mode != 'clobber': raise ValueError( f'Must use clobber mode for \'{k}\' inside of a `while_loop`.') reap_avals = {k: v['aval'] for k, v in body_metadata.items()} cond_fun = jax_core.jaxpr_as_fun(cond_jaxpr) body_fun = jax_core.jaxpr_as_fun(body_jaxpr) reap_settings = dict(tag=settings.tag, allowlist=settings.allowlist, blocklist=settings.blocklist, exclusive=settings.exclusive) def new_cond(carry, _): return cond_fun(*(cond_const_vals + carry)) def new_body(carry, _): carry, reaps = call_and_reap( body_fun, **reap_settings)(*(body_const_vals + carry)) return (carry, reaps) new_in_avals, new_in_tree = tree_util.tree_flatten( (init_avals, reap_avals)) new_cond_jaxpr, cond_consts, _ = lcf._initial_style_jaxpr( # pylint: disable=protected-access new_cond, new_in_tree, tuple(new_in_avals)) new_body_jaxpr, body_consts, out_tree = lcf._initial_style_jaxpr( # pylint: disable=protected-access new_body, new_in_tree, tuple(new_in_avals)) dummy_reap_vals = tree_util.tree_map(lambda x: jnp.zeros(x.shape, x.dtype), reap_avals) new_in_vals = tree_util.tree_leaves((init_vals, dummy_reap_vals)) out = lax.while_p.bind(*(cond_consts + body_consts + new_in_vals), cond_nconsts=len(cond_consts), body_nconsts=len(body_consts), cond_jaxpr=new_cond_jaxpr, body_jaxpr=new_body_jaxpr) out = jax_util.safe_map(trace.pure, out) out, reaps = tree_util.tree_unflatten(out_tree, out) for k, v in reaps.items(): sow(v, name=k, tag=settings.tag, mode=body_metadata[k]['mode']) return out
def default_process_primitive( self, primitive: jax_core.Primitive, tracers: List['HarvestTracer'], params: Dict[str, Any]) -> Union['HarvestTracer', List['HarvestTracer']]: context = trace_util.get_dynamic_context(self) vals = [t.val for t in tracers] if primitive is sow_p: outvals = context.process_sow(*vals, **params) return jax_util.safe_map(self.pure, outvals) outvals = primitive.bind(*vals, **params) if not primitive.multiple_results: outvals = [outvals] out_tracers = jax_util.safe_map(self.pure, outvals) if primitive.multiple_results: return out_tracers return out_tracers[0]
def _plant_while_rule(trace: HarvestTrace, *tracers, cond_jaxpr, body_jaxpr, cond_nconsts, body_nconsts): """Injects values into a while loop, overriding values for all iterations.""" cond_const_tracers, body_const_tracers, init_tracers = jax_util.split_list( tracers, [cond_nconsts, body_nconsts]) init_avals = tree_util.tree_map(lambda x: x.aval, init_tracers) cond_const_vals, body_const_vals, init_vals = tree_util.tree_map( lambda x: x.val, (cond_const_tracers, body_const_tracers, init_tracers)) context = trace_util.get_dynamic_context(trace) settings = context.settings body_metadata = _get_harvest_metadata(body_jaxpr, settings, *(body_const_tracers + init_tracers)) for k, meta in body_metadata.items(): mode = meta['mode'] if mode != 'clobber': raise ValueError( f'Must use clobber mode for \'{k}\' inside of a `while_loop`.') body_fun = jax_core.jaxpr_as_fun(body_jaxpr) plant_settings = dict(tag=settings.tag, allowlist=settings.allowlist, blocklist=settings.blocklist, exclusive=settings.exclusive) plants = context.plants def new_body(*carry): carry = plant(body_fun, **plant_settings)(plants, *(tuple(body_const_vals) + carry)) return carry in_tree = tree_util.tree_structure(init_avals) new_body_jaxpr, new_body_consts, _ = lcf._initial_style_jaxpr( # pylint: disable=protected-access new_body, in_tree, tuple(init_avals)) out = lcf.while_p.bind(*(cond_const_vals + new_body_consts + init_vals), cond_nconsts=len(cond_const_vals), body_nconsts=len(new_body_consts), cond_jaxpr=cond_jaxpr, body_jaxpr=new_body_jaxpr) return jax_util.safe_map(trace.pure, out)
def default_process_primitive(self, primitive, tracers, params): """Partially evaluate primitives and saves variable recipes.""" pvs, consts = jax_util.unzip2(t.pval for t in tracers) if all(pv is None for pv in pvs): return primitive.bind(*consts, **params) settings = trace_util.get_dynamic_context(self).settings tracers = safe_map(self.instantiate_const, tracers) if any(not isinstance(t, UnzipTracer) for t in tracers): assert False key = all(t.is_key() for t in tracers) avals = [t.aval for t in tracers] ans = primitive.abstract_eval(*avals, **params) if not primitive.multiple_results: ans = [ans] out_tracers = [ UnzipTracer(self, pe.PartialVal((aval, jax_core.unit)), None, key) for aval in ans ] # Passing in UnzipTracer, which pytype does not recognize as JaxprTracer eqn = pe.new_eqn_recipe(tracers, out_tracers, primitive, params, source_info_util.current()) # pytype: disable=wrong-arg-types for t in out_tracers: t.recipe = eqn is_variable = (key and primitive is harvest.sow_p and params['tag'] == settings.tag) # This block is where UnzipTrace mainly differs from pe.JaxprTrace. Where # JaxprTrace will just return out_tracers, UnzipTrace will record an # additional VariableRecipe into the tracers, which will be used after # the trace is complete to construct init/apply Jaxprs. if is_variable: name, var_in_tracers, var_out_tracers = unzip_registry[primitive]( tracers, out_tracers, **params) variable_recipe = VariableRecipe(name, var_in_tracers, var_out_tracers) for t in out_tracers: t.variable_recipe = variable_recipe if primitive.multiple_results: return out_tracers return out_tracers[0]
def process_higher_order_primitive(self, trace, call_primitive, f, tracers, params, is_map): del is_map name = jax_util.wrap_name(params.pop('name', f.__name__), 'reap') context = trace_util.get_dynamic_context(trace) vals = [t.val for t in tracers] plants = context.plants if 'in_axes' in params: # TODO(b/199459308): figure out if invars are mapped or unmapped params = dict(params, in_axes=(0, ) * len(tree_util.tree_leaves(plants)) + params['in_axes']) if 'donated_invars' in params: params = dict(params) params['donated_invars'] = ( (False, ) * len(tree_util.tree_leaves(plants)) + params['donated_invars']) elif call_primitive is nest_p: plants = plants.get(params['scope'], {}) all_vals, all_tree = tree_util.tree_flatten((plants, vals)) f = plant_eval(f, trace, self.settings, all_tree) out_vals = call_primitive.bind(f, *all_vals, name=name, **params) return jax_util.safe_map(trace.pure, out_vals)
def _plant_scan_rule(trace: HarvestTrace, *tracers, length, reverse, jaxpr, num_consts, num_carry, linear, unroll): """Injects values into a scan according to their sow mode.""" const_tracers, carry_tracers, xs_tracers = jax_util.split_list( tracers, [num_consts, num_carry]) carry_avals, xs_avals = tree_util.tree_map(lambda x: x.aval, (carry_tracers, xs_tracers)) const_vals, carry_vals, xs_vals = tree_util.tree_map( lambda x: x.val, (const_tracers, carry_tracers, xs_tracers)) context = trace_util.get_dynamic_context(trace) settings = context.settings x_tracers = [t[0] if hasattr(t, '_getitem') else t for t in xs_tracers] x_avals = [t.aval for t in x_tracers] metadata = _get_harvest_metadata( jaxpr, settings, *(const_tracers + carry_tracers + x_tracers)) plants = context.plants plant_modes = collections.defaultdict(set) plant_xs_avals = {} for name, meta in metadata.items(): mode = meta['mode'] aval = meta['aval'] if mode == 'strict': raise ValueError( f'Cannot use strict mode for \'{name}\' inside `scan`.') plant_modes[mode].add(name) if mode == 'append' and name in plants: plant_xs_avals[name] = aval body_fun = jax_core.jaxpr_as_fun(jaxpr) clobber_plants = { name: value for name, value in plants.items() if name in plant_modes['clobber'] } append_plants = { name: value for name, value in plants.items() if name in plant_modes['append'] } plant_xs_flat_avals, _ = tree_util.tree_flatten(plant_xs_avals) plant_xs_in_tree = tree_util.tree_structure( (carry_avals, (xs_avals, plant_xs_avals))) def new_body(carry, x): x, plants = x all_plants = {**plants, **clobber_plants} all_values = const_vals + tree_util.tree_leaves((carry, x)) out = plant(body_fun, tag=settings.tag, allowlist=settings.allowlist, blocklist=settings.blocklist, exclusive=settings.exclusive)(all_plants, *all_values) carry_out, y = jax_util.split_list(out, [num_carry]) return carry_out, y new_body_jaxpr, consts, _ = lcf._initial_style_jaxpr( # pylint: disable=protected-access new_body, plant_xs_in_tree, tuple(carry_avals + x_avals + plant_xs_flat_avals)) plant_vals = tree_util.tree_leaves(append_plants) out = lcf.scan_p.bind(*(consts + carry_vals + xs_vals + plant_vals), reverse=reverse, length=length, jaxpr=new_body_jaxpr, num_consts=len(consts), num_carry=num_carry, linear=linear + (False, ) * len(plant_vals), unroll=unroll) return out
def _reap_scan_rule(trace: HarvestTrace, *tracers, length, reverse, jaxpr, num_consts, num_carry, linear, unroll): """Reaps the body of a scan to pull out `clobber` and `append` sows.""" const_tracers, carry_tracers, xs_tracers = jax_util.split_list( tracers, [num_consts, num_carry]) _, carry_avals, xs_avals = tree_util.tree_map( lambda x: x.aval, (const_tracers, carry_tracers, xs_tracers)) const_vals, carry_vals, xs_vals = tree_util.tree_map( lambda x: x.val, (const_tracers, carry_tracers, xs_tracers)) context = trace_util.get_dynamic_context(trace) settings = context.settings x_tracers = [t[0] if hasattr(t, '_getitem') else t for t in xs_tracers] x_avals = [t.aval for t in x_tracers] x_vals = [t.val for t in x_tracers] metadata = _get_harvest_metadata(jaxpr, settings, *(const_vals + carry_vals + x_vals)) reap_modes = collections.defaultdict(set) reap_carry_avals = {} for name, meta in metadata.items(): mode = meta['mode'] aval = meta['aval'] if mode == 'strict': raise ValueError( f'Cannot use strict mode for \'{name}\' inside `scan`.') reap_modes[mode].add(name) if mode == 'clobber': reap_carry_avals[name] = aval body_fun = jax_core.jaxpr_as_fun(jaxpr) reap_carry_flat_avals, _ = tree_util.tree_flatten(reap_carry_avals) reap_carry_in_tree = tree_util.tree_structure( ((carry_avals, reap_carry_avals), xs_avals)) def new_body(carry, x): carry, _ = carry all_values = const_vals + tree_util.tree_leaves((carry, x)) out, reaps = call_and_reap(body_fun, tag=settings.tag, allowlist=settings.allowlist, blocklist=settings.blocklist, exclusive=settings.exclusive)(*all_values) carry_out, y = jax_util.split_list(out, [num_carry]) carry_reaps = { name: val for name, val in reaps.items() if name in reap_modes['clobber'] } xs_reaps = { name: val for name, val in reaps.items() if name in reap_modes['append'] } return (carry_out, carry_reaps), (y, xs_reaps) new_body_jaxpr, consts, out_tree = lcf._initial_style_jaxpr( # pylint: disable=protected-access new_body, reap_carry_in_tree, tuple(carry_avals + reap_carry_flat_avals + x_avals)) dummy_reap_carry_vals = tree_util.tree_map( lambda x: jnp.zeros(x.shape, x.dtype), reap_carry_flat_avals) out = lax.scan_p.bind( *(consts + carry_vals + dummy_reap_carry_vals + xs_vals), reverse=reverse, length=length, jaxpr=new_body_jaxpr, num_consts=len(consts), num_carry=len(carry_vals + dummy_reap_carry_vals), linear=(linear[:len(consts)] + (False, ) * len(dummy_reap_carry_vals) + linear[len(consts):]), unroll=unroll) (carry_out, carry_reaps), (ys, ys_reaps) = tree_util.tree_unflatten(out_tree, out) (carry_out, carry_reaps), (ys, ys_reaps) = tree_util.tree_map( trace.pure, ((carry_out, carry_reaps), (ys, ys_reaps))) for k, v in {**carry_reaps, **ys_reaps}.items(): sow(v, tag=settings.tag, mode=metadata[k]['mode'], name=k) return carry_out + ys
def process_map(self, call_primitive: jax_core.Primitive, f: Any, tracers: List['HarvestTracer'], params: Dict[str, Any]): context = trace_util.get_dynamic_context(self) return context.process_higher_order_primitive(self, call_primitive, f, tracers, params, True)
def handle_call_primitive(self, call_primitive, f, tracers, params, is_map): """Handler for call_primitives, like jit or layer_call. When an UnzipTracer hits a call primitive, there is either a variable inside of the call primitive, in which case the input function needs to be unzipped into two, or there are no variables in the function, so the call_primitive is recorded in the trace as-is. We use `unzip_eval_wrapper`, which returns whether or not an unzip was successful or not. If it was successful, we record two new Jaxprs into the trace (one for init, one for apply). Otherwise, we just record the Jaxpr corresponding to the function call. Args: call_primitive: a call primitive like xla_call f: a jax.linear_util wrapped function to be called tracers: inputs to the function params: parameters of the primitives is_map: whether or not the primitive is a map primitive (e.g. xla_pmap) Returns: A list of output tracers """ name = params.get('name', f.__name__) settings = trace_util.get_dynamic_context(self).settings tracers = safe_map(self.instantiate_const_abstracted, tracers) if call_primitive in current_custom_rules(): return current_custom_rules()[call_primitive](self, f, *tracers, **params) if call_primitive in pe.call_partial_eval_rules: raise NotImplementedError in_pvals = [t.pval for t in tracers] if is_map: unknown = pe.PartialVal.unknown in_pvals = [ pval if pval.is_known() or in_axis is None else unknown( mapped_aval(params['axis_size'], in_axis, pval[0])) for pval, in_axis in zip(in_pvals, params['in_axes']) ] out_axes_thunk = params['out_axes_thunk'] @jax_util.as_hashable_function(closure=('unzip', out_axes_thunk)) def new_out_axes_thunk(): out_axes = out_axes_thunk() assert all(out_axis == 0 for out_axis in out_axes) _, num_outputs, _ = aux() return (0, ) * num_outputs new_params = dict(params, out_axes_thunk=new_out_axes_thunk) else: new_params = params pvs, in_consts = jax_util.unzip2(t.pval for t in tracers) keys = tuple(t.is_key() for t in tracers) new_settings = UnzipSettings(settings.tag, call_primitive in block_registry) fun, aux = unzip_eval(f, self, keys, tuple(pvs), new_settings) out_flat = call_primitive.bind(fun, *in_consts, **new_params) success, _, results = aux() if not success: out_pvs, out_keys, jaxpr, env = results out_pv_consts, consts = jax_util.split_list( out_flat, [len(out_pvs)]) out_tracers = self._bound_output_tracers(call_primitive, new_params, jaxpr, consts, env, tracers, out_pvs, out_pv_consts, out_keys, name, is_map) return out_tracers init_name = jax_util.wrap_name(name, 'init') apply_name = jax_util.wrap_name(name, 'apply') init_pvs, num_init_consts, apply_pvs = results[0] init_jaxpr, apply_jaxpr = results[1] init_env, apply_env = results[2] variable_names, variable_tree, apply_keys = results[3] key_tracers = [t for t in tracers if t.is_key()] abstract_tracers = [t for t in tracers if not t.is_key()] all_init_consts, all_apply_consts = jax_util.split_list( out_flat, [len(init_pvs) + num_init_consts]) init_pv_consts, init_consts = jax_util.split_list( all_init_consts, [len(init_pvs)]) apply_pv_consts, apply_consts = jax_util.split_list( all_apply_consts, [len(apply_pvs)]) variable_tracers = self._bound_output_tracers( call_primitive, new_params, init_jaxpr, init_consts, init_env, key_tracers, init_pvs, init_pv_consts, [True] * len(init_pvs), init_name, is_map) unflat_variables = tree_util.tree_unflatten(variable_tree, variable_tracers) if call_primitive is harvest.nest_p: variable_dict = harvest.sow(dict( safe_zip(variable_names, unflat_variables)), tag=settings.tag, name=new_params['scope'], mode='strict') unflat_variables = tuple(variable_dict[name] for name in variable_names) else: unflat_variables = [ harvest.sow( # pylint: disable=g-complex-comprehension unflat_variable, tag=settings.tag, name=name, mode='strict') for unflat_variable, name in safe_zip( unflat_variables, variable_names) ] variable_tracers = tree_util.tree_leaves(unflat_variables) out_tracers = self._bound_output_tracers( call_primitive, new_params, apply_jaxpr, apply_consts, apply_env, variable_tracers + abstract_tracers, apply_pvs, apply_pv_consts, apply_keys, apply_name, is_map) return out_tracers
def _scan_harvest_rule(trace: HarvestTrace, *tracers, length, reverse, jaxpr, num_consts, num_carry, linear, unroll): """Collects and injects values into/from the scan body.""" context = trace_util.get_dynamic_context(trace) settings = context.settings values = [t.val for t in tracers] consts, init, xs = jax_util.split_list(values, [num_consts, num_carry]) active_sows = _find_sows(jaxpr, settings.tag) active_modes = [params['mode'] for params in active_sows] if any(mode == 'strict' for mode in active_modes): raise ValueError('Cannot use strict mode in a scan.') active_names = [params['name'] for params in active_sows] sow_modes = {name: mode for name, mode in zip(active_names, active_modes)} carry_plants = { name: context.plants[name] for name in active_names if name in context.plants and sow_modes[name] == 'clobber' } xs_plants = { name: context.plants[name] for name in active_names if name in context.plants and sow_modes[name] == 'append' } def jaxpr_fun(carry, x): body_out = jax_core.eval_jaxpr(jaxpr.jaxpr, jaxpr.literals, *(consts + carry + x)) carry, y = jax_util.split_list(body_out, [num_carry]) return carry, y harvest_body = harvest(jaxpr_fun, tag=settings.tag, allowlist=settings.allowlist, blocklist=settings.blocklist) def body(carry, x): x_plants, x_vals = x (carry, y), reaps = harvest_body({ **carry_plants, **x_plants }, carry, x_vals) return carry, (y, reaps) xs_flat = tree_util.tree_leaves((xs_plants, xs)) x_avals = [] for x in xs_flat: x_aval = jax_core.get_aval(x) if x_aval is jax_core.abstract_unit: x_avals.append(x_aval) else: x_shape, x_dtype = masking.padded_shape_as_value( x.shape[1:]), x.dtype x_avals.append(abstract_arrays.ShapedArray(x_shape, x_dtype)) x_avals = tuple(x_avals) init_avals = tuple( abstract_arrays.raise_to_shaped(jax_core.get_aval(a)) for a in init) in_flat, in_tree = tree_util.tree_flatten((init, (xs_plants, xs))) body_jaxpr, new_consts, out_tree = ( jax.lax.lax_control_flow._initial_style_jaxpr( # pylint: disable=protected-access body, in_tree, init_avals + x_avals)) new_values = list(new_consts) + in_flat num_xs_plants = len(new_values) - len(init) - len(xs) - len(new_consts) remaining_linear = linear[num_consts:] new_linear = ((False, ) * len(new_consts) + remaining_linear[:len(init)] + (False, ) * num_xs_plants + remaining_linear[len(init):]) assert len(new_linear) == len(new_values) outs = lax.scan_p.bind(*new_values, length=length, reverse=reverse, jaxpr=body_jaxpr, num_consts=len(new_consts), num_carry=num_carry, linear=new_linear, unroll=unroll) outs = safe_map(trace.pure, outs) carry, (ys, reaps) = tree_util.tree_unflatten(out_tree, outs) out_reaps = {} for k, val in reaps.items(): mode = sow_modes.get(k, 'strict') if mode == 'append': val = tree_util.tree_map(np.concatenate, val) elif mode == 'clobber': val = tree_util.tree_map(lambda x: x[-1], val) out_reaps[k] = sow(val, tag=settings.tag, name=k, mode='strict') (carry, ys) = prim.tie_in(out_reaps, (carry, ys)) return carry + ys
def handle_sow(self, *tracers, name, tag, mode, tree): vals = [t.val for t in tracers] context = trace_util.get_dynamic_context(self) return safe_map( self.pure, context.handle_sow(vals, name=name, tag=tag, mode=mode, tree=tree))