Exemple #1
0
def qr_jvp_rule(primals, tangents, full_matrices):
    # See j-towns.github.io/papers/qr-derivative.pdf for a terse derivation.
    x, = primals
    if full_matrices or np.shape(x)[-2] < np.shape(x)[-1]:
        raise NotImplementedError
    dx, = tangents
    q, r = qr_p.bind(x, full_matrices=False)
    dx_rinv = triangular_solve(r, dx)  # Right side solve by default
    qt_dx_rinv = np.matmul(_T(q), dx_rinv)
    qt_dx_rinv_lower = np.tril(qt_dx_rinv, -1)
    domega = qt_dx_rinv_lower - _T(qt_dx_rinv_lower)  # This is skew-symmetric
    dq = np.matmul(q, domega - qt_dx_rinv) + dx_rinv
    dr = np.matmul(qt_dx_rinv - domega, r)
    return core.pack((q, r)), core.pack((dq, dr))
Exemple #2
0
def _scan_jvp(primals, tangents, forward, length, jaxpr):
    consts, init, xs = primals
    consts_dot, init_dot, xs_dot = tangents
    consts_aval, carry_aval, x_aval = jaxpr.in_avals
    _, y_aval = jaxpr.out_aval

    consts_nonzeros = ad.get_nonzeros(consts_dot)
    init_nonzeros = ad.get_nonzeros(init_dot)
    xs_nonzeros = ad.get_nonzeros(xs_dot)  # same as x_nonzeros b/c arrays

    carry_nonzeros = init_nonzeros
    for _ in range(1000):
        nonzeros = (consts_nonzeros, carry_nonzeros, xs_nonzeros)
        jaxpr_jvp, nonzeros_out = ad.jvp_jaxpr(jaxpr,
                                               nonzeros,
                                               instantiate=(carry_nonzeros,
                                                            False))
        carry_nonzeros_out, ys_nonzeros = nonzeros_out
        if _binary_lattice_eq(carry_nonzeros_out, carry_nonzeros):
            break
        else:
            carry_nonzeros = _binary_lattice_join(carry_nonzeros_out,
                                                  carry_nonzeros)
    else:
        raise FixedPointError

    # convert_zeros is like strip_zeros but uses explicit lattice information to
    # instantiate zeros in some cases, namely in init_dot based on the fixed point
    nonzero_init_dot = _convert_zeros(carry_nonzeros, init, init_dot)
    nonzero_consts_dot = _convert_zeros(consts_nonzeros, consts, consts_dot)
    nonzero_xs_dot = _convert_zeros(xs_nonzeros, xs, xs_dot)

    consts_dual = core.pack((consts, nonzero_consts_dot))
    init_dual = core.pack((init, nonzero_init_dot))
    xs_dual = core.pack((xs, nonzero_xs_dot))

    carry_out_dual, ys_dual = scan_p.bind(consts_dual,
                                          init_dual,
                                          xs_dual,
                                          forward=forward,
                                          length=length,
                                          jaxpr=jaxpr_jvp)

    ys, ys_dot = ys_dual
    ys_dot = ad.put_zeros(ad.TangentTuple, ys_nonzeros, ys_dot)

    carry_out, carry_out_dot = carry_out_dual
    carry_out_dot = ad.put_zeros(ad.TangentTuple, carry_nonzeros_out,
                                 carry_out_dot)
    return core.pack((carry_out, ys)), ad.TangentTuple((carry_out_dot, ys_dot))
Exemple #3
0
def _tscan_impl(a, bs, fields, consts, aval_out, jaxpr):
    length = tuple(bs)[0].shape[0]
    state = [
        lax.full((length, ) + a[i].shape, 0, lax._dtype(a[i])) for i in fields
    ]

    def body_fun(i, vals):
        a, state = vals
        # select i-th element from each b
        b = [lax.dynamic_index_in_dim(b, i, keepdims=False) for b in bs]
        a_out = core.eval_jaxpr(jaxpr, consts, (), a, core.pack(b))
        # select fields from a_out and update state
        state_out = [
            lax.dynamic_update_index_in_dim(s, a[None, ...], i, axis=0)
            for a, s in zip([tuple(a_out)[j] for j in fields], state)
        ]
        return a_out, state_out

    _, state = lax.fori_loop(0, length, body_fun, (a, state))

    # set None for non-selected fields
    out = [None] * len(a)
    for field, i in zip(fields, range(len(fields))):
        out[field] = state[i]
    return core.pack(out)
