Esempio n. 1
0
 def evaluate_dataloader(self,
                         data,
                         criterion,
                         logger=None,
                         ratio_width=None,
                         metric=None,
                         output=None,
                         **kwargs):
     self.model.eval()
     total_loss = 0
     if not metric:
         metric = self.build_metric()
     else:
         metric.reset()
     timer = CountdownTimer(len(data))
     for idx, batch in enumerate(data):
         out, mask = self.feed_batch(batch)
         y = batch['chart_id']
         loss, span_probs = self.compute_loss(out, y, mask)
         total_loss += loss.item()
         prediction = self.decode_output(out, mask, batch, span_probs)
         self.update_metrics(metric, batch, prediction)
         timer.log(f'loss: {total_loss / (idx + 1):.4f} {metric}',
                   ratio_percentage=False,
                   logger=logger,
                   ratio_width=ratio_width)
     total_loss /= len(data)
     if output:
         output.close()
     return total_loss, metric
Esempio n. 2
0
 def fit_dataloader(self,
                    trn: DataLoader,
                    criterion,
                    optimizer,
                    metric,
                    logger: logging.Logger,
                    history: History = None,
                    gradient_accumulation=1,
                    ratio_percentage=None,
                    **kwargs):
     optimizer, scheduler = optimizer
     self.model.train()
     timer = CountdownTimer(
         history.num_training_steps(
             len(trn), gradient_accumulation=gradient_accumulation))
     total_loss = 0
     for batch in trn:
         output_dict = self.feed_batch(batch)
         loss = output_dict['loss']
         if gradient_accumulation and gradient_accumulation > 1:
             loss /= gradient_accumulation
         loss.backward()
         total_loss += loss.item()
         if history.step(gradient_accumulation):
             self._step(optimizer, scheduler)
             timer.log(self.report_metrics(total_loss /
                                           (timer.current + 1)),
                       ratio_percentage=ratio_percentage,
                       logger=logger)
         del loss
         del output_dict
     return total_loss / max(timer.total, 1)
Esempio n. 3
0
    def build_vocabs(self, dataset, logger=None, transformer=False):
        rel_vocab = self.vocabs.get('rel', None)
        if rel_vocab is None:
            rel_vocab = Vocab(unk_token=None, pad_token=self.config.get('pad_rel', None))
            self.vocabs.put(rel=rel_vocab)

        timer = CountdownTimer(len(dataset))
        if transformer:
            token_vocab = None
        else:
            self.vocabs.token = token_vocab = VocabCounter(unk_token=self.config.get('unk', UNK))
        for i, sample in enumerate(dataset):
            timer.log('Building vocab [blink][yellow]...[/yellow][/blink]', ratio_percentage=True)
        min_freq = self.config.get('min_freq', None)
        if min_freq:
            token_vocab.trim(min_freq)
        rel_vocab.set_unk_as_safe_unk()  # Some relation in dev set is OOV
        self.vocabs.lock()
        self.vocabs.summary(logger=logger)
        if token_vocab:
            self.config.n_words = len(self.vocabs['token'])
        self.config.n_rels = len(self.vocabs['rel'])
        if token_vocab:
            self.config.pad_index = self.vocabs['token'].pad_idx
            self.config.unk_index = self.vocabs['token'].unk_idx
Esempio n. 4
0
    def compute_lens(self,
                     data: Union[List[Dict[str, Any]], str],
                     dataset: TransformDataset,
                     input_ids='token_input_ids',
                     length_field='token'):
        """

        Args:
            data: Samples to be measured or path to dataset during training time.
            dataset: During training time, use this dataset to measure the length of each sample inside.
            input_ids: Field name corresponds to input ids.
            length_field: Fall back to this field during prediction as input_ids may not be generated yet.

        Returns:

            Length list of this samples

        """
        if isinstance(data, str):
            if not dataset.cache:
                warnings.warn(
                    f'Caching for the dataset is not enabled, '
                    f'try `dataset.purge_cache()` if possible. The dataset is {dataset}.'
                )
            timer = CountdownTimer(len(dataset))
            for each in dataset:
                timer.log(
                    'Preprocessing and caching samples [blink][yellow]...[/yellow][/blink]'
                )
            timer.erase()
            return [len(x[input_ids]) for x in dataset]
        return [len(x[length_field]) for x in data]
