def test_step(self):
        dummy_cell_state = array_ops.zeros([self.batch_size, self.beam_width])
        beam_state = beam_search_decoder.BeamSearchDecoderState(
            cell_state=dummy_cell_state,
            log_probs=nn_ops.log_softmax(
                array_ops.ones([self.batch_size, self.beam_width])),
            lengths=constant_op.constant(
                2,
                shape=[self.batch_size, self.beam_width],
                dtype=dtypes.int32),
            finished=array_ops.zeros([self.batch_size, self.beam_width],
                                     dtype=dtypes.bool))

        logits_ = np.full([self.batch_size, self.beam_width, self.vocab_size],
                          0.0001)
        logits_[0, 0, 2] = 1.9
        logits_[0, 0, 3] = 2.1
        logits_[0, 1, 3] = 3.1
        logits_[0, 1, 4] = 0.9
        logits_[1, 0, 1] = 0.5
        logits_[1, 1, 2] = 2.7
        logits_[1, 2, 2] = 10.0
        logits_[1, 2, 3] = 0.2
        logits = ops.convert_to_tensor(logits_, dtype=dtypes.float32)
        log_probs = nn_ops.log_softmax(logits)

        outputs, next_beam_state = beam_search_decoder._beam_search_step(
            time=2,
            logits=logits,
            beam_state=beam_state,
            batch_size=ops.convert_to_tensor(self.batch_size),
            beam_width=self.beam_width,
            end_token=self.end_token,
            length_penalty_weight=self.length_penalty_weight)

        with self.test_session() as sess:
            outputs_, next_state_, state_, log_probs_ = sess.run(
                [outputs, next_beam_state, beam_state, log_probs])

        np.testing.assert_array_equal(outputs_.predicted_ids,
                                      [[3, 3, 2], [2, 2, 1]])
        np.testing.assert_array_equal(outputs_.parent_ids,
                                      [[1, 0, 0], [2, 1, 0]])
        np.testing.assert_array_equal(next_state_.lengths,
                                      [[3, 3, 3], [3, 3, 3]])
        np.testing.assert_array_equal(
            next_state_.finished,
            [[False, False, False], [False, False, False]])

        expected_log_probs = []
        expected_log_probs.append(state_.log_probs[0][[1, 0, 0]])
        expected_log_probs.append(state_.log_probs[1][[2, 1, 0]])  # 0 --> 1
        expected_log_probs[0][0] += log_probs_[0, 1, 3]
        expected_log_probs[0][1] += log_probs_[0, 0, 3]
        expected_log_probs[0][2] += log_probs_[0, 0, 2]
        expected_log_probs[1][0] += log_probs_[1, 2, 2]
        expected_log_probs[1][1] += log_probs_[1, 1, 2]
        expected_log_probs[1][2] += log_probs_[1, 0, 1]
        np.testing.assert_array_equal(next_state_.log_probs,
                                      expected_log_probs)
