Пример #1
0
    def test_get_class(self):
        cls = utils.get_class('LowerCaserProcessor',
                              ['forte.processors.lowercaser_processor'])
        self.assertEqual(cls.__name__, 'LowerCaserProcessor')

        with self.assertRaises(ValueError):
            utils.get_class('NonExistProcessor')

        with self.assertRaises(ValueError):
            utils.get_class('NonExistProcessor',
                            ['forte.processors.lowercaser_processor'])
Пример #2
0
    def test_get_class(self):
        cls = utils.get_class(
            "LowerCaserProcessor",
            ["forte.processors.misc.lowercaser_processor"],
        )
        self.assertEqual(cls.__name__, "LowerCaserProcessor")

        with self.assertRaises(ValueError):
            utils.get_class("NonExistProcessor")

        with self.assertRaises(ValueError):
            utils.get_class(
                "NonExistProcessor",
                ["forte.processors.misc.lowercaser_processor"],
            )
Пример #3
0
    def _augment(self, input_pack: MultiPack, aug_pack_names: List[str]):
        r"""
        This function calls the data augmentation ops and
        modifies the input in-place. This function only applies for
        replacement-based methods. The subclasses should override
        this function to implement other data augmentation methods, such
        as Easy Data Augmentation.

        Args:
            input_pack: The input MultiPack.
            aug_pack_names: The packs names for DataPacks to be augmented.
        """
        replacement_op = create_class_with_kwargs(
            self.configs["data_aug_op"],
            class_args={
                "configs": self.configs["data_aug_op_config"]["kwargs"]
            },
        )
        augment_entry = get_class(self.configs["augment_entry"])

        for pack_name in aug_pack_names:
            data_pack: DataPack = input_pack.get_pack(pack_name)
            anno: Annotation
            for anno in data_pack.get(augment_entry):
                self._replace(replacement_op, anno)
Пример #4
0
    def _augment(self, input_pack: MultiPack, aug_pack_names: List[str]):
        replacement_op = create_class_with_kwargs(
            self.configs["data_aug_op"],
            class_args={
                "configs": self.configs["data_aug_op_config"]["kwargs"]
            })
        augment_entry = get_class(self.configs["augment_entry"])

        for pack_name in aug_pack_names:
            data_pack: DataPack = input_pack.get_pack(pack_name)
            annotations = []
            pos = [0]
            for anno in data_pack.get(augment_entry):
                if anno.text not in self.stopwords:
                    annotations.append(anno)
                    pos.append(anno.end)
            if len(annotations) > 0:
                for _ in range(ceil(self.configs['alpha'] * len(annotations))):
                    src_anno = random.choice(annotations)
                    _, replaced_text = replacement_op.replace(src_anno)
                    insert_pos = random.choice(pos)
                    if insert_pos > 0:
                        replaced_text = " " + replaced_text
                    else:
                        replaced_text = replaced_text + " "
                    self._insert(replaced_text, data_pack, insert_pos)
Пример #5
0
 def _augment(self, input_pack: MultiPack, aug_pack_names: List[str]):
     augment_entry = get_class(self.configs["augment_entry"])
     if not issubclass(augment_entry, Annotation):
         raise ValueError(f"This augmenter only accept data of "
                          f"'forte.data.ontology.top.Annotation' type, "
                          f"but {self.configs['augment_entry']} is not.")
     for pack_name in aug_pack_names:
         data_pack: DataPack = input_pack.get_pack(pack_name)
         annotations: List[Annotation] = list(data_pack.get(augment_entry))
         if len(annotations) > 0:
             replace_map: Dict = {}
             for _ in range(ceil(self.configs["alpha"] * len(annotations))):
                 swap_idx = random.sample(range(len(annotations)), 2)
                 new_idx_0 = (swap_idx[1] if swap_idx[1] not in replace_map
                              else replace_map[swap_idx[1]])
                 new_idx_1 = (swap_idx[0] if swap_idx[0] not in replace_map
                              else replace_map[swap_idx[0]])
                 replace_map[swap_idx[0]] = new_idx_0
                 replace_map[swap_idx[1]] = new_idx_1
             pid: int = data_pack.pack_id
             for idx in replace_map:
                 self._replaced_annos[pid].add((
                     annotations[idx].span,
                     annotations[replace_map[idx]].text,
                 ))
