Exemplo n.º 1
0
def _scan_transpose(cts, *args, **kwargs):
  forward, length, num_consts, num_carry, jaxpr, linear = split_dict(
      kwargs, ["forward", "length", "num_consts", "num_carry", "jaxpr", "linear"])

  # we can only transpose scans for which the nonlinear values appear in xs
  consts_lin, init_lin, xs_lin = split_list(linear, [num_consts, num_carry])
  num_lin = sum(xs_lin)
  if not all(consts_lin) or not all(init_lin) or not all(xs_lin[:num_lin]):
    raise NotImplementedError

  consts, init, xs, res = split_list(args, [num_consts, num_carry, num_lin])
  assert not any(r is ad.undefined_primal for r in res)

  carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry])
  ys_avals = _map(partial(_promote_aval_rank, length), y_avals)
  ct_carry, ct_ys = split_list(cts, [num_carry])
  ct_carry = _map(ad.instantiate_zeros_aval, carry_avals, ct_carry)
  ct_ys = _map(ad.instantiate_zeros_aval, ys_avals, ct_ys)
  ct_consts = _map(ad_util.zeros_like_aval, jaxpr.in_avals[:num_consts])

  #       jaxpr :: [T d] -> [T c] -> [T a, res] -> ([T c], [T b])
  # jaxpr_trans :: [] -> [CT d, CT c] -> [CT b, res] -> ([CT d, CT c], [CT a])
  jaxpr_trans = _transpose_jaxpr(num_consts, len(res), jaxpr)
  linear_trans = ([True] * (len(ct_consts) + len(ct_carry) + len(ct_ys))
                  + [False] * len(res))

  outs = scan_p.bind(
      *(ct_consts + ct_carry + ct_ys + res), forward=not forward, length=length,
      jaxpr=jaxpr_trans, num_consts=0, num_carry=num_consts+num_carry,
      linear=linear_trans)
  ct_consts, ct_init, ct_xs = split_list(outs, [num_consts, num_carry])
  return ct_consts + ct_init + ct_xs + [None] * len(res)
Exemplo n.º 2
0
def _custom_linear_solve_impl(*args, **kwargs):
  const_lengths, jaxprs, tree = split_dict(
      kwargs, ['const_lengths', 'jaxprs', 'tree'])
  params, b = _split_linear_solve_args(args, const_lengths)
  x = core.jaxpr_as_fun(jaxprs.solve)(*(params.solve + b))
  _check_shapes('solve', 'b', x, b, tree)
  return x
Exemplo n.º 3
0
def scan_bind(*args, **kwargs):
    forward, length, num_consts, num_carry, jaxpr, linear = split_dict(
        kwargs,
        ["forward", "length", "num_consts", "num_carry", "jaxpr", "linear"])
    consts, init, xs = split_list(args, [num_consts, num_carry])
    assert len(linear) == len(args)

    # check that args match input types
    consts_avals, init_avals, x_avals = split_list(jaxpr.in_avals,
                                                   [num_consts, num_carry])
    xs_avals = _map(partial(_promote_aval_rank, length), x_avals)
    assert all(_map(typecheck, consts_avals, consts))
    assert all(_map(typecheck, init_avals, init))
    # assert all(_map(typecheck, xs_avals, xs))
    # check that output carry type matches input carry type
    carry_avals, _ = split_list(jaxpr.out_avals, [num_carry])
    assert all(_map(typematch, init_avals, carry_avals))

    # check that the data flow is sensible
    core.check_jaxpr(jaxpr.jaxpr)

    return core.Primitive.bind(scan_p,
                               *args,
                               forward=forward,
                               length=length,
                               jaxpr=jaxpr,
                               num_consts=num_consts,
                               num_carry=num_carry,
                               linear=linear)
Exemplo n.º 4
0
def _cond_translation_rule(c, axis_env, pred, *args, **kwargs):
    backend = kwargs.pop("backend", None)
    true_jaxpr, false_jaxpr, true_nconsts, false_nconsts = split_dict(
        kwargs, ["true_jaxpr", "false_jaxpr", "true_nconsts", "false_nconsts"])
    true_nops = len(true_jaxpr.in_avals) - true_nconsts
    true_consts, true_ops, false_consts, false_ops = split_list(
        args, [true_nconsts, true_nops, false_nconsts])

    def make_computation(name, jaxpr, op_shape):
        c = xb.make_computation_builder(name)
        op = c.ParameterWithShape(op_shape)
        ops = [c.GetTupleElement(op, i) for i in range(len(jaxpr.in_avals))]
        out = c.Call(
            xla.jaxpr_computation(jaxpr.jaxpr, backend,
                                  axis_env, jaxpr.literals, (),
                                  *_map(c.GetShape, ops)), ops)
        return c.Build(out)

    true_op = c.Tuple(*(true_consts + true_ops))
    true_c = make_computation("true_comp", true_jaxpr, c.GetShape(true_op))

    false_op = c.Tuple(*(false_consts + false_ops))
    false_c = make_computation("false_comp", false_jaxpr, c.GetShape(false_op))

    return c.Conditional(pred, true_op, true_c, false_op, false_c)
