def test_1(self):

        # Read data using WFDB python package
        annotation = wfdb.rdann('sampledata/100', 'atr')

        
        # This is not the fault of the script. The annotation file specifies a
        # length 3
        annotation.aux[0] = '(N'
        # aux field with a null written after '(N' which the script correctly picks up. I am just
        # getting rid of the null in this unit test to compare with the regexp output below which has
        # no null to detect in the output text file of rdann.

        # Target data from WFDB software package
        lines = tuple(open('tests/targetoutputdata/anntarget1', 'r'))
        nannot = len(lines)

        Ttime = [None] * nannot
        Tannsamp = np.empty(nannot, dtype='object')
        Tanntype = [None] * nannot
        Tsubtype = np.empty(nannot, dtype='object')
        Tchan = np.empty(nannot, dtype='object')
        Tnum = np.empty(nannot, dtype='object')
        Taux = [None] * nannot

        RXannot = re.compile(
            '[ \t]*(?P<time>[\[\]\w\.:]+) +(?P<annsamp>\d+) +(?P<anntype>.) +(?P<subtype>\d+) +(?P<chan>\d+) +(?P<num>\d+)\t?(?P<aux>.*)')

        for i in range(0, nannot):
            Ttime[i], Tannsamp[i], Tanntype[i], Tsubtype[i], Tchan[
                i], Tnum[i], Taux[i] = RXannot.findall(lines[i])[0]

        # Convert objects into integers
        Tannsamp = Tannsamp.astype('int')
        Tnum = Tnum.astype('int')
        Tsubtype = Tsubtype.astype('int')
        Tchan = Tchan.astype('int')

        # Compare
        comp = [np.array_equal(annotation.annsamp, Tannsamp), 
                np.array_equal(annotation.anntype, Tanntype), 
                np.array_equal(annotation.subtype, Tsubtype), 
                np.array_equal(annotation.chan, Tchan), 
                np.array_equal(annotation.num, Tnum), 
                annotation.aux == Taux]

        # Test file streaming
        pbannotation = wfdb.rdann('100', 'atr', pbdir = 'mitdb')
        pbannotation.aux[0] = '(N'
        
        # Test file writing
        annotation.wrann(writefs=True)
        annotationwrite = wfdb.rdann('100', 'atr')

        assert (comp == [True] * 6)
        assert annotation.__eq__(pbannotation)
        assert annotation.__eq__(annotationwrite)
예제 #2
0
    def test_3(self):
        """
        Annotation file with custom annotation types

        Target file created with:
            rdann -r sample-data/1003 -a atr > ann-3
        """
        annotation = wfdb.rdann('sample-data/1003', 'atr')

        # Target data from WFDB software package
        lines = tuple(open('tests/target-output/ann-3', 'r'))
        nannot = len(lines)

        target_time = [None] * nannot
        target_sample = np.empty(nannot, dtype='object')
        target_symbol = [None] * nannot
        target_subtype = np.empty(nannot, dtype='object')
        target_chan = np.empty(nannot, dtype='object')
        target_num = np.empty(nannot, dtype='object')
        target_aux_note = [None] * nannot

        RXannot = re.compile(
            '[ \t]*(?P<time>[\[\]\w\.:]+) +(?P<sample>\d+) +(?P<symbol>.) +(?P<subtype>\d+) +(?P<chan>\d+) +(?P<num>\d+)\t?(?P<aux_note>.*)')

        for i in range(0, nannot):
            target_time[i], target_sample[i], target_symbol[i], target_subtype[i], target_chan[
                i], target_num[i], target_aux_note[i] = RXannot.findall(lines[i])[0]

        # Convert objects into integers
        target_sample = target_sample.astype('int')
        target_num = target_num.astype('int')
        target_subtype = target_subtype.astype('int')
        target_chan = target_chan.astype('int')

        # Compare
        comp = [np.array_equal(annotation.sample, target_sample),
                np.array_equal(annotation.symbol, target_symbol),
                np.array_equal(annotation.subtype, target_subtype),
                np.array_equal(annotation.chan, target_chan),
                np.array_equal(annotation.num, target_num),
                annotation.aux_note == target_aux_note]

        # Test file streaming
        pbannotation = wfdb.rdann('1003', 'atr', pb_dir='challenge/2014/set-p2', return_label_elements=['label_store', 'symbol'])
        pbannotation.create_label_map()

        # Test file writing
        annotation.wrann(write_fs=True)
        writeannotation = wfdb.rdann('1003', 'atr', return_label_elements=['label_store', 'symbol'])
        writeannotation.create_label_map()

        assert (comp == [True] * 6)
        assert annotation.__eq__(pbannotation)
        assert annotation.__eq__(writeannotation)
    def test_3(self):

        # Read data using WFDB python package
        annotation = wfdb.rdann('sampledata/1003', 'atr')

        # Target data from WFDB software package
        lines = tuple(open('tests/targetoutputdata/anntarget3', 'r'))
        nannot = len(lines)

        Ttime = [None] * nannot
        Tannsamp = np.empty(nannot, dtype='object')
        Tanntype = [None] * nannot
        Tsubtype = np.empty(nannot, dtype='object')
        Tchan = np.empty(nannot, dtype='object')
        Tnum = np.empty(nannot, dtype='object')
        Taux = [None] * nannot

        RXannot = re.compile(
            '[ \t]*(?P<time>[\[\]\w\.:]+) +(?P<annsamp>\d+) +(?P<anntype>.) +(?P<subtype>\d+) +(?P<chan>\d+) +(?P<num>\d+)\t?(?P<aux>.*)')

        for i in range(0, nannot):
            Ttime[i], Tannsamp[i], Tanntype[i], Tsubtype[i], Tchan[
                i], Tnum[i], Taux[i] = RXannot.findall(lines[i])[0]

        # Convert objects into integers
        Tannsamp = Tannsamp.astype('int')
        Tnum = Tnum.astype('int')
        Tsubtype = Tsubtype.astype('int')
        Tchan = Tchan.astype('int')

        # Compare
        comp = [np.array_equal(annotation.annsamp, Tannsamp), 
                np.array_equal(annotation.anntype, Tanntype), 
                np.array_equal(annotation.subtype, Tsubtype), 
                np.array_equal(annotation.chan, Tchan), 
                np.array_equal(annotation.num, Tnum), 
                annotation.aux == Taux]

        # Test file streaming
        pbannotation = wfdb.rdann('1003', 'atr', pbdir = 'challenge/2014/set-p2')
        
        # Test file writing
        annotation.wrann(writefs=True)
        annotationwrite = wfdb.rdann('1003', 'atr')

        assert (comp == [True] * 6)
        assert annotation.__eq__(pbannotation)
        assert annotation.__eq__(annotationwrite)
