示例#1
0
def remove_low_cellcount_reads(inbam, outbam, mincount, log):
    """
    This function takes a bam file with barcodes in the
    RG tag as input and outputs a bam file containing
    only barcodes that exceed the minimum number of aligments
    for a given barcode.

    """
    treatment = AlignmentFile(inbam, 'rb')
    header = treatment.header
    barcodecounts = {bc['ID']: 0 for bc in header['RG']}

    # first parse the file to determine the per barcode
    # alignment counts

    for aln in treatment.fetch(until_eof=True):
        if aln.is_proper_pair and aln.is_read1 or not aln.is_paired:
            rg = aln.get_tag('RG')
            barcodecounts[rg] += 1

    treatment.close()

    # make new header with the valid barcodes
    treatment = AlignmentFile(inbam, 'rb')
    header = treatment.header.to_dict().copy()
    rgheader = []
    for rg in header['RG']:
        if barcodecounts[rg['ID']] >= mincount:
            rgheader.append(rg)

    header['RG'] = rgheader

    #log summary
    log_content = {}
    log_content['below_minbarcodecounts'] = 0
    log_content['above_minbarcodecounts'] = 0
    log_content['total'] = 0
    for bc in barcodecounts:
        log_content['total'] += barcodecounts[bc]
        if barcodecounts[bc] >= mincount:
            log_content['above_minbarcodecounts'] += barcodecounts[bc]
        else:
            log_content['below_minbarcodecounts'] += barcodecounts[bc]

    bam_writer = AlignmentFile(outbam, 'wb', header=header)
    for aln in treatment.fetch(until_eof=True):
        if barcodecounts[aln.get_tag('RG')] >= mincount:
            bam_writer.write(aln)

    treatment.close()
    bam_writer.close()

    #write log file
    with open(log, 'w') as f:
        f.write('Readgroup\tcounts\n')
        for icnt in log_content:
            f.write('{}\t{}\n'.format(icnt, log_content[icnt]))
示例#2
0
文件: Utilities.py 项目: emlec/pyDNA
def fetch_count_read (alignment_file, seq_name, start, end):
    """
    Count the number of read that are at least partly overlapping a specified chromosomic region
    @param alignment_file Path to a sam or a bam file
    @param seq_name Name of the sequence where read are to be aligned on
    @param start Start genomic coordinates of the area of alignment
    @param end End End genomic coordinates of the area of alignment
    """
    # Specific imports
    from pysam import AlignmentFile
    
    # Init a generator on the sam or bam file with pysam
    if alignment_file[-3:].lower() == "bam":
        al = AlignmentFile(alignment_file, "rb")
        
    elif alignment_file[-3:].lower() == "sam":
        al = AlignmentFile(alignment_file, "r")
    
    else:
        raise Exception("Wrong file format (sam or bam)") 
    
    # Count read aligned at least partly on the specified region
    n = 0
    for i in al.fetch(seq_name, start, end):
        n += 1
        
    al.close()
    
    return n
示例#3
0
def filter_bam(input_bam, pore_c_table, output_bam, clean_read_name):
    from pysam import AlignmentFile

    inbam = AlignmentFile(input_bam, "rb")
    outbam = AlignmentFile(output_bam, "wb", template=inbam)

    aligns = pd.read_parquet(pore_c_table,
                             engine=PQ_ENGINE,
                             columns=["align_idx",
                                      "pass_filter"]).set_index(["align_idx"])
    aligns = aligns[aligns["pass_filter"]]

    expected = len(aligns)
    counter = 0

    for align in inbam.fetch(until_eof=True):
        align_idx = int(align.query_name.rsplit(":")[2])
        if align_idx not in aligns.index:
            continue
        if clean_read_name:
            readname_only = align.query_name.split(":")[0]
            align.query_name = readname_only
        outbam.write(align)
        counter += 1
    if counter != expected:
        raise ValueError(
            f"Number of alignments doesn't match. Expected {expected} got {counter}"
        )
    logger.info(f"Wrote {counter} reads to {output_bam}")
示例#4
0
文件: umis.py 项目: roryk/umis
def bamtag(sam, umi_only):
    ''' Convert a BAM/SAM with fastqtransformed read names to have UMI and
    cellular barcode tags
    '''
    from pysam import AlignmentFile

    if umi_only:
        parser_re = re.compile('.*:UMI_(?P<MB>.*)')
    else:
        parser_re = re.compile('.*:CELL_(?P<CB>.*):UMI_(?P<MB>.*)')

    start_time = time.time()

    sam_mode = 'r' if sam.endswith(".sam") else 'rb'
    sam_file = AlignmentFile(sam, mode=sam_mode)
    out_file = AlignmentFile("-", "wh", template=sam_file)

    track = sam_file.fetch(until_eof=True)

    for count, aln in enumerate(track):
        if not count % 100000:
            logger.info("Processed %d alignments.")

        match = parser_re.match(aln.qname)
        tags = aln.tags

        if not umi_only:
            aln.tags += [('XC', match.group('CB'))]

        aln.tags += [('XR', match.group('MB'))]
        out_file.write(aln)

    total_time = time.time() - start_time
    logger.info('BAM tag conversion done - {:.3}s, {:,} alns/min'.format(total_time, int(60. * count / total_time)))
def main(
    sam: str,
    output: str,
    reference2taxid: str
) -> None:
    """Write row with taxid and classification status for each alignment."""
    aln_infile = AlignmentFile(sam, "r")
    aln_outfile = AlignmentFile('-', "w", template=aln_infile)
    ref2taxid_df = pd.read_csv(
        reference2taxid, sep='\t', names=['acc', 'taxid'], index_col=0)
    output_tsv = open(output, 'w+')

    for aln in aln_infile.fetch(until_eof=True):

        mapped = 'U' if aln.is_unmapped else 'C'
        queryid = aln.query_name
        querylen = aln.query_length

        taxid = 0
        if not aln.is_unmapped:
            taxid = ref2taxid_df.at[aln.reference_name, 'taxid']

        output_tsv.write(
            '{mapped}\t{queryid}\t{taxid}\t0|{querylen}\n'.format(
                mapped=mapped, queryid=queryid, taxid=taxid, querylen=querylen
            )
        )

        aln_outfile.write(aln)
示例#6
0
文件: mapqto0.py 项目: PySean/mutools
def umappedq2zero(bamdir):
    """
    Reads in a BAM file, setting the MAPQ value for an alignment segment
    to zero if it is unmapped.
    Opens up both infile and outfile and outputs these modified
    reads to outfile.
    """
    if not os.path.exists(bamdir):
        sys.stderr.write("Sorry, but the specified directory does not exist.")
        sys.exit(1)

    bamfiles = os.listdir(bamdir)
    bampaths = filter(lambda x: x.endswith(".bam"), bamfiles)
    bampaths = map(lambda x: os.path.join(bamdir, x), bampaths)
    for bam in bampaths:
        inbam = AlignmentFile(bam, "rb")
        # Template is specified to maintain the same header information.
        outbam = AlignmentFile("temp.bam", "wb", template=inbam)
        # Construct reads iterator using fetch.
        reads = inbam.fetch(until_eof=True)
        for read in reads:
            if read.is_unmapped == True:
                read.mapping_quality = 0
            outbam.write(read)  # Don't omit any reads!
        # Overwrite the original with the new file with MAPQs set to zero.
        os.rename("temp.bam", bam)
示例#7
0
def cell_scaling_factors_bam(file, selected_barcodes=None, tag='CB', mapq=10):
    """ Generates pseudo-bulk tracks.

    Parameters
    ----------
    file : str
       Input bam file.
    tag : str or callable
       Barcode tag or callable to extract barcode from the alignments. Default: 'CB'
    mapq : int
       Minimum mapping quality. Default: 10

    Returns
    -------
    pd.Series
       Series containing the barcode counts per barcode.

    """

    barcodecount = Counter()
    afile = AlignmentFile(file, 'rb')
    barcoder = Barcoder(tag)
    for aln in afile.fetch():
        if aln.mapping_quality < mapq:
            continue
        bct = barcoder(aln)
        if selected_barcodes is not None:
            if bct not in selected_barcodes:
                continue
        barcodecount[bct] += 1
    return pd.Series(barcodecount)
