def _compute_single_tag(
            self, source_token, target_token_idx,
            target_tokens):
        """Computes a single tag.

        The tag may match multiple target tokens (via tag.added_phrase) so we return
        the next unmatched target token.

        Args:
          source_token: The token to be tagged.
          target_token_idx: Index of the current target tag.
          target_tokens: List of all target tokens.

        Returns:
          A tuple with (1) the computed tag and (2) the next target_token_idx.
        """
        source_token = source_token.lower()
        target_token = target_tokens[target_token_idx].lower()
        if source_token == target_token:
            return tagging.Tag('KEEP'), target_token_idx + 1
        # source_token!=target_token的情况
        added_phrase = ''
        for num_added_tokens in range(1, self._max_added_phrase_length + 1):
            if target_token not in self._token_vocabulary:
                break
            added_phrase += (' ' if added_phrase else '') + target_token
            next_target_token_idx = target_token_idx + num_added_tokens
            if next_target_token_idx >= len(target_tokens):  # 已经完成转化
                break
            target_token = target_tokens[next_target_token_idx].lower()
            if (source_token == target_token and
                    added_phrase in self._phrase_vocabulary):
                return tagging.Tag('KEEP|' + added_phrase), next_target_token_idx + 1
        return tagging.Tag('DELETE'), target_token_idx
 def test_first_deletion_idx_computation(self):
     converter = tagging_converter.TaggingConverter([])
     tag_strs = ['KEEP', 'DELETE', 'DELETE', 'KEEP']
     tags = [tagging.Tag(s) for s in tag_strs]
     source_token_idx = 3
     idx = converter._find_first_deletion_idx(source_token_idx, tags)
     self.assertEqual(idx, 1)
Exemplo n.º 3
0
    def __init__(self, tf_predictor, example_builder, label_map):
        """Initializes an instance of LaserTaggerPredictor.

    Args:
      tf_predictor: Loaded Tensorflow model.
      example_builder: BERT example builder.
      label_map: Mapping from tags to tag IDs.
    """
        self._predictor = tf_predictor
        self._example_builder = example_builder
        self._id_2_tag = {
            tag_id: tagging.Tag(tag)
            for tag, tag_id in label_map.items()
        }
    def _compute_tags_fixed_order(self, source_tokens, target_tokens):
        """Computes tags when the order of sources is fixed.

        Args:
          source_tokens: List of source tokens.
          target_tokens: List of tokens to be obtained via edit operations.

        Returns:
          List of tagging.Tag objects. If the source couldn't be converted into the
          target via tagging, returns an empty list.
        """
        tags = [tagging.Tag('DELETE') for _ in source_tokens]
        # Indices of the tokens currently being processed.
        source_token_idx = 0
        target_token_idx = 0
        while target_token_idx < len(target_tokens):
            tags[source_token_idx], target_token_idx = self._compute_single_tag(
                source_tokens[source_token_idx], target_token_idx, target_tokens)
            # TODO 可以有多种标注方式从source转化为target,目前限定到一种
            # If we're adding a phrase and the previous source token(s) were deleted,
            # we could add the phrase before a previously deleted token and still get
            # the same realized output. For example:
            #    [DELETE, DELETE, KEEP|"what is"]
            # and
            #    [DELETE|"what is", DELETE, KEEP]
            # Would yield the same realized output. Experimentally, we noticed that
            # the model works better / the learning task becomes easier when phrases
            # are always added before the first deleted token. Also note that in the
            # current implementation, this way of moving the added phrase backward is
            # the only way a DELETE tag can have an added phrase, so sequences like
            # [DELETE|"What", DELETE|"is"] will never be created.
            if tags[source_token_idx].added_phrase:
                # # the learning task becomes easier when phrases are always added before the first deleted token
                first_deletion_idx = self._find_first_deletion_idx(
                    source_token_idx, tags)
                if first_deletion_idx != source_token_idx:
                    tags[first_deletion_idx].added_phrase = (
                        tags[source_token_idx].added_phrase)
                    tags[source_token_idx].added_phrase = ''
            source_token_idx += 1
            if source_token_idx >= len(tags):
                break

        # If all target tokens have been consumed, we have found a conversion and
        # can return the tags. Note that if there are remaining source tokens, they
        # are already marked deleted when initializing the tag list.
        if target_token_idx >= len(target_tokens):  # all target tokens have been consumed
            return tags
        return []  # TODO   不能转化
def get_phrase_vocabulary_from_label_map(
        label_map):
    """Extract the set of all phrases from label map.

    Args:
      label_map: Mapping from tags to tag IDs.

    Returns:
      Set of all phrases appearing in the label map.
    """
    phrase_vocabulary = set()
    for label in label_map.keys():
        tag = tagging.Tag(label)
        if tag.added_phrase:
            phrase_vocabulary.add(tag.added_phrase)
    return phrase_vocabulary
Exemplo n.º 6
0
    def build_bert_example(
            self,
            sources,
            target=None,
            use_arbitrary_target_ids_for_infeasible_examples=False,
            location=None):
        """Constructs a BERT Example.

    Args:
      sources: List of source texts.
      target: Target text or None when building an example during inference.
      use_arbitrary_target_ids_for_infeasible_examples: Whether to build an
        example with arbitrary target ids even if the target can't be obtained
        via tagging.

    Returns:
      BertExample, or None if the conversion from text to tags was infeasible
      and use_arbitrary_target_ids_for_infeasible_examples == False.
    """
        # Compute target labels.
        task = tagging.EditingTask(sources, location=location)
        if target is not None:
            tags = self._converter.compute_tags(task, target)
            if not tags:  #  不可转化,取决于 use_arbitrary_target_ids_for_infeasible_examples
                if use_arbitrary_target_ids_for_infeasible_examples:
                    # Create a tag sequence [KEEP, DELETE, KEEP, DELETE, ...] which is
                    # unlikely to be predicted by chance.
                    tags = [
                        tagging.Tag('KEEP') if i %
                        2 == 0 else tagging.Tag('DELETE')
                        for i, _ in enumerate(task.source_tokens)
                    ]
                else:
                    return None
        else:
            # If target is not provided, we set all target labels to KEEP.
            tags = [tagging.Tag('KEEP') for _ in task.source_tokens]
        labels = [self._label_map[str(tag)] for tag in tags]

        tokens, labels, token_start_indices = self._split_to_wordpieces(  #  wordpiece: tag是以word为单位的,组成word的piece的标注与这个word相同
            task.source_tokens, labels)
        if len(tokens) > self._max_seq_length - 2:
            print(curLine(), "%d tokens is to long," % len(task.source_tokens),
                  "truncate task.source_tokens:", task.source_tokens)
        #  截断到self._max_seq_length - 2
        tokens = self._truncate_list(tokens)
        labels = self._truncate_list(labels)

        input_tokens = ['[CLS]'] + tokens + ['[SEP]']
        labels_mask = [0] + [1] * len(labels) + [0]
        labels = [0] + labels + [0]

        input_ids = self._tokenizer.convert_tokens_to_ids(input_tokens)
        input_mask = [1] * len(input_ids)
        segment_ids = [0] * len(input_ids)
        example = BertExample(input_ids=input_ids,
                              input_mask=input_mask,
                              segment_ids=segment_ids,
                              labels=labels,
                              labels_mask=labels_mask,
                              token_start_indices=token_start_indices,
                              task=task,
                              default_label=self._keep_tag_id)
        example.pad_to_max_length(self._max_seq_length, self._pad_id)
        return example