Example #1
0
 def test_usage(self):
     text = "a&p:\n # dm2 "
     tokens = tokenizer_lib.tokenize(text)
     token_span = (3, 11)  # ":\n # dm2 "
     self.assertEqual(
         ap_parsing_utils.normalize_token_span(token_span, tokens),
         (8, 10))  # "dm2"
Example #2
0
    def test_usage(self):
        #       0   12     34     56 78   90
        text = "some longer tokens in this test"
        #       0123456789012345678901234567890
        #       0         1         2         3

        tokens = tokenizer_lib.tokenize(text)

        labeled_char_spans = [
            ap_parsing_lib.LabeledCharSpan(
                span_type=ap_parsing_lib.LabeledSpanType.UNKNOWN_TYPE,
                start_char=5,
                end_char=18),  # "longer tokens" - already normalized.
            ap_parsing_lib.LabeledCharSpan(
                span_type=ap_parsing_lib.LabeledSpanType.UNKNOWN_TYPE,
                start_char=14,
                end_char=25),  # "kens in thi" -> "tokens in this"
            ap_parsing_lib.LabeledCharSpan(
                span_type=ap_parsing_lib.LabeledSpanType.UNKNOWN_TYPE,
                start_char=18,
                end_char=19),  # Invalid - only space.
        ]
        expected = [
            ap_parsing_lib.LabeledCharSpan(
                span_type=ap_parsing_lib.LabeledSpanType.UNKNOWN_TYPE,
                start_char=5,
                end_char=18),  # "longer tokens"
            ap_parsing_lib.LabeledCharSpan(
                span_type=ap_parsing_lib.LabeledSpanType.UNKNOWN_TYPE,
                start_char=12,
                end_char=26)  # "tokens in this"
        ]
        self.assertEqual(
            ap_parsing_utils.normalize_labeled_char_spans_iterable(
                labeled_char_spans, tokens), expected)
Example #3
0
    def test_one_sided(self):
        text = "a&p:\n # dm2 "
        tokens = tokenizer_lib.tokenize(text)

        # Remove suffix only.
        labeled_char_span = ap_parsing_lib.LabeledCharSpan(
            start_char=8,
            end_char=12,
            span_type=ap_parsing_lib.LabeledSpanType.UNKNOWN_TYPE)  # "dm2 "
        self.assertEqual(
            ap_parsing_utils.normalize_labeled_char_span(
                labeled_char_span, tokens),
            ap_parsing_lib.LabeledCharSpan(
                start_char=8,
                end_char=11,
                span_type=ap_parsing_lib.LabeledSpanType.UNKNOWN_TYPE)
        )  # "dm2"

        # Remove prefix only.
        labeled_char_span = ap_parsing_lib.LabeledCharSpan(
            start_char=3,
            end_char=11,
            span_type=ap_parsing_lib.LabeledSpanType.UNKNOWN_TYPE
        )  # ":\n # dm2"
        self.assertEqual(
            ap_parsing_utils.normalize_labeled_char_span(
                labeled_char_span, tokens),
            ap_parsing_lib.LabeledCharSpan(
                start_char=8,
                end_char=11,
                span_type=ap_parsing_lib.LabeledSpanType.UNKNOWN_TYPE)
        )  # "dm2"
Example #4
0
    def test_usage(self):
        #       0   12     34     56 78   90
        text = "some longer tokens in this test"
        tokens = tokenizer_lib.tokenize(text)

        self.assertEqual(
            ap_parsing_utils.char_span_to_token_span(tokens, (5, 18)), (2, 5))
