예제 #1
0
    def plot_periods(self,
                     period_idxs=None,
                     period_secs=None,
                     out_path=None,
                     highlight_periods=True):
        """
        Plot multiple periods of data by indices or seconds

        Args:
            period_idxs:        Indices for all periods to plot
            period_secs:        The starting seconds of the periods to plot
            out_path:           Path to save the figure to
            highlight_periods:  Plot period-separating vertical lines
        """
        if bool(period_idxs) == bool(period_secs):
            raise ValueError("Must specify either period_idxs or period_secs.")
        from utime.visualization.psg_plotting import plot_periods
        period_secs = list(period_secs
                           or map(self.period_idx_to_sec, period_idxs))
        if any(np.diff(period_secs) != self.period_length_sec):
            raise ValueError("Periods to plot must be consecutive.")
        x, y = zip(*map(self.get_period_at_sec, period_secs))
        plot_periods(
            X=x,
            y=[Defaults.get_class_int_to_stage_string()[y_] for y_ in y],
            channel_names=self.select_channels,
            init_second=period_secs[0],
            sample_rate=self.sample_rate,
            out_path=out_path,
            highlight_periods=highlight_periods)
예제 #2
0
def assert_project_folder(project_folder, evaluation=False):
    """
    Raises RuntimeError if a folder 'project_folder' does not seem to be a
    valid U-Time folder in the training phase (evaluation=False) or evaluation
    phase (evaluation=True).

    Args:
        project_folder: A path to a folder to check for U-Time compat.
        evaluation:     Should the folder adhere to train- or eval time checks.
    """
    import os
    import glob
    project_folder = os.path.abspath(project_folder)
    if not os.path.exists(Defaults.get_hparams_path(project_folder)):
        # Folder must contain a 'hparams.yaml' file in all cases.
        raise RuntimeError("Folder {} is not a valid project folder."
                           " Must contain a 'hparams.yaml' "
                           "file.".format(project_folder))
    if evaluation:
        # Folder must contain a 'model' subfolder storing saved model files
        model_path = os.path.join(project_folder, "model")
        if not os.path.exists(model_path):
            raise RuntimeError("Folder {} is not a valid project "
                               "folder. Must contain a 'model' "
                               "subfolder.".format(project_folder))
        # There must be a least 1 model file (.h5) in the folder
        models = glob.glob(os.path.join(model_path, "*.h5"))
        if not models:
            raise RuntimeError("Did not find any model parameter files in "
                               "model subfolder {}. Model files should have"
                               " extension '.h5' to "
                               "be recognized.".format(project_folder))
예제 #3
0
    def __init__(self,
                 annotation_dict,
                 period_length_sec,
                 no_hypnogram,
                 logger=None):
        """
        TODO

        Args:
            annotation_dict:
            period_length_sec:
            no_hypnogram:
            logger:
        """
        self.logger = logger or ScreenLogger()
        self.annotation_dict = annotation_dict
        self.no_hypnogram = no_hypnogram
        self.period_length_sec = period_length_sec or \
            Defaults.get_default_period_length(self.logger)

        # Hidden attributes controlled in property functions to limit setting
        # of these values to the load() function
        self._psg = None
        self._hypnogram = None
        self._select_channels = None
        self._alternative_select_channels = None
예제 #4
0
def get_translated_triplet_rules(translation_map=None):
    translation_map = translation_map or Defaults.get_stage_string_to_class_int(
    )
    translation_map = np.vectorize(translation_map.get)
    new_map = []
    for s1, s2 in TRIPLET_RULES:
        s1 = translation_map(s1)
        s2 = translation_map(s2)
        new_map.append((s1, s2))
    return new_map
