コード例 #1
0
    def search_step(state):
        """ Beam search step. """
        # [batch_size * beam_size, vocab_size]
        log_probs = _calculate_log_probs(
            state=state,
            symbols_to_logits_fn=symbols_to_logits_fn,
            eos_id=eos_id,
            unk_id=unk_id,
            ensemble_weights=ensemble_weights)
        # masking out the EOS in the probability when decoding length < min_length
        eos_beam_bias = layer_utils.one_entry_bias(
            on_entry=eos_id,
            num_entries=log_probs.get_shape().as_list()[-1],
            on_value=compat.FLOAT_MIN,
            off_value=0.,
            dtype=log_probs.dtype)
        eos_beam_bias = layer_utils.tile_tensor(eos_beam_bias,
                                                tf.shape(log_probs)[0],
                                                axis=0)
        log_probs = tf.cond(
            tf.less(state[_StateKeys.TIME_STEP], minimum_decode_length - 1),
            lambda: log_probs + eos_beam_bias, lambda: log_probs)

        # compute log probs and generate next token ids according to beam scores
        sample_ids, beam_ids, next_log_probs, next_length = _sample_next_word(
            state=state,
            log_probs=log_probs,
            beam_size=beam_size,
            length_penalty=length_penalty)
        # re-order beams by beam_ids
        next_predicted_ids = tf.gather(state[_StateKeys.PREDICTED_IDS],
                                       beam_ids)
        if padded_decode:
            next_predicted_ids = tf.transpose(
                tf.tensor_scatter_nd_update(tf.transpose(next_predicted_ids),
                                            [[state[_StateKeys.TIME_STEP]]],
                                            tf.expand_dims(sample_ids,
                                                           axis=0)))
        else:
            next_predicted_ids = tf.concat(
                [next_predicted_ids,
                 tf.expand_dims(sample_ids, axis=1)],
                axis=1)
        next_cache = tf.nest.map_structure(lambda x: tf.gather(x, beam_ids),
                                           state[_StateKeys.CACHE])
        next_finished = tf.equal(eos_id, sample_ids)
        state.update({
            _StateKeys.TIME_STEP: state[_StateKeys.TIME_STEP] + 1,
            _StateKeys.INPUT_IDS: sample_ids,
            _StateKeys.CACHE: next_cache,
            _StateKeys.FINISHED_FLAGS: next_finished,
            _StateKeys.LOG_PROBS: next_log_probs,
            _StateKeys.DECODING_LENGTH: next_length,
            _StateKeys.PREDICTED_IDS: next_predicted_ids
        })
        return [state]
コード例 #2
0
def test_fn_expand_tensor():
    vocab_size = 10
    eos_id = 9
    batch_size = 3
    beam_size = 4
    batch_beam_size = batch_size * beam_size
    finished_beam_bias = tf1codebase_finished_beam_one_entry_bias(
        on_entry=eos_id, num_entries=vocab_size, dtype=tf.float32)
    assert (tf1codebase_expand_to_beam_size(finished_beam_bias,
                                            batch_beam_size, axis=0).numpy()
            == tile_tensor(finished_beam_bias,
                           batch_beam_size, axis=0).numpy()).all()
コード例 #3
0
    def create_decoding_internal_cache(self,
                                       encoder_outputs,
                                       encoder_inputs_padding,
                                       is_inference=False,
                                       decode_padded_length=None):
        """ Creates internal cache for decoding.

        Args:
            encoder_outputs: The output tensor from encoder
                with shape [batch_size, max_input_length, hidden_size].
            encoder_inputs_padding: A float tensor with shape [batch_size, max_length],
                indicating the padding positions of `encoder_output`, where 1.0 for
                padding and 0.0 for non-padding.
            is_inference: A boolean scalar, whether in inference mode or not.
            decode_padded_length: The maximum decoding length when inference, for creating
                static-shape cache.

        Returns:
            `cache`, a dictionary containing static(e.g. encoder hidden states
            for attention) and dynamic(e.g. transformer decoding cache) tensors used
            during decoding and will be passed to `call()`. Note that, the dynamic
            tensors must store in cache["decoding_states"] for beam search use.
        """
        # [batch_size, max_length], FLOAT_MIN for padding, 0.0 for non-padding
        if is_inference:
            decoding_states = {}
            batch_size = tf.shape(encoder_outputs)[0]
            # initialize decoder self attention keys/values
            for lid, layer in enumerate(self._stacking_layers):
                # Ensure shape invariance for tf.while_loop.
                decoding_states[
                    f"layer_{lid}"] = layer.create_decoding_internal_cache(
                        decode_padded_length)
            decoding_states = tf.nest.map_structure(
                lambda ts: tile_tensor(ts, batch_size, axis=0),
                decoding_states)
            for lid, layer in enumerate(self._stacking_layers):
                decoding_states[f"layer_{lid}"].update(
                    layer.memorize_memory(encoder_outputs))
        else:
            decoding_states = None
        cache = dict(decoding_states=decoding_states)
        if encoder_inputs_padding is not None:
            cache["memory"] = encoder_outputs
            cache["memory_bias"] = layer_utils.input_padding_to_bias(
                encoder_inputs_padding)
        return cache
