Ejemplo n.º 1
0
def Test_hog1d():
    '''Hog feature method test.'''
    target_label = 'T'
    qt = QTloader()
    rec_ind = 103
    reclist = qt.getreclist()
    sig = qt.load(reclist[rec_ind])
    raw_sig = sig['sig']
    len_sig = len(raw_sig)

    # Hog feature tester
    hog = HogFeatureExtractor(target_label=target_label)
    rec_list = hog.qt.getreclist()

    training_list = reclist[0:100]
    # training_rec = list(set(rec_list) - set(testing_rec))

    hog.Train(training_list)

    # Load the trained model
    # with open('./hog.mdl', 'rb') as fin:
    # mdl = pickle.load(fin)
    # hog.LoadModel(mdl)
    # Save the trained model
    with open('./hog.mdl', 'wb') as fout:
        pickle.dump(hog.gbdt, fout)

    dpi = DPI()
    qrs_list = dpi.QRS_Detection(raw_sig)

    start_time = time.time()
    detected_poslist = hog.Testing(raw_sig,
                                   zip(qrs_list, [
                                       'R',
                                   ] * len(qrs_list)))
    print 'Testing time: %d s.' % (time.time() - start_time)
    while len(detected_poslist) > 0 and detected_poslist[-1] > len_sig:
        del detected_poslist[-1]
    while len(detected_poslist) > 0 and detected_poslist[0] < 0:
        del detected_poslist[0]

    sig = qt.load(reclist[rec_ind])
    raw_sig = sig['sig']
    plt.plot(raw_sig, label='raw signal D1')
    plt.plot(sig['sig2'], label='raw signal D2')
    amp_list = [raw_sig[int(x)] for x in detected_poslist]
    plt.plot(detected_poslist,
             amp_list,
             'ro',
             markersize=12,
             label=target_label)
    plt.title('Record name %s' % reclist[rec_ind])
    plt.legend()
    plt.show()
Ejemplo n.º 2
0
def TEST_ExpertQRS():
    recname = 'sel103'
    QTdb = QTloader()
    rawsig = QTdb.load(recname)
    rawsig = rawsig['sig']
    MarkList = QTdb.getExpert(recname)

    swt = SWT_NoPredictQRS(rawsig, MarkList)
    swt.swt()

    # cDlist
    wtlist = swt.cDlist[-4]

    plt.figure(1)
    # plot Non QRS ECG & SWT
    plt.subplot(211)
    plt.plot(rawsig)
    plt.plot(wtlist)
    plt.grid(True)
    # plot Original ECG
    rawsig = swt.QTdb.load(recname)
    rawsig = rawsig['sig']
    rawsig = swt.crop_data_for_swt(rawsig)
    coeflist = pywt.swt(rawsig, 'db6', 9)
    cAlist, cDlist = zip(*coeflist)
    wtlist = cDlist[-4]

    plt.subplot(212)
    plt.plot(rawsig)
    plt.plot(wtlist)
    plt.grid(True)
    plt.show()
Ejemplo n.º 3
0
def TrainingModels(target_label, model_file_name, training_list):
    '''Randomly select num_training records to train, and test others.'''
    qt = QTloader()
    record_list = qt.getreclist()
    testing_list = list(set(record_list) - set(training_list))

    random_forest_config = dict(max_depth=10)
    walker = RandomWalker(target_label=target_label,
                          random_forest_config=random_forest_config,
                          random_pattern_file_name=os.path.join(
                              os.path.dirname(model_file_name),
                              'random_pattern.json'))

    start_time = time.time()
    for record_name in training_list:
        print 'Collecting features from record %s.' % record_name
        sig = qt.load(record_name)
        walker.collect_training_data(sig['sig'], qt.getExpert(record_name))
    print 'random forest start training(%s)...' % target_label
    walker.training()
    print 'trianing used %.3f seconds' % (time.time() - start_time)

    import joblib
    start_time = time.time()
    walker.save_model(model_file_name)
    print 'Serializing model time cost %f' % (time.time() - start_time)
