コード例 #1
ファイル: Controller.py プロジェクト: JohnCEarls/AUREA
class Controller:
    def __init__(self, workspace ):
        This is the controller (MVC design pattern) for the GUI.
        workspace is the folder that contains the config file and where
        temporary files and data files will be brought.
        self.softFile = []
        self.geneNetworkFile = None
        self.geneSynonymFile = None
        self.softparser = []
        self.datatable = []
        self.datapackage = None
        self.workspace = workspace
        configFile = os.path.join(self.workspace, 'data', 'config.xml')
        logo = os.path.join(self.workspace, 'data', 'AUREA-logo-200.pgm')
        if not os.path.exists(configFile):
            raise Exception, configFile + " not found.  Exiting"
        if not os.path.exists(logo):
            raise Exception, logo + " not found.  Exiting"
        self.config = SettingsParser(configFile)
        self.dirac = None
        self.tsp = None
        self.ktsp = None
        self.tst = None
        self.adaptive = None

        #for classification page
        self.tsp_classified_results = []
        self.ktsp_classified_results = []
        self.tst_classified_results = []
        self.dirac_classified_results = []
        self.adaptive_classified_results = []

        #stores truth tables from crossValidation
        self.tsp_tt = None
        self.ktsp_tt = None
        self.tst_tt = None
        self.dirac_tt = None
        self.adaptive_tt = None

        #stores tuples with apparent accuracy of learners
        self.tsp_acc = None
        self.ktsp_acc = None
        self.tst_acc = None
        self.dirac_acc = None
        self.adaptive_acc = None       

        self.dependency_state = [0 for x in range(AUREARemote.NumStates)]#see App.AUREARemote for mappings

    def setSOFTFile(self, softFile):
        raise Exception, "controller.setSOFTFile deprecated"
        self.softFile = softFile

    def addSOFTFile(self, softFile):   
        self.softFile.append( softFile ) 

    def setGeneNetworkFile(self, gnFile):
        self.geneNetworkFile = gnFile

    def setSynonymFile(self, sFile):
        self.geneSynonymFile = sFile

    def getNetworkInfo(self):
        if self.datapackage is not None:
            fname = self.geneNetworkFile
            gnc = self.datapackage.getGeneNetCount()
            if gnc is None:
                return None
            count, ave, max, min = gnc
            return (fname, count, ave, max, min)
            return None

    def setApp(self, app):
        self.app = app
        self.remote = app.remote
        self.queue = self.app.thread_message_queue

    def updateState(self, dependency, satisfied):
        Given a dependency, update to satisfied in global dependencies, propagating change
        #update provided state
        self.dependency_state[dependency] = satisfied
        #clear all dependents
        for i,d in enumerate(self.remote.getDependents(dependency)):
            if d == 1:
                self.updateState(i, 0)

    def initWorkspace(self):
        Initialize the workspace files.
        Copy from the system any necessary files. (config, gene_syn, etc)
        raise Exception, "initWorkspace is deprecated, download the workspace.zip file"

    def downloadSOFT(self, softfilename):
        Note raises urllib2.URLError when the download attempt fails
        self.queue.put(('statusbarset',"Downloading " + softfilename))
        dl = SOFTDownloader(softfilename, output_directory=self.app.data_dir)
        return dl.getFilePath()

    def loadFiles(self):
        Loads all of the specified files
        if self.geneNetworkFile:
        if self.geneSynonymFile:
            self.queue.put(('statusbarset',"Loading Synonyms"))
        self.queue.put(('statusbarset',"Data import complete"))

    def validFileFormat(self, filename):
        Takes a file name (not a full path) and returns true
        if it is a parseable format
        NOTE: if you add a parser, edit this method
        import re
        gdssoft = re.compile(r'GDS\d{3,4}\.soft[\.gz]?')
        csv = re.compile(r'\w+\.csv')
        return gdssoft.match(filename) or csv.match(filename)
    def unloadFiles(self):
        self.softparser = []
        self.geneNetworkParser = None
        self.datapackage = None
        self.datatable = []
        self.softFile = []

    def parseSOFTFiles(self):
        adds SOFTparser objects to softparser
        for fil in self.softFile:
            self.queue.put(('statusbarset',"Parsing " + fil))
            if fil[-3:] == "csv":
                gc = self.config.getSetting("datatable",fil +  "(Gene Column)")[0]
                pc = self.config.getSetting("datatable",fil +  "(Probe Column)")[0]
                self.softparser.append(CSVParser(fil,probe_column_name=pc, gene_column_name=gc ))

    def buildDataTables(self):
        Builds a table for each data file
        self.queue.put(('statusbarset',"Building tables"))
        collision = self.config.getSetting("datatable", "Gene Collision Rule")[0]
        bad_data =self.config.getSetting("datatable", "Bad Data Value")[0] 
        gene_column = self.config.getSetting("datatable", "Gene Column")[0]
        probe_column = self.config.getSetting("datatable", "Probe Column")[0]
        for sf in self.softparser:
            if isinstance(sf, SOFTParser):
                self.datatable.append( DataTable(probe_column, gene_column, collision, bad_data))
                self.datatable.append( DataTable(probe_column, gene_column, collision, bad_data))      

    def buildDataPackage(self):
        Merges all the data tables into one package from which we pull our
        data for learning
        self.queue.put(('statusbarset',"Building data package"))
        self.datapackage = dataPackager()
        for dt in self.datatable:

    def getDataPackagingResults(self):
        Returns a list of tuples with
        (genes in merge, probes in merge)
        if self.datapackage is not None:
            return self.datapackage.getDataCount()
            return (None,None)

    def getLearnerAccuracy(self):
        Returns the apparent accuracy of the learners over the training set
        (TSP,kTSP,TST,DiRaC, Adaptive)
        return (self.tsp_acc, self.ktsp_acc,self.tst_acc,self.dirac_acc,self.adaptive_acc)

    def getCrossValidationResults(self):
        Returns the cross validation accuracy of the learners over 
        the training set
        (TSP,kTSP,TST,DiRaC, Adaptive)
        return (self.tsp_tt,self.ktsp_tt,self.tst_tt,self.dirac_tt,self.adaptive_tt)

    def getCVTruthTables(self):
        return (self.tsp_tt,self.ktsp_cv,self.tst_cv,self.dirac_cv,self.adaptive_cv)
    def parseNetworkFile(self):
        Parse network file and add networks to datapackage
        self.queue.put(('statusbarset',"Parsing network file"))
        self.geneNetworkParser = GMTParser(self.geneNetworkFile)
        self.queue.put(('statusbarset',"Loading gene networks"))
    def createClassification(self, page):
        Gets and sets the class labels
        c1 = self.class1name = page.className1.get().strip()
        c2 = self.class2name = page.className2.get().strip()

    def getClassificationInfo(self):
        Returns the names and sizes of the partitioned classes 
        (c1name, c1size, c2name, c2size)
        ('',0,'',0) is returned if classifications have not been created
        if self.datapackage is not None:
            classinfo = self.datapackage.getClassifications()
            if len(classinfo) > 0:
                return (classinfo[0][0], len(classinfo[0][1]), classinfo[1][0], len(classinfo[1][1]))
        return ('',0,'',0)

    def partitionClasses(self, class1List, class2List):
        Puts the samples into their chosen classes
        for table, sample in class1List:
            self.datapackage.addToClassification(self.class1name, table, sample)

        for table, sample in class2List:
            self.datapackage.addToClassification(self.class2name, table, sample)

    def getSamples(self):
        Returns a list of strings describing all available samples in the data tables
        [ '[dt1].samp_name', '[dt2].samp_name', ...]
        sample_list = []
        for table in self.datapackage.getTables():
            table_id = table.dt_id
            for sample_id in table.getSamples():
                sample_list.append(self._makeSampleString(table_id, sample_id))
        return sample_list
    def _makeSampleString(self, table_id, sample_id):
        helper to keep sample strings consistent
        return "[" + table_id + "]." + sample_id

    def getSubsets(self):
        Returns a list of 2-tuples, 
        (description, list of sample names formatted to match getSamples),
        sample_list = self.getSamples()[:]
        sample_set = set(self.getSamples())
        subset_list = []        

        for table in self.datapackage.getTables():
            table_id = table.dt_id
            for ssetdesc, ssetsamples in table.subsets:
                sschecked_list = []
                for sample_id in ssetsamples:
                    sssid = self._makeSampleString(table_id, sample_id)
                    if sssid in sample_set:
                subset_list.append((ssetdesc, sschecked_list))

        return subset_list

    def clearClassSamples(self):
        Removes samples from any classifications
        if self.datapackage is not None:
    def getUntrainedSamples(self):
        This is basically getSamples with the training set removed
        currentClassifications = self.datapackage.getClassifications()
        classified_samples = []
        for cc_class, cc_samples in currentClassifications:
            for samp in cc_samples:
                 text = "[" + samp[0] + "]." + samp[1]
        all_samples = self.getSamples()
        #do an inorder comparison to build list
        unclassified_samples = []
        curr_class_samp_i = 0
        for sample in all_samples:
            if curr_class_samp_i == len(classified_samples) or sample != classified_samples[curr_class_samp_i]:
                curr_class_samp_i += 1
        return unclassified_samples        

    def addUnclassified(self, table, sample_name):
        Adds an unclassified sample to the data package
        self.datapackage.setUnclassified(table, sample_name)
    def getSampleInfo(self, table, sample_id):
        Gets the information about a sample if it is available
        table = self.datapackage.getTable(table)        
        return table.getSampleDescription(sample_id)
    def _getLearnerAccuracy(self, learner, row_key):
        Takes a trained learner (and its row_key gene/probe)
         and returns the results of
        classifying the Trained Data
        Returns a tuple (T0,T1,F0, F1)
        Note T0 = True Positive = True class 1
        import math
        dp = self.datapackage
        class1, class2 = dp.getClassifications()
        T0 = 0
        F0 = 0
        T1 = 0
        F1 = 0
        for table, sample in class1[1]:
            self.addUnclassified(table, sample)
            if learner.classify() == 0:
                T0 += 1
                F1 += 1  

        for table, sample in class2[1]:
            self.addUnclassified(table, sample)
            if learner.classify() == 1:
                T1 += 1
                F0 += 1  
        myacc =  float(T0+T1)/(T0+T1+F0+F1)
        return (T0,T1,F0, F1)

    def trainDirac(self, crossValidate=False):
        self.queue.put(('statusbarset',"Preparing Dirac"))
        min_net = self.config.getSetting("dirac","Minimum Network Size")[0]
        row_key = self.config.getSetting("dirac","Row Key(genes/probes)")[0]
        numTopNetworks = self.config.getSetting("dirac","Number of Top Networks")[0]
        data_vector, num_genes = self.datapackage.getDataVector(row_key)
        class_vector = self.datapackage.getClassVector()

        gene_net, gene_net_size = self.datapackage.getGeneNetVector(min_net)
        netMap = self.datapackage.gene_net_map
        d = dirac.Dirac(data_vector, num_genes,class_vector, gene_net, gene_net_size, numTopNetworks, netMap)
        if crossValidate:
            return d
        self.queue.put(('statusbarset',"Training Dirac"))
        self.dirac = d
        self.queue.put(('statusbarset',"Training Complete, Checking Accuracy"))
        self.dirac_acc = self._getLearnerAccuracy(self.dirac, row_key)
        self.queue.put(('statusbarset',"Accuracy Check Complete"))

    def trainTSP(self, crossValidate=False):
        Performs the training of TSP
        self.queue.put(('statusbarset',"Preparing TSP"))
        filters = self.config.getSetting("tsp","filters")
        row_key = self.config.getSetting("tsp","Row Key(genes/probes)")[0]
        data_vector, num_genes = self.datapackage.getDataVector(row_key)
        class_vector = self.datapackage.getClassVector()
        vecFilter = tsp.IntVector()
        for val in filters:
        self.queue.put(('statusbarset',"Init TSP"))

        t = tsp.TSP(data_vector, num_genes, class_vector, vecFilter)
        if crossValidate:
            return t
        self.queue.put(('statusbarset',"Training TSP"))
        self.tsp = t
        self.queue.put(('statusbarset',"Training Complete, Checking Accuracy"))
        self.tsp_acc = self._getLearnerAccuracy(self.tsp, row_key)
        self.queue.put(('statusbarset',"Accuracy Check Complete"))

    def trainTST(self, crossValidate=False):
        Performs the training of tst
        self.queue.put(('statusbarset',"Preparing TST"))
        filters = self.config.getSetting("tst","filters")
        row_key = self.config.getSetting("tst","Row Key(genes/probes)")[0]
        data_vector, num_genes = self.datapackage.getDataVector(row_key)
        class_vector = self.datapackage.getClassVector()
        vecFilter = tst.IntVector()
        for val in filters:
        t = tst.TST(data_vector, num_genes, class_vector, vecFilter)    
        if crossValidate:
            return t
        self.tst = t
        self.queue.put(('statusbarset',"Training TST"))
        self.queue.put(('statusbarset',"Training Complete, Checking Accuracy"))
        self.tst_acc = self._getLearnerAccuracy(self.tst, row_key)
        self.queue.put(('statusbarset',"Accuracy Check Complete"))

    def trainkTSP(self, crossValidate=False):
        Performs the training of k-TSP
        self.queue.put(('statusbarset',"Preparing k-TSP"))
        maxk = self.config.getSetting("ktsp","Maximum K value")[0]
        cross_remove = self.config.getSetting("ktsp","Remove for Cross Validation")[0]
        num_cross = self.config.getSetting("ktsp","Number of Cross Validation Runs")[0]
        row_key = self.config.getSetting("ktsp","Row Key(genes/probes)")[0]
        data_vector, num_genes = self.datapackage.getDataVector(row_key)
        filters = self.config.getSetting("ktsp","filters")
        class_vector = self.datapackage.getClassVector()
        vecFilter = tst.IntVector()
        for x in filters:
            if x < 2*maxk:
                raise Exception("Ktsp setting error.  The filters must be at least twice the Maximum K value")

        k = ktsp.KTSP( data_vector, num_genes, class_vector, vecFilter, maxk, cross_remove, num_cross)
        if crossValidate:
            return k
        self.queue.put(('statusbarset',"Training k-TSP"))
        self.ktsp = k
        self.queue.put(('statusbarset',"Training Complete, Checking Accuracy"))
        self.ktsp_acc = self._getLearnerAccuracy(self.ktsp, row_key)
        self.queue.put(('statusbarset',"Accuracy Check Complete"))

    def trainAdaptive(self, target_accuracy, maxTime  ):
        self.queue.put(('statusbarset',"Configuring adaptive training"))
        acc = float(target_accuracy)
        mtime = int(maxTime)
        maxTime = mtime

        target_accuracy = acc
        #build learner queue
        #create adaptive object
        adaptive = Adaptive(self.learnerqueue, app_status_bar = self.queue)
        top_acc, top_settings, top_learner = adaptive.getLearner(target_accuracy, maxTime)
        #store adaptive results (really should be in adaptive)
        self.adaptive_history = adaptive.getHistory()
        self.adaptive = top_learner
        self.adaptive_settings = top_settings
        self.adaptive_top_mcc = top_acc
        self.adaptive_setting_string  = adaptive.getSettingString(top_settings)
        if self.adaptive is not None:
            row_key = top_settings['data_type']
            self.queue.put(('statusbarset',"Training Complete, Checking Accuracy"))
            self.adaptive_acc = self._getLearnerAccuracy(self.adaptive, row_key)
            self.queue.put(('statusbarset',"Accuracy Check Complete"))
            #none of the algorithms ran, maybe timeout is to low
            self.queue.put(('statusbarset',"Adaptive failed to run. Is the timeout too low?"))

    def _adaptiveSetup(self):
        self.queue.put(('statusbarset',"Configuring dirac"))
        self.queue.put(('statusbarset',"Configuring tsp"))
        self.queue.put(('statusbarset',"Configuring tst"))
        self.queue.put(('statusbarset',"Configuring ktsp"))
        self.queue.put(('statusbarset',"Relational learners configured"))

    def _adaptiveSetupLearnerQueue(self):
        dp = self.datapackage
        #Learner Queue Settings
        wilc_data_type = self.config.getSetting("adaptive", "Wilcoxon Row Key (gene/probe)")[0]
        weight = self.config.getSetting("adaptive", "Initial Weight (dirac,tsp,tst,ktsp)")
        scale = None#self.config.getSetting("adaptive", "Initial Scale (dirac,tsp,tst,ktsp)")
        min_weight = self.config.getSetting("adaptive", "Minimum Weight")[0]
        self.learnerqueue = LearnerQueue(dp, wilc_data_type, weight, scale, min_weight)

    def _adaptiveSetupDirac(self):
        #dirac settings
        d_row_key = self.config.getSetting("adaptive", "Dirac-Row Key(gene/probe)")[0]
        d_min_net = self.config.getSetting("adaptive", "Dirac-Min. Network Size Range")
        d_num_top_net = self.config.getSetting("adaptive", "Dirac-Num Top Networks Range")
        self.learnerqueue.genDirac(d_min_net, d_num_top_net, d_row_key)        

    def _adaptiveSetupTSP(self):
        #tsp settings
        p_row_key = self.config.getSetting("adaptive", "TSP-Row Key(gene/probe)")[0]
        p_equijoin = self.config.getSetting("adaptive", "TSP-Only use equal filters")[0]
        p_filter_1 = self.config.getSetting("adaptive", "TSP-Filter 1 Range")
        p_filter_2 = self.config.getSetting("adaptive", "TSP-Filter 2 Range")
        self.learnerqueue.genTSP(tuple(p_filter_1), tuple(p_filter_2), p_equijoin, p_row_key)

    def _adaptiveSetupTST(self):
        #tst settings
        t_row_key = self.config.getSetting("adaptive", "TST-Row Key(gene/probe)")[0]
        t_equijoin = self.config.getSetting("adaptive", "TST-Only use equal filters")[0]
        t_filter_1 = self.config.getSetting("adaptive", "TST-Filter 1 Range")
        t_filter_2 = self.config.getSetting("adaptive", "TST-Filter 2 Range")
        t_filter_3 = self.config.getSetting("adaptive", "TST-Filter 3 Range")           
        self.learnerqueue.genTST(tuple(t_filter_1), tuple(t_filter_2), tuple(t_filter_3), t_equijoin, t_row_key)

    def _adaptiveSetupKTSP(self):
        #ktsp settings
        k_row_key = self.config.getSetting("adaptive", "k-TSP-Row Key(gene/probe)")[0]
        k_equijoin = self.config.getSetting("adaptive", "k-TSP-Only use equal filters")[0]
        k_filter_1 = self.config.getSetting("adaptive", "k-TSP-Filter 1 Range")
        k_filter_2 = self.config.getSetting("adaptive", "k-TSP-Filter 2 Range")
        k_maxK = self.config.getSetting("adaptive", "Maximum k Range")
        k_ncv = self.config.getSetting("adaptive", "Number of Internal CV to find k Range")
        k_nlo = self.config.getSetting("adaptive", "Number to leave out on Internal CV to find k Range")
        self.learnerqueue.genKTSP(tuple(k_maxK),tuple(k_ncv), tuple(k_nlo), tuple(k_filter_1), tuple(k_filter_2), k_equijoin, k_row_key)
    def clearLearningAlg(self):
        Set all learning algorithms to None.
        Happens when we change something further up the dependency
        self.dirac = None
        self.tsp = None
        self.tst = None
        self.ktsp = None
        self.adaptive = None
        self.tsp_acc = None
        self.ktsp_acc = None
        self.tst_acc = None
        self.dirac_acc = None
        self.adaptive_acc = None
        self.tsp_tt = None
        self.ktsp_tt = None
        self.tst_tt = None
        self.dirac_tt = None
        self.adaptive_tt = None

    def classifyDirac(self):
        dp = self.datapackage
        row_key = self.config.getSetting("dirac","Row Key(genes/probes)")[0]
        self.dirac.addUnclassified(dp.getUnclassifiedDataVector(row_key ))
        self.dirac_classification = self.dirac.classify()
        return self.dirac_classification

    def classifyTSP(self):
        dp = self.datapackage
        row_key = self.config.getSetting("tsp","Row Key(genes/probes)")[0]
        self.tsp_classification = self.tsp.classify()
        return self.tsp_classification

    def classifyTST(self):
        dp = self.datapackage
        row_key = self.config.getSetting("tst","Row Key(genes/probes)")[0]
        self.tst_classification = self.tst.classify()
        return self.tst_classification

    def classifykTSP(self):
        dp = self.datapackage
        row_key = self.config.getSetting("ktsp","Row Key(genes/probes)")[0]
        self.ktsp_classification = self.ktsp.classify()
        return self.ktsp_classification

    def classifyAdaptive(self):
        dp = self.datapackage
        learner = self.adaptive
        settings = self.adaptive_settings
        row_key = settings['data_type']
        self.adaptive_classification = learner.classify()
        return self.adaptive_classification

    def _acc(self, truth_table):
        Given truth_table compute accuracy
        tpos = truth_table[0]
        tneg = truth_table[1]
        fpos = truth_table[2]
        fneg = truth_table[3]
        return float(tpos+tneg)/(tpos+tneg+fpos+fneg)

    def _mcc(self, truth_table):
        Given truth_table, compute MCC
        import math
        tpos = truth_table[0]
        tneg = truth_table[1]
        fpos = truth_table[2]
        fneg = truth_table[3]

        den = math.sqrt(float((tpos+fpos)*(tpos+fneg)*(tneg+fpos)*(tneg+fneg)))
        if den < .000001:
            den = 1.0
        return float(tpos*tneg - fpos*fneg)/den

    def crossValidateDirac(self):
        dirac = self.trainDirac(crossValidate = True)
        self.queue.put(('statusbarset',"Cross Validating"))
        tt = dirac.truth_table
        self.dirac_tt = [tt[0],tt[1],tt[2], tt[3]]#sometimes intvectors do not want to iterate
        self.queue.put(('statusbarset',"Dirac had an Accuracy of " + str(self._acc(self.dirac_tt))[:4]))

    def crossValidateTSP(self):
        tsp = self.trainTSP(crossValidate = True)
        self.queue.put(('statusbarset',"Cross Validating"))
        tt = tsp.truth_table
        self.tsp_tt = [tt[0],tt[1],tt[2], tt[3]]#sometimes intvectors do not want to iterate
        self.queue.put(('statusbarset',"TSP had an Accuracy of " + str(self._acc(self.tsp_tt))[:4]))

    def crossValidateTST(self):
        tst = self.trainTST(crossValidate = True)
        self.queue.put(('statusbarset',"Cross Validating"))
        tt = tst.truth_table
        self.tst_tt = [tt[0],tt[1],tt[2], tt[3]]#sometimes intvectors do not want to iterate
        self.queue.put(('statusbarset',"TST had an Accuracy of " + str(self._acc(self.tst_tt))[:4]))

    def crossValidateKTSP(self):
        ktsp = self.trainkTSP(crossValidate = True)
        self.queue.put(('statusbarset',"Cross Validating"))
        tt = ktsp.truth_table
        self.ktsp_tt = [tt[0],tt[1],tt[2], tt[3]]#sometimes intvectors do not want to iterate
        self.queue.put(('statusbarset',"kTSP had an Accuracy of " + str(self._acc(self.ktsp_tt))[:4]))
    def crossValidateAdaptive(self, target_acc, maxtime):
        #create adaptive object
        adaptive = Adaptive(self.learnerqueue, app_status_bar = self.queue)
        #using accuracy, because we are reporting accuracy - acc is not used to choose learner
        adaptive.crossValidate(target_acc, maxtime)
        self.adaptive_tt = adaptive.truth_table[:]
        self.queue.put(('statusbarset',"Adaptive had an Accuracy of " + str(self._acc(self.adaptive_tt))[:4]))

    def _checkRowKey(self, row_key, srcStr="Not Given"):
        if row_key not in ['gene', 'probe']:
            raise InputError(srcStr, "Given key " + row_key + " is invalid" )