Ejemplo n.º 1
0
    def evaluate_dataloader(self,
                            data,
                            criterion,
                            logger=None,
                            ratio_width=None,
                            metric=None,
                            output=None,
                            **kwargs):
        self.model.eval()
        if isinstance(output, str):
            output = open(output, 'w')

        loss = 0
        if not metric:
            metric = self.build_metric()
        else:
            metric.reset()
        timer = CountdownTimer(len(data))
        for idx, batch in enumerate(data):
            logits, mask = self.feed_batch(batch)
            y = batch['tag_id']
            loss += self.compute_loss(criterion, logits, y, mask).item()
            prediction = self.decode_output(logits, mask, batch)
            self.update_metrics(metric, logits, y, mask, batch, prediction)
            if output:
                self.write_prediction(prediction, batch, output)
            timer.log(f'loss: {loss / (idx + 1):.4f} {metric}',
                      ratio_percentage=False,
                      logger=logger,
                      ratio_width=ratio_width)
        loss /= len(data)
        if output:
            output.close()
        return float(loss), metric
Ejemplo n.º 2
0
    def evaluate_dataloader(self, loader: PadSequenceDataLoader, criterion, logger=None, filename=None, output=False,
                            ratio_width=None,
                            metric=None,
                            **kwargs):
        self.model.eval()

        total_loss = 0
        if not metric:
            metric = self.build_metric()

        timer = CountdownTimer(len(loader))
        for batch in loader:
            (s_arc, s_sib, s_rel), mask, puncts = self.feed_batch(batch)
            arcs, sibs, rels = batch['arc'], batch['sib_id'], batch['rel_id']
            loss, s_arc = self.compute_loss(s_arc, s_sib, s_rel, arcs, sibs, rels, mask)
            total_loss += float(loss)
            arc_preds, rel_preds = self.decode(s_arc, s_sib, s_rel, mask)
            self.update_metric(arc_preds, rel_preds, arcs, rels, mask, puncts, metric)
            report = self._report(total_loss / (timer.current + 1), metric)
            if filename:
                report = f'{os.path.basename(filename)} ' + report
            timer.log(report, ratio_percentage=False, logger=logger, ratio_width=ratio_width)
        total_loss /= len(loader)

        return total_loss, metric
Ejemplo n.º 3
0
 def fit_dataloader(self,
                    trn: DataLoader,
                    criterion,
                    optimizer,
                    metric,
                    logger: logging.Logger,
                    linear_scheduler=None,
                    gradient_accumulation=1,
                    **kwargs):
     self.model.train()
     timer = CountdownTimer(len(trn) // gradient_accumulation)
     total_loss = 0
     self.reset_metrics(metric)
     for idx, batch in enumerate(trn):
         output_dict = self.feed_batch(batch)
         self.update_metrics(batch, output_dict, metric)
         loss = output_dict['loss']
         loss = loss.sum()  # For data parallel
         loss.backward()
         if gradient_accumulation and gradient_accumulation > 1:
             loss /= gradient_accumulation
         if self.config.grad_norm:
             torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                            self.config.grad_norm)
         if (idx + 1) % gradient_accumulation == 0:
             self._step(optimizer, linear_scheduler)
             timer.log(self.report_metrics(total_loss / (timer.current + 1),
                                           metric),
                       ratio_percentage=None,
                       logger=logger)
         total_loss += loss.item()
         del loss
     if len(trn) % gradient_accumulation:
         self._step(optimizer, linear_scheduler)
     return total_loss / timer.total
Ejemplo n.º 4
0
 def evaluate_dataloader(self,
                         data,
                         criterion,
                         logger=None,
                         ratio_width=None,
                         metric=None,
                         output=None,
                         **kwargs):
     self.model.eval()
     total_loss = 0
     if not metric:
         metric = self.build_metric()
     else:
         metric.reset()
     timer = CountdownTimer(len(data))
     for idx, batch in enumerate(data):
         out, mask = self.feed_batch(batch)
         y = batch['chart_id']
         loss, span_probs = self.compute_loss(out, y, mask)
         total_loss += loss.item()
         prediction = self.decode_output(out, mask, batch, span_probs)
         self.update_metrics(metric, batch, prediction)
         timer.log(f'loss: {total_loss / (idx + 1):.4f} {metric}',
                   ratio_percentage=False,
                   logger=logger,
                   ratio_width=ratio_width)
     total_loss /= len(data)
     if output:
         output.close()
     return total_loss, metric
