Ejemplo n.º 1
0
def TestAllQTdata(saveresultpath, testinglist):
    '''Test all records in testinglist, training on remaining records in QTdb.'''
    qt_loader = QTloader()
    QTreclist = qt_loader.getQTrecnamelist()
    # Get training record list
    traininglist = list(set(QTreclist) - set(testinglist))

    # Testing
    from randomwalk.test_api import GetModels
    from randomwalk.test_api import Testing
    # pattern_filename = os.path.join(os.path.dirname(saveresultpath), 'randrel.json')
    pattern_filename = '/home/alex/LabGit/ECG_random_walk/randomwalk/data/Lw3Np4000/random_pattern.json'
    model_folder = '/home/alex/LabGit/ECG_random_walk/randomwalk/data/Lw3Np4000'
    model_list = GetModels(model_folder, pattern_filename)

    for record_name in testinglist:
        sig = qt_loader.load(record_name)
        raw_sig = sig['sig']

        start_time = time.time()
        results = Testing(raw_sig, 250.0, model_list, walker_iterations=100)
        time_cost = time.time() - start_time

        with open(os.path.join(saveresultpath, '%s.json' % record_name),
                  'w') as fout:
            json.dump(results, fout)
            print 'Testing time %f s, data time %f s.' % (time_cost,
                                                          len(raw_sig) / 250.0)
Ejemplo n.º 2
0
def TrainSwtModel(save_model_folder, save_sample_folder, target_label,
                  random_relation_file_path):
    '''Test all records in testinglist, training on remaining records in QTdb.'''
    qt_loader = QTloader()
    QTreclist = qt_loader.getQTrecnamelist()

    # Get training record list
    random.shuffle(QTreclist)
    traininglist = QTreclist[0:75]

    random_relation_file_path = os.path.dirname(random_relation_file_path)
    rf_classifier = ECGrf(SaveTrainingSampleFolder=save_sample_folder,
                          allowed_label_list=[
                              target_label,
                          ],
                          random_relation_path=random_relation_file_path)
    # Multi Process
    rf_classifier.TestRange = 'All'

    # Training
    time_cost_output = []
    timing_for(rf_classifier.TrainQtRecords, [
        traininglist,
    ],
               prompt='Total Training time:',
               time_cost_output=time_cost_output)
    log.info('Total training time cost: %.2f seconds', time_cost_output[-1])
    # save trained mdl
    backupobj(rf_classifier.mdl,
              os.path.join(save_model_folder, 'trained_model.mdl'))
Ejemplo n.º 3
0
def TrainModel(save_result_folder):
    '''Train a random forest model with QT records.'''
    qt_loader = QTloader()
    QTreclist = qt_loader.getQTrecnamelist()
    training_record_list = QTreclist[0:15]
    # training
    Training(save_result_folder, training_record_list)
Ejemplo n.º 4
0
def TestRecord(saveresultpath):
    '''Test all records in testinglist, training on remaining records in QTdb.'''
    qt_loader = QTloader()
    QTreclist = qt_loader.getQTrecnamelist()
    # Get training record list
    testinglist = QTreclist[0:10]
    # traininglist = QTreclist[0:10]

    # debug
    # traininglist = QTreclist[0:10]
    # testinglist = QTreclist[0:10]
    # log.warning('Using custom testing & training records.')
    # log.warning('Training range: 0-10')
    # log.warning('Testing range: 0-10')

    rf_classifier = ECGrf(SaveTrainingSampleFolder = saveresultpath)
    # Multi Process
    rf_classifier.TestRange = 'All'

    # Load classification model.
    with open(os.path.join(saveresultpath, 'trained_model.mdl'), 'rb') as fin:
        trained_model = pickle.load(fin)
        rf_classifier.mdl = trained_model

    # testing
    log.info('Testing records:\n    %s',', '.join(testinglist))
    for record_name in testinglist:
        sig = qt_loader.load(record_name)
        raw_signal = sig['sig']
        result = rf_classifier.testing(raw_signal, trained_model)
        with open(os.path.join(saveresultpath, 'result_{}'.format(record_name)),'w') as fout:
            json.dump(result,fout,indent = 4)
Ejemplo n.º 5
0
class RecSelector():
    def __init__(self):
        self.qt = QTloader()

    def inspect_recs(self):
        reclist = self.qt.getQTrecnamelist()
        sel1213 = conf['sel1213']
        sel1213set = set(sel1213)

        out_reclist = set(reclist)  # - sel1213set

        # records selected
        selected_record_list = []
        for ind, recname in enumerate(out_reclist):
            # inspect
            print '{} records left.'.format(len(out_reclist) - ind - 1)
            self.qt.plotrec(recname)
            # debug
            if ind > 2:
                pass
                #print 'debug break'
                #break
    def plot_rec_dwt(self):
        reclist = self.qt.getQTrecnamelist()
        sel1213 = conf['sel1213']
        sel1213set = set(sel1213)

        out_reclist = set(reclist)  # - sel1213set

        # dwt coef
        dwtECG = WTfeature()
        # records selected
        selected_record_list = []
        for ind, recname in enumerate(out_reclist):
            # inspect
            print 'processing record : {}'.format(recname)
            print '{} records left.'.format(len(out_reclist) - ind - 1)
            # debug selection
            if not recname.startswith(ur'sele0704'):
                continue
            sig = self.qt.load(recname)
            # wavelet
            waveletobj = dwtECG.gswt_wavelet()
            self.plot_dwt_coef_with_annotation(sig['sig'],
                                               waveletobj=waveletobj,
                                               ECGrecordname=recname)
Ejemplo n.º 6
0
def TestAllQTdata(saveresultpath):
    # Leave Ntest out of 30 records to test
    #

    qt_loader = QTloader()
    QTreclist = qt_loader.getQTrecnamelist()
    print 'Totoal QT record number:{}'.format(len(QTreclist))
    ## Training record list
    selall0 = conf["selQTall0"]
    selrecords = list(set(QTreclist) - set(selall0))
    rf = ECGRF.ECGrf()
    # Multi Process
    rf.useParallelTest = True
    rf.TestRange = 'All'

    # ================
    # evaluate time cost for each stage
    # ================

    # clear debug logger
    ECGRF.debugLogger.clear()
    # display the time left to finish program
    one_round_time_cost = []

    # ====================
    # Training
    # ====================
    ECGRF.debugLogger.dump('\n====Test Start ====\n')

    time0 = time.time()
    # training the rf classifier with reclist
    #
    # dump to debug logger
    rf.training(selrecords)
    # timing
    time1 = time.time()
    print 'Total Training time:', time1 - time0
    ECGRF.debugLogger.dump('Total Training time: {:.2f} s\n'.format(time1 -
                                                                    time0))

    # save the trained model
    savemodelpath = os.path.join(saveresultpath, 'QT_sel.mdl')
    with open(savemodelpath, 'w') as fout:
        pickle.dump(rf.mdl, fout)
    return

    ## test
    testinglist = selall0
    print '\n>>Testing:', testinglist
    ECGRF.debugLogger.dump('\n======\n\nTest Set :{}'.format(selrecords))
    rf.testrecords(reclist=selrecords,
                   TestResultFileName=os.path.join(saveresultpath,
                                                   'TestResult.out'))
Ejemplo n.º 7
0
def TEST1(save_result_path, target_label='T'):
    qt_loader = QTloader()
    qt_record_list = qt_loader.getQTrecnamelist()

    # testing& training set
    training_list = qt_record_list[0:4]
    testing_list = qt_record_list[41:50]

    # Start traing & testing
    TestAllQTdata(save_result_path,
                  testing_list,
                  training_list,
                  target_label=target_label)
