Example #1
0
def _rewrite_jaxpr(jaxpr: core.Jaxpr, has_input_token: bool,
                   has_output_token: bool) -> core.Jaxpr:
  """Rewrite a Jaxpr to thread the token, if needed."""
  assert has_input_token or not has_output_token

  if not has_input_token and not xla.jaxpr_uses_outfeed(jaxpr):
    return jaxpr

  mk_new_var = core.gensym([jaxpr])

  eqns: List[core.JaxprEqn] = []
  last_token_var = mk_new_var(core.abstract_token)  # store the incoming token
  if has_input_token:
    invars = jaxpr.invars + [last_token_var]
  else:
    invars = jaxpr.invars
    eqns.append(
        core.new_jaxpr_eqn([jaxpr.invars[0]], [last_token_var],
                           lax.create_token_p, {}, source_info_util.current()))

  for eqn in jaxpr.eqns:
    if not xla.primitive_uses_outfeed(eqn.primitive, eqn.params):
      eqns.append(eqn)
    else:
      output_token_var = mk_new_var(core.abstract_token)
      _rewrite_eqn(eqn, eqns, last_token_var, output_token_var, mk_new_var)
      last_token_var = output_token_var

  outvars = jaxpr.outvars + ([last_token_var] if has_output_token else [])
  new_jaxpr = core.Jaxpr(jaxpr.constvars, invars, outvars, eqns)
  return new_jaxpr
Example #2
0
 def _bound_output_tracers(self, primitive, params, jaxpr, consts, env,
                           in_tracers, out_pvs, out_consts, out_keys, name,
                           is_map):
     """Takes a traced function and binds the Jaxpr to output tracers."""
     lifted_jaxpr = pe.convert_constvars_jaxpr(jaxpr)
     const_tracers = safe_map(self.new_instantiated_const, consts)
     env_tracers = safe_map(self.instantiate_const, env)
     out_tracers = [
         UnzipTracer(self, pe.PartialVal((pv, const)), None, key)
         for pv, const, key in safe_zip(out_pvs, out_consts, out_keys)
     ]
     new_params = dict(params, name=name, call_jaxpr=lifted_jaxpr)
     if is_map:
         new_params = dict(new_params,
                           mapped_invars=tuple([True] * len(const_tracers) +
                                               [False] * len(env_tracers) +
                                               [True] * len(in_tracers)))
     if 'donated_invars' in params:
         new_donated_invars = (
             (False, ) * len(const_tracers) + (False, ) * len(env_tracers) +
             tuple(v for v, t in zip(params['donated_invars'], in_tracers)
                   if not t.pval.is_known()))
         new_params['donated_invars'] = new_donated_invars
     eqn = pe.new_eqn_recipe(tuple(const_tracers + env_tracers +
                                   in_tracers), out_tracers, primitive,
                             new_params, source_info_util.current())  # pytype: disable=wrong-arg-types
     for t in out_tracers:
         t.recipe = eqn
     return out_tracers
Example #3
0
def _axis_index_bind(*, axis_name):
    dynamic_axis_env = pxla._thread_local_state.dynamic_axis_env
    frame = dynamic_axis_env[axis_name]
    trace = frame.pmap_trace

    out_aval = ShapedArray((), np.int32)
    out_tracer = pe.JaxprTracer(trace, pe.PartialVal.unknown(out_aval), None)
    eqn = pe.new_eqn_recipe([], [out_tracer], axis_index_p,
                            dict(axis_name=axis_name),
                            source_info_util.current())
    out_tracer.recipe = eqn

    return out_tracer
Example #4
0
    def default_process_primitive(self, primitive, tracers, params):
        """Partially evaluate primitives and saves variable recipes."""
        pvs, consts = jax_util.unzip2(t.pval for t in tracers)
        if all(pv is None for pv in pvs):
            return primitive.bind(*consts, **params)
        settings = trace_util.get_dynamic_context(self).settings
        tracers = safe_map(self.instantiate_const, tracers)
        if any(not isinstance(t, UnzipTracer) for t in tracers):
            assert False
        key = all(t.is_key() for t in tracers)
        avals = [t.aval for t in tracers]
        ans = primitive.abstract_eval(*avals, **params)
        if not primitive.multiple_results:
            ans = [ans]
        out_tracers = [
            UnzipTracer(self, pe.PartialVal((aval, jax_core.unit)), None, key)
            for aval in ans
        ]
        # Passing in UnzipTracer, which pytype does not recognize as JaxprTracer
        eqn = pe.new_eqn_recipe(tracers, out_tracers, primitive, params,
                                source_info_util.current())  # pytype: disable=wrong-arg-types
        for t in out_tracers:
            t.recipe = eqn

        is_variable = (key and primitive is harvest.sow_p
                       and params['tag'] == settings.tag)
        # This block is where UnzipTrace mainly differs from pe.JaxprTrace. Where
        # JaxprTrace will just return out_tracers, UnzipTrace will record an
        # additional VariableRecipe into the tracers, which will be used after
        # the trace is complete to construct init/apply Jaxprs.
        if is_variable:
            name, var_in_tracers, var_out_tracers = unzip_registry[primitive](
                tracers, out_tracers, **params)
            variable_recipe = VariableRecipe(name, var_in_tracers,
                                             var_out_tracers)
            for t in out_tracers:
                t.variable_recipe = variable_recipe

        if primitive.multiple_results:
            return out_tracers
        return out_tracers[0]