Ejemplo n.º 1
0
def tagcount(sam, out, genemap, output_evidence_table, positional, minevidence):
    ''' Count up evidence for tagged molecules
    '''
    from pysam import AlignmentFile
    from cStringIO import StringIO
    import pandas as pd

    from utils import weigh_evidence

    logger.info('Reading optional files')

    gene_map = None
    if genemap:
        with open(genemap) as fh:
            gene_map = dict(p.strip().split() for p in fh)

    if positional:
        tuple_template = '{0},{1},{2},{3}'
    else:
        tuple_template = '{0},{1},{3}'

    parser_re = re.compile('.*:CELL_(?P<CB>.*):UMI_(?P<MB>.*)')

    logger.info('Tallying evidence')
    start_tally = time.time()

    evidence = collections.defaultdict(int)

    sam_file = AlignmentFile(sam, mode='r')
    track = sam_file.fetch(until_eof=True)
    for i, aln in enumerate(track):
        if aln.is_unmapped:
            continue

        match = parser_re.match(aln.qname)
        CB = match.group('CB')
        MB = match.group('MB')

        txid = sam_file.getrname(aln.reference_id)
        if gene_map:
            target_name = gene_map[txid]

        else:
            target_name = txid

        e_tuple = tuple_template.format(CB, target_name, aln.pos, MB)

        # Scale evidence by number of hits
        evidence[e_tuple] += weigh_evidence(aln.tags)

    tally_time = time.time() - start_tally
    logger.info('Tally done - {:.3}s, {:,} alns/min'.format(tally_time, int(60. * i / tally_time)))
    logger.info('Collapsing evidence')

    buf = StringIO()
    for key in evidence:
        line = '{},{}\n'.format(key, evidence[key])
        buf.write(line)

    buf.seek(0)
    evidence_table = pd.read_csv(buf)
    evidence_query = 'evidence >= %f' % minevidence
    if positional:
        evidence_table.columns=['cell', 'gene', 'umi', 'pos', 'evidence']
        collapsed = evidence_table.query(evidence_query).groupby(['cell', 'gene'])['umi', 'pos'].size()

    else:
        evidence_table.columns=['cell', 'gene', 'umi', 'evidence']
        collapsed = evidence_table.query(evidence_query).groupby(['cell', 'gene'])['umi'].size()

    expanded = collapsed.unstack().T

    if gene_map:
        # This Series is just for sorting the index
        genes = pd.Series(index=set(gene_map.values()))
        genes = genes.sort_index()
        # Now genes is assigned to a DataFrame
        genes = expanded.ix[genes.index]

    else:
        genes = expanded

    genes.replace(pd.np.nan, 0, inplace=True)

    logger.info('Output results')

    if output_evidence_table:
        import shutil
        buf.seek(0)
        with open(output_evidence_table, 'w') as etab_fh:
            shutil.copyfileobj(buf, etab_fh)

    genes.to_csv(out)
