コード例 #1
0
    def test_pipeline(self, texts):
        for idx, text in enumerate(texts):
            file_path = os.path.join(self.test_dir, f"{idx+1}.txt")
            with open(file_path, 'w') as f:
                f.write(text)

        nlp = Pipeline()
        reader_config = HParams(
            {
                "input_pack_name": "input",
                "output_pack_name": "output"
            }, MultiPackSentenceReader.default_hparams())
        nlp.set_reader(reader=MultiPackSentenceReader(), config=reader_config)
        translator_config = HParams(
            {
                "src_language": "de",
                "target_language": "en",
                "in_pack_name": "input",
                "out_pack_name": "result"
            }, None)

        nlp.add_processor(MicrosoftBingTranslator(), config=translator_config)
        nlp.initialize()

        english_results = ["Hey good morning", "This is Forte. A tool for NLP"]
        for idx, m_pack in enumerate(nlp.process_dataset(self.test_dir)):
            self.assertEqual(set(m_pack._pack_names),
                             set(["input", "output", "result"]))
            self.assertEqual(
                m_pack.get_pack("result").text, english_results[idx] + "\n")
コード例 #2
0
    def test_pipeline(self, texts):
        for idx, text in enumerate(texts):
            file_path = os.path.join(self.test_dir, f"{idx+1}.txt")
            with open(file_path, 'w') as f:
                f.write(text)

        nlp = Pipeline()
        reader_config = HParams(
            {
                "input_pack_name": "query",
                "output_pack_name": "output"
            }, MultiPackSentenceReader.default_hparams())
        nlp.set_reader(reader=MultiPackSentenceReader(), config=reader_config)
        config = HParams(
            {
                "model": {
                    "name": "bert-base-uncased"
                },
                "tokenizer": {
                    "name": "bert-base-uncased"
                },
                "max_seq_length": 128,
                "query_pack_name": "query"
            }, None)
        nlp.add_processor(BertBasedQueryCreator(), config=config)

        nlp.initialize()

        for idx, m_pack in enumerate(nlp.process_dataset(self.test_dir)):
            query_pack = m_pack.get_pack("query")
            self.assertEqual(len(query_pack.generics), 1)
            self.assertIsInstance(query_pack.generics[0], Query)
            query = query_pack.generics[0].value
            self.assertEqual(query.shape, (1, 768))
コード例 #3
0
    def test_pipeline(self, texts):
        for idx, text in enumerate(texts):
            file_path = os.path.join(self.test_dir, f"{idx+1}.txt")
            with open(file_path, 'w') as f:
                f.write(text)

        nlp = Pipeline()
        reader_config = HParams({"input_pack_name": "input",
                                 "output_pack_name": "output"},
                                MultiPackSentenceReader.default_hparams())
        nlp.set_reader(reader=MultiPackSentenceReader(), config=reader_config)
        nlp.initialize()

        for idx, m_pack in enumerate(nlp.process_dataset(self.test_dir)):
            self.assertEqual(m_pack._pack_names, ["input", "output"])
            self.assertEqual(m_pack.get_pack("input").text, texts[idx] + "\n")
コード例 #4
0
    def test_parse_pack(self, text, annotation_length):

        file_path = os.path.join(self.test_dir, 'test.txt')
        with open(file_path, 'w') as f:
            f.write(text)

        multipack = list(MultiPackSentenceReader().parse_pack(file_path))[0]
        input_pack = multipack.get_pack('input_src')
        self.assertEqual(len(multipack.packs), 2)
        self.assertEqual(multipack._pack_names, ['input_src', 'output_tgt'])
        self.assertEqual(len(input_pack.annotations), annotation_length)
        self.assertEqual(input_pack.text, text + "\n")
コード例 #5
0
    def test_parse_pack(self, text, annotation_length):

        file_path = os.path.join(self.test_dir, 'test.txt')
        with open(file_path, 'w') as f:
            f.write(text)

        pl = Pipeline()
        pl.set_reader(MultiPackSentenceReader())
        pl.initialize()

        multipack: MultiPack = pl.process_one(self.test_dir)
        input_pack = multipack.get_pack('input_src')
        self.assertEqual(len(multipack.packs), 2)
        self.assertEqual(multipack._pack_names, ['input_src', 'output_tgt'])
        self.assertEqual(len(input_pack.annotations), annotation_length)
        self.assertEqual(input_pack.text, text + "\n")
