예제 #1
0
def main(args, outs):
    outs.coerce_strings()

    in_bam = tk_bam.create_bam_infile(args.input_bam)
    out_bam, _ = tk_bam.create_bam_outfile(outs.output,
                                           None,
                                           None,
                                           template=in_bam)
    cell_bcs = set(cr_utils.load_barcode_tsv(args.cell_barcodes))

    for (tid, pos), reads_iter in itertools.groupby(in_bam,
                                                    key=cr_utils.pos_sort_key):
        dupe_keys = set()
        for read in reads_iter:
            if cr_utils.get_read_barcode(read) not in cell_bcs:
                continue

            if cr_utils.is_read_dupe_candidate(
                    read, cr_utils.get_high_conf_mapq(args.align)):
                dupe_key = (cr_utils.si_pcr_dupe_func(read),
                            cr_utils.get_read_umi(read))
                if dupe_key in dupe_keys:
                    continue

                dupe_keys.add(dupe_key)
                out_bam.write(read)
예제 #2
0
def main(args, outs):
    bam_in = tk_bam.create_bam_infile(args.chunk_input)

    # Get gem groups
    library_info = rna_library.get_bam_library_info(bam_in)
    gem_groups = sorted(list(set(lib['gem_group'] for lib in library_info)))

    # Define buckets
    bucket_names = []
    prefixes = cr_utils.get_seqs(args.nbases)
    for gg in gem_groups:
        for prefix in prefixes:
            bucket_names.append('%s-%d' % (prefix, gg))
    bucket_names.append('')

    # Read all records
    reads = [read for read in bam_in]

    # Bucket the records
    bams_out = {}
    outs.buckets = {}
    buckets = {}
    for bucket_name in bucket_names:
        filename = martian.make_path("bc-%s.bam" % bucket_name)
        bam_out, _ = tk_bam.create_bam_outfile(filename,
                                               None,
                                               None,
                                               template=bam_in,
                                               rgs=args.read_groups,
                                               replace_rg=True)

        bams_out[bucket_name] = bam_out
        outs.buckets[bucket_name] = filename
        buckets[bucket_name] = []

    for r in reads:
        barcode = cr_utils.get_read_barcode(r)
        if barcode is None:
            bucket_name = ''
        else:
            barcode_seq, gem_group = cr_utils.split_barcode_seq(barcode)
            prefix = barcode_seq[:args.nbases]
            bucket_name = '%s-%d' % (prefix, gem_group)
        buckets[bucket_name].append(r)

    for bucket_name, bucket in buckets.iteritems():
        bucket.sort(key=cr_utils.barcode_sort_key)
        bam_out = bams_out[bucket_name]
        for r in bucket:
            bam_out.write(r)
        bam_out.close()
예제 #3
0
def main(args, outs):
    outs.coerce_strings()

    in_bam = tk_bam.create_bam_infile(args.possorted_bam)
    in_bam_chunk = tk_bam.read_bam_chunk(in_bam, (args.chunk_start, args.chunk_end))
    out_bam, _ = tk_bam.create_bam_outfile(outs.filtered_bam, None, None, template=in_bam)
    cluster_bcs = set(args.cluster_bcs)

    for (tid, pos), reads_iter in itertools.groupby(in_bam_chunk, key=cr_utils.pos_sort_key):
        dupe_keys = set()
        for read in reads_iter:
            if cr_utils.get_read_barcode(read) not in cluster_bcs:
                continue

            if cr_utils.is_read_dupe_candidate(read, cr_utils.get_high_conf_mapq({"high_conf_mapq":60})):
                dupe_key = (cr_utils.si_pcr_dupe_func(read), cr_utils.get_read_umi(read))
                if dupe_key in dupe_keys:
                    continue

                dupe_keys.add(dupe_key)
                read.is_duplicate = False
                out_bam.write(read)
예제 #4
0
def main(args, outs):
    prefixes = cr_utils.get_seqs(args.nbases)
    prefixes.append('')

    bam_in = tk_bam.create_bam_infile(args.chunk_input)
    reads = [read for read in bam_in]

    bams_out = {}
    outs.buckets = {}
    buckets = {}
    for prefix in prefixes:
        filename = martian.make_path("bc_%s.bam" % prefix)
        bam_out, _ = tk_bam.create_bam_outfile(filename,
                                               None,
                                               None,
                                               template=bam_in,
                                               rgs=args.read_groups,
                                               replace_rg=True)

        bams_out[prefix] = bam_out
        outs.buckets[prefix] = filename
        buckets[prefix] = []

    for r in reads:
        barcode = cr_utils.get_read_barcode(r)
        if barcode is None:
            prefix = ''
        else:
            prefix = barcode[:args.nbases]
        buckets[prefix].append(r)

    for prefix, bucket in buckets.iteritems():
        bucket.sort(key=cr_utils.barcode_sort_key)
        bam_out = bams_out[prefix]
        for r in bucket:
            bam_out.write(r)
        bam_out.close()