Ejemplo n.º 2
0
    def test_step_with_eos(self):
        dummy_cell_state = array_ops.zeros([self.batch_size, self.beam_width])
        beam_state = beam_search_decoder.BeamSearchDecoderState(
            cell_state=dummy_cell_state,
            log_probs=nn_ops.log_softmax(
                array_ops.ones([self.batch_size, self.beam_width])),
            lengths=ops.convert_to_tensor([[2, 1, 2], [2, 2, 1]],
                                          dtype=dtypes.int64),
            finished=ops.convert_to_tensor(
                [[False, True, False], [False, False, True]],
                dtype=dtypes.bool),
            accumulated_attention_probs=())

        logits_ = np.full([self.batch_size, self.beam_width, self.vocab_size],
                          0.0001)
        logits_[0, 0, 2] = 1.9
        logits_[0, 0, 3] = 2.1
        logits_[0, 1, 3] = 3.1
        logits_[0, 1, 4] = 0.9
        logits_[1, 0, 1] = 0.5
        logits_[1, 1, 2] = 5.7  # why does this not work when it's 2.7?
        logits_[1, 2, 2] = 1.0
        logits_[1, 2, 3] = 0.2
        logits = ops.convert_to_tensor(logits_, dtype=dtypes.float32)
        log_probs = nn_ops.log_softmax(logits)

        outputs, next_beam_state = beam_search_decoder._beam_search_step(
            time=2,
            logits=logits,
            next_cell_state=dummy_cell_state,
            beam_state=beam_state,
            batch_size=ops.convert_to_tensor(self.batch_size),
            beam_width=self.beam_width,
            end_token=self.end_token,
            length_penalty_weight=self.length_penalty_weight,
            coverage_penalty_weight=self.coverage_penalty_weight)

        with self.cached_session() as sess:
            outputs_, next_state_, state_, log_probs_ = sess.run(
                [outputs, next_beam_state, beam_state, log_probs])

        self.assertAllEqual(outputs_.parent_ids, [[1, 0, 0], [1, 2, 0]])
        self.assertAllEqual(outputs_.predicted_ids, [[0, 3, 2], [2, 0, 1]])
        self.assertAllEqual(next_state_.lengths, [[1, 3, 3], [3, 1, 3]])
        self.assertAllEqual(next_state_.finished,
                            [[True, False, False], [False, True, False]])

        expected_log_probs = []
        expected_log_probs.append(state_.log_probs[0][[1, 0, 0]])
        expected_log_probs.append(state_.log_probs[1][[1, 2, 0]])
        expected_log_probs[0][1] += log_probs_[0, 0, 3]
        expected_log_probs[0][2] += log_probs_[0, 0, 2]
        expected_log_probs[1][0] += log_probs_[1, 1, 2]
        expected_log_probs[1][2] += log_probs_[1, 0, 1]
        self.assertAllEqual(next_state_.log_probs, expected_log_probs)
    def initialize(self, name=None):
        """Initialize the decoder.
        Args:
                name: Name scope for any created operations.
        Returns:
        `(finished, start_inputs, initial_state)`.
        """
        finished, start_inputs = self._finished, self._start_inputs

        dtype = nest.flatten(self._initial_cell_state)[0].dtype

        if self._start_token_logits is None:
            log_probs = array_ops.one_hot(  # shape(batch_sz, beam_sz)
                array_ops.zeros([self._batch_size], dtype=dtypes.int32),
                depth=self._beam_width,
                on_value=ops.convert_to_tensor(0.0, dtype=dtype),
                off_value=ops.convert_to_tensor(-np.Inf, dtype=dtype),
                dtype=dtype)
        else:
            log_probs = self._start_token_logits

        sequence_lengths = array_ops.zeros(
            [self._batch_size, self._beam_width], dtype=dtypes.int64)

        # Start tokens are part of output if no _GO token used. Make changes accordingly
        if not self._use_go_tokens:
            finished = math_ops.equal(self._start_tokens, self._raw_end_token)

            sequence_lengths = array_ops.where(
                math_ops.logical_not(finished),
                array_ops.fill(array_ops.shape(sequence_lengths),
                               tf.constant(1, dtype=dtypes.int64)),
                sequence_lengths)

        init_attention_probs = beam_search_decoder.get_attention_probs(
            self._initial_cell_state, self._coverage_penalty_weight)
        if init_attention_probs is None:
            init_attention_probs = ()

        initial_state = beam_search_decoder.BeamSearchDecoderState(
            cell_state=self._initial_cell_state,
            log_probs=log_probs,
            finished=finished,
            lengths=sequence_lengths,
            accumulated_attention_probs=init_attention_probs)

        return (finished, start_inputs, initial_state)
    def test_step(self):
        def get_probs():
            """this simulates the initialize method in BeamSearchDecoder."""
            log_prob_mask = array_ops.one_hot(array_ops.zeros(
                [self.batch_size], dtype=dtypes.int32),
                                              depth=self.beam_width,
                                              on_value=True,
                                              off_value=False,
                                              dtype=dtypes.bool)

            log_prob_zeros = array_ops.zeros(
                [self.batch_size, self.beam_width], dtype=dtypes.float32)
            log_prob_neg_inf = array_ops.ones(
                [self.batch_size, self.beam_width],
                dtype=dtypes.float32) * -np.Inf

            log_probs = array_ops.where(log_prob_mask, log_prob_zeros,
                                        log_prob_neg_inf)
            return log_probs

        log_probs = get_probs()
        dummy_cell_state = array_ops.zeros([self.batch_size, self.beam_width])

        # pylint: disable=invalid-name
        _finished = array_ops.one_hot(array_ops.zeros([self.batch_size],
                                                      dtype=dtypes.int32),
                                      depth=self.beam_width,
                                      on_value=False,
                                      off_value=True,
                                      dtype=dtypes.bool)
        _lengths = np.zeros([self.batch_size, self.beam_width], dtype=np.int64)
        _lengths[:, 0] = 2
        _lengths = constant_op.constant(_lengths, dtype=dtypes.int64)

        beam_state = beam_search_decoder.BeamSearchDecoderState(
            cell_state=dummy_cell_state,
            log_probs=log_probs,
            lengths=_lengths,
            finished=_finished)

        logits_ = np.full([self.batch_size, self.beam_width, self.vocab_size],
                          0.0001)
        logits_[0, 0, 2] = 1.9
        logits_[0, 0, 3] = 2.1
        logits_[0, 1, 3] = 3.1
        logits_[0, 1, 4] = 0.9
        logits_[1, 0, 1] = 0.5
        logits_[1, 1, 2] = 2.7
        logits_[1, 2, 2] = 10.0
        logits_[1, 2, 3] = 0.2
        logits = constant_op.constant(logits_, dtype=dtypes.float32)
        log_probs = nn_ops.log_softmax(logits)

        outputs, next_beam_state = beam_search_decoder._beam_search_step(
            time=2,
            logits=logits,
            next_cell_state=dummy_cell_state,
            beam_state=beam_state,
            batch_size=ops.convert_to_tensor(self.batch_size),
            beam_width=self.beam_width,
            end_token=self.end_token,
            length_penalty_weight=self.length_penalty_weight)

        with self.test_session() as sess:
            outputs_, next_state_, _, _ = sess.run(
                [outputs, next_beam_state, beam_state, log_probs])

        self.assertEqual(outputs_.predicted_ids[0, 0], 3)
        self.assertEqual(outputs_.predicted_ids[0, 1], 2)
        self.assertEqual(outputs_.predicted_ids[1, 0], 1)
        neg_inf = -np.Inf
        self.assertAllEqual(
            next_state_.log_probs[:, -3:],
            [[neg_inf, neg_inf, neg_inf], [neg_inf, neg_inf, neg_inf]])
        self.assertEqual((next_state_.log_probs[:, :-3] > neg_inf).all(), True)
        self.assertEqual((next_state_.lengths[:, :-3] > 0).all(), True)
        self.assertAllEqual(next_state_.lengths[:, -3:],
                            [[0, 0, 0], [0, 0, 0]])
