Example #1
0
  def test_fori_traced_length(self):
    m = CountingModule()

    def f(lower, upper):
      y = stateful.fori_loop(lower, upper, lambda i, x: m(i), 2)
      return y

    # Because of the jit, lower and upper will be tracers.
    out = stateful.jit(f)(0, 3)
    self.assertEqual(out, 4)
    self.assertEqual(m.count, 3)
Example #2
0
    def testEmaUpdating(self, use_jit, dtype):
        if jax.local_devices()[0].platform == 'tpu' and dtype == jnp.float64:
            self.skipTest('F64 not supported by TPU')

        embedding_dim = 6
        np_dtype = np.float64 if dtype is jnp.float64 else np.float32
        decay = np.array(0.1, dtype=np_dtype)
        vqvae_module = vqvae.VectorQuantizerEMA(embedding_dim=embedding_dim,
                                                num_embeddings=7,
                                                commitment_cost=0.5,
                                                decay=decay,
                                                dtype=dtype)

        if use_jit:
            vqvae_f = stateful.jit(vqvae_module, static_argnums=1)
        else:
            vqvae_f = vqvae_module

        batch_size = 16

        prev_embeddings = vqvae_module.embeddings

        # Embeddings should change with every forwards pass if is_training == True.
        for _ in range(10):
            inputs = np.random.rand(batch_size, embedding_dim).astype(dtype)
            vqvae_f(inputs, True)
            current_embeddings = vqvae_module.embeddings
            self.assertFalse((prev_embeddings == current_embeddings).all())
            prev_embeddings = current_embeddings

        # Forward passes with is_training == False don't change anything
        for _ in range(10):
            inputs = np.random.rand(batch_size, embedding_dim).astype(dtype)
            vqvae_f(inputs, False)
            current_embeddings = vqvae_module.embeddings
            self.assertTrue((current_embeddings == prev_embeddings).all())
Example #3
0
 def test_jit_no_transform(self):
   x = jnp.array(2)
   with self.assertRaises(ValueError, msg="Use jax.jit() instead"):
     stateful.jit(lambda x: x**2)(x)
Example #4
0
 def test_jit(self):
   mod = SquareModule()
   x = jnp.array(2)
   y = stateful.jit(mod)(x)
   self.assertEqual(y, x ** 2)
Example #5
0
 def g(x, jit=False):
     mod = module_fn()
     if jit:
         mod = stateful.jit(mod)
     return mod(x)