Beispiel #1
0
 def test_k_val(self, k):
     with pytest.raises(ConfigurationError):
         initial_predictions = torch.tensor([0] * 5)
         take_step = take_step_with_timestep
         beam_size = 3
         top_k, log_probs = BeamSearch.top_k_sampling(
             self.end_index, k=k,
             beam_size=beam_size).search(initial_predictions, {}, take_step)
Beispiel #2
0
    def test_empty_k(self):
        initial_predictions = torch.LongTensor(
            [self.end_index - 1, self.end_index - 1])
        take_step = take_step_with_timestep

        with pytest.warns(RuntimeWarning, match="Empty sequences predicted"):
            predictions, log_probs = BeamSearch.top_k_sampling(
                self.end_index, beam_size=1).search(initial_predictions, {},
                                                    take_step)
        # predictions hould have shape `(batch_size, beam_size, max_predicted_length)`.
        assert list(predictions.size()) == [2, 1, 1]
        # log probs hould have shape `(batch_size, beam_size)`.
        assert list(log_probs.size()) == [2, 1]
        assert (predictions == self.end_index).all()
        assert (log_probs == 0).all()
Beispiel #3
0
    def test_top_k_search(self):
        initial_predictions = torch.tensor([0] * 5)
        beam_size = 3
        take_step = take_step_with_timestep

        top_k, log_probs = BeamSearch.top_k_sampling(
            self.end_index, k=1,
            beam_size=beam_size).search(initial_predictions, {}, take_step)

        beam_size = beam_size or 1
        batch_size = 5

        # top_p should be shape `(batch_size, beam_size, max_predicted_length)`.
        assert list(top_k.size())[:-1] == [batch_size, beam_size]

        assert ((0 <= top_k) & (top_k <= 5)).all()

        # log_probs should be shape `(batch_size, beam_size, max_predicted_length)`.
        assert list(log_probs.size()) == [batch_size, beam_size]