Exemple #1
0
def test_sequence_max_label_index_decoder():
    label_vocabulary = LabelVocabulary(
        [["B-T", "B-T", "B-T", "I-T", "I-T", "O"]],
        padding=LabelVocabulary.PADDING)

    b_index = label_vocabulary.index("B-T")
    ASSERT.assertEqual(0, b_index)
    i_index = label_vocabulary.index("I-T")
    ASSERT.assertEqual(1, i_index)
    o_index = label_vocabulary.index("O")
    ASSERT.assertEqual(2, o_index)

    # [[O, B, I], [B, B, I], [B, I, I], [B, I, O]]
    batch_sequence_logits = torch.tensor(
        [[[0.2, 0.3, 0.4], [0.7, 0.2, 0.3], [0.2, 0.3, 0.1]],
         [[0.8, 0.3, 0.4], [0.7, 0.2, 0.3], [0.2, 0.3, 0.1]],
         [[0.8, 0.3, 0.4], [0.1, 0.7, 0.3], [0.2, 0.3, 0.1]],
         [[0.8, 0.3, 0.4], [0.1, 0.7, 0.3], [0.2, 0.3, 0.5]]],
        dtype=torch.float)

    expect_sequence_labels = [["O", "B-T", "I-T"], ["B-T", "B-T", "I-T"],
                              ["B-T", "I-T", "I-T"], ["B-T", "I-T", "O"]]

    expect = list()

    for expect_sequence_label in expect_sequence_labels:
        expect.append(
            [label_vocabulary.index(label) for label in expect_sequence_label])

    decoder = SequenceMaxLabelIndexDecoder(label_vocabulary=label_vocabulary)

    label_indices = decoder(logits=batch_sequence_logits, mask=None)

    ASSERT.assertEqual(expect, label_indices.tolist())
Exemple #2
0
def test_decode_decode_label_index_to_span():
    """
    测试解码 golden label index
    :return:
    """

    vocabulary = LabelVocabulary([["B-T", "B-T", "B-T", "I-T", "I-T", "O"]],
                                 padding=LabelVocabulary.PADDING)

    b_index = vocabulary.index("B-T")
    ASSERT.assertEqual(0, b_index)
    i_index = vocabulary.index("I-T")
    ASSERT.assertEqual(1, i_index)
    o_index = vocabulary.index("O")
    ASSERT.assertEqual(2, o_index)

    golden_labels = torch.tensor([[0, 1, 2, 0],
                                  [2, 0, 1, 1]])

    expect = [[{"label": "T", "begin": 0, "end": 2}, {"label": "T", "begin": 3, "end": 4}],
              [{"label": "T", "begin": 1, "end": 4}]]

    spans = BIO.decode_label_index_to_span(batch_sequence_label_index=golden_labels,
                                           mask=None,
                                           vocabulary=vocabulary)

    ASSERT.assertListEqual(expect, spans)
Exemple #3
0
def test_decode():
    """
    测试 模型输出的 batch logits 解码
    :return:
    """

    # [[O, B, I], [B, B, I], [B, I, I], [B, I, O]]
    batch_sequence_logits = torch.tensor([[[0.2, 0.3, 0.4], [0.7, 0.2, 0.3], [0.2, 0.3, 0.1]],
                                          [[0.8, 0.3, 0.4], [0.7, 0.2, 0.3], [0.2, 0.3, 0.1]],
                                          [[0.8, 0.3, 0.4], [0.1, 0.7, 0.3], [0.2, 0.3, 0.1]],
                                          [[0.8, 0.3, 0.4], [0.1, 0.7, 0.3], [0.2, 0.3, 0.5]]],
                                         dtype=torch.float)

    expect = [[{"label": "T", "begin": 1, "end": 3}],
              [{"label": "T", "begin": 0, "end": 1}, {"label": "T", "begin": 1, "end": 3}],
              [{"label": "T", "begin": 0, "end": 3}],
              [{"label": "T", "begin": 0, "end": 2}]]

    vocabulary = LabelVocabulary([["B-T", "B-T", "B-T", "I-T", "I-T", "O"]],
                                 padding=LabelVocabulary.PADDING)

    b_index = vocabulary.index("B-T")
    ASSERT.assertEqual(0, b_index)
    i_index = vocabulary.index("I-T")
    ASSERT.assertEqual(1, i_index)
    o_index = vocabulary.index("O")
    ASSERT.assertEqual(2, o_index)

    spans = BIO.decode(batch_sequence_logits=batch_sequence_logits,
                       mask=None,
                       vocabulary=vocabulary)

    ASSERT.assertListEqual(expect, spans)
