def custom_layer_cau_batch(trace, f, tracers, params): """Batching rule for layer_cau primitive to handle custom layers.""" vals, dims = jax_util.unzip2((t.val, t.batch_dim) for t in tracers) if all(dim is batching.not_mapped for dim in dims): return layer_cau_p.bind(f, *vals, **params) args = tree_util.tree_unflatten(params['in_tree'], vals) dims_ = [not_mapped if dim is None else dim for dim in dims] layer, args = args[0], args[1:] if hasattr(layer, '_call_and_update_batched'): num_params = len(tree_util.tree_leaves(layer)) layer_dims, arg_dims = dims_[:num_params], dims_[num_params:] if params['kwargs']['has_rng']: rng, args = args[0], args[1:] rng_dim, arg_dims = arg_dims[0], arg_dims[1:] mapping_over_layer = all(layer_dim is not not_mapped for layer_dim in layer_dims) mapping_over_args = all(arg_dim is not not_mapped for arg_dim in arg_dims) assert mapping_over_layer or mapping_over_args, (layer_dims, arg_dims) if not mapping_over_layer and mapping_over_args: if params['kwargs']['has_rng']: if rng_dim is not not_mapped: arg_dims = tuple(None if dim is not_mapped else dim for dim in arg_dims) map_fun = jax.vmap( lambda layer, rng, *args: _layer_cau_batched( layer, rng, *args, # pylint: disable=unnecessary-lambda, g-long-lambda **params['kwargs']), in_axes=(None, rng_dim) + (None, ) * len(arg_dims)) else: map_fun = lambda layer, *args: _layer_cau_batched( layer, *args, # pylint: disable=unnecessary-lambda, g-long-lambda **params['kwargs']) vals_out, update_out = map_fun(layer, rng, *args) else: vals_out, update_out = _layer_cau_batched( layer, *args, **params['kwargs']) vals_out = tree_util.tree_leaves(vals_out) update_out = tree_util.tree_leaves(update_out) assert all(dim == 0 for dim in arg_dims) # Assume dimensions out are consistent dims_out = (0, ) * len(vals_out) dims_update = (None, ) * len(update_out) dims_out = dims_out + dims_update # Call wrapped function to avoid linear_util error f.call_wrapped(*tracers) return [ batching.BatchTracer(trace, v, d) for v, d in zip(vals_out + update_out, dims_out + dims_update) ] f, dims_out = batching.batch_subtrace(f, trace.master, dims) vals_out = layer_cau_p.subcall('batch').bind(f, *vals, **params) return [ batching.BatchTracer(trace, v, d) for v, d in zip(vals_out, dims_out()) ]
def batch_jaxpr(jaxpr, axis_size, in_dims): dimvars = dict((v, v.aval) for v in jaxpr.in_dim_binders) in_avals = [_replace_vars_with_avals(dimvars, v.aval) for v in jaxpr.in_binders] in_avals = [core.unmapped_aval(axis_size, d, aval) if d is not batching.not_mapped else aval for d, aval in zip(in_dims, in_avals)] fun, out_dims = batching.batch_subtrace(lu.wrap_init(jaxpr_as_fun(jaxpr))) f = _batch_fun(fun, in_dims) jaxpr, consts, _ = trace_to_jaxpr_dynamic(f, in_avals) return jaxpr, consts, out_dims()
def vmap_unrestricted(f: lu.WrappedFun, *args, in_axes, axis_name, axis_size): f, out_axes = batching.batch_subtrace(f) f = batching._batch_outer(f, axis_name, axis_size, in_axes, batching.BatchTrace) outs = f.call_wrapped(*args) return outs, out_axes()
def batch_fun(fun: lu.WrappedFun, in_dims): fun, out_dims = batching.batch_subtrace(fun) return _batch_fun(fun, in_dims), out_dims