Exemple #1
0
def read_bam(inbam,
             filter_exclude,
             resolution,
             min_count=2500,
             normalization='Vanilla',
             mappability=None,
             n_rsites=None,
             cg_content=None,
             sigma=2,
             ncpus=8,
             factor=1,
             outdir='.',
             extra_out='',
             only_valid=False,
             normalize_only=False,
             max_njobs=100,
             min_perc=None,
             max_perc=None,
             extra_bads=None):
    bamfile = AlignmentFile(inbam, 'rb')
    sections = OrderedDict(
        zip(bamfile.references, [x / resolution + 1 for x in bamfile.lengths]))
    total = 0
    section_pos = dict()
    for crm in sections:
        section_pos[crm] = (total, total + sections[crm])
        total += sections[crm]
    bins = []
    for crm in sections:
        len_crm = sections[crm]
        bins.extend([(crm, i) for i in xrange(len_crm)])

    start_bin = 0
    end_bin = len(bins)
    total = len(bins)

    regs = []
    begs = []
    ends = []
    njobs = min(total, max_njobs) + 1
    nbins = total / njobs + 1
    for i in range(start_bin, end_bin, nbins):
        if i + nbins > end_bin:  # make sure that we stop
            nbins = end_bin - i
        try:
            (crm1, beg1), (crm2, end2) = bins[i], bins[i + nbins - 1]
        except IndexError:
            try:
                (crm1, beg1), (crm2, end2) = bins[i], bins[-1]
            except IndexError:
                break
        if crm1 != crm2:
            end1 = sections[crm1]
            beg2 = 0
            regs.append(crm1)
            regs.append(crm2)
            begs.append(beg1 * resolution)
            begs.append(beg2 * resolution)
            ends.append(end1 * resolution + resolution)  # last nt included
            ends.append(end2 * resolution + resolution -
                        1)  # last nt not included (overlap with next window)
        else:
            regs.append(crm1)
            begs.append(beg1 * resolution)
            ends.append(end2 * resolution + resolution - 1)
    ends[-1] += 1  # last nucleotide included

    # print '\n'.join(['%s %d %d' % (a, b, c) for a, b, c in zip(regs, begs, ends)])
    printime('  - Parsing BAM (%d chunks)' % (len(regs)))
    bins_dict = dict([(j, i) for i, j in enumerate(bins)])
    pool = mu.Pool(ncpus)
    procs = []
    read_bam_frag = read_bam_frag_valid if only_valid else read_bam_frag_filter
    for i, (region, start, end) in enumerate(zip(regs, begs, ends)):
        procs.append(
            pool.apply_async(read_bam_frag,
                             args=(
                                 inbam,
                                 filter_exclude,
                                 bins,
                                 bins_dict,
                                 resolution,
                                 outdir,
                                 extra_out,
                                 region,
                                 start,
                                 end,
                             )))
    pool.close()
    print_progress(procs)
    pool.join()
    ## COLLECT RESULTS
    cisprc = {}
    printime('  - Collecting cis and total interactions per bin (%d chunks)' %
             (len(regs)))
    stdout.write('     ')
    for countbin, (region, start, end) in enumerate(zip(regs, begs, ends)):
        if not countbin % 10 and countbin:
            stdout.write(' ')
        if not countbin % 50 and countbin:
            stdout.write(' %9s\n     ' % ('%s/%s' % (countbin, len(regs))))
        stdout.write('.')
        stdout.flush()

        fname = path.join(
            outdir,
            'tmp_bins_%s:%d-%d_%s.pickle' % (region, start, end, extra_out))
        tmp_cisprc = load(open(fname))
        system('rm -f %s' % fname)
        cisprc.update(tmp_cisprc)
    stdout.write('\n')

    printime('  - Removing columns with too few or too much interactions')
    if len(bamfile.references) == 1 and min_count is None:
        raise Exception("ERROR: only one chromosome can't filter by "
                        "cis-percentage, set min_count instead")
    elif min_count is None and len(bamfile.references) > 1:
        badcol = filter_by_cis_percentage(
            cisprc,
            sigma=sigma,
            verbose=True,
            min_perc=min_perc,
            max_perc=max_perc,
            size=total,
            savefig=path.join(
                outdir, 'filtered_bins_%s_%s.png' %
                (nicer(resolution).replace(' ', ''), extra_out)))
    else:
        print(
            '      -> too few interactions defined as less than %9d '
            'interactions') % (min_count)
        badcol = {}
        countL = 0
        countZ = 0
        for c in xrange(total):
            if cisprc.get(c, [0, 0])[1] < min_count:
                badcol[c] = cisprc.get(c, [0, 0])[1]
                countL += 1
                if not c in cisprc:
                    countZ += 1
        print '      -> removed %d columns (%d/%d null/high counts) of %d (%.1f%%)' % (
            len(badcol), countZ, countL, total,
            float(len(badcol)) / total * 100)

    # no mappability will result in NaNs, better to filter out these columns
    if mappability:
        badcol.update((i, True) for i, m in enumerate(mappability) if not m)

    # add manually columns to bad columns
    if extra_bads:
        removed_manually = 0
        for ebc in extra_bads:
            c, ebc = ebc.split(':')
            b, e = map(int, ebc.split('-'))
            b = b / resolution + section_pos[c][0]
            e = e / resolution + section_pos[c][0]
            removed_manually += (e - b)
            badcol.update(dict((p, 'manual') for p in xrange(b, e)))
        printime('  - Removed %d columns manually.' % removed_manually)
    raw_cisprc = sum(
        float(cisprc[k][0]) / cisprc[k][1]
        for k in cisprc if not k in badcol) / (len(cisprc) - len(badcol))

    printime('  - Rescaling sum of interactions per bins')
    size = len(bins)
    biases = [
        float('nan') if k in badcol else cisprc.get(k, [0, 1.])[1]
        for k in xrange(size)
    ]

    if normalization == 'Vanilla':
        printime('  - Vanilla normalization')
        mean_col = nanmean(biases)
        biases = dict(
            (k, b / mean_col * mean_col**0.5) for k, b in enumerate(biases))
    elif normalization == 'oneD':
        printime('  - oneD normalization')
        if len(
                set([
                    len(biases),
                    len(mappability),
                    len(n_rsites),
                    len(cg_content)
                ])) > 1:
            print "biases", "mappability", "n_rsites", "cg_content"
            print len(biases), len(mappability), len(n_rsites), len(cg_content)
            raise Exception('Error: not all arrays have the same size')
        tmp_oneD = path.join(outdir, 'tmp_oneD_%s' % (extra_out))
        mkdir(tmp_oneD)
        biases = oneD(tmp_dir=tmp_oneD,
                      tot=biases,
                      map=mappability,
                      res=n_rsites,
                      cg=cg_content)
        biases = dict((k, b) for k, b in enumerate(biases))
        rmtree(tmp_oneD)
    else:
        raise NotImplementedError('ERROR: method %s not implemented' %
                                  normalization)

    # collect subset-matrices and write genomic one
    # out = open(os.path.join(outdir,
    #                         'hicdata_%s.abc' % (nicer(resolution).replace(' ', ''))), 'w')
    printime('  - Getting sum of normalized bins')
    pool = mu.Pool(ncpus)
    procs = []
    for i, (region, start, end) in enumerate(zip(regs, begs, ends)):
        fname = path.join(
            outdir, 'tmp_%s:%d-%d_%s.pickle' % (region, start, end, extra_out))
        procs.append(pool.apply_async(sum_nrm_matrix, args=(
            fname,
            biases,
        )))
    pool.close()
    print_progress(procs)
    pool.join()

    # to correct biases
    sumnrm = sum(p.get() for p in procs)

    target = (sumnrm / float(size * size * factor))**0.5
    biases = dict([(b, biases[b] * target) for b in biases])

    if not normalize_only:
        printime('  - Computing Cis percentage')
        # Calculate Cis percentage

        pool = mu.Pool(ncpus)
        procs = []
        for i, (region, start, end) in enumerate(zip(regs, begs, ends)):
            fname = path.join(
                outdir,
                'tmp_%s:%d-%d_%s.pickle' % (region, start, end, extra_out))
            procs.append(
                pool.apply_async(get_cis_perc,
                                 args=(fname, biases, badcol, bins)))
        pool.close()
        print_progress(procs)
        pool.join()

        # collect results
        cis = total = 0
        for proc in procs:
            c, t = proc.get()
            cis += c
            total += t
        norm_cisprc = float(cis) / total
        print '    * Cis-percentage: %.1f%%' % (norm_cisprc * 100)
    else:
        norm_cisprc = 0.

    printime('  - Rescaling decay')
    # normalize decay by size of the diagonal, and by Vanilla correction
    # (all cells must still be equals to 1 in average)

    pool = mu.Pool(ncpus)
    procs = []
    for i, (region, start, end) in enumerate(zip(regs, begs, ends)):
        fname = path.join(
            outdir, 'tmp_%s:%d-%d_%s.pickle' % (region, start, end, extra_out))
        procs.append(
            pool.apply_async(sum_dec_matrix,
                             args=(fname, biases, badcol, bins)))
    pool.close()
    print_progress(procs)
    pool.join()

    # collect results
    nrmdec = {}
    rawdec = {}
    for proc in procs:
        tmpnrm, tmpraw = proc.get()
        for c, d in tmpnrm.iteritems():
            for k, v in d.iteritems():
                try:
                    nrmdec[c][k] += v
                    rawdec[c][k] += tmpraw[c][k]
                except KeyError:
                    try:
                        nrmdec[c][k] = v
                        rawdec[c][k] = tmpraw[c][k]
                    except KeyError:
                        nrmdec[c] = {k: v}
                        rawdec[c] = {k: tmpraw[c][k]}
    # count the number of cells per diagonal
    # TODO: parallelize
    # find largest chromosome
    len_crms = dict(
        (c, section_pos[c][1] - section_pos[c][0]) for c in section_pos)
    # initialize dictionary
    ndiags = dict(
        (c, dict((k, 0) for k in xrange(len_crms[c]))) for c in sections)
    for crm in section_pos:
        beg_chr, end_chr = section_pos[crm][0], section_pos[crm][1]
        chr_size = end_chr - beg_chr
        thesebads = [b for b in badcol if beg_chr <= b <= end_chr]
        for dist in xrange(1, chr_size):
            ndiags[crm][dist] += chr_size - dist
            # from this we remove bad columns
            # bad columns will only affect if they are at least as distant from
            # a border as the distance between the longest diagonal and the
            # current diagonal.
            bad_diag = set(
            )  # 2 bad rows can point to the same bad cell in diagonal
            maxp = end_chr - dist
            minp = beg_chr + dist
            for b in thesebads:
                if b < maxp:  # not inclusive!!
                    bad_diag.add(b)
                if b >= minp:
                    bad_diag.add(b - dist)
            ndiags[crm][dist] -= len(bad_diag)
        # different behavior for longest diagonal:
        ndiags[crm][0] += chr_size - sum(beg_chr <= b < end_chr
                                         for b in thesebads)

    # normalize sum per diagonal by total number of cells in diagonal
    signal_to_noise = 0.05
    min_n = signal_to_noise**-2.  # equals 400 when default
    for crm in sections:
        if not crm in nrmdec:
            nrmdec[crm] = {}
            rawdec[crm] = {}
        tmpdec = 0  # store count by diagonal
        tmpsum = 0  # store count by diagonal
        ndiag = 0
        val = 0
        previous = [
        ]  # store diagonals to be summed in case not reaching the minimum
        for k in ndiags[crm]:
            tmpdec += nrmdec[crm].get(k, 0.)
            tmpsum += rawdec[crm].get(k, 0.)
            previous.append(k)
            if tmpsum > min_n:
                ndiag = sum(ndiags[crm][k] for k in previous)
                val = tmpdec  # backup of tmpdec kept for last ones outside the loop
                try:
                    ratio = val / ndiag
                    for k in previous:
                        nrmdec[crm][k] = ratio
                except ZeroDivisionError:  # all columns at this distance are "bad"
                    pass
                previous = []
                tmpdec = 0
                tmpsum = 0
        # last ones we average with previous result
        if len(previous) == len(ndiags[crm]):
            nrmdec[crm] = {}
        elif tmpsum < min_n:
            ndiag += sum(ndiags[crm][k] for k in previous)
            val += tmpdec
            try:
                ratio = val / ndiag
                for k in previous:
                    nrmdec[crm][k] = ratio
            except ZeroDivisionError:  # all columns at this distance are "bad"
                pass
    return biases, nrmdec, badcol, raw_cisprc, norm_cisprc
