示例#1
0
文件: jet.py 项目: yangliuy/jax
 def chooser_taylor_rule(primals_in, series_in, **params):
   operand, = primals_in
   gs, = series_in
   primal_out = chooser_fun(operand, **params)
   axes = params.pop("axes", None)
   primal_dtype = gs[0].dtype
   shape = [1 if i in axes else d for i, d in enumerate(operand.shape)]
   location_indicators = lax.convert_element_type(
         lax._eq_meet(operand, lax.reshape(primal_out, shape)), primal_dtype)
   counts = lax._reduce_sum(location_indicators, axes)
   def _reduce_chooser_taylor_rule(g):
     return lax.div(lax._reduce_sum(lax.mul(g, location_indicators), axes), counts)
   series_out = [_reduce_chooser_taylor_rule(g) for g in gs]
   return primal_out, series_out
示例#2
0
def psum(x, axis_name, *, axis_index_groups=None):
    """Compute an all-reduce sum on ``x`` over the pmapped axis ``axis_name``.

  If ``x`` is a pytree then the result is equivalent to mapping this function to
  each leaf in the tree.

  Inputs of boolean dtype are converted to integers before the reduction.

  Args:
    x: array(s) with a mapped axis named ``axis_name``.
    axis_name: hashable Python object used to name a pmapped axis (see the
      :func:`jax.pmap` documentation for more details).
    axis_index_groups: optional list of lists containing axis indices (e.g. for
      an axis of size 4, [[0, 1], [2, 3]] would perform psums over the first
      two and last two replicas). Groups must cover all axis indices exactly
      once, and all groups must be the same size.


  Returns:
    Array(s) with the same shape as ``x`` representing the result of an
    all-reduce sum along the axis ``axis_name``.

  For example, with 4 XLA devices available:

  >>> x = np.arange(4)
  >>> y = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(x)
  >>> print(y)
  [6 6 6 6]
  >>> y = jax.pmap(lambda x: x / jax.lax.psum(x, 'i'), axis_name='i')(x)
  >>> print(y)
  [ 0.          0.16666667  0.33333334  0.5       ]
  """
    _validate_axis_index_groups(axis_index_groups)
    leaves, treedef = tree_util.tree_flatten(x)
    leaves = [
        lax.convert_element_type(l, np.int32)
        if dtypes.dtype(l) == np.bool_ else l for l in leaves
    ]
    out_flat = psum_p.bind(*leaves,
                           axis_name=axis_name,
                           axis_index_groups=axis_index_groups)
    return tree_util.tree_unflatten(treedef, out_flat)
示例#3
0
def _convert_element_type_papply_rule(name, size, vals, dims, new_dtype,
                                      **params):
    operand, = vals
    dim, = dims
    return lax.convert_element_type(operand, new_dtype), dim