Ejemplo n.º 1
0
 def chooser_taylor_rule(primals_in, series_in, **params):
   operand, = primals_in
   gs, = series_in
   primal_out = chooser_fun(operand, **params)
   axes = params.pop("axes", None)
   primal_dtype = gs[0].dtype
   shape = [1 if i in axes else d for i, d in enumerate(operand.shape)]
   location_indicators = lax.convert_element_type(
         lax._eq_meet(operand, lax.reshape(primal_out, shape)), primal_dtype)
   counts = lax._reduce_sum(location_indicators, axes)
   def _reduce_chooser_taylor_rule(g):
     return lax.div(lax._reduce_sum(lax.mul(g, location_indicators), axes), counts)
   series_out = [_reduce_chooser_taylor_rule(g) for g in gs]
   return primal_out, series_out
Ejemplo n.º 2
0
Archivo: jet.py Proyecto: 0x0is1/jax
 def _reduce_chooser_taylor_rule(g):
     return lax.div(
         lax._reduce_sum(lax.mul(g, location_indicators), axes), counts)
Ejemplo n.º 3
0
def _psum_serial_pmap_rule(vals, axes):
    val, = vals
    axis, = axes
    return lax._reduce_sum(val, [axis]), None
Ejemplo n.º 4
0
 def f(x, y):
     z = lax.mul(x, y)
     w = lax.sin(z)
     u = lax._reduce_sum(w, [0])
     return (u, )