Beispiel #1
    def _get_modified_precision_counts(
        predicted_tokens: torch.LongTensor,
        reference_tokens: torch.LongTensor,
        ngram_size: int,
    ) -> Tuple[int, int]:
        Compare the predicted tokens to the reference (gold) tokens at the desired
        ngram size and calculate the numerator and denominator for a modified
        form of precision.

        The numerator is the number of ngrams in the predicted sentences that match
        with an ngram in the corresponding reference sentence, clipped by the total
        count of that ngram in the reference sentence. The denominator is just
        the total count of predicted ngrams.
        clipped_matches = 0
        total_predicted = 0
        from import ngrams

        for predicted_row, reference_row in zip(predicted_tokens,
            predicted_ngram_counts = ngrams(predicted_row, ngram_size,
            reference_ngram_counts = ngrams(reference_row, ngram_size,
            for ngram, count in predicted_ngram_counts.items():
                clipped_matches += min(count, reference_ngram_counts[ngram])
                total_predicted += count
        return clipped_matches, total_predicted
Beispiel #2
    def _get_rouge_n_stats(
        predicted_tokens: torch.LongTensor,
        reference_tokens: torch.LongTensor,
        ngram_size: int,
    ) -> Tuple[float, float, float]:
        Compare the predicted tokens to the reference (gold) tokens at the desired
        ngram size and compute recall, precision and f1 sums
        total_recall = 0.0
        total_precision = 0.0
        total_f1 = 0.0

        for predicted_seq, reference_seq in zip(predicted_tokens,
            from import ngrams

            predicted_ngram_counts = ngrams(predicted_seq, ngram_size,
            reference_ngram_counts = ngrams(reference_seq, ngram_size,

            matches = 0
            total_reference_ngrams = 0
            for ngram, count in reference_ngram_counts.items():
                matches += min(predicted_ngram_counts[ngram], count)
                total_reference_ngrams += count

            total_predicted_ngrams = sum(predicted_ngram_counts.values())

            if total_reference_ngrams == 0 or total_predicted_ngrams == 0 or matches == 0:

            recall = matches / total_reference_ngrams
            precision = matches / total_predicted_ngrams

            f1 = 2.0 * recall * precision / (recall + precision)

            # Accumulate stats
            total_recall += recall
            total_precision += precision
            total_f1 += f1

        if is_distributed():
            device = predicted_tokens.device
            _total_recall = torch.tensor(total_recall, device=device)
            _total_precision = torch.tensor(total_precision, device=device)
            _total_f1 = torch.tensor(total_f1, device=device)
            dist.all_reduce(_total_recall, op=dist.ReduceOp.SUM)
            dist.all_reduce(_total_precision, op=dist.ReduceOp.SUM)
            dist.all_reduce(_total_f1, op=dist.ReduceOp.SUM)
            total_recall = _total_recall.item()
            total_precision = _total_precision.item()
            total_f1 = _total_f1.item()

        return total_recall, total_precision, total_f1
Beispiel #3
 def _get_modified_precision_counts(self, predicted_tokens, reference_tokens, ngram_size):
     clipped_matches = 0
     total_predicted = 0
     for predicted_row, reference_row in zip(predicted_tokens, reference_tokens):
         predicted_ngram_counts = ngrams(predicted_row, ngram_size, self._exclude_indices)
         reference_ngram_counts = ngrams(reference_row, ngram_size, self._exclude_indices)
         for ngram, count in predicted_ngram_counts.items():
             clipped_matches += min(count, reference_ngram_counts[ngram])
             total_predicted += count
     return clipped_matches, total_predicted
Beispiel #4
    def test_ngrams(self, device: str):
        tensor = torch.tensor([1, 2, 3, 1, 2, 0], device=device)

        exclude_indices = self.metric._exclude_indices

        # Unigrams.
        counts: Counter = Counter(ngrams(tensor, 1, exclude_indices))
        unigram_check = {(1,): 2, (2,): 2, (3,): 1}
        assert counts == unigram_check

        # Bigrams.
        counts = Counter(ngrams(tensor, 2, exclude_indices))
        bigram_check = {(1, 2): 2, (2, 3): 1, (3, 1): 1}
        assert counts == bigram_check

        # Trigrams.
        counts = Counter(ngrams(tensor, 3, exclude_indices))
        trigram_check = {(1, 2, 3): 1, (2, 3, 1): 1, (3, 1, 2): 1}
        assert counts == trigram_check

        # ngram size too big, no ngrams produced.
        counts = Counter(ngrams(tensor, 7, exclude_indices))
        assert counts == {}