예제 #1
0
 def __init__(self):
     
     self.plotPrefix = './simulations/simulation.draft.w_refinement_50'
     self.simCompareFile = './simulations/simCompare.draft.w_refinement_50.full.tsv'
     self.simCompareMarkerSetOut = './simulations/simCompare.draft.marker_set_table.w_refinement_50.tsv'
     self.simCompareConditionOut = './simulations/simCompare.draft.condition_table.w_refinement_50.tsv'
     self.simCompareTaxonomyTableOut = './simulations/simCompare.draft.taxonomy_table.w_refinement_50.tsv'
     self.simCompareRefinementTableOut = './simulations/simCompare.draft.refinment_table.w_refinement_50.tsv'
            
     #self.plotPrefix = './simulations/simulation.scaffolds.draft.w_refinement_50'
     #self.simCompareFile = './simulations/simCompare.scaffolds.draft.w_refinement_50.full.tsv'
     #self.simCompareMarkerSetOut = './simulations/simCompare.scaffolds.draft.marker_set_table.w_refinement_50.tsv'
     #self.simCompareConditionOut = './simulations/simCompare.scaffolds.draft.condition_table.w_refinement_50.tsv'
     #self.simCompareTaxonomyTableOut = './simulations/simCompare.scaffolds.draft.taxonomy_table.w_refinement_50.tsv'
     #self.simCompareRefinementTableOut = './simulations/simCompare.scaffolds.draft.refinment_table.w_refinement_50.tsv'
     
     #self.plotPrefix = './simulations/simulation.random_scaffolds.w_refinement_50'
     #self.simCompareFile = './simulations/simCompare.random_scaffolds.w_refinement_50.full.tsv'
     #self.simCompareMarkerSetOut = './simulations/simCompare.random_scaffolds.marker_set_table.w_refinement_50.tsv'
     #self.simCompareConditionOut = './simulations/simCompare.random_scaffolds.condition_table.w_refinement_50.tsv'
     #self.simCompareTaxonomyTableOut = './simulations/simCompare.random_scaffolds.taxonomy_table.w_refinement_50.tsv'
     #self.simCompareRefinementTableOut = './simulations/simCompare.random_scaffolds.refinment_table.w_refinement_50.tsv'
     
     self.img = IMG('/srv/whitlam/bio/db/checkm/img/img_metadata.tsv', '/srv/whitlam/bio/db/checkm/pfam/tigrfam2pfam.tsv')
     
     self.compsToConsider = [0.5, 0.7, 0.8, 0.9] #[0.5, 0.7, 0.8, 0.9]
     self.contsToConsider = [0.05, 0.1, 0.15] #[0.05, 0.1, 0.15]
     
     self.dpi = 1200
예제 #2
0
 def __init__(self):
     self.img = IMG('/srv/whitlam/bio/db/checkm/img/img_metadata.tsv', '/srv/whitlam/bio/db/checkm/pfam/tigrfam2pfam.tsv')
     self.colocatedFile = './data/colocated.tsv'
     self.duplicateSeqs = self.readDuplicateSeqs()
     self.uniqueIdToLineageStatistics = self.__readNodeMetadata()
     
     self.cachedGeneCountTable = None
예제 #3
0
    def __init__(self):
        self.markerSetBuilder = MarkerSetBuilder()
        self.img = IMG('/srv/whitlam/bio/db/checkm/img/img_metadata.tsv', '/srv/whitlam/bio/db/checkm/pfam/tigrfam2pfam.tsv')

        self.contigLens = [1000, 2000, 5000, 10000, 20000, 50000]
        self.percentComps = [0.5, 0.7, 0.8, 0.9, 0.95, 1.0]
        self.percentConts = [0.0, 0.05, 0.1, 0.15, 0.2]
예제 #4
0
    def __workerThread(self, ubiquityThreshold, singleCopyThreshold,
                       minGenomes, colocatedDistThreshold,
                       colocatedGenomeThreshold, metadata, queueIn, queueOut):
        """Process each data item in parallel."""

        img = IMG('/srv/whitlam/bio/db/checkm/img/img_metadata.tsv',
                  '/srv/whitlam/bio/db/checkm/pfam/tigrfam2pfam.tsv')
        markerSetBuilder = MarkerSetBuilder()

        while True:
            lineage = queueIn.get(block=True, timeout=None)
            if lineage == None:
                break

            if lineage == 'Universal':
                genomeIds = img.genomeIdsByTaxonomy('prokaryotes', metadata)
            else:
                genomeIds = img.genomeIdsByTaxonomy(lineage, metadata)
            if len(genomeIds) >= minGenomes:
                markerSet = markerSetBuilder.buildMarkerSet(
                    genomeIds, ubiquityThreshold, singleCopyThreshold,
                    colocatedDistThreshold)
                colocatedSets = markerSet.markerSet
            else:
                colocatedSets = None

            # allow results to be processed or written to file
            queueOut.put((lineage, colocatedSets, len(genomeIds)))
    def __workerThread(self, ubiquityThreshold, singleCopyThreshold, 
                       minGenomes, 
                       colocatedDistThreshold, colocatedGenomeThreshold, 
                       metadata, 
                       queueIn, queueOut):
        """Process each data item in parallel."""
        
        img = IMG('/srv/whitlam/bio/db/checkm/img/img_metadata.tsv', '/srv/whitlam/bio/db/checkm/pfam/tigrfam2pfam.tsv')
        markerSetBuilder = MarkerSetBuilder()

        while True:
            lineage = queueIn.get(block=True, timeout=None)
            if lineage == None:
                break

            if lineage == 'Universal':
                genomeIds = img.genomeIdsByTaxonomy('prokaryotes', metadata)
            else:
                genomeIds = img.genomeIdsByTaxonomy(lineage, metadata)
            if len(genomeIds) >= minGenomes:
                markerSet = markerSetBuilder.buildMarkerSet(genomeIds, ubiquityThreshold, singleCopyThreshold, colocatedDistThreshold)
                colocatedSets = markerSet.markerSet
            else:
                colocatedSets = None

            # allow results to be processed or written to file
            queueOut.put((lineage, colocatedSets, len(genomeIds)))
예제 #6
0
    def __init__(self, outputDir):
        self.__checkForFastTree()

        self.derepConcatenatedAlignFile = os.path.join(outputDir, 'genome_tree.concatenated.derep.fasta')
        self.tree = os.path.join(outputDir, 'genome_tree.final.tre')

        self.img = IMG('/srv/whitlam/bio/db/checkm/img/img_metadata.tsv', '/srv/whitlam/bio/db/checkm/pfam/tigrfam2pfam.tsv')
        self.metadata = self.img.genomeMetadata()
예제 #7
0
    def __getUniversalMarkerGenes(self, phyloUbiquityThreshold,
                                  phyloSingleCopyThreshold, outputGeneDir):
        img = IMG('/srv/whitlam/bio/db/checkm/img/img_metadata.tsv',
                  '/srv/whitlam/bio/db/checkm/pfam/tigrfam2pfam.tsv')
        markerSetBuilder = MarkerSetBuilder()

        metadata = img.genomeMetadata()

        allTrustedGenomeIds = set()
        phyloMarkerGenes = {}
        for lineage in ['Archaea', 'Bacteria']:
            # get all genomes in lineage
            print('\nIdentifying all ' + lineage + ' genomes.')
            trustedGenomeIds = img.genomeIdsByTaxonomy(lineage, metadata)
            print('  Trusted genomes in lineage: ' +
                  str(len(trustedGenomeIds)))
            if len(trustedGenomeIds) < 1:
                print(
                    '  Skipping lineage due to insufficient number of genomes.'
                )
                continue

            allTrustedGenomeIds.update(trustedGenomeIds)

            print('  Building marker set.')
            markerGenes = markerSetBuilder.buildMarkerGenes(
                trustedGenomeIds, phyloUbiquityThreshold,
                phyloSingleCopyThreshold)
            phyloMarkerGenes[lineage] = markerGenes

            #print lineage
            #print len(markerGenes)
            #print 'pfam01379: ', ('pfam01379' in markerGenes)
            #print '--------------------'

        # universal marker genes
        universalMarkerGenes = None
        for markerGenes in list(phyloMarkerGenes.values()):
            if universalMarkerGenes == None:
                universalMarkerGenes = markerGenes
            else:
                universalMarkerGenes.intersection_update(markerGenes)

        fout = open(os.path.join(outputGeneDir, 'phylo_marker_set.txt'), 'w')
        fout.write(str(universalMarkerGenes))
        fout.close()

        print('')
        print('  Universal marker genes: ' + str(len(universalMarkerGenes)))

        return allTrustedGenomeIds, universalMarkerGenes
예제 #8
0
    def run(self, outputDir, ubiquityThreshold, singleCopyThreshold,
            minGenomes, colocatedDistThreshold, colocatedGenomeThreshold,
            threads):
        if not os.path.exists(outputDir):
            os.makedirs(outputDir)

        # determine lineages to process
        img = IMG('/srv/whitlam/bio/db/checkm/img/img_metadata.tsv',
                  '/srv/whitlam/bio/db/checkm/pfam/tigrfam2pfam.tsv')
        metadata = img.genomeMetadata()
        lineages = img.lineagesSorted(metadata)
        lineages.append('Universal')

        # determine HMM model accession numbers
        pfamIdToPfamAcc = self.__pfamIdToPfamAcc(img)

        # populate worker queue with data to process
        workerQueue = mp.Queue()
        writerQueue = mp.Queue()

        for lineage in lineages:
            workerQueue.put(lineage)

        for _ in range(threads):
            workerQueue.put(None)

        workerProc = [
            mp.Process(target=self.__workerThread,
                       args=(ubiquityThreshold, singleCopyThreshold,
                             minGenomes, colocatedDistThreshold,
                             colocatedGenomeThreshold, metadata, workerQueue,
                             writerQueue)) for _ in range(threads)
        ]
        writeProc = mp.Process(
            target=self.__writerThread,
            args=(pfamIdToPfamAcc, ubiquityThreshold, singleCopyThreshold,
                  colocatedDistThreshold, colocatedGenomeThreshold, outputDir,
                  len(lineages), writerQueue))

        writeProc.start()

        for p in workerProc:
            p.start()

        for p in workerProc:
            p.join()

        writerQueue.put((None, None, None))
        writeProc.join()
예제 #9
0
 def __init__(self):
     
     self.plotPrefix = './simulations/simulation.draft.w_refinement_50'
     self.simCompareFile = './simulations/simCompare.draft.w_refinement_50.full.tsv'
     self.simCompareMarkerSetOut = './simulations/simCompare.draft.marker_set_table.w_refinement_50.tsv'
     self.simCompareConditionOut = './simulations/simCompare.draft.condition_table.w_refinement_50.tsv'
     self.simCompareTaxonomyTableOut = './simulations/simCompare.draft.taxonomy_table.w_refinement_50.tsv'
     self.simCompareRefinementTableOut = './simulations/simCompare.draft.refinment_table.w_refinement_50.tsv'
            
     #self.plotPrefix = './simulations/simulation.scaffolds.draft.w_refinement_50'
     #self.simCompareFile = './simulations/simCompare.scaffolds.draft.w_refinement_50.full.tsv'
     #self.simCompareMarkerSetOut = './simulations/simCompare.scaffolds.draft.marker_set_table.w_refinement_50.tsv'
     #self.simCompareConditionOut = './simulations/simCompare.scaffolds.draft.condition_table.w_refinement_50.tsv'
     #self.simCompareTaxonomyTableOut = './simulations/simCompare.scaffolds.draft.taxonomy_table.w_refinement_50.tsv'
     #self.simCompareRefinementTableOut = './simulations/simCompare.scaffolds.draft.refinment_table.w_refinement_50.tsv'
     
     #self.plotPrefix = './simulations/simulation.random_scaffolds.w_refinement_50'
     #self.simCompareFile = './simulations/simCompare.random_scaffolds.w_refinement_50.full.tsv'
     #self.simCompareMarkerSetOut = './simulations/simCompare.random_scaffolds.marker_set_table.w_refinement_50.tsv'
     #self.simCompareConditionOut = './simulations/simCompare.random_scaffolds.condition_table.w_refinement_50.tsv'
     #self.simCompareTaxonomyTableOut = './simulations/simCompare.random_scaffolds.taxonomy_table.w_refinement_50.tsv'
     #self.simCompareRefinementTableOut = './simulations/simCompare.random_scaffolds.refinment_table.w_refinement_50.tsv'
     
     self.img = IMG('/srv/whitlam/bio/db/checkm/img/img_metadata.tsv', '/srv/whitlam/bio/db/checkm/pfam/tigrfam2pfam.tsv')
     
     self.compsToConsider = [0.5, 0.7, 0.8, 0.9] #[0.5, 0.7, 0.8, 0.9]
     self.contsToConsider = [0.05, 0.1, 0.15] #[0.05, 0.1, 0.15]
     
     self.dpi = 1200
 def __init__(self):
     self.markerSetBuilder = MarkerSetBuilder()
     self.img = IMG('/srv/whitlam/bio/db/checkm/img/img_metadata.tsv', '/srv/whitlam/bio/db/checkm/pfam/tigrfam2pfam.tsv')
     
     self.contigLens = [5000, 20000, 50000]
     self.percentComps = [0.5, 0.7, 0.8, 0.9, 0.95, 1.0]
     self.percentConts = [0.0, 0.05, 0.1, 0.15, 0.2]
    def __getUniversalMarkerGenes(self, phyloUbiquityThreshold, phyloSingleCopyThreshold, outputGeneDir):
        img = IMG('/srv/whitlam/bio/db/checkm/img/img_metadata.tsv', '/srv/whitlam/bio/db/checkm/pfam/tigrfam2pfam.tsv')
        markerSetBuilder = MarkerSetBuilder()

        metadata = img.genomeMetadata()
                        
        allTrustedGenomeIds = set()
        phyloMarkerGenes = {}
        for lineage in ['Archaea', 'Bacteria']:
            # get all genomes in lineage
            print '\nIdentifying all ' + lineage + ' genomes.'
            trustedGenomeIds = img.genomeIdsByTaxonomy(lineage, metadata)
            print '  Trusted genomes in lineage: ' + str(len(trustedGenomeIds))
            if len(trustedGenomeIds) < 1:
                print '  Skipping lineage due to insufficient number of genomes.'
                continue
            
            allTrustedGenomeIds.update(trustedGenomeIds)
            
            print '  Building marker set.'
            markerGenes = markerSetBuilder.buildMarkerGenes(trustedGenomeIds, phyloUbiquityThreshold, phyloSingleCopyThreshold)
            phyloMarkerGenes[lineage] = markerGenes
            
            #print lineage
            #print len(markerGenes)
            #print 'pfam01379: ', ('pfam01379' in markerGenes)
            #print '--------------------'

        # universal marker genes
        universalMarkerGenes = None
        for markerGenes in phyloMarkerGenes.values():
            if universalMarkerGenes == None:
                universalMarkerGenes = markerGenes
            else:
                universalMarkerGenes.intersection_update(markerGenes)

        fout = open(os.path.join(outputGeneDir, 'phylo_marker_set.txt'), 'w')
        fout.write(str(universalMarkerGenes))
        fout.close()

        print ''
        print '  Universal marker genes: ' + str(len(universalMarkerGenes))
        
        return allTrustedGenomeIds, universalMarkerGenes
    def run(self, outputDir, ubiquityThreshold, singleCopyThreshold, minGenomes, colocatedDistThreshold, colocatedGenomeThreshold, threads):
        if not os.path.exists(outputDir):
            os.makedirs(outputDir)
            
        # determine lineages to process
        img = IMG('/srv/whitlam/bio/db/checkm/img/img_metadata.tsv', '/srv/whitlam/bio/db/checkm/pfam/tigrfam2pfam.tsv')
        metadata = img.genomeMetadata()
        lineages = img.lineagesSorted(metadata)
        lineages.append('Universal')
        
        # determine HMM model accession numbers
        pfamIdToPfamAcc = self.__pfamIdToPfamAcc(img)
        
        # populate worker queue with data to process
        workerQueue = mp.Queue()
        writerQueue = mp.Queue()

        for lineage in lineages:
            workerQueue.put(lineage)

        for _ in range(threads):
            workerQueue.put(None)

        workerProc = [mp.Process(target = self.__workerThread, args = (ubiquityThreshold, singleCopyThreshold, 
                                                                       minGenomes, 
                                                                       colocatedDistThreshold, colocatedGenomeThreshold, 
                                                                       metadata,
                                                                       workerQueue, writerQueue)) for _ in range(threads)]
        writeProc = mp.Process(target = self.__writerThread, args = (pfamIdToPfamAcc, 
                                                                       ubiquityThreshold, singleCopyThreshold, 
                                                                       colocatedDistThreshold, colocatedGenomeThreshold,
                                                                       outputDir, len(lineages), writerQueue))

        writeProc.start()

        for p in workerProc:
            p.start()

        for p in workerProc:
            p.join()

        writerQueue.put((None, None, None))
        writeProc.join()
예제 #13
0
class DecorateTree(object):
    def __init__(self):
        self.img = IMG('/srv/whitlam/bio/db/checkm/img/img_metadata.tsv', '/srv/whitlam/bio/db/checkm/pfam/tigrfam2pfam.tsv')
        self.pfamHMMs = '/srv/whitlam/bio/db/pfam/27/Pfam-A.hmm'
        self.markerSetBuilder = MarkerSetBuilder()

    def __meanStd(self, metadata, genomeIds, category):
        values = []
        for genomeId in genomeIds:
            genomeId = genomeId.replace('IMG_', '')
            v = metadata[genomeId][category]
            if v != 'NA':
                values.append(v)

        return mean(values), std(values)

    def __calculateMarkerSet(self, genomeLabels, ubiquityThreshold=0.97, singleCopyThreshold=0.97):
        """Calculate marker set for a set of genomes."""

        # get genome IDs from genome labels
        genomeIds = set()
        for genomeLabel in genomeLabels:
            genomeIds.add(genomeLabel.replace('IMG_', ''))

        markerSet = self.markerSetBuilder.buildMarkerSet(genomeIds, ubiquityThreshold, singleCopyThreshold)

        return markerSet.markerSet

    def __pfamIdToPfamAcc(self, img):
        pfamIdToPfamAcc = {}
        for line in open(self.pfamHMMs):
            if 'ACC' in line:
                acc = line.split()[1].strip()
                pfamId = acc.split('.')[0]

                pfamIdToPfamAcc[pfamId] = acc

        return pfamIdToPfamAcc

    def decorate(self, taxaTreeFile, derepFile, inputTreeFile, metadataOut, numThreads):
        # read genome metadata
        print '  Reading metadata.'
        metadata = self.img.genomeMetadata()

        # read list of taxa with duplicate sequences
        print '  Read list of taxa with duplicate sequences.'
        duplicateTaxa = {}
        for line in open(derepFile):
            lineSplit = line.rstrip().split()
            if len(lineSplit) > 1:
                duplicateTaxa[lineSplit[0]] = lineSplit[1:]

        # build gene count table
        print '  Building gene count table.'
        genomeIds = self.img.genomeMetadata().keys()
        print '    # trusted genomes = ' + str(len(genomeIds))

        # calculate statistics for each internal node using multiple threads
        print '  Calculating statistics for each internal node.'
        self.__internalNodeStatistics(taxaTreeFile, inputTreeFile, duplicateTaxa, metadata, metadataOut, numThreads)

    def __internalNodeStatistics(self, taxaTreeFile, inputTreeFile, duplicateTaxa, metadata, metadataOut, numThreads):

        # determine HMM model accession numbers
        pfamIdToPfamAcc = self.__pfamIdToPfamAcc(self.img)

        taxaTree = dendropy.Tree.get_from_path(taxaTreeFile, schema='newick', as_rooted=True, preserve_underscores=True)
        inputTree = dendropy.Tree.get_from_path(inputTreeFile, schema='newick', as_rooted=True, preserve_underscores=True)

        workerQueue = mp.Queue()
        writerQueue = mp.Queue()

        uniqueId = 0
        for node in inputTree.internal_nodes():
            uniqueId += 1
            workerQueue.put((uniqueId, node))

        for _ in range(numThreads):
            workerQueue.put((None, None))

        calcProc = [mp.Process(target=self.__processInternalNode, args=(taxaTree, duplicateTaxa, workerQueue, writerQueue)) for _ in range(numThreads)]
        writeProc = mp.Process(target=self.__reportStatistics, args=(metadata, metadataOut, inputTree, inputTreeFile, pfamIdToPfamAcc, writerQueue))

        writeProc.start()

        for p in calcProc:
            p.start()

        for p in calcProc:
            p.join()

        writerQueue.put((None, None, None, None, None, None, None))
        writeProc.join()

    def __processInternalNode(self, taxaTree, duplicateTaxa, queueIn, queueOut):
        """Run each marker gene in a separate thread."""

        while True:
            uniqueId, node = queueIn.get(block=True, timeout=None)
            if uniqueId == None:
                break

            # find corresponding internal node in taxa tree
            labels = []
            for leaf in node.leaf_nodes():
                labels.append(leaf.taxon.label)
                if leaf.taxon.label in duplicateTaxa:
                    for genomeId in duplicateTaxa[leaf.taxon.label]:
                        labels.append(genomeId)

            # check if there is a taxonomic label
            mrca = taxaTree.mrca(taxon_labels=labels)
            taxaStr = ''
            if mrca.label:
                taxaStr = mrca.label.replace(' ', '')

            # give node a unique Id while retraining bootstrap value
            bootstrap = ''
            if node.label:
                bootstrap = node.label
            nodeLabel = 'UID' + str(uniqueId) + '|' + taxaStr + '|' + bootstrap

            # calculate marker set
            markerSet = self.__calculateMarkerSet(labels)

            queueOut.put((uniqueId, labels, markerSet, taxaStr, bootstrap, node.oid, nodeLabel))

    def __reportStatistics(self, metadata, metadataOut, inputTree, inputTreeFile, pfamIdToPfamAcc, writerQueue):
        """Store statistics for internal node."""

        fout = open(metadataOut, 'w')
        fout.write('UID\t# genomes\tTaxonomy\tBootstrap')
        fout.write('\tGC mean\tGC std')
        fout.write('\tGenome size mean\tGenome size std')
        fout.write('\tGene count mean\tGene count std')
        fout.write('\tMarker set')
        fout.write('\n')

        numProcessedNodes = 0
        numInternalNodes = len(inputTree.internal_nodes())
        while True:
            uniqueId, labels, markerSet, taxaStr, bootstrap, nodeID, nodeLabel = writerQueue.get(block=True, timeout=None)
            if uniqueId == None:
                break

            numProcessedNodes += 1
            statusStr = '    Finished processing %d of %d (%.2f%%) internal nodes.' % (numProcessedNodes, numInternalNodes, float(numProcessedNodes) * 100 / numInternalNodes)
            sys.stdout.write('%s\r' % statusStr)
            sys.stdout.flush()

            fout.write('UID' + str(uniqueId) + '\t' + str(len(labels)) + '\t' + taxaStr + '\t' + bootstrap)

            m, s = self.__meanStd(metadata, labels, 'GC %')
            fout.write('\t' + str(m * 100) + '\t' + str(s * 100))

            m, s = self.__meanStd(metadata, labels, 'genome size')
            fout.write('\t' + str(m) + '\t' + str(s))

            m, s = self.__meanStd(metadata, labels, 'gene count')
            fout.write('\t' + str(m) + '\t' + str(s))

            # change model names to accession numbers, and make
            # sure there is an HMM model for each PFAM
            mungedMarkerSets = []
            for geneSet in markerSet:
                s = set()
                for geneId in geneSet:
                    if 'pfam' in geneId:
                        pfamId = geneId.replace('pfam', 'PF')
                        if pfamId in pfamIdToPfamAcc:
                            s.add(pfamIdToPfamAcc[pfamId])
                    else:
                        s.add(geneId)
                mungedMarkerSets.append(s)

            fout.write('\t' + str(mungedMarkerSets))

            fout.write('\n')

            node = inputTree.find_node(filter_fn=lambda n: hasattr(n, 'oid') and n.oid == nodeID)
            node.label = nodeLabel

        sys.stdout.write('\n')

        fout.close()

        inputTree.write_to_path(inputTreeFile, schema='newick', suppress_rooting=True, unquoted_underscores=True)
예제 #14
0
 def __init__(self):
     self.img = IMG('/srv/whitlam/bio/db/checkm/img/img_metadata.tsv', '/srv/whitlam/bio/db/checkm/pfam/tigrfam2pfam.tsv')
예제 #15
0
class PlotScaffoldLenVsMarkers(object):
    def __init__(self):
        self.img = IMG('/srv/whitlam/bio/db/checkm/img/img_metadata.tsv', '/srv/whitlam/bio/db/checkm/pfam/tigrfam2pfam.tsv')
        

    def run(self):
        # get all draft genomes consisting of a user-specific minimum number of scaffolds
        print('')
        metadata = self.img.genomeMetadata()
        print('  Total genomes: %d' % len(metadata))
        
        arGenome = set()
        for genomeId in metadata:
            if metadata[genomeId]['taxonomy'][0] == 'Archaea':
                arGenome.add(genomeId)
                
        draftGenomeIds = arGenome - self.img.filterGenomeIds(arGenome, metadata, 'status', 'Finished')
        print('  Number of draft genomes: %d' % len(draftGenomeIds))
        
        minScaffolds = 20
        genomeIdsToTest = set()
        for genomeId in draftGenomeIds:
            if metadata[genomeId]['scaffold count'] >= minScaffolds:
                genomeIdsToTest.add(genomeId)
        print('  Number of draft genomes with >= %d scaffolds: %d' % (minScaffolds, len(genomeIdsToTest)))

        print('')
        print('  Calculating genome information for calculating marker sets:')
        genomeFamilyScaffolds = self.img.precomputeGenomeFamilyScaffolds(genomeIdsToTest)
        
        print('  Calculating genome sequence lengths.')
        genomeSeqLens = self.img.precomputeGenomeSeqLens(genomeIdsToTest)
        
        print('  Determining domain-specific marker sets.')
        taxonParser = TaxonParser()
        taxonMarkerSets = taxonParser.readMarkerSets()
        bacMarkers = taxonMarkerSets['domain']['Bacteria'].getMarkerGenes()
        arMarkers = taxonMarkerSets['domain']['Archaea'].getMarkerGenes()
        print('    There are %d bacterial markers and %d archaeal markers.' % (len(bacMarkers), len(arMarkers)))
        
        print('  Determining percentage of markers on each scaffold.')
        totalMarkers = 0
        totalSequenceLen = 0
        markersOnShortScaffolds = 0
        totalShortScaffoldLen = 0
        
        scaffoldLen = {}
        percentageMarkers = defaultdict(float)
        for genomeId, markerIds in genomeFamilyScaffolds.items():
            domain = metadata[genomeId]['taxonomy'][0]
            markerGenes = bacMarkers if domain == 'Bacteria' else arMarkers
            for markerId in markerGenes:
                if markerId.startswith('PF'):
                    markerId = markerId.replace('PF', 'pfam')
                    markerId = markerId[0:markerId.rfind('.')]
                if markerId in markerIds:
                    for scaffoldId in markerIds[markerId]:
                        scaffoldLen[scaffoldId] = genomeSeqLens[genomeId][scaffoldId]
                        percentageMarkers[scaffoldId] += 1.0/len(markerGenes)
                        
                        totalMarkers += 1
                        totalSequenceLen += genomeSeqLens[genomeId][scaffoldId]
                        
                        if genomeSeqLens[genomeId][scaffoldId] < 10000:
                            markersOnShortScaffolds += 1
                            totalShortScaffoldLen += genomeSeqLens[genomeId][scaffoldId]
       
        print('Markers on short scaffolds: %d over %d Mbp (%f markers per base)' % (markersOnShortScaffolds, totalShortScaffoldLen, float(markersOnShortScaffolds)/totalShortScaffoldLen))
        print('Total markers on scaffolds: %d over %d Mbp (%f markers per base)' % (totalMarkers, totalSequenceLen, float(totalMarkers)/totalSequenceLen))
                        
        print('  Create plot.')
        plotLens = []
        plotPerMarkers = []
        for scaffoldId in percentageMarkers:
            plotLens.append(scaffoldLen[scaffoldId])
            plotPerMarkers.append(percentageMarkers[scaffoldId]/scaffoldLen[scaffoldId] * 1e6)
            
        scatterPlot = ScatterPlot()
        scatterPlot.plot(plotLens, plotPerMarkers)     
        scatterPlot.savePlot('./experiments/plotScaffoldLenVsMarkers.png')
예제 #16
0
 def __init__(self):
     img = IMG("/srv/whitlam/bio/db/checkm/img/img_metadata.tsv", "/srv/whitlam/bio/db/checkm/pfam/tigrfam2pfam.tsv")
     self.metadata = img.genomeMetadata()
