예제 #1
0
 def fit_dataloader(self,
                    trn: DataLoader,
                    criterion,
                    optimizer,
                    metric,
                    logger: logging.Logger,
                    history: History = None,
                    gradient_accumulation=1,
                    ratio_percentage=None,
                    **kwargs):
     optimizer, scheduler = optimizer
     self.model.train()
     timer = CountdownTimer(
         history.num_training_steps(
             len(trn), gradient_accumulation=gradient_accumulation))
     total_loss = 0
     for batch in trn:
         output_dict = self.feed_batch(batch)
         loss = output_dict['loss']
         if gradient_accumulation and gradient_accumulation > 1:
             loss /= gradient_accumulation
         loss.backward()
         total_loss += loss.item()
         if history.step(gradient_accumulation):
             self._step(optimizer, scheduler)
             timer.log(self.report_metrics(total_loss /
                                           (timer.current + 1)),
                       ratio_percentage=ratio_percentage,
                       logger=logger)
         del loss
         del output_dict
     return total_loss / max(timer.total, 1)
예제 #2
0
 def fit_dataloader(self,
                    trn: DataLoader,
                    criterion,
                    optimizer,
                    metric,
                    logger: logging.Logger,
                    linear_scheduler=None,
                    history: History = None,
                    gradient_accumulation=1,
                    **kwargs):
     self.model.train()
     timer = CountdownTimer(history.num_training_steps(len(trn), gradient_accumulation=gradient_accumulation))
     total_loss = 0
     self.reset_metrics(metric)
     for batch in trn:
         optimizer.zero_grad()
         output_dict = self.feed_batch(batch)
         self.update_metrics(batch, output_dict, metric)
         loss = output_dict['loss']
         if gradient_accumulation and gradient_accumulation > 1:
             loss /= gradient_accumulation
         loss.backward()
         total_loss += loss.item()
         if history.step(gradient_accumulation):
             if self.config.grad_norm:
                 torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_norm)
             optimizer.step()
             if linear_scheduler:
                 linear_scheduler.step()
             timer.log(self.report_metrics(total_loss / (timer.current + 1), metric), ratio_percentage=None,
                       logger=logger)
         del loss
     return total_loss / timer.total
예제 #3
0
 def fit_dataloader(self,
                    trn: DataLoader,
                    criterion,
                    optimizer,
                    metric,
                    logger: logging.Logger,
                    history: History,
                    gradient_accumulation=1,
                    grad_norm=None,
                    transformer_grad_norm=None,
                    teacher: Tagger = None,
                    kd_criterion=None,
                    temperature_scheduler=None,
                    ratio_width=None,
                    **kwargs):
     optimizer, scheduler = optimizer
     if teacher:
         scheduler, lambda_scheduler = scheduler
     else:
         lambda_scheduler = None
     self.model.train()
     timer = CountdownTimer(
         history.num_training_steps(
             len(trn), gradient_accumulation=gradient_accumulation))
     total_loss = 0
     for idx, batch in enumerate(trn):
         out, mask = self.feed_batch(batch)
         y = batch['tag_id']
         loss = self.compute_loss(criterion, out, y, mask)
         if gradient_accumulation and gradient_accumulation > 1:
             loss /= gradient_accumulation
         if teacher:
             with torch.no_grad():
                 out_T, _ = teacher.feed_batch(batch)
             # noinspection PyNoneFunctionAssignment
             kd_loss = self.compute_distill_loss(kd_criterion, out, out_T,
                                                 mask,
                                                 temperature_scheduler)
             _lambda = float(lambda_scheduler)
             loss = _lambda * loss + (1 - _lambda) * kd_loss
         loss.backward()
         total_loss += loss.item()
         prediction = self.decode_output(out, mask, batch)
         self.update_metrics(metric, out, y, mask, batch, prediction)
         if history.step(gradient_accumulation):
             self._step(optimizer, scheduler, grad_norm,
                        transformer_grad_norm, lambda_scheduler)
             report = f'loss: {total_loss / (idx + 1):.4f} {metric}'
             timer.log(report,
                       logger=logger,
                       ratio_percentage=False,
                       ratio_width=ratio_width)
         del loss
         del out
         del mask
