def _compute_single_tag( self, source_token, target_token_idx, target_tokens): """Computes a single tag. The tag may match multiple target tokens (via tag.added_phrase) so we return the next unmatched target token. Args: source_token: The token to be tagged. target_token_idx: Index of the current target tag. target_tokens: List of all target tokens. Returns: A tuple with (1) the computed tag and (2) the next target_token_idx. """ source_token = source_token.lower() target_token = target_tokens[target_token_idx].lower() if source_token == target_token: return tagging.Tag('KEEP'), target_token_idx + 1 # source_token!=target_token的情况 added_phrase = '' for num_added_tokens in range(1, self._max_added_phrase_length + 1): if target_token not in self._token_vocabulary: break added_phrase += (' ' if added_phrase else '') + target_token next_target_token_idx = target_token_idx + num_added_tokens if next_target_token_idx >= len(target_tokens): # 已经完成转化 break target_token = target_tokens[next_target_token_idx].lower() if (source_token == target_token and added_phrase in self._phrase_vocabulary): return tagging.Tag('KEEP|' + added_phrase), next_target_token_idx + 1 return tagging.Tag('DELETE'), target_token_idx
def test_first_deletion_idx_computation(self): converter = tagging_converter.TaggingConverter([]) tag_strs = ['KEEP', 'DELETE', 'DELETE', 'KEEP'] tags = [tagging.Tag(s) for s in tag_strs] source_token_idx = 3 idx = converter._find_first_deletion_idx(source_token_idx, tags) self.assertEqual(idx, 1)
def __init__(self, tf_predictor, example_builder, label_map): """Initializes an instance of LaserTaggerPredictor. Args: tf_predictor: Loaded Tensorflow model. example_builder: BERT example builder. label_map: Mapping from tags to tag IDs. """ self._predictor = tf_predictor self._example_builder = example_builder self._id_2_tag = { tag_id: tagging.Tag(tag) for tag, tag_id in label_map.items() }
def _compute_tags_fixed_order(self, source_tokens, target_tokens): """Computes tags when the order of sources is fixed. Args: source_tokens: List of source tokens. target_tokens: List of tokens to be obtained via edit operations. Returns: List of tagging.Tag objects. If the source couldn't be converted into the target via tagging, returns an empty list. """ tags = [tagging.Tag('DELETE') for _ in source_tokens] # Indices of the tokens currently being processed. source_token_idx = 0 target_token_idx = 0 while target_token_idx < len(target_tokens): tags[source_token_idx], target_token_idx = self._compute_single_tag( source_tokens[source_token_idx], target_token_idx, target_tokens) # TODO 可以有多种标注方式从source转化为target,目前限定到一种 # If we're adding a phrase and the previous source token(s) were deleted, # we could add the phrase before a previously deleted token and still get # the same realized output. For example: # [DELETE, DELETE, KEEP|"what is"] # and # [DELETE|"what is", DELETE, KEEP] # Would yield the same realized output. Experimentally, we noticed that # the model works better / the learning task becomes easier when phrases # are always added before the first deleted token. Also note that in the # current implementation, this way of moving the added phrase backward is # the only way a DELETE tag can have an added phrase, so sequences like # [DELETE|"What", DELETE|"is"] will never be created. if tags[source_token_idx].added_phrase: # # the learning task becomes easier when phrases are always added before the first deleted token first_deletion_idx = self._find_first_deletion_idx( source_token_idx, tags) if first_deletion_idx != source_token_idx: tags[first_deletion_idx].added_phrase = ( tags[source_token_idx].added_phrase) tags[source_token_idx].added_phrase = '' source_token_idx += 1 if source_token_idx >= len(tags): break # If all target tokens have been consumed, we have found a conversion and # can return the tags. Note that if there are remaining source tokens, they # are already marked deleted when initializing the tag list. if target_token_idx >= len(target_tokens): # all target tokens have been consumed return tags return [] # TODO 不能转化
def get_phrase_vocabulary_from_label_map( label_map): """Extract the set of all phrases from label map. Args: label_map: Mapping from tags to tag IDs. Returns: Set of all phrases appearing in the label map. """ phrase_vocabulary = set() for label in label_map.keys(): tag = tagging.Tag(label) if tag.added_phrase: phrase_vocabulary.add(tag.added_phrase) return phrase_vocabulary
def build_bert_example( self, sources, target=None, use_arbitrary_target_ids_for_infeasible_examples=False, location=None): """Constructs a BERT Example. Args: sources: List of source texts. target: Target text or None when building an example during inference. use_arbitrary_target_ids_for_infeasible_examples: Whether to build an example with arbitrary target ids even if the target can't be obtained via tagging. Returns: BertExample, or None if the conversion from text to tags was infeasible and use_arbitrary_target_ids_for_infeasible_examples == False. """ # Compute target labels. task = tagging.EditingTask(sources, location=location) if target is not None: tags = self._converter.compute_tags(task, target) if not tags: # 不可转化,取决于 use_arbitrary_target_ids_for_infeasible_examples if use_arbitrary_target_ids_for_infeasible_examples: # Create a tag sequence [KEEP, DELETE, KEEP, DELETE, ...] which is # unlikely to be predicted by chance. tags = [ tagging.Tag('KEEP') if i % 2 == 0 else tagging.Tag('DELETE') for i, _ in enumerate(task.source_tokens) ] else: return None else: # If target is not provided, we set all target labels to KEEP. tags = [tagging.Tag('KEEP') for _ in task.source_tokens] labels = [self._label_map[str(tag)] for tag in tags] tokens, labels, token_start_indices = self._split_to_wordpieces( # wordpiece: tag是以word为单位的,组成word的piece的标注与这个word相同 task.source_tokens, labels) if len(tokens) > self._max_seq_length - 2: print(curLine(), "%d tokens is to long," % len(task.source_tokens), "truncate task.source_tokens:", task.source_tokens) # 截断到self._max_seq_length - 2 tokens = self._truncate_list(tokens) labels = self._truncate_list(labels) input_tokens = ['[CLS]'] + tokens + ['[SEP]'] labels_mask = [0] + [1] * len(labels) + [0] labels = [0] + labels + [0] input_ids = self._tokenizer.convert_tokens_to_ids(input_tokens) input_mask = [1] * len(input_ids) segment_ids = [0] * len(input_ids) example = BertExample(input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, labels=labels, labels_mask=labels_mask, token_start_indices=token_start_indices, task=task, default_label=self._keep_tag_id) example.pad_to_max_length(self._max_seq_length, self._pad_id) return example