Esempio n. 1
0
    def body_infer(time, inputs, caches, outputs_tas, finished,
                   log_probs, lengths, bs_stat_ta, predicted_ids):
        """Internal while_loop body.

        Args:
          time: Scalar int32 Tensor.
          inputs: A list of inputs Tensors.
          caches: A dict of decoder states.
          outputs_tas: A list of TensorArrays.
          finished: A bool tensor (keeping track of what's finished).
          log_probs: The log probability Tensor.
          lengths: The decoding length Tensor.
          bs_stat_ta: structure of TensorArray.
          predicted_ids: A Tensor.

        Returns:
          `(time + 1, next_inputs, next_caches, next_outputs_tas,
          next_finished, next_log_probs, next_lengths, next_infer_status_ta)`.
        """

        # step decoder
        def _decoding(_decoder, _input, _cache, _decoder_output_remover,
                      _outputs_ta, _outputs_to_logits_fn):
            with tf.variable_scope(_decoder.name):
                _output, _next_cache = _decoder.step(_input, _cache)
                _decoder_top_features = _decoder.merge_top_features(_output)
            _ta = nest.map_structure(lambda _ta_ms, _output_ms: _ta_ms.write(time, _output_ms),
                                     _outputs_ta, _decoder_output_remover.apply(_output))
            _logit = _outputs_to_logits_fn(_decoder_top_features)
            return _output, _next_cache, _ta, _logit

        outputs, next_caches, next_outputs_tas, logits = repeat_n_times(
            num_models, _decoding,
            decoders, inputs, caches, decoder_output_removers,
            outputs_tas, outputs_to_logits_fns)

        # sample next symbols
        sample_ids, beam_ids, next_log_probs, next_lengths \
            = helper.sample_symbols(logits, log_probs, finished, lengths, time=time)

        for c in next_caches:
            c["decoding_states"] = gather_states(c["decoding_states"], beam_ids)

        infer_status = BeamSearchStateSpec(
            log_probs=next_log_probs,
            beam_ids=beam_ids)
        bs_stat_ta = nest.map_structure(lambda ta, out: ta.write(time, out),
                                        bs_stat_ta, infer_status)
        predicted_ids = gather_states(tf.reshape(predicted_ids, [-1, time + 1]), beam_ids)
        next_predicted_ids = tf.concat([predicted_ids, tf.expand_dims(sample_ids, axis=1)], axis=1)
        next_predicted_ids = tf.reshape(next_predicted_ids, [-1])
        next_predicted_ids.set_shape([None])
        next_finished, next_input_symbols = helper.next_symbols(time=time, sample_ids=sample_ids)
        next_inputs = repeat_n_times(num_models, target_to_embedding_fns,
                                     next_input_symbols, time + 1)
        next_finished = tf.logical_or(next_finished, finished)

        return time + 1, next_inputs, next_caches, next_outputs_tas, \
               next_finished, next_log_probs, next_lengths, bs_stat_ta, \
               next_predicted_ids
