def _reshape_axis_into(src, dst, x): # NB: `dst` is the number of the dimension that we should reshape into # *after* `src` is removed from `x`'s list of dimensions. For example, if # `src` is an added batch dimension, `dst` might name a target dimension in # the unbatched list of dimensions. perm = [i for i in range(x.ndim) if i != src] perm.insert(dst, src) new_shape = list(np.delete(x.shape, src)) new_shape[dst] *= x.shape[src] return lax.reshape(x, new_shape, perm)
def chooser_taylor_rule(primals_in, series_in, **params): operand, = primals_in gs, = series_in primal_out = chooser_fun(operand, **params) axes = params.pop("axes", None) primal_dtype = gs[0].dtype shape = [1 if i in axes else d for i, d in enumerate(operand.shape)] location_indicators = lax.convert_element_type( lax._eq_meet(operand, lax.reshape(primal_out, shape)), primal_dtype) counts = lax._reduce_sum(location_indicators, axes) def _reduce_chooser_taylor_rule(g): return lax.div(lax._reduce_sum(lax.mul(g, location_indicators), axes), counts) series_out = [_reduce_chooser_taylor_rule(g) for g in gs] return primal_out, series_out
def _reshape_axis_out_of(src, size1, x): shape = list(x.shape) size2, ragged = divmod(shape[src], size1) assert not ragged shape[src:src + 1] = [size1, size2] return lax.reshape(x, shape)
def _reshape_axis_into(src, dst, x): perm = [i for i in range(x.ndim) if i != src] perm.insert(dst, src) new_shape = list(np.delete(x.shape, src)) new_shape[dst] *= x.shape[src] return lax.reshape(x, new_shape, perm)