def test_sequence_max_label_index_decoder(): label_vocabulary = LabelVocabulary( [["B-T", "B-T", "B-T", "I-T", "I-T", "O"]], padding=LabelVocabulary.PADDING) b_index = label_vocabulary.index("B-T") ASSERT.assertEqual(0, b_index) i_index = label_vocabulary.index("I-T") ASSERT.assertEqual(1, i_index) o_index = label_vocabulary.index("O") ASSERT.assertEqual(2, o_index) # [[O, B, I], [B, B, I], [B, I, I], [B, I, O]] batch_sequence_logits = torch.tensor( [[[0.2, 0.3, 0.4], [0.7, 0.2, 0.3], [0.2, 0.3, 0.1]], [[0.8, 0.3, 0.4], [0.7, 0.2, 0.3], [0.2, 0.3, 0.1]], [[0.8, 0.3, 0.4], [0.1, 0.7, 0.3], [0.2, 0.3, 0.1]], [[0.8, 0.3, 0.4], [0.1, 0.7, 0.3], [0.2, 0.3, 0.5]]], dtype=torch.float) expect_sequence_labels = [["O", "B-T", "I-T"], ["B-T", "B-T", "I-T"], ["B-T", "I-T", "I-T"], ["B-T", "I-T", "O"]] expect = list() for expect_sequence_label in expect_sequence_labels: expect.append( [label_vocabulary.index(label) for label in expect_sequence_label]) decoder = SequenceMaxLabelIndexDecoder(label_vocabulary=label_vocabulary) label_indices = decoder(logits=batch_sequence_logits, mask=None) ASSERT.assertEqual(expect, label_indices.tolist())
def test_decode_decode_label_index_to_span(): """ 测试解码 golden label index :return: """ vocabulary = LabelVocabulary([["B-T", "B-T", "B-T", "I-T", "I-T", "O"]], padding=LabelVocabulary.PADDING) b_index = vocabulary.index("B-T") ASSERT.assertEqual(0, b_index) i_index = vocabulary.index("I-T") ASSERT.assertEqual(1, i_index) o_index = vocabulary.index("O") ASSERT.assertEqual(2, o_index) golden_labels = torch.tensor([[0, 1, 2, 0], [2, 0, 1, 1]]) expect = [[{"label": "T", "begin": 0, "end": 2}, {"label": "T", "begin": 3, "end": 4}], [{"label": "T", "begin": 1, "end": 4}]] spans = BIO.decode_label_index_to_span(batch_sequence_label_index=golden_labels, mask=None, vocabulary=vocabulary) ASSERT.assertListEqual(expect, spans)
def test_decode(): """ 测试 模型输出的 batch logits 解码 :return: """ # [[O, B, I], [B, B, I], [B, I, I], [B, I, O]] batch_sequence_logits = torch.tensor([[[0.2, 0.3, 0.4], [0.7, 0.2, 0.3], [0.2, 0.3, 0.1]], [[0.8, 0.3, 0.4], [0.7, 0.2, 0.3], [0.2, 0.3, 0.1]], [[0.8, 0.3, 0.4], [0.1, 0.7, 0.3], [0.2, 0.3, 0.1]], [[0.8, 0.3, 0.4], [0.1, 0.7, 0.3], [0.2, 0.3, 0.5]]], dtype=torch.float) expect = [[{"label": "T", "begin": 1, "end": 3}], [{"label": "T", "begin": 0, "end": 1}, {"label": "T", "begin": 1, "end": 3}], [{"label": "T", "begin": 0, "end": 3}], [{"label": "T", "begin": 0, "end": 2}]] vocabulary = LabelVocabulary([["B-T", "B-T", "B-T", "I-T", "I-T", "O"]], padding=LabelVocabulary.PADDING) b_index = vocabulary.index("B-T") ASSERT.assertEqual(0, b_index) i_index = vocabulary.index("I-T") ASSERT.assertEqual(1, i_index) o_index = vocabulary.index("O") ASSERT.assertEqual(2, o_index) spans = BIO.decode(batch_sequence_logits=batch_sequence_logits, mask=None, vocabulary=vocabulary) ASSERT.assertListEqual(expect, spans)
def test_allowed_transitions(): """ 测试允许转移mask pair :return: """ label_vocabulary = LabelVocabulary(labels=[["B-L1", "I-L1", "B-L2", "I-L2", "O"]], padding=LabelVocabulary.PADDING) allowed_pairs = BIO.allowed_transitions(label_vocabulary=label_vocabulary) for from_idx, to_idx in allowed_pairs: if from_idx == label_vocabulary.label_size: from_label = "START" else: from_label = label_vocabulary.token(from_idx) if to_idx == label_vocabulary.label_size + 1: to_label = "STOP" else: to_label = label_vocabulary.token(to_idx) print(f"(\"{from_label}\", \"{to_label}\"),") expect_trainsition_labels = [ ("B-L1", "B-L1"), ("B-L1", "I-L1"), ("B-L1", "B-L2"), ("B-L1", "O"), ("B-L1", "STOP"), ("I-L1", "B-L1"), ("I-L1", "I-L1"), ("I-L1", "B-L2"), ("I-L1", "O"), ("I-L1", "STOP"), ("B-L2", "B-L1"), ("B-L2", "B-L2"), ("B-L2", "I-L2"), ("B-L2", "O"), ("B-L2", "STOP"), ("I-L2", "B-L1"), ("I-L2", "B-L2"), ("I-L2", "I-L2"), ("I-L2", "O"), ("I-L2", "STOP"), ("O", "B-L1"), ("O", "B-L2"), ("O", "O"), ("O", "STOP"), ("START", "B-L1"), ("START", "B-L2"), ("START", "O")] expect = list() for from_label, to_label in expect_trainsition_labels: if from_label == "START": from_idx = label_vocabulary.label_size else: from_idx = label_vocabulary.index(from_label) if to_label == "STOP": to_idx = label_vocabulary.label_size + 1 else: to_idx = label_vocabulary.index(to_label) expect.append((from_idx, to_idx)) ASSERT.assertSetEqual(set(expect), set(allowed_pairs))
def test_decode_one_sequence_logits_to_label(): """ 测试 decode sequence label :return: """ sequence_logits_list = list() expect_list = list() sequence_logits = torch.tensor([[0.2, 0.3, 0.4], [0.7, 0.2, 0.3], [0.2, 0.3, 0.1]], dtype=torch.float) # O B I 正常 expect = ["O", "B-T", "I-T"] sequence_logits_list.append(sequence_logits) expect_list.append(expect) sequence_logits = torch.tensor([[0.9, 0.3, 0.4], [0.2, 0.8, 0.3], [0.2, 0.3, 0.1]], dtype=torch.float) expect = ["B-T", "I-T", "I-T"] sequence_logits_list.append(sequence_logits) expect_list.append(expect) sequence_logits = torch.tensor([[0.9, 0.3, 0.4], [0.2, 0.8, 0.3], [0.2, 0.3, 0.9]], dtype=torch.float) expect = ["B-T", "I-T", "O"] sequence_logits_list.append(sequence_logits) expect_list.append(expect) vocabulary = LabelVocabulary([["B-T", "B-T", "B-T", "I-T", "I-T", "O"]], padding=LabelVocabulary.PADDING) b_index = vocabulary.index("B-T") ASSERT.assertEqual(0, b_index) i_index = vocabulary.index("I-T") ASSERT.assertEqual(1, i_index) o_index = vocabulary.index("O") ASSERT.assertEqual(2, o_index) for sequence_logits, expect in zip(sequence_logits_list, expect_list): sequence_label, sequence_label_indices = BIO.decode_one_sequence_logits_to_label( sequence_logits=sequence_logits, vocabulary=vocabulary) ASSERT.assertListEqual(sequence_label, expect) expect_indices = [vocabulary.index(label) for label in expect] ASSERT.assertListEqual(sequence_label_indices, expect_indices)
def test_label_vocabulary(): """ 测试 label vocabulary :return: """ vocabulary = LabelVocabulary([["A", "B", "C"], ["D", "E"]], padding="") ASSERT.assertEqual(vocabulary.size, 5) vocabulary = LabelVocabulary([["A", "B", "C"], ["D", "E"]], padding=LabelVocabulary.PADDING) ASSERT.assertEqual(vocabulary.size, 6) ASSERT.assertEqual(vocabulary.label_size, 5) ASSERT.assertEqual(vocabulary.index(vocabulary.padding), 5) for index, w in enumerate(["A", "B", "C", "D", "E"]): ASSERT.assertEqual(vocabulary.index(w), index)
def test_decode_one_sequence_logits_to_label_abnormal(): """ 测试异常case :return: """ # [0.2, 0.5, 0.4] argmax 解码是 I 这是异常的 case, 整个序列是: I B O # 而 decode_sequence_lable_bio 会将 概率值 是 0.4 的 也就是 O 作为标签输出 来修订这个个错误 sequence_logits = torch.tensor([[0.2, 0.5, 0.4], [0.7, 0.2, 0.3], [0.2, 0.3, 0.1]], dtype=torch.float) vocabulary = LabelVocabulary([["B-T", "B-T", "B-T", "I-T", "I-T", "O"]], padding=LabelVocabulary.PADDING) b_index = vocabulary.index("B-T") ASSERT.assertEqual(0, b_index) i_index = vocabulary.index("I-T") ASSERT.assertEqual(1, i_index) o_index = vocabulary.index("O") ASSERT.assertEqual(2, o_index) sequence_label, sequence_label_indices = BIO.decode_one_sequence_logits_to_label(sequence_logits=sequence_logits, vocabulary=vocabulary) expect = ["O", "B-T", "I-T"] expect_indices = [vocabulary.index(label) for label in expect] ASSERT.assertListEqual(expect, sequence_label) ASSERT.assertListEqual(expect_indices, sequence_label_indices) # argmax 解码是 I I I 经过修订后是: O O B sequence_logits = torch.tensor([[0.2, 0.5, 0.4], [0.2, 0.9, 0.3], [0.2, 0.3, 0.1]], dtype=torch.float) sequence_label, sequence_label_indices = BIO.decode_one_sequence_logits_to_label(sequence_logits=sequence_logits, vocabulary=vocabulary) expect = ["O", "O", "B-T"] expect_indices = [vocabulary.index(label) for label in expect] ASSERT.assertListEqual(expect, sequence_label) ASSERT.assertListEqual(expect_indices, sequence_label_indices)
def test_sequence_label_decoder(): """ 测试 sequence label decoder :return: """ sequence_label_list = list() expect_spans = list() sequence_label = ["B-T", "I-T", "O-T"] expect = [{"label": "T", "begin": 0, "end": 2}] sequence_label_list.append(sequence_label) expect_spans.append(expect) sequence_label = ["B-T", "I-T", "I-T"] expect = [{"label": "T", "begin": 0, "end": 3}] sequence_label_list.append(sequence_label) expect_spans.append(expect) sequence_label = ["B-T", "I-T", "I-T", "B-T"] expect = [{ "label": "T", "begin": 0, "end": 3 }, { "label": "T", "begin": 3, "end": 4 }] sequence_label_list.append(sequence_label) expect_spans.append(expect) label_vocabulary = LabelVocabulary(sequence_label_list, padding=LabelVocabulary.PADDING) sequence_label_indices = list() mask_list = list() max_sequence_len = 4 for sequence_labels in sequence_label_list: sequence_label_index = [ label_vocabulary.index(label) for label in sequence_labels ] mask = [1] * len(sequence_label_index) + [0] * ( max_sequence_len - len(sequence_label_index)) sequence_label_index.extend( [label_vocabulary.padding_index] * (max_sequence_len - len(sequence_label_index))) sequence_label_indices.append(sequence_label_index) mask_list.append(mask) sequence_label_indices = torch.tensor(sequence_label_indices, dtype=torch.long) mask = torch.tensor(mask_list, dtype=torch.uint8) decoder = SequenceLabelDecoder(label_vocabulary=label_vocabulary) spans = decoder(label_indices=sequence_label_indices, mask=mask) ASSERT.assertListEqual(expect_spans, spans)