Example #1
0
    def _parse_pack(self, collection: Any) -> Iterator[MultiPack]:
        multi_pack: MultiPack = MultiPack()
        data_pack1 = multi_pack.add_pack(ref_name="pack1")
        data_pack2 = multi_pack.add_pack(ref_name="pack2")
        data_pack3 = multi_pack.add_pack(ref_name="pack_three")

        data_pack1.pack_name = "1"
        data_pack2.pack_name = "2"
        data_pack3.pack_name = "Three"
        yield multi_pack
Example #2
0
    def _process(self, input_pack: MultiPack):
        fp = input_pack.get_pack_at(0)
        sp = input_pack.get_pack_at(1)

        nes1 = list(fp.get(EntityMention))
        nes2 = list(sp.get(EntityMention))

        for ne1 in nes1:
            for ne2 in nes2:
                if ne1.text == ne2.text:
                    CrossDocEntityRelation(input_pack, ne1, ne2)
Example #3
0
    def _process_query(
        self, input_pack: MultiPack
    ) -> Tuple[DataPack, Dict[str, Any]]:
        query_pack: DataPack = input_pack.get_pack(self.config.query_pack_name)
        context = [query_pack.text]

        # use context to build the query
        if "user_utterance" in input_pack.pack_names:
            user_pack = input_pack.get_pack("user_utterance")
            context.append(user_pack.text)

        if "bot_utterance" in input_pack.pack_names:
            bot_pack = input_pack.get_pack("bot_utterance")
            context.append(bot_pack.text)

        text = " ".join(context)

        query_vector = self._build_query(text=text)

        return query_pack, query_vector
Example #4
0
    def new_pack(self, pack_name: Optional[str] = None) -> MultiPack:
        """
        Create a new multi pack using the current pack manager.

        Args:
            pack_name (str, Optional): The name to be used for the pack. If not
              set, the pack name will remained unset.

        Returns:

        """
        return MultiPack(self._pack_manager, pack_name)
Example #5
0
    def _process(self, input_pack: MultiPack):
        query = input_pack.get_pack(self.in_pack_name).text
        params = '?' + urlencode(
            {'api-version': '3.0',
             'from': self.src_language,
             'to': [self.target_language]}, doseq=True)
        microsoft_constructed_url = self.microsoft_translate_url + params

        response = requests.post(
            microsoft_constructed_url, headers=self.microsoft_headers,
            json=[{"text": query}])

        if response.status_code != 200:
            raise RuntimeError(response.json()['error']['message'])

        text = response.json()[0]["translations"][0]["text"]
        pack: DataPack = input_pack.add_pack(self.out_pack_name)
        pack.set_text(text=text)

        Document(pack, 0, len(text))
        Utterance(pack, 0, len(text))
Example #6
0
 def consume_next(self, pred_pack: MultiPack, _):
     query_pack: DataPack = pred_pack.get_pack(self.configs.pack_name)
     query = list(query_pack.get(Query))[0]
     rank = 1
     for pid, _ in query.results.items():
         doc_id: Optional[str] = query_pack.pack_name
         if doc_id is None:
             raise ProcessExecutionException(
                 'Doc ID of the query pack is not set, '
                 'please double check the reader.')
         self.predicted_results.append((doc_id, pid, str(rank)))
         rank += 1
Example #7
0
    def _process(self, multi_pack: MultiPack):
        # Add a pack.
        p1 = multi_pack.add_pack('pack1')
        p2 = multi_pack.add_pack('pack2')

        # Add some entries into one pack.
        e1: ExampleEntry = p1.add_entry(ExampleEntry(p1))
        e1.secret_number = 1
        p2.add_entry(ExampleEntry(p2))

        # Add the multi pack entry.
        mp_entry = ExampleMPEntry(multi_pack)
        mp_entry.refer_entry = e1