Esempio n. 2
0
    def body_infer(time, inputs, caches, outputs_tas, finished,
                   log_probs, lengths, bs_stat_ta, predicted_ids):
        """Internal while_loop body.

        Args:
          time: Scalar int32 Tensor.
          inputs: A list of inputs Tensors.
          caches: A dict of decoder states.
          outputs_tas: A list of TensorArrays.
          finished: A bool tensor (keeping track of what's finished).
          log_probs: The log probability Tensor.
          lengths: The decoding length Tensor.
          bs_stat_ta: structure of TensorArray.
          predicted_ids: A Tensor.

        Returns:
          `(time + 1, next_inputs, next_caches, next_outputs_tas,
          next_finished, next_log_probs, next_lengths, next_infer_status_ta)`.
        """

        # step decoder
        def _decoding(_decoder, _input, _cache, _decoder_output_remover,
                      _outputs_ta, _outputs_to_logits_fn):
            with tf.variable_scope(_decoder.name):
                _output, _next_cache = _decoder.step(_input, _cache)
                _decoder_top_features = _decoder.merge_top_features(_output)
            _ta = nest.map_structure(lambda _ta_ms, _output_ms: _ta_ms.write(time, _output_ms),
                                     _outputs_ta, _decoder_output_remover.apply(_output))
            _logit = _outputs_to_logits_fn(_decoder_top_features)
            return _output, _next_cache, _ta, _logit

        outputs, next_caches, next_outputs_tas, logits = repeat_n_times(
            num_models, _decoding,
            decoders, inputs, caches, decoder_output_removers,
            outputs_tas, outputs_to_logits_fns)

        # sample next symbols
        sample_ids, beam_ids, next_log_probs, next_lengths \
            = helper.sample_symbols(logits, log_probs, finished, lengths, time=time)

        for c in next_caches:
            c["decoding_states"] = gather_states(c["decoding_states"], beam_ids)

        infer_status = BeamSearchStateSpec(
            log_probs=next_log_probs,
            beam_ids=beam_ids)
        bs_stat_ta = nest.map_structure(lambda ta, out: ta.write(time, out),
                                        bs_stat_ta, infer_status)
        predicted_ids = gather_states(tf.reshape(predicted_ids, [-1, time + 1]), beam_ids)
        next_predicted_ids = tf.concat([predicted_ids, tf.expand_dims(sample_ids, axis=1)], axis=1)
        next_predicted_ids = tf.reshape(next_predicted_ids, [-1])
        next_predicted_ids.set_shape([None])
        next_finished, next_input_symbols = helper.next_symbols(time=time, sample_ids=sample_ids)
        next_inputs = repeat_n_times(num_models, target_to_embedding_fns,
                                     next_input_symbols, time + 1)
        next_finished = tf.logical_or(next_finished, finished)

        return time + 1, next_inputs, next_caches, next_outputs_tas, \
               next_finished, next_log_probs, next_lengths, bs_stat_ta, \
               next_predicted_ids
Esempio n. 3
0
    def body_traininfer(time, inputs, cache, outputs_ta, finished, *args):
        """Internal while_loop body.

        Args:
          time: scalar int32 Tensor.
          inputs: The inputs Tensor.
          cache: The decoder states.
          outputs_ta: structure of TensorArray.
          finished: A bool tensor (keeping track of what's finished).
          args: The log_probs, lengths, infer_status for mode==INFER.
        Returns:
          `(time + 1, next_inputs, next_cache, outputs_ta,
          next_finished, *args)`.
        """
        with tf.variable_scope(decoder.name):
            outputs, next_cache = decoder.step(inputs, cache)
        outputs_ta = nest.map_structure(lambda ta, out: ta.write(time, out),
                                        outputs_ta,
                                        decoder_output_remover.apply(outputs))
        inner_loop_vars = [time + 1, None, None, outputs_ta, None]
        sample_ids = None
        if decoder.mode == ModeKeys.INFER:
            log_probs, lengths = args[0], args[1]
            bs_stat_ta = args[2]
            predicted_ids = args[3]
            logits = _compute_logits(decoder, target_modality, outputs)
            # sample next symbols
            sample_ids, beam_ids, next_log_probs, next_lengths \
                = helper.sample_symbols(logits, log_probs, finished, lengths, time=time)
            predicted_ids = gather_states(
                tf.reshape(predicted_ids, [-1, time + 1]), beam_ids)

            next_cache["decoding_states"] = gather_states(
                next_cache["decoding_states"], beam_ids)
            bs_stat = BeamSearchStateSpec(log_probs=next_log_probs,
                                          beam_ids=beam_ids)
            bs_stat_ta = nest.map_structure(
                lambda ta, out: ta.write(time, out), bs_stat_ta, bs_stat)
            next_predicted_ids = tf.concat(
                [predicted_ids,
                 tf.expand_dims(sample_ids, axis=1)], axis=1)
            next_predicted_ids = tf.reshape(next_predicted_ids, [-1])
            next_predicted_ids.set_shape([None])
            inner_loop_vars.extend(
                [next_log_probs, next_lengths, bs_stat_ta, next_predicted_ids])

        next_finished, next_input_symbols = helper.next_symbols(
            time=time, sample_ids=sample_ids)
        next_inputs = _embed_words(target_modality, next_input_symbols,
                                   time + 1)

        next_finished = tf.logical_or(next_finished, finished)
        inner_loop_vars[1] = next_inputs
        inner_loop_vars[2] = next_cache
        inner_loop_vars[4] = next_finished
        return inner_loop_vars
