示例#1
0
    def test_auto_encoding_type_infer(self):
        #  检查是否可以自动check encode的类型
        vocabs = {}
        import random
        for encoding_type in ['bio', 'bioes', 'bmeso']:
            vocab = Vocabulary(unknown=None, padding=None)
            for i in range(random.randint(10, 100)):
                label = str(random.randint(1, 10))
                for tag in encoding_type:
                    if tag != 'o':
                        vocab.add_word(f'{tag}-{label}')
                    else:
                        vocab.add_word('o')
            vocabs[encoding_type] = vocab
        for e in ['bio', 'bioes', 'bmeso']:
            with self.subTest(e=e):
                metric = SpanFPreRecMetric(tag_vocab=vocabs[e])
                assert metric.encoding_type == e

        bmes_vocab = _generate_tags('bmes')
        vocab = Vocabulary()
        for tag, index in bmes_vocab.items():
            vocab.add_word(tag)
        metric = SpanFPreRecMetric(vocab)
        assert metric.encoding_type == 'bmes'

        # 一些无法check的情况
        vocab = Vocabulary()
        for i in range(10):
            vocab.add_word(str(i))
        with self.assertRaises(Exception):
            metric = SpanFPreRecMetric(vocab)
示例#2
0
    def test_encoding_type(self):
        # 检查传入的tag_vocab与encoding_type不符合时,是否会报错
        vocabs = {}
        import random
        from itertools import product
        for encoding_type in ['bio', 'bioes', 'bmeso']:
            vocab = Vocabulary(unknown=None, padding=None)
            for i in range(random.randint(10, 100)):
                label = str(random.randint(1, 10))
                for tag in encoding_type:
                    if tag!='o':
                        vocab.add_word(f'{tag}-{label}')
                    else:
                        vocab.add_word('o')
            vocabs[encoding_type] = vocab
        for e1, e2 in product(['bio', 'bioes', 'bmeso'], ['bio', 'bioes', 'bmeso']):
            with self.subTest(e1=e1, e2=e2):
                if e1==e2:
                    metric = SpanFPreRecMetric(vocabs[e1], encoding_type=e2)
                else:
                    s2 = set(e2)
                    s2.update(set(e1))
                    if s2==set(e2):
                        continue
                    with self.assertRaises(AssertionError):
                        metric = SpanFPreRecMetric(vocabs[e1], encoding_type=e2)
        for encoding_type in ['bio', 'bioes', 'bmeso']:
            with self.assertRaises(AssertionError):
                metric = SpanFPreRecMetric(vocabs[encoding_type], encoding_type='bmes')

        with self.assertWarns(Warning):
            vocab = Vocabulary(unknown=None, padding=None).add_word_lst(list('bmes'))
            metric = SpanFPreRecMetric(vocab, encoding_type='bmeso')
            vocab = Vocabulary().add_word_lst(list('bmes'))
            metric = SpanFPreRecMetric(vocab, encoding_type='bmeso')
示例#3
0
    def test(self, file_path):
        test_data = ConllxDataLoader().load(file_path)

        save_dict = self._dict
        tag_vocab = save_dict["tag_vocab"]
        pipeline = save_dict["pipeline"]
        index_tag = IndexerProcessor(vocab=tag_vocab,
                                     field_name="tag",
                                     new_added_field_name="truth",
                                     is_input=False)
        pipeline.pipeline = [index_tag] + pipeline.pipeline

        test_data.rename_field("pos_tags", "tag")
        pipeline(test_data)
        test_data.set_target("truth")
        prediction = test_data.field_arrays["predict"].content
        truth = test_data.field_arrays["truth"].content
        seq_len = test_data.field_arrays["word_seq_origin_len"].content

        # padding by hand
        max_length = max([len(seq) for seq in prediction])
        for idx in range(len(prediction)):
            prediction[idx] = list(
                prediction[idx]) + ([0] * (max_length - len(prediction[idx])))
            truth[idx] = list(truth[idx]) + ([0] *
                                             (max_length - len(truth[idx])))
        evaluator = SpanFPreRecMetric(tag_vocab=tag_vocab,
                                      pred="predict",
                                      target="truth",
                                      seq_lens="word_seq_origin_len")
        evaluator(
            {
                "predict": torch.Tensor(prediction),
                "word_seq_origin_len": torch.Tensor(seq_len)
            }, {"truth": torch.Tensor(truth)})
        test_result = evaluator.get_metric()
        f1 = round(test_result['f'] * 100, 2)
        pre = round(test_result['pre'] * 100, 2)
        rec = round(test_result['rec'] * 100, 2)

        return {"F1": f1, "precision": pre, "recall": rec}
