コード例 #1
0
ファイル: utility.py プロジェクト: sungcheolkim78/py_gait
def SC_method_localpeak(signal: NDArray[float],
                        peak_window: int = 50,
                        mean_average_window: int = 30,
                        direction: str = 'down') -> (List, List, List):

    hard_peaks = processing.find_local_peaks(signal, peak_window)
    correct_peaks = processing.correct_peaks(signal,
                                             hard_peaks,
                                             peak_window,
                                             mean_average_window,
                                             peak_dir=direction)
    correct_peaks = np.unique(correct_peaks)

    correct_peaks = correct_peaks[(correct_peaks > -1)
                                  & (correct_peaks < len(signal))]
    if len(correct_peaks) > 1:
        durlist = [
            correct_peaks[i + 1] - correct_peaks[i]
            for i in range(len(correct_peaks) - 1)
        ]
        ptplist = [
            np.ptp(signal[correct_peaks[i]:correct_peaks[i + 1]])
            for i in range(len(correct_peaks) - 1)
        ]
    else:
        ptplist = [0]
        durlist = [0]

    return correct_peaks, ptplist, durlist
コード例 #2
0
def find_p_peak(signal, R_R_interval, onset_located):
    """find the positions of P wave peaks"""

    countP = np.zeros(1, dtype=np.int32)
    pop_list = []
    for i in range(onset_located.shape[0]):
        onset_P_interval = []
        try:
            idx = processing.find_local_peaks(
                signal[onset_located[i] -
                       int(R_R_interval / 5):onset_located[i]],
                int(R_R_interval / 5) - 1)
            onset_P_interval.append(idx[0])
            idx = idx[0] + onset_located[i] - int(R_R_interval / 5)
            countP = np.concatenate(
                (countP, np.zeros(idx - countP.shape[0]), -1 * np.ones(1)))
        except:
            pop_list.append(i)
            average_interval = np.mean(
                np.array(onset_P_interval)) if onset_P_interval else 0
            if average_interval != 0:
                countP = np.concatenate(
                    (countP,
                     np.zeros(onset_located[i] - int(R_R_interval / 5) +
                              average_interval - countP.shape[0]),
                     -1 * np.ones(1)))
                print("      Add a P peak at {}".format(onset_located[i] -
                                                        int(R_R_interval / 5) +
                                                        average_interval))
            else:
                print("      No P peak at {}".format(onset_located[i]))
    P_located = np.nonzero(countP)[0]

    return P_located, pop_list
コード例 #3
0
def find_t_peak(signal, R_R_interval, offset_located):
    """find the positions of T wave peaks"""

    countT = np.zeros(1, dtype=np.int32)
    pop_list = []
    for i in range(offset_located.shape[0]):
        offset_T_interval = []
        try:
            idx = processing.find_local_peaks(
                signal[offset_located[i]:offset_located[i] +
                       int(R_R_interval / 2)],
                int(R_R_interval / 2) - 10)
            offset_T_interval.append(idx[0])
            idx = idx[0] + offset_located[i]
            countT = np.concatenate(
                (countT, np.zeros(idx - countT.shape[0]), -1 * np.ones(1)))
        except:
            pop_list.append(i)
            average_interval = np.mean(
                np.array(offset_T_interval)) if offset_T_interval else 0
            if average_interval != 0:
                countT = np.concatenate(
                    (countT,
                     np.zeros(offset_located[i] + average_interval -
                              countT.shape[0]), -1 * np.ones(1)))
                print("      Add a T peak at {}".format(offset_located[i] +
                                                        average_interval))
            else:
                print("      No T peak at {}".format(offset_located[i]))
    T_located = np.nonzero(countT)[0]

    return T_located, pop_list
コード例 #4
0
def fix_labels(signals, beats, labels):
    """
    Change labeling of the normal beats.

    Beat index of some normal beats doesn't occur at the local maxima
    of the ECG signal in MIT-BIH Arrhytmia database. Function checks if
    beat index occurs within 5 samples from the local maxima. If this is
    not true, beat labeling is changed to -1.

    Parameters
    ----------
    signals : list
        List of ECG signals as numpy arrays
    beats : list
        List of numpy arrays that store beat locations
    labels : list
        List of numpy arrays that store beat types

    Returns
    -------
    fixed_labels : list
        List of numpy arrays where -1 has been added for beats that are
        not located in local maxima

    """
    fixed_labels = []
    for s, b, l in zip(signals, beats, labels):

        # Find local maximas
        localmax = find_local_peaks(sig=s, radius=5)
        localmax = correct_peaks(sig=s,
                                 peak_inds=localmax,
                                 search_radius=5,
                                 smooth_window_size=20,
                                 peak_dir='up')

        # Make sure that beat is also in local maxima
        fixed_p = correct_peaks(sig=s,
                                peak_inds=b,
                                search_radius=5,
                                smooth_window_size=20,
                                peak_dir='up')

        # Check what beats are in local maximas
        beat_is_local_peak = np.isin(fixed_p, localmax)
        fixed_l = l

        # Add -1 if beat is not in local max
        fixed_l[~beat_is_local_peak] = -1
        fixed_labels.append(fixed_l)

    return fixed_labels
