def chooser_taylor_rule(primals_in, series_in, **params): operand, = primals_in gs, = series_in primal_out = chooser_fun(operand, **params) axes = params.pop("axes", None) primal_dtype = gs[0].dtype shape = [1 if i in axes else d for i, d in enumerate(operand.shape)] location_indicators = lax.convert_element_type( lax._eq_meet(operand, lax.reshape(primal_out, shape)), primal_dtype) counts = lax._reduce_sum(location_indicators, axes) def _reduce_chooser_taylor_rule(g): return lax.div(lax._reduce_sum(lax.mul(g, location_indicators), axes), counts) series_out = [_reduce_chooser_taylor_rule(g) for g in gs] return primal_out, series_out
def psum(x, axis_name, *, axis_index_groups=None): """Compute an all-reduce sum on ``x`` over the pmapped axis ``axis_name``. If ``x`` is a pytree then the result is equivalent to mapping this function to each leaf in the tree. Inputs of boolean dtype are converted to integers before the reduction. Args: x: array(s) with a mapped axis named ``axis_name``. axis_name: hashable Python object used to name a pmapped axis (see the :func:`jax.pmap` documentation for more details). axis_index_groups: optional list of lists containing axis indices (e.g. for an axis of size 4, [[0, 1], [2, 3]] would perform psums over the first two and last two replicas). Groups must cover all axis indices exactly once, and all groups must be the same size. Returns: Array(s) with the same shape as ``x`` representing the result of an all-reduce sum along the axis ``axis_name``. For example, with 4 XLA devices available: >>> x = np.arange(4) >>> y = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(x) >>> print(y) [6 6 6 6] >>> y = jax.pmap(lambda x: x / jax.lax.psum(x, 'i'), axis_name='i')(x) >>> print(y) [ 0. 0.16666667 0.33333334 0.5 ] """ _validate_axis_index_groups(axis_index_groups) leaves, treedef = tree_util.tree_flatten(x) leaves = [ lax.convert_element_type(l, np.int32) if dtypes.dtype(l) == np.bool_ else l for l in leaves ] out_flat = psum_p.bind(*leaves, axis_name=axis_name, axis_index_groups=axis_index_groups) return tree_util.tree_unflatten(treedef, out_flat)
def _convert_element_type_papply_rule(name, size, vals, dims, new_dtype, **params): operand, = vals dim, = dims return lax.convert_element_type(operand, new_dtype), dim