def test_metric_summary(): y_true = np.array([[0., 1., 1., 1.], [1., 0., 0., 0.]]) y_scores = np.array([[0.1, 0.6, 0.75, 0.8], [0.1, 0.6, 0.75, 0.8]]) f_max, auc, f_scores, average_precisions, average_recalls, thresholds = metric_summary( y_true, y_scores ) assert len(f_scores) == 10 and len(average_precisions) == 10 and len(average_recalls) == 10 assert f_max == 2/3 assert auc == 0.5 assert (f_scores == [2/3, 0.5, 0.5, 0.5, 0.5, 0.5, 0.4, 0.25, np.nan, np.nan])[:-2].all() assert (average_precisions == [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, np.nan, np.nan])[:-2].all() assert (average_recalls == [1., 0.5, 0.5, 0.5, 0.5, 0.5, 1/3, 1/6, 0., 0.]).all()
def test_epoch_end(self, outputs): if hasattr(self, 'model_checkpoint') and int(os.environ.get('LOCAL_RANK', 0)) == 0: print(f'Symlinking model checkpoint to this runs version directory...') filedir = self.version_directory symlinked_model_path = Path(filedir, 'checkpoints', self.model_checkpoint.split('/')[-1]) os.symlink(Path(self.model_checkpoint).resolve(), symlinked_model_path) hparams_path = Path(Path(self.model_checkpoint).parent.parent.resolve(), 'hparams.yaml') symlinked_hparams_path = Path(filedir, 'hparams.yaml') os.symlink(hparams_path, symlinked_hparams_path) print( f'...done. Symlinked to {self.model_checkpoint} to {symlinked_model_path} and ' f'{hparams_path} to {symlinked_hparams_path}.' ) test_loss_mean, test_acc_mean, test_auc_mean, probs, targets = self.all_gather_outputs(outputs, detach=True).values() self.test_outputs = { 'loss': test_loss_mean, 'acc': test_acc_mean, 'auc': test_auc_mean, 'probs': probs, 'targets': targets } if not hasattr(self, 'labels'): self.labels = self.trainer.datamodule.labels # Needed for confusion matrix labels if self.show_heatmaps: with torch.set_grad_enabled(True): self.visualize_best_and_worst_heatmaps(probs.numpy(), targets.numpy()) if hasattr(self, 'model_checkpoint') and int(os.environ.get('LOCAL_RANK', 0)) == 0: print(f'Deleting model checkpoint and symlink to save memory...') symlinked_model_path.unlink() os.remove(self.model_checkpoint) print( f'...done. Removed link at {symlinked_model_path}, ' f'deleted {self.model_checkpoint}.' ) f_max = metric_summary(targets.cpu().numpy(), probs.cpu().numpy())[0] self.logger.plot_confusion_matrix( targets.cpu().numpy(), (probs > 0.5).cpu().numpy(), self.labels ) self.logger.plot_roc(targets.long().cpu().numpy(), probs.cpu().numpy(), self.labels) self.log_dict( { 'test_epoch_loss': test_loss_mean, 'test_epoch_acc': test_acc_mean, 'test_epoch_auc': test_auc_mean, 'test_epoch_f_max': f_max }, on_step=False, on_epoch=True )
def ensemble_outputs(all_test_outputs, model): ensemble_acc = torch.stack([x['acc'] for outputs in all_test_outputs for x in outputs]).mean() ensemble_auc = torch.stack([x['auc'] for outputs in all_test_outputs for x in outputs]).mean() ensemble_probs = torch.stack( [torch.cat([x['probs'] for x in outputs]).cpu() for outputs in all_test_outputs], dim=-1 ).mean(-1) targets = torch.cat([x['targets'] for x in all_test_outputs[0]]).cpu() ensemble_f_max, ensemble_auc, _, _, _, _ = metric_summary( targets.cpu().numpy(), ensemble_probs.cpu().numpy() ) model.logger.plot_confusion_matrix(targets, (ensemble_probs > 0.5).cpu().numpy(), model.labels) model.logger.plot_roc(targets.long().cpu().numpy(), ensemble_probs.cpu().numpy(), model.labels) metrics = { 'acc': ensemble_acc, 'auc': ensemble_auc, 'f_max': ensemble_f_max } model.logger.log_hyperparams(model.hparams, metrics) return metrics
def validation_epoch_end(self, outputs): val_loss_mean, val_acc_mean, val_auc_mean, probs, targets = self.all_gather_outputs(outputs, detach=True).values() ( f_max, _, f_scores, average_precisions, average_recalls, thresholds ) = metric_summary(targets.cpu().numpy(), probs.cpu().numpy()) self.logger.log_metrics( { "Validation/val_loss": val_loss_mean, "Validation/val_acc": val_acc_mean, "Validation/val_auc": val_auc_mean, }, self.validation_log_epoch + 1 ) try: self.logger.add_scalars( "Losses", {"val_loss": val_loss_mean}, self.validation_log_epoch + 1 ) self.logger.add_scalars( "Accuracies", {"val_acc": val_acc_mean}, self.validation_log_epoch + 1 ) self.logger.add_scalars( "AUCs", {"val_auc": val_auc_mean}, self.validation_log_epoch + 1 ) self.logger.add_scalars( "F1 Max Scores", {"val_f1-max": f_max}, self.validation_log_epoch + 1 ) except AttributeError as e: print(f'In (validation) LR find, error ignored: {str(e)}') metric_appendix = {} for f1_score, average_precision, average_recall, threshold in zip( f_scores, average_precisions, average_recalls, thresholds ): metric_appendix.update({ f'Validation Metric Appendix/F1-Score ({threshold:0.1f})': f1_score, f'Validation Metric Appendix/Average Precision ({threshold:0.1f})': average_precision, f'Validation Metric Appendix/Average Recall ({threshold:0.1f})': average_recall, }) self.logger.log_metrics(metric_appendix, self.validation_log_epoch + 1) self.update_hyperparams_and_metrics( { 'val_epoch_loss': val_loss_mean, 'val_epoch_acc': val_acc_mean, 'val_epoch_auc': val_auc_mean, 'val_epoch_f_max': torch.tensor(f_max) } ) if self.profiler and len(self.model.timings) > 0: hook_reports = [] for hook_name in self.model.timings: for callback_name, times in self.model.timings[hook_name].items(): times = np.array(times) mean_time, sum_time = times.mean(), times.sum() hook_reports.append({ 'Callback Name': callback_name, 'Mean time (s)': mean_time, 'Sum time (s)': sum_time }) hook_reports = pd.DataFrame(hook_reports).set_index('Callback Name') print(hook_reports) self.model.timings = {} self.validation_log_epoch += 1 return { 'val_epoch_loss': val_loss_mean, 'val_epoch_auc': val_auc_mean, 'val_epoch_f_max': torch.tensor(f_max) }
def training_epoch_end(self, outputs): loss_mean, acc_mean, auc_mean, probs, targets = self.all_gather_outputs(outputs, detach=True).values() ( f_max, _, f_scores, average_precisions, average_recalls, thresholds ) = metric_summary(targets.cpu().numpy(), probs.cpu().numpy()) self.logger.log_metrics( { "Train/train_loss": loss_mean, "Train/train_acc": acc_mean, "Train/train_auc": auc_mean }, self.training_log_epoch + 1 ) try: self.logger.add_scalars( "Losses", {"train_loss": loss_mean}, self.training_log_epoch + 1 ) self.logger.add_scalars( "Accuracies", {"train_acc": acc_mean}, self.training_log_epoch + 1 ) self.logger.add_scalars( "AUCs", {"train_auc": auc_mean}, self.training_log_epoch + 1 ) self.logger.add_scalars( "F1 Max Scores", {"train_f1-max": f_max}, self.training_log_epoch + 1 ) except AttributeError as e: print(f'In (train) LR find, error ignored: {str(e)}') metric_appendix = {} for f1_score, average_precision, average_recall, threshold in zip( f_scores, average_precisions, average_recalls, thresholds ): metric_appendix.update({ f'Train Metric Appendix/F1-Score ({threshold:0.1f})': f1_score, f'Train Metric Appendix/Average Precision ({threshold:0.1f})': average_precision, f'Train Metric Appendix/Average Recall ({threshold:0.1f})': average_recall, }) self.logger.log_metrics(metric_appendix, self.training_log_epoch + 1) if self.profiler and len(self.model.timings) > 0: hook_reports = [] for hook_name in self.model.timings: for callback_name, times in self.model.timings[hook_name].items(): times = np.array(times) mean_time, sum_time = times.mean(), times.sum() hook_reports.append({ 'Callback Name': callback_name, 'Mean time (s)': mean_time, 'Sum time (s)': sum_time }) print(pd.DataFrame(hook_reports)) hook_reports = pd.DataFrame(hook_reports).set_index('Callback Name') print(hook_reports) self.model.timings = {} if self.swa: if not hasattr(self, 'swa_model'): print(f'Initializing SWA model...') self.swa_model = AveragedModel(self.model.model) print(f'...done.') self.swa_model.update_parameters(self.model.model) torch.optim.swa_utils.update_bn( self.trainer.datamodule.train_dataloader(), self.swa_model, device=self.device ) self.training_log_epoch += 1