def testBasics(self):
    batch_size = 3
    source_length = 5
    target_length = 4
    num_heads = 2
    dim_per_head = 3
    q = np.random.rand(batch_size, target_length, num_heads,
                       dim_per_head).astype(np.float32)
    k = np.random.rand(batch_size, source_length, num_heads,
                       dim_per_head).astype(np.float32)
    v = np.random.rand(batch_size, source_length, num_heads,
                       dim_per_head).astype(np.float32)
    # attention window = 2
    sparsity_indices = np.concatenate([
        np.zeros([batch_size, target_length, num_heads, 1], dtype=np.int32),
        np.ones([batch_size, target_length, num_heads, 1], dtype=np.int32),
    ],
                                      axis=-1)

    with self.session() as sess:
      out, probs = sess.run(
          attention_util.ComputeSparseAttention(q, k, v, sparsity_indices))
      self.assertEqual(out.shape,
                       (batch_size, target_length, num_heads, dim_per_head))
      self.assertEqual(probs.shape,
                       (batch_size, target_length, num_heads, source_length))
      # attention weights sum to 1.
      self.assertAllClose(
          np.sum(probs, axis=-1),
          np.ones([batch_size, target_length, num_heads]))

    # attention window = 4, but last two are always paddings.
    sparsity_indices = np.concatenate([
        sparsity_indices,
        -np.ones([batch_size, target_length, num_heads, 2], dtype=np.int32),
    ],
                                      axis=-1)
    with self.session() as sess:
      out2, probs2 = sess.run(
          attention_util.ComputeSparseAttention(q, k, v, sparsity_indices))
      # We assert that the encoded outputs are the same as before,
      # and the attention weights are 0 on the padded positions.
      self.assertAllClose(out, out2)
      self.assertAllClose(probs, probs2)

    # attention window = 4.
    sparsity_indices = np.tile(
        np.arange(4, dtype=np.int32), [batch_size, target_length, num_heads, 1])
    # but position 2 and 3 are padded.
    paddings = np.tile([0., 0., 1., 1., 0.], [batch_size, 1])
    with self.session() as sess:
      out3, probs3 = sess.run(
          attention_util.ComputeSparseAttention(q, k, v, sparsity_indices,
                                                paddings))
      # We assert that the encoded outputs and attention weights are the same
      # as before.
      self.assertAllClose(out2, out3)
      self.assertAllClose(probs2, probs3)
    def testFullAttention(self):
        batch_size = 4
        source_length = 7
        target_length = 6
        num_heads = 3
        dim_per_head = 5
        q = np.random.rand(batch_size, target_length, num_heads,
                           dim_per_head).astype(np.float32)
        k = np.random.rand(batch_size, source_length, num_heads,
                           dim_per_head).astype(np.float32)
        v = np.random.rand(batch_size, source_length, num_heads,
                           dim_per_head).astype(np.float32)
        # attention window = source length, randomly permutated
        # np.arange(source_length)
        sparsity_indices = np.tile(
            np.random.permutation(source_length).astype(np.int32),
            [batch_size, target_length, num_heads, 1])

        with self.session() as sess:
            out, probs = sess.run(
                attention_util.ComputeSparseAttention(q, k, v,
                                                      sparsity_indices))

        # compute full attention in numpy
        expected_logit = np.einsum('BTNH, BSNH -> BTNS', q, k)
        expected_logit /= np.sqrt(dim_per_head)
        elexp = np.exp(expected_logit)
        expected_probs = elexp / np.expand_dims(np.sum(elexp, axis=-1),
                                                axis=-1)
        expected_output = np.einsum('BTNS, BSNH -> BTNH', expected_probs, v)

        # We assert that the output is close to the full attention,
        # since our sparsity_indices is range(source_length)
        self.assertAllClose(probs, expected_probs)
        self.assertAllClose(out, expected_output)