def _while_loop_batching_rule(batched_args, batch_dims, aval_out, cond_jaxpr, body_jaxpr): # See https://github.com/google/jax/issues/441 for a discussion. # To batch a while_loop, we need to do some masking, since the elements of the # batch may run for different numbers of iterations. We perform that masking # using lax.select, and keep the loop running so long as any of the batch # elements need by effectively using an np.any(...) in the cond_fun. # The basic strategy here is to lift `cond_jaxpr` and `body_jaxpr` back into # traceable Python functions using `core.eval_jaxpr`. Then we can batch them # using `batching.batch_transform` (the transform underlying `api.vmap`). # TODO(mattjj): Revise this using scan machinery (and fixed-point the loop # carry instead of lifting it all the way!) init_val, cond_consts, body_consts = batched_args init_val_bd, cond_consts_bd, body_consts_bd = batch_dims sizes = lax._reduce(set.union, map(batching.dimsize, batch_dims, batched_args)) size = sizes.pop() assert not sizes # TODO(mattjj): if cond_consts_bd is also None, we could keep cond_fun # unbatched and avoid the masking logic, but we ignore that optimization init_val = batching.bdim_at_front(init_val, init_val_bd, size, force_broadcast=True) init_val_bd = 0 def batched_cond_fun(batched_loop_carry): @lu.wrap_init def lifted(loop_carry, cond_consts): return core.eval_jaxpr(cond_jaxpr, cond_consts, (), loop_carry) f = batching.batch_transform(lifted, size, (init_val_bd, cond_consts_bd), 0) preds = f.call_wrapped((batched_loop_carry, cond_consts)) return lax.reduce(preds, onp.array(False), lax.bitwise_or, [0]) def batched_body_fun(batched_loop_carry): @lu.wrap_init def lifted(loop_carry, cond_consts, body_consts): pred = core.eval_jaxpr(cond_jaxpr, cond_consts, (), loop_carry) new_loop_carry = core.eval_jaxpr(body_jaxpr, body_consts, (), loop_carry) return _jaxtupletree_select(pred, new_loop_carry, loop_carry) f = batching.batch_transform( lifted, size, (init_val_bd, cond_consts_bd, body_consts_bd), init_val_bd) return f.call_wrapped((batched_loop_carry, cond_consts, body_consts)) return while_loop(batched_cond_fun, batched_body_fun, init_val), init_val_bd
def _scan_batching_rule(batched_args, batch_dims, forward, length, jaxpr): consts, init, xs = batched_args consts_bdim, init_bdim, xs_bdim = batch_dims sizes = lax._reduce(set.union, map(batching.dimsize, batch_dims, batched_args)) size = sizes.pop() assert not sizes consts_batched = batching.where_batched(consts_bdim) init_batched = batching.where_batched(init_bdim) xs_batched = batching.where_batched(xs_bdim) carry_batched = init_batched for _ in range(1000): which_batched = (consts_batched, carry_batched, xs_batched) jaxpr_batched, batched_out = batching.batch_jaxpr( jaxpr, size, which_batched, instantiate=(carry_batched, False)) carry_batched_out, ys_batched = batched_out if _binary_lattice_eq(carry_batched_out, carry_batched): break else: carry_batched = _binary_lattice_join(carry_batched_out, carry_batched) else: raise FixedPointError consts_batched = batching.instantiate_bdim(size, 0, consts_batched, consts_bdim, consts) init_batched = batching.instantiate_bdim(size, 0, carry_batched, init_bdim, init) xs_batched = batching.instantiate_bdim(size, 1, xs_batched, xs_bdim, xs) carry_out, ys = scan_p.bind(consts_batched, init_batched, xs_batched, forward=forward, length=length, jaxpr=jaxpr_batched) carry_out_bdim = batching.bools_to_bdims(0, carry_batched) ys_bdim = batching.bools_to_bdims(1, ys_batched) return core.pack((carry_out, ys)), (carry_out_bdim, ys_bdim)