Exemplo n.º 1
0
 def execute_training_loop(self,
                           trn: PrefetchDataLoader,
                           dev: PrefetchDataLoader,
                           epochs,
                           criterion,
                           optimizer,
                           metric,
                           save_dir,
                           logger: logging.Logger,
                           devices,
                           ratio_width=None,
                           dev_data=None,
                           gradient_accumulation=1,
                           **kwargs):
     best_epoch, best_metric = 0, -1
     timer = CountdownTimer(epochs)
     history = History()
     try:
         for epoch in range(1, epochs + 1):
             logger.info(f"[yellow]Epoch {epoch} / {epochs}:[/yellow]")
             trn = self.fit_dataloader(
                 trn,
                 criterion,
                 optimizer,
                 metric,
                 logger,
                 ratio_width=ratio_width,
                 gradient_accumulation=gradient_accumulation,
                 history=history,
                 save_dir=save_dir)
             report = f'{timer.elapsed_human}/{timer.total_time_human}'
             if epoch % self.config.eval_every == 0 or epoch == epochs:
                 metric = self.evaluate_dataloader(dev,
                                                   logger,
                                                   dev_data,
                                                   ratio_width=ratio_width,
                                                   save_dir=save_dir,
                                                   use_fast=True)
                 if metric > best_metric:
                     self.save_weights(save_dir)
                     best_metric = metric
                     best_epoch = epoch
                     report += ' [red]saved[/red]'
             timer.log(report,
                       ratio_percentage=False,
                       newline=True,
                       ratio=False)
         if best_epoch and best_epoch != epochs:
             logger.info(
                 f'Restored the best model with {best_metric} saved {epochs - best_epoch} epochs ago'
             )
             self.load_weights(save_dir)
     finally:
         trn.close()
         dev.close()
Exemplo n.º 2
0
 def build_dataloader(self,
                      data,
                      batch_size,
                      shuffle=False,
                      device=None,
                      logger: logging.Logger = None,
                      gradient_accumulation=1,
                      batch_max_tokens=None,
                      **kwargs) -> DataLoader:
     dataset, lens = self.build_dataset(data, logger, training=shuffle)
     if batch_max_tokens:
         batch_max_tokens //= gradient_accumulation
     if not shuffle:
         batch_max_tokens //= 2
     sampler = SortingSampler(lens,
                              batch_size=None,
                              batch_max_tokens=batch_max_tokens,
                              shuffle=shuffle)
     dataloader = PrefetchDataLoader(
         DataLoader(batch_sampler=sampler,
                    dataset=dataset,
                    collate_fn=merge_list_of_dict,
                    num_workers=0),
         batchify=self.build_batchify(device, shuffle))
     return dataloader
Exemplo n.º 3
0
 def build_dataloader(self,
                      data,
                      transform: Callable = None,
                      training=False,
                      device=None,
                      logger: logging.Logger = None,
                      cache=False,
                      gradient_accumulation=1,
                      **kwargs) -> DataLoader:
     if isinstance(data, list):
         data = GraphAbstractMeaningRepresentationParser.build_samples(
             self, data)
     dataset, lens = GraphAbstractMeaningRepresentationParser.build_dataset(
         self, data, logger=logger, transform=transform, training=training)
     if self.vocabs.mutable:
         GraphAbstractMeaningRepresentationParser.build_vocabs(
             self, dataset, logger)
     dataloader = PrefetchDataLoader(
         DataLoader(batch_sampler=self.sampler_builder.build(
             lens,
             shuffle=training,
             gradient_accumulation=gradient_accumulation),
                    dataset=dataset,
                    collate_fn=merge_list_of_dict,
                    num_workers=0),
         batchify=self.build_batchify(device, training),
         prefetch=None)
     return dataloader
Exemplo n.º 4
0
 def evaluate_dataloader(self,
                         data: PrefetchDataLoader,
                         logger,
                         input,
                         output=False,
                         ratio_width=None,
                         save_dir=None,
                         use_fast=False,
                         test=False,
                         **kwargs):
     self.model.eval()
     pp = PostProcessor(self.vocabs['rel'])
     if not output:
         output = os.path.join(save_dir, os.path.basename(input) + '.pred')
     # Squeeze tokens and concepts into one transformer basically reduces the max num of inputs it can handle
     parse_data(self.model,
                pp,
                data,
                input,
                output,
                max_time_step=80 if self.model.squeeze else 100)
     # noinspection PyBroadException
     try:
         output = post_process(output,
                               amr_version=self.config.get(
                                   'amr_version', '2.0'))
         scores = smatch_eval(output,
                              input.replace('.features.preproc', ''),
                              use_fast=use_fast)
     except Exception:
         eprint(f'Evaluation failed due to the following error:')
         traceback.print_exc()
         eprint(
             'As smatch usually fails on erroneous outputs produced at early epochs, '
             'it might be OK to ignore it. Now `nan` will be returned as the score.'
         )
         scores = F1_(float("nan"), float("nan"), float("nan"))
     if logger:
         header = f'{len(data)}/{len(data)}'
         if not ratio_width:
             ratio_width = len(header)
         logger.info(header.rjust(ratio_width) + f' {scores}')
     if test:
         data.close()
     return scores
