예제 #1
0
def test_metric_summary():
    y_true = np.array([[0., 1., 1., 1.], [1., 0., 0., 0.]])
    y_scores = np.array([[0.1, 0.6, 0.75, 0.8], [0.1, 0.6, 0.75, 0.8]])
    f_max, auc, f_scores, average_precisions, average_recalls, thresholds = metric_summary(
        y_true, y_scores
    )
    assert len(f_scores) == 10 and len(average_precisions) == 10 and len(average_recalls) == 10
    assert f_max == 2/3
    assert auc == 0.5
    assert (f_scores == [2/3, 0.5, 0.5, 0.5, 0.5, 0.5, 0.4, 0.25, np.nan, np.nan])[:-2].all()
    assert (average_precisions == [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, np.nan, np.nan])[:-2].all()
    assert (average_recalls == [1., 0.5, 0.5, 0.5, 0.5, 0.5, 1/3, 1/6, 0., 0.]).all()
 def test_epoch_end(self, outputs):
     if hasattr(self, 'model_checkpoint') and int(os.environ.get('LOCAL_RANK', 0)) == 0:
         print(f'Symlinking model checkpoint to this runs version directory...')
         filedir = self.version_directory
         symlinked_model_path = Path(filedir, 'checkpoints', self.model_checkpoint.split('/')[-1])
         os.symlink(Path(self.model_checkpoint).resolve(), symlinked_model_path)
         hparams_path = Path(Path(self.model_checkpoint).parent.parent.resolve(), 'hparams.yaml')
         symlinked_hparams_path = Path(filedir, 'hparams.yaml')
         os.symlink(hparams_path, symlinked_hparams_path)
         print(
             f'...done. Symlinked to {self.model_checkpoint} to {symlinked_model_path} and '
             f'{hparams_path} to {symlinked_hparams_path}.'
         )
     test_loss_mean, test_acc_mean, test_auc_mean, probs, targets = self.all_gather_outputs(outputs, detach=True).values()
     self.test_outputs = {
         'loss': test_loss_mean,
         'acc': test_acc_mean,
         'auc': test_auc_mean,
         'probs': probs,
         'targets': targets
     }
     if not hasattr(self, 'labels'):
         self.labels = self.trainer.datamodule.labels  # Needed for confusion matrix labels
     if self.show_heatmaps:
         with torch.set_grad_enabled(True):
             self.visualize_best_and_worst_heatmaps(probs.numpy(), targets.numpy())
         if hasattr(self, 'model_checkpoint') and int(os.environ.get('LOCAL_RANK', 0)) == 0:
             print(f'Deleting model checkpoint and symlink to save memory...')
             symlinked_model_path.unlink()
             os.remove(self.model_checkpoint)
             print(
                 f'...done. Removed link at {symlinked_model_path}, '
                 f'deleted {self.model_checkpoint}.'
             )
     f_max = metric_summary(targets.cpu().numpy(), probs.cpu().numpy())[0]
     self.logger.plot_confusion_matrix(
         targets.cpu().numpy(), (probs > 0.5).cpu().numpy(), self.labels
     )
     self.logger.plot_roc(targets.long().cpu().numpy(), probs.cpu().numpy(), self.labels)
     self.log_dict(
         {
             'test_epoch_loss': test_loss_mean,
             'test_epoch_acc': test_acc_mean,
             'test_epoch_auc': test_auc_mean,
             'test_epoch_f_max': f_max
         },
         on_step=False,
         on_epoch=True
     )