Example #5
0
        def yield_sections(ap_data, note_ratings=None):
            for section in extract_ap_sections(note.text, section_markers):
                # Copy per section.
                cur_ap_data = dataclasses.replace(ap_data)
                cur_ap_data.ap_text = note.text[section.char_start:section.
                                                char_end]
                cur_ap_data.tokens = tokenizer_lib.tokenize(
                    cur_ap_data.ap_text)
                cur_ap_data.char_offset = section.char_start

                # Filter inorganic AP sections by word counts of non CAPS tokens >= 10
                if filter_inorganic(cur_ap_data.tokens,
                                    threshold=self.filter_inorganic_threshold):
                    continue

                if cur_ap_data.is_rated:
                    cur_ap_data.labeled_char_spans = process_rating_labels(
                        note_ratings, section)
                else:
                    cur_ap_data.partition = Partition.NONRATED
                    cur_ap_data.labeled_char_spans = annotate_ap(
                        cur_ap_data.ap_text)

                cur_ap_data.labeled_char_spans = (
                    ap_parsing_utils.normalize_labeled_char_spans_iterable(
                        cur_ap_data.labeled_char_spans, cur_ap_data.tokens))

                yield (f"{cur_ap_data.note_id}|{section.char_start}",
                       cur_ap_data)
    def build(cls, text, labeled_char_spans):
        """Builds structured AP inplace from text and labels.

    Bundles together labels into clusters based on problem titles.

    Args:
      text: str, text of A&P section
      labeled_char_spans: LabeledCharSpans, which are converted to cluster
        fragments.

    Returns:
      An instance of StructuredAP.
    """

        tokens = tokenizer_lib.tokenize(text)

        labeled_char_spans = ap_parsing_utils.normalize_labeled_char_spans_iterable(
            labeled_char_spans, tokens)
        labeled_char_spans.sort(key=lambda x: x.start_char)

        structured_ap = cls(problem_clusters=list(), prefix_text="")
        structured_ap._parse_problem_clusters(labeled_char_spans, text)  # pylint: disable=protected-access

        prefix_text_span = ap_parsing_utils.normalize_labeled_char_span(
            ap_parsing_lib.LabeledCharSpan(
                start_char=0,
                end_char=labeled_char_spans[0].start_char,
                span_type=ap_parsing_lib.LabeledSpanType.UNKNOWN_TYPE), tokens)
        structured_ap.prefix_text = (
            text[prefix_text_span.start_char:prefix_text_span.end_char]
            if prefix_text_span else "")

        return structured_ap
Example #7
0
    def test_labeled_char_spans_to_token_spans(self):
        #  Char space:
        #       0         1          2          3
        #       01234567890123456 78901234 56789012345
        text = "# DM2: on insulin\n # COPD\n- nebs prn"
        # Token:012 3456 78       90123    4567   89
        #       0                  1
        tokens = tokenizer_lib.tokenize(text)

        labeled_char_spans = [
            ap_parsing_lib.LabeledCharSpan(
                span_type=ap_parsing_lib.LabeledSpanType.PROBLEM_TITLE,
                start_char=2,
                end_char=17),  # "DM2: on insulin"
            ap_parsing_lib.LabeledCharSpan(
                span_type=ap_parsing_lib.LabeledSpanType.PROBLEM_TITLE,
                start_char=21,
                end_char=25),  # "COPD"
            ap_parsing_lib.LabeledCharSpan(
                span_type=ap_parsing_lib.LabeledSpanType.ACTION_ITEM,
                start_char=28,
                end_char=32),  # "nebs"
        ]

        labeled_token_spans = [
            ap_parsing_lib.LabeledTokenSpan(
                span_type=ap_parsing_lib.LabeledSpanType.PROBLEM_TITLE,
                start_token=2,
                end_token=9),  # "DM2: on insulin"
            ap_parsing_lib.LabeledTokenSpan(
                span_type=ap_parsing_lib.LabeledSpanType.PROBLEM_TITLE,
                start_token=13,
                end_token=14),  # "COPD"
            ap_parsing_lib.LabeledTokenSpan(
                span_type=ap_parsing_lib.LabeledSpanType.ACTION_ITEM,
                start_token=17,
                end_token=18),  # "nebs"
        ]

        labeled_char_span_to_labeled_token_span = functools.partial(
            ap_parsing_utils.labeled_char_span_to_labeled_token_span,
            tokens=tokens)
        self.assertEqual(labeled_token_spans, [
            labeled_char_span_to_labeled_token_span(labeled_char_span)
            for labeled_char_span in labeled_char_spans
        ])

        labeled_token_span_to_labeled_char_span = functools.partial(
            ap_parsing_utils.labeled_token_span_to_labeled_char_span,
            tokens=tokens)
        self.assertEqual(labeled_char_spans, [
            labeled_token_span_to_labeled_char_span(labeled_token_span)
            for labeled_token_span in labeled_token_spans
        ])