Example #8
0
    def _process(self, input_pack: MultiPack):
        max_len = self.config.max_seq_length
        query_pack_name = self.config.query_pack_name
        query_pack = input_pack.get_pack(self.config.query_pack_name)
        query_entry = list(query_pack.get(Query))[0]
        query_text = query_pack.text
        doc_score_dict = query_entry.results
        best_doc_id = max(doc_score_dict, key=lambda x: doc_score_dict[x])
        packs = {}

        for doc_id in input_pack.pack_names:
            if doc_id == query_pack_name:
                continue
            pack = input_pack.get_pack(doc_id)
            doc_id_final = pack.pack_name
            if (doc_id_final != best_doc_id):
                query_entry.update_results({doc_id_final: ""})
                continue
            query_doc_input = {'question': query_text, 'context': pack.text}
            result = self.qa_pipeline(query_doc_input)
            # answer = result['answer']

            # Changing answer phrase to the whole sentence where it is present
            # print("====Full Passage Text: ", pack.text)
            # print("====Answer Phrase: ", result['answer'])
            answer_phrase = result['answer']
            ans_sents = [sent.text for sent in spacy_nlp(pack.text).sents]
            # print("====Passage Sentences: ", ans_sents)
            answer = None
            for sent in ans_sents:
                if answer_phrase in sent:
                    answer = sent
                    break
            if not answer:
                answer = answer_phrase

            # print("====Final Answer Sentence: ", answer)
            query_entry.update_results({doc_id_final: answer})
            packs[doc_id] = pack
Example #9
0
    def _process(self, input_pack: MultiPack):
        # Get the pack names for augmentation.
        aug_pack_names: List[str] = []

        # Check if the DataPack exists.
        for pack_name in self.configs["augment_pack_names"]["kwargs"].keys():
            if pack_name in input_pack.pack_names:
                aug_pack_names.append(pack_name)

        if len(self.configs["augment_pack_names"]["kwargs"].keys()) == 0:
            # Augment all the DataPacks if not specified.
            aug_pack_names = list(input_pack.pack_names)

        self._augment(input_pack, aug_pack_names)
        new_packs: List[Tuple[str, DataPack]] = []
        for aug_pack_name in aug_pack_names:
            new_pack_name: str = \
                self.configs["augment_pack_names"]["kwargs"].get(
                    aug_pack_name, "augmented_" + aug_pack_name
                )
            data_pack = input_pack.get_pack(aug_pack_name)
            new_pack = self._auto_align_annotations(
                data_pack=data_pack,
                replaced_annotations=self._replaced_annos[
                    data_pack.meta.pack_id])
            new_packs.append((new_pack_name, new_pack))

        for new_pack_name, new_pack in new_packs:
            input_pack.add_pack_(new_pack, new_pack_name)

        # Copy the MultiPackLinks/MultiPackGroups
        for mpl in input_pack.get(MultiPackLink):
            self._copy_multi_pack_link_or_group(mpl, input_pack)
        for mpg in input_pack.get(MultiPackGroup):
            self._copy_multi_pack_link_or_group(mpg, input_pack)

        # Must be called after processing each multipack
        # to reset internal states.
        self._clear_states()
Example #10
0
class SelectorTest(unittest.TestCase):
    def setUp(self) -> None:
        pm = PackManager()
        self.multi_pack = MultiPack(pm)

        data_pack1 = self.multi_pack.add_pack(ref_name="pack1")
        data_pack2 = self.multi_pack.add_pack(ref_name="pack2")
        data_pack3 = self.multi_pack.add_pack(ref_name="pack_three")

        data_pack1.pack_name = "1"
        data_pack2.pack_name = "2"
        data_pack3.pack_name = "Three"

    def test_name_match_selector(self) -> None:
        selector = NameMatchSelector(select_name="pack1")
        packs = selector.select(self.multi_pack)
        doc_ids = ["1"]
        for doc_id, pack in zip(doc_ids, packs):
            self.assertEqual(doc_id, pack.pack_name)

    def test_regex_name_match_selector(self) -> None:
        selector = RegexNameMatchSelector(select_name="^.*\\d$")
        packs = selector.select(self.multi_pack)
        doc_ids = ["1", "2"]
        for doc_id, pack in zip(doc_ids, packs):
            self.assertEqual(doc_id, pack.pack_name)

    def test_first_pack_selector(self) -> None:
        selector = FirstPackSelector()
        packs = list(selector.select(self.multi_pack))
        self.assertEqual(len(packs), 1)
        self.assertEqual(packs[0].pack_name, "1")

    def test_all_pack_selector(self) -> None:
        selector = AllPackSelector()
        packs = selector.select(self.multi_pack)
        doc_ids = ["1", "2", "Three"]
        for doc_id, pack in zip(doc_ids, packs):
            self.assertEqual(doc_id, pack.pack_name)
