示例#1
0
def TrainingModels(target_label, model_file_name, training_list):
    '''Randomly select num_training records to train, and test others.'''
    qt = QTloader()
    record_list = qt.getreclist()
    testing_list = list(set(record_list) - set(training_list))

    random_forest_config = dict(max_depth=10)
    walker = RandomWalker(target_label=target_label,
                          random_forest_config=random_forest_config,
                          random_pattern_file_name=os.path.join(
                              os.path.dirname(model_file_name),
                              'random_pattern.json'))

    start_time = time.time()
    for record_name in training_list:
        print 'Collecting features from record %s.' % record_name
        sig = qt.load(record_name)
        walker.collect_training_data(sig['sig'], qt.getExpert(record_name))
    print 'random forest start training(%s)...' % target_label
    walker.training()
    print 'trianing used %.3f seconds' % (time.time() - start_time)

    import joblib
    start_time = time.time()
    walker.save_model(model_file_name)
    print 'Serializing model time cost %f' % (time.time() - start_time)
示例#2
0
def Test1():
    '''Comparing to expert labels in QTdb.'''
    qt = QTloader()
    reclist = qt.getreclist()

    rec_ind = 0
    for rec_ind in xrange(0, len(reclist)):

        print 'Processing record[%d] %s ...' % (rec_ind, reclist[rec_ind])
        sig = qt.load(reclist[rec_ind])
        raw_sig = sig['sig']
        expert_labels = qt.getExpert(reclist[rec_ind])
        R_pos_list = [
            x[0] for x in filter(lambda item: item[1] == 'R', expert_labels)
        ]

        # Skip empty expert lists
        if len(R_pos_list) == 0:
            continue

        dpi = DPI()

        qrs_list = dpi.QRS_Detection(raw_sig)

        # Find FN
        FN_arr = GetFN(R_pos_list, qrs_list)
        R_pos_list = FN_arr

        if len(R_pos_list) > 0:
            plt.plot(raw_sig)
            amp_list = [raw_sig[x] for x in qrs_list]
            plt.plot(qrs_list, amp_list, 'ro', markersize=12)
            amp_list = [raw_sig[x] for x in R_pos_list]
            plt.plot(R_pos_list, amp_list, 'ys', markersize=14)
            plt.show()
def randomTraining(root_folder='models/round1',
                   num_training=75,
                   random_pattern_path='models/random_pattern.json'):
    label_list = [
        'P', 'Ponset', 'Poffset', 'T', 'Toffset', 'Ronset', 'Roffset'
    ]

    if os.path.exists(root_folder) == False:
        os.mkdir(root_folder)

    # Refresh training list
    trianing_list = list()
    qt = QTloader()
    record_list = qt.getreclist()
    must_train_list = [
        "sel35", "sel36", "sel31", "sel38", "sel39", "sel820", "sel51",
        "sele0104", "sele0107", "sel223", "sele0607", "sel102", "sele0409",
        "sel41", "sel40", "sel43", "sel42", "sel45", "sel48", "sele0133",
        "sele0116", "sel14172", "sele0111", "sel213", "sel14157", "sel301"
    ]
    num_training -= len(must_train_list)
    record_list = list(set(record_list) - set(must_train_list))
    training_list = must_train_list
    if num_training > 0:
        training_list.extend(random.sample(record_list, num_training))
    # Save training list
    with open(os.path.join(root_folder, 'training_list.json'), 'w') as fout:
        json.dump(training_list, fout, indent=4)
    for target_label in label_list:
        model_file_name = os.path.join(root_folder, '%s.mdl' % target_label)
        TrainingModels(target_label,
                       model_file_name,
                       training_list,
                       random_pattern_path=random_pattern_path)

    # Testing
    testing_list = list(set(qt.getreclist()) - set(training_list))
    saveresultpath = os.path.join(root_folder, 'results')
    if os.path.exists(saveresultpath) == False:
        os.mkdir(saveresultpath)
    from record_test import RoundTesting
    RoundTesting(saveresultpath,
                 testing_list,
                 model_folder=root_folder,
                 pattern_filename=random_pattern_path)