Exemple #4
0
def test_allowed_transitions():
    """
    测试允许转移mask pair
    :return:
    """

    label_vocabulary = LabelVocabulary(labels=[["B-L1", "I-L1", "B-L2", "I-L2", "O"]],
                                       padding=LabelVocabulary.PADDING)

    allowed_pairs = BIO.allowed_transitions(label_vocabulary=label_vocabulary)

    for from_idx, to_idx in allowed_pairs:

        if from_idx == label_vocabulary.label_size:
            from_label = "START"
        else:
            from_label = label_vocabulary.token(from_idx)

        if to_idx == label_vocabulary.label_size + 1:
            to_label = "STOP"
        else:
            to_label = label_vocabulary.token(to_idx)
        print(f"(\"{from_label}\", \"{to_label}\"),")

    expect_trainsition_labels = [
        ("B-L1", "B-L1"), ("B-L1", "I-L1"), ("B-L1", "B-L2"), ("B-L1", "O"), ("B-L1", "STOP"),
        ("I-L1", "B-L1"), ("I-L1", "I-L1"), ("I-L1", "B-L2"), ("I-L1", "O"), ("I-L1", "STOP"),
        ("B-L2", "B-L1"), ("B-L2", "B-L2"), ("B-L2", "I-L2"), ("B-L2", "O"), ("B-L2", "STOP"),
        ("I-L2", "B-L1"), ("I-L2", "B-L2"), ("I-L2", "I-L2"), ("I-L2", "O"), ("I-L2", "STOP"),
        ("O", "B-L1"), ("O", "B-L2"), ("O", "O"), ("O", "STOP"),
        ("START", "B-L1"), ("START", "B-L2"), ("START", "O")]


    expect = list()

    for from_label, to_label in expect_trainsition_labels:
        if from_label == "START":
            from_idx = label_vocabulary.label_size
        else:
            from_idx = label_vocabulary.index(from_label)

        if to_label == "STOP":
            to_idx = label_vocabulary.label_size + 1
        else:
            to_idx = label_vocabulary.index(to_label)

        expect.append((from_idx, to_idx))

    ASSERT.assertSetEqual(set(expect), set(allowed_pairs))
Exemple #5
0
def test_decode_one_sequence_logits_to_label():
    """
    测试 decode sequence label
    :return:
    """

    sequence_logits_list = list()
    expect_list = list()

    sequence_logits = torch.tensor([[0.2, 0.3, 0.4], [0.7, 0.2, 0.3], [0.2, 0.3, 0.1]],
                                   dtype=torch.float)  # O B I 正常
    expect = ["O", "B-T", "I-T"]
    sequence_logits_list.append(sequence_logits)
    expect_list.append(expect)

    sequence_logits = torch.tensor([[0.9, 0.3, 0.4], [0.2, 0.8, 0.3], [0.2, 0.3, 0.1]],
                                   dtype=torch.float)
    expect = ["B-T", "I-T", "I-T"]

    sequence_logits_list.append(sequence_logits)
    expect_list.append(expect)

    sequence_logits = torch.tensor([[0.9, 0.3, 0.4], [0.2, 0.8, 0.3], [0.2, 0.3, 0.9]],
                                   dtype=torch.float)
    expect = ["B-T", "I-T", "O"]
    sequence_logits_list.append(sequence_logits)
    expect_list.append(expect)

    vocabulary = LabelVocabulary([["B-T", "B-T", "B-T", "I-T", "I-T", "O"]],
                                 padding=LabelVocabulary.PADDING)

    b_index = vocabulary.index("B-T")
    ASSERT.assertEqual(0, b_index)
    i_index = vocabulary.index("I-T")
    ASSERT.assertEqual(1, i_index)
    o_index = vocabulary.index("O")
    ASSERT.assertEqual(2, o_index)

    for sequence_logits, expect in zip(sequence_logits_list, expect_list):
        sequence_label, sequence_label_indices = BIO.decode_one_sequence_logits_to_label(
            sequence_logits=sequence_logits,
            vocabulary=vocabulary)

        ASSERT.assertListEqual(sequence_label, expect)
        expect_indices = [vocabulary.index(label) for label in expect]
        ASSERT.assertListEqual(sequence_label_indices, expect_indices)
