예제 #1
0
def _cond_jvp(primals, tangents, branches, linear):
    nonzeros = [type(t) is not ad_util.Zero for t in tangents]

    index_nz, *ops_nz = nonzeros
    assert index_nz is False

    branches_out_nz = [
        ad.jvp_jaxpr(jaxpr, ops_nz, instantiate=False)[1] for jaxpr in branches
    ]
    out_nz = [any(nz) for nz in zip(*branches_out_nz)]

    branches_jvp = tuple(
        ad.jvp_jaxpr(jaxpr, ops_nz, instantiate=out_nz)[0]
        for jaxpr in branches)

    index, *ops = primals
    _, *ops_dot = tangents
    ops_dot = _prune_zeros(ops_dot)

    ops_lin = tuple(linear)
    linear_jvp = ops_lin + (True, ) * len(ops_dot)
    out = cond_p.bind(index,
                      *ops,
                      *ops_dot,
                      branches=branches_jvp,
                      linear=linear_jvp)
    out_primals, out_tangents = split_list(out, [len(out_nz)])
    out_tangents_iter = iter(out_tangents)
    out_tangents = [
        next(out_tangents_iter) if nz else ad_util.Zero.from_value(p)
        for p, nz in zip(out_primals, out_nz)
    ]
    return out_primals, out_tangents
예제 #2
0
def _scan_jvp(primals, tangents, forward, length, jaxpr, num_consts, num_carry,
              linear):
    num_xs = len(jaxpr.in_avals) - num_carry - num_consts
    num_ys = len(jaxpr.out_avals) - num_carry
    nonzeros = [t is not ad_util.zero for t in tangents]
    const_nz, init_nz, xs_nz = split_list(nonzeros, [num_consts, num_carry])

    carry_nz = init_nz
    for _ in range(1000):
        nonzeros = const_nz + carry_nz + xs_nz
        jaxpr_jvp, nonzeros_out = ad.jvp_jaxpr(jaxpr,
                                               nonzeros,
                                               instantiate=carry_nz +
                                               [False] * num_ys)
        carry_nz_out, ys_nz = nonzeros_out[:num_carry], nonzeros_out[
            num_carry:]
        if carry_nz_out == carry_nz:
            break
        else:
            carry_nz = carry_nz_out
    else:
        raise FixedPointError
    tangents = [
        ad.instantiate_zeros(x, t) if t is ad_util.zero and nz else t
        for x, t, nz in zip(primals, tangents, nonzeros)
    ]

    consts, init, xs = split_list(primals, [num_consts, num_carry])
    all_tangents = split_list(tangents, [num_consts, num_carry])
    consts_dot, init_dot, xs_dot = _map(_prune_zeros, all_tangents)

    jaxpr_jvp_rearranged = ad.rearrange_binders(
        jaxpr_jvp, [num_consts, num_carry, num_xs],
        [len(consts_dot), len(init_dot),
         len(xs_dot)], [num_carry, num_ys],
        [len(init_dot), sum(nonzeros_out) - len(init_dot)])

    consts_linear, init_linear, xs_linear = split_list(linear,
                                                       [num_consts, num_carry])
    jaxpr_jvp_linear = (consts_linear + [True] * len(consts_dot) +
                        init_linear + [True] * len(init_dot) + xs_linear +
                        [True] * len(xs_dot))

    out_flat = scan_p.bind(*(consts + consts_dot + init + init_dot + xs +
                             xs_dot),
                           forward=forward,
                           length=length,
                           jaxpr=jaxpr_jvp_rearranged,
                           num_consts=num_consts + len(consts_dot),
                           num_carry=num_carry + len(init_dot),
                           linear=jaxpr_jvp_linear)

    carry, carry_dot, ys, ys_dot = split_list(
        out_flat, [num_carry, len(init_dot), num_ys])
    primals_out = carry + ys
    tangents_out = iter(carry_dot + ys_dot)
    tangents_out = [
        next(tangents_out) if nz else ad_util.zero for nz in nonzeros_out
    ]
    return primals_out, tangents_out