예제 #5
0
def run(args, return_prediction=False, dump_args=None):
    """
    Run the script according to args - Please refer to the argparser.
    """
    assert_args(args)
    # Check project folder is valid
    from utime.utils.scriptutils.scriptutils import assert_project_folder
    project_dir = os.path.abspath(args.project_dir)
    assert_project_folder(project_dir, evaluation=True)

    # Get a logger
    logger = get_logger(project_dir, True, name="prediction_log")
    if dump_args:
        logger("Args dump: \n{}".format(vars(args)))

    # Get hyperparameters and init all described datasets
    from utime.hyperparameters import YAMLHParams
    hparams = YAMLHParams(Defaults.get_hparams_path(project_dir),
                          logger,
                          no_version_control=True)

    # Get the sleep study
    logger("Loading and pre-processing PSG file...")
    hparams['prediction_params']['channels'] = args.channels
    study, channel_groups = get_sleep_study(psg_path=args.f,
                                            logger=logger,
                                            **hparams['prediction_params'])

    # Set GPU and get model
    set_gpu_vis(args.num_GPUs, args.force_GPU, logger)
    hparams["build"]["data_per_prediction"] = args.data_per_prediction
    logger("Predicting with {} data per prediction".format(
        args.data_per_prediction))
    model = get_and_load_one_shot_model(
        n_periods=study.n_periods,
        project_dir=project_dir,
        hparams=hparams,
        logger=logger,
        weights_file_name=hparams.get_from_anywhere('weight_file_name'))
    logger("Predicting...")
    pred = predict_study(study, model, channel_groups, args.no_argmax, logger)
    logger("--> Predicted shape: {}".format(pred.shape))
    if return_prediction:
        return pred
    else:
        save_prediction(pred=pred,
                        out_path=args.o,
                        input_file_path=study.psg_file_path,
                        logger=logger)
예제 #6
0
def replace_before_with(scores,
                        before_stage,
                        replace,
                        to,
                        stage_string_to_class_int_map=None):
    scores = scores.copy()
    if isinstance(before_stage, str):
        stage_map = stage_string_to_class_int_map or Defaults.get_stage_string_to_class_int(
        )
        before_stage, replace, to = map(lambda s: stage_map[s],
                                        (before_stage, replace, to))
    first_appearence = np.where(scores == before_stage)[0]
    if len(first_appearence):
        scores_considered = scores[:first_appearence[0]]
        scores_considered = np.where(scores_considered == replace, to,
                                     scores_considered)
        scores[:first_appearence[0]] = scores_considered
    return scores
예제 #7
0
def standardize_stage_string(stage_string):
    """
    Attempts to map a string representing a sleep stage of which some
    (unambiguous) variability is allowed to a fixed, standardised
    string representation

    Standardized strings:
        "W" (awake)
        "N1" (non-rem sleep stage 1)
        "N2" (non-rem sleep stage 2)
        "N3" (non-rem sleep stage 3)
        "REM" (REM sleep)
        "UNKNOWN" (all other)

    Args:
        stage_string: A string representing a sleep stage

    Returns:
        string, A standardized string representing a sleep stage
    """
    ss = stage_string.strip().upper()

    # Check various types of matches
    matches = []
    for match_func in (check_number_match, check_wake_match,
                       check_REM_match, check_unknown_match):
        possible_match, match_value = match_func(ss)
        if possible_match:
            matches.append(match_value)

    # If exactly 1 match was found, return this, otherwise raise an error
    n_matches = len(matches)
    if n_matches == 1:
        match = matches[0]
        print("[OBS]: Mapping variable stage string '{:s}' to stage '{}' "
              "(class int {})".format(stage_string, match,
                                      Defaults.get_stage_string_to_class_int()[match]))
        return match
    elif n_matches == 0:
        raise_match_error(ss, "Found no valid matches.")
    else:
        raise_match_error(ss, "Found multiple ({}) "
                              "valid matches {}.".format(n_matches, matches))