예제 #4
0
def show_objective():
    """ For the model """
    # Choose a record
    records = dm.get_records()
    path = records[17]
    record = wf.rdsamp(path)
    ann = wf.rdann(path, 'atr')

    chid = 0
    print 'Channel:', record.signame[chid]

    cha = record.p_signals[:, chid]

    # These were found manually
    sta = 184000
    end = sta + 1000
    times = np.arange(end-sta, dtype = 'float')
    times /= record.fs

    # Extract the annotations for that fragment
    where = (sta < ann.annsamp) & (ann.annsamp < end)
    samples = ann.annsamp[where] - sta
    print samples

    # Prepare dirac-comb type of labels
    qrs_values = np.zeros_like(times)
    qrs_values[samples] = 1

    # Prepare gaussian-comb type of labels
    kernel = ss.hamming(36)
    qrs_gauss = np.convolve(kernel,
                            qrs_values,
                            mode = 'same')

    # Make the plots
    fig = plt.figure()
    ax1 = fig.add_subplot(3,1,1)
    ax1.plot(times, cha[sta : end])

    ax2 = fig.add_subplot(3,1,2, sharex=ax1)
    ax2.plot(times,
             qrs_values,
             'C1',
             lw = 4,
             alpha = 0.888)
    ax3 = fig.add_subplot(3,1,3, sharex=ax1)
    ax3.plot(times,
             qrs_gauss,
             'C3',
             lw = 4,
             alpha = 0.888)
    plt.setp(ax1.get_xticklabels(), visible=False)
    plt.setp(ax2.get_xticklabels(), visible=False)
    plt.xlabel('Time [s]')
    plt.xlim([0, 2.5])
    plt.show()
    def test_1(self):
        sig, fields = wfdb.srdsamp('sampledata/100')
        ann = wfdb.rdann('sampledata/100', 'atr')

        fs = fields['fs']
        fs_target = 50

        new_sig, new_ann = wfdb.processing.resample_singlechan(sig[:, 0], ann, fs, fs_target)

        expected_length = int(sig.shape[0]*fs_target/fs)

        assert new_sig.shape[0] == expected_length
예제 #6
0
    def test_resample_multi(self):
        sig, fields = wfdb.rdsamp('sample-data/100')
        ann = wfdb.rdann('sample-data/100', 'atr')

        fs = fields['fs']
        fs_target = 50

        new_sig, new_ann = processing.resample_multichan(sig, ann, fs, fs_target)

        expected_length = int(sig.shape[0]*fs_target/fs)

        assert new_sig.shape[0] == expected_length
        assert new_sig.shape[1] == sig.shape[1]
예제 #7
0
def show_path(path):
    """ As a plot """
    # Read in the data
    record = wf.rdsamp(path)
    annotation = wf.rdann(path, 'atr')
    data = record.p_signals
    cha = data[:, 0]
    print 'Channel type:', record.signame[0]
    times = np.arange(len(cha), dtype = float)
    times /= record.fs
    plt.plot(times, cha)
    plt.xlabel('Time [s]')
    plt.show()
    def test_7(self):
        sig, fields = wfdb.srdsamp('sampledata/100', channels = [0, 1])
        ann = wfdb.rdann('sampledata/100', 'atr')
        fs = fields['fs']
        min_bpm = 10
        max_bpm = 350
        min_gap = fs*60/min_bpm
        max_gap = fs*60/max_bpm

        y_idxs = wfdb.processing.correct_peaks(sig[:,0], ann.annsamp, min_gap, max_gap, smooth_window=150)

        yz = numpy.zeros(sig.shape[0])
        yz[y_idxs] = 1
        yz = numpy.where(yz[:10000]==1)[0]

        assert numpy.array_equal(yz, [77, 370, 663, 947, 1231, 1515, 1809, 2045, 2403, 2706, 2998, 3283, 3560, 3863, 4171, 4466, 4765, 5061, 5347, 5634, 5919, 6215, 6527, 6824, 7106, 7393, 7670, 7953, 8246, 8539, 8837, 9142, 9432, 9710, 9998])
예제 #9
0
    def __getitem__(self, idx):
        folder_name = os.path.join(self.root_dir,
                                   self.landmarks_frame.iloc[idx, 0])
        file_name = self.landmarks_frame.iloc[idx, 0]
        #         print(file_name)
        #         print(folder_name)
        #         file_name='tr03-0005/'
        #         folder_name='../data/training/tr03-0005/'
        signals = wfdb.rdrecord(os.path.join(folder_name, file_name[:-1]))
        arousals = h5py.File(
            os.path.join(folder_name, file_name[:-1] + '-arousal.mat'), 'r')
        tst_ann = wfdb.rdann(os.path.join(folder_name, file_name[:-1]),
                             'arousal')

        POI = []
        for typ in ['arousal_rera', 'resp_hypopnea', 'resp_centralapnea']:
            start_idx = np.where(np.array(tst_ann.aux_note) == '(' + typ)
            end_idx = np.where(np.array(tst_ann.aux_note) == typ + ')')

            _starts = tst_ann.sample[start_idx]
            _ends = tst_ann.sample[end_idx]

            _width = np.subtract(_ends, _starts)
            _centers = _starts + _width // 2
            POI = np.append(POI, _centers)

        W = tst_ann.sample[np.where(np.array(tst_ann.aux_note) == 'W')]
        N1 = tst_ann.sample[np.where(np.array(tst_ann.aux_note) == 'N1')]
        N2 = tst_ann.sample[np.where(np.array(tst_ann.aux_note) == 'N2')]
        N3 = tst_ann.sample[np.where(np.array(tst_ann.aux_note) == 'N3')]

        POI = np.append(POI, W)
        POI = np.append(POI, N1)
        POI = np.append(POI, N2)
        POI = np.append(POI, N3)
        np.random.shuffle(POI)
        #         POI = arousal_centers
        interested = []
        for i in range(13):
            if signals.sig_name[i] in [
                    'SaO2', 'ABD', 'F4-M1', 'C4-M1', 'O2-M1', 'AIRFLOW'
            ]:
                interested.append(i)
        sample = ((signals.p_signal[:, interested], POI),
                  arousals['data']['arousals'].value.ravel())
        return sample