Esempio n. 5
0
    def evaluate_dataloader(self,
                            data,
                            criterion,
                            logger=None,
                            ratio_width=None,
                            metric=None,
                            output=None,
                            **kwargs):
        self.model.eval()
        if isinstance(output, str):
            output = open(output, 'w')

        loss = 0
        if not metric:
            metric = self.build_metric()
        else:
            metric.reset()
        timer = CountdownTimer(len(data))
        for idx, batch in enumerate(data):
            logits, mask = self.feed_batch(batch)
            y = batch['tag_id']
            loss += self.compute_loss(criterion, logits, y, mask).item()
            prediction = self.decode_output(logits, mask, batch)
            self.update_metrics(metric, logits, y, mask, batch, prediction)
            if output:
                self.write_prediction(prediction, batch, output)
            timer.log(f'loss: {loss / (idx + 1):.4f} {metric}',
                      ratio_percentage=False,
                      logger=logger,
                      ratio_width=ratio_width)
        loss /= len(data)
        if output:
            output.close()
        return float(loss), metric
Esempio n. 6
0
 def execute_training_loop(self, trn: DataLoader, dev: DataLoader, epochs,
                           criterion, optimizer, metric, save_dir,
                           logger: logging.Logger, devices, **kwargs):
     best_epoch, best_metric = 0, -1
     timer = CountdownTimer(epochs)
     ratio_width = len(f'{len(trn)}/{len(trn)}')
     for epoch in range(1, epochs + 1):
         logger.info(f"[yellow]Epoch {epoch} / {epochs}:[/yellow]")
         self.fit_dataloader(trn, criterion, optimizer, metric, logger)
         if dev:
             self.evaluate_dataloader(dev,
                                      criterion,
                                      metric,
                                      logger,
                                      ratio_width=ratio_width)
         report = f'{timer.elapsed_human}/{timer.total_time_human}'
         dev_score = metric.score
         if dev_score > best_metric:
             self.save_weights(save_dir)
             best_metric = dev_score
             report += ' [red]saved[/red]'
         timer.log(report,
                   ratio_percentage=False,
                   newline=True,
                   ratio=False)
Esempio n. 7
0
 def execute_training_loop(self,
                           trn: DataLoader,
                           dev: DataLoader,
                           epochs,
                           criterion,
                           optimizer,
                           metric,
                           save_dir,
                           logger: logging.Logger,
                           devices,
                           gradient_accumulation=1,
                           **kwargs):
     best_epoch, best_metric = 0, -1
     optimizer, scheduler = optimizer
     history = History()
     timer = CountdownTimer(epochs)
     ratio_width = len(f'{len(trn)}/{len(trn)}')
     for epoch in range(1, epochs + 1):
         logger.info(f"[yellow]Epoch {epoch} / {epochs}:[/yellow]")
         self.fit_dataloader(trn, criterion, optimizer, metric, logger, history=history,
                             gradient_accumulation=gradient_accumulation,
                             linear_scheduler=scheduler if self._get_transformer() else None)
         if dev:
             self.evaluate_dataloader(dev, criterion, metric, logger, ratio_width=ratio_width)
         report = f'{timer.elapsed_human}/{timer.total_time_human}'
         dev_score = metric.score
         if not self._get_transformer():
             scheduler.step(dev_score)
         if dev_score > best_metric:
             self.save_weights(save_dir)
             best_metric = dev_score
             report += ' [red]saved[/red]'
         timer.log(report, ratio_percentage=False, newline=True, ratio=False)
Esempio n. 8
0
 def evaluate_dataloader(self,
                         data: DataLoader,
                         criterion: Callable,
                         metric,
                         logger,
                         ratio_width=None,
                         output=False,
                         **kwargs):
     self.model.eval()
     self.reset_metrics(metric)
     timer = CountdownTimer(len(data))
     total_loss = 0
     if output:
         fp = open(output, 'w')
     for batch in data:
         output_dict = self.feed_batch(batch)
         if output:
             for sent, pred, gold in zip(batch['token'], output_dict['prediction'], batch['ner']):
                 fp.write('Tokens\t' + ' '.join(sent) + '\n')
                 fp.write('Pred\t' + '\t'.join(
                     ['[' + ' '.join(sent[x:y + 1]) + f']/{label}' for x, y, label in pred]) + '\n')
                 fp.write('Gold\t' + '\t'.join(
                     ['[' + ' '.join(sent[x:y + 1]) + f']/{label}' for x, y, label in gold]) + '\n')
                 fp.write('\n')
         self.update_metrics(batch, output_dict, metric)
         loss = output_dict['loss']
         total_loss += loss.item()
         timer.log(self.report_metrics(total_loss / (timer.current + 1), metric), ratio_percentage=None,
                   logger=logger,
                   ratio_width=ratio_width)
         del loss
     if output:
         fp.close()
     return total_loss / timer.total, metric
