Example #1
0
    def run(self):
        AbstractAnalysis.run(self) #Call base method to do some logging
        refSequences = getFastaDictionary(self.referenceFastaFile) #Hash of names to sequences
        readSequences = getFastqDictionary(self.readFastqFile) #Hash of names to sequences
        sam = pysam.Samfile(self.samFile, "r" )
        indelCounters = map(lambda aR : IndelCounter(sam.getrname(aR.rname), refSequences[sam.getrname(aR.rname)], aR.qname, readSequences[aR.qname], aR), samIterator(sam)) #Iterate on the sam lines
        sam.close()
        #Write out the substitution info
        if len(indelCounters) > 0:
            indelXML = getAggregateIndelStats(indelCounters)
            open(os.path.join(self.outputDir, "indels.xml"), "w").write(prettyXml(indelXML))
            tmp = open(os.path.join(self.outputDir, "indels.tsv"), "w")
            #build list of data as vectors
            data_list = []
            var = ["readInsertionLengths", "readDeletionLengths", "ReadSequenceLengths", "NumberReadInsertions", "NumberReadDeletions", "MedianReadInsertionLengths", "MedianReadDeletionLengths"]
            for x in var:
                data_list.append([x] + indelXML.attrib[x].split())
            #transpose this list so R doesn't take hours to load it using magic
            data_list = map(None, *data_list)
            for line in data_list:
                tmp.write("\t".join(map(str,line))); tmp.write("\n")
            tmp.close()
            system("Rscript nanopore/analyses/indelPlots.R {} {}".format(os.path.join(self.outputDir, "indels.tsv"), os.path.join(self.outputDir, "indel_plots.pdf")))

        self.finish() #Indicates the batch is done
Example #2
0
    def run(self):
        AbstractAnalysis.run(self) #Call base method to do some logging
        refSequences = getFastaDictionary(self.referenceFastaFile) #Hash of names to sequences
        readSequences = getFastqDictionary(self.readFastqFile) #Hash of names to sequences
        sam = pysam.Samfile(self.samFile, "r" )

        #The data we collect
        avgPosteriorMatchProbabilityInCigar = []
        alignedPairsInCigar = []
        posteriorMatchProbabilities = []

        for aR in samIterator(sam): #Iterate on the sam lines
            #Exonerate format Cigar string
            cigarString = getExonerateCigarFormatString(aR, sam)
            
            #Temporary files
            tempCigarFile = os.path.join(self.getLocalTempDir(), "rescoredCigar.cig")
            tempRefFile = os.path.join(self.getLocalTempDir(), "ref.fa")
            tempReadFile = os.path.join(self.getLocalTempDir(), "read.fa")
            tempPosteriorProbsFile = os.path.join(self.getLocalTempDir(), "probs.tsv")
            
            #Write the temporary files.
            fastaWrite(tempRefFile, sam.getrname(aR.rname), refSequences[sam.getrname(aR.rname)]) 
            fastaWrite(tempReadFile, aR.qname, aR.query)
            
            #Trained hmm file to use.
            hmmFile = os.path.join(pathToBaseNanoporeDir(), "nanopore", "mappers", "blasr_hmm_0.txt")
            
            #Call to cactus_realign
            system("echo %s | cactus_realign %s %s --rescoreByPosteriorProbIgnoringGaps --rescoreOriginalAlignment --diagonalExpansion=10 --splitMatrixBiggerThanThis=100 --outputPosteriorProbs=%s --loadHmm=%s > %s" % \
                   (cigarString, tempRefFile, tempReadFile, tempPosteriorProbsFile, hmmFile, tempCigarFile))
            
            #Load the cigar and get the posterior prob
            assert len([ pA for pA in cigarRead(open(tempCigarFile)) ]) > 0
            assert len([ pA for pA in cigarRead(open(tempCigarFile)) ]) == 1
            pA = [ i for i in cigarRead(open(tempCigarFile)) ][0]
            avgPosteriorMatchProbabilityInCigar.append(pA.score)
            
            #Calculate the number of aligned pairs in the cigar
            alignedPairsInCigar.append(sum([ op.length for op in pA.operationList if op.type == PairwiseAlignment.PAIRWISE_MATCH ]))
            assert alignedPairsInCigar[-1] == len([ readPos for readPos, refPos in aR.aligned_pairs if readPos != None and refPos != None ])
            
            #Get the posterior probs
            #posteriorMatchProbabilities += [ float(line.split()[2]) for line in open(tempPosteriorProbsFile) ]
            
        sam.close()
        #Write out the substitution info
        node = ET.Element("alignmentUncertainty", { 
                "averagePosteriorMatchProbabilityPerRead":str(self.formatRatio(sum(avgPosteriorMatchProbabilityInCigar), len(avgPosteriorMatchProbabilityInCigar))),
                "averagePosteriorMatchProbability":str(self.formatRatio(float(sum([ avgMatchProb*alignedPairs for avgMatchProb, alignedPairs in zip(avgPosteriorMatchProbabilityInCigar, alignedPairsInCigar) ])),sum(alignedPairsInCigar))),
                "averagePosteriorMatchProbabilitesPerRead":",".join([ str(i) for i in avgPosteriorMatchProbabilityInCigar ]), 
                "alignedPairsInCigar":",".join([ str(i) for i in alignedPairsInCigar ]) })
        open(os.path.join(self.outputDir, "alignmentUncertainty.xml"), "w").write(prettyXml(node))
        if len(avgPosteriorMatchProbabilityInCigar) > 0:
            outf = open(os.path.join(self.getLocalTempDir(), "tmp_uncertainty"), "w")
            outf.write("\t".join([ str(i) for i in avgPosteriorMatchProbabilityInCigar ])); outf.write("\n")
            outf.close()
            system("Rscript nanopore/analyses/match_hist.R {} {}".format(os.path.join(self.getLocalTempDir(), "tmp_uncertainty"), os.path.join(self.outputDir, "posterior_prob_hist.pdf")))
        #Indicate everything is all done
        self.finish()
Example #3
0
 def run(self, globalAlignment=False):
     AbstractAnalysis.run(self) #Call base method to do some logging
     refSequences = getFastaDictionary(self.referenceFastaFile) #Hash of names to sequences
     readSequences = getFastqDictionary(self.readFastqFile) #Hash of names to sequences
     sam = pysam.Samfile(self.samFile, "r" )
     readsToReadCoverages = {}
     for aR in samIterator(sam): #Iterate on the sam lines
         refSeq = refSequences[sam.getrname(aR.rname)]
         readSeq = readSequences[aR.qname]
         readAlignmentCoverageCounter = ReadAlignmentCoverageCounter(aR.qname, readSeq, sam.getrname(aR.rname), refSeq, aR, globalAlignment)
         if aR.qname not in readsToReadCoverages:
             readsToReadCoverages[aR.qname] = []
         readsToReadCoverages[aR.qname].append(readAlignmentCoverageCounter)
     sam.close()
     #Write out the coverage info for differing subsets of the read alignments
     if len(readsToReadCoverages.values()) > 0:
         for readCoverages, outputName in [ (reduce(lambda x, y : x + y, readsToReadCoverages.values()), "coverage_all"), (map(lambda x : max(x, key=lambda y : y.readCoverage()), readsToReadCoverages.values()), "coverage_bestPerRead") ]:
             parentNode = getAggregateCoverageStats(readCoverages, outputName, refSequences, readSequences, readsToReadCoverages, outputName)
             open(os.path.join(self.outputDir, outputName + ".xml"), 'w').write(prettyXml(parentNode))
             #this is a ugly file format with each line being a different data type - column length is variable
             outf = open(os.path.join(self.outputDir, outputName + ".txt"), "w")
             outf.write("MappedReadLengths " + parentNode.get("mappedReadLengths") + "\n")
             outf.write("UnmappedReadLengths " + parentNode.get("unmappedReadLengths") + "\n")
             outf.write("ReadCoverage " + parentNode.get("distributionreadCoverage") + "\n")
             outf.write("MismatchesPerReadBase " + parentNode.get("distributionmismatchesPerReadBase") + "\n")
             outf.write("ReadIdentity " + parentNode.get("distributionidentity") + "\n")
             outf.write("InsertionsPerBase " + parentNode.get("distributioninsertionsPerReadBase") + "\n")
             outf.write("DeletionsPerBase " + parentNode.get("distributiondeletionsPerReadBase") + "\n")
             outf.close()
             system("Rscript nanopore/analyses/coverage_plot.R {} {}".format(os.path.join(self.outputDir, outputName + ".txt"), os.path.join(self.outputDir, outputName + ".pdf")))
     self.finish()
