def test_pretrained_vocabulary(pretrained_vocabulary):
    """
    测试预训练词汇表
    """

    ASSERT.assertEqual(4, pretrained_vocabulary.size)
    ASSERT.assertEqual(4, len(pretrained_vocabulary))
    ASSERT.assertEqual(2, pretrained_vocabulary.index("我"))
    ASSERT.assertEqual(3, pretrained_vocabulary.index("美丽"))

    ASSERT.assertEqual((pretrained_vocabulary.size, 3),
                       pretrained_vocabulary.embedding_matrix.size())

    expect_embedding_dict = {
        "a": [1.0, 2.0, 3.0],
        "b": [4.0, 5.0, 6.0],
        "美丽": [7.0, 8.0, 9.0]
    }

    ASSERT.assertListEqual(
        expect_embedding_dict["美丽"], pretrained_vocabulary.embedding_matrix[
            pretrained_vocabulary.index("美丽")].tolist())

    zero_vec = [0.] * 3

    for index in [
            pretrained_vocabulary.index("我"),
            pretrained_vocabulary.padding_index,
            pretrained_vocabulary.index(pretrained_vocabulary.unk)
    ]:
        ASSERT.assertListEqual(
            zero_vec, pretrained_vocabulary.embedding_matrix[index].tolist())
Exemple #2
0
    def __call__(self, instances: Iterable[Instance]) -> ModelInputs:

        x = list()
        labels = list()
        for instance in instances:

            x_data = instance["x"]
            x.append(torch.tensor([x_data], dtype=torch.float))

            if x_data - 50 > 0:
                labels.append(1)
            else:
                labels.append(0)

        x = torch.stack(x)

        batch_size = x.size(0)
        ASSERT.assertEqual(x.dim(), 2)
        ASSERT.assertListEqual([batch_size, 1], [x.size(0), x.size(1)])

        labels = torch.tensor(labels)
        ASSERT.assertEqual(labels.dim(), 1)
        ASSERT.assertEqual(batch_size, labels.size(0))

        model_inputs = ModelInputs(batch_size=batch_size,
                                   model_inputs={"x": x},
                                   labels=labels)

        return model_inputs
Exemple #3
0
def test_decode_decode_label_index_to_span():
    """
    测试解码 golden label index
    :return:
    """

    vocabulary = LabelVocabulary([["B-T", "B-T", "B-T", "I-T", "I-T", "O"]],
                                 padding=LabelVocabulary.PADDING)

    b_index = vocabulary.index("B-T")
    ASSERT.assertEqual(0, b_index)
    i_index = vocabulary.index("I-T")
    ASSERT.assertEqual(1, i_index)
    o_index = vocabulary.index("O")
    ASSERT.assertEqual(2, o_index)

    golden_labels = torch.tensor([[0, 1, 2, 0],
                                  [2, 0, 1, 1]])

    expect = [[{"label": "T", "begin": 0, "end": 2}, {"label": "T", "begin": 3, "end": 4}],
              [{"label": "T", "begin": 1, "end": 4}]]

    spans = BIO.decode_label_index_to_span(batch_sequence_label_index=golden_labels,
                                           mask=None,
                                           vocabulary=vocabulary)

    ASSERT.assertListEqual(expect, spans)
Exemple #4
0
def test_decode():
    """
    测试 模型输出的 batch logits 解码
    :return:
    """

    # [[O, B, I], [B, B, I], [B, I, I], [B, I, O]]
    batch_sequence_logits = torch.tensor([[[0.2, 0.3, 0.4], [0.7, 0.2, 0.3], [0.2, 0.3, 0.1]],
                                          [[0.8, 0.3, 0.4], [0.7, 0.2, 0.3], [0.2, 0.3, 0.1]],
                                          [[0.8, 0.3, 0.4], [0.1, 0.7, 0.3], [0.2, 0.3, 0.1]],
                                          [[0.8, 0.3, 0.4], [0.1, 0.7, 0.3], [0.2, 0.3, 0.5]]],
                                         dtype=torch.float)

    expect = [[{"label": "T", "begin": 1, "end": 3}],
              [{"label": "T", "begin": 0, "end": 1}, {"label": "T", "begin": 1, "end": 3}],
              [{"label": "T", "begin": 0, "end": 3}],
              [{"label": "T", "begin": 0, "end": 2}]]

    vocabulary = LabelVocabulary([["B-T", "B-T", "B-T", "I-T", "I-T", "O"]],
                                 padding=LabelVocabulary.PADDING)

    b_index = vocabulary.index("B-T")
    ASSERT.assertEqual(0, b_index)
    i_index = vocabulary.index("I-T")
    ASSERT.assertEqual(1, i_index)
    o_index = vocabulary.index("O")
    ASSERT.assertEqual(2, o_index)

    spans = BIO.decode(batch_sequence_logits=batch_sequence_logits,
                       mask=None,
                       vocabulary=vocabulary)

    ASSERT.assertListEqual(expect, spans)
