예제 #1
0
 def fit_dataloader(self,
                    trn: DataLoader,
                    criterion,
                    optimizer,
                    metric,
                    logger: logging.Logger,
                    history: History = None,
                    gradient_accumulation=1,
                    ratio_percentage=None,
                    **kwargs):
     optimizer, scheduler = optimizer
     self.model.train()
     timer = CountdownTimer(
         history.num_training_steps(
             len(trn), gradient_accumulation=gradient_accumulation))
     total_loss = 0
     for batch in trn:
         output_dict = self.feed_batch(batch)
         loss = output_dict['loss']
         if gradient_accumulation and gradient_accumulation > 1:
             loss /= gradient_accumulation
         loss.backward()
         total_loss += loss.item()
         if history.step(gradient_accumulation):
             self._step(optimizer, scheduler)
             timer.log(self.report_metrics(total_loss /
                                           (timer.current + 1)),
                       ratio_percentage=ratio_percentage,
                       logger=logger)
         del loss
         del output_dict
     return total_loss / max(timer.total, 1)
예제 #2
0
 def fit_dataloader(self,
                    trn: DataLoader,
                    criterion,
                    optimizer,
                    metric,
                    logger: logging.Logger,
                    linear_scheduler=None,
                    history: History = None,
                    gradient_accumulation=1,
                    **kwargs):
     self.model.train()
     timer = CountdownTimer(history.num_training_steps(len(trn), gradient_accumulation=gradient_accumulation))
     total_loss = 0
     self.reset_metrics(metric)
     for batch in trn:
         optimizer.zero_grad()
         output_dict = self.feed_batch(batch)
         self.update_metrics(batch, output_dict, metric)
         loss = output_dict['loss']
         if gradient_accumulation and gradient_accumulation > 1:
             loss /= gradient_accumulation
         loss.backward()
         total_loss += loss.item()
         if history.step(gradient_accumulation):
             if self.config.grad_norm:
                 torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_norm)
             optimizer.step()
             if linear_scheduler:
                 linear_scheduler.step()
             timer.log(self.report_metrics(total_loss / (timer.current + 1), metric), ratio_percentage=None,
                       logger=logger)
         del loss
     return total_loss / timer.total
예제 #3
0
 def fit_dataloader(self,
                    trn: DataLoader,
                    criterion,
                    optimizer,
                    metric,
                    logger: logging.Logger,
                    history: History,
                    gradient_accumulation=1,
                    grad_norm=None,
                    transformer_grad_norm=None,
                    teacher: Tagger = None,
                    kd_criterion=None,
                    temperature_scheduler=None,
                    ratio_width=None,
                    **kwargs):
     optimizer, scheduler = optimizer
     if teacher:
         scheduler, lambda_scheduler = scheduler
     else:
         lambda_scheduler = None
     self.model.train()
     timer = CountdownTimer(
         history.num_training_steps(
             len(trn), gradient_accumulation=gradient_accumulation))
     total_loss = 0
     for idx, batch in enumerate(trn):
         out, mask = self.feed_batch(batch)
         y = batch['tag_id']
         loss = self.compute_loss(criterion, out, y, mask)
         if gradient_accumulation and gradient_accumulation > 1:
             loss /= gradient_accumulation
         if teacher:
             with torch.no_grad():
                 out_T, _ = teacher.feed_batch(batch)
             # noinspection PyNoneFunctionAssignment
             kd_loss = self.compute_distill_loss(kd_criterion, out, out_T,
                                                 mask,
                                                 temperature_scheduler)
             _lambda = float(lambda_scheduler)
             loss = _lambda * loss + (1 - _lambda) * kd_loss
         loss.backward()
         total_loss += loss.item()
         prediction = self.decode_output(out, mask, batch)
         self.update_metrics(metric, out, y, mask, batch, prediction)
         if history.step(gradient_accumulation):
             self._step(optimizer, scheduler, grad_norm,
                        transformer_grad_norm, lambda_scheduler)
             report = f'loss: {total_loss / (idx + 1):.4f} {metric}'
             timer.log(report,
                       logger=logger,
                       ratio_percentage=False,
                       ratio_width=ratio_width)
         del loss
         del out
         del mask
예제 #4
0
 def fit_dataloader(self,
                    trn: DataLoader,
                    criterion,
                    optimizer,
                    metric: SpanMetric,
                    logger: logging.Logger,
                    history: History,
                    gradient_accumulation=1,
                    grad_norm=None,
                    ratio_width=None,
                    eval_trn=True,
                    **kwargs):
     optimizer, scheduler = optimizer
     metric.reset()
     self.model.train()
     timer = CountdownTimer(
         history.num_training_steps(
             len(trn), gradient_accumulation=gradient_accumulation))
     total_loss = 0
     for idx, batch in enumerate(trn):
         out, mask = self.feed_batch(batch)
         y = batch['chart_id']
         loss, span_probs = self.compute_loss(out, y, mask)
         if gradient_accumulation and gradient_accumulation > 1:
             loss /= gradient_accumulation
         loss.backward()
         total_loss += loss.item()
         if eval_trn:
             prediction = self.decode_output(out, mask, batch, span_probs)
             self.update_metrics(metric, batch, prediction)
         if history.step(gradient_accumulation):
             self._step(optimizer, scheduler, grad_norm)
             report = f'loss: {total_loss / (idx + 1):.4f} {metric}' if eval_trn \
                 else f'loss: {total_loss / (idx + 1):.4f}'
             timer.log(report,
                       logger=logger,
                       ratio_percentage=False,
                       ratio_width=ratio_width)
         del loss
         del out
         del mask