示例#8
0
def get_counts(args):
    """function to get fragment sizes

    """
    if args.out is None:
        args.out = '.'.join(os.path.basename(args.bed).split('.')[0:-1])  
    chunks = ChunkList.read(args.bed)
    mat = np.zeros(len(chunks), dtype=np.int)
    bamHandle = AlignmentFile(args.bam)
    j = 0
    for chunk in chunks:
        for read in bamHandle.fetch(chunk.chrom, max(0, chunk.start - args.upper), chunk.end + args.upper):
            if read.is_proper_pair and not read.is_reverse:
                if args.atac:
                    #get left position
                    l_pos = read.pos + 4
                    #get insert size
                    #correct by 8 base pairs to be inserion to insertion
                    ilen = abs(read.template_length) - 8
                else:
                    l_pos = read.pos
                    ilen = abs(read.template_length)
                r_pos = l_pos + ilen - 1
                if _between(ilen, args.lower, args.upper) and (_between(l_pos, chunk.start, chunk.end) or _between(r_pos, chunk.start, chunk.end)):
                    mat[j] += 1
        j += 1
    bamHandle.close()
    np.savetxt(args.out + ".counts.txt.gz", mat, delimiter="\n", fmt='%i')
示例#9
0
def get_counts(args):
    """function to get fragment sizes

    """
    if args.out is None:
        args.out = '.'.join(os.path.basename(args.bed).split('.')[0:-1])
    chunks = ChunkList.read(args.bed)
    mat = np.zeros(len(chunks), dtype=np.int)
    bamHandle = AlignmentFile(args.bam)
    j = 0
    for chunk in chunks:
        for read in bamHandle.fetch(chunk.chrom,
                                    max(0, chunk.start - args.upper),
                                    chunk.end + args.upper):
            if read.is_proper_pair and not read.is_reverse:
                if args.atac:
                    #get left position
                    l_pos = read.pos + 4
                    #get insert size
                    #correct by 8 base pairs to be inserion to insertion
                    ilen = abs(read.template_length) - 8
                else:
                    l_pos = read.pos
                    ilen = abs(read.template_length)
                r_pos = l_pos + ilen - 1
                if _between(ilen, args.lower, args.upper) and (
                        _between(l_pos, chunk.start, chunk.end)
                        or _between(r_pos, chunk.start, chunk.end)):
                    mat[j] += 1
        j += 1
    bamHandle.close()
    np.savetxt(args.out + ".counts.txt.gz", mat, delimiter="\n", fmt='%i')
示例#10
0
def sites_coverage(alignment: pysam.AlignmentFile, sites: dict):

    coverage = dict()

    for segment in alignment.fetch():

        reference_name = segment.reference_name
        segment_type = 2 * segment.is_read2 + segment.is_reverse
        segment_pos = 0
        reference_pos = segment.reference_start + BASE

        for cigartuple in segment.cigartuples:

            if cigartuple[0] == 0:
                for inc in range(1, cigartuple[1] - 1):
                    matched_position = reference_pos + inc
                    if sites.get((reference_name, matched_position), False):
                        offset = segment_pos + inc
                        key = (reference_name + '_' + str(matched_position),
                               offset)
                        if key not in coverage:
                            coverage[key] = [0] * 4
                        coverage[key][segment_type] += 1

                segment_pos += cigartuple[1]
                reference_pos += cigartuple[1]

            elif cigartuple[0] == 1:
                segment_pos += cigartuple[1]
            elif cigartuple[0] in (2, 3):
                reference_pos += cigartuple[1]

    return coverage
示例#11
0
def _read_bam_frag(inbam, filter_exclude, all_bins, sections1, sections2,
                   rand_hash, resolution, tmpdir, region, start, end,
                   half=False, sum_columns=False):
    bamfile = AlignmentFile(inbam, 'rb')
    refs = bamfile.references
    bam_start = start - 2
    bam_start = max(0, bam_start)
    try:
        dico = {}
        for r in bamfile.fetch(region=region,
                               start=bam_start, end=end,  # coords starts at 0
                               multiple_iterators=True):
            if r.flag & filter_exclude:
                continue
            crm1 = r.reference_name
            pos1 = r.reference_start + 1
            crm2 = refs[r.mrnm]
            pos2 = r.mpos + 1
            try:
                pos1 = sections1[(crm1, pos1 / resolution)]
                pos2 = sections2[(crm2, pos2 / resolution)]
            except KeyError:
                continue  # not in the subset matrix we want
            crm = crm1 * (crm1 == crm2)
            try:
                dico[(crm, pos1, pos2)] += 1
            except KeyError:
                dico[(crm, pos1, pos2)] = 1
            # print '%-50s %5s %9s %5s %9s' % (r.query_name,
            #                                  crm1, r.reference_start + 1,
            #                                  crm2, r.mpos + 1)
        if half:
            for c, i, j in dico:
                if i < j:
                    del dico[(c, i, j)]
        out = open(os.path.join(tmpdir, '_tmp_%s' % (rand_hash),
                                '%s:%d-%d.tsv' % (region, start, end)), 'w')
        out.write(''.join('%s\t%d\t%d\t%d\n' % (c, a, b, v)
                          for (c, a, b), v in dico.iteritems()))
        out.close()
        if sum_columns:
            sumcol = {}
            cisprc = {}
            for (c, i, j), v in dico.iteritems():
                # out.write('%d\t%d\t%d\n' % (i, j, v))
                try:
                    sumcol[i] += v
                    cisprc[i][all_bins[i][0] == all_bins[j][0]] += v
                except KeyError:
                    sumcol[i]  = v
                    cisprc[i]  = [0, 0]
                    cisprc[i][all_bins[i][0] == all_bins[j][0]] += v
            return sumcol, cisprc
    except Exception, e:
        exc_type, exc_obj, exc_tb = exc_info()
        fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
        print e
        print(exc_type, fname, exc_tb.tb_lineno)
示例#12
0
def recount_on_file(bam_path):
    bam = AlignmentFile(bam_path)
    sample, ext = os.path.splitext(os.path.basename(bam_path))
    assert(ext == '.bam')
    output_path = os.path.join(temp_dir, sample + tmp_ext)
    ret = ['\t'.join(['chr1', 'pos1', 'str1', 'chr2',
                      'pos2', 'str2', 'count', 'hq_count', 'raw_count', 'sample'])]

    for pair in pairs.values():
        paired_rec1, paired_rec2 = pair.get_pair()
        rec1 = paired_rec1.rec
        rec2 = paired_rec2.rec
        interchrom = (rec1.chrom != rec1.chrom)
        str1 = str(paired_rec1.strand * 2 - 1)
        str2 = str(paired_rec2.strand * 2 - 1)
        loc1 = paired_rec1.get_upstream_region(win_size)
        loc2 = paired_rec2.get_upstream_region(win_size)

        # DEBUG ONLY
        # if rec1.chrom != 'chr5' or rec2.chrom != 'chr5':
        #     continue

        bam1 = list(bam.fetch(region=loc1, multiple_iterators=True))
        bam2 = list(bam.fetch(region=loc2, multiple_iterators=True))

        raw1 = {rec.query_name for rec in bam1}
        raw2 = {rec.query_name for rec in bam2}
        n_shared_raw = len(raw1.intersection(raw2))

        bam1_primary = [rec for rec in bam1 if not rec.is_duplicate]
        bam2_primary = [rec for rec in bam2 if not rec.is_duplicate]

        id1 = {rec.query_name for rec in bam1_primary}
        id2 = {rec.query_name for rec in bam2_primary}
        n_shared = len(id1.intersection(id2))

        bam1_hq = [rec for rec in bam1_primary if rec.mapq >= 30]
        bam2_hq = [rec for rec in bam2_primary if rec.mapq >= 30]

        bam1_r1 = {rec.query_name for rec in bam1_hq if rec.is_read1}
        bam1_r2 = {rec.query_name for rec in bam1_hq if rec.is_read2}
        bam2_r1 = {rec.query_name for rec in bam2_hq if rec.is_read1}
        bam2_r2 = {rec.query_name for rec in bam2_hq if rec.is_read2}
        i1 = bam1_r1.intersection(bam2_r2)
        i2 = bam1_r2.intersection(bam2_r1)
        n_shared_hq = len(i1) + len(i2)

        if n_shared:
            tokens = [rec1.chrom, str(rec1.pos), str1,
                      rec2.chrom, str(rec2.pos), str2,
                      str(n_shared), str(n_shared_hq), str(n_shared_raw),
                      sample]
            ret.append('\t'.join(tokens))

    bam.close()
    with open(output_path, 'w') as f:
        f.write('\n'.join(ret) + '\n')
    return output_path