Esempio n. 9
0
 def fit_dataloader(self,
                    trn: DataLoader,
                    criterion,
                    optimizer,
                    metric,
                    logger: logging.Logger,
                    linear_scheduler=None,
                    history: History = None,
                    gradient_accumulation=1,
                    **kwargs):
     self.model.train()
     timer = CountdownTimer(history.num_training_steps(len(trn), gradient_accumulation=gradient_accumulation))
     total_loss = 0
     self.reset_metrics(metric)
     for batch in trn:
         optimizer.zero_grad()
         output_dict = self.feed_batch(batch)
         self.update_metrics(batch, output_dict, metric)
         loss = output_dict['loss']
         if gradient_accumulation and gradient_accumulation > 1:
             loss /= gradient_accumulation
         loss.backward()
         total_loss += loss.item()
         if history.step(gradient_accumulation):
             if self.config.grad_norm:
                 torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_norm)
             optimizer.step()
             if linear_scheduler:
                 linear_scheduler.step()
             timer.log(self.report_metrics(total_loss / (timer.current + 1), metric), ratio_percentage=None,
                       logger=logger)
         del loss
     return total_loss / timer.total
Esempio n. 10
0
 def build_dataset(self,
                   data,
                   logger: logging.Logger = None,
                   training=True):
     dataset = AbstractMeaningRepresentationDataset(
         data, generate_idx=not training)
     if self.vocabs.mutable:
         self.build_vocabs(dataset, logger)
         self.vocabs.lock()
         self.vocabs.summary(logger)
     lens = [len(x['token']) + len(x['amr']) for x in dataset]
     dataset.append_transform(
         functools.partial(get_concepts,
                           vocab=self.vocabs.predictable_concept,
                           rel_vocab=self.vocabs.rel if self.config.get(
                               'separate_rel', False) else None))
     dataset.append_transform(append_bos)
     # Tokenization will happen in batchify
     if not self.config.get('squeeze', None):
         dataset.append_transform(self.config.encoder.transform())
     if isinstance(data, str):
         dataset.purge_cache()
         timer = CountdownTimer(len(dataset))
         for each in dataset:
             timer.log(
                 'Caching samples [blink][yellow]...[/yellow][/blink]')
     return dataset, lens
Esempio n. 11
0
 def fit_dataloader(self,
                    trn: DataLoader,
                    criterion,
                    optimizer,
                    metric,
                    logger: logging.Logger,
                    linear_scheduler=None,
                    gradient_accumulation=1,
                    **kwargs):
     self.model.train()
     timer = CountdownTimer(len(trn) // gradient_accumulation)
     total_loss = 0
     self.reset_metrics(metric)
     for idx, batch in enumerate(trn):
         output_dict = self.feed_batch(batch)
         self.update_metrics(batch, output_dict, metric)
         loss = output_dict['loss']
         loss = loss.sum()  # For data parallel
         loss.backward()
         if gradient_accumulation and gradient_accumulation > 1:
             loss /= gradient_accumulation
         if self.config.grad_norm:
             torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                            self.config.grad_norm)
         if (idx + 1) % gradient_accumulation == 0:
             self._step(optimizer, linear_scheduler)
             timer.log(self.report_metrics(total_loss / (timer.current + 1),
                                           metric),
                       ratio_percentage=None,
                       logger=logger)
         total_loss += loss.item()
         del loss
     if len(trn) % gradient_accumulation:
         self._step(optimizer, linear_scheduler)
     return total_loss / timer.total
Esempio n. 12
0
    def evaluate_dataloader(self, loader: PadSequenceDataLoader, criterion, logger=None, filename=None, output=False,
                            ratio_width=None,
                            metric=None,
                            **kwargs):
        self.model.eval()

        total_loss = 0
        if not metric:
            metric = self.build_metric()

        timer = CountdownTimer(len(loader))
        for batch in loader:
            (s_arc, s_sib, s_rel), mask, puncts = self.feed_batch(batch)
            arcs, sibs, rels = batch['arc'], batch['sib_id'], batch['rel_id']
            loss, s_arc = self.compute_loss(s_arc, s_sib, s_rel, arcs, sibs, rels, mask)
            total_loss += float(loss)
            arc_preds, rel_preds = self.decode(s_arc, s_sib, s_rel, mask)
            self.update_metric(arc_preds, rel_preds, arcs, rels, mask, puncts, metric)
            report = self._report(total_loss / (timer.current + 1), metric)
            if filename:
                report = f'{os.path.basename(filename)} ' + report
            timer.log(report, ratio_percentage=False, logger=logger, ratio_width=ratio_width)
        total_loss /= len(loader)

        return total_loss, metric