예제 #17
0
class SimComparePlots(object):
    def __init__(self):
        
        self.plotPrefix = './simulations/simulation.draft.w_refinement_50'
        self.simCompareFile = './simulations/simCompare.draft.w_refinement_50.full.tsv'
        self.simCompareMarkerSetOut = './simulations/simCompare.draft.marker_set_table.w_refinement_50.tsv'
        self.simCompareConditionOut = './simulations/simCompare.draft.condition_table.w_refinement_50.tsv'
        self.simCompareTaxonomyTableOut = './simulations/simCompare.draft.taxonomy_table.w_refinement_50.tsv'
        self.simCompareRefinementTableOut = './simulations/simCompare.draft.refinment_table.w_refinement_50.tsv'
               
        #self.plotPrefix = './simulations/simulation.scaffolds.draft.w_refinement_50'
        #self.simCompareFile = './simulations/simCompare.scaffolds.draft.w_refinement_50.full.tsv'
        #self.simCompareMarkerSetOut = './simulations/simCompare.scaffolds.draft.marker_set_table.w_refinement_50.tsv'
        #self.simCompareConditionOut = './simulations/simCompare.scaffolds.draft.condition_table.w_refinement_50.tsv'
        #self.simCompareTaxonomyTableOut = './simulations/simCompare.scaffolds.draft.taxonomy_table.w_refinement_50.tsv'
        #self.simCompareRefinementTableOut = './simulations/simCompare.scaffolds.draft.refinment_table.w_refinement_50.tsv'
        
        #self.plotPrefix = './simulations/simulation.random_scaffolds.w_refinement_50'
        #self.simCompareFile = './simulations/simCompare.random_scaffolds.w_refinement_50.full.tsv'
        #self.simCompareMarkerSetOut = './simulations/simCompare.random_scaffolds.marker_set_table.w_refinement_50.tsv'
        #self.simCompareConditionOut = './simulations/simCompare.random_scaffolds.condition_table.w_refinement_50.tsv'
        #self.simCompareTaxonomyTableOut = './simulations/simCompare.random_scaffolds.taxonomy_table.w_refinement_50.tsv'
        #self.simCompareRefinementTableOut = './simulations/simCompare.random_scaffolds.refinment_table.w_refinement_50.tsv'
        
        self.img = IMG('/srv/whitlam/bio/db/checkm/img/img_metadata.tsv', '/srv/whitlam/bio/db/checkm/pfam/tigrfam2pfam.tsv')
        
        self.compsToConsider = [0.5, 0.7, 0.8, 0.9] #[0.5, 0.7, 0.8, 0.9]
        self.contsToConsider = [0.05, 0.1, 0.15] #[0.05, 0.1, 0.15]
        
        self.dpi = 1200
  
    def __readResults(self, filename):
        results = defaultdict(dict)
        genomeIds = set()
        with open(filename) as f:
            f.readline()
            for line in f:
                lineSplit = line.split('\t')
                
                simId = lineSplit[0]
                genomeId = simId.split('-')[0]
                genomeIds.add(genomeId)
                
                bestCompIM = [float(x) for x in lineSplit[6].split(',')]
                bestContIM = [float(x) for x in lineSplit[7].split(',')]
                
                bestCompMS = [float(x) for x in lineSplit[8].split(',')]
                bestContMS = [float(x) for x in lineSplit[9].split(',')]
                                
                domCompIM = [float(x) for x in lineSplit[10].split(',')]
                domContIM = [float(x) for x in lineSplit[11].split(',')]
                
                domCompMS = [float(x) for x in lineSplit[12].split(',')]
                domContMS = [float(x) for x in lineSplit[13].split(',')]
                
                simCompIM = [float(x) for x in lineSplit[14].split(',')]
                simContIM = [float(x) for x in lineSplit[15].split(',')]
                
                simCompMS = [float(x) for x in lineSplit[16].split(',')]
                simContMS = [float(x) for x in lineSplit[17].split(',')]
                
                simCompRMS = [float(x) for x in lineSplit[18].split(',')]
                simContRMS = [float(x) for x in lineSplit[19].split(',')]
                
                results[simId] = [bestCompIM, bestContIM, bestCompMS, bestContMS, domCompIM, domContIM, domCompMS, domContMS, simCompIM, simContIM, simCompMS, simContMS, simCompRMS, simContRMS]
                
        print('    Number of test genomes: ' + str(len(genomeIds)))
        
        return results
    
    def markerSets(self, results):
        # summarize results from IM vs MS
        print('  Tabulating results for domain-level marker genes vs marker sets.')
        
        itemsProcessed = 0      
        compDataDict = defaultdict(lambda : defaultdict(list))
        contDataDict = defaultdict(lambda : defaultdict(list))

        genomeIds = set()
        for simId in results:
            itemsProcessed += 1
            statusStr = '    Finished processing %d of %d (%.2f%%) test cases.' % (itemsProcessed, len(results), float(itemsProcessed)*100/len(results))
            sys.stdout.write('%s\r' % statusStr)
            sys.stdout.flush()
            
            genomeId, seqLen, comp, cont = simId.split('-')
            genomeIds.add(genomeId)
            expCondStr = str(float(comp)) + '-' + str(float(cont)) + '-' + str(int(seqLen))
            
            compDataDict[expCondStr]['IM'] += results[simId][4]
            compDataDict[expCondStr]['MS'] += results[simId][6]

            contDataDict[expCondStr]['IM'] += results[simId][5]
            contDataDict[expCondStr]['MS'] += results[simId][7]
                
        print('  There are %d unique genomes.' % len(genomeIds))
              
        sys.stdout.write('\n')
        
        print('    There are %d experimental conditions.' % (len(compDataDict)))
                
        # plot data
        print('  Plotting results.')
        compData = []
        contData = []
        rowLabels = []
        
        for comp in self.compsToConsider:
            for cont in self.contsToConsider:
                for seqLen in [20000]: 
                    for msStr in ['MS', 'IM']:
                        rowLabels.append(msStr +': %d%%, %d%%' % (comp*100, cont*100))
                        
                        expCondStr = str(comp) + '-' + str(cont) + '-' + str(seqLen)
                        compData.append(compDataDict[expCondStr][msStr])
                        contData.append(contDataDict[expCondStr][msStr])  
                                       
        print('MS:\t%.2f\t%.2f' % (mean(abs(array(compData[0::2]))), mean(abs(array(contData[0::2])))))
        print('IM:\t%.2f\t%.2f' % (mean(abs(array(compData[1::2]))), mean(abs(array(contData[1::2])))))   
            
        boxPlot = BoxPlot()
        plotFilename = self.plotPrefix + '.markerSets.png'
        boxPlot.plot(plotFilename, compData, contData, rowLabels, 
                        r'$\Delta$' + ' % Completion', 'Simulation Conditions', 
                        r'$\Delta$' + ' % Contamination', None,
                        rowsPerCategory = 2, dpi = self.dpi)
        
        # print table of results 
        tableOut = open(self.simCompareMarkerSetOut, 'w')
        tableOut.write('Comp. (%)\tCont. (%)\tIM (5kb)\t\tMS (5kb)\t\tIM (20kb)\t\tMS (20kb)\t\tIM (50kb)\t\tMS (50kb)\n')
        
        avgComp = defaultdict(lambda : defaultdict(list))
        avgCont = defaultdict(lambda : defaultdict(list))
        for comp in [0.5, 0.7, 0.8, 0.9, 0.95, 1.0]:
            for cont in [0.0, 0.05, 0.1, 0.15, 0.2]:
                
                tableOut.write('%d\t%d' % (comp*100, cont*100))
                
                for seqLen in [5000, 20000, 50000]:
                    expCondStr = str(comp) + '-' + str(cont) + '-' + str(seqLen)
                     
                    meanCompIM = mean(abs(array(compDataDict[expCondStr]['IM'])))
                    stdCompIM = std(abs(array(compDataDict[expCondStr]['IM'])))
                    meanContIM = mean(abs(array(contDataDict[expCondStr]['IM'])))
                    stdContIM = std(abs(array(contDataDict[expCondStr]['IM'])))
                    
                    avgComp[seqLen]['IM'] += compDataDict[expCondStr]['IM']
                    avgCont[seqLen]['IM'] += contDataDict[expCondStr]['IM']
                    
                    meanCompMS = mean(abs(array(compDataDict[expCondStr]['MS'])))
                    stdCompMS = std(abs(array(compDataDict[expCondStr]['MS'])))
                    meanContMS = mean(abs(array(contDataDict[expCondStr]['MS'])))
                    stdContMS = std(abs(array(contDataDict[expCondStr]['MS'])))
                    
                    avgComp[seqLen]['MS'] += compDataDict[expCondStr]['MS']
                    avgCont[seqLen]['MS'] += contDataDict[expCondStr]['MS']
                    
                    tableOut.write('\t%.1f+/-%.2f\t%.1f+/-%.2f\t%.1f+/-%.2f\t%.1f+/-%.2f' % (meanCompIM, stdCompIM, meanCompMS, stdCompMS, meanContIM, stdContIM, meanContMS, stdContMS))
                tableOut.write('\n')
                
        tableOut.write('\tAverage:')
        for seqLen in [5000, 20000, 50000]: 
            meanCompIM = mean(abs(array(avgComp[seqLen]['IM'])))
            stdCompIM = std(abs(array(avgComp[seqLen]['IM'])))
            meanContIM = mean(abs(array(avgCont[seqLen]['IM'])))
            stdContIM = std(abs(array(avgCont[seqLen]['IM'])))
            
            meanCompMS = mean(abs(array(avgComp[seqLen]['MS'])))
            stdCompMS = std(abs(array(avgComp[seqLen]['MS'])))
            meanContMS = mean(abs(array(avgCont[seqLen]['MS'])))
            stdContMS = std(abs(array(avgCont[seqLen]['MS'])))
            
            tableOut.write('\t%.1f+/-%.2f\t%.1f+/-%.2f\t%.1f+/-%.2f\t%.1f+/-%.2f' % (meanCompIM, stdCompIM, meanCompMS, stdCompMS, meanContIM, stdContIM, meanContMS, stdContMS))
                        
        tableOut.write('\n')     
                
        tableOut.close()
    
    def conditionsPlot(self, results):
        # summarize results for each experimental condition  
        print('  Tabulating results for each experimental condition using marker sets.')
        
        itemsProcessed = 0      
        compDataDict = defaultdict(lambda : defaultdict(list))
        contDataDict = defaultdict(lambda : defaultdict(list))
        comps = set()
        conts = set()
        seqLens = set()
        
        compOutliers = defaultdict(list)
        contOutliers = defaultdict(list)
        
        genomeIds = set()
        for simId in results:
            itemsProcessed += 1
            statusStr = '    Finished processing %d of %d (%.2f%%) test cases.' % (itemsProcessed, len(results), float(itemsProcessed)*100/len(results))
            sys.stdout.write('%s\r' % statusStr)
            sys.stdout.flush()
            
            genomeId, seqLen, comp, cont = simId.split('-')
            genomeIds.add(genomeId)
            expCondStr = str(float(comp)) + '-' + str(float(cont)) + '-' + str(int(seqLen))
            
            comps.add(float(comp))
            conts.add(float(cont))
            seqLens.add(int(seqLen))
            
            compDataDict[expCondStr]['best'] += results[simId][2]
            compDataDict[expCondStr]['domain'] += results[simId][6]
            compDataDict[expCondStr]['selected'] += results[simId][10]
            
            for dComp in results[simId][2]:
                compOutliers[expCondStr] += [[dComp, genomeId]]
            
            contDataDict[expCondStr]['best'] += results[simId][3]
            contDataDict[expCondStr]['domain'] += results[simId][7]
            contDataDict[expCondStr]['selected'] += results[simId][11]
            
            for dCont in results[simId][3]:
                contOutliers[expCondStr] += [[dCont, genomeId]]
                
        print('  There are %d unique genomes.' % len(genomeIds))
              
        sys.stdout.write('\n')
        
        print('    There are %d experimental conditions.' % (len(compDataDict)))
                
        # plot data
        print('  Plotting results.')
        compData = []
        contData = []
        rowLabels = []
        
        foutComp = open('./simulations/simulation.scaffolds.draft.comp_outliers.domain.tsv', 'w')
        foutCont = open('./simulations/simulation.scaffolds.draft.cont_outliers.domain.tsv', 'w')
        for comp in self.compsToConsider:
            for cont in self.contsToConsider:
                for msStr in ['best', 'selected', 'domain']:
                    for seqLen in [20000]: 
                        rowLabels.append(msStr +': %d%%, %d%%' % (comp*100, cont*100))
                        
                        expCondStr = str(comp) + '-' + str(cont) + '-' + str(seqLen)
                        compData.append(compDataDict[expCondStr][msStr])
                        contData.append(contDataDict[expCondStr][msStr])  
                    
                # report completenes outliers
                foutComp.write(expCondStr)

                compOutliers[expCondStr].sort()
                
                dComps = array([r[0] for r in compOutliers[expCondStr]])
                perc1 = scoreatpercentile(dComps, 1)
                perc99 = scoreatpercentile(dComps, 99)
                print(expCondStr, perc1, perc99)
                
                foutComp.write('\t%.2f\t%.2f' % (perc1, perc99))
                
                outliers = []
                for item in compOutliers[expCondStr]:
                    if item[0] < perc1 or item[0] > perc99:
                        outliers.append(item[1])
                        
                outlierCount = Counter(outliers)
                for genomeId, count in outlierCount.most_common():
                    foutComp.write('\t' + genomeId + ': ' + str(count))
                foutComp.write('\n')
                
                # report contamination outliers
                foutCont.write(expCondStr)

                contOutliers[expCondStr].sort()
                
                dConts = array([r[0] for r in contOutliers[expCondStr]])
                perc1 = scoreatpercentile(dConts, 1)
                perc99 = scoreatpercentile(dConts, 99)
                
                foutCont.write('\t%.2f\t%.2f' % (perc1, perc99))
                
                outliers = []
                for item in contOutliers[expCondStr]:
                    if item[0] < perc1 or item[0] > perc99:
                        outliers.append(item[1])
                        
                outlierCount = Counter(outliers)
                for genomeId, count in outlierCount.most_common():
                    foutCont.write('\t' + genomeId + ': ' + str(count))
                foutCont.write('\n')
                
        foutComp.close()
        foutCont.close()
                               
        print('best:\t%.2f\t%.2f' % (mean(abs(array(compData[0::3]))), mean(abs(array(contData[0::3])))))
        print('selected:\t%.2f\t%.2f' % (mean(abs(array(compData[1::3]))), mean(abs(array(contData[1::3])))))   
        print('domain:\t%.2f\t%.2f' % (mean(abs(array(compData[2::3]))), mean(abs(array(contData[2::3])))))   

        boxPlot = BoxPlot()
        plotFilename = self.plotPrefix + '.conditions.png'
        boxPlot.plot(plotFilename, compData, contData, rowLabels, 
                        r'$\Delta$' + ' % Completion', 'Simulation Conditions', 
                        r'$\Delta$' + ' % Contamination', None,
                        rowsPerCategory = 3, dpi = self.dpi)
        
        
        # print table of results 
        tableOut = open(self.simCompareConditionOut, 'w')
        tableOut.write('Comp. (%)\tCont. (%)\tbest (5kb)\t\tselected (5kb)\t\tdomain (5kb)\t\tbest (20kb)\t\tselected (20kb)\t\tdomain (20kb)\t\tbest (50kb)\t\tselected (50kb)\t\tdomain (50kb)\n')
        
        avgComp = defaultdict(lambda : defaultdict(list))
        avgCont = defaultdict(lambda : defaultdict(list))
        for comp in [0.5, 0.7, 0.8, 0.9, 0.95, 1.0]:
            for cont in [0.0, 0.05, 0.1, 0.15, 0.2]:
                
                tableOut.write('%d\t%d' % (comp*100, cont*100))
                
                for seqLen in [5000, 20000, 50000]:
                    expCondStr = str(comp) + '-' + str(cont) + '-' + str(seqLen)
                   
                    meanCompD = mean(abs(array(compDataDict[expCondStr]['domain'])))
                    stdCompD = std(abs(array(compDataDict[expCondStr]['domain'])))
                    meanContD = mean(abs(array(contDataDict[expCondStr]['domain'])))
                    stdContD = std(abs(array(contDataDict[expCondStr]['domain'])))
                    
                    avgComp[seqLen]['domain'] += compDataDict[expCondStr]['domain']
                    avgCont[seqLen]['domain'] += contDataDict[expCondStr]['domain']
                    
                    meanCompS = mean(abs(array(compDataDict[expCondStr]['selected'])))
                    stdCompS = std(abs(array(compDataDict[expCondStr]['selected'])))
                    meanContS = mean(abs(array(contDataDict[expCondStr]['selected'])))
                    stdContS = std(abs(array(contDataDict[expCondStr]['selected'])))
                    
                    avgComp[seqLen]['selected'] += compDataDict[expCondStr]['selected']
                    avgCont[seqLen]['selected'] += contDataDict[expCondStr]['selected']
                    
                    meanCompB = mean(abs(array(compDataDict[expCondStr]['best'])))
                    stdCompB = std(abs(array(compDataDict[expCondStr]['best'])))
                    meanContB = mean(abs(array(contDataDict[expCondStr]['best'])))
                    stdContB = std(abs(array(contDataDict[expCondStr]['best'])))
                    
                    avgComp[seqLen]['best'] += compDataDict[expCondStr]['best']
                    avgCont[seqLen]['best'] += contDataDict[expCondStr]['best']
                    
                    tableOut.write('\t%.1f\t%.1f\t%.1f\t%.1f\t%.1f\t%.1f' % (meanCompD, meanCompS, meanCompB, meanContD, meanContS, meanContB))
                tableOut.write('\n')
                
        tableOut.write('\tAverage:')
        for seqLen in [5000, 20000, 50000]: 
            meanCompD = mean(abs(array(avgComp[seqLen]['domain'])))
            stdCompD = std(abs(array(avgComp[seqLen]['domain'])))
            meanContD = mean(abs(array(avgCont[seqLen]['domain'])))
            stdContD = std(abs(array(avgCont[seqLen]['domain'])))
            
            meanCompS = mean(abs(array(avgComp[seqLen]['selected'])))
            stdCompS = std(abs(array(avgComp[seqLen]['selected'])))
            meanContS = mean(abs(array(avgCont[seqLen]['selected'])))
            stdContS = std(abs(array(avgCont[seqLen]['selected'])))
            
            meanCompB = mean(abs(array(avgComp[seqLen]['best'])))
            stdCompB = std(abs(array(avgComp[seqLen]['best'])))
            meanContB = mean(abs(array(avgCont[seqLen]['best'])))
            stdContB = std(abs(array(avgCont[seqLen]['best'])))
            
            tableOut.write('\t%.1f\t%.1f\t%.1f\t%.1f\t%.1f\t%.1f' % (meanCompD, meanCompS, meanCompB, meanContD, meanContS, meanContB))
                        
        tableOut.write('\n')     
                
        tableOut.close()
        
    def taxonomicPlots(self, results):
        # summarize results for different taxonomic groups  
        print('  Tabulating results for taxonomic groups.')
        
        metadata = self.img.genomeMetadata()
        
        itemsProcessed = 0      
        compDataDict = defaultdict(lambda : defaultdict(list))
        contDataDict = defaultdict(lambda : defaultdict(list))
        comps = set()
        conts = set()
        seqLens = set()
        
        ranksToProcess = 3
        taxaByRank = [set() for _ in range(0, ranksToProcess)]
        
        overallComp = []
        overallCont = []
                
        genomeInTaxon = defaultdict(set)
        testCases = 0
        for simId in results:
            itemsProcessed += 1
            statusStr = '    Finished processing %d of %d (%.2f%%) test cases.' % (itemsProcessed, len(results), float(itemsProcessed)*100/len(results))
            sys.stdout.write('%s\r' % statusStr)
            sys.stdout.flush()
            
            genomeId, seqLen, comp, cont = simId.split('-')
            
            if seqLen != '20000':
                continue
            
            if str(float(comp)) in ['0.5', '0.7', '0.8', '0.9'] and str(float(cont)) in ['0.05', '0.10', '0.1', '0.15']:
                print(comp, cont)
                taxonomy = metadata[genomeId]['taxonomy']
                
                testCases += 1
                
                comps.add(float(comp))
                conts.add(float(cont))
                seqLens.add(int(seqLen))
                
                overallComp += results[simId][10]
                overallCont += results[simId][11]
                
                for r in range(0, ranksToProcess):
                    taxon = taxonomy[r]
                    
                    if r == 0 and taxon == 'unclassified':
                        print('*****************************Unclassified at domain-level*****************')
                        continue
                    
                    if taxon == 'unclassified':
                        continue
                    
                    taxon = rankPrefixes[r] + taxon
                    
                    taxaByRank[r].add(taxon)
                                                    
                    compDataDict[taxon]['best'] += results[simId][2]
                    compDataDict[taxon]['domain'] += results[simId][6]
                    compDataDict[taxon]['selected'] += results[simId][10]
                    
                    contDataDict[taxon]['best'] += results[simId][3]
                    contDataDict[taxon]['domain'] += results[simId][7]
                    contDataDict[taxon]['selected'] += results[simId][11]
                    
                    genomeInTaxon[taxon].add(genomeId)
            
        sys.stdout.write('\n')
        
        print('Test cases', testCases)
        
        print('')        
        print('Creating plots for:')
        print('  comps = ', comps)
        print('  conts = ', conts)
        
        print('')
        print('    There are %d taxa.' % (len(compDataDict)))
        
        print('')
        print('  Overall bias:')
        print('    Selected comp: %.2f' % mean(overallComp))
        print('    Selected cont: %.2f' % mean(overallCont))
        
        # get list of ordered taxa by rank
        orderedTaxa = []
        for taxa in taxaByRank:
            orderedTaxa += sorted(taxa)
                
        # plot data
        print('  Plotting results.')
        compData = []
        contData = []
        rowLabels = []
        for taxon in orderedTaxa:
            for msStr in ['best', 'selected', 'domain']:
                numGenomes = len(genomeInTaxon[taxon])
                if numGenomes < 10: # skip groups with only a few genomes
                    continue
                
                rowLabels.append(msStr + ': ' + taxon + ' (' + str(numGenomes) + ')')
                compData.append(compDataDict[taxon][msStr])
                contData.append(contDataDict[taxon][msStr])        
                
        for i, rowLabel in enumerate(rowLabels):
            print(rowLabel + '\t%.2f\t%.2f' % (mean(abs(array(compData[i]))), mean(abs(array(contData[i])))))            
                  
        # print taxonomic table of results organized by class
        taxonomyTableOut = open(self.simCompareTaxonomyTableOut, 'w')
        for taxon in orderedTaxa:
            numGenomes = len(genomeInTaxon[taxon])
            if numGenomes < 2: # skip groups with only a few genomes
                continue
                
            taxonomyTableOut.write(taxon + '\t' + str(numGenomes))
            for msStr in ['domain', 'selected']:                
                meanTaxonComp = mean(abs(array(compDataDict[taxon][msStr])))
                stdTaxonComp = std(abs(array(compDataDict[taxon][msStr])))
                meanTaxonCont = mean(abs(array(contDataDict[taxon][msStr])))
                stdTaxonCont = std(abs(array(contDataDict[taxon][msStr])))
                
                taxonomyTableOut.write('\t%.1f +/- %.2f\t%.1f +/- %.2f' % (meanTaxonComp, stdTaxonComp, meanTaxonCont, stdTaxonCont))
            taxonomyTableOut.write('\n')
        taxonomyTableOut.close()
        
        # create box plot
        boxPlot = BoxPlot()
        plotFilename = self.plotPrefix +  '.taxonomy.png'
        boxPlot.plot(plotFilename, compData, contData, rowLabels, 
                        r'$\Delta$' + ' % Completion', None, 
                        r'$\Delta$' + ' % Contamination', None,
                        rowsPerCategory = 3, dpi = self.dpi)
    
    
    def refinementPlots(self, results):
        # summarize results for different CheckM refinements 
        print('  Tabulating results for different refinements.')
        
        metadata = self.img.genomeMetadata()
        
        itemsProcessed = 0      
        compDataDict = defaultdict(lambda : defaultdict(list))
        contDataDict = defaultdict(lambda : defaultdict(list))
        comps = set()
        conts = set()
        seqLens = set()
        
        ranksToProcess = 3
        taxaByRank = [set() for _ in range(0, ranksToProcess)]
        
        overallCompIM = []
        overallContIM = [] 
        
        overallCompMS = []
        overallContMS = [] 
        
        overallCompRMS = []
        overallContRMS = [] 
        
        genomeInTaxon = defaultdict(set)
        
        testCases = 0
        for simId in results:
            itemsProcessed += 1
            statusStr = '    Finished processing %d of %d (%.2f%%) test cases.' % (itemsProcessed, len(results), float(itemsProcessed)*100/len(results))
            sys.stdout.write('%s\r' % statusStr)
            sys.stdout.flush()
            
            genomeId, seqLen, comp, cont = simId.split('-')
            taxonomy = metadata[genomeId]['taxonomy']
            
            if float(comp) < 0.7 or float(cont) > 0.1:
                continue
            
            comps.add(float(comp))
            conts.add(float(cont))
            seqLens.add(int(seqLen))
            
            overallCompIM.append(results[simId][8])
            overallContIM.append(results[simId][9])
            
            overallCompMS.append(results[simId][10])
            overallContMS.append(results[simId][11])
            
            overallCompRMS.append(results[simId][12])
            overallContRMS.append(results[simId][13])
            
            for r in range(0, ranksToProcess):
                taxon = taxonomy[r]
                
                if taxon == 'unclassified':
                    continue
                
                taxaByRank[r].add(taxon)
                
                compDataDict[taxon]['IM'] += results[simId][8]
                compDataDict[taxon]['MS'] += results[simId][10]
                compDataDict[taxon]['RMS'] += results[simId][12]
                
                contDataDict[taxon]['IM'] += results[simId][9]
                contDataDict[taxon]['MS'] += results[simId][11]
                contDataDict[taxon]['RMS'] += results[simId][13]
                                
                genomeInTaxon[taxon].add(genomeId)
            
        sys.stdout.write('\n')
        
        print('Creating plots for:')
        print('  comps = ', comps)
        print('  conts = ', conts)
        
        print('')
        print('    There are %d taxon.' % (len(compDataDict)))
        print('')
        print('Percentage change MS-IM comp: %.4f' % ((mean(abs(array(overallCompMS))) - mean(abs(array(overallCompIM)))) * 100 / mean(abs(array(overallCompIM)))))
        print('Percentage change MS-IM cont: %.4f' % ((mean(abs(array(overallContMS))) - mean(abs(array(overallContIM)))) * 100 / mean(abs(array(overallContIM)))))
        print('')
        print('Percentage change RMS-MS comp: %.4f' % ((mean(abs(array(overallCompRMS))) - mean(abs(array(overallCompMS)))) * 100 / mean(abs(array(overallCompIM)))))
        print('Percentage change RMS-MS cont: %.4f' % ((mean(abs(array(overallContRMS))) - mean(abs(array(overallContMS)))) * 100 / mean(abs(array(overallContIM)))))
        
        print('')
        
        # get list of ordered taxa by rank
        orderedTaxa = []
        for taxa in taxaByRank:
            orderedTaxa += sorted(taxa)
             
        # print table of results organized by class
        refinmentTableOut = open(self.simCompareRefinementTableOut, 'w')
        for taxon in orderedTaxa:
            numGenomes = len(genomeInTaxon[taxon])
            if numGenomes < 2: # skip groups with only a few genomes
                continue
                
            refinmentTableOut.write(taxon + '\t' + str(numGenomes))
            for refineStr in ['IM', 'MS']:               
                meanTaxonComp = mean(abs(array(compDataDict[taxon][refineStr])))
                stdTaxonComp = std(abs(array(compDataDict[taxon][refineStr])))
                meanTaxonCont = mean(abs(array(contDataDict[taxon][refineStr])))
                stdTaxonCont = std(abs(array(contDataDict[taxon][refineStr])))
                
                refinmentTableOut.write('\t%.1f +/- %.2f\t%.1f +/- %.2f' % (meanTaxonComp, stdTaxonComp, meanTaxonCont, stdTaxonCont))
            
            perCompChange = (mean(abs(array(compDataDict[taxon]['IM']))) - meanTaxonComp) * 100 / mean(abs(array(compDataDict[taxon]['IM'])))
            perContChange = (mean(abs(array(contDataDict[taxon]['IM']))) - meanTaxonCont) * 100 / mean(abs(array(contDataDict[taxon]['IM'])))
            refinmentTableOut.write('\t%.2f\t%.2f\n' % (perCompChange, perContChange))
        refinmentTableOut.close()
       
        # plot data
        print('  Plotting results.')
        compData = []
        contData = []
        rowLabels = []
        for taxon in orderedTaxa:
            for refineStr in ['RMS', 'MS', 'IM']:
                numGenomes = len(genomeInTaxon[taxon])
                if numGenomes < 10: # skip groups with only a few genomes
                    continue

                rowLabels.append(refineStr + ': ' + taxon + ' (' + str(numGenomes) + ')')
                compData.append(compDataDict[taxon][refineStr])
                contData.append(contDataDict[taxon][refineStr])       
                
        for i, rowLabel in enumerate(rowLabels):
            print(rowLabel + '\t%.2f\t%.2f' % (mean(abs(array(compData[i]))), mean(abs(array(contData[i])))))
            
        boxPlot = BoxPlot()
        plotFilename = self.plotPrefix + '.refinements.png'
        boxPlot.plot(plotFilename, compData, contData, rowLabels, 
                        r'$\Delta$' + ' % Completion', None, 
                        r'$\Delta$' + ' % Contamination', None,
                        rowsPerCategory = 3, dpi = self.dpi)
        
    def run(self):
        # read simulation results
        print('  Reading simulation results.')
        results = self.__readResults(self.simCompareFile)
        
        print('\n')         
        #self.markerSets(results)
                   
        print('\n')         
        #self.conditionsPlot(results)
        
        #print '\n'
        self.taxonomicPlots(results)
        
        print('\n')
예제 #18
0
    def run(self, geneTreeDir, treeExtension, consistencyThreshold, minTaxaForAverage, outputFile, outputDir):
        # make sure output directory is empty
        if not os.path.exists(outputDir):
            os.makedirs(outputDir)

        files = os.listdir(outputDir)
        for f in files:
            if os.path.isfile(os.path.join(outputDir, f)):
                os.remove(os.path.join(outputDir, f))

        # get TIGRFam info
        descDict = {}
        files = os.listdir('/srv/db/tigrfam/13.0/TIGRFAMs_13.0_INFO')
        for f in files:
            shortDesc = longDesc = ''
            for line in open('/srv/db/tigrfam/13.0/TIGRFAMs_13.0_INFO/' + f):
                lineSplit = line.split('  ')
                if lineSplit[0] == 'AC':
                    acc = lineSplit[1].strip()
                elif lineSplit[0] == 'DE':
                    shortDesc = lineSplit[1].strip()
                elif lineSplit[0] == 'CC':
                    longDesc = lineSplit[1].strip()

            descDict[acc] = [shortDesc, longDesc]

        # get PFam info
        for line in open('/srv/db/pfam/27/Pfam-A.clans.tsv'):
            lineSplit = line.split('\t')
            acc = lineSplit[0]
            shortDesc = lineSplit[3]
            longDesc = lineSplit[4].strip()

            descDict[acc] = [shortDesc, longDesc]

        # get IMG taxonomy
        img = IMG('/srv/whitlam/bio/db/checkm/img/img_metadata.tsv', '/srv/whitlam/bio/db/checkm/pfam/tigrfam2pfam.tsv')
        metadata = img.genomeMetadata()
        genomeIdToTaxonomy = {}
        for genomeId, m in metadata.iteritems():
            genomeIdToTaxonomy[genomeId] = m['taxonomy']

        # perform analysis for each tree
        treeFiles = os.listdir(geneTreeDir)
        allResults = {}
        allTaxa = [set([]), set([]), set([])]
        taxaCounts = {}
        avgConsistency = {}
        for treeFile in treeFiles:
            if not treeFile.endswith(treeExtension):
                continue

            print treeFile
            tree = dendropy.Tree.get_from_path(os.path.join(geneTreeDir, treeFile), schema='newick', as_rooted=True, preserve_underscores=True)

            domainConsistency = {}
            phylaConsistency = {}
            classConsistency = {}
            consistencyDict = [domainConsistency, phylaConsistency, classConsistency]

            # get abundance of taxa at different taxonomic ranks
            totals = [{}, {}, {}]
            leaves = tree.leaf_nodes()
            print '  Number of leaves: ' + str(len(leaves))
            totalValidLeaves = 0

            for leaf in leaves:
                genomeId = self.__genomeId(leaf.taxon.label)

                if genomeId not in metadata:
                    print '[Error] Genome is missing metadata: ' + genomeId
                    sys.exit()

                totalValidLeaves += 1
                taxonomy = genomeIdToTaxonomy[genomeId]
                for r in xrange(0, 3):
                    totals[r][taxonomy[r]] = totals[r].get(taxonomy[r], 0) + 1
                    consistencyDict[r][taxonomy[r]] = 0
                    allTaxa[r].add(taxonomy[r])

            taxaCounts[treeFile] = [totalValidLeaves, totals[0].get('Bacteria', 0), totals[0].get('Archaea', 0)]

            # find highest consistency nodes (congruent descendant taxa / (total taxa + incongruent descendant taxa))
            internalNodes = tree.internal_nodes()
            for node in internalNodes:
                leaves = node.leaf_nodes()

                for r in xrange(0, 3):
                    leafCounts = {}
                    for leaf in leaves:
                        genomeId = self.__genomeId(leaf.taxon.label)
                        taxonomy = genomeIdToTaxonomy[genomeId]
                        leafCounts[taxonomy[r]] = leafCounts.get(taxonomy[r], 0) + 1

                    # calculate consistency for node
                    for taxa in consistencyDict[r]:
                        totalTaxaCount = totals[r][taxa]
                        if totalTaxaCount <= 1 or taxa == 'unclassified':
                            consistencyDict[r][taxa] = 'N/A'
                            continue

                        taxaCount = leafCounts.get(taxa, 0)
                        incongruentTaxa = len(leaves) - taxaCount
                        c = float(taxaCount) / (totalTaxaCount + incongruentTaxa)
                        if c > consistencyDict[r][taxa]:
                            consistencyDict[r][taxa] = c

                        # consider clan in other direction since the trees are unrooted
                        taxaCount = totalTaxaCount - leafCounts.get(taxa, 0)
                        incongruentTaxa = totalValidLeaves - len(leaves) - taxaCount
                        c = float(taxaCount) / (totalTaxaCount + incongruentTaxa)
                        if c > consistencyDict[r][taxa]:
                            consistencyDict[r][taxa] = c

            # write results
            consistencyDir = os.path.join(outputDir, 'consistency')
            if not os.path.exists(consistencyDir):
                os.makedirs(consistencyDir)
            fout = open(os.path.join(consistencyDir, treeFile + '.results.tsv'), 'w')
            fout.write('Tree')
            for r in xrange(0, 3):
                for taxa in sorted(consistencyDict[r].keys()):
                    fout.write('\t' + taxa)
            fout.write('\n')

            fout.write(treeFile)
            for r in xrange(0, 3):
                for taxa in sorted(consistencyDict[r].keys()):
                    if consistencyDict[r][taxa] != 'N/A':
                        fout.write('\t%.2f' % (consistencyDict[r][taxa]*100))
                    else:
                        fout.write('\tN/A')
            fout.close()

            # calculate average consistency at each taxonomic rank
            average = []
            for r in xrange(0, 3):
                sumConsistency = []
                for taxa in consistencyDict[r]:
                    if totals[r][taxa] > minTaxaForAverage and consistencyDict[r][taxa] != 'N/A':
                        sumConsistency.append(consistencyDict[r][taxa])

                if len(sumConsistency) > 0:
                    average.append(sum(sumConsistency) / len(sumConsistency))
                else:
                    average.append(0)
            avgConsistency[treeFile] = average
            allResults[treeFile] = consistencyDict

            print '  Average consistency: ' + str(average) + ', mean = %.2f' % (sum(average)/len(average))
            print ''

        # print out combined results
        fout = open(outputFile, 'w')
        fout.write('Tree\tShort Desc.\tLong Desc.\tAlignment Length\t# Taxa\t# Bacteria\t# Archaea\tAvg. Consistency\tAvg. Domain Consistency\tAvg. Phylum Consistency\tAvg. Class Consistency')
        for r in xrange(0, 3):
            for t in sorted(allTaxa[r]):
                fout.write('\t' + t)
        fout.write('\n')

        filteredGeneTrees = 0
        retainedGeneTrees = 0
        for treeFile in sorted(allResults.keys()):
            consistencyDict = allResults[treeFile]
            treeId = treeFile[0:treeFile.find('.')].replace('pfam', 'PF')

            fout.write(treeId + '\t' + descDict[treeId][0] + '\t' + descDict[treeId][1])

            # Taxa count
            fout.write('\t' + str(taxaCounts[treeFile][0]) + '\t' + str(taxaCounts[treeFile][1]) + '\t' + str(taxaCounts[treeFile][2]))

            avgCon = 0
            for r in xrange(0, 3):
                avgCon += avgConsistency[treeFile][r]
            avgCon /= 3
            fout.write('\t' + str(avgCon))
            
            if avgCon >= consistencyThreshold:
                retainedGeneTrees += 1
                os.system('cp ' + os.path.join(geneTreeDir, treeFile) + ' ' + os.path.join(outputDir, treeFile))
            else:
                filteredGeneTrees += 1
                print 'Filtered % s with an average consistency of %.4f.' % (treeFile, avgCon)

            for r in xrange(0, 3):
                fout.write('\t' + str(avgConsistency[treeFile][r]))

            for r in xrange(0, 3):
                for t in sorted(allTaxa[r]):
                    if t in consistencyDict[r]:
                        if consistencyDict[r][t] != 'N/A':
                            fout.write('\t%.2f' % (consistencyDict[r][t]*100))
                        else:
                            fout.write('\tN/A')
                    else:
                        fout.write('\tN/A')
            fout.write('\n')
        fout.close()

        print 'Retained gene trees: ' + str(retainedGeneTrees)
        print 'Filtered gene trees: ' + str(filteredGeneTrees)
