def testMaskedScaledDotAttention(self): batch_size = 3 num_heads = 8 queries_length = [8, 6, 10] depth = 20 queries = tf.placeholder_with_default( np.random.randn(batch_size, num_heads, max(queries_length), depth).astype(np.float32), shape=(None, num_heads, None, depth)) mask = transformer.build_future_mask(queries_length, num_heads=num_heads) context, attn = transformer.dot_product_attention( queries, queries, queries, tf.estimator.ModeKeys.PREDICT, mask=mask) with self.test_session() as sess: context, attn = sess.run([context, attn]) illegal_connections = np.triu_indices(max(queries_length), 1) for i in range(batch_size): for h in range(num_heads): self.assertEqual(0.0, np.sum(attn[i, h][illegal_connections]))
def testScaledDotAttention(self): batch_size = 3 num_heads = 8 values_length = [5, 3, 7] queries_length = [8, 6, 10] depth = 20 queries = tf.placeholder_with_default( np.random.randn(batch_size, num_heads, max(queries_length), depth).astype(np.float32), shape=(None, num_heads, None, depth)) values = tf.placeholder_with_default( np.random.randn(batch_size, num_heads, max(values_length), depth).astype(np.float32), shape=(None, num_heads, None, depth)) keys = values mask = transformer.build_sequence_mask(values_length, num_heads=num_heads) context, attn = transformer.dot_product_attention( queries, keys, values, tf.estimator.ModeKeys.PREDICT, mask=mask) with self.test_session() as sess: context, attn = sess.run([context, attn]) self.assertTupleEqual( (batch_size, num_heads, max(queries_length), depth), context.shape) self.assertTupleEqual( (batch_size, num_heads, max(queries_length), max(values_length)), attn.shape) for i in range(batch_size): length = values_length[i] padding_length = max(values_length) - length if padding_length > 0: self.assertEqual(0.0, np.sum(attn[i, :, :, length:max(values_length)]))