def res_to_mat_fromResult(recID, reslist, mat_filename):
    # load QT rawsig
    qt = QTloader()
    sig = qt.load(recID)
    rawsig = sig['sig']
    # load expert label
    expert_reslist = qt.getexpertlabeltuple(recID)
    # save sig and reslist
    label_dict = dict()
    for pos, label in reslist:
        if label in label_dict:
            label_dict[label].append(pos)
        else:
            label_dict[label] = [
                pos,
            ]
    # Expert Labels
    for pos, label in expert_reslist:
        exp_label = 'expert_' + label
        if exp_label in label_dict:
            label_dict[exp_label].append(pos)
        else:
            label_dict[exp_label] = [
                pos,
            ]
    label_dict['sig'] = rawsig
    scipy.io.savemat(mat_filename, label_dict)
    print 'mat file [{}] saved.'.format(mat_filename)
def plot_QTdb_filtered_Result_with_syntax_filter(RFfolder,
                                                 TargetRecordList,
                                                 ResultFilterType,
                                                 showExpertLabels=False):
    # exit
    # ==========================
    # plot prediction result
    # ==========================
    reslist = getresultfilelist(RFfolder)
    qtdb = QTloader()
    non_result_extensions = ['out', 'json', 'log', 'txt']
    for fi, fname in enumerate(reslist):
        # block *.out
        file_extension = fname.split('.')[-1]
        if file_extension in non_result_extensions:
            continue
        print 'file name:', fname
        currecname = os.path.split(fname)[-1]
        print currecname
        #if currecname == 'result_sel820':
        #pdb.set_trace()
        if TargetRecordList is not None:
            if currecname not in TargetRecordList:
                continue
        # load signal and reslist
        with open(fname, 'r') as fin:
            (recID, reslist) = pickle.load(fin)
        # filter result of QT
        resfilter = ResultFilter(reslist)
        if len(ResultFilterType) >= 1 and ResultFilterType[0] == 'G':
            reslist = resfilter.group_local_result(cp_del_thres=1)
        reslist_syntax = reslist
        if len(ResultFilterType) >= 2 and ResultFilterType[1] == 'S':
            reslist_syntax = resfilter.syntax_filter(reslist)
        # empty signal result
        #if reslist is None or len(reslist) == 0:
        #continue
        #pdb.set_trace()
        sigstruct = qtdb.load(recID)
        if showExpertLabels == True:
            # Expert Label AdditionalPlot
            ExpertRes = qtdb.getexpertlabeltuple(recID)
            ExpertPoslist = map(lambda x: x[0], ExpertRes)
            AdditionalPlot = [
                ['kd', 'Expert Labels', ExpertPoslist],
            ]
        else:
            AdditionalPlot = None

        # plot res
        #resploter = ECGResultPloter(sigstruct['sig'],reslist)
        #resploter.plot(plotTitle = 'QT database',plotShow = True,plotFig = 2)
        # syntax_filter
        resploter_syntax = ECGResultPloter(sigstruct['sig'], reslist_syntax)
        resploter_syntax.plot(plotTitle='QT Record {}'.format(recID),
                              plotShow=True,
                              AdditionalPlot=AdditionalPlot)
