Ejemplo n.º 1
0
 def rpeak_detection(self, window_size=None, test_size=None, names=None, combinations=None):
     if window_size is None:
         window_size = 50
     if test_size is None:
         test_size = 0.97
     min_dist = 72
     evaluation_window_size = 36
     approach_f = ['KNN_s', 'KNN_w']
     channels_f = ['1', '2', '12']
     filtered_f = ['FS', 'RS']
     results = defaultdict(list)
     if combinations is None:
         combinations = [approach_f, channels_f, filtered_f]
     if names is None:
         names = wfdb.get_record_list('mitdb')
     for comb in itertools.product(*combinations):
         print(comb)
         if 'KNN_w' in comb:
             precisions, recalls, times = self.QRS_KNN(comb, min_dist, names, test_size, window_size,
                                                       evaluation_window_size)
         else:
             precisions, recalls, times = self.SSK(comb, names, evaluation_window_size)
         comb_name = comb[0] + '_' + comb[1] + '_' + comb[2]
         results[comb_name] = [np.mean(precisions), np.mean(recalls), np.mean(times)]
         print("{:s}, {:f}, {:f}, {:f}".format(comb_name, np.mean(precisions), np.mean(recalls), np.mean(times)))
     print(results)
     return results
Ejemplo n.º 2
0
    def get_record_list(self, db):
        assert db in self.__avail_db, 'Not supported database: %s' % (db)

        if db == 'ahadb':
            return AHADB.get_record_list()
        else:
            return wfdb.get_record_list(db)
def load_mitdb():

    mitdblstring = wfdb.get_record_list("mitdb")
    mitdbls = [int(i) for i in mitdblstring]
    mitdb = []

    for i in mitdbls:
        mitdb.append(load_patient_record("mitdb", str(i)))
    my_db = ecg_database("mitdb")

    MITBIH_classes = [
        'N', 'L', 'R', 'e', 'j', 'A', 'a', 'J', 'S', 'V', 'E', 'F'
    ]  #, 'P', '/', 'f', 'u']
    AAMI_classes = []
    AAMI_classes.append(['N', 'L', 'R'])  # N
    AAMI_classes.append(['A', 'a', 'J', 'S', 'e', 'j'])  # SVEB
    AAMI_classes.append(['V', 'E'])  # VEB
    AAMI_classes.append(['F'])  # F

    my_db.patient_records = mitdb
    my_db.MITBIH_classes = MITBIH_classes
    my_db.AAMI_classes = AAMI_classes
    my_db.filenames = mitdblstring
    #my_db.Annotations = annotations
    return my_db
Ejemplo n.º 4
0
def main() -> None:
    inputs_path, targets_path, data_path = setup_directory(classifier='2shot')

    if not os.path.exists(data_path):
        get_physionet_data()

    record_list = wfdb.get_record_list('slpdb')
    # apnea_labels = np.array(['H', 'HA', 'OA', 'CA', 'CAA', 'X'])
    apnea_labels = ['H', 'HA', 'OA', 'CA', 'CAA', 'X']

    for record_number, record in enumerate(record_list):
        epochs, annotations = get_record_data(os.path.join(data_path, record))

        # binarize data: 1 for a sleep apnea event, 0 for a non-apnoea event
        for idx, event in enumerate(annotations):
            for x in apnea_labels:
                if x in event:
                    annotations[idx] = 1
                    break
                else:
                    annotations[idx] = 0

        for epoch_number, epoch in enumerate(epochs):
            # write input and target data to files
            record_name = get_record_name(record_number, epoch_number)

            # save input and output arrays as csv files
            with open(os.path.join(inputs_path, record_name), 'w') as filehandler:
                csv_writer = csv.writer(filehandler, delimiter=' ')
                csv_writer.writerows(epoch)

            # write target values to csv files, named by record number and epoch number
            with open(os.path.join(targets_path, record_name), "w") as filehandler:
                csv_writer = csv.writer(filehandler, delimiter=' ')
                csv_writer.writerow(str(annotations[epoch_number]))
