def compute_sari_scores(
        sources,
        predictions,
        target_lists,
        ignore_wikisplit_separators=True
):
    print("compute_sari_scores",compute_sari_scores)
    """Computes SARI scores.

    Wraps the t2t implementation of SARI computation.

    Args:
      sources: List of sources.
      predictions: List of predictions.
      target_lists: List of targets (1 or more per prediction).
      ignore_wikisplit_separators: Whether to ignore "<::::>" tokens, used as
        sentence separators in Wikisplit, when evaluating. For the numbers
        reported in the paper, we accidentally ignored those tokens. Ignoring them
        does not affect the Exact score (since there's usually always a period
        before the separator to indicate sentence break), but it decreases the
        SARI score (since the Addition score goes down as the model doesn't get
        points for correctly adding <::::> anymore).

    Returns:
      Tuple (SARI score, keep score, addition score, deletion score).
    """
    sari_sum = 0
    keep_sum = 0
    add_sum = 0
    del_sum = 0
    for source, pred, targets in zip(sources, predictions, target_lists):
        if ignore_wikisplit_separators:
            source = re.sub(' <::::> ', ' ', source)
            pred = re.sub(' <::::> ', ' ', pred)
            targets = [re.sub(' <::::> ', ' ', t) for t in targets]
        source_ids = utils.get_token_list(source)
        pred_ids = utils.get_token_list(pred)
        list_of_targets = [utils.get_token_list(t) for t in targets]
        sari, keep, addition, deletion = sari_hook.get_sari_score(
            source_ids, pred_ids, list_of_targets, beta_for_deletion=1)
        sari_sum += sari
        keep_sum += keep
        add_sum += addition
        del_sum += deletion
    n = max(len(sources), 0.1)  # Avoids 0/0.
    return (sari_sum / n, keep_sum / n, add_sum / n, del_sum / n)



# if __name__ == '__main__':
#     tf.test.main()
#
#     print("done???")
Esempio n. 2
0
    def compute_tags(self, task, target):
        """Computes tags needed for converting the source into the target.

    Args:
      task: tagging.EditingTask that specifies the input.
      target: Target text.

    Returns:
      List of tagging.Tag objects. If the source couldn't be converted into the
      target via tagging, returns an empty list.
    """
        target_tokens = utils.get_token_list(target.lower())
        tags = self._compute_tags_fixed_order(task.source_tokens,
                                              target_tokens)
        # If conversion fails, try to obtain the target after swapping the source
        # order.
        if not tags and len(task.sources) == 2 and self._do_swap:
            swapped_task = tagging.EditingTask(task.sources[::-1])
            tags = self._compute_tags_fixed_order(swapped_task.source_tokens,
                                                  target_tokens)
            if tags:
                tags = (tags[swapped_task.first_tokens[1]:] +
                        tags[:swapped_task.first_tokens[1]])
                # We assume that the last token (typically a period) is never deleted,
                # so we can overwrite the tag_type with SWAP (which keeps the token,
                # moving it and the sentence it's part of to the end).
                tags[task.first_tokens[1] - 1].tag_type = tagging.TagType.SWAP
        return tags
    def __init__(self,
                 phrase_vocabulary,
                 do_swap=True,
                 arbitrary_reordering=True):
        """Initializes an instance of TaggingConverter.

    Args:
      phrase_vocabulary: Iterable of phrase vocabulary items (strings).
      do_swap: Whether to enable the SWAP tag.
      arbitrary_reordering: Whether to use arbitrary reordering
    """
        self._phrase_vocabulary = set(phrase.lower()
                                      for phrase in phrase_vocabulary)
        self._do_swap = do_swap
        # Maximum number of tokens in an added phrase (inferred from the
        # vocabulary).
        self._max_added_phrase_length = 0
        # Set of tokens that are part of a phrase in self.phrase_vocabulary.
        self._token_vocabulary = set()
        for phrase in self._phrase_vocabulary:
            tokens = utils.get_token_list(phrase)
            self._token_vocabulary |= set(tokens)
            if len(tokens) > self._max_added_phrase_length:
                self._max_added_phrase_length = len(tokens)

        self._arbitrary_reordering = arbitrary_reordering

        self._compute_tags_fixed_order = self._compute_tags_fixed_order_without_reordering
        if self._arbitrary_reordering:
            self._compute_tags_fixed_order = self._compute_tags_fixed_order_with_reordering
Esempio n. 4
0
    def __init__(self, sources, arbitrary_reordering=True, extra_tags=10):
        """Initializes an instance of EditingTask.

    Args:
      sources: A list of source strings. Typically contains only one string but
        for sentence fusion it contains two strings to be fused (whose order may
        be swapped).
      arbitrary_reordering: whether arbitrary reordering is used
      extra_tags: the number of extra tags used to pad the source tokens
    """
        self.sources = sources
        source_token_lists = [
            utils.get_token_list(text) for text in self.sources
        ]
        # Tokens of the source texts concatenated into a single list.
        self.source_tokens = []
        # The indices of the first tokens of each source text.
        self.first_tokens = []
        for token_list in source_token_lists:
            self.first_tokens.append(len(self.source_tokens))
            self.source_tokens.extend(token_list)

        self._arbitrary_reordering = arbitrary_reordering

        self._realize_sequence = self._realize_sequence_without_reordering
        if arbitrary_reordering:
            self._realize_sequence = self._realize_sequence_with_reordering

        self._extra_tags = extra_tags
