Beispiel #1
0
def fix_labels(signals, beats, labels):
    """
    Change labeling of the normal beats.

    Beat index of some normal beats doesn't occur at the local maxima
    of the ECG signal in MIT-BIH Arrhytmia database. Function checks if
    beat index occurs within 5 samples from the local maxima. If this is
    not true, beat labeling is changed to -1.

    Parameters
    ----------
    signals : list
        List of ECG signals as numpy arrays
    beats : list
        List of numpy arrays that store beat locations
    labels : list
        List of numpy arrays that store beat types

    Returns
    -------
    fixed_labels : list
        List of numpy arrays where -1 has been added for beats that are
        not located in local maxima

    """
    fixed_labels = []
    for s, b, l in zip(signals, beats, labels):

        # Find local maximas
        localmax = find_local_peaks(sig=s, radius=5)
        localmax = correct_peaks(sig=s,
                                 peak_inds=localmax,
                                 search_radius=5,
                                 smooth_window_size=20,
                                 peak_dir='up')

        # Make sure that beat is also in local maxima
        fixed_p = correct_peaks(sig=s,
                                peak_inds=b,
                                search_radius=5,
                                smooth_window_size=20,
                                peak_dir='up')

        # Check what beats are in local maximas
        beat_is_local_peak = np.isin(fixed_p, localmax)
        fixed_l = l

        # Add -1 if beat is not in local max
        fixed_l[~beat_is_local_peak] = -1
        fixed_labels.append(fixed_l)

    return fixed_labels
Beispiel #2
0
def get_rr_peaks_indices(record, max_bpm=230):
    """
    :param record_name:
    :param database:
    :param max_bpm:
    :return: list of timestamps for
    """

    qrs_inds = processing.gqrs_detect(sig=record.p_signal[:, 0],
                                      fs=record.fs)
    search_radius = int(record.fs * 60 / max_bpm)
    corrected_peak_inds = processing.correct_peaks(
        record.p_signal[:, 0], peak_inds=qrs_inds,
        search_radius=search_radius, smooth_window_size=150)

    result = []
    discarded_count = 0
    prev_pi = -1
    for pi in corrected_peak_inds:
        if pi != prev_pi:
            result.append(pi)
        else:
            discarded_count += 1
        prev_pi = pi
    # returns r-r peaks timestamps
    result = np.array(result) / record.fs

    return result, discarded_count
Beispiel #3
0
def SC_method_localpeak(signal: NDArray[float],
                        peak_window: int = 50,
                        mean_average_window: int = 30,
                        direction: str = 'down') -> (List, List, List):

    hard_peaks = processing.find_local_peaks(signal, peak_window)
    correct_peaks = processing.correct_peaks(signal,
                                             hard_peaks,
                                             peak_window,
                                             mean_average_window,
                                             peak_dir=direction)
    correct_peaks = np.unique(correct_peaks)

    correct_peaks = correct_peaks[(correct_peaks > -1)
                                  & (correct_peaks < len(signal))]
    if len(correct_peaks) > 1:
        durlist = [
            correct_peaks[i + 1] - correct_peaks[i]
            for i in range(len(correct_peaks) - 1)
        ]
        ptplist = [
            np.ptp(signal[correct_peaks[i]:correct_peaks[i + 1]])
            for i in range(len(correct_peaks) - 1)
        ]
    else:
        ptplist = [0]
        durlist = [0]

    return correct_peaks, ptplist, durlist
Beispiel #4
0
    def test_correct_peaks(self):
        sig, fields = wfdb.rdsamp('sample-data/100')
        ann = wfdb.rdann('sample-data/100', 'atr')
        fs = fields['fs']
        min_bpm = 10
        max_bpm = 350
        min_gap = fs * 60 / min_bpm
        max_gap = fs * 60 / max_bpm

        y_idxs = processing.correct_peaks(sig=sig[:, 0],
                                          peak_inds=ann.sample,
                                          search_radius=int(max_gap),
                                          smooth_window_size=150)

        yz = np.zeros(sig.shape[0])
        yz[y_idxs] = 1
        yz = np.where(yz[:10000] == 1)[0]

        expected_peaks = [
            77, 370, 663, 947, 1231, 1515, 1809, 2045, 2403, 2706, 2998, 3283,
            3560, 3863, 4171, 4466, 4765, 5061, 5347, 5634, 5919, 6215, 6527,
            6824, 7106, 7393, 7670, 7953, 8246, 8539, 8837, 9142, 9432, 9710,
            9998
        ]

        assert np.array_equal(yz, expected_peaks)