示例#4
0
    def test_case3(self):
        number_labels = 4
        # bio tag
        fastnlp_bio_vocab = Vocabulary(unknown=None, padding=None)
        fastnlp_bio_vocab.word_count = Counter(_generate_tags('BIO', number_labels))
        fastnlp_bio_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bio_vocab, only_gross=False)
        bio_sequence = torch.FloatTensor([[[-0.4424, -0.4579, -0.7376,  1.8129,  0.1316,  1.6566, -1.2169,
          -0.3782,  0.8240],
         [-1.2348, -0.1876, -0.1462, -0.4834, -0.6692, -0.9735,  1.1563,
          -0.3562, -1.4116],
         [ 1.6550, -0.9555,  0.3782, -1.3160, -1.5835, -0.3443, -1.7858,
           2.0023,  0.7075],
         [-0.3772, -0.5447, -1.5631,  1.1614,  1.4598, -1.2764,  0.5186,
           0.3832, -0.1540],
         [-0.1011,  0.0600,  1.1090, -0.3545,  0.1284,  1.1484, -1.0120,
          -1.3508, -0.9513],
         [ 1.8948,  0.8627, -2.1359,  1.3740, -0.7499,  1.5019,  0.6919,
          -0.0842, -0.4294]],

        [[-0.2802,  0.6941, -0.4788, -0.3845,  1.7752,  1.2950, -1.9490,
          -1.4138, -0.8853],
         [-1.3752, -0.5457, -0.5305,  0.4018,  0.2934,  0.7931,  2.3845,
          -1.0726,  0.0364],
         [ 0.3621,  0.2609,  0.1269, -0.5950,  0.7212,  0.5959,  1.6264,
          -0.8836, -0.9320],
         [ 0.2003, -1.0758, -1.1560, -0.6472, -1.7549,  0.1264,  0.6044,
          -1.6857,  1.1571],
         [ 1.4277, -0.4915,  0.4496,  2.2027,  0.0730, -3.1792, -0.5125,
          -0.5837,  1.0184],
         [ 1.9495,  1.7145, -0.2143, -0.1230, -0.2205,  0.8250,  0.4943,
          -0.9025,  0.0864]]])
        bio_target = torch.LongTensor([[3, 6, 0, 8, 2, 4],
                                        [4, 1, 7, 0, 4, 7]])
        fastnlp_bio_metric({'pred': bio_sequence, 'seq_len': torch.LongTensor([6, 6])}, {'target': bio_target})
        expect_bio_res = {'pre-1': 0.333333, 'rec-1': 0.333333, 'f-1': 0.333333, 'pre-2': 0.5, 'rec-2': 0.5,
                          'f-2': 0.5, 'pre-0': 0.0, 'rec-0': 0.0, 'f-0': 0.0, 'pre-3': 0.0, 'rec-3': 0.0,
                          'f-3': 0.0, 'pre': 0.222222, 'rec': 0.181818, 'f': 0.2}

        self.assertDictEqual(expect_bio_res, fastnlp_bio_metric.get_metric())
