예제 #1
0
    def _augment(self, input_pack: MultiPack, aug_pack_names: List[str]):
        replacement_op = create_class_with_kwargs(
            self.configs["data_aug_op"],
            class_args={
                "configs": self.configs["data_aug_op_config"]["kwargs"]
            })
        augment_entry = get_class(self.configs["augment_entry"])

        for pack_name in aug_pack_names:
            data_pack: DataPack = input_pack.get_pack(pack_name)
            annotations = []
            pos = [0]
            for anno in data_pack.get(augment_entry):
                if anno.text not in self.stopwords:
                    annotations.append(anno)
                    pos.append(anno.end)
            if len(annotations) > 0:
                for _ in range(ceil(self.configs['alpha'] * len(annotations))):
                    src_anno = random.choice(annotations)
                    _, replaced_text = replacement_op.replace(src_anno)
                    insert_pos = random.choice(pos)
                    if insert_pos > 0:
                        replaced_text = " " + replaced_text
                    else:
                        replaced_text = replaced_text + " "
                    self._insert(replaced_text, data_pack, insert_pos)
예제 #2
0
    def _augment(self, input_pack: MultiPack, aug_pack_names: List[str]):
        r"""
        This function calls the data augmentation ops and
        modifies the input in-place. This function only applies for
        replacement-based methods. The subclasses should override
        this function to implement other data augmentation methods, such
        as Easy Data Augmentation.

        Args:
            input_pack: The input MultiPack.
            aug_pack_names: The packs names for DataPacks to be augmented.
        """
        replacement_op = create_class_with_kwargs(
            self.configs["data_aug_op"],
            class_args={
                "configs": self.configs["data_aug_op_config"]["kwargs"]
            },
        )
        augment_entry = get_class(self.configs["augment_entry"])

        for pack_name in aug_pack_names:
            data_pack: DataPack = input_pack.get_pack(pack_name)
            anno: Annotation
            for anno in data_pack.get(augment_entry):
                self._replace(replacement_op, anno)
예제 #3
0
    def cofigure_sampler(self) -> None:
        r"""
        This function sets the sampler that will be
        used by the distribution replacement op. The sampler will be set
        according to the configuration values
        """
        try:
            if "data_path" in self.configs["sampler_config"]["kwargs"].keys():
                distribution_path = self.configs["sampler_config"]["kwargs"][
                    "data_path"
                ]
                try:
                    r = requests.get(distribution_path)
                    data = r.json()
                except requests.exceptions.RequestException:
                    with open(distribution_path, encoding="utf8") as json_file:
                        data = json.load(json_file)
            else:
                data = self.configs["sampler_config"]["kwargs"]["sampler_data"]

            self.sampler = create_class_with_kwargs(
                self.configs["sampler_config"]["type"],
                {
                    "configs": {
                        "sampler_data": data,
                    }
                },
            )
        except KeyError as error:
            print("Could not configure Sampler: " + repr(error))
예제 #4
0
 def __init__(self, configs: Config):
     super().__init__(configs)
     self._validate_configs(configs)
     self.model_to = create_class_with_kwargs(
         configs["model_to"],
         class_args={
             "src_lang": configs["src_language"],
             "tgt_lang": configs["tgt_language"],
             "device": configs["device"],
         },
     )
     self.model_back = create_class_with_kwargs(
         configs["model_back"],
         class_args={
             "src_lang": configs["tgt_language"],
             "tgt_lang": configs["src_language"],
             "device": configs["device"],
         },
     )
예제 #5
0
    def test_create_class_with_kwargs(self):
        p = utils.create_class_with_kwargs(
            class_name='forte.processors.lowercaser_processor'
            '.LowerCaserProcessor',
            class_args={},
        )

        self.assertEqual(
            p.name,
            'forte.processors.lowercaser_processor.LowerCaserProcessor')
