def _cond_batching_rule(args, dims, true_jaxpr, false_jaxpr, true_nconsts, false_nconsts): # TODO: maybe avoid moving arg axes to front if we're promoting to select? args = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0 else x for x, d in zip(args, dims)] true_nops = len(true_jaxpr.in_avals) - true_nconsts (pred,), true_consts, true_ops, false_consts, false_ops = split_list( args, [1, true_nconsts, true_nops, false_nconsts]) size, = {x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped} orig_bat = [d is not batching.not_mapped for d in dims] (pred_bat,), tconst_bat, t_bat, fconst_bat, f_bat = split_list( orig_bat, [1, true_nconsts, true_nops, false_nconsts]) _, true_out_bat = batching.batch_jaxpr(true_jaxpr, size, tconst_bat + t_bat, False) _, false_out_bat = batching.batch_jaxpr(false_jaxpr, size, fconst_bat + f_bat, False) out_bat = [a or b for a, b in zip(true_out_bat, false_out_bat)] true_jaxpr_batched, _ = batching.batch_jaxpr(true_jaxpr, size, tconst_bat + t_bat, out_bat) false_jaxpr_batched, _ = batching.batch_jaxpr(false_jaxpr, size, fconst_bat + f_bat, out_bat) if pred_bat: true_out = core.jaxpr_as_fun(true_jaxpr_batched)(*(true_consts + true_ops)) false_out = core.jaxpr_as_fun(false_jaxpr_batched)(*(false_consts + false_ops)) true_out = [batching.broadcast(x, size, 0) if not b else x for x, b in zip(true_out, out_bat)] false_out = [batching.broadcast(x, size, 0) if not b else x for x, b in zip(false_out, out_bat)] return [_cond_pred_bcast_select(pred, t, f) for t, f in zip(true_out, false_out)], [0] * len(true_out) else: out_dims = [0 if b else batching.not_mapped for b in out_bat] return cond_p.bind( *itertools.chain([pred], true_consts, true_ops, false_consts, false_ops), true_jaxpr=true_jaxpr_batched, false_jaxpr=false_jaxpr_batched, true_nconsts=len(true_consts), false_nconsts=len(false_consts)), out_dims
def _while_loop_batching_rule(args, dims, cond_nconsts, cond_jaxpr, body_nconsts, body_jaxpr): size, = {x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped} orig_batched = [d is not batching.not_mapped for d in dims] cconst_bat, bconst_bat, init_bat = split_list(orig_batched, [cond_nconsts, body_nconsts]) carry_bat = init_bat for _ in range(1000): batched = bconst_bat + carry_bat body_jaxpr_batched, carry_bat_out = batching.batch_jaxpr( body_jaxpr, size, batched, instantiate=carry_bat) cond_jaxpr_batched, (pred_bat,) = batching.batch_jaxpr( cond_jaxpr, size, cconst_bat + carry_bat, instantiate=False) carry_bat_out = _map(partial(operator.or_, pred_bat), carry_bat_out) if carry_bat_out == carry_bat: break else: carry_bat = carry_bat_out else: raise FixedPointError consts, init = split_list(args, [cond_nconsts + body_nconsts]) const_dims, init_dims = split_list(dims, [cond_nconsts + body_nconsts]) new_consts = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0 else x for x, d in zip(consts, const_dims)] new_init = [batching.broadcast(x, size, 0) if now_bat and not was_bat else batching.moveaxis(x, d, 0) if now_bat else x for x, d, was_bat, now_bat in zip(init, init_dims, init_bat, carry_bat)] outs = while_p.bind(*(new_consts + new_init), cond_nconsts=cond_nconsts, cond_jaxpr=cond_jaxpr_batched, body_nconsts=body_nconsts, body_jaxpr=body_jaxpr_batched) out_bdims = [0 if b else batching.not_mapped for b in carry_bat] return outs, out_bdims
def _cond_batching_rule(axis_size, axis_name, main_type, args, dims, branches, linear): index, *ops = args index_dim, *op_dims = dims if index_dim is not batching.not_mapped: # Convert to a lax.select. While we could get away with not broadcasting # some operands yet, because all outputs must be broadcast together anyway # for the select we broadcast the input operands for simplicity and leave # optimizations to XLA. # TODO(mattjj,frostig): assumes branches are side-effect-free, revise! index, *ops = (batching.bdim_at_front(x, d, axis_size) for x, d in zip(args, dims)) in_batched = [True] * len(branches[0].in_avals) out_batched = [True] * len(branches[0].out_avals) branches_batched = [ batching.batch_jaxpr(jaxpr, axis_size, in_batched, out_batched, axis_name, main_type)[0] for jaxpr in branches ] branch_outs = [] for i, jaxpr in enumerate(branches_batched): # Perform a select on the inputs for safety of reverse-mode autodiff; see # https://github.com/google/jax/issues/1052 predicate = lax.eq(index, lax._const(index, i)) ops_ = [ _bcast_select(predicate, x, lax.stop_gradient(x)) for x in ops ] branch_outs.append(core.jaxpr_as_fun(jaxpr)(*ops_)) out = [_bcast_select_n(index, *outs) for outs in zip(*branch_outs)] return out, [0 if b else None for b in out_batched] else: ops_bat = [d is not batching.not_mapped for d in op_dims] ops = [ batching.moveaxis(x, d, 0) if b else x for b, x, d in zip(ops_bat, ops, op_dims) ] branches_out_bat = [ batching.batch_jaxpr(jaxpr, axis_size, ops_bat, False, axis_name, main_type)[1] for jaxpr in branches ] out_bat = [any(bat) for bat in zip(*branches_out_bat)] branches_batched = tuple( batching.batch_jaxpr(jaxpr, axis_size, ops_bat, out_bat, axis_name, main_type)[0] for jaxpr in branches) out_dims = [0 if b else batching.not_mapped for b in out_bat] out = cond_p.bind(index, *ops, branches=branches_batched, linear=linear) return out, out_dims
def batched_jvp_jaxpr_thunk(): jvp_jaxpr = core.ClosedJaxpr(*jvp_jaxpr_thunk()) # consts can be tracers _, args_batched = split_list(in_batched, [num_consts]) _, all_batched = batching.batch_jaxpr(jvp_jaxpr, size, args_batched * 2, False, axis_name, main_type) primals_batched, tangents_batched = split_list(all_batched, [num_out]) out_batched = map(op.or_, primals_batched, tangents_batched) out_dims2.append([0 if b else not_mapped for b in out_batched]) batched_jvp_jaxpr, _ = batching.batch_jaxpr( jvp_jaxpr, size, args_batched * 2, out_batched * 2, axis_name, main_type) return batched_jvp_jaxpr.jaxpr, batched_jvp_jaxpr.consts
def batched_fwd_jaxpr_thunk(): fwd_jaxpr = core.ClosedJaxpr( *fwd_jaxpr_thunk()) # consts can be tracers batched_fwd_jaxpr, out_batched = batching.batch_jaxpr( fwd_jaxpr, axis_size, args_batched, False, axis_name, main_type) out_dims2.append([0 if b else not_mapped for b in out_batched]) return batched_fwd_jaxpr.jaxpr, batched_fwd_jaxpr.consts
def _custom_vjp_call_jaxpr_vmap( args, in_dims, axis_name, main_type, *, fun_jaxpr: core.ClosedJaxpr, fwd_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]], bwd: lu.WrappedFun, out_trees: Callable, num_consts: int): axis_size, = {x.shape[d] for x, d in zip(args, in_dims) if d is not not_mapped} args = [batching.moveaxis(x, d, 0) if d is not not_mapped and d != 0 else x for x, d in zip(args, in_dims)] in_batched = [d is not not_mapped for d in in_dims] _, args_batched = split_list(in_batched, [num_consts]) batched_fun_jaxpr, out_batched = batching.batch_jaxpr( fun_jaxpr, axis_size, in_batched, False, axis_name, main_type) out_dims1 = [0 if b else not_mapped for b in out_batched] out_dims2 = [] @pe._memoize def batched_fwd_jaxpr_thunk(): fwd_jaxpr = core.ClosedJaxpr(*fwd_jaxpr_thunk()) # consts can be tracers batched_fwd_jaxpr, out_batched = batching.batch_jaxpr( fwd_jaxpr, axis_size, args_batched, False, axis_name, main_type) out_dims2.append([0 if b else not_mapped for b in out_batched]) return batched_fwd_jaxpr.jaxpr, batched_fwd_jaxpr.consts fwd_args_batched = [0 if b else not_mapped for b in args_batched] fwd_out_dims = lambda: out_dims2[0] batched_bwd = batching.batch_custom_vjp_bwd(bwd, axis_name, axis_size, fwd_out_dims, fwd_args_batched, main_type) batched_outs = custom_vjp_call_jaxpr_p.bind( *args, fun_jaxpr=batched_fun_jaxpr, fwd_jaxpr_thunk=batched_fwd_jaxpr_thunk, bwd=batched_bwd, out_trees=out_trees, num_consts=num_consts) out_dims = out_dims2[0] if out_dims2 else out_dims1 return batched_outs, out_dims
def _custom_jvp_call_jaxpr_vmap( args, in_dims, axis_name, main_type, *, fun_jaxpr: core.ClosedJaxpr, jvp_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]], num_consts: int): size, = {x.shape[d] for x, d in zip(args, in_dims) if d is not not_mapped} args = [batching.moveaxis(x, d, 0) if d is not not_mapped and d != 0 else x for x, d in zip(args, in_dims)] num_out = len(fun_jaxpr.out_avals) in_batched = [d is not not_mapped for d in in_dims] batched_fun_jaxpr, out_batched = batching.batch_jaxpr( fun_jaxpr, size, in_batched, False, axis_name, main_type) out_dims1 = [0 if b else not_mapped for b in out_batched] out_dims2 = [] # mutable cell updated by batched_jvp_jaxpr_thunk @pe._memoize def batched_jvp_jaxpr_thunk(): jvp_jaxpr = core.ClosedJaxpr(*jvp_jaxpr_thunk()) # consts can be tracers _, args_batched = split_list(in_batched, [num_consts]) _, all_batched = batching.batch_jaxpr(jvp_jaxpr, size, args_batched * 2, False, axis_name, main_type) primals_batched, tangents_batched = split_list(all_batched, [num_out]) out_batched = map(op.or_, primals_batched, tangents_batched) out_dims2.append([0 if b else not_mapped for b in out_batched]) batched_jvp_jaxpr, _ = batching.batch_jaxpr( jvp_jaxpr, size, args_batched * 2, out_batched * 2, axis_name, main_type) return batched_jvp_jaxpr.jaxpr, batched_jvp_jaxpr.consts batched_outs = custom_jvp_call_jaxpr_p.bind( *args, fun_jaxpr=batched_fun_jaxpr, jvp_jaxpr_thunk=batched_jvp_jaxpr_thunk, num_consts=num_consts) out_dims = out_dims2[0] if out_dims2 else out_dims1 return batched_outs, out_dims
def _scan_batching_rule(args, dims, forward, length, jaxpr, num_consts, num_carry, linear): num_ys = len(jaxpr.out_avals) - num_carry size, = { x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped } orig_batched = [d is not batching.not_mapped for d in dims] const_batched, init_batched, xs_batched = split_list( orig_batched, [num_consts, num_carry]) carry_batched = init_batched for _ in range(1000): batched = const_batched + carry_batched + xs_batched jaxpr_batched, batched_out = batching.batch_jaxpr( jaxpr, size, batched, instantiate=carry_batched + [False] * num_ys) carry_batched_out, ys_batched = batched_out[:num_carry], batched_out[ num_carry:] if carry_batched_out == carry_batched: break else: carry_batched = carry_batched_out else: raise FixedPointError consts, init, xs = split_list(args, [num_consts, num_carry]) consts_bdims, init_bdims, xs_bdims = split_list(dims, [num_consts, num_carry]) new_consts = [ batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0 else x for x, d in zip(consts, consts_bdims) ] new_init = [ batching.broadcast(x, size, 0) if now_batched and not was_batched else batching.moveaxis(x, d, 0) if now_batched else x for x, d, was_batched, now_batched in zip(init, init_bdims, init_batched, carry_batched) ] new_xs = [ batching.moveaxis(x, d, 1) if d is not batching.not_mapped and d != 1 else x for x, d in zip(xs, xs_bdims) ] new_args = new_consts + new_init + new_xs outs = scan_p.bind(*new_args, forward=forward, length=length, jaxpr=jaxpr_batched, num_consts=num_consts, num_carry=num_carry, linear=linear) carry_bdims = [0 if b else batching.not_mapped for b in carry_batched] ys_bdims = [1 if b else batching.not_mapped for b in ys_batched] return outs, carry_bdims + ys_bdims
def remat_vmap(axis_size, axis_name, main_type, args, dims, *, jaxpr, **params): assert not jaxpr.constvars in_batched = [d is not batching.not_mapped for d in dims] jaxpr_ = core.ClosedJaxpr(jaxpr, ()) jaxpr_batched_, out_batched = batching.batch_jaxpr( jaxpr_, axis_size, in_batched, instantiate=False, axis_name=axis_name, main_type=main_type) jaxpr_batched, consts = jaxpr_batched_.jaxpr, jaxpr_batched_.consts out_dims = [0 if b else None for b in out_batched] return remat_p.bind(*consts, *args, jaxpr=jaxpr_batched, **params), out_dims
def _scan_batching_rule(batched_args, batch_dims, forward, length, jaxpr): consts, init, xs = batched_args consts_bdim, init_bdim, xs_bdim = batch_dims sizes = lax._reduce(set.union, map(batching.dimsize, batch_dims, batched_args)) size = sizes.pop() assert not sizes consts_batched = batching.where_batched(consts_bdim) init_batched = batching.where_batched(init_bdim) xs_batched = batching.where_batched(xs_bdim) carry_batched = init_batched for _ in range(1000): which_batched = (consts_batched, carry_batched, xs_batched) jaxpr_batched, batched_out = batching.batch_jaxpr( jaxpr, size, which_batched, instantiate=(carry_batched, False)) carry_batched_out, ys_batched = batched_out if _binary_lattice_eq(carry_batched_out, carry_batched): break else: carry_batched = _binary_lattice_join(carry_batched_out, carry_batched) else: raise FixedPointError consts_batched = batching.instantiate_bdim(size, 0, consts_batched, consts_bdim, consts) init_batched = batching.instantiate_bdim(size, 0, carry_batched, init_bdim, init) xs_batched = batching.instantiate_bdim(size, 1, xs_batched, xs_bdim, xs) carry_out, ys = scan_p.bind(consts_batched, init_batched, xs_batched, forward=forward, length=length, jaxpr=jaxpr_batched) carry_out_bdim = batching.bools_to_bdims(0, carry_batched) ys_bdim = batching.bools_to_bdims(1, ys_batched) return core.pack((carry_out, ys)), (carry_out_bdim, ys_bdim)
def _linear_solve_batching_rule(axis_size, axis_name, main_type, args, dims, const_lengths, jaxprs): orig_bat = [d is not batching.not_mapped for d in dims] params, b = _split_linear_solve_args(args, const_lengths) params_dims, b_dims = _split_linear_solve_args(dims, const_lengths) params_bat, orig_b_bat = _split_linear_solve_args(orig_bat, const_lengths) (matvec, vecmat, solve, solve_t) = jaxprs (matvec_bat, vecmat_bat, solve_bat, solve_t_bat) = params_bat num_aux = len(solve.out_avals) - len(matvec.out_avals) # Fixpoint computation of which parts of x and b are batched; we need to # ensure this is consistent between all four jaxprs b_bat = orig_b_bat x_bat = [False] * len(solve.out_avals) for i in range(1 + len(orig_b_bat) + len(solve.out_avals)): # Apply vecmat and solve -> new batched parts of x solve_jaxpr_batched, solve_x_bat = batching.batch_jaxpr( solve, axis_size, solve_bat + b_bat, instantiate=x_bat, axis_name=axis_name, main_type=main_type) if vecmat is None: vecmat_jaxpr_batched = None x_bat_out = solve_x_bat else: vecmat_jaxpr_batched, vecmat_x_bat = batching.batch_jaxpr( vecmat, axis_size, vecmat_bat + b_bat, instantiate=x_bat, axis_name=axis_name, main_type=main_type) # batch all aux data by default x_bat_out = _map(operator.or_, vecmat_x_bat + [True] * num_aux, solve_x_bat) # Apply matvec and solve_t -> new batched parts of b matvec_jaxpr_batched, matvec_b_bat = batching.batch_jaxpr( matvec, axis_size, matvec_bat + x_bat_out, instantiate=b_bat, axis_name=axis_name, main_type=main_type) if solve_t is None: solve_t_jaxpr_batched = None b_bat_out = _map(operator.or_, matvec_b_bat, orig_b_bat) else: solve_t_jaxpr_batched, solve_t_b_aux_bat = batching.batch_jaxpr( solve_t, axis_size, solve_t_bat + x_bat_out, instantiate=b_bat, axis_name=axis_name, main_type=main_type) assert len(solve_t_b_aux_bat) == len(orig_b_bat) + num_aux solve_t_b_bat, _ = split_list(solve_t_b_aux_bat, [len(orig_b_bat)]) b_bat_out = _map(lambda m, s, o: m or s or o, matvec_b_bat, solve_t_b_bat, orig_b_bat) if x_bat_out == x_bat and b_bat_out == b_bat: break else: x_bat = x_bat_out b_bat = b_bat_out else: assert False, "Fixedpoint not reached" batched_jaxprs = _LinearSolveTuple(matvec_jaxpr_batched, vecmat_jaxpr_batched, solve_jaxpr_batched, solve_t_jaxpr_batched) # Move batched axes to the front new_params = [ batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0 else x for x, d in zip(_flatten(params), _flatten(params_dims)) ] # Broadcast out b if necessary new_b = [ batching.broadcast(x, axis_size, 0) if now_bat and not was_bat else batching.moveaxis(x, d, 0) if now_bat and d != 0 else x for x, d, was_bat, now_bat in zip(b, b_dims, orig_b_bat, b_bat) ] outs = linear_solve_p.bind(*(new_params + new_b), const_lengths=const_lengths, jaxprs=batched_jaxprs) out_dims = [ 0 if batched else batching.not_mapped for batched in solve_x_bat ] return outs, out_dims