예제 #3
0
def _for_jvp(primals, tangents, *, jaxpr, nsteps, reverse, which_linear):
    nonzero_tangents = [not isinstance(t, ad_util.Zero) for t in tangents]
    # We need to find out which `Ref`s have nonzero tangents after running the
    # for loop. Ordinarily we do this with a fixed point on the body jaxpr but
    # a `for` body jaxpr is stateful and has no outputs. We therefore discharge
    # the state effect from the jaxpr and we will now have a "symmetric" jaxpr
    # where the inputs line up with the outputs. We use this discharged jaxpr
    # for the fixed point.
    discharged_jaxpr, body_consts = discharge_state(jaxpr, ())
    for _ in range(len(nonzero_tangents)):
        _, out_nonzero_tangents = ad.jvp_jaxpr(core.ClosedJaxpr(
            discharged_jaxpr, body_consts), [False] + nonzero_tangents,
                                               instantiate=nonzero_tangents)
        if out_nonzero_tangents == nonzero_tangents:
            break
        nonzero_tangents = map(operator.or_, nonzero_tangents,
                               out_nonzero_tangents)
    else:
        raise Exception("Invalid fixpoint")
    tangents = [
        ad.instantiate_zeros(t) if inst else t
        for t, inst in zip(tangents, nonzero_tangents)
    ]
    tangents = [t for t in tangents if type(t) is not ad_util.Zero]
    closed_jaxpr = core.ClosedJaxpr(jaxpr, ())
    jvp_jaxpr_, _ = ad.jvp_jaxpr(closed_jaxpr, [False] + nonzero_tangents, [])
    jvp_jaxpr, jvp_consts = jvp_jaxpr_.jaxpr, jvp_jaxpr_.consts
    jvp_which_linear = ((False, ) * len(jvp_consts) + which_linear +
                        (True, ) * len(tangents))
    out_flat = for_p.bind(*jvp_consts,
                          *primals,
                          *tangents,
                          jaxpr=jvp_jaxpr,
                          nsteps=nsteps,
                          reverse=reverse,
                          which_linear=jvp_which_linear)
    # `out_flat` includes constant inputs into the `for_loop` which are
    # converted into outputs as well. We don't care about these in AD so we
    # throw them out.
    _, out_primals, out_tangents = split_list(
        out_flat, [len(jvp_consts), len(primals)])
    out_tangents_iter = iter(out_tangents)
    out_tangents = [
        next(out_tangents_iter) if nz else ad_util.Zero.from_value(p)
        for p, nz in zip(out_primals, nonzero_tangents)
    ]
    return out_primals, out_tangents
예제 #4
0
def _scan_jvp(primals, tangents, forward, length, jaxpr):
    consts, init, xs = primals
    consts_dot, init_dot, xs_dot = tangents
    consts_aval, carry_aval, x_aval = jaxpr.in_avals
    _, y_aval = jaxpr.out_aval

    consts_nonzeros = ad.get_nonzeros(consts_dot)
    init_nonzeros = ad.get_nonzeros(init_dot)
    xs_nonzeros = ad.get_nonzeros(xs_dot)  # same as x_nonzeros b/c arrays

    carry_nonzeros = init_nonzeros
    for _ in range(1000):
        nonzeros = (consts_nonzeros, carry_nonzeros, xs_nonzeros)
        jaxpr_jvp, nonzeros_out = ad.jvp_jaxpr(jaxpr,
                                               nonzeros,
                                               instantiate=(carry_nonzeros,
                                                            False))
        carry_nonzeros_out, ys_nonzeros = nonzeros_out
        if carry_nonzeros_out == carry_nonzeros:
            break
        else:
            carry_nonzeros = _binary_lattice_join(carry_nonzeros_out,
                                                  carry_nonzeros)
    else:
        raise FixedPointError

    # convert_zeros is like strip_zeros but uses explicit lattice information to
    # instantiate zeros in some cases, namely in init_dot based on the fixed point
    nonzero_init_dot = _convert_zeros(carry_nonzeros, init, init_dot)
    nonzero_consts_dot = _convert_zeros(consts_nonzeros, consts, consts_dot)
    nonzero_xs_dot = _convert_zeros(xs_nonzeros, xs, xs_dot)

    consts_dual = core.pack((consts, nonzero_consts_dot))
    init_dual = core.pack((init, nonzero_init_dot))
    xs_dual = core.pack((xs, nonzero_xs_dot))

    carry_out_dual, ys_dual = scan_p.bind(consts_dual,
                                          init_dual,
                                          xs_dual,
                                          forward=forward,
                                          length=length,
                                          jaxpr=jaxpr_jvp)

    ys, ys_dot = ys_dual
    ys_dot = ad.put_zeros(ad.TangentTuple, ys_nonzeros, ys_dot)

    carry_out, carry_out_dot = carry_out_dual
    carry_out_dot = ad.put_zeros(ad.TangentTuple, carry_nonzeros_out,
                                 carry_out_dot)
    return core.pack((carry_out, ys)), ad.TangentTuple((carry_out_dot, ys_dot))
