def test_hong2014(self):
        python_rouge = PythonRouge()
        duc2004 = self._load_multiple_summaries(_duc2004_file_path)
        centroid = self._load_summaries(_centroid_file_path)

        use_stemmer = True
        remove_stopwords = False
        max_words = 100
        expected_metrics = run_rouge(duc2004, centroid,
                                     use_porter_stemmer=use_stemmer, remove_stopwords=remove_stopwords,
                                     max_ngram=2, max_words=max_words, compute_rouge_l=True)
        actual_metrics = python_rouge.run_python_rouge(duc2004, centroid,
                                                       use_porter_stemmer=use_stemmer,
                                                       remove_stopwords=remove_stopwords,
                                                       max_words=max_words,
                                                       compute_rouge_l=True)
        assert math.isclose(expected_metrics[R1_PRECISION], actual_metrics[R1_PRECISION], abs_tol=1e-2)
        assert math.isclose(expected_metrics[R1_RECALL], actual_metrics[R1_RECALL], abs_tol=2e-2)
        assert math.isclose(expected_metrics[R1_F1], actual_metrics[R1_F1], abs_tol=2e-2)
        assert math.isclose(expected_metrics[R2_PRECISION], actual_metrics[R2_PRECISION], abs_tol=1e-2)
        assert math.isclose(expected_metrics[R2_RECALL], actual_metrics[R2_RECALL], abs_tol=1e-2)
        assert math.isclose(expected_metrics[R2_F1], actual_metrics[R2_F1], abs_tol=1e-2)
        # Rouge-L is a little further off, but still reasonably close enough that I'm not worried
        assert math.isclose(expected_metrics[RL_PRECISION], actual_metrics[RL_PRECISION], abs_tol=1e-1)
        assert math.isclose(expected_metrics[RL_RECALL], actual_metrics[RL_RECALL], abs_tol=1e-1)
        assert math.isclose(expected_metrics[RL_F1], actual_metrics[RL_F1], abs_tol=1e-1)
 def test_preprocessing(self):
     rouge = PythonRouge()
     original = [
         'Xu Wenli, Wang Youchai, and Qin Yongmin, leading dissidents and'
     ]
     expected = ['xu wenli wang youchai'.split()]
     actual = rouge.preprocess_summary(original, max_words=5, remove_stopwords=True)
     assert expected == actual
