Beispiel #1
0
def _rewrite_jaxpr(jaxpr: core.Jaxpr, has_input_token: bool,
                   has_output_token: bool) -> core.Jaxpr:
  """Rewrite a Jaxpr to thread the token, if needed."""
  assert has_input_token or not has_output_token

  if not has_input_token and not xla.jaxpr_uses_outfeed(jaxpr):
    return jaxpr

  mk_new_var = core.gensym([jaxpr])

  eqns: List[core.JaxprEqn] = []
  last_token_var = mk_new_var(core.abstract_token)  # store the incoming token
  if has_input_token:
    invars = jaxpr.invars + [last_token_var]
  else:
    invars = jaxpr.invars
    eqns.append(
        core.new_jaxpr_eqn([jaxpr.invars[0]], [last_token_var],
                           lax.create_token_p, {}, source_info_util.current()))

  for eqn in jaxpr.eqns:
    if not xla.primitive_uses_outfeed(eqn.primitive, eqn.params):
      eqns.append(eqn)
    else:
      output_token_var = mk_new_var(core.abstract_token)
      _rewrite_eqn(eqn, eqns, last_token_var, output_token_var, mk_new_var)
      last_token_var = output_token_var

  outvars = jaxpr.outvars + ([last_token_var] if has_output_token else [])
  new_jaxpr = core.Jaxpr(jaxpr.constvars, invars, outvars, eqns)
  return new_jaxpr
Beispiel #2
0
def _rewrite_jaxpr(jaxpr: core.Jaxpr, has_input_token: bool,
                   has_output_token: bool) -> Tuple[core.Jaxpr, bool]:
    """Rewrite a Jaxpr to thread the token, if needed."""
    assert has_input_token or not has_output_token

    if not has_input_token and not xla.jaxpr_uses_outfeed(jaxpr):
        return (jaxpr, False)
    max_var_count = max(_jaxpr_var_defs(jaxpr))
    mk_new_id = itertools.count(start=max_var_count + 1)

    def mk_new_var(aval: core.AbstractValue) -> core.Var:
        return core.Var(next(mk_new_id), '', aval)

    eqns: List[core.JaxprEqn] = []
    last_token_var = mk_new_var(
        core.abstract_token)  # store the incoming token
    if has_input_token:
        invars = jaxpr.invars + [last_token_var]
    else:
        invars = jaxpr.invars
        eqns.append(
            core.new_jaxpr_eqn([jaxpr.invars[0]], [last_token_var],
                               lax.create_token_p, {}))

    for eqn in jaxpr.eqns:
        if not xla.primitive_uses_outfeed(eqn.primitive, eqn.params):
            eqns.append(eqn)
        else:
            output_token_var = mk_new_var(core.abstract_token)
            _rewrite_eqn(eqn, eqns, last_token_var, output_token_var)
            last_token_var = output_token_var

    outvars = jaxpr.outvars + ([last_token_var] if has_output_token else [])
    return (core.Jaxpr(jaxpr.constvars, invars, outvars, eqns), True)