예제 #19
0
class SimulationScaffolds(object):
    def __init__(self):
        self.markerSetBuilder = MarkerSetBuilder()
        self.img = IMG('/srv/whitlam/bio/db/checkm/img/img_metadata.tsv',
                       '/srv/whitlam/bio/db/checkm/pfam/tigrfam2pfam.tsv')

        self.contigLens = [5000, 20000, 50000]
        self.percentComps = [0.5, 0.7, 0.8, 0.9, 0.95, 1.0]
        self.percentConts = [0.0, 0.05, 0.1, 0.15, 0.2]

    def __seqLens(self, seqs):
        """Calculate lengths of seqs."""
        genomeSize = 0
        seqLens = {}
        for seqId, seq in seqs.iteritems():
            seqLens[seqId] = len(seq)
            genomeSize += len(seq)

        return seqLens, genomeSize

    def __workerThread(self, tree, metadata, genomeIdsToTest,
                       ubiquityThreshold, singleCopyThreshold, numReplicates,
                       queueIn, queueOut):
        """Process each data item in parallel."""

        while True:
            testGenomeId = queueIn.get(block=True, timeout=None)
            if testGenomeId == None:
                break

            # build marker sets for evaluating test genome
            testNode = tree.find_node_with_taxon_label('IMG_' + testGenomeId)
            binMarkerSets, refinedBinMarkerSet = self.markerSetBuilder.buildBinMarkerSet(
                tree,
                testNode.parent_node,
                ubiquityThreshold,
                singleCopyThreshold,
                bMarkerSet=True,
                genomeIdsToRemove=[testGenomeId])

            # determine distribution of all marker genes within the test genome
            geneDistTable = self.img.geneDistTable(
                [testGenomeId],
                binMarkerSets.getMarkerGenes(),
                spacingBetweenContigs=0)

            # estimate completeness of unmodified genome
            unmodifiedComp = {}
            unmodifiedCont = {}
            for ms in binMarkerSets.markerSetIter():
                hits = {}
                for mg in ms.getMarkerGenes():
                    if mg in geneDistTable[testGenomeId]:
                        hits[mg] = geneDistTable[testGenomeId][mg]
                completeness, contamination = ms.genomeCheck(
                    hits, bIndividualMarkers=True)
                unmodifiedComp[ms.lineageStr] = completeness
                unmodifiedCont[ms.lineageStr] = contamination

            # estimate completion and contamination of genome after subsampling using both the domain and lineage-specific marker sets
            testSeqs = readFasta(
                os.path.join(self.img.genomeDir, testGenomeId,
                             testGenomeId + '.fna'))
            testSeqLens, genomeSize = self.__seqLens(testSeqs)

            for contigLen in self.contigLens:
                for percentComp in self.percentComps:
                    for percentCont in self.percentConts:
                        deltaComp = defaultdict(list)
                        deltaCont = defaultdict(list)
                        deltaCompSet = defaultdict(list)
                        deltaContSet = defaultdict(list)

                        deltaCompRefined = defaultdict(list)
                        deltaContRefined = defaultdict(list)
                        deltaCompSetRefined = defaultdict(list)
                        deltaContSetRefined = defaultdict(list)

                        trueComps = []
                        trueConts = []

                        numDescendants = {}

                        for i in xrange(0, numReplicates):
                            # generate test genome with a specific level of completeness, by randomly sampling scaffolds to remove
                            # (this will sample >= the desired level of completeness)
                            retainedTestSeqs, trueComp = self.markerSetBuilder.sampleGenomeScaffoldsWithoutReplacement(
                                percentComp, testSeqLens, genomeSize)
                            trueComps.append(trueComp)

                            # select a random genome to use as a source of contamination
                            contGenomeId = random.sample(
                                genomeIdsToTest - set([testGenomeId]), 1)[0]
                            contSeqs = readFasta(
                                os.path.join(self.img.genomeDir, contGenomeId,
                                             contGenomeId + '.fna'))
                            contSeqLens, contGenomeSize = self.__seqLens(
                                contSeqs)
                            seqsToRetain, trueRetainedPer = self.markerSetBuilder.sampleGenomeScaffoldsWithoutReplacement(
                                1 - percentCont, contSeqLens, contGenomeSize)

                            contSampledSeqIds = set(
                                contSeqs.keys()).difference(seqsToRetain)
                            trueCont = 100.0 - trueRetainedPer
                            trueConts.append(trueCont)

                            for ms in binMarkerSets.markerSetIter():
                                numDescendants[ms.lineageStr] = ms.numGenomes
                                containedMarkerGenes = defaultdict(list)
                                self.markerSetBuilder.markerGenesOnScaffolds(
                                    ms.getMarkerGenes(), testGenomeId,
                                    retainedTestSeqs, containedMarkerGenes)
                                self.markerSetBuilder.markerGenesOnScaffolds(
                                    ms.getMarkerGenes(), contGenomeId,
                                    contSampledSeqIds, containedMarkerGenes)

                                completeness, contamination = ms.genomeCheck(
                                    containedMarkerGenes,
                                    bIndividualMarkers=True)
                                deltaComp[ms.lineageStr].append(completeness -
                                                                trueComp)
                                deltaCont[ms.lineageStr].append(contamination -
                                                                trueCont)

                                completeness, contamination = ms.genomeCheck(
                                    containedMarkerGenes,
                                    bIndividualMarkers=False)
                                deltaCompSet[ms.lineageStr].append(
                                    completeness - trueComp)
                                deltaContSet[ms.lineageStr].append(
                                    contamination - trueCont)

                            for ms in refinedBinMarkerSet.markerSetIter():
                                containedMarkerGenes = defaultdict(list)
                                self.markerSetBuilder.markerGenesOnScaffolds(
                                    ms.getMarkerGenes(), testGenomeId,
                                    retainedTestSeqs, containedMarkerGenes)
                                self.markerSetBuilder.markerGenesOnScaffolds(
                                    ms.getMarkerGenes(), contGenomeId,
                                    contSampledSeqIds, containedMarkerGenes)

                                completeness, contamination = ms.genomeCheck(
                                    containedMarkerGenes,
                                    bIndividualMarkers=True)
                                deltaCompRefined[ms.lineageStr].append(
                                    completeness - trueComp)
                                deltaContRefined[ms.lineageStr].append(
                                    contamination - trueCont)

                                completeness, contamination = ms.genomeCheck(
                                    containedMarkerGenes,
                                    bIndividualMarkers=False)
                                deltaCompSetRefined[ms.lineageStr].append(
                                    completeness - trueComp)
                                deltaContSetRefined[ms.lineageStr].append(
                                    contamination - trueCont)

                        taxonomy = ';'.join(metadata[testGenomeId]['taxonomy'])
                        queueOut.put(
                            (testGenomeId, contigLen, percentComp, percentCont,
                             taxonomy, numDescendants, unmodifiedComp,
                             unmodifiedCont, deltaComp, deltaCont,
                             deltaCompSet, deltaContSet, deltaCompRefined,
                             deltaContRefined, deltaCompSetRefined,
                             deltaContSetRefined, trueComps, trueConts))

    def __writerThread(self, numTestGenomes, writerQueue):
        """Store or write results of worker threads in a single thread."""

        summaryOut = open(
            '/tmp/simulation.random_scaffolds.w_refinement_50.draft.summary.tsv',
            'w')
        summaryOut.write('Genome Id\tContig len\t% comp\t% cont')
        summaryOut.write('\tTaxonomy\tMarker set\t# descendants')
        summaryOut.write('\tUnmodified comp\tUnmodified cont')
        summaryOut.write('\tIM comp\tIM comp std\tIM cont\tIM cont std')
        summaryOut.write('\tMS comp\tMS comp std\tMS cont\tMS cont std')
        summaryOut.write('\tRIM comp\tRIM comp std\tRIM cont\tRIM cont std')
        summaryOut.write('\tRMS comp\tRMS comp std\tRMS cont\tRMS cont std\n')

        fout = gzip.open(
            '/tmp/simulation.random_scaffolds.w_refinement_50.draft.tsv.gz',
            'wb')
        fout.write('Genome Id\tContig len\t% comp\t% cont')
        fout.write('\tTaxonomy\tMarker set\t# descendants')
        fout.write('\tUnmodified comp\tUnmodified cont')
        fout.write('\tIM comp\tIM cont')
        fout.write('\tMS comp\tMS cont')
        fout.write('\tRIM comp\tRIM cont')
        fout.write('\tRMS comp\tRMS cont\tTrue Comp\tTrue Cont\n')

        testsPerGenome = len(self.contigLens) * len(self.percentComps) * len(
            self.percentConts)

        itemsProcessed = 0
        while True:
            testGenomeId, contigLen, percentComp, percentCont, taxonomy, numDescendants, unmodifiedComp, unmodifiedCont, deltaComp, deltaCont, deltaCompSet, deltaContSet, deltaCompRefined, deltaContRefined, deltaCompSetRefined, deltaContSetRefined, trueComps, trueConts = writerQueue.get(
                block=True, timeout=None)
            if testGenomeId == None:
                break

            itemsProcessed += 1
            statusStr = '    Finished processing %d of %d (%.2f%%) test cases.' % (
                itemsProcessed, numTestGenomes * testsPerGenome,
                float(itemsProcessed) * 100 /
                (numTestGenomes * testsPerGenome))
            sys.stdout.write('%s\r' % statusStr)
            sys.stdout.flush()

            for markerSetId in unmodifiedComp:
                summaryOut.write(testGenomeId + '\t%d\t%.2f\t%.2f' %
                                 (contigLen, percentComp, percentCont))
                summaryOut.write('\t' + taxonomy + '\t' + markerSetId + '\t' +
                                 str(numDescendants[markerSetId]))
                summaryOut.write(
                    '\t%.3f\t%.3f' %
                    (unmodifiedComp[markerSetId], unmodifiedCont[markerSetId]))
                summaryOut.write('\t%.3f\t%.3f' %
                                 (mean(abs(deltaComp[markerSetId])),
                                  std(abs(deltaComp[markerSetId]))))
                summaryOut.write('\t%.3f\t%.3f' %
                                 (mean(abs(deltaCont[markerSetId])),
                                  std(abs(deltaCont[markerSetId]))))
                summaryOut.write('\t%.3f\t%.3f' %
                                 (mean(abs(deltaCompSet[markerSetId])),
                                  std(abs(deltaCompSet[markerSetId]))))
                summaryOut.write('\t%.3f\t%.3f' %
                                 (mean(abs(deltaContSet[markerSetId])),
                                  std(abs(deltaContSet[markerSetId]))))
                summaryOut.write('\t%.3f\t%.3f' %
                                 (mean(abs(deltaCompRefined[markerSetId])),
                                  std(abs(deltaCompRefined[markerSetId]))))
                summaryOut.write('\t%.3f\t%.3f' %
                                 (mean(abs(deltaContRefined[markerSetId])),
                                  std(abs(deltaContRefined[markerSetId]))))
                summaryOut.write('\t%.3f\t%.3f' %
                                 (mean(abs(deltaCompSetRefined[markerSetId])),
                                  std(abs(deltaCompSetRefined[markerSetId]))))
                summaryOut.write('\t%.3f\t%.3f' %
                                 (mean(abs(deltaContSetRefined[markerSetId])),
                                  std(abs(deltaContSetRefined[markerSetId]))))
                summaryOut.write('\n')

                fout.write(testGenomeId + '\t%d\t%.2f\t%.2f' %
                           (contigLen, percentComp, percentCont))
                fout.write('\t' + taxonomy + '\t' + markerSetId + '\t' +
                           str(numDescendants[markerSetId]))
                fout.write(
                    '\t%.3f\t%.3f' %
                    (unmodifiedComp[markerSetId], unmodifiedCont[markerSetId]))
                fout.write('\t%s' % ','.join(map(str, deltaComp[markerSetId])))
                fout.write('\t%s' % ','.join(map(str, deltaCont[markerSetId])))
                fout.write('\t%s' %
                           ','.join(map(str, deltaCompSet[markerSetId])))
                fout.write('\t%s' %
                           ','.join(map(str, deltaContSet[markerSetId])))
                fout.write('\t%s' %
                           ','.join(map(str, deltaCompRefined[markerSetId])))
                fout.write('\t%s' %
                           ','.join(map(str, deltaContRefined[markerSetId])))
                fout.write(
                    '\t%s' %
                    ','.join(map(str, deltaCompSetRefined[markerSetId])))
                fout.write(
                    '\t%s' %
                    ','.join(map(str, deltaContSetRefined[markerSetId])))
                fout.write('\t%s' % ','.join(map(str, trueComps)))
                fout.write('\t%s' % ','.join(map(str, trueConts)))
                fout.write('\n')

        summaryOut.close()
        fout.close()

        sys.stdout.write('\n')

    def run(self, ubiquityThreshold, singleCopyThreshold, numReplicates,
            minScaffolds, numThreads):
        random.seed(0)

        print '\n  Reading reference genome tree.'
        treeFile = os.path.join('/srv', 'db', 'checkm', 'genome_tree',
                                'genome_tree_prok.refpkg',
                                'genome_tree.final.tre')
        tree = dendropy.Tree.get_from_path(treeFile,
                                           schema='newick',
                                           as_rooted=True,
                                           preserve_underscores=True)

        print '    Number of taxa in tree: %d' % (len(tree.leaf_nodes()))

        genomesInTree = set()
        for leaf in tree.leaf_iter():
            genomesInTree.add(leaf.taxon.label.replace('IMG_', ''))

        # get all draft genomes consisting of a user-specific minimum number of scaffolds
        print ''
        metadata = self.img.genomeMetadata()
        print '  Total genomes: %d' % len(metadata)

        draftGenomeIds = genomesInTree - self.img.filterGenomeIds(
            genomesInTree, metadata, 'status', 'Finished')
        print '  Number of draft genomes: %d' % len(draftGenomeIds)

        genomeIdsToTest = set()
        for genomeId in draftGenomeIds:
            if metadata[genomeId]['scaffold count'] >= minScaffolds:
                genomeIdsToTest.add(genomeId)

        print '  Number of draft genomes with >= %d scaffolds: %d' % (
            minScaffolds, len(genomeIdsToTest))

        print ''
        start = time.time()
        self.markerSetBuilder.readLineageSpecificGenesToRemove()
        end = time.time()
        print '    readLineageSpecificGenesToRemove: %.2f' % (end - start)

        print '  Pre-computing genome information for calculating marker sets:'
        start = time.time()
        self.markerSetBuilder.precomputeGenomeFamilyScaffolds(metadata.keys())
        end = time.time()
        print '    precomputeGenomeFamilyScaffolds: %.2f' % (end - start)

        start = time.time()
        self.markerSetBuilder.cachedGeneCountTable = self.img.geneCountTable(
            metadata.keys())
        end = time.time()
        print '    globalGeneCountTable: %.2f' % (end - start)

        start = time.time()
        self.markerSetBuilder.precomputeGenomeSeqLens(metadata.keys())
        end = time.time()
        print '    precomputeGenomeSeqLens: %.2f' % (end - start)

        start = time.time()
        self.markerSetBuilder.precomputeGenomeFamilyPositions(
            metadata.keys(), 0)
        end = time.time()
        print '    precomputeGenomeFamilyPositions: %.2f' % (end - start)

        print ''
        print '  Evaluating %d test genomes.' % len(genomeIdsToTest)

        workerQueue = mp.Queue()
        writerQueue = mp.Queue()

        for testGenomeId in list(genomeIdsToTest):
            workerQueue.put(testGenomeId)

        for _ in range(numThreads):
            workerQueue.put(None)

        workerProc = [
            mp.Process(target=self.__workerThread,
                       args=(tree, metadata, genomeIdsToTest,
                             ubiquityThreshold, singleCopyThreshold,
                             numReplicates, workerQueue, writerQueue))
            for _ in range(numThreads)
        ]
        writeProc = mp.Process(target=self.__writerThread,
                               args=(len(genomeIdsToTest), writerQueue))

        writeProc.start()

        for p in workerProc:
            p.start()

        for p in workerProc:
            p.join()

        writerQueue.put((None, None, None, None, None, None, None, None, None,
                         None, None, None, None, None, None, None, None, None))
        writeProc.join()
예제 #20
0
    def run(
        self, geneTreeDir, alignmentDir, extension, outputAlignFile, outputTree, outputTaxonomy, bSupportValues=False
    ):
        # read gene trees
        print "Reading gene trees."
        geneIds = set()
        files = os.listdir(geneTreeDir)
        for f in files:
            if f.endswith(".tre"):
                geneId = f[0 : f.find(".")]
                geneIds.add(geneId)

        # write out genome tree taxonomy
        print "Reading trusted genomes."
        img = IMG("/srv/whitlam/bio/db/checkm/img/img_metadata.tsv", "/srv/whitlam/bio/db/checkm/pfam/tigrfam2pfam.tsv")
        genomeIds = img.genomeMetadata().keys()
        self.__taxonomy(img, genomeIds, outputTaxonomy)

        print "  There are %d trusted genomes." % (len(genomeIds))

        # get genes in genomes
        print "Reading all PFAM and TIGRFAM hits in trusted genomes."
        genesInGenomes = self.__genesInGenomes(genomeIds)

        # read alignment files
        print "Reading alignment files."
        alignments = {}
        genomeIds = set()
        files = os.listdir(alignmentDir)
        for f in files:
            geneId = f[0 : f.find(".")]
            if f.endswith(extension) and geneId in geneIds:
                seqs = readFasta(os.path.join(alignmentDir, f))

                imgGeneId = geneId
                if imgGeneId.startswith("PF"):
                    imgGeneId = imgGeneId.replace("PF", "pfam")
                seqs = self.__filterParalogs(seqs, imgGeneId, genesInGenomes)

                genomeIds.update(set(seqs.keys()))
                alignments[geneId] = seqs

        # create concatenated alignment
        print "Concatenating alignments:"
        concatenatedSeqs = {}
        totalAlignLen = 0
        for geneId in sorted(alignments.keys()):
            seqs = alignments[geneId]
            alignLen = len(seqs[seqs.keys()[0]])
            print "  " + str(geneId) + "," + str(alignLen)
            totalAlignLen += alignLen
            for genomeId in genomeIds:
                if genomeId in seqs:
                    # append alignment
                    concatenatedSeqs["IMG_" + genomeId] = concatenatedSeqs.get("IMG_" + genomeId, "") + seqs[genomeId]
                else:
                    # missing gene
                    concatenatedSeqs["IMG_" + genomeId] = concatenatedSeqs.get("IMG_" + genomeId, "") + "-" * alignLen

        print "  Total alignment length: " + str(totalAlignLen)

        # save concatenated alignment
        writeFasta(concatenatedSeqs, outputAlignFile)

        # infer genome tree
        print "Inferring genome tree."
        outputLog = outputTree[0 : outputTree.rfind(".")] + ".log"

        supportStr = " "
        if not bSupportValues:
            supportStr = " -nosupport "

        cmd = "FastTreeMP" + supportStr + "-wag -gamma -log " + outputLog + " " + outputAlignFile + " > " + outputTree
        os.system(cmd)
예제 #21
0
class SimCompareDiffPlot(object):
    def __init__(self):
        self.img = IMG('/srv/whitlam/bio/db/checkm/img/img_metadata.tsv',
                       '/srv/whitlam/bio/db/checkm/pfam/tigrfam2pfam.tsv')

    def run(self):
        # count number of times the lineage-specific marker set results outperform
        # the domain-specific marker set for varying differences between the two sets
        numBars = 15

        lineageCountsComp = [0] * numBars
        domainCountsComp = [0] * numBars

        lineageCountsCont = [0] * numBars
        domainCountsCont = [0] * numBars

        totalCountsComp = 0
        totalCountsCont = 0

        domCompBest = 0
        lineageCompBest = 0
        domContBest = 0
        lineageContBest = 0

        metadata = self.img.genomeMetadata()
        domCompTaxon = defaultdict(int)
        lineageCompTaxon = defaultdict(int)

        for line in open('./simulations/briefSummaryOut.tsv'):
            lineSplit = line.split('\t')
            genomeId = lineSplit[0]
            taxonomy = metadata[genomeId]['taxonomy']
            phylum = taxonomy[1]
            domCompMS, lineageCompMS, lineageCompRMS, domContMS, lineageContMS, lineageContRMS = [
                float(x) for x in lineSplit[1:]
            ]

            diff = abs(abs(lineageCompMS) - abs(domCompMS))
            if diff > 5:
                intDiff = int(diff)
                if intDiff >= numBars:
                    intDiff = (numBars - 1)

                if abs(domCompMS) < abs(lineageCompMS):
                    domainCountsComp[intDiff] += 1
                    domCompBest += 1
                    domCompTaxon[phylum] += 1
                else:
                    lineageCountsComp[intDiff] += 1
                    lineageCompBest += 1
                    lineageCompTaxon[phylum] += 1

                totalCountsComp += 1

            diff = abs(abs(lineageContMS) - abs(domContMS))
            if diff > 5:
                intDiff = int(diff)
                if intDiff >= numBars:
                    intDiff = (numBars - 1)

                if abs(domContMS) < abs(lineageContMS):
                    domainCountsCont[intDiff] += 1
                    domContBest += 1
                else:
                    lineageCountsCont[intDiff] += 1
                    lineageContBest += 1

                totalCountsCont += 1

        print('%% times lineage comp better than domain: %.2f' %
              (float(lineageCompBest) * 100 / (domCompBest + lineageCompBest)))
        print('%% times lineage cont better than domain: %.2f' %
              (float(lineageContBest) * 100 / (domContBest + lineageContBest)))

        print('')
        print('Taxonomy breakdown (dom best, lineage best):')
        taxa = set(domCompTaxon.keys()).union(lineageCompTaxon.keys())
        for t in taxa:
            print('%s\t%.2f\t%.2f' %
                  (t, domCompTaxon[t] * 100.0 / domCompBest,
                   lineageCompTaxon[t] * 100.0 / lineageCompBest))

        # normalize counts
        for i in range(0, numBars):
            lineageCountsComp[i] = float(
                lineageCountsComp[i]) * 100 / totalCountsComp
            domainCountsComp[i] = float(
                domainCountsComp[i]) * 100 / totalCountsComp

            if domainCountsComp[i] > lineageCountsComp[i]:
                print('Domain bets lineage (comp): %d%% (%f, %f)' %
                      (i + 1, domainCountsComp[i], lineageCountsComp[i]))

            lineageCountsCont[i] = float(
                lineageCountsCont[i]) * 100 / totalCountsCont
            domainCountsCont[i] = float(
                domainCountsCont[i]) * 100 / totalCountsCont

            if domainCountsCont[i] > lineageCountsCont[i]:
                print('Domain bets lineage (cont): %d%% (%f, %f)' %
                      (i + 1, domainCountsCont[i], lineageCountsCont[i]))

        stackedBarPlot = StackedBarPlot()
        stackedBarPlot.plot(lineageCountsComp, domainCountsComp,
                            lineageCountsCont, domainCountsCont)
        stackedBarPlot.savePlot('./experiments/simCompareDiffPlot.svg')
예제 #22
0
 def __init__(self):
     self.img = IMG('/srv/whitlam/bio/db/checkm/img/img_metadata.tsv', '/srv/whitlam/bio/db/checkm/pfam/tigrfam2pfam.tsv')
예제 #23
0
class PlotScaffoldLenVsMarkers(object):
    def __init__(self):
        self.img = IMG('/srv/whitlam/bio/db/checkm/img/img_metadata.tsv', '/srv/whitlam/bio/db/checkm/pfam/tigrfam2pfam.tsv')
        

    def run(self):
        # get all draft genomes consisting of a user-specific minimum number of scaffolds
        print ''
        metadata = self.img.genomeMetadata()
        print '  Total genomes: %d' % len(metadata)
        
        arGenome = set()
        for genomeId in metadata:
            if metadata[genomeId]['taxonomy'][0] == 'Archaea':
                arGenome.add(genomeId)
                
        draftGenomeIds = arGenome - self.img.filterGenomeIds(arGenome, metadata, 'status', 'Finished')
        print '  Number of draft genomes: %d' % len(draftGenomeIds)
        
        minScaffolds = 20
        genomeIdsToTest = set()
        for genomeId in draftGenomeIds:
            if metadata[genomeId]['scaffold count'] >= minScaffolds:
                genomeIdsToTest.add(genomeId)
        print '  Number of draft genomes with >= %d scaffolds: %d' % (minScaffolds, len(genomeIdsToTest))

        print ''
        print '  Calculating genome information for calculating marker sets:'
        genomeFamilyScaffolds = self.img.precomputeGenomeFamilyScaffolds(genomeIdsToTest)
        
        print '  Calculating genome sequence lengths.'
        genomeSeqLens = self.img.precomputeGenomeSeqLens(genomeIdsToTest)
        
        print '  Determining domain-specific marker sets.'
        taxonParser = TaxonParser()
        taxonMarkerSets = taxonParser.readMarkerSets()
        bacMarkers = taxonMarkerSets['domain']['Bacteria'].getMarkerGenes()
        arMarkers = taxonMarkerSets['domain']['Archaea'].getMarkerGenes()
        print '    There are %d bacterial markers and %d archaeal markers.' % (len(bacMarkers), len(arMarkers))
        
        print '  Determining percentage of markers on each scaffold.'
        totalMarkers = 0
        totalSequenceLen = 0
        markersOnShortScaffolds = 0
        totalShortScaffoldLen = 0
        
        scaffoldLen = {}
        percentageMarkers = defaultdict(float)
        for genomeId, markerIds in genomeFamilyScaffolds.iteritems():
            domain = metadata[genomeId]['taxonomy'][0]
            markerGenes = bacMarkers if domain == 'Bacteria' else arMarkers
            for markerId in markerGenes:
                if markerId.startswith('PF'):
                    markerId = markerId.replace('PF', 'pfam')
                    markerId = markerId[0:markerId.rfind('.')]
                if markerId in markerIds:
                    for scaffoldId in markerIds[markerId]:
                        scaffoldLen[scaffoldId] = genomeSeqLens[genomeId][scaffoldId]
                        percentageMarkers[scaffoldId] += 1.0/len(markerGenes)
                        
                        totalMarkers += 1
                        totalSequenceLen += genomeSeqLens[genomeId][scaffoldId]
                        
                        if genomeSeqLens[genomeId][scaffoldId] < 10000:
                            markersOnShortScaffolds += 1
                            totalShortScaffoldLen += genomeSeqLens[genomeId][scaffoldId]
       
        print 'Markers on short scaffolds: %d over %d Mbp (%f markers per base)' % (markersOnShortScaffolds, totalShortScaffoldLen, float(markersOnShortScaffolds)/totalShortScaffoldLen)
        print 'Total markers on scaffolds: %d over %d Mbp (%f markers per base)' % (totalMarkers, totalSequenceLen, float(totalMarkers)/totalSequenceLen)
                        
        print '  Create plot.'
        plotLens = []
        plotPerMarkers = []
        for scaffoldId in percentageMarkers:
            plotLens.append(scaffoldLen[scaffoldId])
            plotPerMarkers.append(percentageMarkers[scaffoldId]/scaffoldLen[scaffoldId] * 1e6)
            
        scatterPlot = ScatterPlot()
        scatterPlot.plot(plotLens, plotPerMarkers)     
        scatterPlot.savePlot('./experiments/plotScaffoldLenVsMarkers.png')