Ejemplo n.º 2
0
def tagcount(sam, out, genemap, output_evidence_table, positional, minevidence,
             cb_histogram, cb_cutoff, no_scale_evidence, subsample, sparse):
    ''' Count up evidence for tagged molecules
    '''
    from pysam import AlignmentFile

    from io import StringIO
    import pandas as pd

    from utils import weigh_evidence

    logger.info('Reading optional files')

    gene_map = None
    if genemap:
        with open(genemap) as fh:
            try:
                gene_map = dict(p.strip().split() for p in fh)
            except ValueError:
                logger.error('Incorrectly formatted gene_map, need to be tsv.')
                sys.exit()

    if positional:
        tuple_template = '{0},{1},{2},{3}'
    else:
        tuple_template = '{0},{1},{3}'

    if not cb_cutoff:
        cb_cutoff = 0

    if cb_histogram and cb_cutoff == "auto":
        cb_cutoff = guess_depth_cutoff(cb_histogram)

    cb_cutoff = int(cb_cutoff)

    cb_hist = None
    filter_cb = False
    if cb_histogram:
        cb_hist = pd.read_table(cb_histogram,
                                index_col=0,
                                header=-1,
                                squeeze=True)
        total_num_cbs = cb_hist.shape[0]
        cb_hist = cb_hist[cb_hist > cb_cutoff]
        logger.info('Keeping {} out of {} cellular barcodes.'.format(
            cb_hist.shape[0], total_num_cbs))
        filter_cb = True

    parser_re = re.compile('.*:CELL_(?P<CB>.*):UMI_(?P<MB>.*)')

    if subsample:
        logger.info(
            'Creating reservoir of subsampled reads ({} per cell)'.format(
                subsample))
        start_sampling = time.time()

        reservoir = collections.defaultdict(list)
        cb_hist_sampled = 0 * cb_hist
        cb_obs = 0 * cb_hist

        track = stream_bamfile(sam)
        current_read = 'none_observed_yet'
        for i, aln in enumerate(track):
            if aln.qname == current_read:
                continue

            current_read = aln.qname
            match = parser_re.match(aln.qname)
            CB = match.group('CB')

            if CB not in cb_hist.index:
                continue

            cb_obs[CB] += 1
            if len(reservoir[CB]) < subsample:
                reservoir[CB].append(i)
                cb_hist_sampled[CB] += 1
            else:
                s = pd.np.random.randint(0, cb_obs[CB])
                if s < subsample:
                    reservoir[CB][s] = i

        index_filter = set(itertools.chain.from_iterable(reservoir.values()))
        sam_file.close()
        sampling_time = time.time() - start_sampling
        logger.info('Sampling done - {:.3}s'.format(sampling_time))

    evidence = collections.defaultdict(int)

    logger.info('Tallying evidence')
    start_tally = time.time()

    sam_mode = 'r' if sam.endswith(".sam") else 'rb'
    sam_file = AlignmentFile(sam, mode=sam_mode)
    targets = [x["SN"] for x in sam_file.header["SQ"]]
    track = sam_file.fetch(until_eof=True)
    count = 0
    unmapped = 0
    kept = 0
    nomatchcb = 0
    current_read = 'none_observed_yet'
    count_this_read = True
    for i, aln in enumerate(track):
        if count and not count % 100000:
            logger.info("Processed %d alignments, kept %d." % (count, kept))
            logger.info("%d were filtered for being unmapped." % unmapped)
            if filter_cb:
                logger.info(
                    "%d were filtered for not matching known barcodes." %
                    nomatchcb)
        count += 1

        if aln.is_unmapped:
            unmapped += 1
            continue

        if aln.qname != current_read:
            current_read = aln.qname
            if subsample and i not in index_filter:
                count_this_read = False
                continue
            else:
                count_this_read = True
        else:
            if not count_this_read:
                continue

        match = parser_re.match(aln.qname)
        CB = match.group('CB')
        if filter_cb:
            if CB not in cb_hist.index:
                nomatchcb += 1
                continue

        MB = match.group('MB')

        txid = sam_file.getrname(aln.reference_id)
        if gene_map:
            target_name = gene_map[txid]

        else:
            target_name = txid

        e_tuple = tuple_template.format(CB, target_name, aln.pos, MB)

        # Scale evidence by number of hits
        if no_scale_evidence:
            evidence[e_tuple] += 1.0
        else:
            evidence[e_tuple] += weigh_evidence(aln.tags)
        kept += 1

    tally_time = time.time() - start_tally
    logger.info('Tally done - {:.3}s, {:,} alns/min'.format(
        tally_time, int(60. * count / tally_time)))
    logger.info('Collapsing evidence')

    logger.info('Writing evidence')
    with tempfile.NamedTemporaryFile('w+t') as out_handle:
        for key in evidence:
            line = '{},{}\n'.format(key, evidence[key])
            out_handle.write(line)

        out_handle.flush()
        out_handle.seek(0)
        evidence_table = pd.read_csv(out_handle, header=None)

    del evidence

    evidence_query = 'evidence >= %f' % minevidence
    if positional:
        evidence_table.columns = ['cell', 'gene', 'umi', 'pos', 'evidence']
        collapsed = evidence_table.query(evidence_query).groupby(
            ['cell', 'gene'])['umi', 'pos'].size()

    else:
        evidence_table.columns = ['cell', 'gene', 'umi', 'evidence']
        collapsed = evidence_table.query(evidence_query).groupby(
            ['cell', 'gene'])['umi'].size()

    expanded = collapsed.unstack().T

    if gene_map:
        # This Series is just for sorting the index
        genes = pd.Series(index=set(gene_map.values()))
        genes = genes.sort_index()
        # Now genes is assigned to a DataFrame
        genes = expanded.ix[genes.index]

    else:
        # make data frame have a complete accounting of transcripts
        targets = pd.Series(index=set(targets))
        targets = targets.sort_index()
        expanded = expanded.reindex(targets.index.values, fill_value=0)
        genes = expanded

    genes.replace(pd.np.nan, 0, inplace=True)

    logger.info('Output results')

    if subsample:
        cb_hist_sampled.to_csv('ss_{}_'.format(subsample) +
                               os.path.basename(cb_histogram),
                               sep='\t')

    if output_evidence_table:
        import shutil
        buf.seek(0)
        with open(output_evidence_table, 'w') as etab_fh:
            shutil.copyfileobj(buf, etab_fh)

    if sparse:
        pd.Series(genes.index).to_csv(out + ".rownames", index=False)
        pd.Series(genes.columns.values).to_csv(out + ".colnames", index=False)
        with open(out, "w+b") as out_handle:
            scipy.io.mmwrite(out_handle, scipy.sparse.csr_matrix(genes))
    else:
        genes.to_csv(out)
