예제 #1
0
    def _beam_decode(self, start_tokens: torch.LongTensor, end_token: int,
                     embedding_fn: Callable[
                         [torch.LongTensor, torch.LongTensor], torch.Tensor],
                     decode_length: int = 256, beam_width: int = 5,
                     length_penalty: float = 0.6) \
            -> Tuple[torch.Tensor, torch.Tensor]:
        def _symbols_to_logits_fn(ids, cache):
            batch_size = ids.size(0)
            step = ids.size(-1) - 1
            times = ids.new_full((batch_size, ), step)
            inputs = embedding_fn(ids[:, -1], times)
            return self._inputs_to_outputs(inputs, cache)

        assert self._vocab_size is not None

        outputs, log_prob = beam_search(_symbols_to_logits_fn,
                                        start_tokens,
                                        beam_width,
                                        decode_length,
                                        self._vocab_size,
                                        length_penalty,
                                        states=self._state_cache,
                                        eos_id=end_token)

        # Ignores <BOS>
        outputs = outputs[:, :, 1:]
        # shape = [batch_size, seq_length, beam_width]
        outputs = outputs.permute(0, 2, 1)
        return outputs, log_prob
예제 #2
0
    def _beam_decode(self,
                     embedding_fn,
                     start_tokens,
                     end_token,
                     memory,
                     memory_attention_bias,
                     decode_length=256,
                     beam_width=5,
                     alpha=0.6):
        cache = self._init_cache(memory, memory_attention_bias)
        symbols_to_logits_fn = self._symbols_to_logits_fn(embedding_fn, \
            max_length=decode_length+1)
        outputs, log_prob = beam_search.beam_search(
            symbols_to_logits_fn,
            start_tokens,
            beam_width,
            decode_length,
            self._vocab_size,
            alpha,
            states=cache,
            eos_id=end_token)

        # Ignores <BOS>
        outputs = outputs[:, :, 1:]
        # shape = [batch_size, seq_length, beam_width]
        outputs = tf.transpose(outputs, [0, 2, 1])
        return (outputs, log_prob)
예제 #3
0
    def testStates(self):
        batch_size = 1
        beam_size = 1
        vocab_size = 2
        decode_length = 3

        initial_ids = torch.tensor([0] * batch_size, dtype=torch.int64)
        probabilities = torch.tensor([[[0.7, 0.3]], [[0.4, 0.6]], [[0.5,
                                                                    0.5]]])

        expected_states = torch.tensor([[[0.]], [[1.]]])

        def symbols_to_logits(ids, states):
            pos = ids.shape[1]
            logits = torch.log(probabilities[pos - 1, :]).type(torch.float)
            states["state"] += 1
            return logits, states

        states = {
            "state": torch.zeros(batch_size, 1),
        }

        final_ids, _ = beam_search.beam_search(
            symbols_to_logits_fn=symbols_to_logits,
            initial_ids=initial_ids,
            beam_size=beam_size,
            decode_length=decode_length,
            vocab_size=vocab_size,
            alpha=0.0,
            eos_id=1,
            states=states)
예제 #4
0
    def testStateBeamTwo(self):
        batch_size = 1
        beam_size = 2
        vocab_size = 3
        decode_length = 3

        initial_ids = torch.tensor([0] * batch_size, dtype=torch.int64)
        probabilities = torch.tensor([[[0.1, 0.1, 0.8], [0.1, 0.1, 0.8]],
                                      [[0.4, 0.5, 0.1], [0.2, 0.4, 0.4]],
                                      [[0.05, 0.9, 0.05], [0.4, 0.4, 0.2]]])

        # The top beam is always selected so we should see the top beam's state
        # at each position, which is the one that getting 3 added to it each
        # step.
        expected_states = torch.tensor([[[0.], [0.]], [[3.], [3.]], [[6.],
                                                                     [6.]]])

        def symbols_to_logits(ids, states):
            pos = ids.shape[1]
            logits = torch.log(probabilities[pos - 1, :]).type(torch.float)
            states["state"] += torch.tensor([[3.], [7.]])
            return logits, states

        states = {"state": torch.zeros(batch_size, 1)}

        final_ids, _ = beam_search.beam_search(
            symbols_to_logits_fn=symbols_to_logits,
            initial_ids=initial_ids,
            beam_size=beam_size,
            decode_length=decode_length,
            vocab_size=vocab_size,
            alpha=0.0,
            eos_id=1,
            states=states)
