def test_attn_value(qkv): q, k, v = qkv with tf.device("/cpu:0"): q = tf.zeros_like(q) dot_product_attention = SeqDotProductAttention(0.0) res = dot_product_attention((q, k, v, None)) res, gold = res.numpy(), v.numpy() B, H, T, _ = q.get_shape().as_list() for b in range(B): for h in range(H): for t in range(T): np.testing.assert_allclose(res[b, h, t, :], np.mean(gold, axis=2)[b, h, :], atol=1e-5)
def test_attn_value_seq_mask(qkv): q, k, v = qkv with tf.device("/cpu:0"): B, H, T, _ = q.get_shape().as_list() q = tf.zeros_like(q) lens = np.random.randint(1, T, size=B).astype(np.int32) tf_lens = tf.constant(lens) mask = tf.expand_dims(tf.expand_dims(tf.sequence_mask(tf_lens, T, dtype=tf.float32), 1), 1) dot_product_attention = SeqDotProductAttention(0.0) res = dot_product_attention((q, k, v, mask)) res, gold = res.numpy(), v.numpy() for b in range(B): for h in range(H): for t in range(T): np.testing.assert_allclose( res[b, h, t, :], np.mean(gold[:, :, : lens[b], :], axis=2)[b, h, :], atol=1e-5 )