def test_attn_value(qkv):
    q, k, v = qkv
    with tf.device("/cpu:0"):
        q = tf.zeros_like(q)
        dot_product_attention = SeqDotProductAttention(0.0)
        res = dot_product_attention((q, k, v, None))
        res, gold = res.numpy(), v.numpy()
        B, H, T, _ = q.get_shape().as_list()
        for b in range(B):
            for h in range(H):
                for t in range(T):
                    np.testing.assert_allclose(res[b, h, t, :], np.mean(gold, axis=2)[b, h, :], atol=1e-5)
def test_attn_value_seq_mask(qkv):
    q, k, v = qkv
    with tf.device("/cpu:0"):
        B, H, T, _ = q.get_shape().as_list()
        q = tf.zeros_like(q)
        lens = np.random.randint(1, T, size=B).astype(np.int32)
        tf_lens = tf.constant(lens)
        mask = tf.expand_dims(tf.expand_dims(tf.sequence_mask(tf_lens, T, dtype=tf.float32), 1), 1)
        dot_product_attention = SeqDotProductAttention(0.0)
        res = dot_product_attention((q, k, v, mask))
        res, gold = res.numpy(), v.numpy()
        for b in range(B):
            for h in range(H):
                for t in range(T):
                    np.testing.assert_allclose(
                        res[b, h, t, :], np.mean(gold[:, :, : lens[b], :], axis=2)[b, h, :], atol=1e-5
                    )