def test_eos_masking(self):
        probs = constant_op.constant([
            [[-.2, -.2, -.2, -.2, -.2], [-.3, -.3, -.3, 3, 0], [5, 6, 0, 0,
                                                                0]],
            [[-.2, -.2, -.2, -.2, 0], [-.3, -.3, -.1, 3, 0], [5, 6, 3, 0, 0]],
        ])

        eos_token = 0
        previously_finished = constant_op.constant([[0, 1, 0], [0, 1, 1]],
                                                   dtype=dtypes.float32)
        masked = beam_search_decoder._mask_probs(probs, eos_token,
                                                 previously_finished)

        with self.test_session() as sess:
            probs = sess.run(probs)
            masked = sess.run(masked)

            self.assertAllEqual(probs[0][0], masked[0][0])
            self.assertAllEqual(probs[0][2], masked[0][2])
            self.assertAllEqual(probs[1][0], masked[1][0])

            self.assertEqual(masked[0][1][0], 0)
            self.assertEqual(masked[1][1][0], 0)
            self.assertEqual(masked[1][2][0], 0)

            for i in range(1, 5):
                self.assertAllClose(masked[0][1][i], np.finfo('float32').min)
                self.assertAllClose(masked[1][1][i], np.finfo('float32').min)
                self.assertAllClose(masked[1][2][i], np.finfo('float32').min)
  def test_eos_masking(self):
    probs = constant_op.constant([
        [[-.2, -.2, -.2, -.2, -.2], [-.3, -.3, -.3, 3, 0], [5, 6, 0, 0, 0]],
        [[-.2, -.2, -.2, -.2, 0], [-.3, -.3, -.1, 3, 0], [5, 6, 3, 0, 0]],
    ])

    eos_token = 0
    previously_finished = np.array([[0, 1, 0], [0, 1, 1]], dtype=bool)
    masked = beam_search_decoder._mask_probs(probs, eos_token,
                                             previously_finished)

    with self.cached_session() as sess:
      probs = sess.run(probs)
      masked = sess.run(masked)

      self.assertAllEqual(probs[0][0], masked[0][0])
      self.assertAllEqual(probs[0][2], masked[0][2])
      self.assertAllEqual(probs[1][0], masked[1][0])

      self.assertEqual(masked[0][1][0], 0)
      self.assertEqual(masked[1][1][0], 0)
      self.assertEqual(masked[1][2][0], 0)

      for i in range(1, 5):
        self.assertAllClose(masked[0][1][i], np.finfo('float32').min)
        self.assertAllClose(masked[1][1][i], np.finfo('float32').min)
        self.assertAllClose(masked[1][2][i], np.finfo('float32').min)
  def test_eos_masking(self):
    probs = constant_op.constant([
        [[-.2, -.2, -.2, -.2, -.2], [-.3, -.3, -.3, 3, 0], [5, 6, 0, 0, 0]],
        [[-.2, -.2, -.2, -.2, 0], [-.3, -.3, -.1, 3, 0], [5, 6, 3, 0, 0]],
    ])

    eos_token = 0
    previously_finished = constant_op.constant(
        [[0, 1, 0], [0, 1, 1]], dtype=dtypes.float32)
    masked = beam_search_decoder._mask_probs(probs, eos_token,
                                             previously_finished)

    with self.test_session() as sess:
      probs = sess.run(probs)
      masked = sess.run(masked)

      np.testing.assert_array_equal(probs[0][0], masked[0][0])
      np.testing.assert_array_equal(probs[0][2], masked[0][2])
      np.testing.assert_array_equal(probs[1][0], masked[1][0])

      np.testing.assert_equal(masked[0][1][0], 0)
      np.testing.assert_equal(masked[1][1][0], 0)
      np.testing.assert_equal(masked[1][2][0], 0)

      for i in range(1, 5):
        np.testing.assert_approx_equal(masked[0][1][i], np.finfo('float32').min)
        np.testing.assert_approx_equal(masked[1][1][i], np.finfo('float32').min)
        np.testing.assert_approx_equal(masked[1][2][i], np.finfo('float32').min)