示例#4
0
def Test_hog1d():
    '''Hog feature method test.'''
    target_label = 'T'
    qt = QTloader()
    rec_ind = 103
    reclist = qt.getreclist()
    sig = qt.load(reclist[rec_ind])
    raw_sig = sig['sig']
    len_sig = len(raw_sig)

    # Hog feature tester
    hog = HogFeatureExtractor(target_label=target_label)
    rec_list = hog.qt.getreclist()

    training_list = reclist[0:100]
    # training_rec = list(set(rec_list) - set(testing_rec))

    hog.Train(training_list)

    # Load the trained model
    # with open('./hog.mdl', 'rb') as fin:
    # mdl = pickle.load(fin)
    # hog.LoadModel(mdl)
    # Save the trained model
    with open('./hog.mdl', 'wb') as fout:
        pickle.dump(hog.gbdt, fout)

    dpi = DPI()
    qrs_list = dpi.QRS_Detection(raw_sig)

    start_time = time.time()
    detected_poslist = hog.Testing(raw_sig,
                                   zip(qrs_list, [
                                       'R',
                                   ] * len(qrs_list)))
    print 'Testing time: %d s.' % (time.time() - start_time)
    while len(detected_poslist) > 0 and detected_poslist[-1] > len_sig:
        del detected_poslist[-1]
    while len(detected_poslist) > 0 and detected_poslist[0] < 0:
        del detected_poslist[0]

    sig = qt.load(reclist[rec_ind])
    raw_sig = sig['sig']
    plt.plot(raw_sig, label='raw signal D1')
    plt.plot(sig['sig2'], label='raw signal D2')
    amp_list = [raw_sig[int(x)] for x in detected_poslist]
    plt.plot(detected_poslist,
             amp_list,
             'ro',
             markersize=12,
             label=target_label)
    plt.title('Record name %s' % reclist[rec_ind])
    plt.legend()
    plt.show()
示例#5
0
def viewQT():
    qt = QTloader()
    record_list = qt.getreclist()
    index = 19
    for record_name in record_list[index:]:
        print 'record index:', index
        # if record_name != 'sele0612':
        # continue
        sig = qt.load(record_name)
        raw_sig = sig['sig'][2000:7000]
        viewCWTsignal(raw_sig, 250, figure_title=record_name)

        index += 1
def TrainingModels(target_label, model_file_name, training_list):
    '''Randomly select num_training records to train, and test others.
    CP: Characteristic points
    '''
    qt = QTloader()
    record_list = qt.getreclist()
    testing_list = list(set(record_list) - set(training_list))

    random_forest_config = dict(max_depth=10)
    walker = RandomWalker(target_label=target_label,
                          random_forest_config=random_forest_config,
                          random_pattern_file_name=os.path.join(
                              os.path.dirname(model_file_name),
                              'random_pattern.json'))

    start_time = time.time()
    for record_name in training_list:
        CP_file_name = os.path.join(
            '/home/alex/code/Python/EcgCharacterPointMarks', target_label,
            '%s_poslist.json' % record_name)

        # Add expert marks
        expert_marks = qt.getExpert(record_name)
        CP_marks = [x for x in expert_marks if x[1] == target_label]
        if len(CP_marks) == 0:
            continue

        # Add manual labels if possible
        if os.path.exists(CP_file_name) == True:
            with open(CP_file_name, 'r') as fin:
                CP_info = json.load(fin)
                poslist = CP_info['poslist']
                if len(poslist) == 0:
                    continue
                CP_marks.extend(zip(poslist, [
                    target_label,
                ] * len(poslist)))

        print 'Collecting features from record %s.' % record_name
        sig = qt.load(record_name)
        walker.collect_training_data(sig['sig'], CP_marks)
    print 'random forest start training(%s)...' % target_label
    walker.training()
    print 'trianing used %.3f seconds' % (time.time() - start_time)

    import joblib
    start_time = time.time()
    walker.save_model(model_file_name)
    print 'Serializing model time cost %f' % (time.time() - start_time)