Пример #6
0
    def _augment(self, input_pack: MultiPack, aug_pack_names: List[str]):
        augment_entry = get_class(self.configs["augment_entry"])

        for pack_name in aug_pack_names:
            data_pack: DataPack = input_pack.get_pack(pack_name)
            for anno in data_pack.get(augment_entry):
                if random.random() < self.configs['alpha']:
                    self._delete(anno)
    def initialize(self, resources: Resources, configs: Config):
        if configs.pretrained_model_name in self.name2tokenizer:
            self.tokenizer = \
                self.name2tokenizer[configs.pretrained_model_name](
                    pretrained_model_name=configs.pretrained_model_name)
            self.encoder = self.name2encoder[configs.pretrained_model_name](
                pretrained_model_name=configs.pretrained_model_name)
        else:
            raise ValueError("Unrecognized pre-trained model name.")

        self.entry_type = get_class(configs.entry_type)
        if not isinstance(self.entry_type, Annotation) and \
                not issubclass(self.entry_type, Annotation):
            raise ValueError("entry_type must be annotation type.")
Пример #8
0
    def _augment(self, input_pack: MultiPack, aug_pack_names: List[str]):
        r"""
        This function splits a given word at a random position and replaces
        the original word with 2 split parts of it.
        """
        augment_entry = get_class(self.configs["augment_entry"])

        for pack_name in aug_pack_names:
            data_pack: DataPack = input_pack.get_pack(pack_name)
            annotations: List[Annotation] = []
            indexes: List[int] = []
            endings = []
            annos: Iterable[Annotation] = data_pack.get(augment_entry)
            for idx, anno in enumerate(annos):
                annotations.append(anno)
                indexes.append(idx)
                endings.append(anno.end)
            if len(annotations) > 0:
                annotation_to_split = random.sample(
                    [(anno, idx) for (anno, idx) in zip(annotations, indexes)
                     if (anno.end - anno.begin) > 1],
                    ceil(self.configs["alpha"] * len(annotations)),
                )
                annotation_to_split = sorted(annotation_to_split,
                                             key=lambda x: x[1],
                                             reverse=True)
                for curr_anno in annotation_to_split:
                    src_anno, src_idx = curr_anno
                    splitting_position = random.randrange(
                        1, (src_anno.end - src_anno.begin))
                    word_split = [
                        src_anno.text[:splitting_position],
                        src_anno.text[splitting_position:],
                    ]
                    if src_idx != 0:
                        first_position = endings[src_idx - 1] + 1
                        second_position = endings[src_idx]
                        word_split[1] = " " + word_split[1]
                    else:
                        first_position = 0
                        second_position = endings[0]
                        word_split[1] = " " + word_split[1]

                    self._insert(word_split[1], data_pack, second_position)
                    self._delete(src_anno)
                    self._insert(word_split[0], data_pack, first_position)
Пример #9
0
 def _augment(self, input_pack: MultiPack, aug_pack_names: List[str]):
     augment_entry = get_class(self.configs["augment_entry"])
     for pack_name in aug_pack_names:
         data_pack: DataPack = input_pack.get_pack(pack_name)
         annotations = list(data_pack.get(augment_entry))
         if len(annotations) > 0:
             replace_map: Dict = {}
             for _ in range(ceil(self.configs['alpha'] * len(annotations))):
                 swap_idx = random.sample(range(len(annotations)), 2)
                 new_idx_0 = swap_idx[1] if swap_idx[1] not in replace_map \
                     else replace_map[swap_idx[1]]
                 new_idx_1 = swap_idx[0] if swap_idx[0] not in replace_map \
                     else replace_map[swap_idx[0]]
                 replace_map[swap_idx[0]] = new_idx_0
                 replace_map[swap_idx[1]] = new_idx_1
             pid: int = data_pack.pack_id
             for idx in replace_map:
                 self._replaced_annos[pid]\
                     .add((annotations[idx].span,
                              annotations[replace_map[idx]].text))