Exemple #2
0
def read_bam(inbam, filter_exclude, resolution, min_count=2500, biases_path='',
             normalization='Vanilla', mappability=None, n_rsites=None,
             cg_content=None, sigma=2, ncpus=8, factor=1, outdir='.', seed=1,
             extra_out='', only_valid=False, normalize_only=False, p_fit=None,
             max_njobs=100, min_perc=None, max_perc=None, extra_bads=None):
    bamfile = AlignmentFile(inbam, 'rb')
    sections = OrderedDict(zip(bamfile.references,
                               [x / resolution + 1 for x in bamfile.lengths]))
    total = 0
    section_pos = dict()
    for crm in sections:
        section_pos[crm] = (total, total + sections[crm])
        total += sections[crm]
    bins = []
    for crm in sections:
        len_crm = sections[crm]
        bins.extend([(crm, i) for i in xrange(len_crm)])

    start_bin = 0
    end_bin   = len(bins)
    total     = len(bins)

    regs = []
    begs = []
    ends = []
    njobs = min(total, max_njobs) + 1
    nbins = total / njobs + 1
    for i in range(start_bin, end_bin, nbins):
        if i + nbins > end_bin:  # make sure that we stop
            nbins = end_bin - i
        try:
            (crm1, beg1), (crm2, end2) = bins[i], bins[i + nbins - 1]
        except IndexError:
            try:
                (crm1, beg1), (crm2, end2) = bins[i], bins[-1]
            except IndexError:
                break
        if crm1 != crm2:
            end1 = sections[crm1]
            beg2 = 0
            regs.append(crm1)
            regs.append(crm2)
            begs.append(beg1 * resolution)
            begs.append(beg2 * resolution)
            ends.append(end1 * resolution + resolution)  # last nt included
            ends.append(end2 * resolution + resolution - 1)  # last nt not included (overlap with next window)
        else:
            regs.append(crm1)
            begs.append(beg1 * resolution)
            ends.append(end2 * resolution + resolution - 1)
    ends[-1] += 1  # last nucleotide included

    # print '\n'.join(['%s %d %d' % (a, b, c) for a, b, c in zip(regs, begs, ends)])
    printime('  - Parsing BAM (%d chunks)' % (len(regs)))
    bins_dict = dict([(j, i) for i, j in enumerate(bins)])
    pool = mu.Pool(ncpus)
    procs = []
    read_bam_frag = read_bam_frag_valid if only_valid else read_bam_frag_filter
    for i, (region, start, end) in enumerate(zip(regs, begs, ends)):
        procs.append(pool.apply_async(
            read_bam_frag, args=(inbam, filter_exclude, bins, bins_dict,
                                 resolution, outdir, extra_out,
                                 region, start, end,)))
    pool.close()
    print_progress(procs)
    pool.join()
    ## COLLECT RESULTS
    cisprc = {}
    printime('  - Collecting cis and total interactions per bin (%d chunks)' % (len(regs)))
    stdout.write('     ')
    for countbin, (region, start, end) in enumerate(zip(regs, begs, ends)):
        if not countbin % 10 and countbin:
            stdout.write(' ')
        if not countbin % 50 and countbin:
            stdout.write(' %9s\n     ' % ('%s/%s' % (countbin , len(regs))))
        stdout.write('.')
        stdout.flush()

        fname = path.join(outdir,
                          'tmp_bins_%s:%d-%d_%s.pickle' % (region, start, end, extra_out))
        tmp_cisprc = load(open(fname))
        system('rm -f %s' % fname)
        cisprc.update(tmp_cisprc)
    stdout.write('\n')

    printime('  - Removing columns with too few or too much interactions')
    if len(bamfile.references) == 1 and min_count is None:
        raise Exception("ERROR: only one chromosome can't filter by "
                        "cis-percentage, set min_count instead")
    elif min_count is None and len(bamfile.references) > 1:
        badcol = filter_by_cis_percentage(
            cisprc, sigma=sigma, verbose=True, min_perc=min_perc, max_perc=max_perc,
            size=total, savefig=None)
    else:
        print ('      -> too few interactions defined as less than %9d '
               'interactions') % (min_count)
        badcol = {}
        countL = 0
        countZ = 0
        for c in xrange(total):
            if cisprc.get(c, [0, 0])[1] < min_count:
                badcol[c] = cisprc.get(c, [0, 0])[1]
                countL += 1
                if not c in cisprc:
                    countZ += 1
        print '      -> removed %d columns (%d/%d null/high counts) of %d (%.1f%%)' % (
            len(badcol), countZ, countL, total, float(len(badcol)) / total * 100)

    # no mappability will result in NaNs, better to filter out these columns
    if mappability:
        badcol.update((i, True) for i, m in enumerate(mappability) if not m)

    # add manually columns to bad columns
    if extra_bads:
        removed_manually = 0
        for ebc in extra_bads:
            c, ebc = ebc.split(':')
            b, e = map(int, ebc.split('-'))
            b = b / resolution + section_pos[c][0]
            e = e / resolution + section_pos[c][0]
            removed_manually += (e - b)
            badcol.update(dict((p, 'manual') for p in xrange(b, e)))
        printime('  - Removed %d columns manually.' % removed_manually)
    raw_cisprc = sum(float(cisprc[k][0]) / cisprc[k][1]
                     for k in cisprc if not k in badcol) / (len(cisprc) - len(badcol))

    printime('  - Rescaling sum of interactions per bins')
    size = len(bins)
    biases = [float('nan') if k in badcol else cisprc.get(k, [0, 1.])[1]
              for k in xrange(size)]

    if normalization == 'ICE':
        printime('  - ICE normalization')
        hic_data = load_hic_data_from_bam(
            inbam, resolution, filter_exclude=filter_exclude,
            tmpdir=outdir, ncpus=ncpus)
        hic_data.bads = badcol
        hic_data.normalize_hic(iterations=100, max_dev=0.000001)
        biases = hic_data.bias.copy()
        del(hic_data)
    elif normalization == 'Vanilla':
        printime('  - Vanilla normalization')
        mean_col = nanmean(biases)
        biases   = dict((k, b / mean_col * mean_col**0.5)
                        for k, b in enumerate(biases))
    elif normalization == 'SQRT':
        printime('  - Vanilla-SQRT normalization')
        biases = [b**0.5 for b in biases]
        mean_col = nanmean(biases)
        biases   = dict((k, b / mean_col * mean_col**0.5)
                        for k, b in enumerate(biases))
    elif normalization == 'oneD':
        printime('  - oneD normalization')
        if len(set([len(biases), len(mappability), len(n_rsites), len(cg_content)])) > 1:
            print "biases", "mappability", "n_rsites", "cg_content"
            print len(biases), len(mappability), len(n_rsites), len(cg_content)
            raise Exception('Error: not all arrays have the same size')
        tmp_oneD = path.join(outdir,'tmp_oneD_%s' % (extra_out))
        mkdir(tmp_oneD)
        biases = oneD(tmp_dir=tmp_oneD, p_fit=p_fit, tot=biases, map=mappability,
                      res=n_rsites, cg=cg_content, seed=seed)
        biases = dict((k, b) for k, b in enumerate(biases))
        rmtree(tmp_oneD)
    elif normalization == 'custom':
        n_pos = 0
        biases = {}
        print 'Using provided biases...'
        with open(biases_path, 'r') as r:
            r.next()
            for line in r:
                if line[0] == 'N':
                    #b = float('nan')
                    badcol[n_pos] = 0
                    biases[n_pos] = float('nan')
                else:
                    b = float(line)
                    if b == 0:
                        badcol[n_pos] = 0
                        biases[n_pos] = float('nan')
                    else:
                        biases[n_pos] = b
                n_pos += 1
        for add in range(max(biases.keys()), total + 1):
            biases[add] = float('nan')
    else:
        raise NotImplementedError('ERROR: method %s not implemented' %
                                  normalization)

    # collect subset-matrices and write genomic one
    # out = open(os.path.join(outdir,
    #                         'hicdata_%s.abc' % (nicer(resolution).replace(' ', ''))), 'w')
    printime('  - Getting sum of normalized bins')
    pool = mu.Pool(ncpus)
    procs = []
    for i, (region, start, end) in enumerate(zip(regs, begs, ends)):
        fname = path.join(outdir,
                          'tmp_%s:%d-%d_%s.pickle' % (region, start, end, extra_out))
        procs.append(pool.apply_async(sum_nrm_matrix,
                                      args=(fname, biases,)))
    pool.close()
    print_progress(procs)
    pool.join()

    # to correct biases
    sumnrm = sum(p.get() for p in procs)

    target = (sumnrm / float(size * size * factor))**0.5
    biases = dict([(b, biases[b] * target) for b in biases])

    if not normalize_only:
        printime('  - Computing Cis percentage')
        # Calculate Cis percentage

        pool = mu.Pool(ncpus)
        procs = []
        for i, (region, start, end) in enumerate(zip(regs, begs, ends)):
            fname = path.join(outdir,
                              'tmp_%s:%d-%d_%s.pickle' % (region, start, end, extra_out))
            procs.append(pool.apply_async(get_cis_perc,
                                          args=(fname, biases, badcol, bins)))
        pool.close()
        print_progress(procs)
        pool.join()

        # collect results
        cis = total = 0
        for proc in procs:
            c, t = proc.get()
            cis += c
            total += t
        norm_cisprc = float(cis) / total
        print '    * Cis-percentage: %.1f%%' % (norm_cisprc * 100)
    else:
        norm_cisprc = 0.

    printime('  - Rescaling decay')
    # normalize decay by size of the diagonal, and by Vanilla correction
    # (all cells must still be equals to 1 in average)

    pool = mu.Pool(ncpus)
    procs = []
    for i, (region, start, end) in enumerate(zip(regs, begs, ends)):
        fname = path.join(outdir,
                          'tmp_%s:%d-%d_%s.pickle' % (region, start, end, extra_out))
        procs.append(pool.apply_async(sum_dec_matrix,
                                      args=(fname, biases, badcol, bins)))
    pool.close()
    print_progress(procs)
    pool.join()

    # collect results
    nrmdec = {}
    rawdec = {}
    for proc in procs:
        tmpnrm, tmpraw = proc.get()
        for c, d in tmpnrm.iteritems():
            for k, v in d.iteritems():
                try:
                    nrmdec[c][k] += v
                    rawdec[c][k] += tmpraw[c][k]
                except KeyError:
                    try:
                        nrmdec[c][k]  = v
                        rawdec[c][k] = tmpraw[c][k]
                    except KeyError:
                        nrmdec[c] = {k: v}
                        rawdec[c] = {k: tmpraw[c][k]}
    # count the number of cells per diagonal
    # TODO: parallelize
    # find largest chromosome
    len_crms = dict((c, section_pos[c][1] - section_pos[c][0]) for c in section_pos)
    # initialize dictionary
    ndiags = dict((c, dict((k, 0) for k in xrange(len_crms[c]))) for c in sections)
    for crm in section_pos:
        beg_chr, end_chr = section_pos[crm][0], section_pos[crm][1]
        chr_size = end_chr - beg_chr
        thesebads = [b for b in badcol if beg_chr <= b <= end_chr]
        for dist in xrange(1, chr_size):
            ndiags[crm][dist] += chr_size - dist
            # from this we remove bad columns
            # bad columns will only affect if they are at least as distant from
            # a border as the distance between the longest diagonal and the
            # current diagonal.
            bad_diag = set()  # 2 bad rows can point to the same bad cell in diagonal
            maxp = end_chr - dist
            minp = beg_chr + dist
            for b in thesebads:
                if b < maxp:  # not inclusive!!
                    bad_diag.add(b)
                if b >= minp:
                    bad_diag.add(b - dist)
            ndiags[crm][dist] -= len(bad_diag)
        # different behavior for longest diagonal:
        ndiags[crm][0] += chr_size - sum(beg_chr <= b < end_chr for b in thesebads)

    # normalize sum per diagonal by total number of cells in diagonal
    signal_to_noise = 0.05
    min_n = signal_to_noise ** -2. # equals 400 when default
    for crm in sections:
        if not crm in nrmdec:
            nrmdec[crm] = {}
            rawdec[crm] = {}
        tmpdec = 0  # store count by diagonal
        tmpsum = 0  # store count by diagonal
        ndiag  = 0
        val    = 0
        previous = [] # store diagonals to be summed in case not reaching the minimum
        for k in ndiags[crm]:
            tmpdec += nrmdec[crm].get(k, 0.)
            tmpsum += rawdec[crm].get(k, 0.)
            previous.append(k)
            if tmpsum > min_n:
                ndiag = sum(ndiags[crm][k] for k in previous)
                val = tmpdec  # backup of tmpdec kept for last ones outside the loop
                try:
                    ratio = val / ndiag
                    for l in previous:
                        nrmdec[crm][l] = ratio
                except ZeroDivisionError:  # all columns at this distance are "bad"
                    pass
                previous = []
                tmpdec = 0
                tmpsum = 0
        # last ones we average with previous result
        if  len(previous) == len(ndiags[crm]):
            nrmdec[crm] = {}
        elif tmpsum < min_n:
            ndiag += sum(ndiags[crm][k] for k in previous)
            val += tmpdec
            try:
                ratio = val / ndiag
                for k in previous:
                    nrmdec[crm][k] = ratio
            except ZeroDivisionError:  # all columns at this distance are "bad"
                pass
    return biases, nrmdec, badcol, raw_cisprc, norm_cisprc