示例#13
0
def read_bam_frag_filter(inbam, filter_exclude, all_bins, sections, resolution,
                         outdir, extra_out, region, start, end):
    bamfile = AlignmentFile(inbam, 'rb')
    refs = bamfile.references
    try:
        dico = {}
        for r in bamfile.fetch(
                region=region,
                start=start - (1 if start else 0),
                end=end,  # coords starts at 0
                multiple_iterators=True):
            if r.flag & filter_exclude:
                continue
            crm1 = r.reference_name
            pos1 = r.reference_start + 1
            crm2 = refs[r.mrnm]
            pos2 = r.mpos + 1
            try:
                pos1 = sections[(crm1, pos1 / resolution)]
                pos2 = sections[(crm2, pos2 / resolution)]
            except KeyError:
                continue  # not in the subset matrix we want
            try:
                dico[(pos1, pos2)] += 1
            except KeyError:
                dico[(pos1, pos2)] = 1
        cisprc = {}
        for (i, j), v in dico.iteritems():
            if all_bins[i][0] == all_bins[j][0]:
                try:
                    cisprc[i][0] += v
                    cisprc[i][1] += v
                except KeyError:
                    cisprc[i] = [v, v]
            else:
                try:
                    cisprc[i][1] += v
                except KeyError:
                    cisprc[i] = [0, v]
        out = open(
            path.join(
                outdir,
                'tmp_%s:%d-%d_%s.pickle' % (region, start, end, extra_out)),
            'w')
        dump(dico, out)
        out.close()
        out = open(
            path.join(
                outdir, 'tmp_bins_%s:%d-%d_%s.pickle' %
                (region, start, end, extra_out)), 'w')
        dump(cisprc, out)
        out.close()
    except Exception, e:
        exc_type, exc_obj, exc_tb = exc_info()
        fname = path.split(exc_tb.tb_frame.f_code.co_filename)[1]
        print e
        print(exc_type, fname, exc_tb.tb_lineno)
def parse_magic_blast_out(sam_output, working_dir, cutoff):
    unknown_out = os.path.join(working_dir, '%s.unknown.fasta' % sam_output)
    with open(unknown_out, 'w') as unk_out:
        bf = AlignmentFile(sam_output, 'r', check_header=False, check_sq=False)
        for r in bf.fetch(until_eof=True):
            if r.is_unmapped or (r.qual and r.qual < cutoff):
                sequences = SeqRecord(Seq(r.query, IUPAC.IUPACUnambiguousDNA), id = r.qname, description='')
                SeqIO.write(sequences, unk_out, "fasta")
    return unknown_out
示例#15
0
def deduplicate_reads(bamin, bamout, report, by_rg=True):
    """This script deduplicates the original bamfile.
    Deduplication removes reads align to the same position.
    If the reads in the bamfile contain a RG tag and
    by_rg=True, deduplication is done for each group separately.
    Parameters
    ----------
    bamfile : str
        Sorted bamfile containing barcoded reads.
    output : str
        Output path to a bamfile that contains the deduplicated reads.
    by_rg : boolean
        If True, the reads will be split by group tag.
    """
    bamfile = AlignmentFile(bamin, 'rb')
    output = AlignmentFile(bamout, 'wb', template=bamfile)

    log_counts = {'total': 0, 'retained': 0, 'removed': 0}

    # grep all barcodes from the header
    #barcodes = set()
    last_barcode = {}

    for aln in bamfile.fetch():
        # if previous hash matches the current has
        # skip the read
        val = (aln.reference_id, aln.reference_start, aln.is_reverse, aln.tlen)
        if aln.has_tag('RG') and by_rg:
            rg = aln.get_tag('RG')
        else:
            rg = 'dummy'
        log_counts['total'] += 1

        if rg not in last_barcode:
            output.write(aln)
            # clear dictionary
            last_barcode[rg] = val

        if val == last_barcode[rg]:
            log_counts['removed'] += 1
            continue
        else:
            output.write(aln)
            last_barcode[rg] = val

        log_counts['retained'] += 1

        if (log_counts['retained'] % 1000000) == 0:
            print("Processed {}/{} total/removed reads".format(
                log_counts['total'], log_counts['removed']))

    #write log file
    with open(report, 'w') as f:
        f.write('\tcounts\n')
        for icnt in log_counts:
            f.write('{}\t{}\n'.format(icnt, log_counts[icnt]))
示例#16
0
def remove_chroms(bamin, bamout, rmchroms):
    """ Removes chromosomes from bam-file.

    The function searches for matching chromosomes
    using regular expressions.
    For example, rmchroms=['chrM', '_random']
    would remove 'chrM' as well as all random chromsomes.
    E.g. chr1_KI270706v1_random.

    Parameters
    ----------
    bamin : str
       Input bam file.
    bamout : str
       Output bam file.
    rmchroms : list(str)
       List of chromosome names or name patterns to be removed.

    Returns
    -------
    None

    """

    treatment = AlignmentFile(bamin, 'rb')

    header = copy(treatment.header.as_dict())
    newheader = []
    for seq in header['SQ']:

        if not any([x in seq['SN'] for x in rmchroms]):
            newheader.append(seq)

    header['SQ'] = newheader

    tidmap = {k['SN']: i for i, k in enumerate(header['SQ'])}

    bam_writer = AlignmentFile(bamout, 'wb', header=header)

    # write new bam files containing only valid chromosomes
    for aln in treatment.fetch(until_eof=True):
        if aln.is_unmapped:
            continue
        if aln.reference_name not in tidmap or aln.next_reference_name not in tidmap:
            continue

        refid = tidmap[aln.reference_name]
        refnextid = tidmap[aln.next_reference_name]

        aln.reference_id = refid
        aln.next_reference_id = refnextid
        bam_writer.write(aln)

    bam_writer.close()
    treatment.close()