class SWT_GroupResult2Leads:
    ''' Find P&T peak with SWT+db6
    '''
    def __init__(self, recname, reslist, leadname, MaxSWTLevel=9):
        self.recres = reslist
        #self.LeadRes = (reslist,reslist2)

        self.recname = recname
        self.QTdb = QTloader()
        self.sig_struct = self.QTdb.load(self.recname)
        self.rawsig = self.sig_struct[leadname]

        self.res_groups = None
        self.peak_dict = dict(T=[],
                              P=[],
                              Ponset=[],
                              Poffset=[],
                              Tonset=[],
                              Toffset=[])
        # Get SWT coef
        swt = NonQRS_SWT()
        swt.swt(self.recname)
        self.cDlist = swt.cDlist
        self.cAlist = swt.cAlist
        #self.getSWTcoeflist(MaxLevel = MaxSWTLevel)

    def group_result(self, white_del_thres=20, cp_del_thres=0):
        #
        # 参数说明:1.white_del_thres是删除较小白色组的阈值
        #           2.cp_del_thres是删除较小其他关键点组的阈值
        # Multiple prediction point -> single point output
        ## filter output for evaluation results
        #
        # parameters
        #
        #
        # the number of the group must be greater than:
        #
        # default parameter

        recres = self.recres

        # filtered test result
        frecres = []
        # in var
        prev_label = None
        posGroup = []
        #----------------------
        # [pos,label] in recres
        #----------------------
        for pos, label in recres:
            if prev_label is not None:
                if label != prev_label:
                    frecres.append((prev_label, posGroup))
                    posGroup = []

            prev_label = label
            posGroup.append(pos)
        # add last label group
        if len(posGroup) > 0:
            frecres.append((prev_label, posGroup))
        #======================
        # 1.删除比较小的白色组和其他组(different threshold)
        # 2.合并删除后的相邻同色组
        #======================
        filtered_local_res = []
        for label, posGroup in frecres:
            if label == 'white' and len(posGroup) <= white_del_thres:
                continue
            if label != 'white' and len(posGroup) <= cp_del_thres:
                continue
            # can merge backward?
            if len(filtered_local_res
                   ) > 0 and filtered_local_res[-1][0] == label:
                filtered_local_res[-1][1].extend(posGroup)
            else:
                filtered_local_res.append((label, posGroup))

        frecres = filtered_local_res
        # [(label,[poslist])]
        self.res_groups = frecres

        return frecres

    def filter_smaller_nearby_groups(self,
                                     res_groups,
                                     group_near_dist_thres=100):
        # [(label,[poslist])]
        # filter close groups:
        # delete groups with smaller number
        frecres = res_groups
        N_groups = len(frecres)
        deleted_reslist = []

        deleted_reslist.append(frecres[0])
        for group_ind in xrange(1, N_groups):
            max_before = np.max(deleted_reslist[-1][1])
            min_after = np.min(frecres[group_ind][1])
            if min_after - max_before <= group_near_dist_thres:
                # keep the larger group
                if len(frecres[group_ind][1]) > len(deleted_reslist[-1][1]):
                    # del delete
                    del deleted_reslist[-1]
                    deleted_reslist.append(frecres[group_ind])
            else:
                deleted_reslist.append(frecres[group_ind])
        return deleted_reslist

    def crop_data_for_swt(self, rawsig):
        # crop rawsig
        base2 = 1
        N_data = len(rawsig)
        if len(rawsig) <= 1:
            raise Exception('len(rawsig)={},not enough for swt!', len(rawsig))
        crop_len = base2
        while base2 < N_data:
            if base2 * 2 >= N_data:
                crop_len = base2 * 2
                break
            base2 *= 2
        # pad zeros
        if N_data < crop_len:
            rawsig += [
                rawsig[-1],
            ] * (crop_len - N_data)
        rawsig = rawsig[0:crop_len]
        return rawsig

    def getSWTcoeflist(self, MaxLevel=9):
        # crop two leads
        self.rawsig = self.crop_data_for_swt(self.rawsig)

        print '-' * 3
        print 'len(rawsig)= ', len(self.rawsig)
        print 'SWT maximum level =', pywt.swt_max_level(len(self.rawsig))
        coeflist = pywt.swt(self.rawsig, 'db6', MaxLevel)

        cAlist, cDlist = zip(*coeflist)
        self.cAlist = cAlist
        self.cDlist = cDlist

    def get_downward_cross_zero_list(self, array):
        # rising slope only
        if array is None or len(array) < 2:
            return []
        crosszerolist = []
        #
        # len of array
        N_array = len(array)
        for ind in xrange(1, N_array - 1):
            # cross zero
            if array[ind - 1] > 0 and array[ind] <= 0:
                crosszerolist.append(ind)
        return crosszerolist

    def get_cross_zero_list(self, array):
        # rising slope only
        if array is None or len(array) < 2:
            return []
        crosszerolist = []
        rem = array[0]
        N_array = len(array)

        for ind in xrange(1, N_array):
            if array[ind - 1] < 0 and array[ind] >= 0:
                crosszerolist.append(ind)
        return crosszerolist

    def get_local_minima_list(self, array):
        # rising slope only
        if array is None or len(array) < 2:
            return []
        ret_list = []
        #
        # length of array
        N = len(array)
        for ind in xrange(1, N - 1):
            if array[ind] - array[ind - 1] < 0 and array[ind +
                                                         1] - array[ind] >= 0:
                # local maxima
                ret_list.append(ind)
        return ret_list

    def get_local_maxima_list(self, array):
        # rising slope only
        if array is None or len(array) < 2:
            return []
        ret_list = []
        N = len(array)
        for ind in xrange(1, N - 1):
            if array[ind] - array[ind - 1] > 0 and array[ind +
                                                         1] - array[ind] <= 0:
                # local maxima
                ret_list.append(ind)
        return ret_list

    def bw_find_nearest(self, pos, array):
        N = len(array)
        insertPos = bisect.bisect_left(array, pos)
        if insertPos >= N:
            return abs(pos - array[-1])
        elif insertPos == 0:
            return abs(pos - array[0])
        else:
            return min(abs(pos - array[insertPos]),
                       abs(pos - array[insertPos - 1]))

    def get_longest_downward_slope_index(self, candidate_list, crosszerolist,
                                         e4list):
        #
        # length of cross zeros
        N = len(crosszerolist)
        max_slope_len = -1
        best_candidate = -1
        for prd_pos in candidate_list:

            # find the closest cross point in the list to prd_pos
            insertPos = bisect.bisect_left(crosszerolist, prd_pos)
            peak_pos = -1
            if insertPos >= N:
                peak_pos = N - 1
            elif insertPos == 0:
                peak_pos = 0
            else:
                if abs(prd_pos - crosszerolist[insertPos]) < abs(
                        prd_pos - crosszerolist[insertPos - 1]):
                    peak_pos = insertPos
                else:
                    peak_pos = insertPos - 1

            peak_pos = crosszerolist[peak_pos]
            # Find current slope length
            right_bound = N - 1
            left_bound = 0

            # reached N-1 or local maxima
            for cur_pos in xrange(peak_pos, N - 1):
                if e4list[cur_pos + 1] - e4list[cur_pos] >= 0:
                    # local minima
                    right_bound = cur_pos
                    break

            # reached 0
            for cur_pos in xrange(peak_pos, 0, -1):
                if e4list[cur_pos] - e4list[cur_pos - 1] >= 0:
                    # local maxima
                    left_bound = cur_pos
                    break
            cur_slope_len = right_bound - left_bound
            if cur_slope_len > max_slope_len:
                max_slope_len = cur_slope_len
                best_candidate = peak_pos
        return best_candidate

    def get_current_upward_slope_steepness(self, pos, crosszerolist, e4list):
        # Frequent used var.
        N = len(crosszerolist)
        N_e4list = len(e4list)

        # Only one pos, find its left minima&right maxima.
        prd_pos = pos
        # find the closest cross point in the list to prd_pos
        insertPos = bisect.bisect_left(crosszerolist, prd_pos)

        peak_pos = -1
        if insertPos >= N:
            peak_pos = N - 1
        elif insertPos == 0:
            peak_pos = 0
        else:
            if abs(prd_pos - crosszerolist[insertPos]) < abs(
                    prd_pos - crosszerolist[insertPos - 1]):
                peak_pos = insertPos
            else:
                peak_pos = insertPos - 1

        peak_pos = crosszerolist[peak_pos]
        # Find current slope length
        right_bound = N - 1
        left_bound = 0

        # reached N-1 or local maxima
        for cur_pos in xrange(peak_pos, N_e4list - 1):
            if e4list[cur_pos + 1] - e4list[cur_pos] <= 0:
                # local maxima
                right_bound = cur_pos
                break

        # reached 0
        for cur_pos in xrange(peak_pos, 0, -1):
            if e4list[cur_pos] - e4list[cur_pos - 1] <= 0:
                # local minima
                left_bound = cur_pos
                break
        cur_slope_len = right_bound - left_bound
        if cur_slope_len <= 0:
            plt.figure(2)
            plt.plot(e4list)
            plt.plot(crosszerolist, map(lambda x: e4list[x], crosszerolist),
                     'ro')
            plt.plot(peak_pos, e4list[peak_pos], 'y*', markersize=14)
            plt.grid(True)
            plt.show()
            print
            print 'Warning: cur_slope_len <= 0!'
            pdb.set_trace()
        cur_slope_amp = abs(e4list[right_bound] - e4list[left_bound])
        return float(cur_slope_amp) / cur_slope_len

    def get_steepest_slope_index(self, candidate_list, crosszerolist, e4list):
        N = len(crosszerolist)
        max_slope_steepness = None
        best_candidate = None
        for prd_pos in candidate_list:

            # find the closest cross point in the list to prd_pos
            insertPos = bisect.bisect_left(crosszerolist, prd_pos)
            peak_pos = -1
            if insertPos >= N:
                peak_pos = N - 1
            elif insertPos == 0:
                peak_pos = 0
            else:
                if abs(prd_pos - crosszerolist[insertPos]) < abs(
                        prd_pos - crosszerolist[insertPos - 1]):
                    peak_pos = insertPos
                else:
                    peak_pos = insertPos - 1

            peak_pos = crosszerolist[peak_pos]
            # Find current slope length
            right_bound = N - 1
            left_bound = 0
            right_amp, left_amp = None, None

            # reached N-1 or local maxima
            N_e4list = len(e4list)
            for cur_pos in xrange(peak_pos, N_e4list - 1):
                if e4list[cur_pos + 1] - e4list[cur_pos] <= 0:
                    # local maxima
                    right_bound = cur_pos
                    right_amp = e4list[cur_pos]
                    break

            # reached 0
            for cur_pos in xrange(peak_pos, 0, -1):
                if e4list[cur_pos] - e4list[cur_pos - 1] <= 0:
                    # local minima
                    left_bound = cur_pos
                    left_amp = e4list[cur_pos]
                    break
            cur_slope_len = right_bound - left_bound
            cur_steepness = float(abs(right_amp - left_amp)) / cur_slope_len

            # choose the steepest one as the finnal position
            if max_slope_steepness is None or cur_steepness > max_slope_steepness:
                max_slope_steepness = cur_steepness
                best_candidate = prd_pos
        return best_candidate

    def get_steepest_downward_slope_index(self, candidate_list, crosszerolist,
                                          e4list):
        #
        # length of cross zeros
        N = len(crosszerolist)
        max_slope_len = -1
        best_candidate = -1
        max_slope_steepness = None

        for prd_pos in candidate_list:

            # find the closest cross point in the list to prd_pos
            insertPos = bisect.bisect_left(crosszerolist, prd_pos)
            peak_pos = -1
            if insertPos >= N:
                peak_pos = N - 1
            elif insertPos == 0:
                peak_pos = 0
            else:
                if abs(prd_pos - crosszerolist[insertPos]) < abs(
                        prd_pos - crosszerolist[insertPos - 1]):
                    peak_pos = insertPos
                else:
                    peak_pos = insertPos - 1

            peak_pos = crosszerolist[peak_pos]
            # Find current slope length
            right_bound = N - 1
            left_bound = 0
            right_amp, left_amp = None, None

            # reached N-1 or local maxima
            for cur_pos in xrange(peak_pos, len(e4list)):
                if e4list[cur_pos + 1] - e4list[cur_pos] >= 0:
                    # local minima
                    right_bound = cur_pos
                    right_amp = e4list[cur_pos]
                    break

            # reached 0
            for cur_pos in xrange(peak_pos, 0, -1):
                if e4list[cur_pos] - e4list[cur_pos - 1] >= 0:
                    # local maxima
                    left_bound = cur_pos
                    left_amp = e4list[cur_pos]
                    break
            cur_slope_len = right_bound - left_bound

            if left_amp is None or right_amp is None:
                print 'left_amp,right_amp:'
                print left_amp, right_amp
                pdb.set_trace()
            cur_steepness = float(abs(right_amp - left_amp)) / cur_slope_len

            # choose the steepest one as the finnal position
            if max_slope_steepness is None or cur_steepness > max_slope_steepness:
                max_slope_steepness = cur_steepness
                best_candidate = prd_pos
        return best_candidate

    def get_longest_slope_index(self, candidate_list, crosszerolist, e4list):
        N = len(crosszerolist)
        N_e4list = len(e4list)
        max_slope_len = None
        best_candidate = None
        for prd_pos in candidate_list:

            # find the closest cross point in the list to prd_pos
            insertPos = bisect.bisect_left(crosszerolist, prd_pos)
            peak_pos = -1
            if insertPos >= N:
                peak_pos = N - 1
            elif insertPos == 0:
                peak_pos = 0
            else:
                if abs(prd_pos - crosszerolist[insertPos]) < abs(
                        prd_pos - crosszerolist[insertPos - 1]):
                    peak_pos = insertPos
                else:
                    peak_pos = insertPos - 1

            peak_pos = crosszerolist[peak_pos]
            # Find current slope length
            right_bound = N - 1
            left_bound = 0

            # reached N-1 or local maxima
            for cur_pos in xrange(peak_pos, N_e4list - 1):
                if e4list[cur_pos + 1] - e4list[cur_pos] <= 0:
                    # local maxima
                    right_bound = cur_pos
                    break

            # reached 0
            for cur_pos in xrange(peak_pos, 0, -1):
                if e4list[cur_pos] - e4list[cur_pos - 1] <= 0:
                    # local minima
                    left_bound = cur_pos
                    break
            cur_slope_len = right_bound - left_bound
            if max_slope_len is None or cur_slope_len > max_slope_len:
                max_slope_len = cur_slope_len
                best_candidate = prd_pos
        return best_candidate

    def get_T_peaklist(self):
        if self.res_groups is None:
            # group the raw predition results if not grouped already
            self.group_result()
        res_groups = filter(lambda x: x[0] == 'T', self.res_groups)
        res_groups = self.filter_smaller_nearby_groups(res_groups)

        # get T peak
        #e4list = np.array(self.cDlist[-4])+np.array(self.cDlist[-5])
        D6list = np.array(self.cDlist[-6])
        D5list = np.array(self.cDlist[-5])
        crosszerolist = self.get_cross_zero_list(D6list)
        D5crosszerolist = self.get_cross_zero_list(D5list)

        # debug :D5crosszerolist check!
        # e4list = D5list
        # plt.figure(2)
        # plt.plot(e4list)
        # plt.plot(D5crosszerolist,map(lambda x:e4list[x],D5crosszerolist),'ro')
        # plt.plot(155424,D5list[155424],'y*',markersize = 14)
        # plt.grid(True)
        # plt.show()
        # if 155424 in D5crosszerolist:
        # print '155424'
        # pdb.set_trace()

        # for debug
        sig_struct = self.QTdb.load(self.recname)
        raw_sig = sig_struct['sig']

        debug_res_group_ind = 21
        for label, posgroup in res_groups[22:]:
            debug_res_group_ind += 1
            scorelist = []
            D5scorelist = []
            for pos in posgroup:
                nearest_dist = self.bw_find_nearest(pos, crosszerolist)
                scorelist.append((nearest_dist, pos))
                # for D5
                D5nearest_dist = self.bw_find_nearest(pos, D5crosszerolist)
                D5scorelist.append((D5nearest_dist, pos))

            # get all pos with min score
            min_score, candidate_list = self.get_min_score_poslist(scorelist)
            D5min_score, D5candidate_list = self.get_min_score_poslist(
                D5scorelist)

            D5pos, D6pos = [], []
            D5_swt_mark, D6_swt_mark = True, True
            final_decision_peak = -1
            # get D6 crosszero position
            if min_score > 2:
                # not a peak
                print
                print 'Warning: using mean group position!'
                D6_swt_mark = False
                D6pos.append(np.mean(posgroup))
            elif len(candidate_list) == 1:
                # only mean score by SWT
                D6pos.append(candidate_list[0])
            else:
                # multiple min score
                longest_slope_index = self.get_longest_slope_index(
                    candidate_list, crosszerolist, D6list)
                D6pos.append(crosszerolist[longest_slope_index])
            # get D5 crosszero position
            if D5min_score > 2:
                # not a peak
                print
                print 'Warning: using mean group position!'
                D5_swt_mark = False
                D5pos.append(np.mean(posgroup))
            elif len(D5candidate_list) == 1:
                # only mean score by SWT
                D5pos.append(D5candidate_list[0])
            else:
                # multiple min score
                longest_slope_index = self.get_longest_slope_index(
                    D5candidate_list, D5crosszerolist, D5list)
                D5pos.append(D5crosszerolist[longest_slope_index])
            # plot two positions for debug --- check!

            print 'debug_res_group_ind', debug_res_group_ind
            # Get the slope for D5list and D6list.
            if D5_swt_mark and D6_swt_mark:
                print 'getting D5 slope:'
                D5slope = self.get_current_upward_slope_steepness(
                    D5pos[-1], D5crosszerolist, D5list)
                print 'getting D6 slope:'
                D6slope = self.get_current_upward_slope_steepness(
                    D6pos[-1], crosszerolist, D6list)
                print 'D6 slope value:'
                print D6slope
                print 'D5 slope value:'
                print D5slope

                pdb.set_trace()
                if D5slope < D6slope:
                    final_decision_peak = D6pos[-1]
                else:
                    final_decision_peak = D5pos[-1]

            elif D5_swt_mark:
                final_decision_peak = D5pos[-1]
            elif D6_swt_mark:
                final_decision_peak = D6pos[-1]
            else:
                # Default value is D6pos
                final_decision_peak = D6pos[-1]

            # debug plot
            # 1. get range
            seg_range = [min(posgroup), max(posgroup)]
            seg_range[0] = max(0, seg_range[0] - 100)
            seg_range[1] = min(len(raw_sig) - 1, seg_range[1] + 100)
            # 2.plot
            seg = raw_sig[seg_range[0]:seg_range[1]]
            plt.ion()
            plt.figure(1)
            plt.clf()
            plt.plot(seg, label='ECG')
            # 3.plot group
            seg_posgroup = map(lambda x: x - seg_range[0], posgroup)
            plt.plot(seg_posgroup,
                     map(lambda x: seg[x], seg_posgroup),
                     label='posgroup',
                     marker='o',
                     markersize=4,
                     markerfacecolor='g',
                     alpha=0.7)
            # 4.plot peak pos
            # plot D6 postion
            peak_pos = D6pos[-1] - seg_range[0]
            peak_pos = int(peak_pos)
            plt.plot(peak_pos,
                     seg[peak_pos],
                     'yo',
                     markersize=12,
                     alpha=0.7,
                     label='D6 pos')
            # plot D5 postion
            peak_pos = D5pos[-1] - seg_range[0]
            peak_pos = int(peak_pos)
            plt.plot(peak_pos,
                     seg[peak_pos],
                     'mo',
                     markersize=12,
                     alpha=0.7,
                     label='D5 pos')
            # plot final decision postion
            peak_pos = final_decision_peak - seg_range[0]
            peak_pos = int(peak_pos)
            plt.plot(peak_pos,
                     seg[peak_pos],
                     'r*',
                     markersize=12,
                     alpha=0.7,
                     label='D5 pos')
            # 5.plot determin line
            seg_determin_line = self.cDlist[-6][seg_range[0]:seg_range[1]]
            seg_determin_line5 = self.cDlist[-5][seg_range[0]:seg_range[1]]
            plt.plot(seg_determin_line, 'y', label='D6')
            plt.plot(seg_determin_line5, 'm', label='D5')

            plt.title(self.recname)
            plt.legend()
            plt.grid(True)
            plt.show()
            # debug stop
            if D5_swt_mark and D6_swt_mark:
                print 'D5pos:', D5pos
                print 'D6pos:', D6pos
                print 'final_decision_peak', final_decision_peak
            pdb.set_trace()

        # return list of T peaks
        return self.peak_dict['T']

    def get_min_score_poslist(self, scorelist):
        # get min score poslist from
        # [(score,pos),...]

        if scorelist is None or len(scorelist) == 0:
            return None
        minScore = scorelist[0][0]
        poslist = []

        # find minScore
        for score, pos in scorelist:
            if score < minScore:
                minScore = score

        # find all pos with minScore
        for score, pos in scorelist:
            if score == minScore:
                poslist.append(pos)
        return (minScore, poslist)

    def get_P_peaklist(self, debug=False):

        if self.res_groups is None:
            # group the raw predition results if not grouped already
            self.group_result()
        res_groups = filter(lambda x: x[0] == 'P', self.res_groups)
        res_groups = self.filter_smaller_nearby_groups(res_groups)
        # get P peak list
        #e3list = np.array(self.cDlist[-6])+np.array(self.cDlist[-5])
        e3list = np.array(self.cDlist[-5])

        # Get raw_sig
        sig_struct = self.QTdb.load(self.recname)
        raw_sig = sig_struct['sig']
        N_rawsig = len(raw_sig)

        #local_maxima_list = self.get_local_maxima_list(e3list)
        local_maxima_list = self.get_cross_zero_list(e3list)

        # Expert Label list
        expert_label_list = self.QTdb.getexpertlabeltuple(self.recname)
        print 'record name:', self.recname
        print 'length of expert label:', len(expert_label_list)
        # debug
        #if debug == True:
        #plt.figure(2)
        #plt.plot(e3list)
        #plt.plot(local_maxima_list,map(lambda x:e3list[x],local_maxima_list),'y*',markersize = 14)
        #plt.show()

        # find local maxima point within each group
        debug_ind = 0
        for label, posgroup in res_groups:
            print 'debug_ind:', debug_ind
            debug_ind += 1
            scorelist = []

            print 'Pos Group:', posgroup
            print 'mean Group:', np.mean(posgroup)

            extra_search_len = 10
            extra_search_left = max(0, min(posgroup) - extra_search_len)
            extra_search_right = min(N_rawsig,
                                     max(posgroup) + extra_search_len)

            for pos in xrange(extra_search_left, extra_search_right):
                nearest_dist = self.bw_find_nearest(pos, local_maxima_list)
                scorelist.append((nearest_dist, pos))
            scorelist.sort(key=lambda x: x[0])

            # get all pos with min score
            min_score, candidate_list = self.get_min_score_poslist(scorelist)
            print 'candidate_list:', candidate_list
            print 'got candidate_list'

            if min_score > 2:
                # not a peak
                self.peak_dict['P'].append(np.mean(posgroup))
            elif len(candidate_list) == 1:
                # only mean score by SWT
                self.peak_dict['P'].append(candidate_list[0])
            else:
                # multiple min score
                best_candidate = self.get_steepest_slope_index(
                    candidate_list, local_maxima_list, e3list)
                self.peak_dict['P'].append(best_candidate)

            # debug
            if debug == True:
                print 'SWT peak pos:', self.peak_dict['P'][-1]
                pdb.set_trace()
            # debug plot
            # 1. get range
            seg_range = [min(posgroup), max(posgroup)]
            seg_range[0] = max(0, seg_range[0] - 200)
            seg_range[1] = min(len(raw_sig) - 1, seg_range[1] + 200)
            # 2.plot
            seg = raw_sig[seg_range[0]:seg_range[1]]
            plt.ion()
            plt.figure(1)
            plt.clf()
            plt.plot(seg, label='ECG')
            # 3.plot group
            seg_posgroup = map(lambda x: x - seg_range[0], posgroup)
            plt.plot(seg_posgroup,
                     map(lambda x: seg[x], seg_posgroup),
                     label='posgroup',
                     marker='o',
                     markersize=4,
                     markerfacecolor='g',
                     alpha=0.7)
            # 4.plot peak pos
            # plot D6 postion
            peak_pos = self.peak_dict['P'][-1] - seg_range[0]
            peak_pos = int(peak_pos)
            plt.plot(peak_pos,
                     seg[peak_pos],
                     'yo',
                     markersize=12,
                     alpha=0.7,
                     label='D6 pos')
            # plot D5 postion
            # peak_pos = D5pos[-1]-seg_range[0]
            # peak_pos = int(peak_pos)
            # plt.plot(peak_pos,seg[peak_pos],'mo',markersize = 12,alpha = 0.7,label = 'D5 pos')
            # plot final decision postion
            # peak_pos = final_decision_peak - seg_range[0]
            # peak_pos = int(peak_pos)
            # plt.plot(peak_pos,seg[peak_pos],'r*',markersize = 12,alpha = 0.7,label = 'D5 pos')

            # plot Expert label position
            segment_expertlist = filter(
                lambda x: x[0] >= seg_range[0] and x[0] < seg_range[1],
                expert_label_list)
            if len(segment_expertlist) > 0:
                segment_expert_poslist, segment_expert_labellist = zip(
                    *segment_expertlist)
                plt.plot(map(lambda x: x - seg_range[0],
                             segment_expert_poslist),
                         map(lambda x: seg[x - seg_range[0]],
                             segment_expert_poslist),
                         'rd',
                         markersize=12,
                         alpha=0.7,
                         label='expert label')

            # 5.plot determin line
            seg_determin_line = self.cDlist[-4][seg_range[0]:seg_range[1]]
            seg_determin_lineD5 = self.cDlist[-5][seg_range[0]:seg_range[1]]
            seg_determin_lineD6 = self.cDlist[-3][seg_range[0]:seg_range[1]]

            plt.plot(seg_determin_line, 'y', label='D4')
            plt.plot(seg_determin_lineD5, 'g', label='D5')
            #plt.plot(seg_determin_lineD6,'m',label = 'D3')

            plt.title('{} {}'.format(self.recname, seg_range))
            plt.legend()
            plt.grid(True)
            plt.show()
            # debug stop
            pdb.set_trace()
        # return list of P peaks
        return self.peak_dict['P']

    #
    # detect boundaries using get_boundary_list function
    #
    def get_Ponset_list(self):
        label = 'Ponset'
        e3list = np.array(self.cDlist[-4]) + np.array(self.cDlist[-3])
        crossZeroFunc = self.get_downward_cross_zero_list
        LongestSlopeIndexFunc = self.get_longest_downward_slope_index

        return self.get_boundary_list(label, e3list, crossZeroFunc,
                                      LongestSlopeIndexFunc)

    def get_Poffset_list(self):
        label = 'Poffset'
        e3list = np.array(self.cDlist[-4]) + np.array(self.cDlist[-3])
        crossZeroFunc = self.get_local_minima_list
        LongestSlopeIndexFunc = self.get_longest_downward_slope_index

        return self.get_boundary_list(label, e3list, crossZeroFunc,
                                      LongestSlopeIndexFunc)

    def get_Toffset_list(self):
        label = 'Toffset'
        e3list = np.array(self.cDlist[-4]) + np.array(self.cDlist[-5])
        crossZeroFunc = self.get_downward_cross_zero_list
        LongestSlopeIndexFunc = self.get_longest_downward_slope_index

        return self.get_boundary_list(label, e3list, crossZeroFunc,
                                      LongestSlopeIndexFunc)

    def get_Tonset_list(self, debug=False):
        label = 'Tonset'
        e3list = np.array(self.cDlist[-4]) + np.array(self.cDlist[-5])
        crossZeroFunc = self.get_downward_cross_zero_list
        LongestSlopeIndexFunc = self.get_longest_downward_slope_index

        return self.get_boundary_list(label,
                                      e3list,
                                      crossZeroFunc,
                                      LongestSlopeIndexFunc,
                                      debug=debug)

    def get_boundary_list(self,
                          label,
                          e3list,
                          crossZeroFunc,
                          LongestSlopeIndexFunc,
                          debug=False):

        if self.res_groups is None:
            # group the raw predition results if not grouped already
            self.group_result()
        res_groups = filter(lambda x: x[0] == label, self.res_groups)

        if debug == True:
            print 'length of result groups:', len(res_groups)
            pdb.set_trace()

        res_groups = self.filter_smaller_nearby_groups(res_groups)
        # get label peak list
        #e3list = np.array(self.cDlist[-6])+np.array(self.cDlist[-5])

        #local_maxima_list = self.get_local_maxima_list(e3list)
        #local_maxima_list = self.get_cross_zero_list(e3list)
        local_maxima_list = crossZeroFunc(e3list)

        # debug
        #if debug == True:
        #plt.figure(2)
        #plt.plot(e3list)
        #plt.plot(local_maxima_list,map(lambda x:e3list[x],local_maxima_list),'y*',markersize = 14)
        #plt.show()

        # find local maxima point within each group
        for label, posgroup in res_groups:
            scorelist = []

            if debug == True:
                print 'Pos Group:', posgroup
                print 'mean Group:', np.mean(posgroup)

            for pos in posgroup:
                nearest_dist = self.bw_find_nearest(pos, local_maxima_list)
                scorelist.append((nearest_dist, pos))
            #scorelist.sort(key = lambda x:x[0])

            # get all pos with min score
            min_score, candidate_list = self.get_min_score_poslist(scorelist)

            if min_score > 2:
                # not a peak
                self.peak_dict[label].append(np.mean(posgroup))
            elif len(candidate_list) == 1:
                # only mean score by SWT
                self.peak_dict[label].append(candidate_list[0])
            else:
                # multiple min score
                longest_slope_index = LongestSlopeIndexFunc(
                    candidate_list, local_maxima_list, e3list)
                self.peak_dict[label].append(
                    local_maxima_list[longest_slope_index])

            # debug
            if debug == True:
                print 'SWT peak pos:', self.peak_dict[label][-1]
                pdb.set_trace()
        # return list of label peaks
        return self.peak_dict[label]

    def get_peak_list(self):
        self.group_result()
        # get T&P peak list
        self.get_T_peaklist()
        self.get_P_peaklist()
