def update_metrics(self, batch: Dict[str, Any], output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any], prediction: Dict[str, Any], metric: Union[MetricDict, Metric]): TransformerTaggingTokenizer.update_metrics(self, metric, output, batch['tag_id'], None, batch, prediction)
def decode_output(self, output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any], mask: torch.BoolTensor, batch: Dict[str, Any], decoder, **kwargs) -> Union[Dict[str, Any], Any]: return TransformerTaggingTokenizer.decode_output( self, output, mask, batch, decoder)
def compute_loss( self, batch: Dict[str, Any], output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any], criterion ) -> Union[torch.FloatTensor, Dict[str, torch.FloatTensor]]: return TransformerTaggingTokenizer.compute_loss( self, criterion, output, batch['tag_id'], batch['mask'])
# -*- coding:utf-8 -*- # Author: hankcs # Date: 2020-08-11 02:47 from hanlp.common.dataset import SortingSamplerBuilder from hanlp.components.tokenizers.transformer import TransformerTaggingTokenizer from hanlp.datasets.tokenization.sighan2005 import SIGHAN2005_PKU_TRAIN_ALL, SIGHAN2005_PKU_TEST from tests import cdroot cdroot() tokenizer = TransformerTaggingTokenizer() save_dir = 'data/model/cws/sighan2005_pku_bert_base_96.70' tokenizer.fit( SIGHAN2005_PKU_TRAIN_ALL, SIGHAN2005_PKU_TEST, # Conventionally, no devset is used. See Tian et al. (2020). save_dir, 'bert-base-chinese', max_seq_len=300, char_level=True, hard_constraint=True, sampler_builder=SortingSamplerBuilder(batch_size=32), epochs=3, adam_epsilon=1e-6, warmup_steps=0.1, weight_decay=0.01, word_dropout=0.1, seed=1609836303, ) tokenizer.evaluate(SIGHAN2005_PKU_TEST, save_dir) print(f'Model saved in {save_dir}')
def prediction_to_result(self, prediction: Dict[str, Any], batch: Dict[str, Any]) -> Union[List, Dict]: return TransformerTaggingTokenizer.prediction_to_human( self, prediction, None, batch, rebuild_span=True)
def input_is_flat(self, data) -> bool: return TransformerTaggingTokenizer.input_is_flat(self, data)
def build_criterion(self, model=None, **kwargs): return TransformerTaggingTokenizer.build_criterion(self, model=model, reduction='mean')
def build_metric(self, **kwargs): return TransformerTaggingTokenizer.build_metric(self)