Example #8
0
    def test_one_sided(self):
        text = "a&p:\n # dm2. "
        tokens = tokenizer_lib.tokenize(text)
        token_span = (6, 10)
        self.assertEqual(
            ap_parsing_utils.normalize_token_span(token_span, tokens),
            (8, 10))  # "dm2"

        token_span = (8, 11)
        self.assertEqual(
            ap_parsing_utils.normalize_token_span(token_span, tokens),
            (8, 10))  # "dm2"
Example #9
0
    def test_midtoken(self):
        #       0   12     34     56 78   90
        text = "some longer tokens in this test"
        tokens = tokenizer_lib.tokenize(text)

        # Mid word from the left.
        self.assertEqual(
            ap_parsing_utils.char_span_to_token_span(tokens, (7, 21)), (2, 7))

        # Mid word from the right.
        self.assertEqual(
            ap_parsing_utils.char_span_to_token_span(tokens, (5, 20)), (2, 7))
Example #10
0
    def test_get_token_features(self):
        ap_text = "50 yo m with hx of copd, dm2\n#. COPD Ex"

        #        0    1             23456      7    8
        vocab = [" ", "\n"] + list("-:.,#") + ["2", "50"] + [
            "abx",
            "continue",
            "copd",
            "dm",
            "ed",
            "ex",
            "hx",
            "in",
            "m",
            "of",
        ]

        tokens = tokenizer_lib.tokenize(ap_text)

        token_features = data_lib.generate_token_features(tokens, vocab)

        expected_features = {
            #    OOV is 1
            "token_ids": [
                11, 3, 2, 3, 20, 3, 2, 3, 18, 3, 21, 3, 14, 8, 3, 15, 10, 4, 9,
                7, 3, 14, 3, 17
            ],
            "token_type": [
                3, 5, 1, 5, 1, 5, 1, 5, 1, 5, 1, 5, 1, 2, 5, 1, 3, 5, 2, 2, 5,
                1, 5, 1
            ],
            "is_upper": [
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                1, 0, 0
            ],
            "is_title": [
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 1
            ]
        }

        for key in token_features:
            self.assertAllClose(token_features[key],
                                expected_features[key],
                                msg=key)
Example #11
0
    def test_midword(self):
        text = "a&p:\n # COPD: on nebs "
        tokens = tokenizer_lib.tokenize(text)

        # Extend word boundry right.
        labeled_char_span = ap_parsing_lib.LabeledCharSpan(
            start_char=6,
            end_char=11,
            span_type=ap_parsing_lib.LabeledSpanType.UNKNOWN_TYPE)  # "# COP"
        self.assertEqual(
            ap_parsing_utils.normalize_labeled_char_span(
                labeled_char_span, tokens),
            ap_parsing_lib.LabeledCharSpan(
                start_char=8,
                end_char=12,
                span_type=ap_parsing_lib.LabeledSpanType.UNKNOWN_TYPE)
        )  # "COPD"

        # Extend word boundry left.
        labeled_char_span = ap_parsing_lib.LabeledCharSpan(
            start_char=9,
            end_char=14,
            span_type=ap_parsing_lib.LabeledSpanType.UNKNOWN_TYPE)  # "OPD: "
        self.assertEqual(
            ap_parsing_utils.normalize_labeled_char_span(
                labeled_char_span, tokens),
            ap_parsing_lib.LabeledCharSpan(
                start_char=8,
                end_char=12,
                span_type=ap_parsing_lib.LabeledSpanType.UNKNOWN_TYPE)
        )  # "COPD"

        # Extend word boundry both directions.
        labeled_char_span = ap_parsing_lib.LabeledCharSpan(
            start_char=9,
            end_char=11,
            span_type=ap_parsing_lib.LabeledSpanType.UNKNOWN_TYPE)  # "OP"
        self.assertEqual(
            ap_parsing_utils.normalize_labeled_char_span(
                labeled_char_span, tokens),
            ap_parsing_lib.LabeledCharSpan(
                start_char=8,
                end_char=12,
                span_type=ap_parsing_lib.LabeledSpanType.UNKNOWN_TYPE)
        )  # "COPD"