예제 #3
0
def ensemble_outputs(all_test_outputs, model):
    ensemble_acc = torch.stack([x['acc'] for outputs in all_test_outputs for x in outputs]).mean()
    ensemble_auc = torch.stack([x['auc'] for outputs in all_test_outputs for x in outputs]).mean()
    ensemble_probs = torch.stack(
        [torch.cat([x['probs'] for x in outputs]).cpu() for outputs in all_test_outputs],
        dim=-1
    ).mean(-1)
    targets = torch.cat([x['targets'] for x in all_test_outputs[0]]).cpu()
    ensemble_f_max, ensemble_auc, _, _, _, _ = metric_summary(
        targets.cpu().numpy(), ensemble_probs.cpu().numpy()
    )
    model.logger.plot_confusion_matrix(targets, (ensemble_probs > 0.5).cpu().numpy(), model.labels)
    model.logger.plot_roc(targets.long().cpu().numpy(), ensemble_probs.cpu().numpy(), model.labels)
    metrics = {
        'acc': ensemble_acc,
        'auc': ensemble_auc,
        'f_max': ensemble_f_max
    }
    model.logger.log_hyperparams(model.hparams, metrics)
    return metrics
    def validation_epoch_end(self, outputs):
        val_loss_mean, val_acc_mean, val_auc_mean, probs, targets = self.all_gather_outputs(outputs, detach=True).values()
        (
            f_max,
            _,
            f_scores,
            average_precisions,
            average_recalls,
            thresholds
        ) = metric_summary(targets.cpu().numpy(), probs.cpu().numpy())

        self.logger.log_metrics(
            {
                "Validation/val_loss": val_loss_mean,
                "Validation/val_acc": val_acc_mean,
                "Validation/val_auc": val_auc_mean,
            },
            self.validation_log_epoch + 1
        )
        try:
            self.logger.add_scalars(
                "Losses", {"val_loss": val_loss_mean}, self.validation_log_epoch + 1
            )
            self.logger.add_scalars(
                "Accuracies", {"val_acc": val_acc_mean}, self.validation_log_epoch + 1
            )
            self.logger.add_scalars(
                "AUCs", {"val_auc": val_auc_mean}, self.validation_log_epoch + 1
            )
            self.logger.add_scalars(
                "F1 Max Scores", {"val_f1-max": f_max}, self.validation_log_epoch + 1
            )
        except AttributeError as e:
            print(f'In (validation) LR find, error ignored: {str(e)}')
        metric_appendix = {}
        for f1_score, average_precision, average_recall, threshold in zip(
            f_scores, average_precisions, average_recalls, thresholds
        ):
            metric_appendix.update({
                f'Validation Metric Appendix/F1-Score ({threshold:0.1f})': f1_score,
                f'Validation Metric Appendix/Average Precision ({threshold:0.1f})': average_precision,
                f'Validation Metric Appendix/Average Recall ({threshold:0.1f})': average_recall,
            })
        self.logger.log_metrics(metric_appendix, self.validation_log_epoch + 1)
        self.update_hyperparams_and_metrics(
            {
                'val_epoch_loss': val_loss_mean,
                'val_epoch_acc': val_acc_mean,
                'val_epoch_auc': val_auc_mean,
                'val_epoch_f_max': torch.tensor(f_max)
            }
        )
        if self.profiler and len(self.model.timings) > 0:
            hook_reports = []
            for hook_name in self.model.timings:
                for callback_name, times in self.model.timings[hook_name].items():
                    times = np.array(times)
                    mean_time, sum_time = times.mean(), times.sum()
                    hook_reports.append({
                        'Callback Name': callback_name,
                        'Mean time (s)': mean_time,
                        'Sum time (s)': sum_time
                    })
            hook_reports = pd.DataFrame(hook_reports).set_index('Callback Name')
            print(hook_reports)
            self.model.timings = {}
        self.validation_log_epoch += 1
        return {
            'val_epoch_loss': val_loss_mean,
            'val_epoch_auc': val_auc_mean,
            'val_epoch_f_max': torch.tensor(f_max)
        }
    def training_epoch_end(self, outputs):
        loss_mean, acc_mean, auc_mean, probs, targets = self.all_gather_outputs(outputs, detach=True).values()
        (
            f_max,
            _,
            f_scores,
            average_precisions,
            average_recalls,
            thresholds
        ) = metric_summary(targets.cpu().numpy(), probs.cpu().numpy())

        self.logger.log_metrics(
            {
                "Train/train_loss": loss_mean,
                "Train/train_acc": acc_mean,
                "Train/train_auc": auc_mean
            },
            self.training_log_epoch + 1
        )
        try:
            self.logger.add_scalars(
                "Losses", {"train_loss": loss_mean}, self.training_log_epoch + 1
            )
            self.logger.add_scalars(
                "Accuracies", {"train_acc": acc_mean}, self.training_log_epoch + 1
            )
            self.logger.add_scalars(
                "AUCs", {"train_auc": auc_mean}, self.training_log_epoch + 1
            )
            self.logger.add_scalars(
                "F1 Max Scores", {"train_f1-max": f_max}, self.training_log_epoch + 1
            )
        except AttributeError as e:
            print(f'In (train) LR find, error ignored: {str(e)}')
        metric_appendix = {}
        for f1_score, average_precision, average_recall, threshold in zip(
            f_scores, average_precisions, average_recalls, thresholds
        ):
            metric_appendix.update({
                f'Train Metric Appendix/F1-Score ({threshold:0.1f})': f1_score,
                f'Train Metric Appendix/Average Precision ({threshold:0.1f})': average_precision,
                f'Train Metric Appendix/Average Recall ({threshold:0.1f})': average_recall,
            })
        self.logger.log_metrics(metric_appendix, self.training_log_epoch + 1)
        if self.profiler and len(self.model.timings) > 0:
            hook_reports = []
            for hook_name in self.model.timings:
                for callback_name, times in self.model.timings[hook_name].items():
                    times = np.array(times)
                    mean_time, sum_time = times.mean(), times.sum()
                    hook_reports.append({
                        'Callback Name': callback_name,
                        'Mean time (s)': mean_time,
                        'Sum time (s)': sum_time
                    })
            print(pd.DataFrame(hook_reports))
            hook_reports = pd.DataFrame(hook_reports).set_index('Callback Name')
            print(hook_reports)
            self.model.timings = {}

        if self.swa:
            if not hasattr(self, 'swa_model'):
                print(f'Initializing SWA model...')
                self.swa_model = AveragedModel(self.model.model)
                print(f'...done.')
            self.swa_model.update_parameters(self.model.model)
            torch.optim.swa_utils.update_bn(
                self.trainer.datamodule.train_dataloader(), self.swa_model, device=self.device
            )
        self.training_log_epoch += 1