Example #1
0
    def test_reader_original_span_test(self, value):
        span_ops, output = ([(Span(11, 19), 'New'),
                             (Span(19, 20), ' Shiny '),
                             (Span(25, 25), ' Ends')],
                            '<title>The New Shiny Title Ends </title>')
        input_span, expected_span, mode = value
        reader = PlainTextReader()
        reader.text_replace_operation = lambda _: span_ops
        pack = list(reader.parse_pack(self.file_path))[0]
        self.assertEqual(pack.text, output)

        output_span = pack.get_original_span(input_span, mode)
        self.assertEqual(output_span, expected_span,
                         f"Expected: ({expected_span.begin, expected_span.end}"
                         f"), Found: ({output_span.begin, output_span.end})"
                         f" when Input: ({input_span.begin, input_span.end})"
                         f" and Mode: {mode}")
Example #2
0
 def set_span(self, begin: int, end: int):
     """
     Set the span of the annotation.
     """
     if begin > end:
         raise ValueError(
             f"The begin {begin} of span is greater than the end {end}")
     self._span = Span(begin, end)
Example #3
0
def modify_text_and_track_ops(original_text: str,
                              replace_operations: ReplaceOperationsType) -> \
        Tuple[str, ReplaceOperationsType, List[Tuple[Span, Span]], int]:
    r"""Modifies the original text using replace_operations provided by the user
    to return modified text and other data required for tracking original text

    Args:
        original_text: Text to be modified
        replace_operations: A list of spans and the corresponding replacement
        string that the span in the original string is to be replaced with to
        obtain the original string

    Returns:
        modified_text: Text after modification
        replace_back_operations: A list of spans and the corresponding
        replacement string that the span in the modified string is to be
        replaced with to obtain the original string
        processed_original_spans: List of processed span and its corresponding
        original span
        orig_text_len: length of original text
    """
    orig_text_len: int = len(original_text)
    mod_text: str = original_text
    increment: int = 0
    prev_span_end: int = 0
    replace_back_operations: List[Tuple[Span, str]] = []
    processed_original_spans: List[Tuple[Span, Span]] = []

    # Sorting the spans such that the order of replacement strings
    # is maintained -> utilizing the stable sort property of python sort
    replace_operations.sort(key=lambda item: item[0])

    for span, replacement in replace_operations:
        if span.begin < 0 or span.end < 0:
            raise ValueError("Negative indexing not supported")
        if span.begin > len(original_text) or span.end > len(original_text):
            raise ValueError(
                "One of the span indices are outside the string length")
        if span.end < span.begin:
            print(span.begin, span.end)
            raise ValueError(
                "One of the end indices is lesser than start index")
        if span.begin < prev_span_end:
            raise ValueError(
                "The replacement spans should be mutually exclusive")
        span_begin = span.begin + increment
        span_end = span.end + increment
        original_span_text = mod_text[span_begin:span_end]
        mod_text = mod_text[:span_begin] + replacement + mod_text[span_end:]
        increment += len(replacement) - (span.end - span.begin)
        replacement_span = Span(span_begin, span_begin + len(replacement))
        replace_back_operations.append((replacement_span, original_span_text))
        processed_original_spans.append((replacement_span, span))
        prev_span_end = span.end

    return (mod_text, replace_back_operations,
            sorted(processed_original_spans), orig_text_len)
Example #4
0
    def predict(self, data_batch: Dict) -> Dict[str, List[Prediction]]:
        text: List[List[str]] = [
            sentence.tolist() for sentence in data_batch["Token"]["text"]
        ]
        text_ids, length = tx.data.padded_batch([
            self.word_vocab.map_tokens_to_ids_py(sentence) for sentence in text
        ])
        text_ids = torch.from_numpy(text_ids).to(device=self.device)
        length = torch.tensor(length, dtype=torch.long, device=self.device)
        batch_size = len(text)
        batch = tx.data.Batch(batch_size,
                              text=text,
                              text_ids=text_ids,
                              length=length,
                              srl=[[]] * batch_size)
        self.model = self.model.cuda()
        batch_srl_spans = self.model.decode(batch)

        # Convert predictions into annotations.
        batch_predictions: List[Prediction] = []
        for idx, srl_spans in enumerate(batch_srl_spans):
            word_spans = data_batch["Token"]["span"][idx]
            predictions: Prediction = []
            for pred_idx, pred_args in srl_spans.items():
                begin, end = word_spans[pred_idx]
                # TODO cannot create annotation here.
                pred_span = Span(begin, end)
                arguments = []
                for arg in pred_args:
                    begin = word_spans[arg.start][0]
                    end = word_spans[arg.end][1]
                    arg_annotation = Span(begin, end)
                    arguments.append((arg_annotation, arg.label))
                predictions.append((pred_span, arguments))
            batch_predictions.append(predictions)
        return {"predictions": batch_predictions}