Ejemplo n.º 4
0
def Test():
    '''Test function for HOG1D class.'''

    # Import packages in this project
    from QTdata.loadQTdata import QTloader

    qt = QTloader()
    sig_struct  = qt.load('sel17152')
    sig = sig_struct['sig']
    sig = sig[10000:10900]

    # Plot ECG signal
    # plt.figure(1)
    # plt.plot(sig)
    # plt.show()

    # HOG 1d class
    hoger = HogClass(segment_len = 20)
    # hoger.ComputeHog(sig, debug_plot = True)
    hoger.GetRealHogArray(sig, diff_step = 5, debug_plot = True)
    # hoger.DiscretiseHog(sig, debug_plot = True)

    plt.figure(2)
    plt.plot(np.array(sig) * 10)
    plt.grid(True)
    plt.show()
Ejemplo n.º 5
0
def load_qt():
    '''Load data from QTdb.'''

    qt = QTloader()
    sig = qt.load('sel100')
    plt.plot(sig['sig'])
    plt.show()
Ejemplo n.º 6
0
def Test1():
    '''Comparing to expert labels in QTdb.'''
    qt = QTloader()
    reclist = qt.getreclist()

    rec_ind = 0
    for rec_ind in xrange(0, len(reclist)):

        print 'Processing record[%d] %s ...' % (rec_ind, reclist[rec_ind])
        sig = qt.load(reclist[rec_ind])
        raw_sig = sig['sig']
        expert_labels = qt.getExpert(reclist[rec_ind])
        R_pos_list = [
            x[0] for x in filter(lambda item: item[1] == 'R', expert_labels)
        ]

        # Skip empty expert lists
        if len(R_pos_list) == 0:
            continue

        dpi = DPI()

        qrs_list = dpi.QRS_Detection(raw_sig)

        # Find FN
        FN_arr = GetFN(R_pos_list, qrs_list)
        R_pos_list = FN_arr

        if len(R_pos_list) > 0:
            plt.plot(raw_sig)
            amp_list = [raw_sig[x] for x in qrs_list]
            plt.plot(qrs_list, amp_list, 'ro', markersize=12)
            amp_list = [raw_sig[x] for x in R_pos_list]
            plt.plot(R_pos_list, amp_list, 'ys', markersize=14)
            plt.show()
Ejemplo n.º 7
0
def TestAllQTdata(saveresultpath, testinglist):
    '''Test all records in testinglist, training on remaining records in QTdb.'''
    qt_loader = QTloader()
    QTreclist = qt_loader.getQTrecnamelist()
    # Get training record list
    traininglist = list(set(QTreclist) - set(testinglist))

    # Testing
    from randomwalk.test_api import GetModels
    from randomwalk.test_api import Testing
    # pattern_filename = os.path.join(os.path.dirname(saveresultpath), 'randrel.json')
    pattern_filename = '/home/alex/LabGit/ECG_random_walk/randomwalk/data/Lw3Np4000/random_pattern.json'
    model_folder = '/home/alex/LabGit/ECG_random_walk/randomwalk/data/Lw3Np4000'
    model_list = GetModels(model_folder, pattern_filename)

    for record_name in testinglist:
        sig = qt_loader.load(record_name)
        raw_sig = sig['sig']

        start_time = time.time()
        results = Testing(raw_sig, 250.0, model_list, walker_iterations=100)
        time_cost = time.time() - start_time

        with open(os.path.join(saveresultpath, '%s.json' % record_name),
                  'w') as fout:
            json.dump(results, fout)
            print 'Testing time %f s, data time %f s.' % (time_cost,
                                                          len(raw_sig) / 250.0)
Ejemplo n.º 8
0
def TestQT(record_name, save_result_folder, model_folder, random_pattern_file_name):
    '''Test case1.'''
    fs = 250.0
    qt = QTloader()

    sig = qt.load(record_name)
    expert_annotations = qt.getExpert(record_name)
    pos_list, label_list = zip(*expert_annotations)
    test_range = [np.min(pos_list) - 100, np.max(pos_list) + 100]
    
    result_mat = list()

    print 'Lead1'
    raw_sig = sig['sig']
    results = TestSignal(raw_sig, fs, test_range, model_folder, random_pattern_file_name)
    for ind in xrange(0, len(results)):
        results[ind] = [results[ind][0] + test_range[0], results[ind][1]]
    result_mat.append((record_name, results))

    print 'Lead2'
    raw_sig = sig['sig2']
    results = TestSignal(raw_sig, fs, test_range, model_folder, random_pattern_file_name)
    for ind in xrange(0, len(results)):
        results[ind] = [results[ind][0] + test_range[0], results[ind][1]]
    result_mat.append((record_name + '_sig2', results))
    
    result_file_name = os.path.join(save_result_folder, '%s.json' % record_name)
    with open(result_file_name, 'w') as fout:
        json.dump(result_mat, fout, indent = 4)
        print 'Results saved as %s.' % result_file_name