class EvaluationMultiLeads:
    '''Evaluation of raw detection result by random forest.'''
    def __init__(self, result_converter = None):
        self.QTdb = QTloader()
        self.labellist = []
        self.expertlist = []
        self.recname = None
        self.prdMatchList = []
        self.expMatchList = []
        # Converter that formatting given result format.
        self.result_converter_ = result_converter
        # color schemes
        tableau20 = [(31, 119, 180), (174, 199, 232), (255, 127, 14), (255, 187, 120),    
             (44, 160, 44), (152, 223, 138), (214, 39, 40), (255, 152, 150),    
             (148, 103, 189), (197, 176, 213), (140, 86, 75), (196, 156, 148),    
             (227, 119, 194), (247, 182, 210), (127, 127, 127), (199, 199, 199),    
             (188, 189, 34), (219, 219, 141), (23, 190, 207), (158, 218, 229)]
        self.colors = []
        for color_tri in tableau20:
            self.colors.append((color_tri[0]/255.0,color_tri[1]/255.0,color_tri[2]/255.0))

        # error List
        # [expPos - prdPos,...]
        self.errList = None
    def clear(self):
        # error List
        # [expPos - prdPos,...]
        self.errList = None
        self.labellist = []
        self.expertlist = []
        self.recname = None
        self.prdMatchList = []
        self.expMatchList = []

    def loadlabellist(self,filename, TargetLabel, supress_warning = False):
        '''Load label list with Target Label from json file.
        
            Result should have format:
                {
                    'recname': [],
                    'LeadResult': [
                        {
                            'P':[],
                            'T':[],
                        },
                        {
                            'P':[],
                            'T':[],
                        },
                    ],
                }
        '''
        
        with open(filename,'r') as fin:
            Res = json.load(fin)
            # Convert result format
            Res = self.result_converter_(Res)
        #reclist = Res.keys()

        #if len(reclist) ==0:
            #raise Exception('No result data in file:'+filename+'!')
        #elif len(reclist)>1:
            #print 'Rec List:',reclist
            #print '[Warning] Multiple result data, only the 1st one is used!'

        self.recname = Res['recname']
        if supress_warning == False:
            print '>>loading recname:',self.recname

        self.leadResults= Res['LeadResult']

        # Special case for empty dict
        if TargetLabel not in self.leadResults[0]:
            self.leadPosList = ([], [])
        else:
            self.leadPosList = (self.leadResults[0][TargetLabel],
                    self.leadResults[1][TargetLabel])

        # convert values to int
        #self.labellist = map(lambda x:int(x),self.labellist)

        expertlist = self.QTdb.getexpertlabeltuple(self.recname)
        self.expertlist = map(lambda x:x[0],filter(lambda x:x[1]==TargetLabel,expertlist))

    def getMatchList(self,prdposlist,MaxMatchDist = 50):
        '''Return the match result between prdposlist and expertlist.'''
        # Max Match Dist

        if len(prdposlist) == 0:
            return ([],[])
        prdposlist.sort()
        self.expertlist.sort()

        prdMatchList = [-1,]*len(prdposlist)
        expMatchList = [-1,]*len(self.expertlist)

        N_prdList = len(prdposlist)
        for expInd,exppos in enumerate(self.expertlist):
            insertPos = bisect.bisect_left(prdposlist,exppos)
            left_ind = insertPos-1
            right_ind = insertPos
            matchInd = -1
            matchDist = -1
            if left_ind<0:
                matchInd = right_ind
                matchDist = abs(exppos-prdposlist[right_ind])
            elif right_ind>=N_prdList:
                matchInd = left_ind
                matchDist = abs(exppos-prdposlist[left_ind])
            else:
                leftDist = abs(exppos-prdposlist[left_ind])
                rightDist = abs(exppos-prdposlist[right_ind])
                if leftDist>rightDist:
                    matchInd = right_ind
                    matchDist = rightDist
                else:
                    matchInd = left_ind
                    matchDist = leftDist
            if matchInd == -1 or matchDist>=MaxMatchDist:
                expMatchList[expInd] = -1
            else:
                expMatchList[expInd] = matchInd
                prdMatchList[matchInd] = expInd
        return (expMatchList,prdMatchList)

    def evaluate(self,TargetLabel):
        self.prdMatchList = []
        self.expMatchList = []

        # get 2 Leads match lists
        expMatchList,prdMatchList = self.getMatchList(self.leadPosList[0])
        self.prdMatchList.append(prdMatchList)
        self.expMatchList.append(expMatchList)
        expMatchList,prdMatchList = self.getMatchList(self.leadPosList[1])
        self.prdMatchList.append(prdMatchList)
        self.expMatchList.append(expMatchList)

        # get error statistics
        self.get_errList()

    def getFNlist(self):
        '''Return total number of False Negtives.'''
        return self.FNcnt
    def getFPlist(self):
        return min(self.FPcnt1, self.FPcnt2)

    def plot_evaluation_result(self):
        sigStruct = self.QTdb.load(self.recname)
        rawSig = sigStruct['sig']

        plt.figure(1)
        plt.subplot(211)
        plt.plot(rawSig)

        # plot expert labels
        expPosList = self.expertlist
        plt.plot(expPosList,map(lambda x:rawSig[x],expPosList),'d',color = self.colors[2],markersize = 12,label = 'Expert Label')
        # unmatched expert labels
        #FNPosList = map(lambda x:x[0],filter(lambda x:x[1]!=-1 or x[2]!=-1,zip(expPosList,self.expMatchList[0],self.expMatchList[1])))
        #plt.plot(FNPosList,map(lambda x:rawSig[x],FNPosList),'kd',markersize = 16,label = 'False Negtive')
        
        # T predict list
        prdPosList = self.leadPosList[0]
        plt.plot(prdPosList,map(lambda x:rawSig[int(x)],prdPosList),'*',color = self.colors[3],label = 'prd1',markersize = 12)

        prdPosList = self.leadPosList[1]
        plt.plot(prdPosList,map(lambda x:rawSig[int(x)],prdPosList),'*',color = self.colors[4],label = 'prd2',markersize = 12)

        # not matched list
        #prdPosList = map(lambda x:x[0],filter(lambda x:x[1]==-1,zip(self.labellist,self.prdMatchList)))
        #plt.plot(prdPosList,map(lambda x:rawSig[x],prdPosList),'k*',markersize = 14,label = 'False Positive')

        # plot match line
        #for expPos,matchInd in zip(self.expertlist,self.expMatchList):
            #if matchInd == -1:
                #continue
            #prdPos = self.labellist[matchInd]
            #plt.plot([expPos,prdPos],[rawSig[expPos],rawSig[prdPos]],lw = 14,color = self.colors[4],alpha = 0.3)

        plt.grid(True)
        plt.legend()
        plt.title(self.recname)
        plt.show()

    def getContinousRangeList(self,recname):
        FileFolder = os.path.join(projhomepath,'QTdata','ContinousExpertMarkRangeList','{}_continousRange.json'.format(recname))
        with open(FileFolder,'r') as fin:
            range_list = json.load(fin)
        return range_list

    def get_errList(self):
        '''Choose the minimum error between the two leads, output is in ms.'''
        self.errList = []
        FNcnt = 0
        # Plot match line
        for expPos,matchInd1,matchInd2 in zip(
                self.expertlist,self.expMatchList[0],self.expMatchList[1]):
            if matchInd1 == -1 and matchInd2 == -1:
                FNcnt += 1
                continue
            elif matchInd1 == -1:
                prdpos2 = self.leadPosList[1][matchInd2]
                err2 = expPos - prdpos2
                self.errList.append(4.0*err2)
            elif matchInd2 == -1:
                prdpos1 = self.leadPosList[0][matchInd1]
                err1 = expPos - prdpos1
                self.errList.append(4.0*err1)
            else:
                prdpos1 = self.leadPosList[0][matchInd1]
                prdpos2 = self.leadPosList[1][matchInd2]
                err1 = expPos - prdpos1
                err2 = expPos - prdpos2

                # chooose the one with smaller error 
                if abs(err1)<abs(err2):
                    self.errList.append(4.0*err1)
                else:
                    self.errList.append(4.0*err2)

        # Total number of False Negtives
        self.FNcnt = FNcnt
        # Exclude FP that not in the Continous Range
        range_list = self.getContinousRangeList(self.recname)
        range_set = set()
        for current_range in range_list:
            range_set |= set(range(current_range[0],current_range[1]))

        self.FPcnt1 = 0
        for prdpos,match_index in zip(sorted(self.leadPosList[0]),self.prdMatchList[0]):
            if match_index == -1 and prdpos in range_set:
                self.FPcnt1 += 1
        self.FPcnt2 = 0
        for prdpos, match_index in zip(sorted(self.leadPosList[1]), self.prdMatchList[1]):
            if match_index == -1 and prdpos in range_set:
                self.FPcnt2 += 1

        return self.errList

    def get_total_mean(self):
        meanVal = np.mean(self.errList)
        return meanVal
    def get_total_stdvar(self):
        stdvar = np.std(self.errList)
        return stdvar
