def test_modify_index(self, index, old_spans, new_spans, is_begin, is_inclusive, aligned_index): old_spans = [Span(span[0], span[1]) for span in old_spans] new_spans = [Span(span[0], span[1]) for span in new_spans] output = modify_index(index, old_spans, new_spans, is_begin, is_inclusive) self.assertEqual(aligned_index, output)
def predict(self, data_batch: Dict) -> Dict[str, List[Prediction]]: text: List[List[str]] = [ sentence.tolist() for sentence in data_batch["Token"]["text"]] text_ids, length = tx.data.padded_batch([ self.word_vocab.map_tokens_to_ids_py(sentence) for sentence in text]) text_ids = torch.from_numpy(text_ids).to(device=self.device) length = torch.tensor(length, dtype=torch.long, device=self.device) batch_size = len(text) batch = tx.data.Batch(batch_size, text=text, text_ids=text_ids, length=length, srl=[[]] * batch_size) self.model = self.model.to(self.device) batch_srl_spans = self.model.decode(batch) # Convert predictions into annotations. batch_predictions: List[Prediction] = [] for idx, srl_spans in enumerate(batch_srl_spans): word_spans = data_batch["Token"]["span"][idx] predictions: Prediction = [] for pred_idx, pred_args in srl_spans.items(): begin, end = word_spans[pred_idx] # TODO cannot create annotation here. pred_span = Span(begin, end) arguments = [] for arg in pred_args: begin = word_spans[arg.start][0] end = word_spans[arg.end][1] arg_annotation = Span(begin, end) arguments.append((arg_annotation, arg.label)) predictions.append((pred_span, arguments)) batch_predictions.append(predictions) return {"predictions": batch_predictions}
def parse_allennlp_srl_tags(tags: str) -> \ Tuple[Optional[Span], List[Tuple[Span, str]]]: r"""Parse the tag list of a specific verb output by AllenNLP SRL processor. Args: tags (str): a str of semantic role labels. Returns: the span of the verb and its semantic role arguments. """ pred_span = None arguments = [] begin, end, prev_argument = -1, -1, '' tags += ' O' for i, tag in enumerate(tags.split()): argument = '-'.join(tag.split('-')[1:]) if tag[0] == 'O' or tag[0] == 'B' or \ (tag[0] == 'I' and argument != prev_argument): if prev_argument == 'V': pred_span = Span(begin, end) elif prev_argument != '': arg_span = Span(begin, end) arguments.append((arg_span, prev_argument)) begin = i end = i prev_argument = argument else: end = i return pred_span, arguments
def test_reader_original_span_test(self, value): span_ops, output = ( [ (Span(11, 19), "New"), (Span(19, 20), " Shiny "), (Span(25, 25), " Ends"), ], "<title>The New Shiny Title Ends </title>", ) input_span, expected_span, mode = value pipeline = Pipeline() reader = PlainTextReader() reader.text_replace_operation = lambda _: span_ops pipeline.set_reader(reader, {"file_ext": ".html"}) pipeline.initialize() pack = pipeline.process_one(self.test_dir) self.assertEqual(pack.text, output) output_span = pack.get_original_span(input_span, mode) self.assertEqual( output_span, expected_span, f"Expected: ({expected_span.begin, expected_span.end}" f"), Found: ({output_span.begin, output_span.end})" f" when Input: ({input_span.begin, input_span.end})" f" and Mode: {mode}", )
def set_span(self, begin: int, end: int): r"""Set the span of the annotation. """ if begin > end: raise ValueError( f"The begin {begin} of span is greater than the end {end}") self._span = Span(begin, end)
def _overlap_with_existing(self, pid: int, begin: int, end: int) -> bool: r""" This function will check whether the new span has an overlap with any existing spans. Args: pid: Datapack Id. begin: The span begin index. end: The span end index. Returns: True if the input span overlaps with any existing spans, False otherwise. """ if len(self._replaced_annos[pid]) == 0: return False ind: int = bisect_left(self._replaced_annos[pid], (Span(begin, begin), "")) - 1 if ind < 0: ind += 1 while ind < len(self._replaced_annos[pid]): span: Span = self._replaced_annos[pid][ind][0] if not (span.begin >= end or span.end <= begin): return True if span.begin > end: break ind += 1 return False
def _insert(self, inserted_text: str, data_pack: DataPack, pos: int) -> bool: r""" This is a wrapper function to insert a new annotation. After getting the inserted text, it will register the input & output for later batch process of building the new data pack. The insertion at each position can only occur once. If there is already an insertion at current position, it will abort the insertion and return False. Args: inserted_text: The text string to insert. data_pack: The datapack for insertion. pos: The position(index) of insertion. Returns: A bool value. True if the insertion happened, False otherwise. """ pid: int = data_pack.pack_id if self._overlap_with_existing(pid, pos, pos): return False if pos not in self._inserted_annos_pos_len[pid]: self._replaced_annos[pid].add((Span(pos, pos), inserted_text)) self._inserted_annos_pos_len[pid][pos] = len(inserted_text) return True return False
def test_reader_original_span_test(self, value): span_ops, output = ([(Span(11, 19), 'New'), (Span(19, 20), ' Shiny '), (Span(25, 25), ' Ends')], '<title>The New Shiny Title Ends </title>') input_span, expected_span, mode = value reader = PlainTextReader() reader.text_replace_operation = lambda _: span_ops pack = list(reader.parse_pack(self.file_path))[0] self.assertEqual(pack.text, output) output_span = pack.get_original_span(input_span, mode) self.assertEqual( output_span, expected_span, f"Expected: ({expected_span.begin, expected_span.end}" f"), Found: ({output_span.begin, output_span.end})" f" when Input: ({input_span.begin, input_span.end})" f" and Mode: {mode}")
def modify_text_and_track_ops(original_text: str, replace_operations: ReplaceOperationsType) -> \ Tuple[str, ReplaceOperationsType, List[Tuple[Span, Span]], int]: r"""Modifies the original text using ``replace_operations`` provided by the user to return modified text and other data required for tracking original text. Args: original_text: Text to be modified. replace_operations: A list of spans and the corresponding replacement string that the span in the original string is to be replaced with to obtain the original string. Returns: modified_text: Text after modification. replace_back_operations: A list of spans and the corresponding replacement string that the span in the modified string is to be replaced with to obtain the original string. processed_original_spans: List of processed span and its corresponding original span. orig_text_len: length of original text. """ orig_text_len: int = len(original_text) mod_text: str = original_text increment: int = 0 prev_span_end: int = 0 replace_back_operations: List[Tuple[Span, str]] = [] processed_original_spans: List[Tuple[Span, Span]] = [] # Sorting the spans such that the order of replacement strings # is maintained -> utilizing the stable sort property of python sort replace_operations.sort(key=lambda item: item[0]) for span, replacement in replace_operations: if span.begin < 0 or span.end < 0: raise ValueError("Negative indexing not supported") if span.begin > len(original_text) or span.end > len(original_text): raise ValueError( "One of the span indices are outside the string length") if span.end < span.begin: print(span.begin, span.end) raise ValueError( "One of the end indices is lesser than start index") if span.begin < prev_span_end: raise ValueError( "The replacement spans should be mutually exclusive") span_begin = span.begin + increment span_end = span.end + increment original_span_text = mod_text[span_begin:span_end] mod_text = mod_text[:span_begin] + replacement + mod_text[span_end:] increment += len(replacement) - (span.end - span.begin) replacement_span = Span(span_begin, span_begin + len(replacement)) replace_back_operations.append((replacement_span, original_span_text)) processed_original_spans.append((replacement_span, span)) prev_span_end = span.end return (mod_text, replace_back_operations, sorted(processed_original_spans), orig_text_len)
def __getstate__(self): r"""For serializing Annotation, we should create Span annotations for compatibility purposes. """ self._span = Span(self._begin, self._end) state = super().__getstate__() state.pop("_begin") state.pop("_end") return state
def parse_allennlp_srl_tags( tags: str, ) -> Tuple[Optional[Span], List[Tuple[Span, str]]]: r"""Parse the tag list of a specific verb output by AllenNLP SRL processor. Args: tags (str): a str of semantic role labels. Returns: the span of the verb and its semantic role arguments. """ pred_span = None arguments = [] begin, end, prev_argument = -1, -1, "" tags += " O" for i, tag in enumerate(tags.split()): argument = "-".join(tag.split("-")[1:]) if ( tag[0] == "O" or tag[0] == "B" or (tag[0] == "I" and argument != prev_argument) ): if prev_argument == "V": pred_span = Span(begin, end) elif prev_argument != "": arg_span = Span(begin, end) arguments.append((arg_span, prev_argument)) begin = i end = i prev_argument = argument else: end = i return pred_span, arguments
def set_span(self, begin: int, end: int): r"""Set the span of the annotation. """ if not isinstance(begin, int) or not isinstance(end, int): raise ValueError( f"Begin and End for an annotation must be integer, " f"got {begin}:{type(begin)} and {end}:{type(end)}") if begin > end: raise ValueError( f"The begin {begin} of span is greater than the end {end}") if begin < 0: raise ValueError('The begin cannot be negative.') self._span = Span(begin, end)
def test_span(self): span1 = Span(1, 2) span2 = Span(1, 2) self.assertEqual(span1, span2) span1 = Span(1, 2) span2 = Span(1, 3) self.assertLess(span1, span2) span1 = Span(1, 2) span2 = Span(2, 3) self.assertLess(span1, span2)
def set_span(self, begin: int, end: int): r"""Set the span of the annotation. """ self._span = Span(begin, end)
def __init__(self, pack: PackType, begin: int, end: int): super().__init__(pack) if begin > end: raise ValueError( f"The begin {begin} of span is greater than the end {end}") self._span = Span(begin, end)
class PlainTextReaderTest(unittest.TestCase): def setUp(self): # Create a temporary directory self.test_dir = tempfile.mkdtemp() self.orig_text = "<title>The Original Title </title>" self.file_path = os.path.join(self.test_dir, 'test.html') self.mod_file_path = os.path.join(self.test_dir, 'mod_test.html') with open(self.file_path, 'w') as f: f.write(self.orig_text) def tearDown(self): # Remove the directory after the test shutil.rmtree(self.test_dir) def test_reader_no_replace_test(self): # Read with no replacements reader = PlainTextReader() PackManager().set_input_source(reader.component_name) pack = list(reader.parse_pack(self.file_path))[0] self.assertEqual(pack.text, self.orig_text) @data( # No replacement ([], '<title>The Original Title </title>'), # Insertion ([(Span(11, 11), 'New ')], '<title>The New Original Title </title>'), # Single, sorted multiple and unsorted multiple replacements ([(Span(11, 19), 'New')], '<title>The New Title </title>'), ([(Span(0, 7), ''), (Span(26, 34), '')], 'The Original Title '), ([(Span(26, 34), ''), (Span(0, 7), '')], 'The Original Title '), ) def test_reader_replace_back_test(self, value): # Reading with replacements - replacing a span and changing it back span_ops, output = value reader = PlainTextReader() reader.text_replace_operation = lambda _: span_ops pack = list(reader.parse_pack(self.file_path))[0] self.assertEqual(pack.text, output) orig_text_from_pack = pack.get_original_text() self.assertEqual(self.orig_text, orig_text_from_pack) @data( # before span starts (Span(1, 6), Span(1, 6), "relaxed"), (Span(1, 6), Span(1, 6), "strict"), # after span ends (Span(15, 22), Span(19, 21), "relaxed"), # span itself (Span(11, 14), Span(11, 19), "relaxed"), # complete string (Span(0, 40), Span(0, 34), "strict"), # cases ending to or starting from between the span (Span(11, 40), Span(11, 34), "relaxed"), (Span(13, 40), Span(11, 34), "relaxed"), (Span(14, 40), Span(19, 34), "relaxed"), (Span(13, 40), Span(11, 34), "backward"), (Span(13, 40), Span(19, 34), "forward"), (Span(0, 12), Span(0, 19), "relaxed"), (Span(0, 13), Span(0, 11), "backward"), (Span(0, 14), Span(0, 19), "forward"), # same begin and end (Span(38, 38), Span(32, 32), "relaxed"), (Span(38, 38), Span(32, 32), "strict"), (Span(38, 38), Span(32, 32), "backward"), (Span(38, 38), Span(32, 32), "forward")) def test_reader_original_span_test(self, value): span_ops, output = ([(Span(11, 19), 'New'), (Span(19, 20), ' Shiny '), (Span(25, 25), ' Ends')], '<title>The New Shiny Title Ends </title>') input_span, expected_span, mode = value reader = PlainTextReader() PackManager().set_input_source(reader.component_name) reader.text_replace_operation = lambda _: span_ops pack = list(reader.parse_pack(self.file_path))[0] self.assertEqual(pack.text, output) output_span = pack.get_original_span(input_span, mode) self.assertEqual( output_span, expected_span, f"Expected: ({expected_span.begin, expected_span.end}" f"), Found: ({output_span.begin, output_span.end})" f" when Input: ({input_span.begin, input_span.end})" f" and Mode: {mode}") @data( ([(Span(5, 8), ''), (Span(6, 10), '')], None), # overlap ([(Span(5, 8), ''), (Span(6, 1000), '')], None), # outside limit ([(Span(-1, 8), '')], None), # does not support negative indexing ([(Span(8, -1), '')], None), # does not support negative indexing ([(Span(2, 1), '')], None) # start should be lesser than end ) def test_reader_replace_error_test(self, value): # Read with errors in span replacements span_ops, output = value reader = PlainTextReader() reader.text_replace_operation = lambda _: span_ops try: list(reader.parse_pack(self.file_path))[0] except ValueError: pass except Exception: self.fail('Unexpected exception raised:') else: self.fail('Expected Exception not raised')
def get_original_span(self, input_processed_span: Span, align_mode: str = "relaxed"): r"""Function to obtain span of the original text that aligns with the given span of the processed text. Args: input_processed_span: Span of the processed text for which the corresponding span of the original text is desired align_mode: The strictness criteria for alignment in the ambiguous cases, that is, if a part of input_processed_span spans a part of the inserted span, then align_mode controls whether to use the span fully or ignore it completely according to the following possible values - "strict" - do not allow ambiguous input, give ValueError - "relaxed" - consider spans on both sides - "forward" - align looking forward, that is, ignore the span towards the left, but consider the span towards the right - "backward" - align looking backwards, that is, ignore the span towards the right, but consider the span towards the left Returns: Span of the original text that aligns with input_processed_span Example: * Let o-up1, o-up2, ... and m-up1, m-up2, ... denote the unprocessed spans of the original and modified string respectively. Note that each o-up would have a corresponding m-up of the same size. * Let o-pr1, o-pr2, ... and m-pr1, m-pr2, ... denote the processed spans of the original and modified string respectively. Note that each o-p is modified to a corresponding m-pr that may be of a different size than o-pr. * Original string: <--o-up1--> <-o-pr1-> <----o-up2----> <----o-pr2----> <-o-up3-> * Modified string: <--m-up1--> <----m-pr1----> <----m-up2----> <-m-pr2-> <-m-up3-> * Note that `self.inverse_original_spans` that contains modified processed spans and their corresponding original spans, would look like - [(o-pr1, m-pr1), (o-pr2, m-pr2)] >> data_pack = DataPack() >> original_text = "He plays in the park" >> data_pack.set_text(original_text,\ >> lambda _: [(Span(0, 2), "She"))] >> data_pack.text "She plays in the park" >> input_processed_span = Span(0, len("She plays")) >> orig_span = data_pack.get_original_span(input_processed_span) >> data_pack.get_original_text()[orig_span.begin: orig_span.end] "He plays" """ assert align_mode in ["relaxed", "strict", "backward", "forward"] req_begin = input_processed_span.begin req_end = input_processed_span.end def get_original_index(input_index: int, is_begin_index: bool, mode: str) -> int: r""" Args: input_index: begin or end index of the input span is_begin_index: if the index is the begin index of the input span or the end index of the input span mode: alignment mode Returns: Original index that aligns with input_index """ if len(self.__processed_original_spans) == 0: return input_index len_processed_text = len(self._text) orig_index = None prev_end = 0 for (inverse_span, original_span) in self.__processed_original_spans: # check if the input_index lies between one of the unprocessed # spans if prev_end <= input_index < inverse_span.begin: increment = original_span.begin - inverse_span.begin orig_index = input_index + increment # check if the input_index lies between one of the processed # spans elif inverse_span.begin <= input_index < inverse_span.end: # look backward - backward shift of input_index if is_begin_index and mode in ["backward", "relaxed"]: orig_index = original_span.begin if not is_begin_index and mode == "backward": orig_index = original_span.begin - 1 # look forward - forward shift of input_index if is_begin_index and mode == "forward": orig_index = original_span.end if not is_begin_index and mode in ["forward", "relaxed"]: orig_index = original_span.end - 1 # break if the original index is populated if orig_index is not None: break prev_end = inverse_span.end if orig_index is None: # check if the input_index lies between the last unprocessed # span inverse_span, original_span = self.__processed_original_spans[ -1] if inverse_span.end <= input_index < len_processed_text: increment = original_span.end - inverse_span.end orig_index = input_index + increment else: # check if there input_index is not valid given the # alignment mode or lies outside the processed string raise ValueError(f"The input span either does not adhere " f"to the {align_mode} alignment mode or " f"lies outside to the processed string.") return orig_index orig_begin = get_original_index(req_begin, True, align_mode) orig_end = get_original_index(req_end - 1, False, align_mode) + 1 return Span(orig_begin, orig_end)
def collect_span(self, begin, end): self.spans.append((Span(begin, end), ''))
class HTMLReaderPipelineTest(unittest.TestCase): def setUp(self): self._cache_directory = Path(os.path.join(os.getcwd(), "cache_html")) self.reader = HTMLReader(cache_directory=self._cache_directory, append_to_cache=True) self.pl1 = Pipeline[DataPack]() self.pl1.set_reader(self.reader) self.pl1.initialize() self.pl2 = Pipeline[DataPack]() self.pl2.set_reader( HTMLReader(from_cache=True, cache_directory=self._cache_directory)) self.pl2.initialize() def tearDown(self): shutil.rmtree(self._cache_directory) @data( ("<title>The Original Title </title>", "The Original Title "), ( "<!DOCTYPE html><html><title>Page Title</title><body><p>This is a " "paragraph</p></body></html>", "Page TitleThis is a paragraph", ), ( """<!DOCTYPE html> <html> <head> <title>Page Title</title> </head> <body> <h1>This is a Heading</h1> <p>This is a paragraph.</p> </body> </html> """, """ \n \n Page Title\n \n \n This is a Heading This is a paragraph.\n \n \n """, ), ( """<!DOCTYPE html> <h1 id="section1" class="bar">Section 1</h1> <p class="foo">foo bar\nbaz blah </p> <!-- cool beans! --> <hr/> <br> <p><em>The <strong>End!</strong></em></p> <p><em>error</p></em>weird < q <*****@*****.**> """, """ Section 1 foo bar\nbaz blah \n \n \n \n The End! errorweird < q \n """, ), ( """<!DOCTYPE html PUBLIC "-//W3C//DTD HTML 4.01//EN"> <html<head> <title// <p ltr<span id=p>Text</span</p> </>""", """\n \n Text """, ), ) def test_reader(self, value): # Also writes to cache so that we can read from cache directory # during caching test html_input, expected_output = value for pack in self.pl1.process_dataset(html_input): self.assertEqual(expected_output, pack.text) @data( ("<title>The Original Title </title>"), ("""<!DOCTYPE html PUBLIC "-//W3C//DTD HTML 4.01//EN"> <html<head> <title// <p ltr<span id=p>Text</span</p> </>"""), ("""<!DOCTYPE html> <h1 id="section1" class="bar">Section 1</h1> <p class="foo">foo bar\nbaz blah </p> <!-- cool beans! --> <hr/> <br> <p><em>The <strong>End!</strong></em></p> <p><em>error</p></em>weird < q <*****@*****.**> """), ) def test_reader_replace_back(self, value): input_data = value for pack in self.pl1.process_dataset(input_data): original_text = pack.get_original_text() self.assertEqual(original_text, input_data) @data( ( Span(0, 3), Span(7, 10), "<title>The Original Title </title>", "strict", ), ( Span(18, 22), Span(101, 105), """<!DOCTYPE html PUBLIC "-//W34.01//EN"> <html<head> <title// <p ltr<span id=p>Text</span</p> </>""", "relaxed", ), # # cases ending to or starting from between the span ( Span(15, 30), Span(60, 95), """<!DOCTYPE html> <h1 id="section1" class="bar">Section 1</h1> <p class="foo">foo bar\nbaz blah </p> <!-- cool beans! --> <hr/> <br> <p><em>The <strong>End!</strong></em></p> <p><em>error</p></em>weird < q <*****@*****.**>""", "forward", ), # before span starts ( Span(0, 3), Span(0, 3), "Some text<title>The Original Title </title>", "relaxed", ), ( Span(0, 3), Span(0, 3), "Some text<title>The Original Title </title>", "strict", ), # after span ends # There's an issue with this #TODO (assign) mansi # returns a span of (43, 35) which is wrong. # (Span(28, 28), Span(43, 43), # 'Some text<title>The Original Title </title>T', # "strict"), # same begin and end ( Span(14, 14), Span(21, 21), "Some text<title>The Original Title </title>", "strict", ), ( Span(14, 14), Span(21, 21), "Some text<title>The Original Title </title>", "relaxed", ), ( Span(14, 14), Span(21, 21), "Some text<title>The Original Title </title>", "backward", ), ( Span(14, 14), Span(21, 21), "Some text<title>The Original Title </title>", "forward", ), ) def test_reader_original_span(self, value): new_span, expected_orig_span, html_input, mode = value for pack in self.pl1.process_dataset(html_input): # Retrieve original text original_text = pack.get_original_text() self.assertEqual(original_text, html_input) # Retrieve original span original_span = pack.get_original_span(new_span, mode) self.assertEqual(expected_orig_span, original_span) @data( [ "<title>The Original Title </title>", "<!DOCTYPE html><html><title>Page Title</title><body><p>This is a " "paragraph</p></body></html>", ], ["<html>Test1</html>", "<html>Test12</html>", "<html>Test3</html>"], ) def test_reader_caching(self, value): count_orig = 0 content = [] for pack in self.pl1.process_dataset(value): content.append(pack.text) count_orig = count_orig + 1 num_files = len(os.listdir(self._cache_directory)) self.assertEqual(num_files, count_orig) # Test Caching count_cached = 0 content_cached = [] for pack in self.pl2.process_dataset(value): content_cached.append(pack.text) count_cached = count_cached + 1 self.assertEqual(count_cached, count_orig) self.assertEqual(content_cached, content) def test_reader_with_dir(self): tmp_dir = tempfile.TemporaryDirectory() maybe_download( "https://en.wikipedia.org/wiki/Machine_learning", tmp_dir.name, "test_wikipedia.html", ) maybe_download("https://www.yahoo.com/", tmp_dir.name, "test_yahoo.html") for pack in self.pl1.process_dataset(tmp_dir.name): self.assertIsInstance(pack, DataPack) tmp_dir.cleanup() def test_reader_with_filepath(self): tmp_dir = tempfile.TemporaryDirectory() filepath = maybe_download("https://www.yahoo.com/", tmp_dir.name, "test_yahoo.html") for pack in self.pl1.process_dataset(filepath): self.assertIsInstance(pack, DataPack) tmp_dir.cleanup() @data( [ "<title>The Original Title </title>", "<!DOCTYPE html><html><title>Page Title</title><body><p>This is a " "paragraph</p></body></html>", ], ["<html>Test1</html>", "<html>Test12</html>", "<html>Test3</html>"], ) def test_reader_with_list(self, value): count_orig = 0 for _ in self.pl1.process_dataset(value): count_orig = count_orig + 1 self.assertEqual(count_orig, len(value))
class PlainTextReaderTest(unittest.TestCase): def setUp(self): # Create a temporary directory self.test_dir = tempfile.mkdtemp() self.orig_text = "<title>The Original Title </title>" file_path = os.path.join(self.test_dir, "test.html") with open(file_path, "w") as f: f.write(self.orig_text) def tearDown(self): # Remove the directory after the test shutil.rmtree(self.test_dir) def test_reader_no_replace_test(self): # Read with no replacements pipeline = Pipeline() reader = PlainTextReader() pipeline.set_reader(reader, {"file_ext": ".html"}) pipeline.initialize() pack = pipeline.process_one(self.test_dir) self.assertEqual(pack.text, self.orig_text) @data( # No replacement ([], "<title>The Original Title </title>"), # Insertion ([(Span(11, 11), "New ")], "<title>The New Original Title </title>"), # Single, sorted multiple and unsorted multiple replacements ([(Span(11, 19), "New")], "<title>The New Title </title>"), ([(Span(0, 7), ""), (Span(26, 34), "")], "The Original Title "), ([(Span(26, 34), ""), (Span(0, 7), "")], "The Original Title "), ) def test_reader_replace_back_test(self, value): # Reading with replacements - replacing a span and changing it back span_ops, output = value pipeline = Pipeline() reader = PlainTextReader() reader.text_replace_operation = lambda _: span_ops pipeline.set_reader(reader, {"file_ext": ".html"}) pipeline.initialize() pack: DataPack = pipeline.process_one(self.test_dir) self.assertEqual(pack.text, output) orig_text_from_pack = pack.get_original_text() self.assertEqual(self.orig_text, orig_text_from_pack) @data( # before span starts (Span(1, 6), Span(1, 6), "relaxed"), (Span(1, 6), Span(1, 6), "strict"), # after span ends (Span(15, 22), Span(19, 21), "relaxed"), # span itself (Span(11, 14), Span(11, 19), "relaxed"), # complete string (Span(0, 40), Span(0, 34), "strict"), # cases ending to or starting from between the span (Span(11, 40), Span(11, 34), "relaxed"), (Span(13, 40), Span(11, 34), "relaxed"), (Span(14, 40), Span(19, 34), "relaxed"), (Span(13, 40), Span(11, 34), "backward"), (Span(13, 40), Span(19, 34), "forward"), (Span(0, 12), Span(0, 19), "relaxed"), (Span(0, 13), Span(0, 11), "backward"), (Span(0, 14), Span(0, 19), "forward"), # same begin and end (Span(38, 38), Span(32, 32), "relaxed"), (Span(38, 38), Span(32, 32), "strict"), (Span(38, 38), Span(32, 32), "backward"), (Span(38, 38), Span(32, 32), "forward"), ) def test_reader_original_span_test(self, value): span_ops, output = ( [ (Span(11, 19), "New"), (Span(19, 20), " Shiny "), (Span(25, 25), " Ends"), ], "<title>The New Shiny Title Ends </title>", ) input_span, expected_span, mode = value pipeline = Pipeline() reader = PlainTextReader() reader.text_replace_operation = lambda _: span_ops pipeline.set_reader(reader, {"file_ext": ".html"}) pipeline.initialize() pack = pipeline.process_one(self.test_dir) self.assertEqual(pack.text, output) output_span = pack.get_original_span(input_span, mode) self.assertEqual( output_span, expected_span, f"Expected: ({expected_span.begin, expected_span.end}" f"), Found: ({output_span.begin, output_span.end})" f" when Input: ({input_span.begin, input_span.end})" f" and Mode: {mode}", ) @data( ([(Span(5, 8), ""), (Span(6, 10), "")], None), # overlap ([(Span(5, 8), ""), (Span(6, 1000), "")], None), # outside limit ) def test_reader_replace_error_test(self, value): # Read with errors in span replacements span_ops, output = value pipeline = Pipeline() reader = PlainTextReader() reader.text_replace_operation = lambda _: span_ops pipeline.set_reader(reader, {"file_ext": ".html"}) pipeline.initialize() with self.assertRaises(ValueError): pipeline.process(self.test_dir)
def _auto_align_annotations( self, data_pack: DataPack, replaced_annotations: SortedList, ) -> DataPack: r""" Function to replace some annotations with new strings. It will copy and update the text of datapack and auto-align the annotation spans. The links are also copied if its parent & child are both present in the new pack. The groups are copied if all its members are present in the new pack. Args: data_pack: The Datapack holding the replaced annotations. replaced_annotations: A SortedList of tuples(span, new string). The text and span of the annotations will be updated with the new string. Returns: A new data_pack holds the text after replacement. The annotations in the original data pack will be copied and auto-aligned as instructed by the "other_entry_policy" in the configuration. The links and groups will be copied if there members are copied. """ if len(replaced_annotations) == 0: return deepcopy(data_pack) spans: List[Span] = [span for span, _ in replaced_annotations] replacement_strs: List[str] = [ replacement_str for _, replacement_str in replaced_annotations ] # Get the new text for the new data pack. new_text: str = "" for i, span in enumerate(spans): new_span_str = replacement_strs[i] # First, get the gap text between last and this span. last_span_end: int = spans[i - 1].end if i > 0 else 0 gap_text: str = data_pack.text[last_span_end:span.begin] new_text += gap_text # Then, append the replaced new text. new_text += new_span_str # Finally, append to new_text the text after the last span. new_text += data_pack.text[spans[-1].end:] # Get the span (begin, end) before and after replacement. new_spans: List[Span] = [] # Bias is the delta between the beginning # indices before & after replacement. bias: int = 0 for i, span in enumerate(spans): old_begin: int = spans[i].begin old_end: int = spans[i].end new_begin: int = old_begin + bias new_end = new_begin + len(replacement_strs[i]) new_spans.append(Span(new_begin, new_end)) bias = new_end - old_end new_pack: DataPack = DataPack() new_pack.set_text(new_text) entries_to_copy: List[str] = \ list(self._other_entry_policy.keys()) + \ [self.configs['augment_entry']] entry_map: Dict[int, int] = {} insert_ind: int = 0 pid: int = data_pack.pack_id inserted_annos: List[Tuple[int, int]] = list( self._inserted_annos_pos_len[pid].items()) def _insert_new_span(insert_ind: int, inserted_annos: List[Tuple[int, int]], new_pack: DataPack, spans: List[Span], new_spans: List[Span]): r""" An internal helper function for insertion. """ pos: int length: int pos, length = inserted_annos[insert_ind] insert_end: int = modify_index( pos, spans, new_spans, is_begin=False, # Include the inserted span itself. is_inclusive=True) insert_begin: int = insert_end - length new_anno = create_class_with_kwargs(entry, { "pack": new_pack, "begin": insert_begin, "end": insert_end }) new_pack.add_entry(new_anno) # Iterate over all the original entries and modify their spans. for entry in entries_to_copy: for orig_anno in data_pack.get(get_class(entry)): # Dealing with insertion/deletion only for augment_entry. if entry == self.configs['augment_entry']: while insert_ind < len(inserted_annos) and \ inserted_annos[insert_ind][0] <= orig_anno.begin: # Preserve the order of the spans with merging sort. # It is a 2-way merging from the inserted spans # and original spans based on the begin index. _insert_new_span(insert_ind, inserted_annos, new_pack, spans, new_spans) insert_ind += 1 # Deletion if orig_anno.tid in self._deleted_annos_id[pid]: continue # Auto align the spans. span_new_begin: int = orig_anno.begin span_new_end: int = orig_anno.end if entry == self.configs['augment_entry'] \ or self._other_entry_policy[entry] \ == 'auto_align': # Only inclusive when the entry is not augmented. # E.g.: A Sentence include the inserted Token on the edge. # E.g.: A Token shouldn't include a nearby inserted Token. is_inclusive = entry != self.configs['augment_entry'] span_new_begin = modify_index(orig_anno.begin, spans, new_spans, True, is_inclusive) span_new_end = modify_index(orig_anno.end, spans, new_spans, False, is_inclusive) new_anno = create_class_with_kwargs(entry, { "pack": new_pack, "begin": span_new_begin, "end": span_new_end }) new_pack.add_entry(new_anno) entry_map[orig_anno.tid] = new_anno.tid # Deal with spans after the last annotation in the original pack. if entry == self.configs['augment_entry']: while insert_ind < len(inserted_annos): _insert_new_span(insert_ind, inserted_annos, new_pack, spans, new_spans) insert_ind += 1 # Iterate over and copy the links/groups in the datapack. for link in data_pack.get(Link): self._copy_link_or_group(link, entry_map, new_pack) for group in data_pack.get(Group): self._copy_link_or_group(group, entry_map, new_pack) self._data_pack_map[pid] = new_pack.pack_id self._entry_maps[pid] = entry_map return new_pack
def modify_index( index: int, # Both of the following spans should be SortedList. # Use List to avoid typing errors. old_spans: List[Span], new_spans: List[Span], is_begin: bool, is_inclusive: bool) -> int: r""" A helper function to map an index before replacement to the index after replacement. An index is the character offset in the data pack. The old_spans are the inputs of replacement, and the new_spans are the outputs. Each of the span has start and end index. The old_spans and new_spans are anchors for the mapping, because we depend on them to determine the position change of the index. Given an index, the function will find its the nearest among the old spans before the index, and calculate the difference between the position of the old span and its corresponding new span. The position change is then applied to the input index. An updated index is then calculated and returned. An inserted span might be included as a part of another span. For example, given a sentence "I love NLP.", if we insert a token "Yeah" at the beginning of the sentence(index=0), the Sentence should include the new Token, i.e., the Sentence will have a start index equals to 0. In this case, the parameter is_inclusive should be True. However, for another Token "I", it should not include the new token, so its start index will be larger than 0. The parameter in_inclusive should be False. The input index could be the start or end index of a span, i.e., the left or right boundary of the span. If there is an insertion in the span, we should treat the two boundaries in different ways. For example, we have a paragraph with two sentences "I love NLP! You love NLP too." If we append another "!" to the end of the first sentence, when modifying the end index of the first Sentence, it should be pushed right to include the extra exclamation. In this case, the is_begin is False. However, if we prepend an "And" to the second sentence, when modifying the start index of the second Sentence, it should be pushed left to include the new Token. In this case, the is_begin is True. Args: index (int): The index to map. old_spans (SortedList): The spans before replacement. It should be a sorted list in ascending order. new_spans (SortedList): The spans after replacement. It should be a sorted list in ascending order. is_begin (bool): True if the input index is the start index of a span. is_inclusive (bool): True if the span constructed by the aligned index should include inserted spans. Returns: The aligned index. If the old spans are [0, 1], [2, 3], [4, 6], the new spans are [0, 4], [5, 7], [8, 11], the input index is 3, and there are no insertions, the algorithm will first locate the last span with a begin index less or equal than the target index, ([2,3]), and find the corresponding span in new spans([5,7]). Then we calculate the delta index(7-3=4) and update our input index(3+4=7). The output then is 7. Note that when the input index locates inside the old spans, instead of on the boundary of the spans, we compute the return index so that it maintains the same offset to the begin of the span it belongs to. In the above example, if we change the input index from 3 to 5, the output will become 9, because we locates the input index in the third span [4, 6] and use the same offset 5-4=1 to calculate the output 8+1=9. When insertion is considered, there will be spans with the same begin index, for example, [0, 1], [1, 1], [1, 2]. The span [1, 1] indicates an insertion at index 1, because the insertion can be considered as a replacement of an empty input span, with a length of 0. The output will be affected by whether to include the inserted span(is_inclusive), and whether the input index is a begin or end index of its span(is_begin). If the old spans are [0, 1], [1, 1], [1, 2], the new spans are [0, 2], [2, 4], [4, 5], the input index is 1, the output will be 2 if both is_inclusive and is_begin are True, because the inserted [1, 1] should be included in the span. If the is_inclusive=True but is_begin=False, the output will be 4 because the index is an end index of the span. """ # Get the max index for binary search. max_index: int = old_spans[-1].end + 1 max_index = max(max_index, index) # This is the last span that has a start index less than # the input index. The position change of this span determines # the modification we will apply to the input index. last_span_ind: int = bisect_right(old_spans, Span(index, max_index)) - 1 # If there is an inserted span, it will always be the first of # those spans with the same begin index. For example, given spans # [1, 1], [1, 2], The inserted span [1, 1] will be in the front of # replaced span [1, 2], because it has the smallest end index. if last_span_ind >= 0: if is_inclusive: if is_begin: # When inclusive, move the begin index # to the left to include the inserted span. if last_span_ind > 0 and \ old_spans[last_span_ind - 1].begin == index: # Old spans: [0, 1], [1, 1], [1, 3] # Target index: 1 # Change last_span_index from 2 to 1 # to include the [1, 1] span. last_span_ind -= 1 else: # Old spans: [0, 1], [1, 1], [2, 3] # Target index: 1 # last_span_index: 1 # No need to change. pass else: if not is_begin: # When exclusive, move the end index # to the left to exclude the inserted span. if last_span_ind > 0 and \ old_spans[last_span_ind - 1].begin == index: # Old spans: [0, 1], [1, 1], [1, 3] # Target index: 1 # Change last_span_index from 2 to 0 # to exclude the [1, 1] span. last_span_ind -= 2 elif old_spans[last_span_ind].begin == index and \ old_spans[last_span_ind].end == index: # Old spans: [0, 1], [1, 1], [2, 3] # Target index: 1 # Change last_span_index from 1 to 0 # to exclude the [1, 1] span. last_span_ind -= 1 if last_span_ind < 0: # There is no replacement before this index. return index # Find the nearest anchor point on the left of current index. # Start from the span's begin index. delta_index: int = new_spans[last_span_ind].begin - \ old_spans[last_span_ind].begin if old_spans[last_span_ind].begin == old_spans[last_span_ind].end \ and old_spans[last_span_ind].begin == index \ and is_begin \ and is_inclusive: return index + delta_index if old_spans[last_span_ind].end <= index: # Use the span's end index as anchor, if possible. delta_index = new_spans[last_span_ind].end - \ old_spans[last_span_ind].end return index + delta_index
def span(self) -> Span: # Delay span creation at usage. if self._span is None: self._span = Span(self._begin, self._end) return self._span