示例#5
0
    def test(self, filepath):
        """
        传入一个分词文件路径,返回该数据集上分词f1, precision, recall。
        分词文件应该为::
        
            1	编者按	编者按	NN	O	11	nmod:topic
            2	:	:	PU	O	11	punct
            3	7月	7月	NT	DATE	4	compound:nn
            4	12日	12日	NT	DATE	11	nmod:tmod
            5	,	,	PU	O	11	punct
    
            1	这	这	DT	O	3	det
            2	款	款	M	O	1	mark:clf
            3	飞行	飞行	NN	O	8	nsubj
            4	从	从	P	O	5	case
            5	外型	外型	NN	O	8	nmod:prep
            
        以空行分割两个句子,有内容的每行有7列。

        :param filepath: str, 文件路径路径。
        :return: float, float, float. 分别f1, precision, recall.
        """
        tag_proc = self._dict['tag_proc']
        cws_model = self.pipeline.pipeline[-2].model
        pipeline = self.pipeline.pipeline[:-2]

        pipeline.insert(1, tag_proc)
        pp = Pipeline(pipeline)

        reader = ConllCWSReader()

        # te_filename = '/home/hyan/ctb3/test.conllx'
        te_dataset = reader.load(filepath)
        pp(te_dataset)

        from ..core.tester import Tester
        from ..core.metrics import SpanFPreRecMetric

        tester = Tester(data=te_dataset,
                        model=cws_model,
                        metrics=SpanFPreRecMetric(tag_proc.get_vocab()),
                        batch_size=64,
                        verbose=0)
        eval_res = tester.test()

        f1 = eval_res['SpanFPreRecMetric']['f']
        pre = eval_res['SpanFPreRecMetric']['pre']
        rec = eval_res['SpanFPreRecMetric']['rec']
        # print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1, pre, rec))

        return {"F1": f1, "precision": pre, "recall": rec}
                if args.init == 'uniform':
                    nn.init.xavier_uniform_(p)
                    print_info('xavier uniform init:{}'.format(n))
                elif args.init == 'norm':
                    print_info('xavier norm init:{}'.format(n))
                    nn.init.xavier_normal_(p)
            except:
                print_info(n)
                exit(1208)
    print_info('{}init pram{}'.format('*' * 15, '*' * 15))

loss = LossInForward()
encoding_type = 'bmeso'
f1_metric = SpanFPreRecMetric(vocabs['label'],
                              pred='pred',
                              target='target',
                              seq_len='seq_len',
                              encoding_type=encoding_type)
acc_metric = AccuracyMetric(
    pred='pred',
    target='target',
    seq_len='seq_len',
)
acc_metric.set_metric_name('label_acc')
metrics = [f1_metric, acc_metric]
if args.self_supervised:
    chars_acc_metric = AccuracyMetric(pred='chars_pred',
                                      target='chars_target',
                                      seq_len='seq_len')
    chars_acc_metric.set_metric_name('chars_acc')
    metrics.append(chars_acc_metric)
示例#7
0
    def test_case4(self):
        # bmes tag
        def _generate_samples():
            target = []
            seq_len = []
            vocab = Vocabulary(unknown=None, padding=None)
            for i in range(3):
                target_i = []
                seq_len_i = 0
                for j in range(1, 10):
                    word_len = np.random.randint(1, 5)
                    seq_len_i += word_len
                    if word_len == 1:
                        target_i.append('S')
                    else:
                        target_i.append('B')
                        target_i.extend(['M'] * (word_len - 2))
                        target_i.append('E')
                vocab.add_word_lst(target_i)
                target.append(target_i)
                seq_len.append(seq_len_i)
            target_ = np.zeros((3, max(seq_len)))
            for i in range(3):
                target_i = [vocab.to_index(t) for t in target[i]]
                target_[i, :seq_len[i]] = target_i
            return target_, target, seq_len, vocab

        def get_eval(raw_target, pred, vocab, seq_len):
            pred = pred.argmax(dim=-1).tolist()
            tp = 0
            gold = 0
            seg = 0
            pred_target = []
            for i in range(len(seq_len)):
                tags = [vocab.to_word(p) for p in pred[i][:seq_len[i]]]
                spans = []
                prev_bmes_tag = None
                for idx, tag in enumerate(tags):
                    if tag in ('B', 'S'):
                        spans.append([idx, idx])
                    elif tag in ('M', 'E') and prev_bmes_tag in ('B', 'M'):
                        spans[-1][1] = idx
                    else:
                        spans.append([idx, idx])
                    prev_bmes_tag = tag
                tmp = []
                for span in spans:
                    if span[1] - span[0] > 0:
                        tmp.extend(['B'] + ['M'] * (span[1] - span[0] - 1) +
                                   ['E'])
                    else:
                        tmp.append('S')
                pred_target.append(tmp)
            for i in range(len(seq_len)):
                raw_pred = pred_target[i]
                start = 0
                for j in range(seq_len[i]):
                    if raw_target[i][j] in ('E', 'S'):
                        flag = True
                        for k in range(start, j + 1):
                            if raw_target[i][k] != raw_pred[k]:
                                flag = False
                                break
                        if flag:
                            tp += 1
                        start = j + 1
                        gold += 1
                    if raw_pred[j] in ('E', 'S'):
                        seg += 1

            pre = round(tp / seg, 6)
            rec = round(tp / gold, 6)
            return {
                'f': round(2 * pre * rec / (pre + rec), 6),
                'pre': pre,
                'rec': rec
            }

        target, raw_target, seq_len, vocab = _generate_samples()
        pred = torch.randn(3, max(seq_len), 4)

        expected_metric = get_eval(raw_target, pred, vocab, seq_len)
        metric = SpanFPreRecMetric(vocab, encoding_type='bmes')
        metric({
            'pred': pred,
            'seq_len': torch.LongTensor(seq_len)
        }, {'target': torch.from_numpy(target)})
        # print(metric.get_metric(reset=False))
        # print(expected_metric)
        metric_value = metric.get_metric()
        for key, value in expected_metric.items():
            self.assertAlmostEqual(value, metric_value[key], places=5)