Ejemplo n.º 8
0
def TestAllQTdata(save_result_path,
                  testinglist,
                  training_list=None,
                  target_label='T'):
    '''Test Regression Learner with QTdb.'''
    qt_loader = QTloader()
    QTreclist = qt_loader.getQTrecnamelist()
    # Get training record list
    if training_list is None:
        training_list = list(set(QTreclist) - set(testinglist))

    # debug
    # training_list = training_list[0:2]
    # # testinglist = QTreclist[0:10]
    # log.warning('Using custom testing & training records.')
    # log.warning('Training range: 0-10')
    # log.warning('Testing range: ~30')

    log.info('Totoal QTdb record number:%d, training %d, testing %d',
             len(QTreclist), len(training_list), len(testinglist))

    rf_classifier = ECGrf(target_label=target_label,
                          SaveTrainingSampleFolder=save_result_path)
    # Multi Process
    rf_classifier.TestRange = 'All'

    # Training
    # ====================
    log.info('Start training...')
    print 'training...'

    # training the rf_classifier classifier with reclist
    time_cost_output = []
    timing_for(rf_classifier.TrainQtRecords, [
        training_list,
    ],
               prompt='Total Training time:',
               time_cost_output=time_cost_output)
    log.info('Total training time cost: %.2f seconds', time_cost_output[-1])
    # save trained mdl
    backupobj(rf_classifier.mdl,
              os.path.join(save_result_path, 'trained_model.mdl'))

    # testing
    log.info('Testing records:\n    %s', ', '.join(testinglist))
    rf_classifier.TestQtRecords(save_result_path, reclist=testinglist)
Ejemplo n.º 9
0
def TestAllQTdata(saveresultpath):
    # Leave Ntest out of 30 records to test
    #
    qt_loader = QTloader()
    QTreclist = qt_loader.getQTrecnamelist()
    print 'Totoal QT record number:{}'.format(len(QTreclist))
    ## Training record list
    testinglist = conf["selQTall0_test_set"]
    traininglist = list(set(QTreclist) - set(testinglist))

    rf = ECGRF.ECGrf()
    # Multi Process
    rf.useParallelTest = True
    rf.TestRange = 'All'

    # clear debug logger
    ECGRF.debugLogger.clear()
    # ====================
    # Training
    # ====================
    ECGRF.debugLogger.dump('\n====Test Start ====\n')

    # training the rf classifier with reclist
    #
    # dump to debug logger
    time_cost_output = []
    timing_for(rf.training, [
        traininglist,
    ],
               prompt='Total Training time:',
               time_cost_output=time_cost_output)
    ECGRF.debugLogger.dump('Total Training time: {:.2f} s\n'.format(
        time_cost_output[-1]))
    # save trained mdl
    backupobj(rf.mdl, os.path.join(saveresultpath, 'trained_model.mdl'))

    ## test
    print '\n>>Testing:', testinglist
    ECGRF.debugLogger.dump('\n======\n\nTest Set :{}'.format(testinglist))
    rf.testrecords(saveresultpath, reclist=testinglist)
Ejemplo n.º 10
0
def Round_Test(saveresultpath,
               RoundNumber=1,
               number_of_test_record_per_round=30,
               round_start_index=1,
               target_label='T'):
    '''Randomly select records from QTdb to test.
        Args:
            RoundNumber: Rounds to repeatedly select records form QTdb & test.
            number_of_test_record_per_round: Number of test records to randomly
            select per round.
    '''

    qt_loader = QTloader()
    QTreclist = qt_loader.getQTrecnamelist()

    # To randomly select 30 records from may_testlist
    may_testlist = QTreclist
    # Remove records that must be in the training set
    must_train_list = [
        "sel35", "sel36", "sel31", "sel38", "sel39", "sel820", "sel51",
        "sele0104", "sele0107", "sel223", "sele0607", "sel102", "sele0409",
        "sel41", "sel40", "sel43", "sel42", "sel45", "sel48", "sele0133",
        "sele0116", "sel14172", "sele0111", "sel213", "sel14157", "sel301"
    ]
    may_testlist = list(set(may_testlist) - set(must_train_list))
    N_may_test = len(may_testlist)

    # Start testing.
    log.info('Start Round Testing...')
    for round_ind in xrange(round_start_index, RoundNumber + 1):
        # Generate round folder.
        round_folder = os.path.join(saveresultpath,
                                    'round{}'.format(round_ind))
        os.mkdir(round_folder)
        # Randomly select test records.
        test_ind_list = random.sample(xrange(0, N_may_test),
                                      number_of_test_record_per_round)
        testlist = map(lambda x: may_testlist[x], test_ind_list)
        # Run the test warpper.
        TestAllQTdata(round_folder, testlist, target_label=target_label)
Ejemplo n.º 11
0
class WT_choose:
    def __init__(self):
        self.qt = QTloader()
        self.reclist = self.qt.getQTrecnamelist()

    def loadsig(self, selID):
        self.sig = self.qt.load(self.reclist[selID])

    def PlotDWTCoef(self, sig=None, figID=1, waveletobj=pywt.Wavelet('sym2')):
        if sig is None:
            sig = self.sig['sig']

        rawsig = sig
        wtf = WTfeature()
        # DWT
        Level = 7
        #waveletobj = wtf.gswt_wavelet()
        #waveletobj = pywt.Wavelet('sym2')
        res = pywt.wavedec(rawsig, waveletobj, level=Level)

        plt.figure(figID)
        N_subplot = Level + 2
        plt.subplot(N_subplot, 1, 1)
        plt.plot(rawsig)
        plt.title('Original Signal')
        detail_i = 1
        for i in xrange(Level, 0, -1):
            plt.subplot(N_subplot, 1, detail_i + 1)
            plt.plot(res[i])
            plt.xlim(0, len(res[i]) - 1)
            plt.title('Detail Level {}'.format(detail_i))
            detail_i += 1
        plt.subplot(N_subplot, 1, N_subplot)
        plt.plot(res[0])
        plt.xlim(0, len(res[0]) - 1)
        plt.title('Approximation Level')

        plt.show()
Ejemplo n.º 12
0
def TestAllQTdata(saveresultpath, testinglist):
    '''Test all records in testinglist, training on remaining records in QTdb.'''
    qt_loader = QTloader()
    QTreclist = qt_loader.getQTrecnamelist()
    # Get training record list
    traininglist = list(set(QTreclist) - set(testinglist))

    # debug
    # traininglist = QTreclist[0:10]
    # testinglist = QTreclist[0:10]
    # log.warning('Using custom testing & training records.')
    # log.warning('Training range: 0-10')
    # log.warning('Testing range: 0-10')

    log.info('Totoal QTdb record number:%d, training %d, testing %d',
             len(QTreclist), len(traininglist), len(testinglist))

    rf_classifier = ECGrf(SaveTrainingSampleFolder=saveresultpath)
    # Multi Process
    rf_classifier.TestRange = 'All'

    # Training

    time_cost_output = []
    timing_for(rf_classifier.TrainQtRecords, [
        traininglist,
    ],
               prompt='Total Training time:',
               time_cost_output=time_cost_output)
    log.info('Total training time cost: %.2f seconds', time_cost_output[-1])
    # save trained mdl
    backupobj(rf_classifier.mdl,
              os.path.join(saveresultpath, 'trained_model.mdl'))

    # testing
    log.info('Testing records:\n    %s', ', '.join(testinglist))
    rf_classifier.TestQtRecords(saveresultpath, reclist=testinglist)