Example #4
0
 def run(self, globalAlignment=False):
     AbstractAnalysis.run(self)  #Call base method to do some logging
     refSequences = getFastaDictionary(
         self.referenceFastaFile)  #Hash of names to sequences
     readSequences = getFastqDictionary(
         self.readFastqFile)  #Hash of names to sequences
     sam = pysam.Samfile(self.samFile, "r")
     readsToReadCoverages = {}
     for aR in samIterator(sam):  #Iterate on the sam lines
         refSeq = refSequences[sam.getrname(aR.rname)]
         readSeq = readSequences[aR.qname]
         readAlignmentCoverageCounter = ReadAlignmentCoverageCounter(
             aR.qname, readSeq, sam.getrname(aR.rname), refSeq, aR,
             globalAlignment)
         if aR.qname not in readsToReadCoverages:
             readsToReadCoverages[aR.qname] = []
         readsToReadCoverages[aR.qname].append(readAlignmentCoverageCounter)
     sam.close()
     #Write out the coverage info for differing subsets of the read alignments
     if len(readsToReadCoverages.values()) > 0:
         for readCoverages, outputName in [
             (reduce(lambda x, y: x + y,
                     readsToReadCoverages.values()), "coverage_all"),
             (map(lambda x: max(x, key=lambda y: y.readCoverage()),
                  readsToReadCoverages.values()), "coverage_bestPerRead")
         ]:
             parentNode = getAggregateCoverageStats(
                 readCoverages, outputName, refSequences, readSequences,
                 readsToReadCoverages, outputName)
             open(os.path.join(self.outputDir, outputName + ".xml"),
                  'w').write(prettyXml(parentNode))
             #this is a ugly file format with each line being a different data type - column length is variable
             outf = open(os.path.join(self.outputDir, outputName + ".txt"),
                         "w")
             outf.write("MappedReadLengths " +
                        parentNode.get("mappedReadLengths") + "\n")
             outf.write("UnmappedReadLengths " +
                        parentNode.get("unmappedReadLengths") + "\n")
             outf.write("ReadCoverage " +
                        parentNode.get("distributionreadCoverage") + "\n")
             outf.write(
                 "MismatchesPerReadBase " +
                 parentNode.get("distributionmismatchesPerReadBase") + "\n")
             outf.write("ReadIdentity " +
                        parentNode.get("distributionidentity") + "\n")
             outf.write(
                 "InsertionsPerBase " +
                 parentNode.get("distributioninsertionsPerReadBase") + "\n")
             outf.write("DeletionsPerBase " +
                        parentNode.get("distributiondeletionsPerReadBase") +
                        "\n")
             outf.close()
             system(
                 "Rscript nanopore/analyses/coverage_plot.R {} {}".format(
                     os.path.join(self.outputDir, outputName + ".txt"),
                     os.path.join(self.outputDir, outputName + ".pdf")))
     self.finish()
Example #5
0
 def run(self, args=""):
     tempSamFile = os.path.join(self.getLocalTempDir(), "temp.sam")
     system("blasr %s %s -sam -clipping hard %s > %s" % (self.readFastqFile, self.referenceFastaFile, args, tempSamFile))
     #Blasr seems to corrupt the names of read sequences, so lets correct them.
     sam = pysam.Samfile(tempSamFile, "r" )
     outputSam = pysam.Samfile(self.outputSamFile, "wh", template=sam)
     readSequences = getFastqDictionary(self.readFastqFile) #Hash of names to sequences
     for aR in sam: #Iterate on the sam lines and put into buckets by read
         if aR.qname not in readSequences:
             newName = '/'.join(aR.qname.split('/')[:-1])
             if newName not in readSequences:
                 raise RuntimeError("Tried to deduce correct read name: %s, %s" % (newName, readSequences.keys()))
             aR.qname = newName
         outputSam.write(aR)
     outputSam.close()
Example #6
0
 def run(self):
     AbstractAnalysis.run(self)
     readSequences = getFastqDictionary(self.readFastqFile)
     nr = re.compile(r"channel_[0-9]+_read_[0-9]+")
     per_channel_read_counts = Counter([int(x.split("_")[1]) for x in readSequences.iterkeys() if re.match(nr, x)])
     sam = pysam.Samfile(self.samFile, "r")
     mapped_read_counts = Counter([int(aR.qname.split("_")[1]) for aR in samIterator(sam) if re.match(nr, aR.qname) and aR.is_unmapped is False])
     if len(mapped_read_counts) > 0 and len(per_channel_read_counts) > 0:
         outf = open(os.path.join(self.outputDir, "channel_mappability.tsv"), "w")
         outf.write("Channel\tReadCount\tMappableReadCount\n")
         max_channel = max(513, max(per_channel_read_counts.keys())) #in case there are more than 512 in the future
         for channel in xrange(1, max_channel):
             outf.write("\t".join(map(str, [channel, per_channel_read_counts[channel], mapped_read_counts[channel]])))
             outf.write("\n")
         outf.close()
         system("Rscript nanopore/analyses/channel_plots.R {} {} {} {} {}".format(os.path.join(self.outputDir, "channel_mappability.tsv"), os.path.join(self.outputDir, "channel_mappability.pdf"), os.path.join(self.outputDir, "channel_mappability_sorted.png"), os.path.join(self.outputDir, "mappability_levelplot.png"), os.path.join(self.outputDir, "mappability_leveplot_percent.png")))
     self.finish()
Example #7
0
 def run(self, args=""):
     tempSamFile = os.path.join(self.getLocalTempDir(), "temp.sam")
     system(
         "blasr %s %s -sam -clipping hard %s > %s" %
         (self.readFastqFile, self.referenceFastaFile, args, tempSamFile))
     #Blasr seems to corrupt the names of read sequences, so lets correct them.
     sam = pysam.Samfile(tempSamFile, "r")
     outputSam = pysam.Samfile(self.outputSamFile, "wh", template=sam)
     readSequences = getFastqDictionary(
         self.readFastqFile)  #Hash of names to sequences
     for aR in sam:  #Iterate on the sam lines and put into buckets by read
         if aR.qname not in readSequences:
             newName = '/'.join(aR.qname.split('/')[:-1])
             if newName not in readSequences:
                 raise RuntimeError(
                     "Tried to deduce correct read name: %s, %s" %
                     (newName, readSequences.keys()))
             aR.qname = newName
         outputSam.write(aR)
     outputSam.close()
