def join(args, outs, chunk_defs, chunk_outs):
    args.coerce_strings()
    outs.coerce_strings()
    if args.fragments is None:
        outs.cell_barcodes = None
        outs.cell_calling_summary = None
        outs.singlecell = None
        return

    if args.excluded_barcodes is not None:
        with open(args.excluded_barcodes, 'r') as infile:
            excluded_barcodes = json.load(infile)
    else:
        excluded_barcodes = None

    # Merge the chunk inputs
    ref = ReferenceManager(args.reference_path)
    species_list = ref.list_species()

    barcode_counts_by_species = {
        species: Counter()
        for species in species_list
    }
    targeted_counts_by_species = {
        species: Counter()
        for species in species_list
    }
    fragment_depth = 0
    for chunk_in, chunk_out in zip(chunk_defs, chunk_outs):
        species = ref.species_from_contig(chunk_in.contig)
        with open(chunk_out.barcode_counts, 'r') as infile:
            barcode_counts_by_species[species] += pickle.load(infile)
        with open(chunk_out.targeted_counts, 'r') as infile:
            targeted_counts_by_species[species] += pickle.load(infile)
        fragment_depth += chunk_out.fragment_depth
    print('Total fragments across all chunks: {}'.format(fragment_depth))

    barcodes = list({
        bc
        for species in species_list
        for bc in barcode_counts_by_species[species]
    })
    non_excluded_barcodes = {
        species:
        [bc for bc in barcodes if bc not in excluded_barcodes[species]]
        for species in species_list
    }
    print('Total barcodes observed: {}'.format(len(barcodes)))

    retained_counts = {}
    for species in species_list:
        if excluded_barcodes is None:
            retained_counts[species] = np.array(
                [targeted_counts_by_species[species][bc] for bc in barcodes])
        else:
            retained_counts[species] = np.array([
                targeted_counts_by_species[species][bc] for bc in barcodes
                if bc not in excluded_barcodes[species]
            ])
            print('Barcodes excluded for species {}: {}'.format(
                species, len(excluded_barcodes[species])))
            print('Barcodes remaining for species {}: {}'.format(
                species, len(non_excluded_barcodes[species])))

    parameters = {}

    whitelist_length = len(load_barcode_whitelist(args.barcode_whitelist))
    count_shift = max(
        MINIMUM_COUNT,
        int(fragment_depth * WHITELIST_CONTAM_RATE / whitelist_length))
    print('Count shift for whitelist contamination: {}'.format(count_shift))

    for (species, count_data) in retained_counts.iteritems():
        print('Analyzing species {}'.format(species))
        # Subtract MINIMUM_COUNT from all counts to remove the effects of whitelist contamination
        shifted_data = count_data[count_data >= count_shift] - count_shift
        print('Number of barcodes analyzed: {}'.format(len(shifted_data)))
        count_dict = Counter(shifted_data)
        parameters[species] = {}

        forced_cell_count = None
        if args.force_cells is not None:
            if species in args.force_cells:
                forced_cell_count = int(args.force_cells[species])
            elif "default" in args.force_cells:
                forced_cell_count = int(args.force_cells["default"])
            if forced_cell_count > MAXIMUM_CELLS_PER_SPECIES:
                forced_cell_count = MAXIMUM_CELLS_PER_SPECIES
                martian.log_info(
                    'Attempted to force cells to {}.  Overriding to maximum allowed cells.'
                    .format(forced_cell_count))

        # Initialize parameters to empty
        parameters[species]['noise_mean'] = None
        parameters[species]['noise_dispersion'] = None
        parameters[species]['signal_mean'] = None
        parameters[species]['signal_dispersion'] = None
        parameters[species]['fraction_noise'] = None
        parameters[species]['cell_threshold'] = None
        parameters[species]['goodness_of_fit'] = None
        parameters[species]['estimated_cells_present'] = 0

        # Corner case where FRIP is 0 because the number of peaks is tiny (fuzzer tests)
        if len(count_dict) < 10:
            parameters[species]['cells_detected'] = 0
            forced_cell_count = None
        elif forced_cell_count is None:
            print('Estimating parameters')
            fitted_params = estimate_parameters(count_dict)
            signal_threshold = estimate_threshold(
                fitted_params, CELL_CALLING_THRESHOLD) + count_shift
            print('Primary threshold: {}'.format(signal_threshold))
            parameters[species]['noise_mean'] = fitted_params.mu_noise
            parameters[species]['noise_dispersion'] = fitted_params.alpha_noise
            parameters[species]['signal_mean'] = fitted_params.mu_signal
            parameters[species][
                'signal_dispersion'] = fitted_params.alpha_signal
            parameters[species]['fraction_noise'] = fitted_params.frac_noise
            parameters[species]['cell_threshold'] = signal_threshold
            parameters[species]['goodness_of_fit'] = goodness_of_fit(
                shifted_data, fitted_params)
            called_cell_count = np.sum(count_data >= signal_threshold)
            parameters[species]['cells_detected'] = called_cell_count
            parameters[species]['estimated_cells_present'] = int(
                (1 - fitted_params.frac_noise) * len(shifted_data))
            if called_cell_count > MAXIMUM_CELLS_PER_SPECIES:
                # Abort the model fitting and instead force cells to the maximum
                forced_cell_count = MAXIMUM_CELLS_PER_SPECIES

        if forced_cell_count is not None:
            print('Forcing cells to {}'.format(forced_cell_count))

            if forced_cell_count <= 0:
                raise ValueError("Force cells must be positive")
            else:
                adj_data = shifted_data[shifted_data > 0]
                print('Total barcodes considered for forcing cells: {}'.format(
                    len(adj_data)))
                parameters[species]['cell_threshold'] = min(adj_data) if forced_cell_count >= len(adj_data) else \
                    sorted(adj_data, reverse=True)[forced_cell_count - 1]
                parameters[species]['cell_threshold'] += count_shift
                parameters[species]['cells_detected'] = np.sum(
                    count_data >= parameters[species]['cell_threshold'])

    # For barnyard samples, mask out the noise distribution and re-fit to get cleaner separation
    if len(retained_counts) == 2 and (args.force_cells is None
                                      or not args.force_cells):
        print('Estimating secondary thresholds')
        sp1, sp2 = species_list

        sp1_threshold = -1 if parameters[sp1][
            'cell_threshold'] is not None else parameters[sp1]['cell_threshold']
        sp2_threshold = -1 if parameters[sp2][
            'cell_threshold'] is not None else parameters[sp2]['cell_threshold']

        if parameters[sp1]['cell_threshold'] is not None:
            sp1_counts = np.array([
                targeted_counts_by_species[sp1][bc]
                for bc in non_excluded_barcodes[sp1]
                if (targeted_counts_by_species[sp1][bc] > sp1_threshold) and (
                    targeted_counts_by_species[sp2][bc] > sp2_threshold)
            ])
            sp1_params = estimate_parameters(Counter(sp1_counts),
                                             threshold=sp1_threshold)
            if not np.isnan(sp1_params.frac_noise):
                parameters[sp1]['cell_threshold'] = max(
                    sp1_threshold, estimate_threshold(sp1_params, 20))
            parameters[sp1]['cells_detected'] = np.sum(
                sp1_counts >= parameters[sp1]['cell_threshold'])
        else:
            parameters[sp1]['cells_detected'] = 0

        if parameters[sp2]['cell_threshold'] is not None:
            sp2_counts = np.array([
                targeted_counts_by_species[sp2][bc]
                for bc in non_excluded_barcodes[sp2]
                if (targeted_counts_by_species[sp1][bc] > sp1_threshold) and (
                    targeted_counts_by_species[sp2][bc] > sp2_threshold)
            ])
            sp2_params = estimate_parameters(Counter(sp2_counts),
                                             threshold=sp2_threshold)
            if not np.isnan(sp2_params.frac_noise):
                parameters[sp2]['cell_threshold'] = max(
                    sp2_threshold, estimate_threshold(sp2_params, 20))
            parameters[sp2]['cells_detected'] = np.sum(
                sp2_counts >= parameters[sp2]['cell_threshold'])
        else:
            parameters[sp2]['cells_detected'] = 0

        print('Secondary threshold ({}): {}'.format(
            sp1, parameters[sp1]['cell_threshold']))
        print('Secondary threshold ({}): {}'.format(
            sp2, parameters[sp2]['cell_threshold']))

    print('Writing out cell barcodes')
    cell_barcodes = {}
    for (species, count_data) in retained_counts.iteritems():
        threshold = parameters[species]['cell_threshold']
        cell_barcodes[species] = {}
        print('Cell threshold for species {}: {}'.format(species, threshold))
        if threshold is not None:
            for count, barcode in zip(count_data,
                                      non_excluded_barcodes[species]):
                if count >= threshold:
                    print('{} - Total {}, Targeted {}, Count {}, Threshold {}'.
                          format(barcode,
                                 barcode_counts_by_species[species][barcode],
                                 targeted_counts_by_species[species][barcode],
                                 count, threshold))
                    cell_barcodes[species][barcode] = count
        if len(cell_barcodes[species]
               ) != parameters[species]['cells_detected']:
            print(len(cell_barcodes[species]),
                  parameters[species]['cells_detected'])
            raise ValueError(
                'Mismatch in called cells identified - failure in threshold setting'
            )
        print('Selected {} barcodes of species {}'.format(
            len(cell_barcodes[species]), species))

    with open(outs.cell_barcodes, 'w') as outfile:
        # low mem reduce op to merge-sort bcs across species
        for species in cell_barcodes.keys():
            outfile.write(species + ",")
            outfile.write(",".join(cell_barcodes[species]) + "\n")

    cell_index = compute_cell_index(species_list, cell_barcodes)

    with open(outs.singlecell, 'w') as outfile:
        outfile.write("barcode,cell_id,")
        outfile.write(",".join([
            "is_{}_cell_barcode".format(species) for species in species_list
        ]))
        if len(species_list) > 1:
            for species in species_list:
                outfile.write(",passed_filters_{}".format(species))
                outfile.write(",peak_region_fragments_{}".format(species))
        outfile.write("\n")
        for barcode in [NO_BARCODE] + sorted(barcodes):
            outfile.write("{},".format(barcode))
            outfile.write("{},".format(cell_index.get(barcode, "None")))
            values = [
                str(
                    int(species in cell_barcodes
                        and barcode in cell_barcodes[species]))
                for species in species_list
            ]
            outfile.write(",".join(values))
            if len(species_list) > 1:
                for species in species_list:
                    outfile.write(",{:d}".format(
                        barcode_counts_by_species[species][barcode]))
                    outfile.write(",{:d}".format(
                        targeted_counts_by_species[species][barcode]))
            outfile.write("\n")

    # process data into summary metrics
    summary_info = {}
    summary_info.update(
        generate_cell_calling_metrics(parameters, cell_barcodes))
    summary_info.update(generate_gb_metrics(cell_barcodes, excluded_barcodes))

    with open(outs.cell_calling_summary, 'w') as outfile:
        outfile.write(json.dumps(summary_info, indent=4))