예제 #4
0
 def execute_training_loop(self,
                           trn: DataLoader,
                           dev: DataLoader,
                           epochs,
                           criterion,
                           optimizer,
                           metric,
                           save_dir,
                           logger: logging.Logger,
                           devices,
                           gradient_accumulation=1,
                           **kwargs):
     best_epoch, best_metric = 0, -1
     optimizer, scheduler = optimizer
     history = History()
     timer = CountdownTimer(epochs)
     ratio_width = len(f'{len(trn)}/{len(trn)}')
     for epoch in range(1, epochs + 1):
         logger.info(f"[yellow]Epoch {epoch} / {epochs}:[/yellow]")
         self.fit_dataloader(trn, criterion, optimizer, metric, logger, history=history,
                             gradient_accumulation=gradient_accumulation,
                             linear_scheduler=scheduler if self._get_transformer() else None)
         if dev:
             self.evaluate_dataloader(dev, criterion, metric, logger, ratio_width=ratio_width)
         report = f'{timer.elapsed_human}/{timer.total_time_human}'
         dev_score = metric.score
         if not self._get_transformer():
             scheduler.step(dev_score)
         if dev_score > best_metric:
             self.save_weights(save_dir)
             best_metric = dev_score
             report += ' [red]saved[/red]'
         timer.log(report, ratio_percentage=False, newline=True, ratio=False)
예제 #5
0
 def execute_training_loop(self, trn, dev, devices, epochs, logger,
                           patience, save_dir, optimizer,
                           gradient_accumulation, **kwargs):
     optimizer, scheduler, transformer_optimizer, transformer_scheduler = optimizer
     criterion = self.build_criterion()
     best_e, best_metric = 0, self.build_metric()
     timer = CountdownTimer(epochs)
     history = History()
     ratio_width = len(
         f'{len(trn) // gradient_accumulation}/{len(trn) // gradient_accumulation}'
     )
     for epoch in range(1, epochs + 1):
         # train one epoch and update the parameters
         logger.info(f"[yellow]Epoch {epoch} / {epochs}:[/yellow]")
         self.fit_dataloader(trn,
                             optimizer,
                             scheduler,
                             criterion,
                             epoch,
                             logger,
                             history,
                             transformer_optimizer,
                             transformer_scheduler,
                             gradient_accumulation=gradient_accumulation)
         loss, dev_metric = self.evaluate_dataloader(
             dev, criterion, ratio_width=ratio_width, logger=logger)
         timer.update()
         # logger.info(f"{'Dev' + ' ' * ratio_width} loss: {loss:.4f} {dev_metric}")
         # save the model if it is the best so far
         report = f"{timer.elapsed_human} / {timer.total_time_human} ETA: {timer.eta_human}"
         if dev_metric > best_metric:
             best_e, best_metric = epoch, dev_metric
             self.save_weights(save_dir)
             report += ' ([red]saved[/red])'
         else:
             if patience != epochs:
                 report += f' ({epoch - best_e}/{patience})'
             else:
                 report += f' ({epoch - best_e})'
         logger.info(report)
         if patience is not None and epoch - best_e >= patience:
             logger.info(
                 f'LAS has stopped improving for {patience} epochs, early stop.'
             )
             break
     timer.stop()
     if not best_e:
         self.save_weights(save_dir)
     elif best_e != epoch:
         self.load_weights(save_dir)
     logger.info(
         f"Max score of dev is {best_metric.score:.2%} at epoch {best_e}")
     logger.info(
         f"Average time of each epoch is {timer.elapsed_average_human}")
     logger.info(f"{timer.elapsed_human} elapsed")
예제 #6
0
 def execute_training_loop(self,
                           trn: PrefetchDataLoader,
                           dev: PrefetchDataLoader,
                           epochs,
                           criterion,
                           optimizer,
                           metric,
                           save_dir,
                           logger: logging.Logger,
                           devices,
                           ratio_width=None,
                           dev_data=None,
                           gradient_accumulation=1,
                           **kwargs):
     best_epoch, best_metric = 0, -1
     timer = CountdownTimer(epochs)
     history = History()
     try:
         for epoch in range(1, epochs + 1):
             logger.info(f"[yellow]Epoch {epoch} / {epochs}:[/yellow]")
             trn = self.fit_dataloader(
                 trn,
                 criterion,
                 optimizer,
                 metric,
                 logger,
                 ratio_width=ratio_width,
                 gradient_accumulation=gradient_accumulation,
                 history=history,
                 save_dir=save_dir)
             report = f'{timer.elapsed_human}/{timer.total_time_human}'
             if epoch % self.config.eval_every == 0 or epoch == epochs:
                 metric = self.evaluate_dataloader(dev,
                                                   logger,
                                                   dev_data,
                                                   ratio_width=ratio_width,
                                                   save_dir=save_dir,
                                                   use_fast=True)
                 if metric > best_metric:
                     self.save_weights(save_dir)
                     best_metric = metric
                     best_epoch = epoch
                     report += ' [red]saved[/red]'
             timer.log(report,
                       ratio_percentage=False,
                       newline=True,
                       ratio=False)
         if best_epoch and best_epoch != epochs:
             logger.info(
                 f'Restored the best model with {best_metric} saved {epochs - best_epoch} epochs ago'
             )
             self.load_weights(save_dir)
     finally:
         trn.close()
         dev.close()