Ejemplo n.º 5
0
 def execute_training_loop(self,
                           trn: DataLoader,
                           dev: DataLoader,
                           epochs,
                           criterion,
                           optimizer,
                           metric,
                           save_dir,
                           logger: logging.Logger,
                           devices,
                           gradient_accumulation=1,
                           **kwargs):
     best_epoch, best_metric = 0, -1
     optimizer, scheduler = optimizer
     history = History()
     timer = CountdownTimer(epochs)
     ratio_width = len(f'{len(trn)}/{len(trn)}')
     for epoch in range(1, epochs + 1):
         logger.info(f"[yellow]Epoch {epoch} / {epochs}:[/yellow]")
         self.fit_dataloader(trn, criterion, optimizer, metric, logger, history=history,
                             gradient_accumulation=gradient_accumulation,
                             linear_scheduler=scheduler if self._get_transformer() else None)
         if dev:
             self.evaluate_dataloader(dev, criterion, metric, logger, ratio_width=ratio_width)
         report = f'{timer.elapsed_human}/{timer.total_time_human}'
         dev_score = metric.score
         if not self._get_transformer():
             scheduler.step(dev_score)
         if dev_score > best_metric:
             self.save_weights(save_dir)
             best_metric = dev_score
             report += ' [red]saved[/red]'
         timer.log(report, ratio_percentage=False, newline=True, ratio=False)
Ejemplo n.º 6
0
 def fit_dataloader(self,
                    trn: DataLoader,
                    criterion,
                    optimizer,
                    metric,
                    logger: logging.Logger,
                    history: History = None,
                    gradient_accumulation=1,
                    ratio_percentage=None,
                    **kwargs):
     optimizer, scheduler = optimizer
     self.model.train()
     timer = CountdownTimer(
         history.num_training_steps(
             len(trn), gradient_accumulation=gradient_accumulation))
     total_loss = 0
     for batch in trn:
         output_dict = self.feed_batch(batch)
         loss = output_dict['loss']
         if gradient_accumulation and gradient_accumulation > 1:
             loss /= gradient_accumulation
         loss.backward()
         total_loss += loss.item()
         if history.step(gradient_accumulation):
             self._step(optimizer, scheduler)
             timer.log(self.report_metrics(total_loss /
                                           (timer.current + 1)),
                       ratio_percentage=ratio_percentage,
                       logger=logger)
         del loss
         del output_dict
     return total_loss / max(timer.total, 1)
Ejemplo n.º 7
0
 def execute_training_loop(self, trn: DataLoader, dev: DataLoader, epochs,
                           criterion, optimizer, metric, save_dir,
                           logger: logging.Logger, devices, **kwargs):
     best_epoch, best_metric = 0, -1
     timer = CountdownTimer(epochs)
     ratio_width = len(f'{len(trn)}/{len(trn)}')
     for epoch in range(1, epochs + 1):
         logger.info(f"[yellow]Epoch {epoch} / {epochs}:[/yellow]")
         self.fit_dataloader(trn, criterion, optimizer, metric, logger)
         if dev:
             self.evaluate_dataloader(dev,
                                      criterion,
                                      metric,
                                      logger,
                                      ratio_width=ratio_width)
         report = f'{timer.elapsed_human}/{timer.total_time_human}'
         dev_score = metric.score
         if dev_score > best_metric:
             self.save_weights(save_dir)
             best_metric = dev_score
             report += ' [red]saved[/red]'
         timer.log(report,
                   ratio_percentage=False,
                   newline=True,
                   ratio=False)
Ejemplo n.º 8
0
 def evaluate_dataloader(self,
                         data: DataLoader,
                         criterion: Callable,
                         metric,
                         logger,
                         ratio_width=None,
                         output=False,
                         **kwargs):
     self.model.eval()
     self.reset_metrics(metric)
     timer = CountdownTimer(len(data))
     total_loss = 0
     if output:
         fp = open(output, 'w')
     for batch in data:
         output_dict = self.feed_batch(batch)
         if output:
             for sent, pred, gold in zip(batch['token'], output_dict['prediction'], batch['ner']):
                 fp.write('Tokens\t' + ' '.join(sent) + '\n')
                 fp.write('Pred\t' + '\t'.join(
                     ['[' + ' '.join(sent[x:y + 1]) + f']/{label}' for x, y, label in pred]) + '\n')
                 fp.write('Gold\t' + '\t'.join(
                     ['[' + ' '.join(sent[x:y + 1]) + f']/{label}' for x, y, label in gold]) + '\n')
                 fp.write('\n')
         self.update_metrics(batch, output_dict, metric)
         loss = output_dict['loss']
         total_loss += loss.item()
         timer.log(self.report_metrics(total_loss / (timer.current + 1), metric), ratio_percentage=None,
                   logger=logger,
                   ratio_width=ratio_width)
         del loss
     if output:
         fp.close()
     return total_loss / timer.total, metric
