Exemplo n.º 1
0
def _rewrite_jaxpr(jaxpr: core.Jaxpr, has_input_token: bool,
                   has_output_token: bool) -> Tuple[core.Jaxpr, bool]:
    """Rewrite a Jaxpr to thread the token, if needed."""
    assert has_input_token or not has_output_token

    if not has_input_token and not xla.jaxpr_uses_outfeed(jaxpr):
        return (jaxpr, False)
    max_var_count = max(_jaxpr_var_defs(jaxpr))
    mk_new_id = itertools.count(start=max_var_count + 1)

    def mk_new_var(aval: core.AbstractValue) -> core.Var:
        return core.Var(next(mk_new_id), '', aval)

    eqns: List[core.JaxprEqn] = []
    last_token_var = mk_new_var(
        core.abstract_token)  # store the incoming token
    if has_input_token:
        invars = jaxpr.invars + [last_token_var]
    else:
        invars = jaxpr.invars
        eqns.append(
            core.new_jaxpr_eqn([jaxpr.invars[0]], [last_token_var],
                               lax.create_token_p, {}))

    for eqn in jaxpr.eqns:
        if not xla.primitive_uses_outfeed(eqn.primitive, eqn.params):
            eqns.append(eqn)
        else:
            output_token_var = mk_new_var(core.abstract_token)
            _rewrite_eqn(eqn, eqns, last_token_var, output_token_var)
            last_token_var = output_token_var

    outvars = jaxpr.outvars + ([last_token_var] if has_output_token else [])
    return (core.Jaxpr(jaxpr.constvars, invars, outvars, eqns), True)
Exemplo n.º 2
0
def _rewrite_jaxpr(jaxpr: core.Jaxpr, has_input_token: bool,
                   has_output_token: bool) -> core.Jaxpr:
  """Rewrite a Jaxpr to thread the token, if needed."""
  assert has_input_token or not has_output_token

  if not has_input_token and not xla.jaxpr_uses_outfeed(jaxpr):
    return jaxpr

  mk_new_var = core.gensym([jaxpr])

  eqns: List[core.JaxprEqn] = []
  last_token_var = mk_new_var(core.abstract_token)  # store the incoming token
  if has_input_token:
    invars = jaxpr.invars + [last_token_var]
  else:
    invars = jaxpr.invars
    eqns.append(
        core.new_jaxpr_eqn([jaxpr.invars[0]], [last_token_var],
                           lax.create_token_p, {}, source_info_util.current()))

  for eqn in jaxpr.eqns:
    if not xla.primitive_uses_outfeed(eqn.primitive, eqn.params):
      eqns.append(eqn)
    else:
      output_token_var = mk_new_var(core.abstract_token)
      _rewrite_eqn(eqn, eqns, last_token_var, output_token_var, mk_new_var)
      last_token_var = output_token_var

  outvars = jaxpr.outvars + ([last_token_var] if has_output_token else [])
  new_jaxpr = core.Jaxpr(jaxpr.constvars, invars, outvars, eqns)
  return new_jaxpr
Exemplo n.º 3
0
def rearrange_binders(jaxpr: core.ClosedJaxpr, primals_in, tangents_in, primals_out, tangents_out):
  new_invars = _perm(primals_in, tangents_in, jaxpr.jaxpr.invars)
  new_outvars = _perm(primals_out, tangents_out, jaxpr.jaxpr.outvars)
  new_jaxpr = core.Jaxpr(jaxpr.jaxpr.constvars,
                         new_invars, new_outvars, jaxpr.jaxpr.eqns,
                         jaxpr.jaxpr.effects)
  return core.ClosedJaxpr(new_jaxpr, jaxpr.consts)
Exemplo n.º 4
0
def ignore_errors_jaxpr(jaxpr, error):
    """Constructs a jaxpr which takes two extra args but ignores them."""
    err_aval = core.raise_to_shaped(core.get_aval(error.err))
    code_aval = core.raise_to_shaped(core.get_aval(error.code))
    consts = jaxpr.consts
    jaxpr = jaxpr.jaxpr
    new_vars = core.gensym([jaxpr])
    new_invars = (new_vars(err_aval), new_vars(code_aval), *jaxpr.invars)
    new_jaxpr = core.Jaxpr(jaxpr.constvars, new_invars, jaxpr.outvars,
                           jaxpr.eqns)
    return core.ClosedJaxpr(new_jaxpr, consts)
Exemplo n.º 5
0
    def augment_jaxpr(jaxpr, res_indices):
        num_res = len(res_indices)
        res_vars = jaxpr.jaxpr.invars[:num_res]
        non_res_vars = jaxpr.jaxpr.invars[num_res:]

        aug_res_vars = list(
            util.subvals(all_res_vars, zip(res_indices, res_vars)))
        aug_invars = aug_res_vars + non_res_vars
        jaxpr_aug = core.Jaxpr(jaxpr.jaxpr.constvars, aug_invars,
                               jaxpr.jaxpr.outvars, jaxpr.jaxpr.eqns,
                               jaxpr.jaxpr.effects)
        jaxpr_aug = core.ClosedJaxpr(jaxpr_aug, jaxpr.consts)
        return jaxpr_aug