Ejemplo n.º 13
0
class RecSelector():
    def __init__(self):
        self.qt = QTloader()

    def select_records(self, display_list, output_filename):
        out_reclist = display_list
        # records selected
        selected_record_list = []
        for ind, recname in enumerate(out_reclist):
            # inspect
            print '{} records left.'.format(len(out_reclist) - ind - 1)
            self.qt.plotrec(recname)
            usethis = raw_input('select this record?(y/n):')
            if usethis == 'y':
                selected_record_list.append(recname)
        # save selection result
        with open(os.path.join(os.path.dirname(curfilepath), output_filename),
                  'w') as fout:
            json.dump(selected_record_list, fout)

    def inspect_recs(self):
        reclist = self.qt.getQTrecnamelist()
        sel1213 = conf['sel1213']
        sel1213set = set(sel1213)

        out_reclist = set(reclist)  # - sel1213set

        # records selected
        selected_record_list = []
        for ind, recname in enumerate(out_reclist):
            # inspect
            print '{} records left.'.format(len(out_reclist) - ind - 1)
            self.qt.plotrec(recname)
            usethis = raw_input('Use this record as training record?(y/n):')
            if usethis == 'y':
                selected_record_list.append(recname)
            # debug
            if ind > 2:
                pass
                #print 'debug break'
                #break
        with open(
                os.path.join(os.path.dirname(curfilepath),
                             'selected_records.json'), 'w') as fout:
            json.dump(selected_record_list, fout)

    def pick_invalid_record(self):
        reclist = self.qt.getQTrecnamelist()
        invalidrecordlist = []
        for ind, recname in enumerate(reclist):
            # inspect
            print 'Inspecting record ' '{}' ''.format(recname)
            sig = self.qt.load(recname)
            if abs(sig['sig'][0]) == float('inf'):
                invalidrecordlist.append(recname)
        return invalidrecordlist

    def save_recs_to_img(self):
        reclist = self.qt.getQTrecnamelist()
        sel1213 = conf['sel1213']
        sel1213set = set(sel1213)

        out_reclist = set(reclist) - sel1213set

        for ind, recname in enumerate(reclist):
            # inspect
            print '{} records left.'.format(len(out_reclist) - ind - 1)
            #self.qt.plotrec(recname)
            self.qt.PlotAndSaveRec(recname)
            # debug
            if ind > 9:
                pass
                #print 'debug break'
                #break

    def inspect_recname(self, tarrecname):
        self.qt.plotrec(tarrecname)

    def inspect_selrec(self):
        QTreclist = self.qt.getQTrecnamelist()
        sel0115 = conf['selQTall0']
        test_reclist = set(QTreclist) - set(sel0115)
        for ind, recname in enumerate(test_reclist):
            print '{} records left.'.format(len(sel0115) - ind - 1)
            self.inspect_recname(recname)

    def RFtest(self, testrecname):
        ecgrf = ECGRF()
        sel1213 = conf['sel1213']
        ecgrf.training(sel1213)
        Results = ecgrf.testing([
            testrecname,
        ])
        # Evaluate result
        filtered_Res = ECGRF.resfilter(Results)
        stats = ECGstats(filtered_Res[0:1])
        Err, FN = stats.eval(debug=False)

        # write to log file
        EvalLogfilename = os.path.join(projhomepath, 'res.log')
        stats.dispstat0(\
                pFN = FN,\
                pErr = Err)
        # plot prediction result
        stats.plotevalresofrec(Results[0][0], Results)
