Exemple #1
0
def _maybe_tracer_tuple_to_abstract_tuple(tup):
  if isinstance(tup, pe.JaxprTracerTuple):
    return core.AbstractTuple(list(map(_maybe_tracer_tuple_to_abstract_tuple, tup)))
  elif isinstance(tup, core.AbstractValue):
    return tup
  elif tup is None:
    return core.AbstractTuple(())
  else:
    raise TypeError(tup)
Exemple #2
0
def _scan_transpose(ct, consts, init, xs, forward, length, jaxpr):
  assert consts is None and init is None
  assert type(xs) is tuple
  a, res = xs
  assert a is None and res is not None

  # jaxpr :: d -> c -> (a, res) ->  (c, b)
  # jaxpr_lifted :: res -> (d, c, a) -> (c, b)
  # jaxpr_lifted_trans :: res -> (CT c, CT b) -> (CT d, CT c, CT a)
  # jaxpr_trans :: * -> (CT c, CT d) -> (CT b, res) -> ((CT c, CT d), CT a)
  assert type(jaxpr.jaxpr.invars[2]) is tuple  # assume restructuring
  jaxpr_lifted = rearrange_binders(
      lambda d, c, a_res: (a_res[1], (d, c, a_res[0])), jaxpr)
  jaxpr_lifted_trans = _transpose_jaxpr(jaxpr_lifted)
  jaxpr_trans = _move_stuff_and_add_add(jaxpr_lifted_trans)

  c_aval, b_aval = jaxpr.out_aval
  d_aval, c_aval2, _ = jaxpr.in_avals
  assert c_aval == c_aval2
  bs_aval = _promote_aval_rank(length, b_aval)
  ct_d = ad_util.zeros_like_aval(d_aval)
  ct_c, ct_bs = ad.instantiate_zeros_aval(core.AbstractTuple((c_aval, bs_aval)), ct)
  carry_ct = core.pack((ct_c, ct_d))

  # jaxpr_trans :: * -> (CT c, CT d) -> (CT b, res) -> ((CT c, CT d), CT a)
  core.check_jaxpr(jaxpr_trans.jaxpr)
  unit_aval, (ct_c_aval, ct_d_aval), (ct_b_aval, _) = jaxpr_trans.in_avals
  assert core.lattice_join(ct_c_aval, core.get_aval(ct_c)) == ct_c_aval
  assert core.lattice_join(ct_d_aval, core.get_aval(ct_d)) == ct_d_aval

  out = scan_p.bind(
      core.unit, carry_ct, core.pack((ct_bs, res)),
      forward=not forward, length=length, jaxpr=jaxpr_trans)
  (ct_init, ct_consts), ct_as = out
  return ct_consts, ct_init, (ct_as, None)
Exemple #3
0
def _move_stuff_and_add_add(typed_jaxpr):
  # jaxpr_lifted_trans :: res -> (CT c, CT b) -> (CT d, CT c, CT a)
  # jaxpr_trans :: * -> (CT c, CT d) -> (CT b, res) -> ((CT c, CT d), CT a)

  res_aval, (CTc_aval, CTb_aval) = typed_jaxpr.in_avals
  CTd_aval, CTc_aval2, CTa_aval = typed_jaxpr.out_aval
  assert CTc_aval == CTc_aval2
  in_avals = (core.AbstractTuple(()), core.AbstractTuple((CTc_aval, CTd_aval)),
              core.AbstractTuple((CTb_aval, res_aval)))
  out_aval = core.AbstractTuple((core.AbstractTuple((CTc_aval, CTd_aval)),
                                 CTa_aval))

  jaxpr = typed_jaxpr.jaxpr.copy()
  # assume the jaxpr isn't restructuring any inputs
  assert not any(type(invar) is tuple for invar in jaxpr.invars)

  # munge input side
  CTc_in = _scan_newvar()
  CTb_in = _scan_newvar()
  CTd_in = _scan_newvar()
  res_in, CTc_CTb_in = jaxpr.invars
  jaxpr.invars = ((), (CTc_in, CTd_in), (CTb_in, res_in))
  jaxpr.eqns = (
      [pe._pack_eqn([CTc_in, CTb_in], CTc_CTb_in)] +
      jaxpr.eqns)

  # munge output side
  CTd_new = _scan_newvar()
  CTd_sum = _scan_newvar()
  CTc = _scan_newvar()
  CTa = _scan_newvar()
  partial_out = _scan_newvar()
  outvar = _scan_newvar()
  jaxpr.eqns = (
      jaxpr.eqns +
      [pe._unpack_eqn(jaxpr.outvar, [CTd_new, CTc, CTa]),
       _add_any_eqn(CTd_sum, CTd_new, CTd_in),
       pe._pack_eqn([CTc, CTd_sum], partial_out),
       pe._pack_eqn([partial_out, CTa], outvar)])
  jaxpr.outvar = outvar

  # TODO(mattjj): add a check_typed_jaxpr and use it here
  core.skip_checks or core.check_jaxpr(jaxpr)
  return core.TypedJaxpr(jaxpr, typed_jaxpr.literals, in_avals, out_aval)
