def my_function(x, axis_name): decay = np.array(0.9, dtype=np.float32) vqvae_module = vqvae.VectorQuantizerEMA( embedding_dim=embedding_dim, num_embeddings=7, commitment_cost=0.5, decay=decay, cross_replica_axis=axis_name, dtype=jnp.float32) outputs = vqvae_module(x, is_training=True) return vqvae_module.embeddings, outputs['perplexity']
def testEmaUpdating(self, use_jit, dtype): if jax.local_devices()[0].platform == 'tpu' and dtype == jnp.float64: self.skipTest('F64 not supported by TPU') embedding_dim = 6 np_dtype = np.float64 if dtype is jnp.float64 else np.float32 decay = np.array(0.1, dtype=np_dtype) vqvae_module = vqvae.VectorQuantizerEMA(embedding_dim=embedding_dim, num_embeddings=7, commitment_cost=0.5, decay=decay, dtype=dtype) if use_jit: vqvae_f = stateful.jit(vqvae_module, static_argnums=1) else: vqvae_f = vqvae_module batch_size = 16 prev_embeddings = vqvae_module.embeddings # Embeddings should change with every forwards pass if is_training == True. for _ in range(10): inputs = np.random.rand(batch_size, embedding_dim).astype(dtype) vqvae_f(inputs, True) current_embeddings = vqvae_module.embeddings self.assertFalse((prev_embeddings == current_embeddings).all()) prev_embeddings = current_embeddings # Forward passes with is_training == False don't change anything for _ in range(10): inputs = np.random.rand(batch_size, embedding_dim).astype(dtype) vqvae_f(inputs, False) current_embeddings = vqvae_module.embeddings self.assertTrue((current_embeddings == prev_embeddings).all())