コード例 #5
0
def find_r_peak(fecg, wavelet=pywt.Wavelet('db8')):
    coeffs_2 = pywt.wavedec(fecg, wavelet, 'periodization', k)
    new_coeffs_2 = []
    for i in range(k + 1):
        new_coeffs_2.append(np.zeros(coeffs_2[i].shape))

    score = []
    for i in range(k, 1, -1):
        new_coeffs_2[i] = coeffs_2[i]
        pulse = pywt.waverec(new_coeffs_2, wavelet, 'periodization')
        new_coeffs_2[i] = np.zeros(coeffs_2[i].shape)
        sum_pulse = np.sum(np.abs(pulse))
        score.append(np.abs(np.sum(fecg * np.abs(pulse) / sum_pulse)))

    score_diff = [score[i] - score[i + 1] for i in range(1, len(score) - 1)]
    chosen_scale = int(np.argmax(score_diff) + 2)
    new_coeffs_2[-chosen_scale] = coeffs_2[-chosen_scale]
    pulse = pywt.waverec(new_coeffs_2, wavelet, 'periodization')
    needle = np.abs(fecg * pulse)
    peak_idx = processing.find_local_peaks(needle, 150)
    return peak_idx
コード例 #6
0
def prepare_training_set_aha(set_len=5000, db_dir='wfdb/aha'):
    dir = os.path.join(os.getcwd(), db_dir)
    file_list = []
    for root, dirs, files in os.walk(dir):
        [file_list.append(f) for f in files]

    dat_list = [
        a.split('.')[0]
        for a in filter(lambda x: x.split('.')[1] == 'dat', file_list)
    ]
    hea_list = [
        a.split('.')[0]
        for a in filter(lambda x: x.split('.')[1] == 'hea', file_list)
    ]
    atr_list = [
        a.split('.')[0]
        for a in filter(lambda x: x.split('.')[1] == 'atr', file_list)
    ]

    record_list = [
        dat
        for dat in filter(lambda x: x in hea_list and x in atr_list, dat_list)
    ]
    data_list = []
    anno_list = []
    anno_typ_list = []
    for rec in record_list:
        print('loading data', rec)
        data = wfdb.rdrecord(os.path.join(dir, rec))
        anno = wfdb.rdann(os.path.join(dir, rec), 'atr')

        sig = data.p_signal
        # use local peaks to trace the R peak
        peaks = processing.find_local_peaks(sig[:, 0], 10)

        idx2 = 0
        anno_r = []
        for idx in anno.sample:
            idx_start = idx2
            for idx_peak in peaks[idx_start:]:
                idx2 += 1
                # find the nearest R peak after the QRS onset
                if idx_peak > idx:
                    anno_r.append(idx_peak)
                    break

        anno_r_idx = np.zeros(len(sig[:, 0]))
        for idx in anno_r:
            anno_r_idx[idx] = 1
        anno_typ = anno.subtype
        anno_typ_idx = [''] * len(sig)
        for idx in range(len(anno_r)):
            anno_typ_idx[anno_r[idx]] = anno_typ[idx]

        # trunc the data
        sig = sig[anno_r[0] - 20:]
        anno_r_idx = anno_r_idx[anno_r[0] - 20:]
        anno_r = anno_r - (anno_r[0] + 20)
        anno_typ_idx = anno_typ_idx[anno_r[0] - 20:]

        # resample to 500Hz
        sig = resample(sig[:, 0], 250, 500, method='linear')
        anno_r_idx = resample(anno_r_idx, 250, 500, method='label')
        anno_typ_idx = resample(anno_typ_idx, 250, 500, method='label_str')

        # baseline removing
        sig = med_filter(sig, 150)
        '''test code'''
        # plt.plot(sig)
        # plt.plot(anno_r_idx)
        # plt.show()

        data_list.append(sig)
        anno_list.append(anno_r_idx)
        anno_typ_list.append(anno_typ_idx)

    input = []
    label = []
    typ = []
    zipped = zip(data_list, anno_list, anno_typ_list)
    cnt = 0
    for (data, anno_r_idx, anno_typ_idx) in zipped:
        cnt += 1
        idx = 0
        while idx + set_len < len(data):
            d = data[idx:idx + set_len]
            input.append(d)

            lb = anno_r_idx[idx:idx + set_len]
            label.append(lb)

            at = anno_typ_idx[idx:idx + set_len]
            typ.append(at)

            idx += set_len

    return (input, label, typ)
