def test_allowed_transitions(): """ 测试允许转移mask pair :return: """ label_vocabulary = LabelVocabulary(labels=[["B-L1", "I-L1", "B-L2", "I-L2", "O"]], padding=LabelVocabulary.PADDING) allowed_pairs = BIO.allowed_transitions(label_vocabulary=label_vocabulary) for from_idx, to_idx in allowed_pairs: if from_idx == label_vocabulary.label_size: from_label = "START" else: from_label = label_vocabulary.token(from_idx) if to_idx == label_vocabulary.label_size + 1: to_label = "STOP" else: to_label = label_vocabulary.token(to_idx) print(f"(\"{from_label}\", \"{to_label}\"),") expect_trainsition_labels = [ ("B-L1", "B-L1"), ("B-L1", "I-L1"), ("B-L1", "B-L2"), ("B-L1", "O"), ("B-L1", "STOP"), ("I-L1", "B-L1"), ("I-L1", "I-L1"), ("I-L1", "B-L2"), ("I-L1", "O"), ("I-L1", "STOP"), ("B-L2", "B-L1"), ("B-L2", "B-L2"), ("B-L2", "I-L2"), ("B-L2", "O"), ("B-L2", "STOP"), ("I-L2", "B-L1"), ("I-L2", "B-L2"), ("I-L2", "I-L2"), ("I-L2", "O"), ("I-L2", "STOP"), ("O", "B-L1"), ("O", "B-L2"), ("O", "O"), ("O", "STOP"), ("START", "B-L1"), ("START", "B-L2"), ("START", "O")] expect = list() for from_label, to_label in expect_trainsition_labels: if from_label == "START": from_idx = label_vocabulary.label_size else: from_idx = label_vocabulary.index(from_label) if to_label == "STOP": to_idx = label_vocabulary.label_size + 1 else: to_idx = label_vocabulary.index(to_label) expect.append((from_idx, to_idx)) ASSERT.assertSetEqual(set(expect), set(allowed_pairs))
def allowed_transitions( label_vocabulary: LabelVocabulary) -> List[Tuple[int, int]]: """ 给定 label 字典,计算 IBO schema 下的 转移矩阵 mask. 因为在 BIO 下面,比如 ["O", "I-Per"] 或者 ["B-Per", "I-Loc"], 这些转移是不被允许的。该函数,计算mask. 对于允许的转移返回 pair, (from_label_id, to_label_id) 这里要特别注意的是: 返回的 allowed pair, 是带有 START 和 STOP的。他们的 index 分别是: label_vocabulary.label_size, label_vocabulary.label_size+1 注意 label_vocabulary.padding_index 也是 label_vocabulary.label_size。虽然是一样的, 但是因为 START 是用在转移矩阵这个特定场景下,所以不会产生冲突。 :param label_vocabulary: label 词汇表 :return: 所有被允许转移的 (from_label_id, to_label_id) pair 对列表。 """ num_labels = label_vocabulary.label_size start_index = num_labels end_index = num_labels + 1 labels = [(i, label_vocabulary.token(i)) for i in range(label_vocabulary.label_size)] labels_with_boundaries = labels + [(start_index, "START"), (end_index, "END")] allowed = [] for from_label_index, from_label in labels_with_boundaries: if from_label in ("START", "END"): from_tag = from_label from_entity = "" else: from_tag = from_label[0] from_entity = from_label[1:] for to_label_index, to_label in labels_with_boundaries: if to_label in ("START", "END"): to_tag = to_label to_entity = "" else: to_tag = to_label[0] to_entity = to_label[1:] if _is_transition_allowed(from_tag, from_entity, to_tag, to_entity): allowed.append((from_label_index, to_label_index)) return allowed
def decode_label_index_to_span( batch_sequence_label_index: torch.Tensor, mask: torch.ByteTensor, vocabulary: LabelVocabulary) -> List[List[Dict]]: """ 将 label index 解码 成span batch_sequence_label shape:(B, seq_len) (B-T: 0, I-T: 1, O: 2) [[0, 1, 2], [2, 0, 1]] 对应label序列是: [[B, I, O], [O, B, I]] 解码成: [[{"label": T, "begin": 0, "end": 2}], [{"label": T, "begin": 1, "end": 3}]] :param batch_sequence_label_index: shape: (B, seq_len), label index 序列 :param mask: 对 batch_sequence_label 的 mask :param vocabulary: label 词汇表 :return: 解析好的span列表 """ spans = list() batch_size = batch_sequence_label_index.size(0) if mask is None: mask = torch.ones(size=(batch_sequence_label_index.shape[0], batch_sequence_label_index.shape[1]), dtype=torch.long) sequence_lengths = mask.sum(dim=-1) for i in range(batch_size): label_indices = batch_sequence_label_index[ i, :sequence_lengths[i]].tolist() sequence_label = [vocabulary.token(index) for index in label_indices] span = decode_one_sequence_label_to_span(sequence_label=sequence_label) spans.append(span) return spans
def __init__(self, label_vocabulary: LabelVocabulary) -> None: """ 初始化 :param label_vocabulary: label 的 vocabulary """ labels = set() # 从 B-Label, I-Label, 中获取 Label for index in range(label_vocabulary.label_size): bio_label: str = label_vocabulary.token(index) if bio_label == "O": continue label = bio_label.split("-")[1] labels.add(label) labels = [_ for _ in labels] super().__init__(labels=labels) self.label_vocabulary = label_vocabulary
def decode_one_sequence_logits_to_label( sequence_logits: torch.Tensor, vocabulary: LabelVocabulary) -> Tuple[List[str], List[int]]: """ 对 输出 sequence logits 进行解码, 是仅仅一个 sequence 进行解码,而不是 batch sequence 进行解码。 batch sequence 解码需要进行循环 :param sequence_logits: shape: (seq_len, label_num), 是 mask 之后的有效 sequence,而不是包含 mask 的 sequecne logits. :return: sequence label, B, I, O 的list 以及 label 对应的 index list """ if len(sequence_logits.shape) != 2: raise RuntimeError( f"sequence_logits shape 是 (seq_len, label_num), 现在是 {sequence_logits.shape}" ) idel_state, span_state = 0, 1 sequence_length = sequence_logits.size(0) state = idel_state # 按权重进行排序 indices shape: (seq_len, label_num) sorted_sequence_indices = torch.argsort(sequence_logits, dim=-1, descending=True) sequence_label = list() sequence_label_indices = list() for i in range(sequence_length): indices = sorted_sequence_indices[i, :].tolist() if state == idel_state: # 循环寻找,直到找到一个合理的标签 for index in indices: label = vocabulary.token(index) if label[0] == "O": sequence_label.append(label) sequence_label_indices.append(index) state = idel_state break elif label[0] == "B": sequence_label.append(label) sequence_label_indices.append(index) state = span_state break else: # 其他情况 "I" 这是不合理的,所以这个逻辑是找到一个合理的标签 pass elif state == span_state: for index in indices: label = vocabulary.token(index) if label[0] == "B": sequence_label.append(label) sequence_label_indices.append(index) state = span_state break elif label[0] == "O": sequence_label.append(label) sequence_label_indices.append(index) state = idel_state break elif label[0] == "I": sequence_label.append(label) sequence_label_indices.append(index) state = span_state break else: raise RuntimeError(f"{label} 不符合 BIO 格式") else: raise RuntimeError(f"state is error: {state}") return sequence_label, sequence_label_indices