Exemple #1
0
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2019-12-28 23:15
from hanlp.components.taggers.transformers.transformer_tagger import TransformerTagger
from tests import cdroot

cdroot()
tagger = TransformerTagger()
save_dir = 'data/model/pos/ctb9_albert_base_zh_epoch_20'
tagger.fit('data/pos/ctb9/train.short.tsv',
           'data/pos/ctb9/dev.short.tsv',
           save_dir,
           transformer='albert_base_zh',
           max_seq_length=130,
           warmup_steps_ratio=0.1,
           epochs=20,
           learning_rate=5e-5)
tagger.load(save_dir)
print(tagger(['我', '的', '希望', '是', '希望', '和平']))
tagger.evaluate('data/pos/ctb9/test.short.tsv', save_dir=save_dir)
print(f'Model saved in {save_dir}')
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2019-10-25 21:34

from hanlp.components.taggers.transformers.transformer_tagger import TransformerTagger
from hanlp.datasets.ner.conll03 import CONLL03_EN_TRAIN, CONLL03_EN_VALID, CONLL03_EN_TEST
from tests import cdroot

cdroot()
tagger = TransformerTagger()
save_dir = 'data/model/ner-rnn-debug'
tagger.fit(CONLL03_EN_TRAIN, CONLL03_EN_VALID, save_dir, transformer='bert-base-uncased',
           metrics='f1'
           )
tagger.load(save_dir)
# print(tagger.predict('West Indian all-rounder Phil Simmons eats apple .'.split()))
# print(tagger.predict([['This', 'is', 'an', 'old', 'story'],
#                       ['Not', 'this', 'year', '.']]))
# [['DT', 'VBZ', 'DT', 'JJ', 'NN'], ['RB', 'DT', 'NN', '.']]
tagger.evaluate(CONLL03_EN_TEST, save_dir=save_dir, output=False, batch_size=32)