Ejemplo n.º 5
0
def read_records(dataset_name, data_path, sample_size_seconds=30, samples_per_second=250, num_records=None):
    samples = []
    labels = []
    total_read_records = 0
    for record_name in wfdb.get_record_list(dataset_name):
        header = wfdb.rdheader(data_path + record_name)

        if header.sig_len == 0:
            continue

        offset = 0
        samples_count = 0
        while True:
            record, ann, offset = read_record(data_path, header, offset, sample_size_seconds, samples_per_second)
            if record is None:
                break
            samples.append(record)
            labels.append(ann.aux_note)
            samples_count += 1

        total_read_records += 1
        if num_records is not None and total_read_records == num_records:
            break

    labels = np.array([1 if '(AFIB' in key else 0 for key in labels])
    return samples, labels
Ejemplo n.º 6
0
 def remove_non_beat_for_all(self, signals_dir, rule_based):
     symbols = dict()
     peaks = dict()
     for name in wfdb.get_record_list('mitdb'):
         new_peaks, new_symbol = self.remove_non_beat(
             signals_dir + name, rule_based)
         symbols[name] = new_symbol
         peaks[name] = new_peaks
     return peaks, symbols
Ejemplo n.º 7
0
    def data_path(self,
                  subject,
                  path=None,
                  force_update=False,
                  update_path=None,
                  verbose=None):
        if subject not in self.subject_list:
            raise (ValueError("Invalid subject number"))
        # Check if the .dat, .hea and .win files are present
        # The .dat and .hea files give the main data
        # .win file gives the event windows and the frequencies
        # .flash file can give the exact times of the flashes if necessary
        # Return the file paths depending on the number of sessions for each
        # subject that are denoted a, b, c, ...
        sub = "{:02d}".format(subject)
        sign = self.code.split()[1]
        if sign == "MAMEM1":
            fn = "dataset1/S0{}*.dat"
        elif sign == "MAMEM2":
            fn = "dataset2/T0{}*.dat"
        elif sign == "MAMEM3":
            fn = "dataset3/U0{}*.dat"

        key = "MNE_DATASETS_{:s}_PATH".format(sign)
        key_dest = "MNE-{:s}-data".format(sign.lower())
        path = _get_path(path, key, sign)
        path = os.path.join(path, key_dest)
        s_paths = glob.glob(os.path.join(path, fn.format(sub)))
        subject_paths = []
        for name in s_paths:
            subject_paths.append(os.path.splitext(name)[0])
        # if files for the subject are not present
        if not subject_paths or force_update:
            # if not downloaded, get the list of files to download
            datarec = wfdb.get_record_list("mssvepdb")
            datalist = []
            for ele in datarec:
                if fn.format(sub) in ele:
                    datalist.append(ele)
            wfdb.io.dl_database("mssvepdb",
                                path,
                                datalist,
                                annotators="win",
                                overwrite=force_update)
        # Return the file paths depending on the number of sessions
        s_paths = glob.glob(os.path.join(path, fn.format(sub)))
        subject_paths = []
        for name in s_paths:
            # The adaptation session has the letter x at the end in MAMEM2
            # It should be removed from the returned file names
            if (os.path.splitext(name)[0][-1]) != "x":
                subject_paths.append(os.path.splitext(name)[0])
        return subject_paths
 def plot_beats(self, label):
     for name in wfdb.get_record_list('mitdb'):
         print(name)
         # noinspection PyRedeclaration
         record = wfdb.rdrecord(ecg_path + name)
         record = np.transpose(record.p_signal)
         peaks, symbols = ut.remove_non_beat(ecg_path + name, False)
         s_pairs = list(filter(lambda x: x[1] == label, zip(peaks,
                                                            symbols)))
         s_beats_first = [
             record[0][pair[0] - 70:pair[0] + 100] for pair in s_pairs
         ]
         if len(s_beats_first) > 0:
             for beat in s_beats_first:
                 plt.plot(beat)
     plt.title(label)
     plt.show()
Ejemplo n.º 9
0
 def _ls_rec(self) -> NoReturn:
     """
     """
     try:
         tmp = wfdb.get_record_list(self.db_name)
     except:
         with open(self.metadata_files["all_records"], "r") as f:
             tmp = f.read().splitlines()
     self._all_records = {}
     for l in tmp:
         gp, sb = l.strip("/").split("/")
         # add only those which are in local disc
         if os.path.isdir(os.path.join(self.db_dir, gp, sb)):
             if gp in self._all_records.keys():
                 self._all_records[gp].append(sb)
             else:
                 self._all_records[gp] = [sb]
     self._all_records = {k:sorted(v) for k,v in self._all_records.items()}
