Exemplo n.º 1
0
 def fit_dataloader(self,
                    trn: DataLoader,
                    criterion,
                    optimizer,
                    metric,
                    logger: logging.Logger,
                    history: History,
                    linear_scheduler=None,
                    gradient_accumulation=1,
                    **kwargs):
     self.model.train()
     timer = CountdownTimer(history.num_training_steps(len(trn), gradient_accumulation))
     total_loss = 0
     self.reset_metrics(metric)
     for idx, batch in enumerate(trn):
         output_dict = self.feed_batch(batch)
         self.update_metrics(batch, output_dict, metric)
         loss = output_dict['loss']
         if gradient_accumulation > 1:
             loss /= gradient_accumulation
         loss.backward()
         total_loss += loss.item()
         if history.step(gradient_accumulation):
             self._step(optimizer, linear_scheduler)
             timer.log(self.report_metrics(total_loss / (timer.current + 1), metric), ratio_percentage=None,
                       logger=logger)
         del loss
     return total_loss / timer.total
Exemplo n.º 2
0
 def execute_training_loop(self,
                           trn: DataLoader,
                           dev: DataLoader,
                           epochs,
                           criterion,
                           optimizer,
                           metric,
                           save_dir,
                           logger: logging.Logger,
                           devices,
                           patience=0.5,
                           **kwargs):
     if isinstance(patience, float):
         patience = int(patience * epochs)
     best_epoch, best_metric = 0, -1
     timer = CountdownTimer(epochs)
     ratio_width = len(f'{len(trn)}/{len(trn)}')
     epoch = 0
     history = History()
     for epoch in range(1, epochs + 1):
         logger.info(f"[yellow]Epoch {epoch} / {epochs}:[/yellow]")
         self.fit_dataloader(trn,
                             criterion,
                             optimizer,
                             metric,
                             logger,
                             history,
                             ratio_width=ratio_width,
                             **self.config)
         if dev:
             self.evaluate_dataloader(dev,
                                      criterion,
                                      metric,
                                      logger,
                                      ratio_width=ratio_width,
                                      input='dev')
         report = f'{timer.elapsed_human}/{timer.total_time_human}'
         dev_score = metric.score
         if dev_score > best_metric:
             self.save_weights(save_dir)
             best_metric = dev_score
             best_epoch = epoch
             report += ' [red]saved[/red]'
         else:
             report += f' ({epoch - best_epoch})'
             if epoch - best_epoch >= patience:
                 report += ' early stop'
                 break
         timer.log(report,
                   ratio_percentage=False,
                   newline=True,
                   ratio=False)
     for d in [trn, dev]:
         self._close_dataloader(d)
     if best_epoch != epoch:
         logger.info(
             f'Restoring best model saved [red]{epoch - best_epoch}[/red] epochs ago'
         )
         self.load_weights(save_dir)
     return best_metric
Exemplo n.º 3
0
    def build_vocabs(self, dataset, logger=None, transformer=False):
        rel_vocab = self.vocabs.get('rel', None)
        if rel_vocab is None:
            rel_vocab = Vocab(unk_token=None,
                              pad_token=self.config.get('pad_rel', None))
            self.vocabs.put(rel=rel_vocab)

        timer = CountdownTimer(len(dataset))
        if transformer:
            token_vocab = None
        else:
            self.vocabs.token = token_vocab = VocabCounter(
                unk_token=self.config.get('unk', UNK))
        for i, sample in enumerate(dataset):
            timer.log('Building vocab [blink][yellow]...[/yellow][/blink]',
                      ratio_percentage=True)
        min_freq = self.config.get('min_freq', None)
        if min_freq:
            token_vocab.trim(min_freq)
        rel_vocab.set_unk_as_safe_unk()  # Some relation in dev set is OOV
        self.vocabs.lock()
        self.vocabs.summary(logger=logger)
        if token_vocab:
            self.config.n_words = len(self.vocabs['token'])
        self.config.n_rels = len(self.vocabs['rel'])
        if token_vocab:
            self.config.pad_index = self.vocabs['token'].pad_idx
            self.config.unk_index = self.vocabs['token'].unk_idx
Exemplo n.º 4
0
    def evaluate_dataloader(self,
                            loader: PadSequenceDataLoader,
                            criterion,
                            logger=None,
                            filename=None,
                            output=False,
                            ratio_width=None,
                            metric=None,
                            **kwargs):
        self.model.eval()

        total_loss = 0
        if not metric:
            metric = self.build_metric()

        timer = CountdownTimer(len(loader))
        for batch in loader:
            (s_arc, s_sib, s_rel), mask, puncts = self.feed_batch(batch)
            arcs, sibs, rels = batch['arc'], batch['sib_id'], batch['rel_id']
            loss, s_arc = self.compute_loss(s_arc, s_sib, s_rel, arcs, sibs,
                                            rels, mask)
            total_loss += float(loss)
            arc_preds, rel_preds = self.decode(s_arc, s_sib, s_rel, mask)
            self.update_metric(arc_preds, rel_preds, arcs, rels, mask, puncts,
                               metric)
            report = self._report(total_loss / (timer.current + 1), metric)
            if filename:
                report = f'{os.path.basename(filename)} ' + report
            timer.log(report,
                      ratio_percentage=False,
                      logger=logger,
                      ratio_width=ratio_width)
        total_loss /= len(loader)

        return total_loss, metric
