def test_select_output_token_sample(self): result = sample_output_token(torch.tensor([[0., 0.5, 1.]]), True, 1, 1, 1.0) assert result == torch.tensor(2)
def test_select_output_token_argmax(self): result = sample_output_token(torch.tensor([0., 1.]), False, 0, 0, 0) assert result == torch.tensor(1)