def ContinueAddQtTrainingSamples(walker, target_label):
    '''Add QT training samples.'''
    qt = QTloader()
    record_list = qt.getreclist()

    start_time = time.time()
    for record_name in record_list:

        # Add expert marks
        expert_marks = qt.getExpert(record_name)
        CP_marks = [x for x in expert_marks if x[1] == target_label]
        if len(CP_marks) == 0:
            continue

        print 'Collecting features from QT record %s.' % record_name
        sig = qt.load(record_name)
        walker.collect_training_data(sig['sig'], CP_marks)
示例#8
0
def Test_regression():
    '''Regression test.'''
    target_label = 'T'
    qt = QTloader()
    rec_ind = 65
    reclist = qt.getreclist()
    sig = qt.load(reclist[rec_ind])
    raw_sig = sig['sig']
    len_sig = len(raw_sig)

    rf = RegressionLearner(target_label=target_label)
    rf.TrainQtRecords(reclist[0:])

    # Load the trained model
    # with open('./tmp.mdl', 'rb') as fin:
    # mdl = pickle.load(fin)
    # rf.LoadModel(mdl)
    # Save the trained model
    with open('./tmp.mdl', 'wb') as fout:
        pickle.dump(rf.mdl, fout)

    dpi = DPI()
    qrs_list = dpi.QRS_Detection(raw_sig)

    start_time = time.time()
    detected_poslist = rf.testing(raw_sig,
                                  zip(qrs_list, [
                                      'R',
                                  ] * len(qrs_list)))
    print 'Testing time: %d s.' % (time.time() - start_time)
    while len(detected_poslist) > 0 and detected_poslist[-1] > len_sig:
        del detected_poslist[-1]

    sig = qt.load(reclist[rec_ind])
    raw_sig = sig['sig']
    plt.plot(raw_sig, label='raw signal')
    amp_list = [raw_sig[int(x)] for x in detected_poslist]
    plt.plot(detected_poslist,
             amp_list,
             'ro',
             markersize=12,
             label=target_label)
    plt.legend()
    plt.show()
def TrainingModels(target_label, model_file_name, training_list):
    '''Randomly select num_training records to train, and test others.'''
    qt = QTloader()
    record_list = qt.getreclist()
    testing_list = list(set(record_list) - set(training_list))

    random_forest_config = dict(max_depth=10)
    walker = RandomWalker(target_label=target_label,
                          random_forest_config=random_forest_config,
                          random_pattern_file_name=os.path.join(
                              os.path.dirname(model_file_name),
                              'random_pattern.json'))

    start_time = time.time()
    for record_name in training_list:
        Tonset_file_name = os.path.join(
            '/home/alex/code/Python/Tonset/results',
            '%s_poslist.json' % record_name)
        if os.path.exists(Tonset_file_name) == True:
            with open(Tonset_file_name, 'r') as fin:
                Tonset_info = json.load(fin)
                poslist = Tonset_info['poslist']
                if len(poslist) == 0:
                    continue
                Tonset_marks = zip(poslist, [
                    'Tonset',
                ] * len(poslist))
        else:
            expert_marks = qt.getExpert(record_name)
            Tonset_marks = [x for x in expert_marks if x[1] == 'Tonset']
            if len(Tonset_marks) == 0:
                continue

        print 'Collecting features from record %s.' % record_name
        sig = qt.load(record_name)
        walker.collect_training_data(sig['sig'], Tonset_marks)
    print 'random forest start training(%s)...' % target_label
    walker.training()
    print 'trianing used %.3f seconds' % (time.time() - start_time)

    import joblib
    start_time = time.time()
    walker.save_model(model_file_name)
    print 'Serializing model time cost %f' % (time.time() - start_time)