예제 #24
0
class Simulation(object):
    def __init__(self):
        self.markerSetBuilder = MarkerSetBuilder()
        self.img = IMG(
            "/srv/whitlam/bio/db/checkm/img/img_metadata.tsv", "/srv/whitlam/bio/db/checkm/pfam/tigrfam2pfam.tsv"
        )

        self.contigLens = [1000, 2000, 5000, 10000, 20000, 50000]
        self.percentComps = [0.5, 0.7, 0.8, 0.9, 0.95, 1.0]
        self.percentConts = [0.0, 0.05, 0.1, 0.15, 0.2]

    def __workerThread(self, tree, metadata, ubiquityThreshold, singleCopyThreshold, numReplicates, queueIn, queueOut):
        """Process each data item in parallel."""

        while True:
            testGenomeId = queueIn.get(block=True, timeout=None)
            if testGenomeId == None:
                break

            # build marker sets for evaluating test genome
            testNode = tree.find_node_with_taxon_label("IMG_" + testGenomeId)
            binMarkerSets, refinedBinMarkerSet = self.markerSetBuilder.buildBinMarkerSet(
                tree,
                testNode.parent_node,
                ubiquityThreshold,
                singleCopyThreshold,
                bMarkerSet=True,
                genomeIdsToRemove=[testGenomeId],
            )
            #!!!binMarkerSets, refinedBinMarkerSet = self.markerSetBuilder.buildDomainMarkerSet(tree, testNode.parent_node, ubiquityThreshold, singleCopyThreshold, bMarkerSet = False, genomeIdsToRemove = [testGenomeId])

            # determine distribution of all marker genes within the test genome
            geneDistTable = self.img.geneDistTable(
                [testGenomeId], binMarkerSets.getMarkerGenes(), spacingBetweenContigs=0
            )

            print "# marker genes: ", len(binMarkerSets.getMarkerGenes())
            print "# genes in table: ", len(geneDistTable[testGenomeId])

            # estimate completeness of unmodified genome
            unmodifiedComp = {}
            unmodifiedCont = {}
            for ms in binMarkerSets.markerSetIter():
                hits = {}
                for mg in ms.getMarkerGenes():
                    if mg in geneDistTable[testGenomeId]:
                        hits[mg] = geneDistTable[testGenomeId][mg]
                completeness, contamination = ms.genomeCheck(hits, bIndividualMarkers=True)
                unmodifiedComp[ms.lineageStr] = completeness
                unmodifiedCont[ms.lineageStr] = contamination

            print completeness, contamination

            # estimate completion and contamination of genome after subsampling using both the domain and lineage-specific marker sets
            genomeSize = readFastaBases(os.path.join(self.img.genomeDir, testGenomeId, testGenomeId + ".fna"))
            print "genomeSize", genomeSize

            for contigLen in self.contigLens:
                for percentComp in self.percentComps:
                    for percentCont in self.percentConts:
                        deltaComp = defaultdict(list)
                        deltaCont = defaultdict(list)
                        deltaCompSet = defaultdict(list)
                        deltaContSet = defaultdict(list)

                        deltaCompRefined = defaultdict(list)
                        deltaContRefined = defaultdict(list)
                        deltaCompSetRefined = defaultdict(list)
                        deltaContSetRefined = defaultdict(list)

                        trueComps = []
                        trueConts = []

                        numDescendants = {}

                        for _ in xrange(0, numReplicates):
                            trueComp, trueCont, startPartialGenomeContigs = self.markerSetBuilder.sampleGenome(
                                genomeSize, percentComp, percentCont, contigLen
                            )
                            print contigLen, trueComp, trueCont, len(startPartialGenomeContigs)

                            trueComps.append(trueComp)
                            trueConts.append(trueCont)

                            for ms in binMarkerSets.markerSetIter():
                                numDescendants[ms.lineageStr] = ms.numGenomes

                                containedMarkerGenes = self.markerSetBuilder.containedMarkerGenes(
                                    ms.getMarkerGenes(),
                                    geneDistTable[testGenomeId],
                                    startPartialGenomeContigs,
                                    contigLen,
                                )
                                completeness, contamination = ms.genomeCheck(
                                    containedMarkerGenes, bIndividualMarkers=True
                                )
                                deltaComp[ms.lineageStr].append(completeness - trueComp)
                                deltaCont[ms.lineageStr].append(contamination - trueCont)

                                completeness, contamination = ms.genomeCheck(
                                    containedMarkerGenes, bIndividualMarkers=False
                                )
                                deltaCompSet[ms.lineageStr].append(completeness - trueComp)
                                deltaContSet[ms.lineageStr].append(contamination - trueCont)

                            for ms in refinedBinMarkerSet.markerSetIter():
                                containedMarkerGenes = self.markerSetBuilder.containedMarkerGenes(
                                    ms.getMarkerGenes(),
                                    geneDistTable[testGenomeId],
                                    startPartialGenomeContigs,
                                    contigLen,
                                )
                                completeness, contamination = ms.genomeCheck(
                                    containedMarkerGenes, bIndividualMarkers=True
                                )
                                deltaCompRefined[ms.lineageStr].append(completeness - trueComp)
                                deltaContRefined[ms.lineageStr].append(contamination - trueCont)

                                completeness, contamination = ms.genomeCheck(
                                    containedMarkerGenes, bIndividualMarkers=False
                                )
                                deltaCompSetRefined[ms.lineageStr].append(completeness - trueComp)
                                deltaContSetRefined[ms.lineageStr].append(contamination - trueCont)

                        taxonomy = ";".join(metadata[testGenomeId]["taxonomy"])
                        queueOut.put(
                            (
                                testGenomeId,
                                contigLen,
                                percentComp,
                                percentCont,
                                taxonomy,
                                numDescendants,
                                unmodifiedComp,
                                unmodifiedCont,
                                trueComps,
                                trueConts,
                                deltaComp,
                                deltaCont,
                                deltaCompSet,
                                deltaContSet,
                                deltaCompRefined,
                                deltaContRefined,
                                deltaCompSetRefined,
                                deltaContSetRefined,
                                trueComps,
                                trueConts,
                            )
                        )

    def __writerThread(self, numTestGenomes, writerQueue):
        """Store or write results of worker threads in a single thread."""

        # summaryOut = open('/tmp/simulation.draft.summary.w_refinement_50.tsv', 'w')
        summaryOut = open("/tmp/simulation.summary.testing.tsv", "w")
        summaryOut.write("Genome Id\tContig len\t% comp\t% cont")
        summaryOut.write("\tTaxonomy\tMarker set\t# descendants")
        summaryOut.write("\tUnmodified comp\tUnmodified cont\tTrue comp\tTrue cont")
        summaryOut.write("\tIM comp\tIM comp std\tIM cont\tIM cont std")
        summaryOut.write("\tMS comp\tMS comp std\tMS cont\tMS cont std")
        summaryOut.write("\tRIM comp\tRIM comp std\tRIM cont\tRIM cont std")
        summaryOut.write("\tRMS comp\tRMS comp std\tRMS cont\tRMS cont std\n")

        # fout = gzip.open('/tmp/simulation.draft.w_refinement_50.tsv.gz', 'wb')
        fout = gzip.open("/tmp/simulation.testing.tsv.gz", "wb")
        fout.write("Genome Id\tContig len\t% comp\t% cont")
        fout.write("\tTaxonomy\tMarker set\t# descendants")
        fout.write("\tUnmodified comp\tUnmodified cont\tTrue comp\tTrue cont")
        fout.write("\tIM comp\tIM cont")
        fout.write("\tMS comp\tMS cont")
        fout.write("\tRIM comp\tRIM cont")
        fout.write("\tRMS comp\tRMS cont\tTrue Comp\tTrue Cont\n")

        testsPerGenome = len(self.contigLens) * len(self.percentComps) * len(self.percentConts)

        itemsProcessed = 0
        while True:
            testGenomeId, contigLen, percentComp, percentCont, taxonomy, numDescendants, unmodifiedComp, unmodifiedCont, trueComps, trueConts, deltaComp, deltaCont, deltaCompSet, deltaContSet, deltaCompRefined, deltaContRefined, deltaCompSetRefined, deltaContSetRefined, trueComps, trueConts = writerQueue.get(
                block=True, timeout=None
            )
            if testGenomeId == None:
                break

            itemsProcessed += 1
            statusStr = "    Finished processing %d of %d (%.2f%%) test cases." % (
                itemsProcessed,
                numTestGenomes * testsPerGenome,
                float(itemsProcessed) * 100 / (numTestGenomes * testsPerGenome),
            )
            sys.stdout.write("%s\r" % statusStr)
            sys.stdout.flush()

            for markerSetId in unmodifiedComp:
                summaryOut.write(testGenomeId + "\t%d\t%.2f\t%.2f" % (contigLen, percentComp, percentCont))
                summaryOut.write("\t" + taxonomy + "\t" + markerSetId + "\t" + str(numDescendants[markerSetId]))
                summaryOut.write("\t%.3f\t%.3f" % (unmodifiedComp[markerSetId], unmodifiedCont[markerSetId]))
                summaryOut.write("\t%.3f\t%.3f" % (mean(trueComps), std(trueConts)))
                summaryOut.write("\t%.3f\t%.3f" % (mean(abs(deltaComp[markerSetId])), std(abs(deltaComp[markerSetId]))))
                summaryOut.write("\t%.3f\t%.3f" % (mean(abs(deltaCont[markerSetId])), std(abs(deltaCont[markerSetId]))))
                summaryOut.write(
                    "\t%.3f\t%.3f" % (mean(abs(deltaCompSet[markerSetId])), std(abs(deltaCompSet[markerSetId])))
                )
                summaryOut.write(
                    "\t%.3f\t%.3f" % (mean(abs(deltaContSet[markerSetId])), std(abs(deltaContSet[markerSetId])))
                )
                summaryOut.write(
                    "\t%.3f\t%.3f" % (mean(abs(deltaCompRefined[markerSetId])), std(abs(deltaCompRefined[markerSetId])))
                )
                summaryOut.write(
                    "\t%.3f\t%.3f" % (mean(abs(deltaContRefined[markerSetId])), std(abs(deltaContRefined[markerSetId])))
                )
                summaryOut.write(
                    "\t%.3f\t%.3f"
                    % (mean(abs(deltaCompSetRefined[markerSetId])), std(abs(deltaCompSetRefined[markerSetId])))
                )
                summaryOut.write(
                    "\t%.3f\t%.3f"
                    % (mean(abs(deltaContSetRefined[markerSetId])), std(abs(deltaContSetRefined[markerSetId])))
                )
                summaryOut.write("\n")

                fout.write(testGenomeId + "\t%d\t%.2f\t%.2f" % (contigLen, percentComp, percentCont))
                fout.write("\t" + taxonomy + "\t" + markerSetId + "\t" + str(numDescendants[markerSetId]))
                fout.write("\t%.3f\t%.3f" % (unmodifiedComp[markerSetId], unmodifiedCont[markerSetId]))
                fout.write("\t%s" % ",".join(map(str, trueComps)))
                fout.write("\t%s" % ",".join(map(str, trueConts)))
                fout.write("\t%s" % ",".join(map(str, deltaComp[markerSetId])))
                fout.write("\t%s" % ",".join(map(str, deltaCont[markerSetId])))
                fout.write("\t%s" % ",".join(map(str, deltaCompSet[markerSetId])))
                fout.write("\t%s" % ",".join(map(str, deltaContSet[markerSetId])))
                fout.write("\t%s" % ",".join(map(str, deltaCompRefined[markerSetId])))
                fout.write("\t%s" % ",".join(map(str, deltaContRefined[markerSetId])))
                fout.write("\t%s" % ",".join(map(str, deltaCompSetRefined[markerSetId])))
                fout.write("\t%s" % ",".join(map(str, deltaContSetRefined[markerSetId])))
                fout.write("\t%s" % ",".join(map(str, trueComps)))
                fout.write("\t%s" % ",".join(map(str, trueConts)))
                fout.write("\n")

        summaryOut.close()
        fout.close()

        sys.stdout.write("\n")

    def run(self, ubiquityThreshold, singleCopyThreshold, numReplicates, numThreads):
        print "\n  Reading reference genome tree."
        treeFile = os.path.join("/srv", "db", "checkm", "genome_tree", "genome_tree_full.refpkg", "genome_tree.tre")
        tree = dendropy.Tree.get_from_path(treeFile, schema="newick", as_rooted=True, preserve_underscores=True)

        print "    Number of taxa in tree: %d" % (len(tree.leaf_nodes()))

        genomesInTree = set()
        for leaf in tree.leaf_iter():
            genomesInTree.add(leaf.taxon.label.replace("IMG_", ""))

        # get all draft genomes for testing
        print ""
        metadata = self.img.genomeMetadata()
        print "  Total genomes: %d" % len(metadata)

        genomeIdsToTest = genomesInTree - self.img.filterGenomeIds(genomesInTree, metadata, "status", "Finished")
        print "  Number of draft genomes: %d" % len(genomeIdsToTest)

        print ""
        print "  Pre-computing genome information for calculating marker sets:"
        start = time.time()
        self.markerSetBuilder.readLineageSpecificGenesToRemove()
        end = time.time()
        print "    readLineageSpecificGenesToRemove: %.2f" % (end - start)

        start = time.time()
        # self.markerSetBuilder.cachedGeneCountTable = self.img.geneCountTable(metadata.keys())
        end = time.time()
        print "    globalGeneCountTable: %.2f" % (end - start)

        start = time.time()
        # self.markerSetBuilder.precomputeGenomeSeqLens(metadata.keys())
        end = time.time()
        print "    precomputeGenomeSeqLens: %.2f" % (end - start)

        start = time.time()
        # self.markerSetBuilder.precomputeGenomeFamilyPositions(metadata.keys(), 0)
        end = time.time()
        print "    precomputeGenomeFamilyPositions: %.2f" % (end - start)

        print ""
        print "  Evaluating %d test genomes." % len(genomeIdsToTest)
        workerQueue = mp.Queue()
        writerQueue = mp.Queue()

        for testGenomeId in genomeIdsToTest:
            workerQueue.put(testGenomeId)

        for _ in range(numThreads):
            workerQueue.put(None)

        workerProc = [
            mp.Process(
                target=self.__workerThread,
                args=(tree, metadata, ubiquityThreshold, singleCopyThreshold, numReplicates, workerQueue, writerQueue),
            )
            for _ in range(numThreads)
        ]
        writeProc = mp.Process(target=self.__writerThread, args=(len(genomeIdsToTest), writerQueue))

        writeProc.start()

        for p in workerProc:
            p.start()

        for p in workerProc:
            p.join()

        writerQueue.put(
            (
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
            )
        )
        writeProc.join()
class IdentifyGeneLossAndDuplication(object):
    def __init__(self):
        self.markerSetBuilder = MarkerSetBuilder()
        self.img = IMG('/srv/whitlam/bio/db/checkm/img/img_metadata.tsv', '/srv/whitlam/bio/db/checkm/pfam/tigrfam2pfam.tsv')

    def run(self, ubiquityThreshold, minGenomes):
        # Pre-compute gene count table
        print 'Computing gene count table.'
        start = time.time()
        metadata = self.img.genomeMetadata()
        self.markerSetBuilder.cachedGeneCountTable = self.img.geneCountTable(metadata.keys())
        end = time.time()
        print '    globalGeneCountTable: %.2f' % (end - start)

        # read selected node for defining marker set
        print 'Reading node defining marker set for each internal node.'
        selectedMarkerNode = {}
        for line in open('/srv/whitlam/bio/db/checkm/selected_marker_sets.tsv'):
            lineSplit = line.split('\t')
            selectedMarkerNode[lineSplit[0].strip()] = lineSplit[1].strip()
            
        # read duplicate taxa
        print 'Reading list of identical taxa in genome tree.'
        duplicateTaxa = {}
        for line in open('/srv/whitlam/bio/db/checkm/genome_tree/genome_tree.derep.txt'):
            lineSplit = line.rstrip().split()
            if len(lineSplit) > 1:
                duplicateTaxa[lineSplit[0]] = lineSplit[1:]
        
        # read in node metadata
        print 'Reading node metadata.'
        treeParser = TreeParser()
        uniqueIdToLineageStatistics = treeParser.readNodeMetadata()
        
        # read genome tree
        print 'Reading in genome tree.'
                
        treeFile = '/srv/whitlam/bio/db/checkm/genome_tree/genome_tree_prok.refpkg/genome_tree.final.tre'
        tree = dendropy.Tree.get_from_path(treeFile, schema='newick', as_rooted=True, preserve_underscores=True)
        
        # determine lineage-specific gene loss and duplication (relative to potential marker genes used by a node)
        print 'Determining lineage-specific gene loss and duplication'
        
        fout = open('/srv/whitlam/bio/db/checkm/genome_tree/missing_duplicate_genes_50.tsv', 'w')
        
        processed = 0
        numInternalNodes = len(tree.internal_nodes())
        for node in tree.internal_nodes():
            processed += 1
            statusStr = '    Finished processing %d of %d (%.2f%%) internal nodes.' % (processed, numInternalNodes, float(processed)*100/numInternalNodes)
            sys.stdout.write('%s\r' % statusStr)
            sys.stdout.flush()
            
            nodeId = node.label.split('|')[0]
            
            missingGenes = []
            duplicateGenes = []
            
            nodeStats = uniqueIdToLineageStatistics[nodeId]
            if nodeStats['# genomes'] >= minGenomes:               
                # get marker genes defined for current node along with all parental nodes    
                markerGenes = set() 
                parentNode = node
                while parentNode != None:                     
                    parentNodeId = parentNode.label.split('|')[0]
                
                    stats = uniqueIdToLineageStatistics[parentNodeId]
                    markerSet = MarkerSet(parentNodeId, stats['taxonomy'], stats['# genomes'], eval(stats['marker set']))
                    markerGenes = markerGenes.union(markerSet.getMarkerGenes())
                
                    parentNode = parentNode.parent_node
                
                # silly hack since PFAM ids are inconsistent between the PFAM data and IMG data
                revisedMarkerGeneIds = set()
                for mg in markerGenes:
                    if mg.startswith('PF'):
                        revisedMarkerGeneIds.add(mg[0:mg.rfind('.')].replace('PF', 'pfam'))
                    else:
                        revisedMarkerGeneIds.add(mg)
                
                # get all genomes below the internal node (including genomes removed as duplicates)
                genomeIds = []
                for leaf in node.leaf_nodes():
                    genomeIds.append(leaf.taxon.label.replace('IMG_', ''))
                    if leaf.taxon.label in duplicateTaxa:
                        for genomeId in duplicateTaxa[leaf.taxon.label]:
                            genomeIds.append(genomeId.replace('IMG_', ''))
                            
                    genomeIds.append(leaf.taxon.label.replace('IMG_', ''))
                
                missingGenes = self.markerSetBuilder.missingGenes(genomeIds, revisedMarkerGeneIds, ubiquityThreshold)
                duplicateGenes = self.markerSetBuilder.duplicateGenes(genomeIds, revisedMarkerGeneIds, ubiquityThreshold)
                
            fout.write('%s\t%s\t%s\n' % (nodeId, str(missingGenes), str(duplicateGenes)))
            
        sys.stdout.write('\n')
            
        fout.close()
예제 #26
0
class CreateSamllTree(object):
    def __init__(self, outputDir):
        self.__checkForFastTree()

        self.derepConcatenatedAlignFile = os.path.join(outputDir, 'genome_tree.concatenated.derep.fasta')
        self.tree = os.path.join(outputDir, 'genome_tree.final.tre')

        self.img = IMG('/srv/whitlam/bio/db/checkm/img/img_metadata.tsv', '/srv/whitlam/bio/db/checkm/pfam/tigrfam2pfam.tsv')
        self.metadata = self.img.genomeMetadata()

    def __checkForFastTree(self):
        """Check to see if FastTree is on the system path."""

        try:
            exit_status = os.system('FastTree 2> /dev/null')
        except:
            print "Unexpected error!", sys.exc_info()[0]
            raise

        if exit_status != 0:
            print "[Error] FastTree is not on the system path"
            sys.exit()

    def __nearlyIdentical(self, string1, string2, max_diff_perc=0.08):
        max_diff = int(max_diff_perc * len(string1))

        n_diff = 0
        for c1, c2 in itertools.izip(string1, string2):
            if c1 != c2:
                n_diff += 1
                if n_diff >= max_diff:
                    return False

        return True

    def __nearlyIdenticalGenomes(self, seqs, outputDir):
        identical = []
        numTaxa = 0

        nearlyIdenticalFile = os.path.join(outputDir, 'nearly_identical.tsv')
        if os.path.exists(nearlyIdenticalFile):
            for line in open(nearlyIdenticalFile):
                lineSplit = line.split('\t')
                s = set()
                for genomeId in lineSplit:
                    numTaxa += 1
                    s.add(genomeId.strip())
                identical.append(s)
        else:
            seqIds = seqs.keys()

            processed = set()
            for i in xrange(0, len(seqIds)):
                print '  %d of %d' % (i, len(seqIds))
                seqIdI = seqIds[i]
                seqI = seqs[seqIdI]

                if seqIdI in processed:
                    continue

                processed.add(seqIdI)

                numTaxa += 1

                s = set()
                s.add(seqIdI)
                for j in xrange(i + 1, len(seqIds)):
                    seqIdJ = seqIds[j]
                    seqJ = seqs[seqIdJ]

                    if seqIdJ in processed:
                        continue

                    if self.__nearlyIdentical(seqI, seqJ):
                        s.add(seqIdJ)
                        processed.add(seqIdJ)

                identical.append(s)
                print '    set size: %d' % len(s)
                if len(s) > 1:
                    for genomeId in s:
                        genomeId = genomeId.replace('IMG_', '')
                        print genomeId, self.metadata[genomeId]['taxonomy']

            fout = open(nearlyIdenticalFile, 'w')
            for s in identical:
                fout.write('\t'.join(list(s)) + '\n')
            fout.close()

        print '  Number of taxa: %d' % numTaxa
        print '  Number of dereplicated taxa: %d' % len(identical)

        return identical

    def run(self, outputDir):
        # make sure output directory exists
        if not os.path.exists(outputDir):
            os.mkdir(outputDir)

        # remove similar taxa
        print 'Filtering out highly similar taxa in order to reduce size of tree:'
        seqs = readFasta(self.derepConcatenatedAlignFile)

        nearlyIdentical = self.__nearlyIdenticalGenomes(seqs, outputDir)

        reducedSeqs = {}
        for s in nearlyIdentical:
            rndGenome = random.choice(tuple(s))
            reducedSeqs[rndGenome] = seqs[rndGenome]

        # write out reduced alignment
        reducedAlignmentFile = os.path.join(outputDir, "genome_tree.fasta")
        writeFasta(reducedSeqs, reducedAlignmentFile)

        # prune tree to retained taxa
        print ''
        print 'Pruning tree:'
        tree = dendropy.Tree.get_from_path(self.tree, schema='newick', as_rooted=False, preserve_underscores=True)

        for seqId in reducedSeqs:
            node = tree.find_node_with_taxon_label(seqId)
            if not node:
                print 'Missing taxa: %s' % seqId

        tree.retain_taxa_with_labels(reducedSeqs.keys())

        outputTree = os.path.join(outputDir, 'genome_tree.tre')
        tree.write_to_path(outputTree, schema='newick', suppress_rooting=True, unquoted_underscores=True)

        for t in tree.internal_nodes():
            t.label = None

        for t in tree.leaf_nodes():
            if t.taxon.label not in reducedSeqs:
                print 'missing in sequence file: %s' % t.taxon.label

        outputTreeWithoutLabels = os.path.join(outputDir, 'genome_tree.small.no_internal_labels.tre')
        tree.write_to_path(outputTreeWithoutLabels, schema='newick', suppress_rooting=True, unquoted_underscores=True)
        print '  Pruned tree written to: %s' % outputTree

        # calculate model parameters for pruned tree
        print ''
        print 'Determining model parameters for new tree.'
        outputTreeLog = os.path.join(outputDir, 'genome_tree.log')
        fastTreeOutput = os.path.join(outputDir, 'genome_tree.no_internal_labels.fasttree.tre')
        # os.system('FastTreeMP -nome -mllen -intree %s -log %s < %s > %s' % (outputTreeWithoutLabels, outputTreeLog, reducedAlignmentFile, fastTreeOutput))

        # calculate reference package for pruned tree
        print ''
        print 'Creating reference package.'
        os.system('taxit create -l %s -P %s --aln-fasta %s --tree-stats %s --tree-file %s' % ('genome_tree_reduced', os.path.join(outputDir, 'genome_tree_reduced.refpkg'), reducedAlignmentFile, outputTreeLog, outputTree))
예제 #27
0
class Simulation(object):
    def __init__(self):
        self.markerSetBuilder = MarkerSetBuilder()
        self.img = IMG('/srv/whitlam/bio/db/checkm/img/img_metadata.tsv', '/srv/whitlam/bio/db/checkm/pfam/tigrfam2pfam.tsv')

        self.contigLens = [1000, 2000, 5000, 10000, 20000, 50000]
        self.percentComps = [0.5, 0.7, 0.8, 0.9, 0.95, 1.0]
        self.percentConts = [0.0, 0.05, 0.1, 0.15, 0.2]

    def __workerThread(self, tree, metadata, ubiquityThreshold, singleCopyThreshold, numReplicates, queueIn, queueOut):
        """Process each data item in parallel."""

        while True:
            testGenomeId = queueIn.get(block=True, timeout=None)
            if testGenomeId == None:
                break

            # build marker sets for evaluating test genome
            testNode = tree.find_node_with_taxon_label('IMG_' + testGenomeId)
            binMarkerSets, refinedBinMarkerSet = self.markerSetBuilder.buildBinMarkerSet(tree, testNode.parent_node, ubiquityThreshold, singleCopyThreshold, bMarkerSet=True, genomeIdsToRemove=[testGenomeId])
            #!!!binMarkerSets, refinedBinMarkerSet = self.markerSetBuilder.buildDomainMarkerSet(tree, testNode.parent_node, ubiquityThreshold, singleCopyThreshold, bMarkerSet = False, genomeIdsToRemove = [testGenomeId])

            # determine distribution of all marker genes within the test genome
            geneDistTable = self.img.geneDistTable([testGenomeId], binMarkerSets.getMarkerGenes(), spacingBetweenContigs=0)

            print('# marker genes: ', len(binMarkerSets.getMarkerGenes()))
            print('# genes in table: ', len(geneDistTable[testGenomeId]))

            # estimate completeness of unmodified genome
            unmodifiedComp = {}
            unmodifiedCont = {}
            for ms in binMarkerSets.markerSetIter():
                hits = {}
                for mg in ms.getMarkerGenes():
                    if mg in geneDistTable[testGenomeId]:
                        hits[mg] = geneDistTable[testGenomeId][mg]
                completeness, contamination = ms.genomeCheck(hits, bIndividualMarkers=True)
                unmodifiedComp[ms.lineageStr] = completeness
                unmodifiedCont[ms.lineageStr] = contamination

            print(completeness, contamination)

            # estimate completion and contamination of genome after subsampling using both the domain and lineage-specific marker sets
            genomeSize = readFastaBases(os.path.join(self.img.genomeDir, testGenomeId, testGenomeId + '.fna'))
            print('genomeSize', genomeSize)

            for contigLen in self.contigLens:
                for percentComp in self.percentComps:
                    for percentCont in self.percentConts:
                        deltaComp = defaultdict(list)
                        deltaCont = defaultdict(list)
                        deltaCompSet = defaultdict(list)
                        deltaContSet = defaultdict(list)

                        deltaCompRefined = defaultdict(list)
                        deltaContRefined = defaultdict(list)
                        deltaCompSetRefined = defaultdict(list)
                        deltaContSetRefined = defaultdict(list)

                        trueComps = []
                        trueConts = []

                        numDescendants = {}

                        for _ in range(0, numReplicates):
                            trueComp, trueCont, startPartialGenomeContigs = self.markerSetBuilder.sampleGenome(genomeSize, percentComp, percentCont, contigLen)
                            print(contigLen, trueComp, trueCont, len(startPartialGenomeContigs))

                            trueComps.append(trueComp)
                            trueConts.append(trueCont)

                            for ms in binMarkerSets.markerSetIter():
                                numDescendants[ms.lineageStr] = ms.numGenomes

                                containedMarkerGenes = self.markerSetBuilder.containedMarkerGenes(ms.getMarkerGenes(), geneDistTable[testGenomeId], startPartialGenomeContigs, contigLen)
                                completeness, contamination = ms.genomeCheck(containedMarkerGenes, bIndividualMarkers=True)
                                deltaComp[ms.lineageStr].append(completeness - trueComp)
                                deltaCont[ms.lineageStr].append(contamination - trueCont)

                                completeness, contamination = ms.genomeCheck(containedMarkerGenes, bIndividualMarkers=False)
                                deltaCompSet[ms.lineageStr].append(completeness - trueComp)
                                deltaContSet[ms.lineageStr].append(contamination - trueCont)

                            for ms in refinedBinMarkerSet.markerSetIter():
                                containedMarkerGenes = self.markerSetBuilder.containedMarkerGenes(ms.getMarkerGenes(), geneDistTable[testGenomeId], startPartialGenomeContigs, contigLen)
                                completeness, contamination = ms.genomeCheck(containedMarkerGenes, bIndividualMarkers=True)
                                deltaCompRefined[ms.lineageStr].append(completeness - trueComp)
                                deltaContRefined[ms.lineageStr].append(contamination - trueCont)

                                completeness, contamination = ms.genomeCheck(containedMarkerGenes, bIndividualMarkers=False)
                                deltaCompSetRefined[ms.lineageStr].append(completeness - trueComp)
                                deltaContSetRefined[ms.lineageStr].append(contamination - trueCont)

                        taxonomy = ';'.join(metadata[testGenomeId]['taxonomy'])
                        queueOut.put((testGenomeId, contigLen, percentComp, percentCont, taxonomy, numDescendants, unmodifiedComp, unmodifiedCont, trueComps, trueConts, deltaComp, deltaCont, deltaCompSet, deltaContSet, deltaCompRefined, deltaContRefined, deltaCompSetRefined, deltaContSetRefined, trueComps, trueConts))

    def __writerThread(self, numTestGenomes, writerQueue):
        """Store or write results of worker threads in a single thread."""

        # summaryOut = open('/tmp/simulation.draft.summary.w_refinement_50.tsv', 'w')
        summaryOut = open('/tmp/simulation.summary.testing.tsv', 'w')
        summaryOut.write('Genome Id\tContig len\t% comp\t% cont')
        summaryOut.write('\tTaxonomy\tMarker set\t# descendants')
        summaryOut.write('\tUnmodified comp\tUnmodified cont\tTrue comp\tTrue cont')
        summaryOut.write('\tIM comp\tIM comp std\tIM cont\tIM cont std')
        summaryOut.write('\tMS comp\tMS comp std\tMS cont\tMS cont std')
        summaryOut.write('\tRIM comp\tRIM comp std\tRIM cont\tRIM cont std')
        summaryOut.write('\tRMS comp\tRMS comp std\tRMS cont\tRMS cont std\n')

        # fout = gzip.open('/tmp/simulation.draft.w_refinement_50.tsv.gz', 'wb')
        fout = gzip.open('/tmp/simulation.testing.tsv.gz', 'wb')
        fout.write('Genome Id\tContig len\t% comp\t% cont')
        fout.write('\tTaxonomy\tMarker set\t# descendants')
        fout.write('\tUnmodified comp\tUnmodified cont\tTrue comp\tTrue cont')
        fout.write('\tIM comp\tIM cont')
        fout.write('\tMS comp\tMS cont')
        fout.write('\tRIM comp\tRIM cont')
        fout.write('\tRMS comp\tRMS cont\tTrue Comp\tTrue Cont\n')

        testsPerGenome = len(self.contigLens) * len(self.percentComps) * len(self.percentConts)

        itemsProcessed = 0
        while True:
            testGenomeId, contigLen, percentComp, percentCont, taxonomy, numDescendants, unmodifiedComp, unmodifiedCont, trueComps, trueConts, deltaComp, deltaCont, deltaCompSet, deltaContSet, deltaCompRefined, deltaContRefined, deltaCompSetRefined, deltaContSetRefined, trueComps, trueConts = writerQueue.get(block=True, timeout=None)
            if testGenomeId == None:
                break

            itemsProcessed += 1
            statusStr = '    Finished processing %d of %d (%.2f%%) test cases.' % (itemsProcessed, numTestGenomes * testsPerGenome, float(itemsProcessed) * 100 / (numTestGenomes * testsPerGenome))
            sys.stdout.write('%s\r' % statusStr)
            sys.stdout.flush()

            for markerSetId in unmodifiedComp:
                summaryOut.write(testGenomeId + '\t%d\t%.2f\t%.2f' % (contigLen, percentComp, percentCont))
                summaryOut.write('\t' + taxonomy + '\t' + markerSetId + '\t' + str(numDescendants[markerSetId]))
                summaryOut.write('\t%.3f\t%.3f' % (unmodifiedComp[markerSetId], unmodifiedCont[markerSetId]))
                summaryOut.write('\t%.3f\t%.3f' % (mean(trueComps), std(trueConts)))
                summaryOut.write('\t%.3f\t%.3f' % (mean(abs(deltaComp[markerSetId])), std(abs(deltaComp[markerSetId]))))
                summaryOut.write('\t%.3f\t%.3f' % (mean(abs(deltaCont[markerSetId])), std(abs(deltaCont[markerSetId]))))
                summaryOut.write('\t%.3f\t%.3f' % (mean(abs(deltaCompSet[markerSetId])), std(abs(deltaCompSet[markerSetId]))))
                summaryOut.write('\t%.3f\t%.3f' % (mean(abs(deltaContSet[markerSetId])), std(abs(deltaContSet[markerSetId]))))
                summaryOut.write('\t%.3f\t%.3f' % (mean(abs(deltaCompRefined[markerSetId])), std(abs(deltaCompRefined[markerSetId]))))
                summaryOut.write('\t%.3f\t%.3f' % (mean(abs(deltaContRefined[markerSetId])), std(abs(deltaContRefined[markerSetId]))))
                summaryOut.write('\t%.3f\t%.3f' % (mean(abs(deltaCompSetRefined[markerSetId])), std(abs(deltaCompSetRefined[markerSetId]))))
                summaryOut.write('\t%.3f\t%.3f' % (mean(abs(deltaContSetRefined[markerSetId])), std(abs(deltaContSetRefined[markerSetId]))))
                summaryOut.write('\n')

                fout.write(testGenomeId + '\t%d\t%.2f\t%.2f' % (contigLen, percentComp, percentCont))
                fout.write('\t' + taxonomy + '\t' + markerSetId + '\t' + str(numDescendants[markerSetId]))
                fout.write('\t%.3f\t%.3f' % (unmodifiedComp[markerSetId], unmodifiedCont[markerSetId]))
                fout.write('\t%s' % ','.join(map(str, trueComps)))
                fout.write('\t%s' % ','.join(map(str, trueConts)))
                fout.write('\t%s' % ','.join(map(str, deltaComp[markerSetId])))
                fout.write('\t%s' % ','.join(map(str, deltaCont[markerSetId])))
                fout.write('\t%s' % ','.join(map(str, deltaCompSet[markerSetId])))
                fout.write('\t%s' % ','.join(map(str, deltaContSet[markerSetId])))
                fout.write('\t%s' % ','.join(map(str, deltaCompRefined[markerSetId])))
                fout.write('\t%s' % ','.join(map(str, deltaContRefined[markerSetId])))
                fout.write('\t%s' % ','.join(map(str, deltaCompSetRefined[markerSetId])))
                fout.write('\t%s' % ','.join(map(str, deltaContSetRefined[markerSetId])))
                fout.write('\t%s' % ','.join(map(str, trueComps)))
                fout.write('\t%s' % ','.join(map(str, trueConts)))
                fout.write('\n')

        summaryOut.close()
        fout.close()

        sys.stdout.write('\n')

    def run(self, ubiquityThreshold, singleCopyThreshold, numReplicates, numThreads):
        print('\n  Reading reference genome tree.')
        treeFile = os.path.join('/srv', 'db', 'checkm', 'genome_tree', 'genome_tree_full.refpkg', 'genome_tree.tre')
        tree = dendropy.Tree.get_from_path(treeFile, schema='newick', as_rooted=True, preserve_underscores=True)

        print('    Number of taxa in tree: %d' % (len(tree.leaf_nodes())))

        genomesInTree = set()
        for leaf in tree.leaf_iter():
            genomesInTree.add(leaf.taxon.label.replace('IMG_', ''))

        # get all draft genomes for testing
        print('')
        metadata = self.img.genomeMetadata()
        print('  Total genomes: %d' % len(metadata))

        genomeIdsToTest = genomesInTree - self.img.filterGenomeIds(genomesInTree, metadata, 'status', 'Finished')
        print('  Number of draft genomes: %d' % len(genomeIdsToTest))

        print('')
        print('  Pre-computing genome information for calculating marker sets:')
        start = time.time()
        self.markerSetBuilder.readLineageSpecificGenesToRemove()
        end = time.time()
        print('    readLineageSpecificGenesToRemove: %.2f' % (end - start))


        start = time.time()
        # self.markerSetBuilder.cachedGeneCountTable = self.img.geneCountTable(metadata.keys())
        end = time.time()
        print('    globalGeneCountTable: %.2f' % (end - start))

        start = time.time()
        # self.markerSetBuilder.precomputeGenomeSeqLens(metadata.keys())
        end = time.time()
        print('    precomputeGenomeSeqLens: %.2f' % (end - start))

        start = time.time()
        # self.markerSetBuilder.precomputeGenomeFamilyPositions(metadata.keys(), 0)
        end = time.time()
        print('    precomputeGenomeFamilyPositions: %.2f' % (end - start))

        print('')
        print('  Evaluating %d test genomes.' % len(genomeIdsToTest))
        workerQueue = mp.Queue()
        writerQueue = mp.Queue()

        for testGenomeId in genomeIdsToTest:
            workerQueue.put(testGenomeId)

        for _ in range(numThreads):
            workerQueue.put(None)

        workerProc = [mp.Process(target=self.__workerThread, args=(tree, metadata, ubiquityThreshold, singleCopyThreshold, numReplicates, workerQueue, writerQueue)) for _ in range(numThreads)]
        writeProc = mp.Process(target=self.__writerThread, args=(len(genomeIdsToTest), writerQueue))

        writeProc.start()

        for p in workerProc:
            p.start()

        for p in workerProc:
            p.join()

        writerQueue.put((None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None))
        writeProc.join()