Exemple #4
0
def _update_arrays(i, aval, xs, x):
  assert isinstance(aval, core.AbstractValue)
  if isinstance(aval, core.AbstractTuple):
    return core.pack(map(partial(_update_arrays, i), aval, xs, x))
  else:
    x = lax.reshape(x, (1,) + onp.shape(x))
    return lax.dynamic_update_index_in_dim(xs, x, i, axis=0)
Exemple #5
0
def eigh_jvp_rule(primals, tangents, lower):
  # Derivative for eigh in the simplest case of distinct eigenvalues.
  # This is classic nondegenerate perurbation theory, but also see
  # https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
  # The general solution treating the case of degenerate eigenvalues is
  # considerably more complicated. Ambitious readers may refer to the general
  # methods below or refer to degenerate perturbation theory in physics.
  # https://www.win.tue.nl/analysis/reports/rana06-33.pdf and
  # https://people.orie.cornell.edu/aslewis/publications/99-clarke.pdf
  a, = primals
  a_dot, = tangents

  v, w = eigh_p.bind(symmetrize(a), lower=lower)

  if a_dot is ad_util.zero:
    return core.pack((v, w)), ad.TangentTuple(ad_util.zero, ad_util.zero)

  # for complex numbers we need eigenvalues to be full dtype of v, a:
  w = w.astype(a.dtype)
  eye_n = np.eye(a.shape[-1], dtype=a.dtype)
  # carefully build reciprocal delta-eigenvalue matrix, avoiding NaNs.
  Fmat = np.reciprocal(eye_n + w - w[..., np.newaxis]) - eye_n
  # eigh impl doesn't support batch dims, but future-proof the grad.
  dot = lax.dot if a.ndim == 2 else lax.batch_matmul
  vdag_adot_v = dot(dot(_H(v), a_dot), v)
  dv = dot(v, np.multiply(Fmat, vdag_adot_v))
  dw = np.diagonal(vdag_adot_v)
  return (v, w), (dv, dw)
Exemple #6
0
def svd_jvp_rule(primals, tangents, full_matrices, compute_uv):
  A, = primals
  dA, = tangents
  s, U, Vt = svd_p.bind(A, full_matrices=False, compute_uv=True)

  if dA is ad_util.zero:
    return (core.pack((s, U, Vt)),
            ad.TangentTuple(ad_util.zero, ad_util.zero, ad_util.zero))

  if full_matrices:
    # TODO: implement full matrices case, documented here: https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
    raise NotImplementedError(
      "Singular value decomposition JVP not implemented for full matrices")

  k = s.shape[-1]
  Ut, V = np.conj(U).T, np.conj(Vt).T
  s_dim = s[..., None, :]
  dS = np.dot(np.dot(Ut, dA), V)
  ds = np.real(np.diag(dS))
  F = 1 / (np.square(s_dim) - np.square(s_dim.T) + np.eye(k)) - np.eye(k)
  dSS = s_dim * dS
  SdS = s_dim.T * dS
  dU = np.dot(U, F * (dSS + dSS.T))
  dV = np.dot(V, F * (SdS + SdS.T))

  m, n = A.shape[-2], A.shape[-1]
  if m > n:
    dU = dU + np.dot(np.eye(m) - np.dot(U, Ut), np.dot(dA, V)) / s_dim
  if n > m:
    dV = dV + np.dot(np.eye(n) - np.dot(V, Vt), np.dot(np.conj(dA).T, U)) / s_dim
  return (s, U, Vt), (ds, dU, dV.T)
Exemple #7
0
def _jaxtupletree_select(pred, on_true, on_false):
  aval = core.get_aval(on_true)
  if type(aval) is core.AbstractTuple:
    return core.pack(map(partial(_jaxtupletree_select, pred), on_true, on_false))
  elif isinstance(aval, UnshapedArray):
    return lax.select(pred, on_true, on_false)
  else:
    raise TypeError(aval)