Example #2
0
def join(args, outs, chunk_defs, chunk_outs):
    args.coerce_strings()
    outs.coerce_strings()
    if args.fragments is None:
        outs.low_targeting_barcodes = None
        outs.low_targeting_summary = None
        return

    # Merge the chunk inputs
    ref = ReferenceManager(args.reference_path)
    species_list = ref.list_species()

    barcode_counts_by_species = {species: Counter() for species in species_list}
    targeted_counts_by_species = {species: Counter() for species in species_list}

    peak_bp_by_species = {species: 0 for species in species_list}
    genome_bp_by_species = {species: 0 for species in species_list}

    fragment_lengths = {padding: Counter() for padding in PADDING_VALUES}
    covered_bases = {padding: Counter() for padding in PADDING_VALUES}

    for chunk_in, chunk_out in zip(chunk_defs, chunk_outs):
        species = ref.species_from_contig(chunk_in.contig)

        with open(chunk_out.fragment_counts, "r") as infile:
            barcode_counts_by_species[species] += pickle.load(infile)
        with open(chunk_out.targeted_counts, "r") as infile:
            targeted_counts_by_species[species] += pickle.load(infile)

        with open(chunk_out.fragment_lengths, "r") as infile:
            data = pickle.load(infile)
            for padding in PADDING_VALUES:
                fragment_lengths[padding] += data[padding]

        with open(chunk_out.covered_bases, "r") as infile:
            data = pickle.load(infile)
            for padding in PADDING_VALUES:
                covered_bases[padding] += data[padding]

        peak_bp_by_species[species] += chunk_out.peak_coverage
        genome_bp_by_species[species] += ref.contig_lengths[chunk_in.contig]

    frac_genome_in_peaks_by_species = {
        species: peak_bp_by_species[species] / genome_bp_by_species[species]
        for species in species_list
    }

    # Identify barcodes that have lower fraction of reads overlapping peaks than the
    # genomic coverage of the peaks
    low_targeting_barcodes = {
        "label": "low_targeting",
        "data": {species: {} for species in species_list}
    }
    for species in species_list:
        for barcode, total_count in barcode_counts_by_species[species].iteritems():
            barcode_frac_peaks = (
                targeted_counts_by_species[species][barcode] / total_count
            )
            if barcode_frac_peaks < frac_genome_in_peaks_by_species[species]:
                low_targeting_barcodes["data"][species][barcode] = barcode_frac_peaks

    # Sum up the total fragment counts per barcode across all species
    total_barcode_counts = Counter()
    for species, barcode_counts in barcode_counts_by_species.iteritems():
        total_barcode_counts += barcode_counts
    with open(outs.barcode_counts, "w") as outfile:
        outfile.write(json.dumps(total_barcode_counts, indent=4))

    summary_data = {}
    for species in species_list:
        key_suffix = "" if len(species_list) == 1 else "_{}".format(species)
        summary_data["number_of_low_targeting_barcodes{}".format(key_suffix)] = len(
            low_targeting_barcodes["data"][species]
        )
        summary_data[
            "fraction_of_genome_within_{}bp_of_peaks{}".format(DISTANCE, key_suffix)
        ] = frac_genome_in_peaks_by_species[species]
    with open(outs.low_targeting_summary, "w") as outfile:
        outfile.write(json.dumps(summary_data, indent=4))
    with open(outs.low_targeting_barcodes, "w") as outfile:
        outfile.write(json.dumps(low_targeting_barcodes, indent=4))
    with open(outs.fragment_lengths, "w") as outfile:
        outfile.write(json.dumps(fragment_lengths, indent=4))
    with open(outs.covered_bases, "w") as outfile:
        outfile.write(json.dumps(covered_bases, indent=4))