Пример #10
0
    def _auto_align_annotations(
        self,
        data_pack: DataPack,
        replaced_annotations: SortedList,
    ) -> DataPack:
        r"""
        Function to replace some annotations with new strings.
        It will copy and update the text of datapack and
        auto-align the annotation spans.

        The links are also copied if its parent & child are
        both present in the new pack.

        The groups are copied if all its members are present
        in the new pack.

        Args:
            data_pack: The Datapack holding the replaced annotations.
            replaced_annotations: A SortedList of tuples(span, new string).
                The text and span of the annotations will be updated
                with the new string.

        Returns:
            A new data_pack holds the text after replacement. The annotations
            in the original data pack will be copied and auto-aligned as
            instructed by the "other_entry_policy" in the configuration.
            The links and groups will be copied if there members are copied.
        """
        if len(replaced_annotations) == 0:
            return deepcopy(data_pack)

        spans: List[Span] = [span for span, _ in replaced_annotations]
        replacement_strs: List[str] = [
            replacement_str for _, replacement_str in replaced_annotations
        ]

        # Get the new text for the new data pack.
        new_text: str = ""
        for i, span in enumerate(spans):
            new_span_str = replacement_strs[i]
            # First, get the gap text between last and this span.
            last_span_end: int = spans[i - 1].end if i > 0 else 0
            gap_text: str = data_pack.text[last_span_end:span.begin]
            new_text += gap_text
            # Then, append the replaced new text.
            new_text += new_span_str
        # Finally, append to new_text the text after the last span.
        new_text += data_pack.text[spans[-1].end:]

        # Get the span (begin, end) before and after replacement.
        new_spans: List[Span] = []

        # Bias is the delta between the beginning
        # indices before & after replacement.
        bias: int = 0
        for i, span in enumerate(spans):
            old_begin: int = spans[i].begin
            old_end: int = spans[i].end
            new_begin: int = old_begin + bias
            new_end = new_begin + len(replacement_strs[i])
            new_spans.append(Span(new_begin, new_end))
            bias = new_end - old_end

        new_pack: DataPack = DataPack()
        new_pack.set_text(new_text)

        entries_to_copy: List[str] = \
            list(self._other_entry_policy.keys()) + \
            [self.configs['augment_entry']]

        entry_map: Dict[int, int] = {}
        insert_ind: int = 0
        pid: int = data_pack.pack_id

        inserted_annos: List[Tuple[int, int]] = list(
            self._inserted_annos_pos_len[pid].items())

        def _insert_new_span(insert_ind: int, inserted_annos: List[Tuple[int,
                                                                         int]],
                             new_pack: DataPack, spans: List[Span],
                             new_spans: List[Span]):
            r"""
            An internal helper function for insertion.
            """
            pos: int
            length: int
            pos, length = inserted_annos[insert_ind]
            insert_end: int = modify_index(
                pos,
                spans,
                new_spans,
                is_begin=False,
                # Include the inserted span itself.
                is_inclusive=True)
            insert_begin: int = insert_end - length
            new_anno = create_class_with_kwargs(entry, {
                "pack": new_pack,
                "begin": insert_begin,
                "end": insert_end
            })
            new_pack.add_entry(new_anno)

        # Iterate over all the original entries and modify their spans.
        for entry in entries_to_copy:
            for orig_anno in data_pack.get(get_class(entry)):
                # Dealing with insertion/deletion only for augment_entry.
                if entry == self.configs['augment_entry']:
                    while insert_ind < len(inserted_annos) and \
                            inserted_annos[insert_ind][0] <= orig_anno.begin:
                        # Preserve the order of the spans with merging sort.
                        # It is a 2-way merging from the inserted spans
                        # and original spans based on the begin index.
                        _insert_new_span(insert_ind, inserted_annos, new_pack,
                                         spans, new_spans)
                        insert_ind += 1

                    # Deletion
                    if orig_anno.tid in self._deleted_annos_id[pid]:
                        continue

                # Auto align the spans.
                span_new_begin: int = orig_anno.begin
                span_new_end: int = orig_anno.end

                if entry == self.configs['augment_entry'] \
                        or self._other_entry_policy[entry] \
                        == 'auto_align':
                    # Only inclusive when the entry is not augmented.
                    # E.g.: A Sentence include the inserted Token on the edge.
                    # E.g.: A Token shouldn't include a nearby inserted Token.
                    is_inclusive = entry != self.configs['augment_entry']
                    span_new_begin = modify_index(orig_anno.begin, spans,
                                                  new_spans, True,
                                                  is_inclusive)
                    span_new_end = modify_index(orig_anno.end, spans,
                                                new_spans, False, is_inclusive)

                new_anno = create_class_with_kwargs(entry, {
                    "pack": new_pack,
                    "begin": span_new_begin,
                    "end": span_new_end
                })
                new_pack.add_entry(new_anno)
                entry_map[orig_anno.tid] = new_anno.tid

            # Deal with spans after the last annotation in the original pack.
            if entry == self.configs['augment_entry']:
                while insert_ind < len(inserted_annos):
                    _insert_new_span(insert_ind, inserted_annos, new_pack,
                                     spans, new_spans)
                    insert_ind += 1

        # Iterate over and copy the links/groups in the datapack.
        for link in data_pack.get(Link):
            self._copy_link_or_group(link, entry_map, new_pack)
        for group in data_pack.get(Group):
            self._copy_link_or_group(group, entry_map, new_pack)

        self._data_pack_map[pid] = new_pack.pack_id
        self._entry_maps[pid] = entry_map
        return new_pack