예제 #5
0
    def testGreedyWithCornerCase(self):
        batch_size = 1
        beam_size = 1
        vocab_size = 3
        decode_length = 2

        initial_ids = torch.tensor([0] * batch_size, dtype=torch.int64)
        probabilities = torch.tensor([[0.2, 0.1, 0.7], [0.4, 0.1, 0.5]])

        def symbols_to_logits(ids):
            pos = ids.shape[1]
            logits = torch.log(probabilities[pos - 1, :]).type(torch.float)
            return logits

        final_ids, final_probs = beam_search.beam_search(
            symbols_to_logits_fn=symbols_to_logits,
            initial_ids=initial_ids,
            beam_size=beam_size,
            decode_length=decode_length,
            vocab_size=vocab_size,
            alpha=0.0,
            eos_id=1)

        exp_ids = [[[0, 2, 2]]]
        exp_probs = [[0.7 * 0.5]]

        self.assertEqual(final_ids.tolist(), exp_ids)
        self.assertAlmostEqual(
            np.exp(final_probs).tolist()[0][0], exp_probs[0][0])
예제 #6
0
    def beam_decode(
        self,
        embedding_fn,
        start_tokens,
        EOS,
        memory,
        encoder_decoder_attention_bias,
        segment_ids,
        offsets,
        decode_length=256,
        beam_width=5,
    ):
        cache = self._init_cache(memory, encoder_decoder_attention_bias)
        symbols_to_logits_fn = self._symbols_to_logits_fn(
            embedding_fn,
            max_length=decode_length + 1,
            segment_ids=self._expand_to_beam_width(segment_ids, beam_width),
            offsets=self._expand_to_beam_width(offsets, beam_width))
        outputs, log_probs = beam_search.beam_search(symbols_to_logits_fn,
                                                     start_tokens,
                                                     beam_width,
                                                     decode_length,
                                                     self._vocab_size,
                                                     self._hparams.alpha,
                                                     states=cache,
                                                     eos_id=EOS)

        outputs = outputs[:, :, 1:]  # ignore <BOS>
        return (outputs, log_probs)
예제 #7
0
    def testGreedyBatchOne(self):
        batch_size = 1
        beam_size = 1
        vocab_size = 2
        decode_length = 3

        initial_ids = torch.tensor([0] * batch_size, dtype=torch.int64)

        # Test that beam search finds the most probable sequence.
        # These probabilities represent the following search
        #
        #               G0 (0)
        #                  / \
        #                /     \
        #              /         \
        #            /             \
        #         0(0.7)          1(0.3)
        #           / \
        #          /   \
        #         /     \
        #     0(0.4) 1(0.6)
        #        /\
        #       /  \
        #      /    \
        #    0(0.5) 1(0.5)
        # and the following decoding probabilities
        # 0000 - 0.7 * 0.4  * 0.1
        # 0001 - 0.7 * 0.4  * 0.9
        # 001 - 0.7 * 0.6 (Best)
        # 01 = 0.3
        #
        # 001 is the most likely sequence under these probabilities.
        probabilities = torch.tensor([[[0.7, 0.3]], [[0.4, 0.6]], [[0.5,
                                                                    0.5]]])

        def symbols_to_logits(ids):
            pos = ids.shape[1]
            logits = torch.log(probabilities[pos - 1, :]).type(torch.float)
            return logits

        final_ids, final_probs = beam_search.beam_search(
            symbols_to_logits_fn=symbols_to_logits,
            initial_ids=initial_ids,
            beam_size=beam_size,
            decode_length=decode_length,
            vocab_size=vocab_size,
            alpha=0.0,
            eos_id=1)

        exp_ids = [[[0, 0, 1]]]
        exp_probs = [[0.7 * 0.6]]

        self.assertEqual(final_ids.tolist(), exp_ids)
        self.assertAlmostEqual(
            np.exp(final_probs).tolist()[0][0], exp_probs[0][0])
