def _get_modified_precision_counts( self, predicted_tokens: torch.LongTensor, reference_tokens: torch.LongTensor, ngram_size: int, ) -> Tuple[int, int]: """ Compare the predicted tokens to the reference (gold) tokens at the desired ngram size and calculate the numerator and denominator for a modified form of precision. The numerator is the number of ngrams in the predicted sentences that match with an ngram in the corresponding reference sentence, clipped by the total count of that ngram in the reference sentence. The denominator is just the total count of predicted ngrams. """ clipped_matches = 0 total_predicted = 0 from allennlp.training.util import ngrams for predicted_row, reference_row in zip(predicted_tokens, reference_tokens): predicted_ngram_counts = ngrams(predicted_row, ngram_size, self._exclude_indices) reference_ngram_counts = ngrams(reference_row, ngram_size, self._exclude_indices) for ngram, count in predicted_ngram_counts.items(): clipped_matches += min(count, reference_ngram_counts[ngram]) total_predicted += count return clipped_matches, total_predicted
def _get_rouge_n_stats( self, predicted_tokens: torch.LongTensor, reference_tokens: torch.LongTensor, ngram_size: int, ) -> Tuple[float, float, float]: """ Compare the predicted tokens to the reference (gold) tokens at the desired ngram size and compute recall, precision and f1 sums """ total_recall = 0.0 total_precision = 0.0 total_f1 = 0.0 for predicted_seq, reference_seq in zip(predicted_tokens, reference_tokens): from allennlp.training.util import ngrams predicted_ngram_counts = ngrams(predicted_seq, ngram_size, self._exclude_indices) reference_ngram_counts = ngrams(reference_seq, ngram_size, self._exclude_indices) matches = 0 total_reference_ngrams = 0 for ngram, count in reference_ngram_counts.items(): matches += min(predicted_ngram_counts[ngram], count) total_reference_ngrams += count total_predicted_ngrams = sum(predicted_ngram_counts.values()) if total_reference_ngrams == 0 or total_predicted_ngrams == 0 or matches == 0: continue recall = matches / total_reference_ngrams precision = matches / total_predicted_ngrams f1 = 2.0 * recall * precision / (recall + precision) # Accumulate stats total_recall += recall total_precision += precision total_f1 += f1 if is_distributed(): device = predicted_tokens.device _total_recall = torch.tensor(total_recall, device=device) _total_precision = torch.tensor(total_precision, device=device) _total_f1 = torch.tensor(total_f1, device=device) dist.all_reduce(_total_recall, op=dist.ReduceOp.SUM) dist.all_reduce(_total_precision, op=dist.ReduceOp.SUM) dist.all_reduce(_total_f1, op=dist.ReduceOp.SUM) total_recall = _total_recall.item() total_precision = _total_precision.item() total_f1 = _total_f1.item() return total_recall, total_precision, total_f1
def _get_modified_precision_counts(self, predicted_tokens, reference_tokens, ngram_size): clipped_matches = 0 total_predicted = 0 for predicted_row, reference_row in zip(predicted_tokens, reference_tokens): predicted_ngram_counts = ngrams(predicted_row, ngram_size, self._exclude_indices) reference_ngram_counts = ngrams(reference_row, ngram_size, self._exclude_indices) for ngram, count in predicted_ngram_counts.items(): clipped_matches += min(count, reference_ngram_counts[ngram]) total_predicted += count return clipped_matches, total_predicted
def test_ngrams(self, device: str): tensor = torch.tensor([1, 2, 3, 1, 2, 0], device=device) exclude_indices = self.metric._exclude_indices # Unigrams. counts: Counter = Counter(ngrams(tensor, 1, exclude_indices)) unigram_check = {(1,): 2, (2,): 2, (3,): 1} assert counts == unigram_check # Bigrams. counts = Counter(ngrams(tensor, 2, exclude_indices)) bigram_check = {(1, 2): 2, (2, 3): 1, (3, 1): 1} assert counts == bigram_check # Trigrams. counts = Counter(ngrams(tensor, 3, exclude_indices)) trigram_check = {(1, 2, 3): 1, (2, 3, 1): 1, (3, 1, 2): 1} assert counts == trigram_check # ngram size too big, no ngrams produced. counts = Counter(ngrams(tensor, 7, exclude_indices)) assert counts == {}