예제 #1
0
 def build_dataloader(self,
                      data,
                      batch_size,
                      shuffle,
                      device,
                      logger=None,
                      **kwargs) -> DataLoader:
     vocabs = self.vocabs
     token_embed = self._convert_embed()
     dataset = data if isinstance(
         data, TransformableDataset) else self.build_dataset(
             data, transform=[vocabs])
     if vocabs.mutable:
         # Before building vocabs, let embeddings submit their vocabs, some embeddings will possibly opt out as their
         # transforms are not relevant to vocabs
         if isinstance(token_embed, Embedding):
             transform = token_embed.transform(vocabs=vocabs)
             if transform:
                 dataset.transform.insert(-1, transform)
         self.build_vocabs(dataset, logger)
     if isinstance(token_embed, Embedding):
         # Vocabs built, now add all transforms to the pipeline. Be careful about redundant ones.
         transform = token_embed.transform(vocabs=vocabs)
         if transform and transform not in dataset.transform:
             dataset.transform.insert(-1, transform)
     sampler = SortingSampler(
         [len(sample[self.config.token_key]) for sample in dataset],
         batch_size,
         shuffle=shuffle)
     return PadSequenceDataLoader(dataset,
                                  device=device,
                                  batch_sampler=sampler,
                                  vocabs=vocabs)
예제 #2
0
파일: mlm.py 프로젝트: lei1993/HanLP
 def build_dataloader(self,
                      data,
                      batch_size,
                      shuffle=False,
                      device=None,
                      logger: logging.Logger = None,
                      verbose=False,
                      **kwargs) -> DataLoader:
     dataset = MaskedLanguageModelDataset(
         [{
             'token': x
         } for x in data],
         generate_idx=True,
         transform=TransformerTextTokenizer(self.tokenizer,
                                            text_a_key='token'))
     if verbose:
         verbose = CountdownTimer(len(dataset))
     lens = []
     for each in dataset:
         lens.append(len(each['token_input_ids']))
         if verbose:
             verbose.log(
                 'Preprocessing and caching samples [blink][yellow]...[/yellow][/blink]'
             )
     dataloader = PadSequenceDataLoader(dataset,
                                        batch_sampler=SortingSampler(
                                            lens, batch_size=batch_size),
                                        device=device)
     return dataloader
예제 #3
0
 def build_dataloader(self, data, batch_size, shuffle, device, text_a_key, text_b_key,
                      label_key,
                      logger: logging.Logger = None,
                      sorting=True,
                      **kwargs) -> DataLoader:
     if not batch_size:
         batch_size = self.config.batch_size
     dataset = self.build_dataset(data)
     dataset.append_transform(self.vocabs)
     if self.vocabs.mutable:
         if not any([text_a_key, text_b_key]):
             if len(dataset.headers) == 2:
                 self.config.text_a_key = dataset.headers[0]
                 self.config.label_key = dataset.headers[1]
             elif len(dataset.headers) >= 3:
                 self.config.text_a_key, self.config.text_b_key, self.config.label_key = dataset.headers[0], \
                                                                                         dataset.headers[1], \
                                                                                         dataset.headers[-1]
             else:
                 raise ValueError('Wrong dataset format')
             report = {'text_a_key', 'text_b_key', 'label_key'}
             report = dict((k, self.config[k]) for k in report)
             report = [f'{k}={v}' for k, v in report.items() if v]
             report = ', '.join(report)
             logger.info(f'Guess [bold][blue]{report}[/blue][/bold] according to the headers of training dataset: '
                         f'[blue]{dataset}[/blue]')
         self.build_vocabs(dataset, logger)
         dataset.purge_cache()
     # if self.config.transform:
     #     dataset.append_transform(self.config.transform)
     dataset.append_transform(TransformerTextTokenizer(tokenizer=self.transformer_tokenizer,
                                                       text_a_key=self.config.text_a_key,
                                                       text_b_key=self.config.text_b_key,
                                                       max_seq_length=self.config.max_seq_length,
                                                       truncate_long_sequences=self.config.truncate_long_sequences,
                                                       output_key=''))
     batch_sampler = None
     if sorting and not isdebugging():
         if dataset.cache and len(dataset) > 1000:
             timer = CountdownTimer(len(dataset))
             lens = []
             for idx, sample in enumerate(dataset):
                 lens.append(len(sample['input_ids']))
                 timer.log('Pre-processing and caching dataset [blink][yellow]...[/yellow][/blink]',
                           ratio_percentage=None)
         else:
             lens = [len(sample['input_ids']) for sample in dataset]
         batch_sampler = SortingSampler(lens, batch_size=batch_size, shuffle=shuffle,
                                        batch_max_tokens=self.config.batch_max_tokens)
     return PadSequenceDataLoader(dataset, batch_size, shuffle, batch_sampler=batch_sampler, device=device)
예제 #4
0
    def build_dataloader(self, data, batch_size, shuffle, device, logger: logging.Logger,
                         generate_idx=False, **kwargs) -> DataLoader:
        batch_max_tokens = self.config.batch_max_tokens
        gradient_accumulation = self.config.get('gradient_accumulation', 1)
        if batch_size:
            batch_size //= gradient_accumulation
        if batch_max_tokens:
            batch_max_tokens //= gradient_accumulation
        dataset = self.build_dataset(data, generate_idx, logger)

        sampler = SortingSampler([x['token_length'] for x in dataset],
                                 batch_size=batch_size,
                                 batch_max_tokens=batch_max_tokens,
                                 shuffle=shuffle)
        return PadSequenceDataLoader(batch_sampler=sampler,
                                     device=device,
                                     dataset=dataset)