示例#8
0
    def tese_case3(self):
        from fastNLP.core.vocabulary import Vocabulary
        from collections import Counter
        from fastNLP.core.metrics import SpanFPreRecMetric

        # 与allennlp测试能否正确计算f metric
        #
        def generate_allen_tags(encoding_type, number_labels=4):
            vocab = {}
            for i in range(number_labels):
                label = str(i)
                for tag in encoding_type:
                    if tag == 'O':
                        if tag not in vocab:
                            vocab['O'] = len(vocab) + 1
                        continue
                    vocab['{}-{}'.format(
                        tag, label)] = len(vocab) + 1  # 其实表达的是这个的count
            return vocab

        number_labels = 4
        # bio tag
        fastnlp_bio_vocab = Vocabulary(unknown=None, padding=None)
        fastnlp_bio_vocab.word_count = Counter(
            generate_allen_tags('BIO', number_labels))
        fastnlp_bio_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bio_vocab,
                                               only_gross=False)
        bio_sequence = torch.FloatTensor(
            [[[
                -0.9543, -1.4357, -0.2365, 0.2438, 1.0312, -1.4302, 0.3011,
                0.0470, 0.0971
            ],
              [
                  -0.6638, -0.7116, -1.9804, 0.2787, -0.2732, -0.9501, -1.4523,
                  0.7987, -0.3970
              ],
              [
                  0.2939, 0.8132, -0.0903, -2.8296, 0.2080, -0.9823, -0.1898,
                  0.6880, 1.4348
              ],
              [
                  -0.1886, 0.0067, -0.6862, -0.4635, 2.2776, 0.0710, -1.6793,
                  -1.6876, -0.8917
              ],
              [
                  -0.7663, 0.6377, 0.8669, 0.1237, 1.7628, 0.0313, -1.0824,
                  1.4217, 0.2622
              ]],
             [[
                 0.1529, 0.7474, -0.9037, 1.5287, 0.2771, 0.2223, 0.8136,
                 1.3592, -0.8973
             ],
              [
                  0.4515, -0.5235, 0.3265, -1.1947, 0.8308, 1.8754, -0.4887,
                  -0.4025, -0.3417
              ],
              [
                  -0.7855, 0.1615, -0.1272, -1.9289, -0.5181, 1.9742, -0.9698,
                  0.2861, -0.3966
              ],
              [
                  -0.8291, -0.8823, -1.1496, 0.2164, 1.3390, -0.3964, -0.5275,
                  0.0213, 1.4777
              ],
              [
                  -1.1299, 0.0627, -0.1358, -1.5951, 0.4484, -0.6081, -1.9566,
                  1.3024, 0.2001
              ]]])
        bio_target = torch.LongTensor([[5., 0., 3., 3., 3.],
                                       [5., 6., 8., 6., 0.]])
        fastnlp_bio_metric(
            {
                'pred': bio_sequence,
                'seq_lens': torch.LongTensor([5, 5])
            }, {'target': bio_target})
        expect_bio_res = {
            'pre-1': 0.24999999999999373,
            'rec-1': 0.499999999999975,
            'f-1': 0.33333333333327775,
            'pre-2': 0.0,
            'rec-2': 0.0,
            'f-2': 0.0,
            'pre-3': 0.0,
            'rec-3': 0.0,
            'f-3': 0.0,
            'pre-0': 0.0,
            'rec-0': 0.0,
            'f-0': 0.0,
            'pre': 0.12499999999999845,
            'rec': 0.12499999999999845,
            'f': 0.12499999999994846
        }
        self.assertDictEqual(expect_bio_res, fastnlp_bio_metric.get_metric())

        #bmes tag
        bmes_sequence = torch.FloatTensor(
            [[[
                0.6536, -0.7179, 0.6579, 1.2503, 0.4176, 0.6696, 0.2352,
                -0.4085, 0.4084, -0.4185, 1.4172, -0.9162, -0.2679, 0.3332,
                -0.3505, -0.6002
            ],
              [
                  0.3238, -1.2378, -1.3304, -0.4903, 1.4518, -0.1868, -0.7641,
                  1.6199, -0.8877, 0.1449, 0.8995, -0.5810, 0.1041, 0.1002,
                  0.4439, 0.2514
              ],
              [
                  -0.8362, 2.9526, 0.8008, 0.1193, 1.0488, 0.6670, 1.1696,
                  -1.1006, -0.8540, -0.1600, -0.9519, -0.2749, -0.4948,
                  -1.4753, 0.5802, -0.0516
              ],
              [
                  -0.8383, -1.7292, -1.4079, -1.5023, 0.5383, 0.6653, 0.3121,
                  4.1249, -0.4173, -0.2043, 1.7755, 1.1110, -1.7069, -0.0390,
                  -0.9242, -0.0333
              ],
              [
                  0.9088, -0.4955, -0.5076, 0.3732, 0.0283, -0.0263, -1.0393,
                  0.7734, 1.0968, 0.4132, -1.3647, -0.5762, 0.6678, 0.8809,
                  -0.3779, -0.3195
              ]],
             [[
                 -0.4638, -0.5939, -0.1052, -0.5573, 0.4600, -1.3484, 0.1753,
                 0.0685, 0.3663, -0.6789, 0.0097, 1.0327, -0.0212, -0.9957,
                 -0.1103, 0.4417
             ],
              [
                  -0.2903, 0.9205, -1.5758, -1.0421, 0.2921, -0.2142, -0.3049,
                  -0.0879, -0.4412, -1.3195, -0.0657, -0.2986, 0.7214, 0.0631,
                  -0.6386, 0.2797
              ],
              [
                  0.6440, -0.3748, 1.2912, -0.0170, 0.7447, 1.4075, -0.4947,
                  0.4123, -0.8447, -0.5502, 0.3520, -0.2832, 0.5019, -0.1522,
                  1.1237, -1.5385
              ],
              [
                  0.2839, -0.7649, 0.9067, -0.1163, -1.3789, 0.2571, -1.3977,
                  -0.3680, -0.8902, -0.6983, -1.1583, 1.2779, 0.2197, 0.1376,
                  -0.0591, -0.2461
              ],
              [
                  -0.2977, -1.8564, -0.5347, 1.0011, -1.1260, 0.4252, -2.0097,
                  2.6973, -0.8308, -1.4939, 0.9865, -0.3935, 0.2743, 0.1142,
                  -0.7344, -1.2046
              ]]])
        bmes_target = torch.LongTensor([[9., 6., 1., 9., 15.],
                                        [6., 15., 6., 15., 5.]])

        fastnlp_bmes_vocab = Vocabulary(unknown=None, padding=None)
        fastnlp_bmes_vocab.word_count = Counter(
            generate_allen_tags('BMES', number_labels))
        fastnlp_bmes_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bmes_vocab,
                                                only_gross=False,
                                                encoding_type='bmes')
        fastnlp_bmes_metric(
            {
                'pred': bmes_sequence,
                'seq_lens': torch.LongTensor([20, 20])
            }, {'target': bmes_target})

        expect_bmes_res = {
            'f-3': 0.6666666666665778,
            'pre-3': 0.499999999999975,
            'rec-3': 0.9999999999999001,
            'f-0': 0.0,
            'pre-0': 0.0,
            'rec-0': 0.0,
            'f-1': 0.33333333333327775,
            'pre-1': 0.24999999999999373,
            'rec-1': 0.499999999999975,
            'f-2': 0.7499999999999314,
            'pre-2': 0.7499999999999812,
            'rec-2': 0.7499999999999812,
            'f': 0.49999999999994504,
            'pre': 0.499999999999995,
            'rec': 0.499999999999995
        }

        self.assertDictEqual(fastnlp_bmes_metric.get_metric(), expect_bmes_res)
