Example #1
0
    def test_case2(self):
        # 测试CRF是否正常work。
        import json
        import torch
        from fastNLP import seq_len_to_mask

        with open('tests/data_for_tests/modules/decoder/crf.json', 'r') as f:
            data = json.load(f)

        bio_logits = torch.FloatTensor(data['bio_logits'])
        bio_scores = data['bio_scores']
        bio_path = data['bio_path']
        bio_trans_m = torch.FloatTensor(data['bio_trans_m'])
        bio_seq_lens = torch.LongTensor(data['bio_seq_lens'])

        bmes_logits = torch.FloatTensor(data['bmes_logits'])
        bmes_scores = data['bmes_scores']
        bmes_path = data['bmes_path']
        bmes_trans_m = torch.FloatTensor(data['bmes_trans_m'])
        bmes_seq_lens = torch.LongTensor(data['bmes_seq_lens'])

        labels = ['O']
        for label in ['X', 'Y']:
            for tag in 'BI':
                labels.append('{}-{}'.format(tag, label))
        id2label = {idx: label for idx, label in enumerate(labels)}
        num_tags = len(id2label)

        mask = seq_len_to_mask(bio_seq_lens)

        from fastNLP.modules.decoder.crf import ConditionalRandomField, allowed_transitions
        fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label,
                                                                                                     include_start_end=True))
        fast_CRF.trans_m.data = bio_trans_m
        fast_res = fast_CRF.viterbi_decode(bio_logits, mask, unpad=True)
        # score equal
        self.assertListEqual(bio_scores, [round(s, 4) for s in fast_res[1].tolist()])
        # seq equal
        self.assertListEqual(bio_path, fast_res[0])

        labels = []
        for label in ['X', 'Y']:
            for tag in 'BMES':
                labels.append('{}-{}'.format(tag, label))
        id2label = {idx: label for idx, label in enumerate(labels)}
        num_tags = len(id2label)

        mask = seq_len_to_mask(bmes_seq_lens)

        from fastNLP.modules.decoder.crf import ConditionalRandomField, allowed_transitions
        fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label,
                                                                                                     encoding_type='BMES',
                                                                                                     include_start_end=True))
        fast_CRF.trans_m.data = bmes_trans_m
        fast_res = fast_CRF.viterbi_decode(bmes_logits, mask, unpad=True)
        # score equal
        self.assertListEqual(bmes_scores, [round(s, 4) for s in fast_res[1].tolist()])
        # seq equal
        self.assertListEqual(bmes_path, fast_res[0])
Example #2
0
    def __init__(self, vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None,
                 hidden_size=200, bidirectional=True, embed_drop_p=0.2, num_layers=1, tag_size=4):
        """
        默认使用BMES的标注方式
        :param vocab_num:
        :param embed_dim:
        :param bigram_vocab_num:
        :param bigram_embed_dim:
        :param num_bigram_per_char:
        :param hidden_size:
        :param bidirectional:
        :param embed_drop_p:
        :param num_layers:
        :param tag_size:
        """
        super(CWSBiLSTMCRF, self).__init__()

        self.tag_size = tag_size

        self.encoder_model = CWSBiLSTMEncoder(vocab_num, embed_dim, bigram_vocab_num, bigram_embed_dim, num_bigram_per_char,
                 hidden_size, bidirectional, embed_drop_p, num_layers)

        size_layer = [hidden_size, 200, tag_size]
        self.decoder_model = MLP(size_layer)
        allowed_trans = allowed_transitions({0:'b', 1:'m', 2:'e', 3:'s'}, encoding_type='bmes')
        self.crf = ConditionalRandomField(num_tags=tag_size, include_start_end_trans=False,
                                          allowed_transitions=allowed_trans)
Example #3
0
    def test_case1(self):
        # 检查allowed_transitions()能否正确使用
        from fastNLP.modules.decoder.crf import allowed_transitions

        id2label = {0: 'B', 1: 'I', 2:'O'}
        expected_res = {(0, 0), (0, 1), (0, 2), (0, 4), (1, 0), (1, 1), (1, 2), (1, 4), (2, 0), (2, 2),
                        (2, 4), (3, 0), (3, 2)}
        self.assertSetEqual(expected_res, set(allowed_transitions(id2label)))

        id2label = {0: 'B', 1:'M', 2:'E', 3:'S'}
        expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 5), (3, 0), (3, 3), (3, 5), (4, 0), (4, 3)}
        self.assertSetEqual(expected_res, set(allowed_transitions(id2label, encoding_type='BMES')))

        id2label = {0: 'B', 1: 'I', 2:'O', 3: '<pad>', 4:"<unk>"}
        allowed_transitions(id2label)

        labels = ['O']
        for label in ['X', 'Y']:
            for tag in 'BI':
                labels.append('{}-{}'.format(tag, label))
        id2label = {idx:label for idx, label in enumerate(labels)}
        expected_res = {(0, 0), (0, 1), (0, 3), (0, 6), (1, 0), (1, 1), (1, 2), (1, 3), (1, 6), (2, 0), (2, 1),
                        (2, 2), (2, 3), (2, 6), (3, 0), (3, 1), (3, 3), (3, 4), (3, 6), (4, 0), (4, 1), (4, 3),
                        (4, 4), (4, 6), (5, 0), (5, 1), (5, 3)}
        self.assertSetEqual(expected_res, set(allowed_transitions(id2label)))

        labels = []
        for label in ['X', 'Y']:
            for tag in 'BMES':
                labels.append('{}-{}'.format(tag, label))
        id2label = {idx: label for idx, label in enumerate(labels)}
        expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 4), (2, 7), (2, 9), (3, 0), (3, 3), (3, 4),
                        (3, 7), (3, 9), (4, 5), (4, 6), (5, 5), (5, 6), (6, 0), (6, 3), (6, 4), (6, 7), (6, 9), (7, 0),
                        (7, 3), (7, 4), (7, 7), (7, 9), (8, 0), (8, 3), (8, 4), (8, 7)}
        self.assertSetEqual(expected_res, set(allowed_transitions(id2label, encoding_type='BMES')))
