コード例 #1
0
class SimulationScaffolds(object):
    def __init__(self):
        self.markerSetBuilder = MarkerSetBuilder()
        self.img = IMG('/srv/whitlam/bio/db/checkm/img/img_metadata.tsv',
                       '/srv/whitlam/bio/db/checkm/pfam/tigrfam2pfam.tsv')

        self.contigLens = [5000, 20000, 50000]
        self.percentComps = [0.5, 0.7, 0.8, 0.9, 0.95, 1.0]
        self.percentConts = [0.0, 0.05, 0.1, 0.15, 0.2]

    def __seqLens(self, seqs):
        """Calculate lengths of seqs."""
        genomeSize = 0
        seqLens = {}
        for seqId, seq in seqs.iteritems():
            seqLens[seqId] = len(seq)
            genomeSize += len(seq)

        return seqLens, genomeSize

    def __workerThread(self, tree, metadata, genomeIdsToTest,
                       ubiquityThreshold, singleCopyThreshold, numReplicates,
                       queueIn, queueOut):
        """Process each data item in parallel."""

        while True:
            testGenomeId = queueIn.get(block=True, timeout=None)
            if testGenomeId == None:
                break

            # build marker sets for evaluating test genome
            testNode = tree.find_node_with_taxon_label('IMG_' + testGenomeId)
            binMarkerSets, refinedBinMarkerSet = self.markerSetBuilder.buildBinMarkerSet(
                tree,
                testNode.parent_node,
                ubiquityThreshold,
                singleCopyThreshold,
                bMarkerSet=True,
                genomeIdsToRemove=[testGenomeId])

            # determine distribution of all marker genes within the test genome
            geneDistTable = self.img.geneDistTable(
                [testGenomeId],
                binMarkerSets.getMarkerGenes(),
                spacingBetweenContigs=0)

            # estimate completeness of unmodified genome
            unmodifiedComp = {}
            unmodifiedCont = {}
            for ms in binMarkerSets.markerSetIter():
                hits = {}
                for mg in ms.getMarkerGenes():
                    if mg in geneDistTable[testGenomeId]:
                        hits[mg] = geneDistTable[testGenomeId][mg]
                completeness, contamination = ms.genomeCheck(
                    hits, bIndividualMarkers=True)
                unmodifiedComp[ms.lineageStr] = completeness
                unmodifiedCont[ms.lineageStr] = contamination

            # estimate completion and contamination of genome after subsampling using both the domain and lineage-specific marker sets
            testSeqs = readFasta(
                os.path.join(self.img.genomeDir, testGenomeId,
                             testGenomeId + '.fna'))
            testSeqLens, genomeSize = self.__seqLens(testSeqs)

            for contigLen in self.contigLens:
                for percentComp in self.percentComps:
                    for percentCont in self.percentConts:
                        deltaComp = defaultdict(list)
                        deltaCont = defaultdict(list)
                        deltaCompSet = defaultdict(list)
                        deltaContSet = defaultdict(list)

                        deltaCompRefined = defaultdict(list)
                        deltaContRefined = defaultdict(list)
                        deltaCompSetRefined = defaultdict(list)
                        deltaContSetRefined = defaultdict(list)

                        trueComps = []
                        trueConts = []

                        numDescendants = {}

                        for i in xrange(0, numReplicates):
                            # generate test genome with a specific level of completeness, by randomly sampling scaffolds to remove
                            # (this will sample >= the desired level of completeness)
                            retainedTestSeqs, trueComp = self.markerSetBuilder.sampleGenomeScaffoldsWithoutReplacement(
                                percentComp, testSeqLens, genomeSize)
                            trueComps.append(trueComp)

                            # select a random genome to use as a source of contamination
                            contGenomeId = random.sample(
                                genomeIdsToTest - set([testGenomeId]), 1)[0]
                            contSeqs = readFasta(
                                os.path.join(self.img.genomeDir, contGenomeId,
                                             contGenomeId + '.fna'))
                            contSeqLens, contGenomeSize = self.__seqLens(
                                contSeqs)
                            seqsToRetain, trueRetainedPer = self.markerSetBuilder.sampleGenomeScaffoldsWithoutReplacement(
                                1 - percentCont, contSeqLens, contGenomeSize)

                            contSampledSeqIds = set(
                                contSeqs.keys()).difference(seqsToRetain)
                            trueCont = 100.0 - trueRetainedPer
                            trueConts.append(trueCont)

                            for ms in binMarkerSets.markerSetIter():
                                numDescendants[ms.lineageStr] = ms.numGenomes
                                containedMarkerGenes = defaultdict(list)
                                self.markerSetBuilder.markerGenesOnScaffolds(
                                    ms.getMarkerGenes(), testGenomeId,
                                    retainedTestSeqs, containedMarkerGenes)
                                self.markerSetBuilder.markerGenesOnScaffolds(
                                    ms.getMarkerGenes(), contGenomeId,
                                    contSampledSeqIds, containedMarkerGenes)

                                completeness, contamination = ms.genomeCheck(
                                    containedMarkerGenes,
                                    bIndividualMarkers=True)
                                deltaComp[ms.lineageStr].append(completeness -
                                                                trueComp)
                                deltaCont[ms.lineageStr].append(contamination -
                                                                trueCont)

                                completeness, contamination = ms.genomeCheck(
                                    containedMarkerGenes,
                                    bIndividualMarkers=False)
                                deltaCompSet[ms.lineageStr].append(
                                    completeness - trueComp)
                                deltaContSet[ms.lineageStr].append(
                                    contamination - trueCont)

                            for ms in refinedBinMarkerSet.markerSetIter():
                                containedMarkerGenes = defaultdict(list)
                                self.markerSetBuilder.markerGenesOnScaffolds(
                                    ms.getMarkerGenes(), testGenomeId,
                                    retainedTestSeqs, containedMarkerGenes)
                                self.markerSetBuilder.markerGenesOnScaffolds(
                                    ms.getMarkerGenes(), contGenomeId,
                                    contSampledSeqIds, containedMarkerGenes)

                                completeness, contamination = ms.genomeCheck(
                                    containedMarkerGenes,
                                    bIndividualMarkers=True)
                                deltaCompRefined[ms.lineageStr].append(
                                    completeness - trueComp)
                                deltaContRefined[ms.lineageStr].append(
                                    contamination - trueCont)

                                completeness, contamination = ms.genomeCheck(
                                    containedMarkerGenes,
                                    bIndividualMarkers=False)
                                deltaCompSetRefined[ms.lineageStr].append(
                                    completeness - trueComp)
                                deltaContSetRefined[ms.lineageStr].append(
                                    contamination - trueCont)

                        taxonomy = ';'.join(metadata[testGenomeId]['taxonomy'])
                        queueOut.put(
                            (testGenomeId, contigLen, percentComp, percentCont,
                             taxonomy, numDescendants, unmodifiedComp,
                             unmodifiedCont, deltaComp, deltaCont,
                             deltaCompSet, deltaContSet, deltaCompRefined,
                             deltaContRefined, deltaCompSetRefined,
                             deltaContSetRefined, trueComps, trueConts))

    def __writerThread(self, numTestGenomes, writerQueue):
        """Store or write results of worker threads in a single thread."""

        summaryOut = open(
            '/tmp/simulation.random_scaffolds.w_refinement_50.draft.summary.tsv',
            'w')
        summaryOut.write('Genome Id\tContig len\t% comp\t% cont')
        summaryOut.write('\tTaxonomy\tMarker set\t# descendants')
        summaryOut.write('\tUnmodified comp\tUnmodified cont')
        summaryOut.write('\tIM comp\tIM comp std\tIM cont\tIM cont std')
        summaryOut.write('\tMS comp\tMS comp std\tMS cont\tMS cont std')
        summaryOut.write('\tRIM comp\tRIM comp std\tRIM cont\tRIM cont std')
        summaryOut.write('\tRMS comp\tRMS comp std\tRMS cont\tRMS cont std\n')

        fout = gzip.open(
            '/tmp/simulation.random_scaffolds.w_refinement_50.draft.tsv.gz',
            'wb')
        fout.write('Genome Id\tContig len\t% comp\t% cont')
        fout.write('\tTaxonomy\tMarker set\t# descendants')
        fout.write('\tUnmodified comp\tUnmodified cont')
        fout.write('\tIM comp\tIM cont')
        fout.write('\tMS comp\tMS cont')
        fout.write('\tRIM comp\tRIM cont')
        fout.write('\tRMS comp\tRMS cont\tTrue Comp\tTrue Cont\n')

        testsPerGenome = len(self.contigLens) * len(self.percentComps) * len(
            self.percentConts)

        itemsProcessed = 0
        while True:
            testGenomeId, contigLen, percentComp, percentCont, taxonomy, numDescendants, unmodifiedComp, unmodifiedCont, deltaComp, deltaCont, deltaCompSet, deltaContSet, deltaCompRefined, deltaContRefined, deltaCompSetRefined, deltaContSetRefined, trueComps, trueConts = writerQueue.get(
                block=True, timeout=None)
            if testGenomeId == None:
                break

            itemsProcessed += 1
            statusStr = '    Finished processing %d of %d (%.2f%%) test cases.' % (
                itemsProcessed, numTestGenomes * testsPerGenome,
                float(itemsProcessed) * 100 /
                (numTestGenomes * testsPerGenome))
            sys.stdout.write('%s\r' % statusStr)
            sys.stdout.flush()

            for markerSetId in unmodifiedComp:
                summaryOut.write(testGenomeId + '\t%d\t%.2f\t%.2f' %
                                 (contigLen, percentComp, percentCont))
                summaryOut.write('\t' + taxonomy + '\t' + markerSetId + '\t' +
                                 str(numDescendants[markerSetId]))
                summaryOut.write(
                    '\t%.3f\t%.3f' %
                    (unmodifiedComp[markerSetId], unmodifiedCont[markerSetId]))
                summaryOut.write('\t%.3f\t%.3f' %
                                 (mean(abs(deltaComp[markerSetId])),
                                  std(abs(deltaComp[markerSetId]))))
                summaryOut.write('\t%.3f\t%.3f' %
                                 (mean(abs(deltaCont[markerSetId])),
                                  std(abs(deltaCont[markerSetId]))))
                summaryOut.write('\t%.3f\t%.3f' %
                                 (mean(abs(deltaCompSet[markerSetId])),
                                  std(abs(deltaCompSet[markerSetId]))))
                summaryOut.write('\t%.3f\t%.3f' %
                                 (mean(abs(deltaContSet[markerSetId])),
                                  std(abs(deltaContSet[markerSetId]))))
                summaryOut.write('\t%.3f\t%.3f' %
                                 (mean(abs(deltaCompRefined[markerSetId])),
                                  std(abs(deltaCompRefined[markerSetId]))))
                summaryOut.write('\t%.3f\t%.3f' %
                                 (mean(abs(deltaContRefined[markerSetId])),
                                  std(abs(deltaContRefined[markerSetId]))))
                summaryOut.write('\t%.3f\t%.3f' %
                                 (mean(abs(deltaCompSetRefined[markerSetId])),
                                  std(abs(deltaCompSetRefined[markerSetId]))))
                summaryOut.write('\t%.3f\t%.3f' %
                                 (mean(abs(deltaContSetRefined[markerSetId])),
                                  std(abs(deltaContSetRefined[markerSetId]))))
                summaryOut.write('\n')

                fout.write(testGenomeId + '\t%d\t%.2f\t%.2f' %
                           (contigLen, percentComp, percentCont))
                fout.write('\t' + taxonomy + '\t' + markerSetId + '\t' +
                           str(numDescendants[markerSetId]))
                fout.write(
                    '\t%.3f\t%.3f' %
                    (unmodifiedComp[markerSetId], unmodifiedCont[markerSetId]))
                fout.write('\t%s' % ','.join(map(str, deltaComp[markerSetId])))
                fout.write('\t%s' % ','.join(map(str, deltaCont[markerSetId])))
                fout.write('\t%s' %
                           ','.join(map(str, deltaCompSet[markerSetId])))
                fout.write('\t%s' %
                           ','.join(map(str, deltaContSet[markerSetId])))
                fout.write('\t%s' %
                           ','.join(map(str, deltaCompRefined[markerSetId])))
                fout.write('\t%s' %
                           ','.join(map(str, deltaContRefined[markerSetId])))
                fout.write(
                    '\t%s' %
                    ','.join(map(str, deltaCompSetRefined[markerSetId])))
                fout.write(
                    '\t%s' %
                    ','.join(map(str, deltaContSetRefined[markerSetId])))
                fout.write('\t%s' % ','.join(map(str, trueComps)))
                fout.write('\t%s' % ','.join(map(str, trueConts)))
                fout.write('\n')

        summaryOut.close()
        fout.close()

        sys.stdout.write('\n')

    def run(self, ubiquityThreshold, singleCopyThreshold, numReplicates,
            minScaffolds, numThreads):
        random.seed(0)

        print '\n  Reading reference genome tree.'
        treeFile = os.path.join('/srv', 'db', 'checkm', 'genome_tree',
                                'genome_tree_prok.refpkg',
                                'genome_tree.final.tre')
        tree = dendropy.Tree.get_from_path(treeFile,
                                           schema='newick',
                                           as_rooted=True,
                                           preserve_underscores=True)

        print '    Number of taxa in tree: %d' % (len(tree.leaf_nodes()))

        genomesInTree = set()
        for leaf in tree.leaf_iter():
            genomesInTree.add(leaf.taxon.label.replace('IMG_', ''))

        # get all draft genomes consisting of a user-specific minimum number of scaffolds
        print ''
        metadata = self.img.genomeMetadata()
        print '  Total genomes: %d' % len(metadata)

        draftGenomeIds = genomesInTree - self.img.filterGenomeIds(
            genomesInTree, metadata, 'status', 'Finished')
        print '  Number of draft genomes: %d' % len(draftGenomeIds)

        genomeIdsToTest = set()
        for genomeId in draftGenomeIds:
            if metadata[genomeId]['scaffold count'] >= minScaffolds:
                genomeIdsToTest.add(genomeId)

        print '  Number of draft genomes with >= %d scaffolds: %d' % (
            minScaffolds, len(genomeIdsToTest))

        print ''
        start = time.time()
        self.markerSetBuilder.readLineageSpecificGenesToRemove()
        end = time.time()
        print '    readLineageSpecificGenesToRemove: %.2f' % (end - start)

        print '  Pre-computing genome information for calculating marker sets:'
        start = time.time()
        self.markerSetBuilder.precomputeGenomeFamilyScaffolds(metadata.keys())
        end = time.time()
        print '    precomputeGenomeFamilyScaffolds: %.2f' % (end - start)

        start = time.time()
        self.markerSetBuilder.cachedGeneCountTable = self.img.geneCountTable(
            metadata.keys())
        end = time.time()
        print '    globalGeneCountTable: %.2f' % (end - start)

        start = time.time()
        self.markerSetBuilder.precomputeGenomeSeqLens(metadata.keys())
        end = time.time()
        print '    precomputeGenomeSeqLens: %.2f' % (end - start)

        start = time.time()
        self.markerSetBuilder.precomputeGenomeFamilyPositions(
            metadata.keys(), 0)
        end = time.time()
        print '    precomputeGenomeFamilyPositions: %.2f' % (end - start)

        print ''
        print '  Evaluating %d test genomes.' % len(genomeIdsToTest)

        workerQueue = mp.Queue()
        writerQueue = mp.Queue()

        for testGenomeId in list(genomeIdsToTest):
            workerQueue.put(testGenomeId)

        for _ in range(numThreads):
            workerQueue.put(None)

        workerProc = [
            mp.Process(target=self.__workerThread,
                       args=(tree, metadata, genomeIdsToTest,
                             ubiquityThreshold, singleCopyThreshold,
                             numReplicates, workerQueue, writerQueue))
            for _ in range(numThreads)
        ]
        writeProc = mp.Process(target=self.__writerThread,
                               args=(len(genomeIdsToTest), writerQueue))

        writeProc.start()

        for p in workerProc:
            p.start()

        for p in workerProc:
            p.join()

        writerQueue.put((None, None, None, None, None, None, None, None, None,
                         None, None, None, None, None, None, None, None, None))
        writeProc.join()
