def test_case2(self): # 测试CRF是否正常work。 import json import torch from fastNLP import seq_len_to_mask with open('tests/data_for_tests/modules/decoder/crf.json', 'r') as f: data = json.load(f) bio_logits = torch.FloatTensor(data['bio_logits']) bio_scores = data['bio_scores'] bio_path = data['bio_path'] bio_trans_m = torch.FloatTensor(data['bio_trans_m']) bio_seq_lens = torch.LongTensor(data['bio_seq_lens']) bmes_logits = torch.FloatTensor(data['bmes_logits']) bmes_scores = data['bmes_scores'] bmes_path = data['bmes_path'] bmes_trans_m = torch.FloatTensor(data['bmes_trans_m']) bmes_seq_lens = torch.LongTensor(data['bmes_seq_lens']) labels = ['O'] for label in ['X', 'Y']: for tag in 'BI': labels.append('{}-{}'.format(tag, label)) id2label = {idx: label for idx, label in enumerate(labels)} num_tags = len(id2label) mask = seq_len_to_mask(bio_seq_lens) from fastNLP.modules.decoder.crf import ConditionalRandomField, allowed_transitions fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label, include_start_end=True)) fast_CRF.trans_m.data = bio_trans_m fast_res = fast_CRF.viterbi_decode(bio_logits, mask, unpad=True) # score equal self.assertListEqual(bio_scores, [round(s, 4) for s in fast_res[1].tolist()]) # seq equal self.assertListEqual(bio_path, fast_res[0]) labels = [] for label in ['X', 'Y']: for tag in 'BMES': labels.append('{}-{}'.format(tag, label)) id2label = {idx: label for idx, label in enumerate(labels)} num_tags = len(id2label) mask = seq_len_to_mask(bmes_seq_lens) from fastNLP.modules.decoder.crf import ConditionalRandomField, allowed_transitions fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label, encoding_type='BMES', include_start_end=True)) fast_CRF.trans_m.data = bmes_trans_m fast_res = fast_CRF.viterbi_decode(bmes_logits, mask, unpad=True) # score equal self.assertListEqual(bmes_scores, [round(s, 4) for s in fast_res[1].tolist()]) # seq equal self.assertListEqual(bmes_path, fast_res[0])
def __init__(self, vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None, hidden_size=200, bidirectional=True, embed_drop_p=0.2, num_layers=1, tag_size=4): """ 默认使用BMES的标注方式 :param vocab_num: :param embed_dim: :param bigram_vocab_num: :param bigram_embed_dim: :param num_bigram_per_char: :param hidden_size: :param bidirectional: :param embed_drop_p: :param num_layers: :param tag_size: """ super(CWSBiLSTMCRF, self).__init__() self.tag_size = tag_size self.encoder_model = CWSBiLSTMEncoder(vocab_num, embed_dim, bigram_vocab_num, bigram_embed_dim, num_bigram_per_char, hidden_size, bidirectional, embed_drop_p, num_layers) size_layer = [hidden_size, 200, tag_size] self.decoder_model = MLP(size_layer) allowed_trans = allowed_transitions({0:'b', 1:'m', 2:'e', 3:'s'}, encoding_type='bmes') self.crf = ConditionalRandomField(num_tags=tag_size, include_start_end_trans=False, allowed_transitions=allowed_trans)
def test_case1(self): # 检查allowed_transitions()能否正确使用 from fastNLP.modules.decoder.crf import allowed_transitions id2label = {0: 'B', 1: 'I', 2:'O'} expected_res = {(0, 0), (0, 1), (0, 2), (0, 4), (1, 0), (1, 1), (1, 2), (1, 4), (2, 0), (2, 2), (2, 4), (3, 0), (3, 2)} self.assertSetEqual(expected_res, set(allowed_transitions(id2label))) id2label = {0: 'B', 1:'M', 2:'E', 3:'S'} expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 5), (3, 0), (3, 3), (3, 5), (4, 0), (4, 3)} self.assertSetEqual(expected_res, set(allowed_transitions(id2label, encoding_type='BMES'))) id2label = {0: 'B', 1: 'I', 2:'O', 3: '<pad>', 4:"<unk>"} allowed_transitions(id2label) labels = ['O'] for label in ['X', 'Y']: for tag in 'BI': labels.append('{}-{}'.format(tag, label)) id2label = {idx:label for idx, label in enumerate(labels)} expected_res = {(0, 0), (0, 1), (0, 3), (0, 6), (1, 0), (1, 1), (1, 2), (1, 3), (1, 6), (2, 0), (2, 1), (2, 2), (2, 3), (2, 6), (3, 0), (3, 1), (3, 3), (3, 4), (3, 6), (4, 0), (4, 1), (4, 3), (4, 4), (4, 6), (5, 0), (5, 1), (5, 3)} self.assertSetEqual(expected_res, set(allowed_transitions(id2label))) labels = [] for label in ['X', 'Y']: for tag in 'BMES': labels.append('{}-{}'.format(tag, label)) id2label = {idx: label for idx, label in enumerate(labels)} expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 4), (2, 7), (2, 9), (3, 0), (3, 3), (3, 4), (3, 7), (3, 9), (4, 5), (4, 6), (5, 5), (5, 6), (6, 0), (6, 3), (6, 4), (6, 7), (6, 9), (7, 0), (7, 3), (7, 4), (7, 7), (7, 9), (8, 0), (8, 3), (8, 4), (8, 7)} self.assertSetEqual(expected_res, set(allowed_transitions(id2label, encoding_type='BMES')))
def __init__(self, vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None, embed_drop_p=0.3, hidden_size=200, kernel_size=3, dilate='none', num_layers=1, num_heads=8, tag_size=4, relative_pos_embed_dim=0): super().__init__() self.embedding = nn.Embedding(vocab_num, embed_dim) input_size = embed_dim if bigram_vocab_num: self.bigram_embedding = nn.Embedding(bigram_vocab_num, bigram_embed_dim) input_size += num_bigram_per_char * bigram_embed_dim self.drop = nn.Dropout(embed_drop_p, inplace=True) self.fc1 = nn.Linear(input_size, hidden_size) # value_size = hidden_size//num_heads # self.transformer = TransformerEncoder(num_layers, model_size=hidden_size, inner_size=hidden_size, # key_size=value_size, # value_size=value_size, num_head=num_heads) self.transformer = TransformerDilateEncoder( num_layers=num_layers, model_size=hidden_size, num_heads=num_heads, hidden_size=hidden_size, kernel_size=kernel_size, dilate=dilate, relative_pos_embed_dim=relative_pos_embed_dim) self.fc2 = nn.Linear(hidden_size, tag_size) allowed_trans = allowed_transitions({ 0: 'b', 1: 'm', 2: 'e', 3: 's' }, encoding_type='bmes') self.crf = ConditionalRandomField(num_tags=tag_size, include_start_end_trans=False, allowed_transitions=allowed_trans)
def test_case12(self): # 测试能否通过vocab生成转移矩阵 from fastNLP.modules.decoder.crf import allowed_transitions id2label = {0: 'B', 1: 'I', 2: 'O'} vocab = Vocabulary(unknown=None, padding=None) for idx, tag in id2label.items(): vocab.add_word(tag) expected_res = {(0, 0), (0, 1), (0, 2), (0, 4), (1, 0), (1, 1), (1, 2), (1, 4), (2, 0), (2, 2), (2, 4), (3, 0), (3, 2)} self.assertSetEqual(expected_res, set(allowed_transitions(vocab, include_start_end=True))) id2label = {0: 'B', 1: 'M', 2: 'E', 3: 'S'} vocab = Vocabulary(unknown=None, padding=None) for idx, tag in id2label.items(): vocab.add_word(tag) expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 5), (3, 0), (3, 3), (3, 5), (4, 0), (4, 3)} self.assertSetEqual(expected_res, set( allowed_transitions(vocab, include_start_end=True))) id2label = {0: 'B', 1: 'I', 2: 'O', 3: '<pad>', 4: "<unk>"} vocab = Vocabulary() for idx, tag in id2label.items(): vocab.add_word(tag) allowed_transitions(vocab, include_start_end=True) labels = ['O'] for label in ['X', 'Y']: for tag in 'BI': labels.append('{}-{}'.format(tag, label)) id2label = {idx: label for idx, label in enumerate(labels)} expected_res = {(0, 0), (0, 1), (0, 3), (0, 6), (1, 0), (1, 1), (1, 2), (1, 3), (1, 6), (2, 0), (2, 1), (2, 2), (2, 3), (2, 6), (3, 0), (3, 1), (3, 3), (3, 4), (3, 6), (4, 0), (4, 1), (4, 3), (4, 4), (4, 6), (5, 0), (5, 1), (5, 3)} vocab = Vocabulary(unknown=None, padding=None) for idx, tag in id2label.items(): vocab.add_word(tag) self.assertSetEqual(expected_res, set(allowed_transitions(vocab, include_start_end=True))) labels = [] for label in ['X', 'Y']: for tag in 'BMES': labels.append('{}-{}'.format(tag, label)) id2label = {idx: label for idx, label in enumerate(labels)} vocab = Vocabulary(unknown=None, padding=None) for idx, tag in id2label.items(): vocab.add_word(tag) expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 4), (2, 7), (2, 9), (3, 0), (3, 3), (3, 4), (3, 7), (3, 9), (4, 5), (4, 6), (5, 5), (5, 6), (6, 0), (6, 3), (6, 4), (6, 7), (6, 9), (7, 0), (7, 3), (7, 4), (7, 7), (7, 9), (8, 0), (8, 3), (8, 4), (8, 7)} self.assertSetEqual(expected_res, set( allowed_transitions(vocab, include_start_end=True)))