示例#10
0
def SplitQTrecords(num_training = 75):
    '''Split records for testing & training.'''
    trianing_list = list()
    qt = QTloader()
    record_list = qt.getreclist()
    must_train_list = [
        "sel35", 
        "sel36", 
        "sel31", 
        "sel38", 
        "sel39", 
        "sel820", 
        "sel51", 
        "sele0104", 
        "sele0107", 
        "sel223", 
        "sele0607", 
        "sel102", 
        "sele0409", 
        "sel41", 
        "sel40", 
        "sel43", 
        "sel42", 
        "sel45", 
        "sel48", 
        "sele0133", 
        "sele0116", 
        "sel14172", 
        "sele0111", 
        "sel213", 
        "sel14157", 
        "sel301"
            ]
    num_training -= len(must_train_list)
    record_list = list(set(record_list) - set(must_train_list))
    training_list = must_train_list
    # Refresh training list
    if num_training > 0:
        training_list.extend(random.sample(record_list, num_training))
    testing_list = list(set(record_list) - set(training_list))

    return (training_list, testing_list)
def Test1(target_label='P', num_training=25):
    '''Test case 1: random walk.'''
    qt = QTloader()
    record_list = qt.getreclist()
    training_list = random.sample(record_list, num_training)
    testing_list = list(set(record_list) - set(training_list))

    random_forest_config = dict(max_depth=10)
    walker = RandomWalker(target_label=target_label,
                          random_forest_config=random_forest_config)

    start_time = time.time()
    for record_name in training_list:
        print 'Collecting features from record %s.' % record_name
        sig = qt.load(record_name)
        walker.collect_training_data(sig['sig'], qt.getExpert(record_name))
    print 'random forest start training...'
    walker.training()
    print 'trianing used %.3f seconds' % (time.time() - start_time)

    for record_name in testing_list:
        sig = qt.load(record_name)
        raw_sig = sig['sig']

        seed_position = random.randint(100, len(raw_sig) - 100)
        plt.figure(1)
        plt.clf()
        plt.plot(sig['sig'], label=record_name)
        plt.title(target_label)
        for ti in xrange(0, 20):
            seed_position += random.randint(1, 200)
            print 'testing...(position: %d)' % seed_position
            start_time = time.time()
            results = walker.testing_walk(sig['sig'],
                                          seed_position,
                                          iterations=100,
                                          stepsize=10)
            print 'testing finished in %.3f seconds.' % (time.time() -
                                                         start_time)

            pos_list, values = zip(*results)
            predict_pos = np.mean(pos_list[len(pos_list) / 2:])

            # amp_list = [raw_sig[int(x)] for x in pos_list]
            amp_list = []
            bias = raw_sig[pos_list[0]]
            for pos in pos_list:
                amp_list.append(bias)
                bias -= 0.01

            plt.plot(predict_pos,
                     raw_sig[int(predict_pos)],
                     'ro',
                     markersize=14,
                     label='predict position')
            plt.plot(pos_list,
                     amp_list,
                     'r',
                     label='walk path',
                     markersize=3,
                     linewidth=8,
                     alpha=0.3)
            plt.xlim(min(pos_list) - 100, max(pos_list) + 100)
            plt.grid(True)
            plt.legend()
            plt.show(block=False)
            pdb.set_trace()
    walker.training()
    print 'trianing used %.3f seconds' % (time.time() - start_time)

    import joblib
    start_time = time.time()
    walker.save_model(model_file_name)
    print 'Serializing model time cost %f' % (time.time() - start_time)