示例#9
0
def train(train_data_path, dev_data_path, checkpoint=None, save=None):
    # load config
    train_param = ConfigSection()
    model_param = ConfigSection()
    ConfigLoader().load_config(cfgfile, {
        "train": train_param,
        "model": model_param
    })
    print("config loaded")

    # Data Loader
    print("loading training set...")
    dataset = ConllxDataLoader().load(train_data_path, return_dataset=True)
    print("loading dev set...")
    dev_data = ConllxDataLoader().load(dev_data_path, return_dataset=True)
    print(dataset)
    print("================= dataset ready =====================")

    dataset.rename_field("tag", "truth")
    dev_data.rename_field("tag", "truth")

    vocab_proc = VocabIndexerProcessor("words",
                                       new_added_filed_name="word_seq")
    tag_proc = VocabIndexerProcessor("truth", is_input=True)
    seq_len_proc = SeqLenProcessor(field_name="word_seq",
                                   new_added_field_name="word_seq_origin_len",
                                   is_input=True)
    set_input_proc = SetInputProcessor("word_seq", "word_seq_origin_len")

    vocab_proc(dataset)
    tag_proc(dataset)
    seq_len_proc(dataset)

    # index dev set
    word_vocab, tag_vocab = vocab_proc.vocab, tag_proc.vocab
    dev_data.apply(lambda ins: [word_vocab.to_index(w) for w in ins["words"]],
                   new_field_name="word_seq")
    dev_data.apply(lambda ins: [tag_vocab.to_index(w) for w in ins["truth"]],
                   new_field_name="truth")
    dev_data.apply(lambda ins: len(ins["word_seq"]),
                   new_field_name="word_seq_origin_len")

    # set input & target
    dataset.set_input("word_seq", "word_seq_origin_len", "truth")
    dev_data.set_input("word_seq", "word_seq_origin_len", "truth")
    dataset.set_target("truth", "word_seq_origin_len")
    dev_data.set_target("truth", "word_seq_origin_len")

    # dataset.set_is_target(tag_ids=True)
    model_param["vocab_size"] = vocab_proc.get_vocab_size()
    model_param["num_classes"] = tag_proc.get_vocab_size()
    print("vocab_size={}  num_classes={}".format(model_param["vocab_size"],
                                                 model_param["num_classes"]))

    # define a model
    if checkpoint is None:
        # pre_trained = load_tencent_embed("/home/zyfeng/data/char_tencent_embedding.pkl", vocab_proc.vocab.word2idx)
        pre_trained = None
        model = AdvSeqLabel(model_param, id2words=None, emb=pre_trained)
        print(model)
    else:
        model = torch.load(checkpoint)

    # call trainer to train
    trainer = Trainer(dataset,
                      model,
                      loss=None,
                      metrics=SpanFPreRecMetric(
                          tag_proc.vocab,
                          pred="predict",
                          target="truth",
                          seq_lens="word_seq_origin_len"),
                      dev_data=dev_data,
                      metric_key="f",
                      use_tqdm=True,
                      use_cuda=True,
                      print_every=10,
                      n_epochs=20,
                      save_path=save)
    trainer.train(load_best_model=True)

    # save model & pipeline
    model_proc = ModelProcessor(model,
                                seq_len_field_name="word_seq_origin_len")
    id2tag = Index2WordProcessor(tag_proc.vocab, "predict", "tag")

    pp = Pipeline(
        [vocab_proc, seq_len_proc, set_input_proc, model_proc, id2tag])
    save_dict = {"pipeline": pp, "model": model, "tag_vocab": tag_proc.vocab}
    torch.save(save_dict, os.path.join(save, "model_pp.pkl"))
    print("pipeline saved")