Example #8
0
    def run(self):
        AbstractAnalysis.run(self)  #Call base method to do some logging
        refSequences = getFastaDictionary(
            self.referenceFastaFile)  #Hash of names to sequences
        readSequences = getFastqDictionary(
            self.readFastqFile)  #Hash of names to sequences
        sam = pysam.Samfile(self.samFile, "r")
        indelCounters = map(lambda aR: IndelCounter(
            sam.getrname(aR.rname), refSequences[sam.getrname(
                aR.rname)], aR.qname, readSequences[aR.qname], aR),
                            samIterator(sam))  #Iterate on the sam lines
        sam.close()
        #Write out the substitution info
        if len(indelCounters) > 0:
            indelXML = getAggregateIndelStats(indelCounters)
            open(os.path.join(self.outputDir, "indels.xml"),
                 "w").write(prettyXml(indelXML))
            tmp = open(os.path.join(self.outputDir, "indels.tsv"), "w")
            #build list of data as vectors
            data_list = []
            var = [
                "readInsertionLengths", "readDeletionLengths",
                "ReadSequenceLengths", "NumberReadInsertions",
                "NumberReadDeletions", "MedianReadInsertionLengths",
                "MedianReadDeletionLengths"
            ]
            for x in var:
                data_list.append([x] + indelXML.attrib[x].split())
            #transpose this list so R doesn't take hours to load it using magic
            data_list = map(None, *data_list)
            for line in data_list:
                tmp.write("\t".join(map(str, line)))
                tmp.write("\n")
            tmp.close()
            system("Rscript nanopore/analyses/indelPlots.R {} {}".format(
                os.path.join(self.outputDir, "indels.tsv"),
                os.path.join(self.outputDir, "indel_plots.pdf")))

        self.finish()  #Indicates the batch is done
Example #9
0
 def run(self, kmer=5):
     AbstractAnalysis.run(self) #Call base method to do some logging
     refSequences = getFastaDictionary(self.referenceFastaFile) #Hash of names to sequences
     readSequences = getFastqDictionary(self.readFastqFile) #Hash of names to sequences
     sM = SubstitutionMatrix() #The thing to store the counts in
     sam = pysam.Samfile(self.samFile, "r" )
     for aR in samIterator(sam): #Iterate on the sam lines
         for aP in AlignedPair.iterator(aR, refSequences[sam.getrname(aR.rname)], readSequences[aR.qname]): #Walk through the matches mismatches:
             sM.addAlignedPair(aP.getRefBase(), aP.getReadBase())
     sam.close()
     #Write out the substitution info
     open(os.path.join(self.outputDir, "substitutions.xml"), 'w').write(prettyXml(sM.getXML()))
     bases = "ACGT"
     outf = open(os.path.join(self.outputDir, "subst.tsv"), "w")
     outf.write("A\tC\tG\tT\n")
     for x in bases:
         freqs = sM.getFreqs(x, bases)
         outf.write("{}\t{}\n".format(x, "\t".join(map(str,freqs)), "\n"))
     outf.close()
     analysis = self.outputDir.split("/")[-2].split("_")[-1] + "_Substitution_Levels"
     system("Rscript nanopore/analyses/substitution_plot.R {} {} {}".format(os.path.join(self.outputDir, "subst.tsv"), os.path.join(self.outputDir, "substitution_plot.pdf"), analysis))        
     self.finish()
 def run(self):
     AbstractAnalysis.run(self)
     readSequences = getFastqDictionary(self.readFastqFile)
     nr = re.compile(r"channel_[0-9]+_read_[0-9]+")
     per_channel_read_counts = Counter([
         int(x.split("_")[1]) for x in readSequences.iterkeys()
         if re.match(nr, x)
     ])
     sam = pysam.Samfile(self.samFile, "r")
     mapped_read_counts = Counter([
         int(aR.qname.split("_")[1]) for aR in samIterator(sam)
         if re.match(nr, aR.qname) and aR.is_unmapped is False
     ])
     if len(mapped_read_counts) > 0 and len(per_channel_read_counts) > 0:
         outf = open(
             os.path.join(self.outputDir, "channel_mappability.tsv"), "w")
         outf.write("Channel\tReadCount\tMappableReadCount\n")
         max_channel = max(513, max(per_channel_read_counts.keys())
                           )  #in case there are more than 512 in the future
         for channel in xrange(1, max_channel):
             outf.write("\t".join(
                 map(str, [
                     channel, per_channel_read_counts[channel],
                     mapped_read_counts[channel]
                 ])))
             outf.write("\n")
         outf.close()
         system("Rscript nanopore/analyses/channel_plots.R {} {} {} {} {}".
                format(
                    os.path.join(self.outputDir, "channel_mappability.tsv"),
                    os.path.join(self.outputDir, "channel_mappability.pdf"),
                    os.path.join(self.outputDir,
                                 "channel_mappability_sorted.png"),
                    os.path.join(self.outputDir,
                                 "mappability_levelplot.png"),
                    os.path.join(self.outputDir,
                                 "mappability_leveplot_percent.png")))
     self.finish()
