Esempio n. 1
0
def eval_provenance(fun, *args, **kwargs):
    """
    Compute the provenance output of ``fun`` using JAX's abstract
    interpretation machinery. There is no actual array computation performed.

    :param fun: A callable to track provenance of its (keyword) arguments.
    :param args: Positional arguments of `fun`.
    :param kwargs: Keyword arguments of `fun`.
    :returns: A pytree of :class:`ProvenanceArray`.
    """
    # flatten the function and its arguments
    args_flat, in_tree = jax.tree_util.tree_flatten((args, kwargs))
    wrapped_fun, out_tree = jax.api_util.flatten_fun(wrap_init(fun), in_tree)
    fun = wrap_init(wrapped_fun.call_wrapped)
    avals = jax.util.safe_map(jax.api_util.shaped_abstractify, args_flat)

    # execute the function and trace provenance
    with jax.core.new_main(_ProvenanceJaxprTrace, dynamic=True) as main:
        main.jaxpr_stack = ()
        out = partial_eval.trace_to_subjaxpr_dynamic(fun, main, avals)[1]

    # unflatten the output and get its provenance
    out = [jax.ShapeDtypeStruct(x.shape, x.dtype, x.named_shape) for x in out]
    out = jax.tree_util.tree_unflatten(out_tree(), out)
    return jax.tree_util.tree_map(
        lambda x: ProvenanceArray(
            x, x.named_shape.get("_provenance", frozenset())),
        out,
    )
Esempio n. 2
0
def _dynamic_jaxpr_process_xmap(self, primitive, f, tracers, params):
  from jax.interpreters.partial_eval import (
    trace_to_subjaxpr_dynamic, DynamicJaxprTracer, source_info_util,
    convert_constvars_jaxpr, call_param_updaters, new_jaxpr_eqn)
  assert primitive is xmap_p
  in_avals = [t.aval for t in tracers]
  axis_sizes = params['axis_sizes']
  mapped_in_avals = [_delete_aval_axes(a, a_in_axes)
                     for a, a_in_axes in zip(in_avals, params['in_axes'])]
  with core.extend_axis_env_nd(params['axis_sizes'].items()):
    jaxpr, mapped_out_avals, consts = trace_to_subjaxpr_dynamic(
        f, self.main, mapped_in_avals)
  out_axes = params['out_axes_thunk']()
  out_avals = [_insert_aval_axes(a, a_out_axes, axis_sizes)
               for a, a_out_axes in zip(mapped_out_avals, out_axes)]
  source_info = source_info_util.current()
  out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals]
  invars = map(self.getvar, tracers)
  constvars = map(self.getvar, map(self.instantiate_const, consts))
  outvars = map(self.makevar, out_tracers)
  new_in_axes = (None,) * len(consts) + params['in_axes']
  new_params = dict(params, in_axes=new_in_axes, out_axes=out_axes,
                    call_jaxpr=convert_constvars_jaxpr(jaxpr))
  del new_params['out_axes_thunk']
  update_params = call_param_updaters.get(primitive)
  if update_params:
    new_params = update_params(new_params, [True] * len(tracers))
  eqn = new_jaxpr_eqn([*constvars, *invars], outvars, primitive,
                      new_params, source_info)
  self.frame.eqns.append(eqn)
  return out_tracers
Esempio n. 3
0
File: maps.py Progetto: xxuffei/jax
def _dynamic_jaxpr_process_xmap(self, primitive, f, tracers, params):
  from jax.interpreters.partial_eval import (
    trace_to_subjaxpr_dynamic, DynamicJaxprTracer, source_info_util,
    convert_constvars_jaxpr, new_jaxpr_eqn)
  assert primitive is xmap_p
  in_avals = [t.aval for t in tracers]
  global_axis_sizes = params['global_axis_sizes']
  mapped_in_avals = [_delete_aval_axes(a, a_in_axes)
                     for a, a_in_axes in zip(in_avals, params['in_axes'])]
  with core.extend_axis_env_nd(global_axis_sizes.items()):
    jaxpr, mapped_out_avals, consts = trace_to_subjaxpr_dynamic(
        f, self.main, mapped_in_avals)
  out_axes = params['out_axes_thunk']()
  axis_resource_count = _get_axis_resource_count(params['axis_resources'],
                                                 params['resource_env'])
  local_axis_sizes = {axis: axis_resource_count[axis].to_local(global_size)
                      for axis, global_size in global_axis_sizes.items()}
  out_avals = [_insert_aval_axes(a, a_out_axes, local_axis_sizes)
               for a, a_out_axes in zip(mapped_out_avals, out_axes)]
  _check_out_avals_vs_out_axes(out_avals, out_axes, params['global_axis_sizes'])
  source_info = source_info_util.current()
  out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals]
  invars = map(self.getvar, tracers)
  constvars = map(self.getvar, map(self.instantiate_const, consts))
  outvars = map(self.makevar, out_tracers)
  new_in_axes = (AxisNamePos(user_repr='{}'),) * len(consts) + params['in_axes']
  new_donated_invars = (False,) * len(consts) + params['donated_invars']
  new_params = dict(params, in_axes=new_in_axes, out_axes=out_axes,
                    donated_invars=new_donated_invars,
                    call_jaxpr=convert_constvars_jaxpr(jaxpr))
  del new_params['out_axes_thunk']
  eqn = new_jaxpr_eqn([*constvars, *invars], outvars, primitive,
                      new_params, source_info)
  self.frame.eqns.append(eqn)
  return out_tracers