Esempio n. 4
0
    def body_traininfer(time, inputs, decoder_states, outputs_ta,
                        finished, *args):
        """Internal while_loop body.

        Args:
          time: scalar int32 Tensor.
          inputs: The inputs Tensor.
          decoder_states: The decoder states.
          outputs_ta: structure of TensorArray.
          finished: A bool tensor (keeping track of what's finished).
          args: The log_probs, lengths, infer_status for mode==INFER.
        Returns:
          `(time + 1, next_inputs, next_decoder_states, outputs_ta,
          next_finished, *args)`.
        """
        with tf.variable_scope(decoder.name):
            inputs = inputs_preprocessing_fn(time, inputs)
            outputs, next_decoder_states = decoder.step(inputs, decoder_states, decoding_params)
        outputs_ta = nest.map_structure(lambda ta, out: ta.write(time, out),
                                        outputs_ta, decoder_output_remover.apply(outputs))
        inner_loop_vars = [time + 1, None, None, outputs_ta, None]
        sample_ids = None
        prev_inputs = inputs
        if decoder.mode == ModeKeys.INFER:
            log_probs, lengths = args[0], args[1]
            infer_status_ta = args[2]
            logits = _compute_logits(decoder, target_modality, outputs)
            # sample next symbols
            sample_ids, beam_ids, next_log_probs, next_lengths \
                = helper.sample_symbols(logits, log_probs, finished, lengths, time=time)

            next_decoder_states = gather_states(next_decoder_states, beam_ids)
            prev_inputs = gather_states(inputs, beam_ids)
            infer_status = BeamSearchStateSpec(
                log_probs=next_log_probs,
                predicted_ids=sample_ids,
                beam_ids=beam_ids,
                lengths=next_lengths)
            infer_status_ta = nest.map_structure(lambda ta, out: ta.write(time, out),
                                                 infer_status_ta, infer_status)
            inner_loop_vars.extend([next_log_probs, next_lengths, infer_status_ta])

        next_finished, next_input_symbols = helper.next_symbols(time=time, sample_ids=sample_ids)
        next_inputs = _embed_words(target_modality, next_input_symbols, time + 1)
        with tf.variable_scope(decoder.name):
            next_inputs = inputs_postprocessing_fn(prev_inputs, next_inputs)

        next_finished = tf.logical_or(next_finished, finished)
        inner_loop_vars[1] = next_inputs
        inner_loop_vars[2] = next_decoder_states
        inner_loop_vars[4] = next_finished
        return inner_loop_vars