Exemple #5
0
def test_decode_one_sequence_label_to_span():
    """
    测试对 sequence label 解码成 span 字典
    :return:
    """

    sequence_label_list = list()
    expect_list = list()

    sequence_label = ["B-T", "I-T", "O-T"]
    expect = [{"label": "T", "begin": 0, "end": 2}]

    sequence_label_list.append(sequence_label)
    expect_list.append(expect)

    sequence_label = ["B-T", "I-T", "I-T"]
    expect = [{"label": "T", "begin": 0, "end": 3}]

    sequence_label_list.append(sequence_label)
    expect_list.append(expect)

    sequence_label = ["B-T", "I-T", "I-T", "B-T"]
    expect = [{"label": "T", "begin": 0, "end": 3},
              {"label": "T", "begin": 3, "end": 4}]

    sequence_label_list.append(sequence_label)
    expect_list.append(expect)

    for expect, sequence_label in zip(expect_list, sequence_label_list):
        span = BIO.decode_one_sequence_label_to_span(sequence_label)

        ASSERT.assertListEqual(expect, span)
Exemple #6
0
def test_bmes_to_bio():
    """
    测试 BMES schema 转换成 bio
    :return:
    """
    bmes = ["B-T", "M-T", "E-T", "O", "S-T", "B-T", "E-T"]
    expect_bio = ["B-T", "I-T", "I-T", "O", "B-T", "B-T", "I-T"]

    bio_sequence_label = bio_schema.bmes_to_bio(bmes)

    ASSERT.assertListEqual(expect_bio, bio_sequence_label)
def test_crf_label_index_decoder_with_constraint(crf_data):
    mask = torch.tensor([[1, 1, 1], [1, 1, 0]], dtype=torch.uint8)

    crf_label_index_decoder = CRFLabelIndexDecoder(
        crf=crf_data.constraint_crf,
        label_vocabulary=crf_data.label_vocabulary)

    label_indices = crf_label_index_decoder(logits=crf_data.logits, mask=mask)
    padding_index = crf_data.label_vocabulary.padding_index
    expect = [[2, 3, 3], [2, 3, padding_index]]

    ASSERT.assertListEqual(expect, label_indices.tolist())
def test_max_label_index_decoder():
    """
    测试 max label index
    :return:
    """
    decoder = MaxLabelIndexDecoder()

    logits = torch.tensor([[0.1, 0.9], [0.3, 0.7], [0.8, 0.2]])

    label_indices = decoder(logits=logits)

    expect = [1, 1, 0]

    ASSERT.assertListEqual(expect, label_indices.tolist())
def test_crf_label_index_decoder(crf_data):
    """
    测试 crf label index decoder
    :param crf_data: crf data
    :return:
    """
    mask = torch.tensor([[1, 1, 1], [1, 1, 0]], dtype=torch.long)

    crf_label_index_decoder = CRFLabelIndexDecoder(
        crf=crf_data.crf, label_vocabulary=crf_data.label_vocabulary)

    label_indices = crf_label_index_decoder(logits=crf_data.logits, mask=mask)
    padding_index = crf_data.label_vocabulary.padding_index
    expect = [[2, 4, 3], [4, 2, padding_index]]

    ASSERT.assertListEqual(expect, label_indices.tolist())
Exemple #10
0
def test_decode_one_sequence_logits_to_label():
    """
    测试 decode sequence label
    :return:
    """

    sequence_logits_list = list()
    expect_list = list()

    sequence_logits = torch.tensor([[0.2, 0.3, 0.4], [0.7, 0.2, 0.3], [0.2, 0.3, 0.1]],
                                   dtype=torch.float)  # O B I 正常
    expect = ["O", "B-T", "I-T"]
    sequence_logits_list.append(sequence_logits)
    expect_list.append(expect)

    sequence_logits = torch.tensor([[0.9, 0.3, 0.4], [0.2, 0.8, 0.3], [0.2, 0.3, 0.1]],
                                   dtype=torch.float)
    expect = ["B-T", "I-T", "I-T"]

    sequence_logits_list.append(sequence_logits)
    expect_list.append(expect)

    sequence_logits = torch.tensor([[0.9, 0.3, 0.4], [0.2, 0.8, 0.3], [0.2, 0.3, 0.9]],
                                   dtype=torch.float)
    expect = ["B-T", "I-T", "O"]
    sequence_logits_list.append(sequence_logits)
    expect_list.append(expect)

    vocabulary = LabelVocabulary([["B-T", "B-T", "B-T", "I-T", "I-T", "O"]],
                                 padding=LabelVocabulary.PADDING)

    b_index = vocabulary.index("B-T")
    ASSERT.assertEqual(0, b_index)
    i_index = vocabulary.index("I-T")
    ASSERT.assertEqual(1, i_index)
    o_index = vocabulary.index("O")
    ASSERT.assertEqual(2, o_index)

    for sequence_logits, expect in zip(sequence_logits_list, expect_list):
        sequence_label, sequence_label_indices = BIO.decode_one_sequence_logits_to_label(
            sequence_logits=sequence_logits,
            vocabulary=vocabulary)

        ASSERT.assertListEqual(sequence_label, expect)
        expect_indices = [vocabulary.index(label) for label in expect]
        ASSERT.assertListEqual(sequence_label_indices, expect_indices)