예제 #10
0
def load_wfdb(path, components, *args, **kwargs):
    """Load given components from wfdb file.

    Parameters
    ----------
    path : str
        Path to .hea file.
    components : iterable
        Components to load.
    ann_ext: str
        Extension of the annotation file.

    Returns
    -------
    ecg_data : list
        List of ecg data components.
    """
    _ = args

    ann_ext = kwargs.get("ann_ext")

    path = os.path.splitext(path)[0]
    record = wfdb.rdrecord(path)
    signal = record.__dict__.pop("p_signal").T
    record_meta = record.__dict__
    nsig = record_meta["n_sig"]

    if "annotation" in components and ann_ext is not None:
        annotation = wfdb.rdann(path, ann_ext)
        annot = {"annsamp": annotation.sample,
                 "anntype": annotation.symbol}
    else:
        annot = {}

    # Initialize meta with defined keys, load values from record
    # meta and preprocess to our format.
    meta = dict(zip(META_KEYS, [None] * len(META_KEYS)))
    meta.update(record_meta)

    meta["signame"] = check_signames(meta.pop("sig_name"), nsig)
    meta["units"] = check_units(meta["units"], nsig)

    data = {"signal": signal,
            "annotation": annot,
            "meta": meta}
    return [data[comp] for comp in components]
예제 #11
0
def _read_signal(file):
    record = wfdb.rdrecord(file_path)
    annotation = wfdb.rdann(file_path, 'atr')
    annotated_intervals = list(zip(annotation.sample, annotation.aux_note))
    
    signal_ch1 = record.p_signal[:, 0][3000:-3000]
    signal_ch2 = record.p_signal[:, 2][3000:-3000]
    signal_ch3 = record.p_signal[:, 4][3000:-3000]
    
    signal_ch1 = butter_bandpass_filter(signal_ch1, self.low_freq, 
                                        self.high_freq, sample_freq, order=4)
    signal_ch2 = butter_bandpass_filter(signal_ch2, self.low_freq, 
                                        self.high_freq, sample_freq, order=4)
    signal_ch3 = butter_bandpass_filter(signal_ch3, self.low_freq, 
                                        self.high_freq, sample_freq, order=4)

    return signal_ch1, signal_ch2, signal_ch3, annotated_intervals
예제 #12
0
    def test_xqrs(self):
        """
        Run xqrs detector on record 100 and compare to reference annotations
        """
        sig, fields = wfdb.rdsamp('sample-data/100', channels=[0])
        ann_ref = wfdb.rdann('sample-data/100','atr')

        xqrs = processing.XQRS(sig=sig[:,0], fs=fields['fs'])
        xqrs.detect()

        comparitor = processing.compare_annotations(ann_ref.sample[1:],
                                                    xqrs.qrs_inds,
                                                    int(0.1 * fields['fs']))

        assert comparitor.specificity > 0.99
        assert comparitor.positive_predictivity > 0.99
        assert comparitor.false_positive_rate < 0.01
예제 #13
0
        def dataprocess():
            input_size = config.input_size
            #for num in tqdm(dataSet):
            num = '119'
            from wfdb import rdrecord, rdann
            record = rdrecord('dataset/' + num, smooth_frames=True)
            from sklearn import preprocessing
            signals0 = preprocessing.scale(np.nan_to_num(
                record.p_signal[:, 0])).tolist()
            signals1 = preprocessing.scale(np.nan_to_num(
                record.p_signal[:, 1])).tolist()
            from scipy.signal import find_peaks
            peaks, _ = find_peaks(signals0, distance=150)

            feature0, feature1 = record.sig_name[0], record.sig_name[
                1]  #feature name e.g. ML11 V1 V2

            global lppened0, lappend1, dappend0, dappend1
            lappend0 = datalabel[feature0].append
            lappend1 = datalabel[feature1].append
            dappend0 = datadict[feature0].append
            dappend1 = datadict[feature1].append

            # skip a first peak to have enough range of the sample
            for peak in peaks[1:-1]:
                #for peak in tqdm(peaks[1:-1]):
                start, end = peak - input_size // 2, peak + input_size // 2
                ann = rdann('dataset/' + num,
                            extension='atr',
                            sampfrom=start,
                            sampto=end,
                            return_label_elements=['symbol'])

                def to_dict(chosenSym):
                    y = [0] * Nclass
                    y[classes.index(chosenSym)] = 1
                    lappend0(y)
                    lappend1(y)
                    dappend0(signals0[start:end])
                    dappend1(signals1[start:end])

                annSymbol = ann.symbol
                # remove some of "N" which breaks the balance of dataset
                if len(annSymbol) == 1 and (annSymbol[0] in classes) and (
                        annSymbol[0] != "N" or np.random.random() < 0.15):
                    to_dict(annSymbol[0])
예제 #14
0
    def load_beat_ann(self, rec:str, sampfrom:Optional[int]=None, sampto:Optional[int]=None, keep_original:bool=False) -> Dict[str, np.ndarray]:
        """ finished, checked,

        load beat annotations,
        which are stored in the `symbol` attribute of corresponding annotation files
        
        Parameters
        ----------
        rec: str,
            name of the record
        sampfrom: int, optional,
            start index of the annotations to be loaded
        sampto: int, optional,
            end index of the annotations to be loaded
        keep_original: bool, default False,
            if True, indices will keep the same with the annotation file
            otherwise subtract `sampfrom` if specified
        
        Returns
        -------
        ann, dict,
            locations (indices) of the all the beat types ("A", "N", "Q", "V",)
        """
        fp = os.path.join(self.db_dir, rec)
        header = wfdb.rdheader(fp)
        sig_len = header.sig_len
        sf = sampfrom or 0
        st = sampto or sig_len
        assert st > sf, "`sampto` should be greater than `sampfrom`!"

        wfdb_ann = wfdb.rdann(
            fp,
            extension=self.manual_ann_ext,
            sampfrom=sampfrom or 0,
            sampto=sampto,
        )
        ann = ED({k: [] for k in self.all_beat_types})
        for idx, bt in zip(wfdb_ann.sample, wfdb_ann.symbol):
            if bt not in self.all_beat_types:
                continue
            ann[bt].append(idx)
        if not keep_original and sampfrom is not None:
            ann = ED({k: np.array(v, dtype=int) - sampfrom for k, v in ann.items()})
        else:
            ann = ED({k: np.array(v, dtype=int) for k, v in ann.items()})
        return ann
예제 #15
0
    def test_xqrs(self):
        """
        Run xqrs detector on record 100 and compare to reference annotations
        """
        sig, fields = wfdb.rdsamp('sample-data/100', channels=[0])
        ann_ref = wfdb.rdann('sample-data/100', 'atr')

        xqrs = processing.XQRS(sig=sig[:, 0], fs=fields['fs'])
        xqrs.detect()

        comparitor = processing.compare_annotations(ann_ref.sample[1:],
                                                    xqrs.qrs_inds,
                                                    int(0.1 * fields['fs']))

        assert comparitor.specificity > 0.99
        assert comparitor.positive_predictivity > 0.99
        assert comparitor.false_positive_rate < 0.01