예제 #8
0
    def testNotGreedyBatchTwoBeamTwoWithAlpha(self):
        batch_size = 2
        beam_size = 2
        vocab_size = 3
        decode_length = 3

        initial_ids = torch.tensor([0] * batch_size, dtype=torch.int64)
        # Probabilities for position * batch * beam * vocab
        # Probabilities have been set such that with alpha = 3.5, the less
        # probable but longer sequence will have a better score than the
        # shorter sequence with higher log prob in batch 1, and the order will
        # be reverse in batch 2. That is, the shorter sequence will still have
        # a higher score in spite of the length penalty
        probabilities = torch.tensor([[[[0.1, 0.1, 0.8], [0.1, 0.1, 0.8]],
                                       [[0.1, 0.1, 0.8], [0.1, 0.1, 0.8]]],
                                      [[[0.4, 0.5, 0.1], [0.2, 0.4, 0.4]],
                                       [[0.3, 0.6, 0.1], [0.2, 0.4, 0.4]]],
                                      [[[0.05, 0.9, 0.05], [0.4, 0.4, 0.2]],
                                       [[0.05, 0.9, 0.05], [0.4, 0.4, 0.2]]]])

        def symbols_to_logits(ids):
            pos = ids.shape[1]
            logits = torch.log(probabilities[pos - 1, :]).type(torch.float)
            return logits

        final_ids, final_probs = beam_search.beam_search(
            symbols_to_logits_fn=symbols_to_logits,
            initial_ids=initial_ids,
            beam_size=beam_size,
            decode_length=decode_length,
            vocab_size=vocab_size,
            alpha=3.5,
            eos_id=1)

        exp_ids = [[[0, 2, 0, 1], [0, 2, 1, 0]], [[0, 2, 1, 0], [0, 2, 0, 1]]]
        exp_probs = [[
            np.log(0.8 * 0.4 * 0.9) / (8. / 6.)**3.5,
            np.log(0.8 * 0.5) / (7. / 6.)**3.5
        ],
                     [
                         np.log(0.8 * 0.6) / (7. / 6.)**3.5,
                         np.log(0.8 * 0.3 * 0.9) / (8. / 6.)**3.5
                     ]]

        self.assertEqual(final_ids.tolist(), exp_ids)
        for i in range(2):
            for j in range(2):
                self.assertAlmostEqual(final_probs.tolist()[i][j],
                                       exp_probs[i][j])
예제 #9
0
    def beam_decode(self,
                    start_tokens: torch.LongTensor,
                    end_token: int,
                    initial_state: AttentionWrapperState,
                    decode_length: int = 256,
                    beam_width: int = 5,
                    length_penalty: float = 0.6) \
            -> Tuple[torch.LongTensor, torch.Tensor]:
        def _prepare_beam_search(x):
            x = x.unsqueeze(1).repeat(1, beam_width, *([1] * (x.dim() - 1)))
            x = x.view(-1, *x.size()[2:])
            return x

        memory_beam_search = _prepare_beam_search(self.memory)
        memory_sequence_length_beam_search = _prepare_beam_search(
            self.memory_sequence_length)

        def _symbols_to_logits_fn(ids, state):
            batch_size = ids.size(0)
            step = ids.size(-1) - 1
            times = ids.new_full((batch_size, ), step)
            inputs = self.embed_tokens(ids[:, -1], times)
            wrapper_outputs, wrapper_state = self._cell(
                inputs, state, memory_beam_search,
                memory_sequence_length_beam_search)
            logits = self._output_layer(wrapper_outputs)
            return logits, wrapper_state

        assert self._vocab_size is not None
        outputs, log_prob = beam_search(
            symbols_to_logits_fn=_symbols_to_logits_fn,
            initial_ids=start_tokens,
            beam_size=beam_width,
            decode_length=decode_length,
            vocab_size=self._vocab_size,
            alpha=length_penalty,
            states=initial_state,
            eos_id=end_token)

        # Ignores <BOS>
        outputs = outputs[:, :, 1:]
        # shape = [batch_size, seq_length, beam_width]
        outputs = outputs.permute((0, 2, 1))
        return outputs, log_prob
