示例#1
0
    def _compute_single_tag(self, source_token, target_token_idx,
                            target_tokens):
        """Computes a single tag.

        The tag may match multiple target tokens (via tag.added_phrase) so we return
        the next unmatched target token.

        Args:
            source_token: The token to be tagged.
            target_token_idx: Index of the current target tag.
            target_tokens: List of all target tokens.

        Returns:
            A tuple with (1) the computed tag and (2) the next target_token_idx.
        """
        source_token = source_token.lower()
        target_token = target_tokens[target_token_idx].lower()
        if source_token == target_token:
            return tagging.Tag('KEEP'), target_token_idx + 1

        added_phrase = ''
        for num_added_tokens in range(1, self._max_added_phrase_length + 1):
            if target_token not in self._token_vocabulary:
                break
            added_phrase += (' ' if added_phrase else '') + target_token
            next_target_token_idx = target_token_idx + num_added_tokens
            if next_target_token_idx >= len(target_tokens):
                break
            target_token = target_tokens[next_target_token_idx].lower()
            if source_token == target_token and added_phrase in self._phrase_vocabulary:
                return tagging.Tag('KEEP|' +
                                   added_phrase), next_target_token_idx + 1
        return tagging.Tag('DELETE'), target_token_idx
示例#2
0
    def _compute_tags_fixed_order(self, source_tokens, target_tokens):
        """Computes tags when the order of sources is fixed.

        Args:
            source_tokens: List of source tokens.
            target_tokens: List of tokens to be obtained via edit operations.

        Returns:
            List of tagging.Tag objects. If the source couldn't be converted into the
            target via tagging, returns an empty list.
        """
        tags = [tagging.Tag('DELETE') for _ in source_tokens]
        # Indices of the tokens currently being processed.
        source_token_idx = 0
        target_token_idx = 0
        while target_token_idx < len(target_tokens):
            tags[
                source_token_idx], target_token_idx = self._compute_single_tag(
                    source_tokens[source_token_idx], target_token_idx,
                    target_tokens)
            # If we're adding a phrase and the previous source token(s) were deleted,
            # we could add the phrase before a previously deleted token and still get
            # the same realized output. For example:
            #    [DELETE, DELETE, KEEP|"what is"]
            # and
            #    [DELETE|"what is", DELETE, KEEP]
            # Would yield the same realized output. Experimentally, we noticed that
            # the model works better / the learning task becomes easier when phrases
            # are always added before the first deleted token. Also note that in the
            # current implementation, this way of moving the added phrase backward is
            # the only way a DELETE tag can have an added phrase, so sequences like
            # [DELETE|"What", DELETE|"is"] will never be created.
            if tags[source_token_idx].added_phrase:
                first_deletion_idx = self._find_first_deletion_idx(
                    source_token_idx, tags)
                if first_deletion_idx != source_token_idx:
                    tags[first_deletion_idx].added_phrase = tags[
                        source_token_idx].added_phrase
                    tags[source_token_idx].added_phrase = ''
            source_token_idx += 1
            if source_token_idx >= len(tags):
                break

        # If all target tokens have been consumed, we have found a conversion and
        # can return the tags. Note that if there are remaining source tokens, they
        # are already marked deleted when initializing the tag list.
        if target_token_idx >= len(target_tokens):
            return tags
        return []
示例#3
0
def get_phrase_vocabulary_from_label_map(label_map):
    """Extract the set of all phrases from label map.

    Args:
        label_map: Mapping from tags to tag IDs.

    Returns:
        Set of all phrases appearing in the label map.
    """
    phrase_vocabulary = set()
    for label in label_map.keys():
        tag = tagging.Tag(label)
        if tag.added_phrase:
            phrase_vocabulary.add(tag.added_phrase)
    return phrase_vocabulary
示例#4
0
                "max_steps": max_steps,
                "lr": args.lr,
                "weight_decay": args.weight_decay,
            },
            batches_per_step=args.iter_per_step,
        )

    elif args.command == 'infer':
        tensors_pred = create_pipeline(test_examples,
                                       args.batch_size,
                                       mode="infer")
        computed_tensors = nf.infer(tensors=tensors_pred,
                                    checkpoint_dir=args.work_dir)

        id_2_tag = {
            tag_id: tagging.Tag(tag)
            for tag, tag_id in label_map.items()
        }

        results = []
        for i in computed_tensors[0]:
            if args.use_t2t_decoder:
                results.extend((i[:, 1:]).cpu().numpy().tolist())
            else:
                results.extend(
                    torch.argmax(i, dim=-1).int().cpu().numpy().tolist())

        # compute and realize predictions with LaserTagger
        sources, predictions, target_lists = [], [], []
        logging.info("Saving predictions to " + args.work_dir + "/pred.txt")
        with open(args.work_dir + "/pred.txt", 'w') as f:
示例#5
0
文件: bert_example.py 项目: yoks/NeMo
    def build_bert_example(
        self,
        sources,
        target=None,
        use_arbitrary_target_ids_for_infeasible_examples=False,
        save_tokens=True,
        infer=False,
    ):
        """Constructs a BERT Example.

        Args:
            sources: List of source texts.
            target: Target text or None when building an example during inference.
            use_arbitrary_target_ids_for_infeasible_examples: Whether to build an
                example with arbitrary target ids even if the target can't be obtained
                via tagging.

        Returns:
            BertExample, or None if the conversion from text to tags was infeasible
            and use_arbitrary_target_ids_for_infeasible_examples == False.
        """
        # Compute target labels.
        task = tagging.EditingTask(sources)
        if (target is not None) and (not infer):
            tags = self._converter.compute_tags(task, target)
            if not tags:
                if use_arbitrary_target_ids_for_infeasible_examples:
                    # Create a tag sequence [KEEP, DELETE, KEEP, DELETE, ...] which is
                    # unlikely to be predicted by chance.
                    tags = [
                        tagging.Tag('KEEP') if i %
                        2 == 0 else tagging.Tag('DELETE')
                        for i, _ in enumerate(task.source_tokens)
                    ]
                else:
                    return None
        else:
            # If target is not provided, we set all target labels to KEEP.
            tags = [tagging.Tag('KEEP') for _ in task.source_tokens]
        labels = [self._label_map[str(tag)] for tag in tags]
        tokens, labels, token_start_indices = self._split_to_wordpieces(
            task.source_tokens, labels)

        tokens = self._truncate_list(tokens)
        labels = self._truncate_list(labels)

        input_tokens = ['[CLS]'] + tokens + ['[SEP]']
        labels_mask = [0] + [1] * len(labels) + [0]
        labels = [0] + labels + [0]

        input_ids = self._tokenizer.tokens_to_ids(input_tokens)
        input_mask = [1] * len(input_ids)
        segment_ids = [0] * len(input_ids)

        tgt_ids = self._truncate_list(self._tokenizer.text_to_ids(target))
        tgt_ids = [self._tokenizer.bos_id] + tgt_ids + [self._tokenizer.eos_id]

        if save_tokens:
            for i, t in enumerate(task.source_tokens):
                # Check of out of vocabulary tokens and save them
                if self._tokenizer.token_to_id(t) == 100:
                    self._task_tokens[t] = None

        example = BertExample(
            input_ids=input_ids,
            input_mask=input_mask,
            segment_ids=segment_ids,
            tgt_ids=tgt_ids,
            labels=labels,
            labels_mask=labels_mask,
            token_start_indices=token_start_indices,
            task=task,
            default_label=self._keep_tag_id,
        )
        example.pad_to_max_length(self._max_seq_length, self._pad_id)
        return example