def res_to_mat_fromResult(recID, reslist, mat_filename): # load QT rawsig qt = QTloader() sig = qt.load(recID) rawsig = sig['sig'] # load expert label expert_reslist = qt.getexpertlabeltuple(recID) # save sig and reslist label_dict = dict() for pos, label in reslist: if label in label_dict: label_dict[label].append(pos) else: label_dict[label] = [ pos, ] # Expert Labels for pos, label in expert_reslist: exp_label = 'expert_' + label if exp_label in label_dict: label_dict[exp_label].append(pos) else: label_dict[exp_label] = [ pos, ] label_dict['sig'] = rawsig scipy.io.savemat(mat_filename, label_dict) print 'mat file [{}] saved.'.format(mat_filename)
def plot_QTdb_filtered_Result_with_syntax_filter(RFfolder, TargetRecordList, ResultFilterType, showExpertLabels=False): # exit # ========================== # plot prediction result # ========================== reslist = getresultfilelist(RFfolder) qtdb = QTloader() non_result_extensions = ['out', 'json', 'log', 'txt'] for fi, fname in enumerate(reslist): # block *.out file_extension = fname.split('.')[-1] if file_extension in non_result_extensions: continue print 'file name:', fname currecname = os.path.split(fname)[-1] print currecname #if currecname == 'result_sel820': #pdb.set_trace() if TargetRecordList is not None: if currecname not in TargetRecordList: continue # load signal and reslist with open(fname, 'r') as fin: (recID, reslist) = pickle.load(fin) # filter result of QT resfilter = ResultFilter(reslist) if len(ResultFilterType) >= 1 and ResultFilterType[0] == 'G': reslist = resfilter.group_local_result(cp_del_thres=1) reslist_syntax = reslist if len(ResultFilterType) >= 2 and ResultFilterType[1] == 'S': reslist_syntax = resfilter.syntax_filter(reslist) # empty signal result #if reslist is None or len(reslist) == 0: #continue #pdb.set_trace() sigstruct = qtdb.load(recID) if showExpertLabels == True: # Expert Label AdditionalPlot ExpertRes = qtdb.getexpertlabeltuple(recID) ExpertPoslist = map(lambda x: x[0], ExpertRes) AdditionalPlot = [ ['kd', 'Expert Labels', ExpertPoslist], ] else: AdditionalPlot = None # plot res #resploter = ECGResultPloter(sigstruct['sig'],reslist) #resploter.plot(plotTitle = 'QT database',plotShow = True,plotFig = 2) # syntax_filter resploter_syntax = ECGResultPloter(sigstruct['sig'], reslist_syntax) resploter_syntax.plot(plotTitle='QT Record {}'.format(recID), plotShow=True, AdditionalPlot=AdditionalPlot)
class SWT_GroupResult2Leads: ''' Find P&T peak with SWT+db6 ''' def __init__(self, recname, reslist, leadname, MaxSWTLevel=9): self.recres = reslist #self.LeadRes = (reslist,reslist2) self.recname = recname self.QTdb = QTloader() self.sig_struct = self.QTdb.load(self.recname) self.rawsig = self.sig_struct[leadname] self.res_groups = None self.peak_dict = dict(T=[], P=[], Ponset=[], Poffset=[], Tonset=[], Toffset=[]) # Get SWT coef swt = NonQRS_SWT() swt.swt(self.recname) self.cDlist = swt.cDlist self.cAlist = swt.cAlist #self.getSWTcoeflist(MaxLevel = MaxSWTLevel) def group_result(self, white_del_thres=20, cp_del_thres=0): # # 参数说明:1.white_del_thres是删除较小白色组的阈值 # 2.cp_del_thres是删除较小其他关键点组的阈值 # Multiple prediction point -> single point output ## filter output for evaluation results # # parameters # # # the number of the group must be greater than: # # default parameter recres = self.recres # filtered test result frecres = [] # in var prev_label = None posGroup = [] #---------------------- # [pos,label] in recres #---------------------- for pos, label in recres: if prev_label is not None: if label != prev_label: frecres.append((prev_label, posGroup)) posGroup = [] prev_label = label posGroup.append(pos) # add last label group if len(posGroup) > 0: frecres.append((prev_label, posGroup)) #====================== # 1.删除比较小的白色组和其他组(different threshold) # 2.合并删除后的相邻同色组 #====================== filtered_local_res = [] for label, posGroup in frecres: if label == 'white' and len(posGroup) <= white_del_thres: continue if label != 'white' and len(posGroup) <= cp_del_thres: continue # can merge backward? if len(filtered_local_res ) > 0 and filtered_local_res[-1][0] == label: filtered_local_res[-1][1].extend(posGroup) else: filtered_local_res.append((label, posGroup)) frecres = filtered_local_res # [(label,[poslist])] self.res_groups = frecres return frecres def filter_smaller_nearby_groups(self, res_groups, group_near_dist_thres=100): # [(label,[poslist])] # filter close groups: # delete groups with smaller number frecres = res_groups N_groups = len(frecres) deleted_reslist = [] deleted_reslist.append(frecres[0]) for group_ind in xrange(1, N_groups): max_before = np.max(deleted_reslist[-1][1]) min_after = np.min(frecres[group_ind][1]) if min_after - max_before <= group_near_dist_thres: # keep the larger group if len(frecres[group_ind][1]) > len(deleted_reslist[-1][1]): # del delete del deleted_reslist[-1] deleted_reslist.append(frecres[group_ind]) else: deleted_reslist.append(frecres[group_ind]) return deleted_reslist def crop_data_for_swt(self, rawsig): # crop rawsig base2 = 1 N_data = len(rawsig) if len(rawsig) <= 1: raise Exception('len(rawsig)={},not enough for swt!', len(rawsig)) crop_len = base2 while base2 < N_data: if base2 * 2 >= N_data: crop_len = base2 * 2 break base2 *= 2 # pad zeros if N_data < crop_len: rawsig += [ rawsig[-1], ] * (crop_len - N_data) rawsig = rawsig[0:crop_len] return rawsig def getSWTcoeflist(self, MaxLevel=9): # crop two leads self.rawsig = self.crop_data_for_swt(self.rawsig) print '-' * 3 print 'len(rawsig)= ', len(self.rawsig) print 'SWT maximum level =', pywt.swt_max_level(len(self.rawsig)) coeflist = pywt.swt(self.rawsig, 'db6', MaxLevel) cAlist, cDlist = zip(*coeflist) self.cAlist = cAlist self.cDlist = cDlist def get_downward_cross_zero_list(self, array): # rising slope only if array is None or len(array) < 2: return [] crosszerolist = [] # # len of array N_array = len(array) for ind in xrange(1, N_array - 1): # cross zero if array[ind - 1] > 0 and array[ind] <= 0: crosszerolist.append(ind) return crosszerolist def get_cross_zero_list(self, array): # rising slope only if array is None or len(array) < 2: return [] crosszerolist = [] rem = array[0] N_array = len(array) for ind in xrange(1, N_array): if array[ind - 1] < 0 and array[ind] >= 0: crosszerolist.append(ind) return crosszerolist def get_local_minima_list(self, array): # rising slope only if array is None or len(array) < 2: return [] ret_list = [] # # length of array N = len(array) for ind in xrange(1, N - 1): if array[ind] - array[ind - 1] < 0 and array[ind + 1] - array[ind] >= 0: # local maxima ret_list.append(ind) return ret_list def get_local_maxima_list(self, array): # rising slope only if array is None or len(array) < 2: return [] ret_list = [] N = len(array) for ind in xrange(1, N - 1): if array[ind] - array[ind - 1] > 0 and array[ind + 1] - array[ind] <= 0: # local maxima ret_list.append(ind) return ret_list def bw_find_nearest(self, pos, array): N = len(array) insertPos = bisect.bisect_left(array, pos) if insertPos >= N: return abs(pos - array[-1]) elif insertPos == 0: return abs(pos - array[0]) else: return min(abs(pos - array[insertPos]), abs(pos - array[insertPos - 1])) def get_longest_downward_slope_index(self, candidate_list, crosszerolist, e4list): # # length of cross zeros N = len(crosszerolist) max_slope_len = -1 best_candidate = -1 for prd_pos in candidate_list: # find the closest cross point in the list to prd_pos insertPos = bisect.bisect_left(crosszerolist, prd_pos) peak_pos = -1 if insertPos >= N: peak_pos = N - 1 elif insertPos == 0: peak_pos = 0 else: if abs(prd_pos - crosszerolist[insertPos]) < abs( prd_pos - crosszerolist[insertPos - 1]): peak_pos = insertPos else: peak_pos = insertPos - 1 peak_pos = crosszerolist[peak_pos] # Find current slope length right_bound = N - 1 left_bound = 0 # reached N-1 or local maxima for cur_pos in xrange(peak_pos, N - 1): if e4list[cur_pos + 1] - e4list[cur_pos] >= 0: # local minima right_bound = cur_pos break # reached 0 for cur_pos in xrange(peak_pos, 0, -1): if e4list[cur_pos] - e4list[cur_pos - 1] >= 0: # local maxima left_bound = cur_pos break cur_slope_len = right_bound - left_bound if cur_slope_len > max_slope_len: max_slope_len = cur_slope_len best_candidate = peak_pos return best_candidate def get_current_upward_slope_steepness(self, pos, crosszerolist, e4list): # Frequent used var. N = len(crosszerolist) N_e4list = len(e4list) # Only one pos, find its left minima&right maxima. prd_pos = pos # find the closest cross point in the list to prd_pos insertPos = bisect.bisect_left(crosszerolist, prd_pos) peak_pos = -1 if insertPos >= N: peak_pos = N - 1 elif insertPos == 0: peak_pos = 0 else: if abs(prd_pos - crosszerolist[insertPos]) < abs( prd_pos - crosszerolist[insertPos - 1]): peak_pos = insertPos else: peak_pos = insertPos - 1 peak_pos = crosszerolist[peak_pos] # Find current slope length right_bound = N - 1 left_bound = 0 # reached N-1 or local maxima for cur_pos in xrange(peak_pos, N_e4list - 1): if e4list[cur_pos + 1] - e4list[cur_pos] <= 0: # local maxima right_bound = cur_pos break # reached 0 for cur_pos in xrange(peak_pos, 0, -1): if e4list[cur_pos] - e4list[cur_pos - 1] <= 0: # local minima left_bound = cur_pos break cur_slope_len = right_bound - left_bound if cur_slope_len <= 0: plt.figure(2) plt.plot(e4list) plt.plot(crosszerolist, map(lambda x: e4list[x], crosszerolist), 'ro') plt.plot(peak_pos, e4list[peak_pos], 'y*', markersize=14) plt.grid(True) plt.show() print print 'Warning: cur_slope_len <= 0!' pdb.set_trace() cur_slope_amp = abs(e4list[right_bound] - e4list[left_bound]) return float(cur_slope_amp) / cur_slope_len def get_steepest_slope_index(self, candidate_list, crosszerolist, e4list): N = len(crosszerolist) max_slope_steepness = None best_candidate = None for prd_pos in candidate_list: # find the closest cross point in the list to prd_pos insertPos = bisect.bisect_left(crosszerolist, prd_pos) peak_pos = -1 if insertPos >= N: peak_pos = N - 1 elif insertPos == 0: peak_pos = 0 else: if abs(prd_pos - crosszerolist[insertPos]) < abs( prd_pos - crosszerolist[insertPos - 1]): peak_pos = insertPos else: peak_pos = insertPos - 1 peak_pos = crosszerolist[peak_pos] # Find current slope length right_bound = N - 1 left_bound = 0 right_amp, left_amp = None, None # reached N-1 or local maxima N_e4list = len(e4list) for cur_pos in xrange(peak_pos, N_e4list - 1): if e4list[cur_pos + 1] - e4list[cur_pos] <= 0: # local maxima right_bound = cur_pos right_amp = e4list[cur_pos] break # reached 0 for cur_pos in xrange(peak_pos, 0, -1): if e4list[cur_pos] - e4list[cur_pos - 1] <= 0: # local minima left_bound = cur_pos left_amp = e4list[cur_pos] break cur_slope_len = right_bound - left_bound cur_steepness = float(abs(right_amp - left_amp)) / cur_slope_len # choose the steepest one as the finnal position if max_slope_steepness is None or cur_steepness > max_slope_steepness: max_slope_steepness = cur_steepness best_candidate = prd_pos return best_candidate def get_steepest_downward_slope_index(self, candidate_list, crosszerolist, e4list): # # length of cross zeros N = len(crosszerolist) max_slope_len = -1 best_candidate = -1 max_slope_steepness = None for prd_pos in candidate_list: # find the closest cross point in the list to prd_pos insertPos = bisect.bisect_left(crosszerolist, prd_pos) peak_pos = -1 if insertPos >= N: peak_pos = N - 1 elif insertPos == 0: peak_pos = 0 else: if abs(prd_pos - crosszerolist[insertPos]) < abs( prd_pos - crosszerolist[insertPos - 1]): peak_pos = insertPos else: peak_pos = insertPos - 1 peak_pos = crosszerolist[peak_pos] # Find current slope length right_bound = N - 1 left_bound = 0 right_amp, left_amp = None, None # reached N-1 or local maxima for cur_pos in xrange(peak_pos, len(e4list)): if e4list[cur_pos + 1] - e4list[cur_pos] >= 0: # local minima right_bound = cur_pos right_amp = e4list[cur_pos] break # reached 0 for cur_pos in xrange(peak_pos, 0, -1): if e4list[cur_pos] - e4list[cur_pos - 1] >= 0: # local maxima left_bound = cur_pos left_amp = e4list[cur_pos] break cur_slope_len = right_bound - left_bound if left_amp is None or right_amp is None: print 'left_amp,right_amp:' print left_amp, right_amp pdb.set_trace() cur_steepness = float(abs(right_amp - left_amp)) / cur_slope_len # choose the steepest one as the finnal position if max_slope_steepness is None or cur_steepness > max_slope_steepness: max_slope_steepness = cur_steepness best_candidate = prd_pos return best_candidate def get_longest_slope_index(self, candidate_list, crosszerolist, e4list): N = len(crosszerolist) N_e4list = len(e4list) max_slope_len = None best_candidate = None for prd_pos in candidate_list: # find the closest cross point in the list to prd_pos insertPos = bisect.bisect_left(crosszerolist, prd_pos) peak_pos = -1 if insertPos >= N: peak_pos = N - 1 elif insertPos == 0: peak_pos = 0 else: if abs(prd_pos - crosszerolist[insertPos]) < abs( prd_pos - crosszerolist[insertPos - 1]): peak_pos = insertPos else: peak_pos = insertPos - 1 peak_pos = crosszerolist[peak_pos] # Find current slope length right_bound = N - 1 left_bound = 0 # reached N-1 or local maxima for cur_pos in xrange(peak_pos, N_e4list - 1): if e4list[cur_pos + 1] - e4list[cur_pos] <= 0: # local maxima right_bound = cur_pos break # reached 0 for cur_pos in xrange(peak_pos, 0, -1): if e4list[cur_pos] - e4list[cur_pos - 1] <= 0: # local minima left_bound = cur_pos break cur_slope_len = right_bound - left_bound if max_slope_len is None or cur_slope_len > max_slope_len: max_slope_len = cur_slope_len best_candidate = prd_pos return best_candidate def get_T_peaklist(self): if self.res_groups is None: # group the raw predition results if not grouped already self.group_result() res_groups = filter(lambda x: x[0] == 'T', self.res_groups) res_groups = self.filter_smaller_nearby_groups(res_groups) # get T peak #e4list = np.array(self.cDlist[-4])+np.array(self.cDlist[-5]) D6list = np.array(self.cDlist[-6]) D5list = np.array(self.cDlist[-5]) crosszerolist = self.get_cross_zero_list(D6list) D5crosszerolist = self.get_cross_zero_list(D5list) # debug :D5crosszerolist check! # e4list = D5list # plt.figure(2) # plt.plot(e4list) # plt.plot(D5crosszerolist,map(lambda x:e4list[x],D5crosszerolist),'ro') # plt.plot(155424,D5list[155424],'y*',markersize = 14) # plt.grid(True) # plt.show() # if 155424 in D5crosszerolist: # print '155424' # pdb.set_trace() # for debug sig_struct = self.QTdb.load(self.recname) raw_sig = sig_struct['sig'] debug_res_group_ind = 21 for label, posgroup in res_groups[22:]: debug_res_group_ind += 1 scorelist = [] D5scorelist = [] for pos in posgroup: nearest_dist = self.bw_find_nearest(pos, crosszerolist) scorelist.append((nearest_dist, pos)) # for D5 D5nearest_dist = self.bw_find_nearest(pos, D5crosszerolist) D5scorelist.append((D5nearest_dist, pos)) # get all pos with min score min_score, candidate_list = self.get_min_score_poslist(scorelist) D5min_score, D5candidate_list = self.get_min_score_poslist( D5scorelist) D5pos, D6pos = [], [] D5_swt_mark, D6_swt_mark = True, True final_decision_peak = -1 # get D6 crosszero position if min_score > 2: # not a peak print print 'Warning: using mean group position!' D6_swt_mark = False D6pos.append(np.mean(posgroup)) elif len(candidate_list) == 1: # only mean score by SWT D6pos.append(candidate_list[0]) else: # multiple min score longest_slope_index = self.get_longest_slope_index( candidate_list, crosszerolist, D6list) D6pos.append(crosszerolist[longest_slope_index]) # get D5 crosszero position if D5min_score > 2: # not a peak print print 'Warning: using mean group position!' D5_swt_mark = False D5pos.append(np.mean(posgroup)) elif len(D5candidate_list) == 1: # only mean score by SWT D5pos.append(D5candidate_list[0]) else: # multiple min score longest_slope_index = self.get_longest_slope_index( D5candidate_list, D5crosszerolist, D5list) D5pos.append(D5crosszerolist[longest_slope_index]) # plot two positions for debug --- check! print 'debug_res_group_ind', debug_res_group_ind # Get the slope for D5list and D6list. if D5_swt_mark and D6_swt_mark: print 'getting D5 slope:' D5slope = self.get_current_upward_slope_steepness( D5pos[-1], D5crosszerolist, D5list) print 'getting D6 slope:' D6slope = self.get_current_upward_slope_steepness( D6pos[-1], crosszerolist, D6list) print 'D6 slope value:' print D6slope print 'D5 slope value:' print D5slope pdb.set_trace() if D5slope < D6slope: final_decision_peak = D6pos[-1] else: final_decision_peak = D5pos[-1] elif D5_swt_mark: final_decision_peak = D5pos[-1] elif D6_swt_mark: final_decision_peak = D6pos[-1] else: # Default value is D6pos final_decision_peak = D6pos[-1] # debug plot # 1. get range seg_range = [min(posgroup), max(posgroup)] seg_range[0] = max(0, seg_range[0] - 100) seg_range[1] = min(len(raw_sig) - 1, seg_range[1] + 100) # 2.plot seg = raw_sig[seg_range[0]:seg_range[1]] plt.ion() plt.figure(1) plt.clf() plt.plot(seg, label='ECG') # 3.plot group seg_posgroup = map(lambda x: x - seg_range[0], posgroup) plt.plot(seg_posgroup, map(lambda x: seg[x], seg_posgroup), label='posgroup', marker='o', markersize=4, markerfacecolor='g', alpha=0.7) # 4.plot peak pos # plot D6 postion peak_pos = D6pos[-1] - seg_range[0] peak_pos = int(peak_pos) plt.plot(peak_pos, seg[peak_pos], 'yo', markersize=12, alpha=0.7, label='D6 pos') # plot D5 postion peak_pos = D5pos[-1] - seg_range[0] peak_pos = int(peak_pos) plt.plot(peak_pos, seg[peak_pos], 'mo', markersize=12, alpha=0.7, label='D5 pos') # plot final decision postion peak_pos = final_decision_peak - seg_range[0] peak_pos = int(peak_pos) plt.plot(peak_pos, seg[peak_pos], 'r*', markersize=12, alpha=0.7, label='D5 pos') # 5.plot determin line seg_determin_line = self.cDlist[-6][seg_range[0]:seg_range[1]] seg_determin_line5 = self.cDlist[-5][seg_range[0]:seg_range[1]] plt.plot(seg_determin_line, 'y', label='D6') plt.plot(seg_determin_line5, 'm', label='D5') plt.title(self.recname) plt.legend() plt.grid(True) plt.show() # debug stop if D5_swt_mark and D6_swt_mark: print 'D5pos:', D5pos print 'D6pos:', D6pos print 'final_decision_peak', final_decision_peak pdb.set_trace() # return list of T peaks return self.peak_dict['T'] def get_min_score_poslist(self, scorelist): # get min score poslist from # [(score,pos),...] if scorelist is None or len(scorelist) == 0: return None minScore = scorelist[0][0] poslist = [] # find minScore for score, pos in scorelist: if score < minScore: minScore = score # find all pos with minScore for score, pos in scorelist: if score == minScore: poslist.append(pos) return (minScore, poslist) def get_P_peaklist(self, debug=False): if self.res_groups is None: # group the raw predition results if not grouped already self.group_result() res_groups = filter(lambda x: x[0] == 'P', self.res_groups) res_groups = self.filter_smaller_nearby_groups(res_groups) # get P peak list #e3list = np.array(self.cDlist[-6])+np.array(self.cDlist[-5]) e3list = np.array(self.cDlist[-5]) # Get raw_sig sig_struct = self.QTdb.load(self.recname) raw_sig = sig_struct['sig'] N_rawsig = len(raw_sig) #local_maxima_list = self.get_local_maxima_list(e3list) local_maxima_list = self.get_cross_zero_list(e3list) # Expert Label list expert_label_list = self.QTdb.getexpertlabeltuple(self.recname) print 'record name:', self.recname print 'length of expert label:', len(expert_label_list) # debug #if debug == True: #plt.figure(2) #plt.plot(e3list) #plt.plot(local_maxima_list,map(lambda x:e3list[x],local_maxima_list),'y*',markersize = 14) #plt.show() # find local maxima point within each group debug_ind = 0 for label, posgroup in res_groups: print 'debug_ind:', debug_ind debug_ind += 1 scorelist = [] print 'Pos Group:', posgroup print 'mean Group:', np.mean(posgroup) extra_search_len = 10 extra_search_left = max(0, min(posgroup) - extra_search_len) extra_search_right = min(N_rawsig, max(posgroup) + extra_search_len) for pos in xrange(extra_search_left, extra_search_right): nearest_dist = self.bw_find_nearest(pos, local_maxima_list) scorelist.append((nearest_dist, pos)) scorelist.sort(key=lambda x: x[0]) # get all pos with min score min_score, candidate_list = self.get_min_score_poslist(scorelist) print 'candidate_list:', candidate_list print 'got candidate_list' if min_score > 2: # not a peak self.peak_dict['P'].append(np.mean(posgroup)) elif len(candidate_list) == 1: # only mean score by SWT self.peak_dict['P'].append(candidate_list[0]) else: # multiple min score best_candidate = self.get_steepest_slope_index( candidate_list, local_maxima_list, e3list) self.peak_dict['P'].append(best_candidate) # debug if debug == True: print 'SWT peak pos:', self.peak_dict['P'][-1] pdb.set_trace() # debug plot # 1. get range seg_range = [min(posgroup), max(posgroup)] seg_range[0] = max(0, seg_range[0] - 200) seg_range[1] = min(len(raw_sig) - 1, seg_range[1] + 200) # 2.plot seg = raw_sig[seg_range[0]:seg_range[1]] plt.ion() plt.figure(1) plt.clf() plt.plot(seg, label='ECG') # 3.plot group seg_posgroup = map(lambda x: x - seg_range[0], posgroup) plt.plot(seg_posgroup, map(lambda x: seg[x], seg_posgroup), label='posgroup', marker='o', markersize=4, markerfacecolor='g', alpha=0.7) # 4.plot peak pos # plot D6 postion peak_pos = self.peak_dict['P'][-1] - seg_range[0] peak_pos = int(peak_pos) plt.plot(peak_pos, seg[peak_pos], 'yo', markersize=12, alpha=0.7, label='D6 pos') # plot D5 postion # peak_pos = D5pos[-1]-seg_range[0] # peak_pos = int(peak_pos) # plt.plot(peak_pos,seg[peak_pos],'mo',markersize = 12,alpha = 0.7,label = 'D5 pos') # plot final decision postion # peak_pos = final_decision_peak - seg_range[0] # peak_pos = int(peak_pos) # plt.plot(peak_pos,seg[peak_pos],'r*',markersize = 12,alpha = 0.7,label = 'D5 pos') # plot Expert label position segment_expertlist = filter( lambda x: x[0] >= seg_range[0] and x[0] < seg_range[1], expert_label_list) if len(segment_expertlist) > 0: segment_expert_poslist, segment_expert_labellist = zip( *segment_expertlist) plt.plot(map(lambda x: x - seg_range[0], segment_expert_poslist), map(lambda x: seg[x - seg_range[0]], segment_expert_poslist), 'rd', markersize=12, alpha=0.7, label='expert label') # 5.plot determin line seg_determin_line = self.cDlist[-4][seg_range[0]:seg_range[1]] seg_determin_lineD5 = self.cDlist[-5][seg_range[0]:seg_range[1]] seg_determin_lineD6 = self.cDlist[-3][seg_range[0]:seg_range[1]] plt.plot(seg_determin_line, 'y', label='D4') plt.plot(seg_determin_lineD5, 'g', label='D5') #plt.plot(seg_determin_lineD6,'m',label = 'D3') plt.title('{} {}'.format(self.recname, seg_range)) plt.legend() plt.grid(True) plt.show() # debug stop pdb.set_trace() # return list of P peaks return self.peak_dict['P'] # # detect boundaries using get_boundary_list function # def get_Ponset_list(self): label = 'Ponset' e3list = np.array(self.cDlist[-4]) + np.array(self.cDlist[-3]) crossZeroFunc = self.get_downward_cross_zero_list LongestSlopeIndexFunc = self.get_longest_downward_slope_index return self.get_boundary_list(label, e3list, crossZeroFunc, LongestSlopeIndexFunc) def get_Poffset_list(self): label = 'Poffset' e3list = np.array(self.cDlist[-4]) + np.array(self.cDlist[-3]) crossZeroFunc = self.get_local_minima_list LongestSlopeIndexFunc = self.get_longest_downward_slope_index return self.get_boundary_list(label, e3list, crossZeroFunc, LongestSlopeIndexFunc) def get_Toffset_list(self): label = 'Toffset' e3list = np.array(self.cDlist[-4]) + np.array(self.cDlist[-5]) crossZeroFunc = self.get_downward_cross_zero_list LongestSlopeIndexFunc = self.get_longest_downward_slope_index return self.get_boundary_list(label, e3list, crossZeroFunc, LongestSlopeIndexFunc) def get_Tonset_list(self, debug=False): label = 'Tonset' e3list = np.array(self.cDlist[-4]) + np.array(self.cDlist[-5]) crossZeroFunc = self.get_downward_cross_zero_list LongestSlopeIndexFunc = self.get_longest_downward_slope_index return self.get_boundary_list(label, e3list, crossZeroFunc, LongestSlopeIndexFunc, debug=debug) def get_boundary_list(self, label, e3list, crossZeroFunc, LongestSlopeIndexFunc, debug=False): if self.res_groups is None: # group the raw predition results if not grouped already self.group_result() res_groups = filter(lambda x: x[0] == label, self.res_groups) if debug == True: print 'length of result groups:', len(res_groups) pdb.set_trace() res_groups = self.filter_smaller_nearby_groups(res_groups) # get label peak list #e3list = np.array(self.cDlist[-6])+np.array(self.cDlist[-5]) #local_maxima_list = self.get_local_maxima_list(e3list) #local_maxima_list = self.get_cross_zero_list(e3list) local_maxima_list = crossZeroFunc(e3list) # debug #if debug == True: #plt.figure(2) #plt.plot(e3list) #plt.plot(local_maxima_list,map(lambda x:e3list[x],local_maxima_list),'y*',markersize = 14) #plt.show() # find local maxima point within each group for label, posgroup in res_groups: scorelist = [] if debug == True: print 'Pos Group:', posgroup print 'mean Group:', np.mean(posgroup) for pos in posgroup: nearest_dist = self.bw_find_nearest(pos, local_maxima_list) scorelist.append((nearest_dist, pos)) #scorelist.sort(key = lambda x:x[0]) # get all pos with min score min_score, candidate_list = self.get_min_score_poslist(scorelist) if min_score > 2: # not a peak self.peak_dict[label].append(np.mean(posgroup)) elif len(candidate_list) == 1: # only mean score by SWT self.peak_dict[label].append(candidate_list[0]) else: # multiple min score longest_slope_index = LongestSlopeIndexFunc( candidate_list, local_maxima_list, e3list) self.peak_dict[label].append( local_maxima_list[longest_slope_index]) # debug if debug == True: print 'SWT peak pos:', self.peak_dict[label][-1] pdb.set_trace() # return list of label peaks return self.peak_dict[label] def get_peak_list(self): self.group_result() # get T&P peak list self.get_T_peaklist() self.get_P_peaklist()
class EvaluationMultiLeads: '''Evaluation of raw detection result by random forest.''' def __init__(self, result_converter = None): self.QTdb = QTloader() self.labellist = [] self.expertlist = [] self.recname = None self.prdMatchList = [] self.expMatchList = [] # Converter that formatting given result format. self.result_converter_ = result_converter # color schemes tableau20 = [(31, 119, 180), (174, 199, 232), (255, 127, 14), (255, 187, 120), (44, 160, 44), (152, 223, 138), (214, 39, 40), (255, 152, 150), (148, 103, 189), (197, 176, 213), (140, 86, 75), (196, 156, 148), (227, 119, 194), (247, 182, 210), (127, 127, 127), (199, 199, 199), (188, 189, 34), (219, 219, 141), (23, 190, 207), (158, 218, 229)] self.colors = [] for color_tri in tableau20: self.colors.append((color_tri[0]/255.0,color_tri[1]/255.0,color_tri[2]/255.0)) # error List # [expPos - prdPos,...] self.errList = None def clear(self): # error List # [expPos - prdPos,...] self.errList = None self.labellist = [] self.expertlist = [] self.recname = None self.prdMatchList = [] self.expMatchList = [] def loadlabellist(self,filename, TargetLabel, supress_warning = False): '''Load label list with Target Label from json file. Result should have format: { 'recname': [], 'LeadResult': [ { 'P':[], 'T':[], }, { 'P':[], 'T':[], }, ], } ''' with open(filename,'r') as fin: Res = json.load(fin) # Convert result format Res = self.result_converter_(Res) #reclist = Res.keys() #if len(reclist) ==0: #raise Exception('No result data in file:'+filename+'!') #elif len(reclist)>1: #print 'Rec List:',reclist #print '[Warning] Multiple result data, only the 1st one is used!' self.recname = Res['recname'] if supress_warning == False: print '>>loading recname:',self.recname self.leadResults= Res['LeadResult'] # Special case for empty dict if TargetLabel not in self.leadResults[0]: self.leadPosList = ([], []) else: self.leadPosList = (self.leadResults[0][TargetLabel], self.leadResults[1][TargetLabel]) # convert values to int #self.labellist = map(lambda x:int(x),self.labellist) expertlist = self.QTdb.getexpertlabeltuple(self.recname) self.expertlist = map(lambda x:x[0],filter(lambda x:x[1]==TargetLabel,expertlist)) def getMatchList(self,prdposlist,MaxMatchDist = 50): '''Return the match result between prdposlist and expertlist.''' # Max Match Dist if len(prdposlist) == 0: return ([],[]) prdposlist.sort() self.expertlist.sort() prdMatchList = [-1,]*len(prdposlist) expMatchList = [-1,]*len(self.expertlist) N_prdList = len(prdposlist) for expInd,exppos in enumerate(self.expertlist): insertPos = bisect.bisect_left(prdposlist,exppos) left_ind = insertPos-1 right_ind = insertPos matchInd = -1 matchDist = -1 if left_ind<0: matchInd = right_ind matchDist = abs(exppos-prdposlist[right_ind]) elif right_ind>=N_prdList: matchInd = left_ind matchDist = abs(exppos-prdposlist[left_ind]) else: leftDist = abs(exppos-prdposlist[left_ind]) rightDist = abs(exppos-prdposlist[right_ind]) if leftDist>rightDist: matchInd = right_ind matchDist = rightDist else: matchInd = left_ind matchDist = leftDist if matchInd == -1 or matchDist>=MaxMatchDist: expMatchList[expInd] = -1 else: expMatchList[expInd] = matchInd prdMatchList[matchInd] = expInd return (expMatchList,prdMatchList) def evaluate(self,TargetLabel): self.prdMatchList = [] self.expMatchList = [] # get 2 Leads match lists expMatchList,prdMatchList = self.getMatchList(self.leadPosList[0]) self.prdMatchList.append(prdMatchList) self.expMatchList.append(expMatchList) expMatchList,prdMatchList = self.getMatchList(self.leadPosList[1]) self.prdMatchList.append(prdMatchList) self.expMatchList.append(expMatchList) # get error statistics self.get_errList() def getFNlist(self): '''Return total number of False Negtives.''' return self.FNcnt def getFPlist(self): return min(self.FPcnt1, self.FPcnt2) def plot_evaluation_result(self): sigStruct = self.QTdb.load(self.recname) rawSig = sigStruct['sig'] plt.figure(1) plt.subplot(211) plt.plot(rawSig) # plot expert labels expPosList = self.expertlist plt.plot(expPosList,map(lambda x:rawSig[x],expPosList),'d',color = self.colors[2],markersize = 12,label = 'Expert Label') # unmatched expert labels #FNPosList = map(lambda x:x[0],filter(lambda x:x[1]!=-1 or x[2]!=-1,zip(expPosList,self.expMatchList[0],self.expMatchList[1]))) #plt.plot(FNPosList,map(lambda x:rawSig[x],FNPosList),'kd',markersize = 16,label = 'False Negtive') # T predict list prdPosList = self.leadPosList[0] plt.plot(prdPosList,map(lambda x:rawSig[int(x)],prdPosList),'*',color = self.colors[3],label = 'prd1',markersize = 12) prdPosList = self.leadPosList[1] plt.plot(prdPosList,map(lambda x:rawSig[int(x)],prdPosList),'*',color = self.colors[4],label = 'prd2',markersize = 12) # not matched list #prdPosList = map(lambda x:x[0],filter(lambda x:x[1]==-1,zip(self.labellist,self.prdMatchList))) #plt.plot(prdPosList,map(lambda x:rawSig[x],prdPosList),'k*',markersize = 14,label = 'False Positive') # plot match line #for expPos,matchInd in zip(self.expertlist,self.expMatchList): #if matchInd == -1: #continue #prdPos = self.labellist[matchInd] #plt.plot([expPos,prdPos],[rawSig[expPos],rawSig[prdPos]],lw = 14,color = self.colors[4],alpha = 0.3) plt.grid(True) plt.legend() plt.title(self.recname) plt.show() def getContinousRangeList(self,recname): FileFolder = os.path.join(projhomepath,'QTdata','ContinousExpertMarkRangeList','{}_continousRange.json'.format(recname)) with open(FileFolder,'r') as fin: range_list = json.load(fin) return range_list def get_errList(self): '''Choose the minimum error between the two leads, output is in ms.''' self.errList = [] FNcnt = 0 # Plot match line for expPos,matchInd1,matchInd2 in zip( self.expertlist,self.expMatchList[0],self.expMatchList[1]): if matchInd1 == -1 and matchInd2 == -1: FNcnt += 1 continue elif matchInd1 == -1: prdpos2 = self.leadPosList[1][matchInd2] err2 = expPos - prdpos2 self.errList.append(4.0*err2) elif matchInd2 == -1: prdpos1 = self.leadPosList[0][matchInd1] err1 = expPos - prdpos1 self.errList.append(4.0*err1) else: prdpos1 = self.leadPosList[0][matchInd1] prdpos2 = self.leadPosList[1][matchInd2] err1 = expPos - prdpos1 err2 = expPos - prdpos2 # chooose the one with smaller error if abs(err1)<abs(err2): self.errList.append(4.0*err1) else: self.errList.append(4.0*err2) # Total number of False Negtives self.FNcnt = FNcnt # Exclude FP that not in the Continous Range range_list = self.getContinousRangeList(self.recname) range_set = set() for current_range in range_list: range_set |= set(range(current_range[0],current_range[1])) self.FPcnt1 = 0 for prdpos,match_index in zip(sorted(self.leadPosList[0]),self.prdMatchList[0]): if match_index == -1 and prdpos in range_set: self.FPcnt1 += 1 self.FPcnt2 = 0 for prdpos, match_index in zip(sorted(self.leadPosList[1]), self.prdMatchList[1]): if match_index == -1 and prdpos in range_set: self.FPcnt2 += 1 return self.errList def get_total_mean(self): meanVal = np.mean(self.errList) return meanVal def get_total_stdvar(self): stdvar = np.std(self.errList) return stdvar
class PointBrowser(object): """ Click on a point to select and highlight it -- the data that generated the point will be shown in the lower axes. Use the 'n' and 'p' keys to browse through the next and previous points """ def __init__(self, fig, ax, ax2): def GetRecordIndexbyName(resultlist, recname): for ind, filepath in enumerate(resultlist): current_recname = os.path.split(filepath)[-1].split('.json')[0] if recname == current_recname: return ind print 'record name {} not found!'.format(recname) return None self.fig = fig self.ax = ax self.ax2 = ax2 self.text = self.ax.text(0.05, 0.95, 'selected: none', transform=self.ax.transAxes, va='top') # ============================ # QTdb self.QTdb = QTloader() #self.reclist = self.QTdb.reclist # Round Index RoundIndex = 1 target_record_name = 'sel30' self.recInd = 0 self.resultlist = glob.glob( os.path.join(curfolderpath, 'M4_SWT', 'SWT_GroupRound_T_improved{}'.format(RoundIndex), '*.json')) self.raw_resultlist = glob.glob( os.path.join( 'F:\LabGit\ECG_RSWT\TestResult\paper\MultiRound4\Round{}'. format(RoundIndex), 'result_*')) # get target recname: target_index = GetRecordIndexbyName(self.resultlist, target_record_name) if target_index is not None: pass self.recInd = target_index self.recname = os.path.split( self.resultlist[self.recInd])[-1].split('.json')[0] self.sigStruct = self.QTdb.load(self.recname) self.rawSig = self.sigStruct['sig'] self.rawSig2 = self.sigStruct['sig2'] self.expLabels = self.QTdb.getexpertlabeltuple(self.recname) tableau20 = [(31, 119, 180), (174, 199, 232), (255, 127, 14), (255, 187, 120), (44, 160, 44), (152, 223, 138), (214, 39, 40), (255, 152, 150), (148, 103, 189), (197, 176, 213), (140, 86, 75), (196, 156, 148), (227, 119, 194), (247, 182, 210), (127, 127, 127), (199, 199, 199), (188, 189, 34), (219, 219, 141), (23, 190, 207), (158, 218, 229)] self.colors = [] for color_tri in tableau20: self.colors.append((color_tri[0] / 255.0, color_tri[1] / 255.0, color_tri[2] / 255.0)) # =========================== # Mark list self.whiteRegionList = [] self.totalWhiteCount = 0 self.getCoefList() self.GetExpertLabelRange() #def load_SWT_resultlist(self): #self.resultlist = glob.glob(os.path.join('F:\LabGit\ECG_RSWT\TestResult\paper\MultiRound2\Round1','result_*')) def GetExpertLabelRange(self): poslist, labellist = zip(*self.expLabels) poslist = list(poslist) poslist.sort() ret = [poslist[0], poslist[-1]] ret[0] -= 1000 ret[1] += 1000 ret[0] = max(0, ret[0]) ret[1] = min(len(self.rawSig) - 1, ret[1]) self.expertlabel_range = ret return ret def crop_data_for_swt(self, rawsig): # crop rawsig base2 = 1 N_data = len(rawsig) if len(rawsig) <= 1: raise Exception('len(rawsig)={},not enough for swt!', len(rawsig)) crop_len = base2 while base2 < N_data: if base2 * 2 >= N_data: crop_len = base2 * 2 break base2 *= 2 # pad zeros if N_data < crop_len: rawsig += [ rawsig[-1], ] * (crop_len - N_data) rawsig = rawsig[0:crop_len] return rawsig def getCoefList(self, MaxLevel=9): self.rawSig = self.crop_data_for_swt(self.rawSig) self.coeflist1 = pywt.swt(self.rawSig, 'db6', MaxLevel) self.rawSig2 = self.crop_data_for_swt(self.rawSig2) self.coeflist2 = pywt.swt(self.rawSig2, 'db6', MaxLevel) def PlotTpeakDetermineLines(self): '''T wave judge WT coef.''' cAlist, cDlist1 = zip(*self.coeflist1) cAlist, cDlist2 = zip(*self.coeflist2) color_index = 15 for detail_ind in xrange(-5, -7, -1): # plot axes ax = self.ax e4list = cDlist1[detail_ind] ax.plot(e4list, color=self.colors[color_index], label='detail{}'.format(detail_ind)) ax.legend() # plot axes II ax2 = self.ax2 e4list = cDlist2[detail_ind] #ax2.plot(e4list,color = self.colors[15],label = 'e4list II') ax2.plot(e4list, color=self.colors[color_index], label='detail{}'.format(detail_ind)) ax2.legend() # color change color_index += 1 # 画WT的导数 #for detail_ind in xrange(-5,-7,-1): ## plot axes #ax = self.ax #e4list = cDlist1[detail_ind] #e4_derive = [] #for e4_index in xrange(1,len(e4list)): #e4_derive.append(e4list[e4_index]-e4list[e4_index-1]) #ax.plot(e4_derive,color = self.colors[color_index],label = 'detail{} derive'.format(detail_ind)) #ax.legend() ## plot axes II #ax2 = self.ax2 #e4list = cDlist2[detail_ind] #e4_derive = [] #for e4_index in xrange(1,len(e4list)): #e4_derive.append(e4list[e4_index]-e4list[e4_index-1]) #ax2.plot(e4_derive,color = self.colors[color_index],label = 'detail{} derive'.format(detail_ind)) #ax2.legend() ## color change #color_index += 1 # update draw self.fig.canvas.draw() def plot_e4list(self): '''T wave judge WT coef.''' def gete4list(cDlist): e4list = np.array(cDlist[-7]) #+np.array(cDlist[-5]) return e4list # plot axes ax = self.ax cAlist, cDlist = zip(*self.coeflist1) e4list = gete4list(cDlist) ax.plot(e4list, color=self.colors[15], label='e4list') ax.legend() # plot axes II ax2 = self.ax2 cAlist, cDlist = zip(*self.coeflist2) e4list = gete4list(cDlist) ax2.plot(e4list, color=self.colors[15], label='e4list II') ax2.legend() # update draw self.fig.canvas.draw() def plot_e3list(self): '''P wave judge WT coef.''' def gete3list(cDlist): e3list = np.array(cDlist[-4]) return e3list # plot axes ax = self.ax cAlist, cDlist = zip(*self.coeflist1) e3list = gete3list(cDlist) ax.plot(e3list, color=self.colors[15], label='e3list') # plot expert labels self.plotExpertLabels(ax, e3list) ax.legend() # plot axes ax2 = self.ax2 cAlist, cDlist = zip(*self.coeflist2) e3list = gete3list(cDlist) ax2.plot(e3list, color=self.colors[15], label='e3list II') ax2.legend() # mark expert label position on e3list self.plotExpertLabels(self.ax2, e3list) # update canvas self.fig.canvas.draw() def onpress(self, event): if event.key not in ('n', 'p', ' ', 'x', 'a', 'd'): return if event.key == 'n': self.saveWhiteMarkList2Json() self.next_record() self.clearWhiteMarkList() # clear Marker List self.reDraw() return None elif event.key == ' ': self.reDraw() return None elif event.key == 'x': if len(self.whiteRegionList) > 0: # minus whiteCount self.totalWhiteCount -= abs(self.whiteRegionList[-1][1] - self.whiteRegionList[-1][0]) del self.whiteRegionList[-1] elif event.key == 'a': step = -200 xlims = self.ax.get_xlim() new_xlims = [xlims[0] + step, xlims[1] + step] self.ax.set_xlim(new_xlims) self.ax2.set_xlim(new_xlims) elif event.key == 'd': step = 200 xlims = self.ax.get_xlim() new_xlims = [xlims[0] + step, xlims[1] + step] self.ax.set_xlim(new_xlims) self.ax2.set_xlim(new_xlims) else: pass self.update() def saveWhiteMarkList2Json(self): pass def clearWhiteMarkList(self): self.whiteRegionList = [] self.totalWhiteCount = 0 def addMarkx(self, x): # mark data if len(self.whiteRegionList ) == 0 or self.whiteRegionList[-1][-1] != -1: startInd = int(x) self.whiteRegionList.append([startInd, -1]) else: endInd = int(x) self.whiteRegionList[-1][-1] = endInd # add to total white count # [startInd,endInd] # pair = self.whiteRegionList[-1] if pair[1] < pair[0]: self.whiteRegionList[-1] = [pair[1], pair[0]] pair = self.whiteRegionList[-1] self.totalWhiteCount += pair[1] - pair[0] + 1 # draw markers xlist = xrange(pair[0], pair[1] + 1) ylist = [] N_rawsig = len(self.rawSig) for xval in xlist: if xval >= 0 and xval < N_rawsig: ylist.append(self.rawSig[xval]) else: ylist.append(0) self.ax.plot(xlist, ylist, lw=6, color=self.colors[0], alpha=0.3, label='whiteRegion') def onpick(self, event): # the click locations x = event.mouseevent.xdata y = event.mouseevent.ydata # add white Mark self.addMarkx(x) # update canvas self.fig.canvas.draw() def DrawRawResults(self): ''' Draw Raw results on ax2 to compare the grouping result.''' ax = self.ax ax.cla() self.ax2.cla() self.text = self.ax.text(0.05, 0.95, 'selected: none', transform=self.ax.transAxes, va='top') ax.grid(color=(0.8, 0.8, 0.8), linestyle='--', linewidth=2) self.ax2.grid(color=(0.8, 0.8, 0.8), linestyle='--', linewidth=2) # ==================================== # load ECG signal # ==================================== ax.set_title('QT {} (Index = {})'.format(self.recname, self.recInd)) ax.plot(self.rawSig, picker=5) # 5 points tolerance # plot ax2 self.ax2.set_title('Lead I Raw Results') self.ax2.plot(self.rawSig, picker=5) # 5 points tolerance # plot Expert Labels self.plotExpertLabels(ax, self.rawSig) self.plotExpertLabels(self.ax2, self.rawSig) # plot Result labels self.plotResultLabels(ax, 0, self.rawSig) self.plotRawResultLabels(self.ax2, 1, self.rawSig) # plot error statistics self.plotErrorList(ax, 0, self.rawSig, TargetLabel='T') # plot SWT coefficients #self.plot_e3list() #self.plot_e4list() self.PlotTpeakDetermineLines() # update draw self.fig.canvas.draw() def reDraw(self): #self.DrawRawResults() self.DrawLeadII() def DrawLeadII(self): ax = self.ax ax.cla() self.ax2.cla() self.text = self.ax.text(0.05, 0.95, 'selected: none', transform=self.ax.transAxes, va='top') ax.grid(color=(0.8, 0.8, 0.8), linestyle='--', linewidth=2) self.ax2.grid(color=(0.8, 0.8, 0.8), linestyle='--', linewidth=2) # ==================================== # load ECG signal # ==================================== ax.set_title('QT {} (Index = {})'.format(self.recname, self.recInd)) ax.plot(self.rawSig, picker=5) # 5 points tolerance # plot ax2 self.ax2.set_title('Lead II') self.ax2.plot(self.rawSig2, picker=5) # 5 points tolerance # plot Expert Labels self.plotExpertLabels(ax, self.rawSig) self.plotExpertLabels(self.ax2, self.rawSig2) # plot Result labels self.plotResultLabels(ax, 0, self.rawSig) self.plotResultLabels(self.ax2, 1, self.rawSig2) # plot error statistics self.plotErrorList(ax, 0, self.rawSig, TargetLabel='T') # plot SWT coefficients #self.plot_e3list() #self.plot_e4list() self.PlotTpeakDetermineLines() # update draw self.fig.canvas.draw() def plotErrorList(self, ax, leadnum, rawSig, TargetLabel='P'): # Text Bias Function textBiasFunc = lambda x: x - 0.9 resultfilename = self.resultlist[self.recInd] label = TargetLabel # Evaluation eva = Evaluation2Leads() eva.loadlabellist(resultfilename, label) eva.evaluate(label) # get Error List errorList = [] eva_errorList = eva.errList eva_expMatchList = eva.expMatchList # iter match list, find -1 errInd = 0 for matchInd1, matchInd2 in zip(eva_expMatchList[0], eva_expMatchList[1]): if matchInd1 == -1 and matchInd2 == -1: errorList.append(-1) else: errorList.append(eva_errorList[errInd]) errInd += 1 if errInd != len(eva_errorList): print 'errInd != len(eva_errorList)' pdb.set_trace() raise Exception('error Ind != len(eva_errorList)') # expert poslist with Target Label TargetExpPosList = map(lambda x: x[0], filter(lambda x: x[1] == label, self.expLabels)) if len(errorList) != len(TargetExpPosList): print 'if len(errorList) != len(TargetExpPosList)' pdb.set_trace() raise Exception('if len(errorList) != len(TargetExpPosList)') for errVal, expPos in zip(errorList, TargetExpPosList): ax.annotate('error[{}]'.format(errVal), xy=(expPos, rawSig[expPos]), xytext=(expPos, textBiasFunc(rawSig[expPos])), xycoords='data', textcoords='data', arrowprops=dict(arrowstyle='->', connectionstyle='arc3,rad = -0.5')) # disp error mean and std self.text.set_text('Error mean,std = ({0:.3f},{1:.3f})'.format( np.mean(eva_errorList), np.std(eva_errorList))) def update(self): #self.ax2.text(0.05, 0.9, 'mu=%1.3f\nsigma=%1.3f' % (xs[dataind], ys[dataind]), #transform=self.ax2.transAxes, va='top') #self.ax2.set_ylim(-0.5, 1.5) self.fig.canvas.draw() def next_record(self): self.recInd += 1 self.recname = os.path.split( self.resultlist[self.recInd])[-1].split('.json')[0] self.sigStruct = self.QTdb.load(self.recname) self.rawSig = self.sigStruct['sig'] self.rawSig2 = self.sigStruct['sig2'] self.expLabels = self.QTdb.getexpertlabeltuple(self.recname) # SWT coef self.getCoefList() def plotRawResultLabels(self, ax, prdInd, rawSig): with open(self.raw_resultlist[self.recInd], 'r') as fin: resStruct = json.load(fin) resLabels = resStruct[prdInd][1] #get label Dict labelSet = set() labelDict = dict() for pos, label in resLabels: if label in labelSet: labelDict[label].append(pos) else: labelSet.add(label) labelDict[label] = [ pos, ] # plot to axes for label, posList in labelDict.iteritems(): # plot marker for current label if label[0] == 'T': color = self.colors[7] elif label[0] == 'P': color = self.colors[9] elif label[0] == 'R': color = self.colors[12] else: color = self.colors[14] # marker if 'onset' in label: marker = '<' elif 'offset' in label: marker = '>' elif len(label) == 1: marker = 'o' else: marker = '.' ax.plot(posList, map(lambda x: rawSig[x], posList), marker=marker, color=color, linestyle='none', markersize=6, label='Pred[{}]'.format(label), alpha=0.8) ax.legend(numpoints=1) def plotResultLabels(self, ax, prdInd, rawSig, resultlist_input=None): if resultlist_input is None: resultlist_input = self.resultlist with open(resultlist_input[self.recInd], 'r') as fin: resStruct = json.load(fin) #resLabels = resStruct[prdInd] #get label Dict #labelSet = set() labelDict = resStruct['LeadResult'][prdInd] #for pos,label in resLabels: #if label in labelSet: #labelDict[label].append(pos) #else: #labelSet.add(label) #labelDict[label] = [pos,] # plot to axes for label, posList in labelDict.iteritems(): # plot marker for current label if label[0] == 'T': color = self.colors[7] elif label[0] == 'P': color = self.colors[9] elif label[0] == 'R': color = self.colors[12] else: color = self.colors[14] # marker if 'onset' in label: marker = '<' elif 'offset' in label: marker = '>' elif len(label) == 1: marker = 'o' else: marker = '.' ax.plot(posList, map(lambda x: rawSig[int(x)], posList), marker=marker, color=color, linestyle='none', markersize=6, label='Pred[{}]'.format(label), alpha=0.8) ax.legend(numpoints=1) def plotExpertLabels(self, ax, rawSig): #get label Dict labelSet = set() labelDict = dict() for pos, label in self.expLabels: if label in labelSet: labelDict[label].append(pos) else: labelSet.add(label) labelDict[label] = [ pos, ] # plot to axes for label, posList in labelDict.iteritems(): # plot marker for current label if label[0] == 'T': color = self.colors[4] elif label[0] == 'P': color = self.colors[5] elif label[0] == 'R': color = self.colors[6] # marker if 'onset' in label: marker = '<' elif 'offset' in label: marker = '>' else: marker = 'o' ax.plot(posList, map(lambda x: rawSig[x], posList), marker=marker, color=color, linestyle='none', markersize=14, label=label) ax.legend(numpoints=1)
class PointBrowser(object): """ Click on a point to select and highlight it -- the data that generated the point will be shown in the lower axes. Use the 'n' and 'p' keys to browse through the next and previous points """ def __init__(self, fig, ax, ax2): self.fig = fig self.ax = ax self.ax2 = ax2 self.SaveFolder = os.path.join(curfolderpath, 'QTwhiteMarkList') self.text = self.ax.text(0.05, 0.95, 'selected: none', transform=self.ax.transAxes, va='top') #self.selected, = self.ax.plot([xs[0]], [ys[0]], 'o', ms=12, alpha=0.4, #color='yellow', visible=False) # ============================ # QTdb self.QTdb = QTloader() self.reclist = self.QTdb.reclist self.result_folder_path = os.path.join(os.path.dirname(curfolderpath), 'MultiLead4', 'GroupRound2') self.recInd = 0 # self.recname = self.reclist[self.recInd] self.recname = 'sel116' self.result_file_path = os.path.join(self.result_folder_path, '{}.json'.format(self.recname)) self.sigStruct = self.QTdb.load(self.recname) self.rawSig = self.sigStruct['sig'] self.rawSig2 = self.sigStruct['sig2'] self.expLabels = self.QTdb.getexpertlabeltuple(self.recname) tableau20 = [(31, 119, 180), (174, 199, 232), (255, 127, 14), (255, 187, 120), (44, 160, 44), (152, 223, 138), (214, 39, 40), (255, 152, 150), (148, 103, 189), (197, 176, 213), (140, 86, 75), (196, 156, 148), (227, 119, 194), (247, 182, 210), (127, 127, 127), (199, 199, 199), (188, 189, 34), (219, 219, 141), (23, 190, 207), (158, 218, 229)] self.colors = [] for color_tri in tableau20: self.colors.append((color_tri[0] / 255.0, color_tri[1] / 255.0, color_tri[2] / 255.0)) def onpress(self, event): if event.key not in ('n', 'p', ' ', 'x', 'a', 'd'): return if event.key == 'n': self.saveWhiteMarkList2Json() self.next_record() self.clearWhiteMarkList() # clear Marker List self.reDraw() return None elif event.key == ' ': self.reDraw() return None elif event.key == 'x': if len(self.whiteRegionList) > 0: # minus whiteCount self.totalWhiteCount -= abs(self.whiteRegionList[-1][1] - self.whiteRegionList[-1][0]) del self.whiteRegionList[-1] elif event.key == 'a': step = -200 xlims = self.ax.get_xlim() new_xlims = [xlims[0] + step, xlims[1] + step] self.ax.set_xlim(new_xlims) self.ax2.set_xlim(new_xlims) elif event.key == 'd': step = 200 xlims = self.ax.get_xlim() new_xlims = [xlims[0] + step, xlims[1] + step] self.ax.set_xlim(new_xlims) self.ax2.set_xlim(new_xlims) else: pass self.update() def saveWhiteMarkList2Json(self): pass def clearWhiteMarkList(self): self.whiteRegionList = [] self.totalWhiteCount = 0 def addMarkx(self, x): # mark data if len(self.whiteRegionList ) == 0 or self.whiteRegionList[-1][-1] != -1: startInd = int(x) self.whiteRegionList.append([startInd, -1]) else: endInd = int(x) self.whiteRegionList[-1][-1] = endInd # add to total white count # [startInd,endInd] # pair = self.whiteRegionList[-1] if pair[1] < pair[0]: self.whiteRegionList[-1] = [pair[1], pair[0]] pair = self.whiteRegionList[-1] self.totalWhiteCount += pair[1] - pair[0] + 1 # draw markers xlist = xrange(pair[0], pair[1] + 1) ylist = [] N_rawsig = len(self.rawSig) for xval in xlist: if xval >= 0 and xval < N_rawsig: ylist.append(self.rawSig[xval]) else: ylist.append(0) self.ax.plot(xlist, ylist, lw=6, color=self.colors[0], alpha=0.3, label='whiteRegion') def onpick(self, event): # the click locations x = event.mouseevent.xdata y = event.mouseevent.ydata # add white Mark self.addMarkx(x) self.text.set_text( 'Marking region: ({}) [whiteCnt {}][expertCnt {}]'.format( self.whiteRegionList[-1], self.totalWhiteCount, len(self.expLabels))) # update canvas self.fig.canvas.draw() def reDraw(self): ax = self.ax ax.cla() self.ax2.cla() self.text = self.ax.text(0.05, 0.95, 'selected: none', transform=self.ax.transAxes, va='top') ax.grid(color=(0.8, 0.8, 0.8), linestyle='--', linewidth=2) self.ax2.grid(color=(0.8, 0.8, 0.8), linestyle='--', linewidth=2) # ==================================== # load ECG signal # ==================================== ax.set_title('QT {} (Index = {})'.format(self.recname, self.recInd)) ax.plot(self.rawSig, picker=5) # 5 points tolerance # plot ax2 self.ax2.set_title('QT record {}:Lead V1'.format( self.recname, self.recInd)) self.ax2.plot(self.rawSig2, picker=5) # 5 points tolerance # plot Expert Labels # self.plotExpertLabels(ax,self.rawSig) # self.plotExpertLabels(self.ax2,self.rawSig2) # plot Result labels self.plotResultLabels(ax, 0, self.rawSig) self.plotResultLabels(self.ax2, 1, self.rawSig2) # update draw self.fig.canvas.draw() def update(self): #self.ax2.text(0.05, 0.9, 'mu=%1.3f\nsigma=%1.3f' % (xs[dataind], ys[dataind]), #transform=self.ax2.transAxes, va='top') #self.ax2.set_ylim(-0.5, 1.5) self.fig.canvas.draw() def next_record(self): ''' When press n, plot next record in QT database.''' self.recInd += 1 self.recname = self.reclist[self.recInd] # self.result_file_path = os.path.join(self.result_folder_path,'{}.json'.format(self.recname)) self.sigStruct = self.QTdb.load(self.recname) self.rawSig = self.sigStruct['sig'] self.rawSig2 = self.sigStruct['sig2'] self.expLabels = self.QTdb.getexpertlabeltuple(self.recname) def plotResultLabels(self, ax, prdInd, rawSig): '''Plot Final Output labels.''' with open(self.result_file_path, 'r') as fin: # Json file of format: # { 'LeadResult' = [{label = poslist}, {label = poslist}], # 'recname' = 'sel100' # } resStruct = json.load(fin) labelDict = resStruct["LeadResult"][prdInd] # Sort items by label. result_list = [] for label, posList in labelDict.iteritems(): result_list.append([label, posList]) result_list.sort(key=lambda x: x[0]) # plot to axes for label, posList in result_list: # Make posList integer. posList = map(lambda x: int(x), posList) # Get color of the label to plot. if label[0] == 'T': color = self.colors[1] elif label[0] == 'P': color = self.colors[4] elif label[0] == 'R': color = self.colors[9] else: color = self.colors[14] # Choose the color to plot on axes. if 'onset' in label: marker = '<' elif 'offset' in label: marker = '>' elif len(label) == 1: marker = 'o' else: marker = '.' ax.plot(posList, map(lambda x: rawSig[x], posList), marker=marker, color=color, linestyle='none', markersize=14, label='Pred[{}]'.format(label), alpha=0.8) ax.legend(numpoints=1) def plotExpertLabels(self, ax, rawSig): #get label Dict labelSet = set() labelDict = dict() for pos, label in self.expLabels: if label in labelSet: labelDict[label].append(pos) else: labelSet.add(label) labelDict[label] = [ pos, ] # plot to axes for label, posList in labelDict.iteritems(): # plot marker for current label if label[0] == 'T': color = self.colors[4] elif label[0] == 'P': color = self.colors[5] elif label[0] == 'R': color = self.colors[6] # marker if 'onset' in label: marker = '<' elif 'offset' in label: marker = '>' else: marker = 'o' ax.plot(posList, map(lambda x: rawSig[x], posList), marker=marker, color=color, linestyle='none', markersize=14, label=label) ax.legend(numpoints=1)
def PlotrawPredictionLabels(picklefilename): # Init Parameters target_recname = None ResID = 0 showExpertLabel = True xLimtoLabelRange = True with open(picklefilename,'r') as fin: Results = pickle.load(fin) # only plot target rec if target_recname is not None: print 'Result[{}] QTrecord name:{}'.format(ResID,Results[ResID][0]) if target_recname is not None and Results[ResID][0]!= target_recname: return # # show filtered results & raw results # Evaluate prediction result statistics # recname = Results[ResID][0] recLoader = QTloader() sig = recLoader.load(recname) rawReslist = Results[ResID][1] # plot signal plt.figure(1); rawsig = sig['sig'] # plot sig plt.plot(rawsig) # for xLim(init) Label_xmin ,Label_xmax = rawReslist[0][0],rawReslist[0][0] plotmarkerlist = PlotMarkerList() PlotMarkerdict = {x:[] for x in plotmarkerlist} map(lambda x:PlotMarkerdict[Label2PlotMarker(x[1])].append((x[0],rawsig[x[0]])),rawReslist) # for each maker for mker,posAmpList in PlotMarkerdict.iteritems(): if len(posAmpList) ==0: continue poslist,Amplist = zip(*posAmpList) Label_xmin = min(Label_xmin,min(poslist)) Label_xmax = max(Label_xmax ,max(poslist)) plt.plot(poslist,Amplist,mker,label='{} Label'.format(mker)) # plot expert marks if showExpertLabel: # blend together explbpos = recLoader.getexpertlabeltuple(recname) explbpos = [[x[0],x[0]] for x in explbpos] explbAmp = [[-100,100] for x in explbpos] # plot expert labels for x in explbpos: xpos = x[0] plt.plot([xpos,xpos],[-10,10],'black') #h_expertlabel = plt.plot(explbpos,explbAmp,'black') # set plot properties #plt.setp(h_expertlabel,'ms',12) if xLimtoLabelRange == True: plt.xlim(Label_xmin-100,Label_xmax+100) plt.xlabel('Samples') plt.ylabel('Amplitude') plt.title(recname) plt.show()
class RecSelector(): def __init__(self): self.qt = QTloader() def inspect_recs(self): reclist = self.qt.getQTrecnamelist() set_testing = set(conf['selQTall0_test_set']) set_training = set(reclist) - set_testing out_reclist = set_training # records selected selected_record_list = [] for ind, recname in enumerate(out_reclist): # inspect print '{} records left.'.format(len(out_reclist) - ind - 1) # plot QTsig = self.qt.load(recname) rawsig = QTsig['sig'] # expert labels testresult = self.qt.getexpertlabeltuple(recname) poslist, labellist = zip(*testresult) poslist = list(poslist) poslist.sort() dispRange = (poslist[0], poslist[-1]) resplt = ECGResultPloter(rawsig, testresult) resplt.plot(plotTitle='QT record {}'.format(recname), dispRange=dispRange) #self.qt.plotrec(recname) usethis = raw_input('Use this record as training record?(y/n):') if usethis == 'y': selected_record_list.append(recname) # debug #if ind > 2: #pass with open( os.path.join(os.path.dirname(curfilepath), 'selected_training_records.json'), 'w') as fout: json.dump(selected_record_list, fout) def pick_invalid_record(self): reclist = self.qt.getQTrecnamelist() invalidrecordlist = [] for ind, recname in enumerate(reclist): # inspect print 'Inspecting record ' '{}' ''.format(recname) sig = self.qt.load(recname) if abs(sig['sig'][0]) == float('inf'): invalidrecordlist.append(recname) return invalidrecordlist def save_recs_to_img(self): reclist = self.qt.getQTrecnamelist() sel1213 = conf['sel1213'] sel1213set = set(sel1213) out_reclist = set(reclist) - sel1213set for ind, recname in enumerate(reclist): # inspect print '{} records left.'.format(len(out_reclist) - ind - 1) #self.qt.plotrec(recname) self.qt.PlotAndSaveRec(recname) # debug if ind > 9: pass #print 'debug break' #break def inspect_recname(self, tarrecname): self.qt.plotrec(tarrecname) def inspect_selrec(self): QTreclist = self.qt.getQTrecnamelist() with open( os.path.join(curfolderpath, 'selected_training_records.json'), 'r') as fin: sel0115 = json.load(fin) #sel0115 = conf['selQTall0'] #test_reclist = set(QTreclist) - set(sel0115) for ind, recname in enumerate(sel0115): print '{} records left.'.format(len(sel0115) - ind - 1) self.inspect_recname(recname) def RFtest(self, testrecname): ecgrf = ECGRF() sel1213 = conf['sel1213'] ecgrf.training(sel1213) Results = ecgrf.testing([ testrecname, ]) # Evaluate result filtered_Res = ECGRF.resfilter(Results) stats = ECGstats(filtered_Res[0:1]) Err, FN = stats.eval(debug=False) # write to log file EvalLogfilename = os.path.join(projhomepath, 'res.log') stats.dispstat0(\ pFN = FN,\ pErr = Err) # plot prediction result stats.plotevalresofrec(Results[0][0], Results)
class NonQRS_SWT: def __init__(self): self.QTdb = QTloader() def LoadMarkListFromJson(self, jsonfilepath): # json format: # pass def getNonQRSsig(self, recname, MarkList=None): # QRS width threshold QRS_width_threshold = 180 sig_struct = self.QTdb.load(recname) rawsig = sig_struct['sig'] if MarkList is None: expert_marklist = self.QTdb.getexpertlabeltuple(recname) else: expert_marklist = MarkList # Use QRS region to flattern the signal expert_marklist = filter(lambda x: 'R' in x[1] and len(x[1]) > 1, expert_marklist) # Get QRS region expert_marklist.sort(key=lambda x: x[0]) QRS_regionlist = [] N_Rlist = len(expert_marklist) for ind in xrange(0, N_Rlist - 1): pos, label = expert_marklist[ind] # last one: no match pair if ind == N_Rlist - 1: break elif label != 'Ronset': continue # get next label next_pos, next_label = expert_marklist[ind + 1] if next_label == 'Roffset': if next_pos - pos >= QRS_width_threshold: print 'Warning: no matching Roffset found!' else: QRS_regionlist.append([pos, next_pos]) print 'Adding:', pos, next_pos # flattern the signal amp_start = rawsig[pos] amp_end = rawsig[next_pos] flat_segment = map( lambda x: amp_start + float(x) * (amp_end - amp_start) / (next_pos - pos), xrange(0, next_pos - pos)) for segment_index in xrange(pos, next_pos): rawsig[segment_index] = flat_segment[segment_index - pos] return rawsig def crop_data_for_swt(self, rawsig): # crop rawsig base2 = 1 N_data = len(rawsig) if len(rawsig) <= 1: raise Exception('len(rawsig)={},not enough for swt!', len(rawsig)) crop_len = base2 while base2 < N_data: if base2 * 2 >= N_data: crop_len = base2 * 2 break base2 *= 2 # pad zeros if N_data < crop_len: rawsig += [ rawsig[-1], ] * (crop_len - N_data) return rawsig def swt(self, recname, wavelet='db6', MaxLevel=9): # rawsig = self.getNonQRSsig(recname) rawsig = self.crop_data_for_swt(rawsig) coeflist = pywt.swt(rawsig, wavelet, MaxLevel) cAlist, cDlist = zip(*coeflist) self.cAlist = cAlist self.cDlist = cDlist
class Evaluation2Leads: def __init__(self): print '[Warning]Evaluator will convert prediction indexes to Integer.' self.QTdb = QTloader() self.labellist = [] self.expertlist = [] self.recname = None self.prdMatchList = [] self.expMatchList = [] # color schemes tableau20 = [(31, 119, 180), (174, 199, 232), (255, 127, 14), (255, 187, 120), (44, 160, 44), (152, 223, 138), (214, 39, 40), (255, 152, 150), (148, 103, 189), (197, 176, 213), (140, 86, 75), (196, 156, 148), (227, 119, 194), (247, 182, 210), (127, 127, 127), (199, 199, 199), (188, 189, 34), (219, 219, 141), (23, 190, 207), (158, 218, 229)] self.colors = [] for color_tri in tableau20: self.colors.append((color_tri[0] / 255.0, color_tri[1] / 255.0, color_tri[2] / 255.0)) # error List # [expPos - prdPos,...] self.errList = None def clear(self): # error List # [expPos - prdPos,...] self.errList = None self.labellist = [] self.expertlist = [] self.recname = None self.prdMatchList = [] self.expMatchList = [] def loadlabellist(self, filename, TargetLabel): '''Load label list with Target Label from json file.''' with open(filename, 'r') as fin: Res = json.load(fin) #reclist = Res.keys() #if len(reclist) ==0: #raise Exception('No result data in file:'+filename+'!') #elif len(reclist)>1: #print 'Rec List:',reclist #print '[Warning] Multiple result data, only the 1st one is used!' self.recname = Res['recname'] print '>>loading recname:', self.recname self.leadResults = Res['LeadResult'] self.leadPosList = (self.leadResults[0][TargetLabel], self.leadResults[1][TargetLabel]) # convert values to int #self.labellist = map(lambda x:int(x),self.labellist) expertlist = self.QTdb.getexpertlabeltuple(self.recname) self.expertlist = map( lambda x: x[0], filter(lambda x: x[1] == TargetLabel, expertlist)) def getMatchList(self, prdposlist, MaxMatchDist=50): '''Return the match result between prdposlist and expertlist.''' # Max Match Dist prdposlist.sort() self.expertlist.sort() prdMatchList = [ -1, ] * len(prdposlist) expMatchList = [ -1, ] * len(self.expertlist) N_prdList = len(prdposlist) for expInd, exppos in enumerate(self.expertlist): insertPos = bisect.bisect_left(prdposlist, exppos) left_ind = insertPos - 1 right_ind = insertPos matchInd = -1 matchDist = -1 if left_ind < 0: matchInd = right_ind matchDist = abs(exppos - prdposlist[right_ind]) elif right_ind >= N_prdList: matchInd = left_ind matchDist = abs(exppos - prdposlist[left_ind]) else: leftDist = abs(exppos - prdposlist[left_ind]) rightDist = abs(exppos - prdposlist[right_ind]) if leftDist > rightDist: matchInd = right_ind matchDist = rightDist else: matchInd = left_ind matchDist = leftDist if matchInd == -1 or matchDist >= MaxMatchDist: expMatchList[expInd] = -1 else: expMatchList[expInd] = matchInd prdMatchList[matchInd] = expInd return (expMatchList, prdMatchList) def evaluate(self, TargetLabel): self.prdMatchList = [] self.expMatchList = [] # get 2 Leads match lists expMatchList, prdMatchList = self.getMatchList(self.leadPosList[0]) self.prdMatchList.append(prdMatchList) self.expMatchList.append(expMatchList) expMatchList, prdMatchList = self.getMatchList(self.leadPosList[1]) self.prdMatchList.append(prdMatchList) self.expMatchList.append(expMatchList) # get error statistics self.get_errList() def getFNlist(self): '''Return total number of False Negtives.''' return self.FNcnt def getFPlist(self): return -1 def plot_evaluation_result(self): sigStruct = self.QTdb.load(self.recname) rawSig = sigStruct['sig'] plt.figure(1) plt.subplot(211) plt.plot(rawSig) # plot expert labels expPosList = self.expertlist plt.plot(expPosList, map(lambda x: rawSig[x], expPosList), 'd', color=self.colors[2], markersize=12, label='Expert Label') # unmatched expert labels #FNPosList = map(lambda x:x[0],filter(lambda x:x[1]!=-1 or x[2]!=-1,zip(expPosList,self.expMatchList[0],self.expMatchList[1]))) #plt.plot(FNPosList,map(lambda x:rawSig[x],FNPosList),'kd',markersize = 16,label = 'False Negtive') # T predict list prdPosList = self.leadPosList[0] plt.plot(prdPosList, map(lambda x: rawSig[int(x)], prdPosList), '*', color=self.colors[3], label='prd1', markersize=12) prdPosList = self.leadPosList[1] plt.plot(prdPosList, map(lambda x: rawSig[int(x)], prdPosList), '*', color=self.colors[4], label='prd2', markersize=12) # not matched list #prdPosList = map(lambda x:x[0],filter(lambda x:x[1]==-1,zip(self.labellist,self.prdMatchList))) #plt.plot(prdPosList,map(lambda x:rawSig[x],prdPosList),'k*',markersize = 14,label = 'False Positive') # plot match line #for expPos,matchInd in zip(self.expertlist,self.expMatchList): #if matchInd == -1: #continue #prdPos = self.labellist[matchInd] #plt.plot([expPos,prdPos],[rawSig[expPos],rawSig[prdPos]],lw = 14,color = self.colors[4],alpha = 0.3) plt.grid(True) plt.legend() plt.title(self.recname) plt.show() def get_errList(self): self.errList = [] FNcnt = 0 # plot match line for expPos, matchInd1, matchInd2 in zip(self.expertlist, self.expMatchList[0], self.expMatchList[1]): if matchInd1 == -1 and matchInd2 == -1: FNcnt += 1 continue prdpos1 = self.leadPosList[0][matchInd1] prdpos2 = self.leadPosList[1][matchInd2] err1 = expPos - prdpos1 err2 = expPos - prdpos2 # chooose the one with smaller error if abs(err1) < abs(err2): self.errList.append(4.0 * err1) else: self.errList.append(4.0 * err2) # total number of False Negtives self.FNcnt = FNcnt return self.errList def get_total_mean(self): meanVal = np.mean(self.errList) return meanVal def get_total_stdvar(self): stdvar = np.std(self.errList) return stdvar
class PointBrowser(object): """ Click on a point to select and highlight it -- the data that generated the point will be shown in the lower axes. Use the 'n' and 'p' keys to browse through the next and previous points """ def __init__(self, fig, ax, start_index): self.fig = fig self.ax = ax self.SaveFolder = os.path.join(curfolderpath, 'results') self.text = self.ax.text(0.05, 0.95, 'selected: none', transform=self.ax.transAxes, va='top') # ============================ # QTdb self.QTdb = QTloader() self.reclist = self.QTdb.reclist self.recInd = start_index self.recname = self.reclist[self.recInd] self.sigStruct = self.QTdb.load(self.recname) self.rawSig = self.sigStruct['sig'] self.expLabels = self.QTdb.getexpertlabeltuple(self.recname) tableau20 = [(31, 119, 180), (174, 199, 232), (255, 127, 14), (255, 187, 120), (44, 160, 44), (152, 223, 138), (214, 39, 40), (255, 152, 150), (148, 103, 189), (197, 176, 213), (140, 86, 75), (196, 156, 148), (227, 119, 194), (247, 182, 210), (127, 127, 127), (199, 199, 199), (188, 189, 34), (219, 219, 141), (23, 190, 207), (158, 218, 229)] self.colors = [] for color_tri in tableau20: self.colors.append((color_tri[0] / 255.0, color_tri[1] / 255.0, color_tri[2] / 255.0)) # =========================== # Mark list self.poslist = [] self.totalWhiteCount = 0 def onpress(self, event): if event.key not in ('n', 'p', ' ', 'x', 'a', 'd'): return if event.key == 'n': self.saveWhiteMarkList2Json() self.next_record() self.clearWhiteMarkList() # clear Marker List self.reDraw() return None elif event.key == ' ': self.reDraw() return None elif event.key == 'x': # Delete markers in stack if len(self.poslist) > 0: del self.poslist[-1] elif event.key == 'a': step = -200 xlims = self.ax.get_xlim() new_xlims = [xlims[0] + step, xlims[1] + step] self.ax.set_xlim(new_xlims) elif event.key == 'd': step = 200 xlims = self.ax.get_xlim() new_xlims = [xlims[0] + step, xlims[1] + step] self.ax.set_xlim(new_xlims) else: pass self.update() def saveWhiteMarkList2Json(self): with open( os.path.join(self.SaveFolder, '{}_poslist.json'.format(self.recname)), 'w') as fout: result_info = dict(ID=self.recname, database='QTdb', poslist=self.poslist, type='Tonset') json.dump(result_info, fout, indent=4, sort_keys=True) print 'Json file for record {} saved.'.format(self.recname) def clearWhiteMarkList(self): self.poslist = [] self.totalWhiteCount = 0 def addMarkx(self, x): # mark data pos = int(x) self.poslist.append(pos) self.ax.plot(pos, self.rawSig[pos], marker='x', color=self.colors[7], markersize=22, markeredgewidth=4, alpha=0.9, label='Tonset') self.ax.set_xlim(pos - 500, pos + 500) def onpick(self, event): '''Mouse click to mark target points.''' # The click locations x = event.mouseevent.xdata y = event.mouseevent.ydata # add white Mark self.addMarkx(x) self.text.set_text('Marking Tonset: ({}) [whiteCnt {}]'.format( self.poslist[-1], len(self.poslist))) # update canvas self.fig.canvas.draw() def RepeatCheck(self): '''Check repeat results.''' result_file_name = os.path.join(self.SaveFolder, '{}_poslist.json'.format(self.recname)) if os.path.exists(result_file_name): window = Tkinter.Tk() window.wm_withdraw() tkMessageBox.showinfo(title='Repeat', message='The record %s is already marked!' % self.recname) window.destroy() # Go to next record self.next_record() self.clearWhiteMarkList() self.reDraw() def reDraw(self): self.RepeatCheck() ax = self.ax ax.cla() self.text = self.ax.text(0.05, 0.95, 'selected: none', transform=self.ax.transAxes, va='top') ax.grid(color=(0.8, 0.8, 0.8), linestyle='--', linewidth=2) # ==================================== # load ECG signal ax.set_title('QT {} (Index = {})'.format(self.recname, self.recInd)) ax.plot(self.rawSig, picker=5) # 5 points tolerance # plot Expert Labels self.plotExpertLabels(ax) # draw Markers for pos in self.poslist: # draw markers self.ax.plot(pos, self.rawSig[pos], marker='x', color=self.colors[7], markersize=22, markeredgewidth=4, alpha=0.9, label='Tonset') self.ax.set_xlim(0, len(self.rawSig)) # update draw self.fig.canvas.draw() def update(self): #self.ax2.text(0.05, 0.9, 'mu=%1.3f\nsigma=%1.3f' % (xs[dataind], ys[dataind]), #transform=self.ax2.transAxes, va='top') #self.ax2.set_ylim(-0.5, 1.5) self.fig.canvas.draw() def next_record(self): self.recInd += 1 if self.recInd >= len(self.reclist): return False self.recname = self.reclist[self.recInd] self.sigStruct = self.QTdb.load(self.recname) self.rawSig = self.sigStruct['sig'] self.expLabels = self.QTdb.getexpertlabeltuple(self.recname) return True def plotExpertLabels(self, ax): #get label Dict labelSet = set() labelDict = dict() for pos, label in self.expLabels: if label in labelSet: labelDict[label].append(pos) else: labelSet.add(label) labelDict[label] = [ pos, ] # plot to axes for label, posList in labelDict.iteritems(): # plot marker for current label if label[0] == 'T': color = self.colors[4] elif label[0] == 'P': color = self.colors[5] elif label[0] == 'R': color = self.colors[6] # marker if 'onset' in label: marker = '<' elif 'offset' in label: marker = '>' else: marker = 'o' ax.plot(posList, map(lambda x: self.rawSig[x], posList), marker=marker, color=color, linestyle='none', markersize=14, label=label) ax.legend(numpoints=1)