예제 #16
0
def creat_data():
    for i in range(len(target_class)):
        s=data[target_class[i]]
        R_poses = [np.array([]) for i in range(len(s))]
        beat = [[] for i in range(len(s))]  # record, beat, lead
        class_ID = [[] for i in range(len(s))]
        valid_R = [np.array([]) for i in range(len(s))]
        for k in range(len(s)):
            print(data_path+'/'+ str(s[k])+target_class[i])
            record = wfdb.rdrecord(data_path + '/' + str(s[k]), sampfrom=0, channel_names=['MLII'])
            sigal = record.p_signal
            record = wfdb.rdrecord(data_path + '/' + str(s[k]), sampfrom=0, channels=[1])
            sigal2=record.p_signal
            sigal=pre_pro(sigal)
            np.save('D:/python/svm/mit-bih/'+ str(s[k])+'.npy',sigal)
            sigal2=pre_pro(sigal2)
            annotation = wfdb.rdann(data_path+'/' + str(s[k]),'atr')
            for j in range(annotation.ann_len-1):
                pos=annotation.sample[j]
                if pos >= size_RR_max and pos + size_RR_max <= sigal.shape[0]:
                    index, value = max(enumerate(sigal[pos - size_RR_max: pos + size_RR_max]), key=operator.itemgetter(1))
                    pos = (pos - size_RR_max) + index
                R_poses[k] = np.append(R_poses[k], pos)
                if annotation.symbol[j] in MITBIH_classes:
                    if (pos > 90 and pos < (len(sigal) - 90)):
                        sign = sigal[pos - 90:pos + 90]
                        sign2=sigal2[pos - 90:pos + 90]
                        beat[k].append((sign,sign2))
                        valid_R[k] = np.append(valid_R[k], 1)
                        if annotation.symbol[j]=='N' or annotation.symbol[j]=='L' or annotation.symbol[j]=='R' :
                            class_ID[k].append(0)
                        elif annotation.symbol[j] == 'A' or annotation.symbol[j] == 'a' or annotation.symbol[j] == 'J' or annotation.symbol[j]=='e' or annotation.symbol[j]=='j' or annotation.symbol[j]=='S':
                            class_ID[k].append(1)
                        elif annotation.symbol[j] == 'V' or annotation.symbol[j] == 'E':
                            class_ID[k].append(2)
                        elif annotation.symbol[j] == 'F' :
                            class_ID[k].append(3)
                    else:
                        valid_R[k] = np.append(valid_R[k], 0)
                else:
                    valid_R[k] = np.append(valid_R[k], 0)
        np.save('D:/python/svm/npy_180/data_'+target_class[i]+'_seg.npy', beat)
        np.save('D:/python/svm/npy_180/label_'+target_class[i]+'_seg.npy', class_ID)
        np.save('D:/python/svm/npy_180/R_'+target_class[i]+'.npy', R_poses)
        np.save('D:/python/svm/npy_180/valid_R_' + target_class[i] + '.npy', valid_R)
        print(len(beat))
예제 #17
0
def main():
    FILE_DIR = "data/"

    normal_final = []
    abnormal_final = []
    abnormal_label_final = []

    for filename in os.listdir(FILE_DIR):
        if filename.endswith(".dat"):
            
            fn = FILE_DIR + filename[:-4]
            print(fn)

            sample, _ = wfdb.rdsamp(fn)
            annotation = wfdb.rdann(fn, 'atr')

            sample_array, _ = zip(*sample)
            xqrs = wdpc.XQRS(sig=np.asarray(sample_array), fs=360)
            xqrs.detect()

            n_s, a_s, a_l = split_samples_by_annotation(sample_array, annotation, xqrs.qrs_inds)

            print(len(n_s))
            print(len(a_s))
            for sig in n_s:
                normal_final.append(sig)

            for asig in a_s:
                abnormal_final.append(asig) 
            
            for label in a_l:
                abnormal_label_final.append(label)

            print(np.shape(normal_final))
            print(np.shape(abnormal_final))
            print(np.shape(abnormal_label_final))
    
    with open('normal.pickle', 'wb') as norm_file:
        pickle.dump(normal_final, norm_file)

    with open('abnormal.pickle', 'wb') as abnorm_file:
        pickle.dump(abnormal_final, abnorm_file)

    with open('abnormal_label.pickle', 'wb') as abnorm_label_file:
        pickle.dump(abnormal_label_final, abnorm_label_file)
예제 #18
0
    def read_data(self, record):
        """
        This method reads files with WFDB pkg to deal with .hea, .art, .dat files...
        :return: data_set, label_set
        """

        # Initialization
        _tmp_data_set = []
        _tmp_label_set = []

        print('reading case #{}'.format(record))
        record = self._dir + record
        rec = wfdb.rdrecord(record_name=record, channels=[0], physical=False)
        ann = wfdb.rdann(record_name=record,
                         extension='atr',
                         return_label_elements=['symbol'])

        # Make a screener to fire 5 types of annotation
        ann_ids = np.in1d(ann.symbol, self._ann)
        beats = np.array(ann.sample)[ann_ids]
        label = np.array(ann.symbol)[ann_ids]
        label = DataPreprocessing.str_to_int(label)
        sig = rec.d_signal.ravel()

        # Abandon head and tail [1: -1]
        for j, beat in enumerate(beats[1:-1]):

            # Beat width = 256 sample points
            _from, _to = beat - 128, beat + 128

            if _from < 0:
                # skip uncompleted beat
                pass
            else:
                buffer = self.process(sig, _from, _to)

            # Append sample to _tmp_data_set and _tmp_label_set
            _tmp_data_set = np.concatenate(
                (_tmp_data_set,
                 [buffer])) if _tmp_data_set != list([]) else [buffer]
            _tmp_label_set = np.concatenate(
                (_tmp_label_set, [label[j]]),
                axis=-1) if _tmp_label_set != list([]) else [label[j]]

        return _tmp_data_set, _tmp_label_set
예제 #19
0
def resample(records):

    lead2_signal = []
    lead2_annotation = []

    for i in tqdm(records, desc='Resample'):
        _, fields = wfdb.rdsamp(i, sampto=1)
        if 'MLII' in fields['sig_name']:
            ch = fields['sig_name'].index('MLII')
            signals, _ = wfdb.rdsamp(i, channels=[ch])
            signal_250Hz = signal.resample_poly(signals, 25, 36)
            lead2_signal.append(signal_250Hz)
            ann = wfdb.rdann(i, 'atr')
            sample_250Hz = ann.sample * 25 / 36
            annotation = list(zip(sample_250Hz, ann.symbol))
            lead2_annotation.append(annotation)

    return lead2_signal, lead2_annotation