Beispiel #3
0
def _rewrite_while_outfeed_cond(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
                                input_token_var: core.Var,
                                output_token_var: core.Var,
                                mk_new_var: Callable):
  """Rewrite a while whose cond has outfeed"""
  cond_jaxpr, cond_nconsts, body_jaxpr, body_nconsts = util.split_dict(
      eqn.params, ["cond_jaxpr", "cond_nconsts", "body_jaxpr", "body_nconsts"])
  transformed_cond_jaxpr = _rewrite_typed_jaxpr(cond_jaxpr, True, True)
  carry_invars = eqn.invars[cond_nconsts + body_nconsts:]
  # pred1, token1 = rewrite(COND)(cond_consts, carry_invars, input_token)
  pred1_and_token1 = [
      mk_new_var(ov.aval) for ov in transformed_cond_jaxpr.jaxpr.outvars
  ]
  eqns.append(
      core.new_jaxpr_eqn(
          eqn.invars[0:cond_nconsts] + carry_invars + [input_token_var],
          pred1_and_token1, xla.xla_call_p,
          dict(
              call_jaxpr=transformed_cond_jaxpr.jaxpr,
              name="cond_before",
              donated_invars=(False,) * len(transformed_cond_jaxpr.in_avals)),
          eqn.source_info))
  # Make a new cond "lambda pred, carry, token: pred"
  new_cond_pred_invar = mk_new_var(cond_jaxpr.out_avals[0])
  new_cond_invars = ([new_cond_pred_invar] +
                     [mk_new_var(cv.aval) for cv in carry_invars] +
                     [mk_new_var(core.abstract_token)])
  new_cond_jaxpr = _mk_typed_jaxpr(
      core.Jaxpr([], new_cond_invars, [new_cond_pred_invar], []), [])
  # Make a new body:
  #   "lambda cond_constvars, body_constvars, pred, carry, token:
  #        carry2, token2 = rewrite(BODY)(body_constvars, carry, token)
  #        pred2, token3 = rewrite(COND)(cond_constvars, carry2, token2)
  #        (pred2, carry2, token3)
  transformed_body_jaxpr = _rewrite_typed_jaxpr(body_jaxpr, True, True)
  new_body_invars_cond_constvars = [
      mk_new_var(v.aval) for v in eqn.invars[0:cond_nconsts]
  ]
  new_body_invars_body_constvars = [
      mk_new_var(v.aval)
      for v in eqn.invars[cond_nconsts:cond_nconsts + body_nconsts]
  ]
  new_body_invars_pred = mk_new_var(cond_jaxpr.out_avals[0])
  new_body_invars_carry = [mk_new_var(cv.aval) for cv in carry_invars]
  new_body_invars_token = mk_new_var(core.abstract_token)

  new_body_carry2 = [mk_new_var(cv.aval) for cv in carry_invars]
  new_body_token2 = mk_new_var(core.abstract_token)
  new_body_pred2 = mk_new_var(cond_jaxpr.out_avals[0])
  new_body_token3 = mk_new_var(core.abstract_token)

  new_body_eqns = [
      core.new_jaxpr_eqn(
          new_body_invars_body_constvars + new_body_invars_carry +
          [new_body_invars_token], new_body_carry2 + [new_body_token2],
          xla.xla_call_p,
          dict(
              call_jaxpr=transformed_body_jaxpr.jaxpr,
              name="body",
              donated_invars=(False,) * len(transformed_body_jaxpr.in_avals)),
          eqn.source_info),
      core.new_jaxpr_eqn(
          new_body_invars_cond_constvars + new_body_carry2 + [new_body_token2],
          [new_body_pred2, new_body_token3], xla.xla_call_p,
          dict(
              call_jaxpr=transformed_cond_jaxpr.jaxpr,
              name="cond_body",
              donated_invars=(False,) * len(transformed_cond_jaxpr.in_avals)),
          eqn.source_info)
  ]
  new_body_jaxpr = _mk_typed_jaxpr(
      core.Jaxpr([], (new_body_invars_cond_constvars +
                      new_body_invars_body_constvars + [new_body_invars_pred] +
                      new_body_invars_carry + [new_body_invars_token]),
                 ([new_body_pred2] + new_body_carry2 + [new_body_token3]),
                 new_body_eqns), [])

  pred_out = mk_new_var(cond_jaxpr.out_avals[0])
  eqns.append(
      core.new_jaxpr_eqn(
          (eqn.invars[0:cond_nconsts + body_nconsts] + [pred1_and_token1[0]] +
           carry_invars + [pred1_and_token1[1]]),
          ([pred_out] + eqn.outvars + [output_token_var]), lax.while_p,
          dict(
              cond_jaxpr=new_cond_jaxpr,
              cond_nconsts=0,
              body_jaxpr=new_body_jaxpr,
              body_nconsts=cond_nconsts + body_nconsts), eqn.source_info))