Exemple #11
0
def test_ibo1_to_bio():
    """
    测试 ibo1 转换到 bio
    :return:
    """
    ibo1 = [
        "I-L1", "I-L1", "O", "I-L1", "I-L2", "O", "I-L1", "I-L1", "I-L1",
        "B-L1", "I-L1", "O", "B-L1", "I-L1", "O"
    ]

    expect_bio = [
        "B-L1", "I-L1", "O", "B-L1", "B-L2", "O", "B-L1", "I-L1", "I-L1",
        "B-L1", "I-L1", "O", "B-L1", "I-L1", "O"
    ]

    bio_sequence = bio_schema.ibo1_to_bio(ibo1)

    ASSERT.assertListEqual(expect_bio, bio_sequence)
Exemple #12
0
def test_synchronized_data():
    """
    测试 to_synchronized_data 和 from_synchronized_data
    :return:
    """

    demo_metric = _DemoF1Metric()

    sync_data, op = demo_metric.to_synchronized_data()

    true_positives = sync_data["true_positives"]
    false_positives = sync_data["false_positives"]
    false_negatives = sync_data["false_negatives"]

    expect_values = [v for _, v in demo_metric.true_positives.items()]
    ASSERT.assertListEqual(expect_values, true_positives.tolist())

    expect_values = [v for _, v in demo_metric.false_positives.items()]
    ASSERT.assertListEqual(expect_values, false_positives.tolist())

    expect_values = [v for _, v in demo_metric._false_negatives.items()]
    ASSERT.assertListEqual(expect_values, false_negatives.tolist())

    expect_true_positives = dict(demo_metric.true_positives)
    expect_false_positives = dict(demo_metric.false_positives)
    expect_false_negatives = dict(demo_metric.false_negatives)

    demo_metric.from_synchronized_data(sync_data=sync_data, reduce_op=op)

    ASSERT.assertDictEqual(expect_true_positives, demo_metric.true_positives)
    ASSERT.assertDictEqual(expect_false_positives, demo_metric.false_positives)
    ASSERT.assertDictEqual(expect_false_negatives, demo_metric.false_negatives)
Exemple #13
0
def test_span_intersection():
    span_list1 = [{"label": "T", "begin": 0, "end": 2}, {"label": "T", "begin": 3, "end": 4}]
    span_list2 = [{"label": "T", "begin": 0, "end": 2}]

    intersetction = BIO.span_intersection(span_list1=span_list1,
                                          span_list2=span_list2)

    expect = [{"label": "T", "begin": 0, "end": 2}]
    ASSERT.assertListEqual(expect, intersetction)

    span_list1 = [{"label": "T", "begin": 0, "end": 2}, {"label": "T", "begin": 3, "end": 4}]
    span_list2 = [{"label": "T", "begin": 9, "end": 10}]

    intersetction = BIO.span_intersection(span_list1=span_list1,
                                          span_list2=span_list2)

    expect = []
    ASSERT.assertListEqual(expect, intersetction)

    span_list1 = [{"label": "T", "begin": 0, "end": 2}, {"label": "T", "begin": 3, "end": 4}]
    span_list2 = []

    intersetction = BIO.span_intersection(span_list1=span_list1,
                                          span_list2=span_list2)

    expect = []
    ASSERT.assertListEqual(expect, intersetction)