class SimulationScaffolds(object):
    def __init__(self):
        self.markerSetBuilder = MarkerSetBuilder()
        self.img = IMG('/srv/whitlam/bio/db/checkm/img/img_metadata.tsv', '/srv/whitlam/bio/db/checkm/pfam/tigrfam2pfam.tsv')
        
        self.contigLens = [5000, 20000, 50000]
        self.percentComps = [0.5, 0.7, 0.8, 0.9, 0.95, 1.0]
        self.percentConts = [0.0, 0.05, 0.1, 0.15, 0.2]

    def __seqLens(self, seqs):
        """Calculate lengths of seqs."""
        genomeSize = 0
        seqLens = {}
        for seqId, seq in seqs.iteritems():
            seqLens[seqId] = len(seq)
            genomeSize += len(seq)
    
        return seqLens, genomeSize
    
    def __workerThread(self, tree, metadata, genomeIdsToTest, ubiquityThreshold, singleCopyThreshold, numReplicates, queueIn, queueOut):
        """Process each data item in parallel."""

        while True:
            testGenomeId = queueIn.get(block=True, timeout=None)
            if testGenomeId == None:
                break
                        
            # build marker sets for evaluating test genome
            testNode = tree.find_node_with_taxon_label('IMG_' + testGenomeId)
            binMarkerSets, refinedBinMarkerSet = self.markerSetBuilder.buildBinMarkerSet(tree, testNode.parent_node, ubiquityThreshold, singleCopyThreshold, bMarkerSet = True, genomeIdsToRemove = [testGenomeId])

            # determine distribution of all marker genes within the test genome
            geneDistTable = self.img.geneDistTable([testGenomeId], binMarkerSets.getMarkerGenes(), spacingBetweenContigs=0)
                
            # estimate completeness of unmodified genome
            unmodifiedComp = {}
            unmodifiedCont = {}
            for ms in binMarkerSets.markerSetIter():     
                hits = {}
                for mg in ms.getMarkerGenes():
                    if mg in geneDistTable[testGenomeId]:
                        hits[mg] = geneDistTable[testGenomeId][mg]
                completeness, contamination = ms.genomeCheck(hits, bIndividualMarkers=True) 
                unmodifiedComp[ms.lineageStr] = completeness
                unmodifiedCont[ms.lineageStr] = contamination

            # estimate completion and contamination of genome after subsampling using both the domain and lineage-specific marker sets 
            testSeqs = readFasta(os.path.join(self.img.genomeDir, testGenomeId, testGenomeId + '.fna'))
            testSeqLens, genomeSize = self.__seqLens(testSeqs)
            
            
            for contigLen in self.contigLens: 
                for percentComp in self.percentComps:
                    for percentCont in self.percentConts:
                        deltaComp = defaultdict(list)
                        deltaCont = defaultdict(list)
                        deltaCompSet = defaultdict(list)
                        deltaContSet = defaultdict(list)
                        
                        deltaCompRefined = defaultdict(list)
                        deltaContRefined = defaultdict(list)
                        deltaCompSetRefined = defaultdict(list)
                        deltaContSetRefined = defaultdict(list)
                        
                        trueComps = []
                        trueConts = []
                        
                        numDescendants = {}
            
                        for i in xrange(0, numReplicates):
                            # generate test genome with a specific level of completeness, by randomly sampling scaffolds to remove 
                            # (this will sample >= the desired level of completeness)
                            retainedTestSeqs, trueComp = self.markerSetBuilder.sampleGenomeScaffoldsWithoutReplacement(percentComp, testSeqLens, genomeSize)
                            trueComps.append(trueComp)
    
                            # select a random genome to use as a source of contamination
                            contGenomeId = random.sample(genomeIdsToTest - set([testGenomeId]), 1)[0]
                            contSeqs = readFasta(os.path.join(self.img.genomeDir, contGenomeId, contGenomeId + '.fna'))
                            contSeqLens, contGenomeSize = self.__seqLens(contSeqs) 
                            seqsToRetain, trueRetainedPer = self.markerSetBuilder.sampleGenomeScaffoldsWithoutReplacement(1 - percentCont, contSeqLens, contGenomeSize) 
                            
                            contSampledSeqIds = set(contSeqs.keys()).difference(seqsToRetain)
                            trueCont = 100.0 - trueRetainedPer
                            trueConts.append(trueCont)
              
                            for ms in binMarkerSets.markerSetIter():  
                                numDescendants[ms.lineageStr] = ms.numGenomes
                                containedMarkerGenes= defaultdict(list)
                                self.markerSetBuilder.markerGenesOnScaffolds(ms.getMarkerGenes(), testGenomeId, retainedTestSeqs, containedMarkerGenes)
                                self.markerSetBuilder.markerGenesOnScaffolds(ms.getMarkerGenes(), contGenomeId, contSampledSeqIds, containedMarkerGenes)

                                completeness, contamination = ms.genomeCheck(containedMarkerGenes, bIndividualMarkers=True)
                                deltaComp[ms.lineageStr].append(completeness - trueComp)
                                deltaCont[ms.lineageStr].append(contamination - trueCont)
                                
                                completeness, contamination = ms.genomeCheck(containedMarkerGenes, bIndividualMarkers=False)
                                deltaCompSet[ms.lineageStr].append(completeness - trueComp)
                                deltaContSet[ms.lineageStr].append(contamination - trueCont)
                                
                            for ms in refinedBinMarkerSet.markerSetIter():  
                                containedMarkerGenes= defaultdict(list)
                                self.markerSetBuilder.markerGenesOnScaffolds(ms.getMarkerGenes(), testGenomeId, retainedTestSeqs, containedMarkerGenes)
                                self.markerSetBuilder.markerGenesOnScaffolds(ms.getMarkerGenes(), contGenomeId, contSampledSeqIds, containedMarkerGenes)
                                
                                completeness, contamination = ms.genomeCheck(containedMarkerGenes, bIndividualMarkers=True)
                                deltaCompRefined[ms.lineageStr].append(completeness - trueComp)
                                deltaContRefined[ms.lineageStr].append(contamination - trueCont)
                                
                                completeness, contamination = ms.genomeCheck(containedMarkerGenes, bIndividualMarkers=False)
                                deltaCompSetRefined[ms.lineageStr].append(completeness - trueComp)
                                deltaContSetRefined[ms.lineageStr].append(contamination - trueCont)
                                
                        taxonomy = ';'.join(metadata[testGenomeId]['taxonomy'])
                        queueOut.put((testGenomeId, contigLen, percentComp, percentCont, taxonomy, numDescendants, unmodifiedComp, unmodifiedCont, deltaComp, deltaCont, deltaCompSet, deltaContSet, deltaCompRefined, deltaContRefined, deltaCompSetRefined, deltaContSetRefined, trueComps, trueConts))
            
    def __writerThread(self, numTestGenomes, writerQueue):
        """Store or write results of worker threads in a single thread."""
        
        summaryOut = open('/tmp/simulation.random_scaffolds.w_refinement_50.draft.summary.tsv', 'w')
        summaryOut.write('Genome Id\tContig len\t% comp\t% cont')
        summaryOut.write('\tTaxonomy\tMarker set\t# descendants')
        summaryOut.write('\tUnmodified comp\tUnmodified cont')
        summaryOut.write('\tIM comp\tIM comp std\tIM cont\tIM cont std')
        summaryOut.write('\tMS comp\tMS comp std\tMS cont\tMS cont std')
        summaryOut.write('\tRIM comp\tRIM comp std\tRIM cont\tRIM cont std')
        summaryOut.write('\tRMS comp\tRMS comp std\tRMS cont\tRMS cont std\n')
        
        fout = gzip.open('/tmp/simulation.random_scaffolds.w_refinement_50.draft.tsv.gz', 'wb')
        fout.write('Genome Id\tContig len\t% comp\t% cont')
        fout.write('\tTaxonomy\tMarker set\t# descendants')
        fout.write('\tUnmodified comp\tUnmodified cont')
        fout.write('\tIM comp\tIM cont')
        fout.write('\tMS comp\tMS cont')
        fout.write('\tRIM comp\tRIM cont')
        fout.write('\tRMS comp\tRMS cont\tTrue Comp\tTrue Cont\n')
        
        testsPerGenome = len(self.contigLens) * len(self.percentComps) * len(self.percentConts)

        itemsProcessed = 0
        while True:
            testGenomeId, contigLen, percentComp, percentCont, taxonomy, numDescendants, unmodifiedComp, unmodifiedCont, deltaComp, deltaCont, deltaCompSet, deltaContSet, deltaCompRefined, deltaContRefined, deltaCompSetRefined, deltaContSetRefined, trueComps, trueConts = writerQueue.get(block=True, timeout=None)
            if testGenomeId == None:
                break

            itemsProcessed += 1
            statusStr = '    Finished processing %d of %d (%.2f%%) test cases.' % (itemsProcessed, numTestGenomes*testsPerGenome, float(itemsProcessed)*100/(numTestGenomes*testsPerGenome))
            sys.stdout.write('%s\r' % statusStr)
            sys.stdout.flush()
            
            for markerSetId in unmodifiedComp:
                summaryOut.write(testGenomeId + '\t%d\t%.2f\t%.2f' % (contigLen, percentComp, percentCont)) 
                summaryOut.write('\t' + taxonomy + '\t' + markerSetId + '\t' + str(numDescendants[markerSetId]))
                summaryOut.write('\t%.3f\t%.3f' % (unmodifiedComp[markerSetId], unmodifiedCont[markerSetId]))
                summaryOut.write('\t%.3f\t%.3f' % (mean(abs(deltaComp[markerSetId])), std(abs(deltaComp[markerSetId]))))
                summaryOut.write('\t%.3f\t%.3f' % (mean(abs(deltaCont[markerSetId])), std(abs(deltaCont[markerSetId]))))
                summaryOut.write('\t%.3f\t%.3f' % (mean(abs(deltaCompSet[markerSetId])), std(abs(deltaCompSet[markerSetId]))))
                summaryOut.write('\t%.3f\t%.3f' % (mean(abs(deltaContSet[markerSetId])), std(abs(deltaContSet[markerSetId]))))
                summaryOut.write('\t%.3f\t%.3f' % (mean(abs(deltaCompRefined[markerSetId])), std(abs(deltaCompRefined[markerSetId]))))
                summaryOut.write('\t%.3f\t%.3f' % (mean(abs(deltaContRefined[markerSetId])), std(abs(deltaContRefined[markerSetId]))))
                summaryOut.write('\t%.3f\t%.3f' % (mean(abs(deltaCompSetRefined[markerSetId])), std(abs(deltaCompSetRefined[markerSetId]))))
                summaryOut.write('\t%.3f\t%.3f' % (mean(abs(deltaContSetRefined[markerSetId])), std(abs(deltaContSetRefined[markerSetId]))))
                summaryOut.write('\n')
                
                fout.write(testGenomeId + '\t%d\t%.2f\t%.2f' % (contigLen, percentComp, percentCont)) 
                fout.write('\t' + taxonomy + '\t' + markerSetId + '\t' + str(numDescendants[markerSetId]))
                fout.write('\t%.3f\t%.3f' % (unmodifiedComp[markerSetId], unmodifiedCont[markerSetId]))
                fout.write('\t%s' % ','.join(map(str, deltaComp[markerSetId])))
                fout.write('\t%s' % ','.join(map(str, deltaCont[markerSetId])))
                fout.write('\t%s' % ','.join(map(str, deltaCompSet[markerSetId])))
                fout.write('\t%s' % ','.join(map(str, deltaContSet[markerSetId])))
                fout.write('\t%s' % ','.join(map(str, deltaCompRefined[markerSetId])))
                fout.write('\t%s' % ','.join(map(str, deltaContRefined[markerSetId])))
                fout.write('\t%s' % ','.join(map(str, deltaCompSetRefined[markerSetId])))
                fout.write('\t%s' % ','.join(map(str, deltaContSetRefined[markerSetId])))
                fout.write('\t%s' % ','.join(map(str, trueComps)))
                fout.write('\t%s' % ','.join(map(str, trueConts)))
                fout.write('\n')
            
        summaryOut.close()
        fout.close()

        sys.stdout.write('\n')

    def run(self, ubiquityThreshold, singleCopyThreshold, numReplicates, minScaffolds, numThreads):
        random.seed(0)

        print '\n  Reading reference genome tree.'
        treeFile = os.path.join('/srv', 'db', 'checkm', 'genome_tree', 'genome_tree_prok.refpkg', 'genome_tree.final.tre')
        tree = dendropy.Tree.get_from_path(treeFile, schema='newick', as_rooted=True, preserve_underscores=True)
        
        print '    Number of taxa in tree: %d' % (len(tree.leaf_nodes()))
        
        genomesInTree = set()
        for leaf in tree.leaf_iter():
            genomesInTree.add(leaf.taxon.label.replace('IMG_', ''))

        # get all draft genomes consisting of a user-specific minimum number of scaffolds
        print ''
        metadata = self.img.genomeMetadata()
        print '  Total genomes: %d' % len(metadata)
        
        draftGenomeIds = genomesInTree - self.img.filterGenomeIds(genomesInTree, metadata, 'status', 'Finished')
        print '  Number of draft genomes: %d' % len(draftGenomeIds)
        
        genomeIdsToTest = set()
        for genomeId in draftGenomeIds:
            if metadata[genomeId]['scaffold count'] >= minScaffolds:
                genomeIdsToTest.add(genomeId)
                
        
        print '  Number of draft genomes with >= %d scaffolds: %d' % (minScaffolds, len(genomeIdsToTest))

        print ''
        start = time.time()
        self.markerSetBuilder.readLineageSpecificGenesToRemove()
        end = time.time()
        print '    readLineageSpecificGenesToRemove: %.2f' % (end - start)
        
        print '  Pre-computing genome information for calculating marker sets:'
        start = time.time()
        self.markerSetBuilder.precomputeGenomeFamilyScaffolds(metadata.keys())
        end = time.time()
        print '    precomputeGenomeFamilyScaffolds: %.2f' % (end - start)
        
        start = time.time()
        self.markerSetBuilder.cachedGeneCountTable = self.img.geneCountTable(metadata.keys())
        end = time.time()
        print '    globalGeneCountTable: %.2f' % (end - start)
        
        start = time.time()
        self.markerSetBuilder.precomputeGenomeSeqLens(metadata.keys())
        end = time.time()
        print '    precomputeGenomeSeqLens: %.2f' % (end - start)
        
        start = time.time()
        self.markerSetBuilder.precomputeGenomeFamilyPositions(metadata.keys(), 0)
        end = time.time()
        print '    precomputeGenomeFamilyPositions: %.2f' % (end - start)
                     
        print ''    
        print '  Evaluating %d test genomes.' % len(genomeIdsToTest)
            
        workerQueue = mp.Queue()
        writerQueue = mp.Queue()

        for testGenomeId in list(genomeIdsToTest):
            workerQueue.put(testGenomeId)

        for _ in range(numThreads):
            workerQueue.put(None)

        workerProc = [mp.Process(target = self.__workerThread, args = (tree, metadata, genomeIdsToTest, ubiquityThreshold, singleCopyThreshold, numReplicates, workerQueue, writerQueue)) for _ in range(numThreads)]
        writeProc = mp.Process(target = self.__writerThread, args = (len(genomeIdsToTest), writerQueue))

        writeProc.start()

        for p in workerProc:
            p.start()

        for p in workerProc:
            p.join()

        writerQueue.put((None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None))
        writeProc.join()
예제 #29
0
 def __init__(self):
     img = IMG('/srv/whitlam/bio/db/checkm/img/img_metadata.tsv',
               '/srv/whitlam/bio/db/checkm/pfam/tigrfam2pfam.tsv')
     self.metadata = img.genomeMetadata()
예제 #30
0
    def run(self, geneTreeDir, acceptPer, extension, outputDir):
        # make sure output directory is empty
        if not os.path.exists(outputDir):
            os.makedirs(outputDir)

        files = os.listdir(outputDir)
        for f in files:
            os.remove(os.path.join(outputDir, f))

        img = IMG('/srv/whitlam/bio/db/checkm/img/img_metadata.tsv', '/srv/whitlam/bio/db/checkm/pfam/tigrfam2pfam.tsv')
        metadata = img.genomeMetadata()

        files = os.listdir(geneTreeDir)
        print 'Identifying gene trees with only conspecific paralogous genes:'
        filteredGeneTrees = 0
        retainedGeneTrees = 0
        for f in files:
            if not f.endswith(extension):
                continue

            geneId = f[0:f.find('.')]
            print '  Testing gene tree: ' + geneId

            tree = dendropy.Tree.get_from_path(os.path.join(geneTreeDir, f), schema='newick', as_rooted=False, preserve_underscores=True)

            taxa = tree.leaf_nodes()
            numTaxa = len(taxa)
            print '  Genes in tree: ' + str(numTaxa)

            # root tree with archaeal genomes
            rerootTree = RerootTree()
            rerootTree.reroot(tree)
            
            # get species name of each taxa
            leafNodeToSpeciesName = {}
            for t in taxa:
                genomeId = t.taxon.label.split('|')[0]
                genus = metadata[genomeId]['taxonomy'][5]
                sp = metadata[genomeId]['taxonomy'][6].lower()

                leafNodeToSpeciesName[t.taxon.label] = genus + ' ' + sp
                
            # find all paralogous genes
            print '  Finding paralogous genes.'

            paralogs = defaultdict(set)
            for i in xrange(0, len(taxa)):
                genomeId = taxa[i].taxon.label.split('|')[0]
                for j in xrange(i+1, len(taxa)):
                    # genes from the same genome are paralogs, but we filter out
                    # those that are identical (distance of 0 on the tree) to
                    # speed up computation and because these clearly do not
                    # adversely effect phylogenetic inference
                    if genomeId == taxa[j].taxon.label.split('|')[0] and self.__patristicDist(tree, taxa[i], taxa[j]) > 0:
                        paralogs[genomeId].add(taxa[i].taxon.label)
                        paralogs[genomeId].add(taxa[j].taxon.label)
                        
            print '    Paralogous genes: ' + str(len(paralogs))

            # check if paralogous genes are conspecific
            print '  Determining if paralogous genes are conspecific.'
            nonConspecificGenomes = []
            for genomeId, taxaLabels in paralogs.iteritems():
                lcaNode = tree.mrca(taxon_labels = taxaLabels)

                children = lcaNode.leaf_nodes()
                species = set()
                for child in children:
                    childGenomeId = child.taxon.label.split('|')[0]

                    genus = metadata[childGenomeId]['taxonomy'][5]
                    sp = metadata[childGenomeId]['taxonomy'][6].lower()
                    if sp != '' and sp != 'unclassified' and genus != 'unclassified':
                        species.add(genus + ' ' + sp)

                if len(species) > 1:
                    nonConspecificGenomes.append(genomeId)

            if len(nonConspecificGenomes) > acceptPer*numTaxa:
                filteredGeneTrees += 1
                print '  Tree is not conspecific for the following genome: ' + str(nonConspecificGenomes)
            else:
                retainedGeneTrees += 1

                if len(nonConspecificGenomes) > 1:
                    print '  An acceptable number of genomes are not conspecific: ' + str(nonConspecificGenomes)
                else:
                    print '  Tree is conspecific.'

                os.system('cp ' + os.path.join(geneTreeDir, f) + ' ' + os.path.join(outputDir, f))

            print ''

        print 'Filtered gene trees: ' + str(filteredGeneTrees)
        print 'Retained gene trees: ' + str(retainedGeneTrees)
예제 #31
0
class SimCompareDiffPlot(object):
    def __init__(self):
        self.img = IMG('/srv/whitlam/bio/db/checkm/img/img_metadata.tsv', '/srv/whitlam/bio/db/checkm/pfam/tigrfam2pfam.tsv')
        
    def run(self):
        # count number of times the lineage-specific marker set results outperform
        # the domain-specific marker set for varying differences between the two sets
        numBars = 15
        
        lineageCountsComp = [0]*numBars
        domainCountsComp = [0]*numBars
        
        lineageCountsCont = [0]*numBars
        domainCountsCont = [0]*numBars
        
        totalCountsComp = 0
        totalCountsCont = 0
        
        domCompBest = 0
        lineageCompBest = 0
        domContBest = 0
        lineageContBest = 0
        
        metadata = self.img.genomeMetadata()
        domCompTaxon = defaultdict(int)
        lineageCompTaxon = defaultdict(int)
        
        for line in open('./simulations/briefSummaryOut.tsv'):
            lineSplit = line.split('\t')
            genomeId = lineSplit[0]
            taxonomy = metadata[genomeId]['taxonomy']
            phylum = taxonomy[1]
            domCompMS, lineageCompMS, lineageCompRMS, domContMS, lineageContMS, lineageContRMS = [float(x) for x in lineSplit[1:]]
            
            diff = abs(abs(lineageCompMS) - abs(domCompMS))
            if diff > 5:
                intDiff = int(diff)
                if intDiff >= numBars:
                    intDiff = (numBars-1)
                    
                if abs(domCompMS) < abs(lineageCompMS):
                    domainCountsComp[intDiff] += 1
                    domCompBest += 1
                    domCompTaxon[phylum] += 1
                else:
                    lineageCountsComp[intDiff] += 1
                    lineageCompBest += 1
                    lineageCompTaxon[phylum] += 1
                    
                totalCountsComp += 1
                
            diff = abs(abs(lineageContMS) - abs(domContMS))
            if diff > 5:
                intDiff = int(diff)
                if intDiff >= numBars:
                    intDiff = (numBars-1)
                    
                if abs(domContMS) < abs(lineageContMS):
                    domainCountsCont[intDiff] += 1
                    domContBest += 1
                else:
                    lineageCountsCont[intDiff] += 1
                    lineageContBest += 1
                    
                totalCountsCont += 1
                
        print '%% times lineage comp better than domain: %.2f' % (float(lineageCompBest)*100/(domCompBest + lineageCompBest))
        print '%% times lineage cont better than domain: %.2f' % (float(lineageContBest)*100/(domContBest + lineageContBest))
        
        print ''
        print 'Taxonomy breakdown (dom best, lineage best):'
        taxa = set(domCompTaxon.keys()).union(lineageCompTaxon.keys())
        for t in taxa:
            print '%s\t%.2f\t%.2f' % (t, domCompTaxon[t]*100.0/domCompBest, lineageCompTaxon[t]*100.0/lineageCompBest)
                
        # normalize counts
        for i in xrange(0, numBars):
            lineageCountsComp[i] = float(lineageCountsComp[i])*100 / totalCountsComp
            domainCountsComp[i] = float(domainCountsComp[i])*100 / totalCountsComp
            
            if domainCountsComp[i] > lineageCountsComp[i]:
                print 'Domain bets lineage (comp): %d%% (%f, %f)' % (i+1, domainCountsComp[i], lineageCountsComp[i])
            
            lineageCountsCont[i] = float(lineageCountsCont[i])*100 / totalCountsCont
            domainCountsCont[i] = float(domainCountsCont[i])*100 / totalCountsCont
            
            if domainCountsCont[i] > lineageCountsCont[i]:
                print 'Domain bets lineage (cont): %d%% (%f, %f)' % (i+1, domainCountsCont[i], lineageCountsCont[i])
         
        stackedBarPlot = StackedBarPlot()
        stackedBarPlot.plot(lineageCountsComp, domainCountsComp, lineageCountsCont, domainCountsCont)     
        stackedBarPlot.savePlot('./experiments/simCompareDiffPlot.svg')