Ejemplo n.º 14
0
class FeatureVis:
    def __init__(self, rfmdl, RandomPairList_Path):
        # get random relations
        randrel_path = RandomPairList_Path
        with open(randrel_path, 'r') as fin:
            self.randrels = json.load(fin)

        self.rfmdl = rfmdl
        self.trees = rfmdl.estimators_
        self.qt = QTloader()
        self.qt_reclist = self.qt.getQTrecnamelist()

    def info(self):
        t0 = self.trees[0]
        # list dir
        for attr in dir(t0):
            print attr
        print 'tree-------------------------'
        for attr in dir(t0.tree_):
            print attr

        pdb.set_trace()

    def get_pair_importance_list(self):
        return self.feature_importance_test()

    def feature_importance_test(self):
        # return pairs&their importance
        rID = 1
        sig = self.qt.load(self.qt_reclist[rID])
        # random relations
        randrels = self.randrels
        fimp = self.rfmdl.feature_importances_
        sum_rr = 0
        for rr in randrels:
            sum_rr += len(rr)
        print 'len(feature importance) = {}, len(random relations) = {}'.format(
            len(fimp), sum_rr)
        # ===========================
        # get a struct of importance:
        # ------------
        # [#layer0:[((23, 2), 0.2# importance), ((pair), #importance), ()...], #layer1:[], ..]
        #
        cur_feature_ind = 0
        relation_importance = []
        for rel_layer in randrels:
            layerSZ = len(rel_layer)
            diff_arr = fimp[cur_feature_ind:cur_feature_ind + layerSZ]
            absdiff_arr = fimp[cur_feature_ind + layerSZ:cur_feature_ind +
                               2 * layerSZ]
            imp_arr = map(lambda x: max(x[0], x[1]),
                          zip(diff_arr, absdiff_arr))
            relation_importance.append(zip(rel_layer, imp_arr))
            print 'layer len = {}'.format(len(imp_arr))
            cur_feature_ind += 2 * layerSZ

        return relation_importance

    def plot_dwt_pairs_arrow_partial_window_compare(self,
                                                    rawsig,
                                                    relation_importance,
                                                    Window_Left=1200,
                                                    savefigname=None,
                                                    figsize=(10, 8),
                                                    figtitle='ECG Sample',
                                                    showFigure=True):
        ## =========================================V
        # 展示RSWT示意图
        ## =========================================V
        #
        #================
        # constants
        #================
        N = 5
        figureID = 1
        fs = conf['fs']
        FixedWindowLen = conf['winlen_ratio_to_fs'] * fs
        print 'Fixed Window Length:{}'.format(FixedWindowLen)
        xL = Window_Left
        xR = xL + FixedWindowLen
        tarpos = 1500
        # props of ARROW
        arrowprops = dict(width=1, headwidth=4, facecolor='black', shrink=0)
        # -----------------
        # get median
        # -----------------
        importance_arr = []
        for rel_layer in relation_importance:
            for rel, imp in rel_layer:
                importance_arr.append(imp)
        N_imp = len(importance_arr)
        # ascending order:0->1
        importance_arr.sort()
        IMP_Thres = importance_arr[int(N_imp / 2)]
        IMP_MAX = importance_arr[-1]
        # get wavelet obj
        # dwt coef
        dwtECG = WTfeature()
        waveletobj = dwtECG.gswt_wavelet()
        # props of ARROW
        #arrowprops = dict(width = 1,headwidth = 4,facecolor='black', shrink=0)
        pltxLim = range(xL, xR)
        sigAmp = [rawsig[x] for x in pltxLim]
        cA = rawsig
        # ====================
        # plot raw signal input
        # ====================
        Fig_main = plt.figure(figureID, figsize=figsize)
        # plot raw ECG
        plt.subplot(N + 1, 2, 1)
        # get handle for annote arrow
        # hide axis
        #frame = plt.gca()
        #frame.axes.get_xaxis().set_visible(False)
        #frame.axes.get_yaxis().set_visible(False)
        plt.plot(pltxLim, sigAmp)
        # plot reference point
        #plt.plot(tarpos,rawsig[tarpos],'ro')
        plt.title('Window the signal before DWT')
        #plt.xlim(pltxLim)
        # ===========Col2=================
        plt.subplot(N + 1, 2, 2)
        plt.plot(pltxLim, sigAmp)
        cA_col2 = cA[xL:xR]
        plt.title('Window the signal after DWT')

        for i in range(2, N + 2):
            # ====================
            # subplot col2
            # ====================
            cA_col2, cD_col2 = pywt.dwt(cA_col2, waveletobj)
            plt.subplot(N + 1, 2, i * 2 - 1)
            plt.plot(cA_col2)
            # relation&importance
            rel_layer = relation_importance[i - 2]
            cA, cD = pywt.dwt(cA, waveletobj)
            xL /= 2
            xR /= 2
            tarpos /= 2
            # crop x range out
            xi = range(xL, xR)
            cDamp = [cD[x] for x in xi]
            # get relation points
            rel_x = []
            rel_y = []
            cur_N = len(cDamp)
            # ------------
            # sub plot
            # ------------
            # plot
            fig = plt.subplot(N + 1, 2, 2 * i)
            #------------
            # find pair&its amplitude
            # -----------
            for rel_pair, imp in rel_layer:
                # importance thres
                arrowprops = dict(width=1,
                                  headwidth=4,
                                  facecolor='r',
                                  edgecolor='r',
                                  alpha=imp / IMP_MAX,
                                  shrink=0)

                rel_x.append(rel_pair[0])
                if rel_x[-1] >= cur_N:
                    rel_y.append(0)
                else:
                    rel_y.append(cDamp[rel_pair[0]])
                rel_x.append(rel_pair[1])
                if rel_x[-1] >= cur_N:
                    rel_y.append(0)
                else:
                    rel_y.append(cDamp[rel_x[-1]])
                fig.annotate('',
                             xy=(rel_x[-2], rel_y[-2]),
                             xytext=(rel_x[-1], rel_y[-1]),
                             arrowprops=arrowprops)

            #plt.grid(True)
            plt.plot(rel_x, rel_y, '.b')
            plt.plot(cDamp)
            # reference point
            # plt.plot(tarpos,cDamp[tarpos-xL],'ro')
            plt.xlim(0, len(cDamp) - 1)
            plt.title('DWT Level ({}):'.format(i - 1))
        # plot result
        if showFigure == True:
            plt.show()
        # save fig
        if savefigname is not None:
            Fig_main.savefig(savefigname, dpi=Fig_main.dpi)
            Fig_main.clf()

    def get_min_Importance_threshold(self,
                                     relation_importance,
                                     top_importance_ratio=9.5 / 10):
        if relation_importance == None or len(relation_importance) == 0:
            raise Exception('relation_importance is empty!')
        imp_list = []
        for layer in relation_importance:
            # unzip the lists
            pairs, imps = zip(*layer)
            imp_list.extend(imps)
        # get the min_importance threshold
        N = len(imp_list)
        midpos = int(top_importance_ratio * float(N))
        imp_list.sort()
        return imp_list[midpos]

    def plot_dwt_pairs_arrow_transparent(self,
                                         rawsig,
                                         relation_importance,
                                         Window_Left=1200,
                                         savefigname=None,
                                         figsize=(8, 6),
                                         figtitle='ECG Sample',
                                         showFigure=True):
        ## =========================================V
        # 展示RSWT示意图
        # Plot Arrow
        # with alpha value: png , pdf file
        ## =========================================V
        #
        #================
        # constants
        #================
        N = 5
        N_subplot = N + 2
        # importance pairs lower than this threshold is not shown in the figure
        Thres_min_importance = self.get_min_Importance_threshold(
            relation_importance)
        figureID = 1
        fs = conf['fs']
        FixedWindowLen = conf['winlen_ratio_to_fs'] * fs
        print 'Fixed Window Length:{}'.format(FixedWindowLen)
        xL = Window_Left
        xR = xL + FixedWindowLen
        tarpos = 1500
        # props of ARROW
        arrowprops = dict(width=1, headwidth=4, facecolor='black', shrink=0)
        # -----------------
        # get median
        # -----------------
        importance_arr = []
        for rel_layer in relation_importance:
            for rel, imp in rel_layer:
                importance_arr.append(imp)
        N_imp = len(importance_arr)
        # ascending order:0->1
        importance_arr.sort()
        IMP_Thres = importance_arr[int(N_imp / 2)]
        IMP_MAX = importance_arr[-1]
        # get wavelet obj
        # dwt coef
        dwtECG = WTfeature()
        waveletobj = dwtECG.gswt_wavelet()
        # props of ARROW
        #arrowprops = dict(width = 1,headwidth = 4,facecolor='black', shrink=0)
        pltxLim = range(xL, xR)
        sigAmp = [rawsig[x] for x in pltxLim]
        cA = rawsig
        # ====================
        # plot raw signal input
        # ====================
        Fig_main = plt.figure(figureID, figsize=figsize)
        # plot raw ECG
        plt.subplot((N_subplot + 1) / 2, 2, 1)
        # get handle for annote arrow
        # hide axis
        #frame = plt.gca()
        #frame.axes.get_xaxis().set_visible(False)
        #frame.axes.get_yaxis().set_visible(False)
        plt.plot(pltxLim, sigAmp)
        # plot reference point
        #plt.plot(tarpos,rawsig[tarpos],'ro')
        #plt.title(figtitle+'[Window Left = {}]'.format(Window_Left))
        plt.title(figtitle)
        plt.xlim(pltxLim[0], pltxLim[-1])

        for i in range(2, N_subplot):
            # relation&importance
            rel_layer = relation_importance[i - 2]
            cA, cD = pywt.dwt(cA, waveletobj)
            xL /= 2
            xR /= 2
            tarpos /= 2
            # crop x range out
            xi = range(xL, xR)
            cDamp = [cD[x] for x in xi]
            # get relation points
            rel_x = []
            rel_y = []
            cur_N = len(cDamp)
            # ------------
            # sub plot
            # ------------
            #fig = plt.subplot(N+1,1,i)
            fig = plt.subplot((N_subplot + 1) / 2, 2, i)
            plt.title('DWT Detail Coefficient {}'.format(i - 1))
            #------------
            # find pair&its amplitude
            # -----------
            # sort rel_layer with imp
            rel_layer.sort(key=lambda x: x[1])
            for rel_pair, imp in rel_layer:
                # do not show imp lower than threshold
                if imp < Thres_min_importance:
                    continue
                # Importance thres
                alpha = (imp - Thres_min_importance) / (IMP_MAX -
                                                        Thres_min_importance)
                # Increase alpha for better visual effect.
                alpha_increase_ratio = 0.95
                alpha = alpha * alpha_increase_ratio + 1.0 - alpha_increase_ratio

                #arrow_color = self.get_RGB_from_Alpha((1,0,0),alpha,(1,1,1))
                arrow_color = (1, 0, 0)
                arrowprops = dict(alpha=alpha,
                                  width=0.15,
                                  linewidth=0.15,
                                  headwidth=1.5,
                                  headlength=1.5,
                                  facecolor=arrow_color,
                                  edgecolor=arrow_color,
                                  shrink=0)

                rel_x.append(rel_pair[0])
                if rel_x[-1] >= cur_N:
                    rel_y.append(0)
                else:
                    rel_y.append(cDamp[rel_pair[0]])
                rel_x.append(rel_pair[1])
                if rel_x[-1] >= cur_N:
                    rel_y.append(0)
                else:
                    rel_y.append(cDamp[rel_x[-1]])
                fig.annotate('',
                             xy=(rel_x[-2], rel_y[-2]),
                             xytext=(rel_x[-1], rel_y[-1]),
                             arrowprops=arrowprops)

            #plt.grid(True)
            plt.plot(rel_x, rel_y, '.b')
            plt.plot(cDamp)
            # reference point
            # plt.plot(tarpos,cDamp[tarpos-xL],'ro')
            plt.xlim(0, len(cDamp) - 1)
        # plot Approximation Level
        rel_x = []
        rel_y = []
        rel_layer = relation_importance[-1]
        #fig = plt.subplot((N_subplot + 1)/2,2,N_subplot)
        fig = plt.subplot(4, 1, 4)
        plt.title('Approximation Coefficient')
        cAamp = [cA[x] for x in xi]
        # sort rel_layer with imp
        rel_layer.sort(key=lambda x: x[1])
        for rel_pair, imp in rel_layer:
            # do not show imp lower than threshold
            if imp < Thres_min_importance:
                continue
            # importance thres
            alpha = (imp - Thres_min_importance) / (IMP_MAX -
                                                    Thres_min_importance)
            #arrow_color = self.get_RGB_from_Alpha((1,0,0),alpha,(1,1,1))
            arrow_color = (1, 0, 0)
            #arrowprops = dict(width = 1,headwidth = 4,facecolor=arrow_color,edgecolor = arrow_color,shrink=0)
            arrowprops = dict(alpha=alpha,
                              width=0.15,
                              linewidth=0.15,
                              headwidth=1.5,
                              headlength=1.5,
                              facecolor=arrow_color,
                              edgecolor=arrow_color,
                              shrink=0)

            rel_x.append(rel_pair[0])
            if rel_x[-1] >= cur_N:
                rel_y.append(0)
            else:
                rel_y.append(cAamp[rel_pair[0]])
            rel_x.append(rel_pair[1])
            if rel_x[-1] >= cur_N:
                rel_y.append(0)
            else:
                rel_y.append(cAamp[rel_x[-1]])
            fig.annotate('',
                         xy=(rel_x[-2], rel_y[-2]),
                         xytext=(rel_x[-1], rel_y[-1]),
                         arrowprops=arrowprops)

        # reference point
        plt.plot(rel_x, rel_y, '.b')
        plt.plot(cAamp)
        plt.xlim(0, len(cAamp) - 1)

        # plot result
        if showFigure == True:
            plt.show()
        # save fig
        if savefigname is not None:
            Fig_main.savefig(savefigname, dpi=Fig_main.dpi)
            Fig_main.clf()

    def plot_dwt_pairs_arrow(self,
                             rawsig,
                             relation_importance,
                             Window_Left=1200,
                             savefigname=None,
                             figsize=(10, 8),
                             figtitle='ECG Sample',
                             showFigure=True):
        ## =========================================V
        # 展示RSWT示意图
        # Plot Arrow
        ## =========================================V
        #
        #================
        # constants
        #================
        N = 5
        N_subplot = N + 2
        # importance pairs lower than this threshold is not shown in the figure
        Thres_min_importance = self.get_min_Importance_threshold(
            relation_importance)
        figureID = 1
        fs = conf['fs']
        FixedWindowLen = conf['winlen_ratio_to_fs'] * fs
        print 'Fixed Window Length:{}'.format(FixedWindowLen)
        xL = Window_Left
        xR = xL + FixedWindowLen
        tarpos = 1500
        # props of ARROW
        arrowprops = dict(width=1, headwidth=4, facecolor='black', shrink=0)
        # -----------------
        # get median
        # -----------------
        importance_arr = []
        for rel_layer in relation_importance:
            for rel, imp in rel_layer:
                importance_arr.append(imp)
        N_imp = len(importance_arr)
        # ascending order:0->1
        importance_arr.sort()
        IMP_Thres = importance_arr[int(N_imp / 2)]
        IMP_MAX = importance_arr[-1]
        # get wavelet obj
        # dwt coef
        dwtECG = WTfeature()
        waveletobj = dwtECG.gswt_wavelet()
        # props of ARROW
        #arrowprops = dict(width = 1,headwidth = 4,facecolor='black', shrink=0)
        pltxLim = range(xL, xR)
        sigAmp = [rawsig[x] for x in pltxLim]
        cA = rawsig
        # ====================
        # plot raw signal input
        # ====================
        Fig_main = plt.figure(figureID, figsize=figsize)
        # plot raw ECG
        plt.subplot((N_subplot + 1) / 2, 2, 1)
        # get handle for annote arrow
        # hide axis
        #frame = plt.gca()
        #frame.axes.get_xaxis().set_visible(False)
        #frame.axes.get_yaxis().set_visible(False)
        plt.plot(pltxLim, sigAmp)
        # plot reference point
        #plt.plot(tarpos,rawsig[tarpos],'ro')
        #plt.title(figtitle+'[Window Left = {}]'.format(Window_Left))
        plt.title(figtitle)
        plt.xlim(pltxLim[0], pltxLim[-1])

        for i in range(2, N_subplot):
            # relation&importance
            rel_layer = relation_importance[i - 2]
            cA, cD = pywt.dwt(cA, waveletobj)
            xL /= 2
            xR /= 2
            tarpos /= 2
            # crop x range out
            xi = range(xL, xR)
            cDamp = [cD[x] for x in xi]
            # get relation points
            rel_x = []
            rel_y = []
            cur_N = len(cDamp)
            # ------------
            # sub plot
            # ------------
            #fig = plt.subplot(N+1,1,i)
            fig = plt.subplot((N_subplot + 1) / 2, 2, i)
            plt.title('DWT Detail Coefficient {}'.format(i - 1))
            #------------
            # find pair&its amplitude
            # -----------
            # sort rel_layer with imp
            rel_layer.sort(key=lambda x: x[1])
            for rel_pair, imp in rel_layer:
                # do not show imp lower than threshold
                if imp < Thres_min_importance:
                    continue
                # importance thres
                alpha = (imp - Thres_min_importance) / (IMP_MAX -
                                                        Thres_min_importance)
                arrow_color = self.get_RGB_from_Alpha((1, 0, 0), alpha,
                                                      (1, 1, 1))
                arrowprops = dict(width=0.15,
                                  linewidth=0.15,
                                  headwidth=1.5,
                                  headlength=1.5,
                                  facecolor=arrow_color,
                                  edgecolor=arrow_color,
                                  shrink=0)

                rel_x.append(rel_pair[0])
                if rel_x[-1] >= cur_N:
                    rel_y.append(0)
                else:
                    rel_y.append(cDamp[rel_pair[0]])
                rel_x.append(rel_pair[1])
                if rel_x[-1] >= cur_N:
                    rel_y.append(0)
                else:
                    rel_y.append(cDamp[rel_x[-1]])
                fig.annotate('',
                             xy=(rel_x[-2], rel_y[-2]),
                             xytext=(rel_x[-1], rel_y[-1]),
                             arrowprops=arrowprops)

            #plt.grid(True)
            plt.plot(rel_x, rel_y, '.b')
            plt.plot(cDamp)
            # reference point
            # plt.plot(tarpos,cDamp[tarpos-xL],'ro')
            plt.xlim(0, len(cDamp) - 1)
        # plot Approximation Level
        rel_x = []
        rel_y = []
        rel_layer = relation_importance[-1]
        #fig = plt.subplot((N_subplot + 1)/2,2,N_subplot)
        fig = plt.subplot(4, 1, 4)
        plt.title('Approximation Coefficient')
        cAamp = [cA[x] for x in xi]
        # sort rel_layer with imp
        rel_layer.sort(key=lambda x: x[1])
        for rel_pair, imp in rel_layer:
            # do not show imp lower than threshold
            if imp < Thres_min_importance:
                continue
            # importance thres
            alpha = (imp - Thres_min_importance) / (IMP_MAX -
                                                    Thres_min_importance)
            arrow_color = self.get_RGB_from_Alpha((1, 0, 0), alpha, (1, 1, 1))
            #arrowprops = dict(width = 1,headwidth = 4,facecolor=arrow_color,edgecolor = arrow_color,shrink=0)
            arrowprops = dict(width=0.15,
                              linewidth=0.15,
                              headwidth=1.5,
                              headlength=1.5,
                              facecolor=arrow_color,
                              edgecolor=arrow_color,
                              shrink=0)

            rel_x.append(rel_pair[0])
            if rel_x[-1] >= cur_N:
                rel_y.append(0)
            else:
                rel_y.append(cAamp[rel_pair[0]])
            rel_x.append(rel_pair[1])
            if rel_x[-1] >= cur_N:
                rel_y.append(0)
            else:
                rel_y.append(cAamp[rel_x[-1]])
            fig.annotate('',
                         xy=(rel_x[-2], rel_y[-2]),
                         xytext=(rel_x[-1], rel_y[-1]),
                         arrowprops=arrowprops)

        # reference point
        plt.plot(rel_x, rel_y, '.b')
        plt.plot(cAamp)
        plt.xlim(0, len(cAamp) - 1)

        # plot result
        if showFigure == True:
            plt.show()
        # save fig
        if savefigname is not None:
            Fig_main.savefig(savefigname, dpi=Fig_main.dpi)
            Fig_main.clf()

    def get_RGB_from_Alpha(self, color, alpha, bgcolor):
        new_color = []
        for color_elem, bg_color_elem in zip(color, bgcolor):
            color_elem = float(color_elem)
            bg_color_elem = float(bg_color_elem)
            ncolor = (1.0 - alpha) * bg_color_elem + alpha * color_elem
            new_color.append(ncolor)
        return new_color

    def plot_dwt_pairs_no_arrow(self, rawsig, relation_importance):
        ## =========================================V
        # 展示RSWT示意图
        # No Arrow, only points
        ## =========================================V
        N = 5
        figureID = 1
        fs = conf['fs']
        FixedWindowLen = conf['winlen_ratio_to_fs'] * fs
        print 'Fixed Window Length:{}'.format(FixedWindowLen)
        xL = 1000
        xR = xL + FixedWindowLen
        tarpos = 1500
        # -----------------
        # get median
        # -----------------
        importance_arr = []
        for rel_layer in relation_importance:
            for rel, imp in rel_layer:
                importance_arr.append(imp)
        N_imp = len(importance_arr)
        # ascending order:0->1
        importance_arr.sort()
        IMP_Thres = importance_arr[int(N_imp / 2)]
        # get wavelet obj
        # dwt coef
        dwtECG = WTfeature()
        waveletobj = dwtECG.gswt_wavelet()
        # props of ARROW
        #arrowprops = dict(width = 1,headwidth = 4,facecolor='black', shrink=0)
        pltxLim = range(xL, xR)
        sigAmp = [rawsig[x] for x in pltxLim]
        cA = rawsig
        # ====================
        # plot raw signal input
        # ====================
        plt.figure(figureID)
        # plot raw ECG
        plt.subplot(N + 1, 1, 1)
        # get handle for annote arrow
        # hide axis
        #frame = plt.gca()
        #frame.axes.get_xaxis().set_visible(False)
        #frame.axes.get_yaxis().set_visible(False)
        plt.plot(pltxLim, sigAmp)
        # plot reference point
        #plt.plot(tarpos,rawsig[tarpos],'ro')
        plt.title('ECG sample')
        #plt.xlim(pltxLim)

        for i in range(2, N + 2):
            # relation&importance
            rel_layer = relation_importance[i - 2]
            cA, cD = pywt.dwt(cA, waveletobj)
            xL /= 2
            xR /= 2
            tarpos /= 2
            # crop x range out
            xi = range(xL, xR)
            cDamp = [cD[x] for x in xi]
            # get relation points
            rel_x = []
            rel_y = []
            cur_N = len(cDamp)
            for rel_pair, imp in rel_layer:
                rel_x.append(rel_pair[0])
                if rel_x[-1] >= cur_N:
                    rel_y.append(0)
                else:
                    rel_y.append(cDamp[rel_pair[0]])
                rel_x.append(rel_pair[1])
                if rel_x[-1] >= cur_N:
                    rel_y.append(0)
                else:
                    rel_y.append(cDamp[rel_x[-1]])
            # plot
            fig = plt.subplot(N + 1, 1, i)

            #plt.grid(True)
            plt.plot(rel_x, rel_y, '.b')
            plt.plot(cDamp)
            # reference point
            # plt.plot(tarpos,cDamp[tarpos-xL],'ro')
            plt.xlim(0, len(cDamp) - 1)
            plt.title('DWT Level ({}):'.format(i - 1))
        # plot result
        plt.show()

    def plot_fv_importance(self):
        # cycling plot importance of list of positions
        rID = 2
        sig = self.qt.load(self.qt_reclist[rID])
        rel_imp = self.feature_importance_test()
        # save current fig
        for i in xrange(0, 130, 10):
            savefigname = os.path.join(curfolderpath, 'range_{}.png'.format(i))
            self.plot_dwt_pairs_arrow(sig['sig'],
                                      rel_imp,
                                      Window_Left=1180 + i,
                                      savefigname=savefigname,
                                      figsize=(16, 12),
                                      figtitle='Window Start[{}]'.format(i))

    def plot_fv_importance_gswt(self,
                                savefigname,
                                showFigure,
                                WindowLeftBias=10):
        # QT record ID
        rID = 2
        sig = self.qt.load(self.qt_reclist[rID])
        # Get layers of list [pair,importance]
        # ----
        # Format:
        # [[(pair,importance),...],# layer 1
        # [...],# layer 2
        # ,]
        # ----
        rel_imp = self.get_pair_importance_list()
        # save current fig
        #self.plot_dwt_pairs_arrow_partial_window_compare(sig['sig'],rel_imp,Window_Left = 54100+i,savefigname = savefigname,figsize = (20,18),figtitle = 'Window Start[{}]'.format(i))
        figtitle = 'ECG from QTdb Record {}'.format(self.qt_reclist[rID])
        self.plot_dwt_pairs_arrow_transparent(sig['sig'],
                                              rel_imp,
                                              Window_Left=54100 +
                                              WindowLeftBias,
                                              savefigname=savefigname,
                                              figsize=(14, 14),
                                              figtitle=figtitle,
                                              showFigure=showFigure)

    def load_sig_test(self):
        pass
        rID = 1
        sig = self.qt.load(self.qt_reclist[rID])
        # random relations
        randrel_path = os.path.join(curfolderpath, 'rand_relations.json')
        with open(randrel_path, 'r') as fin:
            randrels = json.load(fin)
        pdb.set_trace()
        # get rand positions and WT coef
        dslist = wtfobj.getWTcoef_gswt(normwinsig)
        for detailsignal, randrels in zip(dslist, WTrelList):
            # debug:
            for x in randrels:
                for xval in x:
                    if xval < 0 or xval >= len(detailsignal):
                        print 'x = ', x
                        pdb.set_trace()

            fv = [detailsignal[x[0]] - detailsignal[x[1]] for x in randrels]
            features.extend(fv)
            fv = [
                abs(detailsignal[x[0]] - detailsignal[x[1]]) for x in randrels
            ]
            features.extend(fv)
        # todo..
        waveletobj = dwtECG.gswt_wavelet()
        self.plot_dwt_rswt(sig['sig'],
                           waveletobj=waveletobj,
                           ECGrecordname=recname,
                           auto_plot=True)

    def plot_dwt_rswt(self):
        ## =========================================V
        # 展示RSWT示意图
        ## =========================================V
        N = 5
        xL, xR = 1000, 1600
        tarpos = 1500
        # props of ARROW
        arrowprops = dict(width=1, headwidth=4, facecolor='black', shrink=0)
        pltxLim = range(xL, xR)
        sigAmp = [rawsig[x] for x in pltxLim]
        cA = rawsig
        # ====================
        # plot raw signal input
        # ====================
        plt.figure(figureID)
        # plot raw ECG
        plt.subplot(N + 1, 1, 1)
        # get handle for annote arrow
        # hide axis
        frame = plt.gca()
        frame.axes.get_xaxis().set_visible(False)
        frame.axes.get_yaxis().set_visible(False)
        plt.plot(pltxLim, sigAmp)
        # plot reference point
        #plt.plot(tarpos,rawsig[tarpos],'ro')
        plt.title('ECG sample')
        #plt.xlim(pltxLim)

        for i in range(2, N + 2):
            cA, cD = pywt.dwt(cA, waveletobj)
            xL /= 2
            xR /= 2
            tarpos /= 2
            # crop x range out
            xi = range(xL, xR)
            cDamp = [cD[x] for x in xi]
            # plot
            fig = plt.subplot(N + 1, 1, i)
            # get max
            amp_max = cDamp[0]
            amp_min = cDamp[0]
            for cval in cDamp:
                amp_max = max(amp_max, cval)
                amp_min = min(amp_min, cval)
            # left pair
            lpairtail = [len(cDamp) * 0.3, amp_min + (amp_max - amp_min) * 0.8]
            L, R = int(len(cDamp) * 0.1), int(len(cDamp) * 0.4)
            rspairhead = (random.randint(L, R - 1), random.randint(L + 1, R))
            lpairtail[0] = int((rspairhead[0] + rspairhead[1]) / 2)
            # debug print
            #print 'left pair tail:',lpairtail
            #print 'rand selected pair head:', rspairhead
            #print 'rs height:',cDamp[int(rspairhead[0])],cDamp[int(rspairhead[1])]
            plt.text(lpairtail[0], lpairtail[1], 'pair({}a)'.format(i - 1))
            color = '#fe3b8d'
            arrowprops = dict(width=1,
                              headwidth=4,
                              facecolor=color,
                              edgecolor=color,
                              shrink=0)
            fig.annotate('',
                         xy=(rspairhead[0], cDamp[int(rspairhead[0])]),
                         xytext=lpairtail,
                         arrowprops=arrowprops)
            fig.annotate('',
                         xy=(rspairhead[1], cDamp[int(rspairhead[1])]),
                         xytext=lpairtail,
                         arrowprops=arrowprops)
            # right pair
            rpairtail = [len(cDamp) * 0.7, amp_min + (amp_max - amp_min) * 0.8]
            L, R = int(len(cDamp) * 0.3), int(len(cDamp) * 0.8)
            rspairhead = (random.randint(L, R - 1), random.randint(L + 1, R))
            rpairtail[0] = int((rspairhead[0] + rspairhead[1]) / 2)
            # debug print
            #print 'left pair tail:',lpairtail
            #print 'rand selected pair head:', rspairhead
            #print 'rs height:',cDamp[int(rspairhead[0])],cDamp[int(rspairhead[1])]
            plt.text(rpairtail[0], rpairtail[1], 'pair({}b)'.format(i - 1))
            arrowprops = dict(width=1,
                              headwidth=4,
                              facecolor='g',
                              edgecolor='g',
                              shrink=0)
            fig.annotate('',
                         xy=(rspairhead[0], cDamp[int(rspairhead[0])]),
                         xytext=rpairtail,
                         arrowprops=arrowprops)
            fig.annotate('',
                         xy=(rspairhead[1], cDamp[int(rspairhead[1])]),
                         xytext=rpairtail,
                         arrowprops=arrowprops)
            # hide axis
            frame = plt.gca()
            frame.axes.get_xaxis().set_visible(False)
            frame.axes.get_yaxis().set_visible(False)

            #plt.grid(True)
            plt.plot(cDamp)
            # reference point
            # plt.plot(tarpos,cDamp[tarpos-xL],'ro')
            plt.xlim(0, len(cDamp) - 1)
            plt.title('DWT Level ({}):'.format(i - 1))
        if auto_plot is True:
            plt.show()

    def loadsig(self, sig):
        pass