예제 #8
0
    def plot_period(self, period_idx=None, period_sec=None, out_path=None):
        """
        Plot a period of data by index or second

        Args:
            period_idx: Period index to plot
            period_sec: The starting second of the period to plot
            out_path:   Path to save the figure to
        """
        if bool(period_idx) == bool(period_sec):
            raise ValueError("Must specify either period_idx or period_sec.")
        from utime.visualization.psg_plotting import plot_period
        period_sec = period_sec or self.period_idx_to_sec(period_idx)
        x, y = self.get_period_at_sec(period_sec)
        plot_period(X=x,
                    y=Defaults.get_class_int_to_stage_string()[y],
                    channel_names=self.select_channels,
                    init_second=period_sec,
                    sample_rate=self.sample_rate,
                    out_path=out_path)
예제 #9
0
def check_number_match(ss):
    possible_match, match_value = False, None

    # Check for number referring to the sleep stage
    numbers = list(map(int, re.findall(r"\d+", ss)))
    if len(numbers) not in (0, 1):
        raise_match_error(ss, "Found multiple numbers in string")
    elif len(numbers) == 1:
        num = numbers[0]
        valid_map = {1: Defaults.NON_REM_STAGE_1[0],
                     2: Defaults.NON_REM_STAGE_2[0],
                     3: Defaults.NON_REM_STAGE_3[0],
                     4: Defaults.NON_REM_STAGE_3[0]}
        assert np.all(np.in1d(list(valid_map.values()),
                              list(Defaults.get_stage_string_to_class_int().keys())))
        if num in valid_map:
            possible_match = True
            match_value = valid_map[num]
        else:
            raise_match_error(ss, "Found invalid number {} in string".format(num))

    return possible_match, match_value
예제 #10
0
    def __init__(self,
                 identifier,
                 pairs=None,
                 period_length_sec=None,
                 logger=None,
                 no_log=False):
        """
        Initialize a SleepStudyDataset from a directory storing one or more
        sub-directories each corresponding to a sleep/PSG study.
        Each sub-dir will be represented by a SleepStudy object.

        Args:
            pairs: TODO
            period_length_sec:       (int)    Ground truth segmentation
                                              period length in seconds.
            identifier:              (string) Dataset ID/name
            logger:                  (Logger) A Logger object
            no_log:                  (bool)   Do not log dataset details on init
        """
        self.logger = logger or ScreenLogger()
        self._identifier = identifier
        self._id_to_study = None
        self._study_identifiers = None
        self._pairs = pairs or []
        self._misc = {}  # May store arbitrary properties for this dataset
        self.period_length_sec = (period_length_sec
                                  or Defaults.get_default_period_length(
                                      self.logger))

        # Get list of subject folders in the data_dir according to folder_regex
        if len(np.unique([p.identifier
                          for p in self.pairs])) != len(self.pairs):
            raise RuntimeError("Two or more SleepStudy objects share the same"
                               " identifier, but all must be unique.")
        if self.pairs:
            self.update_id_to_study_dict()
        if not no_log:
            self.log()
