コード例 #1
0
 def build_dataloader(self,
                      data,
                      transform: TransformList = None,
                      training=False,
                      device=None,
                      logger: logging.Logger = None,
                      cache=False,
                      gradient_accumulation=1,
                      **kwargs) -> DataLoader:
     args = dict((k, self.config[k]) for k in [
         'delimiter', 'max_seq_len', 'sent_delimiter', 'char_level',
         'hard_constraint'
     ] if k in self.config)
     # We only need those transforms before TransformerTokenizer
     transformer_index = transform.index_by_type(
         TransformerSequenceTokenizer)
     assert transformer_index is not None
     transform = transform[:transformer_index + 1]
     if self.transform:
         transform.insert(0, self.transform)
     transform.append(self.last_transform())
     dataset = self.build_dataset(data,
                                  cache=cache,
                                  transform=transform,
                                  **args)
     if self.vocabs.mutable:
         self.build_vocabs(dataset, logger)
     return PadSequenceDataLoader(batch_sampler=self.sampler_builder.build(
         self.compute_lens(data, dataset, 'token_input_ids'),
         shuffle=training,
         gradient_accumulation=gradient_accumulation),
                                  device=device,
                                  dataset=dataset)
コード例 #2
0
ファイル: dep.py プロジェクト: zouyanjian/HanLP
 def build_dataloader(self,
                      data,
                      transform: TransformList = None,
                      training=False,
                      device=None,
                      logger: logging.Logger = None,
                      gradient_accumulation=1,
                      **kwargs) -> DataLoader:
     transform.insert(0, append_bos)
     dataset = BiaffineDependencyParser.build_dataset(self, data, transform)
     if isinstance(data, str):
         dataset.purge_cache()
     if self.vocabs.mutable:
         BiaffineDependencyParser.build_vocabs(self,
                                               dataset,
                                               logger,
                                               transformer=True)
     if dataset.cache:
         timer = CountdownTimer(len(dataset))
         BiaffineDependencyParser.cache_dataset(self, dataset, timer,
                                                training, logger)
     max_seq_len = self.config.get('max_seq_len', None)
     if max_seq_len and isinstance(data, str):
         dataset.prune(lambda x: len(x['token_input_ids']) > 510, logger)
     return PadSequenceDataLoader(batch_sampler=self.sampler_builder.build(
         self.compute_lens(data, dataset, length_field='FORM'),
         shuffle=training,
         gradient_accumulation=gradient_accumulation),
                                  device=device,
                                  dataset=dataset,
                                  pad=self.get_pad_dict())
コード例 #3
0
 def build_transform(self, task: Task) -> Tuple[TransformerSequenceTokenizer, TransformList]:
     encoder: ContextualWordEmbedding = self.config.encoder
     encoder_transform: TransformerSequenceTokenizer = task.build_tokenizer(encoder.transform())
     length_transform = FieldLength('token', 'token_length')
     transform = TransformList(encoder_transform, length_transform)
     extra_transform = self.config.get('transform', None)
     if extra_transform:
         transform.insert(0, extra_transform)
     return encoder_transform, transform
コード例 #4
0
ファイル: biaffine_sdp.py プロジェクト: yehuangcn/HanLP
 def build_dataset(self, data, transform=None):
     transforms = TransformList(
         functools.partial(append_bos_to_form_pos, pos_key='UPOS'),
         functools.partial(unpack_deps_to_head_deprel,
                           pad_rel=self.config.pad_rel))
     if transform:
         transforms.append(transform)
     return super(BiaffineSemanticDependencyParser,
                  self).build_dataset(data, transforms)
コード例 #5
0
    def build_dataloader(self,
                         data,
                         shuffle,
                         device,
                         embed: Embedding,
                         training=False,
                         logger=None,
                         gradient_accumulation=1,
                         sampler_builder=None,
                         batch_size=None,
                         bos='\0',
                         **kwargs) -> DataLoader:
        first_transform = TransformList(functools.partial(append_bos, bos=bos))
        embed_transform = embed.transform(vocabs=self.vocabs)
        transformer_transform = self._get_transformer_transform_from_transforms(
            embed_transform)
        if embed_transform:
            if transformer_transform and isinstance(embed_transform,
                                                    TransformList):
                embed_transform.remove(transformer_transform)

            first_transform.append(embed_transform)
        dataset = self.build_dataset(data, first_transform=first_transform)
        if self.config.get('transform', None):
            dataset.append_transform(self.config.transform)

        if self.vocabs.mutable:
            self.build_vocabs(dataset, logger, self._transformer_trainable())
        if transformer_transform and isinstance(embed_transform,
                                                TransformList):
            embed_transform.append(transformer_transform)

        dataset.append_transform(FieldLength('token', 'sent_length'))
        if isinstance(data, str):
            dataset.purge_cache()
        if len(dataset) > 1000 and isinstance(data, str):
            timer = CountdownTimer(len(dataset))
            self.cache_dataset(dataset, timer, training, logger)
        if sampler_builder:
            lens = [sample['sent_length'] for sample in dataset]
            sampler = sampler_builder.build(lens, shuffle,
                                            gradient_accumulation)
        else:
            sampler = None
        loader = PadSequenceDataLoader(dataset=dataset,
                                       batch_sampler=sampler,
                                       batch_size=batch_size,
                                       pad=self.get_pad_dict(),
                                       device=device,
                                       vocabs=self.vocabs)
        return loader
