Beispiel #1
0
 def test_p_val(self, p):
     with pytest.raises(ConfigurationError):
         initial_predictions = torch.tensor([0] * 5)
         take_step = take_step_with_timestep
         beam_size = 3
         top_p, log_probs = BeamSearch.top_p_sampling(
             self.end_index, p=p,
             beam_size=beam_size).search(initial_predictions, {}, take_step)
Beispiel #2
0
    def test_empty_p(self):
        initial_predictions = torch.LongTensor(
            [self.end_index - 1, self.end_index - 1])
        take_step = take_step_with_timestep

        with pytest.warns(RuntimeWarning, match="Empty sequences predicted"):
            predictions, log_probs = BeamSearch.top_p_sampling(
                self.end_index, beam_size=1).search(initial_predictions, {},
                                                    take_step)
        # predictions hould have shape `(batch_size, beam_size, max_predicted_length)`.
        assert list(predictions.size()) == [2, 1, 1]
        # log probs hould have shape `(batch_size, beam_size)`.
        assert list(log_probs.size()) == [2, 1]
        assert (predictions == self.end_index).all()
        assert (log_probs == 0).all()
Beispiel #3
0
    def test_top_p_search(self):
        initial_predictions = torch.tensor([0] * 5)
        beam_size = 3
        take_step = take_step_with_timestep

        top_p, log_probs = BeamSearch.top_p_sampling(
            self.end_index,
            beam_size=beam_size).search(initial_predictions, {}, take_step)

        # bem_search = BeamSearch(self.end_index, beam_size=3, per_node_beam_size = 1)
        # top_p, log_probs = beam_search.search(initial_predictions, {}, take_step)

        beam_size = beam_size or 1
        batch_size = 5

        # top_p should be shape `(batch_size, beam_size, max_predicted_length)`.
        assert list(top_p.size())[:-1] == [batch_size, beam_size]

        assert ((0 <= top_p) & (top_p <= 5)).all()

        # log_probs should be shape `(batch_size, beam_size, max_predicted_length)`.
        assert list(log_probs.size()) == [batch_size, beam_size]