Exemple #8
0
def _lu_jvp_rule(primals, tangents):
    a, = primals
    a_dot, = tangents
    lu, pivots = lu_p.bind(a)

    if a_dot is ad_util.zero:
        return (core.pack(
            (lu, pivots)), ad.TangentTuple((ad_util.zero, ad_util.zero)))

    a_shape = np.shape(a)
    m, n = a_shape[-2:]
    dtype = lax.dtype(a)
    k = min(m, n)

    permutation = lu_pivots_to_permutation(pivots, m)
    batch_dims = a_shape[:-2]
    iotas = np.ix_(*(lax.iota(np.int32, b) for b in batch_dims + (1, )))
    x = a_dot[iotas[:-1] + (permutation, slice(None))]

    # Differentiation of Matrix Functionals Using Triangular Factorization
    # F. R. De Hoog, R. S. Anderssen, and M. A. Lukas
    #
    #     LU = A
    # ==> L'U + LU' = A'
    # ==> inv(L) . L' + U' . inv(U) = inv(L) A' inv(U)
    # ==> L' = L . tril(inv(L) . A' . inv(U), -1)
    #     U' = triu(inv(L) . A' . inv(U)) . U

    ndims = len(a_shape)
    l_padding = [(0, 0, 0)] * ndims
    l_padding[-1] = (0, m - k, 0)
    zero = np._constant_like(lu, 0)
    l = lax.pad(np.tril(lu[..., :, :k], -1), zero, l_padding)
    l = l + np.eye(m, m, dtype=dtype)

    u_eye = lax.pad(np.eye(n - k, n - k, dtype=dtype), zero,
                    ((k, 0, 0), (k, 0, 0)))
    u_padding = [(0, 0, 0)] * ndims
    u_padding[-2] = (0, n - k, 0)
    u = lax.pad(np.triu(lu[..., :k, :]), zero, u_padding) + u_eye

    la = triangular_solve(l,
                          x,
                          left_side=True,
                          transpose_a=False,
                          lower=True,
                          unit_diagonal=True)
    lau = triangular_solve(u,
                           la,
                           left_side=False,
                           transpose_a=False,
                           lower=False)

    l_dot = np.matmul(l, np.tril(lau, -1))
    u_dot = np.matmul(np.triu(lau), u)
    lu_dot = l_dot + u_dot
    return (lu, pivots), (lu_dot, ad_util.zero)
Exemple #9
0
def _scan_partial_eval(trace, *tracers, **kwargs):
    jaxpr = kwargs.pop('jaxpr')
    length = kwargs.pop('length')
    forward = kwargs.pop('forward')
    assert not kwargs
    in_pvs, _ = unzip2([t.pval for t in tracers])
    sc_consts, sc_init, sc_xs = map(pe.unknown, in_pvs)

    sc_carry = sc_init
    for i in range(1000):
        second_components = (sc_consts, sc_carry, sc_xs)
        jaxpr_1, jaxpr_2, sc_out = pe.partial_eval_jaxpr(jaxpr,
                                                         second_components,
                                                         instantiate=(sc_carry,
                                                                      False))
        sc_carry_out, sc_ys = sc_out
        if sc_carry_out == sc_carry:
            break
        else:
            sc_carry = _binary_lattice_join(sc_carry, sc_carry_out)
    else:
        raise FixedPointError

    consts_tracer, init_tracer, xs_tracer = tracers
    lifted_init_tracer = _lift_tracer(trace, init_tracer, sc_carry)
    lifted_tracers = consts_tracer, lifted_init_tracer, xs_tracer
    in_pvs, in_consts = unzip2([t.pval for t in lifted_tracers])

    carry_aval, y_aval = jaxpr.out_aval
    ys_aval = _promote_aval_rank(length, y_aval)
    out_aval = core.AbstractTuple((carry_aval, ys_aval))
    out_pv = _put_known_pvs(sc_out, out_aval)

    out_carry, (ys, residuals) = scan_p.bind(*in_consts,
                                             forward=forward,
                                             length=length,
                                             jaxpr=jaxpr_1)
    out_const = core.pack((out_carry, ys))
    residuals_tracer = trace.new_instantiated_const(core.pack(residuals))
    d, c, a = lifted_tracers
    new_tracers = (d, c, (a, residuals_tracer))
    eqn = core.JaxprEqn(new_tracers, None, scan_p, (), True, False,
                        dict(forward=forward, length=length, jaxpr=jaxpr_2))
    return pe.JaxprTracer(trace, pe.PartialVal((out_pv, out_const)), eqn)
Exemple #10
0
def _convert_zeros(instantiate, example, tangent):
  t = type(instantiate)
  if t is bool:
    if instantiate:
      return ad.instantiate_zeros(example, tangent)
    elif tangent is ad_util.zero:
      return core.unit
    else:
      raise TypeError(tangent)  # not clear if ever reachable
  elif t is tuple:
    if type(tangent) is ad.TangentTuple:
      return core.pack(map(_convert_zeros, instantiate, example, tangent))
    elif tangent is ad_util.zero:
      zeros = [ad_util.zero] * len(instantiate)
      return core.pack(map(_convert_zeros, instantiate, example, zeros))
    else:
      raise TypeError(tangent)
  else:
    raise TypeError(t)
Exemple #11
0
def _convert_zeros(convert_symbolic, example, tangent):
    if tangent is ad.zero:
        if not convert_symbolic:
            return core.unit
        else:
            return ad.zeros_like_jaxval(example)
    elif type(tangent) is ad.TangentTuple:
        return core.pack(
            map(_convert_zeros, convert_symbolic, example, tangent))
    else:
        return tangent