Exemplo n.º 5
0
    def fit_dataloader(self,
                       trn: PrefetchDataLoader,
                       criterion,
                       optimizer,
                       metric,
                       logger: logging.Logger,
                       gradient_accumulation=1,
                       ratio_width=None,
                       history=None,
                       save_dir=None,
                       **kwargs):
        self.model.train()
        num_training_steps = len(
            trn) * self.config.epochs // gradient_accumulation
        shuffle_sibling_steps = self.config.shuffle_sibling_steps
        if isinstance(shuffle_sibling_steps, float):
            shuffle_sibling_steps = int(shuffle_sibling_steps *
                                        num_training_steps)
        timer = CountdownTimer(
            len([
                i for i in range(history.num_mini_batches +
                                 1, history.num_mini_batches + len(trn) + 1)
                if i % gradient_accumulation == 0
            ]))
        total_loss = 0
        optimizer, scheduler = optimizer
        correct_conc, total_conc, correct_rel, total_rel = 0, 0, 0, 0
        for idx, batch in enumerate(trn):
            loss = self.compute_loss(batch)
            if self.config.joint_arc_concept or self.model.squeeze or self.config.bart:
                loss, (concept_correct, concept_total), rel_out = loss
                correct_conc += concept_correct
                total_conc += concept_total
                if rel_out is not None:
                    rel_correct, rel_total = rel_out
                    correct_rel += rel_correct
                    total_rel += rel_total
            loss /= gradient_accumulation
            # loss = loss.sum()  # For data parallel
            loss.backward()
            total_loss += loss.item()
            history.num_mini_batches += 1
            if history.num_mini_batches % gradient_accumulation == 0:
                self._step(optimizer, scheduler)
                metric = ''
                if self.config.joint_arc_concept or self.model.squeeze or self.model.bart:
                    metric = f' Concept acc: {correct_conc / total_conc:.2%}'
                    if not self.config.levi_graph:
                        metric += f' Relation acc: {correct_rel / total_rel:.2%}'
                timer.log(
                    f'loss: {total_loss / (timer.current + 1):.4f} lr: {optimizer.param_groups[0]["lr"]:.2e}'
                    + metric,
                    ratio_percentage=None,
                    ratio_width=ratio_width,
                    logger=logger)

                if history.num_mini_batches // gradient_accumulation == shuffle_sibling_steps:
                    trn.batchify = self.build_batchify(self.device,
                                                       shuffle=True,
                                                       shuffle_sibling=False)
                    timer.print(
                        f'Switched to [bold]deterministic order[/bold] after {shuffle_sibling_steps} steps',
                        newline=True)
            del loss
        return trn
Exemplo n.º 6
0
    def build_dataloader(self,
                         data,
                         batch_size,
                         shuffle=False,
                         device=None,
                         logger: logging.Logger = None,
                         gradient_accumulation=1,
                         tau: float = 0.8,
                         prune=None,
                         prefetch=None,
                         tasks_need_custom_eval=None,
                         **kwargs) -> DataLoader:
        dataloader = MultiTaskDataLoader(training=shuffle, tau=tau)
        for i, (task_name, task) in enumerate(self.tasks.items()):
            encoder_transform, transform = self.build_transform(task)
            training = None
            if data == 'trn':
                _data = task.trn
                training = True
            elif data == 'dev':
                _data = task.dev
                training = False
            elif data == 'tst':
                _data = task.tst
                training = False
            else:
                _data = data
            if isinstance(data, str):
                logger.info(
                    f'[yellow]{i + 1} / {len(self.tasks)}[/yellow] Building [blue]{data}[/blue] dataset for '
                    f'[cyan]{task_name}[/cyan] ...')
            # Adjust Tokenizer according to task config
            config = copy(task.config)
            config.pop('transform', None)
            task_dataloader: DataLoader = task.build_dataloader(
                _data,
                transform,
                training,
                device,
                logger,
                tokenizer=encoder_transform.tokenizer,
                gradient_accumulation=gradient_accumulation,
                cache=isinstance(data, str),
                **config)
            # if prune:
            #     # noinspection PyTypeChecker
            #     task_dataset: TransformDataset = task_dataloader.dataset
            #     size_before = len(task_dataset)
            #     task_dataset.prune(prune)
            #     size_after = len(task_dataset)
            #     num_pruned = size_before - size_after
            #     logger.info(f'Pruned [yellow]{num_pruned} ({num_pruned / size_before:.1%})[/yellow] '
            #                 f'samples out of {size_before}.')
            dataloader.dataloaders[task_name] = task_dataloader
        if data == 'trn':
            sampling_weights, total_size = dataloader.sampling_weights
            headings = [
                'task', '#batches', '%batches', '#scaled', '%scaled', '#epoch'
            ]
            matrix = []
            min_epochs = []
            for (task_name,
                 dataset), weight in zip(dataloader.dataloaders.items(),
                                         sampling_weights):
                epochs = len(dataset) / weight / total_size
                matrix.append([
                    f'{task_name}',
                    len(dataset), f'{len(dataset) / total_size:.2%}',
                    int(total_size * weight), f'{weight:.2%}', f'{epochs:.2f}'
                ])
                min_epochs.append(epochs)
            longest = int(torch.argmax(torch.tensor(min_epochs)))
            table = markdown_table(headings, matrix)
            rows = table.splitlines()
            cells = rows[longest + 2].split('|')
            cells[-2] = cells[-2].replace(
                f'{min_epochs[longest]:.2f}',
                f'[bold][red]{min_epochs[longest]:.2f}[/red][/bold]')
            rows[longest + 2] = '|'.join(cells)
            logger.info(
                f'[bold][yellow]{"Samples Distribution": ^{len(rows[0])}}[/yellow][/bold]'
            )
            logger.info('\n'.join(rows))
        if prefetch and isinstance(
                data, str) and (data == 'trn' or not tasks_need_custom_eval):
            dataloader = PrefetchDataLoader(dataloader, prefetch=prefetch)

        return dataloader