class FeatureVisualizationSwt:
    def __init__(self, rfmdl, RandomPairList_Path):
        # get random relations
        randrel_path = RandomPairList_Path
        with open(randrel_path, 'r') as fin:
            self.randrels = json.load(fin)

        self.rfmdl = rfmdl
        self.trees = rfmdl.estimators_
        self.qt = QTloader()
        self.qt_reclist = self.qt.getQTrecnamelist()

    def info(self):
        t0 = self.trees[0]
        # list dir
        for attr in dir(t0):
            print attr
        print 'tree-------------------------'
        for attr in dir(t0.tree_):
            print attr

        pdb.set_trace()

    def GetImportancePairs(self):
        '''
        Zipping pairs from random pairs file & importance values from classification model.
        Get a struct of importance:
            [#layer0:[((23, 2), 0.2# importance),
                ((pair), #importance), ()...],
             #layer1:[], ..]
        
        return:
             pairs & their importance
        '''
        # random relations
        randrels = self.randrels
        importance_list = self.rfmdl.feature_importances_
        sum_rr = 0
        for rr in randrels:
            sum_rr += len(rr)
        print 'len(feature importance) = {}, len(random relations) = {}'.format(
            len(importance_list), sum_rr)

        print 'len(importance_list) = ', len(importance_list)
        print 'sum_rr * 2 = ', sum_rr * 2
        if len(importance_list) != 2 * sum_rr:
            raise Exception('Length of important features is not equal '
                            'to the total number of pairs!')

        layer_index_start, layer_index_end = 0, 0
        relation_importance = []
        for rel_layer in randrels:
            layer_size = len(rel_layer)
            layer_index_end = layer_index_start + 2 * layer_size

            pair_diff_importance_list = importance_list[
                layer_index_start:layer_index_end:2]
            abs_diff_importance_list = importance_list[layer_index_start +
                                                       1:layer_index_end:2]
            # Choose the maximum value from the importance of pair difference
            # and absolute pair difference.
            imp_arr = map(
                lambda x: max(x[0], x[1]),
                zip(pair_diff_importance_list, abs_diff_importance_list))
            relation_importance.append(zip(rel_layer, imp_arr))
            layer_index_start += 2 * layer_size

        return relation_importance

    def get_min_Importance_threshold(self,
                                     relation_importance,
                                     top_importance_ratio=9.5 / 10):
        '''Select top p% importance value.'''
        if relation_importance == None or len(relation_importance) == 0:
            raise Exception('relation_importance is empty!')
        imp_list = []
        for layer in relation_importance:
            # unzip the lists
            pairs, imps = zip(*layer)
            imp_list.extend(imps)
        # get the min_importance threshold
        N = len(imp_list)
        midpos = int(top_importance_ratio * float(N))
        imp_list.sort()
        return imp_list[midpos]

    def plot_dwt_pairs_arrow_transparent(self,
                                         rawsig,
                                         relation_importance,
                                         Window_Left=1200,
                                         savefigname=None,
                                         figureID=1,
                                         figsize=(8, 6),
                                         figtitle='ECG Sample',
                                         showFigure=True,
                                         wavelet='db6',
                                         swt_level=6,
                                         window_length_limit=350):
        ## =========================================
        # 展示RSWT示意图
        # Plot Arrow
        #   with alpha value: png, pdf file
        ## =========================================
        N_subplot = 7
        # Importance pairs lower than this threshold is not shown in the figure.
        Thres_min_importance = self.get_min_Importance_threshold(
            relation_importance, top_importance_ratio=0.96)
        fs = conf['fs']
        FixedWindowLen = conf['winlen_ratio_to_fs'] * fs
        print 'Fixed Window Length:{}'.format(FixedWindowLen)

        xL = Window_Left
        xR = xL + FixedWindowLen
        tarpos = (xL + xR) / 2
        # -----------------
        # Get median
        # -----------------
        importance_arr = []
        for rel_layer in relation_importance:
            importance_arr.extend([x[1] for x in rel_layer])
        N_imp = len(importance_arr)
        # Ascending order
        importance_arr.sort()
        IMP_MAX = importance_arr[-1]

        # Get SWT detail coefficient lists
        rawsig = self.crop_data_for_swt(rawsig)
        coeflist = pywt.swt(rawsig, wavelet, swt_level)
        cAlist, cDlist = zip(*coeflist)
        self.cAlist = cAlist[::-1]
        self.cDlist = cDlist[::-1]

        pltxLim = range(xL, xR)
        sigAmp = [rawsig[x] for x in pltxLim]

        detail_index = 1
        # ====================
        # Plot raw signal input
        # ====================
        Fig_main = plt.figure(figureID, figsize=figsize)
        # plot raw ECG
        plt.subplot(N_subplot / 2, 2, 1)
        # Get handle for annote arrow

        plt.plot(sigAmp)
        # Plot reference point
        plt.plot(tarpos, rawsig[tarpos], 'ro')
        plt.title(figtitle)

        window_center = len(sigAmp) / 2
        window_limit_left = int(window_center - window_length_limit / 2)
        window_limit_right = int(window_center + window_length_limit / 2)
        plt.xlim(window_limit_left, window_limit_right)

        plt.grid(True)

        # debug
        # plt.figure(3)
        # plt.plot(self.cDlist[0])
        # plt.title("Detail level 0")
        # plt.grid(True)

        plt.figure(1)

        for i in range(2, N_subplot):
            # Relation&importance
            rel_layer = relation_importance[detail_index - 1]
            cD = self.cDlist[detail_index]
            detail_index += 1

            # Crop x range out
            xi = range(xL, xR)
            cDamp = [cD[x] for x in xi]

            window_center = len(cDamp) / 2
            window_limit_left = int(window_center - window_length_limit / 2)
            window_limit_right = int(window_center + window_length_limit / 2)

            # get relation points
            rel_x = []
            rel_y = []
            cur_N = len(cDamp)
            # ------------
            # sub plot
            # ------------
            fig = plt.subplot(N_subplot / 2, 2, i)
            plt.title('DWT Detail Coefficient {}'.format(detail_index))
            # ------------
            # find pair&its amplitude
            # -----------
            rel_layer = filter(lambda x: x[1] >= Thres_min_importance,
                               rel_layer)
            relation_pair_list, layer_importance_list = zip(*rel_layer)
            # Normalize layer_importance_list in this layer
            layer_importance_list = np.array(layer_importance_list)
            layer_importance_list = layer_importance_list.reshape(1, -1)
            layer_importance_list = sk_pre.normalize(layer_importance_list,
                                                     norm='l2')
            layer_importance_list = layer_importance_list.tolist()[0]
            # print "type:", type(layer_importance_list)
            # print "len:", len(layer_importance_list)
            # pdb.set_trace()

            for rel_pair, imp in zip(relation_pair_list,
                                     layer_importance_list):
                # Do not show imp lower than threshold
                if imp < Thres_min_importance:
                    continue
                # Importance thres
                alpha = (imp - Thres_min_importance) / (IMP_MAX -
                                                        Thres_min_importance)
                # Increase alpha for better visual effect.
                alpha_increase_ratio = 0.8
                alpha = alpha * alpha_increase_ratio + 1.0 - alpha_increase_ratio
                # alpha = 1.0

                #arrow_color = self.get_RGB_from_Alpha((1,0,0),alpha,(1,1,1))
                arrow_color = (1, 0, 0)
                arrowprops = dict(alpha=alpha,
                                  width=0.15,
                                  linewidth=0.15,
                                  headwidth=1.5,
                                  headlength=1.5,
                                  facecolor=arrow_color,
                                  edgecolor=arrow_color,
                                  shrink=0)

                rel_x.append(rel_pair[0])
                if rel_x[-1] >= cur_N:
                    rel_y.append(0)
                else:
                    rel_y.append(cDamp[rel_pair[0]])
                rel_x.append(rel_pair[1])
                if rel_x[-1] >= cur_N:
                    rel_y.append(0)
                else:
                    rel_y.append(cDamp[rel_x[-1]])
                # Hide rel_x that not in the xlim window.
                out_of_bound = False
                for temp_x in rel_x[-2:len(rel_x)]:
                    if (temp_x > window_limit_right
                            or temp_x < window_limit_left):
                        out_of_bound = True
                        break
                if out_of_bound:
                    continue

                fig.annotate('',
                             xy=(rel_x[-2], rel_y[-2]),
                             xytext=(rel_x[-1], rel_y[-1]),
                             arrowprops=arrowprops)

            plt.plot(rel_x, rel_y, '.b')
            plt.plot(cDamp)
            # reference point
            plt.plot(tarpos - xL,
                     cDamp[tarpos - xL],
                     'yo',
                     markersize=12,
                     mec='b')

            plt.xlim(window_limit_left, window_limit_right)

            plt.grid(True)

        # self.PlotApproximationLevel(relation_importance[-1])

        # plot result
        if showFigure == True:
            plt.show()
        # save fig
        if savefigname is not None:
            Fig_main.savefig(savefigname, dpi=Fig_main.dpi)
            Fig_main.clf()

    def PlotApproximationLevel(self, importance_list):
        '''Plot Approximation level importance.'''
        # plot Approximation Level
        rel_x = []
        rel_y = []

        rel_layer = importance_list
        fig = plt.subplot(4, 2, 8)
        plt.title('Approximation Coefficient')
        cAamp = [self.cAlist[-1][x] for x in xi]

        # sort rel_layer with imp
        rel_layer.sort(key=lambda x: x[1])
        for rel_pair, imp in rel_layer:
            # Hide importance lower than threshold
            if imp < Thres_min_importance:
                continue
            # importance thres
            alpha = (imp - Thres_min_importance) / (IMP_MAX -
                                                    Thres_min_importance)
            #arrow_color = self.get_RGB_from_Alpha((1,0,0),alpha,(1,1,1))
            arrow_color = (1, 0, 0)
            #arrowprops = dict(width = 1, headwidth = 4, facecolor=arrow_color, edgecolor = arrow_color, shrink=0)
            arrowprops = dict(alpha=alpha,
                              width=0.15,
                              linewidth=0.15,
                              headwidth=1.5,
                              headlength=1.5,
                              facecolor=arrow_color,
                              edgecolor=arrow_color,
                              shrink=0)

            rel_x.append(rel_pair[0])
            if rel_x[-1] >= cur_N:
                rel_y.append(0)
            else:
                rel_y.append(cAamp[rel_pair[0]])
            rel_x.append(rel_pair[1])
            if rel_x[-1] >= cur_N:
                rel_y.append(0)
            else:
                rel_y.append(cAamp[rel_x[-1]])
            fig.annotate('',
                         xy=(rel_x[-2], rel_y[-2]),
                         xytext=(rel_x[-1], rel_y[-1]),
                         arrowprops=arrowprops)

        # reference point
        plt.plot(rel_x, rel_y, '.b')
        plt.plot(cAamp)
        plt.xlim(0, len(cAamp) - 1)

    def crop_data_for_swt(self, rawsig):
        '''Padding zeros to make the length of the signal to 2^N.'''
        # crop rawsig
        base2 = 1
        N_data = len(rawsig)
        if len(rawsig) <= 1:
            raise Exception('len(rawsig)={},not enough for swt!', len(rawsig))
        crop_len = base2
        while base2 < N_data:
            if base2 * 2 >= N_data:
                crop_len = base2 * 2
                break
            base2 *= 2
        # Extending this signal input with its tail element.
        if N_data < crop_len:
            rawsig += [
                rawsig[-1],
            ] * (crop_len - N_data)
        return rawsig

    def get_RGB_from_Alpha(self, color, alpha, bgcolor):
        new_color = []
        for color_elem, bg_color_elem in zip(color, bgcolor):
            color_elem = float(color_elem)
            bg_color_elem = float(bg_color_elem)
            ncolor = (1.0 - alpha) * bg_color_elem + alpha * color_elem
            new_color.append(ncolor)
        return new_color

    def plot_fv_importance(self):
        # cycling plot importance of list of positions
        rID = 2
        sig = self.qt.load(self.qt_reclist[rID])
        rel_imp = self.GetImportancePairs()
        # save current fig
        for i in xrange(0, 130, 10):
            savefigname = os.path.join(curfolderpath, 'range_{}.png'.format(i))
            self.plot_dwt_pairs_arrow(sig['sig'],
                                      rel_imp,
                                      Window_Left=1180 + i,
                                      savefigname=savefigname,
                                      figsize=(16, 12),
                                      figtitle='Window Start[{}]'.format(i))

    def plot_fv_importance_gswt(self,
                                savefigname,
                                showFigure,
                                WindowLeftBias=10):
        # QT record ID
        rID = 2
        # sig = self.qt.load(self.qt_reclist[rID])
        sig = self.qt.load('sel103')

        rel_imp = self.GetImportancePairs()
        # Save current fig
        figtitle = 'ECG from QTdb Record {}'.format(self.qt_reclist[rID])
        self.plot_dwt_pairs_arrow_transparent(sig['sig'],
                                              rel_imp,
                                              Window_Left=54100 +
                                              WindowLeftBias,
                                              savefigname=savefigname,
                                              figsize=(14, 14),
                                              figtitle=figtitle,
                                              showFigure=showFigure)