예제 #5
0
def remat_jvp(primals, tangents, jaxpr, prevent_cse, differentiated, policy):
  assert not jaxpr.constvars
  in_nonzeros = [type(t) is not ad_util.Zero for t in tangents]
  jaxpr_ = core.ClosedJaxpr(jaxpr, ())
  jaxpr_jvp_, out_nonzeros = ad.jvp_jaxpr(jaxpr_, in_nonzeros, False)
  nonzero_tangents = [t for t in tangents if type(t) is not ad_util.Zero]
  jaxpr_jvp = pe.convert_constvars_jaxpr(jaxpr_jvp_.jaxpr)
  outs = remat_p.bind(
      *jaxpr_jvp_.consts, *primals, *nonzero_tangents, jaxpr=jaxpr_jvp,
      prevent_cse=prevent_cse, differentiated=differentiated, policy=policy)
  out_primals, out_tangents_ = split_list(outs, [len(jaxpr.outvars)])
  out_tangents_ = iter(out_tangents_)
  out_tangents = [next(out_tangents_) if nz else ad_util.Zero.from_value(p)
                  for p, nz in zip(out_primals, out_nonzeros)]
  return out_primals, out_tangents
예제 #6
0
def custom_vmap_jvp(primals, tangents, *, call, rule, in_tree, out_tree):
    def jvp_of_rule_rule(axis_size, in_batched, primals, tangents):
        in_batched_ps, in_batched_ts = in_batched

        mutually_batched = tree_map(operator.and_, in_batched_ps,
                                    in_batched_ts)
        extra_batched_ps = tree_map(
            lambda pb, tb: 0
            if pb and not tb else None, in_batched_ps, in_batched_ts)
        extra_batched_ts = tree_map(
            lambda pb, tb: 0
            if tb and not pb else None, in_batched_ps, in_batched_ts)

        out_mutually_batched = lu.Store()
        flat_ps_ts, tree_ps_ts = tree_flatten((primals, tangents))
        flat_extra_batched_ps_ts, tree_ps_ts2 = tree_flatten(
            (extra_batched_ps, extra_batched_ts), is_leaf=lambda x: x is None)

        # TODO(frostig): assert these also equal:
        #   treedef_tuple((in_tree, in_tree))
        # once https://github.com/google/jax/issues/9066 is fixed
        assert tree_ps_ts == tree_ps_ts2
        del tree_ps_ts2

        def to_jvp(*primals):
            out, out_batched = call_rule(rule, axis_size, mutually_batched,
                                         primals)
            check_vmap_rule_trees(rule, out_tree, tree_structure(out),
                                  tree_structure(out_batched))
            out_mutually_batched.store(out_batched)
            return out

        def to_vmap_over_extra_batched_dims(primals, tangents):
            return jax.jvp(to_jvp, primals, tangents)

        to_vmap_over_extra_batched_dims_flat, out_tree2 = flatten_fun_nokwargs(
            lu.wrap_init(to_vmap_over_extra_batched_dims), tree_ps_ts)

        flat_out_ps_ts, flat_out_axes = vmap_unrestricted(
            to_vmap_over_extra_batched_dims_flat,
            *flat_ps_ts,
            in_axes=flat_extra_batched_ps_ts,
            axis_name=core.no_axis_name,
            axis_size=axis_size)

        n, ragged = divmod(len(flat_out_ps_ts), 2)
        assert not ragged
        flat_out_ps, flat_out_ts = flat_out_ps_ts[:n], flat_out_ps_ts[n:]
        flat_out_axes_p, flat_out_axes_t = flat_out_axes[:n], flat_out_axes[n:]
        flat_out_ps = map(maybe_bdim_at_front, flat_out_ps, flat_out_axes_p)
        flat_out_extra_batched_ps = [
            d is not not_mapped for d in flat_out_axes_p
        ]
        flat_out_ts = map(maybe_bdim_at_front, flat_out_ts, flat_out_axes_t)
        flat_out_extra_batched_ts = [
            d is not not_mapped for d in flat_out_axes_t
        ]

        out_ps, out_ts = tree_unflatten(out_tree2(),
                                        [*flat_out_ps, *flat_out_ts])
        out_extra_batched_ps, out_extra_batched_ts = tree_unflatten(
            out_tree2(),
            [*flat_out_extra_batched_ps, *flat_out_extra_batched_ts])

        out_batched_ps = tree_map(operator.or_, out_mutually_batched.val,
                                  out_extra_batched_ps)
        out_batched_ts = tree_map(operator.or_, out_mutually_batched.val,
                                  out_extra_batched_ts)

        return (out_ps, out_ts), (out_batched_ps, out_batched_ts)

    tangents = map(ad.instantiate_zeros, tangents)
    jvp_call, _ = ad.jvp_jaxpr(call, [True] * len(primals), True)
    jvp_in_tree = treedef_tuple((in_tree, in_tree))
    jvp_out_tree = treedef_tuple((out_tree, out_tree))
    outs = custom_vmap_p.bind(*primals,
                              *tangents,
                              call=jvp_call,
                              rule=jvp_of_rule_rule,
                              in_tree=jvp_in_tree,
                              out_tree=jvp_out_tree)
    assert len(outs) % 2 == 0, len(outs)
    out_primals, out_tangents = util.split_list(outs, [len(outs) // 2])
    return out_primals, out_tangents