def test_global_embedding_input_out_of_range(self):

    vocab_size = 2
    total_vocab_size = vocab_size + 3  # +3 for pad/bos/eos.
    embedding_size = 3
    num_oov_buckets = 2

    inputs = tf.constant([[5], [6]], dtype=tf.int64)

    global_embedding = models.GlobalEmbedding(
        total_vocab_size=total_vocab_size,
        embedding_dim=embedding_size,
        mask_zero=True,
        name='global_embedding_layer')

    local_embedding = models.LocalEmbedding(
        input_dim=num_oov_buckets,
        embedding_dim=embedding_size,
        total_vocab_size=total_vocab_size,
        mask_zero=True,
        name='local_embedding_layer')

    global_output = global_embedding(inputs)
    local_output = local_embedding(inputs)

    self.assertAllEqual(global_output, tf.zeros_like(input=global_output))
    self.assertNotAllEqual(local_output, tf.zeros_like(input=local_output))
  def test_add_embeddings(self):
    vocab_size = 2
    total_vocab_size = vocab_size + 3  # +3 for pad/bos/eos.
    embedding_size = 3
    num_oov_buckets = 2

    inputs = tf.constant([[1], [0], [5], [0]], dtype=tf.int64)

    global_embedding = models.GlobalEmbedding(
        total_vocab_size=total_vocab_size,
        embedding_dim=embedding_size,
        mask_zero=True,
        name='global_embedding_layer')

    local_embedding = models.LocalEmbedding(
        input_dim=num_oov_buckets,
        embedding_dim=embedding_size,
        total_vocab_size=total_vocab_size,
        mask_zero=True,
        name='local_embedding_layer')

    projected = tf.keras.layers.Add()(
        [global_embedding(inputs),
         local_embedding(inputs)])

    projected = PassMask()(projected)

    self.assertAllEqual(projected, [[True], [False], [True], [False]])
  def test_project_on_known_embeddings(self):
    vocab_size = 2
    total_vocab_size = vocab_size + 3  # +3 for pad/bos/eos.
    embedding_size = 3
    num_oov_buckets = 1

    inputs = tf.constant([[1], [0], [2], [0], [5]], dtype=tf.int64)

    global_embedding = models.GlobalEmbedding(
        total_vocab_size=total_vocab_size,
        embedding_dim=embedding_size,
        mask_zero=True,
        initializer=tf.keras.initializers.Constant([[1, 1, 1], [2, 2, 2],
                                                    [3, 3, 3], [4, 4, 4],
                                                    [5, 5, 5]]),
        name='global_embedding_layer')

    local_embedding = models.LocalEmbedding(
        input_dim=num_oov_buckets,
        embedding_dim=embedding_size,
        total_vocab_size=total_vocab_size,
        mask_zero=True,
        initializer=tf.keras.initializers.Constant([[6, 6, 6]]),
        name='local_embedding_layer')

    global_output = global_embedding(inputs)
    local_output = local_embedding(inputs)

    self.assertAllEqual(
        global_output,
        [[[2, 2, 2]], [[1, 1, 1]], [[3, 3, 3]], [[1, 1, 1]], [[0, 0, 0]]])
    self.assertAllEqual(
        local_output,
        [[[0, 0, 0]], [[0, 0, 0]], [[0, 0, 0]], [[0, 0, 0]], [[6, 6, 6]]])
  def test_embedding_mask(self):
    vocab_size = 2
    total_vocab_size = vocab_size + 3  # +3 for pad/bos/eos.
    embedding_size = 3
    num_oov_buckets = 2

    inputs = tf.constant([[1], [2], [0], [0]], dtype=tf.int64)

    global_embedding = models.GlobalEmbedding(
        total_vocab_size=total_vocab_size,
        embedding_dim=embedding_size,
        mask_zero=True,
        name='global_embedding_layer')

    local_embedding = models.LocalEmbedding(
        input_dim=num_oov_buckets,
        embedding_dim=embedding_size,
        total_vocab_size=total_vocab_size,
        mask_zero=True,
        name='local_embedding_layer')

    self.assertAllEqual(
        global_embedding.compute_mask(inputs),
        [[True], [True], [False], [False]])
    self.assertAllEqual(
        local_embedding.compute_mask(inputs),
        [[True], [True], [False], [False]])