예제 #32
0
class MarkerSetBuilder(object):
    def __init__(self):
        self.img = IMG('/srv/whitlam/bio/db/checkm/img/img_metadata.tsv', '/srv/whitlam/bio/db/checkm/pfam/tigrfam2pfam.tsv')
        self.colocatedFile = './data/colocated.tsv'
        self.duplicateSeqs = self.readDuplicateSeqs()
        self.uniqueIdToLineageStatistics = self.__readNodeMetadata()
        
        self.cachedGeneCountTable = None
        
    def precomputeGenomeSeqLens(self, genomeIds):
        """Cache the length of contigs/scaffolds for all genomes."""
        
        # This function is intended to speed up functions, such as img.geneDistTable(),
        # that are called multiple times (typically during simulations)
        self.img.precomputeGenomeSeqLens(genomeIds)
            
    def precomputeGenomeFamilyPositions(self, genomeIds, spacingBetweenContigs):
        """Cache position of PFAM and TIGRFAM genes in genomes."""
        
        # This function is intended to speed up functions, such as img.geneDistTable(),
        # that are called multiple times (typically during simulations)
        self.img.precomputeGenomeFamilyPositions(genomeIds, spacingBetweenContigs)
        
    def precomputeGenomeFamilyScaffolds(self, genomeIds):
        """Cache scaffolds of PFAM and TIGRFAM genes in genomes."""
        
        # This function is intended to speed up functions, such as img.geneDistTable(),
        # that are called multiple times (typically during simulations)
        self.img.precomputeGenomeFamilyScaffolds(genomeIds)
        
    def getLineageMarkerGenes(self, lineage, minGenomes = 20, minMarkerSets = 20):
        pfamIds = set()
        tigrIds = set()

        bHeader = True
        for line in open(self.colocatedFile):
            if bHeader:
                bHeader = False
                continue

            lineSplit = line.split('\t')
            curLineage = lineSplit[0]
            numGenomes = int(lineSplit[1])
            numMarkerSets = int(lineSplit[3])
            markerSets = lineSplit[4:]

            if curLineage != lineage or numGenomes < minGenomes or numMarkerSets < minMarkerSets:
                continue

            for ms in markerSets:
                markers = ms.split(',')
                for m in markers:
                    if 'pfam' in m:
                        pfamIds.add(m.strip())
                    elif 'TIGR' in m:
                        tigrIds.add(m.strip())

        return pfamIds, tigrIds

    def getCalculatedMarkerGenes(self, minGenomes = 20, minMarkerSets = 20):
        pfamIds = set()
        tigrIds = set()

        bHeader = True
        for line in open(self.colocatedFile):
            if bHeader:
                bHeader = False
                continue

            lineSplit = line.split('\t')
            numGenomes = int(lineSplit[1])
            numMarkerSets = int(lineSplit[3])
            markerSets = lineSplit[4:]

            if numGenomes < minGenomes or numMarkerSets < minMarkerSets:
                continue

            for ms in markerSets:
                markers = ms.split(',')
                for m in markers:
                    if 'pfam' in m:
                        pfamIds.add(m.strip())
                    elif 'TIGR' in m:
                        tigrIds.add(m.strip())

        return pfamIds, tigrIds

    def markerGenes(self, genomeIds, countTable, ubiquityThreshold, singleCopyThreshold):
        if ubiquityThreshold < 1 or singleCopyThreshold < 1:
            print '[Warning] Looks like degenerate threshold.'

        # find genes meeting ubiquity and single-copy thresholds
        markers = set()
        for clusterId, genomeCounts in countTable.iteritems():
            ubiquity = 0
            singleCopy = 0
            
            if len(genomeCounts) < ubiquityThreshold:
                # gene is clearly not ubiquitous
                continue
            
            for genomeId in genomeIds:
                count = genomeCounts.get(genomeId, 0)

                if count > 0:
                    ubiquity += 1

                if count == 1:
                    singleCopy += 1

            if ubiquity >= ubiquityThreshold and singleCopy >= singleCopyThreshold:
                markers.add(clusterId)

        return markers

    def colocatedGenes(self, geneDistTable, distThreshold = 5000, genomeThreshold = 0.95):
        """Identify co-located gene pairs."""
                
        colocatedGenes = defaultdict(int)
        for _, clusterIdToGeneLocs in geneDistTable.iteritems():
            clusterIds = clusterIdToGeneLocs.keys()
            for i, clusterId1 in enumerate(clusterIds):
                geneLocations1 = clusterIdToGeneLocs[clusterId1]
                
                for clusterId2 in clusterIds[i+1:]:
                    geneLocations2 = clusterIdToGeneLocs[clusterId2]
                    bColocated = False
                    for p1 in geneLocations1:
                        for p2 in geneLocations2:
                            if abs(p1[0] - p2[0]) < distThreshold:
                                bColocated = True
                                break

                        if bColocated:
                            break

                    if bColocated:
                        if clusterId1 <= clusterId2:
                            colocatedStr = clusterId1 + '-' + clusterId2
                        else:
                            colocatedStr = clusterId2 + '-' + clusterId1
                        colocatedGenes[colocatedStr] += 1

        colocated = []
        for colocatedStr, count in colocatedGenes.iteritems():
            if float(count)/len(geneDistTable) > genomeThreshold:
                colocated.append(colocatedStr)

        return colocated

    def colocatedSets(self, colocatedGenes, markerGenes):
        # run through co-located genes once creating initial sets
        sets = []
        for cg in colocatedGenes:
            geneA, geneB = cg.split('-')
            sets.append(set([geneA, geneB]))

        # combine any sets with overlapping genes
        bProcessed = [False]*len(sets)
        finalSets = []
        for i in xrange(0, len(sets)):
            if bProcessed[i]:
                continue

            curSet = sets[i]
            bProcessed[i] = True

            bUpdated = True
            while bUpdated:
                bUpdated = False
                for j in xrange(i+1, len(sets)):
                    if bProcessed[j]:
                        continue

                    if len(curSet.intersection(sets[j])) > 0:
                        curSet.update(sets[j])
                        bProcessed[j] = True
                        bUpdated = True

            finalSets.append(curSet)

        # add all singletons into colocated sets
        for clusterId in markerGenes:
            bFound = False
            for cs in finalSets:
                if clusterId in cs:
                    bFound = True

            if not bFound:
                finalSets.append(set([clusterId]))

        return finalSets

    def genomeCheck(self, colocatedSet, genomeId, countTable):
        comp = 0.0
        cont = 0.0
        missingMarkers = set()
        duplicateMarkers = set()

        if len(colocatedSet) == 0:
            return comp, cont, missingMarkers, duplicateMarkers

        for cs in colocatedSet:
            present = 0
            multiCopy = 0
            for contigId in cs:
                count = countTable[contigId].get(genomeId, 0)
                if count == 1:
                    present += 1
                elif count > 1:
                    present += 1
                    multiCopy += (count-1)
                    duplicateMarkers.add(contigId)
                elif count == 0:
                    missingMarkers.add(contigId)

            comp += float(present) / len(cs)
            cont += float(multiCopy) / len(cs)

        return comp / len(colocatedSet), cont / len(colocatedSet), missingMarkers, duplicateMarkers

    def uniformity(self, genomeSize, pts):
        U = float(genomeSize) / (len(pts)+1)  # distance between perfectly evenly spaced points

        # calculate distance between adjacent points
        dists = []
        pts = sorted(pts)
        for i in xrange(0, len(pts)-1):
            dists.append(pts[i+1] - pts[i])

        # calculate uniformity index
        num = 0
        den = 0
        for d in dists:
            num += abs(d - U)
            den += max(d, U)

        return 1.0 - num/den
    
    def sampleGenome(self, genomeLen, percentComp, percentCont, contigLen):
        """Sample a genome to simulate a given percent completion and contamination."""
        
        contigsInGenome = genomeLen / contigLen

        # determine number of contigs to achieve desired completeness and contamination
        contigsToSampleComp = int(contigsInGenome*percentComp + 0.5)
        contigsToSampleCont = int(contigsInGenome*percentCont + 0.5)

        # randomly sample contigs with contamination done via sampling with replacement
        compContigs = random.sample(xrange(contigsInGenome), contigsToSampleComp)  
        contContigs = choice(xrange(contigsInGenome), contigsToSampleCont, replace=True)
    
        # determine start of each contig
        contigStarts = [c*contigLen for c in compContigs]
        contigStarts += [c*contigLen for c in contContigs]
            
        contigStarts.sort()
        
        trueComp = float(contigsToSampleComp)*contigLen*100 / genomeLen
        trueCont = float(contigsToSampleCont)*contigLen*100 / genomeLen

        return trueComp, trueCont, contigStarts
    
    def sampleGenomeScaffoldsInvLength(self, targetPer, seqLens, genomeSize):
        """Sample genome comprised of several sequences with probability inversely proportional to length."""
        
        # calculate probability of sampling a sequences
        seqProb = []
        for _, seqLen in seqLens.iteritems():
            prob = 1.0 / (float(seqLen) / genomeSize)
            seqProb.append(prob)
            
        seqProb = array(seqProb)
        seqProb /= sum(seqProb)
            
        # select sequence with probability proportional to length
        selectedSeqsIds = choice(seqLens.keys(), size = len(seqLens), replace=False, p = seqProb)
        
        sampledSeqIds = []
        truePer = 0.0
        for seqId in selectedSeqsIds:
            sampledSeqIds.append(seqId)
            truePer += float(seqLens[seqId]) / genomeSize
            
            if truePer >= targetPer:
                break
        
        return sampledSeqIds, truePer*100
    
    def sampleGenomeScaffoldsWithoutReplacement(self, targetPer, seqLens, genomeSize):
        """Sample genome comprised of several sequences without replacement.
        
          Sampling is conducted randomly until the selected sequences comprise
          greater than or equal to the desired target percentage.
        """
 
        selectedSeqsIds = choice(seqLens.keys(), size = len(seqLens), replace=False)
        
        sampledSeqIds = []
        truePer = 0.0
        for seqId in selectedSeqsIds:
            sampledSeqIds.append(seqId)
            truePer += float(seqLens[seqId]) / genomeSize
            
            if truePer >= targetPer:
                break
        
        return sampledSeqIds, truePer*100
    
    def containedMarkerGenes(self, markerGenes, clusterIdToGenomePositions, startPartialGenomeContigs, contigLen):
        """Determine markers contained in a set of contigs."""
        
        contained = {}
        for markerGene in markerGenes:
            positions = clusterIdToGenomePositions.get(markerGene, [])

            containedPos = []
            for p in positions:
                for s in startPartialGenomeContigs:
                    if (p[0] - s) >= 0 and (p[0] - s) < contigLen:
                        containedPos.append(s)

            if len(containedPos) > 0:
                contained[markerGene] = containedPos

        return contained
    
    def markerGenesOnScaffolds(self, markerGenes, genomeId, scaffoldIds, containedMarkerGenes):
        """Determine if marker genes are found on the scaffolds of a given genome."""
        for markerGeneId in markerGenes:
            scaffoldIdsWithMarker = self.img.cachedGenomeFamilyScaffolds[genomeId].get(markerGeneId, [])

            for scaffoldId in scaffoldIdsWithMarker:
                if scaffoldId in scaffoldIds:
                    containedMarkerGenes[markerGeneId] += [scaffoldId]
        
    def readDuplicateSeqs(self):
        """Parse file indicating duplicate sequence alignments."""
        duplicateSeqs = {}
        for line in open(os.path.join('/srv/whitlam/bio/db/checkm/genome_tree', 'genome_tree.derep.txt')):
            lineSplit = line.rstrip().split()
            if len(lineSplit) > 1:
                duplicateSeqs[lineSplit[0]] = lineSplit[1:]
                
        return duplicateSeqs
        
    def __readNodeMetadata(self):
        """Read metadata for internal nodes."""
        
        uniqueIdToLineageStatistics = {}
        metadataFile = os.path.join('/srv/whitlam/bio/db/checkm/genome_tree', 'genome_tree.metadata.tsv')
        with open(metadataFile) as f:
            f.readline()
            for line in f:
                lineSplit = line.rstrip().split('\t')
                
                uniqueId = lineSplit[0]
                
                d = {}
                d['# genomes'] = int(lineSplit[1])
                d['taxonomy'] = lineSplit[2]
                try:
                    d['bootstrap'] = float(lineSplit[3])
                except:
                    d['bootstrap'] = 'NA'                 
                d['gc mean'] = float(lineSplit[4])
                d['gc std'] = float(lineSplit[5])
                d['genome size mean'] = float(lineSplit[6])/1e6
                d['genome size std'] = float(lineSplit[7])/1e6
                d['gene count mean'] = float(lineSplit[8])
                d['gene count std'] = float(lineSplit[9])
                d['marker set'] = lineSplit[10].rstrip()
                
                uniqueIdToLineageStatistics[uniqueId] = d
                
        return uniqueIdToLineageStatistics
    
    def __getNextNamedNode(self, node, uniqueIdToLineageStatistics):
        """Get first parent node with taxonomy information."""
        parentNode = node.parent_node
        while True:
            if parentNode == None:
                break # reached the root node so terminate
            
            if parentNode.label:
                trustedUniqueId = parentNode.label.split('|')[0]
                trustedStats = uniqueIdToLineageStatistics[trustedUniqueId]
                if trustedStats['taxonomy'] != '':
                    return trustedStats['taxonomy']
                
            parentNode = parentNode.parent_node            
                    
        return 'root'
    
    def __refineMarkerSet(self, markerSet, lineageSpecificMarkerSet):
        """Refine marker set to account for lineage-specific gene loss and duplication."""
                                        
        # refine marker set by finding the intersection between these two sets,
        # this removes markers that are not single-copy or ubiquitous in the 
        # specific lineage of a bin
        # Note: co-localization information is taken from the trusted set
                
        # remove genes not present in the lineage-specific gene set
        finalMarkerSet = []
        for ms in markerSet.markerSet:
            s = set()
            for gene in ms:
                if gene in lineageSpecificMarkerSet.getMarkerGenes():
                    s.add(gene)
                           
            if s:
                finalMarkerSet.append(s)

        refinedMarkerSet = MarkerSet(markerSet.UID, markerSet.lineageStr, markerSet.numGenomes, finalMarkerSet)
    
        return refinedMarkerSet
    
    def ____removeInvalidLineageMarkerGenes(self, markerSet, lineageSpecificMarkersToRemove):
        """Refine marker set to account for lineage-specific gene loss and duplication."""
                                        
        # refine marker set by removing marker genes subject to lineage-specific
        # gene loss and duplication 
        #
        # Note: co-localization information is taken from the trusted set
                
        finalMarkerSet = []
        for ms in markerSet.markerSet:
            s = set()
            for gene in ms:
                if gene.startswith('PF'):
                    print 'ERROR! Expected genes to start with pfam, not PF.'
                    
                if gene not in lineageSpecificMarkersToRemove:
                    s.add(gene)
                           
            if s:
                finalMarkerSet.append(s)

        refinedMarkerSet = MarkerSet(markerSet.UID, markerSet.lineageStr, markerSet.numGenomes, finalMarkerSet)
    
        return refinedMarkerSet
    
    def missingGenes(self, genomeIds, markerGenes, ubiquityThreshold):
        """Inferring consistently missing marker genes within a set of genomes."""
        
        if self.cachedGeneCountTable != None:
            geneCountTable = self.cachedGeneCountTable
        else:
            geneCountTable = self.img.geneCountTable(genomeIds)
        
        # find genes meeting ubiquity and single-copy thresholds
        missing = set()
        for clusterId, genomeCounts in geneCountTable.iteritems():
            if clusterId not in markerGenes:
                continue
                 
            absence = 0 
            for genomeId in genomeIds:
                count = genomeCounts.get(genomeId, 0)

                if count == 0:
                    absence += 1

            if absence >= ubiquityThreshold*len(genomeIds):
                missing.add(clusterId)

        return missing
    
    def duplicateGenes(self, genomeIds, markerGenes, ubiquityThreshold):
        """Inferring consistently duplicated marker genes within a set of genomes."""
        
        if self.cachedGeneCountTable != None:
            geneCountTable = self.cachedGeneCountTable
        else:
            geneCountTable = self.img.geneCountTable(genomeIds)
        
        # find genes meeting ubiquity and single-copy thresholds
        duplicate = set()
        for clusterId, genomeCounts in geneCountTable.iteritems():
            if clusterId not in markerGenes:
                continue
                 
            duplicateCount = 0 
            for genomeId in genomeIds:
                count = genomeCounts.get(genomeId, 0)

                if count > 1:
                    duplicateCount += 1

            if duplicateCount >= ubiquityThreshold*len(genomeIds):
                duplicate.add(clusterId)

        return duplicate
    
    def buildMarkerGenes(self, genomeIds, ubiquityThreshold, singleCopyThreshold):
        """Infer marker genes from specified genomes."""
        
        if self.cachedGeneCountTable != None:
            geneCountTable = self.cachedGeneCountTable
        else:
            geneCountTable = self.img.geneCountTable(genomeIds)
        
        #counts = []
        #singleCopy = 0
        #for genomeId, count in geneCountTable['pfam01351'].iteritems():
        #    print genomeId, count
        #    counts.append(count)
        #    if count == 1:
        #        singleCopy += 1
            
        #print 'Ubiquity: %d of %d' % (len(counts), len(genomeIds))
        #print 'Single-copy: %d of %d' % (singleCopy, len(genomeIds))
        #print 'Mean: %.2f' % mean(counts)

        markerGenes = self.markerGenes(genomeIds, geneCountTable, ubiquityThreshold*len(genomeIds), singleCopyThreshold*len(genomeIds))
        tigrToRemove = self.img.identifyRedundantTIGRFAMs(markerGenes)
        markerGenes = markerGenes - tigrToRemove

        return markerGenes
    
    def buildMarkerSet(self, genomeIds, ubiquityThreshold, singleCopyThreshold, spacingBetweenContigs = 5000):  
        """Infer marker set from specified genomes."""      

        markerGenes = self.buildMarkerGenes(genomeIds, ubiquityThreshold, singleCopyThreshold)

        geneDistTable = self.img.geneDistTable(genomeIds, markerGenes, spacingBetweenContigs)
        colocatedGenes = self.colocatedGenes(geneDistTable)
        colocatedSets = self.colocatedSets(colocatedGenes, markerGenes)
        markerSet = MarkerSet(0, 'NA', len(genomeIds), colocatedSets)

        return markerSet
    
    def readLineageSpecificGenesToRemove(self):
        """Get set of genes subject to lineage-specific gene loss and duplication."""       
    
        self.lineageSpecificGenesToRemove = {}
        for line in open('/srv/whitlam/bio/db/checkm/genome_tree/missing_duplicate_genes_50.tsv'):
            lineSplit = line.split('\t')
            uid = lineSplit[0]
            missingGenes = eval(lineSplit[1])
            duplicateGenes = eval(lineSplit[2])
            self.lineageSpecificGenesToRemove[uid] = missingGenes.union(duplicateGenes)
            
    def buildBinMarkerSet(self, tree, curNode, ubiquityThreshold, singleCopyThreshold, bMarkerSet = True, genomeIdsToRemove = None):   
        """Build lineage-specific marker sets for a genome in a LOO-fashion."""
                               
        # determine marker sets for bin      
        binMarkerSets = BinMarkerSets(curNode.label, BinMarkerSets.TREE_MARKER_SET)
        refinedBinMarkerSet = BinMarkerSets(curNode.label, BinMarkerSets.TREE_MARKER_SET)         

        # ascend tree to root, recording all marker sets 
        uniqueId = curNode.label.split('|')[0] 
        lineageSpecificRefinement = self.lineageSpecificGenesToRemove[uniqueId]
        
        while curNode != None:
            uniqueId = curNode.label.split('|')[0] 
            stats = self.uniqueIdToLineageStatistics[uniqueId]
            taxonomyStr = stats['taxonomy']
            if taxonomyStr == '':
                taxonomyStr = self.__getNextNamedNode(curNode, self.uniqueIdToLineageStatistics)

            leafNodes = curNode.leaf_nodes()
            genomeIds = set()
            for leaf in leafNodes:
                genomeIds.add(leaf.taxon.label.replace('IMG_', ''))
                
                duplicateGenomes = self.duplicateSeqs.get(leaf.taxon.label, [])
                for dup in duplicateGenomes:
                    genomeIds.add(dup.replace('IMG_', ''))

            # remove all genomes from the same taxonomic group as the genome of interest
            if genomeIdsToRemove != None:
                genomeIds.difference_update(genomeIdsToRemove) 

            if len(genomeIds) >= 2:
                if bMarkerSet:
                    markerSet = self.buildMarkerSet(genomeIds, ubiquityThreshold, singleCopyThreshold)
                else:
                    markerSet = MarkerSet(0, 'NA', len(genomeIds), [self.buildMarkerGenes(genomeIds, ubiquityThreshold, singleCopyThreshold)])
                
                markerSet.lineageStr = uniqueId + ' | ' + taxonomyStr.split(';')[-1]
                binMarkerSets.addMarkerSet(markerSet)
        
                #refinedMarkerSet = self.__refineMarkerSet(markerSet, lineageSpecificMarkerSet)
                refinedMarkerSet = self.____removeInvalidLineageMarkerGenes(markerSet, lineageSpecificRefinement)
                #print 'Refinement: %d of %d' % (len(refinedMarkerSet.getMarkerGenes()), len(markerSet.getMarkerGenes()))
                refinedBinMarkerSet.addMarkerSet(refinedMarkerSet)
            
            curNode = curNode.parent_node
                
        return binMarkerSets, refinedBinMarkerSet
    
    def buildDomainMarkerSet(self, tree, curNode, ubiquityThreshold, singleCopyThreshold, bMarkerSet = True, genomeIdsToRemove = None):   
        """Build domain-specific marker sets for a genome in a LOO-fashion."""
                               
        # determine marker sets for bin      
        binMarkerSets = BinMarkerSets(curNode.label, BinMarkerSets.TREE_MARKER_SET)
        refinedBinMarkerSet = BinMarkerSets(curNode.label, BinMarkerSets.TREE_MARKER_SET)         

        # calculate marker set for bacterial or archaeal node
        uniqueId = curNode.label.split('|')[0] 
        lineageSpecificRefinement = self.lineageSpecificGenesToRemove[uniqueId]
        
        while curNode != None:
            uniqueId = curNode.label.split('|')[0] 
            if uniqueId != 'UID2' and uniqueId != 'UID203':
                curNode = curNode.parent_node
                continue

            stats = self.uniqueIdToLineageStatistics[uniqueId]
            taxonomyStr = stats['taxonomy']
            if taxonomyStr == '':
                taxonomyStr = self.__getNextNamedNode(curNode, self.uniqueIdToLineageStatistics)

            leafNodes = curNode.leaf_nodes()
            genomeIds = set()
            for leaf in leafNodes:
                genomeIds.add(leaf.taxon.label.replace('IMG_', ''))
                
                duplicateGenomes = self.duplicateSeqs.get(leaf.taxon.label, [])
                for dup in duplicateGenomes:
                    genomeIds.add(dup.replace('IMG_', ''))

            # remove all genomes from the same taxonomic group as the genome of interest
            if genomeIdsToRemove != None:
                genomeIds.difference_update(genomeIdsToRemove) 

            if len(genomeIds) >= 2:
                if bMarkerSet:
                    markerSet = self.buildMarkerSet(genomeIds, ubiquityThreshold, singleCopyThreshold)
                else:
                    markerSet = MarkerSet(0, 'NA', len(genomeIds), [self.buildMarkerGenes(genomeIds, ubiquityThreshold, singleCopyThreshold)])
                
                markerSet.lineageStr = uniqueId + ' | ' + taxonomyStr.split(';')[-1]
                binMarkerSets.addMarkerSet(markerSet)
        
                #refinedMarkerSet = self.__refineMarkerSet(markerSet, lineageSpecificMarkerSet)
                refinedMarkerSet = self.____removeInvalidLineageMarkerGenes(markerSet, lineageSpecificRefinement)
                #print 'Refinement: %d of %d' % (len(refinedMarkerSet.getMarkerGenes()), len(markerSet.getMarkerGenes()))
                refinedBinMarkerSet.addMarkerSet(refinedMarkerSet)
            
            curNode = curNode.parent_node
                
        return binMarkerSets, refinedBinMarkerSet
예제 #33
0
    def run(self,
            geneTreeDir,
            alignmentDir,
            extension,
            outputAlignFile,
            outputTree,
            outputTaxonomy,
            bSupportValues=False):
        # read gene trees
        print 'Reading gene trees.'
        geneIds = set()
        files = os.listdir(geneTreeDir)
        for f in files:
            if f.endswith('.tre'):
                geneId = f[0:f.find('.')]
                geneIds.add(geneId)

        # write out genome tree taxonomy
        print 'Reading trusted genomes.'
        img = IMG('/srv/whitlam/bio/db/checkm/img/img_metadata.tsv',
                  '/srv/whitlam/bio/db/checkm/pfam/tigrfam2pfam.tsv')
        genomeIds = img.genomeMetadata().keys()
        self.__taxonomy(img, genomeIds, outputTaxonomy)

        print '  There are %d trusted genomes.' % (len(genomeIds))

        # get genes in genomes
        print 'Reading all PFAM and TIGRFAM hits in trusted genomes.'
        genesInGenomes = self.__genesInGenomes(genomeIds)

        # read alignment files
        print 'Reading alignment files.'
        alignments = {}
        genomeIds = set()
        files = os.listdir(alignmentDir)
        for f in files:
            geneId = f[0:f.find('.')]
            if f.endswith(extension) and geneId in geneIds:
                seqs = readFasta(os.path.join(alignmentDir, f))

                imgGeneId = geneId
                if imgGeneId.startswith('PF'):
                    imgGeneId = imgGeneId.replace('PF', 'pfam')
                seqs = self.__filterParalogs(seqs, imgGeneId, genesInGenomes)

                genomeIds.update(set(seqs.keys()))
                alignments[geneId] = seqs

        # create concatenated alignment
        print 'Concatenating alignments:'
        concatenatedSeqs = {}
        totalAlignLen = 0
        for geneId in sorted(alignments.keys()):
            seqs = alignments[geneId]
            alignLen = len(seqs[seqs.keys()[0]])
            print '  ' + str(geneId) + ',' + str(alignLen)
            totalAlignLen += alignLen
            for genomeId in genomeIds:
                if genomeId in seqs:
                    # append alignment
                    concatenatedSeqs['IMG_' + genomeId] = concatenatedSeqs.get(
                        'IMG_' + genomeId, '') + seqs[genomeId]
                else:
                    # missing gene
                    concatenatedSeqs['IMG_' + genomeId] = concatenatedSeqs.get(
                        'IMG_' + genomeId, '') + '-' * alignLen

        print '  Total alignment length: ' + str(totalAlignLen)

        # save concatenated alignment
        writeFasta(concatenatedSeqs, outputAlignFile)

        # infer genome tree
        print 'Inferring genome tree.'
        outputLog = outputTree[0:outputTree.rfind('.')] + '.log'

        supportStr = ' '
        if not bSupportValues:
            supportStr = ' -nosupport '

        cmd = 'FastTreeMP' + supportStr + '-wag -gamma -log ' + outputLog + ' ' + outputAlignFile + ' > ' + outputTree
        os.system(cmd)
예제 #34
0
    def run(self, geneTreeDir, acceptPer, extension, outputDir):
        # make sure output directory is empty
        if not os.path.exists(outputDir):
            os.makedirs(outputDir)

        files = os.listdir(outputDir)
        for f in files:
            os.remove(os.path.join(outputDir, f))

        img = IMG('/srv/whitlam/bio/db/checkm/img/img_metadata.tsv',
                  '/srv/whitlam/bio/db/checkm/pfam/tigrfam2pfam.tsv')
        metadata = img.genomeMetadata()

        files = os.listdir(geneTreeDir)
        print('Identifying gene trees with only conspecific paralogous genes:')
        filteredGeneTrees = 0
        retainedGeneTrees = 0
        for f in files:
            if not f.endswith(extension):
                continue

            geneId = f[0:f.find('.')]
            print('  Testing gene tree: ' + geneId)

            tree = dendropy.Tree.get_from_path(os.path.join(geneTreeDir, f),
                                               schema='newick',
                                               as_rooted=False,
                                               preserve_underscores=True)

            taxa = tree.leaf_nodes()
            numTaxa = len(taxa)
            print('  Genes in tree: ' + str(numTaxa))

            # root tree with archaeal genomes
            rerootTree = RerootTree()
            rerootTree.reroot(tree)

            # get species name of each taxa
            leafNodeToSpeciesName = {}
            for t in taxa:
                genomeId = t.taxon.label.split('|')[0]
                genus = metadata[genomeId]['taxonomy'][5]
                sp = metadata[genomeId]['taxonomy'][6].lower()

                leafNodeToSpeciesName[t.taxon.label] = genus + ' ' + sp

            # find all paralogous genes
            print('  Finding paralogous genes.')

            paralogs = defaultdict(set)
            for i in range(0, len(taxa)):
                genomeId = taxa[i].taxon.label.split('|')[0]
                for j in range(i + 1, len(taxa)):
                    # genes from the same genome are paralogs, but we filter out
                    # those that are identical (distance of 0 on the tree) to
                    # speed up computation and because these clearly do not
                    # adversely effect phylogenetic inference
                    if genomeId == taxa[j].taxon.label.split(
                            '|')[0] and self.__patristicDist(
                                tree, taxa[i], taxa[j]) > 0:
                        paralogs[genomeId].add(taxa[i].taxon.label)
                        paralogs[genomeId].add(taxa[j].taxon.label)

            print('    Paralogous genes: ' + str(len(paralogs)))

            # check if paralogous genes are conspecific
            print('  Determining if paralogous genes are conspecific.')
            nonConspecificGenomes = []
            for genomeId, taxaLabels in paralogs.iteritems():
                lcaNode = tree.mrca(taxon_labels=taxaLabels)

                children = lcaNode.leaf_nodes()
                species = set()
                for child in children:
                    childGenomeId = child.taxon.label.split('|')[0]

                    genus = metadata[childGenomeId]['taxonomy'][5]
                    sp = metadata[childGenomeId]['taxonomy'][6].lower()
                    if sp != '' and sp != 'unclassified' and genus != 'unclassified':
                        species.add(genus + ' ' + sp)

                if len(species) > 1:
                    nonConspecificGenomes.append(genomeId)

            if len(nonConspecificGenomes) > acceptPer * numTaxa:
                filteredGeneTrees += 1
                print('  Tree is not conspecific for the following genome: ' +
                      str(nonConspecificGenomes))
            else:
                retainedGeneTrees += 1

                if len(nonConspecificGenomes) > 1:
                    print(
                        '  An acceptable number of genomes are not conspecific: '
                        + str(nonConspecificGenomes))
                else:
                    print('  Tree is conspecific.')

                os.system('cp ' + os.path.join(geneTreeDir, f) + ' ' +
                          os.path.join(outputDir, f))

            print('')

        print('Filtered gene trees: ' + str(filteredGeneTrees))
        print('Retained gene trees: ' + str(retainedGeneTrees))
예제 #35
0
    def run(self, geneTreeDir, treeExtension, consistencyThreshold,
            minTaxaForAverage, outputFile, outputDir):
        # make sure output directory is empty
        if not os.path.exists(outputDir):
            os.makedirs(outputDir)

        files = os.listdir(outputDir)
        for f in files:
            if os.path.isfile(os.path.join(outputDir, f)):
                os.remove(os.path.join(outputDir, f))

        # get TIGRFam info
        descDict = {}
        files = os.listdir('/srv/db/tigrfam/13.0/TIGRFAMs_13.0_INFO')
        for f in files:
            shortDesc = longDesc = ''
            for line in open('/srv/db/tigrfam/13.0/TIGRFAMs_13.0_INFO/' + f):
                lineSplit = line.split('  ')
                if lineSplit[0] == 'AC':
                    acc = lineSplit[1].strip()
                elif lineSplit[0] == 'DE':
                    shortDesc = lineSplit[1].strip()
                elif lineSplit[0] == 'CC':
                    longDesc = lineSplit[1].strip()

            descDict[acc] = [shortDesc, longDesc]

        # get PFam info
        for line in open('/srv/db/pfam/27/Pfam-A.clans.tsv'):
            lineSplit = line.split('\t')
            acc = lineSplit[0]
            shortDesc = lineSplit[3]
            longDesc = lineSplit[4].strip()

            descDict[acc] = [shortDesc, longDesc]

        # get IMG taxonomy
        img = IMG('/srv/whitlam/bio/db/checkm/img/img_metadata.tsv',
                  '/srv/whitlam/bio/db/checkm/pfam/tigrfam2pfam.tsv')
        metadata = img.genomeMetadata()
        genomeIdToTaxonomy = {}
        for genomeId, m in metadata.iteritems():
            genomeIdToTaxonomy[genomeId] = m['taxonomy']

        # perform analysis for each tree
        treeFiles = os.listdir(geneTreeDir)
        allResults = {}
        allTaxa = [set([]), set([]), set([])]
        taxaCounts = {}
        avgConsistency = {}
        for treeFile in treeFiles:
            if not treeFile.endswith(treeExtension):
                continue

            print treeFile
            tree = dendropy.Tree.get_from_path(os.path.join(
                geneTreeDir, treeFile),
                                               schema='newick',
                                               as_rooted=True,
                                               preserve_underscores=True)

            domainConsistency = {}
            phylaConsistency = {}
            classConsistency = {}
            consistencyDict = [
                domainConsistency, phylaConsistency, classConsistency
            ]

            # get abundance of taxa at different taxonomic ranks
            totals = [{}, {}, {}]
            leaves = tree.leaf_nodes()
            print '  Number of leaves: ' + str(len(leaves))
            totalValidLeaves = 0

            for leaf in leaves:
                genomeId = self.__genomeId(leaf.taxon.label)

                if genomeId not in metadata:
                    print '[Error] Genome is missing metadata: ' + genomeId
                    sys.exit()

                totalValidLeaves += 1
                taxonomy = genomeIdToTaxonomy[genomeId]
                for r in xrange(0, 3):
                    totals[r][taxonomy[r]] = totals[r].get(taxonomy[r], 0) + 1
                    consistencyDict[r][taxonomy[r]] = 0
                    allTaxa[r].add(taxonomy[r])

            taxaCounts[treeFile] = [
                totalValidLeaves, totals[0].get('Bacteria', 0),
                totals[0].get('Archaea', 0)
            ]

            # find highest consistency nodes (congruent descendant taxa / (total taxa + incongruent descendant taxa))
            internalNodes = tree.internal_nodes()
            for node in internalNodes:
                leaves = node.leaf_nodes()

                for r in xrange(0, 3):
                    leafCounts = {}
                    for leaf in leaves:
                        genomeId = self.__genomeId(leaf.taxon.label)
                        taxonomy = genomeIdToTaxonomy[genomeId]
                        leafCounts[taxonomy[r]] = leafCounts.get(
                            taxonomy[r], 0) + 1

                    # calculate consistency for node
                    for taxa in consistencyDict[r]:
                        totalTaxaCount = totals[r][taxa]
                        if totalTaxaCount <= 1 or taxa == 'unclassified':
                            consistencyDict[r][taxa] = 'N/A'
                            continue

                        taxaCount = leafCounts.get(taxa, 0)
                        incongruentTaxa = len(leaves) - taxaCount
                        c = float(taxaCount) / (totalTaxaCount +
                                                incongruentTaxa)
                        if c > consistencyDict[r][taxa]:
                            consistencyDict[r][taxa] = c

                        # consider clan in other direction since the trees are unrooted
                        taxaCount = totalTaxaCount - leafCounts.get(taxa, 0)
                        incongruentTaxa = totalValidLeaves - len(
                            leaves) - taxaCount
                        c = float(taxaCount) / (totalTaxaCount +
                                                incongruentTaxa)
                        if c > consistencyDict[r][taxa]:
                            consistencyDict[r][taxa] = c

            # write results
            consistencyDir = os.path.join(outputDir, 'consistency')
            if not os.path.exists(consistencyDir):
                os.makedirs(consistencyDir)
            fout = open(
                os.path.join(consistencyDir, treeFile + '.results.tsv'), 'w')
            fout.write('Tree')
            for r in xrange(0, 3):
                for taxa in sorted(consistencyDict[r].keys()):
                    fout.write('\t' + taxa)
            fout.write('\n')

            fout.write(treeFile)
            for r in xrange(0, 3):
                for taxa in sorted(consistencyDict[r].keys()):
                    if consistencyDict[r][taxa] != 'N/A':
                        fout.write('\t%.2f' % (consistencyDict[r][taxa] * 100))
                    else:
                        fout.write('\tN/A')
            fout.close()

            # calculate average consistency at each taxonomic rank
            average = []
            for r in xrange(0, 3):
                sumConsistency = []
                for taxa in consistencyDict[r]:
                    if totals[r][taxa] > minTaxaForAverage and consistencyDict[
                            r][taxa] != 'N/A':
                        sumConsistency.append(consistencyDict[r][taxa])

                if len(sumConsistency) > 0:
                    average.append(sum(sumConsistency) / len(sumConsistency))
                else:
                    average.append(0)
            avgConsistency[treeFile] = average
            allResults[treeFile] = consistencyDict

            print '  Average consistency: ' + str(
                average) + ', mean = %.2f' % (sum(average) / len(average))
            print ''

        # print out combined results
        fout = open(outputFile, 'w')
        fout.write(
            'Tree\tShort Desc.\tLong Desc.\tAlignment Length\t# Taxa\t# Bacteria\t# Archaea\tAvg. Consistency\tAvg. Domain Consistency\tAvg. Phylum Consistency\tAvg. Class Consistency'
        )
        for r in xrange(0, 3):
            for t in sorted(allTaxa[r]):
                fout.write('\t' + t)
        fout.write('\n')

        filteredGeneTrees = 0
        retainedGeneTrees = 0
        for treeFile in sorted(allResults.keys()):
            consistencyDict = allResults[treeFile]
            treeId = treeFile[0:treeFile.find('.')].replace('pfam', 'PF')

            fout.write(treeId + '\t' + descDict[treeId][0] + '\t' +
                       descDict[treeId][1])

            # Taxa count
            fout.write('\t' + str(taxaCounts[treeFile][0]) + '\t' +
                       str(taxaCounts[treeFile][1]) + '\t' +
                       str(taxaCounts[treeFile][2]))

            avgCon = 0
            for r in xrange(0, 3):
                avgCon += avgConsistency[treeFile][r]
            avgCon /= 3
            fout.write('\t' + str(avgCon))

            if avgCon >= consistencyThreshold:
                retainedGeneTrees += 1
                os.system('cp ' + os.path.join(geneTreeDir, treeFile) + ' ' +
                          os.path.join(outputDir, treeFile))
            else:
                filteredGeneTrees += 1
                print 'Filtered % s with an average consistency of %.4f.' % (
                    treeFile, avgCon)

            for r in xrange(0, 3):
                fout.write('\t' + str(avgConsistency[treeFile][r]))

            for r in xrange(0, 3):
                for t in sorted(allTaxa[r]):
                    if t in consistencyDict[r]:
                        if consistencyDict[r][t] != 'N/A':
                            fout.write('\t%.2f' %
                                       (consistencyDict[r][t] * 100))
                        else:
                            fout.write('\tN/A')
                    else:
                        fout.write('\tN/A')
            fout.write('\n')
        fout.close()

        print 'Retained gene trees: ' + str(retainedGeneTrees)
        print 'Filtered gene trees: ' + str(filteredGeneTrees)