예제 #20
0
def load_data(pr):
    try:
        file_name = path_to_database + '/' + pr
        record = wfdb.rdrecord(record_name=file_name)
        annotations = wfdb.rdann(file_name, 'atr')
        rec_signal = record.p_signal[:, 0]
        ant_symbol = annotations.symbol
        ant_sample = annotations.sample

        bad_signal = []
        for item, position in zip(ant_symbol, ant_sample):
            if item in abnormal_beats:
                bad_signal.append(position)

        return rec_signal, ant_symbol, ant_sample, bad_signal

    except ValueError:
        print('Error while loading data!')
예제 #21
0
def OpenApneaECG():
    db_path = 'D:\\StudiumAddidtional\\ACML\\datasets\\db1_apnea-ecg\\a01';
    #db_path = 'datasets\\db3_ucddb\\ucddb002';

    record = wfdb.rdrecord(db_path)

    signals, fields = wfdb.rdsamp(db_path)
    print('Fields: ', fields)

    ann = wfdb.rdann(db_path, 'apn')
    print("ann.sample len = ", len(ann.sample))
    print(ann.sample)

    print("ann.symbol len = ", len(ann.symbol))
    print(ann.symbol)

    wfdb.plot_wfdb(record=record, annotation=ann, plot_sym=True, title='Record a1 from Physionet Apnea')
    return
예제 #22
0
 def remove_non_beat(self, sample_name, rule_based):
     if rule_based:
         self.BEAT_ANN.extend(['[', '!', ']', '(BII\x00'])
     annotation = wfdb.rdann(sample_name, "atr")
     beat_ann = list()
     beat_sym = list()
     samples = annotation.sample
     symbols = annotation.symbol
     for j in range(len(annotation.sample)):
         if symbols[j] == '+' and rule_based:
             symbols[j] = annotation.aux_note[j]
         if symbols[j] in self.BEAT_ANN:
             symbol = symbols[j]
             peak = samples[j]
             beat_ann.append(peak)
             beat_sym.append(symbol)
     assert len(beat_ann) == len(beat_sym)
     return beat_ann, beat_sym
예제 #23
0
def getdata(datapath):
    pac = ['a', 'J', 'A', 'S']
    pvc = ['V', 'r']
    mitdb = []
    mitdb_a = []
    mitdb_v = []
    sample = wfdb.rdsamp(datapath)
    ln = (sample[0].shape[0]) // 360
    s = []
    s.append(resample(sample[0][:ln * 360, 0], ln * 400).astype(np.float16))  #
    s.append(resample(sample[0][:ln * 360, 1], ln * 400).astype(np.float16))  #
    ann = wfdb.rdann('data//118', 'atr')
    beats_a = ann.sample[np.isin(ann.symbol, pac)] * 400 // 360
    beats_v = ann.sample[np.isin(ann.symbol, pvc)] * 400 // 360
    mitdb.append(s)
    mitdb_a.append(beats_a)
    mitdb_v.append(beats_v)
    return mitdb, mitdb_a, mitdb_v
예제 #24
0
 def __init__(self, num_ECG, dataBase='MIT', ti=0, tf=20000  ):
     
     self.__dataBase__ = dataBase
     
     if dataBase == 'MIT' :
         signals, fields = wfdb.io.rdsamp( num_ECG, pb_dir='mitdb', sampfrom = ti, sampto = tf)
         ann = wfdb.rdann(     num_ECG, pb_dir='mitdb', sampfrom = ti, sampto = tf, extension = 'atr'   )
         
         self.signal = signals[:,0]
         self.qrs    = ann.sample
         self.fs     = fields['fs']
         self.len    = len(self.signal)
         self.time   = np.arange( ti, tf, 1 ) / self.fs
         self.dfsignal  = pd.DataFrame({     'Signal' : self.signal,
                                             'Time'   : self.time
                                             })
     else:
         print("ERRROR: Formato Incompatible")
예제 #25
0
    def single_classifier_test(self, detector, tolerance=0):
        max_delay_in_samples = 350 / 5
        dat_files = []
        for file in os.listdir(self.mitdb_dir):
            if file.endswith(".dat"):
                dat_files.append(file)

        mit_records = [w.replace(".dat", "") for w in dat_files]

        results = np.zeros((len(mit_records), 5), dtype=int)

        i = 0
        for record in mit_records:
            progress = int(i / float(len(mit_records)) * 100.0)
            print("MITDB progress: %i%%" % progress)

            sig, fields = wfdb.rdsamp(self.mitdb_dir + '/' + record)
            unfiltered_ecg = sig[:, 0]

            ann = wfdb.rdann(str(self.mitdb_dir + '/' + record), 'atr')
            anno = _tester_utils.sort_MIT_annotations(ann)

            r_peaks = detector(unfiltered_ecg)

            delay = _tester_utils.calcMedianDelay(r_peaks, unfiltered_ecg,
                                                  max_delay_in_samples)

            if delay > 1:

                TP, FP, FN = _tester_utils.evaluate_detector(r_peaks,
                                                             anno,
                                                             delay,
                                                             tol=tolerance)
                TN = len(unfiltered_ecg) - (TP + FP + FN)

                results[i, 0] = int(record)
                results[i, 1] = TP
                results[i, 2] = FP
                results[i, 3] = FN
                results[i, 4] = TN

            i = i + 1

        return results
예제 #26
0
 def show_table(self):
     # 这么多行
     self.timer.stop()
     self.bottom_layout.setCurrentIndex(4)
     rows = self.patient
     for row in range(0, rows):
         item = QTableWidgetItem(str(100 + row))
         self.patient_table.setItem(row, 0, item)
         head = wfdb.rdheader('MIT-BIH/mit-bih-database/' + str(100 + row))
         age, gender, _, _, _ = head.comments[0].split(" ")
         item = QTableWidgetItem(str(age))
         self.patient_table.setItem(row, 1, item)
         item = QTableWidgetItem(str(gender))
         self.patient_table.setItem(row, 2, item)
         drugs = head.comments[1]
         item = QTableWidgetItem(str(drugs))
         self.patient_table.setItem(row, 3, item)
         record = wfdb.rdann('MIT-BIH/mit-bih-database/' + str(100 + row),
                             "atr",
                             sampfrom=0,
                             sampto=650000)
         A, V, F, R, L = 0, 0, 0, 0, 0
         for index in record.symbol:
             if index == 'A':
                 A += 1
             if index == "V":
                 V += 1
             if index == "F":
                 F += 1
             if index == "R":
                 R += 1
             if index == "L":
                 L += 1
         item = QTableWidgetItem(str(A))
         self.patient_table.setItem(row, 4, item)
         item = QTableWidgetItem(str(V))
         self.patient_table.setItem(row, 5, item)
         item = QTableWidgetItem(str(F))
         self.patient_table.setItem(row, 6, item)
         item = QTableWidgetItem(str(R))
         self.patient_table.setItem(row, 7, item)
         item = QTableWidgetItem(str(L))
         self.patient_table.setItem(row, 8, item)
         self.patient_table.resizeColumnsToContents()
