def build_dataloader(self, data, batch_size, shuffle, device, logger=None, **kwargs) -> DataLoader: vocabs = self.vocabs token_embed = self._convert_embed() dataset = data if isinstance( data, TransformableDataset) else self.build_dataset( data, transform=[vocabs]) if vocabs.mutable: # Before building vocabs, let embeddings submit their vocabs, some embeddings will possibly opt out as their # transforms are not relevant to vocabs if isinstance(token_embed, Embedding): transform = token_embed.transform(vocabs=vocabs) if transform: dataset.transform.insert(-1, transform) self.build_vocabs(dataset, logger) if isinstance(token_embed, Embedding): # Vocabs built, now add all transforms to the pipeline. Be careful about redundant ones. transform = token_embed.transform(vocabs=vocabs) if transform and transform not in dataset.transform: dataset.transform.insert(-1, transform) sampler = SortingSampler( [len(sample[self.config.token_key]) for sample in dataset], batch_size, shuffle=shuffle) return PadSequenceDataLoader(dataset, device=device, batch_sampler=sampler, vocabs=vocabs)
def build_dataloader(self, data, batch_size, shuffle=False, device=None, logger: logging.Logger = None, verbose=False, **kwargs) -> DataLoader: dataset = MaskedLanguageModelDataset( [{ 'token': x } for x in data], generate_idx=True, transform=TransformerTextTokenizer(self.tokenizer, text_a_key='token')) if verbose: verbose = CountdownTimer(len(dataset)) lens = [] for each in dataset: lens.append(len(each['token_input_ids'])) if verbose: verbose.log( 'Preprocessing and caching samples [blink][yellow]...[/yellow][/blink]' ) dataloader = PadSequenceDataLoader(dataset, batch_sampler=SortingSampler( lens, batch_size=batch_size), device=device) return dataloader
def build_dataloader(self, data, batch_size, shuffle, device, text_a_key, text_b_key, label_key, logger: logging.Logger = None, sorting=True, **kwargs) -> DataLoader: if not batch_size: batch_size = self.config.batch_size dataset = self.build_dataset(data) dataset.append_transform(self.vocabs) if self.vocabs.mutable: if not any([text_a_key, text_b_key]): if len(dataset.headers) == 2: self.config.text_a_key = dataset.headers[0] self.config.label_key = dataset.headers[1] elif len(dataset.headers) >= 3: self.config.text_a_key, self.config.text_b_key, self.config.label_key = dataset.headers[0], \ dataset.headers[1], \ dataset.headers[-1] else: raise ValueError('Wrong dataset format') report = {'text_a_key', 'text_b_key', 'label_key'} report = dict((k, self.config[k]) for k in report) report = [f'{k}={v}' for k, v in report.items() if v] report = ', '.join(report) logger.info(f'Guess [bold][blue]{report}[/blue][/bold] according to the headers of training dataset: ' f'[blue]{dataset}[/blue]') self.build_vocabs(dataset, logger) dataset.purge_cache() # if self.config.transform: # dataset.append_transform(self.config.transform) dataset.append_transform(TransformerTextTokenizer(tokenizer=self.transformer_tokenizer, text_a_key=self.config.text_a_key, text_b_key=self.config.text_b_key, max_seq_length=self.config.max_seq_length, truncate_long_sequences=self.config.truncate_long_sequences, output_key='')) batch_sampler = None if sorting and not isdebugging(): if dataset.cache and len(dataset) > 1000: timer = CountdownTimer(len(dataset)) lens = [] for idx, sample in enumerate(dataset): lens.append(len(sample['input_ids'])) timer.log('Pre-processing and caching dataset [blink][yellow]...[/yellow][/blink]', ratio_percentage=None) else: lens = [len(sample['input_ids']) for sample in dataset] batch_sampler = SortingSampler(lens, batch_size=batch_size, shuffle=shuffle, batch_max_tokens=self.config.batch_max_tokens) return PadSequenceDataLoader(dataset, batch_size, shuffle, batch_sampler=batch_sampler, device=device)
def build_dataloader(self, data, batch_size, shuffle, device, logger: logging.Logger, generate_idx=False, **kwargs) -> DataLoader: batch_max_tokens = self.config.batch_max_tokens gradient_accumulation = self.config.get('gradient_accumulation', 1) if batch_size: batch_size //= gradient_accumulation if batch_max_tokens: batch_max_tokens //= gradient_accumulation dataset = self.build_dataset(data, generate_idx, logger) sampler = SortingSampler([x['token_length'] for x in dataset], batch_size=batch_size, batch_max_tokens=batch_max_tokens, shuffle=shuffle) return PadSequenceDataLoader(batch_sampler=sampler, device=device, dataset=dataset)