def _while_loop_batching_rule(batched_args, batch_dims, cond_consts, body_consts, aval_out, cond_jaxpr, body_jaxpr): # See https://github.com/google/jax/issues/441 for a discussion. # To batch a while_loop, we need to do some masking, since the elements of the # batch may run for different numbers of iterations. We perform that masking # using lax.select, and keep the loop running so long as any of the batch # elements need by effectively using an np.any(...) in the cond_fun. # The basic strategy here is to lift `cond_jaxpr` and `body_jaxpr` back into # traceable Python functions using `core.eval_jaxpr`. Then we can batch them # using `batching.batch_transform` (the transform underlying `api.vmap`). This # code also avoids broadcasting `cond_tracer_consts` and `body_tracer_consts`. init_val, cond_tracer_consts, body_tracer_consts = batched_args init_val_bd, cond_tracer_consts_bd, body_tracer_consts_bd = batch_dims sizes = lax._reduce(set.union, map(batching.dimsize, batch_dims, batched_args)) size = sizes.pop() assert not sizes # TODO(mattjj): if cond_tracer_consts_bd is also None, we could keep cond_fun # unbatched and avoid the masking logic, but we ignore that optimization init_val = batching.bdim_at_front(init_val, init_val_bd, size, force_broadcast=True) init_val_bd = 0 def batched_cond_fun(batched_loop_carry): @lu.wrap_init def lifted(loop_carry, cond_tracer_consts): cond_tracer_consts = tuple(x for x in cond_tracer_consts) return core.eval_jaxpr(cond_jaxpr, cond_consts.val + cond_tracer_consts, (), loop_carry) f = batching.batch_transform(lifted, size, (init_val_bd, cond_tracer_consts_bd), 0) preds = f.call_wrapped((batched_loop_carry, cond_tracer_consts)) return lax.reduce(preds, onp.array(False), lax.bitwise_or, [0]) def batched_body_fun(batched_loop_carry): @lu.wrap_init def lifted(loop_carry, cond_tracer_consts, body_tracer_consts): cond_tracer_consts = tuple(x for x in cond_tracer_consts) body_tracer_consts = tuple(x for x in body_tracer_consts) pred = core.eval_jaxpr(cond_jaxpr, cond_consts.val + cond_tracer_consts, (), loop_carry) new_loop_carry = core.eval_jaxpr( body_jaxpr, body_consts.val + body_tracer_consts, (), loop_carry) return _jaxtupletree_select(pred, new_loop_carry, loop_carry) f = batching.batch_transform( lifted, size, (init_val_bd, cond_tracer_consts_bd, body_tracer_consts_bd), init_val_bd) return f.call_wrapped( (batched_loop_carry, cond_tracer_consts, body_tracer_consts)) return while_loop(batched_cond_fun, batched_body_fun, init_val), init_val_bd
def svd_batching_rule(batched_args, batch_dims, full_matrices, compute_uv): x, = batched_args bd, = batch_dims x = batching.bdim_at_front(x, bd) return svd_p.bind(x, full_matrices=full_matrices, compute_uv=compute_uv), 0
def qr_batching_rule(batched_args, batch_dims, full_matrices): x, = batched_args bd, = batch_dims x = batching.bdim_at_front(x, bd) return qr_p.bind(x, full_matrices=full_matrices), 0
def lu_batching_rule(batched_args, batch_dims): x, = batched_args bd, = batch_dims x = batching.bdim_at_front(x, bd) return lu_p.bind(x), 0
def eigh_batching_rule(batched_args, batch_dims, lower): x, = batched_args bd, = batch_dims x = batching.bdim_at_front(x, bd) return eigh_p.bind(x, lower=lower), 0
def cholesky_batching_rule(batched_args, batch_dims): x, = batched_args bd, = batch_dims x = batching.bdim_at_front(x, bd) return cholesky(x), 0