def read_bam(inbam,
             filter_exclude,
             resolution,
             min_count=2500,
             sigma=2,
             ncpus=8,
             factor=1,
             outdir='.',
             check_sum=False):
    bamfile = AlignmentFile(inbam, 'rb')
    sections = OrderedDict(
        zip(bamfile.references, [x / resolution + 1 for x in bamfile.lengths]))
    total = 0
    section_pos = dict()
    for crm in sections:
        section_pos[crm] = (total, total + sections[crm])
        total += sections[crm] + 1
    bins = []
    for crm in sections:
        len_crm = sections[crm]
        bins.extend([(crm, i) for i in xrange(len_crm + 1)])

    start_bin = 0
    end_bin = len(bins) + 1
    total = len(bins)

    total = end_bin - start_bin + 1
    regs = []
    begs = []
    ends = []
    njobs = min(total, 100) + 1
    nbins = total / njobs + 1
    for i in range(start_bin, end_bin, nbins):
        if i + nbins > end_bin:  # make sure that we stop at the right place
            nbins = end_bin - i
        try:
            (crm1, beg1), (crm2, end2) = bins[i], bins[i + nbins - 1]
        except IndexError:
            (crm1, beg1), (crm2, end2) = bins[i], bins[-1]
        if crm1 != crm2:
            end1 = sections[crm1]
            beg2 = 0
            regs.append(crm1)
            regs.append(crm2)
            begs.append(beg1 * resolution)
            begs.append(beg2 * resolution)
            ends.append(end1 * resolution + resolution)  # last nt included
            ends.append(end2 * resolution + resolution -
                        1)  # last nt not included (overlap with next window)
        else:
            regs.append(crm1)
            begs.append(beg1 * resolution)
            ends.append(end2 * resolution + resolution - 1)
    ends[-1] += 1  # last nucleotide included

    # print '\n'.join(['%s %d %d' % (a, b, c) for a, b, c in zip(regs, begs, ends)])
    printime('\n  - Parsing BAM (%d chunks)' % (len(regs)))
    bins_dict = dict([(j, i) for i, j in enumerate(bins)])
    pool = mu.Pool(ncpus)
    procs = []
    for i, (region, start, end) in enumerate(zip(regs, begs, ends)):
        procs.append(
            pool.apply_async(read_bam_frag,
                             args=(
                                 inbam,
                                 filter_exclude,
                                 bins,
                                 bins_dict,
                                 resolution,
                                 outdir,
                                 region,
                                 start,
                                 end,
                             )))
    pool.close()
    print_progress(procs)
    pool.join()

    ## COLLECT RESULTS
    verbose = True
    cisprc = {}
    for countbin, (region, start, end) in enumerate(zip(regs, begs, ends)):
        if verbose:
            if not countbin % 10 and countbin:
                sys.stdout.write(' ')
            if not countbin % 50 and countbin:
                sys.stdout.write(' %9s\n     ' % ('%s/%s' %
                                                  (countbin, len(regs))))
            sys.stdout.write('.')
            sys.stdout.flush()

        fname = os.path.join(outdir,
                             'tmp_bins_%s:%d-%d.pickle' % (region, start, end))
        tmp_cisprc = load(open(fname))
        cisprc.update(tmp_cisprc)
    if verbose:
        print '%s %9s\n' % (' ' * (54 - (countbin % 50) -
                                   (countbin % 50) / 10), '%s/%s' %
                            (len(regs), len(regs)))

    # out = open(os.path.join(outdir, 'dicos_%s.pickle' % (
    #     nicer(resolution).replace(' ', ''))), 'w')
    # dump(cisprc, out)
    # out.close()
    # bad columns
    def func_gen(x, *args):
        cmd = "zzz = " + func_restring % (args)
        exec(cmd) in globals(), locals()
        #print cmd
        try:
            return np.lib.asarray_chkfinite(zzz)
        except:
            # avoid the creation of NaNs when invalid values for power or log
            return x

    print '  - Removing columns with too few or too much interactions'
    if not min_count:

        badcol = filter_by_cis_percentage(
            cisprc,
            sigma=sigma,
            verbose=True,
            savefig=os.path.join(outdir + 'filtered_bins_%s.png' %
                                 (nicer(resolution).replace(' ', ''))))
    else:
        print '      -> too few  interactions defined as less than %9d interactions' % (
            min_count)
        for k in cisprc:
            cisprc[k] = cisprc[k][1]
        badcol = {}
        countL = 0
        countZ = 0
        for c in xrange(total):
            if cisprc.get(c, 0) < min_count:
                badcol[c] = cisprc.get(c, 0)
                countL += 1
                if not c in cisprc:
                    countZ += 1
        print '      -> removed %d columns (%d/%d null/high counts) of %d (%.1f%%)' % (
            len(badcol), countZ, countL, total,
            float(len(badcol)) / total * 100)

    printime('  - Rescaling biases')
    size = len(bins)
    biases = [cisprc.get(k, 1.) for k in range(size)]
    mean_col = float(sum(biases)) / len(biases)
    biases = dict([(k, b / mean_col * mean_col**0.5)
                   for k, b in enumerate(biases)])

    # collect subset-matrices and write genomic one
    # out = open(os.path.join(outdir,
    #                         'hicdata_%s.abc' % (nicer(resolution).replace(' ', ''))), 'w')
    pool = mu.Pool(ncpus)
    procs = []
    for i, (region, start, end) in enumerate(zip(regs, begs, ends)):
        fname = os.path.join(outdir,
                             'tmp_%s:%d-%d.pickle' % (region, start, end))
        procs.append(pool.apply_async(sum_nrm_matrix, args=(
            fname,
            biases,
        )))
    pool.close()
    print_progress(procs)
    pool.join()

    # to correct biases
    sumnrm = sum(p.get() for p in procs)

    target = (sumnrm / float(size * size * factor))**0.5
    biases = dict([(b, biases[b] * target) for b in biases])

    # check the sum
    if check_sum:
        pool = mu.Pool(ncpus)
        procs = []
        for i, (region, start, end) in enumerate(zip(regs, begs, ends)):
            fname = os.path.join(outdir,
                                 'tmp_%s:%d-%d.pickle' % (region, start, end))
            procs.append(
                pool.apply_async(sum_nrm_matrix, args=(
                    fname,
                    biases,
                )))
        pool.close()
        print_progress(procs)
        pool.join()

        # to correct biases
        sumnrm = sum(p.get() for p in procs)
        print 'SUM:', sumnrm

    printime('  - Rescaling decay')
    # normalize decay by size of the diagonal, and by Vanilla correction
    # (all cells must still be equals to 1 in average)

    pool = mu.Pool(ncpus)
    procs = []
    for i, (region, start, end) in enumerate(zip(regs, begs, ends)):
        fname = os.path.join(outdir,
                             'tmp_%s:%d-%d.pickle' % (region, start, end))
        procs.append(
            pool.apply_async(sum_dec_matrix,
                             args=(fname, biases, badcol, bins)))
    pool.close()
    print_progress(procs)
    pool.join()

    # collect results
    sumdec = {}
    for proc in procs:
        for k, v in proc.get().iteritems():
            try:
                sumdec[k] += v
            except KeyError:
                sumdec[k] = v

    # count the number of cells per diagonal
    # TODO: parallelize
    # find larget chromsome
    len_big = max(section_pos[c][1] - section_pos[c][0] for c in section_pos)
    # initialize dictionary
    ndiags = dict((k, 0) for k in xrange(len_big))
    for crm in section_pos:
        beg_chr, end_chr = section_pos[crm][0], section_pos[crm][1]
        chr_size = end_chr - beg_chr
        thesebads = [b for b in badcol if beg_chr <= b <= end_chr]
        for dist in xrange(1, chr_size):
            ndiags[dist] += chr_size - dist
            # from this we remove bad columns
            # bad columns will only affect if they are at least as distant from
            # a border as the distance between the longest diagonal and the
            # current diagonal.
            bad_diag = set(
            )  # 2 bad rows can point to the same bad cell in diagonal
            maxp = end_chr - dist
            minp = beg_chr + dist
            for b in thesebads:
                if b <= maxp:
                    bad_diag.add(b)
                if b >= minp:
                    bad_diag.add(b - dist)
            ndiags[dist] -= len(bad_diag)
        # chr_sizeerent behavior for longest diagonal:
        ndiags[0] += chr_size - len(thesebads)

    # normalize sum per diagonal by total number of cells in diagonal
    for k in sumdec:
        try:
            sumdec[k] /= ndiags[k]
        except ZeroDivisionError:  # all columns at this distance are "bad"
            pass

    return biases, sumdec, badcol