Exemplo n.º 5
0
 def load_file(self, filepath: str):
     filepath = get_resource(filepath)
     if os.path.isfile(filepath):
         files = [filepath]
     else:
         assert os.path.isdir(
             filepath), f'{filepath} has to be a directory of CoNLL 2012'
         files = sorted(
             glob.glob(f'{filepath}/**/*gold_conll', recursive=True))
     timer = CountdownTimer(len(files))
     for fid, f in enumerate(files):
         timer.log(f'files loading[blink][yellow]...[/yellow][/blink]')
         # 0:DOCUMENT 1:PART 2:INDEX 3:WORD 4:POS 5:PARSE 6:LEMMA 7:FRAME 8:SENSE 9:SPEAKER 10:NE 11-N:ARGS N:COREF
         for sent in read_tsv_as_sents(f, ignore_prefix='#'):
             sense = [cell[7] for cell in sent]
             props = [cell[11:-1] for cell in sent]
             props = map(lambda p: p, zip(*props))
             prd_bio_labels = [
                 self._make_bio_labels(prop) for prop in props
             ]
             prd_bio_labels = [self._remove_B_V(x) for x in prd_bio_labels]
             prd_indices = [i for i, x in enumerate(sense) if x != '-']
             token = [x[3] for x in sent]
             srl = [None for x in token]
             for idx, labels in zip(prd_indices, prd_bio_labels):
                 srl[idx] = labels
             srl = [x if x else ['O'] * len(token) for x in srl]
             yield {'token': token, 'srl': srl}
Exemplo n.º 6
0
 def _build_cache(self, dataset, verbose=HANLP_VERBOSE):
     timer = CountdownTimer(self.size)
     with open(self.filename, "wb") as f:
         for i, batch in enumerate(dataset):
             torch.save(batch, f, _use_new_zipfile_serialization=False)
             if verbose:
                 timer.log(f'Caching {self.filename} [blink][yellow]...[/yellow][/blink]')
Exemplo n.º 7
0
 def fit_dataloader(self,
                    trn: DataLoader,
                    criterion,
                    optimizer,
                    metric,
                    logger: logging.Logger,
                    history: History,
                    gradient_accumulation=1,
                    grad_norm=None,
                    ratio_width=None,
                    eval_trn=False,
                    **kwargs):
     optimizer, scheduler = optimizer
     self.model.train()
     timer = CountdownTimer(history.num_training_steps(len(trn), gradient_accumulation=gradient_accumulation))
     total_loss = 0
     for idx, batch in enumerate(trn):
         pred, mask = self.feed_batch(batch)
         loss = self.compute_loss(criterion, pred, batch['srl_id'], mask)
         if gradient_accumulation and gradient_accumulation > 1:
             loss /= gradient_accumulation
         loss.backward()
         total_loss += loss.item()
         if eval_trn:
             prediction = self.decode_output(pred, mask, batch)
             self.update_metrics(metric, prediction, batch)
         if history.step(gradient_accumulation):
             self._step(optimizer, scheduler, grad_norm)
             report = f'loss: {total_loss / (idx + 1):.4f} {metric}' if eval_trn else f'loss: {total_loss / (idx + 1):.4f}'
             timer.log(report, logger=logger, ratio_percentage=False, ratio_width=ratio_width)
         del loss
         del pred
         del mask
Exemplo n.º 8
0
 def fit_dataloader(self,
                    trn: DataLoader,
                    criterion,
                    optimizer,
                    metric,
                    logger: logging.Logger,
                    ratio_width=None,
                    **kwargs):
     self.model.train()
     timer = CountdownTimer(len(trn))
     total_loss = 0
     for idx, batch in enumerate(trn):
         optimizer.zero_grad()
         out, mask = self.feed_batch(batch)
         y = batch['tag_id']
         loss = self.compute_loss(criterion, out, y, mask)
         loss.backward()
         nn.utils.clip_grad_norm_(self.model.parameters(), 5.0)
         optimizer.step()
         total_loss += loss.item()
         prediction = self.decode_output(out, mask, batch)
         self.update_metrics(metric, out, y, mask, batch, prediction)
         timer.log(f'loss: {loss / (idx + 1):.4f} {metric}',
                   ratio_percentage=False,
                   logger=logger,
                   ratio_width=ratio_width)
         del loss
         del out
         del mask
