示例#1
0
def _reshape_axis_into(src, dst, x):
  # NB: `dst` is the number of the dimension that we should reshape into
  # *after* `src` is removed from `x`'s list of dimensions. For example, if
  # `src` is an added batch dimension, `dst` might name a target dimension in
  # the unbatched list of dimensions.
  perm = [i for i in range(x.ndim) if i != src]
  perm.insert(dst, src)
  new_shape = list(np.delete(x.shape, src))
  new_shape[dst] *= x.shape[src]
  return lax.reshape(x, new_shape, perm)
示例#2
0
文件: jet.py 项目: jamestwebber/jax
 def chooser_taylor_rule(primals_in, series_in, **params):
   operand, = primals_in
   gs, = series_in
   primal_out = chooser_fun(operand, **params)
   axes = params.pop("axes", None)
   primal_dtype = gs[0].dtype
   shape = [1 if i in axes else d for i, d in enumerate(operand.shape)]
   location_indicators = lax.convert_element_type(
         lax._eq_meet(operand, lax.reshape(primal_out, shape)), primal_dtype)
   counts = lax._reduce_sum(location_indicators, axes)
   def _reduce_chooser_taylor_rule(g):
     return lax.div(lax._reduce_sum(lax.mul(g, location_indicators), axes), counts)
   series_out = [_reduce_chooser_taylor_rule(g) for g in gs]
   return primal_out, series_out
示例#3
0
def _reshape_axis_out_of(src, size1, x):
    shape = list(x.shape)
    size2, ragged = divmod(shape[src], size1)
    assert not ragged
    shape[src:src + 1] = [size1, size2]
    return lax.reshape(x, shape)
示例#4
0
def _reshape_axis_into(src, dst, x):
    perm = [i for i in range(x.ndim) if i != src]
    perm.insert(dst, src)
    new_shape = list(np.delete(x.shape, src))
    new_shape[dst] *= x.shape[src]
    return lax.reshape(x, new_shape, perm)