Beispiel #4
0
def _rewrite_eqn(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
                 input_token_var: core.Var, output_token_var: core.Var,
                 mk_new_var: Callable[[core.AbstractValue], core.Var]):
  """Rewrite an `eqn` and append equations to `eqns`.

  Assume that the current token is in `input_token_var` and the resulting
  token must end in `output_token_var`.
  """
  if eqn.primitive is id_tap_p:
    assert "has_token_" not in eqn.params
    eqns.append(
        core.new_jaxpr_eqn(eqn.invars + [input_token_var],
                           eqn.outvars + [output_token_var], eqn.primitive,
                           dict(eqn.params, has_token_=True),
                           eqn.source_info))
  elif eqn.primitive is lax.while_p:
    cond_jaxpr, _, body_jaxpr, _ = util.split_dict(
        eqn.params,
        ["cond_jaxpr", "cond_nconsts", "body_jaxpr", "body_nconsts"])
    if xla.jaxpr_uses_outfeed(cond_jaxpr.jaxpr):
      _rewrite_while_outfeed_cond(eqn, eqns, input_token_var, output_token_var,
                                  mk_new_var)
      return

    eqns.append(
        core.new_jaxpr_eqn(
            eqn.invars + [input_token_var], eqn.outvars + [output_token_var],
            eqn.primitive,
            dict(
                eqn.params,
                body_jaxpr=_rewrite_typed_jaxpr(body_jaxpr, True, True),
                cond_jaxpr=_rewrite_typed_jaxpr(cond_jaxpr, True,
                                                False)), eqn.source_info))
  elif eqn.primitive is lax.cond_p:
    branches, linear = util.split_dict(eqn.params, ["branches", "linear"])
    index, *operands = eqn.invars
    new_invars = [index, *operands, input_token_var]
    eqns.append(
        core.new_jaxpr_eqn(
            new_invars, eqn.outvars + [output_token_var], eqn.primitive,
            dict(
                eqn.params,
                branches=tuple(
                    _rewrite_typed_jaxpr(jaxpr, True, True)
                    for jaxpr in branches),
                linear=(*linear, False)), eqn.source_info))
  elif eqn.primitive is lax.scan_p:
    num_consts, num_carry, carry_jaxpr, linear, _, _, _ = util.split_dict(
        eqn.params,
        ["num_consts", "num_carry", "jaxpr", "linear", "reverse", "length",
         "unroll"])
    # We add the token right at the end of carry
    nr_const_and_carry = num_consts + num_carry
    new_invars = eqn.invars[0:nr_const_and_carry] + [
        input_token_var
    ] + eqn.invars[nr_const_and_carry:]
    new_jaxpr = _rewrite_typed_jaxpr(carry_jaxpr, True, True)
    # The rewrite has put the token at end, it has to be at end of carry
    new_jaxpr_invars = new_jaxpr.jaxpr.invars
    new_jaxpr_invars = (
        new_jaxpr_invars[0:nr_const_and_carry] + [new_jaxpr_invars[-1]] +
        new_jaxpr_invars[nr_const_and_carry:-1])
    new_jaxpr.jaxpr.invars = new_jaxpr_invars
    new_jaxpr.in_avals = [v.aval for v in new_jaxpr_invars]

    new_jaxpr_outvars = new_jaxpr.jaxpr.outvars
    new_jaxpr_outvars = (
        new_jaxpr_outvars[0:num_carry] + [new_jaxpr_outvars[-1]] +
        new_jaxpr_outvars[num_carry:-1])
    new_jaxpr.jaxpr.outvars = new_jaxpr_outvars
    new_jaxpr.out_avals = [v.aval for v in new_jaxpr_outvars]
    eqns.append(
        core.new_jaxpr_eqn(
            new_invars,
            # Output token is at the end of carry result
            eqn.outvars[0:num_carry] + [output_token_var] +
            eqn.outvars[num_carry:],
            eqn.primitive,
            dict(
                eqn.params,
                jaxpr=new_jaxpr,
                num_carry=num_carry + 1,
                linear=linear + (False,)),
            eqn.source_info))
  elif eqn.primitive is xla.xla_call_p:
    call_jaxpr = cast(core.Jaxpr, eqn.params["call_jaxpr"])
    eqns.append(
        core.new_jaxpr_eqn(
            eqn.invars + [input_token_var], eqn.outvars + [output_token_var],
            eqn.primitive,
            dict(
                eqn.params,
                call_jaxpr=_rewrite_jaxpr(call_jaxpr, True,
                                          True),
                donated_invars=eqn.params["donated_invars"] + (False,)
            ),
          eqn.source_info))
  elif eqn.primitive is custom_derivatives.custom_jvp_call_jaxpr_p:
    fun_jaxpr = eqn.params["fun_jaxpr"]
    new_invars = [*eqn.invars, input_token_var]
    def unreachable_thunk():
      assert False, "Should not be reached"
    eqns.append(
        core.new_jaxpr_eqn(
            new_invars, eqn.outvars + [output_token_var], eqn.primitive,
            dict(
                eqn.params,
                fun_jaxpr=_rewrite_typed_jaxpr(fun_jaxpr, True, True),
                jvp_jaxpr_thunk=unreachable_thunk
            ),
            eqn.source_info))
  elif eqn.primitive is custom_derivatives.custom_vjp_call_jaxpr_p:
    fun_jaxpr = eqn.params["fun_jaxpr"]
    new_invars = [*eqn.invars, input_token_var]
    def unreachable_thunk():
      assert False, "Should not be reached"
    eqns.append(
        core.new_jaxpr_eqn(
            new_invars, eqn.outvars + [output_token_var], eqn.primitive,
            dict(
                eqn.params,
                fun_jaxpr=_rewrite_typed_jaxpr(fun_jaxpr, True, True),
                fwd_jaxpr_thunk=unreachable_thunk,
                # The following are illegal values for the parameters, they
                # should not be needed because this rewrite is just before
                # compilation to XLA, which does not use those parameters.
                bwd="illegal param",
                out_trees="illegal param"
            ),
            eqn.source_info))
  else:
    raise NotImplementedError(f"outfeed rewrite {eqn.primitive}")