Exemple #4
0
def lu_abstract_eval(operand):
  if isinstance(operand, ShapedArray):
    if operand.ndim < 2:
      raise ValueError("Argument to LU decomposition must have ndims >= 2")

    batch_dims = operand.shape[:-2]
    m = operand.shape[-2]
    n = operand.shape[-1]
    pivot = ShapedArray(batch_dims + (min(m, n),), np.int32)
  else:
    pivot = operand
  return core.AbstractTuple((operand, pivot))
Exemple #5
0
def eigh_abstract_eval(operand, lower):
  if isinstance(operand, ShapedArray):
    if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]:
      raise ValueError(
        "Argument to symmetric eigendecomposition must have shape [..., n, n]")

    batch_dims = operand.shape[:-2]
    n = operand.shape[-1]
    v = ShapedArray(batch_dims + (n, n), operand.dtype)
    w = ShapedArray(batch_dims + (n,), operand.dtype)
  else:
    v, w = operand, operand
  return core.AbstractTuple((v, w))
Exemple #6
0
def eig_abstract_eval(operand):
  if isinstance(operand, ShapedArray):
    if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]:
      raise ValueError("Argument to nonsymmetric eigendecomposition must have "
                       "shape [..., n, n], got shape {}".format(operand.shape))

    batch_dims = operand.shape[:-2]
    n = operand.shape[-1]
    vl = vr = ShapedArray(batch_dims + (n, n), operand.dtype)
    w = ShapedArray(batch_dims + (n,), lax.lax._complex_basetype(operand.dtype))
  else:
    w = vl = vr = operand
  return core.AbstractTuple((w, vl, vr))
Exemple #7
0
def qr_abstract_eval(operand, full_matrices):
  if isinstance(operand, ShapedArray):
    if operand.ndim < 2:
      raise ValueError("Argument to QR decomposition must have ndims >= 2")
    batch_dims = operand.shape[:-2]
    m = operand.shape[-2]
    n = operand.shape[-1]
    k = m if full_matrices else min(m, n)
    q = ShapedArray(batch_dims + (m, k), operand.dtype)
    r = ShapedArray(batch_dims + (k, n), operand.dtype)
  else:
    q = operand
    r = operand
  return core.AbstractTuple((q, r))
Exemple #8
0
def svd_abstract_eval(operand, full_matrices, compute_uv):
  if isinstance(operand, ShapedArray):
    if operand.ndim < 2:
      raise ValueError("Argument to singular value decomposition must have ndims >= 2")

    batch_dims = operand.shape[:-2]
    m = operand.shape[-2]
    n = operand.shape[-1]
    s = ShapedArray(batch_dims + (min(m, n),), operand.dtype)
    u = ShapedArray(batch_dims + (m, m if full_matrices else min(m, n)), operand.dtype)
    vt = ShapedArray(batch_dims + (n if full_matrices else min(m, n), n), operand.dtype)
  else:
    s = operand
    u = operand
    vt = operand
  return core.AbstractTuple((s, u, vt))