if __name__ == '__main__':
    root_folder = 'data/Lw3Np4000/improved'
    # Refresh training list
    num_training = 105
    trianing_list = list()
    qt = QTloader()
    record_list = qt.getreclist()
    must_train_list = [
        "sel35", "sel36", "sel31", "sel38", "sel39", "sel820", "sel51",
        "sele0104", "sele0107", "sel223", "sele0607", "sel102", "sele0409",
        "sel41", "sel40", "sel43", "sel42", "sel45", "sel48", "sele0133",
        "sele0116", "sel14172", "sele0111", "sel213", "sel14157", "sel301"
    ]
    num_training -= len(must_train_list)
    record_list = list(set(record_list) - set(must_train_list))
    training_list = must_train_list
    if num_training > 0:
        training_list.extend(random.sample(record_list, num_training))
    # Save training list
    with open(os.path.join(root_folder, 'training_list.json'), 'w') as fout:
        json.dump(training_list, fout, indent=4)
    target_label = 'P'
示例#13
0
            plt.legend()

            plt.figure(3)
            plt.plot(raw_sig, label='raw signal')
            amp_list = [raw_sig[x] for x in qrs_arr]
            plt.plot(qrs_arr, amp_list, 'r^', markersize=12)
            plt.title('Raw signal')
            plt.legend()

        return qrs_arr


if __name__ == '__main__':
    from QTdata.loadQTdata import QTloader
    qt = QTloader()
    recname = qt.getreclist()[67]
    print 'record name:', recname

    sig = qt.load(recname)
    raw_sig = sig['sig'][0:]

    debug_info = dict()
    debug_info['time_cost'] = True
    debug_info['plot_results'] = True
    # debug_info['decision_plot'] = 25262
    detector = DPI_QRS_Detector(debug_info=debug_info)
    qrs_arr = detector.QRS_Detection(np.array(raw_sig))

    # Plot R-R histogram
    import matplotlib.pyplot as plt
    plt.figure(2)
示例#14
0
def RoundTest(target_label, result_folder, num_training = 75):
    '''Randomly select num_training records to train, and test others.'''
    qt = QTloader()
    record_list = qt.getreclist()
    must_train_list = [
        "sel35", 
        "sel36", 
        "sel31", 
        "sel38", 
        "sel39", 
        "sel820", 
        "sel51", 
        "sele0104", 
        "sele0107", 
        "sel223", 
        "sele0607", 
        "sel102", 
        "sele0409", 
        "sel41", 
        "sel40", 
        "sel43", 
        "sel42", 
        "sel45", 
        "sel48", 
        "sele0133", 
        "sele0116", 
        "sel14172", 
        "sele0111", 
        "sel213", 
        "sel14157", 
        "sel301"
            ]
    num_training -= len(must_train_list)
    record_list = list(set(record_list) - set(must_train_list))
    training_list = must_train_list
    if num_training > 0:
        training_list.extend(random.sample(record_list, num_training))
    testing_list = list(set(record_list) - set(training_list))

    random_forest_config = dict(
            max_depth = 10)
    walker = RandomWalker(target_label = target_label,
            random_forest_config = random_forest_config)

    start_time = time.time()
    for record_name in training_list:
        print 'Collecting features from record %s.' % record_name
        sig = qt.load(record_name)
        walker.collect_training_data(sig['sig'], qt.getExpert(record_name))
    print 'random forest start training...'
    walker.training()
    print 'trianing used %.3f seconds' % (time.time() - start_time)

    for record_name in testing_list:
        print 'testing record %s...' % record_name
        record_result = list()
        sig = qt.load(record_name)
        raw_sig = sig['sig']
        record_result.append((record_name, testing(walker, raw_sig)))
        raw_sig = sig['sig2']
        record_result.append((record_name + '_sig2', testing(walker, raw_sig)))
        # Write to json
        with open(os.path.join(result_folder, '%s.json' % record_name), 'w') as fout:
            json.dump(record_result, fout, indent = 4)