コード例 #2
0
ファイル: simulation.py プロジェクト: HadrienG/CheckM
class Simulation(object):
    def __init__(self):
        self.markerSetBuilder = MarkerSetBuilder()
        self.img = IMG('/srv/whitlam/bio/db/checkm/img/img_metadata.tsv', '/srv/whitlam/bio/db/checkm/pfam/tigrfam2pfam.tsv')

        self.contigLens = [1000, 2000, 5000, 10000, 20000, 50000]
        self.percentComps = [0.5, 0.7, 0.8, 0.9, 0.95, 1.0]
        self.percentConts = [0.0, 0.05, 0.1, 0.15, 0.2]

    def __workerThread(self, tree, metadata, ubiquityThreshold, singleCopyThreshold, numReplicates, queueIn, queueOut):
        """Process each data item in parallel."""

        while True:
            testGenomeId = queueIn.get(block=True, timeout=None)
            if testGenomeId == None:
                break

            # build marker sets for evaluating test genome
            testNode = tree.find_node_with_taxon_label('IMG_' + testGenomeId)
            binMarkerSets, refinedBinMarkerSet = self.markerSetBuilder.buildBinMarkerSet(tree, testNode.parent_node, ubiquityThreshold, singleCopyThreshold, bMarkerSet=True, genomeIdsToRemove=[testGenomeId])
            #!!!binMarkerSets, refinedBinMarkerSet = self.markerSetBuilder.buildDomainMarkerSet(tree, testNode.parent_node, ubiquityThreshold, singleCopyThreshold, bMarkerSet = False, genomeIdsToRemove = [testGenomeId])

            # determine distribution of all marker genes within the test genome
            geneDistTable = self.img.geneDistTable([testGenomeId], binMarkerSets.getMarkerGenes(), spacingBetweenContigs=0)

            print('# marker genes: ', len(binMarkerSets.getMarkerGenes()))
            print('# genes in table: ', len(geneDistTable[testGenomeId]))

            # estimate completeness of unmodified genome
            unmodifiedComp = {}
            unmodifiedCont = {}
            for ms in binMarkerSets.markerSetIter():
                hits = {}
                for mg in ms.getMarkerGenes():
                    if mg in geneDistTable[testGenomeId]:
                        hits[mg] = geneDistTable[testGenomeId][mg]
                completeness, contamination = ms.genomeCheck(hits, bIndividualMarkers=True)
                unmodifiedComp[ms.lineageStr] = completeness
                unmodifiedCont[ms.lineageStr] = contamination

            print(completeness, contamination)

            # estimate completion and contamination of genome after subsampling using both the domain and lineage-specific marker sets
            genomeSize = readFastaBases(os.path.join(self.img.genomeDir, testGenomeId, testGenomeId + '.fna'))
            print('genomeSize', genomeSize)

            for contigLen in self.contigLens:
                for percentComp in self.percentComps:
                    for percentCont in self.percentConts:
                        deltaComp = defaultdict(list)
                        deltaCont = defaultdict(list)
                        deltaCompSet = defaultdict(list)
                        deltaContSet = defaultdict(list)

                        deltaCompRefined = defaultdict(list)
                        deltaContRefined = defaultdict(list)
                        deltaCompSetRefined = defaultdict(list)
                        deltaContSetRefined = defaultdict(list)

                        trueComps = []
                        trueConts = []

                        numDescendants = {}

                        for _ in range(0, numReplicates):
                            trueComp, trueCont, startPartialGenomeContigs = self.markerSetBuilder.sampleGenome(genomeSize, percentComp, percentCont, contigLen)
                            print(contigLen, trueComp, trueCont, len(startPartialGenomeContigs))

                            trueComps.append(trueComp)
                            trueConts.append(trueCont)

                            for ms in binMarkerSets.markerSetIter():
                                numDescendants[ms.lineageStr] = ms.numGenomes

                                containedMarkerGenes = self.markerSetBuilder.containedMarkerGenes(ms.getMarkerGenes(), geneDistTable[testGenomeId], startPartialGenomeContigs, contigLen)
                                completeness, contamination = ms.genomeCheck(containedMarkerGenes, bIndividualMarkers=True)
                                deltaComp[ms.lineageStr].append(completeness - trueComp)
                                deltaCont[ms.lineageStr].append(contamination - trueCont)

                                completeness, contamination = ms.genomeCheck(containedMarkerGenes, bIndividualMarkers=False)
                                deltaCompSet[ms.lineageStr].append(completeness - trueComp)
                                deltaContSet[ms.lineageStr].append(contamination - trueCont)

                            for ms in refinedBinMarkerSet.markerSetIter():
                                containedMarkerGenes = self.markerSetBuilder.containedMarkerGenes(ms.getMarkerGenes(), geneDistTable[testGenomeId], startPartialGenomeContigs, contigLen)
                                completeness, contamination = ms.genomeCheck(containedMarkerGenes, bIndividualMarkers=True)
                                deltaCompRefined[ms.lineageStr].append(completeness - trueComp)
                                deltaContRefined[ms.lineageStr].append(contamination - trueCont)

                                completeness, contamination = ms.genomeCheck(containedMarkerGenes, bIndividualMarkers=False)
                                deltaCompSetRefined[ms.lineageStr].append(completeness - trueComp)
                                deltaContSetRefined[ms.lineageStr].append(contamination - trueCont)

                        taxonomy = ';'.join(metadata[testGenomeId]['taxonomy'])
                        queueOut.put((testGenomeId, contigLen, percentComp, percentCont, taxonomy, numDescendants, unmodifiedComp, unmodifiedCont, trueComps, trueConts, deltaComp, deltaCont, deltaCompSet, deltaContSet, deltaCompRefined, deltaContRefined, deltaCompSetRefined, deltaContSetRefined, trueComps, trueConts))

    def __writerThread(self, numTestGenomes, writerQueue):
        """Store or write results of worker threads in a single thread."""

        # summaryOut = open('/tmp/simulation.draft.summary.w_refinement_50.tsv', 'w')
        summaryOut = open('/tmp/simulation.summary.testing.tsv', 'w')
        summaryOut.write('Genome Id\tContig len\t% comp\t% cont')
        summaryOut.write('\tTaxonomy\tMarker set\t# descendants')
        summaryOut.write('\tUnmodified comp\tUnmodified cont\tTrue comp\tTrue cont')
        summaryOut.write('\tIM comp\tIM comp std\tIM cont\tIM cont std')
        summaryOut.write('\tMS comp\tMS comp std\tMS cont\tMS cont std')
        summaryOut.write('\tRIM comp\tRIM comp std\tRIM cont\tRIM cont std')
        summaryOut.write('\tRMS comp\tRMS comp std\tRMS cont\tRMS cont std\n')

        # fout = gzip.open('/tmp/simulation.draft.w_refinement_50.tsv.gz', 'wb')
        fout = gzip.open('/tmp/simulation.testing.tsv.gz', 'wb')
        fout.write('Genome Id\tContig len\t% comp\t% cont')
        fout.write('\tTaxonomy\tMarker set\t# descendants')
        fout.write('\tUnmodified comp\tUnmodified cont\tTrue comp\tTrue cont')
        fout.write('\tIM comp\tIM cont')
        fout.write('\tMS comp\tMS cont')
        fout.write('\tRIM comp\tRIM cont')
        fout.write('\tRMS comp\tRMS cont\tTrue Comp\tTrue Cont\n')

        testsPerGenome = len(self.contigLens) * len(self.percentComps) * len(self.percentConts)

        itemsProcessed = 0
        while True:
            testGenomeId, contigLen, percentComp, percentCont, taxonomy, numDescendants, unmodifiedComp, unmodifiedCont, trueComps, trueConts, deltaComp, deltaCont, deltaCompSet, deltaContSet, deltaCompRefined, deltaContRefined, deltaCompSetRefined, deltaContSetRefined, trueComps, trueConts = writerQueue.get(block=True, timeout=None)
            if testGenomeId == None:
                break

            itemsProcessed += 1
            statusStr = '    Finished processing %d of %d (%.2f%%) test cases.' % (itemsProcessed, numTestGenomes * testsPerGenome, float(itemsProcessed) * 100 / (numTestGenomes * testsPerGenome))
            sys.stdout.write('%s\r' % statusStr)
            sys.stdout.flush()

            for markerSetId in unmodifiedComp:
                summaryOut.write(testGenomeId + '\t%d\t%.2f\t%.2f' % (contigLen, percentComp, percentCont))
                summaryOut.write('\t' + taxonomy + '\t' + markerSetId + '\t' + str(numDescendants[markerSetId]))
                summaryOut.write('\t%.3f\t%.3f' % (unmodifiedComp[markerSetId], unmodifiedCont[markerSetId]))
                summaryOut.write('\t%.3f\t%.3f' % (mean(trueComps), std(trueConts)))
                summaryOut.write('\t%.3f\t%.3f' % (mean(abs(deltaComp[markerSetId])), std(abs(deltaComp[markerSetId]))))
                summaryOut.write('\t%.3f\t%.3f' % (mean(abs(deltaCont[markerSetId])), std(abs(deltaCont[markerSetId]))))
                summaryOut.write('\t%.3f\t%.3f' % (mean(abs(deltaCompSet[markerSetId])), std(abs(deltaCompSet[markerSetId]))))
                summaryOut.write('\t%.3f\t%.3f' % (mean(abs(deltaContSet[markerSetId])), std(abs(deltaContSet[markerSetId]))))
                summaryOut.write('\t%.3f\t%.3f' % (mean(abs(deltaCompRefined[markerSetId])), std(abs(deltaCompRefined[markerSetId]))))
                summaryOut.write('\t%.3f\t%.3f' % (mean(abs(deltaContRefined[markerSetId])), std(abs(deltaContRefined[markerSetId]))))
                summaryOut.write('\t%.3f\t%.3f' % (mean(abs(deltaCompSetRefined[markerSetId])), std(abs(deltaCompSetRefined[markerSetId]))))
                summaryOut.write('\t%.3f\t%.3f' % (mean(abs(deltaContSetRefined[markerSetId])), std(abs(deltaContSetRefined[markerSetId]))))
                summaryOut.write('\n')

                fout.write(testGenomeId + '\t%d\t%.2f\t%.2f' % (contigLen, percentComp, percentCont))
                fout.write('\t' + taxonomy + '\t' + markerSetId + '\t' + str(numDescendants[markerSetId]))
                fout.write('\t%.3f\t%.3f' % (unmodifiedComp[markerSetId], unmodifiedCont[markerSetId]))
                fout.write('\t%s' % ','.join(map(str, trueComps)))
                fout.write('\t%s' % ','.join(map(str, trueConts)))
                fout.write('\t%s' % ','.join(map(str, deltaComp[markerSetId])))
                fout.write('\t%s' % ','.join(map(str, deltaCont[markerSetId])))
                fout.write('\t%s' % ','.join(map(str, deltaCompSet[markerSetId])))
                fout.write('\t%s' % ','.join(map(str, deltaContSet[markerSetId])))
                fout.write('\t%s' % ','.join(map(str, deltaCompRefined[markerSetId])))
                fout.write('\t%s' % ','.join(map(str, deltaContRefined[markerSetId])))
                fout.write('\t%s' % ','.join(map(str, deltaCompSetRefined[markerSetId])))
                fout.write('\t%s' % ','.join(map(str, deltaContSetRefined[markerSetId])))
                fout.write('\t%s' % ','.join(map(str, trueComps)))
                fout.write('\t%s' % ','.join(map(str, trueConts)))
                fout.write('\n')

        summaryOut.close()
        fout.close()

        sys.stdout.write('\n')

    def run(self, ubiquityThreshold, singleCopyThreshold, numReplicates, numThreads):
        print('\n  Reading reference genome tree.')
        treeFile = os.path.join('/srv', 'db', 'checkm', 'genome_tree', 'genome_tree_full.refpkg', 'genome_tree.tre')
        tree = dendropy.Tree.get_from_path(treeFile, schema='newick', as_rooted=True, preserve_underscores=True)

        print('    Number of taxa in tree: %d' % (len(tree.leaf_nodes())))

        genomesInTree = set()
        for leaf in tree.leaf_iter():
            genomesInTree.add(leaf.taxon.label.replace('IMG_', ''))

        # get all draft genomes for testing
        print('')
        metadata = self.img.genomeMetadata()
        print('  Total genomes: %d' % len(metadata))

        genomeIdsToTest = genomesInTree - self.img.filterGenomeIds(genomesInTree, metadata, 'status', 'Finished')
        print('  Number of draft genomes: %d' % len(genomeIdsToTest))

        print('')
        print('  Pre-computing genome information for calculating marker sets:')
        start = time.time()
        self.markerSetBuilder.readLineageSpecificGenesToRemove()
        end = time.time()
        print('    readLineageSpecificGenesToRemove: %.2f' % (end - start))


        start = time.time()
        # self.markerSetBuilder.cachedGeneCountTable = self.img.geneCountTable(metadata.keys())
        end = time.time()
        print('    globalGeneCountTable: %.2f' % (end - start))

        start = time.time()
        # self.markerSetBuilder.precomputeGenomeSeqLens(metadata.keys())
        end = time.time()
        print('    precomputeGenomeSeqLens: %.2f' % (end - start))

        start = time.time()
        # self.markerSetBuilder.precomputeGenomeFamilyPositions(metadata.keys(), 0)
        end = time.time()
        print('    precomputeGenomeFamilyPositions: %.2f' % (end - start))

        print('')
        print('  Evaluating %d test genomes.' % len(genomeIdsToTest))
        workerQueue = mp.Queue()
        writerQueue = mp.Queue()

        for testGenomeId in genomeIdsToTest:
            workerQueue.put(testGenomeId)

        for _ in range(numThreads):
            workerQueue.put(None)

        workerProc = [mp.Process(target=self.__workerThread, args=(tree, metadata, ubiquityThreshold, singleCopyThreshold, numReplicates, workerQueue, writerQueue)) for _ in range(numThreads)]
        writeProc = mp.Process(target=self.__writerThread, args=(len(genomeIdsToTest), writerQueue))

        writeProc.start()

        for p in workerProc:
            p.start()

        for p in workerProc:
            p.join()

        writerQueue.put((None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None))
        writeProc.join()