示例#10
0
def train(checkpoint=None):
    # load config
    train_param = ConfigSection()
    model_param = ConfigSection()
    ConfigLoader().load_config(cfgfile, {
        "train": train_param,
        "model": model_param
    })
    print("config loaded")

    # Data Loader
    dataset = ZhConllPOSReader().load("/home/hyan/train.conllx")
    print(dataset)
    print("dataset transformed")

    dataset.rename_field("tag", "truth")

    vocab_proc = VocabIndexerProcessor("words",
                                       new_added_filed_name="word_seq")
    tag_proc = VocabIndexerProcessor("truth")
    seq_len_proc = SeqLenProcessor(field_name="word_seq",
                                   new_added_field_name="word_seq_origin_len",
                                   is_input=True)

    vocab_proc(dataset)
    tag_proc(dataset)
    seq_len_proc(dataset)

    dataset.set_input("word_seq", "word_seq_origin_len", "truth")
    dataset.set_target("truth", "word_seq_origin_len")

    print("processors defined")

    # dataset.set_is_target(tag_ids=True)
    model_param["vocab_size"] = vocab_proc.get_vocab_size()
    model_param["num_classes"] = tag_proc.get_vocab_size()
    print("vocab_size={}  num_classes={}".format(model_param["vocab_size"],
                                                 model_param["num_classes"]))

    # define a model
    if checkpoint is None:
        # pre_trained = load_tencent_embed("/home/zyfeng/data/char_tencent_embedding.pkl", vocab_proc.vocab.word2idx)
        pre_trained = None
        model = AdvSeqLabel(model_param,
                            id2words=tag_proc.vocab.idx2word,
                            emb=pre_trained)
        print(model)
    else:
        model = torch.load(checkpoint)

    # call trainer to train
    trainer = Trainer(dataset,
                      model,
                      loss=None,
                      metrics=SpanFPreRecMetric(
                          tag_proc.vocab,
                          pred="predict",
                          target="truth",
                          seq_lens="word_seq_origin_len"),
                      dev_data=dataset,
                      metric_key="f",
                      use_tqdm=True,
                      use_cuda=True,
                      print_every=5,
                      n_epochs=6,
                      save_path="./save")
    trainer.train(load_best_model=True)

    # save model & pipeline
    model_proc = ModelProcessor(model,
                                seq_len_field_name="word_seq_origin_len")
    id2tag = Index2WordProcessor(tag_proc.vocab, "predict", "tag")

    pp = Pipeline([vocab_proc, seq_len_proc, model_proc, id2tag])
    save_dict = {"pipeline": pp, "model": model, "tag_vocab": tag_proc.vocab}
    torch.save(save_dict, "model_pp.pkl")
    print("pipeline saved")

    torch.save(model, "./save/best_model.pkl")
示例#11
0

args = {"word_emb_dim": 300, "rnn_hidden_units": 300, "num_classes": len(vocab[1]), "init_embedding": embedding,
        "vocab_size": len(vocab[0])}

print(args)


model = SeqLabelingForSLSTM(args)

if torch.cuda.device_count() > 1:
    model = torch.nn.DataParallel(model)

metrics, metric_key = None, None
if arg.dataset == "ner":
    metrics = SpanFPreRecMetric(vocab[1], pred='predict', target='truth', seq_lens='word_seq_origin_len')
    metric_key = "f"
elif arg.dataset == "pos":
    metrics = AccuracyMetric(pred='predict', target='truth', seq_lens='word_seq_origin_len')
    metric_key = "acc"


trainer = Trainer(
    train_data=train_dataset,
    model=model,
    loss=None,
    # loss=CrossEntropyLoss(pred='predict', target='truth'),
    metrics=metrics,
    n_epochs=20,
    batch_size=arg.batch_size,
    print_every=1,