Exemplo n.º 9
0
 def fit_dataloader(self,
                    trn: DataLoader,
                    criterion,
                    optimizer,
                    metric,
                    logger: logging.Logger,
                    linear_scheduler=None,
                    gradient_accumulation=1,
                    **kwargs):
     self.model.train()
     timer = CountdownTimer(len(trn) // gradient_accumulation)
     total_loss = 0
     self.reset_metrics(metric)
     for idx, batch in enumerate(trn):
         output_dict = self.feed_batch(batch)
         self.update_metrics(batch, output_dict, metric)
         loss = output_dict['loss']
         loss = loss.sum()  # For data parallel
         if torch.isnan(loss):  # w/ gold pred, some batches do not have PAs at all, resulting in empty scores
             loss = torch.zeros((1,), device=loss.device)
         else:
             loss.backward()
         if gradient_accumulation and gradient_accumulation > 1:
             loss /= gradient_accumulation
         if self.config.grad_norm:
             torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_norm)
         if (idx + 1) % gradient_accumulation == 0:
             self._step(optimizer, linear_scheduler)
             timer.log(self.report_metrics(total_loss / (timer.current + 1), metric), ratio_percentage=None,
                       logger=logger)
         total_loss += loss.item()
         del loss
     if len(trn) % gradient_accumulation:
         self._step(optimizer, linear_scheduler)
     return total_loss / timer.total
Exemplo n.º 10
0
 def evaluate_dataloader(self,
                         data: DataLoader,
                         criterion: Callable,
                         metric,
                         logger,
                         ratio_width=None,
                         output=False,
                         official=False,
                         confusion_matrix=False,
                         **kwargs):
     self.model.eval()
     self.reset_metrics(metric)
     timer = CountdownTimer(len(data))
     total_loss = 0
     if official:
         sentences = []
         gold = []
         pred = []
     for batch in data:
         output_dict = self.feed_batch(batch)
         if official:
             sentences += batch['token']
             gold += batch['srl']
             pred += output_dict['prediction']
         self.update_metrics(batch, output_dict, metric)
         loss = output_dict['loss']
         total_loss += loss.item()
         timer.log(self.report_metrics(total_loss / (timer.current + 1), metric), ratio_percentage=None,
                   logger=logger,
                   ratio_width=ratio_width)
         del loss
     if official:
         scores = compute_srl_f1(sentences, gold, pred)
         if logger:
             if confusion_matrix:
                 labels = sorted(set(y for x in scores.label_confusions.keys() for y in x))
                 headings = ['GOLD↓PRED→'] + labels
                 matrix = []
                 for i, gold in enumerate(labels):
                     row = [gold]
                     matrix.append(row)
                     for j, pred in enumerate(labels):
                         row.append(scores.label_confusions.get((gold, pred), 0))
                 matrix = markdown_table(headings, matrix)
                 logger.info(f'{"Confusion Matrix": ^{len(matrix.splitlines()[0])}}')
                 logger.info(matrix)
             headings = ['Settings', 'Precision', 'Recall', 'F1']
             data = []
             for h, (p, r, f) in zip(['Unlabeled', 'Labeled', 'Official'], [
                 [scores.unlabeled_precision, scores.unlabeled_recall, scores.unlabeled_f1],
                 [scores.precision, scores.recall, scores.f1],
                 [scores.conll_precision, scores.conll_recall, scores.conll_f1],
             ]):
                 data.append([h] + [f'{x:.2%}' for x in [p, r, f]])
             table = markdown_table(headings, data)
             logger.info(f'{"Scores": ^{len(table.splitlines()[0])}}')
             logger.info(table)
     else:
         scores = metric
     return total_loss / timer.total, scores
Exemplo n.º 11
0
 def execute_training_loop(self, trn: DataLoader, dev: DataLoader, epochs,
                           criterion, optimizer, metric, save_dir,
                           logger: logging.Logger, devices, **kwargs):
     best_epoch, best_metric = 0, -1
     timer = CountdownTimer(epochs)
     ratio_width = len(f'{len(trn)}/{len(trn)}')
     for epoch in range(1, epochs + 1):
         logger.info(f"[yellow]Epoch {epoch} / {epochs}:[/yellow]")
         self.fit_dataloader(trn, criterion, optimizer, metric, logger)
         if dev:
             self.evaluate_dataloader(dev,
                                      criterion,
                                      metric,
                                      logger,
                                      ratio_width=ratio_width)
         report = f'{timer.elapsed_human}/{timer.total_time_human}'
         dev_score = metric.get_metric()
         if dev_score > best_metric:
             self.save_weights(save_dir)
             best_metric = dev_score
             report += ' [red]saved[/red]'
         timer.log(report,
                   ratio_percentage=False,
                   newline=True,
                   ratio=False)
Exemplo n.º 12
0
 def evaluate_dataloader(self,
                         data,
                         criterion,
                         logger=None,
                         ratio_width=None,
                         metric=None,
                         output=None,
                         **kwargs):
     self.model.eval()
     total_loss = 0
     if not metric:
         metric = self.build_metric()
     else:
         metric.reset()
     timer = CountdownTimer(len(data))
     for idx, batch in enumerate(data):
         out, mask = self.feed_batch(batch)
         y = batch['chart_id']
         loss, span_probs = self.compute_loss(out, y, mask)
         total_loss += loss.item()
         prediction = self.decode_output(out, mask, batch, span_probs)
         self.update_metrics(metric, batch, prediction)
         timer.log(f'loss: {total_loss / (idx + 1):.4f} {metric}',
                   ratio_percentage=False,
                   logger=logger,
                   ratio_width=ratio_width)
     total_loss /= len(data)
     if output:
         output.close()
     return total_loss, metric