コード例 #3
0
class SimulationScaffolds(object):
    def __init__(self):
        self.markerSetBuilder = MarkerSetBuilder()
        self.img = IMG('/srv/whitlam/bio/db/checkm/img/img_metadata.tsv', '/srv/whitlam/bio/db/checkm/pfam/tigrfam2pfam.tsv')
        
        self.contigLens = [5000, 20000, 50000]
        self.percentComps = [0.5, 0.7, 0.8, 0.9, 0.95, 1.0]
        self.percentConts = [0.0, 0.05, 0.1, 0.15, 0.2]

    def __seqLens(self, seqs):
        """Calculate lengths of seqs."""
        genomeSize = 0
        seqLens = {}
        for seqId, seq in seqs.iteritems():
            seqLens[seqId] = len(seq)
            genomeSize += len(seq)
    
        return seqLens, genomeSize
    
    def __workerThread(self, tree, metadata, genomeIdsToTest, ubiquityThreshold, singleCopyThreshold, numReplicates, queueIn, queueOut):
        """Process each data item in parallel."""

        while True:
            testGenomeId = queueIn.get(block=True, timeout=None)
            if testGenomeId == None:
                break
                        
            # build marker sets for evaluating test genome
            testNode = tree.find_node_with_taxon_label('IMG_' + testGenomeId)
            binMarkerSets, refinedBinMarkerSet = self.markerSetBuilder.buildBinMarkerSet(tree, testNode.parent_node, ubiquityThreshold, singleCopyThreshold, bMarkerSet = True, genomeIdsToRemove = [testGenomeId])

            # determine distribution of all marker genes within the test genome
            geneDistTable = self.img.geneDistTable([testGenomeId], binMarkerSets.getMarkerGenes(), spacingBetweenContigs=0)
                
            # estimate completeness of unmodified genome
            unmodifiedComp = {}
            unmodifiedCont = {}
            for ms in binMarkerSets.markerSetIter():     
                hits = {}
                for mg in ms.getMarkerGenes():
                    if mg in geneDistTable[testGenomeId]:
                        hits[mg] = geneDistTable[testGenomeId][mg]
                completeness, contamination = ms.genomeCheck(hits, bIndividualMarkers=True) 
                unmodifiedComp[ms.lineageStr] = completeness
                unmodifiedCont[ms.lineageStr] = contamination

            # estimate completion and contamination of genome after subsampling using both the domain and lineage-specific marker sets 
            testSeqs = readFasta(os.path.join(self.img.genomeDir, testGenomeId, testGenomeId + '.fna'))
            testSeqLens, genomeSize = self.__seqLens(testSeqs)
            
            
            for contigLen in self.contigLens: 
                for percentComp in self.percentComps:
                    for percentCont in self.percentConts:
                        deltaComp = defaultdict(list)
                        deltaCont = defaultdict(list)
                        deltaCompSet = defaultdict(list)
                        deltaContSet = defaultdict(list)
                        
                        deltaCompRefined = defaultdict(list)
                        deltaContRefined = defaultdict(list)
                        deltaCompSetRefined = defaultdict(list)
                        deltaContSetRefined = defaultdict(list)
                        
                        trueComps = []
                        trueConts = []
                        
                        numDescendants = {}
            
                        for i in xrange(0, numReplicates):
                            # generate test genome with a specific level of completeness, by randomly sampling scaffolds to remove 
                            # (this will sample >= the desired level of completeness)
                            retainedTestSeqs, trueComp = self.markerSetBuilder.sampleGenomeScaffoldsWithoutReplacement(percentComp, testSeqLens, genomeSize)
                            trueComps.append(trueComp)
    
                            # select a random genome to use as a source of contamination
                            contGenomeId = random.sample(genomeIdsToTest - set([testGenomeId]), 1)[0]
                            contSeqs = readFasta(os.path.join(self.img.genomeDir, contGenomeId, contGenomeId + '.fna'))
                            contSeqLens, contGenomeSize = self.__seqLens(contSeqs) 
                            seqsToRetain, trueRetainedPer = self.markerSetBuilder.sampleGenomeScaffoldsWithoutReplacement(1 - percentCont, contSeqLens, contGenomeSize) 
                            
                            contSampledSeqIds = set(contSeqs.keys()).difference(seqsToRetain)
                            trueCont = 100.0 - trueRetainedPer
                            trueConts.append(trueCont)
              
                            for ms in binMarkerSets.markerSetIter():  
                                numDescendants[ms.lineageStr] = ms.numGenomes
                                containedMarkerGenes= defaultdict(list)
                                self.markerSetBuilder.markerGenesOnScaffolds(ms.getMarkerGenes(), testGenomeId, retainedTestSeqs, containedMarkerGenes)
                                self.markerSetBuilder.markerGenesOnScaffolds(ms.getMarkerGenes(), contGenomeId, contSampledSeqIds, containedMarkerGenes)

                                completeness, contamination = ms.genomeCheck(containedMarkerGenes, bIndividualMarkers=True)
                                deltaComp[ms.lineageStr].append(completeness - trueComp)
                                deltaCont[ms.lineageStr].append(contamination - trueCont)
                                
                                completeness, contamination = ms.genomeCheck(containedMarkerGenes, bIndividualMarkers=False)
                                deltaCompSet[ms.lineageStr].append(completeness - trueComp)
                                deltaContSet[ms.lineageStr].append(contamination - trueCont)
                                
                            for ms in refinedBinMarkerSet.markerSetIter():  
                                containedMarkerGenes= defaultdict(list)
                                self.markerSetBuilder.markerGenesOnScaffolds(ms.getMarkerGenes(), testGenomeId, retainedTestSeqs, containedMarkerGenes)
                                self.markerSetBuilder.markerGenesOnScaffolds(ms.getMarkerGenes(), contGenomeId, contSampledSeqIds, containedMarkerGenes)
                                
                                completeness, contamination = ms.genomeCheck(containedMarkerGenes, bIndividualMarkers=True)
                                deltaCompRefined[ms.lineageStr].append(completeness - trueComp)
                                deltaContRefined[ms.lineageStr].append(contamination - trueCont)
                                
                                completeness, contamination = ms.genomeCheck(containedMarkerGenes, bIndividualMarkers=False)
                                deltaCompSetRefined[ms.lineageStr].append(completeness - trueComp)
                                deltaContSetRefined[ms.lineageStr].append(contamination - trueCont)
                                
                        taxonomy = ';'.join(metadata[testGenomeId]['taxonomy'])
                        queueOut.put((testGenomeId, contigLen, percentComp, percentCont, taxonomy, numDescendants, unmodifiedComp, unmodifiedCont, deltaComp, deltaCont, deltaCompSet, deltaContSet, deltaCompRefined, deltaContRefined, deltaCompSetRefined, deltaContSetRefined, trueComps, trueConts))
            
    def __writerThread(self, numTestGenomes, writerQueue):
        """Store or write results of worker threads in a single thread."""
        
        summaryOut = open('/tmp/simulation.random_scaffolds.w_refinement_50.draft.summary.tsv', 'w')
        summaryOut.write('Genome Id\tContig len\t% comp\t% cont')
        summaryOut.write('\tTaxonomy\tMarker set\t# descendants')
        summaryOut.write('\tUnmodified comp\tUnmodified cont')
        summaryOut.write('\tIM comp\tIM comp std\tIM cont\tIM cont std')
        summaryOut.write('\tMS comp\tMS comp std\tMS cont\tMS cont std')
        summaryOut.write('\tRIM comp\tRIM comp std\tRIM cont\tRIM cont std')
        summaryOut.write('\tRMS comp\tRMS comp std\tRMS cont\tRMS cont std\n')
        
        fout = gzip.open('/tmp/simulation.random_scaffolds.w_refinement_50.draft.tsv.gz', 'wb')
        fout.write('Genome Id\tContig len\t% comp\t% cont')
        fout.write('\tTaxonomy\tMarker set\t# descendants')
        fout.write('\tUnmodified comp\tUnmodified cont')
        fout.write('\tIM comp\tIM cont')
        fout.write('\tMS comp\tMS cont')
        fout.write('\tRIM comp\tRIM cont')
        fout.write('\tRMS comp\tRMS cont\tTrue Comp\tTrue Cont\n')
        
        testsPerGenome = len(self.contigLens) * len(self.percentComps) * len(self.percentConts)

        itemsProcessed = 0
        while True:
            testGenomeId, contigLen, percentComp, percentCont, taxonomy, numDescendants, unmodifiedComp, unmodifiedCont, deltaComp, deltaCont, deltaCompSet, deltaContSet, deltaCompRefined, deltaContRefined, deltaCompSetRefined, deltaContSetRefined, trueComps, trueConts = writerQueue.get(block=True, timeout=None)
            if testGenomeId == None:
                break

            itemsProcessed += 1
            statusStr = '    Finished processing %d of %d (%.2f%%) test cases.' % (itemsProcessed, numTestGenomes*testsPerGenome, float(itemsProcessed)*100/(numTestGenomes*testsPerGenome))
            sys.stdout.write('%s\r' % statusStr)
            sys.stdout.flush()
            
            for markerSetId in unmodifiedComp:
                summaryOut.write(testGenomeId + '\t%d\t%.2f\t%.2f' % (contigLen, percentComp, percentCont)) 
                summaryOut.write('\t' + taxonomy + '\t' + markerSetId + '\t' + str(numDescendants[markerSetId]))
                summaryOut.write('\t%.3f\t%.3f' % (unmodifiedComp[markerSetId], unmodifiedCont[markerSetId]))
                summaryOut.write('\t%.3f\t%.3f' % (mean(abs(deltaComp[markerSetId])), std(abs(deltaComp[markerSetId]))))
                summaryOut.write('\t%.3f\t%.3f' % (mean(abs(deltaCont[markerSetId])), std(abs(deltaCont[markerSetId]))))
                summaryOut.write('\t%.3f\t%.3f' % (mean(abs(deltaCompSet[markerSetId])), std(abs(deltaCompSet[markerSetId]))))
                summaryOut.write('\t%.3f\t%.3f' % (mean(abs(deltaContSet[markerSetId])), std(abs(deltaContSet[markerSetId]))))
                summaryOut.write('\t%.3f\t%.3f' % (mean(abs(deltaCompRefined[markerSetId])), std(abs(deltaCompRefined[markerSetId]))))
                summaryOut.write('\t%.3f\t%.3f' % (mean(abs(deltaContRefined[markerSetId])), std(abs(deltaContRefined[markerSetId]))))
                summaryOut.write('\t%.3f\t%.3f' % (mean(abs(deltaCompSetRefined[markerSetId])), std(abs(deltaCompSetRefined[markerSetId]))))
                summaryOut.write('\t%.3f\t%.3f' % (mean(abs(deltaContSetRefined[markerSetId])), std(abs(deltaContSetRefined[markerSetId]))))
                summaryOut.write('\n')
                
                fout.write(testGenomeId + '\t%d\t%.2f\t%.2f' % (contigLen, percentComp, percentCont)) 
                fout.write('\t' + taxonomy + '\t' + markerSetId + '\t' + str(numDescendants[markerSetId]))
                fout.write('\t%.3f\t%.3f' % (unmodifiedComp[markerSetId], unmodifiedCont[markerSetId]))
                fout.write('\t%s' % ','.join(map(str, deltaComp[markerSetId])))
                fout.write('\t%s' % ','.join(map(str, deltaCont[markerSetId])))
                fout.write('\t%s' % ','.join(map(str, deltaCompSet[markerSetId])))
                fout.write('\t%s' % ','.join(map(str, deltaContSet[markerSetId])))
                fout.write('\t%s' % ','.join(map(str, deltaCompRefined[markerSetId])))
                fout.write('\t%s' % ','.join(map(str, deltaContRefined[markerSetId])))
                fout.write('\t%s' % ','.join(map(str, deltaCompSetRefined[markerSetId])))
                fout.write('\t%s' % ','.join(map(str, deltaContSetRefined[markerSetId])))
                fout.write('\t%s' % ','.join(map(str, trueComps)))
                fout.write('\t%s' % ','.join(map(str, trueConts)))
                fout.write('\n')
            
        summaryOut.close()
        fout.close()

        sys.stdout.write('\n')

    def run(self, ubiquityThreshold, singleCopyThreshold, numReplicates, minScaffolds, numThreads):
        random.seed(0)

        print '\n  Reading reference genome tree.'
        treeFile = os.path.join('/srv', 'db', 'checkm', 'genome_tree', 'genome_tree_prok.refpkg', 'genome_tree.final.tre')
        tree = dendropy.Tree.get_from_path(treeFile, schema='newick', as_rooted=True, preserve_underscores=True)
        
        print '    Number of taxa in tree: %d' % (len(tree.leaf_nodes()))
        
        genomesInTree = set()
        for leaf in tree.leaf_iter():
            genomesInTree.add(leaf.taxon.label.replace('IMG_', ''))

        # get all draft genomes consisting of a user-specific minimum number of scaffolds
        print ''
        metadata = self.img.genomeMetadata()
        print '  Total genomes: %d' % len(metadata)
        
        draftGenomeIds = genomesInTree - self.img.filterGenomeIds(genomesInTree, metadata, 'status', 'Finished')
        print '  Number of draft genomes: %d' % len(draftGenomeIds)
        
        genomeIdsToTest = set()
        for genomeId in draftGenomeIds:
            if metadata[genomeId]['scaffold count'] >= minScaffolds:
                genomeIdsToTest.add(genomeId)
                
        
        print '  Number of draft genomes with >= %d scaffolds: %d' % (minScaffolds, len(genomeIdsToTest))

        print ''
        start = time.time()
        self.markerSetBuilder.readLineageSpecificGenesToRemove()
        end = time.time()
        print '    readLineageSpecificGenesToRemove: %.2f' % (end - start)
        
        print '  Pre-computing genome information for calculating marker sets:'
        start = time.time()
        self.markerSetBuilder.precomputeGenomeFamilyScaffolds(metadata.keys())
        end = time.time()
        print '    precomputeGenomeFamilyScaffolds: %.2f' % (end - start)
        
        start = time.time()
        self.markerSetBuilder.cachedGeneCountTable = self.img.geneCountTable(metadata.keys())
        end = time.time()
        print '    globalGeneCountTable: %.2f' % (end - start)
        
        start = time.time()
        self.markerSetBuilder.precomputeGenomeSeqLens(metadata.keys())
        end = time.time()
        print '    precomputeGenomeSeqLens: %.2f' % (end - start)
        
        start = time.time()
        self.markerSetBuilder.precomputeGenomeFamilyPositions(metadata.keys(), 0)
        end = time.time()
        print '    precomputeGenomeFamilyPositions: %.2f' % (end - start)
                     
        print ''    
        print '  Evaluating %d test genomes.' % len(genomeIdsToTest)
            
        workerQueue = mp.Queue()
        writerQueue = mp.Queue()

        for testGenomeId in list(genomeIdsToTest):
            workerQueue.put(testGenomeId)

        for _ in range(numThreads):
            workerQueue.put(None)

        workerProc = [mp.Process(target = self.__workerThread, args = (tree, metadata, genomeIdsToTest, ubiquityThreshold, singleCopyThreshold, numReplicates, workerQueue, writerQueue)) for _ in range(numThreads)]
        writeProc = mp.Process(target = self.__writerThread, args = (len(genomeIdsToTest), writerQueue))

        writeProc.start()

        for p in workerProc:
            p.start()

        for p in workerProc:
            p.join()

        writerQueue.put((None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None))
        writeProc.join()
