def test_fori_traced_length(self): m = CountingModule() def f(lower, upper): y = stateful.fori_loop(lower, upper, lambda i, x: m(i), 2) return y # Because of the jit, lower and upper will be tracers. out = stateful.jit(f)(0, 3) self.assertEqual(out, 4) self.assertEqual(m.count, 3)
def testEmaUpdating(self, use_jit, dtype): if jax.local_devices()[0].platform == 'tpu' and dtype == jnp.float64: self.skipTest('F64 not supported by TPU') embedding_dim = 6 np_dtype = np.float64 if dtype is jnp.float64 else np.float32 decay = np.array(0.1, dtype=np_dtype) vqvae_module = vqvae.VectorQuantizerEMA(embedding_dim=embedding_dim, num_embeddings=7, commitment_cost=0.5, decay=decay, dtype=dtype) if use_jit: vqvae_f = stateful.jit(vqvae_module, static_argnums=1) else: vqvae_f = vqvae_module batch_size = 16 prev_embeddings = vqvae_module.embeddings # Embeddings should change with every forwards pass if is_training == True. for _ in range(10): inputs = np.random.rand(batch_size, embedding_dim).astype(dtype) vqvae_f(inputs, True) current_embeddings = vqvae_module.embeddings self.assertFalse((prev_embeddings == current_embeddings).all()) prev_embeddings = current_embeddings # Forward passes with is_training == False don't change anything for _ in range(10): inputs = np.random.rand(batch_size, embedding_dim).astype(dtype) vqvae_f(inputs, False) current_embeddings = vqvae_module.embeddings self.assertTrue((current_embeddings == prev_embeddings).all())
def test_jit_no_transform(self): x = jnp.array(2) with self.assertRaises(ValueError, msg="Use jax.jit() instead"): stateful.jit(lambda x: x**2)(x)
def test_jit(self): mod = SquareModule() x = jnp.array(2) y = stateful.jit(mod)(x) self.assertEqual(y, x ** 2)
def g(x, jit=False): mod = module_fn() if jit: mod = stateful.jit(mod) return mod(x)