예제 #36
0
class DecorateTree(object):
    def __init__(self):
        self.img = IMG('/srv/whitlam/bio/db/checkm/img/img_metadata.tsv',
                       '/srv/whitlam/bio/db/checkm/pfam/tigrfam2pfam.tsv')
        self.pfamHMMs = '/srv/whitlam/bio/db/pfam/27/Pfam-A.hmm'
        self.markerSetBuilder = MarkerSetBuilder()

    def __meanStd(self, metadata, genomeIds, category):
        values = []
        for genomeId in genomeIds:
            genomeId = genomeId.replace('IMG_', '')
            v = metadata[genomeId][category]
            if v != 'NA':
                values.append(v)

        return mean(values), std(values)

    def __calculateMarkerSet(self,
                             genomeLabels,
                             ubiquityThreshold=0.97,
                             singleCopyThreshold=0.97):
        """Calculate marker set for a set of genomes."""

        # get genome IDs from genome labels
        genomeIds = set()
        for genomeLabel in genomeLabels:
            genomeIds.add(genomeLabel.replace('IMG_', ''))

        markerSet = self.markerSetBuilder.buildMarkerSet(
            genomeIds, ubiquityThreshold, singleCopyThreshold)

        return markerSet.markerSet

    def __pfamIdToPfamAcc(self, img):
        pfamIdToPfamAcc = {}
        for line in open(self.pfamHMMs):
            if 'ACC' in line:
                acc = line.split()[1].strip()
                pfamId = acc.split('.')[0]

                pfamIdToPfamAcc[pfamId] = acc

        return pfamIdToPfamAcc

    def decorate(self, taxaTreeFile, derepFile, inputTreeFile, metadataOut,
                 numThreads):
        # read genome metadata
        print('  Reading metadata.')
        metadata = self.img.genomeMetadata()

        # read list of taxa with duplicate sequences
        print('  Read list of taxa with duplicate sequences.')
        duplicateTaxa = {}
        for line in open(derepFile):
            lineSplit = line.rstrip().split()
            if len(lineSplit) > 1:
                duplicateTaxa[lineSplit[0]] = lineSplit[1:]

        # build gene count table
        print('  Building gene count table.')
        genomeIds = self.img.genomeMetadata().keys()
        print('    # trusted genomes = ' + str(len(genomeIds)))

        # calculate statistics for each internal node using multiple threads
        print('  Calculating statistics for each internal node.')
        self.__internalNodeStatistics(taxaTreeFile, inputTreeFile,
                                      duplicateTaxa, metadata, metadataOut,
                                      numThreads)

    def __internalNodeStatistics(self, taxaTreeFile, inputTreeFile,
                                 duplicateTaxa, metadata, metadataOut,
                                 numThreads):

        # determine HMM model accession numbers
        pfamIdToPfamAcc = self.__pfamIdToPfamAcc(self.img)

        taxaTree = dendropy.Tree.get_from_path(taxaTreeFile,
                                               schema='newick',
                                               as_rooted=True,
                                               preserve_underscores=True)
        inputTree = dendropy.Tree.get_from_path(inputTreeFile,
                                                schema='newick',
                                                as_rooted=True,
                                                preserve_underscores=True)

        workerQueue = mp.Queue()
        writerQueue = mp.Queue()

        uniqueId = 0
        for node in inputTree.internal_nodes():
            uniqueId += 1
            workerQueue.put((uniqueId, node))

        for _ in range(numThreads):
            workerQueue.put((None, None))

        calcProc = [
            mp.Process(target=self.__processInternalNode,
                       args=(taxaTree, duplicateTaxa, workerQueue,
                             writerQueue)) for _ in range(numThreads)
        ]
        writeProc = mp.Process(target=self.__reportStatistics,
                               args=(metadata, metadataOut, inputTree,
                                     inputTreeFile, pfamIdToPfamAcc,
                                     writerQueue))

        writeProc.start()

        for p in calcProc:
            p.start()

        for p in calcProc:
            p.join()

        writerQueue.put((None, None, None, None, None, None, None))
        writeProc.join()

    def __processInternalNode(self, taxaTree, duplicateTaxa, queueIn,
                              queueOut):
        """Run each marker gene in a separate thread."""

        while True:
            uniqueId, node = queueIn.get(block=True, timeout=None)
            if uniqueId == None:
                break

            # find corresponding internal node in taxa tree
            labels = []
            for leaf in node.leaf_nodes():
                labels.append(leaf.taxon.label)
                if leaf.taxon.label in duplicateTaxa:
                    for genomeId in duplicateTaxa[leaf.taxon.label]:
                        labels.append(genomeId)

            # check if there is a taxonomic label
            mrca = taxaTree.mrca(taxon_labels=labels)
            taxaStr = ''
            if mrca.label:
                taxaStr = mrca.label.replace(' ', '')

            # give node a unique Id while retraining bootstrap value
            bootstrap = ''
            if node.label:
                bootstrap = node.label
            nodeLabel = 'UID' + str(uniqueId) + '|' + taxaStr + '|' + bootstrap

            # calculate marker set
            markerSet = self.__calculateMarkerSet(labels)

            queueOut.put((uniqueId, labels, markerSet, taxaStr, bootstrap,
                          node.oid, nodeLabel))

    def __reportStatistics(self, metadata, metadataOut, inputTree,
                           inputTreeFile, pfamIdToPfamAcc, writerQueue):
        """Store statistics for internal node."""

        fout = open(metadataOut, 'w')
        fout.write('UID\t# genomes\tTaxonomy\tBootstrap')
        fout.write('\tGC mean\tGC std')
        fout.write('\tGenome size mean\tGenome size std')
        fout.write('\tGene count mean\tGene count std')
        fout.write('\tMarker set')
        fout.write('\n')

        numProcessedNodes = 0
        numInternalNodes = len(inputTree.internal_nodes())
        while True:
            uniqueId, labels, markerSet, taxaStr, bootstrap, nodeID, nodeLabel = writerQueue.get(
                block=True, timeout=None)
            if uniqueId == None:
                break

            numProcessedNodes += 1
            statusStr = '    Finished processing %d of %d (%.2f%%) internal nodes.' % (
                numProcessedNodes, numInternalNodes,
                float(numProcessedNodes) * 100 / numInternalNodes)
            sys.stdout.write('%s\r' % statusStr)
            sys.stdout.flush()

            fout.write('UID' + str(uniqueId) + '\t' + str(len(labels)) + '\t' +
                       taxaStr + '\t' + bootstrap)

            m, s = self.__meanStd(metadata, labels, 'GC %')
            fout.write('\t' + str(m * 100) + '\t' + str(s * 100))

            m, s = self.__meanStd(metadata, labels, 'genome size')
            fout.write('\t' + str(m) + '\t' + str(s))

            m, s = self.__meanStd(metadata, labels, 'gene count')
            fout.write('\t' + str(m) + '\t' + str(s))

            # change model names to accession numbers, and make
            # sure there is an HMM model for each PFAM
            mungedMarkerSets = []
            for geneSet in markerSet:
                s = set()
                for geneId in geneSet:
                    if 'pfam' in geneId:
                        pfamId = geneId.replace('pfam', 'PF')
                        if pfamId in pfamIdToPfamAcc:
                            s.add(pfamIdToPfamAcc[pfamId])
                    else:
                        s.add(geneId)
                mungedMarkerSets.append(s)

            fout.write('\t' + str(mungedMarkerSets))

            fout.write('\n')

            node = inputTree.find_node(
                filter_fn=lambda n: hasattr(n, 'oid') and n.oid == nodeID)
            node.label = nodeLabel

        sys.stdout.write('\n')

        fout.close()

        inputTree.write_to_path(inputTreeFile,
                                schema='newick',
                                suppress_rooting=True,
                                unquoted_underscores=True)
예제 #37
0
class MarkerSetBuilder(object):
    def __init__(self):
        self.img = IMG('/srv/whitlam/bio/db/checkm/img/img_metadata.tsv',
                       '/srv/whitlam/bio/db/checkm/pfam/tigrfam2pfam.tsv')
        self.colocatedFile = './data/colocated.tsv'
        self.duplicateSeqs = self.readDuplicateSeqs()
        self.uniqueIdToLineageStatistics = self.__readNodeMetadata()

        self.cachedGeneCountTable = None

    def precomputeGenomeSeqLens(self, genomeIds):
        """Cache the length of contigs/scaffolds for all genomes."""

        # This function is intended to speed up functions, such as img.geneDistTable(),
        # that are called multiple times (typically during simulations)
        self.img.precomputeGenomeSeqLens(genomeIds)

    def precomputeGenomeFamilyPositions(self, genomeIds,
                                        spacingBetweenContigs):
        """Cache position of PFAM and TIGRFAM genes in genomes."""

        # This function is intended to speed up functions, such as img.geneDistTable(),
        # that are called multiple times (typically during simulations)
        self.img.precomputeGenomeFamilyPositions(genomeIds,
                                                 spacingBetweenContigs)

    def precomputeGenomeFamilyScaffolds(self, genomeIds):
        """Cache scaffolds of PFAM and TIGRFAM genes in genomes."""

        # This function is intended to speed up functions, such as img.geneDistTable(),
        # that are called multiple times (typically during simulations)
        self.img.precomputeGenomeFamilyScaffolds(genomeIds)

    def getLineageMarkerGenes(self, lineage, minGenomes=20, minMarkerSets=20):
        pfamIds = set()
        tigrIds = set()

        bHeader = True
        for line in open(self.colocatedFile):
            if bHeader:
                bHeader = False
                continue

            lineSplit = line.split('\t')
            curLineage = lineSplit[0]
            numGenomes = int(lineSplit[1])
            numMarkerSets = int(lineSplit[3])
            markerSets = lineSplit[4:]

            if curLineage != lineage or numGenomes < minGenomes or numMarkerSets < minMarkerSets:
                continue

            for ms in markerSets:
                markers = ms.split(',')
                for m in markers:
                    if 'pfam' in m:
                        pfamIds.add(m.strip())
                    elif 'TIGR' in m:
                        tigrIds.add(m.strip())

        return pfamIds, tigrIds

    def getCalculatedMarkerGenes(self, minGenomes=20, minMarkerSets=20):
        pfamIds = set()
        tigrIds = set()

        bHeader = True
        for line in open(self.colocatedFile):
            if bHeader:
                bHeader = False
                continue

            lineSplit = line.split('\t')
            numGenomes = int(lineSplit[1])
            numMarkerSets = int(lineSplit[3])
            markerSets = lineSplit[4:]

            if numGenomes < minGenomes or numMarkerSets < minMarkerSets:
                continue

            for ms in markerSets:
                markers = ms.split(',')
                for m in markers:
                    if 'pfam' in m:
                        pfamIds.add(m.strip())
                    elif 'TIGR' in m:
                        tigrIds.add(m.strip())

        return pfamIds, tigrIds

    def markerGenes(self, genomeIds, countTable, ubiquityThreshold,
                    singleCopyThreshold):
        if ubiquityThreshold < 1 or singleCopyThreshold < 1:
            print('[Warning] Looks like degenerate threshold.')

        # find genes meeting ubiquity and single-copy thresholds
        markers = set()
        for clusterId, genomeCounts in countTable.iteritems():
            ubiquity = 0
            singleCopy = 0

            if len(genomeCounts) < ubiquityThreshold:
                # gene is clearly not ubiquitous
                continue

            for genomeId in genomeIds:
                count = genomeCounts.get(genomeId, 0)

                if count > 0:
                    ubiquity += 1

                if count == 1:
                    singleCopy += 1

            if ubiquity >= ubiquityThreshold and singleCopy >= singleCopyThreshold:
                markers.add(clusterId)

        return markers

    def colocatedGenes(self,
                       geneDistTable,
                       distThreshold=5000,
                       genomeThreshold=0.95):
        """Identify co-located gene pairs."""

        colocatedGenes = defaultdict(int)
        for _, clusterIdToGeneLocs in geneDistTable.iteritems():
            clusterIds = clusterIdToGeneLocs.keys()
            for i, clusterId1 in enumerate(clusterIds):
                geneLocations1 = clusterIdToGeneLocs[clusterId1]

                for clusterId2 in clusterIds[i + 1:]:
                    geneLocations2 = clusterIdToGeneLocs[clusterId2]
                    bColocated = False
                    for p1 in geneLocations1:
                        for p2 in geneLocations2:
                            if abs(p1[0] - p2[0]) < distThreshold:
                                bColocated = True
                                break

                        if bColocated:
                            break

                    if bColocated:
                        if clusterId1 <= clusterId2:
                            colocatedStr = clusterId1 + '-' + clusterId2
                        else:
                            colocatedStr = clusterId2 + '-' + clusterId1
                        colocatedGenes[colocatedStr] += 1

        colocated = []
        for colocatedStr, count in colocatedGenes.iteritems():
            if float(count) / len(geneDistTable) > genomeThreshold:
                colocated.append(colocatedStr)

        return colocated

    def colocatedSets(self, colocatedGenes, markerGenes):
        # run through co-located genes once creating initial sets
        sets = []
        for cg in colocatedGenes:
            geneA, geneB = cg.split('-')
            sets.append(set([geneA, geneB]))

        # combine any sets with overlapping genes
        bProcessed = [False] * len(sets)
        finalSets = []
        for i in range(0, len(sets)):
            if bProcessed[i]:
                continue

            curSet = sets[i]
            bProcessed[i] = True

            bUpdated = True
            while bUpdated:
                bUpdated = False
                for j in range(i + 1, len(sets)):
                    if bProcessed[j]:
                        continue

                    if len(curSet.intersection(sets[j])) > 0:
                        curSet.update(sets[j])
                        bProcessed[j] = True
                        bUpdated = True

            finalSets.append(curSet)

        # add all singletons into colocated sets
        for clusterId in markerGenes:
            bFound = False
            for cs in finalSets:
                if clusterId in cs:
                    bFound = True

            if not bFound:
                finalSets.append(set([clusterId]))

        return finalSets

    def genomeCheck(self, colocatedSet, genomeId, countTable):
        comp = 0.0
        cont = 0.0
        missingMarkers = set()
        duplicateMarkers = set()

        if len(colocatedSet) == 0:
            return comp, cont, missingMarkers, duplicateMarkers

        for cs in colocatedSet:
            present = 0
            multiCopy = 0
            for contigId in cs:
                count = countTable[contigId].get(genomeId, 0)
                if count == 1:
                    present += 1
                elif count > 1:
                    present += 1
                    multiCopy += (count - 1)
                    duplicateMarkers.add(contigId)
                elif count == 0:
                    missingMarkers.add(contigId)

            comp += float(present) / len(cs)
            cont += float(multiCopy) / len(cs)

        return comp / len(colocatedSet), cont / len(
            colocatedSet), missingMarkers, duplicateMarkers

    def uniformity(self, genomeSize, pts):
        U = float(genomeSize) / (
            len(pts) + 1)  # distance between perfectly evenly spaced points

        # calculate distance between adjacent points
        dists = []
        pts = sorted(pts)
        for i in range(0, len(pts) - 1):
            dists.append(pts[i + 1] - pts[i])

        # calculate uniformity index
        num = 0
        den = 0
        for d in dists:
            num += abs(d - U)
            den += max(d, U)

        return 1.0 - num / den

    def sampleGenome(self, genomeLen, percentComp, percentCont, contigLen):
        """Sample a genome to simulate a given percent completion and contamination."""

        contigsInGenome = genomeLen / contigLen

        # determine number of contigs to achieve desired completeness and contamination
        contigsToSampleComp = int(contigsInGenome * percentComp + 0.5)
        contigsToSampleCont = int(contigsInGenome * percentCont + 0.5)

        # randomly sample contigs with contamination done via sampling with replacement
        compContigs = random.sample(range(contigsInGenome),
                                    contigsToSampleComp)
        contContigs = choice(range(contigsInGenome),
                             contigsToSampleCont,
                             replace=True)

        # determine start of each contig
        contigStarts = [c * contigLen for c in compContigs]
        contigStarts += [c * contigLen for c in contContigs]

        contigStarts.sort()

        trueComp = float(contigsToSampleComp) * contigLen * 100 / genomeLen
        trueCont = float(contigsToSampleCont) * contigLen * 100 / genomeLen

        return trueComp, trueCont, contigStarts

    def sampleGenomeScaffoldsInvLength(self, targetPer, seqLens, genomeSize):
        """Sample genome comprised of several sequences with probability inversely proportional to length."""

        # calculate probability of sampling a sequences
        seqProb = []
        for _, seqLen in seqLens.iteritems():
            prob = 1.0 / (float(seqLen) / genomeSize)
            seqProb.append(prob)

        seqProb = array(seqProb)
        seqProb /= sum(seqProb)

        # select sequence with probability proportional to length
        selectedSeqsIds = choice(seqLens.keys(),
                                 size=len(seqLens),
                                 replace=False,
                                 p=seqProb)

        sampledSeqIds = []
        truePer = 0.0
        for seqId in selectedSeqsIds:
            sampledSeqIds.append(seqId)
            truePer += float(seqLens[seqId]) / genomeSize

            if truePer >= targetPer:
                break

        return sampledSeqIds, truePer * 100

    def sampleGenomeScaffoldsWithoutReplacement(self, targetPer, seqLens,
                                                genomeSize):
        """Sample genome comprised of several sequences without replacement.

          Sampling is conducted randomly until the selected sequences comprise
          greater than or equal to the desired target percentage.
        """

        selectedSeqsIds = choice(seqLens.keys(),
                                 size=len(seqLens),
                                 replace=False)

        sampledSeqIds = []
        truePer = 0.0
        for seqId in selectedSeqsIds:
            sampledSeqIds.append(seqId)
            truePer += float(seqLens[seqId]) / genomeSize

            if truePer >= targetPer:
                break

        return sampledSeqIds, truePer * 100

    def containedMarkerGenes(self, markerGenes, clusterIdToGenomePositions,
                             startPartialGenomeContigs, contigLen):
        """Determine markers contained in a set of contigs."""

        contained = {}
        for markerGene in markerGenes:
            positions = clusterIdToGenomePositions.get(markerGene, [])

            containedPos = []
            for p in positions:
                for s in startPartialGenomeContigs:
                    if (p[0] - s) >= 0 and (p[0] - s) < contigLen:
                        containedPos.append(s)

            if len(containedPos) > 0:
                contained[markerGene] = containedPos

        return contained

    def markerGenesOnScaffolds(self, markerGenes, genomeId, scaffoldIds,
                               containedMarkerGenes):
        """Determine if marker genes are found on the scaffolds of a given genome."""
        for markerGeneId in markerGenes:
            scaffoldIdsWithMarker = self.img.cachedGenomeFamilyScaffolds[
                genomeId].get(markerGeneId, [])

            for scaffoldId in scaffoldIdsWithMarker:
                if scaffoldId in scaffoldIds:
                    containedMarkerGenes[markerGeneId] += [scaffoldId]

    def readDuplicateSeqs(self):
        """Parse file indicating duplicate sequence alignments."""
        duplicateSeqs = {}
        for line in open(
                os.path.join('/srv/whitlam/bio/db/checkm/genome_tree',
                             'genome_tree.derep.txt')):
            lineSplit = line.rstrip().split()
            if len(lineSplit) > 1:
                duplicateSeqs[lineSplit[0]] = lineSplit[1:]

        return duplicateSeqs

    def __readNodeMetadata(self):
        """Read metadata for internal nodes."""

        uniqueIdToLineageStatistics = {}
        metadataFile = os.path.join('/srv/whitlam/bio/db/checkm/genome_tree',
                                    'genome_tree.metadata.tsv')
        with open(metadataFile) as f:
            f.readline()
            for line in f:
                lineSplit = line.rstrip().split('\t')

                uniqueId = lineSplit[0]

                d = {}
                d['# genomes'] = int(lineSplit[1])
                d['taxonomy'] = lineSplit[2]
                try:
                    d['bootstrap'] = float(lineSplit[3])
                except:
                    d['bootstrap'] = 'NA'
                d['gc mean'] = float(lineSplit[4])
                d['gc std'] = float(lineSplit[5])
                d['genome size mean'] = float(lineSplit[6]) / 1e6
                d['genome size std'] = float(lineSplit[7]) / 1e6
                d['gene count mean'] = float(lineSplit[8])
                d['gene count std'] = float(lineSplit[9])
                d['marker set'] = lineSplit[10].rstrip()

                uniqueIdToLineageStatistics[uniqueId] = d

        return uniqueIdToLineageStatistics

    def __getNextNamedNode(self, node, uniqueIdToLineageStatistics):
        """Get first parent node with taxonomy information."""
        parentNode = node.parent_node
        while True:
            if parentNode == None:
                break  # reached the root node so terminate

            if parentNode.label:
                trustedUniqueId = parentNode.label.split('|')[0]
                trustedStats = uniqueIdToLineageStatistics[trustedUniqueId]
                if trustedStats['taxonomy'] != '':
                    return trustedStats['taxonomy']

            parentNode = parentNode.parent_node

        return 'root'

    def __refineMarkerSet(self, markerSet, lineageSpecificMarkerSet):
        """Refine marker set to account for lineage-specific gene loss and duplication."""

        # refine marker set by finding the intersection between these two sets,
        # this removes markers that are not single-copy or ubiquitous in the
        # specific lineage of a bin
        # Note: co-localization information is taken from the trusted set

        # remove genes not present in the lineage-specific gene set
        finalMarkerSet = []
        for ms in markerSet.markerSet:
            s = set()
            for gene in ms:
                if gene in lineageSpecificMarkerSet.getMarkerGenes():
                    s.add(gene)

            if s:
                finalMarkerSet.append(s)

        refinedMarkerSet = MarkerSet(markerSet.UID, markerSet.lineageStr,
                                     markerSet.numGenomes, finalMarkerSet)

        return refinedMarkerSet

    def ____removeInvalidLineageMarkerGenes(self, markerSet,
                                            lineageSpecificMarkersToRemove):
        """Refine marker set to account for lineage-specific gene loss and duplication."""

        # refine marker set by removing marker genes subject to lineage-specific
        # gene loss and duplication
        #
        # Note: co-localization information is taken from the trusted set

        finalMarkerSet = []
        for ms in markerSet.markerSet:
            s = set()
            for gene in ms:
                if gene.startswith('PF'):
                    print('ERROR! Expected genes to start with pfam, not PF.')

                if gene not in lineageSpecificMarkersToRemove:
                    s.add(gene)

            if s:
                finalMarkerSet.append(s)

        refinedMarkerSet = MarkerSet(markerSet.UID, markerSet.lineageStr,
                                     markerSet.numGenomes, finalMarkerSet)

        return refinedMarkerSet

    def missingGenes(self, genomeIds, markerGenes, ubiquityThreshold):
        """Inferring consistently missing marker genes within a set of genomes."""

        if self.cachedGeneCountTable != None:
            geneCountTable = self.cachedGeneCountTable
        else:
            geneCountTable = self.img.geneCountTable(genomeIds)

        # find genes meeting ubiquity and single-copy thresholds
        missing = set()
        for clusterId, genomeCounts in geneCountTable.iteritems():
            if clusterId not in markerGenes:
                continue

            absence = 0
            for genomeId in genomeIds:
                count = genomeCounts.get(genomeId, 0)

                if count == 0:
                    absence += 1

            if absence >= ubiquityThreshold * len(genomeIds):
                missing.add(clusterId)

        return missing

    def duplicateGenes(self, genomeIds, markerGenes, ubiquityThreshold):
        """Inferring consistently duplicated marker genes within a set of genomes."""

        if self.cachedGeneCountTable != None:
            geneCountTable = self.cachedGeneCountTable
        else:
            geneCountTable = self.img.geneCountTable(genomeIds)

        # find genes meeting ubiquity and single-copy thresholds
        duplicate = set()
        for clusterId, genomeCounts in geneCountTable.iteritems():
            if clusterId not in markerGenes:
                continue

            duplicateCount = 0
            for genomeId in genomeIds:
                count = genomeCounts.get(genomeId, 0)

                if count > 1:
                    duplicateCount += 1

            if duplicateCount >= ubiquityThreshold * len(genomeIds):
                duplicate.add(clusterId)

        return duplicate

    def buildMarkerGenes(self, genomeIds, ubiquityThreshold,
                         singleCopyThreshold):
        """Infer marker genes from specified genomes."""

        if self.cachedGeneCountTable != None:
            geneCountTable = self.cachedGeneCountTable
        else:
            geneCountTable = self.img.geneCountTable(genomeIds)

        #counts = []
        #singleCopy = 0
        #for genomeId, count in geneCountTable['pfam01351'].iteritems():
        #    print genomeId, count
        #    counts.append(count)
        #    if count == 1:
        #        singleCopy += 1

        #print 'Ubiquity: %d of %d' % (len(counts), len(genomeIds))
        #print 'Single-copy: %d of %d' % (singleCopy, len(genomeIds))
        #print 'Mean: %.2f' % mean(counts)

        markerGenes = self.markerGenes(genomeIds, geneCountTable,
                                       ubiquityThreshold * len(genomeIds),
                                       singleCopyThreshold * len(genomeIds))
        tigrToRemove = self.img.identifyRedundantTIGRFAMs(markerGenes)
        markerGenes = markerGenes - tigrToRemove

        return markerGenes

    def buildMarkerSet(self,
                       genomeIds,
                       ubiquityThreshold,
                       singleCopyThreshold,
                       spacingBetweenContigs=5000):
        """Infer marker set from specified genomes."""

        markerGenes = self.buildMarkerGenes(genomeIds, ubiquityThreshold,
                                            singleCopyThreshold)

        geneDistTable = self.img.geneDistTable(genomeIds, markerGenes,
                                               spacingBetweenContigs)
        colocatedGenes = self.colocatedGenes(geneDistTable)
        colocatedSets = self.colocatedSets(colocatedGenes, markerGenes)
        markerSet = MarkerSet(0, 'NA', len(genomeIds), colocatedSets)

        return markerSet

    def readLineageSpecificGenesToRemove(self):
        """Get set of genes subject to lineage-specific gene loss and duplication."""

        self.lineageSpecificGenesToRemove = {}
        for line in open(
                '/srv/whitlam/bio/db/checkm/genome_tree/missing_duplicate_genes_50.tsv'
        ):
            lineSplit = line.split('\t')
            uid = lineSplit[0]
            missingGenes = eval(lineSplit[1])
            duplicateGenes = eval(lineSplit[2])
            self.lineageSpecificGenesToRemove[uid] = missingGenes.union(
                duplicateGenes)

    def buildBinMarkerSet(self,
                          tree,
                          curNode,
                          ubiquityThreshold,
                          singleCopyThreshold,
                          bMarkerSet=True,
                          genomeIdsToRemove=None):
        """Build lineage-specific marker sets for a genome in a LOO-fashion."""

        # determine marker sets for bin
        binMarkerSets = BinMarkerSets(curNode.label,
                                      BinMarkerSets.TREE_MARKER_SET)
        refinedBinMarkerSet = BinMarkerSets(curNode.label,
                                            BinMarkerSets.TREE_MARKER_SET)

        # ascend tree to root, recording all marker sets
        uniqueId = curNode.label.split('|')[0]
        lineageSpecificRefinement = self.lineageSpecificGenesToRemove[uniqueId]

        while curNode != None:
            uniqueId = curNode.label.split('|')[0]
            stats = self.uniqueIdToLineageStatistics[uniqueId]
            taxonomyStr = stats['taxonomy']
            if taxonomyStr == '':
                taxonomyStr = self.__getNextNamedNode(
                    curNode, self.uniqueIdToLineageStatistics)

            leafNodes = curNode.leaf_nodes()
            genomeIds = set()
            for leaf in leafNodes:
                genomeIds.add(leaf.taxon.label.replace('IMG_', ''))

                duplicateGenomes = self.duplicateSeqs.get(leaf.taxon.label, [])
                for dup in duplicateGenomes:
                    genomeIds.add(dup.replace('IMG_', ''))

            # remove all genomes from the same taxonomic group as the genome of interest
            if genomeIdsToRemove != None:
                genomeIds.difference_update(genomeIdsToRemove)

            if len(genomeIds) >= 2:
                if bMarkerSet:
                    markerSet = self.buildMarkerSet(genomeIds,
                                                    ubiquityThreshold,
                                                    singleCopyThreshold)
                else:
                    markerSet = MarkerSet(0, 'NA', len(genomeIds), [
                        self.buildMarkerGenes(genomeIds, ubiquityThreshold,
                                              singleCopyThreshold)
                    ])

                markerSet.lineageStr = uniqueId + ' | ' + taxonomyStr.split(
                    ';')[-1]
                binMarkerSets.addMarkerSet(markerSet)

                #refinedMarkerSet = self.__refineMarkerSet(markerSet, lineageSpecificMarkerSet)
                refinedMarkerSet = self.____removeInvalidLineageMarkerGenes(
                    markerSet, lineageSpecificRefinement)
                #print 'Refinement: %d of %d' % (len(refinedMarkerSet.getMarkerGenes()), len(markerSet.getMarkerGenes()))
                refinedBinMarkerSet.addMarkerSet(refinedMarkerSet)

            curNode = curNode.parent_node

        return binMarkerSets, refinedBinMarkerSet

    def buildDomainMarkerSet(self,
                             tree,
                             curNode,
                             ubiquityThreshold,
                             singleCopyThreshold,
                             bMarkerSet=True,
                             genomeIdsToRemove=None):
        """Build domain-specific marker sets for a genome in a LOO-fashion."""

        # determine marker sets for bin
        binMarkerSets = BinMarkerSets(curNode.label,
                                      BinMarkerSets.TREE_MARKER_SET)
        refinedBinMarkerSet = BinMarkerSets(curNode.label,
                                            BinMarkerSets.TREE_MARKER_SET)

        # calculate marker set for bacterial or archaeal node
        uniqueId = curNode.label.split('|')[0]
        lineageSpecificRefinement = self.lineageSpecificGenesToRemove[uniqueId]

        while curNode != None:
            uniqueId = curNode.label.split('|')[0]
            if uniqueId != 'UID2' and uniqueId != 'UID203':
                curNode = curNode.parent_node
                continue

            stats = self.uniqueIdToLineageStatistics[uniqueId]
            taxonomyStr = stats['taxonomy']
            if taxonomyStr == '':
                taxonomyStr = self.__getNextNamedNode(
                    curNode, self.uniqueIdToLineageStatistics)

            leafNodes = curNode.leaf_nodes()
            genomeIds = set()
            for leaf in leafNodes:
                genomeIds.add(leaf.taxon.label.replace('IMG_', ''))

                duplicateGenomes = self.duplicateSeqs.get(leaf.taxon.label, [])
                for dup in duplicateGenomes:
                    genomeIds.add(dup.replace('IMG_', ''))

            # remove all genomes from the same taxonomic group as the genome of interest
            if genomeIdsToRemove != None:
                genomeIds.difference_update(genomeIdsToRemove)

            if len(genomeIds) >= 2:
                if bMarkerSet:
                    markerSet = self.buildMarkerSet(genomeIds,
                                                    ubiquityThreshold,
                                                    singleCopyThreshold)
                else:
                    markerSet = MarkerSet(0, 'NA', len(genomeIds), [
                        self.buildMarkerGenes(genomeIds, ubiquityThreshold,
                                              singleCopyThreshold)
                    ])

                markerSet.lineageStr = uniqueId + ' | ' + taxonomyStr.split(
                    ';')[-1]
                binMarkerSets.addMarkerSet(markerSet)

                #refinedMarkerSet = self.__refineMarkerSet(markerSet, lineageSpecificMarkerSet)
                refinedMarkerSet = self.____removeInvalidLineageMarkerGenes(
                    markerSet, lineageSpecificRefinement)
                #print 'Refinement: %d of %d' % (len(refinedMarkerSet.getMarkerGenes()), len(markerSet.getMarkerGenes()))
                refinedBinMarkerSet.addMarkerSet(refinedMarkerSet)

            curNode = curNode.parent_node

        return binMarkerSets, refinedBinMarkerSet
 def __init__(self):
     self.markerSetBuilder = MarkerSetBuilder()
     self.img = IMG('/srv/whitlam/bio/db/checkm/img/img_metadata.tsv', '/srv/whitlam/bio/db/checkm/pfam/tigrfam2pfam.tsv')