コード例 #4
0
ファイル: simulation.py プロジェクト: IUEayhu/CheckM
class Simulation(object):
    def __init__(self):
        self.markerSetBuilder = MarkerSetBuilder()
        self.img = IMG(
            "/srv/whitlam/bio/db/checkm/img/img_metadata.tsv", "/srv/whitlam/bio/db/checkm/pfam/tigrfam2pfam.tsv"
        )

        self.contigLens = [1000, 2000, 5000, 10000, 20000, 50000]
        self.percentComps = [0.5, 0.7, 0.8, 0.9, 0.95, 1.0]
        self.percentConts = [0.0, 0.05, 0.1, 0.15, 0.2]

    def __workerThread(self, tree, metadata, ubiquityThreshold, singleCopyThreshold, numReplicates, queueIn, queueOut):
        """Process each data item in parallel."""

        while True:
            testGenomeId = queueIn.get(block=True, timeout=None)
            if testGenomeId == None:
                break

            # build marker sets for evaluating test genome
            testNode = tree.find_node_with_taxon_label("IMG_" + testGenomeId)
            binMarkerSets, refinedBinMarkerSet = self.markerSetBuilder.buildBinMarkerSet(
                tree,
                testNode.parent_node,
                ubiquityThreshold,
                singleCopyThreshold,
                bMarkerSet=True,
                genomeIdsToRemove=[testGenomeId],
            )
            #!!!binMarkerSets, refinedBinMarkerSet = self.markerSetBuilder.buildDomainMarkerSet(tree, testNode.parent_node, ubiquityThreshold, singleCopyThreshold, bMarkerSet = False, genomeIdsToRemove = [testGenomeId])

            # determine distribution of all marker genes within the test genome
            geneDistTable = self.img.geneDistTable(
                [testGenomeId], binMarkerSets.getMarkerGenes(), spacingBetweenContigs=0
            )

            print "# marker genes: ", len(binMarkerSets.getMarkerGenes())
            print "# genes in table: ", len(geneDistTable[testGenomeId])

            # estimate completeness of unmodified genome
            unmodifiedComp = {}
            unmodifiedCont = {}
            for ms in binMarkerSets.markerSetIter():
                hits = {}
                for mg in ms.getMarkerGenes():
                    if mg in geneDistTable[testGenomeId]:
                        hits[mg] = geneDistTable[testGenomeId][mg]
                completeness, contamination = ms.genomeCheck(hits, bIndividualMarkers=True)
                unmodifiedComp[ms.lineageStr] = completeness
                unmodifiedCont[ms.lineageStr] = contamination

            print completeness, contamination

            # estimate completion and contamination of genome after subsampling using both the domain and lineage-specific marker sets
            genomeSize = readFastaBases(os.path.join(self.img.genomeDir, testGenomeId, testGenomeId + ".fna"))
            print "genomeSize", genomeSize

            for contigLen in self.contigLens:
                for percentComp in self.percentComps:
                    for percentCont in self.percentConts:
                        deltaComp = defaultdict(list)
                        deltaCont = defaultdict(list)
                        deltaCompSet = defaultdict(list)
                        deltaContSet = defaultdict(list)

                        deltaCompRefined = defaultdict(list)
                        deltaContRefined = defaultdict(list)
                        deltaCompSetRefined = defaultdict(list)
                        deltaContSetRefined = defaultdict(list)

                        trueComps = []
                        trueConts = []

                        numDescendants = {}

                        for _ in xrange(0, numReplicates):
                            trueComp, trueCont, startPartialGenomeContigs = self.markerSetBuilder.sampleGenome(
                                genomeSize, percentComp, percentCont, contigLen
                            )
                            print contigLen, trueComp, trueCont, len(startPartialGenomeContigs)

                            trueComps.append(trueComp)
                            trueConts.append(trueCont)

                            for ms in binMarkerSets.markerSetIter():
                                numDescendants[ms.lineageStr] = ms.numGenomes

                                containedMarkerGenes = self.markerSetBuilder.containedMarkerGenes(
                                    ms.getMarkerGenes(),
                                    geneDistTable[testGenomeId],
                                    startPartialGenomeContigs,
                                    contigLen,
                                )
                                completeness, contamination = ms.genomeCheck(
                                    containedMarkerGenes, bIndividualMarkers=True
                                )
                                deltaComp[ms.lineageStr].append(completeness - trueComp)
                                deltaCont[ms.lineageStr].append(contamination - trueCont)

                                completeness, contamination = ms.genomeCheck(
                                    containedMarkerGenes, bIndividualMarkers=False
                                )
                                deltaCompSet[ms.lineageStr].append(completeness - trueComp)
                                deltaContSet[ms.lineageStr].append(contamination - trueCont)

                            for ms in refinedBinMarkerSet.markerSetIter():
                                containedMarkerGenes = self.markerSetBuilder.containedMarkerGenes(
                                    ms.getMarkerGenes(),
                                    geneDistTable[testGenomeId],
                                    startPartialGenomeContigs,
                                    contigLen,
                                )
                                completeness, contamination = ms.genomeCheck(
                                    containedMarkerGenes, bIndividualMarkers=True
                                )
                                deltaCompRefined[ms.lineageStr].append(completeness - trueComp)
                                deltaContRefined[ms.lineageStr].append(contamination - trueCont)

                                completeness, contamination = ms.genomeCheck(
                                    containedMarkerGenes, bIndividualMarkers=False
                                )
                                deltaCompSetRefined[ms.lineageStr].append(completeness - trueComp)
                                deltaContSetRefined[ms.lineageStr].append(contamination - trueCont)

                        taxonomy = ";".join(metadata[testGenomeId]["taxonomy"])
                        queueOut.put(
                            (
                                testGenomeId,
                                contigLen,
                                percentComp,
                                percentCont,
                                taxonomy,
                                numDescendants,
                                unmodifiedComp,
                                unmodifiedCont,
                                trueComps,
                                trueConts,
                                deltaComp,
                                deltaCont,
                                deltaCompSet,
                                deltaContSet,
                                deltaCompRefined,
                                deltaContRefined,
                                deltaCompSetRefined,
                                deltaContSetRefined,
                                trueComps,
                                trueConts,
                            )
                        )

    def __writerThread(self, numTestGenomes, writerQueue):
        """Store or write results of worker threads in a single thread."""

        # summaryOut = open('/tmp/simulation.draft.summary.w_refinement_50.tsv', 'w')
        summaryOut = open("/tmp/simulation.summary.testing.tsv", "w")
        summaryOut.write("Genome Id\tContig len\t% comp\t% cont")
        summaryOut.write("\tTaxonomy\tMarker set\t# descendants")
        summaryOut.write("\tUnmodified comp\tUnmodified cont\tTrue comp\tTrue cont")
        summaryOut.write("\tIM comp\tIM comp std\tIM cont\tIM cont std")
        summaryOut.write("\tMS comp\tMS comp std\tMS cont\tMS cont std")
        summaryOut.write("\tRIM comp\tRIM comp std\tRIM cont\tRIM cont std")
        summaryOut.write("\tRMS comp\tRMS comp std\tRMS cont\tRMS cont std\n")

        # fout = gzip.open('/tmp/simulation.draft.w_refinement_50.tsv.gz', 'wb')
        fout = gzip.open("/tmp/simulation.testing.tsv.gz", "wb")
        fout.write("Genome Id\tContig len\t% comp\t% cont")
        fout.write("\tTaxonomy\tMarker set\t# descendants")
        fout.write("\tUnmodified comp\tUnmodified cont\tTrue comp\tTrue cont")
        fout.write("\tIM comp\tIM cont")
        fout.write("\tMS comp\tMS cont")
        fout.write("\tRIM comp\tRIM cont")
        fout.write("\tRMS comp\tRMS cont\tTrue Comp\tTrue Cont\n")

        testsPerGenome = len(self.contigLens) * len(self.percentComps) * len(self.percentConts)

        itemsProcessed = 0
        while True:
            testGenomeId, contigLen, percentComp, percentCont, taxonomy, numDescendants, unmodifiedComp, unmodifiedCont, trueComps, trueConts, deltaComp, deltaCont, deltaCompSet, deltaContSet, deltaCompRefined, deltaContRefined, deltaCompSetRefined, deltaContSetRefined, trueComps, trueConts = writerQueue.get(
                block=True, timeout=None
            )
            if testGenomeId == None:
                break

            itemsProcessed += 1
            statusStr = "    Finished processing %d of %d (%.2f%%) test cases." % (
                itemsProcessed,
                numTestGenomes * testsPerGenome,
                float(itemsProcessed) * 100 / (numTestGenomes * testsPerGenome),
            )
            sys.stdout.write("%s\r" % statusStr)
            sys.stdout.flush()

            for markerSetId in unmodifiedComp:
                summaryOut.write(testGenomeId + "\t%d\t%.2f\t%.2f" % (contigLen, percentComp, percentCont))
                summaryOut.write("\t" + taxonomy + "\t" + markerSetId + "\t" + str(numDescendants[markerSetId]))
                summaryOut.write("\t%.3f\t%.3f" % (unmodifiedComp[markerSetId], unmodifiedCont[markerSetId]))
                summaryOut.write("\t%.3f\t%.3f" % (mean(trueComps), std(trueConts)))
                summaryOut.write("\t%.3f\t%.3f" % (mean(abs(deltaComp[markerSetId])), std(abs(deltaComp[markerSetId]))))
                summaryOut.write("\t%.3f\t%.3f" % (mean(abs(deltaCont[markerSetId])), std(abs(deltaCont[markerSetId]))))
                summaryOut.write(
                    "\t%.3f\t%.3f" % (mean(abs(deltaCompSet[markerSetId])), std(abs(deltaCompSet[markerSetId])))
                )
                summaryOut.write(
                    "\t%.3f\t%.3f" % (mean(abs(deltaContSet[markerSetId])), std(abs(deltaContSet[markerSetId])))
                )
                summaryOut.write(
                    "\t%.3f\t%.3f" % (mean(abs(deltaCompRefined[markerSetId])), std(abs(deltaCompRefined[markerSetId])))
                )
                summaryOut.write(
                    "\t%.3f\t%.3f" % (mean(abs(deltaContRefined[markerSetId])), std(abs(deltaContRefined[markerSetId])))
                )
                summaryOut.write(
                    "\t%.3f\t%.3f"
                    % (mean(abs(deltaCompSetRefined[markerSetId])), std(abs(deltaCompSetRefined[markerSetId])))
                )
                summaryOut.write(
                    "\t%.3f\t%.3f"
                    % (mean(abs(deltaContSetRefined[markerSetId])), std(abs(deltaContSetRefined[markerSetId])))
                )
                summaryOut.write("\n")

                fout.write(testGenomeId + "\t%d\t%.2f\t%.2f" % (contigLen, percentComp, percentCont))
                fout.write("\t" + taxonomy + "\t" + markerSetId + "\t" + str(numDescendants[markerSetId]))
                fout.write("\t%.3f\t%.3f" % (unmodifiedComp[markerSetId], unmodifiedCont[markerSetId]))
                fout.write("\t%s" % ",".join(map(str, trueComps)))
                fout.write("\t%s" % ",".join(map(str, trueConts)))
                fout.write("\t%s" % ",".join(map(str, deltaComp[markerSetId])))
                fout.write("\t%s" % ",".join(map(str, deltaCont[markerSetId])))
                fout.write("\t%s" % ",".join(map(str, deltaCompSet[markerSetId])))
                fout.write("\t%s" % ",".join(map(str, deltaContSet[markerSetId])))
                fout.write("\t%s" % ",".join(map(str, deltaCompRefined[markerSetId])))
                fout.write("\t%s" % ",".join(map(str, deltaContRefined[markerSetId])))
                fout.write("\t%s" % ",".join(map(str, deltaCompSetRefined[markerSetId])))
                fout.write("\t%s" % ",".join(map(str, deltaContSetRefined[markerSetId])))
                fout.write("\t%s" % ",".join(map(str, trueComps)))
                fout.write("\t%s" % ",".join(map(str, trueConts)))
                fout.write("\n")

        summaryOut.close()
        fout.close()

        sys.stdout.write("\n")

    def run(self, ubiquityThreshold, singleCopyThreshold, numReplicates, numThreads):
        print "\n  Reading reference genome tree."
        treeFile = os.path.join("/srv", "db", "checkm", "genome_tree", "genome_tree_full.refpkg", "genome_tree.tre")
        tree = dendropy.Tree.get_from_path(treeFile, schema="newick", as_rooted=True, preserve_underscores=True)

        print "    Number of taxa in tree: %d" % (len(tree.leaf_nodes()))

        genomesInTree = set()
        for leaf in tree.leaf_iter():
            genomesInTree.add(leaf.taxon.label.replace("IMG_", ""))

        # get all draft genomes for testing
        print ""
        metadata = self.img.genomeMetadata()
        print "  Total genomes: %d" % len(metadata)

        genomeIdsToTest = genomesInTree - self.img.filterGenomeIds(genomesInTree, metadata, "status", "Finished")
        print "  Number of draft genomes: %d" % len(genomeIdsToTest)

        print ""
        print "  Pre-computing genome information for calculating marker sets:"
        start = time.time()
        self.markerSetBuilder.readLineageSpecificGenesToRemove()
        end = time.time()
        print "    readLineageSpecificGenesToRemove: %.2f" % (end - start)

        start = time.time()
        # self.markerSetBuilder.cachedGeneCountTable = self.img.geneCountTable(metadata.keys())
        end = time.time()
        print "    globalGeneCountTable: %.2f" % (end - start)

        start = time.time()
        # self.markerSetBuilder.precomputeGenomeSeqLens(metadata.keys())
        end = time.time()
        print "    precomputeGenomeSeqLens: %.2f" % (end - start)

        start = time.time()
        # self.markerSetBuilder.precomputeGenomeFamilyPositions(metadata.keys(), 0)
        end = time.time()
        print "    precomputeGenomeFamilyPositions: %.2f" % (end - start)

        print ""
        print "  Evaluating %d test genomes." % len(genomeIdsToTest)
        workerQueue = mp.Queue()
        writerQueue = mp.Queue()

        for testGenomeId in genomeIdsToTest:
            workerQueue.put(testGenomeId)

        for _ in range(numThreads):
            workerQueue.put(None)

        workerProc = [
            mp.Process(
                target=self.__workerThread,
                args=(tree, metadata, ubiquityThreshold, singleCopyThreshold, numReplicates, workerQueue, writerQueue),
            )
            for _ in range(numThreads)
        ]
        writeProc = mp.Process(target=self.__writerThread, args=(len(genomeIdsToTest), writerQueue))

        writeProc.start()

        for p in workerProc:
            p.start()

        for p in workerProc:
            p.join()

        writerQueue.put(
            (
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
            )
        )
        writeProc.join()