Exemplo n.º 5
0
 def process_parametrized(self, primitive, *flat_inputs, **kwargs):
     in_tree, out_tree_container = split_dict(kwargs, ['in_tree', 'out_tree_container'])
     inputs = tree_unflatten(in_tree, flat_inputs)
     outputs = self._process_parametrized_nonflat(primitive, *inputs)
     flat_outputs, out_tree = tree_flatten(outputs)
     out_tree_container.append(out_tree)
     return flat_outputs
Exemplo n.º 6
0
def _while_loop_translation_rule(c, axis_env, *args, **kwargs):
    cond_jaxpr, body_jaxpr, cond_nconsts, body_nconsts = split_dict(
        kwargs, ["cond_jaxpr", "body_jaxpr", "cond_nconsts", "body_nconsts"])
    cond_consts, body_consts, init_vals = split_list(
        args, [cond_nconsts, body_nconsts])
    batched = bool(cond_jaxpr.out_avals[0].shape)

    # Since jaxprs don't have tuples and have multiple return values, but we need
    # the HLO While loop to take a single tuple input and output a single boolean
    # (for the cond computation) or a single tuple output (for the body
    # computation), we build XLA computations that handle the tuple munging before
    # generating a Call into the computations formed from the jaxprs.

    init_carry = c.Tuple(*(cond_consts + body_consts + init_vals))

    cond_c = xb.make_computation_builder("cond_computation")
    cond_carry = cond_c.ParameterWithShape(c.GetShape(init_carry))
    cond_carry_elts = [
        cond_c.GetTupleElement(cond_carry, i) for i in range(len(args))
    ]
    x, _, z = split_list(cond_carry_elts, [cond_nconsts, body_nconsts])
    cond_outs = cond_c.Call(
        xla.jaxpr_computation(cond_jaxpr.jaxpr, axis_env, cond_jaxpr.literals,
                              (), *_map(cond_c.GetShape, x + z)), x + z)
    pred = cond_c.GetTupleElement(cond_outs, 0)
    if batched:
        scalar = xla_client.Shape.array_shape(onp.dtype(onp.bool_), ())
        or_ = xla.primitive_computation(lax.or_p, scalar, scalar)
        pred = cond_c.Reduce(pred, cond_c.Constant(onp.array(False)), or_,
                             list(range(cond_jaxpr.out_avals[0].ndim)))

    body_c = xb.make_computation_builder("body_computation")
    body_carry = body_c.ParameterWithShape(c.GetShape(init_carry))
    body_carry_elts = [
        body_c.GetTupleElement(body_carry, i) for i in range(len(args))
    ]
    x, y, z = split_list(body_carry_elts, [cond_nconsts, body_nconsts])
    body_out = body_c.Call(
        xla.jaxpr_computation(body_jaxpr.jaxpr, axis_env, body_jaxpr.literals,
                              (), *_map(body_c.GetShape, y + z)), y + z)
    new_z = [
        body_c.GetTupleElement(body_out, i) for i in range(len(init_vals))
    ]
    if batched:
        body_cond_outs = body_c.Call(
            xla.jaxpr_computation(cond_jaxpr.jaxpr, axis_env,
                                  cond_jaxpr.literals, (),
                                  *_map(body_c.GetShape, x + z)), x + z)
        body_pred = body_c.GetTupleElement(body_cond_outs, 0)
        new_z = _map(partial(_pred_bcast_select, body_c, body_pred), new_z, z)
        assert _map(body_c.GetShape, new_z) == _map(body_c.GetShape,
                                                    z)  # no broadcast
    new_carry = body_c.Tuple(*(x + y + new_z))

    ans = c.While(cond_c.Build(pred), body_c.Build(new_carry), init_carry)
    ans_elts = [c.GetTupleElement(ans, i) for i in range(len(args))]
    _, _, z = split_list(ans_elts, [cond_nconsts, body_nconsts])
    return c.Tuple(*z)