def get_QTdb_recordname(index=1):
    QTdb = QTloader()
    reclist = QTdb.getQTrecnamelist()
    return reclist[index]
Ejemplo n.º 17
0
class RecSelector():
    def __init__(self):
        self.qt = QTloader()

    def inspect_recs(self):
        reclist = self.qt.getQTrecnamelist()
        set_testing = set(conf['selQTall0_test_set'])
        set_training = set(reclist) - set_testing

        out_reclist = set_training

        # records selected
        selected_record_list = []
        for ind, recname in enumerate(out_reclist):
            # inspect
            print '{} records left.'.format(len(out_reclist) - ind - 1)
            # plot
            QTsig = self.qt.load(recname)
            rawsig = QTsig['sig']
            # expert labels
            testresult = self.qt.getexpertlabeltuple(recname)
            poslist, labellist = zip(*testresult)
            poslist = list(poslist)
            poslist.sort()
            dispRange = (poslist[0], poslist[-1])
            resplt = ECGResultPloter(rawsig, testresult)
            resplt.plot(plotTitle='QT record {}'.format(recname),
                        dispRange=dispRange)
            #self.qt.plotrec(recname)
            usethis = raw_input('Use this record as training record?(y/n):')
            if usethis == 'y':
                selected_record_list.append(recname)
            # debug
            #if ind > 2:
            #pass
        with open(
                os.path.join(os.path.dirname(curfilepath),
                             'selected_training_records.json'), 'w') as fout:
            json.dump(selected_record_list, fout)

    def pick_invalid_record(self):
        reclist = self.qt.getQTrecnamelist()
        invalidrecordlist = []
        for ind, recname in enumerate(reclist):
            # inspect
            print 'Inspecting record ' '{}' ''.format(recname)
            sig = self.qt.load(recname)
            if abs(sig['sig'][0]) == float('inf'):
                invalidrecordlist.append(recname)
        return invalidrecordlist

    def save_recs_to_img(self):
        reclist = self.qt.getQTrecnamelist()
        sel1213 = conf['sel1213']
        sel1213set = set(sel1213)

        out_reclist = set(reclist) - sel1213set

        for ind, recname in enumerate(reclist):
            # inspect
            print '{} records left.'.format(len(out_reclist) - ind - 1)
            #self.qt.plotrec(recname)
            self.qt.PlotAndSaveRec(recname)
            # debug
            if ind > 9:
                pass
                #print 'debug break'
                #break

    def inspect_recname(self, tarrecname):
        self.qt.plotrec(tarrecname)

    def inspect_selrec(self):
        QTreclist = self.qt.getQTrecnamelist()
        with open(
                os.path.join(curfolderpath, 'selected_training_records.json'),
                'r') as fin:
            sel0115 = json.load(fin)
        #sel0115 = conf['selQTall0']
        #test_reclist = set(QTreclist) - set(sel0115)
        for ind, recname in enumerate(sel0115):
            print '{} records left.'.format(len(sel0115) - ind - 1)
            self.inspect_recname(recname)

    def RFtest(self, testrecname):
        ecgrf = ECGRF()
        sel1213 = conf['sel1213']
        ecgrf.training(sel1213)
        Results = ecgrf.testing([
            testrecname,
        ])
        # Evaluate result
        filtered_Res = ECGRF.resfilter(Results)
        stats = ECGstats(filtered_Res[0:1])
        Err, FN = stats.eval(debug=False)

        # write to log file
        EvalLogfilename = os.path.join(projhomepath, 'res.log')
        stats.dispstat0(\
                pFN = FN,\
                pErr = Err)
        # plot prediction result
        stats.plotevalresofrec(Results[0][0], Results)