Exemple #12
0
 def body_fun(i, vals):
     a, state = vals
     # select i-th element from each b
     b = [lax.dynamic_index_in_dim(b, i, keepdims=False) for b in bs]
     a_out = core.eval_jaxpr(jaxpr, consts, (), a, core.pack(b))
     # select fields from a_out and update state
     state_out = [
         lax.dynamic_update_index_in_dim(s, a[None, ...], i, axis=0)
         for a, s in zip([tuple(a_out)[j] for j in fields], state)
     ]
     return a_out, state_out
Exemple #13
0
def _lu_python(x):
    """Default LU decomposition in Python, where no better version exists."""
    m, n = x.shape[-2:]
    batch_dims = x.shape[:-2]
    if len(batch_dims) > 0:
        batch_size = onp.prod(batch_dims, dtype=onp.int64)
        pivot, lu = api.vmap(_lu_blocked)(lax.reshape(x, (batch_size, m, n)))
        pivot = lax.reshape(pivot, batch_dims + (min(m, n), ))
        lu = lax.reshape(lu, batch_dims + (m, n))
    else:
        pivot, lu = _lu_blocked(x)
    return core.pack((lu, pivot))
Exemple #14
0
def _lift_tracer(trace, tracer, is_unknown):
  t = type(is_unknown)
  if t is bool:
    if is_unknown:
      return trace.instantiate_const(tracer)
    else:
      return tracer
  elif t is tuple:
    tracers = map(trace.full_raise, tracer)
    return core.pack(map(partial(_lift_tracer, trace), tracers, is_unknown))
  else:
    raise TypeError(t)
Exemple #15
0
def cond(pred, true_operand, true_fun, false_operand, false_fun):
    def trace_jaxpr(fun, operand):
        op_flat, in_tree = pytree_to_flatjaxtuple(operand)
        fun_flat, out_tree = pytree_fun_to_flatjaxtuple_fun(
            lu.wrap_init(fun), (in_tree, ))
        jaxpr, pvout, consts = pe.trace_to_jaxpr(fun_flat,
                                                 (lax._abstractify(op_flat), ))
        return op_flat, jaxpr, consts, pvout, out_tree

    true_data = trace_jaxpr(true_fun, true_operand)
    true_op, true_jaxpr, true_consts, true_pval, true_tree = true_data
    false_data = trace_jaxpr(false_fun, false_operand)
    false_op, false_jaxpr, false_consts, false_pval, false_tree = false_data

    if true_tree() != false_tree():
        msg = "true_fun and false_fun outputs must have identical structure"
        raise TypeError(msg)

    try:
        joined_pval = pe.join_pvals(true_pval, false_pval)
    except TypeError:
        msg = "could not merge true_fun and false_fun output pvals: {} and {}."
        raise TypeError(msg.format(true_pval, false_pval))
    revis = _revise_cond_jaxpr(joined_pval, true_pval, true_jaxpr, true_consts)
    true_jaxpr, true_consts = revis
    revis = _revise_cond_jaxpr(joined_pval, false_pval, false_jaxpr,
                               false_consts)
    false_jaxpr, false_consts = revis
    aval_out, _ = joined_pval

    out = cond_p.bind(pred,
                      true_op,
                      core.pack(true_consts),
                      false_op,
                      core.pack(false_consts),
                      aval_out=aval_out,
                      true_jaxpr=true_jaxpr,
                      false_jaxpr=false_jaxpr)
    out = pe.merge_pvals(out, joined_pval)
    return tree_unflatten(true_tree(), out)
