def eval(self): pred = self.model.predict(self.X_val, self.batch_size, verbose=1) dices = dice_all(self.y_val, pred.argmax(-1), n_classes=self.n_classes, ignore_zero=True) return dices
def evaluate(pred, true, n_classes, ignore_zero=False): pred = pred_to_class(pred, img_dims=3, has_batch_dim=False) return dice_all(y_true=true, y_pred=pred, ignore_zero=ignore_zero, n_classes=n_classes, skip_if_no_y=False)
def run_pred_and_eval(dataset, out_dir, model, model_func, hparams, args, logger): """ Run evaluation (predict + evaluate) on a all entries of a SleepStudyDataset Args: dataset: A SleepStudyDataset object storing one or more SleepStudy objects out_dir: Path to directory that will store predictions and evaluation results model: An initialized model used for prediction model_func: A callable that returns an initialized model for pred. hparams: An YAMLHparams object storing all hyperparameters args: Passed command-line arguments logger: A Logger object """ from mpunet.evaluate.metrics import dice_all, class_wise_kappa from utime.evaluation.dataframe import (get_eval_df, add_to_eval_df, log_eval_df, with_grand_mean_col) logger("\nPREDICTING ON {} STUDIES".format(len(dataset.pairs))) seq = get_sequencer(dataset, hparams) # Prepare evaluation data frames dice_eval_df = get_eval_df(seq) kappa_eval_df = get_eval_df(seq) # Predict on all samples for i, sleep_study_pair in enumerate(dataset): id_ = sleep_study_pair.identifier logger("[{}/{}] Predicting on SleepStudy: {}".format(i+1, len(dataset), id_)) # Predict with logger.disabled_in_context(), sleep_study_pair.loaded_in_context(): y, pred = predict_on(study_pair=sleep_study_pair, seq=seq, model=model, model_func=model_func, n_aug=args.num_test_time_augment) if args.wake_trim_min: # Trim long periods of wake in start/end of true & prediction from utime.bin.cm import wake_trim y, pred = wake_trim(pairs=[[y, pred]], wake_trim_min=args.wake_trim_min, period_length_sec=dataset.period_length_sec)[0] if not args.no_save: # Save the output save_dir = os.path.join(out_dir, "files/{}".format(id_)) save(pred, fname=os.path.join(save_dir, "pred.npz")) if not args.no_save_true: save(y, fname=os.path.join(save_dir, "true.npz")) # Evaluate: dice scores dice_pr_class = dice_all(y, pred, n_classes=seq.n_classes, ignore_zero=False, smooth=0) logger("-- Dice scores: {}".format(np.round(dice_pr_class, 4))) add_to_eval_df(dice_eval_df, id_, values=dice_pr_class) # Evaluate: kappa kappa_pr_class = class_wise_kappa(y, pred, n_classes=seq.n_classes, ignore_zero=False) logger("-- Kappa scores: {}".format(np.round(kappa_pr_class, 4))) add_to_eval_df(kappa_eval_df, id_, values=kappa_pr_class) # Flag dependent evaluations: if args.plot_hypnograms: plot_hypnogram(out_dir, pred, id_, true=y) if args.plot_CMs: plot_cm(out_dir, pred, y, seq.n_classes, id_) # Log eval to file and screen dice_eval_df = with_grand_mean_col(dice_eval_df) log_eval_df(dice_eval_df.T, out_csv_file=os.path.join(out_dir, "evaluation_dice.csv"), out_txt_file=os.path.join(out_dir, "evaluation_dice.txt"), logger=logger, round=4, txt="EVALUATION DICE SCORES") kappa_eval_df = with_grand_mean_col(kappa_eval_df) log_eval_df(kappa_eval_df.T, out_csv_file=os.path.join(out_dir, "evaluation_kappa.csv"), out_txt_file=os.path.join(out_dir, "evaluation_kappa.txt"), logger=logger, round=4, txt="EVALUATION KAPPA SCORES")