def test_label_vocabulary():
    """
    测试 label vocabulary
    :return:
    """
    vocabulary = LabelVocabulary([["A", "B", "C"], ["D", "E"]], padding="")
    ASSERT.assertEqual(vocabulary.size, 5)

    vocabulary = LabelVocabulary([["A", "B", "C"], ["D", "E"]],
                                 padding=LabelVocabulary.PADDING)
    ASSERT.assertEqual(vocabulary.size, 6)
    ASSERT.assertEqual(vocabulary.label_size, 5)

    ASSERT.assertEqual(vocabulary.index(vocabulary.padding), 5)

    for index, w in enumerate(["A", "B", "C", "D", "E"]):
        ASSERT.assertEqual(vocabulary.index(w), index)
Exemple #7
0
def test_decode_one_sequence_logits_to_label_abnormal():
    """
    测试异常case
    :return:
    """

    # [0.2, 0.5, 0.4] argmax 解码是 I 这是异常的 case, 整个序列是: I B O
    # 而 decode_sequence_lable_bio 会将 概率值 是 0.4 的 也就是 O 作为标签输出 来修订这个个错误
    sequence_logits = torch.tensor([[0.2, 0.5, 0.4], [0.7, 0.2, 0.3], [0.2, 0.3, 0.1]],
                                   dtype=torch.float)

    vocabulary = LabelVocabulary([["B-T", "B-T", "B-T", "I-T", "I-T", "O"]],
                                 padding=LabelVocabulary.PADDING)

    b_index = vocabulary.index("B-T")
    ASSERT.assertEqual(0, b_index)
    i_index = vocabulary.index("I-T")
    ASSERT.assertEqual(1, i_index)
    o_index = vocabulary.index("O")
    ASSERT.assertEqual(2, o_index)

    sequence_label, sequence_label_indices = BIO.decode_one_sequence_logits_to_label(sequence_logits=sequence_logits,
                                                             vocabulary=vocabulary)

    expect = ["O", "B-T", "I-T"]
    expect_indices = [vocabulary.index(label) for label in expect]
    ASSERT.assertListEqual(expect, sequence_label)
    ASSERT.assertListEqual(expect_indices, sequence_label_indices)

    # argmax 解码是 I I I 经过修订后是: O O B
    sequence_logits = torch.tensor([[0.2, 0.5, 0.4], [0.2, 0.9, 0.3], [0.2, 0.3, 0.1]],
                                   dtype=torch.float)

    sequence_label, sequence_label_indices = BIO.decode_one_sequence_logits_to_label(sequence_logits=sequence_logits,
                                                             vocabulary=vocabulary)
    expect = ["O", "O", "B-T"]
    expect_indices = [vocabulary.index(label) for label in expect]
    ASSERT.assertListEqual(expect, sequence_label)
    ASSERT.assertListEqual(expect_indices, sequence_label_indices)
def test_sequence_label_decoder():
    """
    测试 sequence label decoder
    :return:
    """
    sequence_label_list = list()
    expect_spans = list()

    sequence_label = ["B-T", "I-T", "O-T"]
    expect = [{"label": "T", "begin": 0, "end": 2}]

    sequence_label_list.append(sequence_label)
    expect_spans.append(expect)

    sequence_label = ["B-T", "I-T", "I-T"]
    expect = [{"label": "T", "begin": 0, "end": 3}]

    sequence_label_list.append(sequence_label)
    expect_spans.append(expect)

    sequence_label = ["B-T", "I-T", "I-T", "B-T"]
    expect = [{
        "label": "T",
        "begin": 0,
        "end": 3
    }, {
        "label": "T",
        "begin": 3,
        "end": 4
    }]

    sequence_label_list.append(sequence_label)
    expect_spans.append(expect)

    label_vocabulary = LabelVocabulary(sequence_label_list,
                                       padding=LabelVocabulary.PADDING)

    sequence_label_indices = list()
    mask_list = list()

    max_sequence_len = 4
    for sequence_labels in sequence_label_list:
        sequence_label_index = [
            label_vocabulary.index(label) for label in sequence_labels
        ]

        mask = [1] * len(sequence_label_index) + [0] * (
            max_sequence_len - len(sequence_label_index))

        sequence_label_index.extend(
            [label_vocabulary.padding_index] *
            (max_sequence_len - len(sequence_label_index)))

        sequence_label_indices.append(sequence_label_index)
        mask_list.append(mask)

    sequence_label_indices = torch.tensor(sequence_label_indices,
                                          dtype=torch.long)
    mask = torch.tensor(mask_list, dtype=torch.uint8)

    decoder = SequenceLabelDecoder(label_vocabulary=label_vocabulary)

    spans = decoder(label_indices=sequence_label_indices, mask=mask)

    ASSERT.assertListEqual(expect_spans, spans)