def build_dataloader(self, data, transform: Callable = None, training=False, device=None, logger: logging.Logger = None, cache=False, gradient_accumulation=1, **kwargs) -> DataLoader: if isinstance(data, list): data = GraphAbstractMeaningRepresentationParser.build_samples( self, data) dataset, lens = GraphAbstractMeaningRepresentationParser.build_dataset( self, data, logger=logger, transform=transform, training=training) if self.vocabs.mutable: GraphAbstractMeaningRepresentationParser.build_vocabs( self, dataset, logger) dataloader = PrefetchDataLoader( DataLoader(batch_sampler=self.sampler_builder.build( lens, shuffle=training, gradient_accumulation=gradient_accumulation), dataset=dataset, collate_fn=merge_list_of_dict, num_workers=0), batchify=self.build_batchify(device, training), prefetch=None) return dataloader
def build_dataloader(self, data, batch_size, shuffle=False, device=None, logger: logging.Logger = None, gradient_accumulation=1, tau: float = 0.8, prune=None, prefetch=None, tasks_need_custom_eval=None, cache=False, debug=False, **kwargs) -> DataLoader: # This method is only called during training or evaluation but not prediction dataloader = MultiTaskDataLoader(training=shuffle, tau=tau) for i, (task_name, task) in enumerate(self.tasks.items()): encoder_transform, transform = self.build_transform(task) training = None if data == 'trn': if debug: _data = task.dev else: _data = task.trn training = True elif data == 'dev': _data = task.dev training = False elif data == 'tst': _data = task.tst training = False else: _data = data if isinstance(data, str): logger.info( f'[yellow]{i + 1} / {len(self.tasks)}[/yellow] Building [blue]{data}[/blue] dataset for ' f'[cyan]{task_name}[/cyan] ...') # Adjust Tokenizer according to task config config = copy(task.config) config.pop('transform', None) task_dataloader: DataLoader = task.build_dataloader( _data, transform, training, device, logger, tokenizer=encoder_transform.tokenizer, gradient_accumulation=gradient_accumulation, cache=isinstance(data, str), **config) # if prune: # # noinspection PyTypeChecker # task_dataset: TransformDataset = task_dataloader.dataset # size_before = len(task_dataset) # task_dataset.prune(prune) # size_after = len(task_dataset) # num_pruned = size_before - size_after # logger.info(f'Pruned [yellow]{num_pruned} ({num_pruned / size_before:.1%})[/yellow] ' # f'samples out of {size_before}.') if cache and data in ('trn', 'dev'): task_dataloader: CachedDataLoader = CachedDataLoader( task_dataloader, f'{cache}/{os.getpid()}-{data}-{task_name.replace("/", "-")}-cache.pt' if isinstance(cache, str) else None) dataloader.dataloaders[task_name] = task_dataloader if data == 'trn': sampling_weights, total_size = dataloader.sampling_weights headings = [ 'task', '#batches', '%batches', '#scaled', '%scaled', '#epoch' ] matrix = [] min_epochs = [] for (task_name, dataset), weight in zip(dataloader.dataloaders.items(), sampling_weights): epochs = len(dataset) / weight / total_size matrix.append([ f'{task_name}', len(dataset), f'{len(dataset) / total_size:.2%}', int(total_size * weight), f'{weight:.2%}', f'{epochs:.2f}' ]) min_epochs.append(epochs) longest = int(torch.argmax(torch.tensor(min_epochs))) table = markdown_table(headings, matrix) rows = table.splitlines() cells = rows[longest + 2].split('|') cells[-2] = cells[-2].replace( f'{min_epochs[longest]:.2f}', f'[bold][red]{min_epochs[longest]:.2f}[/red][/bold]') rows[longest + 2] = '|'.join(cells) logger.info( f'[bold][yellow]{"Samples Distribution": ^{len(rows[0])}}[/yellow][/bold]' ) logger.info('\n'.join(rows)) if prefetch and (data == 'trn' or not tasks_need_custom_eval): dataloader = PrefetchDataLoader(dataloader, prefetch=prefetch) return dataloader