Exemple #1
0
def _matchaxis_symbolic_zeros(axis_name, sz, name, src, dst, x, sum_match=False):
  # Just like `matchaxis`, but handles symbolic zeros using ad_util.py
  # TODO(mattjj): dedup with matchaxis
  if isinstance(x, Zero):
    if src == dst:
      return x
    elif type(src) == type(dst) == int:
      aval = core.mapped_aval(sz, src, x.aval)
      return Zero(core.unmapped_aval(sz, name, dst, aval))
    elif src is not_mapped and dst is not not_mapped:
      return Zero(core.unmapped_aval(sz, name, dst, x.aval))
    elif dst is not_mapped and sum_match:
      return Zero(core.mapped_aval(sz, src, x.aval))
    else:
      raise ValueError((axis_name, x, src, dst))
  else:
    return matchaxis(axis_name, sz, src, dst, x, sum_match=sum_match)
Exemple #2
0
def _batch_jaxpr_axes(closed_jaxpr, axis_size, in_axes, out_axes_dest,
                      axis_name, main_type):
  f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr))
  f, out_batched = _batch_jaxpr_inner(f, axis_size, out_axes_dest)
  f = _batch_jaxpr_outer(f, axis_name, axis_size, in_axes, main_type)
  avals_in = [core.unmapped_aval(axis_size, axis_name, b, aval) if b is not not_mapped
              else aval for aval, b in zip(closed_jaxpr.in_avals, in_axes)]
  jaxpr_out, _, consts = pe.trace_to_jaxpr_dynamic(f, avals_in)
  return core.ClosedJaxpr(jaxpr_out, consts), out_batched()
Exemple #3
0
def _update_annotation(
    f: lu.WrappedFun,
    orig_type: Optional[Tuple[Tuple[core.AbstractValue, bool], ...]],
    axis_size: int, axis_name: core.AxisName, in_dims: Sequence[Optional[int]]
  ) -> lu.WrappedFun:
  if orig_type is None:
    return f
  batched_in_type = [(core.unmapped_aval(axis_size, axis_name, dim, aval), keep)
                     for dim, (aval, keep) in zip(in_dims, orig_type)]
  return lu.annotate(f, tuple(batched_in_type))
Exemple #4
0
def batch_jaxpr(jaxpr, axis_size, in_dims):
  dimvars = dict((v, v.aval) for v in jaxpr.in_dim_binders)
  in_avals = [_replace_vars_with_avals(dimvars, v.aval) for v in jaxpr.in_binders]

  in_avals = [core.unmapped_aval(axis_size, d, aval)
              if d is not batching.not_mapped else aval
              for d, aval in zip(in_dims, in_avals)]

  fun, out_dims = batching.batch_subtrace(lu.wrap_init(jaxpr_as_fun(jaxpr)))
  f = _batch_fun(fun, in_dims)
  jaxpr, consts, _ = trace_to_jaxpr_dynamic(f, in_avals)
  return jaxpr, consts, out_dims()
Exemple #5
0
 def unmap_zero(zero, in_axis):
   return (zero if in_axis is None else
           Zero(core.unmapped_aval(params['axis_size'], params['axis_name'], in_axis, zero.aval)))