Exemplo n.º 7
0
 def abstract_eval(self, *avals, **kwargs):
     in_tree, out_tree_container = split_dict(kwargs, ['in_tree', 'out_tree_container'])
     flat_outs_fun, out_tree_thunk = flatten_fun_nokwargs(self._wrapped_example_outputs_fun,
                                                          in_tree)
     # populates out_tree_thunk, so that it returns the output tree:
     _, flat_outs, _ = _instantiated_trace_to_jaxpr(flat_outs_fun, avals)
     # return out_tree via container:
     out_tree_container.append(out_tree_thunk())
     return flat_outs
Exemplo n.º 8
0
def _cond_impl(pred, *args, **kwargs):
    true_jaxpr, false_jaxpr, true_nconsts, false_nconsts = split_dict(
        kwargs, ["true_jaxpr", "false_jaxpr", "true_nconsts", "false_nconsts"])
    true_consts, true_ops, false_consts, false_ops = split_list(
        args,
        [true_nconsts, len(true_jaxpr.in_avals), false_nconsts])

    if pred:
        return core.jaxpr_as_fun(true_jaxpr)(*(true_consts + true_ops))
    else:
        return core.jaxpr_as_fun(false_jaxpr)(*(false_consts + false_ops))
Exemplo n.º 9
0
def _custom_linear_solve_transpose_rule(cotangent, *primals, **kwargs):
  const_lengths, jaxprs, tree = split_dict(
      kwargs, ['const_lengths', 'jaxprs', 'tree'])

  if jaxprs.transpose_solve is None:
    raise TypeError('transpose_solve required for backwards mode automatic '
                    'differentiation of custom_linear_solve')

  params, b = _split_linear_solve_args(primals, const_lengths)
  assert b == [ad.undefined_primal] * len(b)
  cotangent_b = custom_linear_solve_p.bind(
      *(_flatten(params.transpose()) + cotangent),
      const_lengths=const_lengths.transpose(), jaxprs=jaxprs.transpose(),
      tree=tree)
  return [None] * sum(const_lengths) + cotangent_b
Exemplo n.º 10
0
def _scan_partial_eval(trace, *tracers, **kwargs):
  forward, length, num_consts, num_carry, jaxpr, linear = split_dict(
      kwargs, ["forward", "length", "num_consts", "num_carry", "jaxpr", "linear"])
  num_xs = len(jaxpr.in_avals) - num_carry - num_consts
  num_ys = len(jaxpr.out_avals) - num_carry

  unknowns = original_unknowns = [t.pval[0] is not None for t in tracers]
  const_uk, init_uk, xs_uk = split_list(unknowns, [num_consts, num_carry])

  carry_uk = init_uk
  for _ in range(1000):
    unknowns = const_uk + carry_uk + xs_uk
    jaxpr_1, jaxpr_2, out_uk = pe.partial_eval_jaxpr(
        jaxpr, unknowns, instantiate=carry_uk + [False] * num_ys)
    carry_uk_out, ys_uk = out_uk[:num_carry], out_uk[num_carry:]
    if carry_uk_out == carry_uk:
      break
    else:
      carry_uk = carry_uk_out
  else:
    raise FixedPointError

  in_consts = [core.unit if uk else t.pval[1] for uk, t in zip(unknowns, tracers)]
  new_tracers = [trace.instantiate_const(t) if uk else trace.new_instantiated_literal(core.unit)
                 for uk, t in zip(unknowns, tracers)]

  carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry])
  ys_avals = _map(partial(_promote_aval_rank, length), y_avals)
  out_avals = carry_avals + ys_avals
  out_pvs = [aval if uk else None for aval, uk in zip(out_avals, out_uk)]

  linear_1 = [lin or uk for uk, lin in zip(unknowns, linear)]
  out_flat = scan_p.bind(
      *in_consts, forward=forward, length=length, jaxpr=jaxpr_1,
      num_consts=num_consts, num_carry=num_carry, linear=linear_1)
  out_carry, ys, residuals = split_list(out_flat, [num_carry, num_ys])
  out_consts = out_carry + ys
  residual_tracers = _map(trace.new_instantiated_const, residuals)
  out_tracers = [pe.JaxprTracer(trace, pe.PartialVal((pv, const)), None)
                 for pv, const in zip(out_pvs, out_consts)]
  linear_2 = ([lin or not uk for uk, lin in zip(unknowns, linear)]
              + [False] * len(residual_tracers))
  eqn = pe.new_jaxpr_eqn(new_tracers + residual_tracers, out_tracers, scan_p,
                         (), dict(forward=forward, length=length, jaxpr=jaxpr_2,
                                  num_consts=num_consts, num_carry=num_carry,
                                  linear=linear_2))
  for t in out_tracers: t.recipe = eqn
  return out_tracers