Example #11
0
    def _process(self, input_pack: MultiPack):
        r"""Searches `Elasticsearch` indexer to fetch documents for a query.
        This query should be contained in the input multipack with name
        `self.config.query_pack_name`.

        This method adds new packs to `input_pack` containing the retrieved
        results. Each result is added as a `ft.onto.base_ontology.Document`.

        Args:
             input_pack: A multipack containing query as a pack.
        """
        query_pack = input_pack.get_pack(self.configs.query_pack_name)

        # ElasticSearchQueryCreator adds a Query entry to query pack. We now
        # fetch it as the first element.
        first_query: Query = query_pack.get_single(Query)
        # pylint: disable=isinstance-second-argument-not-valid-type
        # TODO: until fix: https://github.com/PyCQA/pylint/issues/3507
        if not isinstance(first_query.value, Dict):
            raise ValueError(
                "The query to the elastic indexer need to be a dictionary.")
        results = self.index.search(first_query.value)
        hits = results["hits"]["hits"]

        for idx, hit in enumerate(hits):
            document = hit["_source"]
            first_query.add_result(document["doc_id"], hit["_score"])

            pack: DataPack = input_pack.add_pack(
                f"{self.configs.response_pack_name_prefix}_{idx}"
            )
            pack.pack_name = document["doc_id"]

            content = document[self.configs.field]
            pack.set_text(content)

            Document(pack=pack, begin=0, end=len(content))
Example #12
0
    def _augment(self, input_pack: MultiPack, aug_pack_names: List[str]):
        r"""
        This function splits a given word at a random position and replaces
        the original word with 2 split parts of it.
        """
        augment_entry = get_class(self.configs["augment_entry"])

        for pack_name in aug_pack_names:
            data_pack: DataPack = input_pack.get_pack(pack_name)
            annotations: List[Annotation] = []
            indexes: List[int] = []
            endings = []
            annos: Iterable[Annotation] = data_pack.get(augment_entry)
            for idx, anno in enumerate(annos):
                annotations.append(anno)
                indexes.append(idx)
                endings.append(anno.end)
            if len(annotations) > 0:
                annotation_to_split = random.sample(
                    [(anno, idx) for (anno, idx) in zip(annotations, indexes)
                     if (anno.end - anno.begin) > 1],
                    ceil(self.configs["alpha"] * len(annotations)),
                )
                annotation_to_split = sorted(annotation_to_split,
                                             key=lambda x: x[1],
                                             reverse=True)
                for curr_anno in annotation_to_split:
                    src_anno, src_idx = curr_anno
                    splitting_position = random.randrange(
                        1, (src_anno.end - src_anno.begin))
                    word_split = [
                        src_anno.text[:splitting_position],
                        src_anno.text[splitting_position:],
                    ]
                    if src_idx != 0:
                        first_position = endings[src_idx - 1] + 1
                        second_position = endings[src_idx]
                        word_split[1] = " " + word_split[1]
                    else:
                        first_position = 0
                        second_position = endings[0]
                        word_split[1] = " " + word_split[1]

                    self._insert(word_split[1], data_pack, second_position)
                    self._delete(src_anno)
                    self._insert(word_split[0], data_pack, first_position)
Example #13
0
    def test_entry_attribute_mp_pointer(self):
        mpe: ExampleMPEntry = self.pack.get_single(ExampleMPEntry)
        self.assertIsInstance(mpe.refer_entry, ExampleEntry)
        self.assertIsInstance(mpe.__dict__["refer_entry"], ExampleEntry)

        serialized_mp = self.pack.to_string(drop_record=True)
        recovered_mp = MultiPack.from_string(serialized_mp)

        s_packs = [p.to_string() for p in self.pack.packs]
        recovered_packs = [DataPack.from_string(s) for s in s_packs]

        recovered_mp.relink(recovered_packs)

        re_mpe: ExampleMPEntry = recovered_mp.get_single(ExampleMPEntry)
        self.assertIsInstance(re_mpe.refer_entry, ExampleEntry)
        self.assertEqual(re_mpe.refer_entry.tid, mpe.refer_entry.tid)
        self.assertEqual(re_mpe.tid, mpe.tid)