예제 #39
0
class SimComparePlots(object):
    def __init__(self):
        
        self.plotPrefix = './simulations/simulation.draft.w_refinement_50'
        self.simCompareFile = './simulations/simCompare.draft.w_refinement_50.full.tsv'
        self.simCompareMarkerSetOut = './simulations/simCompare.draft.marker_set_table.w_refinement_50.tsv'
        self.simCompareConditionOut = './simulations/simCompare.draft.condition_table.w_refinement_50.tsv'
        self.simCompareTaxonomyTableOut = './simulations/simCompare.draft.taxonomy_table.w_refinement_50.tsv'
        self.simCompareRefinementTableOut = './simulations/simCompare.draft.refinment_table.w_refinement_50.tsv'
               
        #self.plotPrefix = './simulations/simulation.scaffolds.draft.w_refinement_50'
        #self.simCompareFile = './simulations/simCompare.scaffolds.draft.w_refinement_50.full.tsv'
        #self.simCompareMarkerSetOut = './simulations/simCompare.scaffolds.draft.marker_set_table.w_refinement_50.tsv'
        #self.simCompareConditionOut = './simulations/simCompare.scaffolds.draft.condition_table.w_refinement_50.tsv'
        #self.simCompareTaxonomyTableOut = './simulations/simCompare.scaffolds.draft.taxonomy_table.w_refinement_50.tsv'
        #self.simCompareRefinementTableOut = './simulations/simCompare.scaffolds.draft.refinment_table.w_refinement_50.tsv'
        
        #self.plotPrefix = './simulations/simulation.random_scaffolds.w_refinement_50'
        #self.simCompareFile = './simulations/simCompare.random_scaffolds.w_refinement_50.full.tsv'
        #self.simCompareMarkerSetOut = './simulations/simCompare.random_scaffolds.marker_set_table.w_refinement_50.tsv'
        #self.simCompareConditionOut = './simulations/simCompare.random_scaffolds.condition_table.w_refinement_50.tsv'
        #self.simCompareTaxonomyTableOut = './simulations/simCompare.random_scaffolds.taxonomy_table.w_refinement_50.tsv'
        #self.simCompareRefinementTableOut = './simulations/simCompare.random_scaffolds.refinment_table.w_refinement_50.tsv'
        
        self.img = IMG('/srv/whitlam/bio/db/checkm/img/img_metadata.tsv', '/srv/whitlam/bio/db/checkm/pfam/tigrfam2pfam.tsv')
        
        self.compsToConsider = [0.5, 0.7, 0.8, 0.9] #[0.5, 0.7, 0.8, 0.9]
        self.contsToConsider = [0.05, 0.1, 0.15] #[0.05, 0.1, 0.15]
        
        self.dpi = 1200
  
    def __readResults(self, filename):
        results = defaultdict(dict)
        genomeIds = set()
        with open(filename) as f:
            f.readline()
            for line in f:
                lineSplit = line.split('\t')
                
                simId = lineSplit[0]
                genomeId = simId.split('-')[0]
                genomeIds.add(genomeId)
                
                bestCompIM = [float(x) for x in lineSplit[6].split(',')]
                bestContIM = [float(x) for x in lineSplit[7].split(',')]
                
                bestCompMS = [float(x) for x in lineSplit[8].split(',')]
                bestContMS = [float(x) for x in lineSplit[9].split(',')]
                                
                domCompIM = [float(x) for x in lineSplit[10].split(',')]
                domContIM = [float(x) for x in lineSplit[11].split(',')]
                
                domCompMS = [float(x) for x in lineSplit[12].split(',')]
                domContMS = [float(x) for x in lineSplit[13].split(',')]
                
                simCompIM = [float(x) for x in lineSplit[14].split(',')]
                simContIM = [float(x) for x in lineSplit[15].split(',')]
                
                simCompMS = [float(x) for x in lineSplit[16].split(',')]
                simContMS = [float(x) for x in lineSplit[17].split(',')]
                
                simCompRMS = [float(x) for x in lineSplit[18].split(',')]
                simContRMS = [float(x) for x in lineSplit[19].split(',')]
                
                results[simId] = [bestCompIM, bestContIM, bestCompMS, bestContMS, domCompIM, domContIM, domCompMS, domContMS, simCompIM, simContIM, simCompMS, simContMS, simCompRMS, simContRMS]
                
        print '    Number of test genomes: ' + str(len(genomeIds))
        
        return results
    
    def markerSets(self, results):
        # summarize results from IM vs MS
        print '  Tabulating results for domain-level marker genes vs marker sets.'
        
        itemsProcessed = 0      
        compDataDict = defaultdict(lambda : defaultdict(list))
        contDataDict = defaultdict(lambda : defaultdict(list))

        genomeIds = set()
        for simId in results:
            itemsProcessed += 1
            statusStr = '    Finished processing %d of %d (%.2f%%) test cases.' % (itemsProcessed, len(results), float(itemsProcessed)*100/len(results))
            sys.stdout.write('%s\r' % statusStr)
            sys.stdout.flush()
            
            genomeId, seqLen, comp, cont = simId.split('-')
            genomeIds.add(genomeId)
            expCondStr = str(float(comp)) + '-' + str(float(cont)) + '-' + str(int(seqLen))
            
            compDataDict[expCondStr]['IM'] += results[simId][4]
            compDataDict[expCondStr]['MS'] += results[simId][6]

            contDataDict[expCondStr]['IM'] += results[simId][5]
            contDataDict[expCondStr]['MS'] += results[simId][7]
                
        print '  There are %d unique genomes.' % len(genomeIds)
              
        sys.stdout.write('\n')
        
        print '    There are %d experimental conditions.' % (len(compDataDict))
                
        # plot data
        print '  Plotting results.'
        compData = []
        contData = []
        rowLabels = []
        
        for comp in self.compsToConsider:
            for cont in self.contsToConsider:
                for seqLen in [20000]: 
                    for msStr in ['MS', 'IM']:
                        rowLabels.append(msStr +': %d%%, %d%%' % (comp*100, cont*100))
                        
                        expCondStr = str(comp) + '-' + str(cont) + '-' + str(seqLen)
                        compData.append(compDataDict[expCondStr][msStr])
                        contData.append(contDataDict[expCondStr][msStr])  
                                       
        print 'MS:\t%.2f\t%.2f' % (mean(abs(array(compData[0::2]))), mean(abs(array(contData[0::2]))))
        print 'IM:\t%.2f\t%.2f' % (mean(abs(array(compData[1::2]))), mean(abs(array(contData[1::2]))))   
            
        boxPlot = BoxPlot()
        plotFilename = self.plotPrefix + '.markerSets.png'
        boxPlot.plot(plotFilename, compData, contData, rowLabels, 
                        r'$\Delta$' + ' % Completion', 'Simulation Conditions', 
                        r'$\Delta$' + ' % Contamination', None,
                        rowsPerCategory = 2, dpi = self.dpi)
        
        # print table of results 
        tableOut = open(self.simCompareMarkerSetOut, 'w')
        tableOut.write('Comp. (%)\tCont. (%)\tIM (5kb)\t\tMS (5kb)\t\tIM (20kb)\t\tMS (20kb)\t\tIM (50kb)\t\tMS (50kb)\n')
        
        avgComp = defaultdict(lambda : defaultdict(list))
        avgCont = defaultdict(lambda : defaultdict(list))
        for comp in [0.5, 0.7, 0.8, 0.9, 0.95, 1.0]:
            for cont in [0.0, 0.05, 0.1, 0.15, 0.2]:
                
                tableOut.write('%d\t%d' % (comp*100, cont*100))
                
                for seqLen in [5000, 20000, 50000]:
                    expCondStr = str(comp) + '-' + str(cont) + '-' + str(seqLen)
                     
                    meanCompIM = mean(abs(array(compDataDict[expCondStr]['IM'])))
                    stdCompIM = std(abs(array(compDataDict[expCondStr]['IM'])))
                    meanContIM = mean(abs(array(contDataDict[expCondStr]['IM'])))
                    stdContIM = std(abs(array(contDataDict[expCondStr]['IM'])))
                    
                    avgComp[seqLen]['IM'] += compDataDict[expCondStr]['IM']
                    avgCont[seqLen]['IM'] += contDataDict[expCondStr]['IM']
                    
                    meanCompMS = mean(abs(array(compDataDict[expCondStr]['MS'])))
                    stdCompMS = std(abs(array(compDataDict[expCondStr]['MS'])))
                    meanContMS = mean(abs(array(contDataDict[expCondStr]['MS'])))
                    stdContMS = std(abs(array(contDataDict[expCondStr]['MS'])))
                    
                    avgComp[seqLen]['MS'] += compDataDict[expCondStr]['MS']
                    avgCont[seqLen]['MS'] += contDataDict[expCondStr]['MS']
                    
                    tableOut.write('\t%.1f+/-%.2f\t%.1f+/-%.2f\t%.1f+/-%.2f\t%.1f+/-%.2f' % (meanCompIM, stdCompIM, meanCompMS, stdCompMS, meanContIM, stdContIM, meanContMS, stdContMS))
                tableOut.write('\n')
                
        tableOut.write('\tAverage:')
        for seqLen in [5000, 20000, 50000]: 
            meanCompIM = mean(abs(array(avgComp[seqLen]['IM'])))
            stdCompIM = std(abs(array(avgComp[seqLen]['IM'])))
            meanContIM = mean(abs(array(avgCont[seqLen]['IM'])))
            stdContIM = std(abs(array(avgCont[seqLen]['IM'])))
            
            meanCompMS = mean(abs(array(avgComp[seqLen]['MS'])))
            stdCompMS = std(abs(array(avgComp[seqLen]['MS'])))
            meanContMS = mean(abs(array(avgCont[seqLen]['MS'])))
            stdContMS = std(abs(array(avgCont[seqLen]['MS'])))
            
            tableOut.write('\t%.1f+/-%.2f\t%.1f+/-%.2f\t%.1f+/-%.2f\t%.1f+/-%.2f' % (meanCompIM, stdCompIM, meanCompMS, stdCompMS, meanContIM, stdContIM, meanContMS, stdContMS))
                        
        tableOut.write('\n')     
                
        tableOut.close()
    
    def conditionsPlot(self, results):
        # summarize results for each experimental condition  
        print '  Tabulating results for each experimental condition using marker sets.'
        
        itemsProcessed = 0      
        compDataDict = defaultdict(lambda : defaultdict(list))
        contDataDict = defaultdict(lambda : defaultdict(list))
        comps = set()
        conts = set()
        seqLens = set()
        
        compOutliers = defaultdict(list)
        contOutliers = defaultdict(list)
        
        genomeIds = set()
        for simId in results:
            itemsProcessed += 1
            statusStr = '    Finished processing %d of %d (%.2f%%) test cases.' % (itemsProcessed, len(results), float(itemsProcessed)*100/len(results))
            sys.stdout.write('%s\r' % statusStr)
            sys.stdout.flush()
            
            genomeId, seqLen, comp, cont = simId.split('-')
            genomeIds.add(genomeId)
            expCondStr = str(float(comp)) + '-' + str(float(cont)) + '-' + str(int(seqLen))
            
            comps.add(float(comp))
            conts.add(float(cont))
            seqLens.add(int(seqLen))
            
            compDataDict[expCondStr]['best'] += results[simId][2]
            compDataDict[expCondStr]['domain'] += results[simId][6]
            compDataDict[expCondStr]['selected'] += results[simId][10]
            
            for dComp in results[simId][2]:
                compOutliers[expCondStr] += [[dComp, genomeId]]
            
            contDataDict[expCondStr]['best'] += results[simId][3]
            contDataDict[expCondStr]['domain'] += results[simId][7]
            contDataDict[expCondStr]['selected'] += results[simId][11]
            
            for dCont in results[simId][3]:
                contOutliers[expCondStr] += [[dCont, genomeId]]
                
        print '  There are %d unique genomes.' % len(genomeIds)
              
        sys.stdout.write('\n')
        
        print '    There are %d experimental conditions.' % (len(compDataDict))
                
        # plot data
        print '  Plotting results.'
        compData = []
        contData = []
        rowLabels = []
        
        foutComp = open('./simulations/simulation.scaffolds.draft.comp_outliers.domain.tsv', 'w')
        foutCont = open('./simulations/simulation.scaffolds.draft.cont_outliers.domain.tsv', 'w')
        for comp in self.compsToConsider:
            for cont in self.contsToConsider:
                for msStr in ['best', 'selected', 'domain']:
                    for seqLen in [20000]: 
                        rowLabels.append(msStr +': %d%%, %d%%' % (comp*100, cont*100))
                        
                        expCondStr = str(comp) + '-' + str(cont) + '-' + str(seqLen)
                        compData.append(compDataDict[expCondStr][msStr])
                        contData.append(contDataDict[expCondStr][msStr])  
                    
                # report completenes outliers
                foutComp.write(expCondStr)

                compOutliers[expCondStr].sort()
                
                dComps = array([r[0] for r in compOutliers[expCondStr]])
                perc1 = scoreatpercentile(dComps, 1)
                perc99 = scoreatpercentile(dComps, 99)
                print expCondStr, perc1, perc99
                
                foutComp.write('\t%.2f\t%.2f' % (perc1, perc99))
                
                outliers = []
                for item in compOutliers[expCondStr]:
                    if item[0] < perc1 or item[0] > perc99:
                        outliers.append(item[1])
                        
                outlierCount = Counter(outliers)
                for genomeId, count in outlierCount.most_common():
                    foutComp.write('\t' + genomeId + ': ' + str(count))
                foutComp.write('\n')
                
                # report contamination outliers
                foutCont.write(expCondStr)

                contOutliers[expCondStr].sort()
                
                dConts = array([r[0] for r in contOutliers[expCondStr]])
                perc1 = scoreatpercentile(dConts, 1)
                perc99 = scoreatpercentile(dConts, 99)
                
                foutCont.write('\t%.2f\t%.2f' % (perc1, perc99))
                
                outliers = []
                for item in contOutliers[expCondStr]:
                    if item[0] < perc1 or item[0] > perc99:
                        outliers.append(item[1])
                        
                outlierCount = Counter(outliers)
                for genomeId, count in outlierCount.most_common():
                    foutCont.write('\t' + genomeId + ': ' + str(count))
                foutCont.write('\n')
                
        foutComp.close()
        foutCont.close()
                               
        print 'best:\t%.2f\t%.2f' % (mean(abs(array(compData[0::3]))), mean(abs(array(contData[0::3]))))
        print 'selected:\t%.2f\t%.2f' % (mean(abs(array(compData[1::3]))), mean(abs(array(contData[1::3]))))   
        print 'domain:\t%.2f\t%.2f' % (mean(abs(array(compData[2::3]))), mean(abs(array(contData[2::3]))))   

        boxPlot = BoxPlot()
        plotFilename = self.plotPrefix + '.conditions.png'
        boxPlot.plot(plotFilename, compData, contData, rowLabels, 
                        r'$\Delta$' + ' % Completion', 'Simulation Conditions', 
                        r'$\Delta$' + ' % Contamination', None,
                        rowsPerCategory = 3, dpi = self.dpi)
        
        
        # print table of results 
        tableOut = open(self.simCompareConditionOut, 'w')
        tableOut.write('Comp. (%)\tCont. (%)\tbest (5kb)\t\tselected (5kb)\t\tdomain (5kb)\t\tbest (20kb)\t\tselected (20kb)\t\tdomain (20kb)\t\tbest (50kb)\t\tselected (50kb)\t\tdomain (50kb)\n')
        
        avgComp = defaultdict(lambda : defaultdict(list))
        avgCont = defaultdict(lambda : defaultdict(list))
        for comp in [0.5, 0.7, 0.8, 0.9, 0.95, 1.0]:
            for cont in [0.0, 0.05, 0.1, 0.15, 0.2]:
                
                tableOut.write('%d\t%d' % (comp*100, cont*100))
                
                for seqLen in [5000, 20000, 50000]:
                    expCondStr = str(comp) + '-' + str(cont) + '-' + str(seqLen)
                   
                    meanCompD = mean(abs(array(compDataDict[expCondStr]['domain'])))
                    stdCompD = std(abs(array(compDataDict[expCondStr]['domain'])))
                    meanContD = mean(abs(array(contDataDict[expCondStr]['domain'])))
                    stdContD = std(abs(array(contDataDict[expCondStr]['domain'])))
                    
                    avgComp[seqLen]['domain'] += compDataDict[expCondStr]['domain']
                    avgCont[seqLen]['domain'] += contDataDict[expCondStr]['domain']
                    
                    meanCompS = mean(abs(array(compDataDict[expCondStr]['selected'])))
                    stdCompS = std(abs(array(compDataDict[expCondStr]['selected'])))
                    meanContS = mean(abs(array(contDataDict[expCondStr]['selected'])))
                    stdContS = std(abs(array(contDataDict[expCondStr]['selected'])))
                    
                    avgComp[seqLen]['selected'] += compDataDict[expCondStr]['selected']
                    avgCont[seqLen]['selected'] += contDataDict[expCondStr]['selected']
                    
                    meanCompB = mean(abs(array(compDataDict[expCondStr]['best'])))
                    stdCompB = std(abs(array(compDataDict[expCondStr]['best'])))
                    meanContB = mean(abs(array(contDataDict[expCondStr]['best'])))
                    stdContB = std(abs(array(contDataDict[expCondStr]['best'])))
                    
                    avgComp[seqLen]['best'] += compDataDict[expCondStr]['best']
                    avgCont[seqLen]['best'] += contDataDict[expCondStr]['best']
                    
                    tableOut.write('\t%.1f\t%.1f\t%.1f\t%.1f\t%.1f\t%.1f' % (meanCompD, meanCompS, meanCompB, meanContD, meanContS, meanContB))
                tableOut.write('\n')
                
        tableOut.write('\tAverage:')
        for seqLen in [5000, 20000, 50000]: 
            meanCompD = mean(abs(array(avgComp[seqLen]['domain'])))
            stdCompD = std(abs(array(avgComp[seqLen]['domain'])))
            meanContD = mean(abs(array(avgCont[seqLen]['domain'])))
            stdContD = std(abs(array(avgCont[seqLen]['domain'])))
            
            meanCompS = mean(abs(array(avgComp[seqLen]['selected'])))
            stdCompS = std(abs(array(avgComp[seqLen]['selected'])))
            meanContS = mean(abs(array(avgCont[seqLen]['selected'])))
            stdContS = std(abs(array(avgCont[seqLen]['selected'])))
            
            meanCompB = mean(abs(array(avgComp[seqLen]['best'])))
            stdCompB = std(abs(array(avgComp[seqLen]['best'])))
            meanContB = mean(abs(array(avgCont[seqLen]['best'])))
            stdContB = std(abs(array(avgCont[seqLen]['best'])))
            
            tableOut.write('\t%.1f\t%.1f\t%.1f\t%.1f\t%.1f\t%.1f' % (meanCompD, meanCompS, meanCompB, meanContD, meanContS, meanContB))
                        
        tableOut.write('\n')     
                
        tableOut.close()
        
    def taxonomicPlots(self, results):
        # summarize results for different taxonomic groups  
        print '  Tabulating results for taxonomic groups.'
        
        metadata = self.img.genomeMetadata()
        
        itemsProcessed = 0      
        compDataDict = defaultdict(lambda : defaultdict(list))
        contDataDict = defaultdict(lambda : defaultdict(list))
        comps = set()
        conts = set()
        seqLens = set()
        
        ranksToProcess = 3
        taxaByRank = [set() for _ in xrange(0, ranksToProcess)]
        
        overallComp = []
        overallCont = []
                
        genomeInTaxon = defaultdict(set)
        testCases = 0
        for simId in results:
            itemsProcessed += 1
            statusStr = '    Finished processing %d of %d (%.2f%%) test cases.' % (itemsProcessed, len(results), float(itemsProcessed)*100/len(results))
            sys.stdout.write('%s\r' % statusStr)
            sys.stdout.flush()
            
            genomeId, seqLen, comp, cont = simId.split('-')
            
            if seqLen != '20000':
                continue
            
            if str(float(comp)) in ['0.5', '0.7', '0.8', '0.9'] and str(float(cont)) in ['0.05', '0.10', '0.1', '0.15']:
                print comp, cont
                taxonomy = metadata[genomeId]['taxonomy']
                
                testCases += 1
                
                comps.add(float(comp))
                conts.add(float(cont))
                seqLens.add(int(seqLen))
                
                overallComp += results[simId][10]
                overallCont += results[simId][11]
                
                for r in xrange(0, ranksToProcess):
                    taxon = taxonomy[r]
                    
                    if r == 0 and taxon == 'unclassified':
                        print '*****************************Unclassified at domain-level*****************'
                        continue
                    
                    if taxon == 'unclassified':
                        continue
                    
                    taxon = rankPrefixes[r] + taxon
                    
                    taxaByRank[r].add(taxon)
                                                    
                    compDataDict[taxon]['best'] += results[simId][2]
                    compDataDict[taxon]['domain'] += results[simId][6]
                    compDataDict[taxon]['selected'] += results[simId][10]
                    
                    contDataDict[taxon]['best'] += results[simId][3]
                    contDataDict[taxon]['domain'] += results[simId][7]
                    contDataDict[taxon]['selected'] += results[simId][11]
                    
                    genomeInTaxon[taxon].add(genomeId)
            
        sys.stdout.write('\n')
        
        print 'Test cases', testCases
        
        print ''        
        print 'Creating plots for:'
        print '  comps = ', comps
        print '  conts = ', conts
        
        print ''
        print '    There are %d taxa.' % (len(compDataDict))
        
        print ''
        print '  Overall bias:'
        print '    Selected comp: %.2f' % mean(overallComp)
        print '    Selected cont: %.2f' % mean(overallCont)
        
        # get list of ordered taxa by rank
        orderedTaxa = []
        for taxa in taxaByRank:
            orderedTaxa += sorted(taxa)
                
        # plot data
        print '  Plotting results.'
        compData = []
        contData = []
        rowLabels = []
        for taxon in orderedTaxa:
            for msStr in ['best', 'selected', 'domain']:
                numGenomes = len(genomeInTaxon[taxon])
                if numGenomes < 10: # skip groups with only a few genomes
                    continue
                
                rowLabels.append(msStr + ': ' + taxon + ' (' + str(numGenomes) + ')')
                compData.append(compDataDict[taxon][msStr])
                contData.append(contDataDict[taxon][msStr])        
                
        for i, rowLabel in enumerate(rowLabels):
            print rowLabel + '\t%.2f\t%.2f' % (mean(abs(array(compData[i]))), mean(abs(array(contData[i]))))            
                  
        # print taxonomic table of results organized by class
        taxonomyTableOut = open(self.simCompareTaxonomyTableOut, 'w')
        for taxon in orderedTaxa:
            numGenomes = len(genomeInTaxon[taxon])
            if numGenomes < 2: # skip groups with only a few genomes
                continue
                
            taxonomyTableOut.write(taxon + '\t' + str(numGenomes))
            for msStr in ['domain', 'selected']:                
                meanTaxonComp = mean(abs(array(compDataDict[taxon][msStr])))
                stdTaxonComp = std(abs(array(compDataDict[taxon][msStr])))
                meanTaxonCont = mean(abs(array(contDataDict[taxon][msStr])))
                stdTaxonCont = std(abs(array(contDataDict[taxon][msStr])))
                
                taxonomyTableOut.write('\t%.1f +/- %.2f\t%.1f +/- %.2f' % (meanTaxonComp, stdTaxonComp, meanTaxonCont, stdTaxonCont))
            taxonomyTableOut.write('\n')
        taxonomyTableOut.close()
        
        # create box plot
        boxPlot = BoxPlot()
        plotFilename = self.plotPrefix +  '.taxonomy.png'
        boxPlot.plot(plotFilename, compData, contData, rowLabels, 
                        r'$\Delta$' + ' % Completion', None, 
                        r'$\Delta$' + ' % Contamination', None,
                        rowsPerCategory = 3, dpi = self.dpi)
    
    
    def refinementPlots(self, results):
        # summarize results for different CheckM refinements 
        print '  Tabulating results for different refinements.'
        
        metadata = self.img.genomeMetadata()
        
        itemsProcessed = 0      
        compDataDict = defaultdict(lambda : defaultdict(list))
        contDataDict = defaultdict(lambda : defaultdict(list))
        comps = set()
        conts = set()
        seqLens = set()
        
        ranksToProcess = 3
        taxaByRank = [set() for _ in xrange(0, ranksToProcess)]
        
        overallCompIM = []
        overallContIM = [] 
        
        overallCompMS = []
        overallContMS = [] 
        
        overallCompRMS = []
        overallContRMS = [] 
        
        genomeInTaxon = defaultdict(set)
        
        testCases = 0
        for simId in results:
            itemsProcessed += 1
            statusStr = '    Finished processing %d of %d (%.2f%%) test cases.' % (itemsProcessed, len(results), float(itemsProcessed)*100/len(results))
            sys.stdout.write('%s\r' % statusStr)
            sys.stdout.flush()
            
            genomeId, seqLen, comp, cont = simId.split('-')
            taxonomy = metadata[genomeId]['taxonomy']
            
            if float(comp) < 0.7 or float(cont) > 0.1:
                continue
            
            comps.add(float(comp))
            conts.add(float(cont))
            seqLens.add(int(seqLen))
            
            overallCompIM.append(results[simId][8])
            overallContIM.append(results[simId][9])
            
            overallCompMS.append(results[simId][10])
            overallContMS.append(results[simId][11])
            
            overallCompRMS.append(results[simId][12])
            overallContRMS.append(results[simId][13])
            
            for r in xrange(0, ranksToProcess):
                taxon = taxonomy[r]
                
                if taxon == 'unclassified':
                    continue
                
                taxaByRank[r].add(taxon)
                
                compDataDict[taxon]['IM'] += results[simId][8]
                compDataDict[taxon]['MS'] += results[simId][10]
                compDataDict[taxon]['RMS'] += results[simId][12]
                
                contDataDict[taxon]['IM'] += results[simId][9]
                contDataDict[taxon]['MS'] += results[simId][11]
                contDataDict[taxon]['RMS'] += results[simId][13]
                                
                genomeInTaxon[taxon].add(genomeId)
            
        sys.stdout.write('\n')
        
        print 'Creating plots for:'
        print '  comps = ', comps
        print '  conts = ', conts
        
        print ''
        print '    There are %d taxon.' % (len(compDataDict))
        print ''
        print 'Percentage change MS-IM comp: %.4f' % ((mean(abs(array(overallCompMS))) - mean(abs(array(overallCompIM)))) * 100 / mean(abs(array(overallCompIM))))
        print 'Percentage change MS-IM cont: %.4f' % ((mean(abs(array(overallContMS))) - mean(abs(array(overallContIM)))) * 100 / mean(abs(array(overallContIM))))
        print ''
        print 'Percentage change RMS-MS comp: %.4f' % ((mean(abs(array(overallCompRMS))) - mean(abs(array(overallCompMS)))) * 100 / mean(abs(array(overallCompIM))))
        print 'Percentage change RMS-MS cont: %.4f' % ((mean(abs(array(overallContRMS))) - mean(abs(array(overallContMS)))) * 100 / mean(abs(array(overallContIM))))
        
        print ''
        
        # get list of ordered taxa by rank
        orderedTaxa = []
        for taxa in taxaByRank:
            orderedTaxa += sorted(taxa)
             
        # print table of results organized by class
        refinmentTableOut = open(self.simCompareRefinementTableOut, 'w')
        for taxon in orderedTaxa:
            numGenomes = len(genomeInTaxon[taxon])
            if numGenomes < 2: # skip groups with only a few genomes
                continue
                
            refinmentTableOut.write(taxon + '\t' + str(numGenomes))
            for refineStr in ['IM', 'MS']:               
                meanTaxonComp = mean(abs(array(compDataDict[taxon][refineStr])))
                stdTaxonComp = std(abs(array(compDataDict[taxon][refineStr])))
                meanTaxonCont = mean(abs(array(contDataDict[taxon][refineStr])))
                stdTaxonCont = std(abs(array(contDataDict[taxon][refineStr])))
                
                refinmentTableOut.write('\t%.1f +/- %.2f\t%.1f +/- %.2f' % (meanTaxonComp, stdTaxonComp, meanTaxonCont, stdTaxonCont))
            
            perCompChange = (mean(abs(array(compDataDict[taxon]['IM']))) - meanTaxonComp) * 100 / mean(abs(array(compDataDict[taxon]['IM'])))
            perContChange = (mean(abs(array(contDataDict[taxon]['IM']))) - meanTaxonCont) * 100 / mean(abs(array(contDataDict[taxon]['IM'])))
            refinmentTableOut.write('\t%.2f\t%.2f\n' % (perCompChange, perContChange))
        refinmentTableOut.close()
       
        # plot data
        print '  Plotting results.'
        compData = []
        contData = []
        rowLabels = []
        for taxon in orderedTaxa:
            for refineStr in ['RMS', 'MS', 'IM']:
                numGenomes = len(genomeInTaxon[taxon])
                if numGenomes < 10: # skip groups with only a few genomes
                    continue

                rowLabels.append(refineStr + ': ' + taxon + ' (' + str(numGenomes) + ')')
                compData.append(compDataDict[taxon][refineStr])
                contData.append(contDataDict[taxon][refineStr])       
                
        for i, rowLabel in enumerate(rowLabels):
            print rowLabel + '\t%.2f\t%.2f' % (mean(abs(array(compData[i]))), mean(abs(array(contData[i]))))
            
        boxPlot = BoxPlot()
        plotFilename = self.plotPrefix + '.refinements.png'
        boxPlot.plot(plotFilename, compData, contData, rowLabels, 
                        r'$\Delta$' + ' % Completion', None, 
                        r'$\Delta$' + ' % Contamination', None,
                        rowsPerCategory = 3, dpi = self.dpi)
        
    def run(self):
        # read simulation results
        print '  Reading simulation results.'
        results = self.__readResults(self.simCompareFile)
        
        print '\n'         
        #self.markerSets(results)
                   
        print '\n'         
        #self.conditionsPlot(results)
        
        #print '\n'
        self.taxonomicPlots(results)
        
        print '\n'