Beispiel #3
0
    def test_python_rouge_metric_abstractive(self):
        """
        Tests to ensure that the `PythonRougeMetric` will compute the same Rouge
        scores as the `PythonRouge` class for abstractive summaries (that aren't
        sentence tokenized).
        """
        ngram_orders = [1, 2]
        max_words = 100
        use_stemmer = False
        remove_stopwords = True

        python_rouge = PythonRouge()
        metric = PythonRougeMetric(vocab=self.vocab,
                                   ngram_orders=ngram_orders,
                                   max_words=max_words,
                                   use_porter_stemmer=use_stemmer,
                                   remove_stopwords=remove_stopwords)

        expected_metrics = python_rouge.run_python_rouge(
            self.gold_summaries_abs,
            self.model_summaries_abs,
            ngram_orders=ngram_orders,
            max_words=max_words,
            use_porter_stemmer=use_stemmer,
            remove_stopwords=remove_stopwords)

        # Test passing batched strings
        batch_size = 2
        for i in range(0, len(self.gold_summaries_abs), batch_size):
            metric(self.gold_summaries_abs[i:i + batch_size],
                   self.model_summaries_abs[i:i + batch_size])
        actual_metrics = metric.get_metric(reset=True)
        self._assert_metrics_equal(expected_metrics, actual_metrics)

        # Test passing batched tensors
        for gold_tensor, model_tensor in zip(self.gold_summaries_tensors_abs,
                                             self.model_summaries_tensors_abs):
            metric(gold_tensor, model_tensor)
        actual_metrics = metric.get_metric(reset=True)
        self._assert_metrics_equal(expected_metrics, actual_metrics)

        # Test passing tensors batched with lists
        for gold_tensor, model_tensor in zip(self.gold_summaries_tensors_abs,
                                             self.model_summaries_tensors_abs):
            gold_list = [tensor for tensor in gold_tensor]
            model_list = [tensor for tensor in model_tensor]
            metric(gold_list, model_list)
        actual_metrics = metric.get_metric(reset=True)
        self._assert_metrics_equal(expected_metrics, actual_metrics)
    def __init__(self,
                 vocab: Vocabulary,
                 ngram_orders: Union[int, List[int]],
                 max_sentences: Optional[int] = None,
                 max_words: Optional[int] = None,
                 max_bytes: Optional[int] = None,
                 use_porter_stemmer: bool = True,
                 remove_stopwords: bool = False,
                 namespace: str = 'tokens') -> None:
        super().__init__()
        if isinstance(ngram_orders, int):
            ngram_orders = [ngram_orders]
        self.ngram_orders = ngram_orders
        self.max_sentences = max_sentences
        self.max_words = max_words
        self.max_bytes = max_bytes
        self.use_porter_stemmer = use_porter_stemmer
        self.remove_stopwords = remove_stopwords
        self.python_rouge = PythonRouge()

        self.vocab = vocab
        self.namespace = namespace
        vocab_tokens = vocab.get_token_to_index_vocabulary(namespace)

        # Extract the special tokens from the vocabulary. We need to check and
        # ensure each one exists, otherwise we would get the OOV symbol, which
        # we don't want to skip when converting from indices to strings.
        self.start_index = None
        if START_SYMBOL in vocab_tokens:
            self.start_index = vocab_tokens[START_SYMBOL]
        self.end_index = None
        if END_SYMBOL in vocab_tokens:
            self.end_index = vocab_tokens[END_SYMBOL]
        self.pad_index = None
        if DEFAULT_PADDING_TOKEN in vocab_tokens:
            self.pad_index = vocab_tokens[DEFAULT_PADDING_TOKEN]
        self.sent_start_index = None
        if SENT_START_SYMBOL in vocab_tokens:
            self.sent_start_index = vocab_tokens[SENT_START_SYMBOL]
        self.sent_end_index = None
        if SENT_END_SYMBOL in vocab_tokens:
            self.sent_end_index = vocab_tokens[SENT_END_SYMBOL]

        self.count = 0
        self.totals = {}
    def test_normalize_and_tokenize(self):
        """
        Tests to ensure the python version of Rouge correctly implements the
        perl implementation. The expected tokens were generated by editing the
        perl script to print the tokens during processing. The hyphens were added
        to the original because the perl implementation uses them to separate words.
        """
        rouge = PythonRouge()
        original = 'Xu Wenli, Wang Youchai, and Qin Yongmin, leading dissidents and '\
            'prominent members of the China-Democracy-Party, were found guilty of subversion ' \
            'and sentenced to 13, 11, and 12 years in prison, respectively.'
        expected = 'xu wenli wang youchai and qin yongmin leading dissidents and ' \
            'prominent members of the china democracy party were found guilty of subversion ' \
            'and sentenced to 13 11 and 12 years in prison respectively'.split()
        actual = rouge.normalize_and_tokenize_sentence(original, use_porter_stemmer=False, remove_stopwords=False)
        assert expected == actual

        expected = 'xu wenli wang youchai and qin yongmin lead dissid and promin '\
            'member of the china democraci parti be find guilti of subvers and sentenc '\
            'to 13 11 and 12 year in prison respect'.split()
        actual = rouge.normalize_and_tokenize_sentence(original, use_porter_stemmer=True, remove_stopwords=False)
        assert expected == actual

        expected = 'xu wenli wang youchai qin yongmin leading dissidents prominent '\
            'members china democracy party found guilty subversion sentenced 13 11 '\
            '12 years prison'.split()
        actual = rouge.normalize_and_tokenize_sentence(original, use_porter_stemmer=False, remove_stopwords=True)
        assert expected == actual

        expected = 'xu wenli wang youchai qin yongmin lead dissid promin member china '\
            'democraci parti find guilti subvers sentenc 13 11 12 year prison'.split()
        actual = rouge.normalize_and_tokenize_sentence(original, use_porter_stemmer=True, remove_stopwords=True)
        assert expected == actual