コード例 #4
0
ファイル: transformer_encoder.py プロジェクト: lileicc/neurst
    def incremental_encode(self, inputs, cache, time=None):
        """ Encoding function for streaming input.

        Args:
            inputs: The embedded input at time t, a float tensor with shape [batch, embedding_dim]
                or [batch, length, embedding_dim]
            cache: A dict containing cached tensors.
            time: The start time of the inputs

        Returns: The incremented encoder output with shape [batch, t+1, dim],
            and the updated cache dict.
        """
        params = self.get_config()
        assert params["attention_monotonic"], (
            "function `incremental_encode` only available when attention_monotonic=True"
        )
        if cache is None:
            cache = {}
        if cache is not None and len(cache) == 0:
            batch_size = tf.shape(inputs)[0]
            for lid in range(params["num_layers"]):
                cache[f"layer_{lid}"] = self._stacking_layers[
                    lid].create_internal_cache()
            cache = tf.nest.map_structure(
                lambda ts: layer_utils.tile_tensor(ts, batch_size, axis=0),
                cache)
        if inputs.get_shape().ndims == 2:
            x = tf.expand_dims(inputs, axis=1)
            x_bias = None
        else:
            x = inputs
            if time is None:
                time = 0
            x_bias = layer_utils.lower_triangle_attention_bias(
                time + tf.shape(x)[1])[:, :, -tf.shape(x)[1]:]
        for idx, layer in enumerate(self._stacking_layers):
            layer_cache = None if cache is None else cache[f"layer_{idx}"]
            x = layer(x, x_bias, layer_cache, is_training=False)
        outputs = x
        if not params["post_normalize"]:
            outputs = self.quant(self._output_norm_layer(x), name="output_ln")
        return outputs, cache
コード例 #5
0
def _calculate_log_probs(state,
                         symbols_to_logits_fn,
                         eos_id,
                         unk_id,
                         ensemble_weights=None):
    """ Calculates one-step log probability.

    Finished beam will be masked and UNK will be masked
    if strategy == BASIC_NO_UNK.

    Args:
        state: A dictionary containing current state of beam search.
        symbols_to_logits_fn:
        eos_id: An int scalar, indicating the end-of-sentence token id, used to determine when a
            sequence has finished.
        unk_id: An int scalar, indicating the unknown token id.
        ensemble_weights: A list of float values, indicating the weights of each submodel's probability.

    Returns:
        A float tensor with the same shape as `logits`.
    """
    logits = symbols_to_logits_fn(state[_StateKeys.INPUT_IDS],
                                  state[_StateKeys.CACHE],
                                  state[_StateKeys.TIME_STEP])
    logits = tf.nest.flatten(logits)
    vocab_size = logits[0].get_shape().as_list()[-1]
    batch_beam_size = tf.shape(logits[0])[0]
    if len(logits) == 1:
        # [batch_size * beam_size, target_vocab_size]
        log_probs = tf.nn.log_softmax(logits[0])
    else:
        probs = tf.nest.map_structure(
            lambda x: tf.expand_dims(tf.reshape(tf.nn.softmax(x), shape=[-1]),
                                     axis=0), logits)
        original_shape = tf.shape(logits[0])
        # [num_models, xxx]
        probs = tf.concat(probs, axis=0)
        # [1, num_models]
        weights = tf.expand_dims(tf.convert_to_tensor(ensemble_weights,
                                                      dtype=probs.dtype),
                                 axis=0)
        probs = tf.matmul(weights, probs)
        log_probs = tf.math.log(tf.reshape(probs, original_shape))

    # [batch_size * beam_size,]
    prev_finished_float = tf.cast(state[_StateKeys.FINISHED_FLAGS],
                                  log_probs.dtype)
    # mask the finished beam except only one entrance (target_eos_id)
    #   [target_vocab_size, ]: [float_min, float_min, float_min, ..., 0]
    #   this forces the beam with EOS continue to generate EOS
    finished_beam_bias = layer_utils.one_entry_bias(on_entry=eos_id,
                                                    num_entries=vocab_size,
                                                    on_value=0.,
                                                    off_value=compat.FLOAT_MIN,
                                                    dtype=log_probs.dtype)
    # [batch_size * beam_size, target_vocab_size]: outer product
    finished_beam_bias = layer_utils.tile_tensor(finished_beam_bias,
                                                 batch_beam_size,
                                                 axis=0)
    finished_beam_bias *= tf.expand_dims(prev_finished_float, 1)
    # compute new probs, with finished flags & mask
    log_probs = log_probs * tf.expand_dims(1. - prev_finished_float,
                                           1) + finished_beam_bias

    # we should use the trick for masking out the UNK in the probability
    if unk_id is not None:
        unk_beam_bias = layer_utils.one_entry_bias(on_entry=unk_id,
                                                   num_entries=vocab_size,
                                                   on_value=compat.FLOAT_MIN,
                                                   off_value=0.,
                                                   dtype=log_probs.dtype)
        unk_beam_bias = layer_utils.tile_tensor(unk_beam_bias,
                                                batch_beam_size,
                                                axis=0)
        log_probs += unk_beam_bias
    return log_probs