Exemple #16
0
def _scan_transpose(ct, consts, init, xs, forward, length, jaxpr):
    assert consts is None and init is None
    assert type(xs) is tuple
    a, res = xs
    assert a is None and res is not None

    # jaxpr :: d -> c -> (a, res) ->  (c, b)
    # jaxpr_lifted :: res -> (d, c, a) -> (c, b)
    # jaxpr_lifted_trans :: res -> (CT c, CT b) -> (CT d, CT c, CT a)
    # jaxpr_trans :: * -> (CT c, CT d) -> (CT b, res) -> ((CT c, CT d), CT a)
    assert type(jaxpr.jaxpr.invars[2]) is tuple  # assume restructuring
    jaxpr_lifted = rearrange_binders(
        lambda d, c, a_res: (a_res[1], (d, c, a_res[0])), jaxpr)
    jaxpr_lifted_trans = _transpose_jaxpr(jaxpr_lifted)
    jaxpr_trans = _move_stuff_and_add_add(jaxpr_lifted_trans)

    c_aval, b_aval = jaxpr.out_aval
    d_aval, c_aval2, _ = jaxpr.in_avals
    assert c_aval == c_aval2
    bs_aval = _promote_aval_rank(length, b_aval)
    ct_d = ad_util.zeros_like_aval(d_aval)
    ct_c, ct_bs = ad.instantiate_zeros_aval(
        core.AbstractTuple((c_aval, bs_aval)), ct)
    carry_ct = core.pack((ct_c, ct_d))

    # jaxpr_trans :: * -> (CT c, CT d) -> (CT b, res) -> ((CT c, CT d), CT a)
    core.check_jaxpr(jaxpr_trans.jaxpr)
    unit_aval, (ct_c_aval, ct_d_aval), (ct_b_aval, _) = jaxpr_trans.in_avals
    assert core.lattice_join(ct_c_aval, core.get_aval(ct_c)) == ct_c_aval
    assert core.lattice_join(ct_d_aval, core.get_aval(ct_d)) == ct_d_aval

    out = scan_p.bind(core.unit,
                      carry_ct,
                      core.pack((ct_bs, res)),
                      forward=not forward,
                      length=length,
                      jaxpr=jaxpr_trans)
    (ct_init, ct_consts), ct_as = out
    return ct_consts, ct_init, (ct_as, None)
Exemple #17
0
def lu_jvp_rule(primals, tangents):
    a, = primals
    a_dot, = tangents
    lu, pivots = lu_p.bind(a)

    a_shape = np.shape(a)
    m, n = a_shape[-2:]
    dtype = lax._dtype(a)
    k = min(m, n)

    # TODO(phawkins): use a gather rather than a matrix multiplication here.
    permutation = lu_pivots_to_permutation(pivots, m)
    p = np.array(permutation[:, None] == np.arange(m), dtype=dtype)
    x = np.matmul(p, a_dot)

    # Differentiation of Matrix Functionals Using Triangular Factorization
    # F. R. De Hoog, R. S. Anderssen, and M. A. Lukas
    #
    #     LU = A
    # ==> L'U + LU' = A'
    # ==> inv(L) . L' + U' . inv(U) = inv(L) A' inv(U)
    # ==> L' = L . tril(inv(L) . A' . inv(U), -1)
    #     U' = triu(inv(L) . A' . inv(U)) . U

    ndims = len(a_shape)
    l_padding = [(0, 0, 0)] * ndims
    l_padding[-1] = (0, m - k, 0)
    zero = np._constant_like(lu, 0)
    l = lax.pad(np.tril(lu[..., :, :k], -1), zero, l_padding)
    l = l + np.eye(m, m, dtype=dtype)

    u_eye = lax.pad(np.eye(n - k, n - k, dtype=dtype), zero,
                    ((k, 0, 0), (k, 0, 0)))
    u_padding = [(0, 0, 0)] * ndims
    u_padding[-2] = (0, n - k, 0)
    u = lax.pad(np.triu(lu[..., :k, :]), zero, u_padding) + u_eye

    la = triangular_solve(l, x, left_side=True, transpose_a=False, lower=True)
    lau = triangular_solve(u,
                           la,
                           left_side=False,
                           transpose_a=False,
                           lower=False)

    l_dot = np.matmul(l, np.tril(lau, -1))
    u_dot = np.matmul(np.triu(lau), u)
    lu_dot = l_dot + u_dot
    return core.pack((lu, pivots)), ad.TangentTuple((lu_dot, ad_util.zero))
Exemple #18
0
def _scan_impl(consts, init, xs, forward, length, jaxpr):
  _, _, x_aval = jaxpr.in_avals
  _, y_aval = jaxpr.out_aval
  ys_aval = _promote_aval_rank(length, y_aval)

  def body_fun(i, vals):
    idx = i if forward else length - i - 1
    carry, ys = vals
    x = _index_arrays(idx, x_aval, xs)
    carry_out, y = core.jaxpr_as_fun(jaxpr)(consts, carry, x)
    ys_out = _update_arrays(idx, y_aval, ys, y)
    return (carry_out, ys_out)

  ys_init = _empty_arrays(ys_aval)
  carry, ys = fori_loop(0, length, body_fun, (init, ys_init))
  return core.pack((carry, ys))