Esempio n. 5
0
    def body_infer(time, inputs, caches, outputs_tas, finished, log_probs,
                   lengths, infer_status_ta):
        """Internal while_loop body.

        Args:
          time: Scalar int32 Tensor.
          inputs: A list of inputs Tensors.
          caches: A dict of decoder states.
          outputs_tas: A list of TensorArrays.
          finished: A bool tensor (keeping track of what's finished).
          log_probs: The log probability Tensor.
          lengths: The decoding length Tensor.
          infer_status_ta: structure of TensorArray.

        Returns:
          `(time + 1, next_inputs, next_caches, next_outputs_tas,
          next_finished, next_log_probs, next_lengths, next_infer_status_ta)`.
        """
        # step decoder
        outputs = []
        next_caches = []
        for dec, inp, cache in zip(decoders, inputs, caches):
            with tf.variable_scope(dec.name):
                out, next_cache = dec.step(inp, cache)
                outputs.append(out)
                next_caches.append(next_cache)
        next_outputs_tas = []
        for out_ta, out, rem in zip(outputs_tas, outputs,
                                    decoder_output_removers):
            ta = nest.map_structure(lambda ta, out: ta.write(time, out),
                                    out_ta, rem.apply(out))
            next_outputs_tas.append(ta)
        logits = []
        for dec, modality, out in zip(decoders, target_modalities, outputs):
            logits.append(_compute_logits(dec, modality, out))
        # sample next symbols
        sample_ids, beam_ids, next_log_probs, next_lengths \
            = helper.sample_symbols(logits, log_probs, finished, lengths, time=time)

        for c in next_caches:
            c["decoding_states"] = gather_states(c["decoding_states"],
                                                 beam_ids)

        infer_status = BeamSearchStateSpec(log_probs=next_log_probs,
                                           predicted_ids=sample_ids,
                                           beam_ids=beam_ids,
                                           lengths=next_lengths)
        infer_status_ta = nest.map_structure(
            lambda ta, out: ta.write(time, out), infer_status_ta, infer_status)
        next_finished, next_input_symbols = helper.next_symbols(
            time=time, sample_ids=sample_ids)
        next_inputs = nest.map_structure(
            lambda modality: _embed_words(modality, next_input_symbols, time +
                                          1), target_modalities)
        next_finished = tf.logical_or(next_finished, finished)

        return time + 1, next_inputs, next_caches, next_outputs_tas, \
               next_finished, next_log_probs, next_lengths, infer_status_ta
def sample_symbols_new(logits, log_probs, finished, lengths, time):
    """
    :param logits: [batch_size * beam_size, target_vocab_size]
    :param log_probs: [batch_size * beam_size, ]
    :param finished: [batch_size * beam_size, ]
    :param lengths: decoding length [batch_size * beam_size, ]
    :param time:
    :return:
    """

    # [batch_size * beam_size,]
    prev_finished_float = math_ops.to_float(finished)
    # [batch_size * beam_size, ]
    prev_log_probs = log_probs
    # [batch_size * beam_size, target_vocab_size]
    probs = advanced_log_softmax(logits)  # negative

    # mask the finished beam except only one entrance (target_eos_id)
    #   [target_vocab_size, ]: [float_min, float_min, float_min, ..., 0]
    #   this forces the beam with EOS continue to generate EOS
    finished_beam_bias = finished_beam_one_entry_bias(on_entry=eos_id,
                                                      num_entries=vocab_size)
    # [batch_size * beam_size, target_vocab_size]: outer product
    finished_beam_bias = expand_to_beam_size(finished_beam_bias,
                                             beam_size * batch_size,
                                             axis=0)
    finished_beam_bias *= array_ops.expand_dims(prev_finished_float, 1)
    # compute new probs, with finished flags & mask
    probs = probs * array_ops.expand_dims(1. - prev_finished_float,
                                          1) + finished_beam_bias

    # [batch_size * beam_size, target_vocab_size]
    # compute new log_probs
    log_probs = probs + array_ops.expand_dims(prev_log_probs, 1)
    # new decoding length: [batch_size * beam_size]
    lengths = lengths + 1 - math_ops.to_int32(finished)
    # compute beam score
    #  length_penalty: [batch_size * beam_size,]
    length_penalty = math_ops.pow(((5.0 + math_ops.to_float(lengths)) / 6.0),
                                  -alpha)
    scores = log_probs * array_ops.expand_dims(length_penalty, axis=1)

    # flatten
    # [batch_size, beam_size * target_vocab_size]
    scores = array_ops.reshape(array_ops.reshape(scores, [-1]),
                               [batch_size, -1])
    ret_log_probs = array_ops.reshape(array_ops.reshape(log_probs, [-1]),
                                      [batch_size, -1])

    scores_flat = control_flow_ops.cond(
        ops.convert_to_tensor(time) > 0,
        lambda: scores,  # time > 0: all
        lambda: array_ops.slice(scores, [0, 0], [-1, vocab_size])
    )  # time = 0: first logits in each batch

    # [batch_size, beam_size] will restore top live_k
    sample_scores, sample_ids = nn_ops.top_k(scores_flat, k=beam_size)
    ret_sample_ids = array_ops.reshape(sample_ids, [-1])
    # flatten: [batch_size * beam_size,]
    sample_ids = array_ops.reshape(sample_ids, [-1])
    # because we do topk to scores with dim:[batch, beam * vocab]
    #   we need to cover the true word ids
    word_ids = math_ops.mod(sample_ids, vocab_size)

    # beam ids should be adjusted according to batch_size
    #  batch_pos, [batch_size, beam_size]: [[0, 0, ...], [1, 1,...], [batch_size,...] ]
    batch_pos = compute_batch_indices(batch_size, beam_size)

    # compute new beam_ids, [batch_size * beam_size, ]
    beam_ids = math_ops.div(sample_ids, vocab_size) \
               + array_ops.reshape(batch_pos * beam_size, [-1])

    # we need to recover log_probs from score
    # flatten sample_scores: [batch_size * beam_size,]
    sample_scores_flatten = array_ops.reshape(sample_scores, [-1])
    # gather each length penalty
    length_penalty = gather_states(length_penalty, beam_ids)
    # recover log probabilities
    next_log_probs = sample_scores_flatten / length_penalty
    # gather states according to beam_ids
    next_lengths = gather_states(lengths, beam_ids)

    # [batch_size * beam_size * vocab_size, ]
    log_probs_flat = array_ops.reshape(log_probs, [-1])
    log_probs_index = array_ops.reshape(
        batch_pos, [-1]) * beam_size * vocab_size + sample_ids
    next_log_probs = array_ops.gather(log_probs_flat, log_probs_index)

    return word_ids, beam_ids, next_log_probs, next_lengths, ret_log_probs, ret_sample_ids, length_penalty