Esempio n. 13
0
    def evaluate_dataloader(self,
                            data: MultiTaskDataLoader,
                            criterion,
                            metric: MetricDict,
                            logger,
                            ratio_width=None,
                            input: str = None,
                            **kwargs):
        self.model.eval()
        self.reset_metrics(metric)
        tasks_need_custom_eval = self.config.get('tasks_need_custom_eval',
                                                 None)
        tasks_need_custom_eval = tasks_need_custom_eval or {}
        tasks_need_custom_eval = dict(
            (k, None) for k in tasks_need_custom_eval)
        for each in tasks_need_custom_eval:
            tasks_need_custom_eval[each] = data.dataloaders.pop(each)
        timer = CountdownTimer(len(data) + len(tasks_need_custom_eval))
        total_loss = 0
        for idx, (task_name, batch) in enumerate(data):
            output_dict, _ = self.feed_batch(batch, task_name)
            loss = self.compute_loss(batch, output_dict[task_name]['output'],
                                     criterion[task_name],
                                     self.tasks[task_name])
            total_loss += loss.item()
            self.decode_output(output_dict, batch, task_name)
            self.update_metrics(batch, output_dict, metric, task_name)
            timer.log(self.report_metrics(total_loss / (timer.current + 1),
                                          metric),
                      ratio_percentage=None,
                      logger=logger,
                      ratio_width=ratio_width)
            del loss
            del output_dict

        for task_name, dataset in tasks_need_custom_eval.items():
            task = self.tasks[task_name]
            decoder = self.model_.decoders[task_name]
            task.evaluate_dataloader(
                dataset,
                task.build_criterion(decoder=decoder),
                metric=metric[task_name],
                input=task.dev if input == 'dev' else task.tst,
                split=input,
                decoder=decoder,
                h=functools.partial(self._encode,
                                    task_name=task_name,
                                    cls_is_bos=task.cls_is_bos,
                                    sep_is_eos=task.sep_is_eos))
            data.dataloaders[task_name] = dataset
            timer.log(self.report_metrics(total_loss / (timer.current + 1),
                                          metric),
                      ratio_percentage=None,
                      logger=logger,
                      ratio_width=ratio_width)

        return total_loss / timer.total, metric, data
Esempio n. 14
0
 def execute_training_loop(self,
                           trn: PrefetchDataLoader,
                           dev: PrefetchDataLoader,
                           epochs,
                           criterion,
                           optimizer,
                           metric,
                           save_dir,
                           logger: logging.Logger,
                           devices,
                           ratio_width=None,
                           dev_data=None,
                           gradient_accumulation=1,
                           **kwargs):
     best_epoch, best_metric = 0, -1
     timer = CountdownTimer(epochs)
     history = History()
     try:
         for epoch in range(1, epochs + 1):
             logger.info(f"[yellow]Epoch {epoch} / {epochs}:[/yellow]")
             trn = self.fit_dataloader(
                 trn,
                 criterion,
                 optimizer,
                 metric,
                 logger,
                 ratio_width=ratio_width,
                 gradient_accumulation=gradient_accumulation,
                 history=history,
                 save_dir=save_dir)
             report = f'{timer.elapsed_human}/{timer.total_time_human}'
             if epoch % self.config.eval_every == 0 or epoch == epochs:
                 metric = self.evaluate_dataloader(dev,
                                                   logger,
                                                   dev_data,
                                                   ratio_width=ratio_width,
                                                   save_dir=save_dir,
                                                   use_fast=True)
                 if metric > best_metric:
                     self.save_weights(save_dir)
                     best_metric = metric
                     best_epoch = epoch
                     report += ' [red]saved[/red]'
             timer.log(report,
                       ratio_percentage=False,
                       newline=True,
                       ratio=False)
         if best_epoch and best_epoch != epochs:
             logger.info(
                 f'Restored the best model with {best_metric} saved {epochs - best_epoch} epochs ago'
             )
             self.load_weights(save_dir)
     finally:
         trn.close()
         dev.close()
Esempio n. 15
0
 def fit_dataloader(self,
                    trn: DataLoader,
                    criterion,
                    optimizer,
                    metric,
                    logger: logging.Logger,
                    history: History,
                    gradient_accumulation=1,
                    grad_norm=None,
                    transformer_grad_norm=None,
                    teacher: Tagger = None,
                    kd_criterion=None,
                    temperature_scheduler=None,
                    ratio_width=None,
                    **kwargs):
     optimizer, scheduler = optimizer
     if teacher:
         scheduler, lambda_scheduler = scheduler
     else:
         lambda_scheduler = None
     self.model.train()
     timer = CountdownTimer(
         history.num_training_steps(
             len(trn), gradient_accumulation=gradient_accumulation))
     total_loss = 0
     for idx, batch in enumerate(trn):
         out, mask = self.feed_batch(batch)
         y = batch['tag_id']
         loss = self.compute_loss(criterion, out, y, mask)
         if gradient_accumulation and gradient_accumulation > 1:
             loss /= gradient_accumulation
         if teacher:
             with torch.no_grad():
                 out_T, _ = teacher.feed_batch(batch)
             # noinspection PyNoneFunctionAssignment
             kd_loss = self.compute_distill_loss(kd_criterion, out, out_T,
                                                 mask,
                                                 temperature_scheduler)
             _lambda = float(lambda_scheduler)
             loss = _lambda * loss + (1 - _lambda) * kd_loss
         loss.backward()
         total_loss += loss.item()
         prediction = self.decode_output(out, mask, batch)
         self.update_metrics(metric, out, y, mask, batch, prediction)
         if history.step(gradient_accumulation):
             self._step(optimizer, scheduler, grad_norm,
                        transformer_grad_norm, lambda_scheduler)
             report = f'loss: {total_loss / (idx + 1):.4f} {metric}'
             timer.log(report,
                       logger=logger,
                       ratio_percentage=False,
                       ratio_width=ratio_width)
         del loss
         del out
         del mask