Example #11
0
    def run(self):
        #Call base method to do some logging
        AbstractAnalysis.run(self) 
        
        #Hmm file
        hmmFile = os.path.join(os.path.split(self.samFile)[0], "hmm.txt.xml")
        if os.path.exists(hmmFile):
            #Load the hmm
            hmmsNode = ET.parse(hmmFile).getroot()

            #Plot graphviz version of nanopore hmm, showing transitions and variances.
            fH = open(os.path.join(self.outputDir, "hmm.dot"), 'w')
            setupGraphFile(fH)
            #Make states
            addNodeToGraph("n0n", fH, "match")
            addNodeToGraph("n1n", fH, "short delete")
            addNodeToGraph("n2n", fH, "short insert")
            addNodeToGraph("n3n", fH, "long insert")
            addNodeToGraph("n4n", fH, "long delete")

            #Make edges with labelled transition probs.
            for transition in hmmsNode.findall("transition"):
                if float(transition.attrib["avg"]) > 0.0:
                    addEdgeToGraph("n%sn" % transition.attrib["from"], 
                                   "n%sn" % transition.attrib["to"], 
                                   fH, dir="arrow", style='""',
                                   label="%.3f,%.3f" % (float(transition.attrib["avg"]), float(transition.attrib["std"])))

            #Finish up
            finishGraphFile(fH)
            fH.close()

            #Plot match emission data
            emissions = dict([ ((emission.attrib["x"], emission.attrib["y"]), emission.attrib["avg"]) \
                  for emission in hmmsNode.findall("emission") if emission.attrib["state"] == '0' ])
            
            matchEmissionsFile = os.path.join(self.outputDir, "matchEmissions.tsv")
            outf = open(matchEmissionsFile, "w")
            bases = "ACGT"
            outf.write("\t".join(bases) + "\n")
            for base in bases:
                outf.write("\t".join([ base] + map(lambda x : emissions[(base, x)], bases)) + "\n")
            outf.close()
            system("Rscript nanopore/analyses/substitution_plot.R %s %s %s" % (matchEmissionsFile, os.path.join(self.outputDir, "substitution_plot.pdf"), "Per-Base Substitutions after HMM"))

            #Plot indel info
            #Get the sequences to contrast the neutral model.
            refSequences = getFastaDictionary(self.referenceFastaFile) #Hash of names to sequences
            readSequences = getFastqDictionary(self.readFastqFile) #Hash of names to sequences
            
            #Need to do plot of insert and deletion gap emissions
            
            #Plot of insert and deletion gap emissions
            insertEmissions = { "A":0.0, 'C':0.0, 'G':0.0, 'T':0.0 }
            deleteEmissions = { "A":0.0, 'C':0.0, 'G':0.0, 'T':0.0 }
            for emission in hmmsNode.findall("emission"):
                if emission.attrib["state"] == '2':
                    insertEmissions[emission.attrib["x"]] += float(emission.attrib["avg"])
                elif emission.attrib["state"] == '1':
                    deleteEmissions[emission.attrib["y"]] += float(emission.attrib["avg"])
            #PLot insert and delete emissions
            indelEmissionsFile = os.path.join(self.outputDir, "indelEmissions.tsv")
            outf = open(indelEmissionsFile, "w")
            outf.write("\t".join(bases) + "\n")
            outf.write("\t".join(map(lambda x : str(insertEmissions[x]), bases)) + "\n")
            outf.write("\t".join(map(lambda x : str(deleteEmissions[x]), bases)) + "\n")
            outf.close()
            ###Here's where we do the plot..
            system("Rscript nanopore/analyses/emissions_plot.R {} {}".format(indelEmissionsFile, os.path.join(self.outputDir, "indelEmissions_plot.pdf")))

            #Plot convergence of likelihoods
            outf = open(os.path.join(self.outputDir, "runninglikelihoods.tsv"), "w")
            for hmmNode in hmmsNode.findall("hmm"): #This is a loop over trials
                runningLikelihoods = map(float, hmmNode.attrib["runningLikelihoods"].split()) #This is a list of floats ordered from the first iteration to last.
                outf.write("\t".join(map(str, runningLikelihoods))); outf.write("\n")
            outf.close()
            system("Rscript nanopore/analyses/running_likelihood.R {} {}".format(os.path.join(self.outputDir, "runninglikelihoods.tsv"), os.path.join(self.outputDir, "running_likelihood.pdf")))
            
        self.finish() #Indicates the batch is done
    def run(self):
        AbstractAnalysis.run(self) #Call base method to do some logging
        refSequences = getFastaDictionary(self.referenceFastaFile) #Hash of names to sequences
        readSequences = getFastqDictionary(self.readFastqFile) #Hash of names to sequences
        
        node = ET.Element("marginAlignComparison")
        for hmmType in ("cactus", "trained_0",  "trained_20", "trained_40"): 
            for coverage in (1000000, 120, 60, 30, 10): 
                for replicate in xrange(3 if coverage < 1000000 else 1): #Do replicates, unless coverage is all
                    sam = pysam.Samfile(self.samFile, "r" )
                    
                    #Trained hmm file to use.q
                    hmmFile0 = os.path.join(pathToBaseNanoporeDir(), "nanopore", "mappers", "blasr_hmm_0.txt")
                    hmmFile20 = os.path.join(pathToBaseNanoporeDir(), "nanopore", "mappers", "blasr_hmm_20.txt")
                    hmmFile40 = os.path.join(pathToBaseNanoporeDir(), "nanopore", "mappers", "blasr_hmm_40.txt")
              
                    #Get substitution matrices
                    nullSubstitionMatrix = getNullSubstitutionMatrix()
                    flatSubstitutionMatrix = getJukesCantorTypeSubstitutionMatrix()
                    hmmErrorSubstitutionMatrix = loadHmmErrorSubstitutionMatrix(hmmFile20)
                
                    #Load the held out snps
                    snpSet = {}
                    referenceAlignmentFile = self.referenceFastaFile + "_Index.txt"
                    if os.path.exists(referenceAlignmentFile):
                        seqsAndMutatedSeqs = getFastaDictionary(referenceAlignmentFile)
                        count = 0
                        for name in seqsAndMutatedSeqs:
                            if name in refSequences:
                                count += 1
                                trueSeq = seqsAndMutatedSeqs[name]
                                mutatedSeq = seqsAndMutatedSeqs[name + "_mutated"]
                                assert mutatedSeq == refSequences[name]
                                for i in xrange(len(trueSeq)):
                                    if trueSeq[i] != mutatedSeq[i]:
                                        snpSet[(name, i)] = trueSeq[i] 
                            else:
                                assert name.split("_")[-1] == "mutated"
                        assert count == len(refSequences.keys())
                    
                    #The data we collect
                    expectationsOfBasesAtEachPosition = {}
                    frequenciesOfAlignedBasesAtEachPosition = {}
                    
                    totalSampledReads = 0
                    totalAlignedPairs = 0
                    totalReadLength = 0
                    totalReferenceLength = sum(map(len, refSequences.values()))
                    
                    #Get a randomised ordering for the reads
                    reads = [ aR for aR in samIterator(sam) ]
                    random.shuffle(reads)
                    
                    for aR in reads: #Iterate on the sam lines
                        if totalReadLength/totalReferenceLength >= coverage: #Stop when coverage exceeds the quota
                            break
                        totalReadLength += len(readSequences[aR.qname])
                        totalSampledReads += 1
                        
                        #Temporary files
                        tempCigarFile = os.path.join(self.getLocalTempDir(), "rescoredCigar.cig")
                        tempRefFile = os.path.join(self.getLocalTempDir(), "ref.fa")
                        tempReadFile = os.path.join(self.getLocalTempDir(), "read.fa")
                        tempPosteriorProbsFile = os.path.join(self.getLocalTempDir(), "probs.tsv")
                        
                        #Ref name
                        refSeqName = sam.getrname(aR.rname)
                        
                        #Sequences
                        refSeq = refSequences[sam.getrname(aR.rname)]
                        
                        #Walk through the aligned pairs to collate the bases of aligned positions
                        for aP in AlignedPair.iterator(aR, refSeq, readSequences[aR.qname]): 
                            totalAlignedPairs += 1 #Record an aligned pair
                            key = (refSeqName, aP.refPos)
                            if key not in frequenciesOfAlignedBasesAtEachPosition:
                                frequenciesOfAlignedBasesAtEachPosition[key] = dict(zip(bases, [0.0]*len(bases))) 
                            readBase = aP.getReadBase() #readSeq[aP.readPos].upper() #Use the absolute read, ins
                            if readBase in bases:
                                frequenciesOfAlignedBasesAtEachPosition[key][readBase] += 1
                        
                        #Write the temporary files.
                        readSeq = aR.query #This excludes bases that were soft-clipped and is always of positive strand coordinates
                        fastaWrite(tempRefFile, refSeqName, refSeq) 
                        fastaWrite(tempReadFile, aR.qname, readSeq)
                        
                        #Exonerate format Cigar string, which is in readSeq coordinates (positive strand).
                        assert aR.pos == 0
                        assert aR.qstart == 0
                        assert aR.qend == len(readSeq)
                        assert aR.aend == len(refSeq)

                        cigarString = getExonerateCigarFormatString(aR, sam)
                        
                        #Call to cactus_realign
                        if hmmType == "trained_0":
                            system("echo %s | cactus_realign %s %s --diagonalExpansion=10 --splitMatrixBiggerThanThis=100 --outputAllPosteriorProbs=%s --loadHmm=%s > %s" % \
                                   (cigarString, tempRefFile, tempReadFile, tempPosteriorProbsFile, hmmFile0, tempCigarFile))
                        elif hmmType == "trained_20":
                            system("echo %s | cactus_realign %s %s --diagonalExpansion=10 --splitMatrixBiggerThanThis=100 --outputAllPosteriorProbs=%s --loadHmm=%s > %s" % \
                                   (cigarString, tempRefFile, tempReadFile, tempPosteriorProbsFile, hmmFile20, tempCigarFile))
                        elif hmmType == "trained_40":
                            system("echo %s | cactus_realign %s %s --diagonalExpansion=10 --splitMatrixBiggerThanThis=100 --outputAllPosteriorProbs=%s --loadHmm=%s > %s" % \
                                   (cigarString, tempRefFile, tempReadFile, tempPosteriorProbsFile, hmmFile40, tempCigarFile))
                        else:
                            system("echo %s | cactus_realign %s %s --diagonalExpansion=10 --splitMatrixBiggerThanThis=100 --outputAllPosteriorProbs=%s > %s" % \
                                   (cigarString, tempRefFile, tempReadFile, tempPosteriorProbsFile, tempCigarFile))
                        
                        #Now collate the reference position expectations
                        for refPosition, readPosition, posteriorProb in map(lambda x : map(float, x.split()), open(tempPosteriorProbsFile, 'r')):
                            key = (refSeqName, int(refPosition))
                            if key not in expectationsOfBasesAtEachPosition:
                                expectationsOfBasesAtEachPosition[key] = dict(zip(bases, [0.0]*len(bases))) 
                            readBase = readSeq[int(readPosition)].upper()
                            if readBase in bases:
                                expectationsOfBasesAtEachPosition[key][readBase] += posteriorProb
                        
                        #Collate aligned positions from cigars
            
                    sam.close()
                    
                    totalHeldOut = len(snpSet)
                    totalNotHeldOut = totalReferenceLength - totalHeldOut
                    
                    class SnpCalls:
                        def __init__(self):
                            self.falsePositives = []
                            self.truePositives = []
                            self.falseNegatives = []
                            self.notCalled = 0
                        
                        @staticmethod
                        def bucket(calls):
                            calls = calls[:]
                            calls.sort()
                            buckets = [0.0]*101
                            for prob in calls: #Discretize
                                buckets[int(round(prob*100))] += 1
                            for i in xrange(len(buckets)-2, -1, -1): #Make cumulative
                                buckets[i] += buckets[i+1]
                            return buckets
                        
                        def getPrecisionByProbability(self):
                            tPs = self.bucket(map(lambda x : x[0], self.truePositives)) 
                            fPs = self.bucket(map(lambda x : x[0], self.falsePositives))
                            return map(lambda i : float(tPs[i]) / (tPs[i] + fPs[i]) if tPs[i] + fPs[i] != 0 else 0, xrange(len(tPs)))
                        
                        def getRecallByProbability(self):
                            return map(lambda i : i/totalHeldOut if totalHeldOut != 0 else 0, self.bucket(map(lambda x : x[0], self.truePositives)))
                        
                        def getTruePositiveLocations(self):
                            return map(lambda x : x[1], self.truePositives)
                        
                        def getFalsePositiveLocations(self):
                            return map(lambda x : x[1], self.falsePositives)
                        
                        def getFalseNegativeLocations(self):
                            return map(lambda x : x[0], self.falseNegatives)
            
                    #The different call sets
                    marginAlignMaxExpectedSnpCalls = SnpCalls()
                    marginAlignMaxLikelihoodSnpCalls = SnpCalls()
                    maxFrequencySnpCalls = SnpCalls()
                    maximumLikelihoodSnpCalls = SnpCalls()
                    
                    #Now calculate the calls
                    for refSeqName in refSequences:
                        refSeq = refSequences[refSeqName]
                        for refPosition in xrange(len(refSeq)):
                            mutatedRefBase = refSeq[refPosition].upper()
                            trueRefBase = (mutatedRefBase if not (refSeqName, refPosition) in snpSet else snpSet[(refSeqName, refPosition)]).upper()
                            key = (refSeqName, refPosition)
                            
                            
                            #Get base calls
                            for errorSubstitutionMatrix, evolutionarySubstitutionMatrix, baseExpectations, snpCalls in \
                            ((flatSubstitutionMatrix, nullSubstitionMatrix, expectationsOfBasesAtEachPosition, marginAlignMaxExpectedSnpCalls),
                             (hmmErrorSubstitutionMatrix, nullSubstitionMatrix, expectationsOfBasesAtEachPosition, marginAlignMaxLikelihoodSnpCalls),
                             (flatSubstitutionMatrix, nullSubstitionMatrix, frequenciesOfAlignedBasesAtEachPosition, maxFrequencySnpCalls),
                             (hmmErrorSubstitutionMatrix, nullSubstitionMatrix, frequenciesOfAlignedBasesAtEachPosition, maximumLikelihoodSnpCalls)):
                                
                                if key in baseExpectations:
                                    #Get posterior likelihoods
                                    expectations = baseExpectations[key]
                                    totalExpectation = sum(expectations.values())
                                    if totalExpectation > 0.0: #expectationCallingThreshold:
                                        posteriorProbs = calcBasePosteriorProbs(dict(zip(bases, map(lambda x : float(expectations[x])/totalExpectation, bases))), mutatedRefBase, 
                                                               evolutionarySubstitutionMatrix, errorSubstitutionMatrix)
                                        probs = [ posteriorProbs[base] for base in "ACGT" ]
                                        #posteriorProbs.pop(mutatedRefBase) #Remove the ref base.
                                        #maxPosteriorProb = max(posteriorProbs.values())
                                        #chosenBase = random.choice([ base for base in posteriorProbs if posteriorProbs[base] == maxPosteriorProb ]).upper() #Very naive way to call the base

                                        for chosenBase in "ACGT":
                                            if chosenBase != mutatedRefBase:
                                                maxPosteriorProb = posteriorProbs[chosenBase]
                                                if trueRefBase != mutatedRefBase and trueRefBase == chosenBase:
                                                    snpCalls.truePositives.append((maxPosteriorProb, refPosition)) #True positive
                                                else:
                                                    snpCalls.falsePositives.append((maxPosteriorProb, refPosition)) #False positive
                                                """
                                                    snpCalls.falseNegatives.append((refPosition, trueRefBase, mutatedRefBase, probs)) #False negative
                                                if trueRefBase != mutatedRefBase:
                                                    if trueRefBase == chosenBase:
                                                        snpCalls.truePositives.append((maxPosteriorProb, refPosition)) #True positive
                                                    else:
                                                        snpCalls.falseNegatives.append((refPosition, trueRefBase, mutatedRefBase, probs)) #False negative
                                                else:
                                                    snpCalls.falsePositives.append((maxPosteriorProb, refPosition)) #False positive
                                                """
                                else:
                                    snpCalls.notCalled += 1
                        
                    #Now find max-fscore point
                    
                    
                    for snpCalls, tagName in ((marginAlignMaxExpectedSnpCalls, "marginAlignMaxExpectedSnpCalls"), 
                                              (marginAlignMaxLikelihoodSnpCalls, "marginAlignMaxLikelihoodSnpCalls"),
                                              (maxFrequencySnpCalls, "maxFrequencySnpCalls"),
                                              (maximumLikelihoodSnpCalls, "maximumLikelihoodSnpCalls")):
                        recall = snpCalls.getRecallByProbability()
                        precision = snpCalls.getPrecisionByProbability()
                        assert len(recall) == len(precision)
                        fScore, pIndex = max(map(lambda i : (2 * recall[i] * precision[i] / (recall[i] + precision[i]) if recall[i] + precision[i] > 0 else 0.0, i), range(len(recall))))
                        truePositives = snpCalls.getRecallByProbability()[pIndex]
                        falsePositives = snpCalls.getPrecisionByProbability()[pIndex]
                        optimumProbThreshold = float(pIndex)/100.0
                        
                        #Write out the substitution info
                        node2 = ET.SubElement(node, tagName + "_" + hmmType, {  
                                "coverage":str(coverage),
                                "actualCoverage":str(float(totalAlignedPairs)/totalReferenceLength),
                                "totalAlignedPairs":str(totalAlignedPairs),
                                "totalReferenceLength":str(totalReferenceLength),
                                "replicate":str(replicate),
                                "totalReads":str(len(reads)),
                                "avgSampledReadLength":str(float(totalReadLength)/totalSampledReads),
                                "totalSampledReads":str(totalSampledReads),
                                
                                "totalHeldOut":str(totalHeldOut),
                                "totalNonHeldOut":str(totalNotHeldOut),
                                
                                "recall":str(recall[pIndex]),
                                "precision":str(precision[pIndex]),
                                "fScore":str(fScore),
                                "optimumProbThreshold":str(optimumProbThreshold),
                                "totalNoCalls":str(snpCalls.notCalled),

                                "recallByProbability":" ".join(map(str, snpCalls.getRecallByProbability())),
                                "precisionByProbability":" ".join(map(str, snpCalls.getPrecisionByProbability())) })
                                
                                #"falsePositiveLocations":" ".join(map(str, snpCalls.getFalsePositiveLocations())),
                                #"falseNegativeLocations":" ".join(map(str, snpCalls.getFalseNegativeLocations())),
                                #"truePositiveLocations":" ".join(map(str, snpCalls.getTruePositiveLocations())) })
                        for refPosition, trueRefBase, mutatedRefBase, posteriorProbs in snpCalls.falseNegatives:
                            ET.SubElement(node2, "falseNegative_%s_%s" % (trueRefBase, mutatedRefBase), { "posteriorProbs":" ".join(map(str, posteriorProbs))})
                        for falseNegativeBase in bases:
                            for mutatedBase in bases:
                                posteriorProbsArray = [ posteriorProbs for refPosition, trueRefBase, mutatedRefBase, posteriorProbs in snpCalls.falseNegatives if (trueRefBase.upper() == falseNegativeBase.upper() and mutatedBase.upper() == mutatedRefBase.upper() ) ]
                                if len(posteriorProbsArray) > 0:
                                    summedProbs = reduce(lambda x, y : map(lambda i : x[i] + y[i], xrange(len(x))), posteriorProbsArray)
                                    summedProbs = map(lambda x : float(x)/sum(summedProbs), summedProbs)
                                    ET.SubElement(node2, "combinedFalseNegative_%s_%s" % (falseNegativeBase, mutatedBase), { "posteriorProbs":" ".join(map(str, summedProbs))})
                        
        open(os.path.join(self.outputDir, "marginaliseConsensus.xml"), "w").write(prettyXml(node))
        
        
        #Indicate everything is all done
        self.finish()