Exemplo n.º 11
0
def _custom_cell_scan_impl(flat_cell, *args, **kwargs):
    """lax_control_flow._scan_impl, but allowing for a custom cell function."""

    reverse, length, num_consts, num_carry, jaxpr, linear, unroll = split_dict(
        kwargs, ["reverse", "length", "num_consts", "num_carry", "jaxpr", "linear", "unroll"])

    consts, init, xs = split_list(args, [num_consts, num_carry])
    _, _, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry])
    cell_args = consts + init + map(partial(_index_array, 0), x_avals, xs)

    jaxpr, new_consts = _flat_initial_style_jaxpr(wrap_init(flat_cell), _abstractified(cell_args))

    args = list(new_consts) + init + xs
    kwargs['jaxpr'] = jaxpr
    kwargs['num_consts'] = len(new_consts)
    kwargs['linear'] = (False,) * len(args)

    return scan_p.bind(*args, **kwargs)
Exemplo n.º 12
0
def _scan_impl(*args, **kwargs):
  forward, length, num_consts, num_carry, jaxpr, linear = split_dict(
      kwargs, ["forward", "length", "num_consts", "num_carry", "jaxpr", "linear"])

  consts, init, xs = split_list(args, [num_consts, num_carry])
  _, _, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry])
  _, y_avals = split_list(jaxpr.out_avals, [num_carry])

  def body_fun(i, vals):
    i = i if forward else length - i - 1
    carry, ys = split_list(vals, [num_carry])
    x = _map(partial(_index_array, i), x_avals, xs)
    out_flat = core.jaxpr_as_fun(jaxpr)(*(consts + carry + x))
    carry_out, y_updates = split_list(out_flat, [num_carry])
    ys_out = _map(partial(_update_array, i), y_avals, ys, y_updates)
    return carry_out + ys_out

  ys_init = _map(partial(_empty_array, length), y_avals)
  return fori_loop(0, length, body_fun, init + ys_init)
Exemplo n.º 13
0
def _root_impl(*args, **kwargs):
  num_consts, jaxpr, solve, _ = split_dict(
      kwargs, ['num_consts', 'jaxpr', 'solve', 'tangent_solve'])
  params, initial_guess = split_list(args, [num_consts])
  f = partial(core.jaxpr_as_fun(jaxpr), *params)
  return solve(f, *initial_guess)