Beispiel #5
0
def get_peaks(raw_signal: np.ndarray, fs: int) -> np.ndarray:
    MAX_BPM = 220
    raw_peaks, _ = find_peaks(raw_signal,
                              distance=int((60 / MAX_BPM) / (1 / fs)))
    med_peaks = processing.correct_peaks(raw_signal,
                                         raw_peaks,
                                         30,
                                         35,
                                         peak_dir='up')
    # print("med_peaks: ", med_peaks[:10])
    # print("med_peaks: ", med_peaks[:10])
    # print("med_peaks: ", med_peaks[:10])
    try:
        wel_peaks = processing.correct_peaks(
            raw_signal, med_peaks, 30, 35,
            peak_dir='up') if len(med_peaks) > 0 else raw_peaks
    except ValueError:
        return med_peaks[~np.isnan(med_peaks)]
    return wel_peaks[~np.isnan(wel_peaks)]
Beispiel #6
0
    def get_qrs_inds(self, signals, fs):
        qrs_inds = processing.correct_peaks(
            signals[:, 0],
            processing.xqrs_detect(signals[:, 0],
                                   fs,
                                   conf=processing.XQRS.Conf(hr_min=20,
                                                             hr_max=230,
                                                             qrs_width=0.5)),
            fs * 60 // 230, fs // 2, 'compare')

        return qrs_inds
Beispiel #7
0
    def test_correct_peaks(self):
        sig, fields = wfdb.rdsamp('sample-data/100')
        ann = wfdb.rdann('sample-data/100', 'atr')
        fs = fields['fs']
        min_bpm = 10
        max_bpm = 350
        min_gap = fs*60/min_bpm
        max_gap = fs * 60 / max_bpm

        y_idxs = processing.correct_peaks(sig=sig[:,0], peak_inds=ann.sample,
                                          search_radius=int(max_gap),
                                          smooth_window_size=150)

        yz = np.zeros(sig.shape[0])
        yz[y_idxs] = 1
        yz = np.where(yz[:10000]==1)[0]

        expected_peaks = [77, 370, 663, 947, 1231, 1515, 1809, 2045, 2403,
                          2706, 2998, 3283, 3560, 3863, 4171, 4466, 4765, 5061,
                          5347, 5634, 5919, 6215, 6527, 6824, 7106, 7393, 7670,
                          7953, 8246, 8539, 8837, 9142, 9432, 9710, 9998]

        assert np.array_equal(yz, expected_peaks)
Beispiel #8
0
def detect_qrs(ecg_measurements, fs, win_len_min=10, win_overlap_min=1):

    win_len = int(fs * 60 * win_len_min)
    win_overlap = int(fs * 60 * win_overlap_min)
    indices = set()
    win_skip = win_len - win_overlap
    wins = int((len(ecg_measurements) - win_len) / win_skip) + 2
    min_bpm = 20
    max_bpm = 230
    search_radius = int(fs * 60 / max_bpm)
    start_time_all = time.time()
    for i, w_start in enumerate([x * win_skip for x in range(wins)]):
        w_end = w_start + win_len
        start_time = time.time()
        calc_data = ecg_measurements[w_start:w_end]
        qrs_inds = processing.xqrs_detect(sig=calc_data,
                                          fs=fs,
                                          learn=True,
                                          verbose=False)
        corrected_peak_inds = processing.correct_peaks(
            calc_data,
            peak_inds=qrs_inds,
            search_radius=search_radius,
            smooth_window_size=150)
        start_upd_dict_time = time.time()
        corrected_peak_inds = corrected_peak_inds + w_start
        indices.update(corrected_peak_inds)
        end_time = time.time()
        print(
            "{:5} {:10} - {:10} czas wykonania: {:10.2} całkowity czas wykonania: {:10.2}  liczba probek: {} liczba qrs: {} liczba probek w zbiorze {}"
            .format(i, w_start, w_end, end_time - start_time,
                    end_time - start_time_all, len(calc_data),
                    len(corrected_peak_inds), len(indices)),
            flush=True)
    result = np.array(list(indices))
    np.ndarray.sort(result)
    return result
Beispiel #9
0
def eval_model(test_exs,
               eval_fun,
               params,
               plot_examples=True,
               exs=None,
               nb=2,
               threshold=None,
               nearest_fpr=None,
               eval_margin=10):
    assert threshold is not None or nearest_fpr is not None
    min_gap = params['min_gap']
    max_gap = params['max_gap']
    left_border = params['left_border']
    right_border = params['right_border']
    fs_target = params['fs_target']
    segment_size = params['segment_size']
    segment_step = params['segment_step']
    normalize_steps = params['normalize_steps']
    smooth_window_correct = params['smooth_window_correct']

    if exs is None:
        exs = numpy.random.randint(len(test_exs), size=nb).tolist()
    if plot_examples:
        fig, ax = plt.subplots(len(exs),
                               figsize=(30, 10 * len(exs)))  #, dpi=600)

    y_trues = []
    y_preds = []

    print('Evaluating', end='')
    for i, ex_id in enumerate(exs):
        print('.', end='')
        #print('Example {}'.format(ex_id))
        db, k, j = test_exs[ex_id]
        XY = load_steps(db, k, params)[j]
        X, Y = numpy.reshape(XY[0],
                             (1, 1, 5000)), numpy.reshape(XY[1], (5000, ))

        #print(X.shape, Y.shape)

        res = eval_fun(X)
        res = res[0][0]

        if threshold is not None:
            x = numpy.where(res >= threshold)
            numpy.put(res, x, 1.0)
        res = res.astype('int32')

        best_peaks_idxs = correct_peaks(x=X[0][0],
                                        peak_indexes=res,
                                        min_gap=min_gap,
                                        max_gap=max_gap,
                                        smooth_window=smooth_window_correct)
        best_peaks_vals = X[0][0][best_peaks_idxs]

        y_true = Y
        y_pred = numpy.zeros(len(res))
        y_pred[best_peaks_idxs] = 1

        y_trues += y_true.tolist()[left_border:-right_border]
        y_preds += res.tolist()[left_border:-right_border]

        rp, fp, tp, fpr, tpr, thresholds, auc = roc_auc(
            y_true[left_border:-right_border],
            y_pred[left_border:-right_border],
            margin=eval_margin)

        if plot_examples:
            ax[i].plot(Y + 1, color='blue')
            ax[i].plot(X[0][0], color='green')
            ax[i].plot(res - 1, color='red')
            b_peaks_idx = numpy.where(Y == 1)[0]
            ax[i].plot(b_peaks_idx, X[0][0][b_peaks_idx], 'b+')
            ax[i].plot(best_peaks_idxs, best_peaks_vals, 'r+')
            ax[i].plot([left_border, left_border], [-1, 2], 'm-')
            ax[i].plot([len(res) - right_border,
                        len(res) - right_border], [-1, 2], 'm-')
            ax[i].set_title(
                'Example {} ({}/{}/{}) (TP={}/{}, FP={}/0, TPR={}, FPR={})'.
                format(ex_id, db, k, j, tp, rp, fp, tpr, fpr))
    print()
    if plot_examples:
        plt.show()

    rp, fp, tp, fpr, tpr, thresholds, auc = roc_auc(numpy.asarray(y_trues),
                                                    numpy.asarray(y_preds),
                                                    margin=eval_margin)

    print('FPR\t\t\tTPR\t\t\tThreshold')
    if nearest_fpr is not None:
        idx = (numpy.abs(fpr - nearest_fpr)).argmin()
        print('{:.6f}\t\t{:.6f}\+t\t{:.7f}'.format(fpr[idx], tpr[idx],
                                                   thresholds[idx]))
    else:
        for i, t in enumerate(thresholds):
            print('{:.6f}\t\t{:.6f}\+t\t{:.7f}'.format(fpr[i], tpr[i], t))

    plot_roc(fpr, tpr, auc, figsize=(10, 10))

    print('Samples:\t\t{} samples'.format(len(y_trues)))
    print('Beats:')
    print('  - {} labelized'.format(rp))
    print('  - {} detected'.format(fp + tp))
    print('  - TP:  {}/{}'.format(tp, rp))
    print('  - FP:  {}/{}'.format(fp, 0))
    print('  - TPR: {:.4f}'.format(tp / rp))
Beispiel #10
0
    if saveto is not None:
        plt.savefig(saveto, dpi=600)
    plt.show()
#加载ECG信号
record=wfdb.rdrecord('./2')  #.hea .dat文件名称
#help(wfdb.rdrecord)
#使用gqrs算法定位qrs波位置
qrs_inds=processing.gqrs_detect(sig=record.p_signal[:, 0], fs=record.fs)  #未矫正位置
#画出结果
#peaks_hr(sig=record.p_signal, peak_inds=qrs_inds, fs=record.fs, title='GQRS peak detection on record 100')
#修正峰值,将其设置为局部最大值
min_bpm=20
max_bpm=230
#使用可能最大的bpm作为搜索半径
search_radius=int(record.fs*60/max_bpm)
corrected_peak_inds=processing.correct_peaks(record.p_signal[:, 0], peak_inds=qrs_inds, search_radius=search_radius, smooth_window_size=150)
#输出矫正后的QRS波峰位置
print('Corrected gqrs detected peak indices:', sorted(corrected_peak_inds))




# Feature 1: 计算R波波峰
signal=record.p_signal
R_peak = -100
for x in corrected_peak_inds:
    if R_peak < max(signal[x]):
        R_peak = max(signal[x])
print(R_peak)

Beispiel #11
0
def dataGeneration(data_path, csv_path, record_path):

    # initialize dataset
    dataset = pd.DataFrame(columns=['label', 'record'])

    if record_path == None:

        # a loop for each patient
        detail_path = data_path + '/'
        record_files = [
            i.split('.')[0] for i in os.listdir(detail_path)
            if (not i.startswith('.') and i.endswith('.hea'))
        ]

        Bar.check_tty = False
        bar = Bar('Processing',
                  max=len(record_files),
                  fill='#',
                  suffix='%(percent)d%%')

        # a loop for each record
        for record_name in record_files:

            # load record
            signal, info = wfdb.rdsamp(detail_path + record_name)

            fs = 200

            signal = processing.resample_sig(signal[:, 0], info['fs'], fs)[0]

            # set some parameters
            window_size_half = int(fs * 0.125 / 2)
            max_bpm = 230

            # detect QRS peaks
            qrs_inds = processing.gqrs_detect(signal, fs=fs)
            search_radius = int(fs * 60 / max_bpm)
            corrected_qrs_inds = processing.correct_peaks(
                signal,
                peak_inds=qrs_inds,
                search_radius=search_radius,
                smooth_window_size=150)

            average_qrs = 0
            count = 0
            for i in range(1, len(corrected_qrs_inds) - 1):
                start_ind = corrected_qrs_inds[i] - window_size_half
                end_ind = corrected_qrs_inds[i] + window_size_half + 1
                if start_ind < corrected_qrs_inds[
                        i - 1] or end_ind > corrected_qrs_inds[i + 1]:
                    continue
                average_qrs = average_qrs + signal[start_ind:end_ind]
                count = count + 1

            # remove outliers
            if count < 8:
                print('\noutlier detected, discard ' + record_name)
                continue

            average_qrs = average_qrs / count

            corrcoefs = []
            for i in range(1, len(corrected_qrs_inds) - 1):
                start_ind = corrected_qrs_inds[i] - window_size_half
                end_ind = corrected_qrs_inds[i] + window_size_half + 1
                if start_ind < corrected_qrs_inds[
                        i - 1] or end_ind > corrected_qrs_inds[i + 1]:
                    corrcoefs.append(-100)
                    continue
                corrcoef = pearsonr(signal[start_ind:end_ind], average_qrs)[0]
                corrcoefs.append(corrcoef)

            max_corr = list(map(corrcoefs.index, heapq.nlargest(8, corrcoefs)))

            index_corr = random.sample(
                list(itertools.permutations(max_corr, 8)), 100)

            for index in index_corr:
                # a temp dataframe to store one record
                record_temp = pd.DataFrame()

                signal_temp = []

                for i in index:
                    start_ind = corrected_qrs_inds[i + 1] - window_size_half
                    end_ind = corrected_qrs_inds[i + 1] + window_size_half + 1
                    sig = processing.normalize_bound(signal[start_ind:end_ind],
                                                     -1, 1)
                    signal_temp = np.concatenate((signal_temp, sig))

                record_temp = record_temp.append(pd.DataFrame(
                    signal_temp.reshape(-1, signal_temp.shape[0])),
                                                 ignore_index=True,
                                                 sort=False)
                record_temp['label'] = record_name
                record_temp['record'] = record_name

                # add it to final dataset
                dataset = dataset.append(record_temp,
                                         ignore_index=True,
                                         sort=False)

            bar.next()
        bar.finish()
    else:
        patient_folders = [
            i for i in os.listdir(data_path)
            if (not i.startswith('.') and i.startswith(record_path))
        ]

        Bar.check_tty = False
        bar = Bar('Processing',
                  max=len(patient_folders),
                  fill='#',
                  suffix='%(percent)d%%')
        # a loop for each patient
        for patient_name in patient_folders:
            detail_path = data_path + patient_name + '/'
            record_files = [
                i.split('.')[0] for i in os.listdir(detail_path)
                if i.endswith('.hea')
            ]

            # a loop for each record
            for record_name in record_files:

                # load record
                signal, info = wfdb.rdsamp(detail_path + record_name)

                fs = 200

                signal = processing.resample_sig(signal[:, 0], info['fs'],
                                                 fs)[0]

                # set some parameters
                window_size_half = int(fs * 0.125 / 2)
                max_bpm = 230

                # detect QRS peaks
                qrs_inds = processing.gqrs_detect(signal, fs=fs)
                search_radius = int(fs * 60 / max_bpm)
                corrected_qrs_inds = processing.correct_peaks(
                    signal,
                    peak_inds=qrs_inds,
                    search_radius=search_radius,
                    smooth_window_size=150)

                average_qrs = 0
                count = 0
                for i in range(1, len(corrected_qrs_inds) - 1):
                    start_ind = corrected_qrs_inds[i] - window_size_half
                    end_ind = corrected_qrs_inds[i] + window_size_half + 1
                    if start_ind < corrected_qrs_inds[
                            i - 1] or end_ind > corrected_qrs_inds[i + 1]:
                        continue
                    average_qrs = average_qrs + signal[start_ind:end_ind]
                    count = count + 1

                # remove outliers
                if count < 8:
                    print('\noutlier detected, discard ' + record_name +
                          ' of ' + patient_name)
                    continue

                average_qrs = average_qrs / count

                corrcoefs = []
                for i in range(1, len(corrected_qrs_inds) - 1):
                    start_ind = corrected_qrs_inds[i] - window_size_half
                    end_ind = corrected_qrs_inds[i] + window_size_half + 1
                    if start_ind < corrected_qrs_inds[
                            i - 1] or end_ind > corrected_qrs_inds[i + 1]:
                        corrcoefs.append(-100)
                        continue
                    corrcoef = pearsonr(signal[start_ind:end_ind],
                                        average_qrs)[0]
                    corrcoefs.append(corrcoef)

                max_corr = list(
                    map(corrcoefs.index, heapq.nlargest(8, corrcoefs)))

                index_corr = random.sample(
                    list(itertools.permutations(max_corr, 8)), 100)

                for index in index_corr:
                    # a temp dataframe to store one record
                    record_temp = pd.DataFrame()

                    signal_temp = []

                    for i in index:
                        start_ind = corrected_qrs_inds[i +
                                                       1] - window_size_half
                        end_ind = corrected_qrs_inds[i +
                                                     1] + window_size_half + 1
                        sig = processing.normalize_bound(
                            signal[start_ind:end_ind], -1, 1)
                        signal_temp = np.concatenate((signal_temp, sig))

                    record_temp = record_temp.append(pd.DataFrame(
                        signal_temp.reshape(-1, signal_temp.shape[0])),
                                                     ignore_index=True,
                                                     sort=False)
                    record_temp['label'] = patient_name
                    record_temp['record'] = record_name

                    # add it to final dataset
                    dataset = dataset.append(record_temp,
                                             ignore_index=True,
                                             sort=False)

            bar.next()
        bar.finish()

    # save for further use
    dataset.to_csv(csv_path, index=False)

    print('processing completed')
Beispiel #12
0
    def find_peaks(self, signal, verbose=False):
        """
        Execute the peak detection algorithm.

        Function uses LSTM model that was trained with 1000 sample windows
        of simulated noisy ECG data with sampling rate of 250 Hz. Following
        steps are executed:
        1. Input ECG is sig to 250 Hz
        2. Input ECG is divided into overlapping 1000 sample windows
        3. LSTM model is used to make predictions for windows from step 2
        4. R-peak locations are decided based on predictions
        5. R-peak locations are mapped back into original sampling frequency
        and they are corrected upwards

        Parameters
        ----------
        signal : array
            Single channel ECG signal.
        verbose : bool
            Whether print information.

        Returns
        -------
        orig_peaks : array
            Indices of the R-peaks as samples from the beginning from the
            original signal (not resampled signal).
        filtered_proba : array
            Probability values (probability that point is an R-peak) for
            all points in orig_peaks array.

        """
        if self.resample:
            if verbose:
                print("Resampling signal from ", self.iput_fs, "Hz to 250 Hz")
            sig = resample_poly(signal, up=250, down=self.iput_fs)

        else:
            sig = signal

        if verbose:
            print("Extracting windows, window size:", self.win_size,
                  " stride:", self.stride)
        padded_indices, data_windows = self._extract_windows(signal=sig)

        # Normalize each window to -1, 1 range
        normalize = partial(processing.normalize_bound, lb=-1, ub=1)
        data_windows = np.apply_along_axis(normalize, 1, data_windows)

        if verbose:
            print("Predicting peaks")
        predictions = self.model.predict(data_windows, verbose=0)

        if verbose:
            print("Calculating means for overlapping predictions (windows)")
        means_for_predictions = self._mean_preds(win_idx=padded_indices,
                                                 preds=predictions,
                                                 orig_len=sig.shape[0])

        predictions = means_for_predictions

        if verbose:
            print("Filtering out predictions below probabilty threshold ",
                  self.threshold)
        filtered_peaks, filtered_proba = self._filter_predictions(
            signal=sig, preds=predictions)
        if self.resample:
            # Resampled positions
            resampled_pos = np.round(np.linspace(
                0, (sig.shape[0] - 0.5),
                int(sig.shape[0] * (self.iput_fs / 250))),
                                     decimals=1)

            # Resample peaks back to original frequency
            orig_peaks = processing.resample_ann(resampled_pos, filtered_peaks)

            # Correct peaks with respect to original signal
            orig_peaks = processing.correct_peaks(sig=signal,
                                                  peak_inds=orig_peaks,
                                                  search_radius=int(
                                                      self.iput_fs / 50),
                                                  smooth_window_size=20,
                                                  peak_dir='up')

            # In some very rare cases final correction can introduce duplicate
            # peak values. If this is the case, then mean of the duplicate
            # values is taken.
            filtered_proba = self._calculate_means(indices=orig_peaks,
                                                   values=filtered_proba)
            orig_peaks = np.unique(orig_peaks)

        else:
            orig_peaks = filtered_peaks

        if verbose:
            print("Everything done")

        return orig_peaks, filtered_proba
Beispiel #13
0
    def _filter_predictions(self, signal, preds):
        """
        Filter model predictions.

        Function filters model predictions by using following steps:
        1. selects only the predictions that are above the given
        probability threshold.
        2. Correct these predictions upwards with respect the given ECG
        3. Check if at least five points are corrected into the same
        location.
        4. If step 3 is true, then location is classified as an R-peak
        5. Calculate probability of location being an R-peak by taking
        mean of the probabilities from predictions in the same location.

        Aforementioned steps can be thought as an noise reducing measure as
        in original training data every R-peak was labeled with 5 points.

        Parameters
        ----------
        signal : array
            Same signal that was used with extract_windows function. It is
            used in correct_peaks function.
        preds : array
            Predictions for the sample points of the signal.

        Returns
        -------
        filtered_peaks : array
            locations of the filtered peaks.
        filtered_probs : array
            probability that filtered peak is an R-peak.

        """
        assert (signal.shape == preds.shape)

        # Select points probabilities and indices that are above
        # self.threshold
        above_thresh = preds[preds > self.threshold]
        above_threshold_idx = np.where(preds > self.threshold)[0]

        # Keep only points above self.threshold and correct them upwards
        correct_up = processing.correct_peaks(sig=signal,
                                              peak_inds=above_threshold_idx,
                                              search_radius=5,
                                              smooth_window_size=20,
                                              peak_dir='up')

        filtered_peaks = []
        filtered_probs = []

        for peak_id in np.unique(correct_up):
            # Select indices and take probabilities from the locations
            # that contain at leas 5 points
            points_in_peak = np.where(correct_up == peak_id)[0]
            if points_in_peak.shape[0] >= 5:
                filtered_probs.append(above_thresh[points_in_peak].mean())
                filtered_peaks.append(peak_id)

        filtered_peaks = np.asarray(filtered_peaks)
        filtered_probs = np.asarray(filtered_probs)

        return filtered_peaks, filtered_probs