Ejemplo n.º 5
0
def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
                      beam_width, end_token, length_penalty_weight,
                      coverage_penalty_weight):
    """Performs a single step of Beam Search Decoding.

  Args:
    time: Beam search time step, should start at 0. At time 0 we assume that all
      beams are equal and consider only the first beam for continuations.
    logits: Logits at the current time step. A tensor of shape `[batch_size,
      beam_width, vocab_size]`
    next_cell_state: The next state from the cell, e.g. an instance of
      AttentionWrapperState if the cell is attentional.
    beam_state: Current state of the beam search. An instance of
      `BeamSearchDecoderState`.
    batch_size: The batch size for this input.
    beam_width: Python int.  The size of the beams.
    end_token: The int32 end token.
    length_penalty_weight: Float weight to penalize length. Disabled with 0.0.
    coverage_penalty_weight: Float weight to penalize the coverage of source
      sentence. Disabled with 0.0.

  Returns:
    A new beam state.
  """
    static_batch_size = tensor_util.constant_value(batch_size)

    # Calculate the current lengths of the predictions
    prediction_lengths = beam_state.lengths
    previously_finished = beam_state.finished
    not_finished = math_ops.logical_not(previously_finished)

    # Calculate the total log probs for the new hypotheses
    # Final Shape: [batch_size, beam_width, vocab_size]
    step_log_probs = nn_ops.log_softmax(logits)
    step_log_probs = _mask_probs(step_log_probs, end_token,
                                 previously_finished)
    total_probs = array_ops.expand_dims(beam_state.log_probs,
                                        2) + step_log_probs

    # Calculate the continuation lengths by adding to all continuing beams.
    vocab_size = logits.shape[-1].value or array_ops.shape(logits)[-1]
    lengths_to_add = array_ops.one_hot(indices=array_ops.fill(
        [batch_size, beam_width], end_token),
                                       depth=vocab_size,
                                       on_value=np.int64(0),
                                       off_value=np.int64(1),
                                       dtype=dtypes.int64)
    add_mask = math_ops.to_int64(not_finished)
    lengths_to_add *= array_ops.expand_dims(add_mask, 2)
    new_prediction_lengths = (lengths_to_add +
                              array_ops.expand_dims(prediction_lengths, 2))

    # Calculate the accumulated attention probabilities if coverage penalty is
    # enabled.
    accumulated_attention_probs = None
    attention_probs = get_attention_probs(next_cell_state,
                                          coverage_penalty_weight)
    if attention_probs is not None:
        attention_probs *= array_ops.expand_dims(
            math_ops.to_float(not_finished), 2)
        accumulated_attention_probs = (beam_state.accumulated_attention_probs +
                                       attention_probs)

    # Calculate the scores for each beam
    scores = _get_scores(
        log_probs=total_probs,
        sequence_lengths=new_prediction_lengths,
        length_penalty_weight=length_penalty_weight,
        coverage_penalty_weight=coverage_penalty_weight,
        finished=previously_finished,
        accumulated_attention_probs=accumulated_attention_probs)

    time = ops.convert_to_tensor(time, name="time")
    # During the first time step we only consider the initial beam
    scores_flat = array_ops.reshape(scores, [batch_size, -1])

    # Pick the next beams according to the specified successors function
    next_beam_size = ops.convert_to_tensor(beam_width,
                                           dtype=dtypes.int32,
                                           name="beam_width")
    next_beam_scores, word_indices = nn_ops.top_k(scores_flat,
                                                  k=next_beam_size)

    next_beam_scores.set_shape([static_batch_size, beam_width])
    word_indices.set_shape([static_batch_size, beam_width])

    # Pick out the probs, beam_ids, and states according to the chosen predictions
    next_beam_probs = _tensor_gather_helper(gather_indices=word_indices,
                                            gather_from=total_probs,
                                            batch_size=batch_size,
                                            range_size=beam_width * vocab_size,
                                            gather_shape=[-1],
                                            name="next_beam_probs")
    # Note: just doing the following
    #   math_ops.to_int32(word_indices % vocab_size,
    #       name="next_beam_word_ids")
    # would be a lot cleaner but for reasons unclear, that hides the results of
    # the op which prevents capturing it with tfdbg debug ops.
    raw_next_word_ids = math_ops.mod(word_indices,
                                     vocab_size,
                                     name="next_beam_word_ids")
    next_word_ids = math_ops.to_int32(raw_next_word_ids)
    next_beam_ids = math_ops.to_int32(word_indices / vocab_size,
                                      name="next_beam_parent_ids")

    # Append new ids to current predictions
    previously_finished = _tensor_gather_helper(
        gather_indices=next_beam_ids,
        gather_from=previously_finished,
        batch_size=batch_size,
        range_size=beam_width,
        gather_shape=[-1])
    next_finished = math_ops.logical_or(previously_finished,
                                        math_ops.equal(next_word_ids,
                                                       end_token),
                                        name="next_beam_finished")

    # Calculate the length of the next predictions.
    # 1. Finished beams remain unchanged.
    # 2. Beams that are now finished (EOS predicted) have their length
    #    increased by 1.
    # 3. Beams that are not yet finished have their length increased by 1.
    lengths_to_add = math_ops.to_int64(
        math_ops.logical_not(previously_finished))
    next_prediction_len = _tensor_gather_helper(gather_indices=next_beam_ids,
                                                gather_from=beam_state.lengths,
                                                batch_size=batch_size,
                                                range_size=beam_width,
                                                gather_shape=[-1])
    next_prediction_len += lengths_to_add
    next_accumulated_attention_probs = ()
    if accumulated_attention_probs is not None:
        next_accumulated_attention_probs = _tensor_gather_helper(
            gather_indices=next_beam_ids,
            gather_from=accumulated_attention_probs,
            batch_size=batch_size,
            range_size=beam_width,
            gather_shape=[batch_size * beam_width, -1],
            name="next_accumulated_attention_probs")

    # Pick out the cell_states according to the next_beam_ids. We use a
    # different gather_shape here because the cell_state tensors, i.e.
    # the tensors that would be gathered from, all have dimension
    # greater than two and we need to preserve those dimensions.
    # pylint: disable=g-long-lambda
    next_cell_state = nest.map_structure(
        lambda gather_from: _maybe_tensor_gather_helper(
            gather_indices=next_beam_ids,
            gather_from=gather_from,
            batch_size=batch_size,
            range_size=beam_width,
            gather_shape=[batch_size * beam_width, -1]), next_cell_state)
    # pylint: enable=g-long-lambda

    next_state = beam_search_decoder.BeamSearchDecoderState(
        cell_state=next_cell_state,
        log_probs=next_beam_probs,
        lengths=next_prediction_len,
        finished=next_finished,
        accumulated_attention_probs=next_accumulated_attention_probs)

    output = beam_search_decoder.BeamSearchDecoderOutput(
        scores=next_beam_scores,
        predicted_ids=next_word_ids,
        parent_ids=next_beam_ids)

    return output, next_state