示例#1
0
    def testMaskedScaledDotAttention(self):
        batch_size = 3
        num_heads = 8
        queries_length = [8, 6, 10]
        depth = 20

        queries = tf.placeholder_with_default(
            np.random.randn(batch_size, num_heads, max(queries_length),
                            depth).astype(np.float32),
            shape=(None, num_heads, None, depth))

        mask = transformer.build_future_mask(queries_length,
                                             num_heads=num_heads)
        context, attn = transformer.scaled_dot_attention(
            queries,
            queries,
            queries,
            tf.estimator.ModeKeys.PREDICT,
            mask=mask)

        with self.test_session() as sess:
            context, attn = sess.run([context, attn])
            illegal_connections = np.triu_indices(max(queries_length), 1)
            for i in range(batch_size):
                for h in range(num_heads):
                    self.assertEqual(0.0,
                                     np.sum(attn[i, h][illegal_connections]))
示例#2
0
    def testBuildFutureMask(self):
        num_heads = 4
        length = [2, 4, 3]
        expected = [[[1.0, 0.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0],
                     [1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0]],
                    [[1.0, 0.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0],
                     [1.0, 1.0, 1.0, 0.0], [1.0, 1.0, 1.0, 1.0]],
                    [[1.0, 0.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0],
                     [1.0, 1.0, 1.0, 0.0], [1.0, 1.0, 1.0, 0.0]]]

        mask = transformer.build_future_mask(tf.constant(length),
                                             num_heads=num_heads)

        with self.test_session() as sess:
            mask = sess.run(mask)
            mask = np.transpose(mask, (1, 0, 2, 3))
            for b in range(len(length)):
                self.assertAllEqual(expected, mask[b])
示例#3
0
    def _self_attention_stack(self,
                              inputs,
                              sequence_length=None,
                              mode=tf.estimator.ModeKeys.TRAIN,
                              cache=None,
                              memory=None,
                              memory_sequence_length=None):
        inputs = tf.layers.dropout(
            inputs,
            rate=self.dropout,
            training=mode == tf.estimator.ModeKeys.TRAIN)

        decoder_mask = None
        memory_mask = None

        if sequence_length is not None:
            decoder_mask = transformer.build_future_mask(
                sequence_length, num_heads=self.num_heads, dtype=inputs.dtype)
        if memory_sequence_length is not None:
            memory_mask = transformer.build_sequence_mask(
                memory_sequence_length,
                num_heads=self.num_heads,
                dtype=memory.dtype)

        for l in range(self.num_layers):
            layer_name = "layer_{}".format(l)
            layer_cache = cache[layer_name] if cache is not None else None
            with tf.variable_scope(layer_name):
                with tf.variable_scope("masked_multi_head"):
                    inputs_norm = transformer.norm(inputs)
                    encoded = transformer.multi_head_attention(
                        self.num_heads,
                        inputs_norm,
                        inputs_norm,
                        mode,
                        num_units=self.num_units,
                        mask=decoder_mask,
                        cache=layer_cache,
                        dropout=self.attention_dropout)
                    encoded = transformer.drop_and_add(inputs,
                                                       encoded,
                                                       mode,
                                                       dropout=self.dropout)

                if memory is not None:
                    with tf.variable_scope("multi_head"):
                        context = transformer.multi_head_attention(
                            self.num_heads,
                            transformer.norm(encoded),
                            memory,
                            mode,
                            mask=memory_mask,
                            dropout=self.attention_dropout)
                        context = transformer.drop_and_add(
                            encoded, context, mode, dropout=self.dropout)

                with tf.variable_scope("ffn"):
                    transformed = transformer.feed_forward(
                        transformer.norm(context),
                        self.ffn_inner_dim,
                        mode,
                        dropout=self.relu_dropout)
                    transformed = transformer.drop_and_add(
                        context, transformed, mode, dropout=self.dropout)

                inputs = transformed

        outputs = transformer.norm(inputs)
        return outputs