Ejemplo n.º 9
0
 def fit_dataloader(self,
                    trn: DataLoader,
                    criterion,
                    optimizer,
                    metric,
                    logger: logging.Logger,
                    linear_scheduler=None,
                    history: History = None,
                    gradient_accumulation=1,
                    **kwargs):
     self.model.train()
     timer = CountdownTimer(history.num_training_steps(len(trn), gradient_accumulation=gradient_accumulation))
     total_loss = 0
     self.reset_metrics(metric)
     for batch in trn:
         optimizer.zero_grad()
         output_dict = self.feed_batch(batch)
         self.update_metrics(batch, output_dict, metric)
         loss = output_dict['loss']
         if gradient_accumulation and gradient_accumulation > 1:
             loss /= gradient_accumulation
         loss.backward()
         total_loss += loss.item()
         if history.step(gradient_accumulation):
             if self.config.grad_norm:
                 torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_norm)
             optimizer.step()
             if linear_scheduler:
                 linear_scheduler.step()
             timer.log(self.report_metrics(total_loss / (timer.current + 1), metric), ratio_percentage=None,
                       logger=logger)
         del loss
     return total_loss / timer.total
Ejemplo n.º 10
0
 def build_dataset(self,
                   data,
                   logger: logging.Logger = None,
                   training=True):
     dataset = AbstractMeaningRepresentationDataset(
         data, generate_idx=not training)
     if self.vocabs.mutable:
         self.build_vocabs(dataset, logger)
         self.vocabs.lock()
         self.vocabs.summary(logger)
     lens = [len(x['token']) + len(x['amr']) for x in dataset]
     dataset.append_transform(
         functools.partial(get_concepts,
                           vocab=self.vocabs.predictable_concept,
                           rel_vocab=self.vocabs.rel if self.config.get(
                               'separate_rel', False) else None))
     dataset.append_transform(append_bos)
     # Tokenization will happen in batchify
     if not self.config.get('squeeze', None):
         dataset.append_transform(self.config.encoder.transform())
     if isinstance(data, str):
         dataset.purge_cache()
         timer = CountdownTimer(len(dataset))
         for each in dataset:
             timer.log(
                 'Caching samples [blink][yellow]...[/yellow][/blink]')
     return dataset, lens
Ejemplo n.º 11
0
    def build_vocabs(self, dataset, logger=None, transformer=False):
        rel_vocab = self.vocabs.get('rel', None)
        if rel_vocab is None:
            rel_vocab = Vocab(unk_token=None, pad_token=self.config.get('pad_rel', None))
            self.vocabs.put(rel=rel_vocab)

        timer = CountdownTimer(len(dataset))
        if transformer:
            token_vocab = None
        else:
            self.vocabs.token = token_vocab = VocabCounter(unk_token=self.config.get('unk', UNK))
        for i, sample in enumerate(dataset):
            timer.log('Building vocab [blink][yellow]...[/yellow][/blink]', ratio_percentage=True)
        min_freq = self.config.get('min_freq', None)
        if min_freq:
            token_vocab.trim(min_freq)
        rel_vocab.set_unk_as_safe_unk()  # Some relation in dev set is OOV
        self.vocabs.lock()
        self.vocabs.summary(logger=logger)
        if token_vocab:
            self.config.n_words = len(self.vocabs['token'])
        self.config.n_rels = len(self.vocabs['rel'])
        if token_vocab:
            self.config.pad_index = self.vocabs['token'].pad_idx
            self.config.unk_index = self.vocabs['token'].unk_idx
Ejemplo n.º 12
0
 def execute_training_loop(self,
                           trn: PrefetchDataLoader,
                           dev: PrefetchDataLoader,
                           epochs,
                           criterion,
                           optimizer,
                           metric,
                           save_dir,
                           logger: logging.Logger,
                           devices,
                           ratio_width=None,
                           dev_data=None,
                           gradient_accumulation=1,
                           **kwargs):
     best_epoch, best_metric = 0, -1
     timer = CountdownTimer(epochs)
     history = History()
     try:
         for epoch in range(1, epochs + 1):
             logger.info(f"[yellow]Epoch {epoch} / {epochs}:[/yellow]")
             trn = self.fit_dataloader(
                 trn,
                 criterion,
                 optimizer,
                 metric,
                 logger,
                 ratio_width=ratio_width,
                 gradient_accumulation=gradient_accumulation,
                 history=history,
                 save_dir=save_dir)
             report = f'{timer.elapsed_human}/{timer.total_time_human}'
             if epoch % self.config.eval_every == 0 or epoch == epochs:
                 metric = self.evaluate_dataloader(dev,
                                                   logger,
                                                   dev_data,
                                                   ratio_width=ratio_width,
                                                   save_dir=save_dir,
                                                   use_fast=True)
                 if metric > best_metric:
                     self.save_weights(save_dir)
                     best_metric = metric
                     best_epoch = epoch
                     report += ' [red]saved[/red]'
             timer.log(report,
                       ratio_percentage=False,
                       newline=True,
                       ratio=False)
         if best_epoch and best_epoch != epochs:
             logger.info(
                 f'Restored the best model with {best_metric} saved {epochs - best_epoch} epochs ago'
             )
             self.load_weights(save_dir)
     finally:
         trn.close()
         dev.close()
