Esempio n. 1
0
            def my_function(x, axis_name):
                decay = np.array(0.9, dtype=np.float32)
                vqvae_module = vqvae.VectorQuantizerEMA(
                    embedding_dim=embedding_dim,
                    num_embeddings=7,
                    commitment_cost=0.5,
                    decay=decay,
                    cross_replica_axis=axis_name,
                    dtype=jnp.float32)

                outputs = vqvae_module(x, is_training=True)
                return vqvae_module.embeddings, outputs['perplexity']
Esempio n. 2
0
    def testEmaUpdating(self, use_jit, dtype):
        if jax.local_devices()[0].platform == 'tpu' and dtype == jnp.float64:
            self.skipTest('F64 not supported by TPU')

        embedding_dim = 6
        np_dtype = np.float64 if dtype is jnp.float64 else np.float32
        decay = np.array(0.1, dtype=np_dtype)
        vqvae_module = vqvae.VectorQuantizerEMA(embedding_dim=embedding_dim,
                                                num_embeddings=7,
                                                commitment_cost=0.5,
                                                decay=decay,
                                                dtype=dtype)

        if use_jit:
            vqvae_f = stateful.jit(vqvae_module, static_argnums=1)
        else:
            vqvae_f = vqvae_module

        batch_size = 16

        prev_embeddings = vqvae_module.embeddings

        # Embeddings should change with every forwards pass if is_training == True.
        for _ in range(10):
            inputs = np.random.rand(batch_size, embedding_dim).astype(dtype)
            vqvae_f(inputs, True)
            current_embeddings = vqvae_module.embeddings
            self.assertFalse((prev_embeddings == current_embeddings).all())
            prev_embeddings = current_embeddings

        # Forward passes with is_training == False don't change anything
        for _ in range(10):
            inputs = np.random.rand(batch_size, embedding_dim).astype(dtype)
            vqvae_f(inputs, False)
            current_embeddings = vqvae_module.embeddings
            self.assertTrue((current_embeddings == prev_embeddings).all())