Beispiel #5
0
def _rewrite_eqn(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
                 input_token_var: core.Var, output_token_var: core.Var):
    """Rewrite an `eqn` and append equations to `eqns`. Assume that the
  current token is in `input_token_var` and the resulting token must end in
  `output_token_var`."""
    if eqn.primitive is id_tap_p:
        eqns.append(
            core.new_jaxpr_eqn(eqn.invars + [input_token_var],
                               eqn.outvars + [output_token_var], eqn.primitive,
                               eqn.params))
    elif eqn.primitive is lax.while_p:
        cond_jaxpr, cond_nconsts, body_jaxpr, body_nconsts = util.split_dict(
            eqn.params,
            ["cond_jaxpr", "cond_nconsts", "body_jaxpr", "body_nconsts"])
        if xla.jaxpr_uses_outfeed(cond_jaxpr.jaxpr):
            # TODO(necula): implement tapping from the conditional of a while
            raise NotImplementedError(
                "outfeed not supported in the conditional of a while")

        eqns.append(
            core.new_jaxpr_eqn(
                eqn.invars + [input_token_var],
                eqn.outvars + [output_token_var], eqn.primitive,
                dict(eqn.params,
                     body_jaxpr=_rewrite_typed_jaxpr(body_jaxpr, True,
                                                     True)[0],
                     cond_jaxpr=_rewrite_typed_jaxpr(cond_jaxpr, True,
                                                     False)[0])))
    elif eqn.primitive is lax.cond_p:
        true_jaxpr, false_jaxpr, linear = util.split_dict(
            eqn.params, ["true_jaxpr", "false_jaxpr", "linear"])
        nr_true_invars = len(true_jaxpr.jaxpr.invars)
        pred, true_invars, false_invars = util.split_list(
            eqn.invars, [1, nr_true_invars])
        new_invars = pred + true_invars + [input_token_var] + false_invars + [
            input_token_var
        ]
        eqns.append(
            core.new_jaxpr_eqn(
                new_invars, eqn.outvars + [output_token_var], eqn.primitive,
                dict(eqn.params,
                     true_jaxpr=_rewrite_typed_jaxpr(true_jaxpr, True,
                                                     True)[0],
                     false_jaxpr=_rewrite_typed_jaxpr(false_jaxpr, True,
                                                      True)[0],
                     linear=linear + (False, False))))
    elif eqn.primitive is lax.scan_p:
        num_consts, num_carry, carry_jaxpr, linear, _, _ = util.split_dict(
            eqn.params, [
                "num_consts", "num_carry", "jaxpr", "linear", "reverse",
                "length"
            ])
        # We add the token right at the end of carry
        nr_const_and_carry = num_consts + num_carry
        new_invars = eqn.invars[0:nr_const_and_carry] + [
            input_token_var
        ] + eqn.invars[nr_const_and_carry:]
        new_jaxpr = _rewrite_typed_jaxpr(carry_jaxpr, True, True)[0]
        # The rewrite has put the token at end, it has to be at end of carry
        new_jaxpr_invars = new_jaxpr.jaxpr.invars
        new_jaxpr_invars = (new_jaxpr_invars[0:nr_const_and_carry] +
                            [new_jaxpr_invars[-1]] +
                            new_jaxpr_invars[nr_const_and_carry:-1])
        new_jaxpr.jaxpr.invars = new_jaxpr_invars
        new_jaxpr.in_avals = [v.aval for v in new_jaxpr_invars]

        new_jaxpr_outvars = new_jaxpr.jaxpr.outvars
        new_jaxpr_outvars = (new_jaxpr_outvars[0:num_carry] +
                             [new_jaxpr_outvars[-1]] +
                             new_jaxpr_outvars[num_carry:-1])
        new_jaxpr.jaxpr.outvars = new_jaxpr_outvars
        new_jaxpr.out_avals = [v.aval for v in new_jaxpr_outvars]
        eqns.append(
            core.new_jaxpr_eqn(
                new_invars,
                # Output token is at the end of carry result
                eqn.outvars[0:num_carry] + [output_token_var] +
                eqn.outvars[num_carry:],
                eqn.primitive,
                dict(eqn.params,
                     jaxpr=new_jaxpr,
                     num_carry=num_carry + 1,
                     linear=linear + (False, ))))
    elif eqn.primitive is xla.xla_call_p:
        call_jaxpr = eqn.params["call_jaxpr"]
        eqns.append(
            core.new_jaxpr_eqn(
                eqn.invars + [input_token_var],
                eqn.outvars + [output_token_var], eqn.primitive,
                dict(eqn.params,
                     call_jaxpr=_rewrite_jaxpr(call_jaxpr, True, True)[0])))
    else:
        raise NotImplementedError(f"outfeed rewrite {eqn.primitive}")
