Пример #1
0
 def test_top_p(self):
   logits = tf.math.log(
       tf.constant([[0.01, 0.02, 0.3, 0.07, 0.6]], dtype=tf.float32))
   flt_min = tf.float32.min
   logits_1 = decoding.process_logits(logits, top_p=0.8)
   self.assertAllClose(
       [[flt_min, flt_min,
         math.log(0.3), flt_min,
         math.log(0.6)]], logits_1)
   logits_2 = decoding.process_logits(logits, top_p=0.1)
   self.assertAllClose([[flt_min, flt_min, flt_min, flt_min,
                         math.log(0.6)]], logits_2)
Пример #2
0
 def test_top_k(self):
     logits = tf.constant([[1, 2, 3, 4, 5]], dtype=tf.float32)
     flt_min = tf.float32.min
     logits = decoding.process_logits(logits, top_k=3)
     self.assertAllEqual([[flt_min, flt_min, 3, 4, 5]], logits)
 def logits_loop(i, decode_BxT, logits_BxTxV):
     logits_BxV = tf.reshape(logits_BxTxV[0][i], [batchsize, vocab_size])
     logits_BxV = process_logits(logits_BxV, top_k, top_p, temperature)
     sampled_BxT = inplace_update_i(decode_BxT, tf.argmax(logits_BxV, -1),
                                    i)
     return i + 1, sampled_BxT, logits_BxTxV