Beispiel #6
0
def main(args):
    python_rouge = PythonRouge()
    with JsonlWriter(args.output_jsonl) as out:
        with JsonlReader(args.input_jsonl) as f:
            with Parallel(n_jobs=args.num_cores) as parallel:
                batch = []
                for instance in tqdm(f):
                    batch.append(instance)
                    if len(batch) == _BATCH_SIZE:
                        _process_batch(parallel, batch, python_rouge, out)
                        batch.clear()

                if batch:
                    _process_batch(parallel, batch, python_rouge, out)
Beispiel #7
0
def main(args):
    python_rouge = PythonRouge()

    with JsonlWriter(args.output_jsonl) as out:
        with JsonlReader(args.input_jsonl) as f:
            for instance in f:
                document = instance['document']
                summary = instance['summary']
                _, labels = get_greedy_oracle_summary(document, summary, args.metric,
                                                      max_sentences=args.max_sentences,
                                                      max_tokens=args.max_tokens,
                                                      max_bytes=args.max_bytes,
                                                      use_porter_stemmer=args.use_stemmer,
                                                      remove_stopwords=args.remove_stopwords,
                                                      python_rouge=python_rouge)
                instance['labels'] = labels
                out.write(instance)
Beispiel #8
0
def main(args):
    python_rouge = PythonRouge()

    with JsonlWriter(args.output_jsonl) as out:
        with JsonlReader(args.input_jsonl) as f:
            for instance in f:
                document = instance['document']
                cloze = instance['cloze']
                oracle, labels = get_greedy_oracle_summary(
                    document, [cloze],
                    args.metric,
                    max_sentences=args.max_sentences,
                    max_tokens=args.max_tokens,
                    max_bytes=args.max_bytes,
                    use_porter_stemmer=args.use_stemmer,
                    remove_stopwords=args.remove_stopwords,
                    python_rouge=python_rouge)
                if args.cloze_only:
                    oracle = ' '.join(oracle)
                    out.write({'cloze': oracle})
                else:
                    instance['labels'] = labels
                    out.write(instance)
    def test_python_rouge(self):
        python_rouge = PythonRouge()
        summary = [
            "His tenacity holds despite the summary trials and harsh punishments for Xu, Wang Youcai and Qin Yongmin prominent party principals from the provinces who were sentenced to 11 and 12 years and despite threatening signs from the ruling Communist Party.",
            "The dissidents Xu Wenli, who was sentenced Monday to 13 years in prison, Wang Youcai, who received an 11-year sentence, and Qin Yongming, who was reported to have received 12 years were charged with subversion.",
            "As police moved against Xu's friends, labor rights campaigner Liu Nianchun was taken from a prison camp outside Beijing and, with his wife and daughter, was put on a plane to Canada and then New York, his first taste of freedom in more than 3 1/2 years."
        ]
        gold_summaries = [
            [
                "While China plans to sign the International Covenant on Civil and Political Rights at the U.N., it is still harassing and arresting human rights campaigners.",
                "Three prominent leaders of the China Democratic Party were put to trial and sentenced to 11-, 12- and 13-year prison terms.",
                "Germany and the U.S. condemned the arrests.",
                "A labor rights activist was released and exiled to the U.S. to blunt any opposition to Communist rule.",
                "U.S. policy to encourage trade and diplomacy in hope of democratic reforms evidences failure, but the U.S. is continuing its policy of encouragement.",
                "Friends of jailed dissidents state that they will continue to campaign for change."
            ],
            [
                "The US trade-driven policy of expanded ties encouraging Chinese democracy is questioned.",
                "China signed rights treaties and dissidents used new laws to set up China Democracy Party, but China violates the new laws by persecuting dissidents.",
                "It regularly frees activists from prison then exiles them so they lose local influence.",
                "It arrested an activist trying to register a rights monitoring group.",
                "CP leader Jiang's hard-line speech and publicity for activists sentenced to long prison terms signals a renewed Chinese crackdown.",
                "A rights activist expected to be sacrificed in the cause of democracy.",
                "Germany called China's sentencing of dissidents unacceptable."
            ],
            [
                "After 2 years of wooing the West by signing international accords, apparently relaxing controls on free speech, and releasing and exiling three dissenters, China cracked down against political dissent in Dec 1998.",
                "Leaders of the China Democracy Party (CDP) were arrested and three were sentenced to jail terms of 11 to 13 years.",
                "The West, including the US, UK and Germany, reacted strongly.",
                "Clinton's China policy of engagement was questioned.",
                "China's Jiang Zemin stated economic reform is not a prelude to democracy and vowed to crush any challenges to the Communist Party or \"social stability\".",
                "The CDP vowed to keep working, as more leaders awaited arrest."
            ],
            [
                "Xu Wenli, Wang Youchai, and Qin Yongmin, leading dissidents and prominent members of the China Democracy Party, were found guilty of subversion and sentenced to 13, 11, and 12 years in prison, respectively.",
                "Soon after the sentencing, China's president, Jiang Zemin, delivered speeches in which he asserted that Western political system must not be adopted and vowed to crush challenges to Communist Party rule.",
                "The harsh sentences and speeches signal a crackdown on dissent, but Zha Jianguo, another Democracy Party leader, says he will continue to push for change.",
                "Western nations condemned the sentences as violations of U.N. rights treaties signed by China."
            ]
        ]

        compute_rouge_l = True
        use_stemmer = False
        remove_stopwords = False
        expected_metrics = run_rouge([gold_summaries], [summary],
                                     use_porter_stemmer=use_stemmer, remove_stopwords=remove_stopwords,
                                     max_ngram=2, compute_rouge_l=compute_rouge_l)
        actual_metrics = python_rouge.run_python_rouge([gold_summaries], [summary],
                                                       use_porter_stemmer=use_stemmer,
                                                       remove_stopwords=remove_stopwords,
                                                       compute_rouge_l=compute_rouge_l)
        self.assertAlmostEqual(expected_metrics[R1_PRECISION], actual_metrics[R1_PRECISION], places=2)
        self.assertAlmostEqual(expected_metrics[R1_RECALL], actual_metrics[R1_RECALL], places=2)
        self.assertAlmostEqual(expected_metrics[R1_F1], actual_metrics[R1_F1], places=2)
        self.assertAlmostEqual(expected_metrics[R2_PRECISION], actual_metrics[R2_PRECISION], places=2)
        self.assertAlmostEqual(expected_metrics[R2_RECALL], actual_metrics[R2_RECALL], places=2)
        self.assertAlmostEqual(expected_metrics[R2_F1], actual_metrics[R2_F1], places=2)
        self.assertAlmostEqual(expected_metrics[RL_PRECISION], actual_metrics[RL_PRECISION], places=2)
        self.assertAlmostEqual(expected_metrics[RL_RECALL], actual_metrics[RL_RECALL], places=2)
        self.assertAlmostEqual(expected_metrics[RL_F1], actual_metrics[RL_F1], places=2)

        use_stemmer = False
        remove_stopwords = True
        expected_metrics = run_rouge([gold_summaries], [summary],
                                     use_porter_stemmer=use_stemmer, remove_stopwords=remove_stopwords,
                                     max_ngram=2, compute_rouge_l=compute_rouge_l)
        actual_metrics = python_rouge.run_python_rouge([gold_summaries], [summary],
                                                       use_porter_stemmer=use_stemmer,
                                                       remove_stopwords=remove_stopwords,
                                                       compute_rouge_l=compute_rouge_l)
        self.assertAlmostEqual(expected_metrics[R1_PRECISION], actual_metrics[R1_PRECISION], places=2)
        self.assertAlmostEqual(expected_metrics[R1_RECALL], actual_metrics[R1_RECALL], places=2)
        self.assertAlmostEqual(expected_metrics[R1_F1], actual_metrics[R1_F1], places=2)
        self.assertAlmostEqual(expected_metrics[R2_PRECISION], actual_metrics[R2_PRECISION], places=2)
        self.assertAlmostEqual(expected_metrics[R2_RECALL], actual_metrics[R2_RECALL], places=2)
        self.assertAlmostEqual(expected_metrics[R2_F1], actual_metrics[R2_F1], places=2)
        self.assertAlmostEqual(expected_metrics[RL_PRECISION], actual_metrics[RL_PRECISION], places=2)
        self.assertAlmostEqual(expected_metrics[RL_RECALL], actual_metrics[RL_RECALL], places=2)
        self.assertAlmostEqual(expected_metrics[RL_F1], actual_metrics[RL_F1], places=2)

        use_stemmer = True
        remove_stopwords = False
        expected_metrics = run_rouge([gold_summaries], [summary],
                                     use_porter_stemmer=use_stemmer, remove_stopwords=remove_stopwords,
                                     max_ngram=2, compute_rouge_l=compute_rouge_l)
        actual_metrics = python_rouge.run_python_rouge([gold_summaries], [summary],
                                                       use_porter_stemmer=use_stemmer,
                                                       remove_stopwords=remove_stopwords,
                                                       compute_rouge_l=compute_rouge_l)
        self.assertAlmostEqual(expected_metrics[R1_PRECISION], actual_metrics[R1_PRECISION], places=2)
        self.assertAlmostEqual(expected_metrics[R1_RECALL], actual_metrics[R1_RECALL], places=2)
        self.assertAlmostEqual(expected_metrics[R1_F1], actual_metrics[R1_F1], places=2)
        self.assertAlmostEqual(expected_metrics[R2_PRECISION], actual_metrics[R2_PRECISION], places=2)
        self.assertAlmostEqual(expected_metrics[R2_RECALL], actual_metrics[R2_RECALL], places=2)
        self.assertAlmostEqual(expected_metrics[R2_F1], actual_metrics[R2_F1], places=2)
        self.assertAlmostEqual(expected_metrics[RL_PRECISION], actual_metrics[RL_PRECISION], places=2)
        self.assertAlmostEqual(expected_metrics[RL_RECALL], actual_metrics[RL_RECALL], places=2)
        self.assertAlmostEqual(expected_metrics[RL_F1], actual_metrics[RL_F1], places=2)

        use_stemmer = True
        remove_stopwords = True
        expected_metrics = run_rouge([gold_summaries], [summary],
                                     use_porter_stemmer=use_stemmer, remove_stopwords=remove_stopwords,
                                     max_ngram=2, compute_rouge_l=compute_rouge_l)
        actual_metrics = python_rouge.run_python_rouge([gold_summaries], [summary],
                                                       use_porter_stemmer=use_stemmer,
                                                       remove_stopwords=remove_stopwords,
                                                       compute_rouge_l=compute_rouge_l)
        self.assertAlmostEqual(expected_metrics[R1_PRECISION], actual_metrics[R1_PRECISION], places=2)
        self.assertAlmostEqual(expected_metrics[R1_RECALL], actual_metrics[R1_RECALL], places=2)
        self.assertAlmostEqual(expected_metrics[R1_F1], actual_metrics[R1_F1], places=2)
        self.assertAlmostEqual(expected_metrics[R2_PRECISION], actual_metrics[R2_PRECISION], places=2)
        self.assertAlmostEqual(expected_metrics[R2_RECALL], actual_metrics[R2_RECALL], places=2)
        self.assertAlmostEqual(expected_metrics[R2_F1], actual_metrics[R2_F1], places=2)
        self.assertAlmostEqual(expected_metrics[RL_PRECISION], actual_metrics[RL_PRECISION], places=2)
        self.assertAlmostEqual(expected_metrics[RL_RECALL], actual_metrics[RL_RECALL], places=2)
        self.assertAlmostEqual(expected_metrics[RL_F1], actual_metrics[RL_F1], places=2)