コード例 #6
0
 def build_dataloader(self,
                      data,
                      batch_size,
                      sampler_builder: SamplerBuilder = None,
                      gradient_accumulation=1,
                      shuffle=False,
                      device=None,
                      logger: logging.Logger = None,
                      **kwargs) -> DataLoader:
     if isinstance(data, TransformableDataset):
         dataset = data
     else:
         transform = self.config.encoder.transform()
         if self.config.get('transform', None):
             transform = TransformList(self.config.transform, transform)
         dataset = self.build_dataset(data, transform, logger)
     if self.vocabs.mutable:
         # noinspection PyTypeChecker
         self.build_vocabs(dataset, logger)
     lens = [len(x['token_input_ids']) for x in dataset]
     if sampler_builder:
         sampler = sampler_builder.build(lens, shuffle,
                                         gradient_accumulation)
     else:
         sampler = None
     return PadSequenceDataLoader(dataset,
                                  batch_size,
                                  shuffle,
                                  device=device,
                                  batch_sampler=sampler)
コード例 #7
0
ファイル: dataset.py プロジェクト: longgege/HanLP
    def append_transform(self, transform: Callable):
        """Append a transform to its list of transforms.

        Args:
            transform: A new transform to be appended.

        Returns: Itself.

        """
        assert transform is not None, 'None transform not allowed'
        if not self.transform:
            self.transform = TransformList(transform)
        elif not isinstance(self.transform, TransformList):
            if self.transform != transform:
                self.transform = TransformList(self.transform, transform)
        else:
            if transform not in self.transform:
                self.transform.append(transform)
        return self
コード例 #8
0
ファイル: dataset.py プロジェクト: Zilig/HanLP
    def __init__(self, transform: Union[Callable, List] = None) -> None:
        """An object which can be transformed with a list of functions. It can be treated as an objected being passed
        through a list of functions, while these functions are kept in a list.

        Args:
            transform: A transform function or a list of functions.
        """
        super().__init__()
        if isinstance(transform, list) and not isinstance(transform, TransformList):
            transform = TransformList(*transform)
        self.transform: Union[Callable, TransformList] = transform
コード例 #9
0
ファイル: dataset.py プロジェクト: longgege/HanLP
    def insert_transform(self, index: int, transform: Callable):
        """Insert a transform to a certain position.

        Args:
            index: A certain position.
            transform: A new transform.

        Returns: Dataset itself.

        """
        assert transform is not None, 'None transform not allowed'
        if not self.transform:
            self.transform = TransformList(transform)
        elif not isinstance(self.transform, TransformList):
            if self.transform != transform:
                self.transform = TransformList(self.transform)
                self.transform.insert(index, transform)
        else:
            if transform not in self.transform:
                self.transform.insert(index, transform)
        return self
コード例 #10
0
ファイル: biaffine_ner.py プロジェクト: cfy42584125/HanLP-1
 def build_dataloader(self,
                      data,
                      transform: TransformList = None,
                      training=False,
                      device=None,
                      logger: logging.Logger = None,
                      gradient_accumulation=1,
                      **kwargs) -> DataLoader:
     transform = copy(transform)
     transform.append(unpack_ner)
     dataset = BiaffineNamedEntityRecognizer.build_dataset(
         self, data, self.vocabs, transform)
     if self.vocabs.mutable:
         BiaffineNamedEntityRecognizer.build_vocabs(self, dataset, logger,
                                                    self.vocabs)
     return PadSequenceDataLoader(batch_sampler=self.sampler_builder.build(
         self.compute_lens(data, dataset),
         shuffle=training,
         gradient_accumulation=gradient_accumulation),
                                  device=device,
                                  dataset=dataset)
