def _body(step, finished, inputs, lengths, log_probs, state): logits, state = symbols_to_logits_fn(inputs, step, state) probs = tf.nn.log_softmax(tf.to_float(logits)) if min_decode_length > 0: probs = tf.cond( step < min_decode_length, true_fn=lambda: beam_search.penalize_token(probs, end_id), false_fn=lambda: probs) # Sample best prediction. sample_ids = tf.argmax(probs, axis=-1, output_type=inputs.dtype) sample_probs = tf.reduce_max(probs, axis=-1) # Don't update finished batches. masked_lengths_inc = 1 - tf.cast(finished, lengths.dtype) masked_probs = sample_probs * (1.0 - tf.cast(finished, sample_probs.dtype)) step = step + 1 inputs = tf.concat([inputs, tf.expand_dims(sample_ids, 1)], -1) lengths += masked_lengths_inc log_probs += masked_probs finished = tf.logical_or(finished, tf.equal(sample_ids, end_id)) if decode_length is not None: finished = tf.logical_or(finished, step >= decode_length) return step, finished, inputs, lengths, log_probs, state
def _body(step, finished, inputs, outputs, lengths, cum_log_probs, state): # Run next step. logits, state = symbols_to_logits_fn(inputs, step, state) log_probs = tf.nn.log_softmax(tf.cast(logits, tf.float32)) if min_decode_length > 0: log_probs = tf.cond( step < min_decode_length, true_fn=lambda: beam_search.penalize_token(log_probs, end_id), false_fn=lambda: log_probs) # Sample best prediction. sampled_log_probs, sampled_ids = tf.nn.top_k(log_probs, k=1) sampled_log_probs = tf.squeeze(sampled_log_probs, axis=1) sampled_ids = tf.squeeze(sampled_ids, axis=1) outputs = outputs.write(step, sampled_ids) # Don't update finished batches. lengths += 1 - tf.cast(finished, lengths.dtype) cum_log_probs += sampled_log_probs * ( 1.0 - tf.cast(finished, sampled_log_probs.dtype)) finished = tf.logical_or(finished, tf.equal(sampled_ids, end_id)) if last_step_as_input: next_inputs = sampled_ids else: next_inputs = tf.concat( [inputs, tf.expand_dims(sampled_ids, 1)], axis=1) return step + 1, finished, next_inputs, outputs, lengths, cum_log_probs, state
def testPenalizeToken(self): log_probs = tf.zeros([4, 6]) token_id = 1 log_probs = beam_search.penalize_token(log_probs, token_id) log_probs = self.evaluate(log_probs) self.assertTrue(np.all(log_probs[:, token_id] < 0)) non_penalized = np.delete(log_probs, 1, token_id) self.assertEqual(np.sum(non_penalized), 0)
def _body(step, finished, inputs, outputs, lengths, cum_log_probs, state): # Run next step. logits, state = symbols_to_logits_fn(inputs, step, state) logits = tf.cast(logits, tf.float32) if sample_temperature != 1: logits /= tf.cast(sample_temperature, logits.dtype) log_probs = tf.nn.log_softmax(logits) if min_decode_length > 0: log_probs = tf.cond( step < min_decode_length, true_fn=lambda: beam_search.penalize_token(log_probs, end_id), false_fn=lambda: log_probs) if sample_from == 1: # Sample best prediction. sampled_ids = tf.argmax(log_probs, axis=-1, output_type=inputs.dtype) elif sample_from == 0: # Sample from the full output distribution. distribution = tf.distributions.Categorical( probs=tf.exp(log_probs), dtype=inputs.dtype) sampled_ids = distribution.sample() else: # Sample from the top K. topk_log_probs, topk_ids = tf.nn.top_k(log_probs, k=sample_from) topk_ids = tf.cast(topk_ids, inputs.dtype) distribution = tf.distributions.Categorical(logits=topk_log_probs, dtype=inputs.dtype) topk_sampled_ids = distribution.sample() sampled_ids = tf.gather_nd( topk_ids, tf.stack([batch_ids, topk_sampled_ids], axis=-1)) sampled_log_probs = tf.gather_nd( log_probs, tf.stack([batch_ids, sampled_ids], axis=-1)) outputs = outputs.write(step, sampled_ids) # Don't update finished batches. lengths += 1 - tf.cast(finished, lengths.dtype) cum_log_probs += sampled_log_probs * ( 1.0 - tf.cast(finished, sampled_log_probs.dtype)) finished = tf.logical_or(finished, tf.equal(sampled_ids, end_id)) if last_step_as_input: next_inputs = sampled_ids else: next_inputs = tf.concat( [inputs, tf.expand_dims(sampled_ids, 1)], axis=1) return step + 1, finished, next_inputs, outputs, lengths, cum_log_probs, state