예제 #1
0
 def evaluate_dataloader(self,
                         data: DataLoader,
                         criterion: Callable,
                         metric,
                         logger,
                         ratio_width=None,
                         output=False,
                         official=False,
                         confusion_matrix=False,
                         **kwargs):
     self.model.eval()
     self.reset_metrics(metric)
     timer = CountdownTimer(len(data))
     total_loss = 0
     if official:
         sentences = []
         gold = []
         pred = []
     for batch in data:
         output_dict = self.feed_batch(batch)
         if official:
             sentences += batch['token']
             gold += batch['srl']
             pred += output_dict['prediction']
         self.update_metrics(batch, output_dict, metric)
         loss = output_dict['loss']
         total_loss += loss.item()
         timer.log(self.report_metrics(total_loss / (timer.current + 1), metric), ratio_percentage=None,
                   logger=logger,
                   ratio_width=ratio_width)
         del loss
     if official:
         scores = compute_srl_f1(sentences, gold, pred)
         if logger:
             if confusion_matrix:
                 labels = sorted(set(y for x in scores.label_confusions.keys() for y in x))
                 headings = ['GOLD↓PRED→'] + labels
                 matrix = []
                 for i, gold in enumerate(labels):
                     row = [gold]
                     matrix.append(row)
                     for j, pred in enumerate(labels):
                         row.append(scores.label_confusions.get((gold, pred), 0))
                 matrix = markdown_table(headings, matrix)
                 logger.info(f'{"Confusion Matrix": ^{len(matrix.splitlines()[0])}}')
                 logger.info(matrix)
             headings = ['Settings', 'Precision', 'Recall', 'F1']
             data = []
             for h, (p, r, f) in zip(['Unlabeled', 'Labeled', 'Official'], [
                 [scores.unlabeled_precision, scores.unlabeled_recall, scores.unlabeled_f1],
                 [scores.precision, scores.recall, scores.f1],
                 [scores.conll_precision, scores.conll_recall, scores.conll_f1],
             ]):
                 data.append([h] + [f'{x:.2%}' for x in [p, r, f]])
             table = markdown_table(headings, data)
             logger.info(f'{"Scores": ^{len(table.splitlines()[0])}}')
             logger.info(table)
     else:
         scores = metric
     return total_loss / timer.total, scores
예제 #2
0
    def to_markdown(self, headings: Union[str, List[str]] = 'auto') -> str:
        r"""Convert into markdown string.

        Args:
            headings: ``auto`` to automatically detect the word type. When passed a list of string, they are treated as
                        headings for each field.

        Returns:
            A markdown representation of this sentence.
        """
        cells = [str(word).split('\t') for word in self]
        if headings == 'auto':
            if isinstance(self[0], CoNLLWord):
                headings = [
                    'ID', 'FORM', 'LEMMA', 'CPOS', 'POS', 'FEATS', 'HEAD',
                    'DEPREL', 'PHEAD', 'PDEPREL'
                ]
            else:  # conllu
                headings = [
                    'ID', 'FORM', 'LEMMA', 'UPOS', 'XPOS', 'FEATS', 'HEAD',
                    'DEPREL', 'DEPS', 'MISC'
                ]
                for each in cells:
                    # if '|' in each[8]:
                    # each[8] = f'`{each[8]}`'
                    each[8] = each[8].replace('|', '⎮')
        alignment = [('^', '>'), ('^', '<'), ('^', '<'), ('^', '<'),
                     ('^', '<'), ('^', '<'), ('^', '>'), ('^', '<'),
                     ('^', '<'), ('^', '<')]
        text = markdown_table(headings, cells, alignment=alignment)
        return text
예제 #3
0
    def build_dataloader(self,
                         data,
                         batch_size,
                         shuffle=False,
                         device=None,
                         logger: logging.Logger = None,
                         gradient_accumulation=1,
                         tau: float = 0.8,
                         prune=None,
                         prefetch=None,
                         tasks_need_custom_eval=None,
                         cache=False,
                         debug=False,
                         **kwargs) -> DataLoader:
        # This method is only called during training or evaluation but not prediction
        dataloader = MultiTaskDataLoader(training=shuffle, tau=tau)
        for i, (task_name, task) in enumerate(self.tasks.items()):
            encoder_transform, transform = self.build_transform(task)
            training = None
            if data == 'trn':
                if debug:
                    _data = task.dev
                else:
                    _data = task.trn
                training = True
            elif data == 'dev':
                _data = task.dev
                training = False
            elif data == 'tst':
                _data = task.tst
                training = False
            else:
                _data = data
            if isinstance(data, str):
                logger.info(
                    f'[yellow]{i + 1} / {len(self.tasks)}[/yellow] Building [blue]{data}[/blue] dataset for '
                    f'[cyan]{task_name}[/cyan] ...')
            # Adjust Tokenizer according to task config
            config = copy(task.config)
            config.pop('transform', None)
            task_dataloader: DataLoader = task.build_dataloader(
                _data,
                transform,
                training,
                device,
                logger,
                tokenizer=encoder_transform.tokenizer,
                gradient_accumulation=gradient_accumulation,
                cache=isinstance(data, str),
                **config)
            # if prune:
            #     # noinspection PyTypeChecker
            #     task_dataset: TransformDataset = task_dataloader.dataset
            #     size_before = len(task_dataset)
            #     task_dataset.prune(prune)
            #     size_after = len(task_dataset)
            #     num_pruned = size_before - size_after
            #     logger.info(f'Pruned [yellow]{num_pruned} ({num_pruned / size_before:.1%})[/yellow] '
            #                 f'samples out of {size_before}.')
            if cache and data in ('trn', 'dev'):
                task_dataloader: CachedDataLoader = CachedDataLoader(
                    task_dataloader,
                    f'{cache}/{os.getpid()}-{data}-{task_name.replace("/", "-")}-cache.pt'
                    if isinstance(cache, str) else None)
            dataloader.dataloaders[task_name] = task_dataloader
        if data == 'trn':
            sampling_weights, total_size = dataloader.sampling_weights
            headings = [
                'task', '#batches', '%batches', '#scaled', '%scaled', '#epoch'
            ]
            matrix = []
            min_epochs = []
            for (task_name,
                 dataset), weight in zip(dataloader.dataloaders.items(),
                                         sampling_weights):
                epochs = len(dataset) / weight / total_size
                matrix.append([
                    f'{task_name}',
                    len(dataset), f'{len(dataset) / total_size:.2%}',
                    int(total_size * weight), f'{weight:.2%}', f'{epochs:.2f}'
                ])
                min_epochs.append(epochs)
            longest = int(torch.argmax(torch.tensor(min_epochs)))
            table = markdown_table(headings, matrix)
            rows = table.splitlines()
            cells = rows[longest + 2].split('|')
            cells[-2] = cells[-2].replace(
                f'{min_epochs[longest]:.2f}',
                f'[bold][red]{min_epochs[longest]:.2f}[/red][/bold]')
            rows[longest + 2] = '|'.join(cells)
            logger.info(
                f'[bold][yellow]{"Samples Distribution": ^{len(rows[0])}}[/yellow][/bold]'
            )
            logger.info('\n'.join(rows))
        if prefetch and (data == 'trn' or not tasks_need_custom_eval):
            dataloader = PrefetchDataLoader(dataloader, prefetch=prefetch)

        return dataloader