Пример #1
0
def _while_loop_batching_rule(batched_args, batch_dims, cond_consts,
                              body_consts, aval_out, cond_jaxpr, body_jaxpr):
    # See https://github.com/google/jax/issues/441 for a discussion.
    # To batch a while_loop, we need to do some masking, since the elements of the
    # batch may run for different numbers of iterations. We perform that masking
    # using lax.select, and keep the loop running so long as any of the batch
    # elements need by effectively using an np.any(...) in the cond_fun.
    # The basic strategy here is to lift `cond_jaxpr` and `body_jaxpr` back into
    # traceable Python functions using `core.eval_jaxpr`. Then we can batch them
    # using `batching.batch_transform` (the transform underlying `api.vmap`). This
    # code also avoids broadcasting `cond_tracer_consts` and `body_tracer_consts`.
    init_val, cond_tracer_consts, body_tracer_consts = batched_args
    init_val_bd, cond_tracer_consts_bd, body_tracer_consts_bd = batch_dims

    sizes = lax._reduce(set.union,
                        map(batching.dimsize, batch_dims, batched_args))
    size = sizes.pop()
    assert not sizes

    # TODO(mattjj): if cond_tracer_consts_bd is also None, we could keep cond_fun
    # unbatched and avoid the masking logic, but we ignore that optimization
    init_val = batching.bdim_at_front(init_val,
                                      init_val_bd,
                                      size,
                                      force_broadcast=True)
    init_val_bd = 0

    def batched_cond_fun(batched_loop_carry):
        @lu.wrap_init
        def lifted(loop_carry, cond_tracer_consts):
            cond_tracer_consts = tuple(x for x in cond_tracer_consts)
            return core.eval_jaxpr(cond_jaxpr,
                                   cond_consts.val + cond_tracer_consts, (),
                                   loop_carry)

        f = batching.batch_transform(lifted, size,
                                     (init_val_bd, cond_tracer_consts_bd), 0)
        preds = f.call_wrapped((batched_loop_carry, cond_tracer_consts))
        return lax.reduce(preds, onp.array(False), lax.bitwise_or, [0])

    def batched_body_fun(batched_loop_carry):
        @lu.wrap_init
        def lifted(loop_carry, cond_tracer_consts, body_tracer_consts):
            cond_tracer_consts = tuple(x for x in cond_tracer_consts)
            body_tracer_consts = tuple(x for x in body_tracer_consts)
            pred = core.eval_jaxpr(cond_jaxpr,
                                   cond_consts.val + cond_tracer_consts, (),
                                   loop_carry)
            new_loop_carry = core.eval_jaxpr(
                body_jaxpr, body_consts.val + body_tracer_consts, (),
                loop_carry)
            return _jaxtupletree_select(pred, new_loop_carry, loop_carry)

        f = batching.batch_transform(
            lifted, size,
            (init_val_bd, cond_tracer_consts_bd, body_tracer_consts_bd),
            init_val_bd)
        return f.call_wrapped(
            (batched_loop_carry, cond_tracer_consts, body_tracer_consts))

    return while_loop(batched_cond_fun, batched_body_fun,
                      init_val), init_val_bd
Пример #2
0
def svd_batching_rule(batched_args, batch_dims, full_matrices, compute_uv):
    x, = batched_args
    bd, = batch_dims
    x = batching.bdim_at_front(x, bd)
    return svd_p.bind(x, full_matrices=full_matrices, compute_uv=compute_uv), 0
Пример #3
0
def qr_batching_rule(batched_args, batch_dims, full_matrices):
    x, = batched_args
    bd, = batch_dims
    x = batching.bdim_at_front(x, bd)
    return qr_p.bind(x, full_matrices=full_matrices), 0
Пример #4
0
def lu_batching_rule(batched_args, batch_dims):
    x, = batched_args
    bd, = batch_dims
    x = batching.bdim_at_front(x, bd)
    return lu_p.bind(x), 0
Пример #5
0
def eigh_batching_rule(batched_args, batch_dims, lower):
    x, = batched_args
    bd, = batch_dims
    x = batching.bdim_at_front(x, bd)
    return eigh_p.bind(x, lower=lower), 0
Пример #6
0
def cholesky_batching_rule(batched_args, batch_dims):
    x, = batched_args
    bd, = batch_dims
    x = batching.bdim_at_front(x, bd)
    return cholesky(x), 0