Exemple #19
0
def _revise_cond_jaxpr(new_pval, old_pval, jaxpr, consts):
    new_pv, new_const = new_pval
    old_pv, old_const = old_pval
    if new_pv == old_pv:
        # we didn't move up the lattice by joining with the other side
        return jaxpr, consts
    elif old_pv is None:
        # we moved up the lattice from totally-known, so make a new jaxpr that
        # returns a single constant JaxTuple with elements that are constants
        # drawn from consts where new_pv is unknown
        assert not jaxpr.eqns and not consts
        outvar = pe.Var(0, "_cond")
        new_jaxpr = jaxpr.copy()
        new_jaxpr.constvars = [outvar]
        new_jaxpr.outvar = outvar
        new_consts = (core.pack([
            core.unit if pv is None else old_c
            for pv, old_c in zip(new_pv, old_const)
        ]), )
        return new_jaxpr, new_consts
    else:
        # we moved up the lattice, but not from totally-constant, so adapt the
        # japxr to return some new constants in places that are now unknown but
        # weren't before
        eqn = jaxpr.eqns[-1]
        assert eqn.primitive == core.pack_p
        assert len(eqn.outvars) == 1 and eqn.outvars[0] == jaxpr.outvar
        newvar = pe.gensym("_cond")
        new_constvars, new_constvals = unzip2([
            (newvar(), c) for new, old, c in zip(new_pv, old_pv, old_const)
            if old is None and new is not None
        ])
        new_consts = consts + tuple(new_constvals)
        new_jaxpr = jaxpr.copy()
        new_jaxpr.constvars = tuple(jaxpr.constvars) + tuple(new_constvars)
        newvars = iter(new_constvars)
        new_invars = [
            next(newvars) if old is None and new is not None else
            (core.unitvar if new is None and old is None else v)
            for new, old, v in zip(new_pv, old_pv, eqn.invars)
        ]
        new_jaxpr.eqns = (list(jaxpr.eqns[:-1]) +
                          [_pack_eqn(new_invars, jaxpr.outvar)])
        return new_jaxpr, new_consts
Exemple #20
0
  def testShardedDeviceTuple(self):
    f = lambda x: core.pack((x, x))
    f = pmap(f)

    shape = (xla_bridge.device_count(), 4)
    x = onp.arange(prod(shape), dtype=onp.float32).reshape(shape)

    # test that we can pass in and out ShardedDeviceTuples (and unpack them)
    y = f(x)
    self.assertIsInstance(y, pxla.ShardedDeviceTuple)
    self.assertIsInstance(y, core.JaxTuple)
    self.assertAllClose(y, (x, x), check_dtypes=False)
    z = f(y)
    self.assertIsInstance(z, pxla.ShardedDeviceTuple)
    self.assertAllClose(z, (y, y), check_dtypes=True)

    # test that we can pass a ShardedDeviceTuple to a regular jit computation
    w = jit(lambda x: list(x)[0])(y)
    self.assertAllClose(w, x, check_dtypes=False)
Exemple #21
0
def _scan_batching_rule(batched_args, batch_dims, forward, length, jaxpr):
    consts, init, xs = batched_args
    consts_bdim, init_bdim, xs_bdim = batch_dims

    sizes = lax._reduce(set.union,
                        map(batching.dimsize, batch_dims, batched_args))
    size = sizes.pop()
    assert not sizes

    consts_batched = batching.where_batched(consts_bdim)
    init_batched = batching.where_batched(init_bdim)
    xs_batched = batching.where_batched(xs_bdim)

    carry_batched = init_batched
    for _ in range(1000):
        which_batched = (consts_batched, carry_batched, xs_batched)
        jaxpr_batched, batched_out = batching.batch_jaxpr(
            jaxpr, size, which_batched, instantiate=(carry_batched, False))
        carry_batched_out, ys_batched = batched_out
        if _binary_lattice_eq(carry_batched_out, carry_batched):
            break
        else:
            carry_batched = _binary_lattice_join(carry_batched_out,
                                                 carry_batched)
    else:
        raise FixedPointError

    consts_batched = batching.instantiate_bdim(size, 0, consts_batched,
                                               consts_bdim, consts)
    init_batched = batching.instantiate_bdim(size, 0, carry_batched, init_bdim,
                                             init)
    xs_batched = batching.instantiate_bdim(size, 1, xs_batched, xs_bdim, xs)

    carry_out, ys = scan_p.bind(consts_batched,
                                init_batched,
                                xs_batched,
                                forward=forward,
                                length=length,
                                jaxpr=jaxpr_batched)

    carry_out_bdim = batching.bools_to_bdims(0, carry_batched)
    ys_bdim = batching.bools_to_bdims(1, ys_batched)
    return core.pack((carry_out, ys)), (carry_out_bdim, ys_bdim)
