def testMaskedScaledDotAttention(self): batch_size = 3 num_heads = 8 queries_length = [8, 6, 10] depth = 20 queries = tf.placeholder_with_default( np.random.randn(batch_size, num_heads, max(queries_length), depth).astype(np.float32), shape=(None, num_heads, None, depth)) mask = transformer.build_future_mask(queries_length, num_heads=num_heads) context, attn = transformer.scaled_dot_attention( queries, queries, queries, tf.estimator.ModeKeys.PREDICT, mask=mask) with self.test_session() as sess: context, attn = sess.run([context, attn]) illegal_connections = np.triu_indices(max(queries_length), 1) for i in range(batch_size): for h in range(num_heads): self.assertEqual(0.0, np.sum(attn[i, h][illegal_connections]))
def testBuildFutureMask(self): num_heads = 4 length = [2, 4, 3] expected = [[[1.0, 0.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0]], [[1.0, 0.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 1.0, 0.0], [1.0, 1.0, 1.0, 1.0]], [[1.0, 0.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 1.0, 0.0], [1.0, 1.0, 1.0, 0.0]]] mask = transformer.build_future_mask(tf.constant(length), num_heads=num_heads) with self.test_session() as sess: mask = sess.run(mask) mask = np.transpose(mask, (1, 0, 2, 3)) for b in range(len(length)): self.assertAllEqual(expected, mask[b])
def _self_attention_stack(self, inputs, sequence_length=None, mode=tf.estimator.ModeKeys.TRAIN, cache=None, memory=None, memory_sequence_length=None): inputs = tf.layers.dropout( inputs, rate=self.dropout, training=mode == tf.estimator.ModeKeys.TRAIN) decoder_mask = None memory_mask = None if sequence_length is not None: decoder_mask = transformer.build_future_mask( sequence_length, num_heads=self.num_heads, dtype=inputs.dtype) if memory_sequence_length is not None: memory_mask = transformer.build_sequence_mask( memory_sequence_length, num_heads=self.num_heads, dtype=memory.dtype) for l in range(self.num_layers): layer_name = "layer_{}".format(l) layer_cache = cache[layer_name] if cache is not None else None with tf.variable_scope(layer_name): with tf.variable_scope("masked_multi_head"): inputs_norm = transformer.norm(inputs) encoded = transformer.multi_head_attention( self.num_heads, inputs_norm, inputs_norm, mode, num_units=self.num_units, mask=decoder_mask, cache=layer_cache, dropout=self.attention_dropout) encoded = transformer.drop_and_add(inputs, encoded, mode, dropout=self.dropout) if memory is not None: with tf.variable_scope("multi_head"): context = transformer.multi_head_attention( self.num_heads, transformer.norm(encoded), memory, mode, mask=memory_mask, dropout=self.attention_dropout) context = transformer.drop_and_add( encoded, context, mode, dropout=self.dropout) with tf.variable_scope("ffn"): transformed = transformer.feed_forward( transformer.norm(context), self.ffn_inner_dim, mode, dropout=self.relu_dropout) transformed = transformer.drop_and_add( context, transformed, mode, dropout=self.dropout) inputs = transformed outputs = transformer.norm(inputs) return outputs