def _process(self, input_pack: MultiPack): context_list = list() doc_id_list = list() for doc_id in input_pack.pack_names: if doc_id == self.configs.question_pack_name: continue pack = input_pack.get_pack(doc_id) context_list.append(pack.get_single(self.configs.entry_type).text) doc_id_list.append(doc_id) question_pack = input_pack.get_pack(self.configs.question_pack_name) first_question = question_pack.get_single(Sentence) question_list = [question_pack.text for i in range(len(context_list))] result_collection = self.extractor( context=context_list, question=question_list, max_answer_len=self.configs.max_answer_len, handle_impossible_answer=self.configs.handle_impossible_answer, ) for i, result in enumerate(result_collection): start = result["start"] end = result["end"] doc_pack = input_pack.get_pack(doc_id_list[i]) ans_phrase = Phrase(pack=doc_pack, begin=start, end=end) input_pack.add_entry( MultiPackLink(input_pack, first_question, ans_phrase))
def pack(self, data_pack: MultiPack, output_dict): """ Write the prediction results back to datapack. If :attr:`_overwrite` is `True`, write the predicted ner to the original tokens. Otherwise, create a new set of tokens and write the predicted ner to the new tokens (usually use this configuration for evaluation.) """ assert output_dict is not None output_pack = data_pack.get_pack(self.output_pack_name) input_sent_tids = output_dict["input_sents_tids"] output_sentences = output_dict["output_sents"] text = output_pack.text input_pack = data_pack.get_pack(self.input_pack_name) for input_id, output_sentence in zip(input_sent_tids, output_sentences): offset = len(output_pack.text) sent = Sentence(output_pack, offset, offset + len(output_sentence)) text += output_sentence + "\n" input_sent = input_pack.get_entry(input_id) cross_link = MultiPackLink(data_pack, input_sent, sent) data_pack.add_entry(cross_link) # We may also consider adding two link with opposite directions # Here the unidirectional link indicates the generation dependency output_pack.set_text(text)
def _copy_multi_pack_link_or_group( self, entry: Union[MultiPackLink, MultiPackGroup], multi_pack: MultiPack ) -> bool: r""" This function copies a MultiPackLink/MultiPackGroup in the multipack. It could be used in tasks such as text generation, where MultiPackLink is used to align the source and target. Args: entry: The MultiPackLink/MultiPackGroup to copy. multi_pack: The multi_pack contains the input entry. Returns: A bool value indicating whether the copy happens. """ # The entry should be either MultiPackLink or MultiPackGroup. is_link: bool = isinstance(entry, BaseLink) children: List[Entry] if is_link: children = [entry.get_parent(), entry.get_child()] else: children = entry.get_members() # Get the copied children entries. new_children: List[Entry] = [] for child_entry in children: child_pack: DataPack = child_entry.pack child_pack_pid: int = child_pack.pack_id # The new pack should be present. if ( child_pack_pid not in self._data_pack_map or child_pack_pid not in self._entry_maps ): return False new_child_pack: DataPack = multi_pack.get_pack_at( multi_pack.get_pack_index(self._data_pack_map[child_pack_pid]) ) # The new child entry should be present. if child_entry.tid not in self._entry_maps[child_pack_pid]: return False new_child_tid: int = self._entry_maps[child_pack_pid][ child_entry.tid ] new_child_entry: Entry = new_child_pack.get_entry(new_child_tid) new_children.append(new_child_entry) # Create the new entry and add to the multi pack. new_entry: Entry if is_link: entry = cast(MultiPackLink, entry) new_link_parent: Entry = new_children[0] new_link_child: Entry = new_children[1] new_entry = type(entry)( multi_pack, new_link_parent, new_link_child # type: ignore ) else: entry = cast(MultiPackGroup, entry) new_entry = type(entry)(multi_pack, new_children) # type: ignore multi_pack.add_entry(new_entry) return True
def _process(self, input_pack: MultiPack): pack_i = input_pack.get_pack('default') pack_j = input_pack.get_pack('duplicate') for ent_i, ent_j in zip(pack_i.get(EntityMention), pack_j.get(EntityMention)): link = CrossDocEntityRelation(input_pack, ent_i, ent_j) link.rel_type = 'coreference' input_pack.add_entry(link)
def test_multi_pack_copy_link_or_group(self): processor = ReplacementDataAugmentProcessor() m_pack = MultiPack() src_pack = m_pack.add_pack("src") tgt_pack = m_pack.add_pack("tgt") src_pack.set_text("input") tgt_pack.set_text("output") src_token = src_pack.add_entry(Token(src_pack, 0, len(src_pack.text))) tgt_token = tgt_pack.add_entry(Token(tgt_pack, 0, len(tgt_pack.text))) mpl = m_pack.add_entry(MultiPackLink(m_pack, src_token, tgt_token)) # The MultiPackLink should not be copied, because its children are not copied. self.assertEqual(processor._copy_multi_pack_link_or_group(mpl, m_pack), False) new_src_pack = processor._auto_align_annotations(src_pack, []) self.assertEqual(len(list(new_src_pack.get(Token))), 1)
class DataPackTest(unittest.TestCase): def setUp(self) -> None: # Note: input source is created automatically by the system, but we # can also set it manually at test cases. pm = PackManager() self.multi_pack = MultiPack(pm) self.data_pack1 = self.multi_pack.add_pack(ref_name="left pack") self.data_pack2 = self.multi_pack.add_pack(ref_name="right pack") self.data_pack1.pack_name = "some pack" self.data_pack1.set_text("This pack contains some sample data.") self.data_pack2.pack_name = "another pack" self.data_pack2.set_text("This pack contains some other sample data.") def test_serialization(self): ser_str: str = self.multi_pack.serialize() print(ser_str) def test_add_pack(self): data_pack3 = self.multi_pack.add_pack(ref_name="new pack") data_pack3.pack_name = "the third pack" data_pack3.set_text("Test to see if we can add new packs..") self.assertEqual(len(self.multi_pack.packs), 3) self.assertEqual(self.multi_pack.pack_names, {'left pack', 'right pack', 'new pack'}) def test_rename_pack(self): self.multi_pack.rename_pack('right pack', 'last pack') self.assertEqual(self.multi_pack.pack_names, {'left pack', 'last pack'}) def test_multipack_groups(self): """ Test some multi pack group. Returns: """ # Add tokens to each pack. for pack in self.multi_pack.packs: _space_token(pack) # Create some group. token: Annotation left_tokens = {} for token in self.multi_pack.packs[0].get(Token): left_tokens[token.text] = token right_tokens = {} for token in self.multi_pack.packs[1].get(Token): right_tokens[token.text] = token for key, lt in left_tokens.items(): if key in right_tokens: rt = right_tokens[key] self.multi_pack.add_entry(MultiPackGroup( self.multi_pack, [lt, rt])) # Check the groups. expected_content = [("This", "This"), ("pack", "pack"), ("contains", "contains"), ("some", "some"), ("sample", "sample"), ("data.", "data.")] group_content = [] g: MultiPackGroup for g in self.multi_pack.get(MultiPackGroup): e: Annotation group_content.append(tuple([e.text for e in g.get_members()])) self.assertListEqual(expected_content, group_content) def test_multipack_entries(self): """ Test some multi pack entry. Returns: """ # 1. Add tokens to each pack. for pack in self.multi_pack.packs: _space_token(pack) left_tokens = [t.text for t in self.multi_pack.packs[0].get(Token)] right_tokens = [t.text for t in self.multi_pack.packs[1].get(Token)] self.assertListEqual(left_tokens, ["This", "pack", "contains", "some", "sample", "data."]) self.assertListEqual(right_tokens, ["This", "pack", "contains", "some", "other", "sample", "data."]) # 2. Link the same words from two packs. token: Annotation left_tokens = {} for token in self.multi_pack.packs[0].get(Token): left_tokens[token.text] = token right_tokens = {} for token in self.multi_pack.packs[1].get(Token): right_tokens[token.text] = token for key, lt in left_tokens.items(): if key in right_tokens: rt = right_tokens[key] self.multi_pack.add_entry(MultiPackLink( self.multi_pack, lt, rt)) # One way to link tokens. linked_tokens = [] for link in self.multi_pack.links: parent_text = link.get_parent().text child_text = link.get_child().text linked_tokens.append((parent_text, child_text)) self.assertListEqual( linked_tokens, [("This", "This"), ("pack", "pack"), ("contains", "contains"), ("some", "some"), ("sample", "sample"), ("data.", "data.")]) # Another way to get the links linked_tokens = [] for link in self.multi_pack.get(MultiPackLink): parent_text = link.get_parent().text child_text = link.get_child().text linked_tokens.append((parent_text, child_text)) self.assertListEqual( linked_tokens, [("This", "This"), ("pack", "pack"), ("contains", "contains"), ("some", "some"), ("sample", "sample"), ("data.", "data.")]) # 3. Test deletion # Delete the second link. self.multi_pack.delete_entry(self.multi_pack.links[1]) linked_tokens = [] for link in self.multi_pack.links: parent_text = link.get_parent().text child_text = link.get_child().text linked_tokens.append((parent_text, child_text)) self.assertListEqual( linked_tokens, [("This", "This"), ("contains", "contains"), ("some", "some"), ("sample", "sample"), ("data.", "data.")])