Ejemplo n.º 13
0
 def fit_dataloader(self,
                    trn: DataLoader,
                    criterion,
                    optimizer,
                    metric,
                    logger: logging.Logger,
                    history: History,
                    gradient_accumulation=1,
                    grad_norm=None,
                    transformer_grad_norm=None,
                    teacher: Tagger = None,
                    kd_criterion=None,
                    temperature_scheduler=None,
                    ratio_width=None,
                    **kwargs):
     optimizer, scheduler = optimizer
     if teacher:
         scheduler, lambda_scheduler = scheduler
     else:
         lambda_scheduler = None
     self.model.train()
     timer = CountdownTimer(
         history.num_training_steps(
             len(trn), gradient_accumulation=gradient_accumulation))
     total_loss = 0
     for idx, batch in enumerate(trn):
         out, mask = self.feed_batch(batch)
         y = batch['tag_id']
         loss = self.compute_loss(criterion, out, y, mask)
         if gradient_accumulation and gradient_accumulation > 1:
             loss /= gradient_accumulation
         if teacher:
             with torch.no_grad():
                 out_T, _ = teacher.feed_batch(batch)
             # noinspection PyNoneFunctionAssignment
             kd_loss = self.compute_distill_loss(kd_criterion, out, out_T,
                                                 mask,
                                                 temperature_scheduler)
             _lambda = float(lambda_scheduler)
             loss = _lambda * loss + (1 - _lambda) * kd_loss
         loss.backward()
         total_loss += loss.item()
         prediction = self.decode_output(out, mask, batch)
         self.update_metrics(metric, out, y, mask, batch, prediction)
         if history.step(gradient_accumulation):
             self._step(optimizer, scheduler, grad_norm,
                        transformer_grad_norm, lambda_scheduler)
             report = f'loss: {total_loss / (idx + 1):.4f} {metric}'
             timer.log(report,
                       logger=logger,
                       ratio_percentage=False,
                       ratio_width=ratio_width)
         del loss
         del out
         del mask
Ejemplo n.º 14
0
 def build_vocabs(self, dataset, logger, vocabs, lock=True, label_vocab_name='label', **kwargs):
     vocabs[label_vocab_name] = label_vocab = Vocab(pad_token=None, unk_token=None)
     # Use null to indicate no relationship
     label_vocab.add('<null>')
     timer = CountdownTimer(len(dataset))
     for each in dataset:
         timer.log('Building NER vocab [blink][yellow]...[/yellow][/blink]')
     label_vocab.set_unk_as_safe_unk()
     if lock:
         vocabs.lock()
         vocabs.summary(logger)
Ejemplo n.º 15
0
 def build_vocabs(self, trn, logger, **kwargs):
     self.vocabs.chart = VocabWithNone(pad_token=None, unk_token=None)
     timer = CountdownTimer(len(trn))
     max_seq_len = 0
     for each in trn:
         max_seq_len = max(max_seq_len, len(each['token_input_ids']))
         timer.log(
             f'Building vocab [blink][yellow]...[/yellow][/blink] (longest sequence: {max_seq_len})'
         )
     self.vocabs.chart.set_unk_as_safe_unk()
     self.vocabs.lock()
     self.vocabs.summary(logger)