Exemplo n.º 14
0
def _rewrite_eqn(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
                 input_token_var: core.Var, output_token_var: core.Var,
                 mk_new_var: Callable[[core.AbstractValue], core.Var]):
    """Rewrite an `eqn` and append equations to `eqns`. Assume that the
  current token is in `input_token_var` and the resulting token must end in
  `output_token_var`."""
    if eqn.primitive is id_tap_p:
        eqns.append(
            core.new_jaxpr_eqn(eqn.invars + [input_token_var],
                               eqn.outvars + [output_token_var], eqn.primitive,
                               eqn.params, eqn.source_info))
    elif eqn.primitive is lax.while_p:
        cond_jaxpr, cond_nconsts, body_jaxpr, body_nconsts = util.split_dict(
            eqn.params,
            ["cond_jaxpr", "cond_nconsts", "body_jaxpr", "body_nconsts"])
        if xla.jaxpr_uses_outfeed(cond_jaxpr.jaxpr):
            _rewrite_while_outfeed_cond(eqn, eqns, input_token_var,
                                        output_token_var, mk_new_var)
            return

        eqns.append(
            core.new_jaxpr_eqn(
                eqn.invars + [input_token_var],
                eqn.outvars + [output_token_var], eqn.primitive,
                dict(eqn.params,
                     body_jaxpr=_rewrite_typed_jaxpr(body_jaxpr, True,
                                                     True)[0],
                     cond_jaxpr=_rewrite_typed_jaxpr(cond_jaxpr, True,
                                                     False)[0]),
                eqn.source_info))
    elif eqn.primitive is lax.cond_p:
        branches, linear = util.split_dict(eqn.params, ["branches", "linear"])
        index, *operands = eqn.invars
        new_invars = [index, *operands, input_token_var]
        eqns.append(
            core.new_jaxpr_eqn(
                new_invars, eqn.outvars + [output_token_var], eqn.primitive,
                dict(eqn.params,
                     branches=tuple(
                         _rewrite_typed_jaxpr(jaxpr, True, True)[0]
                         for jaxpr in branches),
                     linear=(*linear, False)), eqn.source_info))
    elif eqn.primitive is lax.scan_p:
        num_consts, num_carry, carry_jaxpr, linear, _, _ = util.split_dict(
            eqn.params, [
                "num_consts", "num_carry", "jaxpr", "linear", "reverse",
                "length"
            ])
        # We add the token right at the end of carry
        nr_const_and_carry = num_consts + num_carry
        new_invars = eqn.invars[0:nr_const_and_carry] + [
            input_token_var
        ] + eqn.invars[nr_const_and_carry:]
        new_jaxpr = _rewrite_typed_jaxpr(carry_jaxpr, True, True)[0]
        # The rewrite has put the token at end, it has to be at end of carry
        new_jaxpr_invars = new_jaxpr.jaxpr.invars
        new_jaxpr_invars = (new_jaxpr_invars[0:nr_const_and_carry] +
                            [new_jaxpr_invars[-1]] +
                            new_jaxpr_invars[nr_const_and_carry:-1])
        new_jaxpr.jaxpr.invars = new_jaxpr_invars
        new_jaxpr.in_avals = [v.aval for v in new_jaxpr_invars]

        new_jaxpr_outvars = new_jaxpr.jaxpr.outvars
        new_jaxpr_outvars = (new_jaxpr_outvars[0:num_carry] +
                             [new_jaxpr_outvars[-1]] +
                             new_jaxpr_outvars[num_carry:-1])
        new_jaxpr.jaxpr.outvars = new_jaxpr_outvars
        new_jaxpr.out_avals = [v.aval for v in new_jaxpr_outvars]
        eqns.append(
            core.new_jaxpr_eqn(
                new_invars,
                # Output token is at the end of carry result
                eqn.outvars[0:num_carry] + [output_token_var] +
                eqn.outvars[num_carry:],
                eqn.primitive,
                dict(eqn.params,
                     jaxpr=new_jaxpr,
                     num_carry=num_carry + 1,
                     linear=linear + (False, )),
                eqn.source_info))
    elif eqn.primitive is xla.xla_call_p:
        call_jaxpr = cast(core.Jaxpr, eqn.params["call_jaxpr"])
        eqns.append(
            core.new_jaxpr_eqn(
                eqn.invars + [input_token_var],
                eqn.outvars + [output_token_var], eqn.primitive,
                dict(eqn.params,
                     call_jaxpr=_rewrite_jaxpr(call_jaxpr, True,
                                               True)[0]), eqn.source_info))
    else:
        raise NotImplementedError(f"outfeed rewrite {eqn.primitive}")