Example #13
0
    def run(self):
        AbstractAnalysis.run(self)  #Call base method to do some logging
        refSequences = getFastaDictionary(
            self.referenceFastaFile)  #Hash of names to sequences
        readSequences = getFastqDictionary(
            self.readFastqFile)  #Hash of names to sequences

        node = ET.Element("marginAlignComparison")
        for hmmType in ("cactus", "trained_0", "trained_20", "trained_40"):
            for coverage in (1000000, 120, 60, 30, 10):
                for replicate in xrange(
                        3 if coverage < 1000000 else 1
                ):  #Do replicates, unless coverage is all
                    sam = pysam.Samfile(self.samFile, "r")

                    #Trained hmm file to use.q
                    hmmFile0 = os.path.join(pathToBaseNanoporeDir(),
                                            "nanopore", "mappers",
                                            "blasr_hmm_0.txt")
                    hmmFile20 = os.path.join(pathToBaseNanoporeDir(),
                                             "nanopore", "mappers",
                                             "blasr_hmm_20.txt")
                    hmmFile40 = os.path.join(pathToBaseNanoporeDir(),
                                             "nanopore", "mappers",
                                             "blasr_hmm_40.txt")

                    #Get substitution matrices
                    nullSubstitionMatrix = getNullSubstitutionMatrix()
                    flatSubstitutionMatrix = getJukesCantorTypeSubstitutionMatrix(
                    )
                    hmmErrorSubstitutionMatrix = loadHmmErrorSubstitutionMatrix(
                        hmmFile20)

                    #Load the held out snps
                    snpSet = {}
                    referenceAlignmentFile = self.referenceFastaFile + "_Index.txt"
                    if os.path.exists(referenceAlignmentFile):
                        seqsAndMutatedSeqs = getFastaDictionary(
                            referenceAlignmentFile)
                        count = 0
                        for name in seqsAndMutatedSeqs:
                            if name in refSequences:
                                count += 1
                                trueSeq = seqsAndMutatedSeqs[name]
                                mutatedSeq = seqsAndMutatedSeqs[name +
                                                                "_mutated"]
                                assert mutatedSeq == refSequences[name]
                                for i in xrange(len(trueSeq)):
                                    if trueSeq[i] != mutatedSeq[i]:
                                        snpSet[(name, i)] = trueSeq[i]
                            else:
                                assert name.split("_")[-1] == "mutated"
                        assert count == len(refSequences.keys())

                    #The data we collect
                    expectationsOfBasesAtEachPosition = {}
                    frequenciesOfAlignedBasesAtEachPosition = {}

                    totalSampledReads = 0
                    totalAlignedPairs = 0
                    totalReadLength = 0
                    totalReferenceLength = sum(map(len, refSequences.values()))

                    #Get a randomised ordering for the reads
                    reads = [aR for aR in samIterator(sam)]
                    random.shuffle(reads)

                    for aR in reads:  #Iterate on the sam lines
                        if totalReadLength / totalReferenceLength >= coverage:  #Stop when coverage exceeds the quota
                            break
                        totalReadLength += len(readSequences[aR.qname])
                        totalSampledReads += 1

                        #Temporary files
                        tempCigarFile = os.path.join(self.getLocalTempDir(),
                                                     "rescoredCigar.cig")
                        tempRefFile = os.path.join(self.getLocalTempDir(),
                                                   "ref.fa")
                        tempReadFile = os.path.join(self.getLocalTempDir(),
                                                    "read.fa")
                        tempPosteriorProbsFile = os.path.join(
                            self.getLocalTempDir(), "probs.tsv")

                        #Ref name
                        refSeqName = sam.getrname(aR.rname)

                        #Sequences
                        refSeq = refSequences[sam.getrname(aR.rname)]

                        #Walk through the aligned pairs to collate the bases of aligned positions
                        for aP in AlignedPair.iterator(
                                aR, refSeq, readSequences[aR.qname]):
                            totalAlignedPairs += 1  #Record an aligned pair
                            key = (refSeqName, aP.refPos)
                            if key not in frequenciesOfAlignedBasesAtEachPosition:
                                frequenciesOfAlignedBasesAtEachPosition[
                                    key] = dict(zip(bases, [0.0] * len(bases)))
                            readBase = aP.getReadBase(
                            )  #readSeq[aP.readPos].upper() #Use the absolute read, ins
                            if readBase in bases:
                                frequenciesOfAlignedBasesAtEachPosition[key][
                                    readBase] += 1

                        #Write the temporary files.
                        readSeq = aR.query  #This excludes bases that were soft-clipped and is always of positive strand coordinates
                        fastaWrite(tempRefFile, refSeqName, refSeq)
                        fastaWrite(tempReadFile, aR.qname, readSeq)

                        #Exonerate format Cigar string, which is in readSeq coordinates (positive strand).
                        assert aR.pos == 0
                        assert aR.qstart == 0
                        assert aR.qend == len(readSeq)
                        assert aR.aend == len(refSeq)

                        cigarString = getExonerateCigarFormatString(aR, sam)

                        #Call to cactus_realign
                        if hmmType == "trained_0":
                            system("echo %s | cactus_realign %s %s --diagonalExpansion=10 --splitMatrixBiggerThanThis=100 --outputAllPosteriorProbs=%s --loadHmm=%s > %s" % \
                                   (cigarString, tempRefFile, tempReadFile, tempPosteriorProbsFile, hmmFile0, tempCigarFile))
                        elif hmmType == "trained_20":
                            system("echo %s | cactus_realign %s %s --diagonalExpansion=10 --splitMatrixBiggerThanThis=100 --outputAllPosteriorProbs=%s --loadHmm=%s > %s" % \
                                   (cigarString, tempRefFile, tempReadFile, tempPosteriorProbsFile, hmmFile20, tempCigarFile))
                        elif hmmType == "trained_40":
                            system("echo %s | cactus_realign %s %s --diagonalExpansion=10 --splitMatrixBiggerThanThis=100 --outputAllPosteriorProbs=%s --loadHmm=%s > %s" % \
                                   (cigarString, tempRefFile, tempReadFile, tempPosteriorProbsFile, hmmFile40, tempCigarFile))
                        else:
                            system("echo %s | cactus_realign %s %s --diagonalExpansion=10 --splitMatrixBiggerThanThis=100 --outputAllPosteriorProbs=%s > %s" % \
                                   (cigarString, tempRefFile, tempReadFile, tempPosteriorProbsFile, tempCigarFile))

                        #Now collate the reference position expectations
                        for refPosition, readPosition, posteriorProb in map(
                                lambda x: map(float, x.split()),
                                open(tempPosteriorProbsFile, 'r')):
                            key = (refSeqName, int(refPosition))
                            if key not in expectationsOfBasesAtEachPosition:
                                expectationsOfBasesAtEachPosition[key] = dict(
                                    zip(bases, [0.0] * len(bases)))
                            readBase = readSeq[int(readPosition)].upper()
                            if readBase in bases:
                                expectationsOfBasesAtEachPosition[key][
                                    readBase] += posteriorProb

                        #Collate aligned positions from cigars

                    sam.close()

                    totalHeldOut = len(snpSet)
                    totalNotHeldOut = totalReferenceLength - totalHeldOut

                    class SnpCalls:
                        def __init__(self):
                            self.falsePositives = []
                            self.truePositives = []
                            self.falseNegatives = []
                            self.notCalled = 0

                        @staticmethod
                        def bucket(calls):
                            calls = calls[:]
                            calls.sort()
                            buckets = [0.0] * 101
                            for prob in calls:  #Discretize
                                buckets[int(round(prob * 100))] += 1
                            for i in xrange(len(buckets) - 2, -1,
                                            -1):  #Make cumulative
                                buckets[i] += buckets[i + 1]
                            return buckets

                        def getPrecisionByProbability(self):
                            tPs = self.bucket(
                                map(lambda x: x[0], self.truePositives))
                            fPs = self.bucket(
                                map(lambda x: x[0], self.falsePositives))
                            return map(
                                lambda i: float(tPs[i]) / (tPs[i] + fPs[i])
                                if tPs[i] + fPs[i] != 0 else 0,
                                xrange(len(tPs)))

                        def getRecallByProbability(self):
                            return map(
                                lambda i: i / totalHeldOut
                                if totalHeldOut != 0 else 0,
                                self.bucket(
                                    map(lambda x: x[0], self.truePositives)))

                        def getTruePositiveLocations(self):
                            return map(lambda x: x[1], self.truePositives)

                        def getFalsePositiveLocations(self):
                            return map(lambda x: x[1], self.falsePositives)

                        def getFalseNegativeLocations(self):
                            return map(lambda x: x[0], self.falseNegatives)

                    #The different call sets
                    marginAlignMaxExpectedSnpCalls = SnpCalls()
                    marginAlignMaxLikelihoodSnpCalls = SnpCalls()
                    maxFrequencySnpCalls = SnpCalls()
                    maximumLikelihoodSnpCalls = SnpCalls()

                    #Now calculate the calls
                    for refSeqName in refSequences:
                        refSeq = refSequences[refSeqName]
                        for refPosition in xrange(len(refSeq)):
                            mutatedRefBase = refSeq[refPosition].upper()
                            trueRefBase = (
                                mutatedRefBase
                                if not (refSeqName, refPosition) in snpSet else
                                snpSet[(refSeqName, refPosition)]).upper()
                            key = (refSeqName, refPosition)

                            #Get base calls
                            for errorSubstitutionMatrix, evolutionarySubstitutionMatrix, baseExpectations, snpCalls in \
                            ((flatSubstitutionMatrix, nullSubstitionMatrix, expectationsOfBasesAtEachPosition, marginAlignMaxExpectedSnpCalls),
                             (hmmErrorSubstitutionMatrix, nullSubstitionMatrix, expectationsOfBasesAtEachPosition, marginAlignMaxLikelihoodSnpCalls),
                             (flatSubstitutionMatrix, nullSubstitionMatrix, frequenciesOfAlignedBasesAtEachPosition, maxFrequencySnpCalls),
                             (hmmErrorSubstitutionMatrix, nullSubstitionMatrix, frequenciesOfAlignedBasesAtEachPosition, maximumLikelihoodSnpCalls)):

                                if key in baseExpectations:
                                    #Get posterior likelihoods
                                    expectations = baseExpectations[key]
                                    totalExpectation = sum(
                                        expectations.values())
                                    if totalExpectation > 0.0:  #expectationCallingThreshold:
                                        posteriorProbs = calcBasePosteriorProbs(
                                            dict(
                                                zip(
                                                    bases,
                                                    map(
                                                        lambda x: float(
                                                            expectations[x]) /
                                                        totalExpectation,
                                                        bases))),
                                            mutatedRefBase,
                                            evolutionarySubstitutionMatrix,
                                            errorSubstitutionMatrix)
                                        probs = [
                                            posteriorProbs[base]
                                            for base in "ACGT"
                                        ]
                                        #posteriorProbs.pop(mutatedRefBase) #Remove the ref base.
                                        #maxPosteriorProb = max(posteriorProbs.values())
                                        #chosenBase = random.choice([ base for base in posteriorProbs if posteriorProbs[base] == maxPosteriorProb ]).upper() #Very naive way to call the base

                                        for chosenBase in "ACGT":
                                            if chosenBase != mutatedRefBase:
                                                maxPosteriorProb = posteriorProbs[
                                                    chosenBase]
                                                if trueRefBase != mutatedRefBase and trueRefBase == chosenBase:
                                                    snpCalls.truePositives.append(
                                                        (maxPosteriorProb,
                                                         refPosition
                                                         ))  #True positive
                                                else:
                                                    snpCalls.falsePositives.append(
                                                        (maxPosteriorProb,
                                                         refPosition
                                                         ))  #False positive
                                                """
                                                    snpCalls.falseNegatives.append((refPosition, trueRefBase, mutatedRefBase, probs)) #False negative
                                                if trueRefBase != mutatedRefBase:
                                                    if trueRefBase == chosenBase:
                                                        snpCalls.truePositives.append((maxPosteriorProb, refPosition)) #True positive
                                                    else:
                                                        snpCalls.falseNegatives.append((refPosition, trueRefBase, mutatedRefBase, probs)) #False negative
                                                else:
                                                    snpCalls.falsePositives.append((maxPosteriorProb, refPosition)) #False positive
                                                """
                                else:
                                    snpCalls.notCalled += 1

                    #Now find max-fscore point

                    for snpCalls, tagName in (
                        (marginAlignMaxExpectedSnpCalls,
                         "marginAlignMaxExpectedSnpCalls"),
                        (marginAlignMaxLikelihoodSnpCalls,
                         "marginAlignMaxLikelihoodSnpCalls"),
                        (maxFrequencySnpCalls, "maxFrequencySnpCalls"),
                        (maximumLikelihoodSnpCalls,
                         "maximumLikelihoodSnpCalls")):
                        recall = snpCalls.getRecallByProbability()
                        precision = snpCalls.getPrecisionByProbability()
                        assert len(recall) == len(precision)
                        fScore, pIndex = max(
                            map(
                                lambda i:
                                (2 * recall[i] * precision[i] /
                                 (recall[i] + precision[i])
                                 if recall[i] + precision[i] > 0 else 0.0, i),
                                range(len(recall))))
                        truePositives = snpCalls.getRecallByProbability(
                        )[pIndex]
                        falsePositives = snpCalls.getPrecisionByProbability(
                        )[pIndex]
                        optimumProbThreshold = float(pIndex) / 100.0

                        #Write out the substitution info
                        node2 = ET.SubElement(
                            node, tagName + "_" + hmmType, {
                                "coverage":
                                str(coverage),
                                "actualCoverage":
                                str(
                                    float(totalAlignedPairs) /
                                    totalReferenceLength),
                                "totalAlignedPairs":
                                str(totalAlignedPairs),
                                "totalReferenceLength":
                                str(totalReferenceLength),
                                "replicate":
                                str(replicate),
                                "totalReads":
                                str(len(reads)),
                                "avgSampledReadLength":
                                str(
                                    float(totalReadLength) /
                                    totalSampledReads),
                                "totalSampledReads":
                                str(totalSampledReads),
                                "totalHeldOut":
                                str(totalHeldOut),
                                "totalNonHeldOut":
                                str(totalNotHeldOut),
                                "recall":
                                str(recall[pIndex]),
                                "precision":
                                str(precision[pIndex]),
                                "fScore":
                                str(fScore),
                                "optimumProbThreshold":
                                str(optimumProbThreshold),
                                "totalNoCalls":
                                str(snpCalls.notCalled),
                                "recallByProbability":
                                " ".join(
                                    map(str,
                                        snpCalls.getRecallByProbability())),
                                "precisionByProbability":
                                " ".join(
                                    map(str,
                                        snpCalls.getPrecisionByProbability()))
                            })

                        #"falsePositiveLocations":" ".join(map(str, snpCalls.getFalsePositiveLocations())),
                        #"falseNegativeLocations":" ".join(map(str, snpCalls.getFalseNegativeLocations())),
                        #"truePositiveLocations":" ".join(map(str, snpCalls.getTruePositiveLocations())) })
                        for refPosition, trueRefBase, mutatedRefBase, posteriorProbs in snpCalls.falseNegatives:
                            ET.SubElement(
                                node2, "falseNegative_%s_%s" %
                                (trueRefBase, mutatedRefBase), {
                                    "posteriorProbs":
                                    " ".join(map(str, posteriorProbs))
                                })
                        for falseNegativeBase in bases:
                            for mutatedBase in bases:
                                posteriorProbsArray = [
                                    posteriorProbs for refPosition,
                                    trueRefBase, mutatedRefBase, posteriorProbs
                                    in snpCalls.falseNegatives
                                    if (trueRefBase.upper() ==
                                        falseNegativeBase.upper()
                                        and mutatedBase.upper() ==
                                        mutatedRefBase.upper())
                                ]
                                if len(posteriorProbsArray) > 0:
                                    summedProbs = reduce(
                                        lambda x, y: map(
                                            lambda i: x[i] + y[i],
                                            xrange(len(x))),
                                        posteriorProbsArray)
                                    summedProbs = map(
                                        lambda x: float(x) / sum(summedProbs),
                                        summedProbs)
                                    ET.SubElement(
                                        node2, "combinedFalseNegative_%s_%s" %
                                        (falseNegativeBase, mutatedBase), {
                                            "posteriorProbs":
                                            " ".join(map(str, summedProbs))
                                        })

        open(os.path.join(self.outputDir, "marginaliseConsensus.xml"),
             "w").write(prettyXml(node))

        #Indicate everything is all done
        self.finish()