Exemple #9
0
def _scan_partial_eval(trace, *tracers, **kwargs):
    jaxpr = kwargs.pop('jaxpr')
    length = kwargs.pop('length')
    forward = kwargs.pop('forward')
    assert not kwargs
    in_pvs, _ = unzip2([t.pval for t in tracers])
    sc_consts, sc_init, sc_xs = map(pe.unknown, in_pvs)

    sc_carry = sc_init
    for i in range(1000):
        second_components = (sc_consts, sc_carry, sc_xs)
        jaxpr_1, jaxpr_2, sc_out = pe.partial_eval_jaxpr(jaxpr,
                                                         second_components,
                                                         instantiate=(sc_carry,
                                                                      False))
        sc_carry_out, sc_ys = sc_out
        if sc_carry_out == sc_carry:
            break
        else:
            sc_carry = _binary_lattice_join(sc_carry, sc_carry_out)
    else:
        raise FixedPointError

    consts_tracer, init_tracer, xs_tracer = tracers
    lifted_init_tracer = _lift_tracer(trace, init_tracer, sc_carry)
    lifted_tracers = consts_tracer, lifted_init_tracer, xs_tracer
    in_pvs, in_consts = unzip2([t.pval for t in lifted_tracers])

    carry_aval, y_aval = jaxpr.out_aval
    ys_aval = _promote_aval_rank(length, y_aval)
    out_aval = core.AbstractTuple((carry_aval, ys_aval))
    out_pv = _put_known_pvs(sc_out, out_aval)

    out_carry, (ys, residuals) = scan_p.bind(*in_consts,
                                             forward=forward,
                                             length=length,
                                             jaxpr=jaxpr_1)
    out_const = core.pack((out_carry, ys))
    residuals_tracer = trace.new_instantiated_const(core.pack(residuals))
    d, c, a = lifted_tracers
    new_tracers = (d, c, (a, residuals_tracer))
    eqn = core.JaxprEqn(new_tracers, None, scan_p, (), True, False,
                        dict(forward=forward, length=length, jaxpr=jaxpr_2))
    return pe.JaxprTracer(trace, pe.PartialVal((out_pv, out_const)), eqn)
Exemple #10
0
def _tscan(f, a, bs, fields=(0, )):
    """
    Works as jax.lax.scan but has additional `fields` argument to select only
    necessary fields from `a`'s structure. Defaults to selecting only the first
    field. Other fields will be filled by None.
    """
    # Note: code is copied and modified from lax.scan implementation in
    # [JAX](https://github.com/google/jax) to support the additional `fields`
    # arg. Original code has the following copyright:
    #
    # Copyright 2018 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License")

    # convert pytree to flat jaxtuple
    a, a_tree = pytree_to_flatjaxtuple(a)
    bs, b_tree = pytree_to_flatjaxtuple(bs)
    fields, _ = pytree_to_flatjaxtuple(fields)
    f, out_tree = pytree_fun_to_flatjaxtuple_fun(wrap_init(f),
                                                 (a_tree, b_tree))

    # convert arrays to abstract values
    a_aval, _ = lax._abstractify(a)
    bs_aval, _ = lax._abstractify(bs)
    # convert bs to b
    b_aval = core.AbstractTuple(
        [ShapedArray(b.shape[1:], b.dtype) for b in bs_aval])

    # convert abstract values to partial values (?) then evaluate to get jaxpr
    a_pval = partial_eval.PartialVal((a_aval, core.unit))
    b_pval = partial_eval.PartialVal((b_aval, core.unit))
    jaxpr, pval_out, consts = partial_eval.trace_to_jaxpr(f, (a_pval, b_pval))
    aval_out, _ = pval_out
    consts = core.pack(consts)

    out = tscan_p.bind(a, bs, fields, consts, aval_out=aval_out, jaxpr=jaxpr)
    return tree_unflatten(out_tree(), out)
