def test_left2right_decode(self, beam_size, beam_alpha, temperature, top_k): max_decode_len = 5 batch_size = 2 vocab_size = 7 context = {"encoded_states": tf.ones((batch_size, 9), tf.float32)} def symbols_to_logits_fn(decodes, unused_context, i): logits = tf.equal( tf.tile( tf.expand_dims(tf.range(vocab_size, dtype=i.dtype), axis=0), [decodes.shape[0], 1]), i + 2) return tf.cast(logits, tf.float32) tf.set_random_seed(0) decodes = decoding.left2right_decode(symbols_to_logits_fn, context, batch_size, max_decode_len, vocab_size, beam_size=beam_size, beam_alpha=beam_alpha, temperature=temperature, top_k=top_k) self.assertAllEqual([[2, 3, 4, 5, 6]] * 2, decodes)
def predict(self, features, max_decode_len, beam_size, **beam_kwargs): """Predict.""" cache = self._encode(features, False) B, _, D = cache["memory"].shape T, V, H = max_decode_len, self._vocab_size, self._num_heads bias_1xTxT = attention.upper_triangle_bias(T, self._dtype) for i in range(len(self._decoder_layers)): cache[str(i)] = { "k": tf.zeros([B, H, T, D // H], self._dtype), "v": tf.zeros([B, H, T, D // H], self._dtype) } def symbols_to_logits_fn(dec_BxT, context, i): """Decode loop.""" dec_Bx1 = tf.slice(dec_BxT, [0, tf.maximum(tf.cast(0, i.dtype), i - 1)], [dec_BxT.shape[0], 1]) bias_1x1xT = tf.slice(bias_1xTxT, [0, i, 0], [1, 1, T]) dec_Bx1xD = self._embedding_layer(dec_Bx1, True) dec_Bx1xD *= tf.cast(tf.greater(i, 0), self._dtype) dec_Bx1xD = timing.add_time_signal(dec_Bx1xD, start_index=i) with tf.variable_scope(self._decoder_scope_name, reuse=tf.AUTO_REUSE): dec_Bx1xD = transformer_block.stack(self._decoder_layers, False, dec_Bx1xD, bias_1x1xT, context["memory"], context["memory_bias"], context, i) dec_Bx1xD = contrib_layers.layer_norm(dec_Bx1xD, begin_norm_axis=2) logits_Bx1xV = self._embedding_layer(dec_Bx1xD, False) logits_BxV = tf.squeeze(logits_Bx1xV, axis=1) return logits_BxV decodes_BxT = decoding.left2right_decode(symbols_to_logits_fn, cache, B, T, V, beam_size, **beam_kwargs) return {"outputs": decodes_BxT}
def test_beam_decode_2(self): beam_size = 2 beam_alpha = 0.1 temperature = 0.0 max_decode_len = 3 batch_size = 1 vocab_size = 4 def symbols_to_logits_fn(unused_decodes, unused_context, unused_i): return tf.constant([[1, 2, 3, 4], [1, 2, 3, 4]], tf.float32) * 3 tf.set_random_seed(0) decodes = decoding.left2right_decode(symbols_to_logits_fn, {}, batch_size, max_decode_len, vocab_size, beam_size=beam_size, beam_alpha=beam_alpha, temperature=temperature) self.assertAllEqual([[3, 3, 1]], decodes)