class BleuTest(AllenNlpTestCase): def setUp(self): super().setUp() self.metric = BLEU(ngram_weights=(0.5, 0.5), exclude_indices={0}) def test_get_valid_tokens_mask(self): tensor = torch.tensor([[1, 2, 3, 0], [0, 1, 1, 0]]) result = self.metric._get_valid_tokens_mask(tensor) result = result.long().numpy() check = np.array([[1, 1, 1, 0], [0, 1, 1, 0]]) np.testing.assert_array_equal(result, check) def test_ngrams(self): tensor = torch.tensor([1, 2, 3, 1, 2, 0]) # Unigrams. counts = Counter(self.metric._ngrams(tensor, 1)) unigram_check = {(1,): 2, (2,): 2, (3,): 1} assert counts == unigram_check # Bigrams. counts = Counter(self.metric._ngrams(tensor, 2)) bigram_check = {(1, 2): 2, (2, 3): 1, (3, 1): 1} assert counts == bigram_check # Trigrams. counts = Counter(self.metric._ngrams(tensor, 3)) trigram_check = {(1, 2, 3): 1, (2, 3, 1): 1, (3, 1, 2): 1} assert counts == trigram_check # ngram size too big, no ngrams produced. counts = Counter(self.metric._ngrams(tensor, 7)) assert counts == {} def test_bleu_computed_correctly(self): self.metric.reset() # shape: (batch_size, max_sequence_length) predictions = torch.tensor([[1, 0, 0], [1, 1, 0], [1, 1, 1]]) # shape: (batch_size, max_gold_sequence_length) gold_targets = torch.tensor([[2, 0, 0], [1, 0, 0], [1, 1, 2]]) self.metric(predictions, gold_targets) assert self.metric._prediction_lengths == 6 assert self.metric._reference_lengths == 5 # Number of unigrams in predicted sentences that match gold sentences # (but not more than maximum occurence of gold unigram within batch). assert self.metric._precision_matches[1] == ( 0 + # no matches in first sentence. 1 + # one clipped match in second sentence. 2 # two clipped matches in third sentence. ) # Total number of predicted unigrams. assert self.metric._precision_totals[1] == ( 1 + 2 + 3 ) # Number of bigrams in predicted sentences that match gold sentences # (but not more than maximum occurence of gold bigram within batch). assert self.metric._precision_matches[2] == ( 0 + 0 + 1 ) # Total number of predicted bigrams. assert self.metric._precision_totals[2] == ( 0 + 1 + 2 ) # Brevity penalty should be 1.0 assert self.metric._get_brevity_penalty() == 1.0 bleu = self.metric.get_metric(reset=True)["BLEU"] check = math.exp(0.5 * (math.log(3) - math.log(6)) + 0.5 * (math.log(1) - math.log(3))) np.testing.assert_approx_equal(bleu, check) def test_bleu_computed_with_zero_counts(self): self.metric.reset() assert self.metric.get_metric()["BLEU"] == 0
class BleuTest(AllenNlpTestCase): def setUp(self): super().setUp() self.metric = BLEU(ngram_weights=(0.5, 0.5), exclude_indices={0}) def test_get_valid_tokens_mask(self): tensor = torch.tensor([[1, 2, 3, 0], [0, 1, 1, 0]]) result = self.metric._get_valid_tokens_mask(tensor) result = result.long().numpy() check = np.array([[1, 1, 1, 0], [0, 1, 1, 0]]) np.testing.assert_array_equal(result, check) def test_ngrams(self): tensor = torch.tensor([1, 2, 3, 1, 2, 0]) # Unigrams. counts = Counter(self.metric._ngrams(tensor, 1)) unigram_check = {(1,): 2, (2,): 2, (3,): 1} assert counts == unigram_check # Bigrams. counts = Counter(self.metric._ngrams(tensor, 2)) bigram_check = {(1, 2): 2, (2, 3): 1, (3, 1): 1} assert counts == bigram_check # Trigrams. counts = Counter(self.metric._ngrams(tensor, 3)) trigram_check = {(1, 2, 3): 1, (2, 3, 1): 1, (3, 1, 2): 1} assert counts == trigram_check # ngram size too big, no ngrams produced. counts = Counter(self.metric._ngrams(tensor, 7)) assert counts == {} def test_bleu_computed_correctly(self): self.metric.reset() # shape: (batch_size, max_sequence_length) predictions = torch.tensor([[1, 0, 0], [1, 1, 0], [1, 1, 1]]) # shape: (batch_size, max_gold_sequence_length) gold_targets = torch.tensor([[2, 0, 0], [1, 0, 0], [1, 1, 2]]) self.metric(predictions, gold_targets) assert self.metric._prediction_lengths == 6 assert self.metric._reference_lengths == 5 # Number of unigrams in predicted sentences that match gold sentences # (but not more than maximum occurence of gold unigram within batch). assert self.metric._precision_matches[1] == ( 0 + 1 # no matches in first sentence. + 2 # one clipped match in second sentence. # two clipped matches in third sentence. ) # Total number of predicted unigrams. assert self.metric._precision_totals[1] == (1 + 2 + 3) # Number of bigrams in predicted sentences that match gold sentences # (but not more than maximum occurence of gold bigram within batch). assert self.metric._precision_matches[2] == (0 + 0 + 1) # Total number of predicted bigrams. assert self.metric._precision_totals[2] == (0 + 1 + 2) # Brevity penalty should be 1.0 assert self.metric._get_brevity_penalty() == 1.0 bleu = self.metric.get_metric(reset=True)["BLEU"] check = math.exp(0.5 * (math.log(3) - math.log(6)) + 0.5 * (math.log(1) - math.log(3))) np.testing.assert_approx_equal(bleu, check) def test_bleu_computed_with_zero_counts(self): self.metric.reset() assert self.metric.get_metric()["BLEU"] == 0
class BleuTest(AllenNlpTestCase): def setup_method(self): super().setup_method() self.metric = BLEU(ngram_weights=(0.5, 0.5), exclude_indices={0}) @multi_device def test_get_valid_tokens_mask(self, device: str): tensor = torch.tensor([[1, 2, 3, 0], [0, 1, 1, 0]], device=device) result = get_valid_tokens_mask(tensor, self.metric._exclude_indices).long() check = torch.tensor([[1, 1, 1, 0], [0, 1, 1, 0]], device=device) assert_allclose(result, check) @multi_device def test_ngrams(self, device: str): tensor = torch.tensor([1, 2, 3, 1, 2, 0], device=device) exclude_indices = self.metric._exclude_indices # Unigrams. counts: Counter = Counter(ngrams(tensor, 1, exclude_indices)) unigram_check = {(1,): 2, (2,): 2, (3,): 1} assert counts == unigram_check # Bigrams. counts = Counter(ngrams(tensor, 2, exclude_indices)) bigram_check = {(1, 2): 2, (2, 3): 1, (3, 1): 1} assert counts == bigram_check # Trigrams. counts = Counter(ngrams(tensor, 3, exclude_indices)) trigram_check = {(1, 2, 3): 1, (2, 3, 1): 1, (3, 1, 2): 1} assert counts == trigram_check # ngram size too big, no ngrams produced. counts = Counter(ngrams(tensor, 7, exclude_indices)) assert counts == {} @multi_device def test_bleu_computed_correctly(self, device: str): self.metric.reset() # shape: (batch_size, max_sequence_length) predictions = torch.tensor([[1, 0, 0], [1, 1, 0], [1, 1, 1]], device=device) # shape: (batch_size, max_gold_sequence_length) gold_targets = torch.tensor([[2, 0, 0], [1, 0, 0], [1, 1, 2]], device=device) self.metric(predictions, gold_targets) assert self.metric._prediction_lengths == 6 assert self.metric._reference_lengths == 5 # Number of unigrams in predicted sentences that match gold sentences # (but not more than maximum occurrence of gold unigram within batch). assert self.metric._precision_matches[1] == ( 0 + 1 # no matches in first sentence. + 2 # one clipped match in second sentence. # two clipped matches in third sentence. ) # Total number of predicted unigrams. assert self.metric._precision_totals[1] == (1 + 2 + 3) # Number of bigrams in predicted sentences that match gold sentences # (but not more than maximum occurrence of gold bigram within batch). assert self.metric._precision_matches[2] == (0 + 0 + 1) # Total number of predicted bigrams. assert self.metric._precision_totals[2] == (0 + 1 + 2) # Brevity penalty should be 1.0 assert self.metric._get_brevity_penalty() == 1.0 bleu = self.metric.get_metric(reset=True)["BLEU"] check = math.exp(0.5 * (math.log(3) - math.log(6)) + 0.5 * (math.log(1) - math.log(3))) assert_allclose(bleu, check) @multi_device def test_bleu_computed_with_zero_counts(self, device: str): self.metric.reset() assert self.metric.get_metric()["BLEU"] == 0 def test_distributed_bleu(self): predictions = [ torch.tensor([[1, 0, 0], [1, 1, 0]]), torch.tensor([[1, 1, 1]]), ] gold_targets = [ torch.tensor([[2, 0, 0], [1, 0, 0]]), torch.tensor([[1, 1, 2]]), ] check = math.exp(0.5 * (math.log(3) - math.log(6)) + 0.5 * (math.log(1) - math.log(3))) metric_kwargs = {"predictions": predictions, "gold_targets": gold_targets} desired_values = {"BLEU": check} run_distributed_test( [-1, -1], global_distributed_metric, BLEU(ngram_weights=(0.5, 0.5), exclude_indices={0}), metric_kwargs, desired_values, exact=False, )