コード例 #5
0
ファイル: markerSetBuilder.py プロジェクト: HadrienG/CheckM
class MarkerSetBuilder(object):
    def __init__(self):
        self.img = IMG('/srv/whitlam/bio/db/checkm/img/img_metadata.tsv',
                       '/srv/whitlam/bio/db/checkm/pfam/tigrfam2pfam.tsv')
        self.colocatedFile = './data/colocated.tsv'
        self.duplicateSeqs = self.readDuplicateSeqs()
        self.uniqueIdToLineageStatistics = self.__readNodeMetadata()

        self.cachedGeneCountTable = None

    def precomputeGenomeSeqLens(self, genomeIds):
        """Cache the length of contigs/scaffolds for all genomes."""

        # This function is intended to speed up functions, such as img.geneDistTable(),
        # that are called multiple times (typically during simulations)
        self.img.precomputeGenomeSeqLens(genomeIds)

    def precomputeGenomeFamilyPositions(self, genomeIds,
                                        spacingBetweenContigs):
        """Cache position of PFAM and TIGRFAM genes in genomes."""

        # This function is intended to speed up functions, such as img.geneDistTable(),
        # that are called multiple times (typically during simulations)
        self.img.precomputeGenomeFamilyPositions(genomeIds,
                                                 spacingBetweenContigs)

    def precomputeGenomeFamilyScaffolds(self, genomeIds):
        """Cache scaffolds of PFAM and TIGRFAM genes in genomes."""

        # This function is intended to speed up functions, such as img.geneDistTable(),
        # that are called multiple times (typically during simulations)
        self.img.precomputeGenomeFamilyScaffolds(genomeIds)

    def getLineageMarkerGenes(self, lineage, minGenomes=20, minMarkerSets=20):
        pfamIds = set()
        tigrIds = set()

        bHeader = True
        for line in open(self.colocatedFile):
            if bHeader:
                bHeader = False
                continue

            lineSplit = line.split('\t')
            curLineage = lineSplit[0]
            numGenomes = int(lineSplit[1])
            numMarkerSets = int(lineSplit[3])
            markerSets = lineSplit[4:]

            if curLineage != lineage or numGenomes < minGenomes or numMarkerSets < minMarkerSets:
                continue

            for ms in markerSets:
                markers = ms.split(',')
                for m in markers:
                    if 'pfam' in m:
                        pfamIds.add(m.strip())
                    elif 'TIGR' in m:
                        tigrIds.add(m.strip())

        return pfamIds, tigrIds

    def getCalculatedMarkerGenes(self, minGenomes=20, minMarkerSets=20):
        pfamIds = set()
        tigrIds = set()

        bHeader = True
        for line in open(self.colocatedFile):
            if bHeader:
                bHeader = False
                continue

            lineSplit = line.split('\t')
            numGenomes = int(lineSplit[1])
            numMarkerSets = int(lineSplit[3])
            markerSets = lineSplit[4:]

            if numGenomes < minGenomes or numMarkerSets < minMarkerSets:
                continue

            for ms in markerSets:
                markers = ms.split(',')
                for m in markers:
                    if 'pfam' in m:
                        pfamIds.add(m.strip())
                    elif 'TIGR' in m:
                        tigrIds.add(m.strip())

        return pfamIds, tigrIds

    def markerGenes(self, genomeIds, countTable, ubiquityThreshold,
                    singleCopyThreshold):
        if ubiquityThreshold < 1 or singleCopyThreshold < 1:
            print('[Warning] Looks like degenerate threshold.')

        # find genes meeting ubiquity and single-copy thresholds
        markers = set()
        for clusterId, genomeCounts in countTable.iteritems():
            ubiquity = 0
            singleCopy = 0

            if len(genomeCounts) < ubiquityThreshold:
                # gene is clearly not ubiquitous
                continue

            for genomeId in genomeIds:
                count = genomeCounts.get(genomeId, 0)

                if count > 0:
                    ubiquity += 1

                if count == 1:
                    singleCopy += 1

            if ubiquity >= ubiquityThreshold and singleCopy >= singleCopyThreshold:
                markers.add(clusterId)

        return markers

    def colocatedGenes(self,
                       geneDistTable,
                       distThreshold=5000,
                       genomeThreshold=0.95):
        """Identify co-located gene pairs."""

        colocatedGenes = defaultdict(int)
        for _, clusterIdToGeneLocs in geneDistTable.iteritems():
            clusterIds = clusterIdToGeneLocs.keys()
            for i, clusterId1 in enumerate(clusterIds):
                geneLocations1 = clusterIdToGeneLocs[clusterId1]

                for clusterId2 in clusterIds[i + 1:]:
                    geneLocations2 = clusterIdToGeneLocs[clusterId2]
                    bColocated = False
                    for p1 in geneLocations1:
                        for p2 in geneLocations2:
                            if abs(p1[0] - p2[0]) < distThreshold:
                                bColocated = True
                                break

                        if bColocated:
                            break

                    if bColocated:
                        if clusterId1 <= clusterId2:
                            colocatedStr = clusterId1 + '-' + clusterId2
                        else:
                            colocatedStr = clusterId2 + '-' + clusterId1
                        colocatedGenes[colocatedStr] += 1

        colocated = []
        for colocatedStr, count in colocatedGenes.iteritems():
            if float(count) / len(geneDistTable) > genomeThreshold:
                colocated.append(colocatedStr)

        return colocated

    def colocatedSets(self, colocatedGenes, markerGenes):
        # run through co-located genes once creating initial sets
        sets = []
        for cg in colocatedGenes:
            geneA, geneB = cg.split('-')
            sets.append(set([geneA, geneB]))

        # combine any sets with overlapping genes
        bProcessed = [False] * len(sets)
        finalSets = []
        for i in range(0, len(sets)):
            if bProcessed[i]:
                continue

            curSet = sets[i]
            bProcessed[i] = True

            bUpdated = True
            while bUpdated:
                bUpdated = False
                for j in range(i + 1, len(sets)):
                    if bProcessed[j]:
                        continue

                    if len(curSet.intersection(sets[j])) > 0:
                        curSet.update(sets[j])
                        bProcessed[j] = True
                        bUpdated = True

            finalSets.append(curSet)

        # add all singletons into colocated sets
        for clusterId in markerGenes:
            bFound = False
            for cs in finalSets:
                if clusterId in cs:
                    bFound = True

            if not bFound:
                finalSets.append(set([clusterId]))

        return finalSets

    def genomeCheck(self, colocatedSet, genomeId, countTable):
        comp = 0.0
        cont = 0.0
        missingMarkers = set()
        duplicateMarkers = set()

        if len(colocatedSet) == 0:
            return comp, cont, missingMarkers, duplicateMarkers

        for cs in colocatedSet:
            present = 0
            multiCopy = 0
            for contigId in cs:
                count = countTable[contigId].get(genomeId, 0)
                if count == 1:
                    present += 1
                elif count > 1:
                    present += 1
                    multiCopy += (count - 1)
                    duplicateMarkers.add(contigId)
                elif count == 0:
                    missingMarkers.add(contigId)

            comp += float(present) / len(cs)
            cont += float(multiCopy) / len(cs)

        return comp / len(colocatedSet), cont / len(
            colocatedSet), missingMarkers, duplicateMarkers

    def uniformity(self, genomeSize, pts):
        U = float(genomeSize) / (
            len(pts) + 1)  # distance between perfectly evenly spaced points

        # calculate distance between adjacent points
        dists = []
        pts = sorted(pts)
        for i in range(0, len(pts) - 1):
            dists.append(pts[i + 1] - pts[i])

        # calculate uniformity index
        num = 0
        den = 0
        for d in dists:
            num += abs(d - U)
            den += max(d, U)

        return 1.0 - num / den

    def sampleGenome(self, genomeLen, percentComp, percentCont, contigLen):
        """Sample a genome to simulate a given percent completion and contamination."""

        contigsInGenome = genomeLen / contigLen

        # determine number of contigs to achieve desired completeness and contamination
        contigsToSampleComp = int(contigsInGenome * percentComp + 0.5)
        contigsToSampleCont = int(contigsInGenome * percentCont + 0.5)

        # randomly sample contigs with contamination done via sampling with replacement
        compContigs = random.sample(range(contigsInGenome),
                                    contigsToSampleComp)
        contContigs = choice(range(contigsInGenome),
                             contigsToSampleCont,
                             replace=True)

        # determine start of each contig
        contigStarts = [c * contigLen for c in compContigs]
        contigStarts += [c * contigLen for c in contContigs]

        contigStarts.sort()

        trueComp = float(contigsToSampleComp) * contigLen * 100 / genomeLen
        trueCont = float(contigsToSampleCont) * contigLen * 100 / genomeLen

        return trueComp, trueCont, contigStarts

    def sampleGenomeScaffoldsInvLength(self, targetPer, seqLens, genomeSize):
        """Sample genome comprised of several sequences with probability inversely proportional to length."""

        # calculate probability of sampling a sequences
        seqProb = []
        for _, seqLen in seqLens.iteritems():
            prob = 1.0 / (float(seqLen) / genomeSize)
            seqProb.append(prob)

        seqProb = array(seqProb)
        seqProb /= sum(seqProb)

        # select sequence with probability proportional to length
        selectedSeqsIds = choice(seqLens.keys(),
                                 size=len(seqLens),
                                 replace=False,
                                 p=seqProb)

        sampledSeqIds = []
        truePer = 0.0
        for seqId in selectedSeqsIds:
            sampledSeqIds.append(seqId)
            truePer += float(seqLens[seqId]) / genomeSize

            if truePer >= targetPer:
                break

        return sampledSeqIds, truePer * 100

    def sampleGenomeScaffoldsWithoutReplacement(self, targetPer, seqLens,
                                                genomeSize):
        """Sample genome comprised of several sequences without replacement.

          Sampling is conducted randomly until the selected sequences comprise
          greater than or equal to the desired target percentage.
        """

        selectedSeqsIds = choice(seqLens.keys(),
                                 size=len(seqLens),
                                 replace=False)

        sampledSeqIds = []
        truePer = 0.0
        for seqId in selectedSeqsIds:
            sampledSeqIds.append(seqId)
            truePer += float(seqLens[seqId]) / genomeSize

            if truePer >= targetPer:
                break

        return sampledSeqIds, truePer * 100

    def containedMarkerGenes(self, markerGenes, clusterIdToGenomePositions,
                             startPartialGenomeContigs, contigLen):
        """Determine markers contained in a set of contigs."""

        contained = {}
        for markerGene in markerGenes:
            positions = clusterIdToGenomePositions.get(markerGene, [])

            containedPos = []
            for p in positions:
                for s in startPartialGenomeContigs:
                    if (p[0] - s) >= 0 and (p[0] - s) < contigLen:
                        containedPos.append(s)

            if len(containedPos) > 0:
                contained[markerGene] = containedPos

        return contained

    def markerGenesOnScaffolds(self, markerGenes, genomeId, scaffoldIds,
                               containedMarkerGenes):
        """Determine if marker genes are found on the scaffolds of a given genome."""
        for markerGeneId in markerGenes:
            scaffoldIdsWithMarker = self.img.cachedGenomeFamilyScaffolds[
                genomeId].get(markerGeneId, [])

            for scaffoldId in scaffoldIdsWithMarker:
                if scaffoldId in scaffoldIds:
                    containedMarkerGenes[markerGeneId] += [scaffoldId]

    def readDuplicateSeqs(self):
        """Parse file indicating duplicate sequence alignments."""
        duplicateSeqs = {}
        for line in open(
                os.path.join('/srv/whitlam/bio/db/checkm/genome_tree',
                             'genome_tree.derep.txt')):
            lineSplit = line.rstrip().split()
            if len(lineSplit) > 1:
                duplicateSeqs[lineSplit[0]] = lineSplit[1:]

        return duplicateSeqs

    def __readNodeMetadata(self):
        """Read metadata for internal nodes."""

        uniqueIdToLineageStatistics = {}
        metadataFile = os.path.join('/srv/whitlam/bio/db/checkm/genome_tree',
                                    'genome_tree.metadata.tsv')
        with open(metadataFile) as f:
            f.readline()
            for line in f:
                lineSplit = line.rstrip().split('\t')

                uniqueId = lineSplit[0]

                d = {}
                d['# genomes'] = int(lineSplit[1])
                d['taxonomy'] = lineSplit[2]
                try:
                    d['bootstrap'] = float(lineSplit[3])
                except:
                    d['bootstrap'] = 'NA'
                d['gc mean'] = float(lineSplit[4])
                d['gc std'] = float(lineSplit[5])
                d['genome size mean'] = float(lineSplit[6]) / 1e6
                d['genome size std'] = float(lineSplit[7]) / 1e6
                d['gene count mean'] = float(lineSplit[8])
                d['gene count std'] = float(lineSplit[9])
                d['marker set'] = lineSplit[10].rstrip()

                uniqueIdToLineageStatistics[uniqueId] = d

        return uniqueIdToLineageStatistics

    def __getNextNamedNode(self, node, uniqueIdToLineageStatistics):
        """Get first parent node with taxonomy information."""
        parentNode = node.parent_node
        while True:
            if parentNode == None:
                break  # reached the root node so terminate

            if parentNode.label:
                trustedUniqueId = parentNode.label.split('|')[0]
                trustedStats = uniqueIdToLineageStatistics[trustedUniqueId]
                if trustedStats['taxonomy'] != '':
                    return trustedStats['taxonomy']

            parentNode = parentNode.parent_node

        return 'root'

    def __refineMarkerSet(self, markerSet, lineageSpecificMarkerSet):
        """Refine marker set to account for lineage-specific gene loss and duplication."""

        # refine marker set by finding the intersection between these two sets,
        # this removes markers that are not single-copy or ubiquitous in the
        # specific lineage of a bin
        # Note: co-localization information is taken from the trusted set

        # remove genes not present in the lineage-specific gene set
        finalMarkerSet = []
        for ms in markerSet.markerSet:
            s = set()
            for gene in ms:
                if gene in lineageSpecificMarkerSet.getMarkerGenes():
                    s.add(gene)

            if s:
                finalMarkerSet.append(s)

        refinedMarkerSet = MarkerSet(markerSet.UID, markerSet.lineageStr,
                                     markerSet.numGenomes, finalMarkerSet)

        return refinedMarkerSet

    def ____removeInvalidLineageMarkerGenes(self, markerSet,
                                            lineageSpecificMarkersToRemove):
        """Refine marker set to account for lineage-specific gene loss and duplication."""

        # refine marker set by removing marker genes subject to lineage-specific
        # gene loss and duplication
        #
        # Note: co-localization information is taken from the trusted set

        finalMarkerSet = []
        for ms in markerSet.markerSet:
            s = set()
            for gene in ms:
                if gene.startswith('PF'):
                    print('ERROR! Expected genes to start with pfam, not PF.')

                if gene not in lineageSpecificMarkersToRemove:
                    s.add(gene)

            if s:
                finalMarkerSet.append(s)

        refinedMarkerSet = MarkerSet(markerSet.UID, markerSet.lineageStr,
                                     markerSet.numGenomes, finalMarkerSet)

        return refinedMarkerSet

    def missingGenes(self, genomeIds, markerGenes, ubiquityThreshold):
        """Inferring consistently missing marker genes within a set of genomes."""

        if self.cachedGeneCountTable != None:
            geneCountTable = self.cachedGeneCountTable
        else:
            geneCountTable = self.img.geneCountTable(genomeIds)

        # find genes meeting ubiquity and single-copy thresholds
        missing = set()
        for clusterId, genomeCounts in geneCountTable.iteritems():
            if clusterId not in markerGenes:
                continue

            absence = 0
            for genomeId in genomeIds:
                count = genomeCounts.get(genomeId, 0)

                if count == 0:
                    absence += 1

            if absence >= ubiquityThreshold * len(genomeIds):
                missing.add(clusterId)

        return missing

    def duplicateGenes(self, genomeIds, markerGenes, ubiquityThreshold):
        """Inferring consistently duplicated marker genes within a set of genomes."""

        if self.cachedGeneCountTable != None:
            geneCountTable = self.cachedGeneCountTable
        else:
            geneCountTable = self.img.geneCountTable(genomeIds)

        # find genes meeting ubiquity and single-copy thresholds
        duplicate = set()
        for clusterId, genomeCounts in geneCountTable.iteritems():
            if clusterId not in markerGenes:
                continue

            duplicateCount = 0
            for genomeId in genomeIds:
                count = genomeCounts.get(genomeId, 0)

                if count > 1:
                    duplicateCount += 1

            if duplicateCount >= ubiquityThreshold * len(genomeIds):
                duplicate.add(clusterId)

        return duplicate

    def buildMarkerGenes(self, genomeIds, ubiquityThreshold,
                         singleCopyThreshold):
        """Infer marker genes from specified genomes."""

        if self.cachedGeneCountTable != None:
            geneCountTable = self.cachedGeneCountTable
        else:
            geneCountTable = self.img.geneCountTable(genomeIds)

        #counts = []
        #singleCopy = 0
        #for genomeId, count in geneCountTable['pfam01351'].iteritems():
        #    print genomeId, count
        #    counts.append(count)
        #    if count == 1:
        #        singleCopy += 1

        #print 'Ubiquity: %d of %d' % (len(counts), len(genomeIds))
        #print 'Single-copy: %d of %d' % (singleCopy, len(genomeIds))
        #print 'Mean: %.2f' % mean(counts)

        markerGenes = self.markerGenes(genomeIds, geneCountTable,
                                       ubiquityThreshold * len(genomeIds),
                                       singleCopyThreshold * len(genomeIds))
        tigrToRemove = self.img.identifyRedundantTIGRFAMs(markerGenes)
        markerGenes = markerGenes - tigrToRemove

        return markerGenes

    def buildMarkerSet(self,
                       genomeIds,
                       ubiquityThreshold,
                       singleCopyThreshold,
                       spacingBetweenContigs=5000):
        """Infer marker set from specified genomes."""

        markerGenes = self.buildMarkerGenes(genomeIds, ubiquityThreshold,
                                            singleCopyThreshold)

        geneDistTable = self.img.geneDistTable(genomeIds, markerGenes,
                                               spacingBetweenContigs)
        colocatedGenes = self.colocatedGenes(geneDistTable)
        colocatedSets = self.colocatedSets(colocatedGenes, markerGenes)
        markerSet = MarkerSet(0, 'NA', len(genomeIds), colocatedSets)

        return markerSet

    def readLineageSpecificGenesToRemove(self):
        """Get set of genes subject to lineage-specific gene loss and duplication."""

        self.lineageSpecificGenesToRemove = {}
        for line in open(
                '/srv/whitlam/bio/db/checkm/genome_tree/missing_duplicate_genes_50.tsv'
        ):
            lineSplit = line.split('\t')
            uid = lineSplit[0]
            missingGenes = eval(lineSplit[1])
            duplicateGenes = eval(lineSplit[2])
            self.lineageSpecificGenesToRemove[uid] = missingGenes.union(
                duplicateGenes)

    def buildBinMarkerSet(self,
                          tree,
                          curNode,
                          ubiquityThreshold,
                          singleCopyThreshold,
                          bMarkerSet=True,
                          genomeIdsToRemove=None):
        """Build lineage-specific marker sets for a genome in a LOO-fashion."""

        # determine marker sets for bin
        binMarkerSets = BinMarkerSets(curNode.label,
                                      BinMarkerSets.TREE_MARKER_SET)
        refinedBinMarkerSet = BinMarkerSets(curNode.label,
                                            BinMarkerSets.TREE_MARKER_SET)

        # ascend tree to root, recording all marker sets
        uniqueId = curNode.label.split('|')[0]
        lineageSpecificRefinement = self.lineageSpecificGenesToRemove[uniqueId]

        while curNode != None:
            uniqueId = curNode.label.split('|')[0]
            stats = self.uniqueIdToLineageStatistics[uniqueId]
            taxonomyStr = stats['taxonomy']
            if taxonomyStr == '':
                taxonomyStr = self.__getNextNamedNode(
                    curNode, self.uniqueIdToLineageStatistics)

            leafNodes = curNode.leaf_nodes()
            genomeIds = set()
            for leaf in leafNodes:
                genomeIds.add(leaf.taxon.label.replace('IMG_', ''))

                duplicateGenomes = self.duplicateSeqs.get(leaf.taxon.label, [])
                for dup in duplicateGenomes:
                    genomeIds.add(dup.replace('IMG_', ''))

            # remove all genomes from the same taxonomic group as the genome of interest
            if genomeIdsToRemove != None:
                genomeIds.difference_update(genomeIdsToRemove)

            if len(genomeIds) >= 2:
                if bMarkerSet:
                    markerSet = self.buildMarkerSet(genomeIds,
                                                    ubiquityThreshold,
                                                    singleCopyThreshold)
                else:
                    markerSet = MarkerSet(0, 'NA', len(genomeIds), [
                        self.buildMarkerGenes(genomeIds, ubiquityThreshold,
                                              singleCopyThreshold)
                    ])

                markerSet.lineageStr = uniqueId + ' | ' + taxonomyStr.split(
                    ';')[-1]
                binMarkerSets.addMarkerSet(markerSet)

                #refinedMarkerSet = self.__refineMarkerSet(markerSet, lineageSpecificMarkerSet)
                refinedMarkerSet = self.____removeInvalidLineageMarkerGenes(
                    markerSet, lineageSpecificRefinement)
                #print 'Refinement: %d of %d' % (len(refinedMarkerSet.getMarkerGenes()), len(markerSet.getMarkerGenes()))
                refinedBinMarkerSet.addMarkerSet(refinedMarkerSet)

            curNode = curNode.parent_node

        return binMarkerSets, refinedBinMarkerSet

    def buildDomainMarkerSet(self,
                             tree,
                             curNode,
                             ubiquityThreshold,
                             singleCopyThreshold,
                             bMarkerSet=True,
                             genomeIdsToRemove=None):
        """Build domain-specific marker sets for a genome in a LOO-fashion."""

        # determine marker sets for bin
        binMarkerSets = BinMarkerSets(curNode.label,
                                      BinMarkerSets.TREE_MARKER_SET)
        refinedBinMarkerSet = BinMarkerSets(curNode.label,
                                            BinMarkerSets.TREE_MARKER_SET)

        # calculate marker set for bacterial or archaeal node
        uniqueId = curNode.label.split('|')[0]
        lineageSpecificRefinement = self.lineageSpecificGenesToRemove[uniqueId]

        while curNode != None:
            uniqueId = curNode.label.split('|')[0]
            if uniqueId != 'UID2' and uniqueId != 'UID203':
                curNode = curNode.parent_node
                continue

            stats = self.uniqueIdToLineageStatistics[uniqueId]
            taxonomyStr = stats['taxonomy']
            if taxonomyStr == '':
                taxonomyStr = self.__getNextNamedNode(
                    curNode, self.uniqueIdToLineageStatistics)

            leafNodes = curNode.leaf_nodes()
            genomeIds = set()
            for leaf in leafNodes:
                genomeIds.add(leaf.taxon.label.replace('IMG_', ''))

                duplicateGenomes = self.duplicateSeqs.get(leaf.taxon.label, [])
                for dup in duplicateGenomes:
                    genomeIds.add(dup.replace('IMG_', ''))

            # remove all genomes from the same taxonomic group as the genome of interest
            if genomeIdsToRemove != None:
                genomeIds.difference_update(genomeIdsToRemove)

            if len(genomeIds) >= 2:
                if bMarkerSet:
                    markerSet = self.buildMarkerSet(genomeIds,
                                                    ubiquityThreshold,
                                                    singleCopyThreshold)
                else:
                    markerSet = MarkerSet(0, 'NA', len(genomeIds), [
                        self.buildMarkerGenes(genomeIds, ubiquityThreshold,
                                              singleCopyThreshold)
                    ])

                markerSet.lineageStr = uniqueId + ' | ' + taxonomyStr.split(
                    ';')[-1]
                binMarkerSets.addMarkerSet(markerSet)

                #refinedMarkerSet = self.__refineMarkerSet(markerSet, lineageSpecificMarkerSet)
                refinedMarkerSet = self.____removeInvalidLineageMarkerGenes(
                    markerSet, lineageSpecificRefinement)
                #print 'Refinement: %d of %d' % (len(refinedMarkerSet.getMarkerGenes()), len(markerSet.getMarkerGenes()))
                refinedBinMarkerSet.addMarkerSet(refinedMarkerSet)

            curNode = curNode.parent_node

        return binMarkerSets, refinedBinMarkerSet