Ejemplo n.º 10
0
 def rpeak_detection(self):
     channels_f = ['1', '2']
     combinations = [channels_f, ['FS']]
     ecg_path = '../../data/ecg/temp/'
     fs = 360
     sig_len = 650000
     times = dict()
     for comb in itertools.product(*combinations):
         print(comb)
         channels = [int(comb[0]) - 1]
         for name in wfdb.get_record_list('mitdb'):
             print(name)
             record = wfdb.rdrecord(ecg_path + name, channels=channels)
             record = np.transpose(record.p_signal)
             record = record[0]
             start = time.time()
             self.pan_tompkin(record, fs)
             elapsed = time.time() - start
             times[comb[0] + comb[1]] = elapsed / sig_len
     print(times)
Ejemplo n.º 11
0
def download_database(
    data_folder,
    database='afdb',
):
    records = [
        ECGRPeaksRecord(r, database) for r in wfdb.get_record_list(database)
    ]
    logging.info("saving records")
    for r in records:
        try:
            r.download()
        except Exception:
            exc_type, _, traceback = sys.exc_info()
            logging.info("failed to download record:{};\n"
                         "Exception:{}\nTraceback:{}\n\n".format(
                             r.name, exc_type, traceback))
            continue
        path_to_record = os.path.join(data_folder, r.name)
        with open(path_to_record, "wb") as f:
            pickle.dump(r, f)
Ejemplo n.º 12
0
 def signals_evaluation(self, threshold):
     combs = [['FS', '1'], ['FS', '2'], ['RS','1'], ['RS', '2']]
     results = defaultdict(list)
     for comb in combs:
         comb_name = comb[0]+ '_' + comb[1]
         print(comb_name)
         precisions = list()
         recalls = list()
         times = list()
         for name in wfdb.get_record_list('mitdb'):
             start_time = time.time()
             peaks = self.detect_peaks(name, threshold, comb[0], comb[1], comb)
             elapsed = time.time() - start_time
             elapsed = elapsed/650000
             precision, recall = rpeak.evaluate(peaks, PATH +name, eval_width, False)
             precisions.append(precision)
             recalls.append(recall)
             times.append(elapsed)
             print('{:s}, {:f}, {:f}'.format(name, precision, recall))
         print("{:s}, {:f}, {:f}, {:f}".format(comb_name, np.mean(precisions), np.mean(recalls), np.mean(times)))
         results[comb_name] = [comb_name, np.mean(precisions), np.mean(recalls), np.mean(times)]
     print(results)
Ejemplo n.º 13
0
def get_ecg_records(database, channel):
    signals = []
    beat_locations = []
    beat_types = []
    useless_afrecord = ['00735', '03665', '04043', '08405', '08434']

    record_files = wfdb.get_record_list(database)
    print('record_files:', record_files)

    for record in record_files:
        if record in useless_afrecord:
            continue
        else:
            print('processing record:', record)
            s, f = wfdb.rdsamp(record, pn_dir=database)
            print(f)
            annotation = wfdb.rdann('{}/{}'.format(database, record),
                                    extension='atr')
            signal, annotation = resample_singlechan(s[:, channel],
                                                     annotation,
                                                     fs=f['fs'],
                                                     fs_target=250)

            print(signal)
            print(signal.shape)
            #beat_loc, beat_type = get_beats(annotation)
            signals.append(signal)
            #beat_locations.append(beat_loc)
            #beat_types.append(beat_type)
            print('size of signal list: ', len(signals))
            print('--------')
    print('---------record processed!---')

    serialization(database + '_signal', signals)
    serialization(database + '_beat_loc', beat_locations)
    serialization(database + '_beat_types', beat_types)
    #return signals, beat_locations, beat_types
    return signals
