Ejemplo n.º 1
0
 def predict(self, Xnew, full_cov=False):
     with jax.checking_leaks():
         if full_cov:
             mu, var = self.likelihood.predict(Xnew)
             return mu, var
         else:
             mu, var = self.likelihood.predict(Xnew)
             return mu, jnp.diag(var)
Ejemplo n.º 2
0
Archivo: nn_test.py Proyecto: gtr8/jax
  def testCustomJVPLeak(self):
    # https://github.com/google/jax/issues/8171
    @jax.jit
    def fwd():
      a = jnp.array(1.)

      def f(hx, _):
        hx = jax.nn.sigmoid(hx + a)
        return hx, None

      hx = jnp.array(0.)
      jax.lax.scan(f, hx, None, length=2)

    with jax.checking_leaks():
      fwd()  # doesn't crash
Ejemplo n.º 3
0
 def test_eval_shape_no_leaked_tracers_under_leak_checker(self):
   with jax.checking_leaks():
     stateful.eval_shape(SquareModule(), jnp.ones(()))  # does not crash