def test_reader_original_span_test(self, value): span_ops, output = ([(Span(11, 19), 'New'), (Span(19, 20), ' Shiny '), (Span(25, 25), ' Ends')], '<title>The New Shiny Title Ends </title>') input_span, expected_span, mode = value reader = PlainTextReader() reader.text_replace_operation = lambda _: span_ops pack = list(reader.parse_pack(self.file_path))[0] self.assertEqual(pack.text, output) output_span = pack.get_original_span(input_span, mode) self.assertEqual(output_span, expected_span, f"Expected: ({expected_span.begin, expected_span.end}" f"), Found: ({output_span.begin, output_span.end})" f" when Input: ({input_span.begin, input_span.end})" f" and Mode: {mode}")
def set_span(self, begin: int, end: int): """ Set the span of the annotation. """ if begin > end: raise ValueError( f"The begin {begin} of span is greater than the end {end}") self._span = Span(begin, end)
def modify_text_and_track_ops(original_text: str, replace_operations: ReplaceOperationsType) -> \ Tuple[str, ReplaceOperationsType, List[Tuple[Span, Span]], int]: r"""Modifies the original text using replace_operations provided by the user to return modified text and other data required for tracking original text Args: original_text: Text to be modified replace_operations: A list of spans and the corresponding replacement string that the span in the original string is to be replaced with to obtain the original string Returns: modified_text: Text after modification replace_back_operations: A list of spans and the corresponding replacement string that the span in the modified string is to be replaced with to obtain the original string processed_original_spans: List of processed span and its corresponding original span orig_text_len: length of original text """ orig_text_len: int = len(original_text) mod_text: str = original_text increment: int = 0 prev_span_end: int = 0 replace_back_operations: List[Tuple[Span, str]] = [] processed_original_spans: List[Tuple[Span, Span]] = [] # Sorting the spans such that the order of replacement strings # is maintained -> utilizing the stable sort property of python sort replace_operations.sort(key=lambda item: item[0]) for span, replacement in replace_operations: if span.begin < 0 or span.end < 0: raise ValueError("Negative indexing not supported") if span.begin > len(original_text) or span.end > len(original_text): raise ValueError( "One of the span indices are outside the string length") if span.end < span.begin: print(span.begin, span.end) raise ValueError( "One of the end indices is lesser than start index") if span.begin < prev_span_end: raise ValueError( "The replacement spans should be mutually exclusive") span_begin = span.begin + increment span_end = span.end + increment original_span_text = mod_text[span_begin:span_end] mod_text = mod_text[:span_begin] + replacement + mod_text[span_end:] increment += len(replacement) - (span.end - span.begin) replacement_span = Span(span_begin, span_begin + len(replacement)) replace_back_operations.append((replacement_span, original_span_text)) processed_original_spans.append((replacement_span, span)) prev_span_end = span.end return (mod_text, replace_back_operations, sorted(processed_original_spans), orig_text_len)
def predict(self, data_batch: Dict) -> Dict[str, List[Prediction]]: text: List[List[str]] = [ sentence.tolist() for sentence in data_batch["Token"]["text"] ] text_ids, length = tx.data.padded_batch([ self.word_vocab.map_tokens_to_ids_py(sentence) for sentence in text ]) text_ids = torch.from_numpy(text_ids).to(device=self.device) length = torch.tensor(length, dtype=torch.long, device=self.device) batch_size = len(text) batch = tx.data.Batch(batch_size, text=text, text_ids=text_ids, length=length, srl=[[]] * batch_size) self.model = self.model.cuda() batch_srl_spans = self.model.decode(batch) # Convert predictions into annotations. batch_predictions: List[Prediction] = [] for idx, srl_spans in enumerate(batch_srl_spans): word_spans = data_batch["Token"]["span"][idx] predictions: Prediction = [] for pred_idx, pred_args in srl_spans.items(): begin, end = word_spans[pred_idx] # TODO cannot create annotation here. pred_span = Span(begin, end) arguments = [] for arg in pred_args: begin = word_spans[arg.start][0] end = word_spans[arg.end][1] arg_annotation = Span(begin, end) arguments.append((arg_annotation, arg.label)) predictions.append((pred_span, arguments)) batch_predictions.append(predictions) return {"predictions": batch_predictions}
def __init__(self, pack: PackType, begin: int, end: int): super().__init__(pack) if begin > end: raise ValueError( f"The begin {begin} of span is greater than the end {end}") self._span = Span(begin, end)
def get_original_span(self, input_processed_span: Span, align_mode: str = "relaxed"): """ Function to obtain span of the original text that aligns with the given span of the processed text. Args: input_processed_span: Span of the processed text for which the corresponding span of the original text is desired align_mode: The strictness criteria for alignment in the ambiguous cases, that is, if a part of input_processed_span spans a part of the inserted span, then align_mode controls whether to use the span fully or ignore it completely according to the following possible values - "strict" - do not allow ambiguous input, give ValueError - "relaxed" - consider spans on both sides - "forward" - align looking forward, that is, ignore the span towards the left, but consider the span towards the right - "backward" - align looking backwards, that is, ignore the span towards the right, but consider the span towards the left Returns: Span of the original text that aligns with input_processed_span Example: * Let o-up1, o-up2, ... and m-up1, m-up2, ... denote the unprocessed spans of the original and modified string respectively. Note that each o-up would have a corresponding m-up of the same size. * Let o-pr1, o-pr2, ... and m-pr1, m-pr2, ... denote the processed spans of the original and modified string respectively. Note that each o-p is modified to a corresponding m-pr that may be of a different size than o-pr. * Original string: <--o-up1--> <-o-pr1-> <----o-up2----> <----o-pr2----> <-o-up3-> * Modified string: <--m-up1--> <----m-pr1----> <----m-up2----> <-m-pr2-> <-m-up3-> * Note that `self.inverse_original_spans` that contains modified processed spans and their corresponding original spans, would look like - [(o-pr1, m-pr1), (o-pr2, m-pr2)] >> data_pack = DataPack() >> original_text = "He plays in the park" >> data_pack.set_text(original_text,\ >> lambda _: [(Span(0, 2), "She"))] >> data_pack.text "She plays in the park" >> input_processed_span = Span(0, len("She plays")) >> orig_span = data_pack.get_original_span(input_processed_span) >> data_pack.get_original_text()[orig_span.begin: orig_span.end] "He plays" """ assert align_mode in ["relaxed", "strict", "backward", "forward"] req_begin = input_processed_span.begin req_end = input_processed_span.end def get_original_index(input_index: int, is_begin_index: bool, mode: str) -> int: """ Args: input_index: begin or end index of the input span is_begin_index: if the index is the begin index of the input span or the end index of the input span mode: alignment mode Returns: Original index that aligns with input_index """ if len(self.processed_original_spans) == 0: return input_index len_processed_text = len(self._text) orig_index = None prev_end = 0 for (inverse_span, original_span) in self.processed_original_spans: # check if the input_index lies between one of the unprocessed # spans if prev_end <= input_index < inverse_span.begin: increment = original_span.begin - inverse_span.begin orig_index = input_index + increment # check if the input_index lies between one of the processed # spans elif inverse_span.begin <= input_index < inverse_span.end: # look backward - backward shift of input_index if is_begin_index and mode in ["backward", "relaxed"]: orig_index = original_span.begin if not is_begin_index and mode == "backward": orig_index = original_span.begin - 1 # look forward - forward shift of input_index if is_begin_index and mode == "forward": orig_index = original_span.end if not is_begin_index and mode in ["forward", "relaxed"]: orig_index = original_span.end - 1 # break if the original index is populated if orig_index is not None: break prev_end = inverse_span.end if orig_index is None: # check if the input_index lies between the last unprocessed # span inverse_span, original_span = self.processed_original_spans[-1] if inverse_span.end <= input_index < len_processed_text: increment = original_span.end - inverse_span.end orig_index = input_index + increment else: # check if there input_index is not valid given the # alignment mode or lies outside the processed string raise ValueError(f"The input span either does not adhere " f"to the {align_mode} alignment mode or " f"lies outside to the processed string.") return orig_index orig_begin = get_original_index(req_begin, True, align_mode) orig_end = get_original_index(req_end - 1, False, align_mode) + 1 return Span(orig_begin, orig_end)
class PlainTextReaderTest(unittest.TestCase): def setUp(self): # Create a temporary directory self.test_dir = tempfile.mkdtemp() self.orig_text = "<title>The Original Title </title>" self.file_path = os.path.join(self.test_dir, 'test.html') self.mod_file_path = os.path.join(self.test_dir, 'mod_test.html') with open(self.file_path, 'w') as f: f.write(self.orig_text) def tearDown(self): # Remove the directory after the test shutil.rmtree(self.test_dir) def test_reader_no_replace_test(self): # Read with no replacements pack = list(PlainTextReader().parse_pack(self.file_path))[0] self.assertEqual(pack.text, self.orig_text) @data( # No replacement ([], '<title>The Original Title </title>'), # Insertion ([(Span(11, 11), 'New ')], '<title>The New Original Title </title>'), # Single, sorted multiple and unsorted multiple replacements ([(Span(11, 19), 'New')], '<title>The New Title </title>'), ([(Span(0, 7), ''), (Span(26, 34), '')], 'The Original Title '), ([(Span(26, 34), ''), (Span(0, 7), '')], 'The Original Title '), ) def test_reader_replace_back_test(self, value): # Reading with replacements - replacing a span and changing it back span_ops, output = value reader = PlainTextReader() reader.text_replace_operation = lambda _: span_ops pack = list(reader.parse_pack(self.file_path))[0] self.assertEqual(pack.text, output) orig_text_from_pack = pack.get_original_text() self.assertEqual(self.orig_text, orig_text_from_pack) @data( # before span starts (Span(1, 6), Span(1, 6), "relaxed"), (Span(1, 6), Span(1, 6), "strict"), # after span ends (Span(15, 22), Span(19, 21), "relaxed"), # span itself (Span(11, 14), Span(11, 19), "relaxed"), # complete string (Span(0, 40), Span(0, 34), "strict"), # cases ending to or starting from between the span (Span(11, 40), Span(11, 34), "relaxed"), (Span(13, 40), Span(11, 34), "relaxed"), (Span(14, 40), Span(19, 34), "relaxed"), (Span(13, 40), Span(11, 34), "backward"), (Span(13, 40), Span(19, 34), "forward"), (Span(0, 12), Span(0, 19), "relaxed"), (Span(0, 13), Span(0, 11), "backward"), (Span(0, 14), Span(0, 19), "forward"), # same begin and end (Span(38, 38), Span(32, 32), "relaxed"), (Span(38, 38), Span(32, 32), "strict"), (Span(38, 38), Span(32, 32), "backward"), (Span(38, 38), Span(32, 32), "forward") ) def test_reader_original_span_test(self, value): span_ops, output = ([(Span(11, 19), 'New'), (Span(19, 20), ' Shiny '), (Span(25, 25), ' Ends')], '<title>The New Shiny Title Ends </title>') input_span, expected_span, mode = value reader = PlainTextReader() reader.text_replace_operation = lambda _: span_ops pack = list(reader.parse_pack(self.file_path))[0] self.assertEqual(pack.text, output) output_span = pack.get_original_span(input_span, mode) self.assertEqual(output_span, expected_span, f"Expected: ({expected_span.begin, expected_span.end}" f"), Found: ({output_span.begin, output_span.end})" f" when Input: ({input_span.begin, input_span.end})" f" and Mode: {mode}") @data( ([(Span(5, 8), ''), (Span(6, 10), '')], None), # overlap ([(Span(5, 8), ''), (Span(6, 1000), '')], None), # outside limit ([(Span(-1, 8), '')], None), # does not support negative indexing ([(Span(8, -1), '')], None), # does not support negative indexing ([(Span(2, 1), '')], None) # start should be lesser than end ) def test_reader_replace_error_test(self, value): # Read with errors in span replacements span_ops, output = value reader = PlainTextReader() reader.text_replace_operation = lambda _: span_ops try: list(reader.parse_pack(self.file_path))[0] except ValueError: pass except Exception: self.fail('Unexpected exception raised:') else: self.fail('Expected Exception not raised')