Exemple #14
0
def test_synchronized_data():
    """
    测试 from_synchronized_data 和 to_synchronized_data
    :return:
    """

    acc_metric = AccMetric()

    sync_data, op = acc_metric.to_synchronized_data()

    ASSERT.assertEqual((2,), sync_data.size())
    ASSERT.assertEqual(0, sync_data[0].item())
    ASSERT.assertEqual(0, sync_data[1].item())

    # 对应的label是 [1, 1, 0, 1]
    logits = torch.tensor([[1., 2.], [3., 4.], [5, 4.], [3., 7.]], dtype=torch.float)
    prediction_labels = torch.argmax(logits, dim=-1)

    golden_labels = torch.tensor([0, 1, 1, 0], dtype=torch.long)

    acc_metric(prediction_labels=prediction_labels, gold_labels=golden_labels, mask=None)

    # acc = 1/4

    sync_data, op = acc_metric.to_synchronized_data()
    ASSERT.assertListEqual([1, 4], sync_data.tolist())

    acc_metric.from_synchronized_data(sync_data=sync_data, reduce_op=op)
    acc = acc_metric.metric

    expect = 1/4
    ASSERT.assertAlmostEqual(expect, acc[AccMetric.ACC])

    new_sync_data, op = acc_metric.to_synchronized_data()

    ASSERT.assertListEqual(sync_data.tolist(), new_sync_data.tolist())
Exemple #15
0
def test_decode_one_sequence_logits_to_label_abnormal():
    """
    测试异常case
    :return:
    """

    # [0.2, 0.5, 0.4] argmax 解码是 I 这是异常的 case, 整个序列是: I B O
    # 而 decode_sequence_lable_bio 会将 概率值 是 0.4 的 也就是 O 作为标签输出 来修订这个个错误
    sequence_logits = torch.tensor([[0.2, 0.5, 0.4], [0.7, 0.2, 0.3], [0.2, 0.3, 0.1]],
                                   dtype=torch.float)

    vocabulary = LabelVocabulary([["B-T", "B-T", "B-T", "I-T", "I-T", "O"]],
                                 padding=LabelVocabulary.PADDING)

    b_index = vocabulary.index("B-T")
    ASSERT.assertEqual(0, b_index)
    i_index = vocabulary.index("I-T")
    ASSERT.assertEqual(1, i_index)
    o_index = vocabulary.index("O")
    ASSERT.assertEqual(2, o_index)

    sequence_label, sequence_label_indices = BIO.decode_one_sequence_logits_to_label(sequence_logits=sequence_logits,
                                                             vocabulary=vocabulary)

    expect = ["O", "B-T", "I-T"]
    expect_indices = [vocabulary.index(label) for label in expect]
    ASSERT.assertListEqual(expect, sequence_label)
    ASSERT.assertListEqual(expect_indices, sequence_label_indices)

    # argmax 解码是 I I I 经过修订后是: O O B
    sequence_logits = torch.tensor([[0.2, 0.5, 0.4], [0.2, 0.9, 0.3], [0.2, 0.3, 0.1]],
                                   dtype=torch.float)

    sequence_label, sequence_label_indices = BIO.decode_one_sequence_logits_to_label(sequence_logits=sequence_logits,
                                                             vocabulary=vocabulary)
    expect = ["O", "O", "B-T"]
    expect_indices = [vocabulary.index(label) for label in expect]
    ASSERT.assertListEqual(expect, sequence_label)
    ASSERT.assertListEqual(expect_indices, sequence_label_indices)
def test_sequence_label_decoder():
    """
    测试 sequence label decoder
    :return:
    """
    sequence_label_list = list()
    expect_spans = list()

    sequence_label = ["B-T", "I-T", "O-T"]
    expect = [{"label": "T", "begin": 0, "end": 2}]

    sequence_label_list.append(sequence_label)
    expect_spans.append(expect)

    sequence_label = ["B-T", "I-T", "I-T"]
    expect = [{"label": "T", "begin": 0, "end": 3}]

    sequence_label_list.append(sequence_label)
    expect_spans.append(expect)

    sequence_label = ["B-T", "I-T", "I-T", "B-T"]
    expect = [{
        "label": "T",
        "begin": 0,
        "end": 3
    }, {
        "label": "T",
        "begin": 3,
        "end": 4
    }]

    sequence_label_list.append(sequence_label)
    expect_spans.append(expect)

    label_vocabulary = LabelVocabulary(sequence_label_list,
                                       padding=LabelVocabulary.PADDING)

    sequence_label_indices = list()
    mask_list = list()

    max_sequence_len = 4
    for sequence_labels in sequence_label_list:
        sequence_label_index = [
            label_vocabulary.index(label) for label in sequence_labels
        ]

        mask = [1] * len(sequence_label_index) + [0] * (
            max_sequence_len - len(sequence_label_index))

        sequence_label_index.extend(
            [label_vocabulary.padding_index] *
            (max_sequence_len - len(sequence_label_index)))

        sequence_label_indices.append(sequence_label_index)
        mask_list.append(mask)

    sequence_label_indices = torch.tensor(sequence_label_indices,
                                          dtype=torch.long)
    mask = torch.tensor(mask_list, dtype=torch.uint8)

    decoder = SequenceLabelDecoder(label_vocabulary=label_vocabulary)

    spans = decoder(label_indices=sequence_label_indices, mask=mask)

    ASSERT.assertListEqual(expect_spans, spans)