예제 #27
0
def read_file(file, participant):
    """Utility function
    """
    # Get signal
    data = pd.DataFrame({"ECG": wfdb.rdsamp(file[:-4])[0][:, 0]})
    data["Participant"] = "MIT-Arrhythmia_%.2i" %(participant)
    data["Sample"] = range(len(data))
    data["Sampling_Rate"] = 360
    data["Database"] = "MIT-Arrhythmia-x" if "x_mitdb" in file else "MIT-Arrhythmia"

    # getting annotations
    anno = wfdb.rdann(file[:-4], 'atr')
    anno = np.unique(anno.sample[np.in1d(anno.symbol, ['N', 'L', 'R', 'B', 'A', 'a', 'J', 'S', 'V', 'r', 'F', 'e', 'j', 'n', 'E', '/', 'f', 'Q', '?'])])
    anno = pd.DataFrame({"Rpeaks": anno})
    anno["Participant"] = "MIT-Arrhythmia_%.2i" %(participant)
    anno["Sampling_Rate"] = 360
    anno["Database"] = "MIT-Arrhythmia-x" if "x_mitdb" in file else "MIT-Arrhythmia"

    return data, anno
예제 #28
0
def segmentation(records,beat):
    Normal = []
    for e in records:
        signals, fields = wfdb.rdsamp(e)
        ann = wfdb.rdann(e, 'atr')
        good = [beat]
        ids = np.in1d(ann.symbol, good)
        imp_beats = ann.sample[ids]
        beats = (ann.sample)
        for i in imp_beats:
            beats = list(beats)
            j = beats.index(i)
            if(j!=0 and j!=(len(beats)-1)):
                x = beats[j-1]
                y = beats[j+1]
                diff1 = abs(x - beats[j])//2
                diff2 = abs(y - beats[j])//2
                Normal.append(signals[beats[j] - diff1: beats[j] + diff2, 0])
    return Normal
예제 #29
0
    def read_annot(self):

        annotation = wfdb.rdann(self.path.split(".")[0], 'atr')
        arryth = [None for i in range(650000)]
        beat = [None for i in range(650000)]
        self.n_ann = len(annotation.sample)
        self.inds = annotation.sample
        self.arrhyt = annotation.aux_note
        self.beat = annotation.symbol

        c = 0
        n = []

        for i in self.arrhyt:
            if i != "":
                n.append(i)
            else:
                n.append(0)
        self.arrhyt = n
예제 #30
0
def create_dataset_from_records(records):
    norm = [] # stores normal heartbeat data
    pvc  = [] # stores PVC heartbeat data
    lbbb = [] # stores LBBB heartbeat data
    rbbb = [] # stores RBBB heartbeat data

    for record in records:
        signals, fields = wfdb.rdsamp('mitdb/' + str(record)) # read record signals
        annotations     = wfdb.rdann('mitdb/' + str(record), 'atr') # read record annotations
        window_size     = 90 # size of window for each sample

        MLII = [sig[0] for sig in signals] # using MLII data

        # slice off data that doesn't fit in the window
        start = 0
        while annotations.sample[start] < window_size:
            start += 1

        end = 0
        while fields['sig_len'] - annotations.sample[end - 1] < window_size:
            end -= 1

        annos = zip(annotations.sample[start : end], annotations.symbol[start : end])

        # extracts normal and pvc data
        for sample, symbol in annos:
            if symbol == 'N':
                norm.append(MLII[sample - window_size : sample + window_size])
            elif symbol == 'V':
                pvc.append(MLII[sample - window_size : sample + window_size])
            elif symbol == 'L':
                lbbb.append(MLII[sample - window_size : sample + window_size])
            elif symbol == 'R':
                rbbb.append(MLII[sample - window_size : sample + window_size])

    # write to HDF5
    min_len = len(rbbb)
    with h5py.File('mitdb.hdf5', 'w') as f:
        f.create_dataset('normal', data=np.array(norm)[:min_len])
        f.create_dataset('pvc', data=np.array(pvc)[:min_len])
        f.create_dataset('lbbb', data=np.array(lbbb)[:min_len])
        f.create_dataset('rbbb', data=np.array(rbbb))
    return
예제 #31
0
def read_r_peak(filename, sampfrom=None, sampto=None):
    annotation = wfdb.rdann(filename, 'atr', sampfrom=sampfrom, sampto=sampto)
    sample = annotation.sample
    symbol = annotation.symbol
    #删除非r波的标注
    AAMI_MIT_MAP = {
        'N': 'Nfe/jnBLR',  # 将19类信号分为五大类,这19类与r波位置相关
        'S': 'SAJa',
        'V': 'VEr',
        'F': 'F',
        'Q': 'Q?'
    }
    MIT2AAMI = {c: k for k in AAMI_MIT_MAP.keys() for c in AAMI_MIT_MAP[k]}
    mit_beat_codes = list(MIT2AAMI.keys())
    symbol = np.array(symbol)
    print(symbol)
    isin = np.isin(symbol, mit_beat_codes)
    sample = sample[isin]
    return sample
예제 #32
0
def splitup_signal_from_file(file_path, splice_size=None):
    ann_file_name = os.path.join(file_path)
    ann_file_exists = os.path.exists(ann_file_name +
                                     '.atr') and os.path.isfile(ann_file_name +
                                                                '.atr')
    rec_file_name = os.path.join(file_path)
    rec_file_exists = os.path.exists(ann_file_name +
                                     '.dat') and os.path.isfile(ann_file_name +
                                                                '.dat')
    if ann_file_exists and rec_file_exists:
        annotations = wfdb.rdann(ann_file_name, extension='atr')
        sample, meta = wfdb.rdsamp(rec_file_name, channels=[0])
        meta['file_name'] = file_path.rsplit(os.path.sep, 1)[1]
    else:
        return None, None, None
    if splice_size:
        return meta, splice_per_beat_type(sample, annotations, splice_size)
    else:
        return meta, splitup_signal_by_beat_type(sample, annotations)
예제 #33
0
def show_annotations(path):
    """ Exemplary code """
    record = wf.rdsamp(path)
    annotation = wf.rdann(path, 'atr')

    # Get data and annotations for the first 2000 samples
    howmany = 2000
    channel = record.p_signals[:howmany, 0]

    # Extract all of the annotation related infromation
    where = annotation.annsamp < howmany
    samp = annotation.annsamp[where]

    # Convert to numpy.array to get fancy indexing access
    types = np.array(annotation.anntype)
    types = types[where]

    times = np.arange(howmany, dtype = 'float') / record.fs
    plt.plot(times, channel)

    # Prepare qrs information for the plot
    qrs_times = times[samp]

    # Scale to show markers at the top 
    qrs_values = np.ones_like(qrs_times)
    qrs_values *= channel.max() * 1.4

    plt.plot(qrs_times, qrs_values, 'ro')

    # Also show annotation code
    # And their words
    for it, sam in enumerate(samp):
        # Get the annotation position
        xa = times[sam]
        ya = channel.max() * 1.1

        # Use just the first letter 
        a_txt = types[it]
        plt.annotate(a_txt, xy = (xa, ya))

    plt.xlim([0, 4])
    plt.xlabel('Time [s]')
    plt.show()