Esempio n. 7
0
    def sample_symbols(self, logits, log_probs, finished, lengths, time):
        """ Samples symbols and returns it.

        Args:
            logits: The logits Tensor with shape [beam_size * batch_size, vocab_size],
              or a list of logits Tensors.
            log_probs: Accumulated log probabilities, a float32 Tensor with shape
              [beam_size * batch_size, ].
            finished: Finished flag of each beam in each batch, a bool Tensor with
              shape [beam_size * batch_size, ].
            lengths: The length of each beam in each batch, a int32 Tensor with
              shape [beam_size * batch_size, ].
            time: A int32 Scalar, the current time.

        Returns: A tuple `(word_ids, beam_ids, next_log_probs, next_lengths)`, where
          `words_ids` is the ids of sampled word symbols; `beam_ids` indicates the index
          of beam which the symbol at the position is from; `next_log_probs` is the accumulated
          log probabilities of each beam; `next_lengths` is the decoding lengths of
          each beam.
          All of the Tensors have shape [batch_size * beam_size, ].
        """
        # [batch_size * beam_size,]
        prev_finished_float = tf.to_float(finished)
        # [batch_size * beam_size, ]
        prev_log_probs = log_probs
        # [batch_size * beam_size, target_vocab_size]
        probs = self._compute_log_probs(logits)

        # mask the finished beam except only one entrance (target_eos_id)
        #   [target_vocab_size, ]: [float_min, float_min, float_min, ..., 0]
        #   this forces the beam with EOS continue to generate EOS
        finished_beam_bias = finished_beam_one_entry_bias(
            on_entry=self._vocab.eos_id, num_entries=self._vocab.vocab_size)
        # [batch_size * beam_size, target_vocab_size]: outer product
        finished_beam_bias = expand_to_beam_size(
            finished_beam_bias, self._beam_size * self._batch_size, axis=0)
        finished_beam_bias *= tf.expand_dims(prev_finished_float, 1)
        # compute new probs, with finished flags & mask
        probs = probs * tf.expand_dims(1. - prev_finished_float, 1) + finished_beam_bias

        # [batch_size * beam_size, target_vocab_size]
        # compute new log_probs
        log_probs = probs + tf.expand_dims(prev_log_probs, 1)
        # new decoding length: [batch_size * beam_size]
        lengths = lengths + 1 - tf.to_int32(finished)
        # compute beam score
        #  length_penalty: [batch_size * beam_size,]
        length_penalty = compute_length_penalty(lengths, self._alpha)
        scores = log_probs * tf.expand_dims(length_penalty, axis=1)

        # flatten: [batch_size, beam_size * target_vocab_size]
        scores = tf.reshape(tf.reshape(scores, [-1]),
                            [self._batch_size, -1])
        scores_flat = tf.cond(
            tf.convert_to_tensor(time) > 0, lambda: scores,  # time > 0: all
            lambda: tf.slice(scores, [0, 0],
                             [-1, self._vocab.vocab_size]))  # time = 0: first logits in each batch

        # [batch_size, beam_size] will restore top live_k
        sample_scores, sample_ids = tf.nn.top_k(scores_flat, k=self._beam_size)
        # flatten: [batch_size * beam_size,]
        sample_ids = tf.reshape(sample_ids, [-1])

        # because we do topk to scores with dim:[batch, beam * vocab]
        #   we need to cover the true word ids
        word_ids = tf.mod(sample_ids, self._vocab.vocab_size)

        # find beam_ids, indicating the current position is from which beam
        #  batch_pos, [batch_size, beam_size]: [[0, 0, ...], [1, 1,...], ..., [batch_size,...] ]
        batch_pos = compute_batch_indices(self._batch_size, self._beam_size)
        #  beam_base_pos: [batch_size * beam_size,]: [0, 0, ..., beam, beam,..., 2beam, 2beam, ...]
        beam_base_pos = tf.reshape(batch_pos * self._beam_size, [-1])
        # compute new beam_ids, [batch_size * beam_size, ]
        beam_ids = tf.div(sample_ids, self._vocab.vocab_size) + beam_base_pos

        # gather states according to beam_ids
        next_lengths = gather_states(lengths, beam_ids)

        # we need to recover log_probs according to scores's topk ids
        # [batch_size * beam_size * vocab_size, ]
        log_probs_flat = tf.reshape(log_probs, [-1])
        log_probs_index = beam_base_pos * self._vocab.vocab_size + sample_ids
        next_log_probs = tf.gather(log_probs_flat, log_probs_index)

        return word_ids, beam_ids, next_log_probs, next_lengths