예제 #7
0
 def fit_dataloader(self,
                    trn: DataLoader,
                    criterion,
                    optimizer,
                    metric: SpanMetric,
                    logger: logging.Logger,
                    history: History,
                    gradient_accumulation=1,
                    grad_norm=None,
                    ratio_width=None,
                    eval_trn=True,
                    **kwargs):
     optimizer, scheduler = optimizer
     metric.reset()
     self.model.train()
     timer = CountdownTimer(
         history.num_training_steps(
             len(trn), gradient_accumulation=gradient_accumulation))
     total_loss = 0
     for idx, batch in enumerate(trn):
         out, mask = self.feed_batch(batch)
         y = batch['chart_id']
         loss, span_probs = self.compute_loss(out, y, mask)
         if gradient_accumulation and gradient_accumulation > 1:
             loss /= gradient_accumulation
         loss.backward()
         total_loss += loss.item()
         if eval_trn:
             prediction = self.decode_output(out, mask, batch, span_probs)
             self.update_metrics(metric, batch, prediction)
         if history.step(gradient_accumulation):
             self._step(optimizer, scheduler, grad_norm)
             report = f'loss: {total_loss / (idx + 1):.4f} {metric}' if eval_trn \
                 else f'loss: {total_loss / (idx + 1):.4f}'
             timer.log(report,
                       logger=logger,
                       ratio_percentage=False,
                       ratio_width=ratio_width)
         del loss
         del out
         del mask
예제 #8
0
 def execute_training_loop(self,
                           trn: DataLoader,
                           dev: DataLoader,
                           epochs,
                           criterion,
                           optimizer,
                           metric,
                           save_dir,
                           logger: logging.Logger,
                           devices,
                           ratio_width=None,
                           patience=0.5,
                           eval_trn=True,
                           **kwargs):
     if isinstance(patience, float):
         patience = int(patience * epochs)
     best_epoch, best_metric = 0, -1
     timer = CountdownTimer(epochs)
     history = History()
     for epoch in range(1, epochs + 1):
         logger.info(f"[yellow]Epoch {epoch} / {epochs}:[/yellow]")
         self.fit_dataloader(trn,
                             criterion,
                             optimizer,
                             metric,
                             logger,
                             history=history,
                             ratio_width=ratio_width,
                             eval_trn=eval_trn,
                             **self.config)
         loss, dev_metric = self.evaluate_dataloader(
             dev, criterion, logger=logger, ratio_width=ratio_width)
         timer.update()
         report = f"{timer.elapsed_human} / {timer.total_time_human} ETA: {timer.eta_human}"
         if dev_metric > best_metric:
             best_epoch, best_metric = epoch, dev_metric
             self.save_weights(save_dir)
             report += ' [red](saved)[/red]'
         else:
             report += f' ({epoch - best_epoch})'
             if epoch - best_epoch >= patience:
                 report += ' early stop'
         logger.info(report)
         if epoch - best_epoch >= patience:
             break
     if not best_epoch:
         self.save_weights(save_dir)
     elif best_epoch != epoch:
         self.load_weights(save_dir)
     logger.info(f"Max score of dev is {best_metric} at epoch {best_epoch}")
     logger.info(
         f"Average time of each epoch is {timer.elapsed_average_human}")
     logger.info(f"{timer.elapsed_human} elapsed")