Esempio n. 16
0
 def build_vocabs(self, dataset, logger, vocabs, lock=True, label_vocab_name='label', **kwargs):
     vocabs[label_vocab_name] = label_vocab = Vocab(pad_token=None, unk_token=None)
     # Use null to indicate no relationship
     label_vocab.add('<null>')
     timer = CountdownTimer(len(dataset))
     for each in dataset:
         timer.log('Building NER vocab [blink][yellow]...[/yellow][/blink]')
     label_vocab.set_unk_as_safe_unk()
     if lock:
         vocabs.lock()
         vocabs.summary(logger)
Esempio n. 17
0
 def build_vocabs(self, trn, logger, **kwargs):
     self.vocabs.chart = VocabWithNone(pad_token=None, unk_token=None)
     timer = CountdownTimer(len(trn))
     max_seq_len = 0
     for each in trn:
         max_seq_len = max(max_seq_len, len(each['token_input_ids']))
         timer.log(
             f'Building vocab [blink][yellow]...[/yellow][/blink] (longest sequence: {max_seq_len})'
         )
     self.vocabs.chart.set_unk_as_safe_unk()
     self.vocabs.lock()
     self.vocabs.summary(logger)
Esempio n. 18
0
    def evaluate_dataloader(self,
                            loader: PadSequenceDataLoader,
                            criterion,
                            logger=None,
                            filename=None,
                            output=False,
                            ratio_width=None,
                            metric=None,
                            **kwargs):
        self.model.eval()

        loss = 0
        if not metric:
            metric = self.build_metric()
        if output:
            fp = open(output, 'w')
            predictions, build_data, data, order = self.before_outputs(None)

        timer = CountdownTimer(len(loader))
        use_pos = self.use_pos
        for batch in loader:
            arc_scores, rel_scores, mask, puncts = self.feed_batch(batch)
            if output:
                self.collect_outputs(arc_scores, rel_scores, mask, batch,
                                     predictions, order, data, use_pos,
                                     build_data)
            arcs, rels = batch['arc'], batch['rel_id']
            loss += self.compute_loss(arc_scores, rel_scores, arcs, rels, mask,
                                      criterion, batch).item()
            arc_preds, rel_preds = self.decode(arc_scores, rel_scores, mask,
                                               batch)
            self.update_metric(arc_preds, rel_preds, arcs, rels, mask, puncts,
                               metric, batch)
            report = self._report(loss / (timer.current + 1), metric)
            if filename:
                report = f'{os.path.basename(filename)} ' + report
            timer.log(report,
                      ratio_percentage=False,
                      logger=logger,
                      ratio_width=ratio_width)
        loss /= len(loader)
        if output:
            outputs = self.post_outputs(predictions, data, order, use_pos,
                                        build_data)
            for each in outputs:
                fp.write(f'{each}\n\n')
            fp.close()
            logger.info(
                f'Predictions saved in [underline][yellow]{output}[/yellow][/underline]'
            )

        return loss, metric
Esempio n. 19
0
 def build_vocabs(self, trn, logger, **kwargs):
     self.vocabs.tag = Vocab(pad_token=None, unk_token=None)
     timer = CountdownTimer(len(trn))
     max_seq_len = 0
     token_key = self.config.token_key
     for each in trn:
         max_seq_len = max(max_seq_len, len(each[token_key]))
         timer.log(
             f'Building vocab [blink][yellow]...[/yellow][/blink] (longest sequence: {max_seq_len})'
         )
     self.vocabs.tag.set_unk_as_safe_unk()
     self.vocabs.lock()
     self.vocabs.summary(logger)