예제 #5
0
def main(args, outs):
    in_bam = tk_bam.create_bam_infile(args.reads)

    out_vcf = tk_io.VariantFileWriter(open(outs.filtered_variants, 'w'),
                                      template_file=open(args.chunk_variants))

    snps = load_snps(args.snps)
    bcs = cr_utils.load_barcode_tsv(args.cell_barcodes)

    raw_matrix_types = snp_constants.SNP_BASE_TYPES
    raw_matrix_snps = [snps for _ in snp_constants.SNP_BASE_TYPES]
    raw_allele_bc_matrices = cr_matrix.GeneBCMatrices(raw_matrix_types,
                                                      raw_matrix_snps, bcs)

    likelihood_matrix_types = snp_constants.ALLELES
    likelihood_matrix_snps = [snps for _ in snp_constants.ALLELES]
    likelihood_allele_bc_matrices = cr_matrix.GeneBCMatrices(
        likelihood_matrix_types, likelihood_matrix_snps, bcs, dtype=np.float64)

    # Configurable SNP filter parameters
    min_snp_call_qual = args.min_snp_call_qual if args.min_snp_call_qual is not None else snp_constants.DEFAULT_MIN_SNP_CALL_QUAL
    min_bcs_per_snp = args.min_bcs_per_snp if args.min_bcs_per_snp is not None else snp_constants.DEFAULT_MIN_BCS_PER_SNP
    min_snp_obs = args.min_snp_obs if args.min_snp_obs is not None else snp_constants.DEFAULT_MIN_SNP_OBS
    base_error_rate = args.base_error_rate if args.base_error_rate is not None else snp_constants.DEFAULT_BASE_ERROR_RATE
    min_snp_base_qual = args.min_snp_base_qual if args.min_snp_base_qual is not None else snp_constants.DEFAULT_MIN_SNP_BASE_QUAL

    for record in vcf_record_iter(args.chunk_variants, min_snp_call_qual):
        ref_base = str(record.REF)
        alt_base = str(record.ALT[0])

        pos = record.POS - 1
        snps = collections.defaultdict(lambda: np.zeros((2, 2)))
        for col in in_bam.pileup(record.CHROM, pos, pos + 1):
            if col.pos != pos:
                continue

            for read in col.pileups:
                bc = cr_utils.get_read_barcode(read.alignment)
                umi = cr_utils.get_read_umi(read.alignment)
                assert bc in set(bcs) and umi is not None

                # Overlaps an exon junction
                qpos = get_read_qpos(read)
                if qpos is None:
                    continue

                base = str(read.alignment.query[qpos - read.alignment.qstart])
                base_qual = ord(read.alignment.qual[
                    qpos -
                    read.alignment.qstart]) - tk_constants.ILLUMINA_QUAL_OFFSET

                if base == ref_base:
                    base_index = 0
                elif base == alt_base:
                    base_index = 1
                else:
                    continue

                dupe_key = (bc, umi)
                snps[dupe_key][base_index, 0] += 1
                snps[dupe_key][base_index,
                               1] = max(base_qual, snps[dupe_key][base_index,
                                                                  1])

        bcs_bases = collections.defaultdict(collections.Counter)
        for (bc, umi), bases in snps.iteritems():
            base_index = np.argmax(bases[:, 0])
            base = ref_base if base_index == 0 else alt_base
            base_qual = bases[base_index, 1]
            if base_qual < min_snp_base_qual:
                continue
            bcs_bases[bc][base] += 1

        # Filter if not enough unique barcodes
        if len(bcs_bases) < min_bcs_per_snp:
            continue

        # Filter if not enough observed bases
        snp_obs = 0
        for b in bcs_bases.itervalues():
            snp_obs += sum([count for count in b.itervalues()])
        if snp_obs < min_snp_obs:
            continue

        for bc, bases in bcs_bases.iteritems():
            ref_obs = bases[ref_base]
            alt_obs = bases[alt_base]
            total_obs = ref_obs + alt_obs
            obs = np.array([
                ref_obs,
                alt_obs,
            ])

            log_p_hom_ref = sp_stats.binom.logpmf(ref_obs, total_obs,
                                                  1 - base_error_rate)
            log_p_hom_alt = sp_stats.binom.logpmf(alt_obs, total_obs,
                                                  1 - base_error_rate)
            log_p_het = sp_stats.binom.logpmf(ref_obs, total_obs, 0.5)

            log_p = np.array([
                log_p_hom_ref,
                log_p_het,
                log_p_hom_alt,
            ])
            log_p -= sp_misc.logsumexp(log_p)

            matrix = raw_allele_bc_matrices.matrices.values()[0]
            snp_index = matrix.gene_id_to_int(format_record(record))
            bc_index = matrix.bc_to_int(bc)

            for i, base_type in enumerate(snp_constants.SNP_BASE_TYPES):
                raw_allele_bc_matrices.get_matrix(base_type).m[
                    snp_index, bc_index] = obs[i]

            for i, allele in enumerate(snp_constants.ALLELES):
                likelihood_allele_bc_matrices.get_matrix(allele).m[
                    snp_index, bc_index] = log_p[i]

        out_vcf.write_record(record)

    raw_allele_bc_matrices.save_h5(outs.raw_allele_bc_matrices_h5)
    likelihood_allele_bc_matrices.save_h5(
        outs.likelihood_allele_bc_matrices_h5)