class PointBrowser(object):
    """
    Click on a point to select and highlight it -- the data that
    generated the point will be shown in the lower axes.  Use the 'n'
    and 'p' keys to browse through the next and previous points
    """
    def __init__(self, fig, ax, ax2):
        def GetRecordIndexbyName(resultlist, recname):
            for ind, filepath in enumerate(resultlist):
                current_recname = os.path.split(filepath)[-1].split('.json')[0]
                if recname == current_recname:
                    return ind

            print 'record name {} not found!'.format(recname)
            return None

        self.fig = fig
        self.ax = ax
        self.ax2 = ax2
        self.text = self.ax.text(0.05,
                                 0.95,
                                 'selected: none',
                                 transform=self.ax.transAxes,
                                 va='top')
        # ============================
        # QTdb
        self.QTdb = QTloader()
        #self.reclist = self.QTdb.reclist

        # Round Index
        RoundIndex = 1
        target_record_name = 'sel30'
        self.recInd = 0

        self.resultlist = glob.glob(
            os.path.join(curfolderpath, 'M4_SWT',
                         'SWT_GroupRound_T_improved{}'.format(RoundIndex),
                         '*.json'))
        self.raw_resultlist = glob.glob(
            os.path.join(
                'F:\LabGit\ECG_RSWT\TestResult\paper\MultiRound4\Round{}'.
                format(RoundIndex), 'result_*'))

        # get target recname:
        target_index = GetRecordIndexbyName(self.resultlist,
                                            target_record_name)
        if target_index is not None:
            pass
            self.recInd = target_index

        self.recname = os.path.split(
            self.resultlist[self.recInd])[-1].split('.json')[0]
        self.sigStruct = self.QTdb.load(self.recname)
        self.rawSig = self.sigStruct['sig']
        self.rawSig2 = self.sigStruct['sig2']
        self.expLabels = self.QTdb.getexpertlabeltuple(self.recname)

        tableau20 = [(31, 119, 180), (174, 199, 232), (255, 127, 14),
                     (255, 187, 120), (44, 160, 44), (152, 223, 138),
                     (214, 39, 40), (255, 152, 150), (148, 103, 189),
                     (197, 176, 213), (140, 86, 75), (196, 156, 148),
                     (227, 119, 194), (247, 182, 210), (127, 127, 127),
                     (199, 199, 199), (188, 189, 34), (219, 219, 141),
                     (23, 190, 207), (158, 218, 229)]
        self.colors = []
        for color_tri in tableau20:
            self.colors.append((color_tri[0] / 255.0, color_tri[1] / 255.0,
                                color_tri[2] / 255.0))
        # ===========================
        # Mark list
        self.whiteRegionList = []
        self.totalWhiteCount = 0

        self.getCoefList()
        self.GetExpertLabelRange()

    #def load_SWT_resultlist(self):
    #self.resultlist = glob.glob(os.path.join('F:\LabGit\ECG_RSWT\TestResult\paper\MultiRound2\Round1','result_*'))

    def GetExpertLabelRange(self):
        poslist, labellist = zip(*self.expLabels)
        poslist = list(poslist)
        poslist.sort()
        ret = [poslist[0], poslist[-1]]
        ret[0] -= 1000
        ret[1] += 1000

        ret[0] = max(0, ret[0])
        ret[1] = min(len(self.rawSig) - 1, ret[1])

        self.expertlabel_range = ret
        return ret

    def crop_data_for_swt(self, rawsig):
        # crop rawsig
        base2 = 1
        N_data = len(rawsig)
        if len(rawsig) <= 1:
            raise Exception('len(rawsig)={},not enough for swt!', len(rawsig))
        crop_len = base2
        while base2 < N_data:
            if base2 * 2 >= N_data:
                crop_len = base2 * 2
                break
            base2 *= 2
        # pad zeros
        if N_data < crop_len:
            rawsig += [
                rawsig[-1],
            ] * (crop_len - N_data)
        rawsig = rawsig[0:crop_len]
        return rawsig

    def getCoefList(self, MaxLevel=9):
        self.rawSig = self.crop_data_for_swt(self.rawSig)
        self.coeflist1 = pywt.swt(self.rawSig, 'db6', MaxLevel)

        self.rawSig2 = self.crop_data_for_swt(self.rawSig2)
        self.coeflist2 = pywt.swt(self.rawSig2, 'db6', MaxLevel)

    def PlotTpeakDetermineLines(self):
        '''T wave judge WT coef.'''

        cAlist, cDlist1 = zip(*self.coeflist1)
        cAlist, cDlist2 = zip(*self.coeflist2)

        color_index = 15
        for detail_ind in xrange(-5, -7, -1):
            # plot axes
            ax = self.ax

            e4list = cDlist1[detail_ind]
            ax.plot(e4list,
                    color=self.colors[color_index],
                    label='detail{}'.format(detail_ind))
            ax.legend()

            # plot axes II
            ax2 = self.ax2

            e4list = cDlist2[detail_ind]
            #ax2.plot(e4list,color = self.colors[15],label = 'e4list II')
            ax2.plot(e4list,
                     color=self.colors[color_index],
                     label='detail{}'.format(detail_ind))
            ax2.legend()
            # color change
            color_index += 1
        # 画WT的导数
        #for detail_ind in xrange(-5,-7,-1):
        ## plot axes
        #ax = self.ax

        #e4list = cDlist1[detail_ind]
        #e4_derive = []
        #for e4_index in xrange(1,len(e4list)):
        #e4_derive.append(e4list[e4_index]-e4list[e4_index-1])

        #ax.plot(e4_derive,color = self.colors[color_index],label = 'detail{} derive'.format(detail_ind))
        #ax.legend()

        ## plot axes II
        #ax2 = self.ax2

        #e4list = cDlist2[detail_ind]
        #e4_derive = []
        #for e4_index in xrange(1,len(e4list)):
        #e4_derive.append(e4list[e4_index]-e4list[e4_index-1])

        #ax2.plot(e4_derive,color = self.colors[color_index],label = 'detail{} derive'.format(detail_ind))
        #ax2.legend()
        ## color change
        #color_index += 1

        # update draw
        self.fig.canvas.draw()

    def plot_e4list(self):
        '''T wave judge WT coef.'''
        def gete4list(cDlist):
            e4list = np.array(cDlist[-7])  #+np.array(cDlist[-5])
            return e4list

        # plot axes
        ax = self.ax

        cAlist, cDlist = zip(*self.coeflist1)
        e4list = gete4list(cDlist)

        ax.plot(e4list, color=self.colors[15], label='e4list')
        ax.legend()

        # plot axes II
        ax2 = self.ax2

        cAlist, cDlist = zip(*self.coeflist2)
        e4list = gete4list(cDlist)

        ax2.plot(e4list, color=self.colors[15], label='e4list II')
        ax2.legend()

        # update draw
        self.fig.canvas.draw()

    def plot_e3list(self):
        '''P wave judge WT coef.'''
        def gete3list(cDlist):
            e3list = np.array(cDlist[-4])
            return e3list

        # plot axes
        ax = self.ax

        cAlist, cDlist = zip(*self.coeflist1)
        e3list = gete3list(cDlist)

        ax.plot(e3list, color=self.colors[15], label='e3list')

        # plot expert labels
        self.plotExpertLabels(ax, e3list)
        ax.legend()

        # plot axes
        ax2 = self.ax2

        cAlist, cDlist = zip(*self.coeflist2)
        e3list = gete3list(cDlist)

        ax2.plot(e3list, color=self.colors[15], label='e3list II')
        ax2.legend()

        # mark expert label position on e3list
        self.plotExpertLabels(self.ax2, e3list)

        # update canvas
        self.fig.canvas.draw()

    def onpress(self, event):
        if event.key not in ('n', 'p', ' ', 'x', 'a', 'd'):
            return

        if event.key == 'n':
            self.saveWhiteMarkList2Json()
            self.next_record()
            self.clearWhiteMarkList()

            # clear Marker List
            self.reDraw()
            return None
        elif event.key == ' ':
            self.reDraw()

            return None
        elif event.key == 'x':
            if len(self.whiteRegionList) > 0:
                # minus whiteCount
                self.totalWhiteCount -= abs(self.whiteRegionList[-1][1] -
                                            self.whiteRegionList[-1][0])
                del self.whiteRegionList[-1]
        elif event.key == 'a':
            step = -200
            xlims = self.ax.get_xlim()
            new_xlims = [xlims[0] + step, xlims[1] + step]
            self.ax.set_xlim(new_xlims)
            self.ax2.set_xlim(new_xlims)
        elif event.key == 'd':
            step = 200
            xlims = self.ax.get_xlim()
            new_xlims = [xlims[0] + step, xlims[1] + step]
            self.ax.set_xlim(new_xlims)
            self.ax2.set_xlim(new_xlims)
        else:
            pass

        self.update()

    def saveWhiteMarkList2Json(self):
        pass

    def clearWhiteMarkList(self):
        self.whiteRegionList = []
        self.totalWhiteCount = 0

    def addMarkx(self, x):
        # mark data
        if len(self.whiteRegionList
               ) == 0 or self.whiteRegionList[-1][-1] != -1:
            startInd = int(x)
            self.whiteRegionList.append([startInd, -1])
        else:
            endInd = int(x)
            self.whiteRegionList[-1][-1] = endInd

            # add to total white count
            # [startInd,endInd]
            #
            pair = self.whiteRegionList[-1]
            if pair[1] < pair[0]:
                self.whiteRegionList[-1] = [pair[1], pair[0]]
                pair = self.whiteRegionList[-1]

            self.totalWhiteCount += pair[1] - pair[0] + 1

            # draw markers
            xlist = xrange(pair[0], pair[1] + 1)
            ylist = []
            N_rawsig = len(self.rawSig)
            for xval in xlist:
                if xval >= 0 and xval < N_rawsig:
                    ylist.append(self.rawSig[xval])
                else:
                    ylist.append(0)
            self.ax.plot(xlist,
                         ylist,
                         lw=6,
                         color=self.colors[0],
                         alpha=0.3,
                         label='whiteRegion')

    def onpick(self, event):

        # the click locations
        x = event.mouseevent.xdata
        y = event.mouseevent.ydata

        # add white Mark
        self.addMarkx(x)

        # update canvas
        self.fig.canvas.draw()

    def DrawRawResults(self):
        ''' Draw Raw results on ax2 to compare the grouping result.'''
        ax = self.ax
        ax.cla()
        self.ax2.cla()

        self.text = self.ax.text(0.05,
                                 0.95,
                                 'selected: none',
                                 transform=self.ax.transAxes,
                                 va='top')
        ax.grid(color=(0.8, 0.8, 0.8), linestyle='--', linewidth=2)
        self.ax2.grid(color=(0.8, 0.8, 0.8), linestyle='--', linewidth=2)

        # ====================================
        # load ECG signal

        # ====================================
        ax.set_title('QT {} (Index = {})'.format(self.recname, self.recInd))
        ax.plot(self.rawSig, picker=5)  # 5 points tolerance

        # plot ax2
        self.ax2.set_title('Lead I Raw Results')
        self.ax2.plot(self.rawSig, picker=5)  # 5 points tolerance

        # plot Expert Labels
        self.plotExpertLabels(ax, self.rawSig)
        self.plotExpertLabels(self.ax2, self.rawSig)
        # plot Result labels
        self.plotResultLabels(ax, 0, self.rawSig)
        self.plotRawResultLabels(self.ax2, 1, self.rawSig)
        # plot error statistics
        self.plotErrorList(ax, 0, self.rawSig, TargetLabel='T')

        # plot SWT coefficients
        #self.plot_e3list()
        #self.plot_e4list()
        self.PlotTpeakDetermineLines()

        # update draw
        self.fig.canvas.draw()

    def reDraw(self):
        #self.DrawRawResults()
        self.DrawLeadII()

    def DrawLeadII(self):

        ax = self.ax
        ax.cla()
        self.ax2.cla()

        self.text = self.ax.text(0.05,
                                 0.95,
                                 'selected: none',
                                 transform=self.ax.transAxes,
                                 va='top')
        ax.grid(color=(0.8, 0.8, 0.8), linestyle='--', linewidth=2)
        self.ax2.grid(color=(0.8, 0.8, 0.8), linestyle='--', linewidth=2)

        # ====================================
        # load ECG signal

        # ====================================
        ax.set_title('QT {} (Index = {})'.format(self.recname, self.recInd))
        ax.plot(self.rawSig, picker=5)  # 5 points tolerance

        # plot ax2
        self.ax2.set_title('Lead II')
        self.ax2.plot(self.rawSig2, picker=5)  # 5 points tolerance

        # plot Expert Labels
        self.plotExpertLabels(ax, self.rawSig)
        self.plotExpertLabels(self.ax2, self.rawSig2)
        # plot Result labels
        self.plotResultLabels(ax, 0, self.rawSig)
        self.plotResultLabels(self.ax2, 1, self.rawSig2)
        # plot error statistics
        self.plotErrorList(ax, 0, self.rawSig, TargetLabel='T')

        # plot SWT coefficients
        #self.plot_e3list()
        #self.plot_e4list()
        self.PlotTpeakDetermineLines()

        # update draw
        self.fig.canvas.draw()

    def plotErrorList(self, ax, leadnum, rawSig, TargetLabel='P'):
        # Text Bias Function
        textBiasFunc = lambda x: x - 0.9

        resultfilename = self.resultlist[self.recInd]
        label = TargetLabel

        # Evaluation
        eva = Evaluation2Leads()
        eva.loadlabellist(resultfilename, label)
        eva.evaluate(label)

        # get Error List
        errorList = []
        eva_errorList = eva.errList
        eva_expMatchList = eva.expMatchList

        # iter match list, find -1
        errInd = 0
        for matchInd1, matchInd2 in zip(eva_expMatchList[0],
                                        eva_expMatchList[1]):
            if matchInd1 == -1 and matchInd2 == -1:
                errorList.append(-1)
            else:
                errorList.append(eva_errorList[errInd])
                errInd += 1
        if errInd != len(eva_errorList):
            print 'errInd != len(eva_errorList)'
            pdb.set_trace()
            raise Exception('error Ind != len(eva_errorList)')

        # expert poslist with Target Label
        TargetExpPosList = map(lambda x: x[0],
                               filter(lambda x: x[1] == label, self.expLabels))
        if len(errorList) != len(TargetExpPosList):
            print 'if len(errorList) != len(TargetExpPosList)'
            pdb.set_trace()
            raise Exception('if len(errorList) != len(TargetExpPosList)')

        for errVal, expPos in zip(errorList, TargetExpPosList):
            ax.annotate('error[{}]'.format(errVal),
                        xy=(expPos, rawSig[expPos]),
                        xytext=(expPos, textBiasFunc(rawSig[expPos])),
                        xycoords='data',
                        textcoords='data',
                        arrowprops=dict(arrowstyle='->',
                                        connectionstyle='arc3,rad = -0.5'))

        # disp error mean and std
        self.text.set_text('Error mean,std = ({0:.3f},{1:.3f})'.format(
            np.mean(eva_errorList), np.std(eva_errorList)))

    def update(self):

        #self.ax2.text(0.05, 0.9, 'mu=%1.3f\nsigma=%1.3f' % (xs[dataind], ys[dataind]),
        #transform=self.ax2.transAxes, va='top')
        #self.ax2.set_ylim(-0.5, 1.5)

        self.fig.canvas.draw()

    def next_record(self):
        self.recInd += 1
        self.recname = os.path.split(
            self.resultlist[self.recInd])[-1].split('.json')[0]
        self.sigStruct = self.QTdb.load(self.recname)
        self.rawSig = self.sigStruct['sig']
        self.rawSig2 = self.sigStruct['sig2']
        self.expLabels = self.QTdb.getexpertlabeltuple(self.recname)

        # SWT coef
        self.getCoefList()

    def plotRawResultLabels(self, ax, prdInd, rawSig):

        with open(self.raw_resultlist[self.recInd], 'r') as fin:
            resStruct = json.load(fin)

        resLabels = resStruct[prdInd][1]
        #get label Dict
        labelSet = set()
        labelDict = dict()
        for pos, label in resLabels:
            if label in labelSet:
                labelDict[label].append(pos)
            else:
                labelSet.add(label)
                labelDict[label] = [
                    pos,
                ]

        # plot to axes
        for label, posList in labelDict.iteritems():
            # plot marker for current label
            if label[0] == 'T':
                color = self.colors[7]
            elif label[0] == 'P':
                color = self.colors[9]
            elif label[0] == 'R':
                color = self.colors[12]
            else:
                color = self.colors[14]
            # marker
            if 'onset' in label:
                marker = '<'
            elif 'offset' in label:
                marker = '>'
            elif len(label) == 1:
                marker = 'o'
            else:
                marker = '.'
            ax.plot(posList,
                    map(lambda x: rawSig[x], posList),
                    marker=marker,
                    color=color,
                    linestyle='none',
                    markersize=6,
                    label='Pred[{}]'.format(label),
                    alpha=0.8)
        ax.legend(numpoints=1)

    def plotResultLabels(self, ax, prdInd, rawSig, resultlist_input=None):
        if resultlist_input is None:
            resultlist_input = self.resultlist

        with open(resultlist_input[self.recInd], 'r') as fin:
            resStruct = json.load(fin)

        #resLabels = resStruct[prdInd]
        #get label Dict
        #labelSet = set()
        labelDict = resStruct['LeadResult'][prdInd]
        #for pos,label in resLabels:
        #if label in labelSet:
        #labelDict[label].append(pos)
        #else:
        #labelSet.add(label)
        #labelDict[label] = [pos,]

        # plot to axes
        for label, posList in labelDict.iteritems():
            # plot marker for current label
            if label[0] == 'T':
                color = self.colors[7]
            elif label[0] == 'P':
                color = self.colors[9]
            elif label[0] == 'R':
                color = self.colors[12]
            else:
                color = self.colors[14]
            # marker
            if 'onset' in label:
                marker = '<'
            elif 'offset' in label:
                marker = '>'
            elif len(label) == 1:
                marker = 'o'
            else:
                marker = '.'
            ax.plot(posList,
                    map(lambda x: rawSig[int(x)], posList),
                    marker=marker,
                    color=color,
                    linestyle='none',
                    markersize=6,
                    label='Pred[{}]'.format(label),
                    alpha=0.8)
        ax.legend(numpoints=1)

    def plotExpertLabels(self, ax, rawSig):

        #get label Dict
        labelSet = set()
        labelDict = dict()
        for pos, label in self.expLabels:
            if label in labelSet:
                labelDict[label].append(pos)
            else:
                labelSet.add(label)
                labelDict[label] = [
                    pos,
                ]

        # plot to axes
        for label, posList in labelDict.iteritems():
            # plot marker for current label
            if label[0] == 'T':
                color = self.colors[4]
            elif label[0] == 'P':
                color = self.colors[5]
            elif label[0] == 'R':
                color = self.colors[6]
            # marker
            if 'onset' in label:
                marker = '<'
            elif 'offset' in label:
                marker = '>'
            else:
                marker = 'o'
            ax.plot(posList,
                    map(lambda x: rawSig[x], posList),
                    marker=marker,
                    color=color,
                    linestyle='none',
                    markersize=14,
                    label=label)
        ax.legend(numpoints=1)
