def TrainingModels(target_label, model_file_name, training_list): '''Randomly select num_training records to train, and test others.''' qt = QTloader() record_list = qt.getreclist() testing_list = list(set(record_list) - set(training_list)) random_forest_config = dict(max_depth=10) walker = RandomWalker(target_label=target_label, random_forest_config=random_forest_config, random_pattern_file_name=os.path.join( os.path.dirname(model_file_name), 'random_pattern.json')) start_time = time.time() for record_name in training_list: print 'Collecting features from record %s.' % record_name sig = qt.load(record_name) walker.collect_training_data(sig['sig'], qt.getExpert(record_name)) print 'random forest start training(%s)...' % target_label walker.training() print 'trianing used %.3f seconds' % (time.time() - start_time) import joblib start_time = time.time() walker.save_model(model_file_name) print 'Serializing model time cost %f' % (time.time() - start_time)
def Test1(): '''Comparing to expert labels in QTdb.''' qt = QTloader() reclist = qt.getreclist() rec_ind = 0 for rec_ind in xrange(0, len(reclist)): print 'Processing record[%d] %s ...' % (rec_ind, reclist[rec_ind]) sig = qt.load(reclist[rec_ind]) raw_sig = sig['sig'] expert_labels = qt.getExpert(reclist[rec_ind]) R_pos_list = [ x[0] for x in filter(lambda item: item[1] == 'R', expert_labels) ] # Skip empty expert lists if len(R_pos_list) == 0: continue dpi = DPI() qrs_list = dpi.QRS_Detection(raw_sig) # Find FN FN_arr = GetFN(R_pos_list, qrs_list) R_pos_list = FN_arr if len(R_pos_list) > 0: plt.plot(raw_sig) amp_list = [raw_sig[x] for x in qrs_list] plt.plot(qrs_list, amp_list, 'ro', markersize=12) amp_list = [raw_sig[x] for x in R_pos_list] plt.plot(R_pos_list, amp_list, 'ys', markersize=14) plt.show()
def randomTraining(root_folder='models/round1', num_training=75, random_pattern_path='models/random_pattern.json'): label_list = [ 'P', 'Ponset', 'Poffset', 'T', 'Toffset', 'Ronset', 'Roffset' ] if os.path.exists(root_folder) == False: os.mkdir(root_folder) # Refresh training list trianing_list = list() qt = QTloader() record_list = qt.getreclist() must_train_list = [ "sel35", "sel36", "sel31", "sel38", "sel39", "sel820", "sel51", "sele0104", "sele0107", "sel223", "sele0607", "sel102", "sele0409", "sel41", "sel40", "sel43", "sel42", "sel45", "sel48", "sele0133", "sele0116", "sel14172", "sele0111", "sel213", "sel14157", "sel301" ] num_training -= len(must_train_list) record_list = list(set(record_list) - set(must_train_list)) training_list = must_train_list if num_training > 0: training_list.extend(random.sample(record_list, num_training)) # Save training list with open(os.path.join(root_folder, 'training_list.json'), 'w') as fout: json.dump(training_list, fout, indent=4) for target_label in label_list: model_file_name = os.path.join(root_folder, '%s.mdl' % target_label) TrainingModels(target_label, model_file_name, training_list, random_pattern_path=random_pattern_path) # Testing testing_list = list(set(qt.getreclist()) - set(training_list)) saveresultpath = os.path.join(root_folder, 'results') if os.path.exists(saveresultpath) == False: os.mkdir(saveresultpath) from record_test import RoundTesting RoundTesting(saveresultpath, testing_list, model_folder=root_folder, pattern_filename=random_pattern_path)
def Test_hog1d(): '''Hog feature method test.''' target_label = 'T' qt = QTloader() rec_ind = 103 reclist = qt.getreclist() sig = qt.load(reclist[rec_ind]) raw_sig = sig['sig'] len_sig = len(raw_sig) # Hog feature tester hog = HogFeatureExtractor(target_label=target_label) rec_list = hog.qt.getreclist() training_list = reclist[0:100] # training_rec = list(set(rec_list) - set(testing_rec)) hog.Train(training_list) # Load the trained model # with open('./hog.mdl', 'rb') as fin: # mdl = pickle.load(fin) # hog.LoadModel(mdl) # Save the trained model with open('./hog.mdl', 'wb') as fout: pickle.dump(hog.gbdt, fout) dpi = DPI() qrs_list = dpi.QRS_Detection(raw_sig) start_time = time.time() detected_poslist = hog.Testing(raw_sig, zip(qrs_list, [ 'R', ] * len(qrs_list))) print 'Testing time: %d s.' % (time.time() - start_time) while len(detected_poslist) > 0 and detected_poslist[-1] > len_sig: del detected_poslist[-1] while len(detected_poslist) > 0 and detected_poslist[0] < 0: del detected_poslist[0] sig = qt.load(reclist[rec_ind]) raw_sig = sig['sig'] plt.plot(raw_sig, label='raw signal D1') plt.plot(sig['sig2'], label='raw signal D2') amp_list = [raw_sig[int(x)] for x in detected_poslist] plt.plot(detected_poslist, amp_list, 'ro', markersize=12, label=target_label) plt.title('Record name %s' % reclist[rec_ind]) plt.legend() plt.show()
def viewQT(): qt = QTloader() record_list = qt.getreclist() index = 19 for record_name in record_list[index:]: print 'record index:', index # if record_name != 'sele0612': # continue sig = qt.load(record_name) raw_sig = sig['sig'][2000:7000] viewCWTsignal(raw_sig, 250, figure_title=record_name) index += 1
def TrainingModels(target_label, model_file_name, training_list): '''Randomly select num_training records to train, and test others. CP: Characteristic points ''' qt = QTloader() record_list = qt.getreclist() testing_list = list(set(record_list) - set(training_list)) random_forest_config = dict(max_depth=10) walker = RandomWalker(target_label=target_label, random_forest_config=random_forest_config, random_pattern_file_name=os.path.join( os.path.dirname(model_file_name), 'random_pattern.json')) start_time = time.time() for record_name in training_list: CP_file_name = os.path.join( '/home/alex/code/Python/EcgCharacterPointMarks', target_label, '%s_poslist.json' % record_name) # Add expert marks expert_marks = qt.getExpert(record_name) CP_marks = [x for x in expert_marks if x[1] == target_label] if len(CP_marks) == 0: continue # Add manual labels if possible if os.path.exists(CP_file_name) == True: with open(CP_file_name, 'r') as fin: CP_info = json.load(fin) poslist = CP_info['poslist'] if len(poslist) == 0: continue CP_marks.extend(zip(poslist, [ target_label, ] * len(poslist))) print 'Collecting features from record %s.' % record_name sig = qt.load(record_name) walker.collect_training_data(sig['sig'], CP_marks) print 'random forest start training(%s)...' % target_label walker.training() print 'trianing used %.3f seconds' % (time.time() - start_time) import joblib start_time = time.time() walker.save_model(model_file_name) print 'Serializing model time cost %f' % (time.time() - start_time)
def ContinueAddQtTrainingSamples(walker, target_label): '''Add QT training samples.''' qt = QTloader() record_list = qt.getreclist() start_time = time.time() for record_name in record_list: # Add expert marks expert_marks = qt.getExpert(record_name) CP_marks = [x for x in expert_marks if x[1] == target_label] if len(CP_marks) == 0: continue print 'Collecting features from QT record %s.' % record_name sig = qt.load(record_name) walker.collect_training_data(sig['sig'], CP_marks)
def Test_regression(): '''Regression test.''' target_label = 'T' qt = QTloader() rec_ind = 65 reclist = qt.getreclist() sig = qt.load(reclist[rec_ind]) raw_sig = sig['sig'] len_sig = len(raw_sig) rf = RegressionLearner(target_label=target_label) rf.TrainQtRecords(reclist[0:]) # Load the trained model # with open('./tmp.mdl', 'rb') as fin: # mdl = pickle.load(fin) # rf.LoadModel(mdl) # Save the trained model with open('./tmp.mdl', 'wb') as fout: pickle.dump(rf.mdl, fout) dpi = DPI() qrs_list = dpi.QRS_Detection(raw_sig) start_time = time.time() detected_poslist = rf.testing(raw_sig, zip(qrs_list, [ 'R', ] * len(qrs_list))) print 'Testing time: %d s.' % (time.time() - start_time) while len(detected_poslist) > 0 and detected_poslist[-1] > len_sig: del detected_poslist[-1] sig = qt.load(reclist[rec_ind]) raw_sig = sig['sig'] plt.plot(raw_sig, label='raw signal') amp_list = [raw_sig[int(x)] for x in detected_poslist] plt.plot(detected_poslist, amp_list, 'ro', markersize=12, label=target_label) plt.legend() plt.show()
def TrainingModels(target_label, model_file_name, training_list): '''Randomly select num_training records to train, and test others.''' qt = QTloader() record_list = qt.getreclist() testing_list = list(set(record_list) - set(training_list)) random_forest_config = dict(max_depth=10) walker = RandomWalker(target_label=target_label, random_forest_config=random_forest_config, random_pattern_file_name=os.path.join( os.path.dirname(model_file_name), 'random_pattern.json')) start_time = time.time() for record_name in training_list: Tonset_file_name = os.path.join( '/home/alex/code/Python/Tonset/results', '%s_poslist.json' % record_name) if os.path.exists(Tonset_file_name) == True: with open(Tonset_file_name, 'r') as fin: Tonset_info = json.load(fin) poslist = Tonset_info['poslist'] if len(poslist) == 0: continue Tonset_marks = zip(poslist, [ 'Tonset', ] * len(poslist)) else: expert_marks = qt.getExpert(record_name) Tonset_marks = [x for x in expert_marks if x[1] == 'Tonset'] if len(Tonset_marks) == 0: continue print 'Collecting features from record %s.' % record_name sig = qt.load(record_name) walker.collect_training_data(sig['sig'], Tonset_marks) print 'random forest start training(%s)...' % target_label walker.training() print 'trianing used %.3f seconds' % (time.time() - start_time) import joblib start_time = time.time() walker.save_model(model_file_name) print 'Serializing model time cost %f' % (time.time() - start_time)
def SplitQTrecords(num_training = 75): '''Split records for testing & training.''' trianing_list = list() qt = QTloader() record_list = qt.getreclist() must_train_list = [ "sel35", "sel36", "sel31", "sel38", "sel39", "sel820", "sel51", "sele0104", "sele0107", "sel223", "sele0607", "sel102", "sele0409", "sel41", "sel40", "sel43", "sel42", "sel45", "sel48", "sele0133", "sele0116", "sel14172", "sele0111", "sel213", "sel14157", "sel301" ] num_training -= len(must_train_list) record_list = list(set(record_list) - set(must_train_list)) training_list = must_train_list # Refresh training list if num_training > 0: training_list.extend(random.sample(record_list, num_training)) testing_list = list(set(record_list) - set(training_list)) return (training_list, testing_list)
def Test1(target_label='P', num_training=25): '''Test case 1: random walk.''' qt = QTloader() record_list = qt.getreclist() training_list = random.sample(record_list, num_training) testing_list = list(set(record_list) - set(training_list)) random_forest_config = dict(max_depth=10) walker = RandomWalker(target_label=target_label, random_forest_config=random_forest_config) start_time = time.time() for record_name in training_list: print 'Collecting features from record %s.' % record_name sig = qt.load(record_name) walker.collect_training_data(sig['sig'], qt.getExpert(record_name)) print 'random forest start training...' walker.training() print 'trianing used %.3f seconds' % (time.time() - start_time) for record_name in testing_list: sig = qt.load(record_name) raw_sig = sig['sig'] seed_position = random.randint(100, len(raw_sig) - 100) plt.figure(1) plt.clf() plt.plot(sig['sig'], label=record_name) plt.title(target_label) for ti in xrange(0, 20): seed_position += random.randint(1, 200) print 'testing...(position: %d)' % seed_position start_time = time.time() results = walker.testing_walk(sig['sig'], seed_position, iterations=100, stepsize=10) print 'testing finished in %.3f seconds.' % (time.time() - start_time) pos_list, values = zip(*results) predict_pos = np.mean(pos_list[len(pos_list) / 2:]) # amp_list = [raw_sig[int(x)] for x in pos_list] amp_list = [] bias = raw_sig[pos_list[0]] for pos in pos_list: amp_list.append(bias) bias -= 0.01 plt.plot(predict_pos, raw_sig[int(predict_pos)], 'ro', markersize=14, label='predict position') plt.plot(pos_list, amp_list, 'r', label='walk path', markersize=3, linewidth=8, alpha=0.3) plt.xlim(min(pos_list) - 100, max(pos_list) + 100) plt.grid(True) plt.legend() plt.show(block=False) pdb.set_trace()
walker.training() print 'trianing used %.3f seconds' % (time.time() - start_time) import joblib start_time = time.time() walker.save_model(model_file_name) print 'Serializing model time cost %f' % (time.time() - start_time) if __name__ == '__main__': root_folder = 'data/Lw3Np4000/improved' # Refresh training list num_training = 105 trianing_list = list() qt = QTloader() record_list = qt.getreclist() must_train_list = [ "sel35", "sel36", "sel31", "sel38", "sel39", "sel820", "sel51", "sele0104", "sele0107", "sel223", "sele0607", "sel102", "sele0409", "sel41", "sel40", "sel43", "sel42", "sel45", "sel48", "sele0133", "sele0116", "sel14172", "sele0111", "sel213", "sel14157", "sel301" ] num_training -= len(must_train_list) record_list = list(set(record_list) - set(must_train_list)) training_list = must_train_list if num_training > 0: training_list.extend(random.sample(record_list, num_training)) # Save training list with open(os.path.join(root_folder, 'training_list.json'), 'w') as fout: json.dump(training_list, fout, indent=4) target_label = 'P'
plt.legend() plt.figure(3) plt.plot(raw_sig, label='raw signal') amp_list = [raw_sig[x] for x in qrs_arr] plt.plot(qrs_arr, amp_list, 'r^', markersize=12) plt.title('Raw signal') plt.legend() return qrs_arr if __name__ == '__main__': from QTdata.loadQTdata import QTloader qt = QTloader() recname = qt.getreclist()[67] print 'record name:', recname sig = qt.load(recname) raw_sig = sig['sig'][0:] debug_info = dict() debug_info['time_cost'] = True debug_info['plot_results'] = True # debug_info['decision_plot'] = 25262 detector = DPI_QRS_Detector(debug_info=debug_info) qrs_arr = detector.QRS_Detection(np.array(raw_sig)) # Plot R-R histogram import matplotlib.pyplot as plt plt.figure(2)
def RoundTest(target_label, result_folder, num_training = 75): '''Randomly select num_training records to train, and test others.''' qt = QTloader() record_list = qt.getreclist() must_train_list = [ "sel35", "sel36", "sel31", "sel38", "sel39", "sel820", "sel51", "sele0104", "sele0107", "sel223", "sele0607", "sel102", "sele0409", "sel41", "sel40", "sel43", "sel42", "sel45", "sel48", "sele0133", "sele0116", "sel14172", "sele0111", "sel213", "sel14157", "sel301" ] num_training -= len(must_train_list) record_list = list(set(record_list) - set(must_train_list)) training_list = must_train_list if num_training > 0: training_list.extend(random.sample(record_list, num_training)) testing_list = list(set(record_list) - set(training_list)) random_forest_config = dict( max_depth = 10) walker = RandomWalker(target_label = target_label, random_forest_config = random_forest_config) start_time = time.time() for record_name in training_list: print 'Collecting features from record %s.' % record_name sig = qt.load(record_name) walker.collect_training_data(sig['sig'], qt.getExpert(record_name)) print 'random forest start training...' walker.training() print 'trianing used %.3f seconds' % (time.time() - start_time) for record_name in testing_list: print 'testing record %s...' % record_name record_result = list() sig = qt.load(record_name) raw_sig = sig['sig'] record_result.append((record_name, testing(walker, raw_sig))) raw_sig = sig['sig2'] record_result.append((record_name + '_sig2', testing(walker, raw_sig))) # Write to json with open(os.path.join(result_folder, '%s.json' % record_name), 'w') as fout: json.dump(record_result, fout, indent = 4)
def QRS_Detection(fs=250.0): '''High pass filtering.''' qt = QTloader() print qt.getreclist() sig = qt.load('sel40') raw_sig = sig['sig'][0:5500] fsig = HPF(raw_sig) # DPI m1 = -2 len_sig = fsig.size qrs_arr = list() ind = 10 while ind < len_sig: dpi_arr = list() N_m2 = int(fs * 1.71) for m2 in xrange(0, N_m2): lower_index = ind + m1 + 1 upper_index = ind + m1 + m2 lower_index = max(0, lower_index) upper_index = min(len_sig - 1, upper_index) if upper_index >= lower_index: s_avg = float(np.sum(np.abs(fsig[lower_index:upper_index + 1]))) # s_avg /= m2 ** 0.5 s_avg /= math.pow(m2, 0.5) else: s_avg = 1.0 # Prevent 0 division error if s_avg < 1e-6: s_avg = 1.0 dpi_val = np.abs(fsig[ind]) / s_avg dpi_arr.append(dpi_val) # Find cross zeros dpi_difference = [x[1] - x[0] for x in zip(dpi_arr, dpi_arr[1:])] cross_zero_positions = [ 0, ] * len(dpi_difference) for diff_ind in xrange(1, len(cross_zero_positions)): if dpi_difference[diff_ind] == 0: cross_zero_positions[diff_ind] = 2 elif dpi_difference[diff_ind - 1] * dpi_difference[diff_ind] < 0: if dpi_difference[diff_ind] > 0: cross_zero_positions[diff_ind] = -1 else: cross_zero_positions[diff_ind] = 1 # Find max swing min_distance_to_current_QRS = fs * 0.3 max_swing_value = None max_swing_pair = [0, 0] prev_peak_position = None for cross_ind, val in enumerate(cross_zero_positions): if val == 1: prev_peak_position = cross_ind elif val == -1: if prev_peak_position is not None: cur_amplitude_difference = dpi_arr[ prev_peak_position] - dpi_arr[cross_ind] if cross_ind >= min_distance_to_current_QRS: if max_swing_value is None or max_swing_value < cur_amplitude_difference: max_swing_value = cur_amplitude_difference max_swing_pair = [prev_peak_position, cross_ind] prev_peak_position = None if max_swing_value is None: break center_pos = sum(max_swing_pair) / 2.0 search_radius = fs * 285.0 / 1000 search_left = int(max(0, center_pos - search_radius + ind)) search_right = int(min(len_sig - 1, center_pos + search_radius + ind)) max_qrs_amplitude = fsig[center_pos] qrs_position = center_pos for sig_ind in xrange(search_left, search_right + 1): sig_val = fsig[sig_ind] if sig_val > max_qrs_amplitude: max_qrs_amplitude = sig_val qrs_position = sig_ind # debug # plt.plot(xrange(ind, ind + len(dpi_arr)), dpi_arr, label = 'DPI') # plt.plot(fsig, label = 'fsig') # amp_list = [fsig[x] for x in qrs_arr] # plt.plot(qrs_arr, amp_list, 'g^', markersize = 12, label = 'QRS detected') # plt.plot(np.array(max_swing_pair) + ind, [dpi_arr[x] for x in max_swing_pair], 'md', markersize = 12, # label = 'max_swing_pair') # plt.plot(ind, fsig[ind], 'mx', markersize = 12, label = 'current index position') # plt.title('DPI') # plt.legend() # plt.show() if qrs_position >= len_sig or qrs_position <= ind: break qrs_arr.append(qrs_position) ind = qrs_position plt.plot(xrange(ind, ind + len(dpi_arr)), dpi_arr, label='DPI') plt.plot(fsig, label='fsig') amp_list = [fsig[x] for x in qrs_arr] plt.plot(qrs_arr, amp_list, 'r^', markersize=12) plt.title('DPI') plt.legend() plt.show()