Ejemplo n.º 14
0
    def _ls_rec(self,
                db_name: Optional[str] = None,
                local: bool = True) -> NoReturn:
        """ finished, checked,

        find all records (relative path without file extension),
        and save into `self._all_records` for further use

        Parameters
        ----------
        db_name: str, optional,
            name of the database for using `wfdb.get_record_list`,
            if not set, `self.db_name` will be used
        local: bool, default True,
            if True, read from local storage, prior to using `wfdb.get_record_list`
        """
        if local:
            self._ls_rec_local()
            return
        try:
            self._all_records = wfdb.get_record_list(db_name or self.db_name)
        except:
            self._ls_rec_local()
 def extract(self,
             ann_path,
             db='mitdb',
             peaks=None,
             include_vf=False,
             from_annot=True):
     """reads beat labels for each signal in a Physionet database
     :arg ann_path: path of the local folder containing the annotations files(.atr)
     :arg db: string identifier for the Physionet DB
     :arg include_vf: whether to include Ventricular Fibrillation(VF) annotations
     :arg peaks: list or numpy array containing the peaks locations. Used only if from_annot=False
     :arg from_annot: whether to associate labels to peaks from the ground truth(.atr file)
     :returns labels: a dictionary [signal_name, labels]
     """
     labels = defaultdict(list)
     if peaks is None:
         peaks = defaultdict(list)
     names = wfdb.get_record_list(db)
     for name in names:
         if name == '207' and include_vf:
             annotation = wfdb.rdann(ann_path + name + '_VF', 'atr')
             ann_symbols = annotation.symbol
         else:
             ann_samples, ann_symbols = ut.remove_non_beat(ann_path + name)
         if from_annot:
             peaks[name] = ann_samples
             labels[name] = ann_symbols
         else:
             output_labels = list()
             for peak in peaks:
                 closest = self.take_closest(ann_samples, peak)
                 if closest == len(ann_samples):
                     output_labels.append(ann_symbols[closest - 1])
                 else:
                     output_labels.append(ann_symbols[closest])
             labels[name] = output_labels
     return labels, peaks
Ejemplo n.º 16
0
 def choose_tresholds(self, thresholds):
     precisions = defaultdict(list)
     recalls = defaultdict(list)
     for name in wfdb.get_record_list('mitdb'):
         print(name)
         for thresh in thresholds:
             record, indices = self.detect_peaks(name, thresh)
             recall, precision = rpeak.evaluate(indices, PATH + name, eval_width, rule_based=False)
             precisions[thresh].append(precision)
             recalls[thresh].append(recall)
     average_prec = [np.mean(precisions[t]) for t in thresholds]
     average_rec = [np.mean(recalls[t]) for t in thresholds]
     thresh_index = np.argmax([(average_prec[j] + average_rec[j])/2 for j in range(len(average_rec))])
     best_threshold = thresholds[thresh_index]
     plt.plot([best_threshold]*2, [0, 1], label='best_threshold')
     plt.plot(thresholds, average_prec, label = 'precision')
     plt.plot(thresholds, average_rec, label = 'recall')
     plt.xlabel('threshold')
     plt.ylabel('precision/recall')
     plt.legend()
     print(average_rec)
     print(average_prec)
     plt.savefig('prec-rec-threshold-generic.png')
     return best_threshold
Ejemplo n.º 17
0
'''
Created on 4 Jul 2019

@author: filipe
'''
import os
import wfdb

curr_dir = os.getcwd()
record_list = wfdb.get_record_list('mitdb', records='all')

annotation_list = []
for i in record_list:
    annotation_list.append(i + '.atr')

signal_list = []
for i in record_list:
    signal_list.append(i + '.dat')

header_list = []
for i in record_list:
    header_list.append(i + '.hea')

data_list = annotation_list + signal_list + header_list
data_list.sort()
print(data_list)

wfdb.dl_files(db='mitdb', dl_dir=curr_dir + '/Data', files=data_list)
Ejemplo n.º 18
0
def main() -> None:
    """Loads relevant data from PhysioBank using wfdb package specified in documentation and saves it to folders"""

    annotation_dict = defaultdict(
        lambda: 'error', {
            '1': [1, 0, 0, 0],
            '2': [0, 1, 0, 0],
            '3': [0, 0, 1, 0],
            '4': [0, 0, 1, 0],
            'R': [0, 0, 0, 1],
        })

    classes = defaultdict(lambda: '6', {
        '1000': '1',
        '0100': '2',
        '0010': '3',
        '0001': '4',
        '0000': '5'
    })

    class_count = defaultdict(lambda: 0, {
        '1000': 0,
        '0100': 0,
        '0010': 0,
        '0001': 0,
        '0000': 0
    })

    inputs_path, targets_path, data_path = setup_directory()

    if not os.path.exists(data_path):
        get_physionet_data()

    record_list = wfdb.get_record_list('slpdb')

    for record_index, record in enumerate(record_list):
        epochs, annotations = get_record_data(os.path.join(data_path, record))

        for annotation_index in range(len(annotations)):
            # annotations may have several labels but only one is relevant
            labels = annotations[annotation_index].split(' ')
            target = [0, 0, 0, 0]
            for label in labels:
                if annotation_dict[label] != 'error':
                    target = annotation_dict[label]

            # write input and target data to files
            record_name = get_record_name(record_index, annotation_index)

            with open(os.path.join(inputs_path, record_name),
                      'w') as filehandler:
                filehandler.write("\n".join(
                    str(num) for num in epochs[annotation_index]))

            with open(os.path.join(targets_path, record_name),
                      "w") as fileHandler:
                event_class = ''.join([str(v) for v in target])
                class_count[event_class] += 1

                event_class = classes[event_class]
                fileHandler.write(event_class)

        print("\r {:2d}/{:2d}".format(record_index + 1, len(record_list)),
              end=" ")

    # class count statistic
    for key, value in class_count.items():
        print("\n")
        print(key, value)