Example #14
0
    def _parse_pack(self, multi_pack_path: str) -> Iterator[MultiPack]:
        # pylint: disable=protected-access
        with open(os.path.join(self.configs.data_path,
                               multi_pack_path)) as m_data:
            m_pack: MultiPack = MultiPack.deserialize(m_data.read())
            for pid in m_pack._pack_ref:
                sub_pack_path = self.__pack_index[pid]
                if self._pack_manager.get_remapped_id(pid) >= 0:
                    # This pid is already been read.
                    continue

                with open(os.path.join(self.configs.data_path,
                                       sub_pack_path)) as pack_data:
                    pack: DataPack = DataPack.deserialize(pack_data.read())
                    # Add a reference count to this pack, because the multipack
                    # needs it.
                    self._pack_manager.reference_pack(pack)
            m_pack.realign_packs()
            yield m_pack
    def consume_next(self, pred_pack: MultiPack, _):
        #print(self.configs.pack_name)
        query_pack: DataPack = pred_pack.get_pack(self.configs.pack_name)
        query = list(query_pack.get(Query))[0]
        query_text = query_pack.text
        #print(pred_pack.get_pack('passage_6').text)

        sorted_query_results = sorted(list(query.results.items()),
                                      key=lambda x: x[1],
                                      reverse=True)
        rank = 1
        for pid, _ in sorted_query_results:
            doc_id: Optional[str] = query_pack.pack_name
            if doc_id is None:
                raise ProcessExecutionException(
                    'Doc ID of the query pack is not set, '
                    'please double check the reader.')
            self.predicted_results.append((doc_id, pid, str(rank)))
            rank += 1
Example #16
0
    def _parse_pack(self, multi_pack_str: str) -> Iterator[MultiPack]:
        m_pack: MultiPack = MultiPack.deserialize(multi_pack_str)

        for pid in m_pack.pack_ids():
            p_content = self._get_pack_content(pid)
            if p_content is None:
                logging.warning(
                    "Cannot locate the data pack with pid %d "
                    "for multi pack %d", pid, m_pack.pack_id)
                break
            pack: DataPack
            if isinstance(p_content, str):
                pack = DataPack.deserialize(p_content)
            else:
                pack = p_content
            # Only in deserialization we can do this.
            m_pack.packs.append(pack)
        else:
            # No multi pack will be yield if there are packs not located.
            yield m_pack
Example #17
0
 def _augment(self, input_pack: MultiPack, aug_pack_names: List[str]):
     augment_entry = get_class(self.configs["augment_entry"])
     for pack_name in aug_pack_names:
         data_pack: DataPack = input_pack.get_pack(pack_name)
         annotations = list(data_pack.get(augment_entry))
         if len(annotations) > 0:
             replace_map: Dict = {}
             for _ in range(ceil(self.configs['alpha'] * len(annotations))):
                 swap_idx = random.sample(range(len(annotations)), 2)
                 new_idx_0 = swap_idx[1] if swap_idx[1] not in replace_map \
                     else replace_map[swap_idx[1]]
                 new_idx_1 = swap_idx[0] if swap_idx[0] not in replace_map \
                     else replace_map[swap_idx[0]]
                 replace_map[swap_idx[0]] = new_idx_0
                 replace_map[swap_idx[1]] = new_idx_1
             pid: int = data_pack.pack_id
             for idx in replace_map:
                 self._replaced_annos[pid]\
                     .add((annotations[idx].span,
                              annotations[replace_map[idx]].text))
Example #18
0
    def _parse_pack(self, file_path: str) -> Iterator[MultiPack]:
        m_pack: MultiPack = MultiPack()

        input_pack_name = "input_src"
        output_pack_name = "output_tgt"

        with open(file_path, "r", encoding="utf8") as doc:
            text = ""
            offset = 0

            sentence_cnt = 0

            input_pack = DataPack(doc_id=file_path)

            for line in doc:
                line = line.strip()
                if len(line) == 0:
                    # skip empty lines
                    continue
                # add sentence
                sent = Sentence(input_pack, offset, offset + len(line))
                input_pack.add_entry(sent)
                text += line + '\n'
                offset = offset + len(line) + 1

                sentence_cnt += 1

                if sentence_cnt >= 20:
                    break

            input_pack.set_text(text, replace_func=self.text_replace_operation)

        output_pack = DataPack()

        m_pack.update_pack({
            input_pack_name: input_pack,
            output_pack_name: output_pack
        })

        yield m_pack