예제 #34
0
파일: patient.py 프로젝트: fpp123/sim_gan
    def get_annotations(self):
        """Get signal annotation using the wfdb package.

        :return:
        """
        ann = wfdb.rdann(
            self.patient_number,
            'atr',
            pb_dir='mitdb',
            return_label_elements=['symbol', 'label_store', 'description'],
            summarize_labels=True)

        mit_bih_labels_str = ann.symbol

        labels_locations = ann.sample

        labels_description = ann.description

        return mit_bih_labels_str, labels_locations, labels_description
예제 #35
0
def segmentation(records):
    # Normal = []
    # Abnormal = []

    mainPatient = []
    otherPatients = []

    isMainPatient = True

    for e in records:
        signals, fields = wfdb.rdsamp(e, channels=[0])
        ann = wfdb.rdann(e, 'atr')
        all_beats = ann.sample[:]
        beats = ann.sample

        if isMainPatient:
            for i in all_beats:
                beats = list(beats)
                j = beats.index(i)
                if j != 0 and j != (len(beats) - 1):
                    # print(j-1)
                    x = beats[j - 1]
                    y = beats[j + 1]
                    diff1 = abs(x - beats[j]) // 2
                    diff2 = abs(y - beats[j]) // 2
                    a = signals[beats[j] - diff1: beats[j] + diff2, 0]
                    for k in a:
                        mainPatient.append(k)
            isMainPatient = False
        else:
            for i in all_beats:
                beats = list(beats)
                j = beats.index(i)
                if j != 0 and j != (len(beats) - 1):
                    # print(j-1)
                    x = beats[j - 1]
                    y = beats[j + 1]
                    diff1 = abs(x - beats[j]) // 2
                    diff2 = abs(y - beats[j]) // 2
                    a = signals[beats[j] - diff1: beats[j] + diff2, 0]
                    for k in a:
                        otherPatients.append(k)
    return np.array(mainPatient, dtype=np.float), np.array(otherPatients, dtype=np.float)
예제 #36
0
def extract_R(dataset, span=100, sampfrom=0, sampto=None):
    # Get data and classifications from record
    record = wfdb.rdrecord(('data_raw/' + dataset), sampfrom, sampto)
    R_class = wfdb.rdann(('data_raw/' + dataset), 'atr', sampfrom, sampto)
    pd.DataFrame(record.p_signal[:, 0]).to_csv("temp.csv")
    R_error = 0
    sampto = len(record.p_signal[:, 0])

    # Detect QRS
    qrs = QRSDetectorOffline(ecg_data_path="temp.csv",
                             verbose=True,
                             log_data=False,
                             plot_data=False,
                             show_plot=False)
    peaks = qrs.detected_peaks_indices
    signal = np.zeros((span, len(peaks)))
    last_idx = 0

    # Check if all QRS are detected - avoid desynchronization
    if (len(peaks) < len(R_class.symbol)):
        error = 1

    # Move QRS data to array - one element in array contains centered QRS with span defined in arg
    for peak in range(0, len(peaks)):
        if (peaks[peak] + round(span / 2) <= (sampto - sampfrom)):
            last_idx = peak
            for i in range(0, span):
                signal[i, peak] = record.p_signal[(peaks[peak] -
                                                   round(span / 2) + i), 0]

    # Check if the signal has desired length - avoid QRS in start/end of the frame (i.e when QRS is centered at t = 499 and frame has T = 500)
    R_signal = np.zeros((span, last_idx + 1))
    for peak in range(0, last_idx + 1):
        for i in range(0, span):
            R_signal[i, peak] = signal[i, peak]

    # Get classification of each QRS
    R_reference = np.empty((last_idx + 1), dtype='object')
    for i in range(0, last_idx + 1):
        R_reference[i] = R_class.symbol[i]

    # Return classifications and signals
    return (R_error, R_reference, R_signal)
예제 #37
0
 def sigcontainer(self, annotators: Iterable[str] = None) -> SigContainer:
     c = SigContainer.from_signal_array(signals=np.transpose(self.record.p_signal),
                                        channels=self.record.sig_name,
                                        units=self.record.units, fs=self.record.fs)
     if annotators is not None:
         with h5py.File("physionet_cache.h5") as store:
             for annotator in annotators:
                 if annotator not in self.annotations:
                     annopath = f"{self.path}/{annotator}"
                     if annopath not in store:
                         self.annotations[annotator] = wfdb.rdann(self.name, annotator,
                                                                  pb_dir=self.database)
                         store.create_dataset(annopath, data=dumpa(self.annotations[annotator]),
                                              compression="gzip")
                     else:
                         self.annotations[annotator] = pickle.loads(store[annopath][:])
                 data = self.annotations[annotator]
                 c.add_annotation(annotator, data.sample, data.symbol, data.aux_note)
     return c
예제 #38
0
def show_annotations(path):
    """ Exemplary code """
    record = wf.rdsamp(path)
    annotation = wf.rdann(path, 'atr')

    # Get data and annotations for the first 2000 samples
    howmany = 2000
    channel = record.p_signals[:howmany, 0]

    # Extract all of the annotation related infromation
    where = annotation.annsamp < howmany
    samp = annotation.annsamp[where]

    # Convert to numpy.array to get fancy indexing access
    types = np.array(annotation.anntype)
    types = types[where]

    times = np.arange(howmany, dtype='float') / record.fs
    plt.plot(times, channel)

    # Prepare qrs information for the plot
    qrs_times = times[samp]

    # Scale to show markers at the top
    qrs_values = np.ones_like(qrs_times)
    qrs_values *= channel.max() * 1.4

    plt.plot(qrs_times, qrs_values, 'ro')

    # Also show annotation code
    # And their words
    for it, sam in enumerate(samp):
        # Get the annotation position
        xa = times[sam]
        ya = channel.max() * 1.1

        # Use just the first letter
        a_txt = types[it]
        plt.annotate(a_txt, xy=(xa, ya))

    plt.xlim([0, 4])
    plt.xlabel('Time [s]')
    plt.show()