Ejemplo n.º 9
0
def res_to_mat_fromResult(recID, reslist, mat_filename):
    # load QT rawsig
    qt = QTloader()
    sig = qt.load(recID)
    rawsig = sig['sig']
    # load expert label
    expert_reslist = qt.getexpertlabeltuple(recID)
    # save sig and reslist
    label_dict = dict()
    for pos, label in reslist:
        if label in label_dict:
            label_dict[label].append(pos)
        else:
            label_dict[label] = [
                pos,
            ]
    # Expert Labels
    for pos, label in expert_reslist:
        exp_label = 'expert_' + label
        if exp_label in label_dict:
            label_dict[exp_label].append(pos)
        else:
            label_dict[exp_label] = [
                pos,
            ]
    label_dict['sig'] = rawsig
    scipy.io.savemat(mat_filename, label_dict)
    print 'mat file [{}] saved.'.format(mat_filename)
Ejemplo n.º 10
0
def DPI():
    '''High pass filtering.'''
    qt = QTloader()
    sig = qt.load('sel100')
    raw_sig = sig['sig']
    fsig = HPF(raw_sig)

    # DPI
    m1 = 100
    m2 = 300
    len_sig = fsig.size
    dpi_arr = list()
    for ind in xrange(0, len_sig):
        lower_index = ind + m1 + 1
        upper_index = ind + m1 + m2

        if upper_index >= lower_index:
            s_avg = float(np.sum(np.abs(fsig[lower_index:upper_index + 1])))
            s_avg /= m2
        else:
            s_avg = 1.0
        dpi_val = np.abs(fsig[ind]) / s_avg

        dpi_val /= 5.0
        dpi_arr.append(dpi_val)

    plt.plot(dpi_arr, label='DPI')
    plt.plot(fsig, label='fsig')
    plt.title('DPI')
    plt.legend()
    plt.show()
Ejemplo n.º 11
0
def TestRecord(saveresultpath):
    '''Test all records in testinglist, training on remaining records in QTdb.'''
    qt_loader = QTloader()
    QTreclist = qt_loader.getQTrecnamelist()
    # Get training record list
    testinglist = QTreclist[0:10]
    # traininglist = QTreclist[0:10]

    # debug
    # traininglist = QTreclist[0:10]
    # testinglist = QTreclist[0:10]
    # log.warning('Using custom testing & training records.')
    # log.warning('Training range: 0-10')
    # log.warning('Testing range: 0-10')

    rf_classifier = ECGrf(SaveTrainingSampleFolder = saveresultpath)
    # Multi Process
    rf_classifier.TestRange = 'All'

    # Load classification model.
    with open(os.path.join(saveresultpath, 'trained_model.mdl'), 'rb') as fin:
        trained_model = pickle.load(fin)
        rf_classifier.mdl = trained_model

    # testing
    log.info('Testing records:\n    %s',', '.join(testinglist))
    for record_name in testinglist:
        sig = qt_loader.load(record_name)
        raw_signal = sig['sig']
        result = rf_classifier.testing(raw_signal, trained_model)
        with open(os.path.join(saveresultpath, 'result_{}'.format(record_name)),'w') as fout:
            json.dump(result,fout,indent = 4)
Ejemplo n.º 12
0
def TEST():
    '''Test code for WaveDelineator.'''
    qt = QTloader()
    sig = qt.load('sel31')
    raw_sig = sig['sig'][1000:2000]

    pd = WaveDelineator(raw_sig, fs=250.0)
    result = pd.run()
    pd.plot_results(raw_sig, result)