Exemplo n.º 13
0
    def build_dataloader(self, data, batch_size,
                         gradient_accumulation=1,
                         shuffle=False,
                         sampler_builder: SamplerBuilder = None,
                         device=None,
                         logger: logging.Logger = None,
                         **kwargs) -> DataLoader:
        dataset = self.build_dataset(data, not shuffle)
        if self.vocabs.mutable:
            self.build_vocabs(dataset, logger)
        self.finalize_dataset(dataset, logger)
        if isinstance(data, str):
            dataset.purge_cache()
            timer = CountdownTimer(len(dataset))
            max_num_tokens = 0
            # lc = Counter()
            for each in dataset:
                max_num_tokens = max(max_num_tokens, len(each['text_token_ids']))
                # lc[len(each['text_token_ids'])] += 1
                timer.log(f'Preprocessing and caching samples (longest sequence {max_num_tokens})'
                          f'[blink][yellow]...[/yellow][/blink]')
            # print(lc.most_common())
            if self.vocabs.mutable:
                self.vocabs.lock()
                self.vocabs.summary(logger)

        if not sampler_builder:
            sampler_builder = SortingSamplerBuilder(batch_max_tokens=500)
        sampler = sampler_builder.build([len(x['text_token_ids']) for x in dataset], shuffle,
                                        gradient_accumulation if dataset.cache else 1)
        return self._create_dataloader(dataset, batch_size, device, sampler, shuffle)
Exemplo n.º 14
0
Arquivo: mlm.py Projeto: lei1993/HanLP
 def build_dataloader(self,
                      data,
                      batch_size,
                      shuffle=False,
                      device=None,
                      logger: logging.Logger = None,
                      verbose=False,
                      **kwargs) -> DataLoader:
     dataset = MaskedLanguageModelDataset(
         [{
             'token': x
         } for x in data],
         generate_idx=True,
         transform=TransformerTextTokenizer(self.tokenizer,
                                            text_a_key='token'))
     if verbose:
         verbose = CountdownTimer(len(dataset))
     lens = []
     for each in dataset:
         lens.append(len(each['token_input_ids']))
         if verbose:
             verbose.log(
                 'Preprocessing and caching samples [blink][yellow]...[/yellow][/blink]'
             )
     dataloader = PadSequenceDataLoader(dataset,
                                        batch_sampler=SortingSampler(
                                            lens, batch_size=batch_size),
                                        device=device)
     return dataloader
Exemplo n.º 15
0
 def execute_training_loop(self,
                           trn: DataLoader,
                           dev: DataLoader,
                           epochs,
                           criterion,
                           optimizer,
                           metric,
                           save_dir,
                           logger: logging.Logger,
                           devices,
                           **kwargs):
     best_epoch, best_score = 0, -1
     optimizer, scheduler = optimizer
     timer = CountdownTimer(epochs)
     _len_trn = len(trn) // self.config.gradient_accumulation
     ratio_width = len(f'{_len_trn}/{_len_trn}')
     history = History()
     for epoch in range(1, epochs + 1):
         logger.info(f"[yellow]Epoch {epoch} / {epochs}:[/yellow]")
         self.fit_dataloader(trn, criterion, optimizer, metric, logger, history,
                             linear_scheduler=scheduler if self.use_transformer else None, **kwargs)
         if dev:
             metric = self.evaluate_dataloader(dev, criterion, metric, logger, ratio_width=ratio_width)
         report = f'{timer.elapsed_human}/{timer.total_time_human}'
         dev_score = sum(x.score for x in metric) / len(metric)
         if not self.use_transformer:
             scheduler.step(dev_score)
         if dev_score > best_score:
             self.save_weights(save_dir)
             best_score = dev_score
             report += ' [red]saved[/red]'
         timer.log(report, ratio_percentage=False, newline=True, ratio=False)
Exemplo n.º 16
0
 def fit_dataloader(self,
                    trn: DataLoader,
                    criterion,
                    optimizer,
                    metric,
                    logger: logging.Logger,
                    linear_scheduler=None,
                    **kwargs):
     self.model.train()
     timer = CountdownTimer(len(trn))
     total_loss = 0
     self.reset_metrics()
     for batch in trn:
         optimizer.zero_grad()
         output_dict = self.feed_batch(batch)
         loss = output_dict['loss']
         loss.backward()
         if self.config.grad_norm:
             clip_grad_norm(self.model, self.config.grad_norm)
         optimizer.step()
         if linear_scheduler:
             linear_scheduler.step()
         total_loss += loss.item()
         timer.log(self.report_metrics(total_loss / (timer.current + 1)),
                   ratio_percentage=None,
                   logger=logger)
         del loss
     return total_loss / timer.total