Esempio n. 8
0
    def sample_symbols(self, logits, log_probs, finished, lengths, time):
        """ Samples symbols and returns it.

        Args:
            logits: The logits Tensor with shape [beam_size * batch_size, vocab_size],
              or a list of logits Tensors.
            log_probs: Accumulated log probabilities, a float32 Tensor with shape
              [beam_size * batch_size, ].
            finished: Finished flag of each beam in each batch, a bool Tensor with
              shape [beam_size * batch_size, ].
            lengths: The length of each beam in each batch, a int32 Tensor with
              shape [beam_size * batch_size, ].
            time: A int32 Scalar, the current time.

        Returns: A tuple `(word_ids, beam_ids, next_log_probs, next_lengths)`, where
          `words_ids` is the ids of sampled word symbols; `beam_ids` indicates the index
          of beam which the symbol at the position is from; `next_log_probs` is the accumulated
          log probabilities of each beam; `next_lengths` is the decoding lengths of
          each beam.
          All of the Tensors have shape [batch_size * beam_size, ].
        """
        # [batch_size * beam_size,]
        prev_finished_float = tf.to_float(finished)
        # [batch_size * beam_size, ]
        prev_log_probs = log_probs
        # [batch_size * beam_size, target_vocab_size]
        probs = self._compute_log_probs(logits)

        # mask the finished beam except only one entrance (target_eos_id)
        #   [target_vocab_size, ]: [float_min, float_min, float_min, ..., 0]
        #   this forces the beam with EOS continue to generate EOS
        finished_beam_bias = finished_beam_one_entry_bias(
            on_entry=self._vocab.eos_id, num_entries=self._vocab.vocab_size)
        # [batch_size * beam_size, target_vocab_size]: outer product
        finished_beam_bias = expand_to_beam_size(
            finished_beam_bias, self._beam_size * self._batch_size, axis=0)
        finished_beam_bias *= tf.expand_dims(prev_finished_float, 1)
        # compute new probs, with finished flags & mask
        probs = probs * tf.expand_dims(1. - prev_finished_float, 1) + finished_beam_bias

        # [batch_size * beam_size, target_vocab_size]
        # compute new log_probs
        log_probs = probs + tf.expand_dims(prev_log_probs, 1)
        # new decoding length: [batch_size * beam_size]
        lengths = lengths + 1 - tf.to_int32(finished)
        # compute beam score
        #  length_penalty: [batch_size * beam_size,]
        length_penalty = compute_length_penalty(lengths, self._alpha)
        scores = log_probs * tf.expand_dims(length_penalty, axis=1)

        # flatten: [batch_size, beam_size * target_vocab_size]
        scores = tf.reshape(tf.reshape(scores, [-1]),
                            [self._batch_size, -1])
        scores_flat = tf.cond(
            tf.convert_to_tensor(time) > 0, lambda: scores,  # time > 0: all
            lambda: tf.slice(scores, [0, 0],
                             [-1, self._vocab.vocab_size]))  # time = 0: first logits in each batch

        # [batch_size, beam_size] will restore top live_k
        sample_scores, sample_ids = tf.nn.top_k(scores_flat, k=self._beam_size)
        # flatten: [batch_size * beam_size,]
        sample_ids = tf.reshape(sample_ids, [-1])

        # because we do topk to scores with dim:[batch, beam * vocab]
        #   we need to cover the true word ids
        word_ids = tf.mod(sample_ids, self._vocab.vocab_size)

        # find beam_ids, indicating the current position is from which beam
        #  batch_pos, [batch_size, beam_size]: [[0, 0, ...], [1, 1,...], ..., [batch_size,...] ]
        batch_pos = compute_batch_indices(self._batch_size, self._beam_size)
        #  beam_base_pos: [batch_size * beam_size,]: [0, 0, ..., beam, beam,..., 2beam, 2beam, ...]
        beam_base_pos = tf.reshape(batch_pos * self._beam_size, [-1])
        # compute new beam_ids, [batch_size * beam_size, ]
        beam_ids = tf.div(sample_ids, self._vocab.vocab_size) + beam_base_pos

        # gather states according to beam_ids
        next_lengths = gather_states(lengths, beam_ids)

        # we need to recover log_probs according to scores's topk ids
        # [batch_size * beam_size * vocab_size, ]
        log_probs_flat = tf.reshape(log_probs, [-1])
        log_probs_index = beam_base_pos * self._vocab.vocab_size + sample_ids
        next_log_probs = tf.gather(log_probs_flat, log_probs_index)

        return word_ids, beam_ids, next_log_probs, next_lengths