Example #5
0
 def __init__(self, pack: PackType, begin: int, end: int):
     super().__init__(pack)
     if begin > end:
         raise ValueError(
             f"The begin {begin} of span is greater than the end {end}")
     self._span = Span(begin, end)
Example #6
0
    def get_original_span(self,
                          input_processed_span: Span,
                          align_mode: str = "relaxed"):
        """
        Function to obtain span of the original text that aligns with the
        given span of the processed text.

        Args:
            input_processed_span: Span of the processed text for which the
            corresponding span of the original text is desired
            align_mode: The strictness criteria for alignment in the ambiguous
            cases, that is, if a part of input_processed_span spans a part
            of the inserted span, then align_mode controls whether to use the
            span fully or ignore it completely according to the following
            possible values

            - "strict" - do not allow ambiguous input, give ValueError
            - "relaxed" - consider spans on both sides
            - "forward" - align looking forward, that is, ignore the span
            towards the left, but consider the span towards the right
            - "backward" - align looking backwards, that is, ignore the span
            towards the right, but consider the span towards the left

        Returns:
            Span of the original text that aligns with input_processed_span

        Example:
            * Let o-up1, o-up2, ... and m-up1, m-up2, ... denote the unprocessed
              spans of the original and modified string respectively. Note that
              each o-up would have a corresponding m-up of the same size.
            * Let o-pr1, o-pr2, ... and m-pr1, m-pr2, ... denote the processed
              spans of the original and modified string respectively. Note that
              each o-p is modified to a corresponding m-pr that may be of a
              different size than o-pr.
            * Original string:
              <--o-up1--> <-o-pr1-> <----o-up2----> <----o-pr2----> <-o-up3->
            * Modified string:
              <--m-up1--> <----m-pr1----> <----m-up2----> <-m-pr2-> <-m-up3->
            * Note that `self.inverse_original_spans` that contains modified
              processed spans and their corresponding original spans, would look
              like - [(o-pr1, m-pr1), (o-pr2, m-pr2)]

            >> data_pack = DataPack()
            >> original_text = "He plays in the park"
            >> data_pack.set_text(original_text,\
            >>                    lambda _: [(Span(0, 2), "She"))]
            >> data_pack.text
            "She plays in the park"
            >> input_processed_span = Span(0, len("She plays"))
            >> orig_span = data_pack.get_original_span(input_processed_span)
            >> data_pack.get_original_text()[orig_span.begin: orig_span.end]
            "He plays"
        """
        assert align_mode in ["relaxed", "strict", "backward", "forward"]

        req_begin = input_processed_span.begin
        req_end = input_processed_span.end

        def get_original_index(input_index: int, is_begin_index: bool,
                               mode: str) -> int:
            """
            Args:
                input_index: begin or end index of the input span
                is_begin_index: if the index is the begin index of the input
                span or the end index of the input span
                mode: alignment mode
            Returns:
                Original index that aligns with input_index
            """
            if len(self.processed_original_spans) == 0:
                return input_index

            len_processed_text = len(self._text)
            orig_index = None
            prev_end = 0
            for (inverse_span, original_span) in self.processed_original_spans:
                # check if the input_index lies between one of the unprocessed
                # spans
                if prev_end <= input_index < inverse_span.begin:
                    increment = original_span.begin - inverse_span.begin
                    orig_index = input_index + increment
                # check if the input_index lies between one of the processed
                # spans
                elif inverse_span.begin <= input_index < inverse_span.end:
                    # look backward - backward shift of input_index
                    if is_begin_index and mode in ["backward", "relaxed"]:
                        orig_index = original_span.begin
                    if not is_begin_index and mode == "backward":
                        orig_index = original_span.begin - 1

                    # look forward - forward shift of input_index
                    if is_begin_index and mode == "forward":
                        orig_index = original_span.end
                    if not is_begin_index and mode in ["forward", "relaxed"]:
                        orig_index = original_span.end - 1

                # break if the original index is populated
                if orig_index is not None:
                    break
                prev_end = inverse_span.end

            if orig_index is None:
                # check if the input_index lies between the last unprocessed
                # span
                inverse_span, original_span = self.processed_original_spans[-1]
                if inverse_span.end <= input_index < len_processed_text:
                    increment = original_span.end - inverse_span.end
                    orig_index = input_index + increment
                else:
                    # check if there input_index is not valid given the
                    # alignment mode or lies outside the processed string
                    raise ValueError(f"The input span either does not adhere "
                                     f"to the {align_mode} alignment mode or "
                                     f"lies outside to the processed string.")
            return orig_index

        orig_begin = get_original_index(req_begin, True, align_mode)
        orig_end = get_original_index(req_end - 1, False, align_mode) + 1

        return Span(orig_begin, orig_end)
