def _cond_batching_rule(args, dims, true_jaxpr, false_jaxpr, true_nconsts, false_nconsts): # TODO: maybe avoid moving arg axes to front if we're promoting to select? args = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0 else x for x, d in zip(args, dims)] true_nops = len(true_jaxpr.in_avals) - true_nconsts (pred,), true_consts, true_ops, false_consts, false_ops = split_list( args, [1, true_nconsts, true_nops, false_nconsts]) size, = {x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped} orig_bat = [d is not batching.not_mapped for d in dims] (pred_bat,), tconst_bat, t_bat, fconst_bat, f_bat = split_list( orig_bat, [1, true_nconsts, true_nops, false_nconsts]) _, true_out_bat = batching.batch_jaxpr(true_jaxpr, size, tconst_bat + t_bat, False) _, false_out_bat = batching.batch_jaxpr(false_jaxpr, size, fconst_bat + f_bat, False) out_bat = [a or b for a, b in zip(true_out_bat, false_out_bat)] true_jaxpr_batched, _ = batching.batch_jaxpr(true_jaxpr, size, tconst_bat + t_bat, out_bat) false_jaxpr_batched, _ = batching.batch_jaxpr(false_jaxpr, size, fconst_bat + f_bat, out_bat) if pred_bat: true_out = core.jaxpr_as_fun(true_jaxpr_batched)(*(true_consts + true_ops)) false_out = core.jaxpr_as_fun(false_jaxpr_batched)(*(false_consts + false_ops)) true_out = [batching.broadcast(x, size, 0) if not b else x for x, b in zip(true_out, out_bat)] false_out = [batching.broadcast(x, size, 0) if not b else x for x, b in zip(false_out, out_bat)] return [_cond_pred_bcast_select(pred, t, f) for t, f in zip(true_out, false_out)], [0] * len(true_out) else: out_dims = [0 if b else batching.not_mapped for b in out_bat] return cond_p.bind( *itertools.chain([pred], true_consts, true_ops, false_consts, false_ops), true_jaxpr=true_jaxpr_batched, false_jaxpr=false_jaxpr_batched, true_nconsts=len(true_consts), false_nconsts=len(false_consts)), out_dims
def _while_loop_batching_rule(args, dims, cond_nconsts, cond_jaxpr, body_nconsts, body_jaxpr): size, = {x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped} orig_batched = [d is not batching.not_mapped for d in dims] cconst_bat, bconst_bat, init_bat = split_list(orig_batched, [cond_nconsts, body_nconsts]) carry_bat = init_bat for _ in range(1000): batched = bconst_bat + carry_bat body_jaxpr_batched, carry_bat_out = batching.batch_jaxpr( body_jaxpr, size, batched, instantiate=carry_bat) cond_jaxpr_batched, (pred_bat,) = batching.batch_jaxpr( cond_jaxpr, size, cconst_bat + carry_bat, instantiate=False) carry_bat_out = _map(partial(operator.or_, pred_bat), carry_bat_out) if carry_bat_out == carry_bat: break else: carry_bat = carry_bat_out else: raise FixedPointError consts, init = split_list(args, [cond_nconsts + body_nconsts]) const_dims, init_dims = split_list(dims, [cond_nconsts + body_nconsts]) new_consts = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0 else x for x, d in zip(consts, const_dims)] new_init = [batching.broadcast(x, size, 0) if now_bat and not was_bat else batching.moveaxis(x, d, 0) if now_bat else x for x, d, was_bat, now_bat in zip(init, init_dims, init_bat, carry_bat)] outs = while_p.bind(*(new_consts + new_init), cond_nconsts=cond_nconsts, cond_jaxpr=cond_jaxpr_batched, body_nconsts=body_nconsts, body_jaxpr=body_jaxpr_batched) out_bdims = [0 if b else batching.not_mapped for b in carry_bat] return outs, out_bdims
def _scan_batching_rule(args, dims, forward, length, jaxpr, num_consts, num_carry, linear): num_ys = len(jaxpr.out_avals) - num_carry size, = { x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped } orig_batched = [d is not batching.not_mapped for d in dims] const_batched, init_batched, xs_batched = split_list( orig_batched, [num_consts, num_carry]) carry_batched = init_batched for _ in range(1000): batched = const_batched + carry_batched + xs_batched jaxpr_batched, batched_out = batching.batch_jaxpr( jaxpr, size, batched, instantiate=carry_batched + [False] * num_ys) carry_batched_out, ys_batched = batched_out[:num_carry], batched_out[ num_carry:] if carry_batched_out == carry_batched: break else: carry_batched = carry_batched_out else: raise FixedPointError consts, init, xs = split_list(args, [num_consts, num_carry]) consts_bdims, init_bdims, xs_bdims = split_list(dims, [num_consts, num_carry]) new_consts = [ batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0 else x for x, d in zip(consts, consts_bdims) ] new_init = [ batching.broadcast(x, size, 0) if now_batched and not was_batched else batching.moveaxis(x, d, 0) if now_batched else x for x, d, was_batched, now_batched in zip(init, init_bdims, init_batched, carry_batched) ] new_xs = [ batching.moveaxis(x, d, 1) if d is not batching.not_mapped and d != 1 else x for x, d in zip(xs, xs_bdims) ] new_args = new_consts + new_init + new_xs outs = scan_p.bind(*new_args, forward=forward, length=length, jaxpr=jaxpr_batched, num_consts=num_consts, num_carry=num_carry, linear=linear) carry_bdims = [0 if b else batching.not_mapped for b in carry_batched] ys_bdims = [1 if b else batching.not_mapped for b in ys_batched] return outs, carry_bdims + ys_bdims
def _linear_solve_batching_rule(axis_size, axis_name, main_type, args, dims, const_lengths, jaxprs): orig_bat = [d is not batching.not_mapped for d in dims] params, b = _split_linear_solve_args(args, const_lengths) params_dims, b_dims = _split_linear_solve_args(dims, const_lengths) params_bat, orig_b_bat = _split_linear_solve_args(orig_bat, const_lengths) (matvec, vecmat, solve, solve_t) = jaxprs (matvec_bat, vecmat_bat, solve_bat, solve_t_bat) = params_bat num_aux = len(solve.out_avals) - len(matvec.out_avals) # Fixpoint computation of which parts of x and b are batched; we need to # ensure this is consistent between all four jaxprs b_bat = orig_b_bat x_bat = [False] * len(solve.out_avals) for i in range(1 + len(orig_b_bat) + len(solve.out_avals)): # Apply vecmat and solve -> new batched parts of x solve_jaxpr_batched, solve_x_bat = batching.batch_jaxpr( solve, axis_size, solve_bat + b_bat, instantiate=x_bat, axis_name=axis_name, main_type=main_type) if vecmat is None: vecmat_jaxpr_batched = None x_bat_out = solve_x_bat else: vecmat_jaxpr_batched, vecmat_x_bat = batching.batch_jaxpr( vecmat, axis_size, vecmat_bat + b_bat, instantiate=x_bat, axis_name=axis_name, main_type=main_type) # batch all aux data by default x_bat_out = _map(operator.or_, vecmat_x_bat + [True] * num_aux, solve_x_bat) # Apply matvec and solve_t -> new batched parts of b matvec_jaxpr_batched, matvec_b_bat = batching.batch_jaxpr( matvec, axis_size, matvec_bat + x_bat_out, instantiate=b_bat, axis_name=axis_name, main_type=main_type) if solve_t is None: solve_t_jaxpr_batched = None b_bat_out = _map(operator.or_, matvec_b_bat, orig_b_bat) else: solve_t_jaxpr_batched, solve_t_b_aux_bat = batching.batch_jaxpr( solve_t, axis_size, solve_t_bat + x_bat_out, instantiate=b_bat, axis_name=axis_name, main_type=main_type) assert len(solve_t_b_aux_bat) == len(orig_b_bat) + num_aux solve_t_b_bat, _ = split_list(solve_t_b_aux_bat, [len(orig_b_bat)]) b_bat_out = _map(lambda m, s, o: m or s or o, matvec_b_bat, solve_t_b_bat, orig_b_bat) if x_bat_out == x_bat and b_bat_out == b_bat: break else: x_bat = x_bat_out b_bat = b_bat_out else: assert False, "Fixedpoint not reached" batched_jaxprs = _LinearSolveTuple(matvec_jaxpr_batched, vecmat_jaxpr_batched, solve_jaxpr_batched, solve_t_jaxpr_batched) # Move batched axes to the front new_params = [ batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0 else x for x, d in zip(_flatten(params), _flatten(params_dims)) ] # Broadcast out b if necessary new_b = [ batching.broadcast(x, axis_size, 0) if now_bat and not was_bat else batching.moveaxis(x, d, 0) if now_bat and d != 0 else x for x, d, was_bat, now_bat in zip(b, b_dims, orig_b_bat, b_bat) ] outs = linear_solve_p.bind(*(new_params + new_b), const_lengths=const_lengths, jaxprs=batched_jaxprs) out_dims = [ 0 if batched else batching.not_mapped for batched in solve_x_bat ] return outs, out_dims