Ejemplo n.º 3
0
Archivo: umis.py Proyecto: lylamha/umis
def tagcount(sam, out, genemap, output_evidence_table, positional, minevidence,
             cb_histogram, cb_cutoff, no_scale_evidence):
    ''' Count up evidence for tagged molecules
    '''
    from pysam import AlignmentFile
    from cStringIO import StringIO
    import pandas as pd

    from utils import weigh_evidence

    logger.info('Reading optional files')

    gene_map = None
    if genemap:
        with open(genemap) as fh:
            gene_map = dict(p.strip().split() for p in fh)

    if positional:
        tuple_template = '{0},{1},{2},{3}'
    else:
        tuple_template = '{0},{1},{3}'

    cb_set = set()
    if cb_histogram:
        with open(cb_histogram) as fh:
            cb_map = dict(p.strip().split() for p in fh)
            cb_set = set([k for k, v in cb_map.items() if int(v) > cb_cutoff])
            logger.info('Keeping %d out of %d cellular barcodes.' %
                        (len(cb_map), len(cb_set)))

    parser_re = re.compile('.*:CELL_(?P<CB>.*):UMI_(?P<MB>.*)')

    logger.info('Tallying evidence')
    start_tally = time.time()

    evidence = collections.defaultdict(int)

    sam_mode = 'r' if sam.endswith(".sam") else 'rb'
    sam_file = AlignmentFile(sam, mode=sam_mode)
    track = sam_file.fetch(until_eof=True)
    count = 0
    kept = 0
    for i, aln in enumerate(track):
        count += 1
        if not count % 100000:
            logger.info("Processed %d alignments, kept %d." % (count, kept))

        if aln.is_unmapped:
            continue

        match = parser_re.match(aln.qname)
        CB = match.group('CB')
        if cb_set and CB not in cb_set:
            continue
        MB = match.group('MB')

        txid = sam_file.getrname(aln.reference_id)
        if gene_map:
            target_name = gene_map[txid]

        else:
            target_name = txid

        e_tuple = tuple_template.format(CB, target_name, aln.pos, MB)

        # Scale evidence by number of hits
        if no_scale_evidence:
            evidence[e_tuple] += 1.0
        else:
            evidence[e_tuple] += weigh_evidence(aln.tags)
        kept += 1

    tally_time = time.time() - start_tally
    logger.info('Tally done - {:.3}s, {:,} alns/min'.format(
        tally_time, int(60. * i / tally_time)))
    logger.info('Collapsing evidence')

    buf = StringIO()
    for key in evidence:
        line = '{},{}\n'.format(key, evidence[key])
        buf.write(line)

    buf.seek(0)
    evidence_table = pd.read_csv(buf)
    evidence_query = 'evidence >= %f' % minevidence
    if positional:
        evidence_table.columns = ['cell', 'gene', 'umi', 'pos', 'evidence']
        collapsed = evidence_table.query(evidence_query).groupby(
            ['cell', 'gene'])['umi', 'pos'].size()

    else:
        evidence_table.columns = ['cell', 'gene', 'umi', 'evidence']
        collapsed = evidence_table.query(evidence_query).groupby(
            ['cell', 'gene'])['umi'].size()

    expanded = collapsed.unstack().T

    if gene_map:
        # This Series is just for sorting the index
        genes = pd.Series(index=set(gene_map.values()))
        genes = genes.sort_index()
        # Now genes is assigned to a DataFrame
        genes = expanded.ix[genes.index]

    else:
        genes = expanded

    genes.replace(pd.np.nan, 0, inplace=True)

    logger.info('Output results')

    if output_evidence_table:
        import shutil
        buf.seek(0)
        with open(output_evidence_table, 'w') as etab_fh:
            shutil.copyfileobj(buf, etab_fh)

    genes.to_csv(out)
