Пример #1
0
    def build_dataloader(self, data, batch_size,
                         gradient_accumulation=1,
                         shuffle=False,
                         sampler_builder: SamplerBuilder = None,
                         device=None,
                         logger: logging.Logger = None,
                         **kwargs) -> DataLoader:
        dataset = self.build_dataset(data, not shuffle)
        if self.vocabs.mutable:
            self.build_vocabs(dataset, logger)
        self.finalize_dataset(dataset, logger)
        if isinstance(data, str):
            dataset.purge_cache()
            timer = CountdownTimer(len(dataset))
            max_num_tokens = 0
            # lc = Counter()
            for each in dataset:
                max_num_tokens = max(max_num_tokens, len(each['text_token_ids']))
                # lc[len(each['text_token_ids'])] += 1
                timer.log(f'Preprocessing and caching samples (longest sequence {max_num_tokens})'
                          f'[blink][yellow]...[/yellow][/blink]')
            # print(lc.most_common())
            if self.vocabs.mutable:
                self.vocabs.lock()
                self.vocabs.summary(logger)

        if not sampler_builder:
            sampler_builder = SortingSamplerBuilder(batch_max_tokens=500)
        sampler = sampler_builder.build([len(x['text_token_ids']) for x in dataset], shuffle,
                                        gradient_accumulation if dataset.cache else 1)
        return self._create_dataloader(dataset, batch_size, device, sampler, shuffle)
Пример #2
0
 def build_dataloader(self,
                      data,
                      batch_size,
                      sent_a_col=None,
                      sent_b_col=None,
                      similarity_col=None,
                      delimiter='auto',
                      gradient_accumulation=1,
                      sampler_builder=None,
                      shuffle=False,
                      device=None,
                      logger: logging.Logger = None,
                      split=None,
                      **kwargs) -> DataLoader:
     dataset = SemanticTextualSimilarityDataset(data,
                                                sent_a_col,
                                                sent_b_col,
                                                similarity_col,
                                                delimiter=delimiter,
                                                transform=self._tokenizer,
                                                cache=isinstance(data, str))
     if split == 'trn':
         scores = [x['similarity'] for x in dataset]
         self.config.max_score = max(scores)
         self.config.min_score = min(scores)
     if not sampler_builder:
         sampler_builder = SortingSamplerBuilder(batch_size=batch_size)
     lens = [len(x['input_ids']) for x in dataset]
     return PadSequenceDataLoader(dataset,
                                  batch_sampler=sampler_builder.build(
                                      lens, shuffle, gradient_accumulation),
                                  device=device,
                                  pad={
                                      'similarity':
                                      0.0,
                                      'input_ids':
                                      self._tokenizer.tokenizer.pad_token_id
                                  })
Пример #3
0
    def __init__(self,
                 trn: str = None,
                 dev: str = None,
                 tst: str = None,
                 sampler_builder: SamplerBuilder = None,
                 dependencies: str = None,
                 scalar_mix: ScalarMixWithDropoutBuilder = None,
                 use_raw_hidden_states=False,
                 lr=None,
                 separate_optimizer=False,
                 cls_is_bos=False,
                 sep_is_eos=False,
                 **kwargs) -> None:
        """
        A task in the multi-task learning framework

        Args:
            trn: Path to training set.
            dev: Path to dev set.
            tst: Path to test set.
            sampler_builder: A builder which builds a sampler.
            dependencies: Its dependencies on other tasks.
            scalar_mix: A builder which builds a `ScalarMixWithDropout` object.
            use_raw_hidden_states: Whether to use raw hidden states from transformer without any pooling.
            lr: Learning rate for this task.
            separate_optimizer: Use customized separate optimizer for this task.
            cls_is_bos: ``True`` to treat the first token as ``BOS``.
            sep_is_eos: ``True`` to treat the last token as ``EOS``.
            **kwargs: Additional config.
        """
        ConfigTracker.__init__(self, merge_locals_kwargs(locals(), kwargs))
        for f, n in zip([trn, dev, tst], ['trn', 'dev', 'tst']):
            if f and os.path.isfile(f):  # anonymize local file names
                self.config.pop(n)
        self.separate_optimizer = separate_optimizer
        self.lr = lr
        self.use_raw_hidden_states = use_raw_hidden_states
        if sampler_builder is None:
            sampler_builder = SortingSamplerBuilder(batch_size=32)
        self.sampler_builder: Union[SortingSamplerBuilder,
                                    KMeansSamplerBuilder] = sampler_builder
        self.dependencies = dependencies
        self.tst = tst
        self.dev = dev
        self.trn = trn
        self.scalar_mix = scalar_mix
        self.cls_is_bos = cls_is_bos
        self.sep_is_eos = sep_is_eos