예제 #39
0
def create_img_from_sign_filtered(size=(128, 128), size_paa=100, augmentation=True):
    """
        For each beat for each patient creates img apply some filters
        :param size: the img size
        :param augmentation: create for each image another nine for each side
    """
    if not os.path.exists(_directory):
        os.makedirs(_directory)

    files = [f[:-4] for f in listdir(_directory) if isfile(join(_directory, f)) if (f.find('.dat') != -1)]

    for file in files:
        sig, _ = wfdb.rdsamp(_directory + file)
        ann = wfdb.rdann(_directory + file, extension='atr')
        for i in tqdm.tqdm(range(1, len(ann.sample) - 1)):

            if ann.symbol[i] not in lb.original_labels:
                continue
            label = lb.original_labels[ann.symbol[i]]
            ''' Get the Q-peak intervall '''
            start = ann.sample[i - 1] + _range_to_ignore
            end = ann.sample[i + 1] - _range_to_ignore

            signal = [sig[i][0] for i in range(start, end)]
            paa = piecewise_aggregate_approximation(signal, size_paa)
            plot_x = paa
            plot_y = [i for i in range(len(paa))]
            ''' Plot and save the beat'''
            fig = plt.figure(frameon=False)
            plt.plot(plot_y, plot_x)
            plt.xticks([]), plt.yticks([])
            for spine in plt.gca().spines.values():
                spine.set_visible(False)

            filename = '{}{}_{}{}{}0.png'.format(_dataset_dir, label, file[-3:], start, end)
            fig.savefig(filename, bbox_inches='tight')
            im_gray = cv2.imread(filename, cv2.IMREAD_GRAYSCALE)
            im_gray = cv2.resize(im_gray, size, interpolation=cv2.INTER_LANCZOS4)
            cv2.imwrite(filename, im_gray)
            plt.cla()
            plt.clf()
            plt.close('all')
예제 #40
0
    def load_rpeak_indices(self, rec:str, sampfrom:Optional[int]=None, sampto:Optional[int]=None, use_manual:bool=True, keep_original:bool=False) -> np.ndarray:
        """ finished, checked,

        load rpeak indices, or equivalently qrs complex locations,
        which are stored in the `symbol` attribute of corresponding annotation files,
        regardless of their beat types,
        
        Parameters
        ----------
        rec: str,
            name of the record
        sampfrom: int, optional,
            start index of the annotations to be loaded
        sampto: int, optional,
            end index of the annotations to be loaded
        use_manual: bool, default True,
            use manually annotated beat annotations (qrs),
            instead of those generated by algorithms
        keep_original: bool, default False,
            if True, indices will keep the same with the annotation file
            otherwise subtract `sampfrom` if specified
        
        Returns
        -------
        ann, ndarray,
            locations (indices) of the all the rpeaks (qrs complexes)
        """
        fp = os.path.join(self.db_dir, rec)
        if use_manual:
            ext = self.manual_ann_ext
        else:
            ext = self.auto_ann_ext
        wfdb_ann = wfdb.rdann(
            fp,
            extension=ext,
            sampfrom=sampfrom or 0,
            sampto=sampto,
        )
        rpeak_inds = wfdb_ann.sample[np.isin(wfdb_ann.symbol, self.all_beat_types)]
        if not keep_original and sampfrom is not None:
            rpeak_inds = rpeak_inds - sampfrom
        return rpeak_inds
예제 #41
0
    def test_correct_peaks(self):
        sig, fields = wfdb.rdsamp('sample-data/100')
        ann = wfdb.rdann('sample-data/100', 'atr')
        fs = fields['fs']
        min_bpm = 10
        max_bpm = 350
        min_gap = fs*60/min_bpm
        max_gap = fs * 60 / max_bpm

        y_idxs = processing.correct_peaks(sig=sig[:,0], peak_inds=ann.sample,
                                          search_radius=int(max_gap),
                                          smooth_window_size=150)

        yz = np.zeros(sig.shape[0])
        yz[y_idxs] = 1
        yz = np.where(yz[:10000]==1)[0]

        expected_peaks = [77, 370, 663, 947, 1231, 1515, 1809, 2045, 2403,
                          2706, 2998, 3283, 3560, 3863, 4171, 4466, 4765, 5061,
                          5347, 5634, 5919, 6215, 6527, 6824, 7106, 7393, 7670,
                          7953, 8246, 8539, 8837, 9142, 9432, 9710, 9998]

        assert np.array_equal(yz, expected_peaks)
예제 #42
0
    def test_1(self):
        """
        Target file created with:
            rdann -r sample-data/100 -a atr > ann-1
        """
        annotation = wfdb.rdann('sample-data/100', 'atr')


        # This is not the fault of the script. The annotation file specifies a
        # length 3
        annotation.aux_note[0] = '(N'
        # aux_note field with a null written after '(N' which the script correctly picks up. I am just
        # getting rid of the null in this unit test to compare with the regexp output below which has
        # no null to detect in the output text file of rdann.

        # Target data from WFDB software package
        lines = tuple(open('tests/target-output/ann-1', 'r'))
        nannot = len(lines)

        target_time = [None] * nannot
        target_sample = np.empty(nannot, dtype='object')
        target_symbol = [None] * nannot
        target_subtype = np.empty(nannot, dtype='object')
        target_chan = np.empty(nannot, dtype='object')
        target_num = np.empty(nannot, dtype='object')
        target_aux_note = [None] * nannot

        RXannot = re.compile(
            '[ \t]*(?P<time>[\[\]\w\.:]+) +(?P<sample>\d+) +(?P<symbol>.) +(?P<subtype>\d+) +(?P<chan>\d+) +(?P<num>\d+)\t?(?P<aux_note>.*)')

        for i in range(0, nannot):
            target_time[i], target_sample[i], target_symbol[i], target_subtype[i], target_chan[
                i], target_num[i], target_aux_note[i] = RXannot.findall(lines[i])[0]

        # Convert objects into integers
        target_sample = target_sample.astype('int')
        target_num = target_num.astype('int')
        target_subtype = target_subtype.astype('int')
        target_chan = target_chan.astype('int')

        # Compare
        comp = [np.array_equal(annotation.sample, target_sample),
                np.array_equal(annotation.symbol, target_symbol),
                np.array_equal(annotation.subtype, target_subtype),
                np.array_equal(annotation.chan, target_chan),
                np.array_equal(annotation.num, target_num),
                annotation.aux_note == target_aux_note]

        # Test file streaming
        pbannotation = wfdb.rdann('100', 'atr', pb_dir='mitdb', return_label_elements=['label_store', 'symbol'])
        pbannotation.aux_note[0] = '(N'
        pbannotation.create_label_map()

        # Test file writing
        annotation.wrann(write_fs=True)
        writeannotation = wfdb.rdann('100', 'atr', return_label_elements=['label_store', 'symbol'])
        writeannotation.create_label_map()

        assert (comp == [True] * 6)
        assert annotation.__eq__(pbannotation)
        assert annotation.__eq__(writeannotation)