class PointBrowser(object):
    """
    Click on a point to select and highlight it -- the data that
    generated the point will be shown in the lower axes.  Use the 'n'
    and 'p' keys to browse through the next and previous points
    """
    def __init__(self, fig, ax, ax2):
        self.fig = fig
        self.ax = ax
        self.ax2 = ax2

        self.SaveFolder = os.path.join(curfolderpath, 'QTwhiteMarkList')

        self.text = self.ax.text(0.05,
                                 0.95,
                                 'selected: none',
                                 transform=self.ax.transAxes,
                                 va='top')
        #self.selected, = self.ax.plot([xs[0]], [ys[0]], 'o', ms=12, alpha=0.4,
        #color='yellow', visible=False)
        # ============================
        # QTdb
        self.QTdb = QTloader()
        self.reclist = self.QTdb.reclist
        self.result_folder_path = os.path.join(os.path.dirname(curfolderpath),
                                               'MultiLead4', 'GroupRound2')

        self.recInd = 0
        # self.recname = self.reclist[self.recInd]
        self.recname = 'sel116'
        self.result_file_path = os.path.join(self.result_folder_path,
                                             '{}.json'.format(self.recname))
        self.sigStruct = self.QTdb.load(self.recname)
        self.rawSig = self.sigStruct['sig']
        self.rawSig2 = self.sigStruct['sig2']
        self.expLabels = self.QTdb.getexpertlabeltuple(self.recname)

        tableau20 = [(31, 119, 180), (174, 199, 232), (255, 127, 14),
                     (255, 187, 120), (44, 160, 44), (152, 223, 138),
                     (214, 39, 40), (255, 152, 150), (148, 103, 189),
                     (197, 176, 213), (140, 86, 75), (196, 156, 148),
                     (227, 119, 194), (247, 182, 210), (127, 127, 127),
                     (199, 199, 199), (188, 189, 34), (219, 219, 141),
                     (23, 190, 207), (158, 218, 229)]
        self.colors = []
        for color_tri in tableau20:
            self.colors.append((color_tri[0] / 255.0, color_tri[1] / 255.0,
                                color_tri[2] / 255.0))

    def onpress(self, event):
        if event.key not in ('n', 'p', ' ', 'x', 'a', 'd'):
            return

        if event.key == 'n':
            self.saveWhiteMarkList2Json()
            self.next_record()
            self.clearWhiteMarkList()

            # clear Marker List
            self.reDraw()
            return None
        elif event.key == ' ':
            self.reDraw()

            return None
        elif event.key == 'x':
            if len(self.whiteRegionList) > 0:
                # minus whiteCount
                self.totalWhiteCount -= abs(self.whiteRegionList[-1][1] -
                                            self.whiteRegionList[-1][0])
                del self.whiteRegionList[-1]
        elif event.key == 'a':
            step = -200
            xlims = self.ax.get_xlim()
            new_xlims = [xlims[0] + step, xlims[1] + step]
            self.ax.set_xlim(new_xlims)
            self.ax2.set_xlim(new_xlims)
        elif event.key == 'd':
            step = 200
            xlims = self.ax.get_xlim()
            new_xlims = [xlims[0] + step, xlims[1] + step]
            self.ax.set_xlim(new_xlims)
            self.ax2.set_xlim(new_xlims)
        else:
            pass

        self.update()

    def saveWhiteMarkList2Json(self):
        pass

    def clearWhiteMarkList(self):
        self.whiteRegionList = []
        self.totalWhiteCount = 0

    def addMarkx(self, x):
        # mark data
        if len(self.whiteRegionList
               ) == 0 or self.whiteRegionList[-1][-1] != -1:
            startInd = int(x)
            self.whiteRegionList.append([startInd, -1])
        else:
            endInd = int(x)
            self.whiteRegionList[-1][-1] = endInd

            # add to total white count
            # [startInd,endInd]
            #
            pair = self.whiteRegionList[-1]
            if pair[1] < pair[0]:
                self.whiteRegionList[-1] = [pair[1], pair[0]]
                pair = self.whiteRegionList[-1]

            self.totalWhiteCount += pair[1] - pair[0] + 1

            # draw markers
            xlist = xrange(pair[0], pair[1] + 1)
            ylist = []
            N_rawsig = len(self.rawSig)
            for xval in xlist:
                if xval >= 0 and xval < N_rawsig:
                    ylist.append(self.rawSig[xval])
                else:
                    ylist.append(0)
            self.ax.plot(xlist,
                         ylist,
                         lw=6,
                         color=self.colors[0],
                         alpha=0.3,
                         label='whiteRegion')

    def onpick(self, event):

        # the click locations
        x = event.mouseevent.xdata
        y = event.mouseevent.ydata

        # add white Mark
        self.addMarkx(x)
        self.text.set_text(
            'Marking region: ({}) [whiteCnt {}][expertCnt {}]'.format(
                self.whiteRegionList[-1], self.totalWhiteCount,
                len(self.expLabels)))

        # update canvas
        self.fig.canvas.draw()

    def reDraw(self):

        ax = self.ax
        ax.cla()
        self.ax2.cla()

        self.text = self.ax.text(0.05,
                                 0.95,
                                 'selected: none',
                                 transform=self.ax.transAxes,
                                 va='top')
        ax.grid(color=(0.8, 0.8, 0.8), linestyle='--', linewidth=2)
        self.ax2.grid(color=(0.8, 0.8, 0.8), linestyle='--', linewidth=2)

        # ====================================
        # load ECG signal
        # ====================================
        ax.set_title('QT {} (Index = {})'.format(self.recname, self.recInd))
        ax.plot(self.rawSig, picker=5)  # 5 points tolerance

        # plot ax2
        self.ax2.set_title('QT record {}:Lead V1'.format(
            self.recname, self.recInd))
        self.ax2.plot(self.rawSig2, picker=5)  # 5 points tolerance

        # plot Expert Labels
        # self.plotExpertLabels(ax,self.rawSig)
        # self.plotExpertLabels(self.ax2,self.rawSig2)
        # plot Result labels
        self.plotResultLabels(ax, 0, self.rawSig)
        self.plotResultLabels(self.ax2, 1, self.rawSig2)

        # update draw
        self.fig.canvas.draw()

    def update(self):

        #self.ax2.text(0.05, 0.9, 'mu=%1.3f\nsigma=%1.3f' % (xs[dataind], ys[dataind]),
        #transform=self.ax2.transAxes, va='top')
        #self.ax2.set_ylim(-0.5, 1.5)

        self.fig.canvas.draw()

    def next_record(self):
        ''' When press n, plot next record in QT database.'''
        self.recInd += 1
        self.recname = self.reclist[self.recInd]
        # self.result_file_path = os.path.join(self.result_folder_path,'{}.json'.format(self.recname))
        self.sigStruct = self.QTdb.load(self.recname)
        self.rawSig = self.sigStruct['sig']
        self.rawSig2 = self.sigStruct['sig2']
        self.expLabels = self.QTdb.getexpertlabeltuple(self.recname)

    def plotResultLabels(self, ax, prdInd, rawSig):
        '''Plot Final Output labels.'''
        with open(self.result_file_path, 'r') as fin:
            # Json file of format:
            # { 'LeadResult' = [{label = poslist}, {label = poslist}],
            #   'recname' = 'sel100'
            # }
            resStruct = json.load(fin)
            labelDict = resStruct["LeadResult"][prdInd]
        # Sort items by label.
        result_list = []
        for label, posList in labelDict.iteritems():
            result_list.append([label, posList])
        result_list.sort(key=lambda x: x[0])
        # plot to axes
        for label, posList in result_list:
            # Make posList integer.
            posList = map(lambda x: int(x), posList)
            # Get color of the label to plot.
            if label[0] == 'T':
                color = self.colors[1]
            elif label[0] == 'P':
                color = self.colors[4]
            elif label[0] == 'R':
                color = self.colors[9]
            else:
                color = self.colors[14]
            # Choose the color to plot on axes.
            if 'onset' in label:
                marker = '<'
            elif 'offset' in label:
                marker = '>'
            elif len(label) == 1:
                marker = 'o'
            else:
                marker = '.'
            ax.plot(posList,
                    map(lambda x: rawSig[x], posList),
                    marker=marker,
                    color=color,
                    linestyle='none',
                    markersize=14,
                    label='Pred[{}]'.format(label),
                    alpha=0.8)
        ax.legend(numpoints=1)

    def plotExpertLabels(self, ax, rawSig):

        #get label Dict
        labelSet = set()
        labelDict = dict()
        for pos, label in self.expLabels:
            if label in labelSet:
                labelDict[label].append(pos)
            else:
                labelSet.add(label)
                labelDict[label] = [
                    pos,
                ]

        # plot to axes
        for label, posList in labelDict.iteritems():
            # plot marker for current label
            if label[0] == 'T':
                color = self.colors[4]
            elif label[0] == 'P':
                color = self.colors[5]
            elif label[0] == 'R':
                color = self.colors[6]
            # marker
            if 'onset' in label:
                marker = '<'
            elif 'offset' in label:
                marker = '>'
            else:
                marker = 'o'
            ax.plot(posList,
                    map(lambda x: rawSig[x], posList),
                    marker=marker,
                    color=color,
                    linestyle='none',
                    markersize=14,
                    label=label)
        ax.legend(numpoints=1)