Exemple #22
0
    def __init__(self, seed):
        """Create a new PRNG key.

    Args:
      seed: a scalar integer value used to initialize the PRNG key.

    Returns:
      A new PRNGKey object.
    """
        convert = lambda key: lax.convert_element_type(key, onp.uint32)
        if onp.shape(seed):
            raise TypeError("PRNGKey seed must be a scalar.")
        if isinstance(seed, (int, onp.ndarray)):
            # Special handling of raw integer values, which may have be 64bit even
            # when jax_enable_x64=False and we don't want to drop the top 32 bits
            k1 = convert(onp.bitwise_and(onp.right_shift(seed, 32),
                                         0xFFFFFFFF))
        else:
            k1 = convert(lax.shift_right_logical(seed, 32))
        k2 = convert(lax.bitwise_and(seed, 0xFFFFFFFF))
        self.keypair = core.pack((k1, k2))
Exemple #23
0
def _scan_init(rng, submodule_params_dict, consts, init, xs, forward, length,
               jaxpr, reuse, reuse_only):
    # TODO update to jax==0.1.42
    assert len(consts) == 0

    _, _, x_aval = jaxpr.in_avals
    _, y_aval = jaxpr.out_aval
    ys_aval = _promote_aval_rank(length, y_aval)

    x = _index_arrays(0, x_aval, xs)
    submodule_params_dict = _get_submodule_params(rng,
                                                  jaxpr.jaxpr,
                                                  jaxpr.literals, (),
                                                  submodule_params_dict,
                                                  consts,
                                                  init,
                                                  x,
                                                  reuse=reuse,
                                                  reuse_only=reuse_only)

    if len(submodule_params_dict) == 0:
        submodule_params = ()
    else:
        primitive, = submodule_params_dict.keys()
        submodule_params = (primitive._params_namedtuple(
            submodule_params_dict[primitive]), )

    def body_fun(i, vals):
        idx = i if forward else length - i - 1
        carry, ys = vals
        x = _index_arrays(idx, x_aval, xs)
        cell = parametrized(jc.jaxpr_as_fun(jaxpr))
        carry_out, y = cell.apply(submodule_params, consts, carry, x)
        ys_out = _update_arrays(idx, y_aval, ys, y)
        return carry_out, ys_out

    ys_init = _empty_arrays(ys_aval)
    carry, ys = lax.fori_loop(0, length, body_fun, (init, ys_init))
    return jc.pack((carry, ys)), submodule_params_dict
Exemple #24
0
def _scan_apply(submodule_params_iter, consts, init, xs, forward, length,
                jaxpr):
    # TODO update to jax==0.1.42
    _, _, x_aval = jaxpr.in_avals
    _, y_aval = jaxpr.out_aval
    ys_aval = _promote_aval_rank(length, y_aval)

    # TODO fix param sharing
    cell_params = (submodule_params_iter.get_params(None), ) if len(
        submodule_params_iter.submodule_params) > 0 else ()

    def body_fun(i, vals):
        idx = i if forward else length - i - 1
        carry, ys = vals
        x = _index_arrays(idx, x_aval, xs)
        cell = parametrized(jc.jaxpr_as_fun(jaxpr))
        carry_out, y = cell.apply(cell_params, consts, carry, x)
        ys_out = _update_arrays(idx, y_aval, ys, y)
        return carry_out, ys_out

    ys_init = _empty_arrays(ys_aval)
    carry, ys = lax.fori_loop(0, length, body_fun, (init, ys_init))
    return jc.pack((carry, ys))
Exemple #25
0
def _tscan(f, a, bs, fields=(0, )):
    """
    Works as jax.lax.scan but has additional `fields` argument to select only
    necessary fields from `a`'s structure. Defaults to selecting only the first
    field. Other fields will be filled by None.
    """
    # Note: code is copied and modified from lax.scan implementation in
    # [JAX](https://github.com/google/jax) to support the additional `fields`
    # arg. Original code has the following copyright:
    #
    # Copyright 2018 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License")

    # convert pytree to flat jaxtuple
    a, a_tree = pytree_to_flatjaxtuple(a)
    bs, b_tree = pytree_to_flatjaxtuple(bs)
    fields, _ = pytree_to_flatjaxtuple(fields)
    f, out_tree = pytree_fun_to_flatjaxtuple_fun(wrap_init(f),
                                                 (a_tree, b_tree))

    # convert arrays to abstract values
    a_aval, _ = lax._abstractify(a)
    bs_aval, _ = lax._abstractify(bs)
    # convert bs to b
    b_aval = core.AbstractTuple(
        [ShapedArray(b.shape[1:], b.dtype) for b in bs_aval])

    # convert abstract values to partial values (?) then evaluate to get jaxpr
    a_pval = partial_eval.PartialVal((a_aval, core.unit))
    b_pval = partial_eval.PartialVal((b_aval, core.unit))
    jaxpr, pval_out, consts = partial_eval.trace_to_jaxpr(f, (a_pval, b_pval))
    aval_out, _ = pval_out
    consts = core.pack(consts)

    out = tscan_p.bind(a, bs, fields, consts, aval_out=aval_out, jaxpr=jaxpr)
    return tree_unflatten(out_tree(), out)