Ejemplo n.º 16
0
 def execute_training_loop(self,
                           trn: DataLoader,
                           dev: DataLoader,
                           epochs,
                           criterion,
                           optimizer,
                           metric,
                           save_dir,
                           logger: logging.Logger,
                           devices,
                           ratio_width=None,
                           patience=0.5,
                           eval_trn=True,
                           **kwargs):
     if isinstance(patience, float):
         patience = int(patience * epochs)
     best_epoch, best_metric = 0, -1
     timer = CountdownTimer(epochs)
     history = History()
     for epoch in range(1, epochs + 1):
         logger.info(f"[yellow]Epoch {epoch} / {epochs}:[/yellow]")
         self.fit_dataloader(trn,
                             criterion,
                             optimizer,
                             metric,
                             logger,
                             history=history,
                             ratio_width=ratio_width,
                             eval_trn=eval_trn,
                             **self.config)
         loss, dev_metric = self.evaluate_dataloader(
             dev, criterion, logger=logger, ratio_width=ratio_width)
         timer.update()
         report = f"{timer.elapsed_human} / {timer.total_time_human} ETA: {timer.eta_human}"
         if dev_metric > best_metric:
             best_epoch, best_metric = epoch, dev_metric
             self.save_weights(save_dir)
             report += ' [red](saved)[/red]'
         else:
             report += f' ({epoch - best_epoch})'
             if epoch - best_epoch >= patience:
                 report += ' early stop'
         logger.info(report)
         if epoch - best_epoch >= patience:
             break
     if not best_epoch:
         self.save_weights(save_dir)
     elif best_epoch != epoch:
         self.load_weights(save_dir)
     logger.info(f"Max score of dev is {best_metric} at epoch {best_epoch}")
     logger.info(
         f"Average time of each epoch is {timer.elapsed_average_human}")
     logger.info(f"{timer.elapsed_human} elapsed")
Ejemplo n.º 17
0
    def evaluate_dataloader(self,
                            loader: PadSequenceDataLoader,
                            criterion,
                            logger=None,
                            filename=None,
                            output=False,
                            ratio_width=None,
                            metric=None,
                            **kwargs):
        self.model.eval()

        loss = 0
        if not metric:
            metric = self.build_metric()
        if output:
            fp = open(output, 'w')
            predictions, build_data, data, order = self.before_outputs(None)

        timer = CountdownTimer(len(loader))
        use_pos = self.use_pos
        for batch in loader:
            arc_scores, rel_scores, mask, puncts = self.feed_batch(batch)
            if output:
                self.collect_outputs(arc_scores, rel_scores, mask, batch,
                                     predictions, order, data, use_pos,
                                     build_data)
            arcs, rels = batch['arc'], batch['rel_id']
            loss += self.compute_loss(arc_scores, rel_scores, arcs, rels, mask,
                                      criterion, batch).item()
            arc_preds, rel_preds = self.decode(arc_scores, rel_scores, mask,
                                               batch)
            self.update_metric(arc_preds, rel_preds, arcs, rels, mask, puncts,
                               metric, batch)
            report = self._report(loss / (timer.current + 1), metric)
            if filename:
                report = f'{os.path.basename(filename)} ' + report
            timer.log(report,
                      ratio_percentage=False,
                      logger=logger,
                      ratio_width=ratio_width)
        loss /= len(loader)
        if output:
            outputs = self.post_outputs(predictions, data, order, use_pos,
                                        build_data)
            for each in outputs:
                fp.write(f'{each}\n\n')
            fp.close()
            logger.info(
                f'Predictions saved in [underline][yellow]{output}[/yellow][/underline]'
            )

        return loss, metric
Ejemplo n.º 18
0
 def build_vocabs(self, trn, logger, **kwargs):
     self.vocabs.tag = Vocab(pad_token=None, unk_token=None)
     timer = CountdownTimer(len(trn))
     max_seq_len = 0
     token_key = self.config.token_key
     for each in trn:
         max_seq_len = max(max_seq_len, len(each[token_key]))
         timer.log(
             f'Building vocab [blink][yellow]...[/yellow][/blink] (longest sequence: {max_seq_len})'
         )
     self.vocabs.tag.set_unk_as_safe_unk()
     self.vocabs.lock()
     self.vocabs.summary(logger)
Ejemplo n.º 19
0
    def build_vocabs(self, dataset, logger=None, transformer=None):
        rel_vocab = self.vocabs.get('rel', None)
        if rel_vocab is None:
            rel_vocab = Vocab(unk_token=None,
                              pad_token=self.config.get('pad_rel', None))
            self.vocabs.put(rel=rel_vocab)
        if self.config.get('feat', None) == 'pos' or self.config.get(
                'use_pos', False):
            self.vocabs['pos'] = Vocab(unk_token=None, pad_token=None)

        timer = CountdownTimer(len(dataset))
        if transformer:
            token_vocab = None
        else:
            token_vocab = Vocab()
            self.vocabs.token = token_vocab
            unk = self.config.get('unk', None)
            if unk is not None:
                token_vocab.unk_token = unk
        if token_vocab and self.config.get('min_freq', None):
            counter = Counter()
            for sample in dataset:
                for form in sample['token']:
                    counter[form] += 1
            reserved_token = [token_vocab.pad_token, token_vocab.unk_token]
            if ROOT in token_vocab:
                reserved_token.append(ROOT)
            freq_words = reserved_token + [
                token for token, freq in counter.items()
                if freq >= self.config.min_freq
            ]
            token_vocab.token_to_idx.clear()
            for word in freq_words:
                token_vocab(word)
        else:
            for i, sample in enumerate(dataset):
                timer.log('vocab building [blink][yellow]...[/yellow][/blink]',
                          ratio_percentage=True)
        rel_vocab.set_unk_as_safe_unk()  # Some relation in dev set is OOV
        self.vocabs.lock()
        self.vocabs.summary(logger=logger)
        if token_vocab:
            self.config.n_words = len(self.vocabs['token'])
        if 'pos' in self.vocabs:
            self.config.n_feats = len(self.vocabs['pos'])
            self.vocabs['pos'].set_unk_as_safe_unk()
        self.config.n_rels = len(self.vocabs['rel'])
        if token_vocab:
            self.config.pad_index = self.vocabs['token'].pad_idx
            self.config.unk_index = self.vocabs['token'].unk_idx