Example #19
0
    def _augment(self, input_pack: MultiPack, aug_pack_names: List[str]):
        r"""
        This function calls the data augmentation ops and
        modifies the input in-place. This function only applies for
        replacement-based methods. The subclasses should override
        this function to implement other data augmentation methods, such
        as Easy Data Augmentation.

        Args:
            input_pack: The input MultiPack.
            aug_pack_names: The packs names for DataPacks to be augmented.
        """
        replacement_op = create_class_with_kwargs(
            self.configs["data_aug_op"],
            class_args={
                "configs": self.configs["data_aug_op_config"]["kwargs"]
            })
        augment_entry = get_class(self.configs["augment_entry"])

        for pack_name in aug_pack_names:
            data_pack: DataPack = input_pack.get_pack(pack_name)
            for anno in data_pack.get(augment_entry):
                self._replace(replacement_op, anno)
Example #20
0
    def _parse_pack(self,
                    base_and_path: Tuple[str, str]) -> Iterator[MultiPack]:
        base_dir, file_path = base_and_path

        m_pack: MultiPack = MultiPack()

        input_pack_name = self.configs.input_pack_name
        output_pack_name = self.configs.output_pack_name

        text = ""
        offset = 0
        with open(file_path, "r", encoding="utf8") as doc:
            # Remove long path from the beginning.
            doc_id = file_path[
                     file_path.startswith(base_dir) and len(base_dir):]
            doc_id = doc_id.strip(os.path.sep)

            input_pack = m_pack.add_pack(input_pack_name)
            input_pack.pack_name = doc_id

            for line in doc:
                line = line.strip()

                if len(line) == 0:
                    continue

                # add sentence
                Sentence(input_pack, offset, offset + len(line))
                text += line + '\n'
                offset = offset + len(line) + 1

            input_pack.set_text(
                text, replace_func=self.text_replace_operation)
            # Create a output pack without text.
            m_pack.add_pack(output_pack_name)
            yield m_pack
    def _parse_pack(self, file_path: str) -> Iterator[MultiPack]:

        m_pack: MultiPack = MultiPack()

        input_pack_name = self.config.input_pack_name
        output_pack_name = self.config.output_pack_name

        text = ""
        offset = 0
        with open(file_path, "r", encoding="utf8") as doc:

            input_pack = DataPack(doc_id=file_path)

            for line in doc:
                line = line.strip()

                if len(line) == 0:
                    continue

                # add sentence
                sent = Sentence(input_pack, offset, offset + len(line))
                input_pack.add_entry(sent)
                text += line + '\n'
                offset = offset + len(line) + 1

            input_pack.set_text(
                text, replace_func=self.text_replace_operation)

            output_pack = DataPack()

            m_pack.update_pack({
                input_pack_name: input_pack,
                output_pack_name: output_pack
            })

            yield m_pack
