Ejemplo n.º 1
0
def get_mapped_chunk(map_folder, nreads):
    seqs = {}
    printime(' - loading chunk')
    pos_file = 0
    for fname in os.listdir(map_folder):
        printime('    - ' + fname)
        fhandler = magic_open(os.path.join(map_folder, fname))
        for line in fhandler:
            pos_file += 1
            rid, seq, qal, _, pos = line.split()
            pos = int(pos.split(':')[2])
            rid = rid.split('~')[0]
            seqs[rid, pos] = (seq, qal)
            if pos_file >= nreads:
                yield seqs
                printime(' - loading chunk')
                seqs = {}
                pos_file = 0
    yield seqs
Ejemplo n.º 2
0
def read_bam(inbam, filter_exclude, resolution, min_count=2500, biases_path='',
             normalization='Vanilla', mappability=None, n_rsites=None,
             cg_content=None, sigma=2, ncpus=8, factor=1, outdir='.', seed=1,
             extra_out='', only_valid=False, normalize_only=False, p_fit=None,
             max_njobs=100, extra_bads=None, 
             cis_limit=1, trans_limit=5, min_ratio=1.0, fast_filter=False):
    bamfile = AlignmentFile(inbam, 'rb')
    sections = OrderedDict(list(zip(bamfile.references,
                               [x // resolution + 1 for x in bamfile.lengths])))
    total = 0
    section_pos = dict()
    for crm in sections:
        section_pos[crm] = (total, total + sections[crm])
        total += sections[crm]
    bins = []
    for crm in sections:
        len_crm = sections[crm]
        bins.extend([(crm, i) for i in range(len_crm)])

    start_bin = 0
    end_bin   = len(bins)
    total     = len(bins)

    regs = []
    begs = []
    ends = []
    njobs = min(total, max_njobs) + 1
    nbins = total // njobs + 1
    for i in range(start_bin, end_bin, nbins):
        if i + nbins > end_bin:  # make sure that we stop
            nbins = end_bin - i
        try:
            (crm1, beg1), (crm2, end2) = bins[i], bins[i + nbins - 1]
        except IndexError:
            try:
                (crm1, beg1), (crm2, end2) = bins[i], bins[-1]
            except IndexError:
                break
        if crm1 != crm2:
            end1 = sections[crm1]
            beg2 = 0
            regs.append(crm1)
            regs.append(crm2)
            begs.append(beg1 * resolution)
            begs.append(beg2 * resolution)
            ends.append(end1 * resolution + resolution)  # last nt included
            ends.append(end2 * resolution + resolution - 1)  # last nt not included (overlap with next window)
        else:
            regs.append(crm1)
            begs.append(beg1 * resolution)
            ends.append(end2 * resolution + resolution - 1)
    ends[-1] += 1  # last nucleotide included

    # print '\n'.join(['%s %d %d' % (a, b, c) for a, b, c in zip(regs, begs, ends)])
    printime('  - Parsing BAM (%d chunks)' % (len(regs)))
    # define limits for cis and trans interactions if not given
    if cis_limit is None:
        cis_limit = int(1_000_000 / resolution)
    print('      -> cis interactions are defined as being bellow {}'.format(
        nicer(cis_limit * resolution)))
    if trans_limit is None:
        trans_limit = cis_limit * 5
    print('      -> trans interactions are defined as being bellow {}'.format(
        nicer(trans_limit * resolution)))

    bins_dict = dict([(j, i) for i, j in enumerate(bins)])
    pool = mu.Pool(ncpus)
    procs = []
    read_bam_frag = read_bam_frag_valid if only_valid else read_bam_frag_filter
    for i, (region, start, end) in enumerate(zip(regs, begs, ends)):
        procs.append(pool.apply_async(
            read_bam_frag, args=(inbam, filter_exclude, bins, bins_dict,
                                 resolution, outdir, extra_out,
                                 region, start, end, cis_limit, trans_limit)))
    pool.close()
    print_progress(procs)
    pool.join()
    ## COLLECT RESULTS
    cisprc = {}
    printime('  - Collecting cis and total interactions per bin (%d chunks)' % (len(regs)))
    stdout.write('     ')
    for countbin, (region, start, end) in enumerate(zip(regs, begs, ends)):
        if not countbin % 10 and countbin:
            stdout.write(' ')
        if not countbin % 50 and countbin:
            stdout.write(' %9s\n     ' % ('%s/%s' % (countbin , len(regs))))
        stdout.write('.')
        stdout.flush()

        fname = path.join(outdir,
                          'tmp_bins_%s:%d-%d_%s.pickle' % (region, start, end, extra_out))
        tmp_cisprc = load(open(fname,'rb'))
        system('rm -f %s' % fname)
        cisprc.update(tmp_cisprc)
    stdout.write('\n')

    # get cis/trans ratio
    for k in cisprc:
        try:
            cisprc[k][3] = cisprc[k][2] / cisprc[k][3]
        except ZeroDivisionError:
            cisprc[k][3] = 0

    # BIN FILTERINGS
    printime('  - Removing columns with too few or too much interactions')
    
    # define filter for minimum interactions per bin
    if not fast_filter:
        if min_count is None:
            min_count = nanpercentile(
                [cisprc[k][2] for k in range(total) 
                if cisprc.get(k, [0, 0, 0, 0])[3] < min_ratio 
                and cisprc.get(k, [0, 0, 0, 0])[2] >= 1], 95)  # harcoded parameter we are filtering
                                                                # out bins with no interactions in cis

        print('      -> too few interactions defined as less than %9d '
               'interactions' % (min_count))
        badcol = dict((k, True) for k in range(total) 
                    if cisprc.get(k, [0, 0, 0, 0])[3] < min_ratio
                    or cisprc[k][2] < min_count)
        print('      -> removed %d columns of %d (%.1f%%)' % (
            len(badcol), total, float(len(badcol)) / total * 100))
    else:
    # if len(bamfile.references) == 1 and min_count is None:
    #     raise Exception("ERROR: only one chromosome can't filter by "
    #                     "cis-percentage, set min_count instead")
    # elif min_count is None and len(bamfile.references) > 1:
        # badcol = filter_by_cis_percentage(
        #     cisprc, sigma=sigma, verbose=True, min_perc=min_perc, max_perc=max_perc,
        #     size=total, savefig=None)

        print('      -> too few interactions defined as less than %9d '
               'interactions' % (min_count))
        badcol = {}
        countL = 0
        countZ = 0
        for c in range(total):
            if cisprc.get(c, [0, 0, 0, 0])[1] < min_count:
                badcol[c] = cisprc.get(c, [0, 0, 0, 0])[1]
                countL += 1
                if not c in cisprc:
                    countZ += 1
        print('      -> removed %d columns (%d/%d null/high counts) of %d (%.1f%%)' % (
            len(badcol), countZ, countL, total, float(len(badcol)) / total * 100))

    # Plot
    plot_filtering(dict((k, cisprc[k][2]) for k in cisprc), 
                   dict((k, cisprc[k][3]) for k in cisprc), total, min_count, min_ratio, 
                   path.join(outdir, 'filtering_summary_plot_{}_{}.png'.format(nicer(resolution, sep=''),
                    extra_out)),
                   base_position=0, next_position=cis_limit, last_position=trans_limit, resolution=resolution,
                   legend='Filtered {} of {} bins'.format(len(badcol), total))

    # no mappability will result in NaNs, better to filter out these columns
    if mappability:
        badcol.update((i, True) for i, m in enumerate(mappability) if not m)

    # add manually columns to bad columns
    if extra_bads:
        removed_manually = 0
        for ebc in extra_bads:
            c, ebc = ebc.split(':')
            b, e = list(map(int, ebc.split('-')))
            b = b // resolution + section_pos[c][0]
            e = e // resolution + section_pos[c][0]
            removed_manually += (e - b)
            badcol.update(dict((p, 'manual') for p in range(b, e)))
        printime('  - Removed %d columns manually.' % removed_manually)
    raw_cisprc = sum(float(cisprc[k][0]) / cisprc[k][1]
                     for k in cisprc if not k in badcol) / (len(cisprc) - len(badcol))

    printime('  - Rescaling sum of interactions per bins')
    size = len(bins)
    biases = [float('nan') if k in badcol else cisprc.get(k, [0, 1., 0, 0])[1]
              for k in range(size)]

    if normalization == 'ICE':
        printime('  - ICE normalization')
        hic_data = load_hic_data_from_bam(
            inbam, resolution, filter_exclude=filter_exclude,
            tmpdir=outdir, ncpus=ncpus, nchunks=max_njobs)
        hic_data.bads = badcol
        hic_data.normalize_hic(iterations=100, max_dev=0.000001)
        biases = hic_data.bias.copy()
        del(hic_data)
    elif normalization == 'Vanilla':
        printime('  - Vanilla normalization')
        mean_col = nanmean(biases)
        biases   = dict((k, b / mean_col * mean_col**0.5)
                        for k, b in enumerate(biases))
    elif normalization == 'SQRT':
        printime('  - Vanilla-SQRT normalization')
        biases = [b**0.5 for b in biases]
        mean_col = nanmean(biases)
        biases   = dict((k, b / mean_col * mean_col**0.5)
                        for k, b in enumerate(biases))
    elif normalization == 'oneD':
        printime('  - oneD normalization')
        if len(set([len(biases), len(mappability), len(n_rsites), len(cg_content)])) > 1:
            print("biases", "mappability", "n_rsites", "cg_content")
            print(len(biases), len(mappability), len(n_rsites), len(cg_content))
            raise Exception('Error: not all arrays have the same size')
        tmp_oneD = path.join(outdir,'tmp_oneD_%s' % (extra_out))
        mkdir(tmp_oneD)
        biases = oneD(tmp_dir=tmp_oneD, p_fit=p_fit, tot=biases, map=mappability,
                      res=n_rsites, cg=cg_content, seed=seed)
        biases = dict((k, b) for k, b in enumerate(biases))
        rmtree(tmp_oneD)
    elif normalization == 'custom':
        n_pos = 0
        biases = {}
        print('Using provided biases...')
        with open(biases_path, 'r') as r:
            next(r)
            for line in r:
                if line[0] == 'N':
                    #b = float('nan')
                    badcol[n_pos] = 0
                    biases[n_pos] = float('nan')
                else:
                    b = float(line)
                    if b == 0:
                        badcol[n_pos] = 0
                        biases[n_pos] = float('nan')
                    else:
                        biases[n_pos] = b
                n_pos += 1
        for add in range(max(biases.keys()), total + 1):
            biases[add] = float('nan')
    else:
        raise NotImplementedError('ERROR: method %s not implemented' %
                                  normalization)

    # collect subset-matrices and write genomic one
    # out = open(os.path.join(outdir,
    #                         'hicdata_%s.abc' % (nicer(resolution).replace(' ', ''))), 'w')
    printime('  - Getting sum of normalized bins')
    pool = mu.Pool(ncpus)
    procs = []
    for i, (region, start, end) in enumerate(zip(regs, begs, ends)):
        fname = path.join(outdir,
                          'tmp_%s:%d-%d_%s.pickle' % (region, start, end, extra_out))
        procs.append(pool.apply_async(sum_nrm_matrix,
                                      args=(fname, biases,)))
    pool.close()
    print_progress(procs)
    pool.join()

    # to correct biases
    sumnrm = sum(p.get() for p in procs)

    target = (sumnrm / float(size * size * factor))**0.5
    biases = dict([(b, biases[b] * target) for b in biases])

    if not normalize_only:
        printime('  - Computing Cis percentage')
        # Calculate Cis percentage

        pool = mu.Pool(ncpus)
        procs = []
        for i, (region, start, end) in enumerate(zip(regs, begs, ends)):
            fname = path.join(outdir,
                              'tmp_%s:%d-%d_%s.pickle' % (region, start, end, extra_out))
            procs.append(pool.apply_async(get_cis_perc,
                                          args=(fname, biases, badcol, bins)))
        pool.close()
        print_progress(procs)
        pool.join()

        # collect results
        cis = total = 0
        for proc in procs:
            c, t = proc.get()
            cis += c
            total += t
        norm_cisprc = float(cis) / total
        print('    * Cis-percentage: %.1f%%' % (norm_cisprc * 100))
    else:
        norm_cisprc = 0.

    printime('  - Rescaling decay')
    # normalize decay by size of the diagonal, and by Vanilla correction
    # (all cells must still be equals to 1 in average)

    pool = mu.Pool(ncpus)
    procs = []
    for i, (region, start, end) in enumerate(zip(regs, begs, ends)):
        fname = path.join(outdir,
                          'tmp_%s:%d-%d_%s.pickle' % (region, start, end, extra_out))
        procs.append(pool.apply_async(sum_dec_matrix,
                                      args=(fname, biases, badcol, bins)))
    pool.close()
    print_progress(procs)
    pool.join()

    # collect results
    nrmdec = {}
    rawdec = {}
    for proc in procs:
        tmpnrm, tmpraw = proc.get()
        for c, d in tmpnrm.items():
            for k, v in d.items():
                try:
                    nrmdec[c][k] += v
                    rawdec[c][k] += tmpraw[c][k]
                except KeyError:
                    try:
                        nrmdec[c][k]  = v
                        rawdec[c][k] = tmpraw[c][k]
                    except KeyError:
                        nrmdec[c] = {k: v}
                        rawdec[c] = {k: tmpraw[c][k]}
    # count the number of cells per diagonal
    # TODO: parallelize
    # find largest chromosome
    len_crms = dict((c, section_pos[c][1] - section_pos[c][0]) for c in section_pos)
    # initialize dictionary
    ndiags = dict((c, dict((k, 0) for k in range(len_crms[c]))) for c in sections)
    for crm in section_pos:
        beg_chr, end_chr = section_pos[crm][0], section_pos[crm][1]
        chr_size = end_chr - beg_chr
        thesebads = [b for b in badcol if beg_chr <= b <= end_chr]
        for dist in range(1, chr_size):
            ndiags[crm][dist] += chr_size - dist
            # from this we remove bad columns
            # bad columns will only affect if they are at least as distant from
            # a border as the distance between the longest diagonal and the
            # current diagonal.
            bad_diag = set()  # 2 bad rows can point to the same bad cell in diagonal
            maxp = end_chr - dist
            minp = beg_chr + dist
            for b in thesebads:
                if b < maxp:  # not inclusive!!
                    bad_diag.add(b)
                if b >= minp:
                    bad_diag.add(b - dist)
            ndiags[crm][dist] -= len(bad_diag)
        # different behavior for longest diagonal:
        ndiags[crm][0] += chr_size - sum(beg_chr <= b < end_chr for b in thesebads)

    # normalize sum per diagonal by total number of cells in diagonal
    signal_to_noise = 0.05
    min_n = signal_to_noise ** -2. # equals 400 when default
    for crm in sections:
        if not crm in nrmdec:
            nrmdec[crm] = {}
            rawdec[crm] = {}
        tmpdec = 0  # store count by diagonal
        tmpsum = 0  # store count by diagonal
        ndiag  = 0
        val    = 0
        previous = [] # store diagonals to be summed in case not reaching the minimum
        for k in ndiags[crm]:
            tmpdec += nrmdec[crm].get(k, 0.)
            tmpsum += rawdec[crm].get(k, 0.)
            previous.append(k)
            if tmpsum > min_n:
                ndiag = sum(ndiags[crm][k] for k in previous)
                val = tmpdec  # backup of tmpdec kept for last ones outside the loop
                try:
                    ratio = val / ndiag
                    for l in previous:
                        nrmdec[crm][l] = ratio
                except ZeroDivisionError:  # all columns at this distance are "bad"
                    pass
                previous = []
                tmpdec = 0
                tmpsum = 0
        # last ones we average with previous result
        if  len(previous) == len(ndiags[crm]):
            nrmdec[crm] = {}
        elif tmpsum < min_n:
            ndiag += sum(ndiags[crm][k] for k in previous)
            val += tmpdec
            try:
                ratio = val / ndiag
                for k in previous:
                    nrmdec[crm][k] = ratio
            except ZeroDivisionError:  # all columns at this distance are "bad"
                pass
    return biases, nrmdec, badcol, raw_cisprc, norm_cisprc
Ejemplo n.º 3
0
def run(opts):
    check_options(opts)
    launch_time = time.localtime()

    param_hash = digest_parameters(opts)
    if opts.bam:
        mreads = path.realpath(opts.bam)
    else:
        mreads = path.join(opts.workdir, load_parameters_fromdb(opts))

    filter_exclude = opts.filter

    outdir = path.join(opts.workdir, '04_normalization')
    mkdir(outdir)

    mappability = gc_content = n_rsites = None
    if opts.normalization == 'oneD':
        if not opts.fasta:
            raise Exception('ERROR: missing path to FASTA for oneD normalization')
        if not opts.renz:
            raise Exception('ERROR: missing restriction enzyme name for oneD normalization')
        if not opts.mappability:
            raise Exception('ERROR: missing path to mappability for oneD normalization')
        bamfile = AlignmentFile(mreads, 'rb')
        refs = bamfile.references
        bamfile.close()

        # get genome sequence ~1 min
        printime('  - parsing FASTA')
        genome = parse_fasta(opts.fasta, verbose=False)

        fas = set(genome.keys())
        bam = set(refs)
        if fas - bam:
            print('WARNING: %d extra chromosomes in FASTA (removing them)' % (len(fas - bam)))
            if len(fas - bam) <= 50:
                print('\n'.join([('  - ' + c) for c in (fas - bam)]))
        if bam - fas:
            txt = ('\n'.join([('  - ' + c) for c in (bam - fas)])
                   if len(bam - fas) <= 50 else '')
            raise Exception('ERROR: %d extra chromosomes in BAM (remove them):\n%s\n' % (
                len(bam - fas), txt))
        refs = [crm for crm in refs if crm in genome]
        if len(refs) == 0:
            raise Exception("ERROR: chromosomes in FASTA different the ones"
                            " in BAM")

        # get mappability ~2 min
        printime('  - Parsing mappability')
        mappability = parse_mappability_bedGraph(
            opts.mappability, opts.reso,
            wanted_chrom=refs[0] if len(refs)==1 else None)
        # resize chomosomes
        for c in refs:
            if not c in mappability:
                mappability[c] = [float('nan')] * (len(refs) // opts.reso + 1)
            if len(mappability[c]) < len(refs) // opts.reso + 1:
                mappability[c] += [float('nan')] * (
                    (len(refs) // opts.reso + 1) - len(mappability[c]))
        # concatenates
        mappability = reduce(lambda x, y: x + y,
                             (mappability.get(c, []) for c in refs))

        printime('  - Computing GC content per bin (removing Ns)')
        gc_content = get_gc_content(genome, opts.reso, chromosomes=refs,
                                    n_cpus=opts.cpus)
        # pad mappability at the end if the size is close to gc_content
        if len(mappability)<len(gc_content) and len(mappability)/len(gc_content) > 0.95:
            mappability += [float('nan')] * (len(gc_content)-len(mappability))

        # compute r_sites ~30 sec
        # TODO: read from DB
        printime('  - Computing number of RE sites per bin (+/- 200 bp)')
        n_rsites  = []
        re_site = RESTRICTION_ENZYMES[opts.renz].replace('|', '')
        for crm in refs:
            for pos in range(200, len(genome[crm]) + 200, opts.reso):
                seq = genome[crm][pos-200:pos + opts.reso + 200]
                n_rsites.append(seq.count(re_site))

        ## CHECK TO BE REMOVED
        # out = open('tmp_mappability.txt', 'w')
        # i = 0
        # for crm in refs:
        #     for pos in xrange(len(genome[crm]) / opts.reso + 1):
        #         out.write('%s\t%d\t%d\t%f\n' % (crm, pos * opts.reso, pos * opts.reso + opts.reso, mappability[i]))
        #         i += 1`
        # out.close()
        # compute GC content ~30 sec
        # TODO: read from DB
    biases, decay, badcol, raw_cisprc, norm_cisprc = read_bam(
        mreads, filter_exclude, opts.reso, min_count=opts.min_count, sigma=2,
        factor=1, outdir=outdir, extra_out=param_hash, ncpus=opts.cpus,
        normalization=opts.normalization, mappability=mappability,
        p_fit=opts.p_fit, cg_content=gc_content, n_rsites=n_rsites,
        seed=opts.seed,
        normalize_only=opts.normalize_only, max_njobs=opts.max_njobs,
        extra_bads=opts.badcols, biases_path=opts.biases_path, 
        cis_limit=opts.cis_limit, trans_limit=opts.trans_limit, 
        min_ratio=opts.ratio_limit, fast_filter=opts.fast_filter)

    inter_vs_gcoord = path.join(opts.workdir, '04_normalization',
                                'interactions_vs_genomic-coords.png_%s_%s.png' % (
                                    opts.reso, param_hash))

    # get and plot decay
    if not opts.normalize_only:
        printime('  - Computing interaction decay vs genomic distance')
        (_, _, _), (a2, _, _), (_, _, _) = plot_distance_vs_interactions(
            decay, max_diff=10000, resolution=opts.reso, normalized=not opts.filter_only,
            savefig=inter_vs_gcoord)

        print ('    -> Decay slope 0.7-10 Mb\t%s' % a2)
    else:
        a2 = 0.

    printime('  - Saving biases and badcol columns')
    # biases
    bias_file = path.join(outdir, 'biases_%s_%s.pickle' % (
        nicer(opts.reso).replace(' ', ''), param_hash))
    out = open(bias_file, 'wb')

    dump({'biases'    : biases,
          'decay'     : decay,
          'badcol'    : badcol,
          'resolution': opts.reso}, out, HIGHEST_PROTOCOL)
    out.close()

    finish_time = time.localtime()

    try:
        save_to_db(opts, bias_file, mreads, len(badcol),
                   len(biases), raw_cisprc, norm_cisprc,
                   inter_vs_gcoord, a2, opts.filter,
                   launch_time, finish_time)
    except:
        # release lock anyway
        print_exc()
        try:
            remove(path.join(opts.workdir, '__lock_db'))
        except OSError:
            pass
        exit(1)
Ejemplo n.º 4
0
def run(opts):
    check_options(opts)
    launch_time = time.localtime()
    param_hash = digest_parameters(opts, extra=['quiet'])
    biases = None

    if opts.zrange:
        vmin = float(opts.zrange.split(',')[0])
        vmax = float(opts.zrange.split(',')[1])
    else:
        vmin = vmax = None

    if opts.figsize:
        opts.figsize = list(map(float, opts.figsize.split(',')))

    clean = True  # change for debug
    biases = None
    if opts.bam:
        mreads = path.realpath(opts.bam)
        if not opts.biases and all(v != 'raw' for v in opts.normalizations):
            raise Exception('ERROR: external BAM input, should provide path to'
                            ' biases file.')
    else:
        biases, mreads = load_parameters_fromdb(opts)
        mreads = path.join(opts.workdir, mreads)
        biases = path.join(opts.workdir, biases) if biases else None
    if opts.biases:
        biases = opts.biases

    coord1 = opts.coord1
    coord2 = opts.coord2

    if coord2 and not coord1:
        coord1, coord2 = coord2, coord1

    if not coord1:
        region1 = None
        start1 = None
        end1 = None
        region2 = None
        start2 = None
        end2 = None
    else:
        try:
            crm1, pos1 = coord1.split(':')
            start1, end1 = pos1.split('-')
            region1 = crm1
            start1 = int(start1)
            end1 = int(end1)
        except ValueError:
            region1 = coord1
            start1 = None
            end1 = None
        if coord2:
            try:
                crm2, pos2 = coord2.split(':')
                start2, end2 = pos2.split('-')
                region2 = crm2
                start2 = int(start2)
                end2 = int(end2)
            except ValueError:
                region2 = coord2
                start2 = None
                end2 = None
        else:
            region2 = None
            start2 = None
            end2 = None

    if opts.plot and not opts.force_plot:
        if opts.interactive:
            max_size = 3500**2
        else:
            max_size = 5000**2
    else:
        max_size = None

    outdir = path.join(opts.workdir, '05_sub-matrices')
    mkdir(outdir)
    tmpdir = path.join(opts.workdir, '05_sub-matrices',
                       '_tmp_sub-matrices_%s' % param_hash)
    mkdir(tmpdir)

    if region1:
        if region1:
            if not opts.quiet:
                stdout.write('\nExtraction of %s' % (region1))
            if start1:
                if not opts.quiet:
                    stdout.write(':%s-%s' % (start1, end1))
            else:
                if not opts.quiet:
                    stdout.write(' (full chromosome)')
            if region2:
                if not opts.quiet:
                    stdout.write(' intersection with %s' % (region2))
                if start2:
                    if not opts.quiet:
                        stdout.write(':%s-%s\n' % (start2, end2))
                else:
                    if not opts.quiet:
                        stdout.write(' (full chromosome)\n')
            else:
                if not opts.quiet:
                    stdout.write('\n')
    else:
        if not opts.quiet:
            stdout.write('\nExtraction of %s genome\n' %
                         ('partial' if opts.chr_name else 'full'))

    out_files = {}
    out_plots = {}

    if opts.matrix or opts.plot:
        sections, section_pos = get_sections(mreads, opts.chr_name)
        for norm in opts.normalizations:
            norm_string = ('RAW' if norm == 'raw' else
                           'NRM' if norm == 'norm' else 'DEC')
            printime('Getting %s matrices' % norm)
            try:
                matrix, bads1, bads2, regions, name, bin_coords = get_matrix(
                    mreads,
                    opts.reso,
                    load(open(biases, 'rb'))
                    if biases and norm != 'raw' else None,
                    normalization=norm,
                    filter_exclude=opts.filter,
                    region1=region1,
                    start1=start1,
                    end1=end1,
                    region2=region2,
                    start2=start2,
                    end2=end2,
                    tmpdir=tmpdir,
                    ncpus=opts.cpus,
                    return_headers=True,
                    nchunks=opts.nchunks,
                    verbose=not opts.quiet,
                    clean=clean,
                    max_size=max_size,
                    chr_order=opts.chr_name)
            except NotImplementedError:
                if norm == "raw&decay":
                    warn('WARNING: raw&decay normalization not implemented '
                         'for matrices\n... skipping\n')
                    continue
                raise

            b1, e1, b2, e2 = bin_coords
            b1, e1 = 0, e1 - b1
            b2, e2 = 0, e2 - b2

            if opts.row_names:
                starts = [start1, start2]
                ends = [end1, end2]
                row_names = ((reg, p + 1, p + opts.reso)
                             for r, reg in enumerate(regions) for p in range(
                                 starts[r] if r < len(starts) and starts[r]
                                 else 0, ends[r] if r < len(ends) and ends[r]
                                 else sections[reg], opts.reso))

            if opts.matrix:
                printime(' - Writing: %s' % norm)
                fnam = '%s_%s_%s%s.mat' % (norm, name, nicer(
                    opts.reso, sep=''), ('_' + param_hash))
                out_files[norm_string] = path.join(outdir, fnam)
                out = open(path.join(outdir, fnam), 'w')
                for reg in regions:
                    out.write('# CRM %s\t%d\n' % (reg, sections[reg]))
                if region2:
                    out.write('# BADROWS %s\n' %
                              (','.join([str(b) for b in bads1])))
                    out.write('# BADCOLS %s\n' %
                              (','.join([str(b) for b in bads2])))
                else:
                    out.write('# MASKED %s\n' %
                              (','.join([str(b) for b in bads1])))
                if opts.row_names:
                    out.write('\n'.join('%s\t%d\t%d\t' %
                                        (next(row_names)) + '\t'.join(
                                            str(matrix.get((i, j), 0))
                                            for i in range(b1, e1))
                                        for j in range(b2, e2)) + '\n')
                else:
                    out.write('\n'.join('\t'.join(
                        str(matrix.get((i, j), 0)) for i in range(b1, e1))
                                        for j in range(b2, e2)) + '\n')
                out.close()
            if opts.plot:
                # transform matrix
                matrix = array([
                    array([matrix.get((i, j), 0) for i in range(b1, e1)])
                    for j in range(b2, e2)
                ])
                m = zeros_like(matrix)
                for bad1 in bads1:
                    m[:, bad1] = 1
                    for bad2 in bads2:
                        m[bad2, :] = 1
                matrix = ma.masked_array(matrix, m)
                printime(' - Plotting: %s' % norm)
                fnam = '%s_%s_%s%s%s.%s' % (
                    'nrm' if norm == 'norm' else norm[:3], name,
                    nicer(opts.reso, sep=''), ('_' + param_hash),
                    '_tri' if opts.triangular else '', opts.format)
                out_plots[norm_string] = path.join(outdir, fnam)
                pltbeg1 = 0 if start1 is None else start1
                pltend1 = sections[regions[0]] if end1 is None else end1
                pltbeg2 = 0 if start2 is None else start2
                pltend2 = sections[regions[-1]] if end2 is None else end2
                xlabel = '{}:{:,}-{:,}'.format(regions[0],
                                               pltbeg1 if pltbeg1 else 1,
                                               pltend1)
                ylabel = '{}:{:,}-{:,}'.format(regions[-1],
                                               pltbeg2 if pltbeg2 else 1,
                                               pltend2)
                section_pos = OrderedDict(
                    (k, section_pos[k]) for k in section_pos if k in regions)
                transform = (log2 if opts.transform == 'log2' else
                             log if opts.transform == 'log' else lambda x: x)
                tads = None
                if opts.tad_def and not region2:
                    tads = load_tads_fromdb(opts)
                    if tads and start1:
                        tads = dict([
                            (t, tads[t]) for t in tads
                            if (int(tads[t]['start']) >= start1 // opts.reso
                                and int(tads[t]['end']) <= end1 // opts.reso)
                        ])
                        for tad in tads:
                            tads[tad]['start'] -= start1 // opts.reso
                            tads[tad]['end'] -= start1 // opts.reso
                ax1, _ = plot_HiC_matrix(
                    matrix,
                    triangular=opts.triangular,
                    vmin=vmin,
                    vmax=vmax,
                    cmap=opts.cmap,
                    figsize=opts.figsize,
                    transform=transform,
                    bad_color=opts.bad_color if norm != 'raw' else None,
                    tad_def=tads)
                ax1.set_title('Region: %s, normalization: %s, resolution: %s' %
                              (name, norm, nicer(opts.reso)),
                              y=1.05)
                format_HiC_axes(ax1,
                                start1,
                                end1,
                                start2,
                                end2,
                                opts.reso,
                                regions,
                                section_pos,
                                sections,
                                opts.xtick_rotation,
                                triangular=False)
                if opts.interactive:
                    plt.show()
                    plt.close('all')
                else:
                    tadbit_savefig(path.join(outdir, fnam))
    if not opts.matrix and not opts.only_plot:
        printime('Getting and writing matrices')
        out_files.update(
            write_matrix(mreads,
                         opts.reso,
                         load(open(biases, 'rb')) if biases else None,
                         outdir,
                         filter_exclude=opts.filter,
                         normalizations=opts.normalizations,
                         region1=region1,
                         start1=start1,
                         end1=end1,
                         region2=region2,
                         start2=start2,
                         end2=end2,
                         tmpdir=tmpdir,
                         append_to_tar=None,
                         ncpus=opts.cpus,
                         nchunks=opts.nchunks,
                         verbose=not opts.quiet,
                         extra=param_hash,
                         cooler=opts.cooler,
                         clean=clean,
                         chr_order=opts.chr_name))

    if clean:
        printime('Cleaning')
        system('rm -rf %s ' % tmpdir)

    if not opts.interactive:
        printime('Saving to DB')
        finish_time = time.localtime()
        save_to_db(opts, launch_time, finish_time, out_files, out_plots)
Ejemplo n.º 5
0
def run(opts):
    check_options(opts)
    samtools = which(opts.samtools)
    launch_time = time.localtime()

    param_hash = digest_parameters(opts)

    reso1 = reso2 = None
    if opts.bam1:
        mreads1 = path.realpath(opts.bam1)
        biases1 = opts.biases1
    else:
        biases1, mreads1, reso1 = load_parameters_fromdb(
            opts.workdir1, opts.jobid1, opts, opts.tmpdb1)
        mreads1 = path.join(opts.workdir1, mreads1)
        try:
            biases1 = path.join(opts.workdir1, biases1)
        except AttributeError:
            biases1 = None
        except TypeError:  # Py3
            biases1 = None

    if opts.bam2:
        mreads2 = path.realpath(opts.bam2)
        biases2 = opts.biases2
    else:
        biases2, mreads2, reso2 = load_parameters_fromdb(
            opts.workdir2, opts.jobid2, opts, opts.tmpdb2)
        mreads2 = path.join(opts.workdir2, mreads2)
        try:
            biases2 = path.join(opts.workdir2, biases2)
        except AttributeError:
            biases2 = None
        except TypeError:  # Py3
            biases1 = None

    filter_exclude = opts.filter

    if reso1 != reso2:
        raise Exception('ERROR: differing resolutions between experiments to '
                        'be merged')

    mkdir(path.join(opts.workdir, '00_merge'))

    if not opts.skip_comparison:
        printime('  - loading first sample %s' % (mreads1))
        hic_data1 = load_hic_data_from_bam(mreads1,
                                           opts.reso,
                                           biases=biases1,
                                           tmpdir=path.join(
                                               opts.workdir, '00_merge'),
                                           ncpus=opts.cpus,
                                           filter_exclude=filter_exclude)

        printime('  - loading second sample %s' % (mreads2))
        hic_data2 = load_hic_data_from_bam(mreads2,
                                           opts.reso,
                                           biases=biases2,
                                           tmpdir=path.join(
                                               opts.workdir, '00_merge'),
                                           ncpus=opts.cpus,
                                           filter_exclude=filter_exclude)

        if opts.workdir1 and opts.workdir2:
            masked1 = {'valid-pairs': {'count': 0}}
            masked2 = {'valid-pairs': {'count': 0}}
        else:
            masked1 = {'valid-pairs': {'count': sum(hic_data1.values())}}
            masked2 = {'valid-pairs': {'count': sum(hic_data2.values())}}

        decay_corr_dat = path.join(
            opts.workdir, '00_merge',
            'decay_corr_dat_%s_%s.txt' % (opts.reso, param_hash))
        decay_corr_fig = path.join(
            opts.workdir, '00_merge',
            'decay_corr_dat_%s_%s.png' % (opts.reso, param_hash))
        eigen_corr_dat = path.join(
            opts.workdir, '00_merge',
            'eigen_corr_dat_%s_%s.txt' % (opts.reso, param_hash))
        eigen_corr_fig = path.join(
            opts.workdir, '00_merge',
            'eigen_corr_dat_%s_%s.png' % (opts.reso, param_hash))

        printime('  - comparing experiments')
        printime('    => correlation between equidistant loci')
        corr, _, scc, std, bads = correlate_matrices(hic_data1,
                                                     hic_data2,
                                                     normalized=opts.norm,
                                                     remove_bad_columns=True,
                                                     savefig=decay_corr_fig,
                                                     savedata=decay_corr_dat,
                                                     get_bads=True)
        print('         - correlation score (SCC): %.4f (+- %.7f)' %
              (scc, std))
        printime('    => correlation between eigenvectors')
        eig_corr = eig_correlate_matrices(hic_data1,
                                          hic_data2,
                                          normalized=opts.norm,
                                          remove_bad_columns=True,
                                          nvect=6,
                                          savefig=eigen_corr_fig,
                                          savedata=eigen_corr_dat)

        printime('    => reproducibility score')
        reprod = get_reproducibility(hic_data1,
                                     hic_data2,
                                     num_evec=20,
                                     normalized=opts.norm,
                                     verbose=False,
                                     remove_bad_columns=True)
        print('         - reproducibility score: %.4f' % (reprod))
        ncols = len(hic_data1)
    else:
        ncols = 0
        decay_corr_dat = 'None'
        decay_corr_fig = 'None'
        eigen_corr_dat = 'None'
        eigen_corr_fig = 'None'
        masked1 = {}
        masked2 = {}

        corr = eig_corr = scc = std = reprod = 0
        bads = {}

    # merge inputs
    mkdir(path.join(opts.workdir, '03_filtered_reads'))
    outbam = path.join(opts.workdir, '03_filtered_reads',
                       'intersection_%s.bam' % (param_hash))

    if not opts.skip_merge:
        outbam = path.join(opts.workdir, '03_filtered_reads',
                           'intersection_%s.bam' % (param_hash))
        printime('  - Mergeing experiments')
        system(samtools + ' merge -@ %d %s %s %s' %
               (opts.cpus, outbam, mreads1, mreads2))
        printime('  - Indexing new BAM file')
        # check samtools version number and modify command line
        version = LooseVersion([
            l.split()[1]
            for l in Popen(samtools, stderr=PIPE, universal_newlines=True).
            communicate()[1].split('\n') if 'Version' in l
        ][0])
        if version >= LooseVersion('1.3.1'):
            system(samtools + ' index -@ %d %s' % (opts.cpus, outbam))
        else:
            system(samtools + ' index %s' % (outbam))
    else:
        outbam = ''

    finish_time = time.localtime()
    save_to_db(opts, mreads1, mreads2, decay_corr_dat, decay_corr_fig,
               len(list(bads.keys())), ncols, scc, std, reprod, eigen_corr_dat,
               eigen_corr_fig, outbam, corr, eig_corr, biases1, biases2,
               masked1, masked2, launch_time, finish_time)
    printime('\nDone.')
Ejemplo n.º 6
0
def main():
    """
    main function
    """
    opts = get_options()
    filter_exclude = filters_to_bin(opts.filter)
    tadbit_bam = opts.tadbit_bam
    hicup_bam = opts.hicup_bam
    map_folder = opts.map_folder
    nreads = opts.nreads * 1_000_000

    tag_dict = {
        (1, 1): (67, 131),
        (0, 0): (115, 179),
        (1, 0): (99, 147),
        (0, 1): (83, 163),
    }

    out = open(hicup_bam, 'w')
    for seqs in get_mapped_chunk(map_folder, nreads):
        bamfile = AlignmentFile(tadbit_bam, 'rb')
        refs = bamfile.references
        printime(f' - processing BAM (for {len(seqs) / 1_000_000}M reads)')
        for r in bamfile.fetch(multiple_iterators=False):
            if r.flag & filter_exclude:
                continue
            rid = r.qname
            ridname = rid.split('#')[0]
            pos1 = r.reference_start + 1
            which, len1 = r.cigar[0]
            tags = dict(r.tags)
            if which == 6:  # first read-end
                s1, s2 = tags['S1'], tags['S2']
            else:
                s2, s1 = tags['S1'], tags['S2']
            if s1 == 0:
                pos1 = pos1 - len1 + 1
            try:
                seq, qal = seqs[ridname, pos1]
            except KeyError:
                continue
            crm1 = r.reference_name
            crm2 = refs[r.mrnm]
            pos2 = r.mpos + 1
            len2 = r.tlen

            dist = 0 if crm1 != crm2 else abs(pos2 - pos1)
            tags = dict(r.tags)

            if s2 == 0:
                pos2 = pos2 - len2 + 1

            flag = tag_dict[s1, s2][0]

            out.write((f'{r.qname}\t{flag}\t{crm1}\t{pos1}\t{len1}\t'
                       f'{len(seq)}M\t{crm2}\t{pos2}\t{dist}\t{seq}\t'
                       f'{qal}\tMD:Z:{len1}\tPG:Z:MarkDuplicates\tNM:i:0\t'
                       f'AS:i:{len1}\tXS:i:1\n'))
        bamfile.close()
        seqs.clear()
    out.close()