예제 #5
0
    def fit_dataloader(self,
                       trn,
                       optimizer,
                       scheduler,
                       criterion,
                       epoch,
                       logger,
                       history: History,
                       transformer_optimizer=None,
                       transformer_scheduler=None,
                       gradient_accumulation=1,
                       eval_trn=False,
                       **kwargs):
        self.model.train()

        timer = CountdownTimer(history.num_training_steps(len(trn), gradient_accumulation))
        metric = self.build_metric(training=True)
        total_loss = 0
        for idx, batch in enumerate(trn):
            optimizer.zero_grad()
            (s_arc, s_sib, s_rel), mask, puncts = self.feed_batch(batch)
            arcs, sibs, rels = batch['arc'], batch['sib_id'], batch['rel_id']

            loss, s_arc = self.compute_loss(s_arc, s_sib, s_rel, arcs, sibs, rels, mask)
            if gradient_accumulation > 1:
                loss /= gradient_accumulation
            loss.backward()
            total_loss += loss.item()
            if eval_trn:
                arc_preds, rel_preds = self.decode(s_arc, s_sib, s_rel, mask)
                self.update_metric(arc_preds, rel_preds, arcs, rels, mask, puncts, metric)
            if history.step(gradient_accumulation):
                self._step(optimizer, scheduler, transformer_optimizer, transformer_scheduler)
                report = self._report(total_loss / (timer.current + 1), metric if eval_trn else None)
                lr = scheduler.get_last_lr()[0]
                report += f' lr: {lr:.4e}'
                timer.log(report, ratio_percentage=False, logger=logger)
            del loss
예제 #6
0
    def fit_dataloader(self,
                       trn,
                       optimizer,
                       scheduler,
                       criterion,
                       epoch,
                       logger,
                       history: History,
                       transformer_optimizer=None,
                       transformer_scheduler=None,
                       gradient_accumulation=1,
                       **kwargs):
        self.model.train()

        timer = CountdownTimer(
            history.num_training_steps(len(trn), gradient_accumulation))
        metric = self.build_metric(training=True)
        total_loss = 0
        for idx, batch in enumerate(trn):
            arc_scores, rel_scores, mask, puncts = self.feed_batch(batch)
            arcs, rels = batch['arc'], batch['rel_id']
            loss = self.compute_loss(arc_scores, rel_scores, arcs, rels, mask,
                                     criterion, batch)
            if gradient_accumulation > 1:
                loss /= gradient_accumulation
            loss.backward()
            total_loss += loss.item()
            arc_preds, rel_preds = self.decode(arc_scores, rel_scores, mask,
                                               batch)
            self.update_metric(arc_preds, rel_preds, arcs, rels, mask, puncts,
                               metric, batch)
            if history.step(gradient_accumulation):
                self._step(optimizer, scheduler, transformer_optimizer,
                           transformer_scheduler)
                report = self._report(total_loss / (timer.current + 1), metric)
                timer.log(report, ratio_percentage=False, logger=logger)
            del loss
예제 #7
0
 def fit_dataloader(self,
                    trn: DataLoader,
                    criterion,
                    optimizer,
                    metric,
                    logger: logging.Logger,
                    history: History,
                    ratio_width=None,
                    gradient_accumulation=1,
                    encoder_grad_norm=None,
                    decoder_grad_norm=None,
                    patience=0.5,
                    eval_trn=False,
                    **kwargs):
     self.model.train()
     encoder_optimizer, encoder_scheduler, decoder_optimizers = optimizer
     timer = CountdownTimer(len(trn))
     total_loss = 0
     self.reset_metrics(metric)
     model = self.model_
     encoder_parameters = model.encoder.parameters()
     decoder_parameters = model.decoders.parameters()
     for idx, (task_name, batch) in enumerate(trn):
         decoder_optimizer = decoder_optimizers.get(task_name, None)
         output_dict, _ = self.feed_batch(batch, task_name)
         loss = self.compute_loss(batch, output_dict[task_name]['output'],
                                  criterion[task_name],
                                  self.tasks[task_name])
         if gradient_accumulation and gradient_accumulation > 1:
             loss /= gradient_accumulation
         loss.backward()
         total_loss += float(loss.item())
         if history.step(gradient_accumulation):
             if self.config.get('grad_norm', None):
                 clip_grad_norm(model, self.config.grad_norm)
             if encoder_grad_norm:
                 torch.nn.utils.clip_grad_norm_(encoder_parameters,
                                                encoder_grad_norm)
             if decoder_grad_norm:
                 torch.nn.utils.clip_grad_norm_(decoder_parameters,
                                                decoder_grad_norm)
             encoder_optimizer.step()
             encoder_optimizer.zero_grad()
             encoder_scheduler.step()
             if decoder_optimizer:
                 if isinstance(decoder_optimizer, tuple):
                     decoder_optimizer, decoder_scheduler = decoder_optimizer
                 else:
                     decoder_scheduler = None
                 decoder_optimizer.step()
                 decoder_optimizer.zero_grad()
                 if decoder_scheduler:
                     decoder_scheduler.step()
         if eval_trn:
             self.decode_output(output_dict, batch, task_name)
             self.update_metrics(batch, output_dict, metric, task_name)
         timer.log(self.report_metrics(total_loss / (timer.current + 1),
                                       metric if eval_trn else None),
                   ratio_percentage=None,
                   ratio_width=ratio_width,
                   logger=logger)
         del loss
         del output_dict
     return total_loss / timer.total