Exemplo n.º 15
0
def _rewrite_eqn(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
                 input_token_var: core.Var, output_token_var: core.Var,
                 mk_new_var: Callable[[core.AbstractValue], core.Var]):
  """Rewrite an `eqn` and append equations to `eqns`.

  Assume that the current token is in `input_token_var` and the resulting
  token must end in `output_token_var`.
  """
  if eqn.primitive is id_tap_p:
    assert "has_token_" not in eqn.params
    eqns.append(
        core.new_jaxpr_eqn(eqn.invars + [input_token_var],
                           eqn.outvars + [output_token_var], eqn.primitive,
                           dict(eqn.params, has_token_=True),
                           eqn.source_info))
  elif eqn.primitive is lax.while_p:
    cond_jaxpr, _, body_jaxpr, _ = util.split_dict(
        eqn.params,
        ["cond_jaxpr", "cond_nconsts", "body_jaxpr", "body_nconsts"])
    if xla.jaxpr_uses_outfeed(cond_jaxpr.jaxpr):
      _rewrite_while_outfeed_cond(eqn, eqns, input_token_var, output_token_var,
                                  mk_new_var)
      return

    eqns.append(
        core.new_jaxpr_eqn(
            eqn.invars + [input_token_var], eqn.outvars + [output_token_var],
            eqn.primitive,
            dict(
                eqn.params,
                body_jaxpr=_rewrite_typed_jaxpr(body_jaxpr, True, True),
                cond_jaxpr=_rewrite_typed_jaxpr(cond_jaxpr, True,
                                                False)), eqn.source_info))
  elif eqn.primitive is lax.cond_p:
    branches, linear = util.split_dict(eqn.params, ["branches", "linear"])
    index, *operands = eqn.invars
    new_invars = [index, *operands, input_token_var]
    eqns.append(
        core.new_jaxpr_eqn(
            new_invars, eqn.outvars + [output_token_var], eqn.primitive,
            dict(
                eqn.params,
                branches=tuple(
                    _rewrite_typed_jaxpr(jaxpr, True, True)
                    for jaxpr in branches),
                linear=(*linear, False)), eqn.source_info))
  elif eqn.primitive is lax.scan_p:
    num_consts, num_carry, carry_jaxpr, linear, _, _, _ = util.split_dict(
        eqn.params,
        ["num_consts", "num_carry", "jaxpr", "linear", "reverse", "length",
         "unroll"])
    # We add the token right at the end of carry
    nr_const_and_carry = num_consts + num_carry
    new_invars = eqn.invars[0:nr_const_and_carry] + [
        input_token_var
    ] + eqn.invars[nr_const_and_carry:]
    new_jaxpr = _rewrite_typed_jaxpr(carry_jaxpr, True, True)
    # The rewrite has put the token at end, it has to be at end of carry
    new_jaxpr_invars = new_jaxpr.jaxpr.invars
    new_jaxpr_invars = (
        new_jaxpr_invars[0:nr_const_and_carry] + [new_jaxpr_invars[-1]] +
        new_jaxpr_invars[nr_const_and_carry:-1])
    new_jaxpr.jaxpr.invars = new_jaxpr_invars
    new_jaxpr.in_avals = [v.aval for v in new_jaxpr_invars]

    new_jaxpr_outvars = new_jaxpr.jaxpr.outvars
    new_jaxpr_outvars = (
        new_jaxpr_outvars[0:num_carry] + [new_jaxpr_outvars[-1]] +
        new_jaxpr_outvars[num_carry:-1])
    new_jaxpr.jaxpr.outvars = new_jaxpr_outvars
    new_jaxpr.out_avals = [v.aval for v in new_jaxpr_outvars]
    eqns.append(
        core.new_jaxpr_eqn(
            new_invars,
            # Output token is at the end of carry result
            eqn.outvars[0:num_carry] + [output_token_var] +
            eqn.outvars[num_carry:],
            eqn.primitive,
            dict(
                eqn.params,
                jaxpr=new_jaxpr,
                num_carry=num_carry + 1,
                linear=linear + (False,)),
            eqn.source_info))
  elif eqn.primitive is xla.xla_call_p:
    call_jaxpr = cast(core.Jaxpr, eqn.params["call_jaxpr"])
    eqns.append(
        core.new_jaxpr_eqn(
            eqn.invars + [input_token_var], eqn.outvars + [output_token_var],
            eqn.primitive,
            dict(
                eqn.params,
                call_jaxpr=_rewrite_jaxpr(call_jaxpr, True,
                                          True),
                donated_invars=eqn.params["donated_invars"] + (False,)
            ),
          eqn.source_info))
  elif eqn.primitive is custom_derivatives.custom_jvp_call_jaxpr_p:
    fun_jaxpr = eqn.params["fun_jaxpr"]
    new_invars = [*eqn.invars, input_token_var]
    def unreachable_thunk():
      assert False, "Should not be reached"
    eqns.append(
        core.new_jaxpr_eqn(
            new_invars, eqn.outvars + [output_token_var], eqn.primitive,
            dict(
                eqn.params,
                fun_jaxpr=_rewrite_typed_jaxpr(fun_jaxpr, True, True),
                jvp_jaxpr_thunk=unreachable_thunk
            ),
            eqn.source_info))
  elif eqn.primitive is custom_derivatives.custom_vjp_call_jaxpr_p:
    fun_jaxpr = eqn.params["fun_jaxpr"]
    new_invars = [*eqn.invars, input_token_var]
    def unreachable_thunk():
      assert False, "Should not be reached"
    eqns.append(
        core.new_jaxpr_eqn(
            new_invars, eqn.outvars + [output_token_var], eqn.primitive,
            dict(
                eqn.params,
                fun_jaxpr=_rewrite_typed_jaxpr(fun_jaxpr, True, True),
                fwd_jaxpr_thunk=unreachable_thunk,
                # The following are illegal values for the parameters, they
                # should not be needed because this rewrite is just before
                # compilation to XLA, which does not use those parameters.
                bwd="illegal param",
                out_trees="illegal param"
            ),
            eqn.source_info))
  else:
    raise NotImplementedError(f"outfeed rewrite {eqn.primitive}")