Ejemplo n.º 20
0
 def build_dataloader(self,
                      data,
                      transform: Callable = None,
                      training=False,
                      device=None,
                      logger: logging.Logger = None,
                      gradient_accumulation=1,
                      **kwargs) -> DataLoader:
     dataset = BiaffineSecondaryParser.build_dataset(self, data, transform)
     if isinstance(data, str):
         dataset.purge_cache()
     if self.vocabs.mutable:
         BiaffineSecondaryParser.build_vocabs(self,
                                              dataset,
                                              logger,
                                              transformer=True)
     max_seq_len = self.config.get('max_seq_len', None)
     if max_seq_len and isinstance(data, str):
         dataset.prune(lambda x: len(x['token_input_ids']) > 510, logger)
     if dataset.cache:
         timer = CountdownTimer(len(dataset))
         BiaffineSecondaryDependencyParsing.cache_dataset(
             self, dataset, timer, training, logger)
     return PadSequenceDataLoader(batch_sampler=self.sampler_builder.build(
         self.compute_lens(data, dataset),
         shuffle=training,
         gradient_accumulation=gradient_accumulation),
                                  device=device,
                                  dataset=dataset,
                                  pad={
                                      'arc': 0,
                                      'arc_2nd': False
                                  })
Ejemplo n.º 21
0
 def build_dataloader(self,
                      data,
                      transform: TransformList = None,
                      training=False,
                      device=None,
                      logger: logging.Logger = None,
                      gradient_accumulation=1,
                      **kwargs) -> DataLoader:
     transform.insert(0, append_bos)
     dataset = BiaffineDependencyParser.build_dataset(self, data, transform)
     if isinstance(data, str):
         dataset.purge_cache()
     if self.vocabs.mutable:
         BiaffineDependencyParser.build_vocabs(self,
                                               dataset,
                                               logger,
                                               transformer=True)
     if dataset.cache:
         timer = CountdownTimer(len(dataset))
         BiaffineDependencyParser.cache_dataset(self, dataset, timer,
                                                training, logger)
     max_seq_len = self.config.get('max_seq_len', None)
     if max_seq_len and isinstance(data, str):
         dataset.prune(lambda x: len(x['token_input_ids']) > 510, logger)
     return PadSequenceDataLoader(batch_sampler=self.sampler_builder.build(
         self.compute_lens(data, dataset, length_field='FORM'),
         shuffle=training,
         gradient_accumulation=gradient_accumulation),
                                  device=device,
                                  dataset=dataset,
                                  pad=self.get_pad_dict())
Ejemplo n.º 22
0
 def build_dataloader(self,
                      data,
                      transform: Callable = None,
                      training=False,
                      device=None,
                      logger: logging.Logger = None,
                      cache=False,
                      gradient_accumulation=1,
                      **kwargs) -> DataLoader:
     dataset = CRFConstituencyParsing.build_dataset(self, data, transform)
     if isinstance(data, str):
         dataset.purge_cache()
     if self.vocabs.mutable:
         CRFConstituencyParsing.build_vocabs(self, dataset, logger)
     if dataset.cache:
         timer = CountdownTimer(len(dataset))
         # noinspection PyCallByClass
         BiaffineDependencyParser.cache_dataset(self, dataset, timer,
                                                training, logger)
     return PadSequenceDataLoader(batch_sampler=self.sampler_builder.build(
         self.compute_lens(data, dataset),
         shuffle=training,
         gradient_accumulation=gradient_accumulation),
                                  device=device,
                                  dataset=dataset)
