def encode(self,
               inputs,
               sequence_length=None,
               mode=tf.estimator.ModeKeys.TRAIN):
        if self.position_encoder is not None:
            inputs = self.position_encoder(inputs,
                                           sequence_length=sequence_length)

        inputs = tf.layers.dropout(
            inputs,
            rate=self.dropout,
            training=mode == tf.estimator.ModeKeys.TRAIN)
        mask = transformer.build_sequence_mask(
            sequence_length,
            num_heads=self.num_heads,
            maximum_length=tf.shape(inputs)[1],
            dtype=inputs.dtype)

        state = ()

        for l in range(self.num_layers):
            with tf.variable_scope("layer_{}".format(l)):
                with tf.variable_scope("multi_head"):
                    inputs_norm = transformer.norm(inputs)
                    context = transformer.multi_head_attention(
                        self.num_heads,
                        inputs_norm,
                        inputs_norm,
                        mode,
                        num_units=self.num_units,
                        mask=mask,
                        dropout=self.attention_dropout)
                    context = transformer.drop_and_add(inputs,
                                                       context,
                                                       mode,
                                                       dropout=self.dropout)

                with tf.variable_scope("ffn"):
                    transformed = transformer.feed_forward(
                        transformer.norm(context),
                        self.ffn_inner_dim,
                        mode,
                        dropout=self.relu_dropout)
                    transformed = transformer.drop_and_add(
                        context, transformed, mode, dropout=self.dropout)

                inputs = transformed
                state += (tf.reduce_mean(inputs, axis=1), )

        outputs = transformer.norm(inputs)
        return (outputs, state, sequence_length)
    def testBuildSequenceMask(self):
        num_heads = 4
        length = [5, 3, 7]
        expected = [[1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0],
                    [1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0],
                    [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]]

        mask = transformer.build_sequence_mask(tf.constant(length),
                                               num_heads=num_heads)

        with self.test_session() as sess:
            mask = sess.run(mask)
            mask = np.transpose(mask, (1, 0, 2))
            for b in range(len(length)):
                self.assertAllEqual(expected, mask[b])
Example #3
0
    def testScaledDotAttention(self):
        batch_size = 3
        num_heads = 8
        values_length = [5, 3, 7]
        queries_length = [8, 6, 10]
        depth = 20

        queries = tf.placeholder_with_default(
            np.random.randn(batch_size, num_heads, max(queries_length),
                            depth).astype(np.float32),
            shape=(None, num_heads, None, depth))
        values = tf.placeholder_with_default(
            np.random.randn(batch_size, num_heads, max(values_length),
                            depth).astype(np.float32),
            shape=(None, num_heads, None, depth))
        keys = values

        mask = transformer.build_sequence_mask(values_length,
                                               num_heads=num_heads)
        context, attn = transformer.scaled_dot_attention(
            queries, keys, values, tf.estimator.ModeKeys.PREDICT, mask=mask)

        with self.test_session() as sess:
            context, attn = sess.run([context, attn])
            self.assertTupleEqual(
                (batch_size, num_heads, max(queries_length), depth),
                context.shape)
            self.assertTupleEqual((batch_size, num_heads, max(queries_length),
                                   max(values_length)), attn.shape)

            for i in range(batch_size):
                length = values_length[i]
                padding_length = max(values_length) - length
                if padding_length > 0:
                    self.assertEqual(
                        0.0, np.sum(attn[i, :, :, length:max(values_length)]))
Example #4
0
    def _self_attention_stack(self,
                              inputs,
                              sequence_length=None,
                              mode=tf.estimator.ModeKeys.TRAIN,
                              cache=None,
                              memory=None,
                              memory_sequence_length=None):
        inputs = tf.layers.dropout(
            inputs,
            rate=self.dropout,
            training=mode == tf.estimator.ModeKeys.TRAIN)

        decoder_mask = None
        memory_mask = None

        if sequence_length is not None:
            decoder_mask = transformer.build_future_mask(
                sequence_length, num_heads=self.num_heads, dtype=inputs.dtype)
        if memory_sequence_length is not None:
            memory_mask = transformer.build_sequence_mask(
                memory_sequence_length,
                num_heads=self.num_heads,
                dtype=memory.dtype)

        for l in range(self.num_layers):
            layer_name = "layer_{}".format(l)
            layer_cache = cache[layer_name] if cache is not None else None
            with tf.variable_scope(layer_name):
                with tf.variable_scope("masked_multi_head"):
                    inputs_norm = transformer.norm(inputs)
                    encoded = transformer.multi_head_attention(
                        self.num_heads,
                        inputs_norm,
                        inputs_norm,
                        mode,
                        num_units=self.num_units,
                        mask=decoder_mask,
                        cache=layer_cache,
                        dropout=self.attention_dropout)
                    encoded = transformer.drop_and_add(inputs,
                                                       encoded,
                                                       mode,
                                                       dropout=self.dropout)

                if memory is not None:
                    with tf.variable_scope("multi_head"):
                        context = transformer.multi_head_attention(
                            self.num_heads,
                            transformer.norm(encoded),
                            memory,
                            mode,
                            mask=memory_mask,
                            dropout=self.attention_dropout)
                        context = transformer.drop_and_add(
                            encoded, context, mode, dropout=self.dropout)

                with tf.variable_scope("ffn"):
                    transformed = transformer.feed_forward(
                        transformer.norm(context),
                        self.ffn_inner_dim,
                        mode,
                        dropout=self.relu_dropout)
                    transformed = transformer.drop_and_add(
                        context, transformed, mode, dropout=self.dropout)

                inputs = transformed

        outputs = transformer.norm(inputs)
        return outputs