Пример #4
0
 def build_dataloader(self,
                      data,
                      batch_size,
                      shuffle=False,
                      device=None,
                      logger: logging.Logger = None,
                      sampler_builder=None,
                      gradient_accumulation=1,
                      transformer: ContextualWordEmbedding = None,
                      **kwargs) -> DataLoader:
     transform = [
         generate_lemma_rule, append_bos, self.vocabs,
         transformer.transform(),
         FieldLength('token')
     ]
     if not self.config.punct:
         transform.append(PunctuationMask('token', 'punct_mask'))
     dataset = self.build_dataset(data, transform)
     if self.vocabs.mutable:
         # noinspection PyTypeChecker
         self.build_vocabs(dataset, logger)
     lens = [len(x['token_input_ids']) for x in dataset]
     if sampler_builder:
         sampler = sampler_builder.build(lens, shuffle,
                                         gradient_accumulation)
     else:
         sampler = SortingSamplerBuilder(batch_size).build(
             lens, shuffle, gradient_accumulation)
     return PadSequenceDataLoader(
         dataset,
         batch_size,
         shuffle,
         device=device,
         batch_sampler=sampler,
         pad={'arc': 0},
     )
Пример #5
0
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2020-08-11 02:47
from hanlp.common.dataset import SortingSamplerBuilder
from hanlp.components.tokenizers.transformer import TransformerTaggingTokenizer
from hanlp.datasets.tokenization.sighan2005 import SIGHAN2005_PKU_TRAIN_ALL, SIGHAN2005_PKU_TEST
from tests import cdroot

cdroot()
tokenizer = TransformerTaggingTokenizer()
save_dir = 'data/model/cws/sighan2005_pku_bert_base_96.70'
tokenizer.fit(
    SIGHAN2005_PKU_TRAIN_ALL,
    SIGHAN2005_PKU_TEST,  # Conventionally, no devset is used. See Tian et al. (2020).
    save_dir,
    'bert-base-chinese',
    max_seq_len=300,
    char_level=True,
    hard_constraint=True,
    sampler_builder=SortingSamplerBuilder(batch_size=32),
    epochs=3,
    adam_epsilon=1e-6,
    warmup_steps=0.1,
    weight_decay=0.01,
    word_dropout=0.1,
    seed=1609836303,
)
tokenizer.evaluate(SIGHAN2005_PKU_TEST, save_dir)
print(f'Model saved in {save_dir}')
Пример #6
0
from hanlp.datasets.srl.ontonotes5.chinese import ONTONOTES5_CONLL12_CHINESE_TEST, ONTONOTES5_CONLL12_CHINESE_DEV, \
    ONTONOTES5_CONLL12_CHINESE_TRAIN
from hanlp.layers.embeddings.contextual_word_embedding import ContextualWordEmbedding
from hanlp.layers.transformers.relative_transformer import RelativeTransformerEncoder
from hanlp.utils.lang.zh.char_table import HANLP_CHAR_TABLE_JSON
from hanlp.utils.log_util import cprint
from tests import cdroot

cdroot()
tasks = {
    'tok':
    TaggingTokenization(
        CTB8_CWS_TRAIN,
        CTB8_CWS_DEV,
        CTB8_CWS_TEST,
        SortingSamplerBuilder(batch_size=32),
        max_seq_len=510,
        hard_constraint=True,
        char_level=True,
        tagging_scheme='BMES',
        lr=1e-3,
        transform=NormalizeCharacter(HANLP_CHAR_TABLE_JSON, 'token'),
    ),
    'pos':
    TransformerTagging(
        CTB8_POS_TRAIN,
        CTB8_POS_DEV,
        CTB8_POS_TEST,
        SortingSamplerBuilder(batch_size=32),
        hard_constraint=True,
        max_seq_len=510,