Ejemplo n.º 23
0
 def evaluate_dataloader(self,
                         data: DataLoader,
                         criterion: Callable,
                         metric,
                         logger,
                         ratio_width=None,
                         filename=None,
                         output=None,
                         **kwargs):
     self.model.eval()
     timer = CountdownTimer(len(data))
     total_loss = 0
     metric.reset()
     num_samples = 0
     if output:
         output = open(output, 'w')
     for batch in data:
         logits = self.feed_batch(batch)
         target = batch['label_id']
         loss = self.compute_loss(criterion, logits, target, batch)
         total_loss += loss.item()
         label_ids = self.update_metric(metric, logits, target, output)
         if output:
             labels = [
                 self.vocabs[self.config.label_key].idx_to_token[i]
                 for i in label_ids.tolist()
             ]
             for i, label in enumerate(labels):
                 # text_a text_b pred gold
                 columns = [batch[self.config.text_a_key][i]]
                 if self.config.text_b_key:
                     columns.append(batch[self.config.text_b_key][i])
                 columns.append(label)
                 columns.append(batch[self.config.label_key][i])
                 output.write('\t'.join(columns))
                 output.write('\n')
         num_samples += len(target)
         report = f'loss: {total_loss / (timer.current + 1):.4f} acc: {metric.get_metric():.2%}'
         if filename:
             report = f'{filename} {report} {num_samples / timer.elapsed:.0f} samples/sec'
         timer.log(report,
                   ratio_percentage=None,
                   logger=logger,
                   ratio_width=ratio_width)
     if output:
         output.close()
     return total_loss / timer.total
Ejemplo n.º 24
0
 def fit_dataloader(self,
                    trn: DataLoader,
                    criterion,
                    optimizer,
                    metric: SpanMetric,
                    logger: logging.Logger,
                    history: History,
                    gradient_accumulation=1,
                    grad_norm=None,
                    ratio_width=None,
                    eval_trn=True,
                    **kwargs):
     optimizer, scheduler = optimizer
     metric.reset()
     self.model.train()
     timer = CountdownTimer(
         history.num_training_steps(
             len(trn), gradient_accumulation=gradient_accumulation))
     total_loss = 0
     for idx, batch in enumerate(trn):
         out, mask = self.feed_batch(batch)
         y = batch['chart_id']
         loss, span_probs = self.compute_loss(out, y, mask)
         if gradient_accumulation and gradient_accumulation > 1:
             loss /= gradient_accumulation
         loss.backward()
         total_loss += loss.item()
         if eval_trn:
             prediction = self.decode_output(out, mask, batch, span_probs)
             self.update_metrics(metric, batch, prediction)
         if history.step(gradient_accumulation):
             self._step(optimizer, scheduler, grad_norm)
             report = f'loss: {total_loss / (idx + 1):.4f} {metric}' if eval_trn \
                 else f'loss: {total_loss / (idx + 1):.4f}'
             timer.log(report,
                       logger=logger,
                       ratio_percentage=False,
                       ratio_width=ratio_width)
         del loss
         del out
         del mask
Ejemplo n.º 25
0
 def fit_dataloader(self, trn: DataLoader, criterion, optimizer, metric, logger: logging.Logger, ratio_width=None,
                    **kwargs):
     self.model.train()
     timer = CountdownTimer(len(trn))
     total_loss = 0
     for idx, batch in enumerate(trn):
         optimizer.zero_grad()
         out, mask = self.feed_batch(batch)
         y = batch['tag_id']
         loss = self.compute_loss(criterion, out, y, mask)
         loss.backward()
         nn.utils.clip_grad_norm_(self.model.parameters(), 5.0)
         optimizer.step()
         total_loss += loss.item()
         prediction = self.decode_output(out, mask, batch)
         self.update_metrics(metric, out, y, mask, batch, prediction)
         timer.log(f'loss: {loss / (idx + 1):.4f} {metric}', ratio_percentage=False, logger=logger,
                   ratio_width=ratio_width)
         del loss
         del out
         del mask
Ejemplo n.º 26
0
    def compute_lens(self,
                     data: Union[List[Dict[str, Any]], str],
                     dataset: TransformDataset,
                     input_ids='token_input_ids',
                     length_field='token'):
        """

        Args:
            data: Samples to be measured or path to dataset during training time.
            dataset: During training time, use this dataset to measure the length of each sample inside.
            input_ids: Field name corresponds to input ids.
            length_field: Fall back to this field during prediction as input_ids may not be generated yet.

        Returns:

            Length list of this samples

        """
        if isinstance(data, str):
            if not dataset.cache:
                warnings.warn(
                    f'Caching for the dataset is not enabled, '
                    f'try `dataset.purge_cache()` if possible. The dataset is {dataset}.'
                )
            timer = CountdownTimer(len(dataset))
            for each in dataset:
                timer.log(
                    'Preprocessing and caching samples [blink][yellow]...[/yellow][/blink]'
                )
            timer.erase()
            return [len(x[input_ids]) for x in dataset]
        return [len(x[length_field]) for x in data]