예제 #9
0
    def fit_dataloader(self,
                       trn,
                       optimizer,
                       scheduler,
                       criterion,
                       epoch,
                       logger,
                       history: History,
                       transformer_optimizer=None,
                       transformer_scheduler=None,
                       gradient_accumulation=1,
                       eval_trn=False,
                       **kwargs):
        self.model.train()

        timer = CountdownTimer(history.num_training_steps(len(trn), gradient_accumulation))
        metric = self.build_metric(training=True)
        total_loss = 0
        for idx, batch in enumerate(trn):
            optimizer.zero_grad()
            (s_arc, s_sib, s_rel), mask, puncts = self.feed_batch(batch)
            arcs, sibs, rels = batch['arc'], batch['sib_id'], batch['rel_id']

            loss, s_arc = self.compute_loss(s_arc, s_sib, s_rel, arcs, sibs, rels, mask)
            if gradient_accumulation > 1:
                loss /= gradient_accumulation
            loss.backward()
            total_loss += loss.item()
            if eval_trn:
                arc_preds, rel_preds = self.decode(s_arc, s_sib, s_rel, mask)
                self.update_metric(arc_preds, rel_preds, arcs, rels, mask, puncts, metric)
            if history.step(gradient_accumulation):
                self._step(optimizer, scheduler, transformer_optimizer, transformer_scheduler)
                report = self._report(total_loss / (timer.current + 1), metric if eval_trn else None)
                lr = scheduler.get_last_lr()[0]
                report += f' lr: {lr:.4e}'
                timer.log(report, ratio_percentage=False, logger=logger)
            del loss
예제 #10
0
    def fit_dataloader(self,
                       trn,
                       optimizer,
                       scheduler,
                       criterion,
                       epoch,
                       logger,
                       history: History,
                       transformer_optimizer=None,
                       transformer_scheduler=None,
                       gradient_accumulation=1,
                       **kwargs):
        self.model.train()

        timer = CountdownTimer(
            history.num_training_steps(len(trn), gradient_accumulation))
        metric = self.build_metric(training=True)
        total_loss = 0
        for idx, batch in enumerate(trn):
            arc_scores, rel_scores, mask, puncts = self.feed_batch(batch)
            arcs, rels = batch['arc'], batch['rel_id']
            loss = self.compute_loss(arc_scores, rel_scores, arcs, rels, mask,
                                     criterion, batch)
            if gradient_accumulation > 1:
                loss /= gradient_accumulation
            loss.backward()
            total_loss += loss.item()
            arc_preds, rel_preds = self.decode(arc_scores, rel_scores, mask,
                                               batch)
            self.update_metric(arc_preds, rel_preds, arcs, rels, mask, puncts,
                               metric, batch)
            if history.step(gradient_accumulation):
                self._step(optimizer, scheduler, transformer_optimizer,
                           transformer_scheduler)
                report = self._report(total_loss / (timer.current + 1), metric)
                timer.log(report, ratio_percentage=False, logger=logger)
            del loss
예제 #11
0
 def execute_training_loop(self,
                           trn: DataLoader,
                           dev: DataLoader,
                           epochs,
                           criterion,
                           optimizer,
                           metric,
                           save_dir,
                           logger: logging.Logger,
                           devices,
                           ratio_width=None,
                           dev_data=None,
                           eval_after=None,
                           **kwargs):
     best_epoch, best_metric = 0, -1
     if isinstance(eval_after, float):
         eval_after = int(epochs * eval_after)
     timer = CountdownTimer(epochs)
     history = History()
     for epoch in range(1, epochs + 1):
         logger.info(f"[yellow]Epoch {epoch} / {epochs}:[/yellow]")
         self.fit_dataloader(trn,
                             criterion,
                             optimizer,
                             metric,
                             logger,
                             history=history,
                             ratio_width=ratio_width,
                             **self.config)
         if epoch > eval_after:
             dev_metric = self.evaluate_dataloader(dev,
                                                   criterion,
                                                   logger=logger,
                                                   ratio_width=ratio_width,
                                                   output=os.path.join(
                                                       save_dir,
                                                       'dev.pred.txt'),
                                                   input=dev_data,
                                                   use_fast=True)
         timer.update()
         report = f"{timer.elapsed_human} / {timer.total_time_human} ETA: {timer.eta_human}"
         if epoch > eval_after:
             if dev_metric > best_metric:
                 best_epoch, best_metric = epoch, dev_metric
                 self.save_weights(save_dir)
                 report += ' [red](saved)[/red]'
             else:
                 report += f' ({epoch - best_epoch})'
             # if epoch - best_epoch >= patience:
             #     report += ' early stop'
         logger.info(report)
         # if epoch - best_epoch >= patience:
         #     break
     if not best_epoch:
         self.save_weights(save_dir)
     elif best_epoch != epoch:
         self.load_weights(save_dir)
     logger.info(f"Max score of dev is {best_metric} at epoch {best_epoch}")
     logger.info(
         f"Average time of each epoch is {timer.elapsed_average_human}")
     logger.info(f"{timer.elapsed_human} elapsed")
     return best_metric