예제 #6
0
def main(args, outs):
    outs.coerce_strings()

    in_bam = tk_bam.create_bam_infile(args.chunk_input)

    counter = cr_mol_counter.MoleculeCounter.open(outs.output, mode='w')

    mol_data_keys = cr_mol_counter.MoleculeCounter.get_data_columns()
    mol_data_columns = {key: idx for idx, key in enumerate(mol_data_keys)}

    gene_index = cr_reference.GeneIndex.load_pickle(
        cr_utils.get_reference_genes_index(args.reference_path))
    genomes = cr_utils.get_reference_genomes(args.reference_path)
    genome_index = cr_reference.get_genome_index(genomes)
    none_gene_id = len(gene_index.get_genes())

    # store reference index columns
    # NOTE - these must be cast to str first, as unicode is not supported
    counter.set_ref_column('genome_ids', [str(genome) for genome in genomes])
    counter.set_ref_column('gene_ids',
                           [str(gene.id) for gene in gene_index.genes])
    counter.set_ref_column('gene_names',
                           [str(gene.name) for gene in gene_index.genes])

    filtered_bcs_per_genome = cr_utils.load_barcode_csv(args.filtered_barcodes)
    filtered_bcs = set()
    for _, bcs in filtered_bcs_per_genome.iteritems():
        filtered_bcs |= set(bcs)

    gg_metrics = collections.defaultdict(
        lambda: {cr_mol_counter.GG_CONF_MAPPED_FILTERED_BC_READS_METRIC: 0})

    for (gem_group, barcode, gene_ids), reads_iter in itertools.groupby(
            in_bam, key=cr_utils.barcode_sort_key):
        if barcode is None or gem_group is None:
            continue
        is_cell_barcode = cr_utils.format_barcode_seq(
            barcode, gem_group) in filtered_bcs
        molecules = collections.defaultdict(
            lambda: np.zeros(len(mol_data_columns), dtype=np.uint64))

        compressed_barcode = cr_mol_counter.MoleculeCounter.compress_barcode_seq(
            barcode)
        gem_group = cr_mol_counter.MoleculeCounter.compress_gem_group(
            gem_group)

        read_positions = collections.defaultdict(set)
        for read in reads_iter:
            umi = cr_utils.get_read_umi(read)
            # ignore read2 to avoid double-counting. the mapping + annotation should be equivalent.
            if read.is_secondary or umi is None or read.is_read2:
                continue

            raw_umi = cr_utils.get_read_raw_umi(read)
            raw_bc, raw_gg = cr_utils.split_barcode_seq(
                cr_utils.get_read_raw_barcode(read))
            proc_bc, proc_gg = cr_utils.split_barcode_seq(
                cr_utils.get_read_barcode(read))

            if cr_utils.is_read_conf_mapped_to_transcriptome(
                    read, cr_utils.get_high_conf_mapq(args.align)):
                assert len(gene_ids) == 1

                mol_key, map_type = (umi, gene_index.gene_id_to_int(
                    gene_ids[0])), 'reads'

                read_pos = (read.tid, read.pos)
                uniq_read_pos = read_pos not in read_positions[mol_key]
                read_positions[mol_key].add(read_pos)

                if is_cell_barcode:
                    gg_metrics[int(gem_group)][
                        cr_mol_counter.
                        GG_CONF_MAPPED_FILTERED_BC_READS_METRIC] += 1

            elif read.is_unmapped:
                mol_key, map_type, uniq_read_pos = (
                    umi, none_gene_id), 'unmapped_reads', False
            else:
                mol_key, map_type, uniq_read_pos = (
                    umi, none_gene_id), 'nonconf_mapped_reads', False
            molecules[mol_key][mol_data_columns[map_type]] += 1
            molecules[mol_key][mol_data_columns['umi_corrected_reads']] += int(
                not raw_umi == umi)
            molecules[mol_key][mol_data_columns[
                'barcode_corrected_reads']] += int(not raw_bc == proc_bc)
            molecules[mol_key][mol_data_columns[
                'conf_mapped_uniq_read_pos']] += int(uniq_read_pos)

        for mol_key, molecule in sorted(molecules.items()):
            umi, gene_id = mol_key
            genome = cr_utils.get_genome_from_str(
                gene_index.int_to_gene_id(gene_id), genomes)
            genome_id = cr_reference.get_genome_id(genome, genome_index)
            counter.add(
                barcode=compressed_barcode,
                gem_group=gem_group,
                umi=cr_mol_counter.MoleculeCounter.compress_umi_seq(umi),
                gene=gene_id,
                genome=genome_id,
                **{
                    key: molecule[col_idx]
                    for key, col_idx in mol_data_columns.iteritems()
                })

    in_bam.close()

    counter.set_metric(cr_mol_counter.GEM_GROUPS_METRIC, dict(gg_metrics))

    counter.save()