Example #22
0
class DataPackTest(unittest.TestCase):

    def setUp(self) -> None:
        # Note: input source is created automatically by the system, but we
        #  can also set it manually at test cases.
        pm = PackManager()
        self.multi_pack = MultiPack(pm)
        self.data_pack1 = self.multi_pack.add_pack(ref_name="left pack")
        self.data_pack2 = self.multi_pack.add_pack(ref_name="right pack")

        self.data_pack1.pack_name = "some pack"
        self.data_pack1.set_text("This pack contains some sample data.")

        self.data_pack2.pack_name = "another pack"
        self.data_pack2.set_text("This pack contains some other sample data.")

    def test_serialization(self):
        ser_str: str = self.multi_pack.serialize()
        print(ser_str)

    def test_add_pack(self):
        data_pack3 = self.multi_pack.add_pack(ref_name="new pack")
        data_pack3.pack_name = "the third pack"
        data_pack3.set_text("Test to see if we can add new packs..")

        self.assertEqual(len(self.multi_pack.packs), 3)
        self.assertEqual(self.multi_pack.pack_names,
                         {'left pack', 'right pack', 'new pack'})

    def test_rename_pack(self):
        self.multi_pack.rename_pack('right pack', 'last pack')
        self.assertEqual(self.multi_pack.pack_names,
                         {'left pack', 'last pack'})

    def test_multipack_groups(self):
        """
        Test some multi pack group.
        Returns:

        """
        # Add tokens to each pack.
        for pack in self.multi_pack.packs:
            _space_token(pack)

        # Create some group.
        token: Annotation
        left_tokens = {}
        for token in self.multi_pack.packs[0].get(Token):
            left_tokens[token.text] = token

        right_tokens = {}
        for token in self.multi_pack.packs[1].get(Token):
            right_tokens[token.text] = token

        for key, lt in left_tokens.items():
            if key in right_tokens:
                rt = right_tokens[key]
                self.multi_pack.add_entry(MultiPackGroup(
                    self.multi_pack, [lt, rt]))

        # Check the groups.
        expected_content = [("This", "This"), ("pack", "pack"),
                            ("contains", "contains"), ("some", "some"),
                            ("sample", "sample"), ("data.", "data.")]

        group_content = []
        g: MultiPackGroup
        for g in self.multi_pack.get(MultiPackGroup):
            e: Annotation
            group_content.append(tuple([e.text for e in g.get_members()]))

        self.assertListEqual(expected_content, group_content)

    def test_multipack_entries(self):
        """
        Test some multi pack entry.
        Returns:

        """
        # 1. Add tokens to each pack.
        for pack in self.multi_pack.packs:
            _space_token(pack)

        left_tokens = [t.text for t in self.multi_pack.packs[0].get(Token)]
        right_tokens = [t.text for t in self.multi_pack.packs[1].get(Token)]

        self.assertListEqual(left_tokens,
                             ["This", "pack", "contains", "some", "sample",
                              "data."])
        self.assertListEqual(right_tokens,
                             ["This", "pack", "contains", "some", "other",
                              "sample", "data."])

        # 2. Link the same words from two packs.
        token: Annotation
        left_tokens = {}
        for token in self.multi_pack.packs[0].get(Token):
            left_tokens[token.text] = token

        right_tokens = {}
        for token in self.multi_pack.packs[1].get(Token):
            right_tokens[token.text] = token

        for key, lt in left_tokens.items():
            if key in right_tokens:
                rt = right_tokens[key]
                self.multi_pack.add_entry(MultiPackLink(
                    self.multi_pack, lt, rt))

        # One way to link tokens.
        linked_tokens = []
        for link in self.multi_pack.links:
            parent_text = link.get_parent().text
            child_text = link.get_child().text
            linked_tokens.append((parent_text, child_text))

        self.assertListEqual(
            linked_tokens,
            [("This", "This"), ("pack", "pack"), ("contains", "contains"),
             ("some", "some"), ("sample", "sample"), ("data.", "data.")])

        # Another way to get the links
        linked_tokens = []
        for link in self.multi_pack.get(MultiPackLink):
            parent_text = link.get_parent().text
            child_text = link.get_child().text
            linked_tokens.append((parent_text, child_text))

        self.assertListEqual(
            linked_tokens,
            [("This", "This"), ("pack", "pack"), ("contains", "contains"),
             ("some", "some"), ("sample", "sample"), ("data.", "data.")])

        # 3. Test deletion

        # Delete the second link.
        self.multi_pack.delete_entry(self.multi_pack.links[1])

        linked_tokens = []
        for link in self.multi_pack.links:
            parent_text = link.get_parent().text
            child_text = link.get_child().text
            linked_tokens.append((parent_text, child_text))

        self.assertListEqual(
            linked_tokens,
            [("This", "This"), ("contains", "contains"),
             ("some", "some"), ("sample", "sample"), ("data.", "data.")])
Example #23
0
 def _process(self, input_pack: MultiPack):
     pack = input_pack.add_pack()
     pack.set_text(input_pack.get_pack_at(0).text)
Example #24
0
 def _process(self, input_pack: MultiPack):
     rels = list(input_pack.get_entries_of(CrossDocEntityRelation))
     self.coref_count += len(rels)
Example #25
0
 def _process(self, input_pack: MultiPack):
     rels = input_pack.get_entries_by_type(CrossDocEntityRelation)
     self.coref_count += len(rels)
Example #26
0
 def _parse_pack(self, name: str) -> Iterator[MultiPack]:
     p = MultiPack()
     p.pack_name = name
     yield p
Example #27
0
 def _process(self, input_pack: MultiPack):
     for doc_i in docs:
         pack = input_pack.add_pack(ref_name=doc_i)
         pack.set_text(docs[doc_i])
         Document(pack, 0, len(pack.text))
Example #28
0
 def new_pack(self, pack_name: Optional[str] = None) -> MultiPack:
     return MultiPack(self._pack_manager, pack_name)