Example #4
0
    def __init__(self,
                 vocab_num,
                 embed_dim=100,
                 bigram_vocab_num=None,
                 bigram_embed_dim=100,
                 num_bigram_per_char=None,
                 embed_drop_p=0.3,
                 hidden_size=200,
                 kernel_size=3,
                 dilate='none',
                 num_layers=1,
                 num_heads=8,
                 tag_size=4,
                 relative_pos_embed_dim=0):
        super().__init__()

        self.embedding = nn.Embedding(vocab_num, embed_dim)
        input_size = embed_dim
        if bigram_vocab_num:
            self.bigram_embedding = nn.Embedding(bigram_vocab_num,
                                                 bigram_embed_dim)
            input_size += num_bigram_per_char * bigram_embed_dim

        self.drop = nn.Dropout(embed_drop_p, inplace=True)

        self.fc1 = nn.Linear(input_size, hidden_size)

        # value_size = hidden_size//num_heads
        # self.transformer = TransformerEncoder(num_layers, model_size=hidden_size, inner_size=hidden_size,
        #                                       key_size=value_size,
        #                                       value_size=value_size, num_head=num_heads)
        self.transformer = TransformerDilateEncoder(
            num_layers=num_layers,
            model_size=hidden_size,
            num_heads=num_heads,
            hidden_size=hidden_size,
            kernel_size=kernel_size,
            dilate=dilate,
            relative_pos_embed_dim=relative_pos_embed_dim)
        self.fc2 = nn.Linear(hidden_size, tag_size)

        allowed_trans = allowed_transitions({
            0: 'b',
            1: 'm',
            2: 'e',
            3: 's'
        },
                                            encoding_type='bmes')
        self.crf = ConditionalRandomField(num_tags=tag_size,
                                          include_start_end_trans=False,
                                          allowed_transitions=allowed_trans)
Example #5
0
    def test_case12(self):
        # 测试能否通过vocab生成转移矩阵
        from fastNLP.modules.decoder.crf import allowed_transitions

        id2label = {0: 'B', 1: 'I', 2: 'O'}
        vocab = Vocabulary(unknown=None, padding=None)
        for idx, tag in id2label.items():
            vocab.add_word(tag)
        expected_res = {(0, 0), (0, 1), (0, 2), (0, 4), (1, 0), (1, 1), (1, 2), (1, 4), (2, 0), (2, 2),
                        (2, 4), (3, 0), (3, 2)}
        self.assertSetEqual(expected_res, set(allowed_transitions(vocab, include_start_end=True)))

        id2label = {0: 'B', 1: 'M', 2: 'E', 3: 'S'}
        vocab = Vocabulary(unknown=None, padding=None)
        for idx, tag in id2label.items():
            vocab.add_word(tag)
        expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 5), (3, 0), (3, 3), (3, 5), (4, 0), (4, 3)}
        self.assertSetEqual(expected_res, set(
            allowed_transitions(vocab, include_start_end=True)))

        id2label = {0: 'B', 1: 'I', 2: 'O', 3: '<pad>', 4: "<unk>"}
        vocab = Vocabulary()
        for idx, tag in id2label.items():
            vocab.add_word(tag)
        allowed_transitions(vocab, include_start_end=True)

        labels = ['O']
        for label in ['X', 'Y']:
            for tag in 'BI':
                labels.append('{}-{}'.format(tag, label))
        id2label = {idx: label for idx, label in enumerate(labels)}
        expected_res = {(0, 0), (0, 1), (0, 3), (0, 6), (1, 0), (1, 1), (1, 2), (1, 3), (1, 6), (2, 0), (2, 1),
                        (2, 2), (2, 3), (2, 6), (3, 0), (3, 1), (3, 3), (3, 4), (3, 6), (4, 0), (4, 1), (4, 3),
                        (4, 4), (4, 6), (5, 0), (5, 1), (5, 3)}
        vocab = Vocabulary(unknown=None, padding=None)
        for idx, tag in id2label.items():
            vocab.add_word(tag)
        self.assertSetEqual(expected_res, set(allowed_transitions(vocab, include_start_end=True)))

        labels = []
        for label in ['X', 'Y']:
            for tag in 'BMES':
                labels.append('{}-{}'.format(tag, label))
        id2label = {idx: label for idx, label in enumerate(labels)}
        vocab = Vocabulary(unknown=None, padding=None)
        for idx, tag in id2label.items():
            vocab.add_word(tag)
        expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 4), (2, 7), (2, 9), (3, 0), (3, 3), (3, 4),
                        (3, 7), (3, 9), (4, 5), (4, 6), (5, 5), (5, 6), (6, 0), (6, 3), (6, 4), (6, 7), (6, 9), (7, 0),
                        (7, 3), (7, 4), (7, 7), (7, 9), (8, 0), (8, 3), (8, 4), (8, 7)}
        self.assertSetEqual(expected_res, set(
            allowed_transitions(vocab, include_start_end=True)))