Example #7
0
class PlainTextReaderTest(unittest.TestCase):
    def setUp(self):
        # Create a temporary directory
        self.test_dir = tempfile.mkdtemp()
        self.orig_text = "<title>The Original Title </title>"
        self.file_path = os.path.join(self.test_dir, 'test.html')
        self.mod_file_path = os.path.join(self.test_dir, 'mod_test.html')
        with open(self.file_path, 'w') as f:
            f.write(self.orig_text)

    def tearDown(self):
        # Remove the directory after the test
        shutil.rmtree(self.test_dir)

    def test_reader_no_replace_test(self):
        # Read with no replacements
        pack = list(PlainTextReader().parse_pack(self.file_path))[0]
        self.assertEqual(pack.text, self.orig_text)

    @data(
        # No replacement
        ([], '<title>The Original Title </title>'),
        # Insertion
        ([(Span(11, 11), 'New ')], '<title>The New Original Title </title>'),
        # Single, sorted multiple and unsorted multiple replacements
        ([(Span(11, 19), 'New')], '<title>The New Title </title>'),
        ([(Span(0, 7), ''), (Span(26, 34), '')], 'The Original Title '),
        ([(Span(26, 34), ''), (Span(0, 7), '')], 'The Original Title '),
    )
    def test_reader_replace_back_test(self, value):
        # Reading with replacements - replacing a span and changing it back
        span_ops, output = value
        reader = PlainTextReader()
        reader.text_replace_operation = lambda _: span_ops
        pack = list(reader.parse_pack(self.file_path))[0]
        self.assertEqual(pack.text, output)

        orig_text_from_pack = pack.get_original_text()
        self.assertEqual(self.orig_text, orig_text_from_pack)

    @data(
        # before span starts
        (Span(1, 6), Span(1, 6), "relaxed"),
        (Span(1, 6), Span(1, 6), "strict"),
        # after span ends
        (Span(15, 22), Span(19, 21), "relaxed"),
        # span itself
        (Span(11, 14), Span(11, 19), "relaxed"),
        # complete string
        (Span(0, 40), Span(0, 34), "strict"),
        # cases ending to or starting from between the span
        (Span(11, 40), Span(11, 34), "relaxed"),
        (Span(13, 40), Span(11, 34), "relaxed"),
        (Span(14, 40), Span(19, 34), "relaxed"),
        (Span(13, 40), Span(11, 34), "backward"),
        (Span(13, 40), Span(19, 34), "forward"),
        (Span(0, 12), Span(0, 19), "relaxed"),
        (Span(0, 13), Span(0, 11), "backward"),
        (Span(0, 14), Span(0, 19), "forward"),
        # same begin and end
        (Span(38, 38), Span(32, 32), "relaxed"),
        (Span(38, 38), Span(32, 32), "strict"),
        (Span(38, 38), Span(32, 32), "backward"),
        (Span(38, 38), Span(32, 32), "forward")
    )
    def test_reader_original_span_test(self, value):
        span_ops, output = ([(Span(11, 19), 'New'),
                             (Span(19, 20), ' Shiny '),
                             (Span(25, 25), ' Ends')],
                            '<title>The New Shiny Title Ends </title>')
        input_span, expected_span, mode = value
        reader = PlainTextReader()
        reader.text_replace_operation = lambda _: span_ops
        pack = list(reader.parse_pack(self.file_path))[0]
        self.assertEqual(pack.text, output)

        output_span = pack.get_original_span(input_span, mode)
        self.assertEqual(output_span, expected_span,
                         f"Expected: ({expected_span.begin, expected_span.end}"
                         f"), Found: ({output_span.begin, output_span.end})"
                         f" when Input: ({input_span.begin, input_span.end})"
                         f" and Mode: {mode}")

    @data(
        ([(Span(5, 8), ''), (Span(6, 10), '')], None),  # overlap
        ([(Span(5, 8), ''), (Span(6, 1000), '')], None),  # outside limit
        ([(Span(-1, 8), '')], None),  # does not support negative indexing
        ([(Span(8, -1), '')], None),  # does not support negative indexing
        ([(Span(2, 1), '')], None)  # start should be lesser than end
    )
    def test_reader_replace_error_test(self, value):
        # Read with errors in span replacements
        span_ops, output = value
        reader = PlainTextReader()
        reader.text_replace_operation = lambda _: span_ops
        try:
            list(reader.parse_pack(self.file_path))[0]
        except ValueError:
            pass
        except Exception:
            self.fail('Unexpected exception raised:')
        else:
            self.fail('Expected Exception not raised')