Example #29
0
 def _process_query(
         self, input_pack: MultiPack) -> Tuple[DataPack, Dict[str, Any]]:
     query_pack = input_pack.get_pack(self.configs.query_pack_name)
     query = self._build_query(text=query_pack.text)
     return query_pack, query
Example #30
0
class SelectorTest(unittest.TestCase):
    def setUp(self) -> None:
        self.multi_pack = MultiPack()

        data_pack1 = self.multi_pack.add_pack(ref_name="pack1")
        data_pack2 = self.multi_pack.add_pack(ref_name="pack2")
        data_pack3 = self.multi_pack.add_pack(ref_name="pack_three")

        data_pack1.pack_name = "1"
        data_pack2.pack_name = "2"
        data_pack3.pack_name = "Three"

    def test_name_match_selector(self) -> None:
        selector = NameMatchSelector()
        selector.initialize(
            configs={"select_name": "pack1"},
        )
        packs = selector.select(self.multi_pack)
        doc_ids = ["1"]
        for doc_id, pack in zip(doc_ids, packs):
            self.assertEqual(doc_id, pack.pack_name)

        # Test reverse selection.
        selector.initialize(
            configs={"select_name": "pack1", "reverse_selection": True},
        )
        packs = selector.select(self.multi_pack)
        doc_ids = ["2", "Three"]
        for doc_id, pack in zip(doc_ids, packs):
            self.assertEqual(doc_id, pack.pack_name)

    def test_name_match_selector_backward_compatability(self) -> None:
        selector = NameMatchSelector(select_name="pack1")
        selector.initialize()
        packs = selector.select(self.multi_pack)
        doc_ids = ["1"]
        for doc_id, pack in zip(doc_ids, packs):
            self.assertEqual(doc_id, pack.pack_name)

        selector = NameMatchSelector("pack1")
        selector.initialize()
        packs = selector.select(self.multi_pack)
        doc_ids = ["1"]
        for doc_id, pack in zip(doc_ids, packs):
            self.assertEqual(doc_id, pack.pack_name)

    def test_regex_name_match_selector(self) -> None:
        selector = RegexNameMatchSelector()
        selector.initialize(
            configs={"select_name": "^.*\\d$"},
        )
        packs = selector.select(self.multi_pack)
        doc_ids = ["1", "2"]
        for doc_id, pack in zip(doc_ids, packs):
            self.assertEqual(doc_id, pack.pack_name)

        # Test reverse selection.
        selector.initialize(
            {"select_name": "^.*\\d$", "reverse_selection": True}
        )
        packs = selector.select(self.multi_pack)
        doc_ids = ["Three"]
        for doc_id, pack in zip(doc_ids, packs):
            self.assertEqual(doc_id, pack.pack_name)

    def test_regex_name_match_selector_backward_compatability(self) -> None:
        selector = RegexNameMatchSelector(select_name="^.*\\d$")
        selector.initialize()
        packs = selector.select(self.multi_pack)
        doc_ids = ["1", "2"]
        for doc_id, pack in zip(doc_ids, packs):
            self.assertEqual(doc_id, pack.pack_name)

        # Test different configuration method (backward compatibility)
        selector = RegexNameMatchSelector("^.*\\d$")
        selector.initialize()
        packs = selector.select(self.multi_pack)
        doc_ids = ["1", "2"]
        for doc_id, pack in zip(doc_ids, packs):
            self.assertEqual(doc_id, pack.pack_name)

        # Test reverse selection.
        selector.initialize({"reverse_selection": True})
        packs = selector.select(self.multi_pack)
        doc_ids = ["Three"]
        for doc_id, pack in zip(doc_ids, packs):
            self.assertEqual(doc_id, pack.pack_name)

    def test_first_pack_selector(self) -> None:
        selector = FirstPackSelector()
        selector.initialize()
        packs = list(selector.select(self.multi_pack))
        self.assertEqual(len(packs), 1)
        self.assertEqual(packs[0].pack_name, "1")

        # Test reverse selection.
        selector.initialize({"reverse_selection": True})
        packs = list(selector.select(self.multi_pack))
        self.assertEqual(len(packs), len(self.multi_pack.packs) - 1)

    def test_all_pack_selector(self) -> None:
        selector = AllPackSelector()
        selector.initialize()
        packs = selector.select(self.multi_pack)
        doc_ids = ["1", "2", "Three"]
        for doc_id, pack in zip(doc_ids, packs):
            self.assertEqual(doc_id, pack.pack_name)