示例#17
0
def fragmentlength_in_regions(file, regions, mapq, maxlen, resolution):
    """ Extract fragment lengths per region.

    Deprecated.

    Parameters
    ----------
    bamfile : str
       Indexed input bam file.
    regions : str
       Regions in bed format. Must be genome-wide bins.
    mapq : int
       Mapping quality
    maxlen : int
       Maximum fragment length.
    resolution : int
       Base pair resolution.

    Return
    -------
        CountMatrix and annotation as pd.DataFrame
    """

    warnings.warn('fragmentlength_in_regions deprecated.',
                  category=DeprecationWarning)
    bed = BedTool(regions)
    binsize = bed[0].end - bed[0].start
    fragments = np.zeros((len(bed), maxlen // resolution))
    m = {(iv.chrom, iv.start): i for i, iv in enumerate(bed)}

    afile = AlignmentFile(bamfile, "rb")

    for aln in afile.fetch():
        if aln.mapping_quality < mapq:
            continue
        if aln.is_proper_pair and aln.is_read1:

            pos = (min(aln.reference_start, aln.next_reference_start) //
                   binsize) * binsize

            tlen = abs(aln.tlen) // resolution
            if tlen < maxlen // resolution:
                if (aln.reference_name, pos) in m:
                    fragments[m[(aln.reference_name, pos)], tlen] += 1

    afile.close()
    cmat = fragments
    cannot = pd.DataFrame({'barcode':
                           ['{}bp'.format(bp*resolution) \
                            for bp in range(maxlen// resolution)]})

    return cmat, cannot
示例#18
0
def extract_barcode(sam, barcode):

    parser_re = re.compile('.*:CELL_(?P<CB>.*):UMI_(?P<MB>.*)')
    sam_file = AlignmentFile(sam, mode='r')
    filter_file = AlignmentFile("-", mode='wh', template=sam_file)
    track = sam_file.fetch(until_eof=True)
    for i, aln in enumerate(track):
        if aln.is_unmapped:
            continue
        match = parser_re.match(aln.qname)
        CB = match.group('CB')
        if CB == barcode:
            filter_file.write(aln)
示例#19
0
def extract_barcode(sam, barcode):

    parser_re = re.compile(".*:CELL_(?P<CB>.*):UMI_(?P<MB>.*)")
    sam_file = AlignmentFile(sam, mode="r")
    filter_file = AlignmentFile("-", mode="wh", template=sam_file)
    track = sam_file.fetch(until_eof=True)
    for i, aln in enumerate(track):
        if aln.is_unmapped:
            continue
        match = parser_re.match(aln.qname)
        CB = match.group("CB")
        if CB == barcode:
            filter_file.write(aln)
示例#20
0
def read_bam_frag_filter(inbam, filter_exclude, all_bins, sections,
                         resolution, outdir, extra_out,region, start, end):
    bamfile = AlignmentFile(inbam, 'rb')
    refs = bamfile.references
    try:
        dico = {}
        for r in bamfile.fetch(region=region,
                               start=start - (1 if start else 0), end=end,  # coords starts at 0
                               multiple_iterators=True):
            if r.flag & filter_exclude:
                continue
            crm1 = r.reference_name
            pos1 = r.reference_start + 1
            crm2 = refs[r.mrnm]
            pos2 = r.mpos + 1
            try:
                pos1 = sections[(crm1, pos1 / resolution)]
                pos2 = sections[(crm2, pos2 / resolution)]
            except KeyError:
                continue  # not in the subset matrix we want
            try:
                dico[(pos1, pos2)] += 1
            except KeyError:
                dico[(pos1, pos2)] = 1
        cisprc = {}
        for (i, j), v in dico.iteritems():
            if all_bins[i][0] == all_bins[j][0]:
                try:
                    cisprc[i][0] += v
                    cisprc[i][1] += v
                except KeyError:
                    cisprc[i] = [v, v]
            else:
                try:
                    cisprc[i][1] += v
                except KeyError:
                    cisprc[i] = [0, v]
        out = open(path.join(outdir,
                             'tmp_%s:%d-%d_%s.pickle' % (region, start, end, extra_out)), 'w')
        dump(dico, out, HIGHEST_PROTOCOL)
        out.close()
        out = open(path.join(outdir, 'tmp_bins_%s:%d-%d_%s.pickle' % (
            region, start, end, extra_out)), 'w')
        dump(cisprc, out, HIGHEST_PROTOCOL)
        out.close()
    except Exception, e:
        exc_type, exc_obj, exc_tb = exc_info()
        fname = path.split(exc_tb.tb_frame.f_code.co_filename)[1]
        print e
        print(exc_type, fname, exc_tb.tb_lineno)
示例#21
0
def alignment2counts(alignment: pysam.AlignmentFile,
                     sites: DefaultDict[Tuple[str, int], Set], unique: bool,
                     primary: bool, stranded: bool,
                     strand_mode: str) -> Counter[SiteWithOffset]:
    """
    Take an alignment and dictionary of splice sites as input
    and returns dictionary of splice sites with their counts
    """

    trans = prepare_ref_names(alignment)

    counts = Counter()
    segment: pysam.AlignedSegment  # just annotation line

    # iterating through the alignment file
    for segment in alignment.fetch():

        # check if read is mapped
        if segment.is_unmapped:
            continue

        # check if read is not multimapped
        if unique and segment.has_tag("NH") and segment.get_tag("NH") > 1:
            continue

        # if read is multimapped only primary alignment will be considered
        if primary and segment.is_secondary:
            continue

        # if read is multimapped only primary alignment will be considered
        if primary and segment.is_supplementary:
            continue

        ref_name = trans.get(segment.reference_name)
        if ref_name is None:
            continue

        # adding new counts to splice site
        segment2counts(segment=segment,
                       ref_name=ref_name,
                       sites=sites,
                       counts=counts,
                       stranded=stranded,
                       strand_mode=strand_mode)

    return counts
def remove_idx_from_read_names(input_bam: Path):
    """ Replace READNAME:ALIGN_IDX with just READNAME

    Originally created because WhatsHap requires unique read names.
    """

    infile = AlignmentFile(input_bam, "rb")
    stdout = AlignmentFile("-", "wb", template=infile)
    align_iter = infile.fetch(until_eof=True)

    for read in align_iter:
        readname = read.query_name.split(":")[0]
        read.query_name = readname
        stdout.write(read)

    stdout.close()
    infile.close()
示例#23
0
def pairs_to_telbam(af_pairs: AlignmentFile, af_telbam: AlignmentFile):
    read_iter = af_pairs.fetch(until_eof=True)
    while True:
        read_a = next(read_iter, None)
        if read_a is None:
            break
        read_b = next(read_iter)
        qseq = read_a.query_sequence
        if TEL_PATS[0] in qseq or TEL_PATS[1] in qseq:
            af_telbam.write(read_a)
            af_telbam.write(read_b)
        else:
            qseq = read_b.query_sequence
            if TEL_PATS[0] in qseq or TEL_PATS[1] in qseq:
                af_telbam.write(read_a)
                af_telbam.write(read_b)
    return
示例#24
0
def find_single_cds(cram: AlignmentFile, sequence_cds: CdsPos) -> str:
    """ Finds (presumed) cds sequence by parameters

    :param cram: pre-loaded file to be search
    :param sequence_cds: CDS location

    :return: a string with DNA symbols A, T, C, G
    """
    single_cds = ''

    size = 0
    # The CDS be spliced from multiple exons.
    for cds_from, cds_to in sequence_cds.indexes:
        assert cds_from is not None
        assert cds_to is not None
        size += cds_to - cds_from

        region = '{}:{}-{}'.format(sequence_cds.molecule, cds_from, cds_to)
        index = cds_from
        for read in cram.fetch(region=region):
            if read.reference_start is None or read.reference_end is None:
                continue
            ref_from = read.positions[0]
            ref_to = read.positions[-1]
            assert ref_from is not None
            assert ref_to is not None

            if ref_from > index:
                missing_len = ref_from - index
                single_cds += '-' * missing_len
                index += missing_len

            if ref_from <= index < ref_to and index <= cds_to + 1:
                read_start = max(cds_from, index) - ref_from
                read_end = min(cds_to, ref_to) - ref_from
                single_cds += read.seq[read_start:read_end]
                index += read_end - read_start

            if index > cds_to:
                break

        # Fill missing remainder (if any) with -.
        single_cds += '-' * (size - len(single_cds))

    return single_cds
示例#25
0
def main():
    args = parse_args()
    with args.bam_file:
        bam_reader = AlignmentFile(args.bam_file)
        
        if not bam_reader.has_index():
            print('Adding index...')
            index_args = ['samtools', 'index', args.bam_file.name]
            run(index_args, check=True)
            args.bam_file.seek(0)  # Go back to start of header.
            bam_reader = AlignmentFile(args.bam_file)
            bam_reader.check_index()
        x = bam_reader.parse_region(region=args.target_region)
        sequences = sorted(bam_reader.fetch(region=args.target_region),
                           key=attrgetter('qname'))
        print(len(sequences))
        for seq in sequences:
            print(seq.qname)
示例#26
0
def fetch_count_read (alignment_file, seq_name, start, end):
    """
    Count the number of read that are at least partly overlapping a specified chromosomic region
    @param alignment_file Path to a sam or a bam file
    @param seq_name Name of the sequence where read are to be aligned on
    @param start Start genomic coordinates of the area of alignment
    @param end End End genomic coordinates of the area of alignment
    """
    # Specific imports
    from pysam import AlignmentFile
    
    al = AlignmentFile(alignment_file, "rb")
    
    # Count read aligned at least partly on the specified region
    n = 0
    for i in al.fetch(seq_name, start, end):
        n += 1
    return n
def add_idx_to_read_name(input_bam: Path):
    """ Changes the readname to be READNAME:ALIGN_IDX to have 'unique' readnames

    WhatsHap requires unique read names.
    """

    infile = AlignmentFile(input_bam, "rb")
    stdout = AlignmentFile("-", "wb", template=infile)
    align_iter = infile.fetch(until_eof=True)

    i = 0
    for read in align_iter:
        read.query_name = read.query_name + ":" + str(i)
        stdout.write(read)
        i = i + 1

    stdout.close()
    infile.close()
示例#28
0
def get_barcode_frequency_genomewide(bamfile, storage):
    """ This function obtains the barcode frequency
    and stores it in a table.

    Parameters
    ----------
    bamfile :  str
        Path to a bamfile. The bamfile must be indexed.
    storage : str
        Path to the output hdf5 file, which contains the counts per chromsome.
    """

    # Obtain the header information
    afile = AlignmentFile(bamfile, 'rb')

    if 'RG' in afile.header:
        use_group = True
    else:
        use_group = False

    barcodes = {}
    if use_group:
        # extract barcodes
        for idx, item in enumerate(afile.header['RG']):
            barcodes[item['ID']] = 0
    else:
        barcodes['dummy'] = 0
    print('found {} barcodes'.format(len(barcodes)))

    for aln in afile.fetch(until_eof=True):
        if aln.is_proper_pair and aln.is_read1:
            barcodes[aln.get_tag('RG') if use_group else 'dummy'] += 1

        if not aln.is_paired:
            barcodes[aln.get_tag('RG') if use_group else 'dummy'] += 1

    afile.close()

    names = [key for key in barcodes]
    counts = [barcodes[key] for key in barcodes]

    df = pd.DataFrame({'barcodes': names, 'counts': counts})

    df.to_csv(storage, sep='\t', header=True, index=False)
示例#29
0
def gather_sv_data(options, collection):
    # Read regions of interest BED file
    regions = BedTool(options.region_file)

    # Read BAM file
    bamfile = AlignmentFile(options.bam_file, "rb")

    # Intersect regions
    for reg in regions:
        for read in bamfile.fetch(reg.chrom, reg.start, reg.end):
            #print read
            if read.query_name.endswith("2d"):
                collection[read.query_name] = []
            if read.query_name.startswith("ctg"):
                collection[read.query_name] = []
                #print read.reference_id, read.reference_start, read.reference_end
                #print read.query_name, read.query_alignment_start, read.query_alignment_end

    bamfile.close()
示例#30
0
def gather_sv_data(options, collection):
	# Read regions of interest BED file
	regions = BedTool(options.region_file)

	# Read BAM file
	bamfile = AlignmentFile(options.bam_file, "rb")

	# Intersect regions
	for reg in regions:
		for read in bamfile.fetch(reg.chrom, reg.start, reg.end):
			#print read
			if read.query_name.endswith("2d"):
				collection[read.query_name] = []
			if read.query_name.startswith("ctg"):
				collection[read.query_name] = []
				#print read.reference_id, read.reference_start, read.reference_end
				#print read.query_name, read.query_alignment_start, read.query_alignment_end

	bamfile.close()
示例#31
0
def alignment2junctions(
        alignment: pysam.AlignmentFile, unique: bool,
        primary: bool) -> Tuple[DefaultDict[RawJunction, List[int]], int]:
    """
    Find all junctions in alignment file and compute their counts.
    Also determine most common read length.

    :param alignment: alignment file
    :param unique: account only uniquely mapped (aligned) segments
    :param primary: account only primary alignment if segment is a multi-mapper
    :return: dictionary mapping junction to its counts in each offset and most common read length
    """
    segment: pysam.AlignedSegment  # just annotation line for convenience
    junctions_with_counts = defaultdict(lambda: [0] * 4)
    read_lengths_counter = Counter()

    # iterating through the alignment file
    for segment in alignment.fetch():

        # if read is not mapped
        if segment.is_unmapped:
            continue

        # if read is a multi-mapper
        if unique and segment.has_tag("NH") and segment.get_tag("NH") > 1:
            continue

        # if read is a multi-mapper consider only primary alignment
        if primary and segment.is_secondary:
            continue

        # if read is a multi-mapper consider only primary alignment
        if primary and segment.is_supplementary:
            continue

        read_lengths_counter[segment.infer_read_length()] += 1

        # adding new junctions if present and updating counts
        segment2junctions(segment=segment,
                          junctions_with_counts=junctions_with_counts)

    return junctions_with_counts, read_lengths_counter.most_common()[0][0]
示例#32
0
def constructDistributions(bamName, lengths):
	'''
	Given a BAM file, constructs a coverage distribution for each long read
	Inputs
	- (str) bamName: BAM file name
	- (dict[(str) refName] = (int) read length) lengths: 
          returns the length of the long read given its read name
	Outputs
	- ( dict[(str) refName] = (numpy.array of ints) distribution ) dists: contains the coverage distributions 
          for each long read
	'''
	samfile = AlignmentFile(bamName, 'r')
	iter = samfile.fetch()
	dists = {}
	for alignment in iter: 
		refName = alignment.reference_name
		start = int(alignment.reference_start)
		cigarTups = alignment.cigartuples
		updateDistribution(dists, lengths, refName, start, cigarTups)
	return dists
def extract_barcode(sam, barcode_file, outdir):

    # Create the hash set for cell names
    fin = open(barcode_file, 'r')
    barcodes_filtered = set()
    for line in fin:
        line = line.strip()
        barcodes_filtered.add(line)

    print(len(barcodes_filtered))

    sam_file = AlignmentFile(sam, mode='r')
    #filter_file = AlignmentFile("-", mode='wh', template=sam_file)
    track = sam_file.fetch(until_eof=True)
    for i, aln in enumerate(track):
        # if aln.is_unmapped:
        # continue
        # print(i)
        ''' Error to use query_alignment_sequence, use query_sequence instead? '''
        reads_name, reads, cell_barcode, umi, quality = aln.qname, aln.query_sequence, aln.get_tag(
            'XC'), aln.get_tag('XM'), aln.qual
        # print(reads_name, reads, cell_barcode, umi, quality)
        # print(reads)
        # print(quality)

        if cell_barcode in barcodes_filtered:
            # print(reads_name, reads, cell_barcode, umi, quality)
            if len(reads) != len(aln.qual):
                print("Error, skipped:", reads, quality)
                continue

            fout_umi = open(outdir + '/' + cell_barcode + '.umi', 'a+')
            fout_umi.write(umi + '\n')

            fout_fq = open(outdir + '/' + cell_barcode + '.fastq', 'a+')
            fout_fq.write('@' + reads_name + '\n')
            fout_fq.write(reads + '\n')
            fout_fq.write('+\n')
            fout_fq.write(quality + '\n')
        if i % 100000 == 0:
            print(i / 209400000.0)
示例#34
0
文件: snp_10x.py 项目: redst4r/scsnp
def get_bases_at_genomic_position(chrom: str,
                                  start: int,
                                  bamfile: pysam.AlignmentFile,
                                  ignore_softclipped=True):
    """
    for each read covering the given position,
    return readname, cell-barcode, umi and base
    ignore_softclipped: if the base of the read at the locus is softclipped, should we still return the read? (usually not as the base is not really aligned)
    """
    n_reads = 0
    reads_covering_position = []
    for alignment in bamfile.fetch(chrom, start, start + 1):
        n_reads += 1
        if start in alignment.get_reference_positions():

            if alignment.has_tag('CB') and alignment.has_tag('UB'):
                readname, cellbarcode, umi = parse_chromium_bamread_metadata(
                    alignment)

                base_index = pysam_reference_coordinate_2_query_coordinate(
                    alignment, start)

                if ignore_softclipped:

                    def cigar2str(a):
                        "just spells out the cigar for each base, ie a 124BP read gets a 124CIGAR string"
                        tmp = [[symbol] * freq for symbol, freq in a.cigar]
                        return "".join(str(_) for _ in itertools.chain(*tmp))

                    cig = cigar2str(alignment)
                    if cig[base_index] == '4':  # $ is the code for softclipped
                        continue

                if alignment.is_duplicate or alignment.is_qcfail or alignment.is_secondary or alignment.mapping_quality != 255:  # 255 mean uniquly mapped in STAR
                    continue

                base = alignment.query_sequence[base_index]
                reads_covering_position.append(
                    (readname, cellbarcode, umi, base))
    return reads_covering_position
示例#35
0
def getFragmentSizesFromChunkList(chunks, bamfile, lower, upper, atac=1):
    sizes = np.zeros(upper - lower, dtype=np.float)
    # loop over samfile
    bamHandle = AlignmentFile(bamfile)
    for chunk in chunks:
        for read in bamHandle.fetch(chunk.chrom, max(0, chunk.start - upper),
                                    chunk.end + upper):
            if read.is_proper_pair and not read.is_reverse:
                if atac:
                    #get left position
                    l_pos = read.pos + 4
                    #get insert size
                    #correct by 8 base pairs to be inserion to insertion
                    ilen = abs(read.template_length) - 8
                else:
                    l_pos = read.pos
                    ilen = abs(read.template_length)
                center = l_pos + (ilen - 1) // 2
                if ilen < upper and ilen >= lower and center >= chunk.start and center < chunk.end:
                    sizes[ilen - lower] += 1
    bamHandle.close()
    return sizes
示例#36
0
def count_mapped_bp(args, tempdir, genes):
    """ Count number of bp mapped to each gene across pangenomes.
    Return number covered genes and average gene depth per species.
    Result contains only covered species, but being a defaultdict,
    would yield 0 for any uncovered species, which is appropriate.
    """
    bam_path = f"{tempdir}/pangenomes.bam"
    bamfile = AlignmentFile(bam_path, "rb")
    covered_genes = {}

    # loop over alignments, sum values per gene
    for aln in bamfile.fetch(until_eof=True):
        gene_id = bamfile.getrname(aln.reference_id)
        gene = genes[gene_id]
        gene["aligned_reads"] += 1
        if keep_read(aln, args.aln_mapid, args.aln_readq, args.aln_mapq, args.aln_cov):
            gene["mapped_reads"] += 1
            gene["depth"] += len(aln.query_alignment_sequence) / float(gene["length"])
            covered_genes[gene_id] = gene

    tsprint("Pangenome count_mapped_bp:  total aligned reads: %s" % sum(g["aligned_reads"] for g in genes.values()))
    tsprint("Pangenome count_mapped_bp:  total mapped reads: %s" % sum(g["mapped_reads"] for g in genes.values()))

    # Filter to genes with non-zero depth, then group by species
    nonzero_gene_depths = defaultdict(list)
    for g in covered_genes.values():
        gene_depth = g["depth"]
        if gene_depth > 0:  # This should always pass, because ags.aln_cov is always >0.
            species_id = g["species_id"]
            nonzero_gene_depths[species_id].append(gene_depth)

    # Compute number of covered genes per species, and average gene depth.
    num_covered_genes = defaultdict(int)
    mean_coverage = defaultdict(float)
    for species_id, non_zero_depths in nonzero_gene_depths.items():
        num_covered_genes[species_id] = len(non_zero_depths)
        mean_coverage[species_id] = np.mean(non_zero_depths)

    return num_covered_genes, mean_coverage, covered_genes
示例#37
0
def deduplicate_reads(bamin, bamout, tag='CB'):
    """Performs deduplication within barcodes/cells.

    Parameters
    ----------
    bamin : str
        Position sorted input bamfile.
    bamout : str
        Output file containing deduplicated reads.
    tag : str or callable
        Indicates the barcode tag or custom function to extract the barcode. Default: 'CB'

    Returns
    -------
    None

    """
    bamfile = AlignmentFile(bamin, 'rb')
    output = AlignmentFile(bamout, 'wb', template=bamfile)

    last_barcode = {}
    barcoder = Barcoder(tag)

    for aln in bamfile.fetch():
        # if previous hash matches the current has
        # skip the read
        val = (aln.reference_id, aln.reference_start, aln.is_reverse, aln.tlen)
        barcode = barcoder(aln)

        if barcode not in last_barcode:
            output.write(aln)
            # clear dictionary
            last_barcode[barcode] = val

        if val == last_barcode[barcode]:
            continue
        else:
            output.write(aln)
            last_barcode[barcode] = val
示例#38
0
文件: umis.py 项目: Teichlab/umis
def tagcount(sam, out, genemap, output_evidence_table, positional, minevidence):
    ''' Count up evidence for tagged molecules
    '''
    from pysam import AlignmentFile
    from cStringIO import StringIO
    import pandas as pd

    from utils import weigh_evidence

    logger.info('Reading optional files')

    gene_map = None
    if genemap:
        with open(genemap) as fh:
            gene_map = dict(p.strip().split() for p in fh)

    if positional:
        tuple_template = '{0},{1},{2},{3}'
    else:
        tuple_template = '{0},{1},{3}'

    parser_re = re.compile('.*:CELL_(?P<CB>.*):UMI_(?P<MB>.*)')

    logger.info('Tallying evidence')
    start_tally = time.time()

    evidence = collections.defaultdict(int)

    sam_file = AlignmentFile(sam, mode='r')
    track = sam_file.fetch(until_eof=True)
    for i, aln in enumerate(track):
        if aln.is_unmapped:
            continue

        match = parser_re.match(aln.qname)
        CB = match.group('CB')
        MB = match.group('MB')

        txid = sam_file.getrname(aln.reference_id)
        if gene_map:
            target_name = gene_map[txid]

        else:
            target_name = txid

        e_tuple = tuple_template.format(CB, target_name, aln.pos, MB)

        # Scale evidence by number of hits
        evidence[e_tuple] += weigh_evidence(aln.tags)

    tally_time = time.time() - start_tally
    logger.info('Tally done - {:.3}s, {:,} alns/min'.format(tally_time, int(60. * i / tally_time)))
    logger.info('Collapsing evidence')

    buf = StringIO()
    for key in evidence:
        line = '{},{}\n'.format(key, evidence[key])
        buf.write(line)

    buf.seek(0)
    evidence_table = pd.read_csv(buf)
    evidence_query = 'evidence >= %f' % minevidence
    if positional:
        evidence_table.columns=['cell', 'gene', 'umi', 'pos', 'evidence']
        collapsed = evidence_table.query(evidence_query).groupby(['cell', 'gene'])['umi', 'pos'].size()

    else:
        evidence_table.columns=['cell', 'gene', 'umi', 'evidence']
        collapsed = evidence_table.query(evidence_query).groupby(['cell', 'gene'])['umi'].size()

    expanded = collapsed.unstack().T

    if gene_map:
        # This Series is just for sorting the index
        genes = pd.Series(index=set(gene_map.values()))
        genes = genes.sort_index()
        # Now genes is assigned to a DataFrame
        genes = expanded.ix[genes.index]

    else:
        genes = expanded

    genes.replace(pd.np.nan, 0, inplace=True)

    logger.info('Output results')

    if output_evidence_table:
        import shutil
        buf.seek(0)
        with open(output_evidence_table, 'w') as etab_fh:
            shutil.copyfileobj(buf, etab_fh)

    genes.to_csv(out)
示例#39
0
文件: umis.py 项目: vals/umis
def tagcount(sam, out, genemap, output_evidence_table, positional, minevidence,
             cb_histogram, cb_cutoff, no_scale_evidence, subsample, sparse,
             parse_tags, gene_tags):
    ''' Count up evidence for tagged molecules
    '''
    from pysam import AlignmentFile

    from io import StringIO
    import pandas as pd

    from utils import weigh_evidence

    logger.info('Reading optional files')

    gene_map = None
    if genemap:
        with open(genemap) as fh:
            try:
                gene_map = dict(p.strip().split() for p in fh)
            except ValueError:
                logger.error('Incorrectly formatted gene_map, need to be tsv.')
                sys.exit()

    if positional:
        tuple_template = '{0},{1},{2},{3}'
    else:
        tuple_template = '{0},{1},{3}'

    if not cb_cutoff:
        cb_cutoff = 0

    if cb_histogram and cb_cutoff == "auto":
        cb_cutoff = guess_depth_cutoff(cb_histogram)

    cb_cutoff = int(cb_cutoff)

    cb_hist = None
    filter_cb = False
    if cb_histogram:
        cb_hist = pd.read_csv(cb_histogram, index_col=0, header=-1, squeeze=True, sep="\t")
        total_num_cbs = cb_hist.shape[0]
        cb_hist = cb_hist[cb_hist > cb_cutoff]
        logger.info('Keeping {} out of {} cellular barcodes.'.format(cb_hist.shape[0], total_num_cbs))
        filter_cb = True

    parser_re = re.compile('.*:CELL_(?P<CB>.*):UMI_(?P<MB>.*)')

    if subsample:
        logger.info('Creating reservoir of subsampled reads ({} per cell)'.format(subsample))
        start_sampling  = time.time()

        reservoir = collections.defaultdict(list)
        cb_hist_sampled = 0 * cb_hist
        cb_obs = 0 * cb_hist

        track = stream_bamfile(sam)
        current_read = 'none_observed_yet'
        for i, aln in enumerate(track):
            if aln.qname == current_read:
                continue

            current_read = aln.qname

            if parse_tags:
                CB = aln.get_tag('CR')
            else:
                match = parser_re.match(aln.qname)
                CB = match.group('CB')

            if CB not in cb_hist.index:
                continue

            cb_obs[CB] += 1
            if len(reservoir[CB]) < subsample:
                reservoir[CB].append(i)
                cb_hist_sampled[CB] += 1
            else:
                s = pd.np.random.randint(0, cb_obs[CB])
                if s < subsample:
                    reservoir[CB][s] = i

        index_filter = set(itertools.chain.from_iterable(reservoir.values()))
        sam_file.close()
        sampling_time = time.time() - start_sampling
        logger.info('Sampling done - {:.3}s'.format(sampling_time))

    evidence = collections.defaultdict(int)

    logger.info('Tallying evidence')
    start_tally = time.time()

    sam_mode = 'r' if sam.endswith(".sam") else 'rb'
    sam_file = AlignmentFile(sam, mode=sam_mode)
    targets = [x["SN"] for x in sam_file.header["SQ"]]
    track = sam_file.fetch(until_eof=True)
    count = 0
    unmapped = 0
    kept = 0
    nomatchcb = 0
    current_read = 'none_observed_yet'
    count_this_read = True
    missing_transcripts = set()
    for i, aln in enumerate(track):
        if count and not count % 1000000:
            logger.info("Processed %d alignments, kept %d." % (count, kept))
            logger.info("%d were filtered for being unmapped." % unmapped)
            if filter_cb:
                logger.info("%d were filtered for not matching known barcodes."
                            % nomatchcb)
        count += 1

        if aln.is_unmapped:
            unmapped += 1
            continue

        if gene_tags and not aln.has_tag('GX'):
            unmapped += 1
            continue

        if aln.qname != current_read:
            current_read = aln.qname
            if subsample and i not in index_filter:
                count_this_read = False
                continue
            else:
                count_this_read = True
        else:
            if not count_this_read:
                continue

        if parse_tags:
            CB = aln.get_tag('CR')
        else:
            match = parser_re.match(aln.qname)
            CB = match.group('CB')

        if filter_cb:
            if CB not in cb_hist.index:
                nomatchcb += 1
                continue

        if parse_tags:
            MB = aln.get_tag('UM')
        else:
            MB = match.group('MB')

        if gene_tags:
            target_name = aln.get_tag('GX').split(',')[0]
        else:
            txid = sam_file.getrname(aln.reference_id)
            if gene_map:
                if txid in gene_map:
                    target_name = gene_map[txid]
                else:
                    missing_transcripts.add(txid)
                    target_name = txid
            else:
                target_name = txid

        e_tuple = tuple_template.format(CB, target_name, aln.pos, MB)

        # Scale evidence by number of hits
        if no_scale_evidence:
            evidence[e_tuple] += 1.0
        else:
            evidence[e_tuple] += weigh_evidence(aln.tags)

        kept += 1

    tally_time = time.time() - start_tally
    if missing_transcripts:
        logger.warn('The following transcripts were missing gene_ids, so we added them as the transcript ids: %s' % str(missing_transcripts))
    logger.info('Tally done - {:.3}s, {:,} alns/min'.format(tally_time, int(60. * count / tally_time)))
    logger.info('Collapsing evidence')

    logger.info('Writing evidence')
    with tempfile.NamedTemporaryFile('w+t') as out_handle:
        for key in evidence:
            line = '{},{}\n'.format(key, evidence[key])
            out_handle.write(line)

        out_handle.flush()
        out_handle.seek(0)

        evidence_table = pd.read_csv(out_handle, header=None)

    del evidence

    evidence_query = 'evidence >= %f' % minevidence
    if positional:
        evidence_table.columns=['cell', 'gene', 'umi', 'pos', 'evidence']
        collapsed = evidence_table.query(evidence_query).groupby(['cell', 'gene'])['umi', 'pos'].size()

    else:
        evidence_table.columns=['cell', 'gene', 'umi', 'evidence']
        collapsed = evidence_table.query(evidence_query).groupby(['cell', 'gene'])['umi'].size()

    expanded = collapsed.unstack().T

    if gene_map:
        # This Series is just for sorting the index
        genes = pd.Series(index=set(gene_map.values()))
        genes = genes.sort_index()
        # Now genes is assigned to a DataFrame
        genes = expanded.ix[genes.index]

    elif gene_tags:
        expanded.sort_index()
        genes = expanded

    else:
        # make data frame have a complete accounting of transcripts
        targets = pd.Series(index=set(targets))
        targets = targets.sort_index()
        expanded = expanded.reindex(targets.index.values, fill_value=0)
        genes = expanded

    genes.fillna(0, inplace=True)
    genes = genes.astype(int)
    genes.index.name = "gene"

    logger.info('Output results')

    if subsample:
        cb_hist_sampled.to_csv('ss_{}_'.format(subsample) + os.path.basename(cb_histogram), sep='\t')

    if output_evidence_table:
        import shutil
        buf.seek(0)
        with open(output_evidence_table, 'w') as etab_fh:
            shutil.copyfileobj(buf, etab_fh)

    if sparse:
        pd.Series(genes.index).to_csv(out + ".rownames", index=False, header=False)
        pd.Series(genes.columns.values).to_csv(out + ".colnames", index=False, header=False)
        with open(out, "w+b") as out_handle:
            scipy.io.mmwrite(out_handle, scipy.sparse.csr_matrix(genes))
    else:
        genes.to_csv(out)
示例#40
0
文件: umis.py 项目: vals/umis
def fasttagcount(sam, out, genemap, positional, minevidence, cb_histogram, 
                 cb_cutoff, subsample, parse_tags, gene_tags, umi_matrix):
    ''' Count up evidence for tagged molecules, this implementation assumes the
    alignment file is coordinate sorted
    '''
    from pysam import AlignmentFile

    from io import StringIO
    import pandas as pd

    from utils import weigh_evidence

    if sam.endswith(".sam"):
        logger.error("To use the fasttagcount subcommand, the alignment file must be a "
                     "coordinate sorted, indexed BAM file.")
        sys.exit(1)

    logger.info('Reading optional files')

    gene_map = None
    if genemap:
        with open(genemap) as fh:
            try:
                gene_map = dict(p.strip().split() for p in fh)
            except ValueError:
                logger.error('Incorrectly formatted gene_map, need to be tsv.')
                sys.exit()

    if positional:
        tuple_template = '{0},{1},{2},{3}'
    else:
        tuple_template = '{0},{1},{3}'

    if not cb_cutoff:
        cb_cutoff = 0

    if cb_histogram and cb_cutoff == "auto":
        cb_cutoff = guess_depth_cutoff(cb_histogram)

    cb_cutoff = int(cb_cutoff)

    cb_hist = None
    filter_cb = False
    if cb_histogram:
        cb_hist = pd.read_csv(cb_histogram, index_col=0, header=-1, squeeze=True, sep="\t")
        total_num_cbs = cb_hist.shape[0]
        cb_hist = cb_hist[cb_hist > cb_cutoff]
        logger.info('Keeping {} out of {} cellular barcodes.'.format(cb_hist.shape[0], total_num_cbs))
        filter_cb = True

    parser_re = re.compile('.*:CELL_(?P<CB>.*):UMI_(?P<MB>.*)')

    if subsample:
        logger.info('Creating reservoir of subsampled reads ({} per cell)'.format(subsample))
        start_sampling  = time.time()

        reservoir = collections.defaultdict(list)
        cb_hist_sampled = 0 * cb_hist
        cb_obs = 0 * cb_hist

        track = stream_bamfile(sam)
        current_read = 'none_observed_yet'
        for i, aln in enumerate(track):
            if aln.qname == current_read:
                continue

            current_read = aln.qname

            if parse_tags:
                CB = aln.get_tag('CR')
            else:
                match = parser_re.match(aln.qname)
                CB = match.group('CB')

            if CB not in cb_hist.index:
                continue

            cb_obs[CB] += 1
            if len(reservoir[CB]) < subsample:
                reservoir[CB].append(i)
                cb_hist_sampled[CB] += 1
            else:
                s = pd.np.random.randint(0, cb_obs[CB])
                if s < subsample:
                    reservoir[CB][s] = i

        index_filter = set(itertools.chain.from_iterable(reservoir.values()))
        sam_file.close()
        sampling_time = time.time() - start_sampling
        logger.info('Sampling done - {:.3}s'.format(sampling_time))

    evidence = collections.defaultdict(lambda: collections.defaultdict(float))
    bare_evidence = collections.defaultdict(float)
    logger.info('Tallying evidence')
    start_tally = time.time()

    sam_mode = 'r' if sam.endswith(".sam") else 'rb'
    sam_file = AlignmentFile(sam, mode=sam_mode)
    transcript_map = collections.defaultdict(set)
    sam_transcripts = [x["SN"] for x in sam_file.header["SQ"]]
    if gene_map:
        for transcript, gene in gene_map.items():
            if transcript in sam_transcripts:
                transcript_map[gene].add(transcript)
    else:
        for transcript in sam_transcripts:
            transcript_map[transcript].add(transcript)
    missing_transcripts = set()
    alignments_processed = 0
    unmapped = 0
    kept = 0
    nomatchcb = 0
    current_read = 'none_observed_yet'
    current_transcript = None
    count_this_read = True
    transcripts_processed = 0
    genes_processed = 0
    cells = list(cb_hist.index)
    targets_seen = set()

    if umi_matrix:
        bare_evidence_handle = open(umi_matrix, "w")
        bare_evidence_handle.write(",".join(["gene"] + cells) + "\n")

    with open(out, "w") as out_handle:
        out_handle.write(",".join(["gene"] + cells) + "\n")
        for gene, transcripts in transcript_map.items():
            for transcript in transcripts:
                for aln in sam_file.fetch(transcript):
                    alignments_processed += 1

                    if aln.is_unmapped:
                        unmapped += 1
                        continue

                    if gene_tags and not aln.has_tag('GX'):
                        unmapped += 1
                        continue

                    if aln.qname != current_read:
                        current_read = aln.qname
                        if subsample and i not in index_filter:
                            count_this_read = False
                            continue
                        else:
                            count_this_read = True
                    else:
                        if not count_this_read:
                            continue

                    if parse_tags:
                        CB = aln.get_tag('CR')
                    else:
                        match = parser_re.match(aln.qname)
                        CB = match.group('CB')

                    if filter_cb:
                        if CB not in cb_hist.index:
                            nomatchcb += 1
                            continue

                    if parse_tags:
                        MB = aln.get_tag('UM')
                    else:
                        MB = match.group('MB')

                    if gene_tags:
                        target_name = aln.get_tag('GX').split(',')[0]
                    else:
                        txid = sam_file.getrname(aln.reference_id)
                        if gene_map:
                            if txid in gene_map:
                                target_name = gene_map[txid]
                            else:
                                missing_transcripts.add(txid)
                                continue
                        else:
                            target_name = txid
                    targets_seen.add(target_name)

                    # Scale evidence by number of hits
                    evidence[CB][MB] += weigh_evidence(aln.tags)
                    bare_evidence[CB] += weigh_evidence(aln.tags)
                    kept += 1
                transcripts_processed += 1
                if not transcripts_processed % 1000:
                    logger.info("%d genes processed." % genes_processed)
                    logger.info("%d transcripts processed." % transcripts_processed)
                    logger.info("%d alignments processed." % alignments_processed)

            earray = []
            for cell in cells:
                umis = [1 for _, v in evidence[cell].items() if v >= minevidence]
                earray.append(str(sum(umis)))
            out_handle.write(",".join([gene] + earray) + "\n")
            earray = []
            if umi_matrix:
                for cell in cells:
                    earray.append(str(int(bare_evidence[cell])))
                bare_evidence_handle.write(",".join([gene] + earray) + "\n")

            evidence = collections.defaultdict(lambda: collections.defaultdict(int))
            bare_evidence = collections.defaultdict(int)
            genes_processed += 1

    if umi_matrix:
        bare_evidence_handle.close()

    # fill dataframe with missing values, sort and output
    df = pd.read_csv(out, index_col=0, header=0)
    targets = pd.Series(index=set(transcript_map.keys()))
    targets = targets.sort_index()
    df = df.reindex(targets.index.values, fill_value=0)
    df = df.sort_index()
    df.to_csv(out)

    if umi_matrix:
        df = pd.read_csv(umi_matrix, index_col=0, header=0)
        df = df.reindex(targets.index.values, fill_value=0)
        df = df.sort_index()
        df.to_csv(umi_matrix)
示例#41
0
文件: umis.py 项目: roryk/umis
def tagcount(sam, out, genemap, output_evidence_table, positional, minevidence,
             cb_histogram, cb_cutoff, no_scale_evidence, subsample):
    ''' Count up evidence for tagged molecules
    '''
    from pysam import AlignmentFile

    from io import StringIO
    import pandas as pd

    from utils import weigh_evidence

    logger.info('Reading optional files')

    gene_map = None
    if genemap:
        with open(genemap) as fh:
            try:
                gene_map = dict(p.strip().split() for p in fh)
            except ValueError:
                logger.error('Incorrectly formatted gene_map, need to be tsv.')
                sys.exit()

    if positional:
        tuple_template = '{0},{1},{2},{3}'
    else:
        tuple_template = '{0},{1},{3}'

    if not cb_cutoff:
        cb_cutoff = 0

    if cb_histogram and cb_cutoff == "auto":
        cb_cutoff = guess_depth_cutoff(cb_histogram)

    cb_cutoff = int(cb_cutoff)

    cb_hist = None
    filter_cb = False
    if cb_histogram:
        cb_hist = pd.read_table(cb_histogram, index_col=0, header=-1, squeeze=True)
        total_num_cbs = cb_hist.shape[0]
        cb_hist = cb_hist[cb_hist > cb_cutoff]
        logger.info('Keeping {} out of {} cellular barcodes.'.format(cb_hist.shape[0], total_num_cbs))
        filter_cb = True

    parser_re = re.compile('.*:CELL_(?P<CB>.*):UMI_(?P<MB>.*)')

    if subsample:
        logger.info('Creating reservoir of subsampled reads ({} per cell)'.format(subsample))
        start_sampling  = time.time()

        reservoir = collections.defaultdict(list)
        cb_hist_sampled = 0 * cb_hist
        cb_obs = 0 * cb_hist

        sam_mode = 'r' if sam.endswith(".sam") else 'rb'
        sam_file = AlignmentFile(sam, mode=sam_mode)
        track = sam_file.fetch(until_eof=True)
        current_read = 'none_observed_yet'
        for i, aln in enumerate(track):
            if aln.qname == current_read:
                continue

            current_read = aln.qname
            match = parser_re.match(aln.qname)
            CB = match.group('CB')

            if CB not in cb_hist.index:
                continue

            cb_obs[CB] += 1
            if len(reservoir[CB]) < subsample:
                reservoir[CB].append(i)
                cb_hist_sampled[CB] += 1
            else:
                s = pd.np.random.randint(0, cb_obs[CB])
                if s < subsample:
                    reservoir[CB][s] = i

        index_filter = set(itertools.chain.from_iterable(reservoir.values()))
        sam_file.close()
        sampling_time = time.time() - start_sampling
        logger.info('Sampling done - {:.3}s'.format(sampling_time))

    evidence = collections.defaultdict(int)

    logger.info('Tallying evidence')
    start_tally = time.time()

    sam_mode = 'r' if sam.endswith(".sam") else 'rb'
    sam_file = AlignmentFile(sam, mode=sam_mode)
    track = sam_file.fetch(until_eof=True)
    count = 0
    unmapped = 0
    kept = 0
    nomatchcb = 0
    current_read = 'none_observed_yet'
    count_this_read = True
    for i, aln in enumerate(track):
        count += 1
        if not count % 100000:
            logger.info("Processed %d alignments, kept %d." % (count, kept))
            logger.info("%d were filtered for being unmapped." % unmapped)
            if filter_cb:
                logger.info("%d were filtered for not matching known barcodes."
                            % nomatchcb)

        if aln.is_unmapped:
            unmapped += 1
            continue

        if aln.qname != current_read:
            current_read = aln.qname
            if subsample and i not in index_filter:
                count_this_read = False
                continue
            else:
                count_this_read = True
        else:
            if not count_this_read:
                continue

        match = parser_re.match(aln.qname)
        CB = match.group('CB')
        if filter_cb:
            if CB not in cb_hist.index:
                nomatchcb += 1
                continue

        MB = match.group('MB')

        txid = sam_file.getrname(aln.reference_id)
        if gene_map:
            target_name = gene_map[txid]

        else:
            target_name = txid

        e_tuple = tuple_template.format(CB, target_name, aln.pos, MB)

        # Scale evidence by number of hits
        if no_scale_evidence:
            evidence[e_tuple] += 1.0
        else:
            evidence[e_tuple] += weigh_evidence(aln.tags)
        kept += 1

    tally_time = time.time() - start_tally
    logger.info('Tally done - {:.3}s, {:,} alns/min'.format(tally_time, int(60. * count / tally_time)))
    logger.info('Collapsing evidence')

    buf = StringIO()
    for key in evidence:
        line = '{},{}\n'.format(key, evidence[key])
        buf.write(unicode(line), "utf-8")

    buf.seek(0)
    evidence_table = pd.read_csv(buf)
    evidence_query = 'evidence >= %f' % minevidence
    if positional:
        evidence_table.columns=['cell', 'gene', 'umi', 'pos', 'evidence']
        collapsed = evidence_table.query(evidence_query).groupby(['cell', 'gene'])['umi', 'pos'].size()

    else:
        evidence_table.columns=['cell', 'gene', 'umi', 'evidence']
        collapsed = evidence_table.query(evidence_query).groupby(['cell', 'gene'])['umi'].size()

    expanded = collapsed.unstack().T

    if gene_map:
        # This Series is just for sorting the index
        genes = pd.Series(index=set(gene_map.values()))
        genes = genes.sort_index()
        # Now genes is assigned to a DataFrame
        genes = expanded.ix[genes.index]

    else:
        genes = expanded

    genes.replace(pd.np.nan, 0, inplace=True)

    logger.info('Output results')

    if subsample:
        cb_hist_sampled.to_csv('ss_{}_'.format(subsample) + os.path.basename(cb_histogram), sep='\t')

    if output_evidence_table:
        import shutil
        buf.seek(0)
        with open(output_evidence_table, 'w') as etab_fh:
            shutil.copyfileobj(buf, etab_fh)

    genes.to_csv(out)