Esempio n. 20
0
    def build_vocabs(self, dataset, logger=None, transformer=None):
        rel_vocab = self.vocabs.get('rel', None)
        if rel_vocab is None:
            rel_vocab = Vocab(unk_token=None,
                              pad_token=self.config.get('pad_rel', None))
            self.vocabs.put(rel=rel_vocab)
        if self.config.get('feat', None) == 'pos' or self.config.get(
                'use_pos', False):
            self.vocabs['pos'] = Vocab(unk_token=None, pad_token=None)

        timer = CountdownTimer(len(dataset))
        if transformer:
            token_vocab = None
        else:
            token_vocab = Vocab()
            self.vocabs.token = token_vocab
            unk = self.config.get('unk', None)
            if unk is not None:
                token_vocab.unk_token = unk
        if token_vocab and self.config.get('min_freq', None):
            counter = Counter()
            for sample in dataset:
                for form in sample['token']:
                    counter[form] += 1
            reserved_token = [token_vocab.pad_token, token_vocab.unk_token]
            if ROOT in token_vocab:
                reserved_token.append(ROOT)
            freq_words = reserved_token + [
                token for token, freq in counter.items()
                if freq >= self.config.min_freq
            ]
            token_vocab.token_to_idx.clear()
            for word in freq_words:
                token_vocab(word)
        else:
            for i, sample in enumerate(dataset):
                timer.log('vocab building [blink][yellow]...[/yellow][/blink]',
                          ratio_percentage=True)
        rel_vocab.set_unk_as_safe_unk()  # Some relation in dev set is OOV
        self.vocabs.lock()
        self.vocabs.summary(logger=logger)
        if token_vocab:
            self.config.n_words = len(self.vocabs['token'])
        if 'pos' in self.vocabs:
            self.config.n_feats = len(self.vocabs['pos'])
            self.vocabs['pos'].set_unk_as_safe_unk()
        self.config.n_rels = len(self.vocabs['rel'])
        if token_vocab:
            self.config.pad_index = self.vocabs['token'].pad_idx
            self.config.unk_index = self.vocabs['token'].unk_idx
Esempio n. 21
0
 def evaluate_dataloader(self,
                         data: DataLoader,
                         criterion: Callable,
                         metric,
                         logger,
                         ratio_width=None,
                         filename=None,
                         output=None,
                         **kwargs):
     self.model.eval()
     timer = CountdownTimer(len(data))
     total_loss = 0
     metric.reset()
     num_samples = 0
     if output:
         output = open(output, 'w')
     for batch in data:
         logits = self.feed_batch(batch)
         target = batch['label_id']
         loss = self.compute_loss(criterion, logits, target, batch)
         total_loss += loss.item()
         label_ids = self.update_metric(metric, logits, target, output)
         if output:
             labels = [
                 self.vocabs[self.config.label_key].idx_to_token[i]
                 for i in label_ids.tolist()
             ]
             for i, label in enumerate(labels):
                 # text_a text_b pred gold
                 columns = [batch[self.config.text_a_key][i]]
                 if self.config.text_b_key:
                     columns.append(batch[self.config.text_b_key][i])
                 columns.append(label)
                 columns.append(batch[self.config.label_key][i])
                 output.write('\t'.join(columns))
                 output.write('\n')
         num_samples += len(target)
         report = f'loss: {total_loss / (timer.current + 1):.4f} acc: {metric.get_metric():.2%}'
         if filename:
             report = f'{filename} {report} {num_samples / timer.elapsed:.0f} samples/sec'
         timer.log(report,
                   ratio_percentage=None,
                   logger=logger,
                   ratio_width=ratio_width)
     if output:
         output.close()
     return total_loss / timer.total
Esempio n. 22
0
 def build_vocabs(self, dataset, logger, **kwargs):
     self.vocabs.srl_label = Vocab(pad_token=None, unk_token=None)
     # Use null to indicate no relationship
     self.vocabs.srl_label.add('<null>')
     timer = CountdownTimer(len(dataset))
     max_seq_len = 0
     for each in dataset:
         max_seq_len = max(max_seq_len, len(each['token_input_ids']))
         timer.log(
             f'Building vocabs (max sequence length {max_seq_len}) [blink][yellow]...[/yellow][/blink]'
         )
         pass
     timer.stop()
     timer.erase()
     self.vocabs['srl_label'].set_unk_as_safe_unk()
     self.vocabs.lock()
     self.vocabs.summary(logger)
Esempio n. 23
0
    def evaluate_dataloader(self,
                            data: DataLoader,
                            criterion: Callable,
                            metric=None,
                            output=False,
                            ratio_width=None,
                            logger=None,
                            input=None,
                            use_fast=False,
                            **kwargs):
        self.model.eval()
        timer = CountdownTimer(len(data))
        graphs = []
        orders = []
        smatch = 0
        for idx, batch in enumerate(data):
            graphs_per_batch = self.predict_amrs(batch)
            graphs_per_batch = [x[0] for x in graphs_per_batch]
            # Copy meta data from gold graph
            for gp, gg in zip(graphs_per_batch, batch['amr']):
                metadata = gg.metadata.copy()
                metadata['annotator'] = f'{self.config.transformer}-amr'
                metadata['date'] = str(datetime.datetime.now())
                if 'save-date' in metadata:
                    del metadata['save-date']
                gp.metadata = metadata
            graphs.extend(graphs_per_batch)
            orders.extend(batch[IDX])
            if idx == timer.total - 1:
                graphs = reorder(graphs, orders)
                write_predictions(output, self._tokenizer, graphs)
                try:
                    if use_fast:
                        smatch = compute_smatch(output, input)
                    else:
                        smatch = smatch_eval(output, input, use_fast=False)
                except:
                    pass
                timer.log(smatch.cstr() if isinstance(smatch, MetricDict) else
                          f'{smatch:.2%}',
                          ratio_percentage=False,
                          logger=logger)
            else:
                timer.log(ratio_percentage=False, logger=logger)

        return smatch
