Exemple #1
0
def run(opts):
    check_options(opts)
    samtools = which(opts.samtools)
    launch_time = time.localtime()

    param_hash = digest_parameters(opts)

    reso1 = reso2 = None
    if opts.bam1:
        mreads1 = path.realpath(opts.bam1)
        biases1 = opts.biases1
    else:
        biases1, mreads1, reso1 = load_parameters_fromdb(
            opts.workdir1, opts.jobid1, opts, opts.tmpdb1)
        mreads1 = path.join(opts.workdir1, mreads1)
        try:
            biases1 = path.join(opts.workdir1, biases1)
        except AttributeError:
            biases1 = None
        except TypeError:  # Py3
            biases1 = None

    if opts.bam2:
        mreads2 = path.realpath(opts.bam2)
        biases2 = opts.biases2
    else:
        biases2, mreads2, reso2 = load_parameters_fromdb(
            opts.workdir2, opts.jobid2, opts, opts.tmpdb2)
        mreads2 = path.join(opts.workdir2, mreads2)
        try:
            biases2 = path.join(opts.workdir2, biases2)
        except AttributeError:
            biases2 = None
        except TypeError:  # Py3
            biases1 = None

    filter_exclude = opts.filter

    if reso1 != reso2:
        raise Exception('ERROR: differing resolutions between experiments to '
                        'be merged')

    mkdir(path.join(opts.workdir, '00_merge'))

    if not opts.skip_comparison:
        printime('  - loading first sample %s' % (mreads1))
        hic_data1 = load_hic_data_from_bam(mreads1, opts.reso, biases=biases1,
                                           tmpdir=path.join(opts.workdir, '00_merge'),
                                           ncpus=opts.cpus,
                                           filter_exclude=filter_exclude)

        printime('  - loading second sample %s' % (mreads2))
        hic_data2 = load_hic_data_from_bam(mreads2, opts.reso, biases=biases2,
                                           tmpdir=path.join(opts.workdir, '00_merge'),
                                           ncpus=opts.cpus,
                                           filter_exclude=filter_exclude)

        if opts.workdir1 and opts.workdir2:
            masked1 = {'valid-pairs': {'count': 0}}
            masked2 = {'valid-pairs': {'count': 0}}
        else:
            masked1 = {'valid-pairs': {'count': sum(hic_data1.values())}}
            masked2 = {'valid-pairs': {'count': sum(hic_data2.values())}}

        decay_corr_dat = path.join(opts.workdir, '00_merge', 'decay_corr_dat_%s_%s.txt' % (opts.reso, param_hash))
        decay_corr_fig = path.join(opts.workdir, '00_merge', 'decay_corr_dat_%s_%s.png' % (opts.reso, param_hash))
        eigen_corr_dat = path.join(opts.workdir, '00_merge', 'eigen_corr_dat_%s_%s.txt' % (opts.reso, param_hash))
        eigen_corr_fig = path.join(opts.workdir, '00_merge', 'eigen_corr_dat_%s_%s.png' % (opts.reso, param_hash))

        printime('  - comparing experiments')
        printime('    => correlation between equidistant loci')
        corr, _, scc, std, bads = correlate_matrices(
            hic_data1, hic_data2, normalized=opts.norm,
            remove_bad_columns=True, savefig=decay_corr_fig,
            savedata=decay_corr_dat, get_bads=True)
        print('         - correlation score (SCC): %.4f (+- %.7f)' % (scc, std))
        printime('    => correlation between eigenvectors')
        eig_corr = eig_correlate_matrices(hic_data1, hic_data2, normalized=opts.norm,
                                          remove_bad_columns=True, nvect=6,
                                          savefig=eigen_corr_fig,
                                          savedata=eigen_corr_dat)

        printime('    => reproducibility score')
        reprod = get_reproducibility(hic_data1, hic_data2, num_evec=20, normalized=opts.norm,
                                     verbose=False, remove_bad_columns=True)
        print('         - reproducibility score: %.4f' % (reprod))
        ncols = len(hic_data1)
    else:
        ncols = 0
        decay_corr_dat = 'None'
        decay_corr_fig = 'None'
        eigen_corr_dat = 'None'
        eigen_corr_fig = 'None'
        masked1 = {}
        masked2 = {}

        corr = eig_corr = scc = std = reprod = 0
        bads = {}

    # merge inputs
    mkdir(path.join(opts.workdir, '03_filtered_reads'))
    outbam = path.join(opts.workdir, '03_filtered_reads',
                       'intersection_%s.bam' % (param_hash))

    if not opts.skip_merge:
        outbam = path.join(opts.workdir, '03_filtered_reads',
                           'intersection_%s.bam' % (param_hash))
        printime('  - Mergeing experiments')
        system(samtools  + ' merge -@ %d %s %s %s' % (opts.cpus, outbam, mreads1, mreads2))
        printime('  - Indexing new BAM file')
        # check samtools version number and modify command line
        version = LooseVersion([l.split()[1]
                                for l in Popen(samtools, stderr=PIPE,
                                               universal_newlines=True).communicate()[1].split('\n')
                                if 'Version' in l][0])
        if version >= LooseVersion('1.3.1'):
            system(samtools  + ' index -@ %d %s' % (opts.cpus, outbam))
        else:
            system(samtools  + ' index %s' % (outbam))
    else:
        outbam = ''

    finish_time = time.localtime()
    save_to_db (opts, mreads1, mreads2, decay_corr_dat, decay_corr_fig,
                len(list(bads.keys())), ncols, scc, std, reprod,
                eigen_corr_dat, eigen_corr_fig, outbam, corr, eig_corr,
                biases1, biases2, masked1, masked2, launch_time, finish_time)
    printime('\nDone.')