def plot_QTdb_filtered_Result_with_syntax_filter(RFfolder,
                                                 TargetRecordList,
                                                 ResultFilterType,
                                                 showExpertLabels=False):
    # exit
    # ==========================
    # plot prediction result
    # ==========================
    reslist = getresultfilelist(RFfolder)
    qtdb = QTloader()
    non_result_extensions = ['out', 'json', 'log', 'txt']
    for fi, fname in enumerate(reslist):
        # block *.out
        file_extension = fname.split('.')[-1]
        if file_extension in non_result_extensions:
            continue
        print 'file name:', fname
        currecname = os.path.split(fname)[-1]
        print currecname
        #if currecname == 'result_sel820':
        #pdb.set_trace()
        if TargetRecordList is not None:
            if currecname not in TargetRecordList:
                continue
        # load signal and reslist
        with open(fname, 'r') as fin:
            (recID, reslist) = pickle.load(fin)
        # filter result of QT
        resfilter = ResultFilter(reslist)
        if len(ResultFilterType) >= 1 and ResultFilterType[0] == 'G':
            reslist = resfilter.group_local_result(cp_del_thres=1)
        reslist_syntax = reslist
        if len(ResultFilterType) >= 2 and ResultFilterType[1] == 'S':
            reslist_syntax = resfilter.syntax_filter(reslist)
        # empty signal result
        #if reslist is None or len(reslist) == 0:
        #continue
        #pdb.set_trace()
        sigstruct = qtdb.load(recID)
        if showExpertLabels == True:
            # Expert Label AdditionalPlot
            ExpertRes = qtdb.getexpertlabeltuple(recID)
            ExpertPoslist = map(lambda x: x[0], ExpertRes)
            AdditionalPlot = [
                ['kd', 'Expert Labels', ExpertPoslist],
            ]
        else:
            AdditionalPlot = None

        # plot res
        #resploter = ECGResultPloter(sigstruct['sig'],reslist)
        #resploter.plot(plotTitle = 'QT database',plotShow = True,plotFig = 2)
        # syntax_filter
        resploter_syntax = ECGResultPloter(sigstruct['sig'], reslist_syntax)
        resploter_syntax.plot(plotTitle='QT Record {}'.format(recID),
                              plotShow=True,
                              AdditionalPlot=AdditionalPlot)
Ejemplo n.º 14
0
def Test_regression():
    '''Regression test.'''
    target_label = 'T'
    qt = QTloader()
    rec_ind = 65
    reclist = qt.getreclist()
    sig = qt.load(reclist[rec_ind])
    raw_sig = sig['sig']
    len_sig = len(raw_sig)

    rf = RegressionLearner(target_label=target_label)
    rf.TrainQtRecords(reclist[0:])

    # Load the trained model
    # with open('./tmp.mdl', 'rb') as fin:
    # mdl = pickle.load(fin)
    # rf.LoadModel(mdl)
    # Save the trained model
    with open('./tmp.mdl', 'wb') as fout:
        pickle.dump(rf.mdl, fout)

    dpi = DPI()
    qrs_list = dpi.QRS_Detection(raw_sig)

    start_time = time.time()
    detected_poslist = rf.testing(raw_sig,
                                  zip(qrs_list, [
                                      'R',
                                  ] * len(qrs_list)))
    print 'Testing time: %d s.' % (time.time() - start_time)
    while len(detected_poslist) > 0 and detected_poslist[-1] > len_sig:
        del detected_poslist[-1]

    sig = qt.load(reclist[rec_ind])
    raw_sig = sig['sig']
    plt.plot(raw_sig, label='raw signal')
    amp_list = [raw_sig[int(x)] for x in detected_poslist]
    plt.plot(detected_poslist,
             amp_list,
             'ro',
             markersize=12,
             label=target_label)
    plt.legend()
    plt.show()
Ejemplo n.º 15
0
def Test1():
    '''Test case1.'''
    # data = sio.loadmat('./data/ft.mat')
    # v2 = np.squeeze(data['II'])
    # raw_sig = v2
    # fs = 500
    qt = QTloader()
    sig = qt.load('sel32')
    raw_sig = sig['sig'][1000:3000]
    fs = 250
    # raw_sig = scipy.signal.resample(raw_sig, len(raw_sig) / 2)
    # fs = 250

    model_folder = '/home/alex/LabGit/ECG_random_walk/randomwalk/data/m3_full_models'
    pattern_file_name = '/home/alex/LabGit/ECG_random_walk/randomwalk/data/m3_full_models/random_pattern.json'
    model_list = GetModels(model_folder, pattern_file_name)
    start_time = time.time()

    # First: QRS detection
    dpi = DPI(debug_info=dict())
    r_list = dpi.QRS_Detection(raw_sig, fs=fs)
    results = zip(r_list,
                  len(r_list) * [
                      'R',
                  ])
    results.extend(Testing_QS(raw_sig, fs, r_list))
    walk_results = Testing_random_walk(raw_sig, fs, r_list, model_list)
    results.extend(walk_results)

    # results = Testing(raw_sig, fs, model_list)
    print 'Testing time cost %f secs.' % (time.time() - start_time)

    samples_count = len(raw_sig)
    time_span = samples_count / fs
    print 'Span of testing range: %f samples(%f seconds).' % (samples_count,
                                                              time_span)

    with open('./data/new_result.json', 'w') as fout:
        json.dump(results, fout, indent=4)

    # Display results
    plt.figure(1)
    plt.plot(raw_sig, label='ECG')
    pos_list, label_list = zip(*results)
    labels = set(label_list)
    for label in labels:
        pos_list = [
            int(x[0]) for x in results if x[1] == label and x[0] < len(raw_sig)
        ]
        amp_list = [raw_sig[x] for x in pos_list]
        plt.plot(pos_list, amp_list, 'o', markersize=15, label=label)
    plt.title('ECG')
    plt.grid(True)
    plt.legend()
    plt.show()