Esempio n. 24
0
def parse_data(model,
               pp: PostProcessor,
               data,
               input_file,
               output_file,
               beam_size=8,
               alpha=0.6,
               max_time_step=100,
               h=None):
    if not output_file:
        output_file = tempfile.NamedTemporaryFile().name
    tot = 0
    levi_graph = model.decoder.levi_graph if hasattr(model,
                                                     'decoder') else False
    with open(output_file, 'w') as fo:
        timer = CountdownTimer(len(data))
        order = []
        outputs = []
        for batch in data:
            order.extend(batch[IDX])
            res = parse_batch(model,
                              batch,
                              beam_size,
                              alpha,
                              max_time_step,
                              h=h)
            outputs.extend(
                list(zip(res['concept'], res['relation'], res['score'])))
            timer.log('Parsing [blink][yellow]...[/yellow][/blink]',
                      ratio_percentage=False)
        outputs = reorder(outputs, order)
        timer = CountdownTimer(len(data))
        for concept, relation, score in outputs:
            fo.write('# ::conc ' + ' '.join(concept) + '\n')
            fo.write('# ::score %.6f\n' % score)
            fo.write(
                pp.postprocess(concept, relation, check_connected=levi_graph) +
                '\n\n')
            tot += 1
            timer.log('Post-processing [blink][yellow]...[/yellow][/blink]',
                      ratio_percentage=False)
    match(output_file, input_file)
Esempio n. 25
0
    def execute_training_loop(self, trn: DataLoader, dev: DataLoader, epochs, criterion,
                              optimizer,
                              metric,
                              save_dir,
                              logger,
                              patience,
                              **kwargs):
        max_e, max_metric = 0, -1

        criterion = self.build_criterion()
        timer = CountdownTimer(epochs)
        ratio_width = len(f'{len(trn)}/{len(trn)}')
        scheduler = self.build_scheduler(**merge_dict(self.config, optimizer=optimizer, overwrite=True))
        if not patience:
            patience = epochs
        for epoch in range(1, epochs + 1):
            logger.info(f"[yellow]Epoch {epoch} / {epochs}:[/yellow]")
            self.fit_dataloader(trn, criterion, optimizer, metric, logger, ratio_width=ratio_width)
            loss, dev_metric = self.evaluate_dataloader(dev, criterion, logger)
            if scheduler:
                if isinstance(scheduler, ReduceLROnPlateau):
                    scheduler.step(dev_metric.score)
                else:
                    scheduler.step(epoch)
            report_patience = f'Patience: {epoch - max_e}/{patience}'
            # save the model if it is the best so far
            if dev_metric > max_metric:
                self.save_weights(save_dir)
                max_e, max_metric = epoch, dev_metric
                report_patience = '[red]Saved[/red] '
            stop = epoch - max_e >= patience
            if stop:
                timer.stop()
            timer.log(f'{report_patience} lr: {optimizer.param_groups[0]["lr"]:.4f}',
                      ratio_percentage=False, newline=True, ratio=False)
            if stop:
                break
        timer.stop()
        if max_e != epoch:
            self.load_weights(save_dir)
        logger.info(f"Max score of dev is {max_metric.score:.2%} at epoch {max_e}")
        logger.info(f"{timer.elapsed_human} elapsed, average time of each epoch is {timer.elapsed_average_human}")
Esempio n. 26
0
 def fit_dataloader(self,
                    trn: DataLoader,
                    criterion,
                    optimizer,
                    metric: SpanMetric,
                    logger: logging.Logger,
                    history: History,
                    gradient_accumulation=1,
                    grad_norm=None,
                    ratio_width=None,
                    eval_trn=True,
                    **kwargs):
     optimizer, scheduler = optimizer
     metric.reset()
     self.model.train()
     timer = CountdownTimer(
         history.num_training_steps(
             len(trn), gradient_accumulation=gradient_accumulation))
     total_loss = 0
     for idx, batch in enumerate(trn):
         out, mask = self.feed_batch(batch)
         y = batch['chart_id']
         loss, span_probs = self.compute_loss(out, y, mask)
         if gradient_accumulation and gradient_accumulation > 1:
             loss /= gradient_accumulation
         loss.backward()
         total_loss += loss.item()
         if eval_trn:
             prediction = self.decode_output(out, mask, batch, span_probs)
             self.update_metrics(metric, batch, prediction)
         if history.step(gradient_accumulation):
             self._step(optimizer, scheduler, grad_norm)
             report = f'loss: {total_loss / (idx + 1):.4f} {metric}' if eval_trn \
                 else f'loss: {total_loss / (idx + 1):.4f}'
             timer.log(report,
                       logger=logger,
                       ratio_percentage=False,
                       ratio_width=ratio_width)
         del loss
         del out
         del mask