コード例 #6
0
    def test_pipeline(self, texts):
        for idx, text in enumerate(texts):
            file_path = os.path.join(self.test_dir, f"{idx + 1}.txt")
            with open(file_path, 'w') as f:
                f.write(text)

        nlp = Pipeline[MultiPack]()
        reader_config = {
            "input_pack_name": "input",
            "output_pack_name": "output"
        }
        nlp.set_reader(reader=MultiPackSentenceReader(), config=reader_config)
        nlp.initialize()

        m_pack: MultiPack
        for m_pack in nlp.process_dataset(self.test_dir):
            # Recover the test sentence order from the doc id.
            docid = m_pack.get_pack("input").meta.doc_id
            idx = int(os.path.basename(docid).rstrip('.txt')) - 1
            self.assertEqual(m_pack._pack_names, ["input", "output"])
            self.assertEqual(m_pack.get_pack("input").text, texts[idx] + "\n")
コード例 #7
0
    def test_replace_token(self, texts, expected_outputs, expected_tokens, expected_links):
        for idx, text in enumerate(texts):
            file_path = os.path.join(self.test_dir, f"{idx + 1}.txt")
            with open(file_path, 'w') as f:
                f.write(text)

        nlp = Pipeline[MultiPack]()
        reader_config = {
            "input_pack_name": "input_src",
            "output_pack_name": "output_tgt"
        }
        nlp.set_reader(reader=MultiPackSentenceReader(), config=reader_config)

        nlp.add(component=NLTKWordTokenizer(), selector=AllPackSelector())
        nlp.add(component=NLTKPOSTagger(), selector=AllPackSelector())

        nlp.initialize()

        processor_config = {
            'augment_entry': "ft.onto.base_ontology.Token",
            'other_entry_policy': {
                "kwargs": {
                    "ft.onto.base_ontology.Sentence": "auto_align"
                }
            },
            'type': 'data_augmentation_op',
            'data_aug_op': 'tests.forte.processors.base.data_augment_replacement_processor_test.TmpReplacer',
            "data_aug_op_config": {
                'kwargs': {}
            },
            'augment_pack_names': {
                'kwargs': {}
            }
        }

        processor = ReplacementDataAugmentProcessor()
        processor.initialize(resources=None, configs=processor_config)

        for idx, m_pack in enumerate(nlp.process_dataset(self.test_dir)):
            src_pack = m_pack.get_pack('input_src')
            tgt_pack = m_pack.get_pack('output_tgt')

            num_mpl_orig, num_mpg_orig = 0, 0
            # Copy the source pack to target pack.
            tgt_pack.set_text(src_pack.text)

            src_pack.add_entry(Document(src_pack, 0, len(src_pack.text)))
            for anno in src_pack.get(Annotation):
                new_anno = type(anno)(
                    tgt_pack, anno.begin, anno.end
                )
                tgt_pack.add_entry(new_anno)

                # Create MultiPackLink.
                m_pack.add_entry(
                    MultiPackLink(
                        m_pack, anno, new_anno
                    )
                )

                # Create MultiPackGroup.
                m_pack.add_entry(
                    MultiPackGroup(
                        m_pack, [anno, new_anno]
                    )
                )

                # Count the number of MultiPackLink/MultiPackGroup.
                num_mpl_orig += 1
                num_mpg_orig += 1

            # Create Links in the source pack.
            # The Links should be a tree:
            #
            #                           Link 3
            #                    _________|_________
            #                   |                  |
            #                 Link 2               |
            #            _______|________          |
            #           |               |          |
            #         Link 1            |          |
            #     ______|_____          |          |
            #    |           |          |          |
            # token 1     token 2    token 3    token 4 ... ...
            prev_entry = None
            for i, token in enumerate(src_pack.get(Token)):
                # Avoid overlapping with deleted tokens.
                if i < 10:
                    continue
                if prev_entry:
                    link = Link(src_pack, prev_entry, token)
                    src_pack.add_entry(
                        link
                    )
                    prev_entry = link
                else:
                    prev_entry = token

            # Create Groups in the target pack.
            # The Groups should be a tree like the Links.
            prev_entry = None
            for i, token in enumerate(tgt_pack.get(Token)):
                # Avoid overlapping with deleted tokens.
                if i < 10:
                    continue
                if prev_entry:
                    group = Group(tgt_pack, [prev_entry, token])
                    tgt_pack.add_entry(
                        group
                    )
                    prev_entry = group
                else:
                    prev_entry = token

            doc_src = list(src_pack.get(Document))[0]
            doc_tgt = list(tgt_pack.get(Document))[0]

            sent_src = list(src_pack.get(Sentence))[0]
            sent_tgt = list(tgt_pack.get(Sentence))[0]

            # Insert two extra Links in the src_pack.
            # They should not be copied to new_src_pack, because the Document is not copied.
            link_src_low = src_pack.add_entry(Link(src_pack, doc_src, sent_src))
            src_pack.add_entry(Link(src_pack, link_src_low, sent_src))

            # Insert two extra Groups in the tgt_pack.
            # They should not be copied to new_tgt_pack, because the Document is not copied.
            group_tgt_low = tgt_pack.add_entry(Group(tgt_pack, [doc_tgt, sent_tgt]))
            tgt_pack.add_entry(Group(tgt_pack, [group_tgt_low, sent_tgt]))

            # Call the augment function explicitly for duplicate replacement
            # to test the False case of _replace function.
            processor._augment(m_pack, ["input_src", "output_tgt"])

            # Test the insertion and deletion
            for pack in (src_pack, tgt_pack):
                # Insert an "NLP" at the beginning
                processor._insert(" NLP ", pack, 0)
                processor._insert(" NLP ", pack, 18)
                processor._insert(" NLP ", pack, len(pack.text) - 2)
                processor._insert("NLP", pack, len(pack.text) - 1)
                # Delete the second token "and"
                processor._delete(list(pack.get(Token))[1])

                # This duplicate insertion should be invalid.
                processor._insert(" NLP ", pack, 0)
                # This insertion overlaps with a replacement.
                # It should be invalid.
                processor._insert(" NLP ", pack, 2)

            processor._process(m_pack)

            new_src_pack = m_pack.get_pack('augmented_input_src')
            new_tgt_pack = m_pack.get_pack('augmented_output_tgt')

            self.assertEqual(new_src_pack.text, expected_outputs[idx] + "\n")

            for j, token in enumerate(new_src_pack.get(Token)):
                self.assertEqual(token.text, expected_tokens[idx][j])

            for sent in new_src_pack.get(Sentence):
                self.assertEqual(sent.text, expected_outputs[idx])

            # Test the copied Links.
            prev_link = None
            for i, link in enumerate(new_src_pack.get(Link)):
                if prev_link:
                    self.assertEqual(link.get_parent().tid, prev_link.tid)
                    self.assertEqual(link.get_child().text, expected_links[idx][i])
                prev_link = link

            # Test the copied Groups.
            prev_group = None
            for i, group in enumerate(new_tgt_pack.get(Group)):
                members = group.get_members()
                if isinstance(members[0], Token):
                    member_token = members[0]
                    member_group = members[1]
                else:
                    member_token = members[1]
                    member_group = members[0]

                if prev_group:
                    self.assertEqual(isinstance(member_token, Token), True)
                    self.assertEqual(isinstance(member_group, Group), True)
                    self.assertEqual(member_group.tid, prev_group.tid)
                    self.assertEqual(member_token.text, expected_links[idx][i])

                prev_group = group

            # The two extra Links should not be copied, because of missing Document.
            self.assertEqual(len(list(src_pack.get(Link))) - 2, len(list(new_src_pack.get(Link))))
            # The two extra Groups should not be copied, because of missing Document.
            self.assertEqual(len(list(tgt_pack.get(Group))) - 2, len(list(new_tgt_pack.get(Group))))

            # Test the MultiPackLink/MultiPackGroup
            num_mpl_aug, num_mpg_aug = 0, 0
            for mpl in m_pack.get(MultiPackLink):
                parent = mpl.get_parent()
                child = mpl.get_child()
                num_mpl_aug += 1
                self.assertEqual(parent.text, child.text)
                self.assertNotEqual(parent.pack.meta.pack_id, child.pack.meta.pack_id)

            for mpg in m_pack.get(MultiPackGroup):
                members = mpg.get_members()
                num_mpg_aug += 1
                self.assertEqual(members[0].text, members[1].text)
                self.assertNotEqual(members[0].pack.meta.pack_id, members[1].pack.meta.pack_id)

            # Test the number of MultiPackLink/MultiPackGroup.
            # Minus the aug and orig counters by 1, because the Document is not copied.
            # So we ignore the MPL and MPG between Document.
            # The number should be doubled, except for one deletion.
            self.assertEqual(num_mpl_aug - 1, (num_mpl_orig - 1) * 2 - 1)
            self.assertEqual(num_mpg_aug - 1, (num_mpg_orig - 1) * 2 - 1)