def test_p_val(self, p_val): with pytest.raises(ValueError): initial_predictions = torch.tensor([0] * 5) take_step = take_step_with_timestep beam_size = 3 p_sampler = TopPSampler(p=p_val, with_replacement=True) top_k, log_probs = BeamSearch(self.end_index, beam_size=beam_size, max_steps=10, sampler=p_sampler).search( initial_predictions, {}, take_step)
def test_top_p_search(self): initial_predictions = torch.tensor([0] * 5) beam_size = 3 take_step = take_step_with_timestep p_sampler = TopPSampler(p=0.8) top_p, log_probs = BeamSearch(self.end_index, beam_size=beam_size, max_steps=10, sampler=p_sampler).search( initial_predictions, {}, take_step) beam_size = beam_size or 1 batch_size = 5 # top_p should be shape `(batch_size, beam_size, max_predicted_length)`. assert list(top_p.size())[:-1] == [batch_size, beam_size] assert ((0 <= top_p) & (top_p <= 5)).all() # log_probs should be shape `(batch_size, beam_size, max_predicted_length)`. assert list(log_probs.size()) == [batch_size, beam_size]
def test_top_p_sampler(self): sampler = TopPSampler(p=0.8, temperature=0.9) probabilities, classes, state = sampler.sample_nodes( log_probabilities, 3, {"foo": "bar"}) assert probabilities.size() == classes.size() assert classes.size() == (2, 3) assert all([x > 0 and x < 4 for x in classes[0]]) assert all([x > 1 and x < 5 for x in classes[1]]) # Make sure the filtered classes include the first class that exceeds p sampler = TopPSampler(p=0.7, temperature=1.0) probabilities, classes, state = sampler.sample_nodes( log_probabilities, 2, {"foo": "bar"}) assert all([x == 2 or x == 3 or x == 1 for x in classes[0]]) assert all([x == 2 or x == 3 for x in classes[1]])