Ejemplo n.º 16
0
def viewQT():
    qt = QTloader()
    record_list = qt.getreclist()
    index = 19
    for record_name in record_list[index:]:
        print 'record index:', index
        # if record_name != 'sele0612':
        # continue
        sig = qt.load(record_name)
        raw_sig = sig['sig'][2000:7000]
        viewCWTsignal(raw_sig, 250, figure_title=record_name)

        index += 1
Ejemplo n.º 17
0
def plot_QTdb_filtered_Result_with_syntax_filter():
    # exit
    RFfolder = os.path.join(\
           projhomepath,\
           'TestResult',\
           'pc',\
           'r5')
    TargetRecordList = [
        'result_sel39',
        'result_sel41',
        'result_sel48',
    ]  #'sel38','sel42','result_sel821','result_sel14046']
    # ==========================
    # plot prediction result
    # ==========================
    reslist = glob.glob(os.path.join(\
           RFfolder,'*'))
    qtdb = QTloader()
    non_result_extensions = ['out', 'json', 'log']
    for fi, fname in enumerate(reslist):
        # block *.out
        file_extension = fname.split('.')[-1]
        if file_extension in non_result_extensions:
            continue
        print 'file name:', fname
        currecname = os.path.split(fname)[-1]
        print currecname
        #if currecname == 'result_sel820':
        #pdb.set_trace()
        if currecname not in TargetRecordList:
            pass
            continue
        # load signal and reslist
        with open(fname, 'r') as fin:
            (recID, reslist) = pickle.load(fin)
        # filter result of QT
        resfilter = ResultFilter(reslist)
        reslist = resfilter.group_local_result(cp_del_thres=1)
        reslist_syntax = resfilter.syntax_filter(reslist)
        # empty signal result
        #if reslist is None or len(reslist) == 0:
        #continue
        #pdb.set_trace()
        sigstruct = qtdb.load(recID)
        # plot res
        #resploter = ECGResultPloter(sigstruct['sig'],reslist)
        #resploter.plot(plotTitle = 'QT database',plotShow = True,plotFig = 2)
        # syntax_filter
        resploter_syntax = ECGResultPloter(sigstruct['sig'], reslist_syntax)
        resploter_syntax.plot(plotTitle='QT database syntax_filter',
                              plotShow=True)
Ejemplo n.º 18
0
def TestQT(record_index):
    '''Test case'''
    result_folder = '/home/alex/LabGit/ECG_random_walk/randomwalk/data/test_results/r2'
    qt = QTloader()
    P_width = 50
    result_files = glob.glob(os.path.join(result_folder, 'sel*.json'))
    with open(result_files[record_index], 'r') as fin:
        data = json.load(fin)
        record_name = data[0][0]
        print 'Record name:', record_name

    sig = qt.load(record_name)
    raw_sig = np.array(sig['sig']) / 40.0 + 1.0
    results = data[0][1]
    fs = 250
    if abs(fs - 250.0) > 1e-6:
        raw_sig = scipy.signal.resample(raw_sig,
                                        int(len(raw_sig) / float(fs) * 250.0))
        fs_recover = fs
        fs = 250.0

    P_positions = results
    P_positions.sort(key=lambda x: x[0])

    show_count = 1
    segment_list = list()
    wholewave_list = list()
    for ind in xrange(0, len(P_positions)):
        if show_count > 2:
            break
        pos, label = P_positions[ind]
        if label == 'P':
            Pon = pos - P_width / 2.0
            Poff = pos + P_width / 2.0
            for pi in xrange(ind - 1, -1, -1):
                if P_positions[pi][1] == 'Ponset':
                    Pon = P_positions[pi][0]
                    break
                elif (P_positions[pi][1] == 'Roffset'
                      or P_positions[pi][1] == 'Toffset'
                      or P_positions[pi][1] == 'T'):
                    Pon = P_positions[pi][0]
                    break
            for pi in xrange(ind + 1, len(P_positions)):
                if P_positions[pi][1] == 'Poffset':
                    Poff = P_positions[pi][0]
                    break
                elif P_positions[pi][1] == 'Ronset':
                    Poff = P_positions[pi][0]
                    break
            Pshape(raw_sig, (Pon, pos, Poff))