コード例 #6
0
class MarkerSetBuilder(object):
    def __init__(self):
        self.img = IMG('/srv/whitlam/bio/db/checkm/img/img_metadata.tsv', '/srv/whitlam/bio/db/checkm/pfam/tigrfam2pfam.tsv')
        self.colocatedFile = './data/colocated.tsv'
        self.duplicateSeqs = self.readDuplicateSeqs()
        self.uniqueIdToLineageStatistics = self.__readNodeMetadata()
        
        self.cachedGeneCountTable = None
        
    def precomputeGenomeSeqLens(self, genomeIds):
        """Cache the length of contigs/scaffolds for all genomes."""
        
        # This function is intended to speed up functions, such as img.geneDistTable(),
        # that are called multiple times (typically during simulations)
        self.img.precomputeGenomeSeqLens(genomeIds)
            
    def precomputeGenomeFamilyPositions(self, genomeIds, spacingBetweenContigs):
        """Cache position of PFAM and TIGRFAM genes in genomes."""
        
        # This function is intended to speed up functions, such as img.geneDistTable(),
        # that are called multiple times (typically during simulations)
        self.img.precomputeGenomeFamilyPositions(genomeIds, spacingBetweenContigs)
        
    def precomputeGenomeFamilyScaffolds(self, genomeIds):
        """Cache scaffolds of PFAM and TIGRFAM genes in genomes."""
        
        # This function is intended to speed up functions, such as img.geneDistTable(),
        # that are called multiple times (typically during simulations)
        self.img.precomputeGenomeFamilyScaffolds(genomeIds)
        
    def getLineageMarkerGenes(self, lineage, minGenomes = 20, minMarkerSets = 20):
        pfamIds = set()
        tigrIds = set()

        bHeader = True
        for line in open(self.colocatedFile):
            if bHeader:
                bHeader = False
                continue

            lineSplit = line.split('\t')
            curLineage = lineSplit[0]
            numGenomes = int(lineSplit[1])
            numMarkerSets = int(lineSplit[3])
            markerSets = lineSplit[4:]

            if curLineage != lineage or numGenomes < minGenomes or numMarkerSets < minMarkerSets:
                continue

            for ms in markerSets:
                markers = ms.split(',')
                for m in markers:
                    if 'pfam' in m:
                        pfamIds.add(m.strip())
                    elif 'TIGR' in m:
                        tigrIds.add(m.strip())

        return pfamIds, tigrIds

    def getCalculatedMarkerGenes(self, minGenomes = 20, minMarkerSets = 20):
        pfamIds = set()
        tigrIds = set()

        bHeader = True
        for line in open(self.colocatedFile):
            if bHeader:
                bHeader = False
                continue

            lineSplit = line.split('\t')
            numGenomes = int(lineSplit[1])
            numMarkerSets = int(lineSplit[3])
            markerSets = lineSplit[4:]

            if numGenomes < minGenomes or numMarkerSets < minMarkerSets:
                continue

            for ms in markerSets:
                markers = ms.split(',')
                for m in markers:
                    if 'pfam' in m:
                        pfamIds.add(m.strip())
                    elif 'TIGR' in m:
                        tigrIds.add(m.strip())

        return pfamIds, tigrIds

    def markerGenes(self, genomeIds, countTable, ubiquityThreshold, singleCopyThreshold):
        if ubiquityThreshold < 1 or singleCopyThreshold < 1:
            print '[Warning] Looks like degenerate threshold.'

        # find genes meeting ubiquity and single-copy thresholds
        markers = set()
        for clusterId, genomeCounts in countTable.iteritems():
            ubiquity = 0
            singleCopy = 0
            
            if len(genomeCounts) < ubiquityThreshold:
                # gene is clearly not ubiquitous
                continue
            
            for genomeId in genomeIds:
                count = genomeCounts.get(genomeId, 0)

                if count > 0:
                    ubiquity += 1

                if count == 1:
                    singleCopy += 1

            if ubiquity >= ubiquityThreshold and singleCopy >= singleCopyThreshold:
                markers.add(clusterId)

        return markers

    def colocatedGenes(self, geneDistTable, distThreshold = 5000, genomeThreshold = 0.95):
        """Identify co-located gene pairs."""
                
        colocatedGenes = defaultdict(int)
        for _, clusterIdToGeneLocs in geneDistTable.iteritems():
            clusterIds = clusterIdToGeneLocs.keys()
            for i, clusterId1 in enumerate(clusterIds):
                geneLocations1 = clusterIdToGeneLocs[clusterId1]
                
                for clusterId2 in clusterIds[i+1:]:
                    geneLocations2 = clusterIdToGeneLocs[clusterId2]
                    bColocated = False
                    for p1 in geneLocations1:
                        for p2 in geneLocations2:
                            if abs(p1[0] - p2[0]) < distThreshold:
                                bColocated = True
                                break

                        if bColocated:
                            break

                    if bColocated:
                        if clusterId1 <= clusterId2:
                            colocatedStr = clusterId1 + '-' + clusterId2
                        else:
                            colocatedStr = clusterId2 + '-' + clusterId1
                        colocatedGenes[colocatedStr] += 1

        colocated = []
        for colocatedStr, count in colocatedGenes.iteritems():
            if float(count)/len(geneDistTable) > genomeThreshold:
                colocated.append(colocatedStr)

        return colocated

    def colocatedSets(self, colocatedGenes, markerGenes):
        # run through co-located genes once creating initial sets
        sets = []
        for cg in colocatedGenes:
            geneA, geneB = cg.split('-')
            sets.append(set([geneA, geneB]))

        # combine any sets with overlapping genes
        bProcessed = [False]*len(sets)
        finalSets = []
        for i in xrange(0, len(sets)):
            if bProcessed[i]:
                continue

            curSet = sets[i]
            bProcessed[i] = True

            bUpdated = True
            while bUpdated:
                bUpdated = False
                for j in xrange(i+1, len(sets)):
                    if bProcessed[j]:
                        continue

                    if len(curSet.intersection(sets[j])) > 0:
                        curSet.update(sets[j])
                        bProcessed[j] = True
                        bUpdated = True

            finalSets.append(curSet)

        # add all singletons into colocated sets
        for clusterId in markerGenes:
            bFound = False
            for cs in finalSets:
                if clusterId in cs:
                    bFound = True

            if not bFound:
                finalSets.append(set([clusterId]))

        return finalSets

    def genomeCheck(self, colocatedSet, genomeId, countTable):
        comp = 0.0
        cont = 0.0
        missingMarkers = set()
        duplicateMarkers = set()

        if len(colocatedSet) == 0:
            return comp, cont, missingMarkers, duplicateMarkers

        for cs in colocatedSet:
            present = 0
            multiCopy = 0
            for contigId in cs:
                count = countTable[contigId].get(genomeId, 0)
                if count == 1:
                    present += 1
                elif count > 1:
                    present += 1
                    multiCopy += (count-1)
                    duplicateMarkers.add(contigId)
                elif count == 0:
                    missingMarkers.add(contigId)

            comp += float(present) / len(cs)
            cont += float(multiCopy) / len(cs)

        return comp / len(colocatedSet), cont / len(colocatedSet), missingMarkers, duplicateMarkers

    def uniformity(self, genomeSize, pts):
        U = float(genomeSize) / (len(pts)+1)  # distance between perfectly evenly spaced points

        # calculate distance between adjacent points
        dists = []
        pts = sorted(pts)
        for i in xrange(0, len(pts)-1):
            dists.append(pts[i+1] - pts[i])

        # calculate uniformity index
        num = 0
        den = 0
        for d in dists:
            num += abs(d - U)
            den += max(d, U)

        return 1.0 - num/den
    
    def sampleGenome(self, genomeLen, percentComp, percentCont, contigLen):
        """Sample a genome to simulate a given percent completion and contamination."""
        
        contigsInGenome = genomeLen / contigLen

        # determine number of contigs to achieve desired completeness and contamination
        contigsToSampleComp = int(contigsInGenome*percentComp + 0.5)
        contigsToSampleCont = int(contigsInGenome*percentCont + 0.5)

        # randomly sample contigs with contamination done via sampling with replacement
        compContigs = random.sample(xrange(contigsInGenome), contigsToSampleComp)  
        contContigs = choice(xrange(contigsInGenome), contigsToSampleCont, replace=True)
    
        # determine start of each contig
        contigStarts = [c*contigLen for c in compContigs]
        contigStarts += [c*contigLen for c in contContigs]
            
        contigStarts.sort()
        
        trueComp = float(contigsToSampleComp)*contigLen*100 / genomeLen
        trueCont = float(contigsToSampleCont)*contigLen*100 / genomeLen

        return trueComp, trueCont, contigStarts
    
    def sampleGenomeScaffoldsInvLength(self, targetPer, seqLens, genomeSize):
        """Sample genome comprised of several sequences with probability inversely proportional to length."""
        
        # calculate probability of sampling a sequences
        seqProb = []
        for _, seqLen in seqLens.iteritems():
            prob = 1.0 / (float(seqLen) / genomeSize)
            seqProb.append(prob)
            
        seqProb = array(seqProb)
        seqProb /= sum(seqProb)
            
        # select sequence with probability proportional to length
        selectedSeqsIds = choice(seqLens.keys(), size = len(seqLens), replace=False, p = seqProb)
        
        sampledSeqIds = []
        truePer = 0.0
        for seqId in selectedSeqsIds:
            sampledSeqIds.append(seqId)
            truePer += float(seqLens[seqId]) / genomeSize
            
            if truePer >= targetPer:
                break
        
        return sampledSeqIds, truePer*100
    
    def sampleGenomeScaffoldsWithoutReplacement(self, targetPer, seqLens, genomeSize):
        """Sample genome comprised of several sequences without replacement.
        
          Sampling is conducted randomly until the selected sequences comprise
          greater than or equal to the desired target percentage.
        """
 
        selectedSeqsIds = choice(seqLens.keys(), size = len(seqLens), replace=False)
        
        sampledSeqIds = []
        truePer = 0.0
        for seqId in selectedSeqsIds:
            sampledSeqIds.append(seqId)
            truePer += float(seqLens[seqId]) / genomeSize
            
            if truePer >= targetPer:
                break
        
        return sampledSeqIds, truePer*100
    
    def containedMarkerGenes(self, markerGenes, clusterIdToGenomePositions, startPartialGenomeContigs, contigLen):
        """Determine markers contained in a set of contigs."""
        
        contained = {}
        for markerGene in markerGenes:
            positions = clusterIdToGenomePositions.get(markerGene, [])

            containedPos = []
            for p in positions:
                for s in startPartialGenomeContigs:
                    if (p[0] - s) >= 0 and (p[0] - s) < contigLen:
                        containedPos.append(s)

            if len(containedPos) > 0:
                contained[markerGene] = containedPos

        return contained
    
    def markerGenesOnScaffolds(self, markerGenes, genomeId, scaffoldIds, containedMarkerGenes):
        """Determine if marker genes are found on the scaffolds of a given genome."""
        for markerGeneId in markerGenes:
            scaffoldIdsWithMarker = self.img.cachedGenomeFamilyScaffolds[genomeId].get(markerGeneId, [])

            for scaffoldId in scaffoldIdsWithMarker:
                if scaffoldId in scaffoldIds:
                    containedMarkerGenes[markerGeneId] += [scaffoldId]
        
    def readDuplicateSeqs(self):
        """Parse file indicating duplicate sequence alignments."""
        duplicateSeqs = {}
        for line in open(os.path.join('/srv/whitlam/bio/db/checkm/genome_tree', 'genome_tree.derep.txt')):
            lineSplit = line.rstrip().split()
            if len(lineSplit) > 1:
                duplicateSeqs[lineSplit[0]] = lineSplit[1:]
                
        return duplicateSeqs
        
    def __readNodeMetadata(self):
        """Read metadata for internal nodes."""
        
        uniqueIdToLineageStatistics = {}
        metadataFile = os.path.join('/srv/whitlam/bio/db/checkm/genome_tree', 'genome_tree.metadata.tsv')
        with open(metadataFile) as f:
            f.readline()
            for line in f:
                lineSplit = line.rstrip().split('\t')
                
                uniqueId = lineSplit[0]
                
                d = {}
                d['# genomes'] = int(lineSplit[1])
                d['taxonomy'] = lineSplit[2]
                try:
                    d['bootstrap'] = float(lineSplit[3])
                except:
                    d['bootstrap'] = 'NA'                 
                d['gc mean'] = float(lineSplit[4])
                d['gc std'] = float(lineSplit[5])
                d['genome size mean'] = float(lineSplit[6])/1e6
                d['genome size std'] = float(lineSplit[7])/1e6
                d['gene count mean'] = float(lineSplit[8])
                d['gene count std'] = float(lineSplit[9])
                d['marker set'] = lineSplit[10].rstrip()
                
                uniqueIdToLineageStatistics[uniqueId] = d
                
        return uniqueIdToLineageStatistics
    
    def __getNextNamedNode(self, node, uniqueIdToLineageStatistics):
        """Get first parent node with taxonomy information."""
        parentNode = node.parent_node
        while True:
            if parentNode == None:
                break # reached the root node so terminate
            
            if parentNode.label:
                trustedUniqueId = parentNode.label.split('|')[0]
                trustedStats = uniqueIdToLineageStatistics[trustedUniqueId]
                if trustedStats['taxonomy'] != '':
                    return trustedStats['taxonomy']
                
            parentNode = parentNode.parent_node            
                    
        return 'root'
    
    def __refineMarkerSet(self, markerSet, lineageSpecificMarkerSet):
        """Refine marker set to account for lineage-specific gene loss and duplication."""
                                        
        # refine marker set by finding the intersection between these two sets,
        # this removes markers that are not single-copy or ubiquitous in the 
        # specific lineage of a bin
        # Note: co-localization information is taken from the trusted set
                
        # remove genes not present in the lineage-specific gene set
        finalMarkerSet = []
        for ms in markerSet.markerSet:
            s = set()
            for gene in ms:
                if gene in lineageSpecificMarkerSet.getMarkerGenes():
                    s.add(gene)
                           
            if s:
                finalMarkerSet.append(s)

        refinedMarkerSet = MarkerSet(markerSet.UID, markerSet.lineageStr, markerSet.numGenomes, finalMarkerSet)
    
        return refinedMarkerSet
    
    def ____removeInvalidLineageMarkerGenes(self, markerSet, lineageSpecificMarkersToRemove):
        """Refine marker set to account for lineage-specific gene loss and duplication."""
                                        
        # refine marker set by removing marker genes subject to lineage-specific
        # gene loss and duplication 
        #
        # Note: co-localization information is taken from the trusted set
                
        finalMarkerSet = []
        for ms in markerSet.markerSet:
            s = set()
            for gene in ms:
                if gene.startswith('PF'):
                    print 'ERROR! Expected genes to start with pfam, not PF.'
                    
                if gene not in lineageSpecificMarkersToRemove:
                    s.add(gene)
                           
            if s:
                finalMarkerSet.append(s)

        refinedMarkerSet = MarkerSet(markerSet.UID, markerSet.lineageStr, markerSet.numGenomes, finalMarkerSet)
    
        return refinedMarkerSet
    
    def missingGenes(self, genomeIds, markerGenes, ubiquityThreshold):
        """Inferring consistently missing marker genes within a set of genomes."""
        
        if self.cachedGeneCountTable != None:
            geneCountTable = self.cachedGeneCountTable
        else:
            geneCountTable = self.img.geneCountTable(genomeIds)
        
        # find genes meeting ubiquity and single-copy thresholds
        missing = set()
        for clusterId, genomeCounts in geneCountTable.iteritems():
            if clusterId not in markerGenes:
                continue
                 
            absence = 0 
            for genomeId in genomeIds:
                count = genomeCounts.get(genomeId, 0)

                if count == 0:
                    absence += 1

            if absence >= ubiquityThreshold*len(genomeIds):
                missing.add(clusterId)

        return missing
    
    def duplicateGenes(self, genomeIds, markerGenes, ubiquityThreshold):
        """Inferring consistently duplicated marker genes within a set of genomes."""
        
        if self.cachedGeneCountTable != None:
            geneCountTable = self.cachedGeneCountTable
        else:
            geneCountTable = self.img.geneCountTable(genomeIds)
        
        # find genes meeting ubiquity and single-copy thresholds
        duplicate = set()
        for clusterId, genomeCounts in geneCountTable.iteritems():
            if clusterId not in markerGenes:
                continue
                 
            duplicateCount = 0 
            for genomeId in genomeIds:
                count = genomeCounts.get(genomeId, 0)

                if count > 1:
                    duplicateCount += 1

            if duplicateCount >= ubiquityThreshold*len(genomeIds):
                duplicate.add(clusterId)

        return duplicate
    
    def buildMarkerGenes(self, genomeIds, ubiquityThreshold, singleCopyThreshold):
        """Infer marker genes from specified genomes."""
        
        if self.cachedGeneCountTable != None:
            geneCountTable = self.cachedGeneCountTable
        else:
            geneCountTable = self.img.geneCountTable(genomeIds)
        
        #counts = []
        #singleCopy = 0
        #for genomeId, count in geneCountTable['pfam01351'].iteritems():
        #    print genomeId, count
        #    counts.append(count)
        #    if count == 1:
        #        singleCopy += 1
            
        #print 'Ubiquity: %d of %d' % (len(counts), len(genomeIds))
        #print 'Single-copy: %d of %d' % (singleCopy, len(genomeIds))
        #print 'Mean: %.2f' % mean(counts)

        markerGenes = self.markerGenes(genomeIds, geneCountTable, ubiquityThreshold*len(genomeIds), singleCopyThreshold*len(genomeIds))
        tigrToRemove = self.img.identifyRedundantTIGRFAMs(markerGenes)
        markerGenes = markerGenes - tigrToRemove

        return markerGenes
    
    def buildMarkerSet(self, genomeIds, ubiquityThreshold, singleCopyThreshold, spacingBetweenContigs = 5000):  
        """Infer marker set from specified genomes."""      

        markerGenes = self.buildMarkerGenes(genomeIds, ubiquityThreshold, singleCopyThreshold)

        geneDistTable = self.img.geneDistTable(genomeIds, markerGenes, spacingBetweenContigs)
        colocatedGenes = self.colocatedGenes(geneDistTable)
        colocatedSets = self.colocatedSets(colocatedGenes, markerGenes)
        markerSet = MarkerSet(0, 'NA', len(genomeIds), colocatedSets)

        return markerSet
    
    def readLineageSpecificGenesToRemove(self):
        """Get set of genes subject to lineage-specific gene loss and duplication."""       
    
        self.lineageSpecificGenesToRemove = {}
        for line in open('/srv/whitlam/bio/db/checkm/genome_tree/missing_duplicate_genes_50.tsv'):
            lineSplit = line.split('\t')
            uid = lineSplit[0]
            missingGenes = eval(lineSplit[1])
            duplicateGenes = eval(lineSplit[2])
            self.lineageSpecificGenesToRemove[uid] = missingGenes.union(duplicateGenes)
            
    def buildBinMarkerSet(self, tree, curNode, ubiquityThreshold, singleCopyThreshold, bMarkerSet = True, genomeIdsToRemove = None):   
        """Build lineage-specific marker sets for a genome in a LOO-fashion."""
                               
        # determine marker sets for bin      
        binMarkerSets = BinMarkerSets(curNode.label, BinMarkerSets.TREE_MARKER_SET)
        refinedBinMarkerSet = BinMarkerSets(curNode.label, BinMarkerSets.TREE_MARKER_SET)         

        # ascend tree to root, recording all marker sets 
        uniqueId = curNode.label.split('|')[0] 
        lineageSpecificRefinement = self.lineageSpecificGenesToRemove[uniqueId]
        
        while curNode != None:
            uniqueId = curNode.label.split('|')[0] 
            stats = self.uniqueIdToLineageStatistics[uniqueId]
            taxonomyStr = stats['taxonomy']
            if taxonomyStr == '':
                taxonomyStr = self.__getNextNamedNode(curNode, self.uniqueIdToLineageStatistics)

            leafNodes = curNode.leaf_nodes()
            genomeIds = set()
            for leaf in leafNodes:
                genomeIds.add(leaf.taxon.label.replace('IMG_', ''))
                
                duplicateGenomes = self.duplicateSeqs.get(leaf.taxon.label, [])
                for dup in duplicateGenomes:
                    genomeIds.add(dup.replace('IMG_', ''))

            # remove all genomes from the same taxonomic group as the genome of interest
            if genomeIdsToRemove != None:
                genomeIds.difference_update(genomeIdsToRemove) 

            if len(genomeIds) >= 2:
                if bMarkerSet:
                    markerSet = self.buildMarkerSet(genomeIds, ubiquityThreshold, singleCopyThreshold)
                else:
                    markerSet = MarkerSet(0, 'NA', len(genomeIds), [self.buildMarkerGenes(genomeIds, ubiquityThreshold, singleCopyThreshold)])
                
                markerSet.lineageStr = uniqueId + ' | ' + taxonomyStr.split(';')[-1]
                binMarkerSets.addMarkerSet(markerSet)
        
                #refinedMarkerSet = self.__refineMarkerSet(markerSet, lineageSpecificMarkerSet)
                refinedMarkerSet = self.____removeInvalidLineageMarkerGenes(markerSet, lineageSpecificRefinement)
                #print 'Refinement: %d of %d' % (len(refinedMarkerSet.getMarkerGenes()), len(markerSet.getMarkerGenes()))
                refinedBinMarkerSet.addMarkerSet(refinedMarkerSet)
            
            curNode = curNode.parent_node
                
        return binMarkerSets, refinedBinMarkerSet
    
    def buildDomainMarkerSet(self, tree, curNode, ubiquityThreshold, singleCopyThreshold, bMarkerSet = True, genomeIdsToRemove = None):   
        """Build domain-specific marker sets for a genome in a LOO-fashion."""
                               
        # determine marker sets for bin      
        binMarkerSets = BinMarkerSets(curNode.label, BinMarkerSets.TREE_MARKER_SET)
        refinedBinMarkerSet = BinMarkerSets(curNode.label, BinMarkerSets.TREE_MARKER_SET)         

        # calculate marker set for bacterial or archaeal node
        uniqueId = curNode.label.split('|')[0] 
        lineageSpecificRefinement = self.lineageSpecificGenesToRemove[uniqueId]
        
        while curNode != None:
            uniqueId = curNode.label.split('|')[0] 
            if uniqueId != 'UID2' and uniqueId != 'UID203':
                curNode = curNode.parent_node
                continue

            stats = self.uniqueIdToLineageStatistics[uniqueId]
            taxonomyStr = stats['taxonomy']
            if taxonomyStr == '':
                taxonomyStr = self.__getNextNamedNode(curNode, self.uniqueIdToLineageStatistics)

            leafNodes = curNode.leaf_nodes()
            genomeIds = set()
            for leaf in leafNodes:
                genomeIds.add(leaf.taxon.label.replace('IMG_', ''))
                
                duplicateGenomes = self.duplicateSeqs.get(leaf.taxon.label, [])
                for dup in duplicateGenomes:
                    genomeIds.add(dup.replace('IMG_', ''))

            # remove all genomes from the same taxonomic group as the genome of interest
            if genomeIdsToRemove != None:
                genomeIds.difference_update(genomeIdsToRemove) 

            if len(genomeIds) >= 2:
                if bMarkerSet:
                    markerSet = self.buildMarkerSet(genomeIds, ubiquityThreshold, singleCopyThreshold)
                else:
                    markerSet = MarkerSet(0, 'NA', len(genomeIds), [self.buildMarkerGenes(genomeIds, ubiquityThreshold, singleCopyThreshold)])
                
                markerSet.lineageStr = uniqueId + ' | ' + taxonomyStr.split(';')[-1]
                binMarkerSets.addMarkerSet(markerSet)
        
                #refinedMarkerSet = self.__refineMarkerSet(markerSet, lineageSpecificMarkerSet)
                refinedMarkerSet = self.____removeInvalidLineageMarkerGenes(markerSet, lineageSpecificRefinement)
                #print 'Refinement: %d of %d' % (len(refinedMarkerSet.getMarkerGenes()), len(markerSet.getMarkerGenes()))
                refinedBinMarkerSet.addMarkerSet(refinedMarkerSet)
            
            curNode = curNode.parent_node
                
        return binMarkerSets, refinedBinMarkerSet