Example #12
0
    def test_get_converted_labels(self):
        ap_text = "\n".join([
            "50 yo m with hx of copd, dm2",
            "#. COPD ex: started on abx in ED.", "  - continue abx."
        ])
        tokens = tokenizer_lib.tokenize(ap_text)

        labels = [
            ap_parsing_lib.LabeledCharSpan(
                span_type=ap_parsing_lib.LabeledSpanType.PROBLEM_TITLE,
                start_char=32,
                end_char=39),  # span_text="COPD ex"
            ap_parsing_lib.LabeledCharSpan(
                span_type=ap_parsing_lib.LabeledSpanType.PROBLEM_DESCRIPTION,
                start_char=41,
                end_char=63),  # span_text="started on abx in ED.\n"
            ap_parsing_lib.LabeledCharSpan(
                span_type=ap_parsing_lib.LabeledSpanType.ACTION_ITEM,
                start_char=67,
                end_char=80,
                action_item_type=ap_parsing_lib.ActionItemType.MEDICATIONS
            ),  # span_text="continue abx."
        ]

        converted_labels = data_lib.generate_model_labels(labels, tokens)

        expected_fragment_labels = np.zeros(45)
        expected_fragment_labels[21] = 1  # B-PT COPD ex
        expected_fragment_labels[22:24] = 2  # I-PT COPD ex

        expected_fragment_labels[26] = 3  # B-PD started on abx in ED
        expected_fragment_labels[27:35] = 4  # I-PD started on abx in ED

        expected_fragment_labels[41] = 5  # B-AI continue abx
        expected_fragment_labels[42:44] = 6  # I-AI continue abx

        expected_ai_labels = np.zeros(45)
        expected_ai_labels[41:44] = 1  # continue abx - medications

        self.assertAllEqual(converted_labels["fragment_type"],
                            expected_fragment_labels)
        self.assertAllEqual(converted_labels["action_item_type"],
                            expected_ai_labels)
Example #13
0
 def test_metadata(self):
     text = "a&p:\n - nebs "
     tokens = tokenizer_lib.tokenize(text)
     labeled_char_span = ap_parsing_lib.LabeledCharSpan(
         start_char=3,
         end_char=11,
         span_type=ap_parsing_lib.LabeledSpanType.ACTION_ITEM,
         action_item_type=ap_parsing_lib.ActionItemType.
         MEDICATIONS  # ":\n - neb"
     )
     self.assertEqual(
         ap_parsing_utils.normalize_labeled_char_span(
             labeled_char_span, tokens),
         ap_parsing_lib.LabeledCharSpan(
             start_char=8,
             end_char=12,
             span_type=ap_parsing_lib.LabeledSpanType.ACTION_ITEM,
             action_item_type=ap_parsing_lib.ActionItemType.MEDICATIONS)
     )  # "nebs"
