def test_eos_masking(self): probs = constant_op.constant([ [[-.2, -.2, -.2, -.2, -.2], [-.3, -.3, -.3, 3, 0], [5, 6, 0, 0, 0]], [[-.2, -.2, -.2, -.2, 0], [-.3, -.3, -.1, 3, 0], [5, 6, 3, 0, 0]], ]) eos_token = 0 previously_finished = constant_op.constant([[0, 1, 0], [0, 1, 1]], dtype=dtypes.float32) masked = beam_search_decoder._mask_probs(probs, eos_token, previously_finished) with self.test_session() as sess: probs = sess.run(probs) masked = sess.run(masked) self.assertAllEqual(probs[0][0], masked[0][0]) self.assertAllEqual(probs[0][2], masked[0][2]) self.assertAllEqual(probs[1][0], masked[1][0]) self.assertEqual(masked[0][1][0], 0) self.assertEqual(masked[1][1][0], 0) self.assertEqual(masked[1][2][0], 0) for i in range(1, 5): self.assertAllClose(masked[0][1][i], np.finfo('float32').min) self.assertAllClose(masked[1][1][i], np.finfo('float32').min) self.assertAllClose(masked[1][2][i], np.finfo('float32').min)
def test_eos_masking(self): probs = constant_op.constant([ [[-.2, -.2, -.2, -.2, -.2], [-.3, -.3, -.3, 3, 0], [5, 6, 0, 0, 0]], [[-.2, -.2, -.2, -.2, 0], [-.3, -.3, -.1, 3, 0], [5, 6, 3, 0, 0]], ]) eos_token = 0 previously_finished = np.array([[0, 1, 0], [0, 1, 1]], dtype=bool) masked = beam_search_decoder._mask_probs(probs, eos_token, previously_finished) with self.cached_session() as sess: probs = sess.run(probs) masked = sess.run(masked) self.assertAllEqual(probs[0][0], masked[0][0]) self.assertAllEqual(probs[0][2], masked[0][2]) self.assertAllEqual(probs[1][0], masked[1][0]) self.assertEqual(masked[0][1][0], 0) self.assertEqual(masked[1][1][0], 0) self.assertEqual(masked[1][2][0], 0) for i in range(1, 5): self.assertAllClose(masked[0][1][i], np.finfo('float32').min) self.assertAllClose(masked[1][1][i], np.finfo('float32').min) self.assertAllClose(masked[1][2][i], np.finfo('float32').min)
def test_eos_masking(self): probs = constant_op.constant([ [[-.2, -.2, -.2, -.2, -.2], [-.3, -.3, -.3, 3, 0], [5, 6, 0, 0, 0]], [[-.2, -.2, -.2, -.2, 0], [-.3, -.3, -.1, 3, 0], [5, 6, 3, 0, 0]], ]) eos_token = 0 previously_finished = constant_op.constant( [[0, 1, 0], [0, 1, 1]], dtype=dtypes.float32) masked = beam_search_decoder._mask_probs(probs, eos_token, previously_finished) with self.test_session() as sess: probs = sess.run(probs) masked = sess.run(masked) np.testing.assert_array_equal(probs[0][0], masked[0][0]) np.testing.assert_array_equal(probs[0][2], masked[0][2]) np.testing.assert_array_equal(probs[1][0], masked[1][0]) np.testing.assert_equal(masked[0][1][0], 0) np.testing.assert_equal(masked[1][1][0], 0) np.testing.assert_equal(masked[1][2][0], 0) for i in range(1, 5): np.testing.assert_approx_equal(masked[0][1][i], np.finfo('float32').min) np.testing.assert_approx_equal(masked[1][1][i], np.finfo('float32').min) np.testing.assert_approx_equal(masked[1][2][i], np.finfo('float32').min)
def _beam_search_step(time, logits, next_cell_state, beam_state, peptide_mass, batch_size, beam_width, end_token, length_penalty_weight, suffix_dp_table, aa_weight_table): """Performs a single step of Beam Search Decoding. Args: time: Beam search time step, should start at 0. At time 0 we assume that all beams are equal and consider only the first beam for continuations. logits: Logits at the current time step. A tensor of shape `[batch_size, beam_width, vocab_size]` next_cell_state: The next state from the cell, e.g. an instance of AttentionWrapperState if the cell is attentional. beam_state: Current state of the beam search. An instance of `IonBeamSearchDecoderState`. batch_size: The batch size for this input. beam_width: Python int. The size of the beams. end_token: The int32 end token. length_penalty_weight: Float weight to penalize length. Disabled with 0.0. Returns: A new beam state. """ static_batch_size = tensor_util.constant_value(batch_size) # Calculate the current lengths of the predictions prediction_lengths = beam_state.lengths previously_finished = beam_state.finished # Calculate the total log probs for the new hypotheses # Final Shape: [batch_size, beam_width, vocab_size] step_log_probs = nn_ops.log_softmax(logits) step_log_probs = beam_search_decoder._mask_probs(step_log_probs, end_token, previously_finished) total_probs = array_ops.expand_dims(beam_state.log_probs, 2) + step_log_probs # Penalize beams with invalid total mass according to dp array of possible mass new_total_probs = penalize_invalid_mass(beam_state.prefix_mass, total_probs, peptide_mass, beam_width, suffix_dp_table, aa_weight_table) total_probs = new_total_probs # Calculate the continuation lengths by adding to all continuing beams. vocab_size = logits.shape[-1].value or array_ops.shape(logits)[-1] lengths_to_add = array_ops.one_hot(indices=array_ops.fill( [batch_size, beam_width], end_token), depth=vocab_size, on_value=np.int64(0), off_value=np.int64(1), dtype=dtypes.int64) add_mask = math_ops.to_int64(math_ops.logical_not(previously_finished)) lengths_to_add *= array_ops.expand_dims(add_mask, 2) new_prediction_lengths = (lengths_to_add + array_ops.expand_dims(prediction_lengths, 2)) # Calculate the scores for each beam scores = beam_search_decoder._get_scores( log_probs=total_probs, sequence_lengths=new_prediction_lengths, length_penalty_weight=length_penalty_weight) time = ops.convert_to_tensor(time, name="time") # During the first time step we only consider the initial beam scores_shape = array_ops.shape(scores) scores_flat = array_ops.reshape(scores, [batch_size, -1]) # Pick the next beams according to the specified successors function next_beam_size = ops.convert_to_tensor(beam_width, dtype=dtypes.int32, name="beam_width") next_beam_scores, word_indices = nn_ops.top_k(scores_flat, k=next_beam_size) # word_indices = tf.Print(word_indices, [tf.shape(scores_flat)], message="** next beam shape") next_beam_scores.set_shape([static_batch_size, beam_width]) word_indices.set_shape([static_batch_size, beam_width]) # Pick out the probs, beam_ids, and states according to the chosen predictions next_beam_probs = beam_search_decoder._tensor_gather_helper( gather_indices=word_indices, gather_from=total_probs, batch_size=batch_size, range_size=beam_width * vocab_size, gather_shape=[-1], name="next_beam_probs") # Pick out log probs for each step according to next_beam_id next_step_probs = beam_search_decoder._tensor_gather_helper( gather_indices=word_indices, gather_from=step_log_probs, batch_size=batch_size, range_size=beam_width * vocab_size, gather_shape=[-1], name="next_step_probs") # Note: just doing the following # math_ops.to_int32(word_indices % vocab_size, # name="next_beam_word_ids") # would be a lot cleaner but for reasons unclear, that hides the results of # the op which prevents capturing it with tfdbg debug ops. raw_next_word_ids = math_ops.mod(word_indices, vocab_size, name="next_beam_word_ids") next_word_ids = math_ops.to_int32(raw_next_word_ids) next_beam_ids = math_ops.to_int32(word_indices / vocab_size, name="next_beam_parent_ids") # Append new ids to current predictions previously_finished = beam_search_decoder._tensor_gather_helper( gather_indices=next_beam_ids, gather_from=previously_finished, batch_size=batch_size, range_size=beam_width, gather_shape=[-1]) next_finished = math_ops.logical_or(previously_finished, math_ops.equal(next_word_ids, end_token), name="next_beam_finished") # Calculate the length of the next predictions. # 1. Finished beams remain unchanged. # 2. Beams that are now finished (EOS predicted) have their length # increased by 1. # 3. Beams that are not yet finished have their length increased by 1. lengths_to_add = math_ops.to_int64( math_ops.logical_not(previously_finished)) next_prediction_len = beam_search_decoder._tensor_gather_helper( gather_indices=next_beam_ids, gather_from=beam_state.lengths, batch_size=batch_size, range_size=beam_width, gather_shape=[-1]) next_prediction_len += lengths_to_add # Pick prefix_mass according to the next_beam_id next_prefix_mass = beam_search_decoder._tensor_gather_helper( gather_indices=next_beam_ids, gather_from=beam_state.prefix_mass, batch_size=batch_size, range_size=beam_width, gather_shape=[-1]) next_prefix_mass = next_prefix_mass + tf.gather(aa_weight_table, next_word_ids) # Pick out the cell_states according to the next_beam_ids. We use a # different gather_shape here because the cell_state tensors, i.e. # the tensors that would be gathered from, all have dimension # greater than two and we need to preserve those dimensions. # pylint: disable=g-long-lambda next_cell_state = nest.map_structure( lambda gather_from: beam_search_decoder._maybe_tensor_gather_helper( gather_indices=next_beam_ids, gather_from=gather_from, batch_size=batch_size, range_size=beam_width, gather_shape=[batch_size * beam_width, -1]), next_cell_state) # pylint: enable=g-long-lambda next_state = IonBeamSearchDecoderState(cell_state=next_cell_state, log_probs=next_beam_probs, lengths=next_prediction_len, prefix_mass=next_prefix_mass, finished=next_finished) output = BeamSearchDecoderOutput(scores=next_beam_scores, predicted_ids=next_word_ids, parent_ids=next_beam_ids, step_log_probs=next_step_probs) return output, next_state