Exemplo n.º 6
0
def _prune_unused_inputs(
    jaxpr: core.Jaxpr) -> Tuple[core.Jaxpr, Set[int], Set[int]]:
  used = {v for v in jaxpr.outvars if isinstance(v, core.Var)}
  # TODO(zhangqiaorjc): Improve the DCE algorithm by also pruning primitive
  # applications that do not produce used outputs. Must handle side-effecting
  # primitives and nested jaxpr.
  used.update(
      v for eqn in jaxpr.eqns for v in eqn.invars if isinstance(v, core.Var))
  kept_const_idx, new_constvars = util.unzip2(
      (i, v) for i, v in enumerate(jaxpr.constvars) if v in used)
  kept_var_idx, new_invars = util.unzip2(
      (i, v) for i, v in enumerate(jaxpr.invars) if v in used)
  new_jaxpr = core.Jaxpr(new_constvars, new_invars, jaxpr.outvars, jaxpr.eqns)
  return new_jaxpr, set(kept_const_idx), set(kept_var_idx)
Exemplo n.º 7
0
def tie_the_knot(typed_jaxpr):
    jaxpr, _, in_avals, out_avals = typed_jaxpr
    assert all(i == o for i, o in zip(in_avals, out_avals))
    in2out = dict(zip(jaxpr.invars, jaxpr.outvars))

    def replace(eqn):
        invars = [
            in2out[i] if (isinstance(i, jc.Var) and i in in2out) else i
            for i in eqn.invars
        ]
        return jc.JaxprEqn(invars, eqn.outvars, eqn.primitive, eqn.params,
                           eqn.source_info)

    eqns = [replace(eqn) for eqn in jaxpr.eqns]
    new_jaxpr = jc.Jaxpr(jaxpr.constvars, [], jaxpr.outvars, eqns)
    return jc.TypedJaxpr(new_jaxpr, typed_jaxpr.literals, [],
                         typed_jaxpr.out_avals)
Exemplo n.º 8
0
def _rewrite_while_outfeed_cond(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
                                input_token_var: core.Var,
                                output_token_var: core.Var,
                                mk_new_var: Callable):
  """Rewrite a while whose cond has outfeed"""
  cond_jaxpr, cond_nconsts, body_jaxpr, body_nconsts = util.split_dict(
      eqn.params, ["cond_jaxpr", "cond_nconsts", "body_jaxpr", "body_nconsts"])
  transformed_cond_jaxpr = _rewrite_typed_jaxpr(cond_jaxpr, True, True)
  carry_invars = eqn.invars[cond_nconsts + body_nconsts:]
  # pred1, token1 = rewrite(COND)(cond_consts, carry_invars, input_token)
  pred1_and_token1 = [
      mk_new_var(ov.aval) for ov in transformed_cond_jaxpr.jaxpr.outvars
  ]
  eqns.append(
      core.new_jaxpr_eqn(
          eqn.invars[0:cond_nconsts] + carry_invars + [input_token_var],
          pred1_and_token1, xla.xla_call_p,
          dict(
              call_jaxpr=transformed_cond_jaxpr.jaxpr,
              name="cond_before",
              donated_invars=(False,) * len(transformed_cond_jaxpr.in_avals)),
          eqn.source_info))
  # Make a new cond "lambda pred, carry, token: pred"
  new_cond_pred_invar = mk_new_var(cond_jaxpr.out_avals[0])
  new_cond_invars = ([new_cond_pred_invar] +
                     [mk_new_var(cv.aval) for cv in carry_invars] +
                     [mk_new_var(core.abstract_token)])
  new_cond_jaxpr = _mk_typed_jaxpr(
      core.Jaxpr([], new_cond_invars, [new_cond_pred_invar], []), [])
  # Make a new body:
  #   "lambda cond_constvars, body_constvars, pred, carry, token:
  #        carry2, token2 = rewrite(BODY)(body_constvars, carry, token)
  #        pred2, token3 = rewrite(COND)(cond_constvars, carry2, token2)
  #        (pred2, carry2, token3)
  transformed_body_jaxpr = _rewrite_typed_jaxpr(body_jaxpr, True, True)
  new_body_invars_cond_constvars = [
      mk_new_var(v.aval) for v in eqn.invars[0:cond_nconsts]
  ]
  new_body_invars_body_constvars = [
      mk_new_var(v.aval)
      for v in eqn.invars[cond_nconsts:cond_nconsts + body_nconsts]
  ]
  new_body_invars_pred = mk_new_var(cond_jaxpr.out_avals[0])
  new_body_invars_carry = [mk_new_var(cv.aval) for cv in carry_invars]
  new_body_invars_token = mk_new_var(core.abstract_token)

  new_body_carry2 = [mk_new_var(cv.aval) for cv in carry_invars]
  new_body_token2 = mk_new_var(core.abstract_token)
  new_body_pred2 = mk_new_var(cond_jaxpr.out_avals[0])
  new_body_token3 = mk_new_var(core.abstract_token)

  new_body_eqns = [
      core.new_jaxpr_eqn(
          new_body_invars_body_constvars + new_body_invars_carry +
          [new_body_invars_token], new_body_carry2 + [new_body_token2],
          xla.xla_call_p,
          dict(
              call_jaxpr=transformed_body_jaxpr.jaxpr,
              name="body",
              donated_invars=(False,) * len(transformed_body_jaxpr.in_avals)),
          eqn.source_info),
      core.new_jaxpr_eqn(
          new_body_invars_cond_constvars + new_body_carry2 + [new_body_token2],
          [new_body_pred2, new_body_token3], xla.xla_call_p,
          dict(
              call_jaxpr=transformed_cond_jaxpr.jaxpr,
              name="cond_body",
              donated_invars=(False,) * len(transformed_cond_jaxpr.in_avals)),
          eqn.source_info)
  ]
  new_body_jaxpr = _mk_typed_jaxpr(
      core.Jaxpr([], (new_body_invars_cond_constvars +
                      new_body_invars_body_constvars + [new_body_invars_pred] +
                      new_body_invars_carry + [new_body_invars_token]),
                 ([new_body_pred2] + new_body_carry2 + [new_body_token3]),
                 new_body_eqns), [])

  pred_out = mk_new_var(cond_jaxpr.out_avals[0])
  eqns.append(
      core.new_jaxpr_eqn(
          (eqn.invars[0:cond_nconsts + body_nconsts] + [pred1_and_token1[0]] +
           carry_invars + [pred1_and_token1[1]]),
          ([pred_out] + eqn.outvars + [output_token_var]), lax.while_p,
          dict(
              cond_jaxpr=new_cond_jaxpr,
              cond_nconsts=0,
              body_jaxpr=new_body_jaxpr,
              body_nconsts=cond_nconsts + body_nconsts), eqn.source_info))