Ejemplo n.º 4
0
def fasttagcount(sam, out, genemap, positional, minevidence, cb_histogram,
                 cb_cutoff, subsample, parse_tags, gene_tags, umi_matrix):
    ''' Count up evidence for tagged molecules, this implementation assumes the
    alignment file is coordinate sorted
    '''
    from pysam import AlignmentFile

    from io import StringIO
    import pandas as pd

    from utils import weigh_evidence

    if sam.endswith(".sam"):
        logger.error(
            "To use the fasttagcount subcommand, the alignment file must be a "
            "coordinate sorted, indexed BAM file.")
        sys.exit(1)

    logger.info('Reading optional files')

    gene_map = None
    if genemap:
        with open(genemap) as fh:
            try:
                gene_map = dict(p.strip().split() for p in fh)
            except ValueError:
                logger.error('Incorrectly formatted gene_map, need to be tsv.')
                sys.exit()

    if positional:
        tuple_template = '{0},{1},{2},{3}'
    else:
        tuple_template = '{0},{1},{3}'

    if not cb_cutoff:
        cb_cutoff = 0

    if cb_histogram and cb_cutoff == "auto":
        cb_cutoff = guess_depth_cutoff(cb_histogram)

    cb_cutoff = int(cb_cutoff)

    cb_hist = None
    filter_cb = False
    if cb_histogram:
        cb_hist = pd.read_table(cb_histogram,
                                index_col=0,
                                header=-1,
                                squeeze=True)
        total_num_cbs = cb_hist.shape[0]
        cb_hist = cb_hist[cb_hist > cb_cutoff]
        logger.info('Keeping {} out of {} cellular barcodes.'.format(
            cb_hist.shape[0], total_num_cbs))
        filter_cb = True

    parser_re = re.compile('.*:CELL_(?P<CB>.*):UMI_(?P<MB>.*)')

    if subsample:
        logger.info(
            'Creating reservoir of subsampled reads ({} per cell)'.format(
                subsample))
        start_sampling = time.time()

        reservoir = collections.defaultdict(list)
        cb_hist_sampled = 0 * cb_hist
        cb_obs = 0 * cb_hist

        track = stream_bamfile(sam)
        current_read = 'none_observed_yet'
        for i, aln in enumerate(track):
            if aln.qname == current_read:
                continue

            current_read = aln.qname

            if parse_tags:
                CB = aln.get_tag('CR')
            else:
                match = parser_re.match(aln.qname)
                CB = match.group('CB')

            if CB not in cb_hist.index:
                continue

            cb_obs[CB] += 1
            if len(reservoir[CB]) < subsample:
                reservoir[CB].append(i)
                cb_hist_sampled[CB] += 1
            else:
                s = pd.np.random.randint(0, cb_obs[CB])
                if s < subsample:
                    reservoir[CB][s] = i

        index_filter = set(itertools.chain.from_iterable(reservoir.values()))
        sam_file.close()
        sampling_time = time.time() - start_sampling
        logger.info('Sampling done - {:.3}s'.format(sampling_time))

    evidence = collections.defaultdict(lambda: collections.defaultdict(float))
    bare_evidence = collections.defaultdict(float)
    logger.info('Tallying evidence')
    start_tally = time.time()

    sam_mode = 'r' if sam.endswith(".sam") else 'rb'
    sam_file = AlignmentFile(sam, mode=sam_mode)
    transcript_map = collections.defaultdict(set)
    sam_transcripts = [x["SN"] for x in sam_file.header["SQ"]]
    if gene_map:
        for transcript, gene in gene_map.items():
            if transcript in sam_transcripts:
                transcript_map[gene].add(transcript)
    else:
        for transcript in sam_transcripts:
            transcript_map[transcript].add(transcript)
    missing_transcripts = set()
    alignments_processed = 0
    unmapped = 0
    kept = 0
    nomatchcb = 0
    current_read = 'none_observed_yet'
    current_transcript = None
    count_this_read = True
    transcripts_processed = 0
    genes_processed = 0
    cells = list(cb_hist.index)
    targets_seen = set()

    if umi_matrix:
        bare_evidence_handle = open(umi_matrix, "w")
        bare_evidence_handle.write(",".join(["gene"] + cells) + "\n")

    with open(out, "w") as out_handle:
        out_handle.write(",".join(["gene"] + cells) + "\n")
        for gene, transcripts in transcript_map.items():
            for transcript in transcripts:
                for aln in sam_file.fetch(transcript):
                    alignments_processed += 1

                    if aln.is_unmapped:
                        unmapped += 1
                        continue

                    if gene_tags and not aln.has_tag('GX'):
                        unmapped += 1
                        continue

                    if aln.qname != current_read:
                        current_read = aln.qname
                        if subsample and i not in index_filter:
                            count_this_read = False
                            continue
                        else:
                            count_this_read = True
                    else:
                        if not count_this_read:
                            continue

                    if parse_tags:
                        CB = aln.get_tag('CR')
                    else:
                        match = parser_re.match(aln.qname)
                        CB = match.group('CB')

                    if filter_cb:
                        if CB not in cb_hist.index:
                            nomatchcb += 1
                            continue

                    if parse_tags:
                        MB = aln.get_tag('UM')
                    else:
                        MB = match.group('MB')

                    if gene_tags:
                        target_name = aln.get_tag('GX').split(',')[0]
                    else:
                        txid = sam_file.getrname(aln.reference_id)
                        if gene_map:
                            if txid in gene_map:
                                target_name = gene_map[txid]
                            else:
                                missing_transcripts.add(txid)
                                continue
                        else:
                            target_name = txid
                    targets_seen.add(target_name)

                    # Scale evidence by number of hits
                    evidence[CB][MB] += weigh_evidence(aln.tags)
                    bare_evidence[CB] += weigh_evidence(aln.tags)
                    kept += 1
                transcripts_processed += 1
                if not transcripts_processed % 1000:
                    logger.info("%d genes processed." % genes_processed)
                    logger.info("%d transcripts processed." %
                                transcripts_processed)
                    logger.info("%d alignments processed." %
                                alignments_processed)

            earray = []
            for cell in cells:
                umis = [
                    1 for _, v in evidence[cell].items() if v >= minevidence
                ]
                earray.append(str(sum(umis)))
            out_handle.write(",".join([gene] + earray) + "\n")
            earray = []
            if umi_matrix:
                for cell in cells:
                    earray.append(str(int(bare_evidence[cell])))
                bare_evidence_handle.write(",".join([gene] + earray) + "\n")

            evidence = collections.defaultdict(
                lambda: collections.defaultdict(int))
            bare_evidence = collections.defaultdict(int)
            genes_processed += 1

    if umi_matrix:
        bare_evidence_handle.close()

    # fill dataframe with missing values, sort and output
    df = pd.read_csv(out, index_col=0, header=0)
    targets = pd.Series(index=set(transcript_map.keys()))
    targets = targets.sort_index()
    df = df.reindex(targets.index.values, fill_value=0)
    df = df.sort_index()
    df.to_csv(out)

    if umi_matrix:
        df = pd.read_csv(umi_matrix, index_col=0, header=0)
        df = df.reindex(targets.index.values, fill_value=0)
        df = df.sort_index()
        df.to_csv(umi_matrix)
