Exemplo n.º 1
0
    def new_sample(self, bytes, independent_sites):
        # Posterior  sample given previous  peak/call counts  (or zero
        # counts, if it's the initial sample.)
        peak_posterior = self.posterior()
        # New peak/call counts
        counts = self.zero_counts()
        # New call counts
        call_counts = scipy.zeros(10)
        site_dists = []
        total_observations = dict((s, 0) for s in '+-')
        for byteidx, cbytes in enumerate(bytes):
            # Get the posterior distributions given each observed peak
            # at this site.
            peaks, dists = zip(*[peak_posterior[s].call_dist(b,s)
                                 for b, s in cbytes])
            # Sample a posterior call
            site_dist = self.gdist*scipy.product(dists, axis=0)
            assert sum(site_dist)
            site_dist = site_dist/sum(site_dist)
            site_dists.append(site_dist)
            call = SNPPrior.sample_call(site_dist)
            strandcall = {'+': call, '-': dict_revcalls[call]}
            # Count the peak/call pairs.
            for peak, (byte, strand) in zip(peaks, cbytes):
                # Undo the reverse complement, if this observation
                # was on the reverse strand.
                ccall = strandcall[strand]
                ccounts = counts[strand]
                # Flat peak/call count
                ccounts[ccall][SNPPrior.peakidxs[peak.peaks]] += 1
                # Call->(peak,score) count
                ccounts['%s-%s' % (ccall, peak.peaks)][peak.bin]+=1
                total_observations[strand] += 1
            # Count of imputed calls
            call_counts[callidxs[call]] += 1
        # Blur the conditional distribution, to speed convergence.
        for strand, ccounts in counts.items():
            for val in ccounts.values():
                val /= total_observations[strand]
                val *= min(total_observations[strand], 100)
        gdist = hyperprior.posterior(call_counts).sample()

        return Sample(gdist, counts, site_dists)
Exemplo n.º 2
0
 def zero_counts(cls):
     return dict((strand, SNPPrior.counts()) for strand in '+-')