Beispiel #4
0
def _beam_search_step(time, logits, next_cell_state, beam_state, peptide_mass,
                      batch_size, beam_width, end_token, length_penalty_weight,
                      suffix_dp_table, aa_weight_table):
    """Performs a single step of Beam Search Decoding.
  Args:
    time: Beam search time step, should start at 0. At time 0 we assume
      that all beams are equal and consider only the first beam for
      continuations.
    logits: Logits at the current time step. A tensor of shape
      `[batch_size, beam_width, vocab_size]`
    next_cell_state: The next state from the cell, e.g. an instance of
      AttentionWrapperState if the cell is attentional.
    beam_state: Current state of the beam search.
      An instance of `IonBeamSearchDecoderState`.
    batch_size: The batch size for this input.
    beam_width: Python int.  The size of the beams.
    end_token: The int32 end token.
    length_penalty_weight: Float weight to penalize length. Disabled with 0.0.
  Returns:
    A new beam state.
  """
    static_batch_size = tensor_util.constant_value(batch_size)

    # Calculate the current lengths of the predictions
    prediction_lengths = beam_state.lengths
    previously_finished = beam_state.finished

    # Calculate the total log probs for the new hypotheses
    # Final Shape: [batch_size, beam_width, vocab_size]
    step_log_probs = nn_ops.log_softmax(logits)
    step_log_probs = beam_search_decoder._mask_probs(step_log_probs, end_token,
                                                     previously_finished)
    total_probs = array_ops.expand_dims(beam_state.log_probs,
                                        2) + step_log_probs

    # Penalize beams with invalid total mass according to dp array of possible mass
    new_total_probs = penalize_invalid_mass(beam_state.prefix_mass,
                                            total_probs, peptide_mass,
                                            beam_width, suffix_dp_table,
                                            aa_weight_table)
    total_probs = new_total_probs

    # Calculate the continuation lengths by adding to all continuing beams.
    vocab_size = logits.shape[-1].value or array_ops.shape(logits)[-1]
    lengths_to_add = array_ops.one_hot(indices=array_ops.fill(
        [batch_size, beam_width], end_token),
                                       depth=vocab_size,
                                       on_value=np.int64(0),
                                       off_value=np.int64(1),
                                       dtype=dtypes.int64)
    add_mask = math_ops.to_int64(math_ops.logical_not(previously_finished))
    lengths_to_add *= array_ops.expand_dims(add_mask, 2)
    new_prediction_lengths = (lengths_to_add +
                              array_ops.expand_dims(prediction_lengths, 2))

    # Calculate the scores for each beam
    scores = beam_search_decoder._get_scores(
        log_probs=total_probs,
        sequence_lengths=new_prediction_lengths,
        length_penalty_weight=length_penalty_weight)

    time = ops.convert_to_tensor(time, name="time")
    # During the first time step we only consider the initial beam
    scores_shape = array_ops.shape(scores)
    scores_flat = array_ops.reshape(scores, [batch_size, -1])

    # Pick the next beams according to the specified successors function
    next_beam_size = ops.convert_to_tensor(beam_width,
                                           dtype=dtypes.int32,
                                           name="beam_width")
    next_beam_scores, word_indices = nn_ops.top_k(scores_flat,
                                                  k=next_beam_size)

    # word_indices = tf.Print(word_indices, [tf.shape(scores_flat)], message="** next beam shape")
    next_beam_scores.set_shape([static_batch_size, beam_width])
    word_indices.set_shape([static_batch_size, beam_width])

    # Pick out the probs, beam_ids, and states according to the chosen predictions
    next_beam_probs = beam_search_decoder._tensor_gather_helper(
        gather_indices=word_indices,
        gather_from=total_probs,
        batch_size=batch_size,
        range_size=beam_width * vocab_size,
        gather_shape=[-1],
        name="next_beam_probs")

    # Pick out log probs for each step according to next_beam_id
    next_step_probs = beam_search_decoder._tensor_gather_helper(
        gather_indices=word_indices,
        gather_from=step_log_probs,
        batch_size=batch_size,
        range_size=beam_width * vocab_size,
        gather_shape=[-1],
        name="next_step_probs")
    # Note: just doing the following
    #   math_ops.to_int32(word_indices % vocab_size,
    #       name="next_beam_word_ids")
    # would be a lot cleaner but for reasons unclear, that hides the results of
    # the op which prevents capturing it with tfdbg debug ops.
    raw_next_word_ids = math_ops.mod(word_indices,
                                     vocab_size,
                                     name="next_beam_word_ids")
    next_word_ids = math_ops.to_int32(raw_next_word_ids)
    next_beam_ids = math_ops.to_int32(word_indices / vocab_size,
                                      name="next_beam_parent_ids")

    # Append new ids to current predictions
    previously_finished = beam_search_decoder._tensor_gather_helper(
        gather_indices=next_beam_ids,
        gather_from=previously_finished,
        batch_size=batch_size,
        range_size=beam_width,
        gather_shape=[-1])
    next_finished = math_ops.logical_or(previously_finished,
                                        math_ops.equal(next_word_ids,
                                                       end_token),
                                        name="next_beam_finished")

    # Calculate the length of the next predictions.
    # 1. Finished beams remain unchanged.
    # 2. Beams that are now finished (EOS predicted) have their length
    #    increased by 1.
    # 3. Beams that are not yet finished have their length increased by 1.
    lengths_to_add = math_ops.to_int64(
        math_ops.logical_not(previously_finished))
    next_prediction_len = beam_search_decoder._tensor_gather_helper(
        gather_indices=next_beam_ids,
        gather_from=beam_state.lengths,
        batch_size=batch_size,
        range_size=beam_width,
        gather_shape=[-1])
    next_prediction_len += lengths_to_add

    # Pick prefix_mass according to the next_beam_id
    next_prefix_mass = beam_search_decoder._tensor_gather_helper(
        gather_indices=next_beam_ids,
        gather_from=beam_state.prefix_mass,
        batch_size=batch_size,
        range_size=beam_width,
        gather_shape=[-1])
    next_prefix_mass = next_prefix_mass + tf.gather(aa_weight_table,
                                                    next_word_ids)

    # Pick out the cell_states according to the next_beam_ids. We use a
    # different gather_shape here because the cell_state tensors, i.e.
    # the tensors that would be gathered from, all have dimension
    # greater than two and we need to preserve those dimensions.
    # pylint: disable=g-long-lambda
    next_cell_state = nest.map_structure(
        lambda gather_from: beam_search_decoder._maybe_tensor_gather_helper(
            gather_indices=next_beam_ids,
            gather_from=gather_from,
            batch_size=batch_size,
            range_size=beam_width,
            gather_shape=[batch_size * beam_width, -1]), next_cell_state)
    # pylint: enable=g-long-lambda

    next_state = IonBeamSearchDecoderState(cell_state=next_cell_state,
                                           log_probs=next_beam_probs,
                                           lengths=next_prediction_len,
                                           prefix_mass=next_prefix_mass,
                                           finished=next_finished)

    output = BeamSearchDecoderOutput(scores=next_beam_scores,
                                     predicted_ids=next_word_ids,
                                     parent_ids=next_beam_ids,
                                     step_log_probs=next_step_probs)

    return output, next_state