def search_step(state): """ Beam search step. """ # [batch_size * beam_size, vocab_size] log_probs = _calculate_log_probs( state=state, symbols_to_logits_fn=symbols_to_logits_fn, eos_id=eos_id, unk_id=unk_id, ensemble_weights=ensemble_weights) # masking out the EOS in the probability when decoding length < min_length eos_beam_bias = layer_utils.one_entry_bias( on_entry=eos_id, num_entries=log_probs.get_shape().as_list()[-1], on_value=compat.FLOAT_MIN, off_value=0., dtype=log_probs.dtype) eos_beam_bias = layer_utils.tile_tensor(eos_beam_bias, tf.shape(log_probs)[0], axis=0) log_probs = tf.cond( tf.less(state[_StateKeys.TIME_STEP], minimum_decode_length - 1), lambda: log_probs + eos_beam_bias, lambda: log_probs) # compute log probs and generate next token ids according to beam scores sample_ids, beam_ids, next_log_probs, next_length = _sample_next_word( state=state, log_probs=log_probs, beam_size=beam_size, length_penalty=length_penalty) # re-order beams by beam_ids next_predicted_ids = tf.gather(state[_StateKeys.PREDICTED_IDS], beam_ids) if padded_decode: next_predicted_ids = tf.transpose( tf.tensor_scatter_nd_update(tf.transpose(next_predicted_ids), [[state[_StateKeys.TIME_STEP]]], tf.expand_dims(sample_ids, axis=0))) else: next_predicted_ids = tf.concat( [next_predicted_ids, tf.expand_dims(sample_ids, axis=1)], axis=1) next_cache = tf.nest.map_structure(lambda x: tf.gather(x, beam_ids), state[_StateKeys.CACHE]) next_finished = tf.equal(eos_id, sample_ids) state.update({ _StateKeys.TIME_STEP: state[_StateKeys.TIME_STEP] + 1, _StateKeys.INPUT_IDS: sample_ids, _StateKeys.CACHE: next_cache, _StateKeys.FINISHED_FLAGS: next_finished, _StateKeys.LOG_PROBS: next_log_probs, _StateKeys.DECODING_LENGTH: next_length, _StateKeys.PREDICTED_IDS: next_predicted_ids }) return [state]
def test_fn_expand_tensor(): vocab_size = 10 eos_id = 9 batch_size = 3 beam_size = 4 batch_beam_size = batch_size * beam_size finished_beam_bias = tf1codebase_finished_beam_one_entry_bias( on_entry=eos_id, num_entries=vocab_size, dtype=tf.float32) assert (tf1codebase_expand_to_beam_size(finished_beam_bias, batch_beam_size, axis=0).numpy() == tile_tensor(finished_beam_bias, batch_beam_size, axis=0).numpy()).all()
def create_decoding_internal_cache(self, encoder_outputs, encoder_inputs_padding, is_inference=False, decode_padded_length=None): """ Creates internal cache for decoding. Args: encoder_outputs: The output tensor from encoder with shape [batch_size, max_input_length, hidden_size]. encoder_inputs_padding: A float tensor with shape [batch_size, max_length], indicating the padding positions of `encoder_output`, where 1.0 for padding and 0.0 for non-padding. is_inference: A boolean scalar, whether in inference mode or not. decode_padded_length: The maximum decoding length when inference, for creating static-shape cache. Returns: `cache`, a dictionary containing static(e.g. encoder hidden states for attention) and dynamic(e.g. transformer decoding cache) tensors used during decoding and will be passed to `call()`. Note that, the dynamic tensors must store in cache["decoding_states"] for beam search use. """ # [batch_size, max_length], FLOAT_MIN for padding, 0.0 for non-padding if is_inference: decoding_states = {} batch_size = tf.shape(encoder_outputs)[0] # initialize decoder self attention keys/values for lid, layer in enumerate(self._stacking_layers): # Ensure shape invariance for tf.while_loop. decoding_states[ f"layer_{lid}"] = layer.create_decoding_internal_cache( decode_padded_length) decoding_states = tf.nest.map_structure( lambda ts: tile_tensor(ts, batch_size, axis=0), decoding_states) for lid, layer in enumerate(self._stacking_layers): decoding_states[f"layer_{lid}"].update( layer.memorize_memory(encoder_outputs)) else: decoding_states = None cache = dict(decoding_states=decoding_states) if encoder_inputs_padding is not None: cache["memory"] = encoder_outputs cache["memory_bias"] = layer_utils.input_padding_to_bias( encoder_inputs_padding) return cache
def incremental_encode(self, inputs, cache, time=None): """ Encoding function for streaming input. Args: inputs: The embedded input at time t, a float tensor with shape [batch, embedding_dim] or [batch, length, embedding_dim] cache: A dict containing cached tensors. time: The start time of the inputs Returns: The incremented encoder output with shape [batch, t+1, dim], and the updated cache dict. """ params = self.get_config() assert params["attention_monotonic"], ( "function `incremental_encode` only available when attention_monotonic=True" ) if cache is None: cache = {} if cache is not None and len(cache) == 0: batch_size = tf.shape(inputs)[0] for lid in range(params["num_layers"]): cache[f"layer_{lid}"] = self._stacking_layers[ lid].create_internal_cache() cache = tf.nest.map_structure( lambda ts: layer_utils.tile_tensor(ts, batch_size, axis=0), cache) if inputs.get_shape().ndims == 2: x = tf.expand_dims(inputs, axis=1) x_bias = None else: x = inputs if time is None: time = 0 x_bias = layer_utils.lower_triangle_attention_bias( time + tf.shape(x)[1])[:, :, -tf.shape(x)[1]:] for idx, layer in enumerate(self._stacking_layers): layer_cache = None if cache is None else cache[f"layer_{idx}"] x = layer(x, x_bias, layer_cache, is_training=False) outputs = x if not params["post_normalize"]: outputs = self.quant(self._output_norm_layer(x), name="output_ln") return outputs, cache
def _calculate_log_probs(state, symbols_to_logits_fn, eos_id, unk_id, ensemble_weights=None): """ Calculates one-step log probability. Finished beam will be masked and UNK will be masked if strategy == BASIC_NO_UNK. Args: state: A dictionary containing current state of beam search. symbols_to_logits_fn: eos_id: An int scalar, indicating the end-of-sentence token id, used to determine when a sequence has finished. unk_id: An int scalar, indicating the unknown token id. ensemble_weights: A list of float values, indicating the weights of each submodel's probability. Returns: A float tensor with the same shape as `logits`. """ logits = symbols_to_logits_fn(state[_StateKeys.INPUT_IDS], state[_StateKeys.CACHE], state[_StateKeys.TIME_STEP]) logits = tf.nest.flatten(logits) vocab_size = logits[0].get_shape().as_list()[-1] batch_beam_size = tf.shape(logits[0])[0] if len(logits) == 1: # [batch_size * beam_size, target_vocab_size] log_probs = tf.nn.log_softmax(logits[0]) else: probs = tf.nest.map_structure( lambda x: tf.expand_dims(tf.reshape(tf.nn.softmax(x), shape=[-1]), axis=0), logits) original_shape = tf.shape(logits[0]) # [num_models, xxx] probs = tf.concat(probs, axis=0) # [1, num_models] weights = tf.expand_dims(tf.convert_to_tensor(ensemble_weights, dtype=probs.dtype), axis=0) probs = tf.matmul(weights, probs) log_probs = tf.math.log(tf.reshape(probs, original_shape)) # [batch_size * beam_size,] prev_finished_float = tf.cast(state[_StateKeys.FINISHED_FLAGS], log_probs.dtype) # mask the finished beam except only one entrance (target_eos_id) # [target_vocab_size, ]: [float_min, float_min, float_min, ..., 0] # this forces the beam with EOS continue to generate EOS finished_beam_bias = layer_utils.one_entry_bias(on_entry=eos_id, num_entries=vocab_size, on_value=0., off_value=compat.FLOAT_MIN, dtype=log_probs.dtype) # [batch_size * beam_size, target_vocab_size]: outer product finished_beam_bias = layer_utils.tile_tensor(finished_beam_bias, batch_beam_size, axis=0) finished_beam_bias *= tf.expand_dims(prev_finished_float, 1) # compute new probs, with finished flags & mask log_probs = log_probs * tf.expand_dims(1. - prev_finished_float, 1) + finished_beam_bias # we should use the trick for masking out the UNK in the probability if unk_id is not None: unk_beam_bias = layer_utils.one_entry_bias(on_entry=unk_id, num_entries=vocab_size, on_value=compat.FLOAT_MIN, off_value=0., dtype=log_probs.dtype) unk_beam_bias = layer_utils.tile_tensor(unk_beam_bias, batch_beam_size, axis=0) log_probs += unk_beam_bias return log_probs