Exemplo n.º 16
0
def _rewrite_while_outfeed_cond(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
                                input_token_var: core.Var,
                                output_token_var: core.Var,
                                mk_new_var: Callable):
  """Rewrite a while whose cond has outfeed"""
  cond_jaxpr, cond_nconsts, body_jaxpr, body_nconsts = util.split_dict(
      eqn.params, ["cond_jaxpr", "cond_nconsts", "body_jaxpr", "body_nconsts"])
  transformed_cond_jaxpr = _rewrite_typed_jaxpr(cond_jaxpr, True, True)
  carry_invars = eqn.invars[cond_nconsts + body_nconsts:]
  # pred1, token1 = rewrite(COND)(cond_consts, carry_invars, input_token)
  pred1_and_token1 = [
      mk_new_var(ov.aval) for ov in transformed_cond_jaxpr.jaxpr.outvars
  ]
  eqns.append(
      core.new_jaxpr_eqn(
          eqn.invars[0:cond_nconsts] + carry_invars + [input_token_var],
          pred1_and_token1, xla.xla_call_p,
          dict(
              call_jaxpr=transformed_cond_jaxpr.jaxpr,
              name="cond_before",
              donated_invars=(False,) * len(transformed_cond_jaxpr.in_avals)),
          eqn.source_info))
  # Make a new cond "lambda pred, carry, token: pred"
  new_cond_pred_invar = mk_new_var(cond_jaxpr.out_avals[0])
  new_cond_invars = ([new_cond_pred_invar] +
                     [mk_new_var(cv.aval) for cv in carry_invars] +
                     [mk_new_var(core.abstract_token)])
  new_cond_jaxpr = _mk_typed_jaxpr(
      core.Jaxpr([], new_cond_invars, [new_cond_pred_invar], []), [])
  # Make a new body:
  #   "lambda cond_constvars, body_constvars, pred, carry, token:
  #        carry2, token2 = rewrite(BODY)(body_constvars, carry, token)
  #        pred2, token3 = rewrite(COND)(cond_constvars, carry2, token2)
  #        (pred2, carry2, token3)
  transformed_body_jaxpr = _rewrite_typed_jaxpr(body_jaxpr, True, True)
  new_body_invars_cond_constvars = [
      mk_new_var(v.aval) for v in eqn.invars[0:cond_nconsts]
  ]
  new_body_invars_body_constvars = [
      mk_new_var(v.aval)
      for v in eqn.invars[cond_nconsts:cond_nconsts + body_nconsts]
  ]
  new_body_invars_pred = mk_new_var(cond_jaxpr.out_avals[0])
  new_body_invars_carry = [mk_new_var(cv.aval) for cv in carry_invars]
  new_body_invars_token = mk_new_var(core.abstract_token)

  new_body_carry2 = [mk_new_var(cv.aval) for cv in carry_invars]
  new_body_token2 = mk_new_var(core.abstract_token)
  new_body_pred2 = mk_new_var(cond_jaxpr.out_avals[0])
  new_body_token3 = mk_new_var(core.abstract_token)

  new_body_eqns = [
      core.new_jaxpr_eqn(
          new_body_invars_body_constvars + new_body_invars_carry +
          [new_body_invars_token], new_body_carry2 + [new_body_token2],
          xla.xla_call_p,
          dict(
              call_jaxpr=transformed_body_jaxpr.jaxpr,
              name="body",
              donated_invars=(False,) * len(transformed_body_jaxpr.in_avals)),
          eqn.source_info),
      core.new_jaxpr_eqn(
          new_body_invars_cond_constvars + new_body_carry2 + [new_body_token2],
          [new_body_pred2, new_body_token3], xla.xla_call_p,
          dict(
              call_jaxpr=transformed_cond_jaxpr.jaxpr,
              name="cond_body",
              donated_invars=(False,) * len(transformed_cond_jaxpr.in_avals)),
          eqn.source_info)
  ]
  new_body_jaxpr = _mk_typed_jaxpr(
      core.Jaxpr([], (new_body_invars_cond_constvars +
                      new_body_invars_body_constvars + [new_body_invars_pred] +
                      new_body_invars_carry + [new_body_invars_token]),
                 ([new_body_pred2] + new_body_carry2 + [new_body_token3]),
                 new_body_eqns), [])

  pred_out = mk_new_var(cond_jaxpr.out_avals[0])
  eqns.append(
      core.new_jaxpr_eqn(
          (eqn.invars[0:cond_nconsts + body_nconsts] + [pred1_and_token1[0]] +
           carry_invars + [pred1_and_token1[1]]),
          ([pred_out] + eqn.outvars + [output_token_var]), lax.while_p,
          dict(
              cond_jaxpr=new_cond_jaxpr,
              cond_nconsts=0,
              body_jaxpr=new_body_jaxpr,
              body_nconsts=cond_nconsts + body_nconsts), eqn.source_info))