예제 #6
0
    def test_create_class_with_kwargs(self):
        p = utils.create_class_with_kwargs(
            class_name="forte.processors.misc.lowercaser_processor"
            ".LowerCaserProcessor",
            class_args={},
        )

        self.assertEqual(
            p.name,
            "forte.processors.misc.lowercaser_processor.LowerCaserProcessor",
        )
예제 #7
0
        def _insert_new_span(
            entry_class: str,
            insert_ind: int,
            inserted_annos: List[Tuple[int, int]],
            new_pack: DataPack,
            spans: List[Span],
            new_spans: List[Span],
        ):
            """
            An internal helper function for insertion.

            Args:
                entry_class: The new annotation type to be created.
                insert_ind: The index to be insert.
                inserted_annos: The annotation span information to be inserted.
                new_pack: The new data pack to insert the annotation.
                spans: The original spans before replacement, should be
                  a sorted ascending list.
                new_spans: The original spans before replacement, should be
                  a sorted ascending list.

            Returns:

            """
            pos: int
            length: int
            pos, length = inserted_annos[insert_ind]
            insert_end: int = modify_index(
                pos,
                spans,
                new_spans,
                is_begin=False,
                # Include the inserted span itself.
                is_inclusive=True,
            )
            insert_begin: int = insert_end - length
            new_anno = create_class_with_kwargs(
                entry_class,
                {"pack": new_pack, "begin": insert_begin, "end": insert_end},
            )
            new_pack.add_entry(new_anno)
예제 #8
0
 def _insert_new_span(insert_ind: int, inserted_annos: List[Tuple[int,
                                                                  int]],
                      new_pack: DataPack, spans: List[Span],
                      new_spans: List[Span]):
     r"""
     An internal helper function for insertion.
     """
     pos: int
     length: int
     pos, length = inserted_annos[insert_ind]
     insert_end: int = modify_index(
         pos,
         spans,
         new_spans,
         is_begin=False,
         # Include the inserted span itself.
         is_inclusive=True)
     insert_begin: int = insert_end - length
     new_anno = create_class_with_kwargs(entry, {
         "pack": new_pack,
         "begin": insert_begin,
         "end": insert_end
     })
     new_pack.add_entry(new_anno)
