示例#1
0
def _while_loop_batching_rule(batched_args, batch_dims, aval_out, cond_jaxpr,
                              body_jaxpr):
    # See https://github.com/google/jax/issues/441 for a discussion.
    # To batch a while_loop, we need to do some masking, since the elements of the
    # batch may run for different numbers of iterations. We perform that masking
    # using lax.select, and keep the loop running so long as any of the batch
    # elements need by effectively using an np.any(...) in the cond_fun.
    # The basic strategy here is to lift `cond_jaxpr` and `body_jaxpr` back into
    # traceable Python functions using `core.eval_jaxpr`. Then we can batch them
    # using `batching.batch_transform` (the transform underlying `api.vmap`).
    # TODO(mattjj): Revise this using scan machinery (and fixed-point the loop
    # carry instead of lifting it all the way!)
    init_val, cond_consts, body_consts = batched_args
    init_val_bd, cond_consts_bd, body_consts_bd = batch_dims

    sizes = lax._reduce(set.union,
                        map(batching.dimsize, batch_dims, batched_args))
    size = sizes.pop()
    assert not sizes

    # TODO(mattjj): if cond_consts_bd is also None, we could keep cond_fun
    # unbatched and avoid the masking logic, but we ignore that optimization
    init_val = batching.bdim_at_front(init_val,
                                      init_val_bd,
                                      size,
                                      force_broadcast=True)
    init_val_bd = 0

    def batched_cond_fun(batched_loop_carry):
        @lu.wrap_init
        def lifted(loop_carry, cond_consts):
            return core.eval_jaxpr(cond_jaxpr, cond_consts, (), loop_carry)

        f = batching.batch_transform(lifted, size,
                                     (init_val_bd, cond_consts_bd), 0)
        preds = f.call_wrapped((batched_loop_carry, cond_consts))
        return lax.reduce(preds, onp.array(False), lax.bitwise_or, [0])

    def batched_body_fun(batched_loop_carry):
        @lu.wrap_init
        def lifted(loop_carry, cond_consts, body_consts):
            pred = core.eval_jaxpr(cond_jaxpr, cond_consts, (), loop_carry)
            new_loop_carry = core.eval_jaxpr(body_jaxpr, body_consts, (),
                                             loop_carry)
            return _jaxtupletree_select(pred, new_loop_carry, loop_carry)

        f = batching.batch_transform(
            lifted, size, (init_val_bd, cond_consts_bd, body_consts_bd),
            init_val_bd)
        return f.call_wrapped((batched_loop_carry, cond_consts, body_consts))

    return while_loop(batched_cond_fun, batched_body_fun,
                      init_val), init_val_bd
示例#2
0
def _scan_batching_rule(batched_args, batch_dims, forward, length, jaxpr):
    consts, init, xs = batched_args
    consts_bdim, init_bdim, xs_bdim = batch_dims

    sizes = lax._reduce(set.union,
                        map(batching.dimsize, batch_dims, batched_args))
    size = sizes.pop()
    assert not sizes

    consts_batched = batching.where_batched(consts_bdim)
    init_batched = batching.where_batched(init_bdim)
    xs_batched = batching.where_batched(xs_bdim)

    carry_batched = init_batched
    for _ in range(1000):
        which_batched = (consts_batched, carry_batched, xs_batched)
        jaxpr_batched, batched_out = batching.batch_jaxpr(
            jaxpr, size, which_batched, instantiate=(carry_batched, False))
        carry_batched_out, ys_batched = batched_out
        if _binary_lattice_eq(carry_batched_out, carry_batched):
            break
        else:
            carry_batched = _binary_lattice_join(carry_batched_out,
                                                 carry_batched)
    else:
        raise FixedPointError

    consts_batched = batching.instantiate_bdim(size, 0, consts_batched,
                                               consts_bdim, consts)
    init_batched = batching.instantiate_bdim(size, 0, carry_batched, init_bdim,
                                             init)
    xs_batched = batching.instantiate_bdim(size, 1, xs_batched, xs_bdim, xs)

    carry_out, ys = scan_p.bind(consts_batched,
                                init_batched,
                                xs_batched,
                                forward=forward,
                                length=length,
                                jaxpr=jaxpr_batched)

    carry_out_bdim = batching.bools_to_bdims(0, carry_batched)
    ys_bdim = batching.bools_to_bdims(1, ys_batched)
    return core.pack((carry_out, ys)), (carry_out_bdim, ys_bdim)