예제 #11
0
def plot_confusion_matrix(y_true,
                          y_pred,
                          n_classes,
                          normalize=False,
                          id_=None,
                          cmap="Blues"):
    """
    Adapted from sklearn 'plot_confusion_matrix.py'.

    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    from sklearn.metrics import confusion_matrix
    from sklearn.utils.multiclass import unique_labels
    if normalize:
        title = 'Normalized confusion matrix for identifier {}'.format(
            id_ or "???")
    else:
        title = 'Confusion matrix, without normalization for identifier {}' \
                ''.format(id_ or "???")

    # Compute confusion matrix
    classes = np.arange(n_classes)
    cm = confusion_matrix(y_true, y_pred)
    classes = classes[unique_labels(y_true, y_pred)]
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

    # Get transformed labels
    from utime import Defaults
    labels = [Defaults.get_class_int_to_stage_string()[i] for i in classes]

    fig, ax = plt.subplots()
    im = ax.imshow(cm, interpolation='nearest', cmap=plt.get_cmap(cmap))
    ax.figure.colorbar(im, ax=ax)
    # We want to show all ticks...
    ax.set(
        xticks=np.arange(cm.shape[1]),
        yticks=np.arange(cm.shape[0]),
        # ... and label them with the respective list entries
        xticklabels=labels,
        yticklabels=labels,
        title=title,
        ylabel='True label',
        xlabel='Predicted label')

    # Rotate the tick labels and set their alignment.
    plt.setp(ax.get_xticklabels(),
             rotation=45,
             ha="right",
             rotation_mode="anchor")

    # Loop over data dimensions and create text annotations.
    fmt = '.3f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            ax.text(j,
                    i,
                    format(cm[i, j], fmt),
                    ha="center",
                    va="center",
                    color="white" if cm[i, j] > thresh else "black")
    fig.tight_layout()
    return fig, ax
예제 #12
0
def run(args):
    """
    Run the script according to args - Please refer to the argparser.
    """
    assert_args(args)
    # Check project folder is valid
    from utime.utils.scriptutils import (assert_project_folder,
                                         get_splits_from_all_datasets)
    project_dir = os.path.abspath(args.project_dir)
    assert_project_folder(project_dir, evaluation=True)

    # Prepare output dir
    out_dir = get_out_dir(args.out_dir, args.data_split)
    prepare_output_dir(out_dir, args.overwrite)
    logger = get_logger(out_dir, args.overwrite)
    logger("Args dump: \n{}".format(vars(args)))

    # Get hyperparameters and init all described datasets
    from utime.hyperparameters import YAMLHParams
    hparams = YAMLHParams(Defaults.get_hparams_path(project_dir), logger)
    if args.channels:
        hparams["select_channels"] = args.channels
        hparams["channel_sampling_groups"] = None
        logger("Evaluating using channels {}".format(args.channels))

    # Get model
    set_gpu_vis(args.num_GPUs, args.force_GPU, logger)
    model, model_func = None, None
    if args.one_shot:
        # Model is initialized for each sleep study later
        def model_func(n_periods):
            return get_and_load_one_shot_model(n_periods,
                                               project_dir,
                                               hparams,
                                               logger,
                                               args.weights_file_name)
    else:
        model = get_and_load_model(project_dir, hparams, logger,
                                   args.weights_file_name)

    # Run predictions on all datasets
    datasets = get_splits_from_all_datasets(hparams=hparams,
                                            splits_to_load=(args.data_split,),
                                            logger=logger)
    eval_dirs = []
    for dataset in datasets:
        dataset = dataset[0]
        if "/" in dataset.identifier:
            # Multiple datasets, separate results into sub-folders
            ds_out_dir = os.path.join(out_dir,
                                      dataset.identifier.split("/")[0])
            if not os.path.exists(ds_out_dir):
                os.mkdir(ds_out_dir)
            eval_dirs.append(ds_out_dir)
        else:
            ds_out_dir = out_dir
        logger("[*] Running eval on dataset {}\n"
               "    Out dir: {}".format(dataset, ds_out_dir))
        run_pred_and_eval(dataset=dataset,
                          out_dir=ds_out_dir,
                          model=model, model_func=model_func,
                          hparams=hparams,
                          args=args,
                          logger=logger)
    if len(eval_dirs) > 1:
        cross_dataset_eval(eval_dirs, out_dir)
예제 #13
0
def stage_string_to_class(stage_string):
    return Defaults.get_stage_string_to_class_int()[stage_string.upper()]
예제 #14
0
def run(args):
    """
    Run the script according to args - Please refer to the argparser.

    args:
        args:    (Namespace)  command-line arguments
    """
    from mpunet.logging import Logger
    from utime.hyperparameters import YAMLHParams
    from utime.utils.scriptutils import assert_project_folder
    from utime.utils.scriptutils import get_splits_from_all_datasets

    project_dir = os.path.abspath("./")
    assert_project_folder(project_dir)

    # Get logger object
    logger = Logger(project_dir + "/preprocessing_logs",
                    active_file='preprocessing',
                    overwrite_existing=args.overwrite,
                    no_sub_folder=True)
    logger("Args dump: {}".format(vars(args)))

    # Load hparams
    hparams = YAMLHParams(Defaults.get_hparams_path(project_dir),
                          logger=logger,
                          no_version_control=True)

    # Initialize and load (potentially multiple) datasets
    datasets = get_splits_from_all_datasets(hparams,
                                            splits_to_load=args.dataset_splits,
                                            logger=logger,
                                            return_data_hparams=True)

    # Check if file exists, and overwrite if specified
    if os.path.exists(args.out_path):
        if args.overwrite:
            os.remove(args.out_path)
        else:
            from sys import exit
            logger("Out file at {} exists, and --overwrite was not set."
                   "".format(args.out_path))
            exit(0)

    # Create dataset hparams output directory
    out_dir = Defaults.get_pre_processed_data_configurations_dir(project_dir)
    if not os.path.exists(out_dir):
        os.mkdir(out_dir)

    with ThreadPoolExecutor(args.num_threads) as pool:
        with h5py.File(args.out_path, "w") as h5_file:
            for dataset, dataset_hparams in datasets:
                # Create a new version of the dataset-specific hyperparameters
                # that contain only the fields needed for pre-processed data
                name = dataset[0].identifier.split("/")[0]
                hparams_out_path = os.path.join(out_dir, name + ".yaml")
                copy_dataset_hparams(dataset_hparams, hparams_out_path)

                # Update paths to dataset hparams in main hparams file
                hparams.set_value(subdir='datasets',
                                  name=name,
                                  value=hparams_out_path,
                                  overwrite=True)
                # Save the hyperparameters to the pre-processed main hparams
                hparams.save_current(
                    Defaults.get_pre_processed_hparams_path(project_dir))

                # Process each dataset
                for split in dataset:
                    # Add this split to the dataset-specific hparams
                    add_dataset_entry(hparams_out_path, args.out_path,
                                      split.identifier.split("/")[-1].lower(),
                                      split.period_length_sec)
                    # Overwrite potential load time channel sampler to None
                    split.set_load_time_channel_sampling_groups(None)

                    # Create dataset group
                    split_group = h5_file.create_group(split.identifier)

                    # Run the preprocessing
                    process_func = partial(preprocess_study, split_group)

                    logger.print_to_screen = True
                    logger("Preprocessing dataset:", split)
                    logger.print_to_screen = False
                    n_pairs = len(split.pairs)
                    for i, _ in enumerate(pool.map(process_func, split.pairs)):
                        print("  {}/{}".format(i + 1, n_pairs),
                              end='\r',
                              flush=True)
                    print("")
예제 #15
0
def plot_hypnogram(hyp_array,
                   true_hyp_array=None,
                   seconds_per_epoch=30,
                   annotation_dict=None,
                   show_f1_scores=True,
                   wake_trim_min=None,
                   order=("N3", "N2", "N1", "REM", "W")):
    """
    Plot a ndarray hypnogram of integers, 'hyp_array', optionally on top of an expert annotated hypnogram
    'true_hyp_array'.

    Args:
        hyp_array:         ndarray, shape [N]
        true_hyp_array:    ndarray, shape [N] (optional, default=None)
        seconds_per_epoch: integer, default=30
        annotation_dict:   dict, integer -> stage string mapping
        show_f1_scores:    bool, annotate the figure with f1 scores, only if true_hyp_array is set
        wake_trim_min:     None or integer, if set a number of max minutes before/proceeding the first/last
                           non-wake sleep stage in TRUE array to consider/'zoom' in on.
        order:             list-like of strings, order of sleep stages on plot y-axis

    Returns:
        fig, axes
    """
    rows = 1 if true_hyp_array is None else 2
    hight = 3 if true_hyp_array is None else 4
    fig, axes = plt.subplots(nrows=rows,
                             figsize=(10, hight),
                             sharex=True,
                             sharey=True)
    if not isinstance(axes, (list, np.ndarray)):
        axes = [axes]

    # Map classes to default string classes
    annotation_dict = annotation_dict or Defaults.get_class_int_to_stage_string(
    )
    str_to_int_map = {value: key for key, value in annotation_dict.items()}

    # Define range in hours
    x_hours = np.array([seconds_per_epoch * i
                        for i in range(len(hyp_array))]) / 3600

    if wake_trim_min and true_hyp_array is not None:
        trim = int((60 / seconds_per_epoch) * wake_trim_min)
        inds = np.where(true_hyp_array != str_to_int_map["W"])[0]
        start = max(0, inds[0] - trim)
        end = inds[-1] + trim
        hyp_array = hyp_array[start:end]
        true_hyp_array = true_hyp_array[start:end]
        x_hours = x_hours[start:end]

    reordered_hyp_array = get_reordered_hypnogram(hyp_array, annotation_dict,
                                                  order)
    axes[0].step(x_hours,
                 reordered_hyp_array,
                 where='post',
                 color="black",
                 label="Predicted")

    # Set ylabels
    for ax in axes:
        ax.set_yticks(range(len(order)))
        ax.set_yticklabels(order)
        ax.tick_params(axis='both', which='major', labelsize=14)
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)
    if true_hyp_array is not None:
        fig.text(x=0.02,
                 y=0.56,
                 s="Sleep Stage",
                 ha="center",
                 va="center",
                 fontdict={"size": 16},
                 rotation=90)
    else:
        axes[0].set_ylabel("Sleep Stage", labelpad=12, size=16)
    axes[-1].set_xlabel("Time (hours)", size=16, labelpad=12)
    axes[-1].set_xlim(x_hours[0], x_hours[-1])

    fig.tight_layout()
    if true_hyp_array is not None:
        reordered_true = get_reordered_hypnogram(true_hyp_array,
                                                 annotation_dict, order)
        axes[1].step(x_hours,
                     reordered_true,
                     where='post',
                     color="darkred",
                     label="Expert")

        if show_f1_scores:
            from sklearn.metrics import f1_score
            labels = [str_to_int_map[w] for w in reversed(order)]
            mask = np.isin(true_hyp_array.flatten(),
                           np.array(labels).flatten())
            f1s = f1_score(true_hyp_array[mask],
                           hyp_array[mask],
                           labels=labels,
                           average=None)
            f1s = [round(l, 2) for l in (list(f1s) + [np.mean(f1s)])]
            f1_labels = list(reversed(order)) + ["Mean"]

            # Plot title
            fig.text(x=0.845,
                     y=0.66,
                     s="F1-scores",
                     ha="left",
                     va="top",
                     fontdict={
                         "alpha": 1,
                         "size": 14
                     })
            for i, (stage, value) in enumerate(zip(f1_labels, f1s)):
                # Plot stage
                fig.text(x=0.845,
                         y=0.58 - (0.05 * i),
                         s=f"{stage}",
                         ha="left",
                         va="top",
                         fontdict={
                             "alpha": 1,
                             "size": 14
                         })
                # Plot score
                fig.text(x=0.9475,
                         y=0.58 - (0.05 * i),
                         s="{:.2f}".format(value),
                         ha="left",
                         va="top",
                         fontdict={
                             "alpha": 1,
                             "size": 14
                         })
        lines_labels = [ax.get_legend_handles_labels() for ax in fig.axes]
        lines, labels = [sum(l, []) for l in zip(*lines_labels)]
        l = fig.legend(lines,
                       labels,
                       loc='right',
                       bbox_to_anchor=(1.01, 0.89),
                       ncol=1,
                       fontsize=14)
        l.get_frame().set_linewidth(0)
        fig.subplots_adjust(hspace=0.13, right=0.81, left=0.10)
    return fig, axes