Exemplo n.º 17
0
    def evaluate_dataloader(self,
                            data: DataLoader,
                            criterion: Callable,
                            metric: MetricDict = None,
                            output=False,
                            logger=None,
                            ratio_width=None,
                            **kwargs):

        metric.reset()
        self.model.eval()
        timer = CountdownTimer(len(data))
        total_loss = 0
        for idx, batch in enumerate(data):
            out, mask = self.feed_batch(batch)
            loss = out['loss']
            total_loss += loss.item()
            self.decode_output(out, mask, batch)
            self.update_metrics(metric, batch, out, mask)
            report = f'loss: {total_loss / (idx + 1):.4f} {metric.cstr()}'
            timer.log(report,
                      logger=logger,
                      ratio_percentage=False,
                      ratio_width=ratio_width)
            del loss
            del out
            del mask
        return total_loss / len(data), metric
Exemplo n.º 18
0
 def fit_dataloader(self,
                    trn: DataLoader,
                    criterion,
                    optimizer,
                    metric,
                    logger: logging.Logger,
                    history: History,
                    ratio_width=None,
                    gradient_accumulation=1,
                    encoder_grad_norm=None,
                    decoder_grad_norm=None,
                    patience=0.5,
                    eval_trn=False,
                    **kwargs):
     self.model.train()
     encoder_optimizer, encoder_scheduler, decoder_optimizers = optimizer
     timer = CountdownTimer(len(trn))
     total_loss = 0
     self.reset_metrics(metric)
     model = self.model_
     encoder_parameters = model.encoder.parameters()
     decoder_parameters = model.decoders.parameters()
     for idx, (task_name, batch) in enumerate(trn):
         decoder_optimizer = decoder_optimizers.get(task_name, None)
         output_dict, _ = self.feed_batch(batch, task_name)
         loss = self.compute_loss(batch, output_dict[task_name]['output'], criterion[task_name],
                                  self.tasks[task_name])
         if gradient_accumulation and gradient_accumulation > 1:
             loss /= gradient_accumulation
         loss.backward()
         total_loss += float(loss.item())
         if history.step(gradient_accumulation):
             if self.config.get('grad_norm', None):
                 clip_grad_norm(model, self.config.grad_norm)
             if encoder_grad_norm:
                 torch.nn.utils.clip_grad_norm_(encoder_parameters, encoder_grad_norm)
             if decoder_grad_norm:
                 torch.nn.utils.clip_grad_norm_(decoder_parameters, decoder_grad_norm)
             encoder_optimizer.step()
             encoder_optimizer.zero_grad()
             encoder_scheduler.step()
             if decoder_optimizer:
                 if isinstance(decoder_optimizer, tuple):
                     decoder_optimizer, decoder_scheduler = decoder_optimizer
                 else:
                     decoder_scheduler = None
                 decoder_optimizer.step()
                 decoder_optimizer.zero_grad()
                 if decoder_scheduler:
                     decoder_scheduler.step()
         if eval_trn:
             self.decode_output(output_dict, batch, task_name)
             self.update_metrics(batch, output_dict, metric, task_name)
         timer.log(self.report_metrics(total_loss / (timer.current + 1), metric if eval_trn else None),
                   ratio_percentage=None,
                   ratio_width=ratio_width,
                   logger=logger)
         del loss
         del output_dict
     return total_loss / timer.total
Exemplo n.º 19
0
    def evaluate_dataloader(self,
                            data,
                            criterion,
                            logger=None,
                            ratio_width=None,
                            metric=None,
                            output=None,
                            **kwargs):
        self.model.eval()
        if isinstance(output, str):
            output = open(output, 'w')

        loss = 0
        if not metric:
            metric = self.build_metric()
        else:
            metric.reset()
        timer = CountdownTimer(len(data))
        for idx, batch in enumerate(data):
            logits, mask = self.feed_batch(batch)
            y = batch['tag_id']
            loss += self.compute_loss(criterion, logits, y, mask).item()
            prediction = self.decode_output(logits, mask, batch)
            self.update_metrics(metric, logits, y, mask, batch, prediction)
            if output:
                self.write_prediction(prediction, batch, output)
            timer.log(f'loss: {loss / (idx + 1):.4f} {metric}',
                      ratio_percentage=False,
                      logger=logger,
                      ratio_width=ratio_width)
        loss /= len(data)
        if output:
            output.close()
        return float(loss), metric
