def _compute_single_tag(self, source_token, target_token_idx, target_tokens): """Computes a single tag. The tag may match multiple target tokens (via tag.added_phrase) so we return the next unmatched target token. Args: source_token: The token to be tagged. target_token_idx: Index of the current target tag. target_tokens: List of all target tokens. Returns: A tuple with (1) the computed tag and (2) the next target_token_idx. """ source_token = source_token.lower() target_token = target_tokens[target_token_idx].lower() if source_token == target_token: return tagging.Tag('KEEP'), target_token_idx + 1 # source_token!=target_token的情况 added_phrase = '' for num_added_tokens in range(1, self._max_added_phrase_length + 1): if target_token not in self._token_vocabulary: break added_phrase += (' ' if added_phrase else '') + target_token next_target_token_idx = target_token_idx + num_added_tokens if next_target_token_idx >= len(target_tokens): # 已经完成转化 break target_token = target_tokens[next_target_token_idx].lower() if (source_token == target_token and added_phrase in self._phrase_vocabulary): return tagging.Tag('KEEP|' + added_phrase), next_target_token_idx + 1 return tagging.Tag('DELETE'), target_token_idx
def test_copying(self): input_texts = ['Turing was born in 1912 in London .'] tag_strs = ['KEEP'] * 8 tags = [tagging.Tag(s) for s in tag_strs] task = tagging.EditingTask(input_texts) self.assertEqual(task.realize_output(tags), input_texts[0]) # With multiple inputs. input_texts = ['a B', 'c D e', 'f g'] tag_strs = ['KEEP'] * 7 tags = [tagging.Tag(s) for s in tag_strs] task = tagging.EditingTask(input_texts) self.assertEqual(task.realize_output(tags), 'a B c D e f g')
def test_casing(self): input_texts = ['A b .', 'Cc dd .'] # Test lowcasing after a period has been removed. tag_strs = ['KEEP', 'KEEP', 'DELETE', 'KEEP', 'KEEP', 'KEEP'] tags = [tagging.Tag(s) for s in tag_strs] task = tagging.EditingTask(input_texts) self.assertEqual(task.realize_output(tags), 'A b cc dd .') # Test upcasing after the first upcased token has been removed. tag_strs = ['KEEP', 'KEEP', 'KEEP', 'DELETE', 'KEEP', 'KEEP'] tags = [tagging.Tag(s) for s in tag_strs] task = tagging.EditingTask(input_texts) self.assertEqual(task.realize_output(tags), 'A b . Dd .')
def test_tag_parsing(self): tag = tagging.Tag('KEEP') self.assertEqual(tag.tag_type, tagging.TagType.KEEP) self.assertEqual(tag.added_phrase, '') tag = tagging.Tag('DELETE|, and she') self.assertEqual(tag.tag_type, tagging.TagType.DELETE) self.assertEqual(tag.added_phrase, ', and she') tag = tagging.Tag('SWAP|asdf | foo') self.assertEqual(tag.tag_type, tagging.TagType.SWAP) self.assertEqual(tag.added_phrase, 'asdf | foo') with self.assertRaises(ValueError): tagging.Tag('NON_EXISTING_TAG')
def test_construct_example(self): vocab_file = "gs://bert_traning_yechen/trained_bert_uncased/bert_POS/vocab.txt" label_map_file = "gs://publicly_available_models_yechen/best_hypertuned_POS/label_map.txt" enable_masking = False do_lower_case = True embedding_type = "POS" label_map = utils.read_label_map(label_map_file) converter = tagging_converter.TaggingConverter( tagging_converter.get_phrase_vocabulary_from_label_map(label_map), True) id_2_tag = { tag_id: tagging.Tag(tag) for tag, tag_id in label_map.items() } builder = bert_example.BertExampleBuilder(label_map, vocab_file, 10, do_lower_case, converter, embedding_type, enable_masking) inputs, example = construct_example("This is a test", builder) self.assertEqual( inputs, { 'input_ids': [2, 12, 1016, 6, 9, 6, 9, 10, 12, 3], 'input_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'segment_ids': [2, 16, 14, 14, 32, 14, 32, 5, 14, 41] })
def test_first_deletion_idx_computation(self): converter = tagging_converter.TaggingConverter([]) tag_strs = ['KEEP', 'DELETE', 'DELETE', 'KEEP'] tags = [tagging.Tag(s) for s in tag_strs] source_token_idx = 3 idx = converter._find_first_deletion_idx(source_token_idx, tags) self.assertEqual(idx, 1)
def _compute_tags_fixed_order_without_reordering(self, source_tokens, target_tokens): """Computes tags when the order of sources is fixed. Args: source_tokens: List of source tokens. target_tokens: List of tokens to be obtained via edit operations. Returns: List of tagging.Tag objects. If the source couldn't be converted into the target via tagging, returns an empty list. """ tags = [tagging.Tag('DELETE') for _ in source_tokens] # Indices of the tokens currently being processed. source_token_idx = 0 target_token_idx = 0 while target_token_idx < len(target_tokens): #tags[source_token_idx], target_token_idx = self._compute_single_tag_mod( # source_tokens[source_token_idx], target_token_idx, target_tokens, source_tokens) tags[ source_token_idx], target_token_idx = self._compute_single_tag_without_reordering( source_tokens[source_token_idx], target_token_idx, target_tokens) # If we're adding a phrase and the previous source token(s) were deleted, # we could add the phrase before a previously deleted token and still get # the same realized output. For example: # [DELETE, DELETE, KEEP|"what is"] # and # [DELETE|"what is", DELETE, KEEP] # Would yield the same realized output. Experimentally, we noticed that # the model works better / the learning task becomes easier when phrases # are always added before the first deleted token. Also note that in the # current implementation, this way of moving the added phrase backward is # the only way a DELETE tag can have an added phrase, so sequences like # [DELETE|"What", DELETE|"is"] will never be created. if tags[source_token_idx].added_phrase and not tags[ source_token_idx].added_phrase.isnumeric(): first_deletion_idx = self._find_first_deletion_idx( source_token_idx, tags) if first_deletion_idx != source_token_idx: tags[first_deletion_idx].added_phrase = ( tags[source_token_idx].added_phrase) tags[source_token_idx].added_phrase = '' source_token_idx += 1 if source_token_idx >= len(tags): break # If all target tokens have been consumed, we have found a conversion and # can return the tags. Note that if there are remaining source tokens, they # are already marked deleted when initializing the tag list. print([ print("token: {0} - {1} ".format(source_tokens[i], str(label))) for i, label in enumerate(tags) ]) if target_token_idx >= len(target_tokens): return tags return []
def test_deletion(self): input_texts = ['Turing was born in 1912 in London .'] tag_strs = [ 'KEEP', 'DELETE', 'KEEP', 'KEEP', 'KEEP', 'KEEP', 'KEEP', 'DELETE' ] tags = [tagging.Tag(s) for s in tag_strs] task = tagging.EditingTask(input_texts) # "was" and "." should have been removed. self.assertEqual(task.realize_output(tags), 'Turing born in 1912 in London')
def test_phrase_adding(self): input_texts = ['Turing was born in 1912 in London .'] tag_strs = [ 'KEEP', 'DELETE|, a pioneer in TCS ,', 'KEEP', 'KEEP', 'KEEP', 'KEEP', 'KEEP', 'KEEP' ] tags = [tagging.Tag(s) for s in tag_strs] task = tagging.EditingTask(input_texts) self.assertEqual( task.realize_output(tags), 'Turing , a pioneer in TCS , born in 1912 in London .')
def test_swapping_complex(self): input_texts = [ 'Dylan won Nobel prize .', 'Dylan is an American musician .' ] tag_strs = [ 'DELETE', 'KEEP', 'KEEP', 'KEEP', 'SWAP', 'KEEP', 'DELETE|,', 'KEEP', 'KEEP', 'KEEP', 'DELETE|,' ] tags = [tagging.Tag(s) for s in tag_strs] task = tagging.EditingTask(input_texts) self.assertEqual(task.realize_output(tags), 'Dylan , an American musician , won Nobel prize .')
def test_swapping(self): input_texts = [ 'Turing was born in 1912 in London .', 'Turing died in 1954 .' ] tag_strs = [ 'KEEP', 'KEEP', 'KEEP', 'KEEP', 'KEEP', 'KEEP', 'KEEP', 'SWAP', 'KEEP', 'KEEP', 'KEEP', 'KEEP', 'KEEP' ] tags = [tagging.Tag(s) for s in tag_strs] task = tagging.EditingTask(input_texts) self.assertEqual( task.realize_output(tags), 'Turing died in 1954 . Turing was born in 1912 in London .')
def __init__(self, tf_predictor, example_builder, label_map): """Initializes an instance of LaserTaggerPredictor. Args: tf_predictor: Loaded Tensorflow model. example_builder: BERT example builder. label_map: Mapping from tags to tag IDs. """ self._predictor = tf_predictor self._example_builder = example_builder self._id_2_tag = { tag_id: tagging.Tag(tag) for tag, tag_id in label_map.items() }
def get_phrase_vocabulary_from_label_map(label_map): """Extract the set of all phrases from label map. Args: label_map: Mapping from tags to tag IDs. Returns: Set of all phrases appearing in the label map. """ phrase_vocabulary = set() for label in label_map.keys(): tag = tagging.Tag(label) if tag.added_phrase: phrase_vocabulary.add(tag.added_phrase) return phrase_vocabulary
def test_invalid_swapping(self): # When SWAP tag is assigned to other than the last token of the first of two # sentences, it should be treated as KEEP. input_texts = [ 'Turing was born in 1912 in London .', 'Turing died in 1954 .' ] tag_strs = [ 'KEEP', 'KEEP', 'KEEP', 'KEEP', 'KEEP', 'KEEP', 'SWAP', 'KEEP', 'KEEP', 'KEEP', 'KEEP', 'KEEP', 'KEEP' ] tags = [tagging.Tag(s) for s in tag_strs] task = tagging.EditingTask(input_texts) self.assertEqual( task.realize_output(tags), 'Turing was born in 1912 in London . Turing died in 1954 .')
def test_realize_output_in_order(self): """ Test for when source tokens occur in the same relative order in the target string """ editing_task = tagging.EditingTask(["word1 word2 <::::> word3 "]) tags_str = ['KEEP|0', 'KEEP|1', 'KEEP|and', 'DELETE', 'KEEP|3'] tags = [tagging.Tag(tag) for tag in tags_str] result = editing_task.realize_output([tags]) expected = "word1 word2 and word3 " self.assertEqual(expected, result)
def _compute_single_tag_with_reordering(self, source_token_idx, target_token_idx, target_tokens, source_tokens): """Computes a single tag. args: source_token_idx: the current index of the source token in the source tokens list target_token_idx: the current index of the target token in the target token list target tokens: list of target tokens source_tokens: list of source_tokens returns: Tag : the tag computed target token idx : updates and returns target token index source token idx : updates and returns source token index position to add at : a positive integer only in the case of addition of phrases, -1 otherwise. This is necessary to ensure the correct position of this tag in the final tags list The predicted tags can be: - KEEP|<position> where position is the position at which the source token occurs in the target string. - KEEP|<phrase_to_add> phrase_to_add is the phrase that is being added from the edit vocabulary - DELETE source token does not occur in the target and hence is tagged as delete """ source_token = source_tokens[min(source_token_idx, len(source_tokens) - 1)].lower() # skip any null tokens while (target_token_idx < len(target_tokens) - 1 and target_tokens[target_token_idx] == "<NULL>"): target_token_idx += 1 target_token = target_tokens[target_token_idx].lower() # if a target token doesnt exist in the source tokens # it is either a part of the edit vocabulary # in case it isnt, the generation is infeasible if target_token not in source_tokens[source_token_idx:]: if target_token in self._phrase_vocabulary: target_tokens[target_token_idx] = "<NULL>" return tagging.Tag( "KEEP|" + target_token ), target_token_idx + 1, source_token_idx, target_token_idx else: return 0 # if source token is in target tokens # return KEEP with the position at which it occurs # otherwise tag it as a delete elif source_token in target_tokens: idx = target_tokens.index(source_token) target_tokens[idx] = "<NULL>" return tagging.Tag( "KEEP|" + str(idx)), target_token_idx, source_token_idx + 1, -1 else: return tagging.Tag( "DELETE"), target_token_idx, source_token_idx + 1, -1
def build_bert_example( self, sources, target=None, use_arbitrary_target_ids_for_infeasible_examples=False): """Constructs a BERT Example. Args: sources: List of source texts. target: Target text or None when building an example during inference. use_arbitrary_target_ids_for_infeasible_examples: Whether to build an example with arbitrary target ids even if the target can't be obtained via tagging. Returns: BertExample, or None if the conversion from text to tags was infeasible and use_arbitrary_target_ids_for_infeasible_examples == False. """ # Compute target labels. task = tagging.EditingTask(sources) if target is not None: tags = self._converter.compute_tags(task, target) if not tags: if use_arbitrary_target_ids_for_infeasible_examples: # Create a tag sequence [KEEP, DELETE, KEEP, DELETE, ...] which is # unlikely to be predicted by chance. tags = [ tagging.Tag('KEEP') if i % 2 == 0 else tagging.Tag('DELETE') for i, _ in enumerate(task.source_tokens) ] else: return None else: # If target is not provided, we set all target labels to KEEP. tags = [tagging.Tag('KEEP') for _ in task.source_tokens] labels = [self._label_map[str(tag)] for tag in tags] tokens, labels, token_start_indices = self._split_to_wordpieces( task.source_tokens, labels) tokens = self._truncate_list(tokens) labels = self._truncate_list(labels) input_tokens = ['[CLS]'] + tokens + ['[SEP]'] labels_mask = [0] + [1] * len(labels) + [0] labels = [0] + labels + [0] input_ids = self._tokenizer.convert_tokens_to_ids(input_tokens) input_mask = [1] * len(input_ids) segment_ids = [0] * len(input_ids) example = BertExample(input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, labels=labels, labels_mask=labels_mask, token_start_indices=token_start_indices, task=task, default_label=self._keep_tag_id) example.pad_to_max_length(self._max_seq_length, self._pad_id) return example
except FileExistsError: print("NLTK averaged_perceptron_tagger exist") if embedding_type == "Normal" or embedding_type == "Sentence": vocab_file = "gs://lasertagger_training_yechen/cased_L-12_H-768_A-12/vocab.txt" elif embedding_type == "POS": vocab_file = "gs://bert_traning_yechen/trained_bert_uncased/bert_POS/vocab.txt" elif embedding_type == "POS_concise": vocab_file = "gs://bert_traning_yechen/trained_bert_uncased/bert_POS_concise/vocab.txt" else: raise ValueError("Unrecognized embedding type") label_map = utils.read_label_map(label_map_file) converter = tagging_converter.TaggingConverter( tagging_converter.get_phrase_vocabulary_from_label_map(label_map), True) id_2_tag = {tag_id: tagging.Tag(tag) for tag, tag_id in label_map.items()} builder = bert_example.BertExampleBuilder(label_map, vocab_file, 128, do_lower_case, converter, embedding_type, enable_masking) grammar_vocab_file = "gs://publicly_available_models_yechen/grammar_checker/vocab.txt" grammar_builder = bert_example_classifier.BertGrammarExampleBuilder( grammar_vocab_file, 128, False) def predict_json(project, model, instances, version=None): """ Send a json object to GCP deployed model for prediction. Args: project: name of the project where the model is in model: the name of the deployed model
def test_wrong_number_of_tags(self): input_texts = ['1 2'] tags = [tagging.Tag('KEEP')] task = tagging.EditingTask(input_texts) with self.assertRaises(ValueError): task.realize_output(tags)