예제 #10
0
    def _beam_decode(self, start_tokens, end_token, decode_length, beam_width,
                     length_penalty):
        def _symbols_to_logits_fn(ids, step, cache):
            return self._input_ids_to_outputs(ids[:, -1], step, cache)

        outputs, log_prob = beam_search.beam_search(_symbols_to_logits_fn,
                                                    start_tokens,
                                                    beam_width,
                                                    decode_length,
                                                    self._vocab_size,
                                                    length_penalty,
                                                    eos_id=end_token,
                                                    states=self._cache)

        # Ignores <BOS>
        outputs = outputs[:, :, 1:]
        # shape = [batch_size, seq_length, beam_width]
        outputs = tf.transpose(outputs, [0, 2, 1])
        return (outputs, log_prob)
예제 #11
0
    def testNotGreedyBeamTwoWithAlpha(self):
        batch_size = 1
        beam_size = 2
        vocab_size = 3
        decode_length = 3

        initial_ids = torch.tensor([0] * batch_size, dtype=torch.int64)
        # Probabilities for position * batch * beam * vocab
        # Probabilities have been set such that with alpha = 3.5, the less
        # probable but longer sequence will have a better score that the
        # shorter sequence with higher log prob.
        probabilities = torch.tensor([[[0.1, 0.1, 0.8], [0.1, 0.1, 0.8]],
                                      [[0.4, 0.5, 0.1], [0.2, 0.4, 0.4]],
                                      [[0.05, 0.9, 0.05], [0.4, 0.4, 0.2]]])

        def symbols_to_logits(ids):
            pos = ids.shape[1]
            logits = torch.log(probabilities[pos - 1, :]).type(torch.float)
            return logits

        # Disable early stopping
        final_ids, final_probs = beam_search.beam_search(
            symbols_to_logits_fn=symbols_to_logits,
            initial_ids=initial_ids,
            beam_size=beam_size,
            decode_length=decode_length,
            vocab_size=vocab_size,
            alpha=3.5,
            eos_id=1)

        exp_ids = [[[0, 2, 0, 1], [0, 2, 1, 0]]]
        exp_probs = [[
            np.log(0.8 * 0.4 * 0.9) / (8. / 6.)**3.5,
            np.log(0.8 * 0.5) / (7. / 6.)**3.5
        ]]

        self.assertEqual(final_ids.tolist(), exp_ids)
        self.assertAlmostEqual(final_probs.tolist()[0][0], exp_probs[0][0])
        self.assertAlmostEqual(final_probs.tolist()[0][1], exp_probs[0][1])
예제 #12
0
    def testNotGreedyBeamTwoWithoutStopEarly(self):
        batch_size = 1
        beam_size = 2
        vocab_size = 3
        decode_length = 3

        initial_ids = torch.tensor([0] * batch_size, dtype=torch.int64)
        probabilities = torch.tensor([[[0.1, 0.1, 0.8], [0.1, 0.1, 0.8]],
                                      [[0.4, 0.5, 0.1], [0.2, 0.4, 0.4]],
                                      [[0.05, 0.9, 0.05], [0.4, 0.4, 0.2]]])

        def symbols_to_logits(ids):
            pos = ids.shape[1]
            logits = torch.log(probabilities[pos - 1, :]).type(torch.float)
            return logits

        final_ids, final_probs = beam_search.beam_search(
            symbols_to_logits_fn=symbols_to_logits,
            initial_ids=initial_ids,
            beam_size=beam_size,
            decode_length=decode_length,
            vocab_size=vocab_size,
            alpha=0.0,
            eos_id=1,
            stop_early=False)

        # given stop_early = False, the algorithm will return all the beams
        # so we can test all of them here

        exp_ids = [[[0, 2, 1, 0], [0, 2, 0, 1]]]
        exp_probs = [[0.8 * 0.5, 0.8 * 0.4 * 0.9]]

        self.assertEqual(final_ids.tolist(), exp_ids)
        self.assertAlmostEqual(
            np.exp(final_probs).tolist()[0][0], exp_probs[0][0])
        self.assertAlmostEqual(
            np.exp(final_probs).tolist()[0][1], exp_probs[0][1])