Exemplo n.º 20
0
 def fit_dataloader(self,
                    trn: DataLoader,
                    criterion,
                    optimizer,
                    metric,
                    logger: logging.Logger,
                    history: History,
                    gradient_accumulation=1,
                    grad_norm=None,
                    transformer_grad_norm=None,
                    teacher: Tagger = None,
                    kd_criterion=None,
                    temperature_scheduler=None,
                    ratio_width=None,
                    eval_trn=True,
                    **kwargs):
     optimizer, scheduler = optimizer
     if teacher:
         scheduler, lambda_scheduler = scheduler
     else:
         lambda_scheduler = None
     self.model.train()
     timer = CountdownTimer(
         history.num_training_steps(
             len(trn), gradient_accumulation=gradient_accumulation))
     total_loss = 0
     for idx, batch in enumerate(trn):
         out, mask = self.feed_batch(batch)
         y = batch['tag_id']
         loss = self.compute_loss(criterion, out, y, mask)
         if gradient_accumulation and gradient_accumulation > 1:
             loss /= gradient_accumulation
         if teacher:
             with torch.no_grad():
                 out_T, _ = teacher.feed_batch(batch)
             # noinspection PyNoneFunctionAssignment
             kd_loss = self.compute_distill_loss(kd_criterion, out, out_T,
                                                 mask,
                                                 temperature_scheduler)
             _lambda = float(lambda_scheduler)
             loss = _lambda * loss + (1 - _lambda) * kd_loss
         loss.backward()
         total_loss += loss.item()
         if eval_trn:
             prediction = self.decode_output(out, mask, batch)
             self.update_metrics(metric, out, y, mask, batch, prediction)
         if history.step(gradient_accumulation):
             self._step(optimizer, scheduler, grad_norm,
                        transformer_grad_norm, lambda_scheduler)
             report = f'loss: {total_loss / (idx + 1):.4f} {metric if eval_trn else ""}'
             timer.log(report,
                       logger=logger,
                       ratio_percentage=False,
                       ratio_width=ratio_width)
         del loss
         del out
         del mask
Exemplo n.º 21
0
 def build(self, del_dataloader_in_memory=True, verbose=HANLP_VERBOSE):
     timer = CountdownTimer(self.size)
     for i, batch in enumerate(self.dataset):
         filename = self._filename(i)
         torch.save(batch, filename)
         if verbose:
             timer.log(
                 f'Caching {filename} [blink][yellow]...[/yellow][/blink]')
     if del_dataloader_in_memory:
         del self.dataset
Exemplo n.º 22
0
 def build_vocabs(self, trn, logger, **kwargs):
     self.vocabs.tag = Vocab(pad_token=None, unk_token=None)
     timer = CountdownTimer(len(trn))
     max_seq_len = 0
     token_key = self.config.token_key
     for each in trn:
         max_seq_len = max(max_seq_len, len(each[token_key]))
         timer.log(f'Building vocab [blink][yellow]...[/yellow][/blink] (longest sequence: {max_seq_len})')
     self.vocabs.tag.set_unk_as_safe_unk()
     self.vocabs.lock()
     self.vocabs.summary(logger)
Exemplo n.º 23
0
 def execute_training_loop(self,
                           trn: DataLoader,
                           dev: DataLoader,
                           epochs,
                           criterion,
                           optimizer,
                           metric,
                           save_dir,
                           logger: logging.Logger,
                           devices,
                           ratio_width=None,
                           patience=5,
                           teacher=None,
                           kd_criterion=None,
                           eval_trn=True,
                           **kwargs):
     best_epoch, best_metric = 0, -1
     timer = CountdownTimer(epochs)
     history = History()
     for epoch in range(1, epochs + 1):
         logger.info(f"[yellow]Epoch {epoch} / {epochs}:[/yellow]")
         self.fit_dataloader(trn,
                             criterion,
                             optimizer,
                             metric,
                             logger,
                             history=history,
                             ratio_width=ratio_width,
                             eval_trn=eval_trn,
                             **self.config)
         loss, dev_metric = self.evaluate_dataloader(
             dev, criterion, logger=logger, ratio_width=ratio_width)
         timer.update()
         report = f"{timer.elapsed_human} / {timer.total_time_human} ETA: {timer.eta_human}"
         if dev_metric > best_metric:
             best_epoch, best_metric = epoch, dev_metric
             self.save_weights(save_dir)
             report += ' [red](saved)[/red]'
         else:
             report += f' ({epoch - best_epoch})'
             if epoch - best_epoch >= patience:
                 report += ' early stop'
         logger.info(report)
         if epoch - best_epoch >= patience:
             break
     if not best_epoch:
         self.save_weights(save_dir)
     elif best_epoch != epoch:
         self.load_weights(save_dir)
     logger.info(f"Max score of dev is {best_metric} at epoch {best_epoch}")
     logger.info(
         f"Average time of each epoch is {timer.elapsed_average_human}")
     logger.info(f"{timer.elapsed_human} elapsed")
     return best_metric
Exemplo n.º 24
0
 def build_vocabs(self, trn, logger, **kwargs):
     self.vocabs.chart = VocabWithNone(pad_token=None, unk_token=None)
     timer = CountdownTimer(len(trn))
     max_seq_len = 0
     for each in trn:
         max_seq_len = max(max_seq_len, len(each['token_input_ids']))
         timer.log(
             f'Building vocab [blink][yellow]...[/yellow][/blink] (longest sequence: {max_seq_len})'
         )
     self.vocabs.chart.set_unk_as_safe_unk()
     self.vocabs.lock()
     self.vocabs.summary(logger)