Ejemplo n.º 19
0
def TrainingModels(target_label, model_file_name, training_list):
    '''Randomly select num_training records to train, and test others.
    CP: Characteristic points
    '''
    qt = QTloader()
    record_list = qt.getreclist()
    testing_list = list(set(record_list) - set(training_list))

    random_forest_config = dict(max_depth=10)
    walker = RandomWalker(target_label=target_label,
                          random_forest_config=random_forest_config,
                          random_pattern_file_name=os.path.join(
                              os.path.dirname(model_file_name),
                              'random_pattern.json'))

    start_time = time.time()
    for record_name in training_list:
        CP_file_name = os.path.join(
            '/home/alex/code/Python/EcgCharacterPointMarks', target_label,
            '%s_poslist.json' % record_name)

        # Add expert marks
        expert_marks = qt.getExpert(record_name)
        CP_marks = [x for x in expert_marks if x[1] == target_label]
        if len(CP_marks) == 0:
            continue

        # Add manual labels if possible
        if os.path.exists(CP_file_name) == True:
            with open(CP_file_name, 'r') as fin:
                CP_info = json.load(fin)
                poslist = CP_info['poslist']
                if len(poslist) == 0:
                    continue
                CP_marks.extend(zip(poslist, [
                    target_label,
                ] * len(poslist)))

        print 'Collecting features from record %s.' % record_name
        sig = qt.load(record_name)
        walker.collect_training_data(sig['sig'], CP_marks)
    print 'random forest start training(%s)...' % target_label
    walker.training()
    print 'trianing used %.3f seconds' % (time.time() - start_time)

    import joblib
    start_time = time.time()
    walker.save_model(model_file_name)
    print 'Serializing model time cost %f' % (time.time() - start_time)
Ejemplo n.º 20
0
def ContinueAddQtTrainingSamples(walker, target_label):
    '''Add QT training samples.'''
    qt = QTloader()
    record_list = qt.getreclist()

    start_time = time.time()
    for record_name in record_list:

        # Add expert marks
        expert_marks = qt.getExpert(record_name)
        CP_marks = [x for x in expert_marks if x[1] == target_label]
        if len(CP_marks) == 0:
            continue

        print 'Collecting features from QT record %s.' % record_name
        sig = qt.load(record_name)
        walker.collect_training_data(sig['sig'], CP_marks)
Ejemplo n.º 21
0
class RecSelector():
    def __init__(self):
        self.qt = QTloader()

    def inspect_recs(self):
        reclist = self.qt.getQTrecnamelist()
        sel1213 = conf['sel1213']
        sel1213set = set(sel1213)

        out_reclist = set(reclist)  # - sel1213set

        # records selected
        selected_record_list = []
        for ind, recname in enumerate(out_reclist):
            # inspect
            print '{} records left.'.format(len(out_reclist) - ind - 1)
            self.qt.plotrec(recname)
            # debug
            if ind > 2:
                pass
                #print 'debug break'
                #break
    def plot_rec_dwt(self):
        reclist = self.qt.getQTrecnamelist()
        sel1213 = conf['sel1213']
        sel1213set = set(sel1213)

        out_reclist = set(reclist)  # - sel1213set

        # dwt coef
        dwtECG = WTfeature()
        # records selected
        selected_record_list = []
        for ind, recname in enumerate(out_reclist):
            # inspect
            print 'processing record : {}'.format(recname)
            print '{} records left.'.format(len(out_reclist) - ind - 1)
            # debug selection
            if not recname.startswith(ur'sele0704'):
                continue
            sig = self.qt.load(recname)
            # wavelet
            waveletobj = dwtECG.gswt_wavelet()
            self.plot_dwt_coef_with_annotation(sig['sig'],
                                               waveletobj=waveletobj,
                                               ECGrecordname=recname)
