コード例 #1
0
ファイル: amr.py プロジェクト: lei1993/HanLP
 def build_dataloader(self,
                      data,
                      transform: Callable = None,
                      training=False,
                      device=None,
                      logger: logging.Logger = None,
                      cache=False,
                      gradient_accumulation=1,
                      **kwargs) -> DataLoader:
     if isinstance(data, list):
         data = GraphAbstractMeaningRepresentationParser.build_samples(
             self, data)
     dataset, lens = GraphAbstractMeaningRepresentationParser.build_dataset(
         self, data, logger=logger, transform=transform, training=training)
     if self.vocabs.mutable:
         GraphAbstractMeaningRepresentationParser.build_vocabs(
             self, dataset, logger)
     dataloader = PrefetchDataLoader(
         DataLoader(batch_sampler=self.sampler_builder.build(
             lens,
             shuffle=training,
             gradient_accumulation=gradient_accumulation),
                    dataset=dataset,
                    collate_fn=merge_list_of_dict,
                    num_workers=0),
         batchify=self.build_batchify(device, training),
         prefetch=None)
     return dataloader
コード例 #2
0
ファイル: multi_task_learning.py プロジェクト: turkeymz/HanLP
    def build_dataloader(self,
                         data,
                         batch_size,
                         shuffle=False,
                         device=None,
                         logger: logging.Logger = None,
                         gradient_accumulation=1,
                         tau: float = 0.8,
                         prune=None,
                         prefetch=None,
                         tasks_need_custom_eval=None,
                         cache=False,
                         debug=False,
                         **kwargs) -> DataLoader:
        # This method is only called during training or evaluation but not prediction
        dataloader = MultiTaskDataLoader(training=shuffle, tau=tau)
        for i, (task_name, task) in enumerate(self.tasks.items()):
            encoder_transform, transform = self.build_transform(task)
            training = None
            if data == 'trn':
                if debug:
                    _data = task.dev
                else:
                    _data = task.trn
                training = True
            elif data == 'dev':
                _data = task.dev
                training = False
            elif data == 'tst':
                _data = task.tst
                training = False
            else:
                _data = data
            if isinstance(data, str):
                logger.info(
                    f'[yellow]{i + 1} / {len(self.tasks)}[/yellow] Building [blue]{data}[/blue] dataset for '
                    f'[cyan]{task_name}[/cyan] ...')
            # Adjust Tokenizer according to task config
            config = copy(task.config)
            config.pop('transform', None)
            task_dataloader: DataLoader = task.build_dataloader(
                _data,
                transform,
                training,
                device,
                logger,
                tokenizer=encoder_transform.tokenizer,
                gradient_accumulation=gradient_accumulation,
                cache=isinstance(data, str),
                **config)
            # if prune:
            #     # noinspection PyTypeChecker
            #     task_dataset: TransformDataset = task_dataloader.dataset
            #     size_before = len(task_dataset)
            #     task_dataset.prune(prune)
            #     size_after = len(task_dataset)
            #     num_pruned = size_before - size_after
            #     logger.info(f'Pruned [yellow]{num_pruned} ({num_pruned / size_before:.1%})[/yellow] '
            #                 f'samples out of {size_before}.')
            if cache and data in ('trn', 'dev'):
                task_dataloader: CachedDataLoader = CachedDataLoader(
                    task_dataloader,
                    f'{cache}/{os.getpid()}-{data}-{task_name.replace("/", "-")}-cache.pt'
                    if isinstance(cache, str) else None)
            dataloader.dataloaders[task_name] = task_dataloader
        if data == 'trn':
            sampling_weights, total_size = dataloader.sampling_weights
            headings = [
                'task', '#batches', '%batches', '#scaled', '%scaled', '#epoch'
            ]
            matrix = []
            min_epochs = []
            for (task_name,
                 dataset), weight in zip(dataloader.dataloaders.items(),
                                         sampling_weights):
                epochs = len(dataset) / weight / total_size
                matrix.append([
                    f'{task_name}',
                    len(dataset), f'{len(dataset) / total_size:.2%}',
                    int(total_size * weight), f'{weight:.2%}', f'{epochs:.2f}'
                ])
                min_epochs.append(epochs)
            longest = int(torch.argmax(torch.tensor(min_epochs)))
            table = markdown_table(headings, matrix)
            rows = table.splitlines()
            cells = rows[longest + 2].split('|')
            cells[-2] = cells[-2].replace(
                f'{min_epochs[longest]:.2f}',
                f'[bold][red]{min_epochs[longest]:.2f}[/red][/bold]')
            rows[longest + 2] = '|'.join(cells)
            logger.info(
                f'[bold][yellow]{"Samples Distribution": ^{len(rows[0])}}[/yellow][/bold]'
            )
            logger.info('\n'.join(rows))
        if prefetch and (data == 'trn' or not tasks_need_custom_eval):
            dataloader = PrefetchDataLoader(dataloader, prefetch=prefetch)

        return dataloader