Ejemplo n.º 5
0
Archivo: umis.py Proyecto: roryk/umis
def tagcount(sam, out, genemap, output_evidence_table, positional, minevidence,
             cb_histogram, cb_cutoff, no_scale_evidence, subsample):
    ''' Count up evidence for tagged molecules
    '''
    from pysam import AlignmentFile

    from io import StringIO
    import pandas as pd

    from utils import weigh_evidence

    logger.info('Reading optional files')

    gene_map = None
    if genemap:
        with open(genemap) as fh:
            try:
                gene_map = dict(p.strip().split() for p in fh)
            except ValueError:
                logger.error('Incorrectly formatted gene_map, need to be tsv.')
                sys.exit()

    if positional:
        tuple_template = '{0},{1},{2},{3}'
    else:
        tuple_template = '{0},{1},{3}'

    if not cb_cutoff:
        cb_cutoff = 0

    if cb_histogram and cb_cutoff == "auto":
        cb_cutoff = guess_depth_cutoff(cb_histogram)

    cb_cutoff = int(cb_cutoff)

    cb_hist = None
    filter_cb = False
    if cb_histogram:
        cb_hist = pd.read_table(cb_histogram, index_col=0, header=-1, squeeze=True)
        total_num_cbs = cb_hist.shape[0]
        cb_hist = cb_hist[cb_hist > cb_cutoff]
        logger.info('Keeping {} out of {} cellular barcodes.'.format(cb_hist.shape[0], total_num_cbs))
        filter_cb = True

    parser_re = re.compile('.*:CELL_(?P<CB>.*):UMI_(?P<MB>.*)')

    if subsample:
        logger.info('Creating reservoir of subsampled reads ({} per cell)'.format(subsample))
        start_sampling  = time.time()

        reservoir = collections.defaultdict(list)
        cb_hist_sampled = 0 * cb_hist
        cb_obs = 0 * cb_hist

        sam_mode = 'r' if sam.endswith(".sam") else 'rb'
        sam_file = AlignmentFile(sam, mode=sam_mode)
        track = sam_file.fetch(until_eof=True)
        current_read = 'none_observed_yet'
        for i, aln in enumerate(track):
            if aln.qname == current_read:
                continue

            current_read = aln.qname
            match = parser_re.match(aln.qname)
            CB = match.group('CB')

            if CB not in cb_hist.index:
                continue

            cb_obs[CB] += 1
            if len(reservoir[CB]) < subsample:
                reservoir[CB].append(i)
                cb_hist_sampled[CB] += 1
            else:
                s = pd.np.random.randint(0, cb_obs[CB])
                if s < subsample:
                    reservoir[CB][s] = i

        index_filter = set(itertools.chain.from_iterable(reservoir.values()))
        sam_file.close()
        sampling_time = time.time() - start_sampling
        logger.info('Sampling done - {:.3}s'.format(sampling_time))

    evidence = collections.defaultdict(int)

    logger.info('Tallying evidence')
    start_tally = time.time()

    sam_mode = 'r' if sam.endswith(".sam") else 'rb'
    sam_file = AlignmentFile(sam, mode=sam_mode)
    track = sam_file.fetch(until_eof=True)
    count = 0
    unmapped = 0
    kept = 0
    nomatchcb = 0
    current_read = 'none_observed_yet'
    count_this_read = True
    for i, aln in enumerate(track):
        count += 1
        if not count % 100000:
            logger.info("Processed %d alignments, kept %d." % (count, kept))
            logger.info("%d were filtered for being unmapped." % unmapped)
            if filter_cb:
                logger.info("%d were filtered for not matching known barcodes."
                            % nomatchcb)

        if aln.is_unmapped:
            unmapped += 1
            continue

        if aln.qname != current_read:
            current_read = aln.qname
            if subsample and i not in index_filter:
                count_this_read = False
                continue
            else:
                count_this_read = True
        else:
            if not count_this_read:
                continue

        match = parser_re.match(aln.qname)
        CB = match.group('CB')
        if filter_cb:
            if CB not in cb_hist.index:
                nomatchcb += 1
                continue

        MB = match.group('MB')

        txid = sam_file.getrname(aln.reference_id)
        if gene_map:
            target_name = gene_map[txid]

        else:
            target_name = txid

        e_tuple = tuple_template.format(CB, target_name, aln.pos, MB)

        # Scale evidence by number of hits
        if no_scale_evidence:
            evidence[e_tuple] += 1.0
        else:
            evidence[e_tuple] += weigh_evidence(aln.tags)
        kept += 1

    tally_time = time.time() - start_tally
    logger.info('Tally done - {:.3}s, {:,} alns/min'.format(tally_time, int(60. * count / tally_time)))
    logger.info('Collapsing evidence')

    buf = StringIO()
    for key in evidence:
        line = '{},{}\n'.format(key, evidence[key])
        buf.write(unicode(line), "utf-8")

    buf.seek(0)
    evidence_table = pd.read_csv(buf)
    evidence_query = 'evidence >= %f' % minevidence
    if positional:
        evidence_table.columns=['cell', 'gene', 'umi', 'pos', 'evidence']
        collapsed = evidence_table.query(evidence_query).groupby(['cell', 'gene'])['umi', 'pos'].size()

    else:
        evidence_table.columns=['cell', 'gene', 'umi', 'evidence']
        collapsed = evidence_table.query(evidence_query).groupby(['cell', 'gene'])['umi'].size()

    expanded = collapsed.unstack().T

    if gene_map:
        # This Series is just for sorting the index
        genes = pd.Series(index=set(gene_map.values()))
        genes = genes.sort_index()
        # Now genes is assigned to a DataFrame
        genes = expanded.ix[genes.index]

    else:
        genes = expanded

    genes.replace(pd.np.nan, 0, inplace=True)

    logger.info('Output results')

    if subsample:
        cb_hist_sampled.to_csv('ss_{}_'.format(subsample) + os.path.basename(cb_histogram), sep='\t')

    if output_evidence_table:
        import shutil
        buf.seek(0)
        with open(output_evidence_table, 'w') as etab_fh:
            shutil.copyfileobj(buf, etab_fh)

    genes.to_csv(out)
