예제 #1
0
파일: span_bio.py 프로젝트: yehuangcn/HanLP
 def build_dataloader(self,
                      data,
                      batch_size,
                      sampler_builder: SamplerBuilder = None,
                      gradient_accumulation=1,
                      shuffle=False,
                      device=None,
                      logger: logging.Logger = None,
                      **kwargs) -> DataLoader:
     if isinstance(data, TransformableDataset):
         dataset = data
     else:
         dataset = self.build_dataset(data, [
             self.config.embed.transform(vocabs=self.vocabs), self.vocabs,
             FieldLength('token')
         ])
     if self.vocabs.mutable:
         # noinspection PyTypeChecker
         self.build_vocabs(dataset, logger)
     lens = [len(x['token_input_ids']) for x in dataset]
     if sampler_builder:
         sampler = sampler_builder.build(lens, shuffle,
                                         gradient_accumulation)
     else:
         sampler = None
     return PadSequenceDataLoader(dataset,
                                  batch_size,
                                  shuffle,
                                  device=device,
                                  batch_sampler=sampler)
예제 #2
0
 def build_dataloader(self, data, batch_size, shuffle, device, logger: logging.Logger = None,
                      sampler_builder: SamplerBuilder = None, gradient_accumulation=1, **kwargs) -> DataLoader:
     if isinstance(data, TransformableDataset):
         dataset = data
     else:
         args = dict((k, self.config.get(k, None)) for k in
                     ['delimiter', 'max_seq_len', 'sent_delimiter', 'char_level', 'hard_constraint'])
         dataset = self.build_dataset(data, **args)
     if self.config.token_key is None:
         self.config.token_key = next(iter(dataset[0]))
         logger.info(
             f'Guess [bold][blue]token_key={self.config.token_key}[/blue][/bold] according to the '
             f'training dataset: [blue]{dataset}[/blue]')
     dataset.append_transform(self.tokenizer_transform)
     dataset.append_transform(self.last_transform())
     if not isinstance(data, list):
         dataset.purge_cache()
     if self.vocabs.mutable:
         self.build_vocabs(dataset, logger)
     if sampler_builder is not None:
         sampler = sampler_builder.build([len(x[f'{self.config.token_key}_input_ids']) for x in dataset], shuffle,
                                         gradient_accumulation=gradient_accumulation if shuffle else 1)
     else:
         sampler = None
     return PadSequenceDataLoader(dataset, batch_size, shuffle, device=device, batch_sampler=sampler)
예제 #3
0
    def build_dataloader(self, data, batch_size,
                         gradient_accumulation=1,
                         shuffle=False,
                         sampler_builder: SamplerBuilder = None,
                         device=None,
                         logger: logging.Logger = None,
                         **kwargs) -> DataLoader:
        dataset = self.build_dataset(data, not shuffle)
        if self.vocabs.mutable:
            self.build_vocabs(dataset, logger)
        self.finalize_dataset(dataset, logger)
        if isinstance(data, str):
            dataset.purge_cache()
            timer = CountdownTimer(len(dataset))
            max_num_tokens = 0
            # lc = Counter()
            for each in dataset:
                max_num_tokens = max(max_num_tokens, len(each['text_token_ids']))
                # lc[len(each['text_token_ids'])] += 1
                timer.log(f'Preprocessing and caching samples (longest sequence {max_num_tokens})'
                          f'[blink][yellow]...[/yellow][/blink]')
            # print(lc.most_common())
            if self.vocabs.mutable:
                self.vocabs.lock()
                self.vocabs.summary(logger)

        if not sampler_builder:
            sampler_builder = SortingSamplerBuilder(batch_max_tokens=500)
        sampler = sampler_builder.build([len(x['text_token_ids']) for x in dataset], shuffle,
                                        gradient_accumulation if dataset.cache else 1)
        return self._create_dataloader(dataset, batch_size, device, sampler, shuffle)