def eval_provenance(fun, *args, **kwargs): """ Compute the provenance output of ``fun`` using JAX's abstract interpretation machinery. There is no actual array computation performed. :param fun: A callable to track provenance of its (keyword) arguments. :param args: Positional arguments of `fun`. :param kwargs: Keyword arguments of `fun`. :returns: A pytree of :class:`ProvenanceArray`. """ # flatten the function and its arguments args_flat, in_tree = jax.tree_util.tree_flatten((args, kwargs)) wrapped_fun, out_tree = jax.api_util.flatten_fun(wrap_init(fun), in_tree) fun = wrap_init(wrapped_fun.call_wrapped) avals = jax.util.safe_map(jax.api_util.shaped_abstractify, args_flat) # execute the function and trace provenance with jax.core.new_main(_ProvenanceJaxprTrace, dynamic=True) as main: main.jaxpr_stack = () out = partial_eval.trace_to_subjaxpr_dynamic(fun, main, avals)[1] # unflatten the output and get its provenance out = [jax.ShapeDtypeStruct(x.shape, x.dtype, x.named_shape) for x in out] out = jax.tree_util.tree_unflatten(out_tree(), out) return jax.tree_util.tree_map( lambda x: ProvenanceArray( x, x.named_shape.get("_provenance", frozenset())), out, )
def _dynamic_jaxpr_process_xmap(self, primitive, f, tracers, params): from jax.interpreters.partial_eval import ( trace_to_subjaxpr_dynamic, DynamicJaxprTracer, source_info_util, convert_constvars_jaxpr, call_param_updaters, new_jaxpr_eqn) assert primitive is xmap_p in_avals = [t.aval for t in tracers] axis_sizes = params['axis_sizes'] mapped_in_avals = [_delete_aval_axes(a, a_in_axes) for a, a_in_axes in zip(in_avals, params['in_axes'])] with core.extend_axis_env_nd(params['axis_sizes'].items()): jaxpr, mapped_out_avals, consts = trace_to_subjaxpr_dynamic( f, self.main, mapped_in_avals) out_axes = params['out_axes_thunk']() out_avals = [_insert_aval_axes(a, a_out_axes, axis_sizes) for a, a_out_axes in zip(mapped_out_avals, out_axes)] source_info = source_info_util.current() out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals] invars = map(self.getvar, tracers) constvars = map(self.getvar, map(self.instantiate_const, consts)) outvars = map(self.makevar, out_tracers) new_in_axes = (None,) * len(consts) + params['in_axes'] new_params = dict(params, in_axes=new_in_axes, out_axes=out_axes, call_jaxpr=convert_constvars_jaxpr(jaxpr)) del new_params['out_axes_thunk'] update_params = call_param_updaters.get(primitive) if update_params: new_params = update_params(new_params, [True] * len(tracers)) eqn = new_jaxpr_eqn([*constvars, *invars], outvars, primitive, new_params, source_info) self.frame.eqns.append(eqn) return out_tracers
def _dynamic_jaxpr_process_xmap(self, primitive, f, tracers, params): from jax.interpreters.partial_eval import ( trace_to_subjaxpr_dynamic, DynamicJaxprTracer, source_info_util, convert_constvars_jaxpr, new_jaxpr_eqn) assert primitive is xmap_p in_avals = [t.aval for t in tracers] global_axis_sizes = params['global_axis_sizes'] mapped_in_avals = [_delete_aval_axes(a, a_in_axes) for a, a_in_axes in zip(in_avals, params['in_axes'])] with core.extend_axis_env_nd(global_axis_sizes.items()): jaxpr, mapped_out_avals, consts = trace_to_subjaxpr_dynamic( f, self.main, mapped_in_avals) out_axes = params['out_axes_thunk']() axis_resource_count = _get_axis_resource_count(params['axis_resources'], params['resource_env']) local_axis_sizes = {axis: axis_resource_count[axis].to_local(global_size) for axis, global_size in global_axis_sizes.items()} out_avals = [_insert_aval_axes(a, a_out_axes, local_axis_sizes) for a, a_out_axes in zip(mapped_out_avals, out_axes)] _check_out_avals_vs_out_axes(out_avals, out_axes, params['global_axis_sizes']) source_info = source_info_util.current() out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals] invars = map(self.getvar, tracers) constvars = map(self.getvar, map(self.instantiate_const, consts)) outvars = map(self.makevar, out_tracers) new_in_axes = (AxisNamePos(user_repr='{}'),) * len(consts) + params['in_axes'] new_donated_invars = (False,) * len(consts) + params['donated_invars'] new_params = dict(params, in_axes=new_in_axes, out_axes=out_axes, donated_invars=new_donated_invars, call_jaxpr=convert_constvars_jaxpr(jaxpr)) del new_params['out_axes_thunk'] eqn = new_jaxpr_eqn([*constvars, *invars], outvars, primitive, new_params, source_info) self.frame.eqns.append(eqn) return out_tracers