def _parse_pack(self, collection: Any) -> Iterator[MultiPack]: multi_pack: MultiPack = MultiPack() data_pack1 = multi_pack.add_pack(ref_name="pack1") data_pack2 = multi_pack.add_pack(ref_name="pack2") data_pack3 = multi_pack.add_pack(ref_name="pack_three") data_pack1.pack_name = "1" data_pack2.pack_name = "2" data_pack3.pack_name = "Three" yield multi_pack
def _process(self, input_pack: MultiPack): fp = input_pack.get_pack_at(0) sp = input_pack.get_pack_at(1) nes1 = list(fp.get(EntityMention)) nes2 = list(sp.get(EntityMention)) for ne1 in nes1: for ne2 in nes2: if ne1.text == ne2.text: CrossDocEntityRelation(input_pack, ne1, ne2)
def _process_query( self, input_pack: MultiPack ) -> Tuple[DataPack, Dict[str, Any]]: query_pack: DataPack = input_pack.get_pack(self.config.query_pack_name) context = [query_pack.text] # use context to build the query if "user_utterance" in input_pack.pack_names: user_pack = input_pack.get_pack("user_utterance") context.append(user_pack.text) if "bot_utterance" in input_pack.pack_names: bot_pack = input_pack.get_pack("bot_utterance") context.append(bot_pack.text) text = " ".join(context) query_vector = self._build_query(text=text) return query_pack, query_vector
def new_pack(self, pack_name: Optional[str] = None) -> MultiPack: """ Create a new multi pack using the current pack manager. Args: pack_name (str, Optional): The name to be used for the pack. If not set, the pack name will remained unset. Returns: """ return MultiPack(self._pack_manager, pack_name)
def _process(self, input_pack: MultiPack): query = input_pack.get_pack(self.in_pack_name).text params = '?' + urlencode( {'api-version': '3.0', 'from': self.src_language, 'to': [self.target_language]}, doseq=True) microsoft_constructed_url = self.microsoft_translate_url + params response = requests.post( microsoft_constructed_url, headers=self.microsoft_headers, json=[{"text": query}]) if response.status_code != 200: raise RuntimeError(response.json()['error']['message']) text = response.json()[0]["translations"][0]["text"] pack: DataPack = input_pack.add_pack(self.out_pack_name) pack.set_text(text=text) Document(pack, 0, len(text)) Utterance(pack, 0, len(text))
def consume_next(self, pred_pack: MultiPack, _): query_pack: DataPack = pred_pack.get_pack(self.configs.pack_name) query = list(query_pack.get(Query))[0] rank = 1 for pid, _ in query.results.items(): doc_id: Optional[str] = query_pack.pack_name if doc_id is None: raise ProcessExecutionException( 'Doc ID of the query pack is not set, ' 'please double check the reader.') self.predicted_results.append((doc_id, pid, str(rank))) rank += 1
def _process(self, multi_pack: MultiPack): # Add a pack. p1 = multi_pack.add_pack('pack1') p2 = multi_pack.add_pack('pack2') # Add some entries into one pack. e1: ExampleEntry = p1.add_entry(ExampleEntry(p1)) e1.secret_number = 1 p2.add_entry(ExampleEntry(p2)) # Add the multi pack entry. mp_entry = ExampleMPEntry(multi_pack) mp_entry.refer_entry = e1
def _process(self, input_pack: MultiPack): max_len = self.config.max_seq_length query_pack_name = self.config.query_pack_name query_pack = input_pack.get_pack(self.config.query_pack_name) query_entry = list(query_pack.get(Query))[0] query_text = query_pack.text doc_score_dict = query_entry.results best_doc_id = max(doc_score_dict, key=lambda x: doc_score_dict[x]) packs = {} for doc_id in input_pack.pack_names: if doc_id == query_pack_name: continue pack = input_pack.get_pack(doc_id) doc_id_final = pack.pack_name if (doc_id_final != best_doc_id): query_entry.update_results({doc_id_final: ""}) continue query_doc_input = {'question': query_text, 'context': pack.text} result = self.qa_pipeline(query_doc_input) # answer = result['answer'] # Changing answer phrase to the whole sentence where it is present # print("====Full Passage Text: ", pack.text) # print("====Answer Phrase: ", result['answer']) answer_phrase = result['answer'] ans_sents = [sent.text for sent in spacy_nlp(pack.text).sents] # print("====Passage Sentences: ", ans_sents) answer = None for sent in ans_sents: if answer_phrase in sent: answer = sent break if not answer: answer = answer_phrase # print("====Final Answer Sentence: ", answer) query_entry.update_results({doc_id_final: answer}) packs[doc_id] = pack
def _process(self, input_pack: MultiPack): # Get the pack names for augmentation. aug_pack_names: List[str] = [] # Check if the DataPack exists. for pack_name in self.configs["augment_pack_names"]["kwargs"].keys(): if pack_name in input_pack.pack_names: aug_pack_names.append(pack_name) if len(self.configs["augment_pack_names"]["kwargs"].keys()) == 0: # Augment all the DataPacks if not specified. aug_pack_names = list(input_pack.pack_names) self._augment(input_pack, aug_pack_names) new_packs: List[Tuple[str, DataPack]] = [] for aug_pack_name in aug_pack_names: new_pack_name: str = \ self.configs["augment_pack_names"]["kwargs"].get( aug_pack_name, "augmented_" + aug_pack_name ) data_pack = input_pack.get_pack(aug_pack_name) new_pack = self._auto_align_annotations( data_pack=data_pack, replaced_annotations=self._replaced_annos[ data_pack.meta.pack_id]) new_packs.append((new_pack_name, new_pack)) for new_pack_name, new_pack in new_packs: input_pack.add_pack_(new_pack, new_pack_name) # Copy the MultiPackLinks/MultiPackGroups for mpl in input_pack.get(MultiPackLink): self._copy_multi_pack_link_or_group(mpl, input_pack) for mpg in input_pack.get(MultiPackGroup): self._copy_multi_pack_link_or_group(mpg, input_pack) # Must be called after processing each multipack # to reset internal states. self._clear_states()
class SelectorTest(unittest.TestCase): def setUp(self) -> None: pm = PackManager() self.multi_pack = MultiPack(pm) data_pack1 = self.multi_pack.add_pack(ref_name="pack1") data_pack2 = self.multi_pack.add_pack(ref_name="pack2") data_pack3 = self.multi_pack.add_pack(ref_name="pack_three") data_pack1.pack_name = "1" data_pack2.pack_name = "2" data_pack3.pack_name = "Three" def test_name_match_selector(self) -> None: selector = NameMatchSelector(select_name="pack1") packs = selector.select(self.multi_pack) doc_ids = ["1"] for doc_id, pack in zip(doc_ids, packs): self.assertEqual(doc_id, pack.pack_name) def test_regex_name_match_selector(self) -> None: selector = RegexNameMatchSelector(select_name="^.*\\d$") packs = selector.select(self.multi_pack) doc_ids = ["1", "2"] for doc_id, pack in zip(doc_ids, packs): self.assertEqual(doc_id, pack.pack_name) def test_first_pack_selector(self) -> None: selector = FirstPackSelector() packs = list(selector.select(self.multi_pack)) self.assertEqual(len(packs), 1) self.assertEqual(packs[0].pack_name, "1") def test_all_pack_selector(self) -> None: selector = AllPackSelector() packs = selector.select(self.multi_pack) doc_ids = ["1", "2", "Three"] for doc_id, pack in zip(doc_ids, packs): self.assertEqual(doc_id, pack.pack_name)
def _process(self, input_pack: MultiPack): r"""Searches `Elasticsearch` indexer to fetch documents for a query. This query should be contained in the input multipack with name `self.config.query_pack_name`. This method adds new packs to `input_pack` containing the retrieved results. Each result is added as a `ft.onto.base_ontology.Document`. Args: input_pack: A multipack containing query as a pack. """ query_pack = input_pack.get_pack(self.configs.query_pack_name) # ElasticSearchQueryCreator adds a Query entry to query pack. We now # fetch it as the first element. first_query: Query = query_pack.get_single(Query) # pylint: disable=isinstance-second-argument-not-valid-type # TODO: until fix: https://github.com/PyCQA/pylint/issues/3507 if not isinstance(first_query.value, Dict): raise ValueError( "The query to the elastic indexer need to be a dictionary.") results = self.index.search(first_query.value) hits = results["hits"]["hits"] for idx, hit in enumerate(hits): document = hit["_source"] first_query.add_result(document["doc_id"], hit["_score"]) pack: DataPack = input_pack.add_pack( f"{self.configs.response_pack_name_prefix}_{idx}" ) pack.pack_name = document["doc_id"] content = document[self.configs.field] pack.set_text(content) Document(pack=pack, begin=0, end=len(content))
def _augment(self, input_pack: MultiPack, aug_pack_names: List[str]): r""" This function splits a given word at a random position and replaces the original word with 2 split parts of it. """ augment_entry = get_class(self.configs["augment_entry"]) for pack_name in aug_pack_names: data_pack: DataPack = input_pack.get_pack(pack_name) annotations: List[Annotation] = [] indexes: List[int] = [] endings = [] annos: Iterable[Annotation] = data_pack.get(augment_entry) for idx, anno in enumerate(annos): annotations.append(anno) indexes.append(idx) endings.append(anno.end) if len(annotations) > 0: annotation_to_split = random.sample( [(anno, idx) for (anno, idx) in zip(annotations, indexes) if (anno.end - anno.begin) > 1], ceil(self.configs["alpha"] * len(annotations)), ) annotation_to_split = sorted(annotation_to_split, key=lambda x: x[1], reverse=True) for curr_anno in annotation_to_split: src_anno, src_idx = curr_anno splitting_position = random.randrange( 1, (src_anno.end - src_anno.begin)) word_split = [ src_anno.text[:splitting_position], src_anno.text[splitting_position:], ] if src_idx != 0: first_position = endings[src_idx - 1] + 1 second_position = endings[src_idx] word_split[1] = " " + word_split[1] else: first_position = 0 second_position = endings[0] word_split[1] = " " + word_split[1] self._insert(word_split[1], data_pack, second_position) self._delete(src_anno) self._insert(word_split[0], data_pack, first_position)
def test_entry_attribute_mp_pointer(self): mpe: ExampleMPEntry = self.pack.get_single(ExampleMPEntry) self.assertIsInstance(mpe.refer_entry, ExampleEntry) self.assertIsInstance(mpe.__dict__["refer_entry"], ExampleEntry) serialized_mp = self.pack.to_string(drop_record=True) recovered_mp = MultiPack.from_string(serialized_mp) s_packs = [p.to_string() for p in self.pack.packs] recovered_packs = [DataPack.from_string(s) for s in s_packs] recovered_mp.relink(recovered_packs) re_mpe: ExampleMPEntry = recovered_mp.get_single(ExampleMPEntry) self.assertIsInstance(re_mpe.refer_entry, ExampleEntry) self.assertEqual(re_mpe.refer_entry.tid, mpe.refer_entry.tid) self.assertEqual(re_mpe.tid, mpe.tid)
def _parse_pack(self, multi_pack_path: str) -> Iterator[MultiPack]: # pylint: disable=protected-access with open(os.path.join(self.configs.data_path, multi_pack_path)) as m_data: m_pack: MultiPack = MultiPack.deserialize(m_data.read()) for pid in m_pack._pack_ref: sub_pack_path = self.__pack_index[pid] if self._pack_manager.get_remapped_id(pid) >= 0: # This pid is already been read. continue with open(os.path.join(self.configs.data_path, sub_pack_path)) as pack_data: pack: DataPack = DataPack.deserialize(pack_data.read()) # Add a reference count to this pack, because the multipack # needs it. self._pack_manager.reference_pack(pack) m_pack.realign_packs() yield m_pack
def consume_next(self, pred_pack: MultiPack, _): #print(self.configs.pack_name) query_pack: DataPack = pred_pack.get_pack(self.configs.pack_name) query = list(query_pack.get(Query))[0] query_text = query_pack.text #print(pred_pack.get_pack('passage_6').text) sorted_query_results = sorted(list(query.results.items()), key=lambda x: x[1], reverse=True) rank = 1 for pid, _ in sorted_query_results: doc_id: Optional[str] = query_pack.pack_name if doc_id is None: raise ProcessExecutionException( 'Doc ID of the query pack is not set, ' 'please double check the reader.') self.predicted_results.append((doc_id, pid, str(rank))) rank += 1
def _parse_pack(self, multi_pack_str: str) -> Iterator[MultiPack]: m_pack: MultiPack = MultiPack.deserialize(multi_pack_str) for pid in m_pack.pack_ids(): p_content = self._get_pack_content(pid) if p_content is None: logging.warning( "Cannot locate the data pack with pid %d " "for multi pack %d", pid, m_pack.pack_id) break pack: DataPack if isinstance(p_content, str): pack = DataPack.deserialize(p_content) else: pack = p_content # Only in deserialization we can do this. m_pack.packs.append(pack) else: # No multi pack will be yield if there are packs not located. yield m_pack
def _augment(self, input_pack: MultiPack, aug_pack_names: List[str]): augment_entry = get_class(self.configs["augment_entry"]) for pack_name in aug_pack_names: data_pack: DataPack = input_pack.get_pack(pack_name) annotations = list(data_pack.get(augment_entry)) if len(annotations) > 0: replace_map: Dict = {} for _ in range(ceil(self.configs['alpha'] * len(annotations))): swap_idx = random.sample(range(len(annotations)), 2) new_idx_0 = swap_idx[1] if swap_idx[1] not in replace_map \ else replace_map[swap_idx[1]] new_idx_1 = swap_idx[0] if swap_idx[0] not in replace_map \ else replace_map[swap_idx[0]] replace_map[swap_idx[0]] = new_idx_0 replace_map[swap_idx[1]] = new_idx_1 pid: int = data_pack.pack_id for idx in replace_map: self._replaced_annos[pid]\ .add((annotations[idx].span, annotations[replace_map[idx]].text))
def _parse_pack(self, file_path: str) -> Iterator[MultiPack]: m_pack: MultiPack = MultiPack() input_pack_name = "input_src" output_pack_name = "output_tgt" with open(file_path, "r", encoding="utf8") as doc: text = "" offset = 0 sentence_cnt = 0 input_pack = DataPack(doc_id=file_path) for line in doc: line = line.strip() if len(line) == 0: # skip empty lines continue # add sentence sent = Sentence(input_pack, offset, offset + len(line)) input_pack.add_entry(sent) text += line + '\n' offset = offset + len(line) + 1 sentence_cnt += 1 if sentence_cnt >= 20: break input_pack.set_text(text, replace_func=self.text_replace_operation) output_pack = DataPack() m_pack.update_pack({ input_pack_name: input_pack, output_pack_name: output_pack }) yield m_pack
def _augment(self, input_pack: MultiPack, aug_pack_names: List[str]): r""" This function calls the data augmentation ops and modifies the input in-place. This function only applies for replacement-based methods. The subclasses should override this function to implement other data augmentation methods, such as Easy Data Augmentation. Args: input_pack: The input MultiPack. aug_pack_names: The packs names for DataPacks to be augmented. """ replacement_op = create_class_with_kwargs( self.configs["data_aug_op"], class_args={ "configs": self.configs["data_aug_op_config"]["kwargs"] }) augment_entry = get_class(self.configs["augment_entry"]) for pack_name in aug_pack_names: data_pack: DataPack = input_pack.get_pack(pack_name) for anno in data_pack.get(augment_entry): self._replace(replacement_op, anno)
def _parse_pack(self, base_and_path: Tuple[str, str]) -> Iterator[MultiPack]: base_dir, file_path = base_and_path m_pack: MultiPack = MultiPack() input_pack_name = self.configs.input_pack_name output_pack_name = self.configs.output_pack_name text = "" offset = 0 with open(file_path, "r", encoding="utf8") as doc: # Remove long path from the beginning. doc_id = file_path[ file_path.startswith(base_dir) and len(base_dir):] doc_id = doc_id.strip(os.path.sep) input_pack = m_pack.add_pack(input_pack_name) input_pack.pack_name = doc_id for line in doc: line = line.strip() if len(line) == 0: continue # add sentence Sentence(input_pack, offset, offset + len(line)) text += line + '\n' offset = offset + len(line) + 1 input_pack.set_text( text, replace_func=self.text_replace_operation) # Create a output pack without text. m_pack.add_pack(output_pack_name) yield m_pack
def _parse_pack(self, file_path: str) -> Iterator[MultiPack]: m_pack: MultiPack = MultiPack() input_pack_name = self.config.input_pack_name output_pack_name = self.config.output_pack_name text = "" offset = 0 with open(file_path, "r", encoding="utf8") as doc: input_pack = DataPack(doc_id=file_path) for line in doc: line = line.strip() if len(line) == 0: continue # add sentence sent = Sentence(input_pack, offset, offset + len(line)) input_pack.add_entry(sent) text += line + '\n' offset = offset + len(line) + 1 input_pack.set_text( text, replace_func=self.text_replace_operation) output_pack = DataPack() m_pack.update_pack({ input_pack_name: input_pack, output_pack_name: output_pack }) yield m_pack
class DataPackTest(unittest.TestCase): def setUp(self) -> None: # Note: input source is created automatically by the system, but we # can also set it manually at test cases. pm = PackManager() self.multi_pack = MultiPack(pm) self.data_pack1 = self.multi_pack.add_pack(ref_name="left pack") self.data_pack2 = self.multi_pack.add_pack(ref_name="right pack") self.data_pack1.pack_name = "some pack" self.data_pack1.set_text("This pack contains some sample data.") self.data_pack2.pack_name = "another pack" self.data_pack2.set_text("This pack contains some other sample data.") def test_serialization(self): ser_str: str = self.multi_pack.serialize() print(ser_str) def test_add_pack(self): data_pack3 = self.multi_pack.add_pack(ref_name="new pack") data_pack3.pack_name = "the third pack" data_pack3.set_text("Test to see if we can add new packs..") self.assertEqual(len(self.multi_pack.packs), 3) self.assertEqual(self.multi_pack.pack_names, {'left pack', 'right pack', 'new pack'}) def test_rename_pack(self): self.multi_pack.rename_pack('right pack', 'last pack') self.assertEqual(self.multi_pack.pack_names, {'left pack', 'last pack'}) def test_multipack_groups(self): """ Test some multi pack group. Returns: """ # Add tokens to each pack. for pack in self.multi_pack.packs: _space_token(pack) # Create some group. token: Annotation left_tokens = {} for token in self.multi_pack.packs[0].get(Token): left_tokens[token.text] = token right_tokens = {} for token in self.multi_pack.packs[1].get(Token): right_tokens[token.text] = token for key, lt in left_tokens.items(): if key in right_tokens: rt = right_tokens[key] self.multi_pack.add_entry(MultiPackGroup( self.multi_pack, [lt, rt])) # Check the groups. expected_content = [("This", "This"), ("pack", "pack"), ("contains", "contains"), ("some", "some"), ("sample", "sample"), ("data.", "data.")] group_content = [] g: MultiPackGroup for g in self.multi_pack.get(MultiPackGroup): e: Annotation group_content.append(tuple([e.text for e in g.get_members()])) self.assertListEqual(expected_content, group_content) def test_multipack_entries(self): """ Test some multi pack entry. Returns: """ # 1. Add tokens to each pack. for pack in self.multi_pack.packs: _space_token(pack) left_tokens = [t.text for t in self.multi_pack.packs[0].get(Token)] right_tokens = [t.text for t in self.multi_pack.packs[1].get(Token)] self.assertListEqual(left_tokens, ["This", "pack", "contains", "some", "sample", "data."]) self.assertListEqual(right_tokens, ["This", "pack", "contains", "some", "other", "sample", "data."]) # 2. Link the same words from two packs. token: Annotation left_tokens = {} for token in self.multi_pack.packs[0].get(Token): left_tokens[token.text] = token right_tokens = {} for token in self.multi_pack.packs[1].get(Token): right_tokens[token.text] = token for key, lt in left_tokens.items(): if key in right_tokens: rt = right_tokens[key] self.multi_pack.add_entry(MultiPackLink( self.multi_pack, lt, rt)) # One way to link tokens. linked_tokens = [] for link in self.multi_pack.links: parent_text = link.get_parent().text child_text = link.get_child().text linked_tokens.append((parent_text, child_text)) self.assertListEqual( linked_tokens, [("This", "This"), ("pack", "pack"), ("contains", "contains"), ("some", "some"), ("sample", "sample"), ("data.", "data.")]) # Another way to get the links linked_tokens = [] for link in self.multi_pack.get(MultiPackLink): parent_text = link.get_parent().text child_text = link.get_child().text linked_tokens.append((parent_text, child_text)) self.assertListEqual( linked_tokens, [("This", "This"), ("pack", "pack"), ("contains", "contains"), ("some", "some"), ("sample", "sample"), ("data.", "data.")]) # 3. Test deletion # Delete the second link. self.multi_pack.delete_entry(self.multi_pack.links[1]) linked_tokens = [] for link in self.multi_pack.links: parent_text = link.get_parent().text child_text = link.get_child().text linked_tokens.append((parent_text, child_text)) self.assertListEqual( linked_tokens, [("This", "This"), ("contains", "contains"), ("some", "some"), ("sample", "sample"), ("data.", "data.")])
def _process(self, input_pack: MultiPack): pack = input_pack.add_pack() pack.set_text(input_pack.get_pack_at(0).text)
def _process(self, input_pack: MultiPack): rels = list(input_pack.get_entries_of(CrossDocEntityRelation)) self.coref_count += len(rels)
def _process(self, input_pack: MultiPack): rels = input_pack.get_entries_by_type(CrossDocEntityRelation) self.coref_count += len(rels)
def _parse_pack(self, name: str) -> Iterator[MultiPack]: p = MultiPack() p.pack_name = name yield p
def _process(self, input_pack: MultiPack): for doc_i in docs: pack = input_pack.add_pack(ref_name=doc_i) pack.set_text(docs[doc_i]) Document(pack, 0, len(pack.text))
def new_pack(self, pack_name: Optional[str] = None) -> MultiPack: return MultiPack(self._pack_manager, pack_name)
def _process_query( self, input_pack: MultiPack) -> Tuple[DataPack, Dict[str, Any]]: query_pack = input_pack.get_pack(self.configs.query_pack_name) query = self._build_query(text=query_pack.text) return query_pack, query
class SelectorTest(unittest.TestCase): def setUp(self) -> None: self.multi_pack = MultiPack() data_pack1 = self.multi_pack.add_pack(ref_name="pack1") data_pack2 = self.multi_pack.add_pack(ref_name="pack2") data_pack3 = self.multi_pack.add_pack(ref_name="pack_three") data_pack1.pack_name = "1" data_pack2.pack_name = "2" data_pack3.pack_name = "Three" def test_name_match_selector(self) -> None: selector = NameMatchSelector() selector.initialize( configs={"select_name": "pack1"}, ) packs = selector.select(self.multi_pack) doc_ids = ["1"] for doc_id, pack in zip(doc_ids, packs): self.assertEqual(doc_id, pack.pack_name) # Test reverse selection. selector.initialize( configs={"select_name": "pack1", "reverse_selection": True}, ) packs = selector.select(self.multi_pack) doc_ids = ["2", "Three"] for doc_id, pack in zip(doc_ids, packs): self.assertEqual(doc_id, pack.pack_name) def test_name_match_selector_backward_compatability(self) -> None: selector = NameMatchSelector(select_name="pack1") selector.initialize() packs = selector.select(self.multi_pack) doc_ids = ["1"] for doc_id, pack in zip(doc_ids, packs): self.assertEqual(doc_id, pack.pack_name) selector = NameMatchSelector("pack1") selector.initialize() packs = selector.select(self.multi_pack) doc_ids = ["1"] for doc_id, pack in zip(doc_ids, packs): self.assertEqual(doc_id, pack.pack_name) def test_regex_name_match_selector(self) -> None: selector = RegexNameMatchSelector() selector.initialize( configs={"select_name": "^.*\\d$"}, ) packs = selector.select(self.multi_pack) doc_ids = ["1", "2"] for doc_id, pack in zip(doc_ids, packs): self.assertEqual(doc_id, pack.pack_name) # Test reverse selection. selector.initialize( {"select_name": "^.*\\d$", "reverse_selection": True} ) packs = selector.select(self.multi_pack) doc_ids = ["Three"] for doc_id, pack in zip(doc_ids, packs): self.assertEqual(doc_id, pack.pack_name) def test_regex_name_match_selector_backward_compatability(self) -> None: selector = RegexNameMatchSelector(select_name="^.*\\d$") selector.initialize() packs = selector.select(self.multi_pack) doc_ids = ["1", "2"] for doc_id, pack in zip(doc_ids, packs): self.assertEqual(doc_id, pack.pack_name) # Test different configuration method (backward compatibility) selector = RegexNameMatchSelector("^.*\\d$") selector.initialize() packs = selector.select(self.multi_pack) doc_ids = ["1", "2"] for doc_id, pack in zip(doc_ids, packs): self.assertEqual(doc_id, pack.pack_name) # Test reverse selection. selector.initialize({"reverse_selection": True}) packs = selector.select(self.multi_pack) doc_ids = ["Three"] for doc_id, pack in zip(doc_ids, packs): self.assertEqual(doc_id, pack.pack_name) def test_first_pack_selector(self) -> None: selector = FirstPackSelector() selector.initialize() packs = list(selector.select(self.multi_pack)) self.assertEqual(len(packs), 1) self.assertEqual(packs[0].pack_name, "1") # Test reverse selection. selector.initialize({"reverse_selection": True}) packs = list(selector.select(self.multi_pack)) self.assertEqual(len(packs), len(self.multi_pack.packs) - 1) def test_all_pack_selector(self) -> None: selector = AllPackSelector() selector.initialize() packs = selector.select(self.multi_pack) doc_ids = ["1", "2", "Three"] for doc_id, pack in zip(doc_ids, packs): self.assertEqual(doc_id, pack.pack_name)