Exemplo n.º 17
0
def _rewrite_eqn(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
                 input_token_var: core.Var, output_token_var: core.Var):
    """Rewrite an `eqn` and append equations to `eqns`. Assume that the
  current token is in `input_token_var` and the resulting token must end in
  `output_token_var`."""
    if eqn.primitive is id_tap_p:
        eqns.append(
            core.new_jaxpr_eqn(eqn.invars + [input_token_var],
                               eqn.outvars + [output_token_var], eqn.primitive,
                               eqn.params))
    elif eqn.primitive is lax.while_p:
        cond_jaxpr, cond_nconsts, body_jaxpr, body_nconsts = util.split_dict(
            eqn.params,
            ["cond_jaxpr", "cond_nconsts", "body_jaxpr", "body_nconsts"])
        if xla.jaxpr_uses_outfeed(cond_jaxpr.jaxpr):
            # TODO(necula): implement tapping from the conditional of a while
            raise NotImplementedError(
                "outfeed not supported in the conditional of a while")

        eqns.append(
            core.new_jaxpr_eqn(
                eqn.invars + [input_token_var],
                eqn.outvars + [output_token_var], eqn.primitive,
                dict(eqn.params,
                     body_jaxpr=_rewrite_typed_jaxpr(body_jaxpr, True,
                                                     True)[0],
                     cond_jaxpr=_rewrite_typed_jaxpr(cond_jaxpr, True,
                                                     False)[0])))
    elif eqn.primitive is lax.cond_p:
        true_jaxpr, false_jaxpr, linear = util.split_dict(
            eqn.params, ["true_jaxpr", "false_jaxpr", "linear"])
        nr_true_invars = len(true_jaxpr.jaxpr.invars)
        pred, true_invars, false_invars = util.split_list(
            eqn.invars, [1, nr_true_invars])
        new_invars = pred + true_invars + [input_token_var] + false_invars + [
            input_token_var
        ]
        eqns.append(
            core.new_jaxpr_eqn(
                new_invars, eqn.outvars + [output_token_var], eqn.primitive,
                dict(eqn.params,
                     true_jaxpr=_rewrite_typed_jaxpr(true_jaxpr, True,
                                                     True)[0],
                     false_jaxpr=_rewrite_typed_jaxpr(false_jaxpr, True,
                                                      True)[0],
                     linear=linear + (False, False))))
    elif eqn.primitive is lax.scan_p:
        num_consts, num_carry, carry_jaxpr, linear, _, _ = util.split_dict(
            eqn.params, [
                "num_consts", "num_carry", "jaxpr", "linear", "reverse",
                "length"
            ])
        # We add the token right at the end of carry
        nr_const_and_carry = num_consts + num_carry
        new_invars = eqn.invars[0:nr_const_and_carry] + [
            input_token_var
        ] + eqn.invars[nr_const_and_carry:]
        new_jaxpr = _rewrite_typed_jaxpr(carry_jaxpr, True, True)[0]
        # The rewrite has put the token at end, it has to be at end of carry
        new_jaxpr_invars = new_jaxpr.jaxpr.invars
        new_jaxpr_invars = (new_jaxpr_invars[0:nr_const_and_carry] +
                            [new_jaxpr_invars[-1]] +
                            new_jaxpr_invars[nr_const_and_carry:-1])
        new_jaxpr.jaxpr.invars = new_jaxpr_invars
        new_jaxpr.in_avals = [v.aval for v in new_jaxpr_invars]

        new_jaxpr_outvars = new_jaxpr.jaxpr.outvars
        new_jaxpr_outvars = (new_jaxpr_outvars[0:num_carry] +
                             [new_jaxpr_outvars[-1]] +
                             new_jaxpr_outvars[num_carry:-1])
        new_jaxpr.jaxpr.outvars = new_jaxpr_outvars
        new_jaxpr.out_avals = [v.aval for v in new_jaxpr_outvars]
        eqns.append(
            core.new_jaxpr_eqn(
                new_invars,
                # Output token is at the end of carry result
                eqn.outvars[0:num_carry] + [output_token_var] +
                eqn.outvars[num_carry:],
                eqn.primitive,
                dict(eqn.params,
                     jaxpr=new_jaxpr,
                     num_carry=num_carry + 1,
                     linear=linear + (False, ))))
    elif eqn.primitive is xla.xla_call_p:
        call_jaxpr = eqn.params["call_jaxpr"]
        eqns.append(
            core.new_jaxpr_eqn(
                eqn.invars + [input_token_var],
                eqn.outvars + [output_token_var], eqn.primitive,
                dict(eqn.params,
                     call_jaxpr=_rewrite_jaxpr(call_jaxpr, True, True)[0])))
    else:
        raise NotImplementedError(f"outfeed rewrite {eqn.primitive}")