def PlotrawPredictionLabels(picklefilename):
    # Init Parameters
    target_recname = None
    ResID = 0
    showExpertLabel = True
    xLimtoLabelRange = True

    with open(picklefilename,'r') as fin:
        Results = pickle.load(fin)
    # only plot target rec
    if target_recname is not None:
        print 'Result[{}] QTrecord name:{}'.format(ResID,Results[ResID][0])
        if target_recname is not None and Results[ResID][0]!= target_recname:
            return 

    #
    # show filtered results & raw results
    # Evaluate prediction result statistics
    #
    recname = Results[ResID][0]
    recLoader = QTloader()
    sig = recLoader.load(recname)

    rawReslist = Results[ResID][1]

    # plot signal
    plt.figure(1);

    rawsig = sig['sig']
    # plot sig 
    plt.plot(rawsig)
    # for xLim(init)
    Label_xmin ,Label_xmax = rawReslist[0][0],rawReslist[0][0]
    plotmarkerlist = PlotMarkerList()
    PlotMarkerdict = {x:[] for x in plotmarkerlist}
    map(lambda x:PlotMarkerdict[Label2PlotMarker(x[1])].append((x[0],rawsig[x[0]])),rawReslist)

    # for each maker
    for mker,posAmpList in PlotMarkerdict.iteritems():
        if len(posAmpList) ==0:
            continue
        poslist,Amplist = zip(*posAmpList)
        Label_xmin = min(Label_xmin,min(poslist))
        Label_xmax = max(Label_xmax ,max(poslist))
        plt.plot(poslist,Amplist,mker,label='{} Label'.format(mker))
    # plot expert marks
    if showExpertLabel:
        # blend together
        explbpos = recLoader.getexpertlabeltuple(recname)
        explbpos = [[x[0],x[0]] for x in explbpos]
        explbAmp = [[-100,100] for x in explbpos]
        # plot expert labels
        for x in explbpos:
            xpos = x[0]
            plt.plot([xpos,xpos],[-10,10],'black')
        #h_expertlabel = plt.plot(explbpos,explbAmp,'black')
        # set plot properties
        #plt.setp(h_expertlabel,'ms',12)
    if xLimtoLabelRange == True:
        plt.xlim(Label_xmin-100,Label_xmax+100)

    plt.xlabel('Samples')
    plt.ylabel('Amplitude')
    plt.title(recname)
    plt.show()
