예제 #1
0
파일: batching.py 프로젝트: wayfeng/jax
 def aval(self):
     aval = raise_to_shaped(core.get_aval(self.val))
     if self.batch_dim is not_mapped or aval is core.abstract_unit:
         return aval
     else:
         return core.mapped_aval(aval.shape[self.batch_dim], self.batch_dim,
                                 aval)
예제 #2
0
파일: batching.py 프로젝트: xueeinstein/jax
def _matchaxis_symbolic_zeros(axis_name, sz, name, src, dst, x, sum_match=False):
  # Just like `matchaxis`, but handles symbolic zeros using ad_util.py
  # TODO(mattjj): dedup with matchaxis
  if isinstance(x, Zero):
    if src == dst:
      return x
    elif type(src) == type(dst) == int:
      aval = core.mapped_aval(sz, src, x.aval)
      return Zero(core.unmapped_aval(sz, name, dst, aval))
    elif src is not_mapped and dst is not not_mapped:
      return Zero(core.unmapped_aval(sz, name, dst, x.aval))
    elif dst is not_mapped and sum_match:
      return Zero(core.mapped_aval(sz, src, x.aval))
    else:
      raise ValueError((axis_name, x, src, dst))
  else:
    return matchaxis(axis_name, sz, src, dst, x, sum_match=sum_match)
예제 #3
0
파일: batching.py 프로젝트: xueeinstein/jax
 def aval(self):
   aval = raise_to_shaped(core.get_aval(self.val))
   return core.mapped_aval(aval.shape[self.batch_dim], self.batch_dim, aval)
예제 #4
0
파일: unzip.py 프로젝트: yli96/probability
def mapped_aval(*args, **kwargs):
    return jax_core.mapped_aval(*args, **kwargs)