예제 #9
0
    def _auto_align_annotations(
        self,
        data_pack: DataPack,
        replaced_annotations: SortedList,
    ) -> DataPack:
        r"""
        Function to replace some annotations with new strings.
        It will copy and update the text of datapack and
        auto-align the annotation spans.

        The links are also copied if its parent & child are
        both present in the new pack.

        The groups are copied if all its members are present
        in the new pack.

        Args:
            data_pack: The Datapack holding the replaced annotations.
            replaced_annotations: A SortedList of tuples(span, new string).
                The text and span of the annotations will be updated
                with the new string.

        Returns:
            A new data_pack holds the text after replacement. The annotations
            in the original data pack will be copied and auto-aligned as
            instructed by the "other_entry_policy" in the configuration.
            The links and groups will be copied if there members are copied.
        """
        if len(replaced_annotations) == 0:
            return deepcopy(data_pack)

        spans: List[Span] = [span for span, _ in replaced_annotations]
        replacement_strs: List[str] = [
            replacement_str for _, replacement_str in replaced_annotations
        ]

        # Get the new text for the new data pack.
        new_text: str = ""
        for i, span in enumerate(spans):
            new_span_str = replacement_strs[i]
            # First, get the gap text between last and this span.
            last_span_end: int = spans[i - 1].end if i > 0 else 0
            gap_text: str = data_pack.text[last_span_end:span.begin]
            new_text += gap_text
            # Then, append the replaced new text.
            new_text += new_span_str
        # Finally, append to new_text the text after the last span.
        new_text += data_pack.text[spans[-1].end:]

        # Get the span (begin, end) before and after replacement.
        new_spans: List[Span] = []

        # Bias is the delta between the beginning
        # indices before & after replacement.
        bias: int = 0
        for i, span in enumerate(spans):
            old_begin: int = spans[i].begin
            old_end: int = spans[i].end
            new_begin: int = old_begin + bias
            new_end = new_begin + len(replacement_strs[i])
            new_spans.append(Span(new_begin, new_end))
            bias = new_end - old_end

        new_pack: DataPack = DataPack()
        new_pack.set_text(new_text)

        entries_to_copy: List[str] = \
            list(self._other_entry_policy.keys()) + \
            [self.configs['augment_entry']]

        entry_map: Dict[int, int] = {}
        insert_ind: int = 0
        pid: int = data_pack.pack_id

        inserted_annos: List[Tuple[int, int]] = list(
            self._inserted_annos_pos_len[pid].items())

        def _insert_new_span(insert_ind: int, inserted_annos: List[Tuple[int,
                                                                         int]],
                             new_pack: DataPack, spans: List[Span],
                             new_spans: List[Span]):
            r"""
            An internal helper function for insertion.
            """
            pos: int
            length: int
            pos, length = inserted_annos[insert_ind]
            insert_end: int = modify_index(
                pos,
                spans,
                new_spans,
                is_begin=False,
                # Include the inserted span itself.
                is_inclusive=True)
            insert_begin: int = insert_end - length
            new_anno = create_class_with_kwargs(entry, {
                "pack": new_pack,
                "begin": insert_begin,
                "end": insert_end
            })
            new_pack.add_entry(new_anno)

        # Iterate over all the original entries and modify their spans.
        for entry in entries_to_copy:
            for orig_anno in data_pack.get(get_class(entry)):
                # Dealing with insertion/deletion only for augment_entry.
                if entry == self.configs['augment_entry']:
                    while insert_ind < len(inserted_annos) and \
                            inserted_annos[insert_ind][0] <= orig_anno.begin:
                        # Preserve the order of the spans with merging sort.
                        # It is a 2-way merging from the inserted spans
                        # and original spans based on the begin index.
                        _insert_new_span(insert_ind, inserted_annos, new_pack,
                                         spans, new_spans)
                        insert_ind += 1

                    # Deletion
                    if orig_anno.tid in self._deleted_annos_id[pid]:
                        continue

                # Auto align the spans.
                span_new_begin: int = orig_anno.begin
                span_new_end: int = orig_anno.end

                if entry == self.configs['augment_entry'] \
                        or self._other_entry_policy[entry] \
                        == 'auto_align':
                    # Only inclusive when the entry is not augmented.
                    # E.g.: A Sentence include the inserted Token on the edge.
                    # E.g.: A Token shouldn't include a nearby inserted Token.
                    is_inclusive = entry != self.configs['augment_entry']
                    span_new_begin = modify_index(orig_anno.begin, spans,
                                                  new_spans, True,
                                                  is_inclusive)
                    span_new_end = modify_index(orig_anno.end, spans,
                                                new_spans, False, is_inclusive)

                new_anno = create_class_with_kwargs(entry, {
                    "pack": new_pack,
                    "begin": span_new_begin,
                    "end": span_new_end
                })
                new_pack.add_entry(new_anno)
                entry_map[orig_anno.tid] = new_anno.tid

            # Deal with spans after the last annotation in the original pack.
            if entry == self.configs['augment_entry']:
                while insert_ind < len(inserted_annos):
                    _insert_new_span(insert_ind, inserted_annos, new_pack,
                                     spans, new_spans)
                    insert_ind += 1

        # Iterate over and copy the links/groups in the datapack.
        for link in data_pack.get(Link):
            self._copy_link_or_group(link, entry_map, new_pack)
        for group in data_pack.get(Group):
            self._copy_link_or_group(group, entry_map, new_pack)

        self._data_pack_map[pid] = new_pack.pack_id
        self._entry_maps[pid] = entry_map
        return new_pack
예제 #10
0
 def __init__(self, configs: Union[Config, Dict[str, Any]]):
     super().__init__(configs)
     self.dictionary = create_class_with_kwargs(configs["dictionary_class"],
                                                class_args={})
 def __init__(self, configs: Config):
     super().__init__(configs)
     self.dictionary = create_class_with_kwargs(configs["dictionary_class"],
                                                class_args={})