Exemplo n.º 25
0
def clean_ctb_bracketed(ctb_root, out_root):
    os.makedirs(out_root, exist_ok=True)
    ctb_root = join(ctb_root, 'bracketed')
    chtbs = _list_treebank_root(ctb_root)
    timer = CountdownTimer(len(chtbs))
    for f in chtbs:
        with open(join(ctb_root, f), encoding='utf-8') as src, open(join(out_root, f + '.txt'), 'w',
                                                                    encoding='utf-8') as out:
            for line in src:
                if not line.strip().startswith('<'):
                    out.write(line)
        timer.log('Cleaning up CTB [blink][yellow]...[/yellow][/blink]', erase=False)
Exemplo n.º 26
0
    def evaluate_dataloader(self,
                            loader: PadSequenceDataLoader,
                            criterion,
                            logger=None,
                            filename=None,
                            output=False,
                            ratio_width=None,
                            metric=None,
                            **kwargs):
        self.model.eval()

        loss = 0
        if not metric:
            metric = self.build_metric()
        if output:
            fp = open(output, 'w')
            predictions, build_data, data, order = self.before_outputs(None)

        timer = CountdownTimer(len(loader))
        use_pos = self.use_pos
        for batch in loader:
            arc_scores, rel_scores, mask, puncts = self.feed_batch(batch)
            if output:
                self.collect_outputs(arc_scores, rel_scores, mask, batch,
                                     predictions, order, data, use_pos,
                                     build_data)
            arcs, rels = batch['arc'], batch['rel_id']
            loss += self.compute_loss(arc_scores, rel_scores, arcs, rels, mask,
                                      criterion, batch).item()
            arc_preds, rel_preds = self.decode(arc_scores, rel_scores, mask,
                                               batch)
            self.update_metric(arc_preds, rel_preds, arcs, rels, mask, puncts,
                               metric, batch)
            report = self._report(loss / (timer.current + 1), metric)
            if filename:
                report = f'{os.path.basename(filename)} ' + report
            timer.log(report,
                      ratio_percentage=False,
                      logger=logger,
                      ratio_width=ratio_width)
        loss /= len(loader)
        if output:
            outputs = self.post_outputs(predictions, data, order, use_pos,
                                        build_data)
            for each in outputs:
                fp.write(f'{each}\n\n')
            fp.close()
            logger.info(
                f'Predictions saved in [underline][yellow]{output}[/yellow][/underline]'
            )

        return loss, metric
Exemplo n.º 27
0
    def build_vocabs(self, dataset, logger=None, transformer=None):
        rel_vocab = self.vocabs.get('rel', None)
        if rel_vocab is None:
            rel_vocab = Vocab(unk_token=None,
                              pad_token=self.config.get('pad_rel', None))
            self.vocabs.put(rel=rel_vocab)
        if self.config.get('feat', None) == 'pos' or self.config.get(
                'use_pos', False):
            self.vocabs['pos'] = Vocab(unk_token=None, pad_token=None)

        timer = CountdownTimer(len(dataset))
        if transformer:
            token_vocab = None
        else:
            token_vocab = Vocab()
            self.vocabs.token = token_vocab
            unk = self.config.get('unk', None)
            if unk is not None:
                token_vocab.unk_token = unk
        if token_vocab and self.config.get('min_freq', None):
            counter = Counter()
            for sample in dataset:
                for form in sample['token']:
                    counter[form] += 1
            reserved_token = [token_vocab.pad_token, token_vocab.unk_token]
            if ROOT in token_vocab:
                reserved_token.append(ROOT)
            freq_words = reserved_token + [
                token for token, freq in counter.items()
                if freq >= self.config.min_freq
            ]
            token_vocab.token_to_idx.clear()
            for word in freq_words:
                token_vocab(word)
        else:
            for i, sample in enumerate(dataset):
                timer.log('vocab building [blink][yellow]...[/yellow][/blink]',
                          ratio_percentage=True)
        rel_vocab.set_unk_as_safe_unk()  # Some relation in dev set is OOV
        self.vocabs.lock()
        self.vocabs.summary(logger=logger)
        if token_vocab:
            self.config.n_words = len(self.vocabs['token'])
        if 'pos' in self.vocabs:
            self.config.n_feats = len(self.vocabs['pos'])
            self.vocabs['pos'].set_unk_as_safe_unk()
        self.config.n_rels = len(self.vocabs['rel'])
        if token_vocab:
            self.config.pad_index = self.vocabs['token'].pad_idx
            self.config.unk_index = self.vocabs['token'].unk_idx