예제 #12
0
 def fit_dataloader(self,
                    trn: DataLoader,
                    criterion,
                    optimizer,
                    metric,
                    logger: logging.Logger,
                    history: History,
                    ratio_width=None,
                    gradient_accumulation=1,
                    encoder_grad_norm=None,
                    decoder_grad_norm=None,
                    patience=0.5,
                    eval_trn=False,
                    **kwargs):
     self.model.train()
     encoder_optimizer, encoder_scheduler, decoder_optimizers = optimizer
     timer = CountdownTimer(len(trn))
     total_loss = 0
     self.reset_metrics(metric)
     model = self.model_
     encoder_parameters = model.encoder.parameters()
     decoder_parameters = model.decoders.parameters()
     for idx, (task_name, batch) in enumerate(trn):
         decoder_optimizer = decoder_optimizers.get(task_name, None)
         output_dict, _ = self.feed_batch(batch, task_name)
         loss = self.compute_loss(batch, output_dict[task_name]['output'],
                                  criterion[task_name],
                                  self.tasks[task_name])
         if gradient_accumulation and gradient_accumulation > 1:
             loss /= gradient_accumulation
         loss.backward()
         total_loss += float(loss.item())
         if history.step(gradient_accumulation):
             if self.config.get('grad_norm', None):
                 clip_grad_norm(model, self.config.grad_norm)
             if encoder_grad_norm:
                 torch.nn.utils.clip_grad_norm_(encoder_parameters,
                                                encoder_grad_norm)
             if decoder_grad_norm:
                 torch.nn.utils.clip_grad_norm_(decoder_parameters,
                                                decoder_grad_norm)
             encoder_optimizer.step()
             encoder_optimizer.zero_grad()
             encoder_scheduler.step()
             if decoder_optimizer:
                 if isinstance(decoder_optimizer, tuple):
                     decoder_optimizer, decoder_scheduler = decoder_optimizer
                 else:
                     decoder_scheduler = None
                 decoder_optimizer.step()
                 decoder_optimizer.zero_grad()
                 if decoder_scheduler:
                     decoder_scheduler.step()
         if eval_trn:
             self.decode_output(output_dict, batch, task_name)
             self.update_metrics(batch, output_dict, metric, task_name)
         timer.log(self.report_metrics(total_loss / (timer.current + 1),
                                       metric if eval_trn else None),
                   ratio_percentage=None,
                   ratio_width=ratio_width,
                   logger=logger)
         del loss
         del output_dict
     return total_loss / timer.total
예제 #13
0
 def execute_training_loop(self,
                           trn: DataLoader,
                           dev: DataLoader,
                           epochs,
                           criterion,
                           optimizer,
                           metric,
                           save_dir,
                           logger: logging.Logger,
                           devices,
                           patience=0.5,
                           **kwargs):
     if isinstance(patience, float):
         patience = int(patience * epochs)
     best_epoch, best_metric = 0, -1
     timer = CountdownTimer(epochs)
     ratio_width = len(f'{len(trn)}/{len(trn)}')
     epoch = 0
     history = History()
     for epoch in range(1, epochs + 1):
         logger.info(f"[yellow]Epoch {epoch} / {epochs}:[/yellow]")
         self.fit_dataloader(trn,
                             criterion,
                             optimizer,
                             metric,
                             logger,
                             history,
                             ratio_width=ratio_width,
                             **self.config)
         if dev:
             self.evaluate_dataloader(dev,
                                      criterion,
                                      metric,
                                      logger,
                                      ratio_width=ratio_width,
                                      input='dev')
         report = f'{timer.elapsed_human}/{timer.total_time_human}'
         dev_score = metric.score
         if dev_score > best_metric:
             self.save_weights(save_dir)
             best_metric = dev_score
             best_epoch = epoch
             report += ' [red]saved[/red]'
         else:
             report += f' ({epoch - best_epoch})'
             if epoch - best_epoch >= patience:
                 report += ' early stop'
                 break
         timer.log(report,
                   ratio_percentage=False,
                   newline=True,
                   ratio=False)
     for d in [trn, dev]:
         if isinstance(d, PrefetchDataLoader):
             d.close()
     if best_epoch != epoch:
         logger.info(
             f'Restoring best model saved [red]{epoch - best_epoch}[/red] epochs ago'
         )
         self.load_weights(save_dir)
     return best_metric