class PythonRougeMetric(Metric):
    """
    The ``PythonRougeMetric`` is an implementation of ROUGE that can be used
    while training models. The metrics computed over several different batches
    should be identical to calling the ``PythonRouge`` class once for the
    entire set of summaries.

    Parameters
    ----------
    ngram_orders: ``Union[int, List[int]]``
        The n-gram orders that should be computed. This should be the minimum
        number required for the fastest computation.
    max_sentences: ``int``, optional (default = ``None``)
        The maximum number of sentences to use for ROUGE. If ``None``, all are used.
    max_words: ``int``, optional (default = ``None``)
        The maximum number of words to use for ROUGE. If ``None``, all are used.
    max_bytes: ``int``, optional (default = ``None``)
        The maximum number of bytes to use for ROUGE. If ``None``, all are used.
    use_porter_stemmer: ``bool``, optional (default = ``True``)
        Indicates if the Porter Stemmer should be used.
    remove_stopwords: ``bool``, optional (default = ``False``)
        Indicates if stopwords should be removed from the summaries.
    namespace: ``str``, optional (default = ``"tokens"``)
        The summary token namespace.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 ngram_orders: Union[int, List[int]],
                 max_sentences: Optional[int] = None,
                 max_words: Optional[int] = None,
                 max_bytes: Optional[int] = None,
                 use_porter_stemmer: bool = True,
                 remove_stopwords: bool = False,
                 namespace: str = 'tokens') -> None:
        super().__init__()
        if isinstance(ngram_orders, int):
            ngram_orders = [ngram_orders]
        self.ngram_orders = ngram_orders
        self.max_sentences = max_sentences
        self.max_words = max_words
        self.max_bytes = max_bytes
        self.use_porter_stemmer = use_porter_stemmer
        self.remove_stopwords = remove_stopwords
        self.python_rouge = PythonRouge()

        self.vocab = vocab
        self.namespace = namespace
        vocab_tokens = vocab.get_token_to_index_vocabulary(namespace)

        # Extract the special tokens from the vocabulary. We need to check and
        # ensure each one exists, otherwise we would get the OOV symbol, which
        # we don't want to skip when converting from indices to strings.
        self.start_index = None
        if START_SYMBOL in vocab_tokens:
            self.start_index = vocab_tokens[START_SYMBOL]
        self.end_index = None
        if END_SYMBOL in vocab_tokens:
            self.end_index = vocab_tokens[END_SYMBOL]
        self.pad_index = None
        if DEFAULT_PADDING_TOKEN in vocab_tokens:
            self.pad_index = vocab_tokens[DEFAULT_PADDING_TOKEN]
        self.sent_start_index = None
        if SENT_START_SYMBOL in vocab_tokens:
            self.sent_start_index = vocab_tokens[SENT_START_SYMBOL]
        self.sent_end_index = None
        if SENT_END_SYMBOL in vocab_tokens:
            self.sent_end_index = vocab_tokens[SENT_END_SYMBOL]

        self.count = 0
        self.totals = {}

    def _get_string_from_tensor(self, tensor: torch.Tensor) -> str:
        assert tensor.dim() == 1
        tokens = []
        for index in tensor:
            index = index.item()
            # We skip the start, sentence start, and sentence end symbols. It is
            # ok if these symbols are ``None`` since they won't match the index
            if index in [
                    self.start_index, self.sent_start_index,
                    self.sent_end_index
            ]:
                continue
            # We end if we see the end or padding index
            if index in [self.end_index, self.pad_index]:
                break
            tokens.append(
                self.vocab.get_token_from_index(index, self.namespace))
        return tokens

    def _convert_to_strings(
        self, summaries: Union[List[str], List[List[str]], List[torch.Tensor],
                               torch.Tensor]
    ) -> List[List[str]]:
        """
        Converts the summaries into ``List[List[str]]``, where each individual
        summary is a ``List[str]``. Abstractive summaries will be a list of length 1.
        """
        # If the inner-most element is a string, these are already ok and they
        # just might need to be added to a new dimension (for abstractive). Otherwise,
        # we need to convert them from tensors into strings
        if isinstance(summaries, list):
            if isinstance(summaries[0], list):
                if isinstance(summaries[0][0], str):
                    # Extractive strings
                    return summaries
            elif isinstance(summaries[0], str):
                # Abstractive strings
                return [[summary] for summary in summaries]

        # Otherwise, this is a tensor
        return self._convert_tensor_to_strings(summaries)

    def _convert_tensor_to_strings(
            self, summaries: Union[List[torch.Tensor],
                                   torch.Tensor]) -> List[List[str]]:
        """
        Converts the summaries represented as tensors to ``List[List[str]]``, where
        each summary is a ``List[str]``.
        """
        # If the summaries have 2 dimensions, they are assumed to be (batch_size, num_tokens)
        # objects that has the tokens as a single sequence. This is generally done for
        # abstractive summarization. If the summaries have 3 dimensions, they are
        # assumed to be of size (batch_size, num_sents, num_tokens), which is common
        # with extractive summarization.
        summaries_strings = []
        for summary in summaries:
            if summary.dim() == 1:
                # Abstractive
                tokens = self._get_string_from_tensor(summary)
                if not tokens:
                    raise Exception(f'Summary has no tokens: {summary}')
                string = ' '.join(tokens)
                summaries_strings.append([string])
            elif summary.dim() == 2:
                # Extractive
                sentence_strings = []
                for sentence in summary:
                    tokens = self._get_string_from_tensor(sentence)
                    if tokens:
                        string = ' '.join(tokens)
                        sentence_strings.append(string)
                if not sentence_strings:
                    raise Exception(f'Summary has no tokens: {summary}')
                summaries_strings.append(sentence_strings)
            else:
                raise Exception(f'Summaries must be 1- or 2-dimensional')
        return summaries_strings

    @overrides
    def __call__(self, gold_summaries: Union[List[str], List[List[str]],
                                             List[torch.Tensor], torch.Tensor],
                 model_summaries: Union[List[str], List[List[str]],
                                        List[torch.Tensor], torch.Tensor],
                 **kwargs) -> None:
        """
        Computes ROUGE based on the batched input summaries. The summaries can be represented
        with several different types, depending on the use case.

            - ``List[str]``: Each summary is a ``str``, used with abstractive summaries
            - ``List[List[str]]``: Each summary is a ``List[str]``, used with extractive summaries
            - ``List[torch.Tensor]``: Each summary is a ``torch.Tensor`` that is either 1- or 2-dimensional
                    for abstractive or extractive summaries, respectively
            - ``torch.Tensor``: Each summary is a ``torch.Tensor``. If the input is 2-dimensional, each
                    row is assumed to be a 1-dimensional abstractive summary. If the input is 3-dimensional,
                    each matrix is assumed to be a 2-dimensional extractive summary.

        Parameters
        ----------
        gold_summaries: ``Union[List[str], List[List[str]], List[torch.Tensor], torch.Tensor]``
            See above.
        model_summaries: ``Union[List[str], List[List[str]], List[torch.Tensor], torch.Tensor]``
            See above.
        """
        gold_summaries = self._convert_to_strings(gold_summaries)
        model_summaries = self._convert_to_strings(model_summaries)
        metrics = self.python_rouge.run_python_rouge(
            gold_summaries,
            model_summaries,
            ngram_orders=self.ngram_orders,
            max_sentences=self.max_sentences,
            max_words=self.max_words,
            max_bytes=self.max_bytes,
            use_porter_stemmer=self.use_porter_stemmer,
            remove_stopwords=self.remove_stopwords)

        num_instances = len(gold_summaries)
        self.count += num_instances
        for metric, value in metrics.items():
            if metric not in self.totals:
                self.totals[metric] = value * num_instances
            else:
                self.totals[metric] += value * num_instances

    @overrides
    def get_metric(self, reset: bool = False) -> Dict[str, float]:
        metrics = {}
        for metric, value in self.totals.items():
            metrics[metric] = value / self.count
        if reset:
            self.reset()
        return metrics

    @overrides
    def reset(self) -> None:
        self.count = 0
        self.totals.clear()
Beispiel #11
0
def get_greedy_oracle_summary(document: List[str],
                              summary: Union[List[str], List[List[str]]],
                              metric: str,
                              max_sentences: Optional[int] = None,
                              max_tokens: Optional[int] = None,
                              max_bytes: Optional[int] = None,
                              use_porter_stemmer: bool = True,
                              remove_stopwords: bool = False,
                              python_rouge: Optional[PythonRouge] = None) -> Tuple[List[str], List[int]]:
    """
    Computes the greedy oracle summary by selecting sentences from the
    input document while greedily increasing the metric until the summary budget
    is met. Exactly one of ``max_sentences``, ``max_tokens``, and ``max_bytes``
    must be not None.

    Parameters
    ----------
    document:
        The sentence-tokenized input document from which to extract the summary.
    summary:
        The ground-truth summary (``List[str]``) or summaries (``List[List[str]]``)
    metric:
        The name of the metric to greedily optimize. The metrics currently supported are
        the Rouge metrics (see `rouge.py`)
    max_sentences:
        The maximum number of allowed sentences to take.
    max_tokens:
        The maximum number of allowed tokens to take.
    max_bytes:
        The maximum number of allowed bytes to take.
    python_rouge:
        The PythonRouge object to use to compute the metrics, useful to avoid
        reloading the external resources on each call.

    Returns
    -------
    The summary (``List[str]``) and the corresponding sentence indices which
    were selected from the input document (``List[int]``).
    """
    if python_rouge is None:
        python_rouge = PythonRouge()

    if metric in [R1_RECALL, R1_PRECISION, R1_F1]:
        ngram_orders = [1]
        compute_rouge_l = False
    elif metric in [R2_RECALL, R2_PRECISION, R2_F1]:
        ngram_orders = [2]
        compute_rouge_l = False
    elif metric in [R3_RECALL, R3_PRECISION, R3_F1]:
        ngram_orders = [3]
        compute_rouge_l = False
    elif metric in [R4_RECALL, R4_PRECISION, R4_F1]:
        ngram_orders = [4]
        compute_rouge_l = False
    elif metric in [RL_RECALL, RL_PRECISION, RL_F1]:
        ngram_orders = []
        compute_rouge_l = True
    else:
        raise Exception(f'Unknown metric: {metric}')

    candidates = set(range(len(document)))
    selected = []
    current_score = None

    while len(candidates) > 0:
        max_index, max_score = None, None
        for index in candidates:
            candidate_summary = [document[index] for index in sorted(selected + [index])]
            metrics = python_rouge.run_python_rouge([summary], [candidate_summary],
                                                    ngram_orders=ngram_orders,
                                                    max_sentences=max_sentences,
                                                    max_words=max_tokens,
                                                    max_bytes=max_bytes,
                                                    use_porter_stemmer=use_porter_stemmer,
                                                    remove_stopwords=remove_stopwords,
                                                    compute_rouge_l=compute_rouge_l)
            score = metrics[metric]
            if max_score is None or score > max_score:
                max_score = score
                max_index = index

        if current_score is None or max_score > current_score:
            current_score = max_score
            selected.append(max_index)
            candidates.remove(max_index)
        else:
            break

    selected = list(sorted(selected))
    summary = [document[index] for index in selected]
    return summary, selected