Example #14
0
    def process(self, element, augmentation_config, random_seed=0):
        key, ap_data = element

        # yield non augmented:
        yield element

        # Apply augmentations to train only:
        if ap_data.partition < Partition.VAL:
            seed = hash_key(key, random_seed)
            rng = np.random.default_rng(seed=seed)

            n_augmentations = augmentation_config.get_n_augmentations(rng)

            # Update distribution.
            self.n_augmentations_dist.update(n_augmentations)

            for _ in range(n_augmentations):
                structured_ap = aug_lib.StructuredAP.build(
                    ap_data.ap_text, ap_data.labeled_char_spans)
                aug_seq = rng.choice(
                    augmentation_config.augmentation_sequences,
                    p=augmentation_config.augmentation_sample_probabilities)

                structured_ap = aug_lib.apply_augmentations(structured_ap,
                                                            aug_seq,
                                                            seed=seed)

                # Update counter.
                self.total_augmented_counter.inc()
                beam.metrics.Metrics.counter("augmentations",
                                             aug_seq.name).inc()

                # Construct new record with augmented text.
                new_ap_data = dataclasses.replace(ap_data)
                new_ap_data.ap_text, new_ap_data.labeled_char_spans = structured_ap.compile(
                )
                new_ap_data.tokens = tokenizer_lib.tokenize(
                    new_ap_data.ap_text)
                new_ap_data.augmentation_name = aug_seq.name

                yield (key, new_ap_data)
Example #15
0
    def test_usage(self):
        augmentation_config = aug_lib.AugmentationConfig(
            augmentation_sequences=[
                aug_lib.AugmentationSequence(
                    name="test",
                    augmentation_sequence=[
                        aug_lib.ChangeDelimAugmentation(fragment_types=[
                            ap_parsing_lib.LabeledSpanType.PROBLEM_TITLE
                        ],
                                                        delims=["\n"])
                    ])
            ],
            augmentation_number_deterministic=1)

        ap_data = [
            (
                "0|10",
                data_lib.APData(
                    note_id=0,
                    subject_id=0,
                    ap_text="a&p:\n # dm2:\n-RISS",
                    labeled_char_spans=[
                        ap_parsing_lib.LabeledCharSpan(
                            span_type=ap_parsing_lib.LabeledSpanType.
                            PROBLEM_TITLE,
                            start_char=8,
                            end_char=11),  # span_text="dm2",
                        ap_parsing_lib.LabeledCharSpan(
                            span_type=ap_parsing_lib.LabeledSpanType.
                            ACTION_ITEM,
                            action_item_type=ap_parsing_lib.ActionItemType.
                            MEDICATIONS,
                            start_char=14,
                            end_char=18)  # span_text="RISS",
                    ])),
        ]
        expected = [
            *ap_data,
            (
                "0|10",
                data_lib.APData(
                    note_id=0,
                    subject_id=0,
                    ap_text="a&p\ndm2:\n- RISS",
                    tokens=tokenizer_lib.tokenize("a&p\ndm2:\n- RISS"),
                    labeled_char_spans=[
                        ap_parsing_lib.LabeledCharSpan(
                            span_type=ap_parsing_lib.LabeledSpanType.
                            PROBLEM_TITLE,
                            start_char=4,
                            end_char=7),  # span_text="dm2",
                        ap_parsing_lib.LabeledCharSpan(
                            span_type=ap_parsing_lib.LabeledSpanType.
                            ACTION_ITEM,
                            action_item_type=ap_parsing_lib.ActionItemType.
                            MEDICATIONS,
                            start_char=11,
                            end_char=15)  # span_text="RISS",
                    ],
                    augmentation_name="test")),
        ]

        with test_pipeline.TestPipeline() as p:
            results = (p
                       | beam.Create(ap_data)
                       | beam.ParDo(data_lib.ApplyAugmentations(),
                                    augmentation_config))
            util.assert_that(results, util.equal_to(expected))
