def plot_periods(self, period_idxs=None, period_secs=None, out_path=None, highlight_periods=True): """ Plot multiple periods of data by indices or seconds Args: period_idxs: Indices for all periods to plot period_secs: The starting seconds of the periods to plot out_path: Path to save the figure to highlight_periods: Plot period-separating vertical lines """ if bool(period_idxs) == bool(period_secs): raise ValueError("Must specify either period_idxs or period_secs.") from utime.visualization.psg_plotting import plot_periods period_secs = list(period_secs or map(self.period_idx_to_sec, period_idxs)) if any(np.diff(period_secs) != self.period_length_sec): raise ValueError("Periods to plot must be consecutive.") x, y = zip(*map(self.get_period_at_sec, period_secs)) plot_periods( X=x, y=[Defaults.get_class_int_to_stage_string()[y_] for y_ in y], channel_names=self.select_channels, init_second=period_secs[0], sample_rate=self.sample_rate, out_path=out_path, highlight_periods=highlight_periods)
def assert_project_folder(project_folder, evaluation=False): """ Raises RuntimeError if a folder 'project_folder' does not seem to be a valid U-Time folder in the training phase (evaluation=False) or evaluation phase (evaluation=True). Args: project_folder: A path to a folder to check for U-Time compat. evaluation: Should the folder adhere to train- or eval time checks. """ import os import glob project_folder = os.path.abspath(project_folder) if not os.path.exists(Defaults.get_hparams_path(project_folder)): # Folder must contain a 'hparams.yaml' file in all cases. raise RuntimeError("Folder {} is not a valid project folder." " Must contain a 'hparams.yaml' " "file.".format(project_folder)) if evaluation: # Folder must contain a 'model' subfolder storing saved model files model_path = os.path.join(project_folder, "model") if not os.path.exists(model_path): raise RuntimeError("Folder {} is not a valid project " "folder. Must contain a 'model' " "subfolder.".format(project_folder)) # There must be a least 1 model file (.h5) in the folder models = glob.glob(os.path.join(model_path, "*.h5")) if not models: raise RuntimeError("Did not find any model parameter files in " "model subfolder {}. Model files should have" " extension '.h5' to " "be recognized.".format(project_folder))
def __init__(self, annotation_dict, period_length_sec, no_hypnogram, logger=None): """ TODO Args: annotation_dict: period_length_sec: no_hypnogram: logger: """ self.logger = logger or ScreenLogger() self.annotation_dict = annotation_dict self.no_hypnogram = no_hypnogram self.period_length_sec = period_length_sec or \ Defaults.get_default_period_length(self.logger) # Hidden attributes controlled in property functions to limit setting # of these values to the load() function self._psg = None self._hypnogram = None self._select_channels = None self._alternative_select_channels = None
def get_translated_triplet_rules(translation_map=None): translation_map = translation_map or Defaults.get_stage_string_to_class_int( ) translation_map = np.vectorize(translation_map.get) new_map = [] for s1, s2 in TRIPLET_RULES: s1 = translation_map(s1) s2 = translation_map(s2) new_map.append((s1, s2)) return new_map
def run(args, return_prediction=False, dump_args=None): """ Run the script according to args - Please refer to the argparser. """ assert_args(args) # Check project folder is valid from utime.utils.scriptutils.scriptutils import assert_project_folder project_dir = os.path.abspath(args.project_dir) assert_project_folder(project_dir, evaluation=True) # Get a logger logger = get_logger(project_dir, True, name="prediction_log") if dump_args: logger("Args dump: \n{}".format(vars(args))) # Get hyperparameters and init all described datasets from utime.hyperparameters import YAMLHParams hparams = YAMLHParams(Defaults.get_hparams_path(project_dir), logger, no_version_control=True) # Get the sleep study logger("Loading and pre-processing PSG file...") hparams['prediction_params']['channels'] = args.channels study, channel_groups = get_sleep_study(psg_path=args.f, logger=logger, **hparams['prediction_params']) # Set GPU and get model set_gpu_vis(args.num_GPUs, args.force_GPU, logger) hparams["build"]["data_per_prediction"] = args.data_per_prediction logger("Predicting with {} data per prediction".format( args.data_per_prediction)) model = get_and_load_one_shot_model( n_periods=study.n_periods, project_dir=project_dir, hparams=hparams, logger=logger, weights_file_name=hparams.get_from_anywhere('weight_file_name')) logger("Predicting...") pred = predict_study(study, model, channel_groups, args.no_argmax, logger) logger("--> Predicted shape: {}".format(pred.shape)) if return_prediction: return pred else: save_prediction(pred=pred, out_path=args.o, input_file_path=study.psg_file_path, logger=logger)
def replace_before_with(scores, before_stage, replace, to, stage_string_to_class_int_map=None): scores = scores.copy() if isinstance(before_stage, str): stage_map = stage_string_to_class_int_map or Defaults.get_stage_string_to_class_int( ) before_stage, replace, to = map(lambda s: stage_map[s], (before_stage, replace, to)) first_appearence = np.where(scores == before_stage)[0] if len(first_appearence): scores_considered = scores[:first_appearence[0]] scores_considered = np.where(scores_considered == replace, to, scores_considered) scores[:first_appearence[0]] = scores_considered return scores
def standardize_stage_string(stage_string): """ Attempts to map a string representing a sleep stage of which some (unambiguous) variability is allowed to a fixed, standardised string representation Standardized strings: "W" (awake) "N1" (non-rem sleep stage 1) "N2" (non-rem sleep stage 2) "N3" (non-rem sleep stage 3) "REM" (REM sleep) "UNKNOWN" (all other) Args: stage_string: A string representing a sleep stage Returns: string, A standardized string representing a sleep stage """ ss = stage_string.strip().upper() # Check various types of matches matches = [] for match_func in (check_number_match, check_wake_match, check_REM_match, check_unknown_match): possible_match, match_value = match_func(ss) if possible_match: matches.append(match_value) # If exactly 1 match was found, return this, otherwise raise an error n_matches = len(matches) if n_matches == 1: match = matches[0] print("[OBS]: Mapping variable stage string '{:s}' to stage '{}' " "(class int {})".format(stage_string, match, Defaults.get_stage_string_to_class_int()[match])) return match elif n_matches == 0: raise_match_error(ss, "Found no valid matches.") else: raise_match_error(ss, "Found multiple ({}) " "valid matches {}.".format(n_matches, matches))
def plot_period(self, period_idx=None, period_sec=None, out_path=None): """ Plot a period of data by index or second Args: period_idx: Period index to plot period_sec: The starting second of the period to plot out_path: Path to save the figure to """ if bool(period_idx) == bool(period_sec): raise ValueError("Must specify either period_idx or period_sec.") from utime.visualization.psg_plotting import plot_period period_sec = period_sec or self.period_idx_to_sec(period_idx) x, y = self.get_period_at_sec(period_sec) plot_period(X=x, y=Defaults.get_class_int_to_stage_string()[y], channel_names=self.select_channels, init_second=period_sec, sample_rate=self.sample_rate, out_path=out_path)
def check_number_match(ss): possible_match, match_value = False, None # Check for number referring to the sleep stage numbers = list(map(int, re.findall(r"\d+", ss))) if len(numbers) not in (0, 1): raise_match_error(ss, "Found multiple numbers in string") elif len(numbers) == 1: num = numbers[0] valid_map = {1: Defaults.NON_REM_STAGE_1[0], 2: Defaults.NON_REM_STAGE_2[0], 3: Defaults.NON_REM_STAGE_3[0], 4: Defaults.NON_REM_STAGE_3[0]} assert np.all(np.in1d(list(valid_map.values()), list(Defaults.get_stage_string_to_class_int().keys()))) if num in valid_map: possible_match = True match_value = valid_map[num] else: raise_match_error(ss, "Found invalid number {} in string".format(num)) return possible_match, match_value
def __init__(self, identifier, pairs=None, period_length_sec=None, logger=None, no_log=False): """ Initialize a SleepStudyDataset from a directory storing one or more sub-directories each corresponding to a sleep/PSG study. Each sub-dir will be represented by a SleepStudy object. Args: pairs: TODO period_length_sec: (int) Ground truth segmentation period length in seconds. identifier: (string) Dataset ID/name logger: (Logger) A Logger object no_log: (bool) Do not log dataset details on init """ self.logger = logger or ScreenLogger() self._identifier = identifier self._id_to_study = None self._study_identifiers = None self._pairs = pairs or [] self._misc = {} # May store arbitrary properties for this dataset self.period_length_sec = (period_length_sec or Defaults.get_default_period_length( self.logger)) # Get list of subject folders in the data_dir according to folder_regex if len(np.unique([p.identifier for p in self.pairs])) != len(self.pairs): raise RuntimeError("Two or more SleepStudy objects share the same" " identifier, but all must be unique.") if self.pairs: self.update_id_to_study_dict() if not no_log: self.log()
def plot_confusion_matrix(y_true, y_pred, n_classes, normalize=False, id_=None, cmap="Blues"): """ Adapted from sklearn 'plot_confusion_matrix.py'. This function prints and plots the confusion matrix. Normalization can be applied by setting `normalize=True`. """ from sklearn.metrics import confusion_matrix from sklearn.utils.multiclass import unique_labels if normalize: title = 'Normalized confusion matrix for identifier {}'.format( id_ or "???") else: title = 'Confusion matrix, without normalization for identifier {}' \ ''.format(id_ or "???") # Compute confusion matrix classes = np.arange(n_classes) cm = confusion_matrix(y_true, y_pred) classes = classes[unique_labels(y_true, y_pred)] if normalize: cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] # Get transformed labels from utime import Defaults labels = [Defaults.get_class_int_to_stage_string()[i] for i in classes] fig, ax = plt.subplots() im = ax.imshow(cm, interpolation='nearest', cmap=plt.get_cmap(cmap)) ax.figure.colorbar(im, ax=ax) # We want to show all ticks... ax.set( xticks=np.arange(cm.shape[1]), yticks=np.arange(cm.shape[0]), # ... and label them with the respective list entries xticklabels=labels, yticklabels=labels, title=title, ylabel='True label', xlabel='Predicted label') # Rotate the tick labels and set their alignment. plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") # Loop over data dimensions and create text annotations. fmt = '.3f' if normalize else 'd' thresh = cm.max() / 2. for i in range(cm.shape[0]): for j in range(cm.shape[1]): ax.text(j, i, format(cm[i, j], fmt), ha="center", va="center", color="white" if cm[i, j] > thresh else "black") fig.tight_layout() return fig, ax
def run(args): """ Run the script according to args - Please refer to the argparser. """ assert_args(args) # Check project folder is valid from utime.utils.scriptutils import (assert_project_folder, get_splits_from_all_datasets) project_dir = os.path.abspath(args.project_dir) assert_project_folder(project_dir, evaluation=True) # Prepare output dir out_dir = get_out_dir(args.out_dir, args.data_split) prepare_output_dir(out_dir, args.overwrite) logger = get_logger(out_dir, args.overwrite) logger("Args dump: \n{}".format(vars(args))) # Get hyperparameters and init all described datasets from utime.hyperparameters import YAMLHParams hparams = YAMLHParams(Defaults.get_hparams_path(project_dir), logger) if args.channels: hparams["select_channels"] = args.channels hparams["channel_sampling_groups"] = None logger("Evaluating using channels {}".format(args.channels)) # Get model set_gpu_vis(args.num_GPUs, args.force_GPU, logger) model, model_func = None, None if args.one_shot: # Model is initialized for each sleep study later def model_func(n_periods): return get_and_load_one_shot_model(n_periods, project_dir, hparams, logger, args.weights_file_name) else: model = get_and_load_model(project_dir, hparams, logger, args.weights_file_name) # Run predictions on all datasets datasets = get_splits_from_all_datasets(hparams=hparams, splits_to_load=(args.data_split,), logger=logger) eval_dirs = [] for dataset in datasets: dataset = dataset[0] if "/" in dataset.identifier: # Multiple datasets, separate results into sub-folders ds_out_dir = os.path.join(out_dir, dataset.identifier.split("/")[0]) if not os.path.exists(ds_out_dir): os.mkdir(ds_out_dir) eval_dirs.append(ds_out_dir) else: ds_out_dir = out_dir logger("[*] Running eval on dataset {}\n" " Out dir: {}".format(dataset, ds_out_dir)) run_pred_and_eval(dataset=dataset, out_dir=ds_out_dir, model=model, model_func=model_func, hparams=hparams, args=args, logger=logger) if len(eval_dirs) > 1: cross_dataset_eval(eval_dirs, out_dir)
def stage_string_to_class(stage_string): return Defaults.get_stage_string_to_class_int()[stage_string.upper()]
def run(args): """ Run the script according to args - Please refer to the argparser. args: args: (Namespace) command-line arguments """ from mpunet.logging import Logger from utime.hyperparameters import YAMLHParams from utime.utils.scriptutils import assert_project_folder from utime.utils.scriptutils import get_splits_from_all_datasets project_dir = os.path.abspath("./") assert_project_folder(project_dir) # Get logger object logger = Logger(project_dir + "/preprocessing_logs", active_file='preprocessing', overwrite_existing=args.overwrite, no_sub_folder=True) logger("Args dump: {}".format(vars(args))) # Load hparams hparams = YAMLHParams(Defaults.get_hparams_path(project_dir), logger=logger, no_version_control=True) # Initialize and load (potentially multiple) datasets datasets = get_splits_from_all_datasets(hparams, splits_to_load=args.dataset_splits, logger=logger, return_data_hparams=True) # Check if file exists, and overwrite if specified if os.path.exists(args.out_path): if args.overwrite: os.remove(args.out_path) else: from sys import exit logger("Out file at {} exists, and --overwrite was not set." "".format(args.out_path)) exit(0) # Create dataset hparams output directory out_dir = Defaults.get_pre_processed_data_configurations_dir(project_dir) if not os.path.exists(out_dir): os.mkdir(out_dir) with ThreadPoolExecutor(args.num_threads) as pool: with h5py.File(args.out_path, "w") as h5_file: for dataset, dataset_hparams in datasets: # Create a new version of the dataset-specific hyperparameters # that contain only the fields needed for pre-processed data name = dataset[0].identifier.split("/")[0] hparams_out_path = os.path.join(out_dir, name + ".yaml") copy_dataset_hparams(dataset_hparams, hparams_out_path) # Update paths to dataset hparams in main hparams file hparams.set_value(subdir='datasets', name=name, value=hparams_out_path, overwrite=True) # Save the hyperparameters to the pre-processed main hparams hparams.save_current( Defaults.get_pre_processed_hparams_path(project_dir)) # Process each dataset for split in dataset: # Add this split to the dataset-specific hparams add_dataset_entry(hparams_out_path, args.out_path, split.identifier.split("/")[-1].lower(), split.period_length_sec) # Overwrite potential load time channel sampler to None split.set_load_time_channel_sampling_groups(None) # Create dataset group split_group = h5_file.create_group(split.identifier) # Run the preprocessing process_func = partial(preprocess_study, split_group) logger.print_to_screen = True logger("Preprocessing dataset:", split) logger.print_to_screen = False n_pairs = len(split.pairs) for i, _ in enumerate(pool.map(process_func, split.pairs)): print(" {}/{}".format(i + 1, n_pairs), end='\r', flush=True) print("")
def plot_hypnogram(hyp_array, true_hyp_array=None, seconds_per_epoch=30, annotation_dict=None, show_f1_scores=True, wake_trim_min=None, order=("N3", "N2", "N1", "REM", "W")): """ Plot a ndarray hypnogram of integers, 'hyp_array', optionally on top of an expert annotated hypnogram 'true_hyp_array'. Args: hyp_array: ndarray, shape [N] true_hyp_array: ndarray, shape [N] (optional, default=None) seconds_per_epoch: integer, default=30 annotation_dict: dict, integer -> stage string mapping show_f1_scores: bool, annotate the figure with f1 scores, only if true_hyp_array is set wake_trim_min: None or integer, if set a number of max minutes before/proceeding the first/last non-wake sleep stage in TRUE array to consider/'zoom' in on. order: list-like of strings, order of sleep stages on plot y-axis Returns: fig, axes """ rows = 1 if true_hyp_array is None else 2 hight = 3 if true_hyp_array is None else 4 fig, axes = plt.subplots(nrows=rows, figsize=(10, hight), sharex=True, sharey=True) if not isinstance(axes, (list, np.ndarray)): axes = [axes] # Map classes to default string classes annotation_dict = annotation_dict or Defaults.get_class_int_to_stage_string( ) str_to_int_map = {value: key for key, value in annotation_dict.items()} # Define range in hours x_hours = np.array([seconds_per_epoch * i for i in range(len(hyp_array))]) / 3600 if wake_trim_min and true_hyp_array is not None: trim = int((60 / seconds_per_epoch) * wake_trim_min) inds = np.where(true_hyp_array != str_to_int_map["W"])[0] start = max(0, inds[0] - trim) end = inds[-1] + trim hyp_array = hyp_array[start:end] true_hyp_array = true_hyp_array[start:end] x_hours = x_hours[start:end] reordered_hyp_array = get_reordered_hypnogram(hyp_array, annotation_dict, order) axes[0].step(x_hours, reordered_hyp_array, where='post', color="black", label="Predicted") # Set ylabels for ax in axes: ax.set_yticks(range(len(order))) ax.set_yticklabels(order) ax.tick_params(axis='both', which='major', labelsize=14) ax.spines['right'].set_visible(False) ax.spines['top'].set_visible(False) if true_hyp_array is not None: fig.text(x=0.02, y=0.56, s="Sleep Stage", ha="center", va="center", fontdict={"size": 16}, rotation=90) else: axes[0].set_ylabel("Sleep Stage", labelpad=12, size=16) axes[-1].set_xlabel("Time (hours)", size=16, labelpad=12) axes[-1].set_xlim(x_hours[0], x_hours[-1]) fig.tight_layout() if true_hyp_array is not None: reordered_true = get_reordered_hypnogram(true_hyp_array, annotation_dict, order) axes[1].step(x_hours, reordered_true, where='post', color="darkred", label="Expert") if show_f1_scores: from sklearn.metrics import f1_score labels = [str_to_int_map[w] for w in reversed(order)] mask = np.isin(true_hyp_array.flatten(), np.array(labels).flatten()) f1s = f1_score(true_hyp_array[mask], hyp_array[mask], labels=labels, average=None) f1s = [round(l, 2) for l in (list(f1s) + [np.mean(f1s)])] f1_labels = list(reversed(order)) + ["Mean"] # Plot title fig.text(x=0.845, y=0.66, s="F1-scores", ha="left", va="top", fontdict={ "alpha": 1, "size": 14 }) for i, (stage, value) in enumerate(zip(f1_labels, f1s)): # Plot stage fig.text(x=0.845, y=0.58 - (0.05 * i), s=f"{stage}", ha="left", va="top", fontdict={ "alpha": 1, "size": 14 }) # Plot score fig.text(x=0.9475, y=0.58 - (0.05 * i), s="{:.2f}".format(value), ha="left", va="top", fontdict={ "alpha": 1, "size": 14 }) lines_labels = [ax.get_legend_handles_labels() for ax in fig.axes] lines, labels = [sum(l, []) for l in zip(*lines_labels)] l = fig.legend(lines, labels, loc='right', bbox_to_anchor=(1.01, 0.89), ncol=1, fontsize=14) l.get_frame().set_linewidth(0) fig.subplots_adjust(hspace=0.13, right=0.81, left=0.10) return fig, axes