def SequenceAppendToken(x, x_paddings, token, extend=False): """Appends <token> to sequence `x`. Args: x: A sequence of tokens of shape [batch_size, x_len_max]. x_paddings: The paddings of `x`. token: The token to append (of type integer). extend: Whether to extend `x` along the length dimension, this must be true for any sequence length in `x` that is `x_len_max` or else an invalid sequence will be emitted. Returns: A tuple. - The new sequence, Tensor of shape [batch_size, x_len_max]. - The new paddings, Tensor of shape [batch_size, x_len_max]. """ batch_size = py_utils.GetShape(x)[0] x_len = tf.to_int32(tf.round(tf.reduce_sum(1 - x_paddings, 1))) if extend: x = tf.pad(x, [[0, 0], [0, 1]]) # Mask all invalid entries of `x` to 0. x *= tf.sequence_mask(x_len, py_utils.GetShape(x)[1], x.dtype) # Append the <token> based on `x_len`. x += tf.scatter_nd(tf.stack([tf.range(batch_size), x_len], axis=1), tf.cast(tf.fill([batch_size], token), x.dtype), py_utils.GetShape(x)) x_paddings = 1 - tf.sequence_mask(x_len + 1, py_utils.GetShape(x)[1], x_paddings.dtype) return x, x_paddings
def Step(recurrent_theta, state0, inputs): """Computes one decoder step.""" del inputs with tf.name_scope('single_sampler_step'): # Compute logits and states. bs_result, bs_state1 = pre_step_callback( recurrent_theta.theta, recurrent_theta.encoder_outputs, tf.expand_dims(state0.ids, 1), # [batch, 1]. state0.bs_state, num_hyps_per_beam=1) batch = tf.shape(bs_result.log_probs)[0] state1 = py_utils.NestedMap(timestep=state0.timestep + 1) state1.logits = bs_result.log_probs # Sample ids from logits. [batch]. state1.ids = tf.reshape( tf.random.stateless_multinomial( state1.logits / p.temperature, num_samples=1, seed=tf.stack( [recurrent_theta.random_seed, state0.timestep]), output_dtype=state0.ids.dtype, name='sample_next_id'), [batch]) if 'is_last_chunk' in bs_result and p.target_eoc_id >= 0: state1.ids = tf.where( tf.logical_and(bs_result.is_last_chunk, tf.equal(state1.ids, p.target_eoc_id)), tf.fill(tf.shape(state1.ids), p.target_eos_id), state1.ids) state1.bs_state = post_step_callback( recurrent_theta.theta, recurrent_theta.encoder_outputs, state1.ids, bs_state1) return state1, py_utils.NestedMap()
def _KeepTopP(sorted_log_probs, p): """Keeps the top-p probability mass of `sorted_log_probs`. For each row, elements that are not included in the first `p` probability mass are set to `LARGE_NEGATIVE_NUMBER`. The first element is always kept as-is. Args: sorted_log_probs: A float tensor of shape [batch, k] that represents log-probabilities sorted in descending order. The probabilities do not need to sum to 1. p: A float tensor of shape [batch] that represents a probability threshold for each batch item. Returns: A tensor like `sorted_log_probs` where elements outside the top-p probability mass are set to `LARGE_NEGATIVE_NUMBER`. """ sorted_cum_probs = tf.math.cumsum(tf.exp(sorted_log_probs), exclusive=True, axis=-1) mask = tf.less(sorted_cum_probs, tf.expand_dims(p, axis=1)) # Set mask[:, 0] = True to always keep the first element. batch_size = tf.shape(mask)[0] true = tf.ones([batch_size, 1], dtype=tf.bool) mask = tf.concat([true, mask[:, 1:]], axis=1) filtered_sorted_log_probs = tf.where( mask, sorted_log_probs, tf.fill( tf.shape(sorted_log_probs), tf.constant(LARGE_NEGATIVE_NUMBER, dtype=sorted_log_probs.dtype))) return filtered_sorted_log_probs
def testBeamSearchDecodeBiased(self, bias, bias_only_if_consistent): dtype = tf.float32 with self.session(use_gpu=True) as sess, self.SetEval(True): tf.random.set_seed(_TF_RANDOM_SEED) src_batch = 2 p = self._DecoderParams(dtype=dtype) p.bias_only_if_consistent = bias_only_if_consistent p.target_seq_len = 6 p.beam_search.num_hyps_per_beam = 2 p.rnn_cell_dim = 32 dec = p.Instantiate() encoder_outputs, _ = self._Inputs(dtype=dtype) encoder_outputs['targets'] = py_utils.NestedMap( labels=tf.constant([[1, 3, 0, 0], [3, 4, 5, 2]]), paddings=tf.constant([[0, 0, 1, 1], [0, 0, 0, 0]], dtype=dtype)) encoder_outputs['targets']['weights'] = tf.fill( tf.shape(encoder_outputs.targets.labels), bias) decode = dec.BeamSearchDecodeBiased(encoder_outputs) # topk_decoded is None in MT decoder, set it to a fake tensor to pass # sess.run(decode). decode = decode._replace(topk_decoded=tf.constant(0, tf.float32)) tf.global_variables_initializer().run() actual_decode = sess.run(decode) num_hyps = src_batch * p.beam_search.num_hyps_per_beam self.assertTupleEqual((p.target_seq_len, num_hyps), actual_decode.done_hyps.shape) self.assertTupleEqual((src_batch, p.beam_search.num_hyps_per_beam), actual_decode.topk_hyps.shape) self.assertTupleEqual((num_hyps, p.target_seq_len), actual_decode.topk_ids.shape) self.assertTupleEqual((num_hyps,), actual_decode.topk_lens.shape) self.assertTupleEqual((src_batch, p.beam_search.num_hyps_per_beam), actual_decode.topk_scores.shape) if bias == 0: expected_topk_ids = [[2, 0, 0, 0, 0, 0], [13, 2, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]] expected_topk_lens = [1, 2, 0, 0] expected_topk_scores = [[-3.783162, -5.767723], [0., 0.]] elif bias == 1 and bias_only_if_consistent: expected_topk_ids = [[1, 3, 2, 0, 0, 0], [1, 3, 13, 2, 0, 0], [3, 4, 5, 2, 0, 0], [0, 0, 0, 0, 0, 0]] expected_topk_lens = [3, 4, 4, 0] expected_topk_scores = [[-3.073836, -5.474799], [-0.415888, 0.]] elif bias == 1 and (not bias_only_if_consistent): expected_topk_ids = [[1, 3, 2, 0, 0, 0], [1, 3, 13, 2, 0, 0], [3, 4, 5, 2, 0, 0], [3, 4, 0, 2, 0, 0]] expected_topk_lens = [3, 4, 4, 4] expected_topk_scores = [[-3.073836, -5.474799], [-0.415888, -23.295631]] self.assertAllEqual(expected_topk_ids, actual_decode.topk_ids) self.assertAllEqual(expected_topk_lens, actual_decode.topk_lens) self.assertAllClose(expected_topk_scores, actual_decode.topk_scores)
def testMassLayer(self): with self.session(use_gpu=False) as sess: batch_size = 3 seq_len = 10 p = self._MassParams() mass_layer = data_augmenter.MASS(p) seq_ids = tf.fill([batch_size, seq_len], 4) weights = tf.ones([batch_size, seq_len]) actual_seq_len = tf.fill([batch_size], 10) mass_out = mass_layer.Mask(seq_ids, weights, actual_seq_len) (src_ids, tgt_ids, tgt_labels, tgt_weights) = sess.run([ mass_out.src.ids, mass_out.tgt.ids, mass_out.tgt.labels, mass_out.tgt.weights ]) self.assertAllEqual(np.sum(src_ids == 3, axis=1), [5, 5, 5]) self.assertAllEqual(np.sum(tgt_ids == 3, axis=1), [5, 5, 5]) self.assertAllEqual( tgt_labels, 4 * np.ones([batch_size, seq_len], dtype=np.int32)) self.assertAllEqual(np.sum(tgt_weights, axis=1), [5., 5., 5.])
def Step(recurrent_theta, state0, inputs): """Computes one decoder step.""" if p.use_recurrent: del inputs with tf.name_scope('single_sampler_step'): # Compute logits and states. bs_result, bs_state1 = pre_step_callback( decoder_theta, recurrent_theta.encoder_outputs, tf.expand_dims(state0.ids, 1), # [batch, 1]. state0.bs_state, num_hyps_per_beam=p.num_hyps_per_beam) batch = tf.shape(bs_result.log_probs)[0] state1 = py_utils.NestedMap(timestep=state0.timestep + 1) state1.logits = bs_result.log_probs if p.top_k > 0: topk_logits, topk_ids = tf.math.top_k(state1.logits, k=p.top_k) sample_logits = tf.nn.log_softmax( topk_logits) if p.top_k_renormalize else topk_logits else: sample_logits = state1.logits # Sample ids from logits. [batch]. ids = tf.reshape( tf.random.stateless_categorical( sample_logits / p.temperature, num_samples=1, seed=tf.stack( [recurrent_theta.random_seed, state0.timestep]), dtype=state0.ids.dtype, name='sample_next_id'), [batch]) state1.ids = tf.gather(topk_ids, ids, axis=1, batch_dims=1) if p.top_k > 0 else ids if 'is_last_chunk' in bs_result and p.target_eoc_id >= 0: state1.ids = tf.where( tf.math.logical_and( bs_result.is_last_chunk, tf.equal(state1.ids, p.target_eoc_id)), tf.fill(tf.shape(state1.ids), p.target_eos_id), state1.ids) state1.bs_state = post_step_callback( decoder_theta, recurrent_theta.encoder_outputs, state1.ids, bs_state1) if p.use_recurrent: return state1, py_utils.NestedMap() else: inputs.ids = inputs.ids.write(state0.timestep, state1.ids) inputs.logits = inputs.logits.write(state0.timestep, state1.logits) return (recurrent_theta, state1, inputs)
def forward(inputs, alpha): with tf.name_scope("entmax_loss"): alpha_shape = inputs.get_shape().as_list() alpha_shape[axis] = 1 alpha = tf.fill(alpha_shape, alpha) alpha = tf.cast(alpha, dtype=inputs.dtype) d = inputs.get_shape().as_list()[axis] alpha_m1 = alpha - 1.0 inputs = inputs * alpha_m1 max_val = tf.math.reduce_max(inputs, axis=axis, keepdims=True) tau_lo = max_val - tf.ones(alpha.get_shape().as_list(), dtype=inputs.dtype) tau_hi = max_val - tf.math.pow( tf.cast((1.0 / d), dtype=inputs.dtype), alpha_m1) f_lo = tf.math.reduce_sum( _calculate_probability(tf.math.subtract(inputs, tau_lo), alpha), axis) - 1.0 dm = tau_hi - tau_lo for _ in range(n_iter): dm /= 2 tau_m = tau_lo + dm p_m = _calculate_probability(inputs - tau_m, alpha) f_m = tf.math.reduce_sum(p_m, axis) - 1.0 mask = tf.expand_dims(tf.math.greater(f_m * f_lo, 0), axis) tau_lo = tf.where(mask, tau_m, tau_lo) if ensure_sum_one: p_m /= tf.expand_dims(tf.math.reduce_sum(p_m, axis), axis) def grad_fn(d_outputs): with tf.name_scope("entmax_grad"): gppr = tf.where(p_m > 0, tf.math.pow(p_m, 2.0 - alpha), tf.zeros_like(p_m)) d_inputs = d_outputs * gppr q = tf.math.reduce_sum(d_inputs, axis) / tf.math.reduce_sum( gppr, axis) q = tf.expand_dims(q, axis) d_inputs -= q * gppr return d_inputs, d_inputs return p_m, grad_fn
def PreBeamSearchStepCallback(unused_theta, unused_encoder_outputs, unused_step_ids, states, unused_num_hyps_per_beam): # Same probs for each id. logits = tf.zeros([tgt_batch_size, vocab_size]) # Except eoc has slightly lower score. logits = logits - 1.0 * tf.expand_dims( tf.one_hot(p.target_eoc_id, vocab_size), 0) # eos has very low score (can not terminate by eos) logits = logits + eos_score * tf.expand_dims( tf.one_hot(p.target_eos_id, vocab_size), 0) return py_utils.NestedMap( atten_probs=tf.zeros([tgt_batch_size, 0]), log_probs=logits, is_last_chunk=tf.fill([tgt_batch_size], value=is_last_chunk)), states
def FillPaddingPos(ids: tf.Tensor, id_len: tf.Tensor, padding_value: int) -> tf.Tensor: """Given a batch of sequences, fills the padding pos with `padding_value`. Args: ids: a [B, max_len] int tensor. id_len: a [B, ] int tensor. padding_value: an int. Returns: new_ids: new ids with the property. - new_ids[b, :id_len[b]] = ids[b, :id_len[b]] - new_ids[b, id_len[b]:] = padding_value """ mask = py_utils.SequencePaddings(id_len, maxlen=tf.shape(ids)[1]) mask = tf.cast(mask, dtype=tf.bool) new_ids = tf.where(mask, tf.fill(tf.shape(ids), padding_value), ids) return new_ids
def _InitIterator(self): """Override of the root's _InitIterator to support dataset repeat.""" if self.host_id in self._dataset: return p = self.params self._repeat_steps = getattr(self._input_generator.params, 'repeat_steps', None) self._repeat_with_sentinel = getattr(self._input_generator.params, 'repeat_with_sentinel', None) with py_utils.GlobalStepContext(None): # Hide global_step tensor from being captured by dataset function. ds = self.GetDataset() if self._repeat_steps: tf.logging.info('Repeating dataset every %d steps.', self._repeat_steps) ds = ds.take(self._repeat_steps).repeat() elif self._repeat_with_sentinel: tf.logging.info('Attaching sentinel to end of dataset and repeat.') # Dataset should contain batches of type NestedMap. sentinel_batch = ds.element_spec.Transform( lambda x: tf.zeros(x.shape, dtype=x.dtype)) # Fill the dummy sentinel batch's sentinel_key tensor with sentinel_value. sentinel_batch[p.sentinel_key] = tf.fill( sentinel_batch[p.sentinel_key].shape, p.sentinel_value) tf.logging.info('attaching sentinel %r', sentinel_batch[p.sentinel_key]) tf.logging.info('sentinel type %r', sentinel_batch[p.sentinel_key].dtype) ds = ds.concatenate( tf.data.Dataset.from_tensors(sentinel_batch)).repeat() options = tf.data.Options() options.experimental_deterministic = bool(self.cluster.in_unit_test) ds = ds.with_options(options) self._dataset[self.host_id] = ds if tf.executing_eagerly_outside_functions(): it = iter(ds) else: it = tf.data.make_initializable_iterator(ds) self._iterator[self.host_id] = it
def MergeBeamSearchOutputs(max_hyps_per_beam, beam_search_outputs): """Merges beam search hyps from multiple decoders. Args: max_hyps_per_beam: the number of top hyps in the merged results. Must be less than or equal to total number of input hyps. beam_search_outputs: a list of BeamSearchDecodeOutput objects. Must share the same source_batch and max sequence length. Returns: A BeamSearchDecodeOutput object containing max_hyps_per_beam hypotheses per beam. """ source_batch = tf.shape(beam_search_outputs[0].topk_hyps)[0] value_dict = {} for output in beam_search_outputs: hyps_per_beam = py_utils.with_dependencies([ py_utils.assert_equal(source_batch, tf.shape(output.topk_hyps)[0]), ], tf.shape(output.topk_hyps)[1]) for k, v in six.iteritems(output._asdict()): if v is None: continue if k == 'done_hyps': v = tf.transpose(v) if k not in value_dict: value_dict[k] = [] value_dict[k].append(tf.reshape(v, [source_batch, hyps_per_beam, -1])) # Concatenate the tensors along the 'num_hyps_per_beam' dimension. concatenated = {} for k, values in six.iteritems(value_dict): if len(values) != len(beam_search_outputs): raise ValueError('Incomplete values for %s: %s' % (k, beam_search_outputs)) concatenated[k] = tf.concat(values, axis=1) scores = concatenated['topk_scores'] scores = tf.where( tf.equal(concatenated['topk_lens'], 0), tf.fill(tf.shape(scores), -1e6), scores) scores = tf.squeeze(scores, -1) # Select top max_hyps_per_beam indices per beam. _, top_indices = tf.nn.top_k(scores, max_hyps_per_beam) batch_ids = tf.tile( tf.expand_dims(tf.range(source_batch), -1), [1, max_hyps_per_beam]) # [source_batch, max_hyps_per_beam, 2] gather_indices = tf.stack([batch_ids, top_indices], axis=-1) # Gather the merged top hyps according to 'gather_indices'. top = beam_search_outputs[0]._asdict() total_hyps = source_batch * max_hyps_per_beam for k, v in six.iteritems(concatenated): v = tf.gather_nd(v, gather_indices) if k == 'done_hyps': v = tf.transpose(tf.reshape(v, [total_hyps, -1])) elif k == 'topk_hyps': v = tf.reshape(v, [source_batch, max_hyps_per_beam]) elif k == 'topk_ids': v = tf.reshape(v, [total_hyps, -1]) elif k in ('topk_lens', 'topk_scores', 'topk_decoded'): v = tf.reshape(v, [total_hyps]) else: raise ValueError('Unexpected field: %s' % k) top[k] = v return BeamSearchDecodeOutput(**top)
def Sample(self, decoder_theta, encoder_outputs, random_seed, init_state_callback, pre_step_callback, post_step_callback): """Samples target sequences, one target sequence per source sequence. (Please see beam_search_helper.py for description of decoder callbacks.) Args: decoder_theta: A NestedMap object containing weights' values of the decoder layer and its children layers, to be passed to decoder callbacks. encoder_outputs: the outputs of the encoder, to be passed to callbacks. random_seed: a scalar int32 tensor representing the random seed. init_state_callback: decoder._InitBeamSearchStateCallback. pre_step_callback: decoder._PreBeamSearchStepCallback. post_step_callback: decoder._PostBeamSearchStepCallback. Returns: A NestedMap containing the following tensors: - 'logits': [batch, max_target_length, vocab_size], representing the distribution from which target sequences are sampled. - 'ids': [batch, max_target_length] of int32, representing the target sequence ids, not including target_sos_id, but maybe ending with target_eos_id if end-of-sequence is reached before target_seq_len. - 'paddings': [batch, max_target_length] of 0/1, where 1 represents a padded timestep. """ p = self.params assert p.temperature > 0 # 'recurrent_theta' represents all cross-timestep information used by the # recurrent loop below, including layer theta and encoder outputs. recurrent_theta = py_utils.NestedMap( theta=decoder_theta, random_seed=random_seed, encoder_outputs=encoder_outputs) bs_result, bs_state = init_state_callback( recurrent_theta.theta, encoder_outputs, num_hyps_per_beam=1) batch = tf.shape(bs_result.log_probs)[0] recurrent_state0 = py_utils.NestedMap( timestep=tf.zeros(shape=[], dtype=tf.int32), logits=bs_result.log_probs, # Start with target_sos_id. ids=tf.fill([batch], tf.to_int32(p.target_sos_id)), bs_state=bs_state) inputs = py_utils.NestedMap(dummy=tf.zeros([p.target_seq_len, batch])) def Step(recurrent_theta, state0, inputs): """Computes one decoder step.""" del inputs with tf.name_scope('single_sampler_step'): # Compute logits and states. bs_result, bs_state1 = pre_step_callback( recurrent_theta.theta, recurrent_theta.encoder_outputs, tf.expand_dims(state0.ids, 1), # [batch, 1]. state0.bs_state, num_hyps_per_beam=1) batch = tf.shape(bs_result.log_probs)[0] state1 = py_utils.NestedMap(timestep=state0.timestep + 1) state1.logits = bs_result.log_probs # Sample ids from logits. [batch]. state1.ids = tf.reshape( tf.random.stateless_multinomial( state1.logits / p.temperature, num_samples=1, seed=tf.stack([recurrent_theta.random_seed, state0.timestep]), output_dtype=state0.ids.dtype, name='sample_next_id'), [batch]) if 'is_last_chunk' in bs_result and p.target_eoc_id >= 0: state1.ids = tf.where( tf.logical_and(bs_result.is_last_chunk, tf.equal(state1.ids, p.target_eoc_id)), tf.fill(tf.shape(state1.ids), p.target_eos_id), state1.ids) state1.bs_state = post_step_callback(recurrent_theta.theta, recurrent_theta.encoder_outputs, state1.ids, bs_state1) return state1, py_utils.NestedMap() accumulated_states, _ = recurrent.Recurrent(recurrent_theta, recurrent_state0, inputs, Step) result = py_utils.NestedMap( logits=tf.transpose(accumulated_states.logits, [1, 0, 2]), ids=tf.transpose(accumulated_states.ids)) result.paddings = tf.cast( _ComputePaddings(result.ids, p.target_eos_id), result.logits.dtype) # Force ids to be eos_id if the timestep is padded. result.ids = tf.where( tf.equal(result.paddings, 0), result.ids, tf.fill(tf.shape(result.ids), p.target_eos_id)) static_batch_size = bs_result.log_probs.shape[0] result.ids.set_shape([static_batch_size, p.target_seq_len]) result.paddings.set_shape([static_batch_size, p.target_seq_len]) return result
def BeamSearchDecode(self, theta, encoder_outputs, num_hyps_per_beam_override=0, init_beam_search_state=None, pre_beam_search_step_callback=None, post_beam_search_step_callback=None, max_steps=None): """Performs beam-search based decoding. Args: theta: A NestedMap object containing weights' values of the decoder layer and its children layers. encoder_outputs: A NestedMap containing encoder outputs to be passed to the callbacks. num_hyps_per_beam_override: If set to a value <= 0, this parameter is ignored. If set to a value > 0, then this value will be used to override `p.num_hyps_per_beam`. init_beam_search_state: The `InitBeamSearchState` callback. Please refer to the class header comments for more details. pre_beam_search_step_callback: The `PreBeamSearchStepCallback` callback. Please refer to the class header comments for more details. post_beam_search_step_callback: The `PostBeamSearchStepCallback` callback. Please refer to the class header comments for more details. max_steps: maximum beam search steps. If None, use self.params.target_seq_len. Returns: A `BeamSearchDecodeOutput`. """ p = self.params num_hyps_per_beam = p.num_hyps_per_beam if num_hyps_per_beam_override > 0: num_hyps_per_beam = num_hyps_per_beam_override if max_steps is None: max_steps = p.target_seq_len initial_results, other_states = init_beam_search_state( theta, encoder_outputs, num_hyps_per_beam) num_hyps = tf.shape(initial_results.log_probs)[0] num_beams = num_hyps // num_hyps_per_beam if 'step_ids' in initial_results: # [num_hyps, 1] step_ids = tf.ensure_shape(initial_results.step_ids, [None, 1]) else: step_ids = tf.fill([num_hyps, 1], tf.constant(p.target_sos_id, dtype=tf.int32)) min_score = -1e36 best_scores = (tf.zeros(shape=[num_beams], dtype=p.dtype) + min_score) cumulative_scores = tf.zeros(shape=[num_hyps], dtype=p.dtype) in_scores = tf.zeros([max_steps, num_hyps], dtype=p.dtype) in_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.int32) in_prev_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.int32) in_done_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.string) bs_atten_probs = tf.zeros( [max_steps, num_hyps, tf.shape(initial_results.atten_probs)[1]], dtype=p.dtype) cur_step = tf.constant(0, dtype=tf.int32) all_done = tf.constant(False, dtype=tf.bool) core_bs_states = (best_scores, cumulative_scores, in_scores, in_hyps, in_prev_hyps, in_done_hyps, bs_atten_probs) def LoopContinue(cur_step, all_done, unused_step_ids, unused_core_bs_states, unused_other_states_list): return tf.logical_and(cur_step < max_steps, tf.logical_not(all_done)) def LoopBody(cur_step, unused_all_done, step_ids, core_bs_states, other_states_list): (cur_step, all_done, new_step_ids, new_bs_states, new_other_states) = self._BeamSearchStep( theta, encoder_outputs, cur_step, step_ids, core_bs_states, other_states.Pack(other_states_list), num_hyps_per_beam, pre_beam_search_step_callback, post_beam_search_step_callback) return (cur_step, all_done, new_step_ids, new_bs_states, new_other_states.Flatten()) flat_other_states = other_states.Flatten() _, _, _, final_bs_states, flat_final_other_states = tf.while_loop( LoopContinue, LoopBody, loop_vars=(cur_step, all_done, step_ids, core_bs_states, flat_other_states), parallel_iterations=10, back_prop=False, swap_memory=False, shape_invariants=(tf.TensorShape(cur_step.get_shape()), tf.TensorShape(all_done.get_shape()), tf.TensorShape(step_ids.get_shape()), _GetShapes(core_bs_states), _GetShapes(flat_other_states, none_shapes=True))) # [target_seq_len, num_beams * num_hyps_per_beam]. final_done_hyps = final_bs_states[5] final_other_states = other_states.Pack(flat_final_other_states) # TODO(rpang): avoid inspecting 'encoder_outputs'. source_paddings = encoder_outputs.padding if isinstance(source_paddings, py_utils.NestedMap): source_seq_lengths = tf.cast( tf.round( tf.reduce_sum(1.0 - tf.transpose(source_paddings.Flatten()[0]), 1)), tf.int32) else: source_seq_lengths = tf.cast( tf.round(tf.reduce_sum(1.0 - tf.transpose(source_paddings), 1)), tf.int32) # [num_beams, num_hyps_per_beam]. topk_hyps = ops.top_k_terminated_hyps( final_done_hyps, source_seq_lengths, k=num_hyps_per_beam, num_hyps_per_beam=num_hyps_per_beam, length_normalization=p.length_normalization, coverage_penalty=p.coverage_penalty, target_seq_length_ratio=p.target_seq_length_ratio, eoc_id=p.target_eoc_id, merge_paths=p.merge_paths) # [num_beams * num_hyps_per_beam, ...]. max_seq_length = 0 if isinstance(max_steps, tf.Tensor) else max_steps topk_ids, topk_lens, topk_scores = ops.unpack_hyp( tf.reshape(topk_hyps, [-1]), max_seq_length=max_seq_length) # [num_beams, num_hyps_per_beam]. topk_scores = tf.reshape(topk_scores, tf.shape(topk_hyps)) return BeamSearchDecodeOutput(final_done_hyps, topk_hyps, topk_ids, topk_lens, topk_scores, None, final_other_states)
def FProp(self, theta, x, x_paddings=None, eos_id=1, force_sample_last_token=True): """Applies SymbolInsertionLayer. We take in a `x`, which represents the groundtruth sequence (i.e., English sequence). We return a sampled rollin (observed) canvas (i.e., random subset of the English sequence), as well as the target (indices) for an insertion-based model (i.e., the targets given the random observed subset). Args: theta: Ignored, this can be None. x: The symbol ids of shape `[batch_size, time_dim]`. x_paddings: The paddings (1 or 0) of shape `[batch_size, time_dim]` where 0 is valid and 1 is invalid. eos_id: The <eos> token id to represent end-of-slot. force_sample_last_token: Set True to force sample the last token of `x`. Returns: A `NestedMap`. - canvas: The canvas (based off of the `rollin_policy`) of shape [batch_size, c_dim]. Note that, `c_dim` <= `time_dim` but need not be equal. - canvas_indices: The canvas indices (into `x`). - canvas_paddings: The paddings of `canvas_indices`. - target_indices: The target indices of shape [num_targets, 3]. `num_targets` is the number of total targets in the entire batch. [:, 0] captures the batch, [:, 1] captures the slot, and [:, 2] captures the token. Each row [batch, slot, vocab] represents the indices of the target -- i.e., the batch, slot and vocab combination of the target. Typical usage of these indices is to tf.gather_nd the log-probs (from the softmax layer). - target_weights: The target weights. Raises: ValueError: If invalid params. """ p = self.params batch_size = py_utils.GetShape(x)[0] time_dim = py_utils.GetShape(x)[1] if x_paddings is None: x_paddings = tf.zeros([batch_size, time_dim], tf.float32) oracle_policy = p.oracle_policy rollin_policy = (oracle_policy if p.rollin_policy == 'oracle' else p.rollin_policy) if rollin_policy != 'uniform': raise ValueError('Unknown or unsupported rollin policy: %s' % rollin_policy) if oracle_policy != 'uniform': raise ValueError('Unknown or unsupported oracle policy: %s' % oracle_policy) x_len = tf.cast(tf.round(tf.reduce_sum(1 - x_paddings, 1)), tf.int32) # Compute the desired length per example in the batch. ratio = tf.random.uniform([batch_size], 0.0, 1.0, seed=p.random_seed) if force_sample_last_token: c_len = tf.minimum( tf.cast(ratio * tf.cast(x_len, tf.float32), tf.int32), x_len - 1) + 1 else: c_len = tf.minimum( tf.cast(ratio * tf.cast(x_len + 1, tf.float32), tf.int32), x_len) # Compute the maximum length across the batch. c_len_max = tf.reduce_max(c_len) # Grab subset of random valid indices per example. z_logits = tf.cast( tf.expand_dims(tf.range(time_dim), 0) >= tf.expand_dims(x_len, 1), tf.float32) * -1e9 if force_sample_last_token: # Force sample the last token -- i.e., as indexed by `x_len - 1`. We can # accomplish this by add +LARGE_NUMBER to the logits. z_logits += tf.cast( tf.equal(tf.expand_dims(tf.range(time_dim), 0), tf.expand_dims(x_len - 1, 1)), tf.float32) * 1e9 # Gumbel-max trick to sample (we only sample valid positions per sample in # the batch). z = -tf.math.log(-tf.math.log( tf.random.uniform([batch_size, time_dim], seed=p.random_seed))) unused_c_values, c_indices = tf.nn.top_k(z_logits + z, time_dim) # Trim everything > c_len_max. c_indices = c_indices[:, :c_len_max] # Invalidate any indices >= c_len, we use the last index as the default # invalid index. c_indices = tf.where( tf.expand_dims(tf.range(c_len_max), 0) < tf.expand_dims(c_len, 1), c_indices, tf.fill(py_utils.GetShape(c_indices), time_dim - 1)) # Materialize the canvas. c_indices = tf.sort(c_indices) c = tf.gather_nd( x, tf.stack([ tf.reshape( tf.tile(tf.expand_dims(tf.range(batch_size), 1), [1, c_len_max]), [-1]), tf.reshape(c_indices, [-1]) ], 1)) c = tf.reshape(c, [batch_size, c_len_max]) # Compute the paddings. c_paddings = 1 - tf.sequence_mask( c_len, c_len_max, dtype=x_paddings.dtype) c *= tf.cast(1 - c_paddings, tf.int32) indices = tf.concat([ tf.reshape( tf.tile(tf.expand_dims(tf.range(batch_size), 1), [1, c_len_max]), [batch_size * c_len_max, 1]), tf.reshape(c_indices, [batch_size * c_len_max, 1]) ], 1) x_token_is_observed = tf.scatter_nd( indices, tf.ones([batch_size * c_len_max], tf.int32), py_utils.GetShape(x)) # `x_segments` captures which slot each `x` belongs to (both observed and # tokens that need to be observed). x_segments = tf.cumsum(x_token_is_observed, 1, exclusive=True) x_token_is_observed = tf.cast(x_token_is_observed, tf.bool) prev_x_token_is_observed = tf.pad(x_token_is_observed[:, :-1], [[0, 0], [1, 0]], constant_values=True) x_token_is_observed = tf.reshape(x_token_is_observed, [-1]) prev_x_token_is_observed = tf.reshape(prev_x_token_is_observed, [-1]) x_is_valid = tf.cast(1 - x_paddings, tf.bool) x_is_valid = tf.reshape(x_is_valid, [-1]) # Remap all the observed to <eos>, note some of these need a zero weight # (or else there would be <eos> and valid token in the same slot). target_indices = tf.cast(tf.reshape(x, [-1, 1]), tf.int32) target_indices = tf.where( x_token_is_observed, tf.fill(py_utils.GetShape(target_indices), eos_id), target_indices) # TODO(williamchan): We give uniform 1.0 weight, however, math suggests # we may want to weigh this term by the original sequence length. target_weights = tf.ones_like(target_indices, tf.float32) # We need to set all the weights for <eos> which actually have valid tokens # in the slot to zero. target_weights = tf.where( x_token_is_observed & ~prev_x_token_is_observed, tf.zeros_like(target_weights), target_weights) # TODO(williamchan): Consider dropping the entries w/ weight zero. # Add the batch and slot indices. target_indices = tf.concat([ tf.reshape( tf.tile(tf.expand_dims(tf.range(batch_size), 1), [1, time_dim]), [batch_size * time_dim, 1]), tf.reshape(x_segments, [-1, 1]), target_indices ], 1) # Select only the valid indices. The selected valid ones include slots w/ # <eos>. target_indices = target_indices[x_is_valid] target_weights = target_weights[x_is_valid] return py_utils.NestedMap(canvas=c, canvas_indices=c_indices, canvas_paddings=c_paddings, target_indices=target_indices, target_weights=target_weights)
def _AddNoise(self, batch): """Adding noise the src (see https://arxiv.org/pdf/1711.00043). This function implement 3 types of noise (hyparams defined in self.params.denoise): 1) slightly shuffle the sentence following p.shuffle_tok_range 2) randomly drop tokens with probability p.drop_tok_prob 3) randomly mask tokens with probability p.blank_tok_prob The noises are added to the input with probability p.noise_sent_prob. Args: batch: a `.NestedMap` of the input batch. """ def IsSpecialExample(task_ids, special_task_ids): """A utility function indicates whether inputs belong to specific tasks. Args: task_ids: Task ids for the input batch. Tensor of shape [batch]. special_task_ids: A list of specified task ids. Returns: A tensor indicating whether each sample in the batch belong to the specified task. Return a tensor of size [batch]. """ batch_size = py_utils.GetShape(task_ids)[0] return tf.reduce_any( tf.equal( tf.expand_dims(task_ids, -1), tf.cast( tf.broadcast_to( special_task_ids, [batch_size, len(special_task_ids)]), tf.int32)), -1) p = self.params.denoise batch_size = tf.shape(batch.src.ids)[0] source_max_len = tf.shape(batch.src.ids)[1] # Shuffle tokens according to p.shuffle_tok_range noise = tf.random.uniform([batch_size, source_max_len], 0, p.shuffle_tok_range + 1) # Don't shuffle eos or padding shuffle_tok_range = tf.fill([batch_size, source_max_len], float(p.shuffle_tok_range)) shifted_paddings = tf.pad(batch.src.paddings[:, 1:], [[0, 0], [0, 1]], constant_values=1) noise = tf.where(tf.equal(shifted_paddings, 0), noise, shuffle_tok_range) indices = tf.broadcast_to(tf.range(source_max_len, dtype=tf.int32), [batch_size, source_max_len]) noisy_indices = tf.cast(indices, dtype=tf.float32) + noise permutations = tf.argsort(noisy_indices) stacked = tf.stack([batch.src.ids, permutations], axis=1) denoise_src_ids = tf.stack(tf.map_fn(lambda x: tf.gather(x[0], x[1]), stacked), axis=0) # Select tokens to drop with probability=p.drop_tok_prob random_drop_tok = tf.random.uniform([batch_size, source_max_len]) # Don't drop eos token is_keep_tok = tf.math.logical_or( tf.greater(random_drop_tok, p.drop_tok_prob), tf.equal(denoise_src_ids, self._src_tokenizer.eos_id)) denoise_src_ids = tf.ragged.boolean_mask( denoise_src_ids, is_keep_tok).to_tensor(default_value=0, shape=tf.shape(batch.src.ids)) denoise_src_paddings = tf.ragged.boolean_mask( batch.src.paddings, is_keep_tok).to_tensor(default_value=1, shape=tf.shape(batch.src.ids)) # Select tokens to blank with probability=p.blank_tok_prob # Don't blank eos token random_blank_tok = tf.random.uniform([batch_size, source_max_len]) shifted_paddings = tf.pad(denoise_src_paddings[:, 1:], [[0, 0], [0, 1]], constant_values=1) is_blank_tok = tf.math.logical_and( tf.less(random_blank_tok, p.blank_tok_prob), tf.equal(shifted_paddings, 0)) blank_id = tf.fill([batch_size, source_max_len], p.blank_id) denoise_src_ids = tf.where(is_blank_tok, blank_id, denoise_src_ids) # Select denoising task examples with probability=p.denoise_sent_prob random_uniform_sent = tf.random.uniform([batch_size]) is_denoise_sent = tf.math.logical_and( tf.less(random_uniform_sent, p.noise_sent_prob), IsSpecialExample(self._GetTaskIds(batch.src.source_ids[:, 0]), p.task_ids)) batch.src.ids = tf.where(is_denoise_sent, denoise_src_ids, batch.src.ids) batch.src.paddings = tf.where(is_denoise_sent, denoise_src_paddings, batch.src.paddings) batch.src.ids_indicator = 1 - batch.src.paddings batch.src.weights = batch.src.ids_indicator
def GreedySearchDecode(self, theta, encoder_outputs, init_beam_search_state=None, pre_beam_search_step_callback=None, post_beam_search_step_callback=None, max_steps=None): """Performs greedy-search based decoding. Args: theta: A NestedMap object containing weights' values of the decoder layer and its children layers. encoder_outputs: A NestedMap containing encoder outputs to be passed to the callbacks. init_beam_search_state: The `InitBeamSearchState` callback. Please refer to the class header comments for more details. pre_beam_search_step_callback: The `PreBeamSearchStepCallback` callback. Please refer to the class header comments for more details. post_beam_search_step_callback: The `PostBeamSearchStepCallback` callback. Please refer to the class header comments for more details. max_steps: maximum beam search steps. If None, use self.params.target_seq_len. Returns: A tuple (hyp_ids, hyp_lens, done_hyps). Note that num_hyps is same as src_batch_size. - hyp_ids: [num_hyps, max_step]. Hyps end with <eos> token if the <eos> token is encountered during search. - hyp_lens: [num_hyps]. - done_hyps: [num_hyps], whether or not an eos is encountered. """ p = self.params if max_steps is None: max_steps = p.target_seq_len initial_results, other_states = init_beam_search_state( theta, encoder_outputs, 1 # num_hyps_per_beam ) num_hyps = tf.shape(initial_results.log_probs)[0] if 'step_ids' in initial_results: # [num_hyps, 1] step_ids = tf.ensure_shape(initial_results.step_ids, [None, 1]) else: step_ids = tf.fill([num_hyps, 1], tf.constant(p.target_sos_id, dtype=tf.int32)) cur_step = tf.constant(0, dtype=tf.int32) done_hyps = inplace_ops.empty(shape=[num_hyps], dtype=tf.bool, init=True, name='done_hyps') hyp_lens = inplace_ops.empty(shape=[num_hyps], dtype=tf.int32, init=True, name='hyp_lens') hyp_ids = inplace_ops.empty(shape=[max_steps, num_hyps], dtype=tf.int32, init=True, name='hyp_ids') def LoopContinue(cur_step, unused_step_ids, unused_hyp_ids, unused_hyp_lens, done_hyps, unused_other_states_list): return tf.logical_and(cur_step < max_steps, tf.logical_not(tf.reduce_all(done_hyps))) def LoopBody(cur_step, step_ids, hyp_ids, hyp_lens, done_hyps, other_states_list): (cur_step, new_step_ids, hyp_ids, hyp_lens, done_hyps, new_other_states) = self._GreedySearchStep( theta, encoder_outputs, cur_step, step_ids, hyp_ids, hyp_lens, done_hyps, other_states.Pack(other_states_list), pre_beam_search_step_callback, post_beam_search_step_callback) return (cur_step, new_step_ids, hyp_ids, hyp_lens, done_hyps, new_other_states.Flatten()) flat_other_states = other_states.Flatten() _, _, final_hyp_ids, final_hyp_lens, final_done_hyps, _ = tf.while_loop( LoopContinue, LoopBody, loop_vars=(cur_step, step_ids, hyp_ids, hyp_lens, done_hyps, flat_other_states), parallel_iterations=10, back_prop=False, swap_memory=False, shape_invariants=(tf.TensorShape(cur_step.get_shape()), tf.TensorShape(step_ids.get_shape()), tf.TensorShape(hyp_ids.get_shape()), tf.TensorShape(hyp_lens.get_shape()), tf.TensorShape(done_hyps.get_shape()), _GetShapes(flat_other_states, none_shapes=True))) # transpose hyp_ids so it matches BeamSearchDecode's output final_hyp_ids = tf.transpose(final_hyp_ids) return final_hyp_ids, final_hyp_lens, final_done_hyps
def PerturbedLogProbs(): # STEP 1: Perform top-k filtering. This is done as a performance # optimization of avoiding sorting the entire `log_probs`, which is # prohibitively slow. top_k = tf.math.top_k(log_probs, k, sorted=True) # shape: [tgt_batch, k] top_k_log_probs = top_k.values # shape: [tgt_batch, k] top_k_ids = top_k.indices # STEP 2: Perform top-p filtering. # shape: [tgt_batch] top_p_threshold = encoder_outputs.stochastic_beam_search.top_p_threshold top_p_threshold = tf.clip_by_value(top_p_threshold, 0., 1.) top_p_threshold = TileForBeamAndFlatten(top_p_threshold) # shape: [tgt_batch, k] filtered_top_k_log_probs = _KeepTopP( top_k_log_probs, top_p_threshold) # STEP 3: Perturb cumulative log-probs. # shape: [tgt_batch, 1] last_cumulative_log_probs = states.cumulative_log_probs # shape: [tgt_batch, 1] last_perturbed_cumulative_log_probs = states.perturbed_cumulative_log_probs # Compute cumulative log-probs of the current step. # shape: [tgt_batch, k] cumulative_log_probs = (last_cumulative_log_probs + filtered_top_k_log_probs) # Perturb cumulative log-probs by Gumbel noises under the condition # that the max of the new perturbed log-probs is equal to # perturbed_cumulative_log_probs of the previous step. # shape: [tgt_batch, k] new_perturbed_cumulative_log_probs = _SampleGumbelWithMax( cumulative_log_probs, last_perturbed_cumulative_log_probs, encoder_outputs.stochastic_beam_search.seed, time_step, encoder_outputs.stochastic_beam_search.src_ids, encoder_outputs.stochastic_beam_search.src_paddings) # STEP 4: Compute updated log_probs. This step is necessary because # the output of PreBeamSearchStepCallback must be "per-step" # log-probs, whereas so far "cumulative" log-probs have been computed. # shape: [tgt_batch, k] updated_top_k_log_probs = ( new_perturbed_cumulative_log_probs - last_perturbed_cumulative_log_probs) # Convert to the shape [tgt_batch, vocab_size]. updated_log_probs = tf.fill( tf.shape(log_probs), tf.constant(LARGE_NEGATIVE_NUMBER, dtype=log_probs.dtype)) updated_log_probs = _BatchScatter(updated_log_probs, top_k_ids, updated_top_k_log_probs) return (updated_log_probs, py_utils.NestedMap( new_perturbed_cumulative_log_probs= new_perturbed_cumulative_log_probs, top_k_log_probs=top_k_log_probs, top_k_ids=top_k_ids, ))
def FProp(self, _, inputs): return tf.fill(tf.shape(inputs), 42)
def Step(recurrent_theta, state0, inputs): """Computes one decoder step.""" if p.use_recurrent: del inputs with tf.name_scope('single_sampler_step'): # Compute logits and states. bs_result, bs_state1 = pre_step_callback( decoder_theta, recurrent_theta.encoder_outputs, tf.expand_dims(state0.ids, 1), # [batch, 1]. state0.bs_state, p.num_hyps_per_beam, 0) # cur_step batch = tf.shape(bs_result.log_probs)[0] state1 = py_utils.NestedMap(timestep=state0.timestep + 1) state1.logits = bs_result.log_probs sample_logits = state1.logits # Perform Nucleus Sampling. Assumes logits are in (-1e10, 1e3). if p.nucleus_p < 1.0: max_logit = 1e3 min_logit = -1e10 sorted_logits = tf.sort(sample_logits, direction='DESCENDING', axis=-1) sorted_probs = tf.nn.softmax(sorted_logits) cumsum_probs = tf.math.cumsum(sorted_probs, axis=-1, exclusive=True) masked_logits = tf.where( cumsum_probs < p.nucleus_p, sorted_logits, tf.ones_like(sorted_logits) * max_logit) threshold = tf.math.reduce_min(masked_logits, axis=-1, keepdims=True) sample_logits = tf.where( sample_logits < threshold, tf.ones_like(sorted_logits) * min_logit, sample_logits) # Note that here, we retain the possibility of applying both top_k # and nucleus filtering. if p.top_k > 0: topk_logits, topk_ids = tf.math.top_k(sample_logits, k=p.top_k) sample_logits = tf.nn.log_softmax( topk_logits) if p.top_k_renormalize else topk_logits # Sample ids from logits. [batch]. ids = tf.reshape( tf.random.stateless_categorical( sample_logits / p.temperature, num_samples=1, seed=tf.stack( [recurrent_theta.random_seed, state0.timestep]), dtype=state0.ids.dtype, name='sample_next_id'), [batch]) state1.ids = tf.gather(topk_ids, ids, axis=1, batch_dims=1) if p.top_k > 0 else ids if 'is_last_chunk' in bs_result and p.target_eoc_id >= 0: state1.ids = tf.where( tf.math.logical_and( bs_result.is_last_chunk, tf.equal(state1.ids, p.target_eoc_id)), tf.fill(tf.shape(state1.ids), p.target_eos_id), state1.ids) state1.bs_state = post_step_callback( decoder_theta, recurrent_theta.encoder_outputs, state1.ids, bs_state1) if p.use_recurrent: return state1, py_utils.NestedMap() else: inputs.ids = inputs.ids.write(state0.timestep, state1.ids) inputs.logits = inputs.logits.write(state0.timestep, state1.logits) return (recurrent_theta, state1, inputs)
def Sample(self, decoder_theta, encoder_outputs, random_seed, init_state_callback, pre_step_callback, post_step_callback, init_step_ids=None): """Samples target sequences, one target sequence per source sequence. (Please see beam_search_helper.py for description of decoder callbacks.) Args: decoder_theta: A NestedMap object containing weights' values of the decoder layer and its children layers, to be passed to decoder callbacks. encoder_outputs: the outputs of the encoder, to be passed to callbacks. random_seed: a scalar int32 tensor representing the random seed. init_state_callback: decoder._InitBeamSearchStateCallback. pre_step_callback: decoder._PreBeamSearchStepCallback. post_step_callback: decoder._PostBeamSearchStepCallback. init_step_ids: [batch], optional init step ids, default to SOS. Returns: A NestedMap containing the following tensors - 'logits': [batch, max_target_length, vocab_size], representing the distribution from which target sequences are sampled. - 'ids': [batch, max_target_length] of int32, representing the target sequence ids, not including target_sos_id, but maybe ending with target_eos_id if end-of-sequence is reached before target_seq_len. - 'paddings': [batch, max_target_length] of 0/1, where 1 represents a padded timestep. """ p = self.params assert p.temperature > 0 assert p.top_k >= 0 assert p.num_hyps_per_beam >= 1 if getattr(encoder_outputs, 'segment_id', 1) is None: # Remove None values, which are not supported by recurrent. del encoder_outputs['segment_id'] # init_state_callback may modify 'encoder_outputs', e.g., by inserting # 'packed_src'. bs_result, bs_state = init_state_callback(decoder_theta, encoder_outputs, p.num_hyps_per_beam) # 'recurrent_theta' represents all cross-timestep information used by the # recurrent loop below, including layer theta and encoder outputs. recurrent_theta = py_utils.NestedMap(random_seed=random_seed, encoder_outputs=encoder_outputs) batch = tf.shape(bs_result.log_probs)[0] recurrent_state0 = py_utils.NestedMap( timestep=tf.zeros(shape=[], dtype=tf.int32), logits=bs_result.log_probs, # Start with target_sos_id. ids=init_step_ids if init_step_ids is not None else tf.fill( [batch], tf.cast(p.target_sos_id, tf.int32)), bs_state=bs_state) if p.use_recurrent: inputs = py_utils.NestedMap( dummy=tf.zeros([p.target_seq_len, batch])) else: inputs = py_utils.NestedMap( ids=tf.TensorArray(dtype=tf.int32, size=p.target_seq_len), logits=tf.TensorArray(dtype=bs_result.log_probs.dtype, size=p.target_seq_len), ) def Step(recurrent_theta, state0, inputs): """Computes one decoder step.""" if p.use_recurrent: del inputs with tf.name_scope('single_sampler_step'): # Compute logits and states. bs_result, bs_state1 = pre_step_callback( decoder_theta, recurrent_theta.encoder_outputs, tf.expand_dims(state0.ids, 1), # [batch, 1]. state0.bs_state, num_hyps_per_beam=p.num_hyps_per_beam) batch = tf.shape(bs_result.log_probs)[0] state1 = py_utils.NestedMap(timestep=state0.timestep + 1) state1.logits = bs_result.log_probs if p.top_k > 0: topk_logits, topk_ids = tf.math.top_k(state1.logits, k=p.top_k) sample_logits = tf.nn.log_softmax( topk_logits) if p.top_k_renormalize else topk_logits else: sample_logits = state1.logits # Sample ids from logits. [batch]. ids = tf.reshape( tf.random.stateless_categorical( sample_logits / p.temperature, num_samples=1, seed=tf.stack( [recurrent_theta.random_seed, state0.timestep]), dtype=state0.ids.dtype, name='sample_next_id'), [batch]) state1.ids = tf.gather(topk_ids, ids, axis=1, batch_dims=1) if p.top_k > 0 else ids if 'is_last_chunk' in bs_result and p.target_eoc_id >= 0: state1.ids = tf.where( tf.math.logical_and( bs_result.is_last_chunk, tf.equal(state1.ids, p.target_eoc_id)), tf.fill(tf.shape(state1.ids), p.target_eos_id), state1.ids) state1.bs_state = post_step_callback( decoder_theta, recurrent_theta.encoder_outputs, state1.ids, bs_state1) if p.use_recurrent: return state1, py_utils.NestedMap() else: inputs.ids = inputs.ids.write(state0.timestep, state1.ids) inputs.logits = inputs.logits.write(state0.timestep, state1.logits) return (recurrent_theta, state1, inputs) if p.use_recurrent: def StopFn(t, theta, state): del t, theta # Unused: this stop function only uses the state ids. return tf.equal(state.ids, p.target_eos_id) else: def StopFn(recurrent_theta, state, inputs): del recurrent_theta, inputs return tf.logical_not( tf.reduce_all(tf.equal(state.ids, p.target_eos_id))) if p.use_stop_fn: stop_fn = StopFn else: stop_fn = None if p.use_recurrent: accumulated_states, _ = recurrent.Recurrent( recurrent_theta, recurrent_state0, inputs, Step, stop_fn=stop_fn, allow_implicit_capture=True) else: loop_vars = (recurrent_theta, recurrent_state0, inputs) (_, _, accumulated_states) = tf.while_loop( StopFn, Step, loop_vars=loop_vars, shape_invariants=_GetShapes(loop_vars, none_shapes=True), back_prop=False, maximum_iterations=p.target_seq_len) accumulated_states.ids = accumulated_states.ids.stack() accumulated_states.logits = accumulated_states.logits.stack() result = py_utils.NestedMap(logits=tf.transpose( accumulated_states.logits, [1, 0, 2]), ids=tf.transpose(accumulated_states.ids)) result.paddings = tf.cast( _ComputePaddings(result.ids, p.target_eos_id), result.logits.dtype) # Force ids to be eos_id if the timestep is padded. result.ids = tf.where(tf.equal(result.paddings, 0), result.ids, tf.fill(tf.shape(result.ids), p.target_eos_id)) static_batch_size = bs_result.log_probs.shape[0] result.ids.set_shape([static_batch_size, p.target_seq_len]) result.paddings.set_shape([static_batch_size, p.target_seq_len]) return result