def test_hong2014(self): python_rouge = PythonRouge() duc2004 = self._load_multiple_summaries(_duc2004_file_path) centroid = self._load_summaries(_centroid_file_path) use_stemmer = True remove_stopwords = False max_words = 100 expected_metrics = run_rouge(duc2004, centroid, use_porter_stemmer=use_stemmer, remove_stopwords=remove_stopwords, max_ngram=2, max_words=max_words, compute_rouge_l=True) actual_metrics = python_rouge.run_python_rouge(duc2004, centroid, use_porter_stemmer=use_stemmer, remove_stopwords=remove_stopwords, max_words=max_words, compute_rouge_l=True) assert math.isclose(expected_metrics[R1_PRECISION], actual_metrics[R1_PRECISION], abs_tol=1e-2) assert math.isclose(expected_metrics[R1_RECALL], actual_metrics[R1_RECALL], abs_tol=2e-2) assert math.isclose(expected_metrics[R1_F1], actual_metrics[R1_F1], abs_tol=2e-2) assert math.isclose(expected_metrics[R2_PRECISION], actual_metrics[R2_PRECISION], abs_tol=1e-2) assert math.isclose(expected_metrics[R2_RECALL], actual_metrics[R2_RECALL], abs_tol=1e-2) assert math.isclose(expected_metrics[R2_F1], actual_metrics[R2_F1], abs_tol=1e-2) # Rouge-L is a little further off, but still reasonably close enough that I'm not worried assert math.isclose(expected_metrics[RL_PRECISION], actual_metrics[RL_PRECISION], abs_tol=1e-1) assert math.isclose(expected_metrics[RL_RECALL], actual_metrics[RL_RECALL], abs_tol=1e-1) assert math.isclose(expected_metrics[RL_F1], actual_metrics[RL_F1], abs_tol=1e-1)
def test_preprocessing(self): rouge = PythonRouge() original = [ 'Xu Wenli, Wang Youchai, and Qin Yongmin, leading dissidents and' ] expected = ['xu wenli wang youchai'.split()] actual = rouge.preprocess_summary(original, max_words=5, remove_stopwords=True) assert expected == actual
def test_python_rouge_metric_abstractive(self): """ Tests to ensure that the `PythonRougeMetric` will compute the same Rouge scores as the `PythonRouge` class for abstractive summaries (that aren't sentence tokenized). """ ngram_orders = [1, 2] max_words = 100 use_stemmer = False remove_stopwords = True python_rouge = PythonRouge() metric = PythonRougeMetric(vocab=self.vocab, ngram_orders=ngram_orders, max_words=max_words, use_porter_stemmer=use_stemmer, remove_stopwords=remove_stopwords) expected_metrics = python_rouge.run_python_rouge( self.gold_summaries_abs, self.model_summaries_abs, ngram_orders=ngram_orders, max_words=max_words, use_porter_stemmer=use_stemmer, remove_stopwords=remove_stopwords) # Test passing batched strings batch_size = 2 for i in range(0, len(self.gold_summaries_abs), batch_size): metric(self.gold_summaries_abs[i:i + batch_size], self.model_summaries_abs[i:i + batch_size]) actual_metrics = metric.get_metric(reset=True) self._assert_metrics_equal(expected_metrics, actual_metrics) # Test passing batched tensors for gold_tensor, model_tensor in zip(self.gold_summaries_tensors_abs, self.model_summaries_tensors_abs): metric(gold_tensor, model_tensor) actual_metrics = metric.get_metric(reset=True) self._assert_metrics_equal(expected_metrics, actual_metrics) # Test passing tensors batched with lists for gold_tensor, model_tensor in zip(self.gold_summaries_tensors_abs, self.model_summaries_tensors_abs): gold_list = [tensor for tensor in gold_tensor] model_list = [tensor for tensor in model_tensor] metric(gold_list, model_list) actual_metrics = metric.get_metric(reset=True) self._assert_metrics_equal(expected_metrics, actual_metrics)
def __init__(self, vocab: Vocabulary, ngram_orders: Union[int, List[int]], max_sentences: Optional[int] = None, max_words: Optional[int] = None, max_bytes: Optional[int] = None, use_porter_stemmer: bool = True, remove_stopwords: bool = False, namespace: str = 'tokens') -> None: super().__init__() if isinstance(ngram_orders, int): ngram_orders = [ngram_orders] self.ngram_orders = ngram_orders self.max_sentences = max_sentences self.max_words = max_words self.max_bytes = max_bytes self.use_porter_stemmer = use_porter_stemmer self.remove_stopwords = remove_stopwords self.python_rouge = PythonRouge() self.vocab = vocab self.namespace = namespace vocab_tokens = vocab.get_token_to_index_vocabulary(namespace) # Extract the special tokens from the vocabulary. We need to check and # ensure each one exists, otherwise we would get the OOV symbol, which # we don't want to skip when converting from indices to strings. self.start_index = None if START_SYMBOL in vocab_tokens: self.start_index = vocab_tokens[START_SYMBOL] self.end_index = None if END_SYMBOL in vocab_tokens: self.end_index = vocab_tokens[END_SYMBOL] self.pad_index = None if DEFAULT_PADDING_TOKEN in vocab_tokens: self.pad_index = vocab_tokens[DEFAULT_PADDING_TOKEN] self.sent_start_index = None if SENT_START_SYMBOL in vocab_tokens: self.sent_start_index = vocab_tokens[SENT_START_SYMBOL] self.sent_end_index = None if SENT_END_SYMBOL in vocab_tokens: self.sent_end_index = vocab_tokens[SENT_END_SYMBOL] self.count = 0 self.totals = {}
def test_normalize_and_tokenize(self): """ Tests to ensure the python version of Rouge correctly implements the perl implementation. The expected tokens were generated by editing the perl script to print the tokens during processing. The hyphens were added to the original because the perl implementation uses them to separate words. """ rouge = PythonRouge() original = 'Xu Wenli, Wang Youchai, and Qin Yongmin, leading dissidents and '\ 'prominent members of the China-Democracy-Party, were found guilty of subversion ' \ 'and sentenced to 13, 11, and 12 years in prison, respectively.' expected = 'xu wenli wang youchai and qin yongmin leading dissidents and ' \ 'prominent members of the china democracy party were found guilty of subversion ' \ 'and sentenced to 13 11 and 12 years in prison respectively'.split() actual = rouge.normalize_and_tokenize_sentence(original, use_porter_stemmer=False, remove_stopwords=False) assert expected == actual expected = 'xu wenli wang youchai and qin yongmin lead dissid and promin '\ 'member of the china democraci parti be find guilti of subvers and sentenc '\ 'to 13 11 and 12 year in prison respect'.split() actual = rouge.normalize_and_tokenize_sentence(original, use_porter_stemmer=True, remove_stopwords=False) assert expected == actual expected = 'xu wenli wang youchai qin yongmin leading dissidents prominent '\ 'members china democracy party found guilty subversion sentenced 13 11 '\ '12 years prison'.split() actual = rouge.normalize_and_tokenize_sentence(original, use_porter_stemmer=False, remove_stopwords=True) assert expected == actual expected = 'xu wenli wang youchai qin yongmin lead dissid promin member china '\ 'democraci parti find guilti subvers sentenc 13 11 12 year prison'.split() actual = rouge.normalize_and_tokenize_sentence(original, use_porter_stemmer=True, remove_stopwords=True) assert expected == actual
def main(args): python_rouge = PythonRouge() with JsonlWriter(args.output_jsonl) as out: with JsonlReader(args.input_jsonl) as f: with Parallel(n_jobs=args.num_cores) as parallel: batch = [] for instance in tqdm(f): batch.append(instance) if len(batch) == _BATCH_SIZE: _process_batch(parallel, batch, python_rouge, out) batch.clear() if batch: _process_batch(parallel, batch, python_rouge, out)
def main(args): python_rouge = PythonRouge() with JsonlWriter(args.output_jsonl) as out: with JsonlReader(args.input_jsonl) as f: for instance in f: document = instance['document'] summary = instance['summary'] _, labels = get_greedy_oracle_summary(document, summary, args.metric, max_sentences=args.max_sentences, max_tokens=args.max_tokens, max_bytes=args.max_bytes, use_porter_stemmer=args.use_stemmer, remove_stopwords=args.remove_stopwords, python_rouge=python_rouge) instance['labels'] = labels out.write(instance)
def main(args): python_rouge = PythonRouge() with JsonlWriter(args.output_jsonl) as out: with JsonlReader(args.input_jsonl) as f: for instance in f: document = instance['document'] cloze = instance['cloze'] oracle, labels = get_greedy_oracle_summary( document, [cloze], args.metric, max_sentences=args.max_sentences, max_tokens=args.max_tokens, max_bytes=args.max_bytes, use_porter_stemmer=args.use_stemmer, remove_stopwords=args.remove_stopwords, python_rouge=python_rouge) if args.cloze_only: oracle = ' '.join(oracle) out.write({'cloze': oracle}) else: instance['labels'] = labels out.write(instance)
def test_python_rouge(self): python_rouge = PythonRouge() summary = [ "His tenacity holds despite the summary trials and harsh punishments for Xu, Wang Youcai and Qin Yongmin prominent party principals from the provinces who were sentenced to 11 and 12 years and despite threatening signs from the ruling Communist Party.", "The dissidents Xu Wenli, who was sentenced Monday to 13 years in prison, Wang Youcai, who received an 11-year sentence, and Qin Yongming, who was reported to have received 12 years were charged with subversion.", "As police moved against Xu's friends, labor rights campaigner Liu Nianchun was taken from a prison camp outside Beijing and, with his wife and daughter, was put on a plane to Canada and then New York, his first taste of freedom in more than 3 1/2 years." ] gold_summaries = [ [ "While China plans to sign the International Covenant on Civil and Political Rights at the U.N., it is still harassing and arresting human rights campaigners.", "Three prominent leaders of the China Democratic Party were put to trial and sentenced to 11-, 12- and 13-year prison terms.", "Germany and the U.S. condemned the arrests.", "A labor rights activist was released and exiled to the U.S. to blunt any opposition to Communist rule.", "U.S. policy to encourage trade and diplomacy in hope of democratic reforms evidences failure, but the U.S. is continuing its policy of encouragement.", "Friends of jailed dissidents state that they will continue to campaign for change." ], [ "The US trade-driven policy of expanded ties encouraging Chinese democracy is questioned.", "China signed rights treaties and dissidents used new laws to set up China Democracy Party, but China violates the new laws by persecuting dissidents.", "It regularly frees activists from prison then exiles them so they lose local influence.", "It arrested an activist trying to register a rights monitoring group.", "CP leader Jiang's hard-line speech and publicity for activists sentenced to long prison terms signals a renewed Chinese crackdown.", "A rights activist expected to be sacrificed in the cause of democracy.", "Germany called China's sentencing of dissidents unacceptable." ], [ "After 2 years of wooing the West by signing international accords, apparently relaxing controls on free speech, and releasing and exiling three dissenters, China cracked down against political dissent in Dec 1998.", "Leaders of the China Democracy Party (CDP) were arrested and three were sentenced to jail terms of 11 to 13 years.", "The West, including the US, UK and Germany, reacted strongly.", "Clinton's China policy of engagement was questioned.", "China's Jiang Zemin stated economic reform is not a prelude to democracy and vowed to crush any challenges to the Communist Party or \"social stability\".", "The CDP vowed to keep working, as more leaders awaited arrest." ], [ "Xu Wenli, Wang Youchai, and Qin Yongmin, leading dissidents and prominent members of the China Democracy Party, were found guilty of subversion and sentenced to 13, 11, and 12 years in prison, respectively.", "Soon after the sentencing, China's president, Jiang Zemin, delivered speeches in which he asserted that Western political system must not be adopted and vowed to crush challenges to Communist Party rule.", "The harsh sentences and speeches signal a crackdown on dissent, but Zha Jianguo, another Democracy Party leader, says he will continue to push for change.", "Western nations condemned the sentences as violations of U.N. rights treaties signed by China." ] ] compute_rouge_l = True use_stemmer = False remove_stopwords = False expected_metrics = run_rouge([gold_summaries], [summary], use_porter_stemmer=use_stemmer, remove_stopwords=remove_stopwords, max_ngram=2, compute_rouge_l=compute_rouge_l) actual_metrics = python_rouge.run_python_rouge([gold_summaries], [summary], use_porter_stemmer=use_stemmer, remove_stopwords=remove_stopwords, compute_rouge_l=compute_rouge_l) self.assertAlmostEqual(expected_metrics[R1_PRECISION], actual_metrics[R1_PRECISION], places=2) self.assertAlmostEqual(expected_metrics[R1_RECALL], actual_metrics[R1_RECALL], places=2) self.assertAlmostEqual(expected_metrics[R1_F1], actual_metrics[R1_F1], places=2) self.assertAlmostEqual(expected_metrics[R2_PRECISION], actual_metrics[R2_PRECISION], places=2) self.assertAlmostEqual(expected_metrics[R2_RECALL], actual_metrics[R2_RECALL], places=2) self.assertAlmostEqual(expected_metrics[R2_F1], actual_metrics[R2_F1], places=2) self.assertAlmostEqual(expected_metrics[RL_PRECISION], actual_metrics[RL_PRECISION], places=2) self.assertAlmostEqual(expected_metrics[RL_RECALL], actual_metrics[RL_RECALL], places=2) self.assertAlmostEqual(expected_metrics[RL_F1], actual_metrics[RL_F1], places=2) use_stemmer = False remove_stopwords = True expected_metrics = run_rouge([gold_summaries], [summary], use_porter_stemmer=use_stemmer, remove_stopwords=remove_stopwords, max_ngram=2, compute_rouge_l=compute_rouge_l) actual_metrics = python_rouge.run_python_rouge([gold_summaries], [summary], use_porter_stemmer=use_stemmer, remove_stopwords=remove_stopwords, compute_rouge_l=compute_rouge_l) self.assertAlmostEqual(expected_metrics[R1_PRECISION], actual_metrics[R1_PRECISION], places=2) self.assertAlmostEqual(expected_metrics[R1_RECALL], actual_metrics[R1_RECALL], places=2) self.assertAlmostEqual(expected_metrics[R1_F1], actual_metrics[R1_F1], places=2) self.assertAlmostEqual(expected_metrics[R2_PRECISION], actual_metrics[R2_PRECISION], places=2) self.assertAlmostEqual(expected_metrics[R2_RECALL], actual_metrics[R2_RECALL], places=2) self.assertAlmostEqual(expected_metrics[R2_F1], actual_metrics[R2_F1], places=2) self.assertAlmostEqual(expected_metrics[RL_PRECISION], actual_metrics[RL_PRECISION], places=2) self.assertAlmostEqual(expected_metrics[RL_RECALL], actual_metrics[RL_RECALL], places=2) self.assertAlmostEqual(expected_metrics[RL_F1], actual_metrics[RL_F1], places=2) use_stemmer = True remove_stopwords = False expected_metrics = run_rouge([gold_summaries], [summary], use_porter_stemmer=use_stemmer, remove_stopwords=remove_stopwords, max_ngram=2, compute_rouge_l=compute_rouge_l) actual_metrics = python_rouge.run_python_rouge([gold_summaries], [summary], use_porter_stemmer=use_stemmer, remove_stopwords=remove_stopwords, compute_rouge_l=compute_rouge_l) self.assertAlmostEqual(expected_metrics[R1_PRECISION], actual_metrics[R1_PRECISION], places=2) self.assertAlmostEqual(expected_metrics[R1_RECALL], actual_metrics[R1_RECALL], places=2) self.assertAlmostEqual(expected_metrics[R1_F1], actual_metrics[R1_F1], places=2) self.assertAlmostEqual(expected_metrics[R2_PRECISION], actual_metrics[R2_PRECISION], places=2) self.assertAlmostEqual(expected_metrics[R2_RECALL], actual_metrics[R2_RECALL], places=2) self.assertAlmostEqual(expected_metrics[R2_F1], actual_metrics[R2_F1], places=2) self.assertAlmostEqual(expected_metrics[RL_PRECISION], actual_metrics[RL_PRECISION], places=2) self.assertAlmostEqual(expected_metrics[RL_RECALL], actual_metrics[RL_RECALL], places=2) self.assertAlmostEqual(expected_metrics[RL_F1], actual_metrics[RL_F1], places=2) use_stemmer = True remove_stopwords = True expected_metrics = run_rouge([gold_summaries], [summary], use_porter_stemmer=use_stemmer, remove_stopwords=remove_stopwords, max_ngram=2, compute_rouge_l=compute_rouge_l) actual_metrics = python_rouge.run_python_rouge([gold_summaries], [summary], use_porter_stemmer=use_stemmer, remove_stopwords=remove_stopwords, compute_rouge_l=compute_rouge_l) self.assertAlmostEqual(expected_metrics[R1_PRECISION], actual_metrics[R1_PRECISION], places=2) self.assertAlmostEqual(expected_metrics[R1_RECALL], actual_metrics[R1_RECALL], places=2) self.assertAlmostEqual(expected_metrics[R1_F1], actual_metrics[R1_F1], places=2) self.assertAlmostEqual(expected_metrics[R2_PRECISION], actual_metrics[R2_PRECISION], places=2) self.assertAlmostEqual(expected_metrics[R2_RECALL], actual_metrics[R2_RECALL], places=2) self.assertAlmostEqual(expected_metrics[R2_F1], actual_metrics[R2_F1], places=2) self.assertAlmostEqual(expected_metrics[RL_PRECISION], actual_metrics[RL_PRECISION], places=2) self.assertAlmostEqual(expected_metrics[RL_RECALL], actual_metrics[RL_RECALL], places=2) self.assertAlmostEqual(expected_metrics[RL_F1], actual_metrics[RL_F1], places=2)
class PythonRougeMetric(Metric): """ The ``PythonRougeMetric`` is an implementation of ROUGE that can be used while training models. The metrics computed over several different batches should be identical to calling the ``PythonRouge`` class once for the entire set of summaries. Parameters ---------- ngram_orders: ``Union[int, List[int]]`` The n-gram orders that should be computed. This should be the minimum number required for the fastest computation. max_sentences: ``int``, optional (default = ``None``) The maximum number of sentences to use for ROUGE. If ``None``, all are used. max_words: ``int``, optional (default = ``None``) The maximum number of words to use for ROUGE. If ``None``, all are used. max_bytes: ``int``, optional (default = ``None``) The maximum number of bytes to use for ROUGE. If ``None``, all are used. use_porter_stemmer: ``bool``, optional (default = ``True``) Indicates if the Porter Stemmer should be used. remove_stopwords: ``bool``, optional (default = ``False``) Indicates if stopwords should be removed from the summaries. namespace: ``str``, optional (default = ``"tokens"``) The summary token namespace. """ def __init__(self, vocab: Vocabulary, ngram_orders: Union[int, List[int]], max_sentences: Optional[int] = None, max_words: Optional[int] = None, max_bytes: Optional[int] = None, use_porter_stemmer: bool = True, remove_stopwords: bool = False, namespace: str = 'tokens') -> None: super().__init__() if isinstance(ngram_orders, int): ngram_orders = [ngram_orders] self.ngram_orders = ngram_orders self.max_sentences = max_sentences self.max_words = max_words self.max_bytes = max_bytes self.use_porter_stemmer = use_porter_stemmer self.remove_stopwords = remove_stopwords self.python_rouge = PythonRouge() self.vocab = vocab self.namespace = namespace vocab_tokens = vocab.get_token_to_index_vocabulary(namespace) # Extract the special tokens from the vocabulary. We need to check and # ensure each one exists, otherwise we would get the OOV symbol, which # we don't want to skip when converting from indices to strings. self.start_index = None if START_SYMBOL in vocab_tokens: self.start_index = vocab_tokens[START_SYMBOL] self.end_index = None if END_SYMBOL in vocab_tokens: self.end_index = vocab_tokens[END_SYMBOL] self.pad_index = None if DEFAULT_PADDING_TOKEN in vocab_tokens: self.pad_index = vocab_tokens[DEFAULT_PADDING_TOKEN] self.sent_start_index = None if SENT_START_SYMBOL in vocab_tokens: self.sent_start_index = vocab_tokens[SENT_START_SYMBOL] self.sent_end_index = None if SENT_END_SYMBOL in vocab_tokens: self.sent_end_index = vocab_tokens[SENT_END_SYMBOL] self.count = 0 self.totals = {} def _get_string_from_tensor(self, tensor: torch.Tensor) -> str: assert tensor.dim() == 1 tokens = [] for index in tensor: index = index.item() # We skip the start, sentence start, and sentence end symbols. It is # ok if these symbols are ``None`` since they won't match the index if index in [ self.start_index, self.sent_start_index, self.sent_end_index ]: continue # We end if we see the end or padding index if index in [self.end_index, self.pad_index]: break tokens.append( self.vocab.get_token_from_index(index, self.namespace)) return tokens def _convert_to_strings( self, summaries: Union[List[str], List[List[str]], List[torch.Tensor], torch.Tensor] ) -> List[List[str]]: """ Converts the summaries into ``List[List[str]]``, where each individual summary is a ``List[str]``. Abstractive summaries will be a list of length 1. """ # If the inner-most element is a string, these are already ok and they # just might need to be added to a new dimension (for abstractive). Otherwise, # we need to convert them from tensors into strings if isinstance(summaries, list): if isinstance(summaries[0], list): if isinstance(summaries[0][0], str): # Extractive strings return summaries elif isinstance(summaries[0], str): # Abstractive strings return [[summary] for summary in summaries] # Otherwise, this is a tensor return self._convert_tensor_to_strings(summaries) def _convert_tensor_to_strings( self, summaries: Union[List[torch.Tensor], torch.Tensor]) -> List[List[str]]: """ Converts the summaries represented as tensors to ``List[List[str]]``, where each summary is a ``List[str]``. """ # If the summaries have 2 dimensions, they are assumed to be (batch_size, num_tokens) # objects that has the tokens as a single sequence. This is generally done for # abstractive summarization. If the summaries have 3 dimensions, they are # assumed to be of size (batch_size, num_sents, num_tokens), which is common # with extractive summarization. summaries_strings = [] for summary in summaries: if summary.dim() == 1: # Abstractive tokens = self._get_string_from_tensor(summary) if not tokens: raise Exception(f'Summary has no tokens: {summary}') string = ' '.join(tokens) summaries_strings.append([string]) elif summary.dim() == 2: # Extractive sentence_strings = [] for sentence in summary: tokens = self._get_string_from_tensor(sentence) if tokens: string = ' '.join(tokens) sentence_strings.append(string) if not sentence_strings: raise Exception(f'Summary has no tokens: {summary}') summaries_strings.append(sentence_strings) else: raise Exception(f'Summaries must be 1- or 2-dimensional') return summaries_strings @overrides def __call__(self, gold_summaries: Union[List[str], List[List[str]], List[torch.Tensor], torch.Tensor], model_summaries: Union[List[str], List[List[str]], List[torch.Tensor], torch.Tensor], **kwargs) -> None: """ Computes ROUGE based on the batched input summaries. The summaries can be represented with several different types, depending on the use case. - ``List[str]``: Each summary is a ``str``, used with abstractive summaries - ``List[List[str]]``: Each summary is a ``List[str]``, used with extractive summaries - ``List[torch.Tensor]``: Each summary is a ``torch.Tensor`` that is either 1- or 2-dimensional for abstractive or extractive summaries, respectively - ``torch.Tensor``: Each summary is a ``torch.Tensor``. If the input is 2-dimensional, each row is assumed to be a 1-dimensional abstractive summary. If the input is 3-dimensional, each matrix is assumed to be a 2-dimensional extractive summary. Parameters ---------- gold_summaries: ``Union[List[str], List[List[str]], List[torch.Tensor], torch.Tensor]`` See above. model_summaries: ``Union[List[str], List[List[str]], List[torch.Tensor], torch.Tensor]`` See above. """ gold_summaries = self._convert_to_strings(gold_summaries) model_summaries = self._convert_to_strings(model_summaries) metrics = self.python_rouge.run_python_rouge( gold_summaries, model_summaries, ngram_orders=self.ngram_orders, max_sentences=self.max_sentences, max_words=self.max_words, max_bytes=self.max_bytes, use_porter_stemmer=self.use_porter_stemmer, remove_stopwords=self.remove_stopwords) num_instances = len(gold_summaries) self.count += num_instances for metric, value in metrics.items(): if metric not in self.totals: self.totals[metric] = value * num_instances else: self.totals[metric] += value * num_instances @overrides def get_metric(self, reset: bool = False) -> Dict[str, float]: metrics = {} for metric, value in self.totals.items(): metrics[metric] = value / self.count if reset: self.reset() return metrics @overrides def reset(self) -> None: self.count = 0 self.totals.clear()
def get_greedy_oracle_summary(document: List[str], summary: Union[List[str], List[List[str]]], metric: str, max_sentences: Optional[int] = None, max_tokens: Optional[int] = None, max_bytes: Optional[int] = None, use_porter_stemmer: bool = True, remove_stopwords: bool = False, python_rouge: Optional[PythonRouge] = None) -> Tuple[List[str], List[int]]: """ Computes the greedy oracle summary by selecting sentences from the input document while greedily increasing the metric until the summary budget is met. Exactly one of ``max_sentences``, ``max_tokens``, and ``max_bytes`` must be not None. Parameters ---------- document: The sentence-tokenized input document from which to extract the summary. summary: The ground-truth summary (``List[str]``) or summaries (``List[List[str]]``) metric: The name of the metric to greedily optimize. The metrics currently supported are the Rouge metrics (see `rouge.py`) max_sentences: The maximum number of allowed sentences to take. max_tokens: The maximum number of allowed tokens to take. max_bytes: The maximum number of allowed bytes to take. python_rouge: The PythonRouge object to use to compute the metrics, useful to avoid reloading the external resources on each call. Returns ------- The summary (``List[str]``) and the corresponding sentence indices which were selected from the input document (``List[int]``). """ if python_rouge is None: python_rouge = PythonRouge() if metric in [R1_RECALL, R1_PRECISION, R1_F1]: ngram_orders = [1] compute_rouge_l = False elif metric in [R2_RECALL, R2_PRECISION, R2_F1]: ngram_orders = [2] compute_rouge_l = False elif metric in [R3_RECALL, R3_PRECISION, R3_F1]: ngram_orders = [3] compute_rouge_l = False elif metric in [R4_RECALL, R4_PRECISION, R4_F1]: ngram_orders = [4] compute_rouge_l = False elif metric in [RL_RECALL, RL_PRECISION, RL_F1]: ngram_orders = [] compute_rouge_l = True else: raise Exception(f'Unknown metric: {metric}') candidates = set(range(len(document))) selected = [] current_score = None while len(candidates) > 0: max_index, max_score = None, None for index in candidates: candidate_summary = [document[index] for index in sorted(selected + [index])] metrics = python_rouge.run_python_rouge([summary], [candidate_summary], ngram_orders=ngram_orders, max_sentences=max_sentences, max_words=max_tokens, max_bytes=max_bytes, use_porter_stemmer=use_porter_stemmer, remove_stopwords=remove_stopwords, compute_rouge_l=compute_rouge_l) score = metrics[metric] if max_score is None or score > max_score: max_score = score max_index = index if current_score is None or max_score > current_score: current_score = max_score selected.append(max_index) candidates.remove(max_index) else: break selected = list(sorted(selected)) summary = [document[index] for index in selected] return summary, selected