Esempio n. 9
0
def sample_symbols_new(logits, log_probs, finished, lengths, time):
    """
    :param logits: [batch_size * beam_size, target_vocab_size]
    :param log_probs: [batch_size * beam_size, ]
    :param finished: [batch_size * beam_size, ]
    :param lengths: decoding length [batch_size * beam_size, ]
    :param time:
    :return:
    """

    # [batch_size * beam_size,]
    prev_finished_float = math_ops.to_float(finished)
    # [batch_size * beam_size, ]
    prev_log_probs = log_probs
    # [batch_size * beam_size, target_vocab_size]
    probs = advanced_log_softmax(logits)  # negative

    # mask the finished beam except only one entrance (target_eos_id)
    #   [target_vocab_size, ]: [float_min, float_min, float_min, ..., 0]
    #   this forces the beam with EOS continue to generate EOS
    finished_beam_bias = finished_beam_one_entry_bias(
        on_entry=eos_id, num_entries=vocab_size)
    # [batch_size * beam_size, target_vocab_size]: outer product
    finished_beam_bias = expand_to_beam_size(
        finished_beam_bias, beam_size * batch_size, axis=0)
    finished_beam_bias *= array_ops.expand_dims(prev_finished_float, 1)
    # compute new probs, with finished flags & mask
    probs = probs * array_ops.expand_dims(1. - prev_finished_float, 1) + finished_beam_bias

    # [batch_size * beam_size, target_vocab_size]
    # compute new log_probs
    log_probs = probs + array_ops.expand_dims(prev_log_probs, 1)
    # new decoding length: [batch_size * beam_size]
    lengths = lengths + 1 - math_ops.to_int32(finished)
    # compute beam score
    #  length_penalty: [batch_size * beam_size,]
    length_penalty = math_ops.pow(
        ((5.0 + math_ops.to_float(lengths)) / 6.0), -alpha)
    scores = log_probs * array_ops.expand_dims(length_penalty, axis=1)

    # flatten
    # [batch_size, beam_size * target_vocab_size]
    scores = array_ops.reshape(array_ops.reshape(scores, [-1]),
                               [batch_size, -1])
    ret_log_probs = array_ops.reshape(array_ops.reshape(log_probs, [-1]),
                                      [batch_size, -1])

    scores_flat = control_flow_ops.cond(
        ops.convert_to_tensor(time) > 0, lambda: scores,  # time > 0: all
        lambda: array_ops.slice(scores, [0, 0],
                                [-1, vocab_size]))  # time = 0: first logits in each batch

    # [batch_size, beam_size] will restore top live_k
    sample_scores, sample_ids = nn_ops.top_k(scores_flat, k=beam_size)
    ret_sample_ids = array_ops.reshape(sample_ids, [-1])
    # flatten: [batch_size * beam_size,]
    sample_ids = array_ops.reshape(sample_ids, [-1])
    # because we do topk to scores with dim:[batch, beam * vocab]
    #   we need to cover the true word ids
    word_ids = math_ops.mod(sample_ids, vocab_size)

    # beam ids should be adjusted according to batch_size
    #  batch_pos, [batch_size, beam_size]: [[0, 0, ...], [1, 1,...], [batch_size,...] ]
    batch_pos = compute_batch_indices(batch_size, beam_size)

    # compute new beam_ids, [batch_size * beam_size, ]
    beam_ids = math_ops.div(sample_ids, vocab_size) \
               + array_ops.reshape(batch_pos * beam_size, [-1])

    # we need to recover log_probs from score
    # flatten sample_scores: [batch_size * beam_size,]
    sample_scores_flatten = array_ops.reshape(sample_scores, [-1])
    # gather each length penalty
    length_penalty = gather_states(length_penalty, beam_ids)
    # recover log probabilities
    next_log_probs = sample_scores_flatten / length_penalty
    # gather states according to beam_ids
    next_lengths = gather_states(lengths, beam_ids)

    # [batch_size * beam_size * vocab_size, ]
    log_probs_flat = array_ops.reshape(log_probs, [-1])
    log_probs_index = array_ops.reshape(batch_pos, [-1]) * beam_size * vocab_size + sample_ids
    next_log_probs = array_ops.gather(log_probs_flat, log_probs_index)

    return word_ids, beam_ids, next_log_probs, next_lengths, ret_log_probs, ret_sample_ids, length_penalty