Ejemplo n.º 19
0
 def horizontal_split(self,
                      classes=None,
                      timesteps=None,
                      window=None,
                      left_window=70,
                      right_window=100,
                      train_db='mitdb',
                      multiclass=True,
                      aami=True,
                      one_hot=True,
                      model=None,
                      standardize=True,
                      channels=None):
     if channels == None:
         channels = [0]
     if train_db == 'mitdb':
         train_dataset = [
             '106', '112', '122', '201', '223', '230', "108", "109", "115",
             "116", "118", "119", "124", "205", "207", "208", "209", "215",
             '101', '114', '203', '220'
         ]
     else:
         train_dataset = wfdb.get_record_list('incartdb')
     test_dataset = [
         "100", "103", "105", "111", "113", "117", "121", "123", "200",
         "202", "210", "212", "213", "214", "219", "221", "222", "228",
         "231", "232", "233", "234"
     ]
     X_train = list()
     Y_train = list()
     X_test = list()
     Y_test = list()
     for name in train_dataset:
         beats, labels, peaks = self.extract_labeled_beats(
             name=name,
             aami=aami,
             classes=classes,
             one_hot=one_hot,
             model=model,
             train_db=train_db,
             multiclass=multiclass,
             window=window,
             left_window=left_window,
             right_window=right_window,
             channels=channels)
         X_train.extend(beats)
         Y_train.extend(labels)
     for name in test_dataset:
         beats, labels, peaks = self.extract_labeled_beats(
             name=name,
             aami=aami,
             classes=classes,
             one_hot=one_hot,
             model=model,
             train_db=train_db,
             multiclass=multiclass,
             window=window,
             left_window=left_window,
             right_window=right_window)
         X_test.extend(beats)
         Y_train.extend(labels)
     X_train = np.array(X_train)
     X_test = np.array(X_test)
     Y_train = np.array(Y_train)
     Y_test = np.array(Y_test)
     X_test, X_val, Y_test, Y_val = train_test_split(X_test,
                                                     Y_test,
                                                     test_size=0.1)
     if standardize:
         X_train, X_val, X_test = self.standardize(X_train, X_val, X_test)
     if timesteps is not None:
         X_train, Y_train = self.compute_timesteps(X_train, Y_train,
                                                   timesteps)
         X_val, Y_val = self.compute_timesteps(X_val, Y_val, timesteps)
         X_test, Y_test = self.compute_timesteps(X_test, Y_test, timesteps)
     return X_train, X_test, Y_train, Y_test
Ejemplo n.º 20
0
import os
import numpy as np
from rpeakdetection.pan_tompkins.rpeak_detector import RPeakDetector
import wfdb
rpd = RPeakDetector()
evaluation_width = 36
ecg_folder = "data/ecg/mitdb/"
peaks_folder = "data/peaks/pantompkins/mitdb/"
precisions = list()
recalls = list()
for name in wfdb.get_record_list('mitdb'):
    peaks = list()
    file = open(peaks_folder + name + '.tsv', "r")
    print(name)
    for line in file:
        peak = line.replace("\n", "")
        peaks.append(int(peak))
    recall, precision = rpd.evaluate(peaks,
                                     ecg_folder + name,
                                     evaluation_width,
                                     rule_based=True)

    print('recall : ' + str(recall))
    print('precision : ' + str(precision))

    precisions.append(precision)
    recalls.append(recall)
print("av prec")
print(np.mean(precisions))
print("av recall")
print(np.mean(recalls))
Ejemplo n.º 21
0