示例#15
0
def QRS_Detection(fs=250.0):
    '''High pass filtering.'''

    qt = QTloader()
    print qt.getreclist()

    sig = qt.load('sel40')
    raw_sig = sig['sig'][0:5500]
    fsig = HPF(raw_sig)

    # DPI
    m1 = -2
    len_sig = fsig.size

    qrs_arr = list()
    ind = 10

    while ind < len_sig:
        dpi_arr = list()
        N_m2 = int(fs * 1.71)
        for m2 in xrange(0, N_m2):
            lower_index = ind + m1 + 1
            upper_index = ind + m1 + m2

            lower_index = max(0, lower_index)
            upper_index = min(len_sig - 1, upper_index)

            if upper_index >= lower_index:
                s_avg = float(np.sum(np.abs(fsig[lower_index:upper_index +
                                                 1])))
                # s_avg /= m2 ** 0.5
                s_avg /= math.pow(m2, 0.5)
            else:
                s_avg = 1.0

            # Prevent 0 division error
            if s_avg < 1e-6:
                s_avg = 1.0

            dpi_val = np.abs(fsig[ind]) / s_avg
            dpi_arr.append(dpi_val)
        # Find cross zeros
        dpi_difference = [x[1] - x[0] for x in zip(dpi_arr, dpi_arr[1:])]
        cross_zero_positions = [
            0,
        ] * len(dpi_difference)
        for diff_ind in xrange(1, len(cross_zero_positions)):
            if dpi_difference[diff_ind] == 0:
                cross_zero_positions[diff_ind] = 2
            elif dpi_difference[diff_ind - 1] * dpi_difference[diff_ind] < 0:
                if dpi_difference[diff_ind] > 0:
                    cross_zero_positions[diff_ind] = -1
                else:
                    cross_zero_positions[diff_ind] = 1

        # Find max swing
        min_distance_to_current_QRS = fs * 0.3
        max_swing_value = None
        max_swing_pair = [0, 0]
        prev_peak_position = None
        for cross_ind, val in enumerate(cross_zero_positions):
            if val == 1:
                prev_peak_position = cross_ind
            elif val == -1:
                if prev_peak_position is not None:
                    cur_amplitude_difference = dpi_arr[
                        prev_peak_position] - dpi_arr[cross_ind]
                    if cross_ind >= min_distance_to_current_QRS:
                        if max_swing_value is None or max_swing_value < cur_amplitude_difference:
                            max_swing_value = cur_amplitude_difference
                            max_swing_pair = [prev_peak_position, cross_ind]
                prev_peak_position = None
        if max_swing_value is None:
            break
        center_pos = sum(max_swing_pair) / 2.0
        search_radius = fs * 285.0 / 1000

        search_left = int(max(0, center_pos - search_radius + ind))
        search_right = int(min(len_sig - 1, center_pos + search_radius + ind))

        max_qrs_amplitude = fsig[center_pos]
        qrs_position = center_pos

        for sig_ind in xrange(search_left, search_right + 1):
            sig_val = fsig[sig_ind]
            if sig_val > max_qrs_amplitude:
                max_qrs_amplitude = sig_val
                qrs_position = sig_ind

        # debug
        # plt.plot(xrange(ind, ind + len(dpi_arr)), dpi_arr, label = 'DPI')
        # plt.plot(fsig, label = 'fsig')
        # amp_list = [fsig[x] for x in qrs_arr]
        # plt.plot(qrs_arr, amp_list, 'g^', markersize = 12, label = 'QRS detected')
        # plt.plot(np.array(max_swing_pair) + ind, [dpi_arr[x] for x in max_swing_pair], 'md', markersize = 12,
        # label = 'max_swing_pair')
        # plt.plot(ind, fsig[ind], 'mx', markersize = 12, label = 'current index position')
        # plt.title('DPI')
        # plt.legend()
        # plt.show()

        if qrs_position >= len_sig or qrs_position <= ind:
            break
        qrs_arr.append(qrs_position)
        ind = qrs_position

    plt.plot(xrange(ind, ind + len(dpi_arr)), dpi_arr, label='DPI')
    plt.plot(fsig, label='fsig')
    amp_list = [fsig[x] for x in qrs_arr]
    plt.plot(qrs_arr, amp_list, 'r^', markersize=12)
    plt.title('DPI')
    plt.legend()
    plt.show()