Ejemplo n.º 22
0
def plotResult():
    '''Plot ECG delineation result.'''
    record_name = 'sel808'
    result_file_path = './round1/round2/sel808.json'
    qt = QTloader()
    sig = qt.load(record_name)
    raw_sig = sig['sig']

    # Load expert labels
    with open(result_file_path, 'r') as fin:
        data = json.load(fin)
        annots = data
    fig, ax = plt.subplots(1, 1)
    plt.plot(raw_sig)
    plt.grid(true)
    plotExpertLabels(ax, raw_sig, annots)
    plt.show()
Ejemplo n.º 23
0
def res_to_mat_fromfilename(res_filename, mat_filename):
    recID, reslist = load_result(res_filename)
    # load QT rawsig
    qt = QTloader()
    sig = qt.load(recID)
    rawsig = sig['sig']
    # save sig and reslist
    label_dict = dict()
    for pos, label in reslist:
        if label in label_dict:
            label_dict[label].append(pos)
        else:
            label_dict[label] = [
                pos,
            ]
    label_dict['sig'] = rawsig
    scipy.io.savemat(mat_filename, label_dict)
    print 'mat file [{}] saved.'.format(mat_filename)
Ejemplo n.º 24
0
    def TrainQtRecords(self, record_list):
        '''API for QTdb: training model with given record_list.'''
        QTdb = QTloader()

        training_count = 1
        # Extracting feature from each record.
        for record_name in record_list:
            sig_struct = QTdb.load(record_name)
            raw_signal = sig_struct['sig']
            expert_labels = QTdb.getExpert(record_name)
            self.AddNewTrainingSignal(raw_signal, expert_labels)
            # Logging
            log.info('Extracted features from %s' % record_name)
            print '.' * training_count, '(%d/%d)' % (training_count, len(record_list))
            training_count += 1

        # Training with feature pool
        self.training()
def TrainingModels(target_label, model_file_name, training_list):
    '''Randomly select num_training records to train, and test others.'''
    qt = QTloader()
    record_list = qt.getreclist()
    testing_list = list(set(record_list) - set(training_list))

    random_forest_config = dict(max_depth=10)
    walker = RandomWalker(target_label=target_label,
                          random_forest_config=random_forest_config,
                          random_pattern_file_name=os.path.join(
                              os.path.dirname(model_file_name),
                              'random_pattern.json'))

    start_time = time.time()
    for record_name in training_list:
        Tonset_file_name = os.path.join(
            '/home/alex/code/Python/Tonset/results',
            '%s_poslist.json' % record_name)
        if os.path.exists(Tonset_file_name) == True:
            with open(Tonset_file_name, 'r') as fin:
                Tonset_info = json.load(fin)
                poslist = Tonset_info['poslist']
                if len(poslist) == 0:
                    continue
                Tonset_marks = zip(poslist, [
                    'Tonset',
                ] * len(poslist))
        else:
            expert_marks = qt.getExpert(record_name)
            Tonset_marks = [x for x in expert_marks if x[1] == 'Tonset']
            if len(Tonset_marks) == 0:
                continue

        print 'Collecting features from record %s.' % record_name
        sig = qt.load(record_name)
        walker.collect_training_data(sig['sig'], Tonset_marks)
    print 'random forest start training(%s)...' % target_label
    walker.training()
    print 'trianing used %.3f seconds' % (time.time() - start_time)

    import joblib
    start_time = time.time()
    walker.save_model(model_file_name)
    print 'Serializing model time cost %f' % (time.time() - start_time)
Ejemplo n.º 26
0
def DPI(fs=250.0):
    '''High pass filtering.'''
    qt = QTloader()
    sig = qt.load('sel100')
    raw_sig = sig['sig'][0:1000]
    fsig = HPF(raw_sig)

    # DPI
    m1 = -2
    m2 = 300
    len_sig = fsig.size
    dpi_arr = list()

    N_m2 = int(fs * 1.71)
    # for ind in xrange(0, len_sig):
    ind = 140
    for m2 in xrange(0, N_m2):
        lower_index = ind + m1 + 1
        upper_index = ind + m1 + m2

        lower_index = max(0, lower_index)
        upper_index = min(len_sig - 1, upper_index)

        if upper_index >= lower_index:
            s_avg = float(np.sum(np.abs(fsig[lower_index:upper_index + 1])))
            # s_avg /= m2 ** 0.5
            s_avg /= math.pow(m2, 0.5)
        else:
            s_avg = 1.0

        # Prevent 0 division error
        if s_avg < 1e-6:
            s_avg = 1.0

        dpi_val = np.abs(fsig[ind]) / s_avg
        dpi_arr.append(dpi_val)

    plt.plot(xrange(ind, ind + len(dpi_arr)), dpi_arr, label='DPI')
    plt.plot(fsig, label='fsig')
    plt.title('DPI')
    plt.legend()
    plt.show()