Exemplo n.º 28
0
 def build_dataloader(self, data, batch_size, shuffle, device, text_a_key, text_b_key,
                      label_key,
                      logger: logging.Logger = None,
                      sorting=True,
                      **kwargs) -> DataLoader:
     if not batch_size:
         batch_size = self.config.batch_size
     dataset = self.build_dataset(data)
     dataset.append_transform(self.vocabs)
     if self.vocabs.mutable:
         if not any([text_a_key, text_b_key]):
             if len(dataset.headers) == 2:
                 self.config.text_a_key = dataset.headers[0]
                 self.config.label_key = dataset.headers[1]
             elif len(dataset.headers) >= 3:
                 self.config.text_a_key, self.config.text_b_key, self.config.label_key = dataset.headers[0], \
                                                                                         dataset.headers[1], \
                                                                                         dataset.headers[-1]
             else:
                 raise ValueError('Wrong dataset format')
             report = {'text_a_key', 'text_b_key', 'label_key'}
             report = dict((k, self.config[k]) for k in report)
             report = [f'{k}={v}' for k, v in report.items() if v]
             report = ', '.join(report)
             logger.info(f'Guess [bold][blue]{report}[/blue][/bold] according to the headers of training dataset: '
                         f'[blue]{dataset}[/blue]')
         self.build_vocabs(dataset, logger)
         dataset.purge_cache()
     # if self.config.transform:
     #     dataset.append_transform(self.config.transform)
     dataset.append_transform(TransformerTextTokenizer(tokenizer=self.transformer_tokenizer,
                                                       text_a_key=self.config.text_a_key,
                                                       text_b_key=self.config.text_b_key,
                                                       max_seq_length=self.config.max_seq_length,
                                                       truncate_long_sequences=self.config.truncate_long_sequences,
                                                       output_key=''))
     batch_sampler = None
     if sorting and not isdebugging():
         if dataset.cache and len(dataset) > 1000:
             timer = CountdownTimer(len(dataset))
             lens = []
             for idx, sample in enumerate(dataset):
                 lens.append(len(sample['input_ids']))
                 timer.log('Pre-processing and caching dataset [blink][yellow]...[/yellow][/blink]',
                           ratio_percentage=None)
         else:
             lens = [len(sample['input_ids']) for sample in dataset]
         batch_sampler = SortingSampler(lens, batch_size=batch_size, shuffle=shuffle,
                                        batch_max_tokens=self.config.batch_max_tokens)
     return PadSequenceDataLoader(dataset, batch_size, shuffle, batch_sampler=batch_sampler, device=device)
Exemplo n.º 29
0
def make_ctb_tasks(chtbs, out_root, part):
    for task in ['cws', 'pos', 'par', 'dep']:
        os.makedirs(join(out_root, task), exist_ok=True)
    timer = CountdownTimer(len(chtbs))
    par_path = join(out_root, 'par', f'{part}.txt')
    with open(join(out_root, 'cws', f'{part}.txt'), 'w', encoding='utf-8') as cws, \
            open(join(out_root, 'pos', f'{part}.tsv'), 'w', encoding='utf-8') as pos, \
            open(par_path, 'w', encoding='utf-8') as par:
        for f in chtbs:
            with open(f, encoding='utf-8') as src:
                content = src.read()
                trees = split_str_to_trees(content)
                for tree in trees:
                    try:
                        tree = Tree.fromstring(tree)
                    except ValueError:
                        print(tree)
                        exit(1)
                    words = []
                    for word, tag in tree.pos():
                        if tag == '-NONE-' or not tag:
                            continue
                        tag = tag.split('-')[0]
                        if tag == 'X':  # 铜_NN 30_CD x_X 25_CD x_X 14_CD cm_NT 1999_NT
                            tag = 'FW'
                        pos.write('{}\t{}\n'.format(word, tag))
                        words.append(word)
                    cws.write(' '.join(words))
                    par.write(tree.pformat(margin=sys.maxsize))
                    for fp in cws, pos, par:
                        fp.write('\n')
            timer.log(
                f'Preprocesing the [blue]{part}[/blue] set of CTB [blink][yellow]...[/yellow][/blink]',
                erase=False)
    remove_all_ec(par_path)
    dep_path = join(out_root, 'dep', f'{part}.conllx')
    convert_to_stanford_dependency_330(par_path, dep_path)
    sents = list(read_conll(dep_path))
    with open(dep_path, 'w') as out:
        for sent in sents:
            for i, cells in enumerate(sent):
                tag = cells[3]
                tag = tag.split('-')[0]  # NT-SHORT ---> NT
                if tag == 'X':  # 铜_NN 30_CD x_X 25_CD x_X 14_CD cm_NT 1999_NT
                    tag = 'FW'
                cells[3] = cells[4] = tag
                out.write('\t'.join(str(x) for x in cells))
                out.write('\n')
            out.write('\n')
Exemplo n.º 30
0
 def evaluate_dataloader(self, data: DataLoader, criterion: Callable, metric, logger, ratio_width=None,
                         filename=None, **kwargs):
     self.model.eval()
     timer = CountdownTimer(len(data))
     total_loss = 0
     metric.reset()
     for idx, batch in enumerate(data):
         pred, mask = self.feed_batch(batch)
         loss = self.compute_loss(criterion, pred, batch['srl_id'], mask)
         total_loss += loss.item()
         prediction = self.decode_output(pred, mask, batch)
         self.update_metrics(metric, prediction, batch)
         report = f'loss: {total_loss / (idx + 1):.4f} {metric}'
         timer.log(report, logger=logger, ratio_percentage=False, ratio_width=ratio_width)
     return total_loss / timer.total, metric