예제 #1
0
 def test_select_output_token_sample(self):
     result = sample_output_token(torch.tensor([[0., 0.5, 1.]]), True, 1, 1,
                                  1.0)
     assert result == torch.tensor(2)
예제 #2
0
 def test_select_output_token_argmax(self):
     result = sample_output_token(torch.tensor([0., 1.]), False, 0, 0, 0)
     assert result == torch.tensor(1)