def _compute_single_tag(self, source_token, target_token_idx, target_tokens): """Computes a single tag. The tag may match multiple target tokens (via tag.added_phrase) so we return the next unmatched target token. Args: source_token: The token to be tagged. target_token_idx: Index of the current target tag. target_tokens: List of all target tokens. Returns: A tuple with (1) the computed tag and (2) the next target_token_idx. """ source_token = source_token.lower() target_token = target_tokens[target_token_idx].lower() if source_token == target_token: return tagging.Tag('KEEP'), target_token_idx + 1 added_phrase = '' for num_added_tokens in range(1, self._max_added_phrase_length + 1): if target_token not in self._token_vocabulary: break added_phrase += (' ' if added_phrase else '') + target_token next_target_token_idx = target_token_idx + num_added_tokens if next_target_token_idx >= len(target_tokens): break target_token = target_tokens[next_target_token_idx].lower() if source_token == target_token and added_phrase in self._phrase_vocabulary: return tagging.Tag('KEEP|' + added_phrase), next_target_token_idx + 1 return tagging.Tag('DELETE'), target_token_idx
def _compute_tags_fixed_order(self, source_tokens, target_tokens): """Computes tags when the order of sources is fixed. Args: source_tokens: List of source tokens. target_tokens: List of tokens to be obtained via edit operations. Returns: List of tagging.Tag objects. If the source couldn't be converted into the target via tagging, returns an empty list. """ tags = [tagging.Tag('DELETE') for _ in source_tokens] # Indices of the tokens currently being processed. source_token_idx = 0 target_token_idx = 0 while target_token_idx < len(target_tokens): tags[ source_token_idx], target_token_idx = self._compute_single_tag( source_tokens[source_token_idx], target_token_idx, target_tokens) # If we're adding a phrase and the previous source token(s) were deleted, # we could add the phrase before a previously deleted token and still get # the same realized output. For example: # [DELETE, DELETE, KEEP|"what is"] # and # [DELETE|"what is", DELETE, KEEP] # Would yield the same realized output. Experimentally, we noticed that # the model works better / the learning task becomes easier when phrases # are always added before the first deleted token. Also note that in the # current implementation, this way of moving the added phrase backward is # the only way a DELETE tag can have an added phrase, so sequences like # [DELETE|"What", DELETE|"is"] will never be created. if tags[source_token_idx].added_phrase: first_deletion_idx = self._find_first_deletion_idx( source_token_idx, tags) if first_deletion_idx != source_token_idx: tags[first_deletion_idx].added_phrase = tags[ source_token_idx].added_phrase tags[source_token_idx].added_phrase = '' source_token_idx += 1 if source_token_idx >= len(tags): break # If all target tokens have been consumed, we have found a conversion and # can return the tags. Note that if there are remaining source tokens, they # are already marked deleted when initializing the tag list. if target_token_idx >= len(target_tokens): return tags return []
def get_phrase_vocabulary_from_label_map(label_map): """Extract the set of all phrases from label map. Args: label_map: Mapping from tags to tag IDs. Returns: Set of all phrases appearing in the label map. """ phrase_vocabulary = set() for label in label_map.keys(): tag = tagging.Tag(label) if tag.added_phrase: phrase_vocabulary.add(tag.added_phrase) return phrase_vocabulary
"max_steps": max_steps, "lr": args.lr, "weight_decay": args.weight_decay, }, batches_per_step=args.iter_per_step, ) elif args.command == 'infer': tensors_pred = create_pipeline(test_examples, args.batch_size, mode="infer") computed_tensors = nf.infer(tensors=tensors_pred, checkpoint_dir=args.work_dir) id_2_tag = { tag_id: tagging.Tag(tag) for tag, tag_id in label_map.items() } results = [] for i in computed_tensors[0]: if args.use_t2t_decoder: results.extend((i[:, 1:]).cpu().numpy().tolist()) else: results.extend( torch.argmax(i, dim=-1).int().cpu().numpy().tolist()) # compute and realize predictions with LaserTagger sources, predictions, target_lists = [], [], [] logging.info("Saving predictions to " + args.work_dir + "/pred.txt") with open(args.work_dir + "/pred.txt", 'w') as f:
def build_bert_example( self, sources, target=None, use_arbitrary_target_ids_for_infeasible_examples=False, save_tokens=True, infer=False, ): """Constructs a BERT Example. Args: sources: List of source texts. target: Target text or None when building an example during inference. use_arbitrary_target_ids_for_infeasible_examples: Whether to build an example with arbitrary target ids even if the target can't be obtained via tagging. Returns: BertExample, or None if the conversion from text to tags was infeasible and use_arbitrary_target_ids_for_infeasible_examples == False. """ # Compute target labels. task = tagging.EditingTask(sources) if (target is not None) and (not infer): tags = self._converter.compute_tags(task, target) if not tags: if use_arbitrary_target_ids_for_infeasible_examples: # Create a tag sequence [KEEP, DELETE, KEEP, DELETE, ...] which is # unlikely to be predicted by chance. tags = [ tagging.Tag('KEEP') if i % 2 == 0 else tagging.Tag('DELETE') for i, _ in enumerate(task.source_tokens) ] else: return None else: # If target is not provided, we set all target labels to KEEP. tags = [tagging.Tag('KEEP') for _ in task.source_tokens] labels = [self._label_map[str(tag)] for tag in tags] tokens, labels, token_start_indices = self._split_to_wordpieces( task.source_tokens, labels) tokens = self._truncate_list(tokens) labels = self._truncate_list(labels) input_tokens = ['[CLS]'] + tokens + ['[SEP]'] labels_mask = [0] + [1] * len(labels) + [0] labels = [0] + labels + [0] input_ids = self._tokenizer.tokens_to_ids(input_tokens) input_mask = [1] * len(input_ids) segment_ids = [0] * len(input_ids) tgt_ids = self._truncate_list(self._tokenizer.text_to_ids(target)) tgt_ids = [self._tokenizer.bos_id] + tgt_ids + [self._tokenizer.eos_id] if save_tokens: for i, t in enumerate(task.source_tokens): # Check of out of vocabulary tokens and save them if self._tokenizer.token_to_id(t) == 100: self._task_tokens[t] = None example = BertExample( input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, tgt_ids=tgt_ids, labels=labels, labels_mask=labels_mask, token_start_indices=token_start_indices, task=task, default_label=self._keep_tag_id, ) example.pad_to_max_length(self._max_seq_length, self._pad_id) return example