def read_bam(inbam, filter_exclude, resolution, min_count=2500,
             sigma=2, ncpus=8, factor=1, outdir='.', check_sum=False):
    bamfile = AlignmentFile(inbam, 'rb')
    sections = OrderedDict(zip(bamfile.references,
                               [x / resolution + 1 for x in bamfile.lengths]))
    total = 0
    section_pos = dict()
    for crm in sections:
        section_pos[crm] = (total, total + sections[crm])
        total += sections[crm] + 1
    bins = []
    for crm in sections:
        len_crm = sections[crm]
        bins.extend([(crm, i) for i in xrange(len_crm + 1)])

    start_bin = 0
    end_bin   = len(bins) + 1
    total = len(bins)

    total = end_bin - start_bin + 1
    regs = []
    begs = []
    ends = []
    njobs = min(total, 100) + 1
    nbins = total / njobs + 1
    for i in range(start_bin, end_bin, nbins):
        if i + nbins > end_bin:  # make sure that we stop at the right place
            nbins = end_bin - i
        try:
            (crm1, beg1), (crm2, end2) = bins[i], bins[i + nbins - 1]
        except IndexError:
            (crm1, beg1), (crm2, end2) = bins[i], bins[-1]
        if crm1 != crm2:
            end1 = sections[crm1]
            beg2 = 0
            regs.append(crm1)
            regs.append(crm2)
            begs.append(beg1 * resolution)
            begs.append(beg2 * resolution)
            ends.append(end1 * resolution + resolution)  # last nt included
            ends.append(end2 * resolution + resolution - 1)  # last nt not included (overlap with next window)
        else:
            regs.append(crm1)
            begs.append(beg1 * resolution)
            ends.append(end2 * resolution + resolution - 1)
    ends[-1] += 1  # last nucleotide included

    # print '\n'.join(['%s %d %d' % (a, b, c) for a, b, c in zip(regs, begs, ends)])
    printime('\n  - Parsing BAM (%d chunks)' % (len(regs)))
    bins_dict = dict([(j, i) for i, j in enumerate(bins)])
    pool = mu.Pool(ncpus)
    procs = []
    for i, (region, start, end) in enumerate(zip(regs, begs, ends)):
        procs.append(pool.apply_async(
            read_bam_frag, args=(inbam, filter_exclude, bins, bins_dict,
                                 resolution, outdir, region, start, end,)))
    pool.close()
    print_progress(procs)
    pool.join()

    ## COLLECT RESULTS
    verbose = True
    cisprc = {}
    for countbin, (region, start, end) in enumerate(zip(regs, begs, ends)):
        if verbose:
            if not countbin % 10 and countbin:
                sys.stdout.write(' ')
            if not countbin % 50 and countbin:
                sys.stdout.write(' %9s\n     ' % ('%s/%s' % (countbin , len(regs))))
            sys.stdout.write('.')
            sys.stdout.flush()

        fname = os.path.join(outdir,
                             'tmp_bins_%s:%d-%d.pickle' % (region, start, end))
        tmp_cisprc = load(open(fname))
        cisprc.update(tmp_cisprc)
    if verbose:
        print '%s %9s\n' % (' ' * (54 - (countbin % 50) - (countbin % 50) / 10),
                            '%s/%s' % (len(regs),len(regs)))

    # out = open(os.path.join(outdir, 'dicos_%s.pickle' % (
    #     nicer(resolution).replace(' ', ''))), 'w')
    # dump(cisprc, out)
    # out.close()
    # bad columns
    def func_gen(x, *args):
        cmd = "zzz = " + func_restring % (args)
        exec(cmd) in globals(), locals()
        #print cmd
        try:
            return np.lib.asarray_chkfinite(zzz)
        except:
            # avoid the creation of NaNs when invalid values for power or log
            return x
    print '  - Removing columns with too few or too much interactions'
    if not min_count:

        badcol = filter_by_cis_percentage(
            cisprc, sigma=sigma, verbose=True,
            savefig=os.path.join(outdir + 'filtered_bins_%s.png' % (
                nicer(resolution).replace(' ', ''))))
    else:
        print '      -> too few  interactions defined as less than %9d interactions' % (
            min_count)
        for k in cisprc:
            cisprc[k] = cisprc[k][1]
        badcol = {}
        countL = 0
        countZ = 0
        for c in xrange(total):
            if cisprc.get(c, 0) < min_count:
                badcol[c] = cisprc.get(c, 0)
                countL += 1
                if not c in cisprc:
                    countZ += 1
        print '      -> removed %d columns (%d/%d null/high counts) of %d (%.1f%%)' % (
            len(badcol), countZ, countL, total, float(len(badcol)) / total * 100)

    printime('  - Rescaling biases')
    size = len(bins)
    biases = [cisprc.get(k, 1.) for k in range(size)]
    mean_col = float(sum(biases)) / len(biases)
    biases = dict([(k, b / mean_col * mean_col**0.5)
                   for k, b in enumerate(biases)])

    # collect subset-matrices and write genomic one
    # out = open(os.path.join(outdir,
    #                         'hicdata_%s.abc' % (nicer(resolution).replace(' ', ''))), 'w')
    pool = mu.Pool(ncpus)
    procs = []
    for i, (region, start, end) in enumerate(zip(regs, begs, ends)):
        fname = os.path.join(outdir, 'tmp_%s:%d-%d.pickle' % (region, start, end))
        procs.append(pool.apply_async(sum_nrm_matrix, args=(fname, biases, )))
    pool.close()
    print_progress(procs)
    pool.join()

    # to correct biases
    sumnrm = sum(p.get() for p in procs)

    target = (sumnrm / float(size * size * factor))**0.5
    biases = dict([(b, biases[b] * target) for b in biases])

    # check the sum
    if check_sum:
        pool = mu.Pool(ncpus)
        procs = []
        for i, (region, start, end) in enumerate(zip(regs, begs, ends)):
            fname = os.path.join(outdir, 'tmp_%s:%d-%d.pickle' % (region, start, end))
            procs.append(pool.apply_async(sum_nrm_matrix, args=(fname, biases, )))
        pool.close()
        print_progress(procs)
        pool.join()

        # to correct biases
        sumnrm = sum(p.get() for p in procs)
        print 'SUM:', sumnrm

    printime('  - Rescaling decay')
    # normalize decay by size of the diagonal, and by Vanilla correction
    # (all cells must still be equals to 1 in average)

    pool = mu.Pool(ncpus)
    procs = []
    for i, (region, start, end) in enumerate(zip(regs, begs, ends)):
        fname = os.path.join(outdir,
                             'tmp_%s:%d-%d.pickle' % (region, start, end))
        procs.append(pool.apply_async(sum_dec_matrix,
                                      args=(fname, biases, badcol, bins)))
    pool.close()
    print_progress(procs)
    pool.join()

    # collect results
    sumdec = {}
    for proc in procs:
        for k, v in proc.get().iteritems():
            try:
                sumdec[k] += v
            except KeyError:
                sumdec[k]  = v

    # count the number of cells per diagonal
    # TODO: parallelize
    # find larget chromsome
    len_big = max(section_pos[c][1] - section_pos[c][0] for c in section_pos)
    # initialize dictionary
    ndiags = dict((k, 0) for k in xrange(len_big))
    for crm in section_pos:
        beg_chr, end_chr = section_pos[crm][0], section_pos[crm][1]
        chr_size = end_chr - beg_chr
        thesebads = [b for b in badcol if beg_chr <= b <= end_chr]
        for dist in xrange(1, chr_size):
            ndiags[dist] += chr_size - dist
            # from this we remove bad columns
            # bad columns will only affect if they are at least as distant from
            # a border as the distance between the longest diagonal and the
            # current diagonal.
            bad_diag = set()  # 2 bad rows can point to the same bad cell in diagonal
            maxp = end_chr - dist
            minp = beg_chr + dist
            for b in thesebads:
                if b <= maxp:
                    bad_diag.add(b)
                if b >= minp:
                    bad_diag.add(b - dist)
            ndiags[dist] -= len(bad_diag)
        # chr_sizeerent behavior for longest diagonal:
        ndiags[0] += chr_size - len(thesebads)

    # normalize sum per diagonal by total number of cells in diagonal
    for k in sumdec:
        try:
            sumdec[k] /= ndiags[k]
        except ZeroDivisionError:  # all columns at this distance are "bad"
            pass

    return biases, sumdec, badcol