class RecSelector():
    def __init__(self):
        self.qt = QTloader()

    def inspect_recs(self):
        reclist = self.qt.getQTrecnamelist()
        set_testing = set(conf['selQTall0_test_set'])
        set_training = set(reclist) - set_testing

        out_reclist = set_training

        # records selected
        selected_record_list = []
        for ind, recname in enumerate(out_reclist):
            # inspect
            print '{} records left.'.format(len(out_reclist) - ind - 1)
            # plot
            QTsig = self.qt.load(recname)
            rawsig = QTsig['sig']
            # expert labels
            testresult = self.qt.getexpertlabeltuple(recname)
            poslist, labellist = zip(*testresult)
            poslist = list(poslist)
            poslist.sort()
            dispRange = (poslist[0], poslist[-1])
            resplt = ECGResultPloter(rawsig, testresult)
            resplt.plot(plotTitle='QT record {}'.format(recname),
                        dispRange=dispRange)
            #self.qt.plotrec(recname)
            usethis = raw_input('Use this record as training record?(y/n):')
            if usethis == 'y':
                selected_record_list.append(recname)
            # debug
            #if ind > 2:
            #pass
        with open(
                os.path.join(os.path.dirname(curfilepath),
                             'selected_training_records.json'), 'w') as fout:
            json.dump(selected_record_list, fout)

    def pick_invalid_record(self):
        reclist = self.qt.getQTrecnamelist()
        invalidrecordlist = []
        for ind, recname in enumerate(reclist):
            # inspect
            print 'Inspecting record ' '{}' ''.format(recname)
            sig = self.qt.load(recname)
            if abs(sig['sig'][0]) == float('inf'):
                invalidrecordlist.append(recname)
        return invalidrecordlist

    def save_recs_to_img(self):
        reclist = self.qt.getQTrecnamelist()
        sel1213 = conf['sel1213']
        sel1213set = set(sel1213)

        out_reclist = set(reclist) - sel1213set

        for ind, recname in enumerate(reclist):
            # inspect
            print '{} records left.'.format(len(out_reclist) - ind - 1)
            #self.qt.plotrec(recname)
            self.qt.PlotAndSaveRec(recname)
            # debug
            if ind > 9:
                pass
                #print 'debug break'
                #break

    def inspect_recname(self, tarrecname):
        self.qt.plotrec(tarrecname)

    def inspect_selrec(self):
        QTreclist = self.qt.getQTrecnamelist()
        with open(
                os.path.join(curfolderpath, 'selected_training_records.json'),
                'r') as fin:
            sel0115 = json.load(fin)
        #sel0115 = conf['selQTall0']
        #test_reclist = set(QTreclist) - set(sel0115)
        for ind, recname in enumerate(sel0115):
            print '{} records left.'.format(len(sel0115) - ind - 1)
            self.inspect_recname(recname)

    def RFtest(self, testrecname):
        ecgrf = ECGRF()
        sel1213 = conf['sel1213']
        ecgrf.training(sel1213)
        Results = ecgrf.testing([
            testrecname,
        ])
        # Evaluate result
        filtered_Res = ECGRF.resfilter(Results)
        stats = ECGstats(filtered_Res[0:1])
        Err, FN = stats.eval(debug=False)

        # write to log file
        EvalLogfilename = os.path.join(projhomepath, 'res.log')
        stats.dispstat0(\
                pFN = FN,\
                pErr = Err)
        # plot prediction result
        stats.plotevalresofrec(Results[0][0], Results)
Exemple #9
0
class NonQRS_SWT:
    def __init__(self):
        self.QTdb = QTloader()

    def LoadMarkListFromJson(self, jsonfilepath):
        # json format:
        #
        pass

    def getNonQRSsig(self, recname, MarkList=None):
        # QRS width threshold
        QRS_width_threshold = 180

        sig_struct = self.QTdb.load(recname)
        rawsig = sig_struct['sig']

        if MarkList is None:
            expert_marklist = self.QTdb.getexpertlabeltuple(recname)
        else:
            expert_marklist = MarkList

        # Use QRS region to flattern the signal
        expert_marklist = filter(lambda x: 'R' in x[1] and len(x[1]) > 1,
                                 expert_marklist)
        # Get QRS region
        expert_marklist.sort(key=lambda x: x[0])
        QRS_regionlist = []
        N_Rlist = len(expert_marklist)
        for ind in xrange(0, N_Rlist - 1):
            pos, label = expert_marklist[ind]
            # last one: no match pair
            if ind == N_Rlist - 1:
                break
            elif label != 'Ronset':
                continue
            # get next label
            next_pos, next_label = expert_marklist[ind + 1]
            if next_label == 'Roffset':
                if next_pos - pos >= QRS_width_threshold:
                    print 'Warning: no matching Roffset found!'
                else:
                    QRS_regionlist.append([pos, next_pos])
                    print 'Adding:', pos, next_pos
                    # flattern the signal
                    amp_start = rawsig[pos]
                    amp_end = rawsig[next_pos]
                    flat_segment = map(
                        lambda x: amp_start + float(x) *
                        (amp_end - amp_start) / (next_pos - pos),
                        xrange(0, next_pos - pos))
                    for segment_index in xrange(pos, next_pos):
                        rawsig[segment_index] = flat_segment[segment_index -
                                                             pos]
        return rawsig

    def crop_data_for_swt(self, rawsig):
        # crop rawsig
        base2 = 1
        N_data = len(rawsig)
        if len(rawsig) <= 1:
            raise Exception('len(rawsig)={},not enough for swt!', len(rawsig))
        crop_len = base2
        while base2 < N_data:
            if base2 * 2 >= N_data:
                crop_len = base2 * 2
                break
            base2 *= 2
        # pad zeros
        if N_data < crop_len:
            rawsig += [
                rawsig[-1],
            ] * (crop_len - N_data)
        return rawsig

    def swt(self, recname, wavelet='db6', MaxLevel=9):
        #
        rawsig = self.getNonQRSsig(recname)
        rawsig = self.crop_data_for_swt(rawsig)

        coeflist = pywt.swt(rawsig, wavelet, MaxLevel)
        cAlist, cDlist = zip(*coeflist)
        self.cAlist = cAlist
        self.cDlist = cDlist