Beispiel #6
0
def _rewrite_eqn(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
                 input_token_var: core.Var, output_token_var: core.Var,
                 mk_new_var: Callable[[core.AbstractValue], core.Var]):
    """Rewrite an `eqn` and append equations to `eqns`. Assume that the
  current token is in `input_token_var` and the resulting token must end in
  `output_token_var`."""
    if eqn.primitive is id_tap_p:
        eqns.append(
            core.new_jaxpr_eqn(eqn.invars + [input_token_var],
                               eqn.outvars + [output_token_var], eqn.primitive,
                               eqn.params, eqn.source_info))
    elif eqn.primitive is lax.while_p:
        cond_jaxpr, cond_nconsts, body_jaxpr, body_nconsts = util.split_dict(
            eqn.params,
            ["cond_jaxpr", "cond_nconsts", "body_jaxpr", "body_nconsts"])
        if xla.jaxpr_uses_outfeed(cond_jaxpr.jaxpr):
            _rewrite_while_outfeed_cond(eqn, eqns, input_token_var,
                                        output_token_var, mk_new_var)
            return

        eqns.append(
            core.new_jaxpr_eqn(
                eqn.invars + [input_token_var],
                eqn.outvars + [output_token_var], eqn.primitive,
                dict(eqn.params,
                     body_jaxpr=_rewrite_typed_jaxpr(body_jaxpr, True,
                                                     True)[0],
                     cond_jaxpr=_rewrite_typed_jaxpr(cond_jaxpr, True,
                                                     False)[0]),
                eqn.source_info))
    elif eqn.primitive is lax.cond_p:
        branches, linear = util.split_dict(eqn.params, ["branches", "linear"])
        index, *operands = eqn.invars
        new_invars = [index, *operands, input_token_var]
        eqns.append(
            core.new_jaxpr_eqn(
                new_invars, eqn.outvars + [output_token_var], eqn.primitive,
                dict(eqn.params,
                     branches=tuple(
                         _rewrite_typed_jaxpr(jaxpr, True, True)[0]
                         for jaxpr in branches),
                     linear=(*linear, False)), eqn.source_info))
    elif eqn.primitive is lax.scan_p:
        num_consts, num_carry, carry_jaxpr, linear, _, _ = util.split_dict(
            eqn.params, [
                "num_consts", "num_carry", "jaxpr", "linear", "reverse",
                "length"
            ])
        # We add the token right at the end of carry
        nr_const_and_carry = num_consts + num_carry
        new_invars = eqn.invars[0:nr_const_and_carry] + [
            input_token_var
        ] + eqn.invars[nr_const_and_carry:]
        new_jaxpr = _rewrite_typed_jaxpr(carry_jaxpr, True, True)[0]
        # The rewrite has put the token at end, it has to be at end of carry
        new_jaxpr_invars = new_jaxpr.jaxpr.invars
        new_jaxpr_invars = (new_jaxpr_invars[0:nr_const_and_carry] +
                            [new_jaxpr_invars[-1]] +
                            new_jaxpr_invars[nr_const_and_carry:-1])
        new_jaxpr.jaxpr.invars = new_jaxpr_invars
        new_jaxpr.in_avals = [v.aval for v in new_jaxpr_invars]

        new_jaxpr_outvars = new_jaxpr.jaxpr.outvars
        new_jaxpr_outvars = (new_jaxpr_outvars[0:num_carry] +
                             [new_jaxpr_outvars[-1]] +
                             new_jaxpr_outvars[num_carry:-1])
        new_jaxpr.jaxpr.outvars = new_jaxpr_outvars
        new_jaxpr.out_avals = [v.aval for v in new_jaxpr_outvars]
        eqns.append(
            core.new_jaxpr_eqn(
                new_invars,
                # Output token is at the end of carry result
                eqn.outvars[0:num_carry] + [output_token_var] +
                eqn.outvars[num_carry:],
                eqn.primitive,
                dict(eqn.params,
                     jaxpr=new_jaxpr,
                     num_carry=num_carry + 1,
                     linear=linear + (False, )),
                eqn.source_info))
    elif eqn.primitive is xla.xla_call_p:
        call_jaxpr = cast(core.Jaxpr, eqn.params["call_jaxpr"])
        eqns.append(
            core.new_jaxpr_eqn(
                eqn.invars + [input_token_var],
                eqn.outvars + [output_token_var], eqn.primitive,
                dict(eqn.params,
                     call_jaxpr=_rewrite_jaxpr(call_jaxpr, True,
                                               True)[0]), eqn.source_info))
    else:
        raise NotImplementedError(f"outfeed rewrite {eqn.primitive}")