def _get_added_phrases(source: Text,
                       target: Text,
                       lang: str = 'en') -> Sequence[Text]:
    """Computes the phrases that need to be added to the source to get the target.

  This is done by aligning each token in the LCS to the first match in the
  target and checking which phrases in the target remain unaligned.

  TODO(b/142853960): The LCS tokens should ideally be aligned to consecutive
  target tokens whenever possible, instead of aligning them always to the first
  match. This should result in a more meaningful phrase vocabulary with a higher
  coverage.

  Note that the algorithm is case-insensitive and the resulting phrases are
  always lowercase.

  Args:
    source: Source text.
    target: Target text.

  Returns:
    List of added phrases.
  """
    source_tokens = utils.get_token_list(source.lower(), lang)
    target_tokens = utils.get_token_list(target.lower(), lang)
    kept_tokens = _compute_lcs(source_tokens, target_tokens)
    added_phrases = []
    # Index of the `kept_tokens` element that we are currently looking for.
    kept_idx = 0
    phrase = []
    for token in target_tokens:
        if kept_idx < len(kept_tokens) and token == kept_tokens[kept_idx]:
            kept_idx += 1
            if phrase:
                if lang == 'zh':
                    added_phrases.append(''.join(phrase))
                else:
                    added_phrases.append(' '.join(phrase))
                phrase = []
        else:
            phrase.append(token)
    if phrase:
        if lang == 'zh':
            added_phrases.append(''.join(phrase))
        else:
            added_phrases.append(' '.join(phrase))
    return added_phrases
Esempio n. 6
0
def _get_added_phrases(source: Text, target: Text) -> Sequence[Text]:
    """Computes the phrases that need to be added to the source to get the target.

    This is done by aligning each token in the LCS to the first match in the
    target and checking which phrases in the target remain unaligned.

    TODO(b/142853960): The LCS tokens should ideally be aligned to consecutive(连续不断的)
    target tokens whenever possible, instead of aligning them always to the first
    match. This should result in a more meaningful phrase vocabulary with a higher
    coverage.

    Note that the algorithm is case-insensitive and the resulting phrases are
    always lowercase.

    Args:
      source: Source text.
      target: Target text.

    Returns:
      List of added phrases.
    """
    sep = ' '  # 英文是分成word sep=' ',中文是分成字 sep=''
    source_tokens = utils.get_token_list(
        source.lower())  # list(source.lower()) #  切句成字列表
    target_tokens = utils.get_token_list(
        target.lower())  # list(target.lower()) #  切句成字列表
    #print("phrase_vocabulary_optimization.py source_tokens",source_tokens)
    #print("phrase_vocabulary_optimization.py target_tokens",target_tokens)
    kept_tokens = _compute_lcs(source_tokens, target_tokens)  # 共用字
    #print("phrase_vocabulary_optimization.py kept_tokens",kept_tokens)
    added_phrases = []
    # Index of the `kept_tokens` element that we are currently looking for.
    kept_idx = 0
    phrase = []
    for token in target_tokens:
        if kept_idx < len(kept_tokens) and token == kept_tokens[kept_idx]:
            kept_idx += 1
            #print(phrase)
            if phrase:
                added_phrases.append(sep.join(phrase))
                phrase = []
        else:
            phrase.append(token)
    #print("phrase_vocabulary_optimization sep",sep)
    if phrase:
        added_phrases.append(sep.join(phrase))
    return added_phrases
Esempio n. 7
0
  def __init__(self, sources):
    """Initializes an instance of EditingTask.

    Args:
      sources: A list of source strings. Typically contains only one string but
        for sentence fusion it contains two strings to be fused (whose order may
        be swapped).
    """
    self.sources = sources
    source_token_lists = [utils.get_token_list(text) for text in self.sources]
    # Tokens of the source texts concatenated into a single list.
    self.source_tokens = []
    # The indices of the first tokens of each source text.
    self.first_tokens = []
    for token_list in source_token_lists:
      self.first_tokens.append(len(self.source_tokens))
      self.source_tokens.extend(token_list)
Esempio n. 8
0
    def __init__(self, phrase_vocabulary, do_swap=True):
        """Initializes an instance of TaggingConverter.

    Args:
      phrase_vocabulary: Iterable of phrase vocabulary items (strings).
      do_swap: Whether to enable the SWAP tag.
    """
        self._phrase_vocabulary = set(phrase.lower()
                                      for phrase in phrase_vocabulary)
        self._do_swap = do_swap
        # Maximum number of tokens in an added phrase (inferred from the
        # vocabulary).
        self._max_added_phrase_length = 0
        # Set of tokens that are part of a phrase in self.phrase_vocabulary.
        self._token_vocabulary = set()  # word piece 的集合
        for phrase in self._phrase_vocabulary:
            tokens = utils.get_token_list(phrase)
            self._token_vocabulary |= set(tokens)  # 集合的合并
            if len(tokens) > self._max_added_phrase_length:
                self._max_added_phrase_length = len(tokens)