if __name__ == "__main__":

    print('Change made in develop branch')

    # initialize empty lists for evaluation metrics
    CR_arr = []
    PRD_arr = []
    R_arr = []
    SNRo_arr = []
    SNRr_arr = []
    d_SNR = []

    # loop all records in database
    for record in wfdb.get_record_list(db_name, records='all'):

        # record = 'Person_01/rec_10'
        # record = '203'

        # get data for current record
        data = wfdb.rdsamp(record, pb_dir=db_name + '/' + record.split('/')[0])

        Fs = data[1]['fs']
        ecg = data[0][:, 0]  #[0:Fs*20]

        # zero-mean
        ecg = ecg - np.mean(ecg)

        # call compression function
        CR, ecg_compressed, wc_orig = compress(ecg, Fs)
Ejemplo n.º 22
0
import streamlit as st
import pandas as pd
import numpy as np
import wfdb
import scaleogram as scg
import re

st.title('ECG Visualization from https://physionet.org/')

record_list = wfdb.get_record_list('butqdb')
record_final_list = []
for rcd in record_list:
    if re.search(r'ECG$', rcd):
        record_final_list.append(rcd.split('/')[0])

chosen_record = st.sidebar.selectbox('Which record ?', record_final_list)

time_lap = list(range(0, 1500))
chosen_time = st.sidebar.selectbox('Which start time (min) ?', time_lap)
frequency = [1000, 500, 250]
#chosen_frequency = st.sidebar.selectbox(
#   'Which frequency (Hz) ?',
#   frequency
#)
secondstart = chosen_time * 10000
sampfromvalue, samptovalue = st.slider('Select a range of values', secondstart,
                                       secondstart + 10000,
                                       (secondstart, secondstart + 2000))


#pn_dir='butqdb'.
def create_index_df(desired_segment_len=3600,
                    basic_arr_path="data/mit-bih-arrhythmia-database-1.0.0"):
    arr_db = wfdb.get_record_list('mitdb')
    num_samples_in_record = 30 * 60 * 360

    # for selection and sampling
    segment_dict_ann = {}
    record_count = 0

    for _, record_id in enumerate(arr_db):
        record_path = os.path.join(basic_arr_path, str(record_id))

        ann = wfdb.rdann(
            record_path,
            'atr',
            sampto=num_samples_in_record,
            return_label_elements=['description', 'symbol', 'label_store'])
        df = pd.DataFrame({
            'description': ann.description,
            'sample': ann.sample,
            'symbol': ann.symbol,
            'label_store': ann.label_store,
            'aux': ann.aux_note
        })
        df = dynamic_replace(df)
        counter = 0
        reset_flag = True
        allowed_labels = ['Normal beat']
        allowed_symbols = ['N']

        normal_counter = 0
        for i in range(1, df.shape[0] - 1):
            curr_label, curr_sample, curr_symbol = df.loc[
                i, ['description', 'sample', 'symbol']]
            if curr_label == 'Normal beat':
                normal_counter += 1
            if reset_flag:
                start_sample = curr_sample
                ann_num_start = i
                normal_counter = 0
                allowed_labels = ['Normal beat']
                allowed_symbols = ['N']
            next_label, next_sample, next_symbol = df.loc[
                i + 1, ['description', 'sample', 'symbol']]
            if curr_label == next_label or next_label in allowed_labels or len(
                    allowed_labels) < 2:

                if next_label not in allowed_labels:
                    allowed_labels.append(next_label)
                    allowed_symbols.append(next_symbol)
                ann_num_end = i + 1
                counter += next_sample - curr_sample
                reset_flag = False
                if counter > desired_segment_len:
                    counter = 0
                    reset_flag = True
                    signal = wfdb.rdsamp(record_path,
                                         sampfrom=start_sample,
                                         sampto=start_sample + 3600)[0][:, 0]
                    normal_ratio = normal_counter / (ann_num_end -
                                                     ann_num_start)

                    if df.loc[ann_num_start:ann_num_end]['aux'].unique(
                    ).shape[0] == 1:
                        aux_seg = df.loc[ann_num_start]['aux']
                    else:
                        aux_seg = 'invalid'

                    segment_dict_ann[record_count] = [
                        record_id, allowed_labels[-1], signal, normal_ratio,
                        allowed_symbols[-1], aux_seg
                    ]
                    record_count = record_count + 1
            else:
                counter = 0
                normal_counter = 0
                reset_flag = True
                allowed_labels = ['Normal beat']
                allowed_symbols = ['N']

    return segment_dict_ann
