Ejemplo n.º 1
0
    def test_p_val(self, p_val):
        with pytest.raises(ValueError):
            initial_predictions = torch.tensor([0] * 5)
            take_step = take_step_with_timestep
            beam_size = 3
            p_sampler = TopPSampler(p=p_val, with_replacement=True)

            top_k, log_probs = BeamSearch(self.end_index,
                                          beam_size=beam_size,
                                          max_steps=10,
                                          sampler=p_sampler).search(
                                              initial_predictions, {},
                                              take_step)
Ejemplo n.º 2
0
    def test_top_p_search(self):
        initial_predictions = torch.tensor([0] * 5)
        beam_size = 3
        take_step = take_step_with_timestep
        p_sampler = TopPSampler(p=0.8)

        top_p, log_probs = BeamSearch(self.end_index,
                                      beam_size=beam_size,
                                      max_steps=10,
                                      sampler=p_sampler).search(
                                          initial_predictions, {}, take_step)

        beam_size = beam_size or 1
        batch_size = 5

        # top_p should be shape `(batch_size, beam_size, max_predicted_length)`.
        assert list(top_p.size())[:-1] == [batch_size, beam_size]

        assert ((0 <= top_p) & (top_p <= 5)).all()

        # log_probs should be shape `(batch_size, beam_size, max_predicted_length)`.
        assert list(log_probs.size()) == [batch_size, beam_size]
Ejemplo n.º 3
0
    def test_top_p_sampler(self):
        sampler = TopPSampler(p=0.8, temperature=0.9)

        probabilities, classes, state = sampler.sample_nodes(
            log_probabilities, 3, {"foo": "bar"})

        assert probabilities.size() == classes.size()
        assert classes.size() == (2, 3)

        assert all([x > 0 and x < 4 for x in classes[0]])
        assert all([x > 1 and x < 5 for x in classes[1]])

        # Make sure the filtered classes include the first class that exceeds p
        sampler = TopPSampler(p=0.7, temperature=1.0)

        probabilities, classes, state = sampler.sample_nodes(
            log_probabilities, 2, {"foo": "bar"})

        assert all([x == 2 or x == 3 or x == 1 for x in classes[0]])
        assert all([x == 2 or x == 3 for x in classes[1]])