def triangular_solve_batching_rule(batched_args, batch_dims, left_side, lower, transpose_a, conjugate_a, unit_diagonal): x, y = batched_args bx, by = batch_dims if bx is batching.not_mapped: if left_side: y = batching.moveaxis(y, by, -1) y_flat = y.reshape(y.shape[:-2] + (y.shape[-2] * y.shape[-1], )) bdim_out = y.ndim - 1 else: y = batching.moveaxis(y, by, -2) y_flat = y.reshape(y.shape[:-3] + (y.shape[-3] * y.shape[-2], y.shape[-1])) bdim_out = y.ndim - 2 out_flat = triangular_solve(x, y_flat, left_side=left_side, lower=lower, transpose_a=transpose_a, conjugate_a=conjugate_a, unit_diagonal=unit_diagonal) return out_flat.reshape(y.shape), bdim_out else: size = next(t.shape[i] for t, i in zip(batched_args, batch_dims) if i is not None) x = batching.bdim_at_front(x, bx, size) y = batching.bdim_at_front(y, by, size) return triangular_solve(x, y, left_side=left_side, lower=lower, transpose_a=transpose_a, conjugate_a=conjugate_a, unit_diagonal=unit_diagonal), 0
def _while_loop_batching_rule(args, dims, cond_nconsts, cond_jaxpr, body_nconsts, body_jaxpr): size, = {x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped} orig_batched = [d is not batching.not_mapped for d in dims] cconst_bat, bconst_bat, init_bat = split_list(orig_batched, [cond_nconsts, body_nconsts]) carry_bat = init_bat for _ in range(1000): batched = bconst_bat + carry_bat body_jaxpr_batched, carry_bat_out = batching.batch_jaxpr( body_jaxpr, size, batched, instantiate=carry_bat) cond_jaxpr_batched, (pred_bat,) = batching.batch_jaxpr( cond_jaxpr, size, cconst_bat + carry_bat, instantiate=False) carry_bat_out = _map(partial(operator.or_, pred_bat), carry_bat_out) if carry_bat_out == carry_bat: break else: carry_bat = carry_bat_out else: raise FixedPointError consts, init = split_list(args, [cond_nconsts + body_nconsts]) const_dims, init_dims = split_list(dims, [cond_nconsts + body_nconsts]) new_consts = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0 else x for x, d in zip(consts, const_dims)] new_init = [batching.broadcast(x, size, 0) if now_bat and not was_bat else batching.moveaxis(x, d, 0) if now_bat else x for x, d, was_bat, now_bat in zip(init, init_dims, init_bat, carry_bat)] outs = while_p.bind(*(new_consts + new_init), cond_nconsts=cond_nconsts, cond_jaxpr=cond_jaxpr_batched, body_nconsts=body_nconsts, body_jaxpr=body_jaxpr_batched) out_bdims = [0 if b else batching.not_mapped for b in carry_bat] return outs, out_bdims
def _scan_batching_rule(args, dims, forward, length, jaxpr, num_consts, num_carry, linear): num_ys = len(jaxpr.out_avals) - num_carry size, = { x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped } orig_batched = [d is not batching.not_mapped for d in dims] const_batched, init_batched, xs_batched = split_list( orig_batched, [num_consts, num_carry]) carry_batched = init_batched for _ in range(1000): batched = const_batched + carry_batched + xs_batched jaxpr_batched, batched_out = batching.batch_jaxpr( jaxpr, size, batched, instantiate=carry_batched + [False] * num_ys) carry_batched_out, ys_batched = batched_out[:num_carry], batched_out[ num_carry:] if carry_batched_out == carry_batched: break else: carry_batched = carry_batched_out else: raise FixedPointError consts, init, xs = split_list(args, [num_consts, num_carry]) consts_bdims, init_bdims, xs_bdims = split_list(dims, [num_consts, num_carry]) new_consts = [ batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0 else x for x, d in zip(consts, consts_bdims) ] new_init = [ batching.broadcast(x, size, 0) if now_batched and not was_batched else batching.moveaxis(x, d, 0) if now_batched else x for x, d, was_batched, now_batched in zip(init, init_bdims, init_batched, carry_batched) ] new_xs = [ batching.moveaxis(x, d, 1) if d is not batching.not_mapped and d != 1 else x for x, d in zip(xs, xs_bdims) ] new_args = new_consts + new_init + new_xs outs = scan_p.bind(*new_args, forward=forward, length=length, jaxpr=jaxpr_batched, num_consts=num_consts, num_carry=num_carry, linear=linear) carry_bdims = [0 if b else batching.not_mapped for b in carry_batched] ys_bdims = [1 if b else batching.not_mapped for b in ys_batched] return outs, carry_bdims + ys_bdims
def _lu_pivots_to_permutation_batching_rule(batched_args, batch_dims, *, permutation_size): x, = batched_args bd, = batch_dims x = batching.moveaxis(x, bd, 0) return lu_pivots_to_permutation_p.bind( x, permutation_size=permutation_size), 0
def _custom_vjp_call_jaxpr_vmap( args, in_dims, axis_name, main_type, *, fun_jaxpr: core.ClosedJaxpr, fwd_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]], bwd: lu.WrappedFun, out_trees: Callable, num_consts: int): axis_size, = {x.shape[d] for x, d in zip(args, in_dims) if d is not not_mapped} args = [batching.moveaxis(x, d, 0) if d is not not_mapped and d != 0 else x for x, d in zip(args, in_dims)] in_batched = [d is not not_mapped for d in in_dims] _, args_batched = split_list(in_batched, [num_consts]) batched_fun_jaxpr, out_batched = batching.batch_jaxpr( fun_jaxpr, axis_size, in_batched, False, axis_name, main_type) out_dims1 = [0 if b else not_mapped for b in out_batched] out_dims2 = [] @pe._memoize def batched_fwd_jaxpr_thunk(): fwd_jaxpr = core.ClosedJaxpr(*fwd_jaxpr_thunk()) # consts can be tracers batched_fwd_jaxpr, out_batched = batching.batch_jaxpr( fwd_jaxpr, axis_size, args_batched, False, axis_name, main_type) out_dims2.append([0 if b else not_mapped for b in out_batched]) return batched_fwd_jaxpr.jaxpr, batched_fwd_jaxpr.consts fwd_args_batched = [0 if b else not_mapped for b in args_batched] fwd_out_dims = lambda: out_dims2[0] batched_bwd = batching.batch_custom_vjp_bwd(bwd, axis_name, axis_size, fwd_out_dims, fwd_args_batched, main_type) batched_outs = custom_vjp_call_jaxpr_p.bind( *args, fun_jaxpr=batched_fun_jaxpr, fwd_jaxpr_thunk=batched_fwd_jaxpr_thunk, bwd=batched_bwd, out_trees=out_trees, num_consts=num_consts) out_dims = out_dims2[0] if out_dims2 else out_dims1 return batched_outs, out_dims
def _custom_jvp_call_jaxpr_vmap( args, in_dims, axis_name, main_type, *, fun_jaxpr: core.ClosedJaxpr, jvp_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]], num_consts: int): size, = {x.shape[d] for x, d in zip(args, in_dims) if d is not not_mapped} args = [batching.moveaxis(x, d, 0) if d is not not_mapped and d != 0 else x for x, d in zip(args, in_dims)] num_out = len(fun_jaxpr.out_avals) in_batched = [d is not not_mapped for d in in_dims] batched_fun_jaxpr, out_batched = batching.batch_jaxpr( fun_jaxpr, size, in_batched, False, axis_name, main_type) out_dims1 = [0 if b else not_mapped for b in out_batched] out_dims2 = [] # mutable cell updated by batched_jvp_jaxpr_thunk @pe._memoize def batched_jvp_jaxpr_thunk(): jvp_jaxpr = core.ClosedJaxpr(*jvp_jaxpr_thunk()) # consts can be tracers _, args_batched = split_list(in_batched, [num_consts]) _, all_batched = batching.batch_jaxpr(jvp_jaxpr, size, args_batched * 2, False, axis_name, main_type) primals_batched, tangents_batched = split_list(all_batched, [num_out]) out_batched = map(op.or_, primals_batched, tangents_batched) out_dims2.append([0 if b else not_mapped for b in out_batched]) batched_jvp_jaxpr, _ = batching.batch_jaxpr( jvp_jaxpr, size, args_batched * 2, out_batched * 2, axis_name, main_type) return batched_jvp_jaxpr.jaxpr, batched_jvp_jaxpr.consts batched_outs = custom_jvp_call_jaxpr_p.bind( *args, fun_jaxpr=batched_fun_jaxpr, jvp_jaxpr_thunk=batched_jvp_jaxpr_thunk, num_consts=num_consts) out_dims = out_dims2[0] if out_dims2 else out_dims1 return batched_outs, out_dims
def _cond_batching_rule(args, dims, true_jaxpr, false_jaxpr, true_nconsts, false_nconsts): # TODO: maybe avoid moving arg axes to front if we're promoting to select? args = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0 else x for x, d in zip(args, dims)] true_nops = len(true_jaxpr.in_avals) - true_nconsts (pred,), true_consts, true_ops, false_consts, false_ops = split_list( args, [1, true_nconsts, true_nops, false_nconsts]) size, = {x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped} orig_bat = [d is not batching.not_mapped for d in dims] (pred_bat,), tconst_bat, t_bat, fconst_bat, f_bat = split_list( orig_bat, [1, true_nconsts, true_nops, false_nconsts]) _, true_out_bat = batching.batch_jaxpr(true_jaxpr, size, tconst_bat + t_bat, False) _, false_out_bat = batching.batch_jaxpr(false_jaxpr, size, fconst_bat + f_bat, False) out_bat = [a or b for a, b in zip(true_out_bat, false_out_bat)] true_jaxpr_batched, _ = batching.batch_jaxpr(true_jaxpr, size, tconst_bat + t_bat, out_bat) false_jaxpr_batched, _ = batching.batch_jaxpr(false_jaxpr, size, fconst_bat + f_bat, out_bat) if pred_bat: true_out = core.jaxpr_as_fun(true_jaxpr_batched)(*(true_consts + true_ops)) false_out = core.jaxpr_as_fun(false_jaxpr_batched)(*(false_consts + false_ops)) true_out = [batching.broadcast(x, size, 0) if not b else x for x, b in zip(true_out, out_bat)] false_out = [batching.broadcast(x, size, 0) if not b else x for x, b in zip(false_out, out_bat)] return [_cond_pred_bcast_select(pred, t, f) for t, f in zip(true_out, false_out)], [0] * len(true_out) else: out_dims = [0 if b else batching.not_mapped for b in out_bat] return cond_p.bind( *itertools.chain([pred], true_consts, true_ops, false_consts, false_ops), true_jaxpr=true_jaxpr_batched, false_jaxpr=false_jaxpr_batched, true_nconsts=len(true_consts), false_nconsts=len(false_consts)), out_dims
def eig_batching_rule(batched_args, batch_dims, *, compute_left_eigenvectors, compute_right_eigenvectors): x, = batched_args bd, = batch_dims x = batching.moveaxis(x, bd, 0) return (eig_p.bind(x, compute_left_eigenvectors=compute_left_eigenvectors, compute_right_eigenvectors=compute_right_eigenvectors), (0,) * (1 + compute_left_eigenvectors + compute_right_eigenvectors))
def svd_batching_rule(batched_args, batch_dims, full_matrices, compute_uv): x, = batched_args bd, = batch_dims x = batching.moveaxis(x, bd, 0) outs = svd_p.bind(x, full_matrices=full_matrices, compute_uv=compute_uv) if compute_uv: return outs, (0, 0, 0) else: return outs, (0, )
def _cond_batching_rule(axis_size, axis_name, main_type, args, dims, branches, linear): index, *ops = args index_dim, *op_dims = dims if index_dim is not batching.not_mapped: # Convert to a lax.select. While we could get away with not broadcasting # some operands yet, because all outputs must be broadcast together anyway # for the select we broadcast the input operands for simplicity and leave # optimizations to XLA. # TODO(mattjj,frostig): assumes branches are side-effect-free, revise! index, *ops = (batching.bdim_at_front(x, d, axis_size) for x, d in zip(args, dims)) in_batched = [True] * len(branches[0].in_avals) out_batched = [True] * len(branches[0].out_avals) branches_batched = [ batching.batch_jaxpr(jaxpr, axis_size, in_batched, out_batched, axis_name, main_type)[0] for jaxpr in branches ] branch_outs = [] for i, jaxpr in enumerate(branches_batched): # Perform a select on the inputs for safety of reverse-mode autodiff; see # https://github.com/google/jax/issues/1052 predicate = lax.eq(index, lax._const(index, i)) ops_ = [ _bcast_select(predicate, x, lax.stop_gradient(x)) for x in ops ] branch_outs.append(core.jaxpr_as_fun(jaxpr)(*ops_)) out = [_bcast_select_n(index, *outs) for outs in zip(*branch_outs)] return out, [0 if b else None for b in out_batched] else: ops_bat = [d is not batching.not_mapped for d in op_dims] ops = [ batching.moveaxis(x, d, 0) if b else x for b, x, d in zip(ops_bat, ops, op_dims) ] branches_out_bat = [ batching.batch_jaxpr(jaxpr, axis_size, ops_bat, False, axis_name, main_type)[1] for jaxpr in branches ] out_bat = [any(bat) for bat in zip(*branches_out_bat)] branches_batched = tuple( batching.batch_jaxpr(jaxpr, axis_size, ops_bat, out_bat, axis_name, main_type)[0] for jaxpr in branches) out_dims = [0 if b else batching.not_mapped for b in out_bat] out = cond_p.bind(index, *ops, branches=branches_batched, linear=linear) return out, out_dims
def dex_call_batched(batched_args, batched_dims, func_atom): """Batching function for dex primitives. Args: batched_args: The possibly-batched arguments. batched_dims: A sequence of the same length as `batched_args`, where each entry indicates the batching axis of the corresponding entry to `args`, or None if that argument should not be batched. Not all entries can be None. Returns: 2-tuple containing the result of the batched function, and the result axis which was batched, which is always zero. """ module = func_atom.module.copy() # Move axes so that we only have to deal with the zero axis being batched. uniform_batched_args = [ batching.moveaxis(arg, bd, 0) if bd is not batching.not_mapped else arg for arg, bd in zip(batched_args, batched_dims) ] # This assumes not all entries in batched_dims are None. batch_size = next(arg.shape[0] for arg, bd in zip(uniform_batched_args, batched_dims) if bd is not batching.not_mapped) # Add the current function atom as a variable in the context, so that we can # use it to apply batching. func_name = func_atom.name assert func_name is not None # Only index into the arguments which are batched. `i` is the index used for # the Dex for loop constructor. batched_fn_params = [ f"x{param_idx}" if dim is batching.not_mapped else f"x{param_idx}.i" for param_idx, dim in enumerate(batched_dims) ] # This is the actual batching expression batched_fn = module.eval(r"\ " + " ".join(f"x{i}" for i in range(len(batched_args))) + ". " + f"for i:(Fin {batch_size}). {func_name} " + " ".join(batched_fn_params)) return primitive(batched_fn)(*uniform_batched_args), 0
def _approx_top_k_batch_rule(batched_args, batch_dims, *, k, reduction_dimension, recall_target, is_max_k, reduction_input_size_override, aggregate_to_topk): prototype_arg, new_bdim = next( (a, b) for a, b in zip(batched_args, batch_dims) if b is not None) new_args = [] for arg, bdim in zip(batched_args, batch_dims): if bdim is None: dims = np.delete(np.arange(prototype_arg.ndim), new_bdim) new_args.append(lax.broadcast_in_dim(arg, prototype_arg.shape, dims)) else: new_args.append(batching.moveaxis(arg, bdim, new_bdim)) new_reduction_dim = reduction_dimension + (new_bdim <= reduction_dimension) bdims = (new_bdim,) * len(new_args) return (approx_top_k_p.bind( *new_args, k=k, reduction_dimension=new_reduction_dim, recall_target=recall_target, is_max_k=False, reduction_input_size_override=reduction_input_size_override, aggregate_to_topk=aggregate_to_topk), bdims)
def qr_batching_rule(batched_args, batch_dims, full_matrices): x, = batched_args bd, = batch_dims x = batching.moveaxis(x, bd, 0) return qr_p.bind(x, full_matrices=full_matrices), (0, 0)
def _lu_batching_rule(batched_args, batch_dims): x, = batched_args bd, = batch_dims x = batching.moveaxis(x, bd, 0) return lu_p.bind(x), (0, 0)
def eigh_batching_rule(batched_args, batch_dims, lower): x, = batched_args bd, = batch_dims x = batching.moveaxis(x, bd, 0) return eigh_p.bind(x, lower=lower), (0, 0)
def cholesky_batching_rule(batched_args, batch_dims): x, = batched_args bd, = batch_dims x = batching.moveaxis(x, bd, 0) return cholesky(x), 0
def fft_batching_rule(batched_args, batch_dims, fft_type, fft_lengths): x, = batched_args bd, = batch_dims x = batching.moveaxis(x, bd, 0) return fft(x, fft_type, fft_lengths), 0
def _linear_solve_batching_rule(axis_size, axis_name, main_type, args, dims, const_lengths, jaxprs): orig_bat = [d is not batching.not_mapped for d in dims] params, b = _split_linear_solve_args(args, const_lengths) params_dims, b_dims = _split_linear_solve_args(dims, const_lengths) params_bat, orig_b_bat = _split_linear_solve_args(orig_bat, const_lengths) (matvec, vecmat, solve, solve_t) = jaxprs (matvec_bat, vecmat_bat, solve_bat, solve_t_bat) = params_bat num_aux = len(solve.out_avals) - len(matvec.out_avals) # Fixpoint computation of which parts of x and b are batched; we need to # ensure this is consistent between all four jaxprs b_bat = orig_b_bat x_bat = [False] * len(solve.out_avals) for i in range(1 + len(orig_b_bat) + len(solve.out_avals)): # Apply vecmat and solve -> new batched parts of x solve_jaxpr_batched, solve_x_bat = batching.batch_jaxpr( solve, axis_size, solve_bat + b_bat, instantiate=x_bat, axis_name=axis_name, main_type=main_type) if vecmat is None: vecmat_jaxpr_batched = None x_bat_out = solve_x_bat else: vecmat_jaxpr_batched, vecmat_x_bat = batching.batch_jaxpr( vecmat, axis_size, vecmat_bat + b_bat, instantiate=x_bat, axis_name=axis_name, main_type=main_type) # batch all aux data by default x_bat_out = _map(operator.or_, vecmat_x_bat + [True] * num_aux, solve_x_bat) # Apply matvec and solve_t -> new batched parts of b matvec_jaxpr_batched, matvec_b_bat = batching.batch_jaxpr( matvec, axis_size, matvec_bat + x_bat_out, instantiate=b_bat, axis_name=axis_name, main_type=main_type) if solve_t is None: solve_t_jaxpr_batched = None b_bat_out = _map(operator.or_, matvec_b_bat, orig_b_bat) else: solve_t_jaxpr_batched, solve_t_b_aux_bat = batching.batch_jaxpr( solve_t, axis_size, solve_t_bat + x_bat_out, instantiate=b_bat, axis_name=axis_name, main_type=main_type) assert len(solve_t_b_aux_bat) == len(orig_b_bat) + num_aux solve_t_b_bat, _ = split_list(solve_t_b_aux_bat, [len(orig_b_bat)]) b_bat_out = _map(lambda m, s, o: m or s or o, matvec_b_bat, solve_t_b_bat, orig_b_bat) if x_bat_out == x_bat and b_bat_out == b_bat: break else: x_bat = x_bat_out b_bat = b_bat_out else: assert False, "Fixedpoint not reached" batched_jaxprs = _LinearSolveTuple(matvec_jaxpr_batched, vecmat_jaxpr_batched, solve_jaxpr_batched, solve_t_jaxpr_batched) # Move batched axes to the front new_params = [ batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0 else x for x, d in zip(_flatten(params), _flatten(params_dims)) ] # Broadcast out b if necessary new_b = [ batching.broadcast(x, axis_size, 0) if now_bat and not was_bat else batching.moveaxis(x, d, 0) if now_bat and d != 0 else x for x, d, was_bat, now_bat in zip(b, b_dims, orig_b_bat, b_bat) ] outs = linear_solve_p.bind(*(new_params + new_b), const_lengths=const_lengths, jaxprs=batched_jaxprs) out_dims = [ 0 if batched else batching.not_mapped for batched in solve_x_bat ] return outs, out_dims