Ejemplo n.º 24
0
'''
Created on 8 Aug 2019

@author: filipe
'''

import wfdb
import pandas as pd
import numpy as np
import separate_beats
import json

if __name__ == '__main__':
    record_list = wfdb.get_record_list(db_dir='mitdb', records='all')
    columns = [
        'Class', 'Distance to Previous Beat', 'Distance to Next Beat', 'Beat'
    ]
    signal_df = pd.DataFrame(columns=columns)

    for i in record_list:
        print(i)
        record, fields = wfdb.rdsamp(record_name='Data/' + i,
                                     sampfrom=0,
                                     channels=[0])
        annotations = wfdb.rdann(record_name='Data/' + i,
                                 extension='atr',
                                 sampfrom=0)
        signal_df = separate_beats.aha_update_beats_df(record, annotations,
                                                       signal_df)

    print(signal_df.head())
Ejemplo n.º 25
0
import os
import wfdb
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from sklearn.metrics import confusion_matrix
import seaborn as sn

os.chdir('C:\\Users\\Jerry\\Desktop\\Jerry\\projects\\heartbeat_python')

#%% download all the files
test = wfdb.get_dbs()
wfdb.get_record_list('mitdb')
wfdb.dl_database(
    'mitdb',
    'C:\\Users\\Jerry\\Desktop\\Jerry\\projects\\Heartbeat Python\\data')


#%% building a loop to read in data and annotations then cutting the ECG into heartbeats
def simple_plot(x, alpha=1, grid=True):
    plt.plot(np.arange(0, len(x)), x, alpha=alpha)
    if grid:
        plt.grid(True)


sample = wfdb.rdsamp('data\\101')[0][:, 0]  # read in the record

annotation = wfdb.rdann('data\\101',
Ejemplo n.º 26
0
def main() -> None:
    """Loads relevant data from PhysioBank using wfdb package specified in documentation and saves it to folders"""

    annotation_dict = defaultdict(lambda: 5, {
        '1': 0,
        '2': 1,
        '3': 2,
        '4': 2,
        'R': 3
    })

    classes = defaultdict(lambda: '6', {
        '1000': '1',
        '0100': '2',
        '0010': '3',
        '0001': '4',
        '0000': '5'
    })

    class_count = defaultdict(lambda: 0, {
        '1000': 0,
        '0100': 0,
        '0010': 0,
        '0001': 0,
        '0000': 0
    })

    inputs_path, targets_path, data_path = folder_setup()
    record_list = wfdb.get_record_list('slpdb')

    for record_index, record in enumerate(record_list):
        # read the annotations and data then create a record
        annsamp = wfdb.rdann(str(data_path / record), extension='st', summarize_labels=True)
        signal = wfdb.rdrecord(str(data_path / record), channels=[2])
        physical_signal = signal.p_signal
        physical_signal = preprocessing.scale(physical_signal)

        # remove unannotated epochs (30 second input segments) from the start of the record and split into inputs
        number_annotations = len(annsamp.aux_note)
        starting_index = int((len(physical_signal) / 7500) - number_annotations)*7500
        physical_signal = physical_signal[starting_index:]
        inputs = np.split(physical_signal, number_annotations)

        # generate each 5 shot classification target as 0000: N1, N2, N3, REM, and wake
        target = [[0]*4 for _ in range(number_annotations)]

        # annotate each input and write its data and annotation to separate files
        for annotation_index in range(number_annotations):
            labels = annsamp.aux_note[annotation_index].split(' ')
            for label in labels:
                if annotation_dict[label] != 5:
                    target[annotation_index][annotation_dict[label]] = 1

            # write each input to a csv file, named by record number and input number
            with open(str(inputs_path / (str(record_index) + '_' + str(annotation_index) + '.csv')),
                      'w') as filehandler:
                filehandler.write("\n".join(str(num) for num in inputs[annotation_index]))

            with open(str(targets_path / (str(record_index) + '_' + str(annotation_index) + ".csv")), "w") as \
                    fileHandler:
                event_class = ''.join([str(v) for v in target[annotation_index]])
                class_count[event_class] += 1
                event_class = classes[event_class]

                fileHandler.write(event_class)

        print("\r {:2d}/{:2d}".format(record_index + 1, len(record_list)), end=" ")

    # class count statistic
    for key, value in class_count.items():
        print("\n")
        print(key, value)