Example #16
0
    def test_multiratings(self):
        section_markers = {
            "hpi": ["history of present illness"],
            "a&p": ["assessment and plan"],
        }
        ap_text = "a&p:\n # dm2:\n-RISS"
        notes_with_ratings = [("0", {
            "notes": [
                data_lib.Note(note_id=0,
                              text="blablabla\n" + ap_text,
                              subject_id=0,
                              category="PHYSICIAN")
            ],
            "ratings":
            [[
                ap_parsing_lib.LabeledCharSpan(
                    span_type=ap_parsing_lib.LabeledSpanType.PROBLEM_TITLE,
                    start_char=19,
                    end_char=22),
                ap_parsing_lib.LabeledCharSpan(
                    span_type=ap_parsing_lib.LabeledSpanType.ACTION_ITEM,
                    start_char=24,
                    end_char=28)
            ],
             [
                 ap_parsing_lib.LabeledCharSpan(
                     span_type=ap_parsing_lib.LabeledSpanType.PROBLEM_TITLE,
                     start_char=18,
                     end_char=22),
                 ap_parsing_lib.LabeledCharSpan(
                     span_type=ap_parsing_lib.LabeledSpanType.ACTION_ITEM,
                     start_char=25,
                     end_char=28)
             ]],
            "note_partition": ["test", "test"]
        })]

        expected = [(
            "0|10",
            data_lib.APData(
                partition=data_lib.Partition.TEST,
                note_id="0",
                subject_id="0",
                ap_text=ap_text,
                char_offset=10,
                tokens=tokenizer_lib.tokenize(ap_text),
                labeled_char_spans=[
                    ap_parsing_lib.LabeledCharSpan(
                        span_type=ap_parsing_lib.LabeledSpanType.PROBLEM_TITLE,
                        start_char=8,
                        end_char=11),
                    ap_parsing_lib.LabeledCharSpan(
                        span_type=ap_parsing_lib.LabeledSpanType.ACTION_ITEM,
                        start_char=14,
                        end_char=18)
                ]))] * 2
        with test_pipeline.TestPipeline() as p:
            results = (p
                       | beam.Create(notes_with_ratings)
                       | beam.ParDo(
                           data_lib.ProcessAPData(
                               filter_inorganic_threshold=0), section_markers))
            util.assert_that(results, util.equal_to(expected))
Example #17
0
    def test_usage(self):
        section_markers = {
            "hpi": ["history of present illness"],
            "a&p": ["assessment and plan"],
        }
        ap_texts = ["a&p:\n # dm2:\n-RISS", "a&p:\n # COPD:\n-nebs"]
        notes_with_ratings = [("0", {
            "notes": [
                data_lib.Note(note_id=0,
                              text="blablabla\n" + ap_texts[0],
                              subject_id=0,
                              category="PHYSICIAN")
            ],
            "ratings": [[
                ap_parsing_lib.LabeledCharSpan(
                    span_type=ap_parsing_lib.LabeledSpanType.PROBLEM_TITLE,
                    start_char=19,
                    end_char=22),
                ap_parsing_lib.LabeledCharSpan(
                    span_type=ap_parsing_lib.LabeledSpanType.ACTION_ITEM,
                    action_item_type=ap_parsing_lib.ActionItemType.MEDICATIONS,
                    start_char=24,
                    end_char=28)
            ]],
            "note_partition": ["val"]
        })] + [("1", {
            "notes": [
                data_lib.Note(note_id=1,
                              text="blablabla\n" + ap_texts[1],
                              subject_id=1,
                              category="PHYSICIAN")
            ],
            "ratings": [],
            "note_partition": []
        })]

        expected = [
            ("0|10",
             data_lib.APData(
                 partition=data_lib.Partition.VAL,
                 note_id="0",
                 subject_id="0",
                 ap_text=ap_texts[0],
                 char_offset=10,
                 tokens=tokenizer_lib.tokenize(ap_texts[0]),
                 labeled_char_spans=[
                     ap_parsing_lib.LabeledCharSpan(
                         span_type=ap_parsing_lib.LabeledSpanType.
                         PROBLEM_TITLE,
                         start_char=8,
                         end_char=11),
                     ap_parsing_lib.LabeledCharSpan(
                         span_type=ap_parsing_lib.LabeledSpanType.ACTION_ITEM,
                         action_item_type=ap_parsing_lib.ActionItemType.
                         MEDICATIONS,
                         start_char=14,
                         end_char=18)
                 ])),
            ("1|10",
             data_lib.APData(
                 partition=data_lib.Partition.NONRATED,
                 note_id="1",
                 subject_id="1",
                 ap_text=ap_texts[1],
                 char_offset=10,
                 tokens=tokenizer_lib.tokenize(ap_texts[1]),
                 labeled_char_spans=[
                     ap_parsing_lib.LabeledCharSpan(
                         span_type=ap_parsing_lib.LabeledSpanType.
                         PROBLEM_TITLE,
                         start_char=8,
                         end_char=12),
                     ap_parsing_lib.LabeledCharSpan(
                         span_type=ap_parsing_lib.LabeledSpanType.ACTION_ITEM,
                         start_char=15,
                         end_char=19)
                 ]))
        ]
        with test_pipeline.TestPipeline() as p:
            results = (p
                       | beam.Create(notes_with_ratings)
                       | beam.ParDo(
                           data_lib.ProcessAPData(
                               filter_inorganic_threshold=0), section_markers))
            util.assert_that(results, util.equal_to(expected))
