def _beam_decode(self, start_tokens: torch.LongTensor, end_token: int, embedding_fn: Callable[ [torch.LongTensor, torch.LongTensor], torch.Tensor], decode_length: int = 256, beam_width: int = 5, length_penalty: float = 0.6) \ -> Tuple[torch.Tensor, torch.Tensor]: def _symbols_to_logits_fn(ids, cache): batch_size = ids.size(0) step = ids.size(-1) - 1 times = ids.new_full((batch_size, ), step) inputs = embedding_fn(ids[:, -1], times) return self._inputs_to_outputs(inputs, cache) assert self._vocab_size is not None outputs, log_prob = beam_search(_symbols_to_logits_fn, start_tokens, beam_width, decode_length, self._vocab_size, length_penalty, states=self._state_cache, eos_id=end_token) # Ignores <BOS> outputs = outputs[:, :, 1:] # shape = [batch_size, seq_length, beam_width] outputs = outputs.permute(0, 2, 1) return outputs, log_prob
def _beam_decode(self, embedding_fn, start_tokens, end_token, memory, memory_attention_bias, decode_length=256, beam_width=5, alpha=0.6): cache = self._init_cache(memory, memory_attention_bias) symbols_to_logits_fn = self._symbols_to_logits_fn(embedding_fn, \ max_length=decode_length+1) outputs, log_prob = beam_search.beam_search( symbols_to_logits_fn, start_tokens, beam_width, decode_length, self._vocab_size, alpha, states=cache, eos_id=end_token) # Ignores <BOS> outputs = outputs[:, :, 1:] # shape = [batch_size, seq_length, beam_width] outputs = tf.transpose(outputs, [0, 2, 1]) return (outputs, log_prob)
def testStates(self): batch_size = 1 beam_size = 1 vocab_size = 2 decode_length = 3 initial_ids = torch.tensor([0] * batch_size, dtype=torch.int64) probabilities = torch.tensor([[[0.7, 0.3]], [[0.4, 0.6]], [[0.5, 0.5]]]) expected_states = torch.tensor([[[0.]], [[1.]]]) def symbols_to_logits(ids, states): pos = ids.shape[1] logits = torch.log(probabilities[pos - 1, :]).type(torch.float) states["state"] += 1 return logits, states states = { "state": torch.zeros(batch_size, 1), } final_ids, _ = beam_search.beam_search( symbols_to_logits_fn=symbols_to_logits, initial_ids=initial_ids, beam_size=beam_size, decode_length=decode_length, vocab_size=vocab_size, alpha=0.0, eos_id=1, states=states)
def testStateBeamTwo(self): batch_size = 1 beam_size = 2 vocab_size = 3 decode_length = 3 initial_ids = torch.tensor([0] * batch_size, dtype=torch.int64) probabilities = torch.tensor([[[0.1, 0.1, 0.8], [0.1, 0.1, 0.8]], [[0.4, 0.5, 0.1], [0.2, 0.4, 0.4]], [[0.05, 0.9, 0.05], [0.4, 0.4, 0.2]]]) # The top beam is always selected so we should see the top beam's state # at each position, which is the one that getting 3 added to it each # step. expected_states = torch.tensor([[[0.], [0.]], [[3.], [3.]], [[6.], [6.]]]) def symbols_to_logits(ids, states): pos = ids.shape[1] logits = torch.log(probabilities[pos - 1, :]).type(torch.float) states["state"] += torch.tensor([[3.], [7.]]) return logits, states states = {"state": torch.zeros(batch_size, 1)} final_ids, _ = beam_search.beam_search( symbols_to_logits_fn=symbols_to_logits, initial_ids=initial_ids, beam_size=beam_size, decode_length=decode_length, vocab_size=vocab_size, alpha=0.0, eos_id=1, states=states)
def testGreedyWithCornerCase(self): batch_size = 1 beam_size = 1 vocab_size = 3 decode_length = 2 initial_ids = torch.tensor([0] * batch_size, dtype=torch.int64) probabilities = torch.tensor([[0.2, 0.1, 0.7], [0.4, 0.1, 0.5]]) def symbols_to_logits(ids): pos = ids.shape[1] logits = torch.log(probabilities[pos - 1, :]).type(torch.float) return logits final_ids, final_probs = beam_search.beam_search( symbols_to_logits_fn=symbols_to_logits, initial_ids=initial_ids, beam_size=beam_size, decode_length=decode_length, vocab_size=vocab_size, alpha=0.0, eos_id=1) exp_ids = [[[0, 2, 2]]] exp_probs = [[0.7 * 0.5]] self.assertEqual(final_ids.tolist(), exp_ids) self.assertAlmostEqual( np.exp(final_probs).tolist()[0][0], exp_probs[0][0])
def beam_decode( self, embedding_fn, start_tokens, EOS, memory, encoder_decoder_attention_bias, segment_ids, offsets, decode_length=256, beam_width=5, ): cache = self._init_cache(memory, encoder_decoder_attention_bias) symbols_to_logits_fn = self._symbols_to_logits_fn( embedding_fn, max_length=decode_length + 1, segment_ids=self._expand_to_beam_width(segment_ids, beam_width), offsets=self._expand_to_beam_width(offsets, beam_width)) outputs, log_probs = beam_search.beam_search(symbols_to_logits_fn, start_tokens, beam_width, decode_length, self._vocab_size, self._hparams.alpha, states=cache, eos_id=EOS) outputs = outputs[:, :, 1:] # ignore <BOS> return (outputs, log_probs)
def testGreedyBatchOne(self): batch_size = 1 beam_size = 1 vocab_size = 2 decode_length = 3 initial_ids = torch.tensor([0] * batch_size, dtype=torch.int64) # Test that beam search finds the most probable sequence. # These probabilities represent the following search # # G0 (0) # / \ # / \ # / \ # / \ # 0(0.7) 1(0.3) # / \ # / \ # / \ # 0(0.4) 1(0.6) # /\ # / \ # / \ # 0(0.5) 1(0.5) # and the following decoding probabilities # 0000 - 0.7 * 0.4 * 0.1 # 0001 - 0.7 * 0.4 * 0.9 # 001 - 0.7 * 0.6 (Best) # 01 = 0.3 # # 001 is the most likely sequence under these probabilities. probabilities = torch.tensor([[[0.7, 0.3]], [[0.4, 0.6]], [[0.5, 0.5]]]) def symbols_to_logits(ids): pos = ids.shape[1] logits = torch.log(probabilities[pos - 1, :]).type(torch.float) return logits final_ids, final_probs = beam_search.beam_search( symbols_to_logits_fn=symbols_to_logits, initial_ids=initial_ids, beam_size=beam_size, decode_length=decode_length, vocab_size=vocab_size, alpha=0.0, eos_id=1) exp_ids = [[[0, 0, 1]]] exp_probs = [[0.7 * 0.6]] self.assertEqual(final_ids.tolist(), exp_ids) self.assertAlmostEqual( np.exp(final_probs).tolist()[0][0], exp_probs[0][0])
def testNotGreedyBatchTwoBeamTwoWithAlpha(self): batch_size = 2 beam_size = 2 vocab_size = 3 decode_length = 3 initial_ids = torch.tensor([0] * batch_size, dtype=torch.int64) # Probabilities for position * batch * beam * vocab # Probabilities have been set such that with alpha = 3.5, the less # probable but longer sequence will have a better score than the # shorter sequence with higher log prob in batch 1, and the order will # be reverse in batch 2. That is, the shorter sequence will still have # a higher score in spite of the length penalty probabilities = torch.tensor([[[[0.1, 0.1, 0.8], [0.1, 0.1, 0.8]], [[0.1, 0.1, 0.8], [0.1, 0.1, 0.8]]], [[[0.4, 0.5, 0.1], [0.2, 0.4, 0.4]], [[0.3, 0.6, 0.1], [0.2, 0.4, 0.4]]], [[[0.05, 0.9, 0.05], [0.4, 0.4, 0.2]], [[0.05, 0.9, 0.05], [0.4, 0.4, 0.2]]]]) def symbols_to_logits(ids): pos = ids.shape[1] logits = torch.log(probabilities[pos - 1, :]).type(torch.float) return logits final_ids, final_probs = beam_search.beam_search( symbols_to_logits_fn=symbols_to_logits, initial_ids=initial_ids, beam_size=beam_size, decode_length=decode_length, vocab_size=vocab_size, alpha=3.5, eos_id=1) exp_ids = [[[0, 2, 0, 1], [0, 2, 1, 0]], [[0, 2, 1, 0], [0, 2, 0, 1]]] exp_probs = [[ np.log(0.8 * 0.4 * 0.9) / (8. / 6.)**3.5, np.log(0.8 * 0.5) / (7. / 6.)**3.5 ], [ np.log(0.8 * 0.6) / (7. / 6.)**3.5, np.log(0.8 * 0.3 * 0.9) / (8. / 6.)**3.5 ]] self.assertEqual(final_ids.tolist(), exp_ids) for i in range(2): for j in range(2): self.assertAlmostEqual(final_probs.tolist()[i][j], exp_probs[i][j])
def beam_decode(self, start_tokens: torch.LongTensor, end_token: int, initial_state: AttentionWrapperState, decode_length: int = 256, beam_width: int = 5, length_penalty: float = 0.6) \ -> Tuple[torch.LongTensor, torch.Tensor]: def _prepare_beam_search(x): x = x.unsqueeze(1).repeat(1, beam_width, *([1] * (x.dim() - 1))) x = x.view(-1, *x.size()[2:]) return x memory_beam_search = _prepare_beam_search(self.memory) memory_sequence_length_beam_search = _prepare_beam_search( self.memory_sequence_length) def _symbols_to_logits_fn(ids, state): batch_size = ids.size(0) step = ids.size(-1) - 1 times = ids.new_full((batch_size, ), step) inputs = self.embed_tokens(ids[:, -1], times) wrapper_outputs, wrapper_state = self._cell( inputs, state, memory_beam_search, memory_sequence_length_beam_search) logits = self._output_layer(wrapper_outputs) return logits, wrapper_state assert self._vocab_size is not None outputs, log_prob = beam_search( symbols_to_logits_fn=_symbols_to_logits_fn, initial_ids=start_tokens, beam_size=beam_width, decode_length=decode_length, vocab_size=self._vocab_size, alpha=length_penalty, states=initial_state, eos_id=end_token) # Ignores <BOS> outputs = outputs[:, :, 1:] # shape = [batch_size, seq_length, beam_width] outputs = outputs.permute((0, 2, 1)) return outputs, log_prob
def _beam_decode(self, start_tokens, end_token, decode_length, beam_width, length_penalty): def _symbols_to_logits_fn(ids, step, cache): return self._input_ids_to_outputs(ids[:, -1], step, cache) outputs, log_prob = beam_search.beam_search(_symbols_to_logits_fn, start_tokens, beam_width, decode_length, self._vocab_size, length_penalty, eos_id=end_token, states=self._cache) # Ignores <BOS> outputs = outputs[:, :, 1:] # shape = [batch_size, seq_length, beam_width] outputs = tf.transpose(outputs, [0, 2, 1]) return (outputs, log_prob)
def testNotGreedyBeamTwoWithAlpha(self): batch_size = 1 beam_size = 2 vocab_size = 3 decode_length = 3 initial_ids = torch.tensor([0] * batch_size, dtype=torch.int64) # Probabilities for position * batch * beam * vocab # Probabilities have been set such that with alpha = 3.5, the less # probable but longer sequence will have a better score that the # shorter sequence with higher log prob. probabilities = torch.tensor([[[0.1, 0.1, 0.8], [0.1, 0.1, 0.8]], [[0.4, 0.5, 0.1], [0.2, 0.4, 0.4]], [[0.05, 0.9, 0.05], [0.4, 0.4, 0.2]]]) def symbols_to_logits(ids): pos = ids.shape[1] logits = torch.log(probabilities[pos - 1, :]).type(torch.float) return logits # Disable early stopping final_ids, final_probs = beam_search.beam_search( symbols_to_logits_fn=symbols_to_logits, initial_ids=initial_ids, beam_size=beam_size, decode_length=decode_length, vocab_size=vocab_size, alpha=3.5, eos_id=1) exp_ids = [[[0, 2, 0, 1], [0, 2, 1, 0]]] exp_probs = [[ np.log(0.8 * 0.4 * 0.9) / (8. / 6.)**3.5, np.log(0.8 * 0.5) / (7. / 6.)**3.5 ]] self.assertEqual(final_ids.tolist(), exp_ids) self.assertAlmostEqual(final_probs.tolist()[0][0], exp_probs[0][0]) self.assertAlmostEqual(final_probs.tolist()[0][1], exp_probs[0][1])
def testNotGreedyBeamTwoWithoutStopEarly(self): batch_size = 1 beam_size = 2 vocab_size = 3 decode_length = 3 initial_ids = torch.tensor([0] * batch_size, dtype=torch.int64) probabilities = torch.tensor([[[0.1, 0.1, 0.8], [0.1, 0.1, 0.8]], [[0.4, 0.5, 0.1], [0.2, 0.4, 0.4]], [[0.05, 0.9, 0.05], [0.4, 0.4, 0.2]]]) def symbols_to_logits(ids): pos = ids.shape[1] logits = torch.log(probabilities[pos - 1, :]).type(torch.float) return logits final_ids, final_probs = beam_search.beam_search( symbols_to_logits_fn=symbols_to_logits, initial_ids=initial_ids, beam_size=beam_size, decode_length=decode_length, vocab_size=vocab_size, alpha=0.0, eos_id=1, stop_early=False) # given stop_early = False, the algorithm will return all the beams # so we can test all of them here exp_ids = [[[0, 2, 1, 0], [0, 2, 0, 1]]] exp_probs = [[0.8 * 0.5, 0.8 * 0.4 * 0.9]] self.assertEqual(final_ids.tolist(), exp_ids) self.assertAlmostEqual( np.exp(final_probs).tolist()[0][0], exp_probs[0][0]) self.assertAlmostEqual( np.exp(final_probs).tolist()[0][1], exp_probs[0][1])
def testNotGreedyBeamTwoWithStopEarly(self): batch_size = 1 beam_size = 2 vocab_size = 3 decode_length = 10 initial_ids = torch.tensor([0] * batch_size, dtype=torch.int64) probabilities = torch.tensor([[[0.1, 0.1, 0.8], [0.1, 0.1, 0.8]], [[0.4, 0.5, 0.1], [0.2, 0.4, 0.4]], [[0.05, 0.9, 0.05], [0.4, 0.4, 0.2]]]) def symbols_to_logits(ids): pos = ids.shape[1] logits = torch.log(probabilities[pos - 1, :]).type(torch.float) return logits final_ids, final_probs = beam_search.beam_search( symbols_to_logits_fn=symbols_to_logits, initial_ids=initial_ids, beam_size=beam_size, decode_length=decode_length, vocab_size=vocab_size, alpha=0.0, eos_id=1, stop_early=True) # default value, but just to make this explicit # given stop_early = True, the only 'assurance' is w.r.t. the first beam # (i.e., other beams may not even be completed) # so, we check only the first beam first_beam = final_ids[:, 0] first_probs = final_probs[:, 0] exp_ids = [[0, 2, 1]] exp_probs = [0.8 * 0.5] self.assertEqual(first_beam.tolist(), exp_ids) self.assertAlmostEqual(np.exp(first_probs).tolist()[0], exp_probs[0])
def _beam_decode(self, start_tokens, end_token, decode_length=256, beam_width=5, alpha=0.6): def _symbols_to_logits_fn(ids, step, cache): return self._inputs_to_outputs( self._prepare_tokens_to_embeds(ids[:, -1]), step, cache) outputs, log_prob = beam_search.beam_search(_symbols_to_logits_fn, start_tokens, beam_width, decode_length, self._vocab_size, alpha, states=self._cache, eos_id=end_token) # Ignores <BOS> outputs = outputs[:, :, 1:] # shape = [batch_size, seq_length, beam_width] outputs = tf.transpose(outputs, [0, 2, 1]) return (outputs, log_prob)
def testShapes(self): batch_size = 2 beam_size = 3 vocab_size = 4 decode_length = 10 initial_ids = torch.tensor([0, 0], dtype=torch.int64) def symbols_to_logits(_): # Just return random logits return torch.rand(batch_size * beam_size, vocab_size) final_ids, final_probs = beam_search.beam_search( symbols_to_logits_fn=symbols_to_logits, initial_ids=initial_ids, beam_size=beam_size, decode_length=decode_length, vocab_size=vocab_size, alpha=0.0, eos_id=1) self.assertEqual(final_ids.shape[1], beam_size) self.assertEqual(final_probs.shape, torch.Size([batch_size, beam_size]))