Exemple #1
0
    def test_left2right_decode(self, beam_size, beam_alpha, temperature,
                               top_k):
        max_decode_len = 5
        batch_size = 2
        vocab_size = 7
        context = {"encoded_states": tf.ones((batch_size, 9), tf.float32)}

        def symbols_to_logits_fn(decodes, unused_context, i):
            logits = tf.equal(
                tf.tile(
                    tf.expand_dims(tf.range(vocab_size, dtype=i.dtype),
                                   axis=0), [decodes.shape[0], 1]), i + 2)
            return tf.cast(logits, tf.float32)

        tf.set_random_seed(0)
        decodes = decoding.left2right_decode(symbols_to_logits_fn,
                                             context,
                                             batch_size,
                                             max_decode_len,
                                             vocab_size,
                                             beam_size=beam_size,
                                             beam_alpha=beam_alpha,
                                             temperature=temperature,
                                             top_k=top_k)
        self.assertAllEqual([[2, 3, 4, 5, 6]] * 2, decodes)
    def predict(self, features, max_decode_len, beam_size, **beam_kwargs):
        """Predict."""
        cache = self._encode(features, False)
        B, _, D = cache["memory"].shape
        T, V, H = max_decode_len, self._vocab_size, self._num_heads

        bias_1xTxT = attention.upper_triangle_bias(T, self._dtype)
        for i in range(len(self._decoder_layers)):
            cache[str(i)] = {
                "k": tf.zeros([B, H, T, D // H], self._dtype),
                "v": tf.zeros([B, H, T, D // H], self._dtype)
            }

        def symbols_to_logits_fn(dec_BxT, context, i):
            """Decode loop."""
            dec_Bx1 = tf.slice(dec_BxT, [0, tf.maximum(tf.cast(0, i.dtype), i - 1)],
                               [dec_BxT.shape[0], 1])
            bias_1x1xT = tf.slice(bias_1xTxT, [0, i, 0], [1, 1, T])
            dec_Bx1xD = self._embedding_layer(dec_Bx1, True)
            dec_Bx1xD *= tf.cast(tf.greater(i, 0), self._dtype)
            dec_Bx1xD = timing.add_time_signal(dec_Bx1xD, start_index=i)
            with tf.variable_scope(self._decoder_scope_name, reuse=tf.AUTO_REUSE):
                dec_Bx1xD = transformer_block.stack(self._decoder_layers, False,
                                                    dec_Bx1xD, bias_1x1xT,
                                                    context["memory"],
                                                    context["memory_bias"], context, i)
                dec_Bx1xD = contrib_layers.layer_norm(dec_Bx1xD, begin_norm_axis=2)
            logits_Bx1xV = self._embedding_layer(dec_Bx1xD, False)
            logits_BxV = tf.squeeze(logits_Bx1xV, axis=1)
            return logits_BxV

        decodes_BxT = decoding.left2right_decode(symbols_to_logits_fn, cache, B, T,
                                                 V, beam_size, **beam_kwargs)
        return {"outputs": decodes_BxT}
Exemple #3
0
    def test_beam_decode_2(self):
        beam_size = 2
        beam_alpha = 0.1
        temperature = 0.0
        max_decode_len = 3
        batch_size = 1
        vocab_size = 4

        def symbols_to_logits_fn(unused_decodes, unused_context, unused_i):
            return tf.constant([[1, 2, 3, 4], [1, 2, 3, 4]], tf.float32) * 3

        tf.set_random_seed(0)
        decodes = decoding.left2right_decode(symbols_to_logits_fn, {},
                                             batch_size,
                                             max_decode_len,
                                             vocab_size,
                                             beam_size=beam_size,
                                             beam_alpha=beam_alpha,
                                             temperature=temperature)
        self.assertAllEqual([[3, 3, 1]], decodes)