Ejemplo n.º 6
0
Archivo: umis.py Proyecto: vals/umis
def fasttagcount(sam, out, genemap, positional, minevidence, cb_histogram, 
                 cb_cutoff, subsample, parse_tags, gene_tags, umi_matrix):
    ''' Count up evidence for tagged molecules, this implementation assumes the
    alignment file is coordinate sorted
    '''
    from pysam import AlignmentFile

    from io import StringIO
    import pandas as pd

    from utils import weigh_evidence

    if sam.endswith(".sam"):
        logger.error("To use the fasttagcount subcommand, the alignment file must be a "
                     "coordinate sorted, indexed BAM file.")
        sys.exit(1)

    logger.info('Reading optional files')

    gene_map = None
    if genemap:
        with open(genemap) as fh:
            try:
                gene_map = dict(p.strip().split() for p in fh)
            except ValueError:
                logger.error('Incorrectly formatted gene_map, need to be tsv.')
                sys.exit()

    if positional:
        tuple_template = '{0},{1},{2},{3}'
    else:
        tuple_template = '{0},{1},{3}'

    if not cb_cutoff:
        cb_cutoff = 0

    if cb_histogram and cb_cutoff == "auto":
        cb_cutoff = guess_depth_cutoff(cb_histogram)

    cb_cutoff = int(cb_cutoff)

    cb_hist = None
    filter_cb = False
    if cb_histogram:
        cb_hist = pd.read_csv(cb_histogram, index_col=0, header=-1, squeeze=True, sep="\t")
        total_num_cbs = cb_hist.shape[0]
        cb_hist = cb_hist[cb_hist > cb_cutoff]
        logger.info('Keeping {} out of {} cellular barcodes.'.format(cb_hist.shape[0], total_num_cbs))
        filter_cb = True

    parser_re = re.compile('.*:CELL_(?P<CB>.*):UMI_(?P<MB>.*)')

    if subsample:
        logger.info('Creating reservoir of subsampled reads ({} per cell)'.format(subsample))
        start_sampling  = time.time()

        reservoir = collections.defaultdict(list)
        cb_hist_sampled = 0 * cb_hist
        cb_obs = 0 * cb_hist

        track = stream_bamfile(sam)
        current_read = 'none_observed_yet'
        for i, aln in enumerate(track):
            if aln.qname == current_read:
                continue

            current_read = aln.qname

            if parse_tags:
                CB = aln.get_tag('CR')
            else:
                match = parser_re.match(aln.qname)
                CB = match.group('CB')

            if CB not in cb_hist.index:
                continue

            cb_obs[CB] += 1
            if len(reservoir[CB]) < subsample:
                reservoir[CB].append(i)
                cb_hist_sampled[CB] += 1
            else:
                s = pd.np.random.randint(0, cb_obs[CB])
                if s < subsample:
                    reservoir[CB][s] = i

        index_filter = set(itertools.chain.from_iterable(reservoir.values()))
        sam_file.close()
        sampling_time = time.time() - start_sampling
        logger.info('Sampling done - {:.3}s'.format(sampling_time))

    evidence = collections.defaultdict(lambda: collections.defaultdict(float))
    bare_evidence = collections.defaultdict(float)
    logger.info('Tallying evidence')
    start_tally = time.time()

    sam_mode = 'r' if sam.endswith(".sam") else 'rb'
    sam_file = AlignmentFile(sam, mode=sam_mode)
    transcript_map = collections.defaultdict(set)
    sam_transcripts = [x["SN"] for x in sam_file.header["SQ"]]
    if gene_map:
        for transcript, gene in gene_map.items():
            if transcript in sam_transcripts:
                transcript_map[gene].add(transcript)
    else:
        for transcript in sam_transcripts:
            transcript_map[transcript].add(transcript)
    missing_transcripts = set()
    alignments_processed = 0
    unmapped = 0
    kept = 0
    nomatchcb = 0
    current_read = 'none_observed_yet'
    current_transcript = None
    count_this_read = True
    transcripts_processed = 0
    genes_processed = 0
    cells = list(cb_hist.index)
    targets_seen = set()

    if umi_matrix:
        bare_evidence_handle = open(umi_matrix, "w")
        bare_evidence_handle.write(",".join(["gene"] + cells) + "\n")

    with open(out, "w") as out_handle:
        out_handle.write(",".join(["gene"] + cells) + "\n")
        for gene, transcripts in transcript_map.items():
            for transcript in transcripts:
                for aln in sam_file.fetch(transcript):
                    alignments_processed += 1

                    if aln.is_unmapped:
                        unmapped += 1
                        continue

                    if gene_tags and not aln.has_tag('GX'):
                        unmapped += 1
                        continue

                    if aln.qname != current_read:
                        current_read = aln.qname
                        if subsample and i not in index_filter:
                            count_this_read = False
                            continue
                        else:
                            count_this_read = True
                    else:
                        if not count_this_read:
                            continue

                    if parse_tags:
                        CB = aln.get_tag('CR')
                    else:
                        match = parser_re.match(aln.qname)
                        CB = match.group('CB')

                    if filter_cb:
                        if CB not in cb_hist.index:
                            nomatchcb += 1
                            continue

                    if parse_tags:
                        MB = aln.get_tag('UM')
                    else:
                        MB = match.group('MB')

                    if gene_tags:
                        target_name = aln.get_tag('GX').split(',')[0]
                    else:
                        txid = sam_file.getrname(aln.reference_id)
                        if gene_map:
                            if txid in gene_map:
                                target_name = gene_map[txid]
                            else:
                                missing_transcripts.add(txid)
                                continue
                        else:
                            target_name = txid
                    targets_seen.add(target_name)

                    # Scale evidence by number of hits
                    evidence[CB][MB] += weigh_evidence(aln.tags)
                    bare_evidence[CB] += weigh_evidence(aln.tags)
                    kept += 1
                transcripts_processed += 1
                if not transcripts_processed % 1000:
                    logger.info("%d genes processed." % genes_processed)
                    logger.info("%d transcripts processed." % transcripts_processed)
                    logger.info("%d alignments processed." % alignments_processed)

            earray = []
            for cell in cells:
                umis = [1 for _, v in evidence[cell].items() if v >= minevidence]
                earray.append(str(sum(umis)))
            out_handle.write(",".join([gene] + earray) + "\n")
            earray = []
            if umi_matrix:
                for cell in cells:
                    earray.append(str(int(bare_evidence[cell])))
                bare_evidence_handle.write(",".join([gene] + earray) + "\n")

            evidence = collections.defaultdict(lambda: collections.defaultdict(int))
            bare_evidence = collections.defaultdict(int)
            genes_processed += 1

    if umi_matrix:
        bare_evidence_handle.close()

    # fill dataframe with missing values, sort and output
    df = pd.read_csv(out, index_col=0, header=0)
    targets = pd.Series(index=set(transcript_map.keys()))
    targets = targets.sort_index()
    df = df.reindex(targets.index.values, fill_value=0)
    df = df.sort_index()
    df.to_csv(out)

    if umi_matrix:
        df = pd.read_csv(umi_matrix, index_col=0, header=0)
        df = df.reindex(targets.index.values, fill_value=0)
        df = df.sort_index()
        df.to_csv(umi_matrix)