Exemple #11
0
def scan(f, init, xs):
  """Scan a function over leading array axes while carrying along state.

  The type signature in brief is

  .. code-block:: haskell

    scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b])

  where we use [t] here to denote the type t with an additional leading axis.
  That is, if t is an array type then [t] represents the type with an additional
  leading axis, and if t is a pytree (container) type with array leaves then [t]
  represents the type with the same pytree structure and corresponding leaves
  each with an additional leading axis.

  When both ``a`` and ``b`` are array types, the semantics of ``scan`` are given
  by this Python implementation::

    def scan(f, init, xs):
      carry = init
      ys = []
      for x in xs:
        carry, y = f(carry, x)
        ys.append(y)
      return carry, np.stack(ys)

  Unlike that Python version, both ``a`` and ``b`` may be arbitrary pytree
  types, and so multiple arrays can be scanned over at once and produce multiple
  output arrays.

  Also unlike that Python version, ``scan`` is a JAX primitive and is lowered to
  a single XLA While HLO. That makes it useful for reducing compilation times
  for jit-compiled functions, since native Python loop constructs in an ``@jit``
  function are unrolled, leading to large XLA computations.

  Args:
    f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning
      that ``f`` accepts two arguments where the first is a value of the loop
      carry and the second is a slice of ``xs`` along its leading axis, and that
      ``f`` returns a pair where the first element represents a new value for
      the loop carry and the second represents a slice of the output.
    init: an initial loop carry value of type ``c``, which can be a scalar,
      array, or any pytree (nested Python tuple/list/dict) thereof, representing
      the initial loop carry value.
    xs: the value of type ``[a]`` over which to scan along the leading axis,
      where ``[a]`` can be an array or any pytree (nested Python
      tuple/list/dict) thereof with consistent leading axis sizes.

  Returns:
    A pair of type ``(c, [b])`` where the first element represents the final
    loop carry value and the second element represents the stacked outputs of
    the second output of ``f`` when scanned over the leading axis of the inputs.
  """
  (init, xs), in_trees = unzip2(map(pytree_to_jaxtupletree, (init, xs)))
  f, out_tree = pytree_fun_to_jaxtupletree_fun(lu.wrap_init(f), in_trees)
  carry_pval = carry_aval, _ = _abstractify(init)
  xs_aval, _ = _abstractify(xs)
  x_aval = _demote_aval_rank(xs_aval)
  x_pval = pe.PartialVal((x_aval, core.unit))
  jaxpr, pval_out, consts = pe.trace_to_jaxpr(
      f, (carry_pval, x_pval), instantiate=True)
  pv_out, const_out = pval_out
  assert isinstance(pv_out, core.AbstractValue) and const_out == core.unit
  if not isinstance(pv_out, core.AbstractTuple) or len(pv_out) != 2:
    msg = ("scanned function must have signature `c -> a -> (c, b)`, but the "
           "output was not a pair: got type {}.")
    raise TypeError(msg.format(pv_out))
  carry_aval_out, y_aval = pv_out
  if carry_aval != carry_aval_out:
    msg = ("scanned function carry output does not match carry input: "
           "input carry is {} and output carry is {}.")
    raise TypeError(msg.format(carry_aval, carry_aval_out))
  lifted_jaxpr = pe._closure_convert_jaxpr(jaxpr)
  consts_aval, _ = _abstractify(core.pack(consts))
  in_avals = (consts_aval, carry_aval, x_aval)
  out_aval = core.AbstractTuple((carry_aval, y_aval))
  jaxpr = core.TypedJaxpr(lifted_jaxpr, (), in_avals, out_aval)
  length = _leading_dim_size(xs)
  out = scan_p.bind(core.pack(consts), init, xs,
                    forward=True, length=length, jaxpr=jaxpr)
  return build_tree(out_tree(), out)
Exemple #12
0
def _promote_aval_rank(n, xs):
  assert isinstance(xs, core.AbstractValue)
  if isinstance(xs, core.AbstractTuple):
    return core.AbstractTuple(map(partial(_promote_aval_rank, n), xs))
  else:
    return ShapedArray((n,) + xs.shape, xs.dtype)
Exemple #13
0
def _demote_aval_rank(xs):
  assert isinstance(xs, core.AbstractValue)
  if isinstance(xs, core.AbstractTuple):
    return core.AbstractTuple(map(_demote_aval_rank, xs))
  else:
    return ShapedArray(xs.shape[1:], xs.dtype)