Example #3
0
def get_counts_by_barcode(reference_path, peaks, fragments, fragments_index=None, contig=None, known_cells=None):
    """Generate targeting, raw and dup counts per barcode. If cell identity is known, then also return that as part of
    the counts
    """
    def load_reference_track(track, padding=0):
        if track is not None:
            with open(track, 'r') as infile:
                regions = regtools.get_target_regions(infile, padding=padding)
        else:
            regions = None
        return regions

    def point_is_in_target(contig, position, target_regions):
        if target_regions is None:
            return False
        if contig not in target_regions:
            return False
        return target_regions[contig].contains_point(position)

    def fragment_overlaps_target(contig, start, stop, target_regions):
        if target_regions is None:
            return False
        if contig not in target_regions:
            return False
        return target_regions[contig].overlaps_region(start, stop)

    ref_manager = ReferenceManager(reference_path)

    # Load in and pad TSS/CTCF regions if present
    tss_regions = load_reference_track(ref_manager.tss_track, padding=2000)
    ctcf_regions = load_reference_track(ref_manager.ctcf_track, padding=250)

    # Load in regions from reference-associated tracks
    dnase_regions = load_reference_track(ref_manager.dnase_track)
    enhancer_regions = load_reference_track(ref_manager.enhancer_track)
    promoter_regions = load_reference_track(ref_manager.promoter_track)
    blacklist_regions = load_reference_track(ref_manager.blacklist_track)
    peak_regions = load_reference_track(peaks)

    # load cell - species map
    cell_barcodes = {}
    species_list = ref_manager.list_species()
    if known_cells is not None:
        with open(known_cells, 'r') as infile:
            for line in infile:
                items = line.strip("\n").split(",")
                for barcode in items[1:]:
                    if barcode != "null":
                        if barcode not in cell_barcodes:
                            cell_barcodes[barcode] = []
                        cell_barcodes[barcode] += [items[0]]

    # get cell index
    cell_index = {}
    spnum = {species: 0 for species in species_list}
    for species in species_list:
        for barcode in cell_barcodes:
            if species in cell_barcodes[barcode]:
                label = "{}_cell_{}".format(species, spnum[species])
                spnum[species] += 1
                cell_index[barcode] = label if barcode not in cell_index else '_'.join([cell_index[barcode], label])

    counts_by_barcode = {}
    tss_relpos = Counter()
    ctcf_relpos = Counter()

    read_count = 0

    iterator = open_fragment_file(fragments) if contig is None else \
        parsed_fragments_from_contig(contig, fragments, index=fragments_index)
    for contig, start, stop, barcode, dups in iterator:
        read_count += 2
        if barcode not in counts_by_barcode:
            counts_by_barcode[barcode] = Counter()
            if known_cells is not None:
                cell_species = cell_barcodes.get(barcode, [])
                counts_by_barcode[barcode]["cell_id"] = cell_index.get(barcode, "None")
                for species in species_list:
                    if species in cell_species:
                        counts_by_barcode[barcode]["is_{}_cell_barcode".format(species)] = 1
                    else:
                        counts_by_barcode[barcode]["is_{}_cell_barcode".format(species)] = 0

        # species splits
        if known_cells is not None and len(species_list) > 1:
            contig_species = ref_manager.species_from_contig(contig)
            counts_by_barcode[barcode]["passed_filters_{}".format(contig_species)] += 1
            if fragment_overlaps_target(contig, start, stop, peak_regions):
                counts_by_barcode[barcode]["peak_region_fragments_{}".format(contig_species)] += 1

        # raw mapping
        counts_by_barcode[barcode]["passed_filters"] += 1
        counts_by_barcode[barcode]["total"] += dups
        counts_by_barcode[barcode]["duplicate"] += dups - 1

        # Count up transposition site targeting
        for position in (start, stop):
            if point_is_in_target(contig, position, tss_regions):
                region = tss_regions[contig].get_region_containing_point(position)
                tss_relpos[region.get_relative_position(position)] += 1
            if point_is_in_target(contig, position, ctcf_regions):
                region = ctcf_regions[contig].get_region_containing_point(position)
                ctcf_relpos[region.get_relative_position(position)] += 1
            if point_is_in_target(contig, position, peak_regions):
                counts_by_barcode[barcode]["peak_region_cutsites"] += 1

        # Count up fragment overlap targeting
        is_targeted = False
        if fragment_overlaps_target(contig, start, stop, tss_regions):
            counts_by_barcode[barcode]["TSS_fragments"] += 1
            is_targeted = True
        if fragment_overlaps_target(contig, start, stop, dnase_regions):
            counts_by_barcode[barcode]["DNase_sensitive_region_fragments"] += 1
            is_targeted = True
        if fragment_overlaps_target(contig, start, stop, enhancer_regions):
            counts_by_barcode[barcode]["enhancer_region_fragments"] += 1
            is_targeted = True
        if fragment_overlaps_target(contig, start, stop, promoter_regions):
            counts_by_barcode[barcode]["promoter_region_fragments"] += 1
            is_targeted = True
        if is_targeted:
            counts_by_barcode[barcode]["on_target_fragments"] += 1
        if fragment_overlaps_target(contig, start, stop, blacklist_regions):
            counts_by_barcode[barcode]["blacklist_region_fragments"] += 1
        if fragment_overlaps_target(contig, start, stop, peak_regions):
            counts_by_barcode[barcode]["peak_region_fragments"] += 1
    return read_count, counts_by_barcode, tss_relpos, ctcf_relpos