コード例 #11
0
ファイル: biaffine_ner.py プロジェクト: cfy42584125/HanLP-1
 def build_dataloader(self,
                      data,
                      batch_size,
                      shuffle,
                      device,
                      logger: logging.Logger = None,
                      vocabs=None,
                      sampler_builder=None,
                      gradient_accumulation=1,
                      **kwargs) -> DataLoader:
     if vocabs is None:
         vocabs = self.vocabs
     transform = TransformList(unpack_ner, FieldLength('token'))
     if isinstance(self.config.embed, Embedding):
         transform.append(self.config.embed.transform(vocabs=vocabs))
     transform.append(self.vocabs)
     dataset = self.build_dataset(data, vocabs, transform)
     if vocabs.mutable:
         self.build_vocabs(dataset, logger, vocabs)
     if 'token' in vocabs:
         lens = [x['token'] for x in dataset]
     else:
         lens = [len(x['token_input_ids']) for x in dataset]
     if sampler_builder:
         sampler = sampler_builder.build(lens, shuffle,
                                         gradient_accumulation)
     else:
         sampler = None
     return PadSequenceDataLoader(batch_sampler=sampler,
                                  device=device,
                                  dataset=dataset)
コード例 #12
0
 def build_dataset(self, data, bos_transform=None):
     transform = TransformList(
         functools.partial(append_bos, pos_key='UPOS'),
         functools.partial(unpack_deps_to_head_deprel,
                           pad_rel=self.config.pad_rel,
                           arc_key='arc_2nd',
                           rel_key='rel_2nd'))
     if self.config.joint:
         transform.append(merge_head_deprel_with_2nd)
     if bos_transform:
         transform.append(bos_transform)
     return super().build_dataset(data, transform)
コード例 #13
0
ファイル: embedding.py プロジェクト: 1032998/LM2
 def transform(self, **kwargs):
     transforms = [e.transform(**kwargs) for e in self._embeddings]
     transforms = [t for t in transforms if t]
     return TransformList(*transforms)
コード例 #14
0
ファイル: transformer_tagger.py プロジェクト: yehuangcn/HanLP
 def last_transform(self):
     return TransformList(self.vocabs, FieldLength(self.config.token_key))
コード例 #15
0
 def build_dataset(self, data, transform=None, **kwargs):
     if not isinstance(transform, list):
         transform = TransformList()
     transform.append(add_lemma_rules_to_sample)
     return super().build_dataset(data, transform, **kwargs)
コード例 #16
0
ファイル: dataset.py プロジェクト: zhoumo99133/HanLP
class Transformable(ABC):
    def __init__(self, transform: Union[Callable, List] = None) -> None:
        """An object which can be transformed with a list of functions. It can be treated as an objected being passed
        through a list of functions, while these functions are kept in a list.

        Args:
            transform: A transform function or a list of functions.
        """
        super().__init__()
        if isinstance(transform,
                      list) and not isinstance(transform, TransformList):
            transform = TransformList(*transform)
        self.transform: Union[Callable, TransformList] = transform

    def append_transform(self, transform: Callable):
        """Append a transform to its list of transforms.

        Args:
            transform: A new transform to be appended.

        Returns:
            Itself.

        """
        assert transform is not None, 'None transform not allowed'
        if not self.transform:
            self.transform = TransformList(transform)
        elif not isinstance(self.transform, TransformList):
            if self.transform != transform:
                self.transform = TransformList(self.transform, transform)
        else:
            if transform not in self.transform:
                self.transform.append(transform)
        return self

    def insert_transform(self, index: int, transform: Callable):
        """Insert a transform to a certain position.

        Args:
            index: A certain position.
            transform: A new transform.

        Returns:
            Itself.

        """
        assert transform is not None, 'None transform not allowed'
        if not self.transform:
            self.transform = TransformList(transform)
        elif not isinstance(self.transform, TransformList):
            if self.transform != transform:
                self.transform = TransformList(self.transform)
                self.transform.insert(index, transform)
        else:
            if transform not in self.transform:
                self.transform.insert(index, transform)
        return self

    def transform_sample(self, sample: dict, inplace=False) -> dict:
        """Apply transforms to a sample.

        Args:
            sample: A sample, which is a ``dict`` holding features.
            inplace: ``True`` to apply transforms inplace.

        .. Attention::
            If any transform modifies existing features, it will modify again and again when ``inplace=True``.
            For example, if a transform insert a ``BOS`` token to a list inplace, and it is called twice,
            then 2 ``BOS`` will be inserted which might not be an intended result.

        Returns:
            Transformed sample.
        """
        if not inplace:
            sample = copy(sample)
        if self.transform:
            sample = self.transform(sample)
        return sample
コード例 #17
0
 def last_transform(self):
     return TransformList(
         functools.partial(generate_tags_for_subtokens,
                           tagging_scheme=self.config.tagging_scheme),
         super().last_transform())
コード例 #18
0
 def transform(self, **kwargs) -> Callable:
     vocab = Vocab()
     vocab.load(os.path.join(get_resource(self.path), 'vocab.json'))
     return TransformList(ContextualStringEmbeddingTransform(self.field),
                          FieldToIndex(f'{self.field}_f_char', vocab),
                          FieldToIndex(f'{self.field}_b_char', vocab))