コード例 #1
0
ファイル: TEST_DPI.py プロジェクト: Jessicarryly/ECG_QRS_DPI
def Test1():
    '''Comparing to expert labels in QTdb.'''
    qt = QTloader()
    reclist = qt.getreclist()

    rec_ind = 0
    for rec_ind in xrange(0, len(reclist)):

        print 'Processing record[%d] %s ...' % (rec_ind, reclist[rec_ind])
        sig = qt.load(reclist[rec_ind])
        raw_sig = sig['sig']
        expert_labels = qt.getExpert(reclist[rec_ind])
        R_pos_list = [
            x[0] for x in filter(lambda item: item[1] == 'R', expert_labels)
        ]

        # Skip empty expert lists
        if len(R_pos_list) == 0:
            continue

        dpi = DPI()

        qrs_list = dpi.QRS_Detection(raw_sig)

        # Find FN
        FN_arr = GetFN(R_pos_list, qrs_list)
        R_pos_list = FN_arr

        if len(R_pos_list) > 0:
            plt.plot(raw_sig)
            amp_list = [raw_sig[x] for x in qrs_list]
            plt.plot(qrs_list, amp_list, 'ro', markersize=12)
            amp_list = [raw_sig[x] for x in R_pos_list]
            plt.plot(R_pos_list, amp_list, 'ys', markersize=14)
            plt.show()
コード例 #2
0
def Test_hog1d():
    '''Hog feature method test.'''
    target_label = 'T'
    qt = QTloader()
    rec_ind = 103
    reclist = qt.getreclist()
    sig = qt.load(reclist[rec_ind])
    raw_sig = sig['sig']
    len_sig = len(raw_sig)

    # Hog feature tester
    hog = HogFeatureExtractor(target_label=target_label)
    rec_list = hog.qt.getreclist()

    training_list = reclist[0:100]
    # training_rec = list(set(rec_list) - set(testing_rec))

    hog.Train(training_list)

    # Load the trained model
    # with open('./hog.mdl', 'rb') as fin:
    # mdl = pickle.load(fin)
    # hog.LoadModel(mdl)
    # Save the trained model
    with open('./hog.mdl', 'wb') as fout:
        pickle.dump(hog.gbdt, fout)

    dpi = DPI()
    qrs_list = dpi.QRS_Detection(raw_sig)

    start_time = time.time()
    detected_poslist = hog.Testing(raw_sig,
                                   zip(qrs_list, [
                                       'R',
                                   ] * len(qrs_list)))
    print 'Testing time: %d s.' % (time.time() - start_time)
    while len(detected_poslist) > 0 and detected_poslist[-1] > len_sig:
        del detected_poslist[-1]
    while len(detected_poslist) > 0 and detected_poslist[0] < 0:
        del detected_poslist[0]

    sig = qt.load(reclist[rec_ind])
    raw_sig = sig['sig']
    plt.plot(raw_sig, label='raw signal D1')
    plt.plot(sig['sig2'], label='raw signal D2')
    amp_list = [raw_sig[int(x)] for x in detected_poslist]
    plt.plot(detected_poslist,
             amp_list,
             'ro',
             markersize=12,
             label=target_label)
    plt.title('Record name %s' % reclist[rec_ind])
    plt.legend()
    plt.show()
コード例 #3
0
ファイル: TEST_DPI.py プロジェクト: Jessicarryly/ECG_QRS_DPI
def TestMit():
    '''Comparing to expert labels in Mitdb.'''
    mit = MITdbLoader()
    reclist = mit.getRecIDList()
    print dir(mit)

    rec_ind = 0
    for rec_ind in xrange(3, len(reclist)):

        print 'Processing record[%d] %s ...' % (rec_ind, reclist[rec_ind])
        sig = mit.load(reclist[rec_ind])
        raw_sig = sig
        # expert_labels = mit.getExpert(reclist[rec_ind])
        R_pos_list = [int(round(x)) for x in mit.markpos]

        # plt.plot(sig)
        # amp_list = [sig[int(x)] for x in R_pos_list]
        # plt.plot(R_pos_list, amp_list, 'ro', markersize = 12)
        # plt.show()

        # Skip empty expert lists
        if len(R_pos_list) == 0:
            continue

        debug_info = dict()
        debug_info['time_cost'] = 75410
        # debug_info['decision_plot'] = 57181
        dpi = DPI(debug_info=debug_info)
        qrs_list = dpi.QRS_Detection(raw_sig, fs=360)

        # Find FN
        FN_arr = GetFN(R_pos_list, qrs_list)
        R_pos_list = FN_arr

        if len(R_pos_list) > 0:
            plt.plot(raw_sig)
            amp_list = [raw_sig[x] for x in qrs_list]
            plt.plot(qrs_list,
                     amp_list,
                     'ro',
                     markersize=12,
                     label='Detected R with DPI')
            amp_list = [raw_sig[x] for x in R_pos_list]
            plt.plot(R_pos_list,
                     amp_list,
                     'ys',
                     markersize=14,
                     label='Expert labels')
            plt.legend()
            plt.title('Record %s' % reclist[rec_ind])
            plt.show()
