def predict(self, Xnew, full_cov=False): with jax.checking_leaks(): if full_cov: mu, var = self.likelihood.predict(Xnew) return mu, var else: mu, var = self.likelihood.predict(Xnew) return mu, jnp.diag(var)
def testCustomJVPLeak(self): # https://github.com/google/jax/issues/8171 @jax.jit def fwd(): a = jnp.array(1.) def f(hx, _): hx = jax.nn.sigmoid(hx + a) return hx, None hx = jnp.array(0.) jax.lax.scan(f, hx, None, length=2) with jax.checking_leaks(): fwd() # doesn't crash
def test_eval_shape_no_leaked_tracers_under_leak_checker(self): with jax.checking_leaks(): stateful.eval_shape(SquareModule(), jnp.ones(())) # does not crash