Esempio n. 1
0
    def call(self, x):
        """Get token embeddings of x.

    Args:
      x: An int64 tensor with shape [batch_size, length]
    Returns:
      embeddings: float32 tensor with shape [batch_size, length, embedding_size]
      padding: float32 tensor with shape [batch_size, length] indicating the
        locations of the padding tokens in x.
    """
        with tf.name_scope("embedding"):
            # Create binary mask of size [batch_size, length]
            mask = tf.to_float(tf.not_equal(x, 0))

            if self.method == "gather":
                embeddings = tf.gather(self.shared_weights, x)
                embeddings *= tf.expand_dims(mask, -1)
            else:  # matmul
                embeddings = tpu_utils.embedding_matmul(
                    embedding_table=self.shared_weights,
                    values=tf.cast(x, dtype=tf.int32),
                    mask=mask)
                # embedding_matmul already zeros out masked positions, so
                # `embeddings *= tf.expand_dims(mask, -1)` is unnecessary.

            # Scale embedding by the sqrt of the hidden size
            embeddings *= self.hidden_size**0.5

            return embeddings
Esempio n. 2
0
    def _test_masking(self, embedding_dim, vocab_size, sequence_length,
                      batch_size, seed):
        """Test that matmul embedding properly zeros masked positions."""
        with self.test_session():
            embedding_table, values, mask = self.construct_embedding_and_values(
                embedding_dim=embedding_dim,
                vocab_size=vocab_size,
                sequence_length=sequence_length,
                batch_size=batch_size,
                seed=seed)

            matmul_embedding = tpu_utils.embedding_matmul(
                embedding_table=embedding_table, values=values, mask=mask)

            self.assertAllClose(matmul_embedding,
                                matmul_embedding * tf.expand_dims(mask, -1))
  def _test_embedding(self, embedding_dim, vocab_size,
                      sequence_length, batch_size, seed):
    """Test that matmul embedding matches embedding lookup (gather)."""

    with self.test_session():
      embedding_table, values, mask = self.construct_embedding_and_values(
          embedding_dim=embedding_dim,
          vocab_size=vocab_size,
          sequence_length=sequence_length,
          batch_size=batch_size,
          seed=seed
      )

      embedding = (tf.nn.embedding_lookup(params=embedding_table, ids=values) *
                   tf.expand_dims(mask, -1))

      matmul_embedding = tpu_utils.embedding_matmul(
          embedding_table=embedding_table, values=values, mask=mask)

      self.assertAllClose(embedding, matmul_embedding)