Ejemplo n.º 27
0
def Test():
    '''Test function for HOG1D class.'''
    qt = QTloader()
    sig_struct = qt.load('sel17152')
    sig = sig_struct['sig']
    sig = sig[10000:10900]

    # Plot ECG signal
    # plt.figure(1)
    # plt.plot(sig)
    # plt.show()

    # HOG 1d class
    hoger = HogClass(segment_len=15)
    # hoger.ComputeHog(sig, debug_plot = True)
    hoger.DiscretiseHog(sig, debug_plot=True)

    plt.figure(2)
    plt.plot(np.array(sig) * 10)
    plt.grid(True)
    plt.show()
Ejemplo n.º 28
0
def getRRhisto():
    ## plot RR histogram
    from MedEvaluationClass import MedEvaluation
    # get test result list
    testlist = getresultfilelist()
    # filter testlist
    invalidlist = conf['InvalidRecords']
    for testrec in testlist:
        print 'processing record:{}'.format(testrec)
        with open(testrec,'r') as fin:
            (recname,resdata) = pickle.load(fin)
            if recname in invalidlist:
                continue
            # RR histo
            medeval = MedEvaluation(resdata)
            medeval.RRhistogram()
            # load signal rawdata
            qtloader = QTloader()
            ECGsigstruct = qtloader.load(recname = recname)
            ECGsig = ECGsigstruct['sig']
            medeval.RRhisto_check(ECGsig)
Ejemplo n.º 29
0
def TEST_PredictionQRS():
    recname = 'sel873'
    GroupResultFolder = os.path.join(curfolderpath, 'MultiLead4',
                                     'GroupRound1')
    QTdb = QTloader()
    rawsig = QTdb.load(recname)
    rawsig = rawsig['sig']
    with open(os.path.join(GroupResultFolder, '{}.json'.format(recname)),
              'r') as fin:
        RawResultDict = json.load(fin)
        LeadResult = RawResultDict['LeadResult']
        MarkDict = LeadResult[0]
        MarkList = Convert2ExpertFormat(MarkDict)

        # Display with 2 subplots.
        swt = SWT_NoPredictQRS(rawsig, MarkList)
        swt.swt()

        # cDlist
        wtlist = swt.cDlist[-4]

        plt.figure(1)
        # plot Non QRS ECG & SWT
        plt.subplot(211)
        plt.plot(rawsig)
        plt.plot(wtlist)
        plt.grid(True)
        # plot Original ECG
        rawsig = swt.QTdb.load(recname)
        rawsig = rawsig['sig']
        rawsig = swt.crop_data_for_swt(rawsig)
        coeflist = pywt.swt(rawsig, 'db6', 9)
        cAlist, cDlist = zip(*coeflist)
        wtlist = cDlist[-4]

        plt.subplot(212)
        plt.plot(rawsig)
        plt.plot(wtlist)
        plt.grid(True)
        plt.show()
Ejemplo n.º 30
0
def test0():
    '''Test code for fast_tester.'''
    from QTdata.loadQTdata import QTloader
    qt = QTloader()
    sig = qt.load('sel100')

    range_right = 1000
    raw_sig = sig['sig2'][0:range_right]

    ft = FastTester()
    res_list = ft.testing(raw_sig, fs=250.0)
    labels = set([x[1] for x in res_list])
    plt.plot(raw_sig)
    for label in labels:
        poslist = [x[0] for x in filter(lambda x: x[1] == label, res_list)]
        amplist = [
            raw_sig[int(x)] for x in filter(lambda x: x < range_right, poslist)
        ]
        plt.plot(poslist, amplist, 'o', markersize=12, alpha=0.5, label=label)
    plt.title('sel100')
    plt.legend()
    plt.show()