Exemplo n.º 9
0
def _append_invars(jaxpr, avals):
  newvar = core.gensym([jaxpr])
  return core.Jaxpr(jaxpr.constvars, jaxpr.invars + map(newvar, avals),
                    jaxpr.outvars, jaxpr.eqns)
Exemplo n.º 10
0
def _tracers_to_jaxpr(in_tracers, out_tracers):
    """Constructs Jaxpr given tracers for inputs and outputs.

  Copied from jax.interpreters.partial_eval.tracers_to_jaxpr but modified to
  raise an VariableError when unknown in_tracers are found, rather than the
  default AssertionError.

  Args:
    in_tracers: the tracers that were created for the function inputs
    out_tracers: the tracers that were output by the function.

  Returns:
    a triple of a `Jaxpr`, a list of constant values corresponding to
    the `constvars` in the returned Jaxps, and a list of environment values.
    The vars for the environment values have been pre-pended to the Jaxpr's
    `invars`.

  Raises:
    VariableError: if an unknown input tracer is found
  """
    newvar = jax_core.gensym(None)
    t_to_var = {}

    def getvar(t):
        var = t_to_var.get(id(t))
        if var is None:
            var = newvar(t.pval.get_aval())
            t_to_var[id(t)] = var
        return var

    sorted_tracers = jax_util.toposort(out_tracers)
    invars = safe_map(getvar, in_tracers)
    eqns = []
    env = {}
    consts = {}
    const_to_var = {}

    def getconstvar(c):
        var = const_to_var.get(id(c))
        if var is None:
            var = newvar(jax_core.get_aval(c))
            const_to_var[id(c)] = var
        return var

    processed_eqn_ids = set()
    for t in sorted_tracers:
        recipe = t.recipe
        if isinstance(recipe, pe.JaxprEqnRecipe):
            if recipe.eqn_id not in processed_eqn_ids:
                eqns.append(pe.recipe_to_eqn(getvar, recipe))
                processed_eqn_ids.add(recipe.eqn_id)
        elif isinstance(recipe, pe.LambdaBinding):
            if not any(t is in_tracer for in_tracer in in_tracers):
                raise VariableError('Found unknown input tracer.')
            assert in_tracers, 'Lambda binding with no args'
        elif isinstance(recipe, pe.FreeVar):
            env[getvar(t)] = recipe.val
        elif isinstance(recipe, pe.ConstVar):
            v = t_to_var[id(t)] = getconstvar(recipe.val)
            consts[v] = recipe.val
        elif isinstance(recipe, jax_core.Literal):
            t_to_var[id(t)] = recipe
        elif recipe is jax_core.unit:
            t_to_var[id(t)] = jax_core.unitvar
        else:
            raise TypeError(recipe)

    env_vars, env_vals = jax_util.unzip2(env.items())
    const_vars, const_vals = jax_util.unzip2(consts.items())
    # The env_vars are pre-pended to the invars
    jaxpr = jax_core.Jaxpr(const_vars, list(it.chain(env_vars, invars)),
                           safe_map(getvar, out_tracers), eqns)
    return jaxpr, const_vals, env_vals