コード例 #4
0
def Test_regression():
    '''Regression test.'''
    target_label = 'T'
    qt = QTloader()
    rec_ind = 65
    reclist = qt.getreclist()
    sig = qt.load(reclist[rec_ind])
    raw_sig = sig['sig']
    len_sig = len(raw_sig)

    rf = RegressionLearner(target_label=target_label)
    rf.TrainQtRecords(reclist[0:])

    # Load the trained model
    # with open('./tmp.mdl', 'rb') as fin:
    # mdl = pickle.load(fin)
    # rf.LoadModel(mdl)
    # Save the trained model
    with open('./tmp.mdl', 'wb') as fout:
        pickle.dump(rf.mdl, fout)

    dpi = DPI()
    qrs_list = dpi.QRS_Detection(raw_sig)

    start_time = time.time()
    detected_poslist = rf.testing(raw_sig,
                                  zip(qrs_list, [
                                      'R',
                                  ] * len(qrs_list)))
    print 'Testing time: %d s.' % (time.time() - start_time)
    while len(detected_poslist) > 0 and detected_poslist[-1] > len_sig:
        del detected_poslist[-1]

    sig = qt.load(reclist[rec_ind])
    raw_sig = sig['sig']
    plt.plot(raw_sig, label='raw signal')
    amp_list = [raw_sig[int(x)] for x in detected_poslist]
    plt.plot(detected_poslist,
             amp_list,
             'ro',
             markersize=12,
             label=target_label)
    plt.legend()
    plt.show()
コード例 #5
0
    def testing(self, raw_sig, fs=250.0):
        '''Testing API.
        Returns:
            A list of (index, label) pairs. For example:
            [(1, 'R'), (24, 'T'), (35, 'Toffset'),]
        '''
        if isinstance(raw_sig, list):
            len_sig = len(raw_sig)
        elif isinstance(raw_sig, np.ndarray):
            len_sig = raw_sig.size

        detected_results = list()

        # Detect R first
        debug_info = dict()
        debug_info['time_cost'] = True
        dpi = DPI(debug_info=debug_info)
        qrs_list = dpi.QRS_Detection(raw_sig, fs=fs)
        detected_results.extend(zip(qrs_list, [
            'R',
        ] * len(qrs_list)))

        for target_label in ['T', 'Toffset', 'Ponset', 'P', 'Poffset']:
            start_time = time.time()
            detected_poslist = self.model_dict[target_label].Testing(
                raw_sig, detected_results)
            print 'Testing time: %d s.' % (time.time() - start_time)
            while len(detected_poslist) > 0 and detected_poslist[-1] > len_sig:
                del detected_poslist[-1]
            while len(detected_poslist) > 0 and detected_poslist[0] < 0:
                del detected_poslist[0]

            detected_results.extend(
                zip(detected_poslist, [
                    target_label,
                ] * len(detected_poslist)))
        return detected_results
コード例 #6
0
ファイル: TEST_DPI.py プロジェクト: Jessicarryly/ECG_QRS_DPI
def TestZN():
    with open('./data.pkl', 'rb') as fin:
        data = pickle.load(fin)

    sig = data.tolist()
    raw_sig = np.squeeze(sig['II'])
    raw_sig = [x / 100.0 for x in raw_sig]

    debug_info = dict()
    debug_info['time_cost'] = 75410
    debug_info['decision_plot'] = 3517
    dpi = DPI(debug_info=debug_info)
    qrs_list = dpi.QRS_Detection(raw_sig, fs=500.0)

    # pdb.set_trace()
    plt.plot(raw_sig)
    amp_list = [raw_sig[x] for x in qrs_list]
    plt.plot(qrs_list,
             amp_list,
             'ro',
             markersize=12,
             label='Detected R with DPI')
    plt.legend()
    plt.show()
コード例 #7
0
def Test_Mit():
    '''Hog feature method test.'''
    target_label = 'T'

    # Hog feature tester
    hog = HogFeatureExtractor(target_label=target_label)
    rec_list = hog.qt.getreclist()

    training_list = rec_list[0:]
    # training_rec = list(set(rec_list) - set(testing_rec))

    # hog.Train(training_list)

    # Load the trained model
    with open('./hog.mdl', 'rb') as fin:
        mdl = pickle.load(fin)
        hog.LoadModel(mdl)
    # Save the trained model
    # with open('./hog.mdl', 'wb') as fout:
    # pickle.dump(hog.gbdt, fout)

    # Testing on mit database
    mit = MITdbLoader()

    rec_name = mit.getreclist()[9]
    raw_sig = mit.load(rec_name)
    resample_length = int(len(raw_sig) * 250.0 / 360.0)
    raw_sig = scipy.signal.resample(raw_sig, resample_length)
    len_sig = len(raw_sig)

    debug_info = dict()
    debug_info['time_cost'] = True
    dpi = DPI(debug_info=debug_info)
    qrs_list = dpi.QRS_Detection(raw_sig)

    start_time = time.time()
    detected_poslist = hog.Testing(raw_sig,
                                   zip(qrs_list, [
                                       'R',
                                   ] * len(qrs_list)))
    print 'Testing time: %d s.' % (time.time() - start_time)
    while len(detected_poslist) > 0 and detected_poslist[-1] > len_sig:
        del detected_poslist[-1]
    while len(detected_poslist) > 0 and detected_poslist[0] < 0:
        del detected_poslist[0]

    raw_sig = mit.load(rec_name)
    raw_sig = scipy.signal.resample(raw_sig, resample_length)
    sigd2 = scipy.signal.resample(mit.sigd2, resample_length)

    plt.plot(raw_sig, label='raw signal D1')
    plt.plot(sigd2, label='raw signal D2')
    amp_list = [raw_sig[int(x)] for x in detected_poslist]
    plt.plot(detected_poslist,
             amp_list,
             'ro',
             markersize=12,
             label=target_label)
    plt.title('Record name %s' % rec_name)
    plt.legend()
    plt.show()