Ejemplo n.º 7
0
Archivo: umis.py Proyecto: vals/umis
def tagcount(sam, out, genemap, output_evidence_table, positional, minevidence,
             cb_histogram, cb_cutoff, no_scale_evidence, subsample, sparse,
             parse_tags, gene_tags):
    ''' Count up evidence for tagged molecules
    '''
    from pysam import AlignmentFile

    from io import StringIO
    import pandas as pd

    from utils import weigh_evidence

    logger.info('Reading optional files')

    gene_map = None
    if genemap:
        with open(genemap) as fh:
            try:
                gene_map = dict(p.strip().split() for p in fh)
            except ValueError:
                logger.error('Incorrectly formatted gene_map, need to be tsv.')
                sys.exit()

    if positional:
        tuple_template = '{0},{1},{2},{3}'
    else:
        tuple_template = '{0},{1},{3}'

    if not cb_cutoff:
        cb_cutoff = 0

    if cb_histogram and cb_cutoff == "auto":
        cb_cutoff = guess_depth_cutoff(cb_histogram)

    cb_cutoff = int(cb_cutoff)

    cb_hist = None
    filter_cb = False
    if cb_histogram:
        cb_hist = pd.read_csv(cb_histogram, index_col=0, header=-1, squeeze=True, sep="\t")
        total_num_cbs = cb_hist.shape[0]
        cb_hist = cb_hist[cb_hist > cb_cutoff]
        logger.info('Keeping {} out of {} cellular barcodes.'.format(cb_hist.shape[0], total_num_cbs))
        filter_cb = True

    parser_re = re.compile('.*:CELL_(?P<CB>.*):UMI_(?P<MB>.*)')

    if subsample:
        logger.info('Creating reservoir of subsampled reads ({} per cell)'.format(subsample))
        start_sampling  = time.time()

        reservoir = collections.defaultdict(list)
        cb_hist_sampled = 0 * cb_hist
        cb_obs = 0 * cb_hist

        track = stream_bamfile(sam)
        current_read = 'none_observed_yet'
        for i, aln in enumerate(track):
            if aln.qname == current_read:
                continue

            current_read = aln.qname

            if parse_tags:
                CB = aln.get_tag('CR')
            else:
                match = parser_re.match(aln.qname)
                CB = match.group('CB')

            if CB not in cb_hist.index:
                continue

            cb_obs[CB] += 1
            if len(reservoir[CB]) < subsample:
                reservoir[CB].append(i)
                cb_hist_sampled[CB] += 1
            else:
                s = pd.np.random.randint(0, cb_obs[CB])
                if s < subsample:
                    reservoir[CB][s] = i

        index_filter = set(itertools.chain.from_iterable(reservoir.values()))
        sam_file.close()
        sampling_time = time.time() - start_sampling
        logger.info('Sampling done - {:.3}s'.format(sampling_time))

    evidence = collections.defaultdict(int)

    logger.info('Tallying evidence')
    start_tally = time.time()

    sam_mode = 'r' if sam.endswith(".sam") else 'rb'
    sam_file = AlignmentFile(sam, mode=sam_mode)
    targets = [x["SN"] for x in sam_file.header["SQ"]]
    track = sam_file.fetch(until_eof=True)
    count = 0
    unmapped = 0
    kept = 0
    nomatchcb = 0
    current_read = 'none_observed_yet'
    count_this_read = True
    missing_transcripts = set()
    for i, aln in enumerate(track):
        if count and not count % 1000000:
            logger.info("Processed %d alignments, kept %d." % (count, kept))
            logger.info("%d were filtered for being unmapped." % unmapped)
            if filter_cb:
                logger.info("%d were filtered for not matching known barcodes."
                            % nomatchcb)
        count += 1

        if aln.is_unmapped:
            unmapped += 1
            continue

        if gene_tags and not aln.has_tag('GX'):
            unmapped += 1
            continue

        if aln.qname != current_read:
            current_read = aln.qname
            if subsample and i not in index_filter:
                count_this_read = False
                continue
            else:
                count_this_read = True
        else:
            if not count_this_read:
                continue

        if parse_tags:
            CB = aln.get_tag('CR')
        else:
            match = parser_re.match(aln.qname)
            CB = match.group('CB')

        if filter_cb:
            if CB not in cb_hist.index:
                nomatchcb += 1
                continue

        if parse_tags:
            MB = aln.get_tag('UM')
        else:
            MB = match.group('MB')

        if gene_tags:
            target_name = aln.get_tag('GX').split(',')[0]
        else:
            txid = sam_file.getrname(aln.reference_id)
            if gene_map:
                if txid in gene_map:
                    target_name = gene_map[txid]
                else:
                    missing_transcripts.add(txid)
                    target_name = txid
            else:
                target_name = txid

        e_tuple = tuple_template.format(CB, target_name, aln.pos, MB)

        # Scale evidence by number of hits
        if no_scale_evidence:
            evidence[e_tuple] += 1.0
        else:
            evidence[e_tuple] += weigh_evidence(aln.tags)

        kept += 1

    tally_time = time.time() - start_tally
    if missing_transcripts:
        logger.warn('The following transcripts were missing gene_ids, so we added them as the transcript ids: %s' % str(missing_transcripts))
    logger.info('Tally done - {:.3}s, {:,} alns/min'.format(tally_time, int(60. * count / tally_time)))
    logger.info('Collapsing evidence')

    logger.info('Writing evidence')
    with tempfile.NamedTemporaryFile('w+t') as out_handle:
        for key in evidence:
            line = '{},{}\n'.format(key, evidence[key])
            out_handle.write(line)

        out_handle.flush()
        out_handle.seek(0)

        evidence_table = pd.read_csv(out_handle, header=None)

    del evidence

    evidence_query = 'evidence >= %f' % minevidence
    if positional:
        evidence_table.columns=['cell', 'gene', 'umi', 'pos', 'evidence']
        collapsed = evidence_table.query(evidence_query).groupby(['cell', 'gene'])['umi', 'pos'].size()

    else:
        evidence_table.columns=['cell', 'gene', 'umi', 'evidence']
        collapsed = evidence_table.query(evidence_query).groupby(['cell', 'gene'])['umi'].size()

    expanded = collapsed.unstack().T

    if gene_map:
        # This Series is just for sorting the index
        genes = pd.Series(index=set(gene_map.values()))
        genes = genes.sort_index()
        # Now genes is assigned to a DataFrame
        genes = expanded.ix[genes.index]

    elif gene_tags:
        expanded.sort_index()
        genes = expanded

    else:
        # make data frame have a complete accounting of transcripts
        targets = pd.Series(index=set(targets))
        targets = targets.sort_index()
        expanded = expanded.reindex(targets.index.values, fill_value=0)
        genes = expanded

    genes.fillna(0, inplace=True)
    genes = genes.astype(int)
    genes.index.name = "gene"

    logger.info('Output results')

    if subsample:
        cb_hist_sampled.to_csv('ss_{}_'.format(subsample) + os.path.basename(cb_histogram), sep='\t')

    if output_evidence_table:
        import shutil
        buf.seek(0)
        with open(output_evidence_table, 'w') as etab_fh:
            shutil.copyfileobj(buf, etab_fh)

    if sparse:
        pd.Series(genes.index).to_csv(out + ".rownames", index=False, header=False)
        pd.Series(genes.columns.values).to_csv(out + ".colnames", index=False, header=False)
        with open(out, "w+b") as out_handle:
            scipy.io.mmwrite(out_handle, scipy.sparse.csr_matrix(genes))
    else:
        genes.to_csv(out)