def chooser_taylor_rule(primals_in, series_in, **params): operand, = primals_in gs, = series_in primal_out = chooser_fun(operand, **params) axes = params.pop("axes", None) primal_dtype = gs[0].dtype shape = [1 if i in axes else d for i, d in enumerate(operand.shape)] location_indicators = lax.convert_element_type( lax._eq_meet(operand, lax.reshape(primal_out, shape)), primal_dtype) counts = lax._reduce_sum(location_indicators, axes) def _reduce_chooser_taylor_rule(g): return lax.div(lax._reduce_sum(lax.mul(g, location_indicators), axes), counts) series_out = [_reduce_chooser_taylor_rule(g) for g in gs] return primal_out, series_out
def _reduce_chooser_taylor_rule(g): return lax.div( lax._reduce_sum(lax.mul(g, location_indicators), axes), counts)
def _psum_serial_pmap_rule(vals, axes): val, = vals axis, = axes return lax._reduce_sum(val, [axis]), None
def f(x, y): z = lax.mul(x, y) w = lax.sin(z) u = lax._reduce_sum(w, [0]) return (u, )