def test_stochastic_beam_search(self): initial_predictions = torch.tensor([0] * 5) batch_size = 5 beam_size = 3 take_step = take_step_with_timestep gumbel_sampler = GumbelSampler() top_k, log_probs = BeamSearch(self.end_index, beam_size=beam_size, max_steps=10, sampler=gumbel_sampler).search( initial_predictions, {}, take_step) # top_p should be shape `(batch_size, beam_size, max_predicted_length)`. assert list(top_k.size())[:-1] == [batch_size, beam_size] assert ((0 <= top_k) & (top_k <= 5)).all() # log_probs should be shape `(batch_size, beam_size, max_predicted_length)`. assert list(log_probs.size()) == [batch_size, beam_size] # Check to make sure that once the end index is predicted, all subsequent tokens # must be the end index. This has been tested on toy examples in which for batch in top_k: for beam in batch: reached_end = False for token in beam: if token == self.end_index: reached_end = True if reached_end: assert token == self.end_index
def test_top_k_search(self): initial_predictions = torch.tensor([0] * 5) beam_size = 3 take_step = take_step_with_timestep k_sampler = TopKSampler(k=5, with_replacement=True) top_k, log_probs = BeamSearch(self.end_index, beam_size=beam_size, max_steps=10, sampler=k_sampler).search( initial_predictions, {}, take_step) beam_size = beam_size or 1 batch_size = 5 # top_p should be shape `(batch_size, beam_size, max_predicted_length)`. assert list(top_k.size())[:-1] == [batch_size, beam_size] assert ((0 <= top_k) & (top_k <= 5)).all() # log_probs should be shape `(batch_size, beam_size, max_predicted_length)`. assert list(log_probs.size()) == [batch_size, beam_size]