Exemple #10
0
class Evaluation2Leads:
    def __init__(self):
        print '[Warning]Evaluator will convert prediction indexes to Integer.'
        self.QTdb = QTloader()
        self.labellist = []
        self.expertlist = []
        self.recname = None
        self.prdMatchList = []
        self.expMatchList = []
        # color schemes
        tableau20 = [(31, 119, 180), (174, 199, 232), (255, 127, 14),
                     (255, 187, 120), (44, 160, 44), (152, 223, 138),
                     (214, 39, 40), (255, 152, 150), (148, 103, 189),
                     (197, 176, 213), (140, 86, 75), (196, 156, 148),
                     (227, 119, 194), (247, 182, 210), (127, 127, 127),
                     (199, 199, 199), (188, 189, 34), (219, 219, 141),
                     (23, 190, 207), (158, 218, 229)]
        self.colors = []
        for color_tri in tableau20:
            self.colors.append((color_tri[0] / 255.0, color_tri[1] / 255.0,
                                color_tri[2] / 255.0))

        # error List
        # [expPos - prdPos,...]
        self.errList = None

    def clear(self):
        # error List
        # [expPos - prdPos,...]
        self.errList = None
        self.labellist = []
        self.expertlist = []
        self.recname = None
        self.prdMatchList = []
        self.expMatchList = []

    def loadlabellist(self, filename, TargetLabel):
        '''Load label list with Target Label from json file.'''

        with open(filename, 'r') as fin:
            Res = json.load(fin)
        #reclist = Res.keys()

        #if len(reclist) ==0:
        #raise Exception('No result data in file:'+filename+'!')
        #elif len(reclist)>1:
        #print 'Rec List:',reclist
        #print '[Warning] Multiple result data, only the 1st one is used!'

        self.recname = Res['recname']
        print '>>loading recname:', self.recname

        self.leadResults = Res['LeadResult']
        self.leadPosList = (self.leadResults[0][TargetLabel],
                            self.leadResults[1][TargetLabel])

        # convert values to int
        #self.labellist = map(lambda x:int(x),self.labellist)

        expertlist = self.QTdb.getexpertlabeltuple(self.recname)
        self.expertlist = map(
            lambda x: x[0], filter(lambda x: x[1] == TargetLabel, expertlist))

    def getMatchList(self, prdposlist, MaxMatchDist=50):
        '''Return the match result between prdposlist and expertlist.'''
        # Max Match Dist

        prdposlist.sort()
        self.expertlist.sort()

        prdMatchList = [
            -1,
        ] * len(prdposlist)
        expMatchList = [
            -1,
        ] * len(self.expertlist)

        N_prdList = len(prdposlist)
        for expInd, exppos in enumerate(self.expertlist):
            insertPos = bisect.bisect_left(prdposlist, exppos)
            left_ind = insertPos - 1
            right_ind = insertPos
            matchInd = -1
            matchDist = -1
            if left_ind < 0:
                matchInd = right_ind
                matchDist = abs(exppos - prdposlist[right_ind])
            elif right_ind >= N_prdList:
                matchInd = left_ind
                matchDist = abs(exppos - prdposlist[left_ind])
            else:
                leftDist = abs(exppos - prdposlist[left_ind])
                rightDist = abs(exppos - prdposlist[right_ind])
                if leftDist > rightDist:
                    matchInd = right_ind
                    matchDist = rightDist
                else:
                    matchInd = left_ind
                    matchDist = leftDist
            if matchInd == -1 or matchDist >= MaxMatchDist:
                expMatchList[expInd] = -1
            else:
                expMatchList[expInd] = matchInd
                prdMatchList[matchInd] = expInd
        return (expMatchList, prdMatchList)

    def evaluate(self, TargetLabel):
        self.prdMatchList = []
        self.expMatchList = []

        # get 2 Leads match lists
        expMatchList, prdMatchList = self.getMatchList(self.leadPosList[0])
        self.prdMatchList.append(prdMatchList)
        self.expMatchList.append(expMatchList)
        expMatchList, prdMatchList = self.getMatchList(self.leadPosList[1])
        self.prdMatchList.append(prdMatchList)
        self.expMatchList.append(expMatchList)

        # get error statistics
        self.get_errList()

    def getFNlist(self):
        '''Return total number of False Negtives.'''
        return self.FNcnt

    def getFPlist(self):
        return -1

    def plot_evaluation_result(self):
        sigStruct = self.QTdb.load(self.recname)
        rawSig = sigStruct['sig']

        plt.figure(1)
        plt.subplot(211)
        plt.plot(rawSig)

        # plot expert labels
        expPosList = self.expertlist
        plt.plot(expPosList,
                 map(lambda x: rawSig[x], expPosList),
                 'd',
                 color=self.colors[2],
                 markersize=12,
                 label='Expert Label')
        # unmatched expert labels
        #FNPosList = map(lambda x:x[0],filter(lambda x:x[1]!=-1 or x[2]!=-1,zip(expPosList,self.expMatchList[0],self.expMatchList[1])))
        #plt.plot(FNPosList,map(lambda x:rawSig[x],FNPosList),'kd',markersize = 16,label = 'False Negtive')

        # T predict list
        prdPosList = self.leadPosList[0]
        plt.plot(prdPosList,
                 map(lambda x: rawSig[int(x)], prdPosList),
                 '*',
                 color=self.colors[3],
                 label='prd1',
                 markersize=12)

        prdPosList = self.leadPosList[1]
        plt.plot(prdPosList,
                 map(lambda x: rawSig[int(x)], prdPosList),
                 '*',
                 color=self.colors[4],
                 label='prd2',
                 markersize=12)

        # not matched list
        #prdPosList = map(lambda x:x[0],filter(lambda x:x[1]==-1,zip(self.labellist,self.prdMatchList)))
        #plt.plot(prdPosList,map(lambda x:rawSig[x],prdPosList),'k*',markersize = 14,label = 'False Positive')

        # plot match line
        #for expPos,matchInd in zip(self.expertlist,self.expMatchList):
        #if matchInd == -1:
        #continue
        #prdPos = self.labellist[matchInd]
        #plt.plot([expPos,prdPos],[rawSig[expPos],rawSig[prdPos]],lw = 14,color = self.colors[4],alpha = 0.3)

        plt.grid(True)
        plt.legend()
        plt.title(self.recname)
        plt.show()

    def get_errList(self):
        self.errList = []
        FNcnt = 0
        # plot match line
        for expPos, matchInd1, matchInd2 in zip(self.expertlist,
                                                self.expMatchList[0],
                                                self.expMatchList[1]):
            if matchInd1 == -1 and matchInd2 == -1:
                FNcnt += 1
                continue
            prdpos1 = self.leadPosList[0][matchInd1]
            prdpos2 = self.leadPosList[1][matchInd2]
            err1 = expPos - prdpos1
            err2 = expPos - prdpos2

            # chooose the one with smaller error
            if abs(err1) < abs(err2):
                self.errList.append(4.0 * err1)
            else:
                self.errList.append(4.0 * err2)

        # total number of False Negtives
        self.FNcnt = FNcnt

        return self.errList

    def get_total_mean(self):
        meanVal = np.mean(self.errList)
        return meanVal

    def get_total_stdvar(self):
        stdvar = np.std(self.errList)
        return stdvar
Exemple #11
0
class PointBrowser(object):
    """
    Click on a point to select and highlight it -- the data that
    generated the point will be shown in the lower axes.  Use the 'n'
    and 'p' keys to browse through the next and previous points
    """
    def __init__(self, fig, ax, start_index):
        self.fig = fig
        self.ax = ax
        self.SaveFolder = os.path.join(curfolderpath, 'results')

        self.text = self.ax.text(0.05,
                                 0.95,
                                 'selected: none',
                                 transform=self.ax.transAxes,
                                 va='top')
        # ============================
        # QTdb
        self.QTdb = QTloader()
        self.reclist = self.QTdb.reclist
        self.recInd = start_index
        self.recname = self.reclist[self.recInd]
        self.sigStruct = self.QTdb.load(self.recname)
        self.rawSig = self.sigStruct['sig']
        self.expLabels = self.QTdb.getexpertlabeltuple(self.recname)

        tableau20 = [(31, 119, 180), (174, 199, 232), (255, 127, 14),
                     (255, 187, 120), (44, 160, 44), (152, 223, 138),
                     (214, 39, 40), (255, 152, 150), (148, 103, 189),
                     (197, 176, 213), (140, 86, 75), (196, 156, 148),
                     (227, 119, 194), (247, 182, 210), (127, 127, 127),
                     (199, 199, 199), (188, 189, 34), (219, 219, 141),
                     (23, 190, 207), (158, 218, 229)]
        self.colors = []
        for color_tri in tableau20:
            self.colors.append((color_tri[0] / 255.0, color_tri[1] / 255.0,
                                color_tri[2] / 255.0))
        # ===========================
        # Mark list
        self.poslist = []
        self.totalWhiteCount = 0

    def onpress(self, event):
        if event.key not in ('n', 'p', ' ', 'x', 'a', 'd'):
            return

        if event.key == 'n':
            self.saveWhiteMarkList2Json()
            self.next_record()
            self.clearWhiteMarkList()

            # clear Marker List
            self.reDraw()
            return None
        elif event.key == ' ':
            self.reDraw()

            return None
        elif event.key == 'x':
            # Delete markers in stack
            if len(self.poslist) > 0:
                del self.poslist[-1]

        elif event.key == 'a':
            step = -200
            xlims = self.ax.get_xlim()
            new_xlims = [xlims[0] + step, xlims[1] + step]
            self.ax.set_xlim(new_xlims)
        elif event.key == 'd':
            step = 200
            xlims = self.ax.get_xlim()
            new_xlims = [xlims[0] + step, xlims[1] + step]
            self.ax.set_xlim(new_xlims)
        else:
            pass

        self.update()

    def saveWhiteMarkList2Json(self):
        with open(
                os.path.join(self.SaveFolder,
                             '{}_poslist.json'.format(self.recname)),
                'w') as fout:
            result_info = dict(ID=self.recname,
                               database='QTdb',
                               poslist=self.poslist,
                               type='Tonset')
            json.dump(result_info, fout, indent=4, sort_keys=True)
            print 'Json file for record {} saved.'.format(self.recname)

    def clearWhiteMarkList(self):
        self.poslist = []
        self.totalWhiteCount = 0

    def addMarkx(self, x):
        # mark data
        pos = int(x)
        self.poslist.append(pos)

        self.ax.plot(pos,
                     self.rawSig[pos],
                     marker='x',
                     color=self.colors[7],
                     markersize=22,
                     markeredgewidth=4,
                     alpha=0.9,
                     label='Tonset')
        self.ax.set_xlim(pos - 500, pos + 500)

    def onpick(self, event):
        '''Mouse click to mark target points.'''
        # The click locations
        x = event.mouseevent.xdata
        y = event.mouseevent.ydata

        # add white Mark
        self.addMarkx(x)
        self.text.set_text('Marking Tonset: ({}) [whiteCnt {}]'.format(
            self.poslist[-1], len(self.poslist)))

        # update canvas
        self.fig.canvas.draw()

    def RepeatCheck(self):
        '''Check repeat results.'''
        result_file_name = os.path.join(self.SaveFolder,
                                        '{}_poslist.json'.format(self.recname))
        if os.path.exists(result_file_name):
            window = Tkinter.Tk()
            window.wm_withdraw()
            tkMessageBox.showinfo(title='Repeat',
                                  message='The record %s is already marked!' %
                                  self.recname)
            window.destroy()

            # Go to next record
            self.next_record()
            self.clearWhiteMarkList()
            self.reDraw()

    def reDraw(self):

        self.RepeatCheck()

        ax = self.ax
        ax.cla()

        self.text = self.ax.text(0.05,
                                 0.95,
                                 'selected: none',
                                 transform=self.ax.transAxes,
                                 va='top')
        ax.grid(color=(0.8, 0.8, 0.8), linestyle='--', linewidth=2)

        # ====================================
        # load ECG signal

        ax.set_title('QT {} (Index = {})'.format(self.recname, self.recInd))
        ax.plot(self.rawSig, picker=5)  # 5 points tolerance
        # plot Expert Labels
        self.plotExpertLabels(ax)

        # draw Markers
        for pos in self.poslist:
            # draw markers
            self.ax.plot(pos,
                         self.rawSig[pos],
                         marker='x',
                         color=self.colors[7],
                         markersize=22,
                         markeredgewidth=4,
                         alpha=0.9,
                         label='Tonset')

        self.ax.set_xlim(0, len(self.rawSig))
        # update draw
        self.fig.canvas.draw()

    def update(self):
        #self.ax2.text(0.05, 0.9, 'mu=%1.3f\nsigma=%1.3f' % (xs[dataind], ys[dataind]),
        #transform=self.ax2.transAxes, va='top')
        #self.ax2.set_ylim(-0.5, 1.5)

        self.fig.canvas.draw()

    def next_record(self):
        self.recInd += 1
        if self.recInd >= len(self.reclist):
            return False
        self.recname = self.reclist[self.recInd]
        self.sigStruct = self.QTdb.load(self.recname)
        self.rawSig = self.sigStruct['sig']
        self.expLabels = self.QTdb.getexpertlabeltuple(self.recname)
        return True

    def plotExpertLabels(self, ax):

        #get label Dict
        labelSet = set()
        labelDict = dict()
        for pos, label in self.expLabels:
            if label in labelSet:
                labelDict[label].append(pos)
            else:
                labelSet.add(label)
                labelDict[label] = [
                    pos,
                ]

        # plot to axes
        for label, posList in labelDict.iteritems():
            # plot marker for current label
            if label[0] == 'T':
                color = self.colors[4]
            elif label[0] == 'P':
                color = self.colors[5]
            elif label[0] == 'R':
                color = self.colors[6]
            # marker
            if 'onset' in label:
                marker = '<'
            elif 'offset' in label:
                marker = '>'
            else:
                marker = 'o'
            ax.plot(posList,
                    map(lambda x: self.rawSig[x], posList),
                    marker=marker,
                    color=color,
                    linestyle='none',
                    markersize=14,
                    label=label)
        ax.legend(numpoints=1)