예제 #1
0
def _scan_jvp(primals, tangents, forward, length, jaxpr, num_consts, num_carry,
              linear):
    num_xs = len(jaxpr.in_avals) - num_carry - num_consts
    num_ys = len(jaxpr.out_avals) - num_carry
    nonzeros = [t is not ad_util.zero for t in tangents]
    const_nz, init_nz, xs_nz = split_list(nonzeros, [num_consts, num_carry])

    carry_nz = init_nz
    for _ in range(1000):
        nonzeros = const_nz + carry_nz + xs_nz
        jaxpr_jvp, nonzeros_out = ad.jvp_jaxpr(jaxpr,
                                               nonzeros,
                                               instantiate=carry_nz +
                                               [False] * num_ys)
        carry_nz_out, ys_nz = nonzeros_out[:num_carry], nonzeros_out[
            num_carry:]
        if carry_nz_out == carry_nz:
            break
        else:
            carry_nz = carry_nz_out
    else:
        raise FixedPointError
    tangents = [
        ad.instantiate_zeros(x, t) if t is ad_util.zero and nz else t
        for x, t, nz in zip(primals, tangents, nonzeros)
    ]

    consts, init, xs = split_list(primals, [num_consts, num_carry])
    all_tangents = split_list(tangents, [num_consts, num_carry])
    consts_dot, init_dot, xs_dot = _map(_prune_zeros, all_tangents)

    jaxpr_jvp_rearranged = ad.rearrange_binders(
        jaxpr_jvp, [num_consts, num_carry, num_xs],
        [len(consts_dot), len(init_dot),
         len(xs_dot)], [num_carry, num_ys],
        [len(init_dot), sum(nonzeros_out) - len(init_dot)])

    consts_linear, init_linear, xs_linear = split_list(linear,
                                                       [num_consts, num_carry])
    jaxpr_jvp_linear = (consts_linear + [True] * len(consts_dot) +
                        init_linear + [True] * len(init_dot) + xs_linear +
                        [True] * len(xs_dot))

    out_flat = scan_p.bind(*(consts + consts_dot + init + init_dot + xs +
                             xs_dot),
                           forward=forward,
                           length=length,
                           jaxpr=jaxpr_jvp_rearranged,
                           num_consts=num_consts + len(consts_dot),
                           num_carry=num_carry + len(init_dot),
                           linear=jaxpr_jvp_linear)

    carry, carry_dot, ys, ys_dot = split_list(
        out_flat, [num_carry, len(init_dot), num_ys])
    primals_out = carry + ys
    tangents_out = iter(carry_dot + ys_dot)
    tangents_out = [
        next(tangents_out) if nz else ad_util.zero for nz in nonzeros_out
    ]
    return primals_out, tangents_out
예제 #2
0
def _for_jvp(primals, tangents, *, jaxpr, nsteps, reverse, which_linear):
    nonzero_tangents = [not isinstance(t, ad_util.Zero) for t in tangents]
    # We need to find out which `Ref`s have nonzero tangents after running the
    # for loop. Ordinarily we do this with a fixed point on the body jaxpr but
    # a `for` body jaxpr is stateful and has no outputs. We therefore discharge
    # the state effect from the jaxpr and we will now have a "symmetric" jaxpr
    # where the inputs line up with the outputs. We use this discharged jaxpr
    # for the fixed point.
    discharged_jaxpr, body_consts = discharge_state(jaxpr, ())
    for _ in range(len(nonzero_tangents)):
        _, out_nonzero_tangents = ad.jvp_jaxpr(core.ClosedJaxpr(
            discharged_jaxpr, body_consts), [False] + nonzero_tangents,
                                               instantiate=nonzero_tangents)
        if out_nonzero_tangents == nonzero_tangents:
            break
        nonzero_tangents = map(operator.or_, nonzero_tangents,
                               out_nonzero_tangents)
    else:
        raise Exception("Invalid fixpoint")
    tangents = [
        ad.instantiate_zeros(t) if inst else t
        for t, inst in zip(tangents, nonzero_tangents)
    ]
    tangents = [t for t in tangents if type(t) is not ad_util.Zero]
    closed_jaxpr = core.ClosedJaxpr(jaxpr, ())
    jvp_jaxpr_, _ = ad.jvp_jaxpr(closed_jaxpr, [False] + nonzero_tangents, [])
    jvp_jaxpr, jvp_consts = jvp_jaxpr_.jaxpr, jvp_jaxpr_.consts
    jvp_which_linear = ((False, ) * len(jvp_consts) + which_linear +
                        (True, ) * len(tangents))
    out_flat = for_p.bind(*jvp_consts,
                          *primals,
                          *tangents,
                          jaxpr=jvp_jaxpr,
                          nsteps=nsteps,
                          reverse=reverse,
                          which_linear=jvp_which_linear)
    # `out_flat` includes constant inputs into the `for_loop` which are
    # converted into outputs as well. We don't care about these in AD so we
    # throw them out.
    _, out_primals, out_tangents = split_list(
        out_flat, [len(jvp_consts), len(primals)])
    out_tangents_iter = iter(out_tangents)
    out_tangents = [
        next(out_tangents_iter) if nz else ad_util.Zero.from_value(p)
        for p, nz in zip(out_primals, nonzero_tangents)
    ]
    return out_primals, out_tangents
예제 #3
0
def _convert_zeros(instantiate, example, tangent):
  t = type(instantiate)
  if t is bool:
    if instantiate:
      return ad.instantiate_zeros(example, tangent)
    elif tangent is ad_util.zero:
      return core.unit
    else:
      raise TypeError(tangent)  # not clear if ever reachable
  elif t is tuple:
    if type(tangent) is ad.TangentTuple:
      return core.pack(map(_convert_zeros, instantiate, example, tangent))
    elif tangent is ad_util.zero:
      zeros = [ad_util.zero] * len(instantiate)
      return core.pack(map(_convert_zeros, instantiate, example, zeros))
    else:
      raise TypeError(tangent)
  else:
    raise TypeError(t)