Ejemplo n.º 27
0
    def fit_dataloader(self,
                       trn,
                       optimizer,
                       scheduler,
                       criterion,
                       epoch,
                       logger,
                       history: History,
                       transformer_optimizer=None,
                       transformer_scheduler=None,
                       gradient_accumulation=1,
                       eval_trn=False,
                       **kwargs):
        self.model.train()

        timer = CountdownTimer(history.num_training_steps(len(trn), gradient_accumulation))
        metric = self.build_metric(training=True)
        total_loss = 0
        for idx, batch in enumerate(trn):
            optimizer.zero_grad()
            (s_arc, s_sib, s_rel), mask, puncts = self.feed_batch(batch)
            arcs, sibs, rels = batch['arc'], batch['sib_id'], batch['rel_id']

            loss, s_arc = self.compute_loss(s_arc, s_sib, s_rel, arcs, sibs, rels, mask)
            if gradient_accumulation > 1:
                loss /= gradient_accumulation
            loss.backward()
            total_loss += loss.item()
            if eval_trn:
                arc_preds, rel_preds = self.decode(s_arc, s_sib, s_rel, mask)
                self.update_metric(arc_preds, rel_preds, arcs, rels, mask, puncts, metric)
            if history.step(gradient_accumulation):
                self._step(optimizer, scheduler, transformer_optimizer, transformer_scheduler)
                report = self._report(total_loss / (timer.current + 1), metric if eval_trn else None)
                lr = scheduler.get_last_lr()[0]
                report += f' lr: {lr:.4e}'
                timer.log(report, ratio_percentage=False, logger=logger)
            del loss
Ejemplo n.º 28
0
 def fit_dataloader(self, trn: DataLoader, criterion, optimizer, metric,
                    logger: logging.Logger, **kwargs):
     self.model.train()
     timer = CountdownTimer(len(trn))
     optimizer, scheduler = optimizer
     total_loss = 0
     metric.reset()
     for batch in trn:
         optimizer.zero_grad()
         logits = self.feed_batch(batch)
         target = batch['label_id']
         loss = self.compute_loss(criterion, logits, target, batch)
         loss.backward()
         optimizer.step()
         scheduler.step()
         total_loss += loss.item()
         self.update_metric(metric, logits, target)
         timer.log(
             f'loss: {total_loss / (timer.current + 1):.4f} acc: {metric.get_metric():.2%}',
             ratio_percentage=None,
             logger=logger)
         del loss
     return total_loss / timer.total
Ejemplo n.º 29
0
 def fit_dataloader(self, trn: DataLoader, criterion, optimizer, metric,
                    logger: logging.Logger, **kwargs):
     self.model.train()
     timer = CountdownTimer(len(trn))
     total_loss = 0
     self.reset_metrics(metric)
     for batch in trn:
         optimizer.zero_grad()
         prediction = self.feed_batch(batch)
         loss = self.compute_loss(prediction, batch, criterion)
         self.update_metrics(batch, prediction, metric)
         loss.backward()
         if self.config.grad_norm:
             torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                            self.config.grad_norm)
         optimizer.step()
         total_loss += loss.item()
         timer.log(self.report_metrics(total_loss / (timer.current + 1),
                                       metric),
                   ratio_percentage=None,
                   logger=logger)
         del loss
     return total_loss / timer.total
Ejemplo n.º 30
0
    def fit_dataloader(self,
                       trn,
                       optimizer,
                       scheduler,
                       criterion,
                       epoch,
                       logger,
                       history: History,
                       transformer_optimizer=None,
                       transformer_scheduler=None,
                       gradient_accumulation=1,
                       **kwargs):
        self.model.train()

        timer = CountdownTimer(
            history.num_training_steps(len(trn), gradient_accumulation))
        metric = self.build_metric(training=True)
        total_loss = 0
        for idx, batch in enumerate(trn):
            arc_scores, rel_scores, mask, puncts = self.feed_batch(batch)
            arcs, rels = batch['arc'], batch['rel_id']
            loss = self.compute_loss(arc_scores, rel_scores, arcs, rels, mask,
                                     criterion, batch)
            if gradient_accumulation > 1:
                loss /= gradient_accumulation
            loss.backward()
            total_loss += loss.item()
            arc_preds, rel_preds = self.decode(arc_scores, rel_scores, mask,
                                               batch)
            self.update_metric(arc_preds, rel_preds, arcs, rels, mask, puncts,
                               metric, batch)
            if history.step(gradient_accumulation):
                self._step(optimizer, scheduler, transformer_optimizer,
                           transformer_scheduler)
                report = self._report(total_loss / (timer.current + 1), metric)
                timer.log(report, ratio_percentage=False, logger=logger)
            del loss