Beispiel #1
0
 def add_ildj(invals, z, z_ildj):
   x, y = invals
   if z is None:
     raise custom_inverse.NonInvertibleError()
   if x is not None and y is None:
     return (x, z - x), (jnp.zeros_like(z_ildj), z_ildj)
   elif x is None and y is not None:
     # Cannot invert if we don't know x
     raise custom_inverse.NonInvertibleError()
   return (None, None), (0., 0.)
Beispiel #2
0
 def dense_ildj2(invals, out, out_ildj):
   (w, b), _ = invals
   if w is None and b is None:
     raise custom_inverse.NonInvertibleError()
   in_ildj = jnp.ones_like(out)
   return ((w, b), w * out), ((jnp.zeros_like(w), jnp.zeros_like(b)),
                              out_ildj + in_ildj)
Beispiel #3
0
 def dense_ildj(invals, out, out_ildj):
   (w, b), _ = invals
   if w is None or b is None:
     raise custom_inverse.NonInvertibleError()
   in_ildj = core.ildj(lambda x: w * x + b)(out)
   return ((w, b), (out - b) / w), ((jnp.zeros_like(w), jnp.zeros_like(b)),
                                    out_ildj + in_ildj)
Beispiel #4
0
 def add_ildj(invals, z, z_ildj):
   x, y = invals
   if z is None:
     raise custom_inverse.NonInvertibleError()
   if x is not None and y is None:
     return (x, z - x), (jnp.zeros_like(z_ildj), z_ildj)
   elif x is None and y is not None:
     return (z - y, y), (z_ildj, jnp.zeros_like(z_ildj))
   return (None, None), (0., 0.)
Beispiel #5
0
 def add_one_inv(_):
   raise custom_inverse.NonInvertibleError()