Example #14
0
    def run(self):
        #Call base method to do some logging
        AbstractAnalysis.run(self)

        #Hmm file
        hmmFile = os.path.join(os.path.split(self.samFile)[0], "hmm.txt.xml")
        if os.path.exists(hmmFile):
            #Load the hmm
            hmmsNode = ET.parse(hmmFile).getroot()

            #Plot graphviz version of nanopore hmm, showing transitions and variances.
            fH = open(os.path.join(self.outputDir, "hmm.dot"), 'w')
            setupGraphFile(fH)
            #Make states
            addNodeToGraph("n0n", fH, "match")
            addNodeToGraph("n1n", fH, "short delete")
            addNodeToGraph("n2n", fH, "short insert")
            addNodeToGraph("n3n", fH, "long insert")
            addNodeToGraph("n4n", fH, "long delete")

            #Make edges with labelled transition probs.
            for transition in hmmsNode.findall("transition"):
                if float(transition.attrib["avg"]) > 0.0:
                    addEdgeToGraph("n%sn" % transition.attrib["from"],
                                   "n%sn" % transition.attrib["to"],
                                   fH,
                                   dir="arrow",
                                   style='""',
                                   label="%.3f,%.3f" %
                                   (float(transition.attrib["avg"]),
                                    float(transition.attrib["std"])))

            #Finish up
            finishGraphFile(fH)
            fH.close()

            #Plot match emission data
            emissions = dict([ ((emission.attrib["x"], emission.attrib["y"]), emission.attrib["avg"]) \
                  for emission in hmmsNode.findall("emission") if emission.attrib["state"] == '0' ])

            matchEmissionsFile = os.path.join(self.outputDir,
                                              "matchEmissions.tsv")
            outf = open(matchEmissionsFile, "w")
            bases = "ACGT"
            outf.write("\t".join(bases) + "\n")
            for base in bases:
                outf.write("\t".join(
                    [base] + map(lambda x: emissions[(base, x)], bases)) +
                           "\n")
            outf.close()
            system("Rscript nanopore/analyses/substitution_plot.R %s %s %s" %
                   (matchEmissionsFile,
                    os.path.join(self.outputDir, "substitution_plot.pdf"),
                    "Per-Base Substitutions after HMM"))

            #Plot indel info
            #Get the sequences to contrast the neutral model.
            refSequences = getFastaDictionary(
                self.referenceFastaFile)  #Hash of names to sequences
            readSequences = getFastqDictionary(
                self.readFastqFile)  #Hash of names to sequences

            #Need to do plot of insert and deletion gap emissions

            #Plot of insert and deletion gap emissions
            insertEmissions = {"A": 0.0, 'C': 0.0, 'G': 0.0, 'T': 0.0}
            deleteEmissions = {"A": 0.0, 'C': 0.0, 'G': 0.0, 'T': 0.0}
            for emission in hmmsNode.findall("emission"):
                if emission.attrib["state"] == '2':
                    insertEmissions[emission.attrib["x"]] += float(
                        emission.attrib["avg"])
                elif emission.attrib["state"] == '1':
                    deleteEmissions[emission.attrib["y"]] += float(
                        emission.attrib["avg"])
            #PLot insert and delete emissions
            indelEmissionsFile = os.path.join(self.outputDir,
                                              "indelEmissions.tsv")
            outf = open(indelEmissionsFile, "w")
            outf.write("\t".join(bases) + "\n")
            outf.write(
                "\t".join(map(lambda x: str(insertEmissions[x]), bases)) +
                "\n")
            outf.write(
                "\t".join(map(lambda x: str(deleteEmissions[x]), bases)) +
                "\n")
            outf.close()
            ###Here's where we do the plot..
            system("Rscript nanopore/analyses/emissions_plot.R {} {}".format(
                indelEmissionsFile,
                os.path.join(self.outputDir, "indelEmissions_plot.pdf")))

            #Plot convergence of likelihoods
            outf = open(os.path.join(self.outputDir, "runninglikelihoods.tsv"),
                        "w")
            for hmmNode in hmmsNode.findall(
                    "hmm"):  #This is a loop over trials
                runningLikelihoods = map(
                    float, hmmNode.attrib["runningLikelihoods"].split()
                )  #This is a list of floats ordered from the first iteration to last.
                outf.write("\t".join(map(str, runningLikelihoods)))
                outf.write("\n")
            outf.close()
            system(
                "Rscript nanopore/analyses/running_likelihood.R {} {}".format(
                    os.path.join(self.outputDir, "runninglikelihoods.tsv"),
                    os.path.join(self.outputDir, "running_likelihood.pdf")))

        self.finish()  #Indicates the batch is done