Exemple #26
0
def while_loop(cond_fun, body_fun, init_val):
    """Call `body_fun` repeatedly in a loop while `cond_fun` is True.

  Arguments:
    cond_fun: pure function of type `T -> Bool`.
    body_fun: pure function of type `T -> T`.
    init_val: value of type `T`, a type that can be a scalar, array, or any
      (nested) Python tuple/list/dict thereof.

  Returns:
    The output from the final iteration of body_fun, of type `T`.

  The semantics of `while_loop` are given by this Python implementation::

    def while_loop(cond_fun, body_fun, init_val):
      val = init_val
      while cond_fun(val):
        val = body_fun(val)
      return val

  Unlike that pure Python version, `while_loop` is a JAX primitive and is
  lowered to a single XLA While HLO. That makes it useful for reducing
  compilation times for jit-compiled functions, since native Python loop
  constructs in an `@jit` function are unrolled, leading to large XLA
  computations.

  Another difference from using Python-native loop constructs is that
  `while_loop` is not (yet) reverse-mode differentiable because XLA computations
  require static bounds on memory requirements.
  """
    init_val_flat, in_tree = pytree_to_jaxtupletree(init_val)
    flat_body_fun, out_tree = pytree_fun_to_jaxtupletree_fun(
        lu.wrap_init(body_fun), (in_tree, ))
    flat_cond_fun, _ = pytree_fun_to_jaxtupletree_fun(lu.wrap_init(cond_fun),
                                                      (in_tree, ))

    pval_flat = lax._abstractify(init_val_flat)
    cond_jaxpr, _, cond_consts = pe.trace_to_jaxpr(flat_cond_fun,
                                                   (pval_flat, ))
    body_jaxpr, pval_out, body_consts = pe.trace_to_jaxpr(
        flat_body_fun, (pval_flat, ))
    aval_out, _ = pval_out

    # We don't want to promote literal constants as loop arguments; there are
    # sometimes many of them. We pass tracers as loop arguments, but leave
    # nontracers as constants. We also sort the constants so the nontracers are
    # first.
    def split_tracers_and_nontracers(jaxpr, consts):
        tracer = []
        nontracer = []
        for x in zip(jaxpr.constvars, consts):
            # TODO(phawkins): We avoid treating DeviceArrays as constant literals so
            # we don't copy large arrays back to the host. We probably should relax
            # this and either always copy small constants, or opportunistically use
            # DeviceArray values for which we already know npy_value.
            not_literal_const = isinstance(x[1],
                                           (core.Tracer, xla.DeviceArray))
            (tracer if not_literal_const else nontracer).append(x)
        tracer_vars, tracer_consts = unzip2(tracer)
        nontracer_vars, nontracer_consts = unzip2(nontracer)
        return nontracer_vars + tracer_vars, nontracer_consts, tracer_consts

    cond_split = split_tracers_and_nontracers(cond_jaxpr, cond_consts)
    cond_jaxpr.constvars, cond_nontracer_consts, cond_tracer_consts = cond_split
    body_split = split_tracers_and_nontracers(body_jaxpr, body_consts)
    body_jaxpr.constvars, body_nontracer_consts, body_tracer_consts = body_split

    if out_tree() != in_tree:
        raise TypeError(
            "body_fun input and output must have identical structure")
    out_flat = while_p.bind(
        init_val_flat,
        core.pack(cond_tracer_consts),
        core.pack(body_tracer_consts),
        cond_consts=lax._OpaqueParam(cond_nontracer_consts),
        body_consts=lax._OpaqueParam(body_nontracer_consts),
        aval_out=aval_out,
        cond_jaxpr=cond_jaxpr,
        body_jaxpr=body_jaxpr)
    return build_tree(out_tree(), out_flat)
Exemple #27
0
 def fun(x):
     y = pack((x, x))
     z = pack((y, x))
     y1, _ = z
     y2, _ = y1
     return y2
Exemple #28
0
 def foo(x):
     return np.tup_add(core.pack((x, y)))
Exemple #29
0
 def bar(y):
     x1, y1 = core.pack((x, y))
     return np.sin(x1 * y1)
Exemple #30
0
 def foo(x):
     x1, y1 = core.pack((x, y))
     assert y1 is y, (y1, y)
     return x1