Exemple #2
0
def read_bam(inbam, filter_exclude, resolution, min_count=2500, biases_path='',
             normalization='Vanilla', mappability=None, n_rsites=None,
             cg_content=None, sigma=2, ncpus=8, factor=1, outdir='.', seed=1,
             extra_out='', only_valid=False, normalize_only=False, p_fit=None,
             max_njobs=100, extra_bads=None, 
             cis_limit=1, trans_limit=5, min_ratio=1.0, fast_filter=False):
    bamfile = AlignmentFile(inbam, 'rb')
    sections = OrderedDict(list(zip(bamfile.references,
                               [x // resolution + 1 for x in bamfile.lengths])))
    total = 0
    section_pos = dict()
    for crm in sections:
        section_pos[crm] = (total, total + sections[crm])
        total += sections[crm]
    bins = []
    for crm in sections:
        len_crm = sections[crm]
        bins.extend([(crm, i) for i in range(len_crm)])

    start_bin = 0
    end_bin   = len(bins)
    total     = len(bins)

    regs = []
    begs = []
    ends = []
    njobs = min(total, max_njobs) + 1
    nbins = total // njobs + 1
    for i in range(start_bin, end_bin, nbins):
        if i + nbins > end_bin:  # make sure that we stop
            nbins = end_bin - i
        try:
            (crm1, beg1), (crm2, end2) = bins[i], bins[i + nbins - 1]
        except IndexError:
            try:
                (crm1, beg1), (crm2, end2) = bins[i], bins[-1]
            except IndexError:
                break
        if crm1 != crm2:
            end1 = sections[crm1]
            beg2 = 0
            regs.append(crm1)
            regs.append(crm2)
            begs.append(beg1 * resolution)
            begs.append(beg2 * resolution)
            ends.append(end1 * resolution + resolution)  # last nt included
            ends.append(end2 * resolution + resolution - 1)  # last nt not included (overlap with next window)
        else:
            regs.append(crm1)
            begs.append(beg1 * resolution)
            ends.append(end2 * resolution + resolution - 1)
    ends[-1] += 1  # last nucleotide included

    # print '\n'.join(['%s %d %d' % (a, b, c) for a, b, c in zip(regs, begs, ends)])
    printime('  - Parsing BAM (%d chunks)' % (len(regs)))
    # define limits for cis and trans interactions if not given
    if cis_limit is None:
        cis_limit = int(1_000_000 / resolution)
    print('      -> cis interactions are defined as being bellow {}'.format(
        nicer(cis_limit * resolution)))
    if trans_limit is None:
        trans_limit = cis_limit * 5
    print('      -> trans interactions are defined as being bellow {}'.format(
        nicer(trans_limit * resolution)))

    bins_dict = dict([(j, i) for i, j in enumerate(bins)])
    pool = mu.Pool(ncpus)
    procs = []
    read_bam_frag = read_bam_frag_valid if only_valid else read_bam_frag_filter
    for i, (region, start, end) in enumerate(zip(regs, begs, ends)):
        procs.append(pool.apply_async(
            read_bam_frag, args=(inbam, filter_exclude, bins, bins_dict,
                                 resolution, outdir, extra_out,
                                 region, start, end, cis_limit, trans_limit)))
    pool.close()
    print_progress(procs)
    pool.join()
    ## COLLECT RESULTS
    cisprc = {}
    printime('  - Collecting cis and total interactions per bin (%d chunks)' % (len(regs)))
    stdout.write('     ')
    for countbin, (region, start, end) in enumerate(zip(regs, begs, ends)):
        if not countbin % 10 and countbin:
            stdout.write(' ')
        if not countbin % 50 and countbin:
            stdout.write(' %9s\n     ' % ('%s/%s' % (countbin , len(regs))))
        stdout.write('.')
        stdout.flush()

        fname = path.join(outdir,
                          'tmp_bins_%s:%d-%d_%s.pickle' % (region, start, end, extra_out))
        tmp_cisprc = load(open(fname,'rb'))
        system('rm -f %s' % fname)
        cisprc.update(tmp_cisprc)
    stdout.write('\n')

    # get cis/trans ratio
    for k in cisprc:
        try:
            cisprc[k][3] = cisprc[k][2] / cisprc[k][3]
        except ZeroDivisionError:
            cisprc[k][3] = 0

    # BIN FILTERINGS
    printime('  - Removing columns with too few or too much interactions')
    
    # define filter for minimum interactions per bin
    if not fast_filter:
        if min_count is None:
            min_count = nanpercentile(
                [cisprc[k][2] for k in range(total) 
                if cisprc.get(k, [0, 0, 0, 0])[3] < min_ratio 
                and cisprc.get(k, [0, 0, 0, 0])[2] >= 1], 95)  # harcoded parameter we are filtering
                                                                # out bins with no interactions in cis

        print('      -> too few interactions defined as less than %9d '
               'interactions' % (min_count))
        badcol = dict((k, True) for k in range(total) 
                    if cisprc.get(k, [0, 0, 0, 0])[3] < min_ratio
                    or cisprc[k][2] < min_count)
        print('      -> removed %d columns of %d (%.1f%%)' % (
            len(badcol), total, float(len(badcol)) / total * 100))
    else:
    # if len(bamfile.references) == 1 and min_count is None:
    #     raise Exception("ERROR: only one chromosome can't filter by "
    #                     "cis-percentage, set min_count instead")
    # elif min_count is None and len(bamfile.references) > 1:
        # badcol = filter_by_cis_percentage(
        #     cisprc, sigma=sigma, verbose=True, min_perc=min_perc, max_perc=max_perc,
        #     size=total, savefig=None)

        print('      -> too few interactions defined as less than %9d '
               'interactions' % (min_count))
        badcol = {}
        countL = 0
        countZ = 0
        for c in range(total):
            if cisprc.get(c, [0, 0, 0, 0])[1] < min_count:
                badcol[c] = cisprc.get(c, [0, 0, 0, 0])[1]
                countL += 1
                if not c in cisprc:
                    countZ += 1
        print('      -> removed %d columns (%d/%d null/high counts) of %d (%.1f%%)' % (
            len(badcol), countZ, countL, total, float(len(badcol)) / total * 100))

    # Plot
    plot_filtering(dict((k, cisprc[k][2]) for k in cisprc), 
                   dict((k, cisprc[k][3]) for k in cisprc), total, min_count, min_ratio, 
                   path.join(outdir, 'filtering_summary_plot_{}_{}.png'.format(nicer(resolution, sep=''),
                    extra_out)),
                   base_position=0, next_position=cis_limit, last_position=trans_limit, resolution=resolution,
                   legend='Filtered {} of {} bins'.format(len(badcol), total))

    # no mappability will result in NaNs, better to filter out these columns
    if mappability:
        badcol.update((i, True) for i, m in enumerate(mappability) if not m)

    # add manually columns to bad columns
    if extra_bads:
        removed_manually = 0
        for ebc in extra_bads:
            c, ebc = ebc.split(':')
            b, e = list(map(int, ebc.split('-')))
            b = b // resolution + section_pos[c][0]
            e = e // resolution + section_pos[c][0]
            removed_manually += (e - b)
            badcol.update(dict((p, 'manual') for p in range(b, e)))
        printime('  - Removed %d columns manually.' % removed_manually)
    raw_cisprc = sum(float(cisprc[k][0]) / cisprc[k][1]
                     for k in cisprc if not k in badcol) / (len(cisprc) - len(badcol))

    printime('  - Rescaling sum of interactions per bins')
    size = len(bins)
    biases = [float('nan') if k in badcol else cisprc.get(k, [0, 1., 0, 0])[1]
              for k in range(size)]

    if normalization == 'ICE':
        printime('  - ICE normalization')
        hic_data = load_hic_data_from_bam(
            inbam, resolution, filter_exclude=filter_exclude,
            tmpdir=outdir, ncpus=ncpus, nchunks=max_njobs)
        hic_data.bads = badcol
        hic_data.normalize_hic(iterations=100, max_dev=0.000001)
        biases = hic_data.bias.copy()
        del(hic_data)
    elif normalization == 'Vanilla':
        printime('  - Vanilla normalization')
        mean_col = nanmean(biases)
        biases   = dict((k, b / mean_col * mean_col**0.5)
                        for k, b in enumerate(biases))
    elif normalization == 'SQRT':
        printime('  - Vanilla-SQRT normalization')
        biases = [b**0.5 for b in biases]
        mean_col = nanmean(biases)
        biases   = dict((k, b / mean_col * mean_col**0.5)
                        for k, b in enumerate(biases))
    elif normalization == 'oneD':
        printime('  - oneD normalization')
        if len(set([len(biases), len(mappability), len(n_rsites), len(cg_content)])) > 1:
            print("biases", "mappability", "n_rsites", "cg_content")
            print(len(biases), len(mappability), len(n_rsites), len(cg_content))
            raise Exception('Error: not all arrays have the same size')
        tmp_oneD = path.join(outdir,'tmp_oneD_%s' % (extra_out))
        mkdir(tmp_oneD)
        biases = oneD(tmp_dir=tmp_oneD, p_fit=p_fit, tot=biases, map=mappability,
                      res=n_rsites, cg=cg_content, seed=seed)
        biases = dict((k, b) for k, b in enumerate(biases))
        rmtree(tmp_oneD)
    elif normalization == 'custom':
        n_pos = 0
        biases = {}
        print('Using provided biases...')
        with open(biases_path, 'r') as r:
            next(r)
            for line in r:
                if line[0] == 'N':
                    #b = float('nan')
                    badcol[n_pos] = 0
                    biases[n_pos] = float('nan')
                else:
                    b = float(line)
                    if b == 0:
                        badcol[n_pos] = 0
                        biases[n_pos] = float('nan')
                    else:
                        biases[n_pos] = b
                n_pos += 1
        for add in range(max(biases.keys()), total + 1):
            biases[add] = float('nan')
    else:
        raise NotImplementedError('ERROR: method %s not implemented' %
                                  normalization)

    # collect subset-matrices and write genomic one
    # out = open(os.path.join(outdir,
    #                         'hicdata_%s.abc' % (nicer(resolution).replace(' ', ''))), 'w')
    printime('  - Getting sum of normalized bins')
    pool = mu.Pool(ncpus)
    procs = []
    for i, (region, start, end) in enumerate(zip(regs, begs, ends)):
        fname = path.join(outdir,
                          'tmp_%s:%d-%d_%s.pickle' % (region, start, end, extra_out))
        procs.append(pool.apply_async(sum_nrm_matrix,
                                      args=(fname, biases,)))
    pool.close()
    print_progress(procs)
    pool.join()

    # to correct biases
    sumnrm = sum(p.get() for p in procs)

    target = (sumnrm / float(size * size * factor))**0.5
    biases = dict([(b, biases[b] * target) for b in biases])

    if not normalize_only:
        printime('  - Computing Cis percentage')
        # Calculate Cis percentage

        pool = mu.Pool(ncpus)
        procs = []
        for i, (region, start, end) in enumerate(zip(regs, begs, ends)):
            fname = path.join(outdir,
                              'tmp_%s:%d-%d_%s.pickle' % (region, start, end, extra_out))
            procs.append(pool.apply_async(get_cis_perc,
                                          args=(fname, biases, badcol, bins)))
        pool.close()
        print_progress(procs)
        pool.join()

        # collect results
        cis = total = 0
        for proc in procs:
            c, t = proc.get()
            cis += c
            total += t
        norm_cisprc = float(cis) / total
        print('    * Cis-percentage: %.1f%%' % (norm_cisprc * 100))
    else:
        norm_cisprc = 0.

    printime('  - Rescaling decay')
    # normalize decay by size of the diagonal, and by Vanilla correction
    # (all cells must still be equals to 1 in average)

    pool = mu.Pool(ncpus)
    procs = []
    for i, (region, start, end) in enumerate(zip(regs, begs, ends)):
        fname = path.join(outdir,
                          'tmp_%s:%d-%d_%s.pickle' % (region, start, end, extra_out))
        procs.append(pool.apply_async(sum_dec_matrix,
                                      args=(fname, biases, badcol, bins)))
    pool.close()
    print_progress(procs)
    pool.join()

    # collect results
    nrmdec = {}
    rawdec = {}
    for proc in procs:
        tmpnrm, tmpraw = proc.get()
        for c, d in tmpnrm.items():
            for k, v in d.items():
                try:
                    nrmdec[c][k] += v
                    rawdec[c][k] += tmpraw[c][k]
                except KeyError:
                    try:
                        nrmdec[c][k]  = v
                        rawdec[c][k] = tmpraw[c][k]
                    except KeyError:
                        nrmdec[c] = {k: v}
                        rawdec[c] = {k: tmpraw[c][k]}
    # count the number of cells per diagonal
    # TODO: parallelize
    # find largest chromosome
    len_crms = dict((c, section_pos[c][1] - section_pos[c][0]) for c in section_pos)
    # initialize dictionary
    ndiags = dict((c, dict((k, 0) for k in range(len_crms[c]))) for c in sections)
    for crm in section_pos:
        beg_chr, end_chr = section_pos[crm][0], section_pos[crm][1]
        chr_size = end_chr - beg_chr
        thesebads = [b for b in badcol if beg_chr <= b <= end_chr]
        for dist in range(1, chr_size):
            ndiags[crm][dist] += chr_size - dist
            # from this we remove bad columns
            # bad columns will only affect if they are at least as distant from
            # a border as the distance between the longest diagonal and the
            # current diagonal.
            bad_diag = set()  # 2 bad rows can point to the same bad cell in diagonal
            maxp = end_chr - dist
            minp = beg_chr + dist
            for b in thesebads:
                if b < maxp:  # not inclusive!!
                    bad_diag.add(b)
                if b >= minp:
                    bad_diag.add(b - dist)
            ndiags[crm][dist] -= len(bad_diag)
        # different behavior for longest diagonal:
        ndiags[crm][0] += chr_size - sum(beg_chr <= b < end_chr for b in thesebads)

    # normalize sum per diagonal by total number of cells in diagonal
    signal_to_noise = 0.05
    min_n = signal_to_noise ** -2. # equals 400 when default
    for crm in sections:
        if not crm in nrmdec:
            nrmdec[crm] = {}
            rawdec[crm] = {}
        tmpdec = 0  # store count by diagonal
        tmpsum = 0  # store count by diagonal
        ndiag  = 0
        val    = 0
        previous = [] # store diagonals to be summed in case not reaching the minimum
        for k in ndiags[crm]:
            tmpdec += nrmdec[crm].get(k, 0.)
            tmpsum += rawdec[crm].get(k, 0.)
            previous.append(k)
            if tmpsum > min_n:
                ndiag = sum(ndiags[crm][k] for k in previous)
                val = tmpdec  # backup of tmpdec kept for last ones outside the loop
                try:
                    ratio = val / ndiag
                    for l in previous:
                        nrmdec[crm][l] = ratio
                except ZeroDivisionError:  # all columns at this distance are "bad"
                    pass
                previous = []
                tmpdec = 0
                tmpsum = 0
        # last ones we average with previous result
        if  len(previous) == len(ndiags[crm]):
            nrmdec[crm] = {}
        elif tmpsum < min_n:
            ndiag += sum(ndiags[crm][k] for k in previous)
            val += tmpdec
            try:
                ratio = val / ndiag
                for k in previous:
                    nrmdec[crm][k] = ratio
            except ZeroDivisionError:  # all columns at this distance are "bad"
                pass
    return biases, nrmdec, badcol, raw_cisprc, norm_cisprc
Exemple #3
0
def read_bam(inbam, filter_exclude, resolution, min_count=2500, biases_path='',
             normalization='Vanilla', mappability=None, n_rsites=None,
             cg_content=None, sigma=2, ncpus=8, factor=1, outdir='.', seed=1,
             extra_out='', only_valid=False, normalize_only=False, p_fit=None,
             max_njobs=100, min_perc=None, max_perc=None, extra_bads=None):
    bamfile = AlignmentFile(inbam, 'rb')
    sections = OrderedDict(zip(bamfile.references,
                               [x / resolution + 1 for x in bamfile.lengths]))
    total = 0
    section_pos = dict()
    for crm in sections:
        section_pos[crm] = (total, total + sections[crm])
        total += sections[crm]
    bins = []
    for crm in sections:
        len_crm = sections[crm]
        bins.extend([(crm, i) for i in xrange(len_crm)])

    start_bin = 0
    end_bin   = len(bins)
    total     = len(bins)

    regs = []
    begs = []
    ends = []
    njobs = min(total, max_njobs) + 1
    nbins = total / njobs + 1
    for i in range(start_bin, end_bin, nbins):
        if i + nbins > end_bin:  # make sure that we stop
            nbins = end_bin - i
        try:
            (crm1, beg1), (crm2, end2) = bins[i], bins[i + nbins - 1]
        except IndexError:
            try:
                (crm1, beg1), (crm2, end2) = bins[i], bins[-1]
            except IndexError:
                break
        if crm1 != crm2:
            end1 = sections[crm1]
            beg2 = 0
            regs.append(crm1)
            regs.append(crm2)
            begs.append(beg1 * resolution)
            begs.append(beg2 * resolution)
            ends.append(end1 * resolution + resolution)  # last nt included
            ends.append(end2 * resolution + resolution - 1)  # last nt not included (overlap with next window)
        else:
            regs.append(crm1)
            begs.append(beg1 * resolution)
            ends.append(end2 * resolution + resolution - 1)
    ends[-1] += 1  # last nucleotide included

    # print '\n'.join(['%s %d %d' % (a, b, c) for a, b, c in zip(regs, begs, ends)])
    printime('  - Parsing BAM (%d chunks)' % (len(regs)))
    bins_dict = dict([(j, i) for i, j in enumerate(bins)])
    pool = mu.Pool(ncpus)
    procs = []
    read_bam_frag = read_bam_frag_valid if only_valid else read_bam_frag_filter
    for i, (region, start, end) in enumerate(zip(regs, begs, ends)):
        procs.append(pool.apply_async(
            read_bam_frag, args=(inbam, filter_exclude, bins, bins_dict,
                                 resolution, outdir, extra_out,
                                 region, start, end,)))
    pool.close()
    print_progress(procs)
    pool.join()
    ## COLLECT RESULTS
    cisprc = {}
    printime('  - Collecting cis and total interactions per bin (%d chunks)' % (len(regs)))
    stdout.write('     ')
    for countbin, (region, start, end) in enumerate(zip(regs, begs, ends)):
        if not countbin % 10 and countbin:
            stdout.write(' ')
        if not countbin % 50 and countbin:
            stdout.write(' %9s\n     ' % ('%s/%s' % (countbin , len(regs))))
        stdout.write('.')
        stdout.flush()

        fname = path.join(outdir,
                          'tmp_bins_%s:%d-%d_%s.pickle' % (region, start, end, extra_out))
        tmp_cisprc = load(open(fname))
        system('rm -f %s' % fname)
        cisprc.update(tmp_cisprc)
    stdout.write('\n')

    printime('  - Removing columns with too few or too much interactions')
    if len(bamfile.references) == 1 and min_count is None:
        raise Exception("ERROR: only one chromosome can't filter by "
                        "cis-percentage, set min_count instead")
    elif min_count is None and len(bamfile.references) > 1:
        badcol = filter_by_cis_percentage(
            cisprc, sigma=sigma, verbose=True, min_perc=min_perc, max_perc=max_perc,
            size=total, savefig=None)
    else:
        print ('      -> too few interactions defined as less than %9d '
               'interactions') % (min_count)
        badcol = {}
        countL = 0
        countZ = 0
        for c in xrange(total):
            if cisprc.get(c, [0, 0])[1] < min_count:
                badcol[c] = cisprc.get(c, [0, 0])[1]
                countL += 1
                if not c in cisprc:
                    countZ += 1
        print '      -> removed %d columns (%d/%d null/high counts) of %d (%.1f%%)' % (
            len(badcol), countZ, countL, total, float(len(badcol)) / total * 100)

    # no mappability will result in NaNs, better to filter out these columns
    if mappability:
        badcol.update((i, True) for i, m in enumerate(mappability) if not m)

    # add manually columns to bad columns
    if extra_bads:
        removed_manually = 0
        for ebc in extra_bads:
            c, ebc = ebc.split(':')
            b, e = map(int, ebc.split('-'))
            b = b / resolution + section_pos[c][0]
            e = e / resolution + section_pos[c][0]
            removed_manually += (e - b)
            badcol.update(dict((p, 'manual') for p in xrange(b, e)))
        printime('  - Removed %d columns manually.' % removed_manually)
    raw_cisprc = sum(float(cisprc[k][0]) / cisprc[k][1]
                     for k in cisprc if not k in badcol) / (len(cisprc) - len(badcol))

    printime('  - Rescaling sum of interactions per bins')
    size = len(bins)
    biases = [float('nan') if k in badcol else cisprc.get(k, [0, 1.])[1]
              for k in xrange(size)]

    if normalization == 'ICE':
        printime('  - ICE normalization')
        hic_data = load_hic_data_from_bam(
            inbam, resolution, filter_exclude=filter_exclude,
            tmpdir=outdir, ncpus=ncpus)
        hic_data.bads = badcol
        hic_data.normalize_hic(iterations=100, max_dev=0.000001)
        biases = hic_data.bias.copy()
        del(hic_data)
    elif normalization == 'Vanilla':
        printime('  - Vanilla normalization')
        mean_col = nanmean(biases)
        biases   = dict((k, b / mean_col * mean_col**0.5)
                        for k, b in enumerate(biases))
    elif normalization == 'SQRT':
        printime('  - Vanilla-SQRT normalization')
        biases = [b**0.5 for b in biases]
        mean_col = nanmean(biases)
        biases   = dict((k, b / mean_col * mean_col**0.5)
                        for k, b in enumerate(biases))
    elif normalization == 'oneD':
        printime('  - oneD normalization')
        if len(set([len(biases), len(mappability), len(n_rsites), len(cg_content)])) > 1:
            print "biases", "mappability", "n_rsites", "cg_content"
            print len(biases), len(mappability), len(n_rsites), len(cg_content)
            raise Exception('Error: not all arrays have the same size')
        tmp_oneD = path.join(outdir,'tmp_oneD_%s' % (extra_out))
        mkdir(tmp_oneD)
        biases = oneD(tmp_dir=tmp_oneD, p_fit=p_fit, tot=biases, map=mappability,
                      res=n_rsites, cg=cg_content, seed=seed)
        biases = dict((k, b) for k, b in enumerate(biases))
        rmtree(tmp_oneD)
    elif normalization == 'custom':
        n_pos = 0
        biases = {}
        print 'Using provided biases...'
        with open(biases_path, 'r') as r:
            r.next()
            for line in r:
                if line[0] == 'N':
                    #b = float('nan')
                    badcol[n_pos] = 0
                    biases[n_pos] = float('nan')
                else:
                    b = float(line)
                    if b == 0:
                        badcol[n_pos] = 0
                        biases[n_pos] = float('nan')
                    else:
                        biases[n_pos] = b
                n_pos += 1
        for add in range(max(biases.keys()), total + 1):
            biases[add] = float('nan')
    else:
        raise NotImplementedError('ERROR: method %s not implemented' %
                                  normalization)

    # collect subset-matrices and write genomic one
    # out = open(os.path.join(outdir,
    #                         'hicdata_%s.abc' % (nicer(resolution).replace(' ', ''))), 'w')
    printime('  - Getting sum of normalized bins')
    pool = mu.Pool(ncpus)
    procs = []
    for i, (region, start, end) in enumerate(zip(regs, begs, ends)):
        fname = path.join(outdir,
                          'tmp_%s:%d-%d_%s.pickle' % (region, start, end, extra_out))
        procs.append(pool.apply_async(sum_nrm_matrix,
                                      args=(fname, biases,)))
    pool.close()
    print_progress(procs)
    pool.join()

    # to correct biases
    sumnrm = sum(p.get() for p in procs)

    target = (sumnrm / float(size * size * factor))**0.5
    biases = dict([(b, biases[b] * target) for b in biases])

    if not normalize_only:
        printime('  - Computing Cis percentage')
        # Calculate Cis percentage

        pool = mu.Pool(ncpus)
        procs = []
        for i, (region, start, end) in enumerate(zip(regs, begs, ends)):
            fname = path.join(outdir,
                              'tmp_%s:%d-%d_%s.pickle' % (region, start, end, extra_out))
            procs.append(pool.apply_async(get_cis_perc,
                                          args=(fname, biases, badcol, bins)))
        pool.close()
        print_progress(procs)
        pool.join()

        # collect results
        cis = total = 0
        for proc in procs:
            c, t = proc.get()
            cis += c
            total += t
        norm_cisprc = float(cis) / total
        print '    * Cis-percentage: %.1f%%' % (norm_cisprc * 100)
    else:
        norm_cisprc = 0.

    printime('  - Rescaling decay')
    # normalize decay by size of the diagonal, and by Vanilla correction
    # (all cells must still be equals to 1 in average)

    pool = mu.Pool(ncpus)
    procs = []
    for i, (region, start, end) in enumerate(zip(regs, begs, ends)):
        fname = path.join(outdir,
                          'tmp_%s:%d-%d_%s.pickle' % (region, start, end, extra_out))
        procs.append(pool.apply_async(sum_dec_matrix,
                                      args=(fname, biases, badcol, bins)))
    pool.close()
    print_progress(procs)
    pool.join()

    # collect results
    nrmdec = {}
    rawdec = {}
    for proc in procs:
        tmpnrm, tmpraw = proc.get()
        for c, d in tmpnrm.iteritems():
            for k, v in d.iteritems():
                try:
                    nrmdec[c][k] += v
                    rawdec[c][k] += tmpraw[c][k]
                except KeyError:
                    try:
                        nrmdec[c][k]  = v
                        rawdec[c][k] = tmpraw[c][k]
                    except KeyError:
                        nrmdec[c] = {k: v}
                        rawdec[c] = {k: tmpraw[c][k]}
    # count the number of cells per diagonal
    # TODO: parallelize
    # find largest chromosome
    len_crms = dict((c, section_pos[c][1] - section_pos[c][0]) for c in section_pos)
    # initialize dictionary
    ndiags = dict((c, dict((k, 0) for k in xrange(len_crms[c]))) for c in sections)
    for crm in section_pos:
        beg_chr, end_chr = section_pos[crm][0], section_pos[crm][1]
        chr_size = end_chr - beg_chr
        thesebads = [b for b in badcol if beg_chr <= b <= end_chr]
        for dist in xrange(1, chr_size):
            ndiags[crm][dist] += chr_size - dist
            # from this we remove bad columns
            # bad columns will only affect if they are at least as distant from
            # a border as the distance between the longest diagonal and the
            # current diagonal.
            bad_diag = set()  # 2 bad rows can point to the same bad cell in diagonal
            maxp = end_chr - dist
            minp = beg_chr + dist
            for b in thesebads:
                if b < maxp:  # not inclusive!!
                    bad_diag.add(b)
                if b >= minp:
                    bad_diag.add(b - dist)
            ndiags[crm][dist] -= len(bad_diag)
        # different behavior for longest diagonal:
        ndiags[crm][0] += chr_size - sum(beg_chr <= b < end_chr for b in thesebads)

    # normalize sum per diagonal by total number of cells in diagonal
    signal_to_noise = 0.05
    min_n = signal_to_noise ** -2. # equals 400 when default
    for crm in sections:
        if not crm in nrmdec:
            nrmdec[crm] = {}
            rawdec[crm] = {}
        tmpdec = 0  # store count by diagonal
        tmpsum = 0  # store count by diagonal
        ndiag  = 0
        val    = 0
        previous = [] # store diagonals to be summed in case not reaching the minimum
        for k in ndiags[crm]:
            tmpdec += nrmdec[crm].get(k, 0.)
            tmpsum += rawdec[crm].get(k, 0.)
            previous.append(k)
            if tmpsum > min_n:
                ndiag = sum(ndiags[crm][k] for k in previous)
                val = tmpdec  # backup of tmpdec kept for last ones outside the loop
                try:
                    ratio = val / ndiag
                    for l in previous:
                        nrmdec[crm][l] = ratio
                except ZeroDivisionError:  # all columns at this distance are "bad"
                    pass
                previous = []
                tmpdec = 0
                tmpsum = 0
        # last ones we average with previous result
        if  len(previous) == len(ndiags[crm]):
            nrmdec[crm] = {}
        elif tmpsum < min_n:
            ndiag += sum(ndiags[crm][k] for k in previous)
            val += tmpdec
            try:
                ratio = val / ndiag
                for k in previous:
                    nrmdec[crm][k] = ratio
            except ZeroDivisionError:  # all columns at this distance are "bad"
                pass
    return biases, nrmdec, badcol, raw_cisprc, norm_cisprc
Exemple #4
0
def run(opts):
    check_options(opts)
    launch_time = time.localtime()
    param_hash = digest_parameters(opts, get_md5=True)

    if opts.nosql:
        biases = opts.biases
        mreads = opts.mreads
        inputs = []
    elif opts.biases or opts.mreads:
        if not opts.mreads:
            raise Exception('ERROR: also need to provide BAM file')
        if not opts.biases:
            raise Exception('ERROR: also need to provide biases file')
        biases = opts.biases
        mreads = opts.mreads
        inputs = ['NA', 'NA']
        mkdir(path.join(opts.workdir))
    else:
        biases, mreads, biases_id, mreads_id = load_parameters_fromdb(opts)
        inputs = [biases_id, mreads_id]
        # store path ids to be saved in database
        mreads = path.join(opts.workdir, mreads)
        biases = path.join(opts.workdir, biases)

    reso   = opts.reso

    mkdir(path.join(opts.workdir, '06_segmentation'))

    print 'loading %s \n    at resolution %s' % (mreads, nice(reso))
    region = None
    if opts.crms and len(opts.crms) == 1:
        region = opts.crms[0]
    hic_data = load_hic_data_from_bam(mreads, reso, ncpus=opts.cpus,
                                      region=region,
                                      biases=None if opts.all_bins else biases,
                                      filter_exclude=opts.filter)

    # compartments
    cmp_result = {}
    richA_stats = {}
    firsts = {}
    if not opts.only_tads:
        print 'Searching compartments'
        cmprt_dir = path.join(opts.workdir, '06_segmentation',
                              'compartments_%s' % (nice(reso)))
        mkdir(cmprt_dir)
        if opts.fasta:
            print '  - Computing GC content to label compartments'
            rich_in_A = get_gc_content(parse_fasta(opts.fasta, chr_filter=opts.crms), reso,
                                       chromosomes=opts.crms,
                                       by_chrom=True, n_cpus=opts.cpus)
        elif opts.rich_in_A:
            rich_in_A = opts.rich_in_A
        else:
            rich_in_A = None
        n_evs = opts.n_evs if opts.n_evs > 0 else 3
        firsts, richA_stats = hic_data.find_compartments(
            crms=opts.crms, savefig=cmprt_dir, verbose=True, suffix=param_hash,
            rich_in_A=rich_in_A, show_compartment_labels=rich_in_A is not None,
            savecorr=cmprt_dir if opts.savecorr else None,
            max_ev=n_evs,
            ev_index=opts.ev_index,
            vmin=None if opts.fix_corr_scale else 'auto',
            vmax=None if opts.fix_corr_scale else 'auto')

        for ncrm, crm in enumerate(opts.crms or hic_data.chromosomes):
            if not crm in firsts:
                continue
            ev_file = open(path.join(
                cmprt_dir, '%s_EigVect%d_%s.tsv' % (
                    crm, opts.ev_index[ncrm] if opts.ev_index else 1,
                    param_hash)), 'w')
            ev_file.write('# %s\n' % ('\t'.join(
                'EV_%d (%.4f)' % (i, v)
                for i, v in enumerate(firsts[crm][0], 1))))
            ev_file.write('\n'.join(['\t'.join([str(v) for v in vs])
                                     for vs in zip(*firsts[crm][1])]))
            ev_file.close()

        for ncrm, crm in enumerate(opts.crms or hic_data.chromosomes):
            cmprt_file1 = path.join(cmprt_dir, '%s_%s.tsv' % (crm, param_hash))
            cmprt_file2 = path.join(cmprt_dir, '%s_EigVect%d_%s.tsv' % (
                crm, opts.ev_index[ncrm] if opts.ev_index else 1, param_hash))
            cmprt_image = path.join(cmprt_dir, '%s_EV%d_%s.%s' % (
                crm, opts.ev_index[ncrm] if opts.ev_index else 1,
                param_hash, opts.format))
            if opts.savecorr:
                cormat_file = path.join(cmprt_dir, '%s_corr-matrix%s.tsv' %
                                       (crm, param_hash))
            else:
                cormat_file = None
            hic_data.write_compartments(cmprt_file1, chroms=[crm])
            cmp_result[crm] = {'path_cmprt1': cmprt_file1,
                               'path_cmprt2': cmprt_file2,
                               'path_cormat': cormat_file,
                               'image_cmprt': cmprt_image,
                               'num' : len(hic_data.compartments[crm])}

    # TADs
    tad_result = {}
    if not opts.only_compartments:
        print 'Searching TADs'
        tad_dir = path.join(opts.workdir, '06_segmentation',
                             'tads_%s' % (nice(reso)))
        mkdir(tad_dir)
        for crm in hic_data.chromosomes:
            if opts.crms and not crm in opts.crms:
                continue
            print '  - %s' % crm
            matrix = hic_data.get_matrix(focus=crm)
            beg, end = hic_data.section_pos[crm]
            size = len(matrix)
            if size < 10:
                print "     Chromosome too short (%d bins), skipping..." % size
                continue
            # transform bad column in chromosome referential
            if hic_data.bads:
                to_rm = tuple([1 if i in hic_data.bads else 0 for i in xrange(beg, end)])
            else:
                to_rm = None
            # maximum size of a TAD
            max_tad_size = (size - 1) if opts.max_tad_size is None else opts.max_tad_size
            result = tadbit([matrix], remove=to_rm,
                            n_cpus=opts.cpus, verbose=opts.verbose,
                            max_tad_size=max_tad_size,
                            no_heuristic=False)

            # use normalization to compute height on TADs called
            if opts.all_bins:
                if opts.nosql:
                    biases = load(open(biases))
                else:
                    biases = load(open(path.join(opts.workdir, biases)))
                hic_data.bads = biases['badcol']
                hic_data.bias = biases['biases']
            tads = load_tad_height(result, size, beg, end, hic_data)
            table = ''
            table += '%s\t%s\t%s\t%s\t%s\n' % ('#', 'start', 'end', 'score', 'density')
            for tad in tads:
                table += '%s\t%s\t%s\t%s%s\n' % (
                    tad, int(tads[tad]['start'] + 1), int(tads[tad]['end'] + 1),
                    abs(tads[tad]['score']), '\t%s' % (round(
                        float(tads[tad]['height']), 3)))
            out_tad = path.join(tad_dir, '%s_%s.tsv' % (crm, param_hash))
            out = open(out_tad, 'w')
            out.write(table)
            out.close()
            tad_result[crm] = {'path' : out_tad,
                               'num': len(tads)}

    finish_time = time.localtime()

    if not opts.nosql:
        try:
            save_to_db(opts, cmp_result, tad_result, reso, inputs,
                       richA_stats, firsts, param_hash,
                       launch_time, finish_time)
        except:
            # release lock anyway
            print_exc()
            try:
                remove(path.join(opts.workdir, '__lock_db'))
            except OSError:
                pass
            exit(1)
Exemple #5
0
def run(opts):
    check_options(opts)
    samtools = which(opts.samtools)
    launch_time = time.localtime()

    param_hash = digest_parameters(opts)

    reso1 = reso2 = None
    if opts.bam1:
        mreads1 = path.realpath(opts.bam1)
        biases1 = opts.biases1
    else:
        biases1, mreads1, reso1 = load_parameters_fromdb(
            opts.workdir1, opts.jobid1, opts, opts.tmpdb1)
        mreads1 = path.join(opts.workdir1, mreads1)
        try:
            biases1 = path.join(opts.workdir1, biases1)
        except AttributeError:
            biases1 = None

    if opts.bam2:
        mreads2 = path.realpath(opts.bam2)
        biases2 = opts.biases2
    else:
        biases2, mreads2, reso2 = load_parameters_fromdb(
            opts.workdir2, opts.jobid2, opts, opts.tmpdb2)
        mreads2 = path.join(opts.workdir2, mreads2)
        try:
            biases2 = path.join(opts.workdir2, biases2)
        except AttributeError:
            biases2 = None

    filter_exclude = opts.filter

    if reso1 != reso2:
        raise Exception('ERROR: differing resolutions between experiments to '
                        'be merged')

    mkdir(path.join(opts.workdir, '00_merge'))

    if not opts.skip_comparison:
        printime('  - loading first sample %s' % (mreads1))
        hic_data1 = load_hic_data_from_bam(mreads1, opts.reso, biases=biases1,
                                           tmpdir=path.join(opts.workdir, '00_merge'),
                                           ncpus=opts.cpus,
                                           filter_exclude=filter_exclude)

        printime('  - loading second sample %s' % (mreads2))
        hic_data2 = load_hic_data_from_bam(mreads2, opts.reso, biases=biases2,
                                           tmpdir=path.join(opts.workdir, '00_merge'),
                                           ncpus=opts.cpus,
                                           filter_exclude=filter_exclude)
        decay_corr_dat = path.join(opts.workdir, '00_merge', 'decay_corr_dat_%s_%s.txt' % (opts.reso, param_hash))
        decay_corr_fig = path.join(opts.workdir, '00_merge', 'decay_corr_dat_%s_%s.png' % (opts.reso, param_hash))
        eigen_corr_dat = path.join(opts.workdir, '00_merge', 'eigen_corr_dat_%s_%s.txt' % (opts.reso, param_hash))
        eigen_corr_fig = path.join(opts.workdir, '00_merge', 'eigen_corr_dat_%s_%s.png' % (opts.reso, param_hash))

        printime('  - comparing experiments')
        printime('    => correlation between equidistant loci')
        corr, _, scc, std, bads = correlate_matrices(
            hic_data1, hic_data2, normalized=opts.norm,
            remove_bad_columns=True, savefig=decay_corr_fig,
            savedata=decay_corr_dat, get_bads=True)
        print '         - correlation score (SCC): %.4f (+- %.7f)' % (scc, std)
        printime('    => correlation between eigenvectors')
        eig_corr = eig_correlate_matrices(hic_data1, hic_data2, normalized=opts.norm,
                                          remove_bad_columns=True, nvect=6,
                                          savefig=eigen_corr_fig,
                                          savedata=eigen_corr_dat)

        printime('    => reproducibility score')
        reprod = get_reproducibility(hic_data1, hic_data2, num_evec=20, normalized=opts.norm,
                                     verbose=False, remove_bad_columns=True)
        print '         - reproducibility score: %.4f' % (reprod)
        ncols = len(hic_data1)
    else:
        ncols = 0
        decay_corr_dat = 'None'
        decay_corr_fig = 'None'
        eigen_corr_dat = 'None'
        eigen_corr_fig = 'None'

        corr = eig_corr = 0
        bads = {}

    # merge inputs
    mkdir(path.join(opts.workdir, '03_filtered_reads'))
    outbam = path.join(opts.workdir, '03_filtered_reads',
                       'intersection_%s.bam' % (param_hash))

    printime('  - Mergeing experiments')
    system(samtools  + ' merge -@ %d %s %s %s' % (opts.cpus, outbam, mreads1, mreads2))
    printime('  - Indexing new BAM file')
    # check samtools version number and modify command line
    version = LooseVersion([l.split()[1]
                            for l in Popen(samtools, stderr=PIPE).communicate()[1].split('\n')
                            if 'Version' in l][0])
    if version >= LooseVersion('1.3.1'):
        system(samtools  + ' index -@ %d %s' % (opts.cpus, outbam))
    else:
        system(samtools  + ' index %s' % (outbam))

    finish_time = time.localtime()
    save_to_db (opts, mreads1, mreads2, decay_corr_dat, decay_corr_fig,
                len(bads.keys()), ncols, scc, std, reprod,
                eigen_corr_dat, eigen_corr_fig, outbam, corr, eig_corr,
                biases1, biases2, launch_time, finish_time)
    printime('\nDone.')