Example #18
0
  def test_evaluate(self):
    # Char space:
    #       0         1          2          3
    #       01234567890123456 78901234 56789012345
    text = "# DM2: on insulin\n # COPD\n- nebs prn"
    # Token:012 3456 78       90123    4567   89
    #       0                  1
    tokens = tokenizer_lib.tokenize(text)

    truth_token_spans = [
        ap_parsing_lib.LabeledTokenSpan(
            span_type=ap_parsing_lib.LabeledSpanType.PROBLEM_TITLE,
            start_token=2,
            end_token=4),  # span_text="DM2"
        ap_parsing_lib.LabeledTokenSpan(
            span_type=ap_parsing_lib.LabeledSpanType.PROBLEM_DESCRIPTION,
            start_token=6,
            end_token=9),  # span_text="on insulin"
        ap_parsing_lib.LabeledTokenSpan(
            span_type=ap_parsing_lib.LabeledSpanType.PROBLEM_TITLE,
            start_token=13,
            end_token=14),  # span_text="COPD"
        ap_parsing_lib.LabeledTokenSpan(
            span_type=ap_parsing_lib.LabeledSpanType.ACTION_ITEM,
            action_item_type=ap_parsing_lib.ActionItemType.MEDICATIONS,
            start_token=17,
            end_token=20),  # span_text="nebs prn"
    ]

    predicted_token_spans = [
        ap_parsing_lib.LabeledTokenSpan(
            span_type=ap_parsing_lib.LabeledSpanType.PROBLEM_TITLE,
            start_token=2,
            end_token=9),  # span_text="DM2: on insulin"
        ap_parsing_lib.LabeledTokenSpan(
            span_type=ap_parsing_lib.LabeledSpanType.PROBLEM_TITLE,
            start_token=13,
            end_token=14),  # span_text="COPD"
        ap_parsing_lib.LabeledTokenSpan(
            span_type=ap_parsing_lib.LabeledSpanType.ACTION_ITEM,
            action_item_type=ap_parsing_lib.ActionItemType.MEDICATIONS,
            start_token=17,
            end_token=18),  # span_text="- nebs"
        ap_parsing_lib.LabeledTokenSpan(
            span_type=ap_parsing_lib.LabeledSpanType.ACTION_ITEM,
            action_item_type=ap_parsing_lib.ActionItemType.MEDICATIONS,
            start_token=18,
            end_token=20),  # span_text="prn"
    ]

    conf_mat = np.zeros((9, 9))
    conf_mat[1, 1] = 1

    expected_metrics = {
        # span_relaxed/PROBLEM_TITLE
        "span_relaxed/PROBLEM_TITLE/precision": 1,
        "span_relaxed/PROBLEM_TITLE/recall": 1,
        "span_relaxed/PROBLEM_TITLE/f1": 1,
        "span_relaxed/PROBLEM_TITLE/jaccard": 1,
        "span_relaxed/PROBLEM_TITLE/tp": 2,
        "span_relaxed/PROBLEM_TITLE/total_true": 2,
        "span_relaxed/PROBLEM_TITLE/total_pred": 2,
        # token_relaxed/PROBLEM_TITLE
        "token_relaxed/PROBLEM_TITLE/precision": 0.5,
        "token_relaxed/PROBLEM_TITLE/recall": 1,
        "token_relaxed/PROBLEM_TITLE/f1": 2 / 3,
        "token_relaxed/PROBLEM_TITLE/jaccard": 0.5,
        "token_relaxed/PROBLEM_TITLE/tp": 3,
        "token_relaxed/PROBLEM_TITLE/total_true": 3,
        "token_relaxed/PROBLEM_TITLE/total_pred": 6,
        # span_relaxed/PROBLEM_DESCRIPTION
        "span_relaxed/PROBLEM_DESCRIPTION/precision": np.nan,
        "span_relaxed/PROBLEM_DESCRIPTION/recall": 0,
        "span_relaxed/PROBLEM_DESCRIPTION/f1": np.nan,
        "span_relaxed/PROBLEM_DESCRIPTION/jaccard": 0,
        "span_relaxed/PROBLEM_DESCRIPTION/tp": 0,
        "span_relaxed/PROBLEM_DESCRIPTION/total_true": 1,
        "span_relaxed/PROBLEM_DESCRIPTION/total_pred": 0,
        # token_relaxed/PROBLEM_DESCRIPTION
        "token_relaxed/PROBLEM_DESCRIPTION/precision": np.nan,
        "token_relaxed/PROBLEM_DESCRIPTION/recall": 0,
        "token_relaxed/PROBLEM_DESCRIPTION/f1": np.nan,
        "token_relaxed/PROBLEM_DESCRIPTION/jaccard": 0,
        "token_relaxed/PROBLEM_DESCRIPTION/tp": 0,
        "token_relaxed/PROBLEM_DESCRIPTION/total_true": 2,
        "token_relaxed/PROBLEM_DESCRIPTION/total_pred": 0,
        # span_relaxed/ACTION_ITEM
        "span_relaxed/ACTION_ITEM/precision": 0.5,
        "span_relaxed/ACTION_ITEM/recall": 1,
        "span_relaxed/ACTION_ITEM/f1": 2 / 3,
        "span_relaxed/ACTION_ITEM/jaccard": 0.5,
        "span_relaxed/ACTION_ITEM/tp": 1,
        "span_relaxed/ACTION_ITEM/total_true": 1,
        "span_relaxed/ACTION_ITEM/total_pred": 2,
        # token_relaxed/ACTION_ITEM
        "token_relaxed/ACTION_ITEM/precision": 1,
        "token_relaxed/ACTION_ITEM/recall": 1,
        "token_relaxed/ACTION_ITEM/f1": 1,
        "token_relaxed/ACTION_ITEM/jaccard": 1,
        "token_relaxed/ACTION_ITEM/tp": 2,
        "token_relaxed/ACTION_ITEM/total_true": 2,
        "token_relaxed/ACTION_ITEM/total_pred": 2,
        # action_item_type/MEDICATIONS
        "action_item_type/MEDICATIONS/f1": 2 / 3,
        "action_item_type/MEDICATIONS/jaccard": 0.5,
        "action_item_type/MEDICATIONS/precision": 0.5,
        "action_item_type/MEDICATIONS/recall": 1,
        "action_item_type/MEDICATIONS/tp": 1,
        "action_item_type/MEDICATIONS/total_true": 1,
        "action_item_type/MEDICATIONS/total_pred": 2,
        "action_item_type/ALL/confusion_matrix": conf_mat,
    }
    calculated_metrics = eval_lib.evaluate_from_labeled_token_spans(
        truth_token_spans, predicted_token_spans, tokens=tokens)

    for k, v in expected_metrics.items():
      self.assertIn(k, calculated_metrics)
      self.assertTrue(
          np.allclose(v, calculated_metrics[k], equal_nan=True),
          msg=f"{k} - expected: {v} got {calculated_metrics[k]}")