コード例 #7
0
def load_aha(db_dir='wfdb', database='aha'):

    dir = os.path.join(db_dir, database)
    file_list = []
    for root, dirs, files in os.walk(dir):
        [file_list.append(f) for f in files]

    dat_list = [
        a.split('.')[0]
        for a in filter(lambda x: x.split('.')[1] == 'dat', file_list)
    ]
    hea_list = [
        a.split('.')[0]
        for a in filter(lambda x: x.split('.')[1] == 'hea', file_list)
    ]
    atr_list = [
        a.split('.')[0]
        for a in filter(lambda x: x.split('.')[1] == 'atr', file_list)
    ]

    record_list = [
        dat
        for dat in filter(lambda x: x in hea_list and x in atr_list, dat_list)
    ]

    name_list = []
    data_list = []
    anno_list = []
    anno_typ_list = []

    # cnt = 0
    for rec in record_list:
        # cnt += 1
        # if cnt > 10:
        #     break
        print('loading data', rec)
        data = wfdb.rdrecord(os.path.join(dir, rec))
        anno = wfdb.rdann(os.path.join(dir, rec), 'atr')

        sig = data.p_signal
        '''aha: local peaks to trace the R peak'''
        '''test code to check nan in data'''
        # has_nan = False
        # for idx in range(len(sig)):
        #     if np.isnan(sig[idx,0]) or np.isnan(sig[idx,1]):
        #         has_nan = True
        #         print(idx)
        # if has_nan:
        #     plt.title(rec)
        #     plt.plot(sig[:,0])
        #     plt.plot(sig[:,1]-2)
        #     plt.show()
        '''nan handling by interpolation'''
        # idx_start = -1
        # idx_stop = -1
        # for idx in range(len(sig)):
        #     if np.isnan(sig[idx,0]):
        #         if idx_start < 0:
        #             idx_start = idx - 1
        #     else:
        #         if idx_start >= 0:
        #             idx_stop = idx
        #
        #     if idx_start >= 0 and idx_stop >= 0:
        #         print('ch0 interpolating from ', idx_start, ' to ', idx_stop)
        #         sig_interp = interpolate(sig[idx_start,0], sig[idx_stop,0], idx_stop-idx_start+1)
        #         cnt = 0
        #         for idx in range(len(sig_interp)):
        #             sig[idx_start+idx,0] = sig_interp[idx]
        #         idx_start = -1
        #         idx_stop = -1
        #
        # idx_start = -1
        # idx_stop = -1
        # for idx in range(len(sig)):
        #     if np.isnan(sig[idx,1]):
        #         if idx_start < 0:
        #             idx_start = idx - 1
        #     else:
        #         if idx_start >= 0:
        #             idx_stop = idx
        #
        #     if idx_start >= 0 and idx_stop >= 0:
        #         print('ch1 interpolating from ', idx_start, ' to ', idx_stop)
        #         sig_interp = interpolate(sig[idx_start,1], sig[idx_stop,1], idx_stop-idx_start+1)
        #         cnt = 0
        #         for idx in range(len(sig_interp)):
        #             sig[idx_start+idx,1] = sig_interp[idx]
        #
        #         idx_start = -1
        #         idx_stop = -1

        peaks = processing.find_local_peaks(sig[:, 0], 10)

        idx2 = 0
        anno_r = []
        for idx in anno.sample:
            idx_start = idx2
            for idx_peak in peaks[idx_start:]:
                idx2 += 1
                '''aha: find the nearest R peak after the QRS onset'''
                if idx_peak > idx:
                    anno_r.append(idx_peak)
                    break

        anno_r_idx = np.zeros(len(sig[:, 0]))
        for idx in anno_r:
            anno_r_idx[idx] = 1
        anno_typ = anno.subtype
        anno_typ_idx = [''] * len(sig)
        for idx in range(len(anno_r)):
            anno_typ_idx[anno_r[idx]] = anno_typ[idx]

        # trunc the data
        sig = sig[anno_r[0] - 200:]
        anno_r_idx = anno_r_idx[anno_r[0] - 200:]
        anno_r = anno_r - (anno_r[0] + 200)
        anno_typ_idx = anno_typ_idx[anno_r[0] - 200:]

        data_list.append(sig)
        anno_list.append(anno_r_idx)
        anno_typ_list.append(anno_typ_idx)

        name_list.append(rec)

        for idx in range(len(sig)):
            if any(np.isnan(sig[idx])):
                print(rec, 'nan ', idx)
                break
    ''' resample before splitting'''
    print('start resampling')
    cnt = 0
    multiTask = MultiTask(pool_size=40, queue_size=5000)
    for d in data_list:
        multiTask.submit(cnt, resample, (d, 250, 500, 'linear'))
        cnt += 1
    rs_data = [d for d in multiTask.subscribe()]

    cnt = 0
    multiTask = MultiTask(pool_size=40, queue_size=5000)
    for a in anno_list:
        multiTask.submit(cnt, resample, (a, 250, 500, 'label'))
        cnt += 1
    rs_anno = [a for a in multiTask.subscribe()]

    cnt = 0
    multiTask = MultiTask(pool_size=40, queue_size=5000)
    for a in anno_typ_list:
        multiTask.submit(cnt, resample, (a, 250, 500, 'label_str'))
        cnt += 1
    rs_anno_typ = [a for a in multiTask.subscribe()]

    return name_list, rs_data, rs_anno, rs_anno_typ