Esempio n. 27
0
 def fit_dataloader(self, trn: DataLoader, criterion, optimizer, metric, logger: logging.Logger, ratio_width=None,
                    **kwargs):
     self.model.train()
     timer = CountdownTimer(len(trn))
     total_loss = 0
     for idx, batch in enumerate(trn):
         optimizer.zero_grad()
         out, mask = self.feed_batch(batch)
         y = batch['tag_id']
         loss = self.compute_loss(criterion, out, y, mask)
         loss.backward()
         nn.utils.clip_grad_norm_(self.model.parameters(), 5.0)
         optimizer.step()
         total_loss += loss.item()
         prediction = self.decode_output(out, mask, batch)
         self.update_metrics(metric, out, y, mask, batch, prediction)
         timer.log(f'loss: {loss / (idx + 1):.4f} {metric}', ratio_percentage=False, logger=logger,
                   ratio_width=ratio_width)
         del loss
         del out
         del mask
Esempio n. 28
0
    def fit_dataloader(self,
                       trn,
                       optimizer,
                       scheduler,
                       criterion,
                       epoch,
                       logger,
                       history: History,
                       transformer_optimizer=None,
                       transformer_scheduler=None,
                       gradient_accumulation=1,
                       eval_trn=False,
                       **kwargs):
        self.model.train()

        timer = CountdownTimer(history.num_training_steps(len(trn), gradient_accumulation))
        metric = self.build_metric(training=True)
        total_loss = 0
        for idx, batch in enumerate(trn):
            optimizer.zero_grad()
            (s_arc, s_sib, s_rel), mask, puncts = self.feed_batch(batch)
            arcs, sibs, rels = batch['arc'], batch['sib_id'], batch['rel_id']

            loss, s_arc = self.compute_loss(s_arc, s_sib, s_rel, arcs, sibs, rels, mask)
            if gradient_accumulation > 1:
                loss /= gradient_accumulation
            loss.backward()
            total_loss += loss.item()
            if eval_trn:
                arc_preds, rel_preds = self.decode(s_arc, s_sib, s_rel, mask)
                self.update_metric(arc_preds, rel_preds, arcs, rels, mask, puncts, metric)
            if history.step(gradient_accumulation):
                self._step(optimizer, scheduler, transformer_optimizer, transformer_scheduler)
                report = self._report(total_loss / (timer.current + 1), metric if eval_trn else None)
                lr = scheduler.get_last_lr()[0]
                report += f' lr: {lr:.4e}'
                timer.log(report, ratio_percentage=False, logger=logger)
            del loss
Esempio n. 29
0
 def fit_dataloader(self, trn: DataLoader, criterion, optimizer, metric,
                    logger: logging.Logger, **kwargs):
     self.model.train()
     timer = CountdownTimer(len(trn))
     optimizer, scheduler = optimizer
     total_loss = 0
     metric.reset()
     for batch in trn:
         optimizer.zero_grad()
         logits = self.feed_batch(batch)
         target = batch['label_id']
         loss = self.compute_loss(criterion, logits, target, batch)
         loss.backward()
         optimizer.step()
         scheduler.step()
         total_loss += loss.item()
         self.update_metric(metric, logits, target)
         timer.log(
             f'loss: {total_loss / (timer.current + 1):.4f} acc: {metric.get_metric():.2%}',
             ratio_percentage=None,
             logger=logger)
         del loss
     return total_loss / timer.total
Esempio n. 30
0
 def fit_dataloader(self, trn: DataLoader, criterion, optimizer, metric,
                    logger: logging.Logger, **kwargs):
     self.model.train()
     timer = CountdownTimer(len(trn))
     total_loss = 0
     self.reset_metrics(metric)
     for batch in trn:
         optimizer.zero_grad()
         prediction = self.feed_batch(batch)
         loss = self.compute_loss(prediction, batch, criterion)
         self.update_metrics(batch, prediction, metric)
         loss.backward()
         if self.config.grad_norm:
             torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                            self.config.grad_norm)
         optimizer.step()
         total_loss += loss.item()
         timer.log(self.report_metrics(total_loss / (timer.current + 1),
                                       metric),
                   ratio_percentage=None,
                   logger=logger)
         del loss
     return total_loss / timer.total