def _matchaxis_symbolic_zeros(axis_name, sz, name, src, dst, x, sum_match=False): # Just like `matchaxis`, but handles symbolic zeros using ad_util.py # TODO(mattjj): dedup with matchaxis if isinstance(x, Zero): if src == dst: return x elif type(src) == type(dst) == int: aval = core.mapped_aval(sz, src, x.aval) return Zero(core.unmapped_aval(sz, name, dst, aval)) elif src is not_mapped and dst is not not_mapped: return Zero(core.unmapped_aval(sz, name, dst, x.aval)) elif dst is not_mapped and sum_match: return Zero(core.mapped_aval(sz, src, x.aval)) else: raise ValueError((axis_name, x, src, dst)) else: return matchaxis(axis_name, sz, src, dst, x, sum_match=sum_match)
def _batch_jaxpr_axes(closed_jaxpr, axis_size, in_axes, out_axes_dest, axis_name, main_type): f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr)) f, out_batched = _batch_jaxpr_inner(f, axis_size, out_axes_dest) f = _batch_jaxpr_outer(f, axis_name, axis_size, in_axes, main_type) avals_in = [core.unmapped_aval(axis_size, axis_name, b, aval) if b is not not_mapped else aval for aval, b in zip(closed_jaxpr.in_avals, in_axes)] jaxpr_out, _, consts = pe.trace_to_jaxpr_dynamic(f, avals_in) return core.ClosedJaxpr(jaxpr_out, consts), out_batched()
def _update_annotation( f: lu.WrappedFun, orig_type: Optional[Tuple[Tuple[core.AbstractValue, bool], ...]], axis_size: int, axis_name: core.AxisName, in_dims: Sequence[Optional[int]] ) -> lu.WrappedFun: if orig_type is None: return f batched_in_type = [(core.unmapped_aval(axis_size, axis_name, dim, aval), keep) for dim, (aval, keep) in zip(in_dims, orig_type)] return lu.annotate(f, tuple(batched_in_type))
def batch_jaxpr(jaxpr, axis_size, in_dims): dimvars = dict((v, v.aval) for v in jaxpr.in_dim_binders) in_avals = [_replace_vars_with_avals(dimvars, v.aval) for v in jaxpr.in_binders] in_avals = [core.unmapped_aval(axis_size, d, aval) if d is not batching.not_mapped else aval for d, aval in zip(in_dims, in_avals)] fun, out_dims = batching.batch_subtrace(lu.wrap_init(jaxpr_as_fun(jaxpr))) f = _batch_fun(fun, in_dims) jaxpr, consts, _ = trace_to_jaxpr_dynamic(f, in_avals) return jaxpr, consts, out_dims()
def unmap_zero(zero, in_axis): return (zero if in_axis is None else Zero(core.unmapped_aval(params['axis_size'], params['axis_name'], in_axis, zero.aval)))