示例#1
0
文件: pos.py 项目: lei1993/HanLP
 def update_metrics(self, batch: Dict[str, Any],
                    output: Union[torch.Tensor, Dict[str, torch.Tensor],
                                  Iterable[torch.Tensor], Any],
                    prediction: Dict[str, Any], metric: Union[MetricDict,
                                                              Metric]):
     return TransformerTagger.update_metrics(self, metric, output,
                                             batch['tag_id'], batch['mask'])
示例#2
0
文件: pos.py 项目: lei1993/HanLP
 def decode_output(self, output: Union[torch.Tensor, Dict[str,
                                                          torch.Tensor],
                                       Iterable[torch.Tensor], Any],
                   mask: torch.BoolTensor, batch: Dict[str, Any], decoder,
                   **kwargs) -> Union[Dict[str, Any], Any]:
     return TransformerTagger.decode_output(self, output, mask, batch,
                                            decoder)
示例#3
0
文件: pos.py 项目: lei1993/HanLP
 def compute_loss(
         self, batch: Dict[str, Any],
         output: Union[torch.Tensor, Dict[str, torch.Tensor],
                       Iterable[torch.Tensor], Any], criterion
 ) -> Union[torch.FloatTensor, Dict[str, torch.FloatTensor]]:
     return TransformerTagger.compute_loss(self, criterion, output,
                                           batch['tag_id'], batch['mask'])
 def feed_batch(self, batch: dict):
     x, mask = TransformerTagger.feed_batch(self, batch)
     # strip [CLS], [SEP] and [unused_i]
     return x[:, 1:-2, :], mask
示例#5
0
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2019-12-28 23:15
from hanlp.components.taggers.transformers.transformer_tagger import TransformerTagger
from tests import cdroot

cdroot()
tagger = TransformerTagger()
save_dir = 'data/model/pos/ctb9_albert_base_zh_epoch_20'
tagger.fit('data/pos/ctb9/train.short.tsv',
           'data/pos/ctb9/dev.short.tsv',
           save_dir,
           transformer='albert_base_zh',
           max_seq_length=130,
           warmup_steps_ratio=0.1,
           epochs=20,
           learning_rate=5e-5)
tagger.load(save_dir)
print(tagger(['我', '的', '希望', '是', '希望', '和平']))
tagger.evaluate('data/pos/ctb9/test.short.tsv', save_dir=save_dir)
print(f'Model saved in {save_dir}')
示例#6
0
文件: pos.py 项目: lei1993/HanLP
 def prediction_to_result(self, prediction: Dict[str, Any],
                          batch: Dict[str, Any]) -> Union[List, Dict]:
     return TransformerTagger.prediction_to_human(
         self, prediction, self.vocabs['tag'].idx_to_token, batch)
示例#7
0
文件: pos.py 项目: lei1993/HanLP
 def input_is_flat(self, data) -> bool:
     return TransformerTagger.input_is_flat(self, data)
示例#8
0
文件: pos.py 项目: lei1993/HanLP
 def build_metric(self, **kwargs):
     return TransformerTagger.build_metric(self, **kwargs)
示例#9
0
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2019-10-25 21:34

from hanlp.components.taggers.transformers.transformer_tagger import TransformerTagger
from hanlp.datasets.ner.conll03 import CONLL03_EN_TRAIN, CONLL03_EN_VALID, CONLL03_EN_TEST
from tests import cdroot

cdroot()
tagger = TransformerTagger()
save_dir = 'data/model/ner-rnn-debug'
tagger.fit(CONLL03_EN_TRAIN, CONLL03_EN_VALID, save_dir, transformer='bert-base-uncased',
           metrics='f1'
           )
tagger.load(save_dir)
# print(tagger.predict('West Indian all-rounder Phil Simmons eats apple .'.split()))
# print(tagger.predict([['This', 'is', 'an', 'old', 'story'],
#                       ['Not', 'this', 'year', '.']]))
# [['DT', 'VBZ', 'DT', 'JJ', 'NN'], ['RB', 'DT', 'NN', '.']]
tagger.evaluate(CONLL03_EN_TEST, save_dir=save_dir, output=False, batch_size=32)