def call(self, x): """Get token embeddings of x. Args: x: An int64 tensor with shape [batch_size, length] Returns: embeddings: float32 tensor with shape [batch_size, length, embedding_size] padding: float32 tensor with shape [batch_size, length] indicating the locations of the padding tokens in x. """ with tf.name_scope("embedding"): # Create binary mask of size [batch_size, length] mask = tf.to_float(tf.not_equal(x, 0)) if self.method == "gather": embeddings = tf.gather(self.shared_weights, x) embeddings *= tf.expand_dims(mask, -1) else: # matmul embeddings = tpu_utils.embedding_matmul( embedding_table=self.shared_weights, values=tf.cast(x, dtype=tf.int32), mask=mask) # embedding_matmul already zeros out masked positions, so # `embeddings *= tf.expand_dims(mask, -1)` is unnecessary. # Scale embedding by the sqrt of the hidden size embeddings *= self.hidden_size**0.5 return embeddings
def _test_masking(self, embedding_dim, vocab_size, sequence_length, batch_size, seed): """Test that matmul embedding properly zeros masked positions.""" with self.test_session(): embedding_table, values, mask = self.construct_embedding_and_values( embedding_dim=embedding_dim, vocab_size=vocab_size, sequence_length=sequence_length, batch_size=batch_size, seed=seed) matmul_embedding = tpu_utils.embedding_matmul( embedding_table=embedding_table, values=values, mask=mask) self.assertAllClose(matmul_embedding, matmul_embedding * tf.expand_dims(mask, -1))
def _test_embedding(self, embedding_dim, vocab_size, sequence_length, batch_size, seed): """Test that matmul embedding matches embedding lookup (gather).""" with self.test_session(): embedding_table, values, mask = self.construct_embedding_and_values( embedding_dim=embedding_dim, vocab_size=vocab_size, sequence_length=sequence_length, batch_size=batch_size, seed=seed ) embedding = (tf.nn.embedding_lookup(params=embedding_table, ids=values) * tf.expand_dims(mask, -1)) matmul_embedding = tpu_utils.embedding_matmul( embedding_table=embedding_table, values=values, mask=mask) self.assertAllClose(embedding, matmul_embedding)