예제 #13
0
    def testNotGreedyBeamTwoWithStopEarly(self):
        batch_size = 1
        beam_size = 2
        vocab_size = 3
        decode_length = 10

        initial_ids = torch.tensor([0] * batch_size, dtype=torch.int64)
        probabilities = torch.tensor([[[0.1, 0.1, 0.8], [0.1, 0.1, 0.8]],
                                      [[0.4, 0.5, 0.1], [0.2, 0.4, 0.4]],
                                      [[0.05, 0.9, 0.05], [0.4, 0.4, 0.2]]])

        def symbols_to_logits(ids):
            pos = ids.shape[1]
            logits = torch.log(probabilities[pos - 1, :]).type(torch.float)
            return logits

        final_ids, final_probs = beam_search.beam_search(
            symbols_to_logits_fn=symbols_to_logits,
            initial_ids=initial_ids,
            beam_size=beam_size,
            decode_length=decode_length,
            vocab_size=vocab_size,
            alpha=0.0,
            eos_id=1,
            stop_early=True)  # default value, but just to make this explicit

        # given stop_early = True, the only 'assurance' is w.r.t. the first beam
        # (i.e., other beams may not even be completed)
        # so, we check only the first beam
        first_beam = final_ids[:, 0]
        first_probs = final_probs[:, 0]

        exp_ids = [[0, 2, 1]]
        exp_probs = [0.8 * 0.5]

        self.assertEqual(first_beam.tolist(), exp_ids)
        self.assertAlmostEqual(np.exp(first_probs).tolist()[0], exp_probs[0])
예제 #14
0
    def _beam_decode(self,
                     start_tokens,
                     end_token,
                     decode_length=256,
                     beam_width=5,
                     alpha=0.6):
        def _symbols_to_logits_fn(ids, step, cache):
            return self._inputs_to_outputs(
                self._prepare_tokens_to_embeds(ids[:, -1]), step, cache)

        outputs, log_prob = beam_search.beam_search(_symbols_to_logits_fn,
                                                    start_tokens,
                                                    beam_width,
                                                    decode_length,
                                                    self._vocab_size,
                                                    alpha,
                                                    states=self._cache,
                                                    eos_id=end_token)

        # Ignores <BOS>
        outputs = outputs[:, :, 1:]
        # shape = [batch_size, seq_length, beam_width]
        outputs = tf.transpose(outputs, [0, 2, 1])
        return (outputs, log_prob)
예제 #15
0
    def testShapes(self):
        batch_size = 2
        beam_size = 3
        vocab_size = 4
        decode_length = 10

        initial_ids = torch.tensor([0, 0], dtype=torch.int64)

        def symbols_to_logits(_):
            # Just return random logits
            return torch.rand(batch_size * beam_size, vocab_size)

        final_ids, final_probs = beam_search.beam_search(
            symbols_to_logits_fn=symbols_to_logits,
            initial_ids=initial_ids,
            beam_size=beam_size,
            decode_length=decode_length,
            vocab_size=vocab_size,
            alpha=0.0,
            eos_id=1)

        self.assertEqual(final_ids.shape[1], beam_size)
        self.assertEqual(final_probs.shape, torch.Size([batch_size,
                                                        beam_size]))