def inner_loop( i, hit_eos, next_id, next_id_tag, decoded_ids, decoded_ids_tag, cache, log_prob, ): """One step of greedy decoding.""" logits, logits_tag, cache = symbols_to_logits_fn( next_id, next_id_tag, i, cache) log_probs = common_layers.log_prob_from_logits(logits) temperature = sampling_temperature if hparams.sampling_method == 'random_per_example': next_id = common_layers.sample_temperature_per_example( logits, temperature, top_k) else: if hparams.sampling_method == 'argmax': temperature = 0.0 next_id = common_layers.sample_with_temperature( logits, temperature, top_k) if hparams.sampling_method == 'random_per_example': next_id_tag = common_layers.sample_temperature_per_example( logits_tag, temperature, top_k) else: if hparams.sampling_method == 'argmax': temperature = 0.0 next_id_tag = common_layers.sample_with_temperature( logits_tag, temperature, top_k) log_prob_indices = tf.stack( [tf.range(tf.to_int64(batch_size)), next_id], axis=1) log_prob += tf.gather_nd( log_probs, log_prob_indices) * (1 - tf.to_float(hit_eos)) hit_eos |= tf.equal(next_id, eos_id) next_id = tf.expand_dims(next_id, axis=1) decoded_ids = tf.concat([decoded_ids, next_id], axis=1) next_id_tag = tf.expand_dims(next_id_tag, axis=1) decoded_ids_tag = tf.concat([decoded_ids_tag, next_id_tag], axis=1) return ( i + 1, hit_eos, next_id, next_id_tag, decoded_ids, decoded_ids_tag, cache, log_prob, )
def testSampleTemperaturePerExampleDynamicBatchSize(self): batch_size = None vocab_size = 7 logits = tf.placeholder(tf.float32, shape=(batch_size, vocab_size)) temperature = tf.placeholder(tf.float32, shape=(batch_size, 1)) sampling_keep_top_k = tf.placeholder(tf.int32, shape=(batch_size, 1)) out = common_layers.sample_temperature_per_example( logits, temperature, sampling_keep_top_k) self.assertAllEqual(out.shape.as_list(), [batch_size])
def testSampleTemperaturePerExampleWithTopK2(self): batch_size = 3 vocab_size = 7 logits = np.random.randn(batch_size, vocab_size) temperature = np.random.rand(batch_size) top_k = np.array([3, -1, 4], dtype=np.int32) out = common_layers.sample_temperature_per_example( logits, temperature, top_k) self.assertAllEqual(self.evaluate(tf.shape(out)), [batch_size])
def testSampleTemperaturePerExample(self): batch_size = 3 seq_len = 5 vocab_size = 7 logits = np.random.randn(batch_size, seq_len, 1, 1, vocab_size) temperature = np.random.rand(batch_size) out = common_layers.sample_temperature_per_example(logits, temperature) self.assertAllEqual( self.evaluate(tf.shape(out)), [batch_size, seq_len, 1, 1])