def evaluate_dataloader(self, data, criterion, logger=None, ratio_width=None, metric=None, output=None, **kwargs): self.model.eval() total_loss = 0 if not metric: metric = self.build_metric() else: metric.reset() timer = CountdownTimer(len(data)) for idx, batch in enumerate(data): out, mask = self.feed_batch(batch) y = batch['chart_id'] loss, span_probs = self.compute_loss(out, y, mask) total_loss += loss.item() prediction = self.decode_output(out, mask, batch, span_probs) self.update_metrics(metric, batch, prediction) timer.log(f'loss: {total_loss / (idx + 1):.4f} {metric}', ratio_percentage=False, logger=logger, ratio_width=ratio_width) total_loss /= len(data) if output: output.close() return total_loss, metric
def fit_dataloader(self, trn: DataLoader, criterion, optimizer, metric, logger: logging.Logger, history: History = None, gradient_accumulation=1, ratio_percentage=None, **kwargs): optimizer, scheduler = optimizer self.model.train() timer = CountdownTimer( history.num_training_steps( len(trn), gradient_accumulation=gradient_accumulation)) total_loss = 0 for batch in trn: output_dict = self.feed_batch(batch) loss = output_dict['loss'] if gradient_accumulation and gradient_accumulation > 1: loss /= gradient_accumulation loss.backward() total_loss += loss.item() if history.step(gradient_accumulation): self._step(optimizer, scheduler) timer.log(self.report_metrics(total_loss / (timer.current + 1)), ratio_percentage=ratio_percentage, logger=logger) del loss del output_dict return total_loss / max(timer.total, 1)
def build_vocabs(self, dataset, logger=None, transformer=False): rel_vocab = self.vocabs.get('rel', None) if rel_vocab is None: rel_vocab = Vocab(unk_token=None, pad_token=self.config.get('pad_rel', None)) self.vocabs.put(rel=rel_vocab) timer = CountdownTimer(len(dataset)) if transformer: token_vocab = None else: self.vocabs.token = token_vocab = VocabCounter(unk_token=self.config.get('unk', UNK)) for i, sample in enumerate(dataset): timer.log('Building vocab [blink][yellow]...[/yellow][/blink]', ratio_percentage=True) min_freq = self.config.get('min_freq', None) if min_freq: token_vocab.trim(min_freq) rel_vocab.set_unk_as_safe_unk() # Some relation in dev set is OOV self.vocabs.lock() self.vocabs.summary(logger=logger) if token_vocab: self.config.n_words = len(self.vocabs['token']) self.config.n_rels = len(self.vocabs['rel']) if token_vocab: self.config.pad_index = self.vocabs['token'].pad_idx self.config.unk_index = self.vocabs['token'].unk_idx
def compute_lens(self, data: Union[List[Dict[str, Any]], str], dataset: TransformDataset, input_ids='token_input_ids', length_field='token'): """ Args: data: Samples to be measured or path to dataset during training time. dataset: During training time, use this dataset to measure the length of each sample inside. input_ids: Field name corresponds to input ids. length_field: Fall back to this field during prediction as input_ids may not be generated yet. Returns: Length list of this samples """ if isinstance(data, str): if not dataset.cache: warnings.warn( f'Caching for the dataset is not enabled, ' f'try `dataset.purge_cache()` if possible. The dataset is {dataset}.' ) timer = CountdownTimer(len(dataset)) for each in dataset: timer.log( 'Preprocessing and caching samples [blink][yellow]...[/yellow][/blink]' ) timer.erase() return [len(x[input_ids]) for x in dataset] return [len(x[length_field]) for x in data]
def evaluate_dataloader(self, data, criterion, logger=None, ratio_width=None, metric=None, output=None, **kwargs): self.model.eval() if isinstance(output, str): output = open(output, 'w') loss = 0 if not metric: metric = self.build_metric() else: metric.reset() timer = CountdownTimer(len(data)) for idx, batch in enumerate(data): logits, mask = self.feed_batch(batch) y = batch['tag_id'] loss += self.compute_loss(criterion, logits, y, mask).item() prediction = self.decode_output(logits, mask, batch) self.update_metrics(metric, logits, y, mask, batch, prediction) if output: self.write_prediction(prediction, batch, output) timer.log(f'loss: {loss / (idx + 1):.4f} {metric}', ratio_percentage=False, logger=logger, ratio_width=ratio_width) loss /= len(data) if output: output.close() return float(loss), metric
def execute_training_loop(self, trn: DataLoader, dev: DataLoader, epochs, criterion, optimizer, metric, save_dir, logger: logging.Logger, devices, **kwargs): best_epoch, best_metric = 0, -1 timer = CountdownTimer(epochs) ratio_width = len(f'{len(trn)}/{len(trn)}') for epoch in range(1, epochs + 1): logger.info(f"[yellow]Epoch {epoch} / {epochs}:[/yellow]") self.fit_dataloader(trn, criterion, optimizer, metric, logger) if dev: self.evaluate_dataloader(dev, criterion, metric, logger, ratio_width=ratio_width) report = f'{timer.elapsed_human}/{timer.total_time_human}' dev_score = metric.score if dev_score > best_metric: self.save_weights(save_dir) best_metric = dev_score report += ' [red]saved[/red]' timer.log(report, ratio_percentage=False, newline=True, ratio=False)
def execute_training_loop(self, trn: DataLoader, dev: DataLoader, epochs, criterion, optimizer, metric, save_dir, logger: logging.Logger, devices, gradient_accumulation=1, **kwargs): best_epoch, best_metric = 0, -1 optimizer, scheduler = optimizer history = History() timer = CountdownTimer(epochs) ratio_width = len(f'{len(trn)}/{len(trn)}') for epoch in range(1, epochs + 1): logger.info(f"[yellow]Epoch {epoch} / {epochs}:[/yellow]") self.fit_dataloader(trn, criterion, optimizer, metric, logger, history=history, gradient_accumulation=gradient_accumulation, linear_scheduler=scheduler if self._get_transformer() else None) if dev: self.evaluate_dataloader(dev, criterion, metric, logger, ratio_width=ratio_width) report = f'{timer.elapsed_human}/{timer.total_time_human}' dev_score = metric.score if not self._get_transformer(): scheduler.step(dev_score) if dev_score > best_metric: self.save_weights(save_dir) best_metric = dev_score report += ' [red]saved[/red]' timer.log(report, ratio_percentage=False, newline=True, ratio=False)
def evaluate_dataloader(self, data: DataLoader, criterion: Callable, metric, logger, ratio_width=None, output=False, **kwargs): self.model.eval() self.reset_metrics(metric) timer = CountdownTimer(len(data)) total_loss = 0 if output: fp = open(output, 'w') for batch in data: output_dict = self.feed_batch(batch) if output: for sent, pred, gold in zip(batch['token'], output_dict['prediction'], batch['ner']): fp.write('Tokens\t' + ' '.join(sent) + '\n') fp.write('Pred\t' + '\t'.join( ['[' + ' '.join(sent[x:y + 1]) + f']/{label}' for x, y, label in pred]) + '\n') fp.write('Gold\t' + '\t'.join( ['[' + ' '.join(sent[x:y + 1]) + f']/{label}' for x, y, label in gold]) + '\n') fp.write('\n') self.update_metrics(batch, output_dict, metric) loss = output_dict['loss'] total_loss += loss.item() timer.log(self.report_metrics(total_loss / (timer.current + 1), metric), ratio_percentage=None, logger=logger, ratio_width=ratio_width) del loss if output: fp.close() return total_loss / timer.total, metric
def fit_dataloader(self, trn: DataLoader, criterion, optimizer, metric, logger: logging.Logger, linear_scheduler=None, history: History = None, gradient_accumulation=1, **kwargs): self.model.train() timer = CountdownTimer(history.num_training_steps(len(trn), gradient_accumulation=gradient_accumulation)) total_loss = 0 self.reset_metrics(metric) for batch in trn: optimizer.zero_grad() output_dict = self.feed_batch(batch) self.update_metrics(batch, output_dict, metric) loss = output_dict['loss'] if gradient_accumulation and gradient_accumulation > 1: loss /= gradient_accumulation loss.backward() total_loss += loss.item() if history.step(gradient_accumulation): if self.config.grad_norm: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_norm) optimizer.step() if linear_scheduler: linear_scheduler.step() timer.log(self.report_metrics(total_loss / (timer.current + 1), metric), ratio_percentage=None, logger=logger) del loss return total_loss / timer.total
def build_dataset(self, data, logger: logging.Logger = None, training=True): dataset = AbstractMeaningRepresentationDataset( data, generate_idx=not training) if self.vocabs.mutable: self.build_vocabs(dataset, logger) self.vocabs.lock() self.vocabs.summary(logger) lens = [len(x['token']) + len(x['amr']) for x in dataset] dataset.append_transform( functools.partial(get_concepts, vocab=self.vocabs.predictable_concept, rel_vocab=self.vocabs.rel if self.config.get( 'separate_rel', False) else None)) dataset.append_transform(append_bos) # Tokenization will happen in batchify if not self.config.get('squeeze', None): dataset.append_transform(self.config.encoder.transform()) if isinstance(data, str): dataset.purge_cache() timer = CountdownTimer(len(dataset)) for each in dataset: timer.log( 'Caching samples [blink][yellow]...[/yellow][/blink]') return dataset, lens
def fit_dataloader(self, trn: DataLoader, criterion, optimizer, metric, logger: logging.Logger, linear_scheduler=None, gradient_accumulation=1, **kwargs): self.model.train() timer = CountdownTimer(len(trn) // gradient_accumulation) total_loss = 0 self.reset_metrics(metric) for idx, batch in enumerate(trn): output_dict = self.feed_batch(batch) self.update_metrics(batch, output_dict, metric) loss = output_dict['loss'] loss = loss.sum() # For data parallel loss.backward() if gradient_accumulation and gradient_accumulation > 1: loss /= gradient_accumulation if self.config.grad_norm: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_norm) if (idx + 1) % gradient_accumulation == 0: self._step(optimizer, linear_scheduler) timer.log(self.report_metrics(total_loss / (timer.current + 1), metric), ratio_percentage=None, logger=logger) total_loss += loss.item() del loss if len(trn) % gradient_accumulation: self._step(optimizer, linear_scheduler) return total_loss / timer.total
def evaluate_dataloader(self, loader: PadSequenceDataLoader, criterion, logger=None, filename=None, output=False, ratio_width=None, metric=None, **kwargs): self.model.eval() total_loss = 0 if not metric: metric = self.build_metric() timer = CountdownTimer(len(loader)) for batch in loader: (s_arc, s_sib, s_rel), mask, puncts = self.feed_batch(batch) arcs, sibs, rels = batch['arc'], batch['sib_id'], batch['rel_id'] loss, s_arc = self.compute_loss(s_arc, s_sib, s_rel, arcs, sibs, rels, mask) total_loss += float(loss) arc_preds, rel_preds = self.decode(s_arc, s_sib, s_rel, mask) self.update_metric(arc_preds, rel_preds, arcs, rels, mask, puncts, metric) report = self._report(total_loss / (timer.current + 1), metric) if filename: report = f'{os.path.basename(filename)} ' + report timer.log(report, ratio_percentage=False, logger=logger, ratio_width=ratio_width) total_loss /= len(loader) return total_loss, metric
def evaluate_dataloader(self, data: MultiTaskDataLoader, criterion, metric: MetricDict, logger, ratio_width=None, input: str = None, **kwargs): self.model.eval() self.reset_metrics(metric) tasks_need_custom_eval = self.config.get('tasks_need_custom_eval', None) tasks_need_custom_eval = tasks_need_custom_eval or {} tasks_need_custom_eval = dict( (k, None) for k in tasks_need_custom_eval) for each in tasks_need_custom_eval: tasks_need_custom_eval[each] = data.dataloaders.pop(each) timer = CountdownTimer(len(data) + len(tasks_need_custom_eval)) total_loss = 0 for idx, (task_name, batch) in enumerate(data): output_dict, _ = self.feed_batch(batch, task_name) loss = self.compute_loss(batch, output_dict[task_name]['output'], criterion[task_name], self.tasks[task_name]) total_loss += loss.item() self.decode_output(output_dict, batch, task_name) self.update_metrics(batch, output_dict, metric, task_name) timer.log(self.report_metrics(total_loss / (timer.current + 1), metric), ratio_percentage=None, logger=logger, ratio_width=ratio_width) del loss del output_dict for task_name, dataset in tasks_need_custom_eval.items(): task = self.tasks[task_name] decoder = self.model_.decoders[task_name] task.evaluate_dataloader( dataset, task.build_criterion(decoder=decoder), metric=metric[task_name], input=task.dev if input == 'dev' else task.tst, split=input, decoder=decoder, h=functools.partial(self._encode, task_name=task_name, cls_is_bos=task.cls_is_bos, sep_is_eos=task.sep_is_eos)) data.dataloaders[task_name] = dataset timer.log(self.report_metrics(total_loss / (timer.current + 1), metric), ratio_percentage=None, logger=logger, ratio_width=ratio_width) return total_loss / timer.total, metric, data
def execute_training_loop(self, trn: PrefetchDataLoader, dev: PrefetchDataLoader, epochs, criterion, optimizer, metric, save_dir, logger: logging.Logger, devices, ratio_width=None, dev_data=None, gradient_accumulation=1, **kwargs): best_epoch, best_metric = 0, -1 timer = CountdownTimer(epochs) history = History() try: for epoch in range(1, epochs + 1): logger.info(f"[yellow]Epoch {epoch} / {epochs}:[/yellow]") trn = self.fit_dataloader( trn, criterion, optimizer, metric, logger, ratio_width=ratio_width, gradient_accumulation=gradient_accumulation, history=history, save_dir=save_dir) report = f'{timer.elapsed_human}/{timer.total_time_human}' if epoch % self.config.eval_every == 0 or epoch == epochs: metric = self.evaluate_dataloader(dev, logger, dev_data, ratio_width=ratio_width, save_dir=save_dir, use_fast=True) if metric > best_metric: self.save_weights(save_dir) best_metric = metric best_epoch = epoch report += ' [red]saved[/red]' timer.log(report, ratio_percentage=False, newline=True, ratio=False) if best_epoch and best_epoch != epochs: logger.info( f'Restored the best model with {best_metric} saved {epochs - best_epoch} epochs ago' ) self.load_weights(save_dir) finally: trn.close() dev.close()
def fit_dataloader(self, trn: DataLoader, criterion, optimizer, metric, logger: logging.Logger, history: History, gradient_accumulation=1, grad_norm=None, transformer_grad_norm=None, teacher: Tagger = None, kd_criterion=None, temperature_scheduler=None, ratio_width=None, **kwargs): optimizer, scheduler = optimizer if teacher: scheduler, lambda_scheduler = scheduler else: lambda_scheduler = None self.model.train() timer = CountdownTimer( history.num_training_steps( len(trn), gradient_accumulation=gradient_accumulation)) total_loss = 0 for idx, batch in enumerate(trn): out, mask = self.feed_batch(batch) y = batch['tag_id'] loss = self.compute_loss(criterion, out, y, mask) if gradient_accumulation and gradient_accumulation > 1: loss /= gradient_accumulation if teacher: with torch.no_grad(): out_T, _ = teacher.feed_batch(batch) # noinspection PyNoneFunctionAssignment kd_loss = self.compute_distill_loss(kd_criterion, out, out_T, mask, temperature_scheduler) _lambda = float(lambda_scheduler) loss = _lambda * loss + (1 - _lambda) * kd_loss loss.backward() total_loss += loss.item() prediction = self.decode_output(out, mask, batch) self.update_metrics(metric, out, y, mask, batch, prediction) if history.step(gradient_accumulation): self._step(optimizer, scheduler, grad_norm, transformer_grad_norm, lambda_scheduler) report = f'loss: {total_loss / (idx + 1):.4f} {metric}' timer.log(report, logger=logger, ratio_percentage=False, ratio_width=ratio_width) del loss del out del mask
def build_vocabs(self, dataset, logger, vocabs, lock=True, label_vocab_name='label', **kwargs): vocabs[label_vocab_name] = label_vocab = Vocab(pad_token=None, unk_token=None) # Use null to indicate no relationship label_vocab.add('<null>') timer = CountdownTimer(len(dataset)) for each in dataset: timer.log('Building NER vocab [blink][yellow]...[/yellow][/blink]') label_vocab.set_unk_as_safe_unk() if lock: vocabs.lock() vocabs.summary(logger)
def build_vocabs(self, trn, logger, **kwargs): self.vocabs.chart = VocabWithNone(pad_token=None, unk_token=None) timer = CountdownTimer(len(trn)) max_seq_len = 0 for each in trn: max_seq_len = max(max_seq_len, len(each['token_input_ids'])) timer.log( f'Building vocab [blink][yellow]...[/yellow][/blink] (longest sequence: {max_seq_len})' ) self.vocabs.chart.set_unk_as_safe_unk() self.vocabs.lock() self.vocabs.summary(logger)
def evaluate_dataloader(self, loader: PadSequenceDataLoader, criterion, logger=None, filename=None, output=False, ratio_width=None, metric=None, **kwargs): self.model.eval() loss = 0 if not metric: metric = self.build_metric() if output: fp = open(output, 'w') predictions, build_data, data, order = self.before_outputs(None) timer = CountdownTimer(len(loader)) use_pos = self.use_pos for batch in loader: arc_scores, rel_scores, mask, puncts = self.feed_batch(batch) if output: self.collect_outputs(arc_scores, rel_scores, mask, batch, predictions, order, data, use_pos, build_data) arcs, rels = batch['arc'], batch['rel_id'] loss += self.compute_loss(arc_scores, rel_scores, arcs, rels, mask, criterion, batch).item() arc_preds, rel_preds = self.decode(arc_scores, rel_scores, mask, batch) self.update_metric(arc_preds, rel_preds, arcs, rels, mask, puncts, metric, batch) report = self._report(loss / (timer.current + 1), metric) if filename: report = f'{os.path.basename(filename)} ' + report timer.log(report, ratio_percentage=False, logger=logger, ratio_width=ratio_width) loss /= len(loader) if output: outputs = self.post_outputs(predictions, data, order, use_pos, build_data) for each in outputs: fp.write(f'{each}\n\n') fp.close() logger.info( f'Predictions saved in [underline][yellow]{output}[/yellow][/underline]' ) return loss, metric
def build_vocabs(self, trn, logger, **kwargs): self.vocabs.tag = Vocab(pad_token=None, unk_token=None) timer = CountdownTimer(len(trn)) max_seq_len = 0 token_key = self.config.token_key for each in trn: max_seq_len = max(max_seq_len, len(each[token_key])) timer.log( f'Building vocab [blink][yellow]...[/yellow][/blink] (longest sequence: {max_seq_len})' ) self.vocabs.tag.set_unk_as_safe_unk() self.vocabs.lock() self.vocabs.summary(logger)
def build_vocabs(self, dataset, logger=None, transformer=None): rel_vocab = self.vocabs.get('rel', None) if rel_vocab is None: rel_vocab = Vocab(unk_token=None, pad_token=self.config.get('pad_rel', None)) self.vocabs.put(rel=rel_vocab) if self.config.get('feat', None) == 'pos' or self.config.get( 'use_pos', False): self.vocabs['pos'] = Vocab(unk_token=None, pad_token=None) timer = CountdownTimer(len(dataset)) if transformer: token_vocab = None else: token_vocab = Vocab() self.vocabs.token = token_vocab unk = self.config.get('unk', None) if unk is not None: token_vocab.unk_token = unk if token_vocab and self.config.get('min_freq', None): counter = Counter() for sample in dataset: for form in sample['token']: counter[form] += 1 reserved_token = [token_vocab.pad_token, token_vocab.unk_token] if ROOT in token_vocab: reserved_token.append(ROOT) freq_words = reserved_token + [ token for token, freq in counter.items() if freq >= self.config.min_freq ] token_vocab.token_to_idx.clear() for word in freq_words: token_vocab(word) else: for i, sample in enumerate(dataset): timer.log('vocab building [blink][yellow]...[/yellow][/blink]', ratio_percentage=True) rel_vocab.set_unk_as_safe_unk() # Some relation in dev set is OOV self.vocabs.lock() self.vocabs.summary(logger=logger) if token_vocab: self.config.n_words = len(self.vocabs['token']) if 'pos' in self.vocabs: self.config.n_feats = len(self.vocabs['pos']) self.vocabs['pos'].set_unk_as_safe_unk() self.config.n_rels = len(self.vocabs['rel']) if token_vocab: self.config.pad_index = self.vocabs['token'].pad_idx self.config.unk_index = self.vocabs['token'].unk_idx
def evaluate_dataloader(self, data: DataLoader, criterion: Callable, metric, logger, ratio_width=None, filename=None, output=None, **kwargs): self.model.eval() timer = CountdownTimer(len(data)) total_loss = 0 metric.reset() num_samples = 0 if output: output = open(output, 'w') for batch in data: logits = self.feed_batch(batch) target = batch['label_id'] loss = self.compute_loss(criterion, logits, target, batch) total_loss += loss.item() label_ids = self.update_metric(metric, logits, target, output) if output: labels = [ self.vocabs[self.config.label_key].idx_to_token[i] for i in label_ids.tolist() ] for i, label in enumerate(labels): # text_a text_b pred gold columns = [batch[self.config.text_a_key][i]] if self.config.text_b_key: columns.append(batch[self.config.text_b_key][i]) columns.append(label) columns.append(batch[self.config.label_key][i]) output.write('\t'.join(columns)) output.write('\n') num_samples += len(target) report = f'loss: {total_loss / (timer.current + 1):.4f} acc: {metric.get_metric():.2%}' if filename: report = f'{filename} {report} {num_samples / timer.elapsed:.0f} samples/sec' timer.log(report, ratio_percentage=None, logger=logger, ratio_width=ratio_width) if output: output.close() return total_loss / timer.total
def build_vocabs(self, dataset, logger, **kwargs): self.vocabs.srl_label = Vocab(pad_token=None, unk_token=None) # Use null to indicate no relationship self.vocabs.srl_label.add('<null>') timer = CountdownTimer(len(dataset)) max_seq_len = 0 for each in dataset: max_seq_len = max(max_seq_len, len(each['token_input_ids'])) timer.log( f'Building vocabs (max sequence length {max_seq_len}) [blink][yellow]...[/yellow][/blink]' ) pass timer.stop() timer.erase() self.vocabs['srl_label'].set_unk_as_safe_unk() self.vocabs.lock() self.vocabs.summary(logger)
def evaluate_dataloader(self, data: DataLoader, criterion: Callable, metric=None, output=False, ratio_width=None, logger=None, input=None, use_fast=False, **kwargs): self.model.eval() timer = CountdownTimer(len(data)) graphs = [] orders = [] smatch = 0 for idx, batch in enumerate(data): graphs_per_batch = self.predict_amrs(batch) graphs_per_batch = [x[0] for x in graphs_per_batch] # Copy meta data from gold graph for gp, gg in zip(graphs_per_batch, batch['amr']): metadata = gg.metadata.copy() metadata['annotator'] = f'{self.config.transformer}-amr' metadata['date'] = str(datetime.datetime.now()) if 'save-date' in metadata: del metadata['save-date'] gp.metadata = metadata graphs.extend(graphs_per_batch) orders.extend(batch[IDX]) if idx == timer.total - 1: graphs = reorder(graphs, orders) write_predictions(output, self._tokenizer, graphs) try: if use_fast: smatch = compute_smatch(output, input) else: smatch = smatch_eval(output, input, use_fast=False) except: pass timer.log(smatch.cstr() if isinstance(smatch, MetricDict) else f'{smatch:.2%}', ratio_percentage=False, logger=logger) else: timer.log(ratio_percentage=False, logger=logger) return smatch
def parse_data(model, pp: PostProcessor, data, input_file, output_file, beam_size=8, alpha=0.6, max_time_step=100, h=None): if not output_file: output_file = tempfile.NamedTemporaryFile().name tot = 0 levi_graph = model.decoder.levi_graph if hasattr(model, 'decoder') else False with open(output_file, 'w') as fo: timer = CountdownTimer(len(data)) order = [] outputs = [] for batch in data: order.extend(batch[IDX]) res = parse_batch(model, batch, beam_size, alpha, max_time_step, h=h) outputs.extend( list(zip(res['concept'], res['relation'], res['score']))) timer.log('Parsing [blink][yellow]...[/yellow][/blink]', ratio_percentage=False) outputs = reorder(outputs, order) timer = CountdownTimer(len(data)) for concept, relation, score in outputs: fo.write('# ::conc ' + ' '.join(concept) + '\n') fo.write('# ::score %.6f\n' % score) fo.write( pp.postprocess(concept, relation, check_connected=levi_graph) + '\n\n') tot += 1 timer.log('Post-processing [blink][yellow]...[/yellow][/blink]', ratio_percentage=False) match(output_file, input_file)
def execute_training_loop(self, trn: DataLoader, dev: DataLoader, epochs, criterion, optimizer, metric, save_dir, logger, patience, **kwargs): max_e, max_metric = 0, -1 criterion = self.build_criterion() timer = CountdownTimer(epochs) ratio_width = len(f'{len(trn)}/{len(trn)}') scheduler = self.build_scheduler(**merge_dict(self.config, optimizer=optimizer, overwrite=True)) if not patience: patience = epochs for epoch in range(1, epochs + 1): logger.info(f"[yellow]Epoch {epoch} / {epochs}:[/yellow]") self.fit_dataloader(trn, criterion, optimizer, metric, logger, ratio_width=ratio_width) loss, dev_metric = self.evaluate_dataloader(dev, criterion, logger) if scheduler: if isinstance(scheduler, ReduceLROnPlateau): scheduler.step(dev_metric.score) else: scheduler.step(epoch) report_patience = f'Patience: {epoch - max_e}/{patience}' # save the model if it is the best so far if dev_metric > max_metric: self.save_weights(save_dir) max_e, max_metric = epoch, dev_metric report_patience = '[red]Saved[/red] ' stop = epoch - max_e >= patience if stop: timer.stop() timer.log(f'{report_patience} lr: {optimizer.param_groups[0]["lr"]:.4f}', ratio_percentage=False, newline=True, ratio=False) if stop: break timer.stop() if max_e != epoch: self.load_weights(save_dir) logger.info(f"Max score of dev is {max_metric.score:.2%} at epoch {max_e}") logger.info(f"{timer.elapsed_human} elapsed, average time of each epoch is {timer.elapsed_average_human}")
def fit_dataloader(self, trn: DataLoader, criterion, optimizer, metric: SpanMetric, logger: logging.Logger, history: History, gradient_accumulation=1, grad_norm=None, ratio_width=None, eval_trn=True, **kwargs): optimizer, scheduler = optimizer metric.reset() self.model.train() timer = CountdownTimer( history.num_training_steps( len(trn), gradient_accumulation=gradient_accumulation)) total_loss = 0 for idx, batch in enumerate(trn): out, mask = self.feed_batch(batch) y = batch['chart_id'] loss, span_probs = self.compute_loss(out, y, mask) if gradient_accumulation and gradient_accumulation > 1: loss /= gradient_accumulation loss.backward() total_loss += loss.item() if eval_trn: prediction = self.decode_output(out, mask, batch, span_probs) self.update_metrics(metric, batch, prediction) if history.step(gradient_accumulation): self._step(optimizer, scheduler, grad_norm) report = f'loss: {total_loss / (idx + 1):.4f} {metric}' if eval_trn \ else f'loss: {total_loss / (idx + 1):.4f}' timer.log(report, logger=logger, ratio_percentage=False, ratio_width=ratio_width) del loss del out del mask
def fit_dataloader(self, trn: DataLoader, criterion, optimizer, metric, logger: logging.Logger, ratio_width=None, **kwargs): self.model.train() timer = CountdownTimer(len(trn)) total_loss = 0 for idx, batch in enumerate(trn): optimizer.zero_grad() out, mask = self.feed_batch(batch) y = batch['tag_id'] loss = self.compute_loss(criterion, out, y, mask) loss.backward() nn.utils.clip_grad_norm_(self.model.parameters(), 5.0) optimizer.step() total_loss += loss.item() prediction = self.decode_output(out, mask, batch) self.update_metrics(metric, out, y, mask, batch, prediction) timer.log(f'loss: {loss / (idx + 1):.4f} {metric}', ratio_percentage=False, logger=logger, ratio_width=ratio_width) del loss del out del mask
def fit_dataloader(self, trn, optimizer, scheduler, criterion, epoch, logger, history: History, transformer_optimizer=None, transformer_scheduler=None, gradient_accumulation=1, eval_trn=False, **kwargs): self.model.train() timer = CountdownTimer(history.num_training_steps(len(trn), gradient_accumulation)) metric = self.build_metric(training=True) total_loss = 0 for idx, batch in enumerate(trn): optimizer.zero_grad() (s_arc, s_sib, s_rel), mask, puncts = self.feed_batch(batch) arcs, sibs, rels = batch['arc'], batch['sib_id'], batch['rel_id'] loss, s_arc = self.compute_loss(s_arc, s_sib, s_rel, arcs, sibs, rels, mask) if gradient_accumulation > 1: loss /= gradient_accumulation loss.backward() total_loss += loss.item() if eval_trn: arc_preds, rel_preds = self.decode(s_arc, s_sib, s_rel, mask) self.update_metric(arc_preds, rel_preds, arcs, rels, mask, puncts, metric) if history.step(gradient_accumulation): self._step(optimizer, scheduler, transformer_optimizer, transformer_scheduler) report = self._report(total_loss / (timer.current + 1), metric if eval_trn else None) lr = scheduler.get_last_lr()[0] report += f' lr: {lr:.4e}' timer.log(report, ratio_percentage=False, logger=logger) del loss
def fit_dataloader(self, trn: DataLoader, criterion, optimizer, metric, logger: logging.Logger, **kwargs): self.model.train() timer = CountdownTimer(len(trn)) optimizer, scheduler = optimizer total_loss = 0 metric.reset() for batch in trn: optimizer.zero_grad() logits = self.feed_batch(batch) target = batch['label_id'] loss = self.compute_loss(criterion, logits, target, batch) loss.backward() optimizer.step() scheduler.step() total_loss += loss.item() self.update_metric(metric, logits, target) timer.log( f'loss: {total_loss / (timer.current + 1):.4f} acc: {metric.get_metric():.2%}', ratio_percentage=None, logger=logger) del loss return total_loss / timer.total
def fit_dataloader(self, trn: DataLoader, criterion, optimizer, metric, logger: logging.Logger, **kwargs): self.model.train() timer = CountdownTimer(len(trn)) total_loss = 0 self.reset_metrics(metric) for batch in trn: optimizer.zero_grad() prediction = self.feed_batch(batch) loss = self.compute_loss(prediction, batch, criterion) self.update_metrics(batch, prediction, metric) loss.backward() if self.config.grad_norm: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_norm) optimizer.step() total_loss += loss.item() timer.log(self.report_metrics(total_loss / (timer.current + 1), metric), ratio_percentage=None, logger=logger) del loss return total_loss / timer.total