def join(args, outs, chunk_defs, chunk_outs): args.coerce_strings() outs.coerce_strings() if args.fragments is None: outs.cell_barcodes = None outs.cell_calling_summary = None outs.singlecell = None return if args.excluded_barcodes is not None: with open(args.excluded_barcodes, 'r') as infile: excluded_barcodes = json.load(infile) else: excluded_barcodes = None # Merge the chunk inputs ref = ReferenceManager(args.reference_path) species_list = ref.list_species() barcode_counts_by_species = { species: Counter() for species in species_list } targeted_counts_by_species = { species: Counter() for species in species_list } fragment_depth = 0 for chunk_in, chunk_out in zip(chunk_defs, chunk_outs): species = ref.species_from_contig(chunk_in.contig) with open(chunk_out.barcode_counts, 'r') as infile: barcode_counts_by_species[species] += pickle.load(infile) with open(chunk_out.targeted_counts, 'r') as infile: targeted_counts_by_species[species] += pickle.load(infile) fragment_depth += chunk_out.fragment_depth print('Total fragments across all chunks: {}'.format(fragment_depth)) barcodes = list({ bc for species in species_list for bc in barcode_counts_by_species[species] }) non_excluded_barcodes = { species: [bc for bc in barcodes if bc not in excluded_barcodes[species]] for species in species_list } print('Total barcodes observed: {}'.format(len(barcodes))) retained_counts = {} for species in species_list: if excluded_barcodes is None: retained_counts[species] = np.array( [targeted_counts_by_species[species][bc] for bc in barcodes]) else: retained_counts[species] = np.array([ targeted_counts_by_species[species][bc] for bc in barcodes if bc not in excluded_barcodes[species] ]) print('Barcodes excluded for species {}: {}'.format( species, len(excluded_barcodes[species]))) print('Barcodes remaining for species {}: {}'.format( species, len(non_excluded_barcodes[species]))) parameters = {} whitelist_length = len(load_barcode_whitelist(args.barcode_whitelist)) count_shift = max( MINIMUM_COUNT, int(fragment_depth * WHITELIST_CONTAM_RATE / whitelist_length)) print('Count shift for whitelist contamination: {}'.format(count_shift)) for (species, count_data) in retained_counts.iteritems(): print('Analyzing species {}'.format(species)) # Subtract MINIMUM_COUNT from all counts to remove the effects of whitelist contamination shifted_data = count_data[count_data >= count_shift] - count_shift print('Number of barcodes analyzed: {}'.format(len(shifted_data))) count_dict = Counter(shifted_data) parameters[species] = {} forced_cell_count = None if args.force_cells is not None: if species in args.force_cells: forced_cell_count = int(args.force_cells[species]) elif "default" in args.force_cells: forced_cell_count = int(args.force_cells["default"]) if forced_cell_count > MAXIMUM_CELLS_PER_SPECIES: forced_cell_count = MAXIMUM_CELLS_PER_SPECIES martian.log_info( 'Attempted to force cells to {}. Overriding to maximum allowed cells.' .format(forced_cell_count)) # Initialize parameters to empty parameters[species]['noise_mean'] = None parameters[species]['noise_dispersion'] = None parameters[species]['signal_mean'] = None parameters[species]['signal_dispersion'] = None parameters[species]['fraction_noise'] = None parameters[species]['cell_threshold'] = None parameters[species]['goodness_of_fit'] = None parameters[species]['estimated_cells_present'] = 0 # Corner case where FRIP is 0 because the number of peaks is tiny (fuzzer tests) if len(count_dict) < 10: parameters[species]['cells_detected'] = 0 forced_cell_count = None elif forced_cell_count is None: print('Estimating parameters') fitted_params = estimate_parameters(count_dict) signal_threshold = estimate_threshold( fitted_params, CELL_CALLING_THRESHOLD) + count_shift print('Primary threshold: {}'.format(signal_threshold)) parameters[species]['noise_mean'] = fitted_params.mu_noise parameters[species]['noise_dispersion'] = fitted_params.alpha_noise parameters[species]['signal_mean'] = fitted_params.mu_signal parameters[species][ 'signal_dispersion'] = fitted_params.alpha_signal parameters[species]['fraction_noise'] = fitted_params.frac_noise parameters[species]['cell_threshold'] = signal_threshold parameters[species]['goodness_of_fit'] = goodness_of_fit( shifted_data, fitted_params) called_cell_count = np.sum(count_data >= signal_threshold) parameters[species]['cells_detected'] = called_cell_count parameters[species]['estimated_cells_present'] = int( (1 - fitted_params.frac_noise) * len(shifted_data)) if called_cell_count > MAXIMUM_CELLS_PER_SPECIES: # Abort the model fitting and instead force cells to the maximum forced_cell_count = MAXIMUM_CELLS_PER_SPECIES if forced_cell_count is not None: print('Forcing cells to {}'.format(forced_cell_count)) if forced_cell_count <= 0: raise ValueError("Force cells must be positive") else: adj_data = shifted_data[shifted_data > 0] print('Total barcodes considered for forcing cells: {}'.format( len(adj_data))) parameters[species]['cell_threshold'] = min(adj_data) if forced_cell_count >= len(adj_data) else \ sorted(adj_data, reverse=True)[forced_cell_count - 1] parameters[species]['cell_threshold'] += count_shift parameters[species]['cells_detected'] = np.sum( count_data >= parameters[species]['cell_threshold']) # For barnyard samples, mask out the noise distribution and re-fit to get cleaner separation if len(retained_counts) == 2 and (args.force_cells is None or not args.force_cells): print('Estimating secondary thresholds') sp1, sp2 = species_list sp1_threshold = -1 if parameters[sp1][ 'cell_threshold'] is not None else parameters[sp1]['cell_threshold'] sp2_threshold = -1 if parameters[sp2][ 'cell_threshold'] is not None else parameters[sp2]['cell_threshold'] if parameters[sp1]['cell_threshold'] is not None: sp1_counts = np.array([ targeted_counts_by_species[sp1][bc] for bc in non_excluded_barcodes[sp1] if (targeted_counts_by_species[sp1][bc] > sp1_threshold) and ( targeted_counts_by_species[sp2][bc] > sp2_threshold) ]) sp1_params = estimate_parameters(Counter(sp1_counts), threshold=sp1_threshold) if not np.isnan(sp1_params.frac_noise): parameters[sp1]['cell_threshold'] = max( sp1_threshold, estimate_threshold(sp1_params, 20)) parameters[sp1]['cells_detected'] = np.sum( sp1_counts >= parameters[sp1]['cell_threshold']) else: parameters[sp1]['cells_detected'] = 0 if parameters[sp2]['cell_threshold'] is not None: sp2_counts = np.array([ targeted_counts_by_species[sp2][bc] for bc in non_excluded_barcodes[sp2] if (targeted_counts_by_species[sp1][bc] > sp1_threshold) and ( targeted_counts_by_species[sp2][bc] > sp2_threshold) ]) sp2_params = estimate_parameters(Counter(sp2_counts), threshold=sp2_threshold) if not np.isnan(sp2_params.frac_noise): parameters[sp2]['cell_threshold'] = max( sp2_threshold, estimate_threshold(sp2_params, 20)) parameters[sp2]['cells_detected'] = np.sum( sp2_counts >= parameters[sp2]['cell_threshold']) else: parameters[sp2]['cells_detected'] = 0 print('Secondary threshold ({}): {}'.format( sp1, parameters[sp1]['cell_threshold'])) print('Secondary threshold ({}): {}'.format( sp2, parameters[sp2]['cell_threshold'])) print('Writing out cell barcodes') cell_barcodes = {} for (species, count_data) in retained_counts.iteritems(): threshold = parameters[species]['cell_threshold'] cell_barcodes[species] = {} print('Cell threshold for species {}: {}'.format(species, threshold)) if threshold is not None: for count, barcode in zip(count_data, non_excluded_barcodes[species]): if count >= threshold: print('{} - Total {}, Targeted {}, Count {}, Threshold {}'. format(barcode, barcode_counts_by_species[species][barcode], targeted_counts_by_species[species][barcode], count, threshold)) cell_barcodes[species][barcode] = count if len(cell_barcodes[species] ) != parameters[species]['cells_detected']: print(len(cell_barcodes[species]), parameters[species]['cells_detected']) raise ValueError( 'Mismatch in called cells identified - failure in threshold setting' ) print('Selected {} barcodes of species {}'.format( len(cell_barcodes[species]), species)) with open(outs.cell_barcodes, 'w') as outfile: # low mem reduce op to merge-sort bcs across species for species in cell_barcodes.keys(): outfile.write(species + ",") outfile.write(",".join(cell_barcodes[species]) + "\n") cell_index = compute_cell_index(species_list, cell_barcodes) with open(outs.singlecell, 'w') as outfile: outfile.write("barcode,cell_id,") outfile.write(",".join([ "is_{}_cell_barcode".format(species) for species in species_list ])) if len(species_list) > 1: for species in species_list: outfile.write(",passed_filters_{}".format(species)) outfile.write(",peak_region_fragments_{}".format(species)) outfile.write("\n") for barcode in [NO_BARCODE] + sorted(barcodes): outfile.write("{},".format(barcode)) outfile.write("{},".format(cell_index.get(barcode, "None"))) values = [ str( int(species in cell_barcodes and barcode in cell_barcodes[species])) for species in species_list ] outfile.write(",".join(values)) if len(species_list) > 1: for species in species_list: outfile.write(",{:d}".format( barcode_counts_by_species[species][barcode])) outfile.write(",{:d}".format( targeted_counts_by_species[species][barcode])) outfile.write("\n") # process data into summary metrics summary_info = {} summary_info.update( generate_cell_calling_metrics(parameters, cell_barcodes)) summary_info.update(generate_gb_metrics(cell_barcodes, excluded_barcodes)) with open(outs.cell_calling_summary, 'w') as outfile: outfile.write(json.dumps(summary_info, indent=4))
def join(args, outs, chunk_defs, chunk_outs): args.coerce_strings() outs.coerce_strings() if args.fragments is None: outs.low_targeting_barcodes = None outs.low_targeting_summary = None return # Merge the chunk inputs ref = ReferenceManager(args.reference_path) species_list = ref.list_species() barcode_counts_by_species = {species: Counter() for species in species_list} targeted_counts_by_species = {species: Counter() for species in species_list} peak_bp_by_species = {species: 0 for species in species_list} genome_bp_by_species = {species: 0 for species in species_list} fragment_lengths = {padding: Counter() for padding in PADDING_VALUES} covered_bases = {padding: Counter() for padding in PADDING_VALUES} for chunk_in, chunk_out in zip(chunk_defs, chunk_outs): species = ref.species_from_contig(chunk_in.contig) with open(chunk_out.fragment_counts, "r") as infile: barcode_counts_by_species[species] += pickle.load(infile) with open(chunk_out.targeted_counts, "r") as infile: targeted_counts_by_species[species] += pickle.load(infile) with open(chunk_out.fragment_lengths, "r") as infile: data = pickle.load(infile) for padding in PADDING_VALUES: fragment_lengths[padding] += data[padding] with open(chunk_out.covered_bases, "r") as infile: data = pickle.load(infile) for padding in PADDING_VALUES: covered_bases[padding] += data[padding] peak_bp_by_species[species] += chunk_out.peak_coverage genome_bp_by_species[species] += ref.contig_lengths[chunk_in.contig] frac_genome_in_peaks_by_species = { species: peak_bp_by_species[species] / genome_bp_by_species[species] for species in species_list } # Identify barcodes that have lower fraction of reads overlapping peaks than the # genomic coverage of the peaks low_targeting_barcodes = { "label": "low_targeting", "data": {species: {} for species in species_list} } for species in species_list: for barcode, total_count in barcode_counts_by_species[species].iteritems(): barcode_frac_peaks = ( targeted_counts_by_species[species][barcode] / total_count ) if barcode_frac_peaks < frac_genome_in_peaks_by_species[species]: low_targeting_barcodes["data"][species][barcode] = barcode_frac_peaks # Sum up the total fragment counts per barcode across all species total_barcode_counts = Counter() for species, barcode_counts in barcode_counts_by_species.iteritems(): total_barcode_counts += barcode_counts with open(outs.barcode_counts, "w") as outfile: outfile.write(json.dumps(total_barcode_counts, indent=4)) summary_data = {} for species in species_list: key_suffix = "" if len(species_list) == 1 else "_{}".format(species) summary_data["number_of_low_targeting_barcodes{}".format(key_suffix)] = len( low_targeting_barcodes["data"][species] ) summary_data[ "fraction_of_genome_within_{}bp_of_peaks{}".format(DISTANCE, key_suffix) ] = frac_genome_in_peaks_by_species[species] with open(outs.low_targeting_summary, "w") as outfile: outfile.write(json.dumps(summary_data, indent=4)) with open(outs.low_targeting_barcodes, "w") as outfile: outfile.write(json.dumps(low_targeting_barcodes, indent=4)) with open(outs.fragment_lengths, "w") as outfile: outfile.write(json.dumps(fragment_lengths, indent=4)) with open(outs.covered_bases, "w") as outfile: outfile.write(json.dumps(covered_bases, indent=4))
def get_counts_by_barcode(reference_path, peaks, fragments, fragments_index=None, contig=None, known_cells=None): """Generate targeting, raw and dup counts per barcode. If cell identity is known, then also return that as part of the counts """ def load_reference_track(track, padding=0): if track is not None: with open(track, 'r') as infile: regions = regtools.get_target_regions(infile, padding=padding) else: regions = None return regions def point_is_in_target(contig, position, target_regions): if target_regions is None: return False if contig not in target_regions: return False return target_regions[contig].contains_point(position) def fragment_overlaps_target(contig, start, stop, target_regions): if target_regions is None: return False if contig not in target_regions: return False return target_regions[contig].overlaps_region(start, stop) ref_manager = ReferenceManager(reference_path) # Load in and pad TSS/CTCF regions if present tss_regions = load_reference_track(ref_manager.tss_track, padding=2000) ctcf_regions = load_reference_track(ref_manager.ctcf_track, padding=250) # Load in regions from reference-associated tracks dnase_regions = load_reference_track(ref_manager.dnase_track) enhancer_regions = load_reference_track(ref_manager.enhancer_track) promoter_regions = load_reference_track(ref_manager.promoter_track) blacklist_regions = load_reference_track(ref_manager.blacklist_track) peak_regions = load_reference_track(peaks) # load cell - species map cell_barcodes = {} species_list = ref_manager.list_species() if known_cells is not None: with open(known_cells, 'r') as infile: for line in infile: items = line.strip("\n").split(",") for barcode in items[1:]: if barcode != "null": if barcode not in cell_barcodes: cell_barcodes[barcode] = [] cell_barcodes[barcode] += [items[0]] # get cell index cell_index = {} spnum = {species: 0 for species in species_list} for species in species_list: for barcode in cell_barcodes: if species in cell_barcodes[barcode]: label = "{}_cell_{}".format(species, spnum[species]) spnum[species] += 1 cell_index[barcode] = label if barcode not in cell_index else '_'.join([cell_index[barcode], label]) counts_by_barcode = {} tss_relpos = Counter() ctcf_relpos = Counter() read_count = 0 iterator = open_fragment_file(fragments) if contig is None else \ parsed_fragments_from_contig(contig, fragments, index=fragments_index) for contig, start, stop, barcode, dups in iterator: read_count += 2 if barcode not in counts_by_barcode: counts_by_barcode[barcode] = Counter() if known_cells is not None: cell_species = cell_barcodes.get(barcode, []) counts_by_barcode[barcode]["cell_id"] = cell_index.get(barcode, "None") for species in species_list: if species in cell_species: counts_by_barcode[barcode]["is_{}_cell_barcode".format(species)] = 1 else: counts_by_barcode[barcode]["is_{}_cell_barcode".format(species)] = 0 # species splits if known_cells is not None and len(species_list) > 1: contig_species = ref_manager.species_from_contig(contig) counts_by_barcode[barcode]["passed_filters_{}".format(contig_species)] += 1 if fragment_overlaps_target(contig, start, stop, peak_regions): counts_by_barcode[barcode]["peak_region_fragments_{}".format(contig_species)] += 1 # raw mapping counts_by_barcode[barcode]["passed_filters"] += 1 counts_by_barcode[barcode]["total"] += dups counts_by_barcode[barcode]["duplicate"] += dups - 1 # Count up transposition site targeting for position in (start, stop): if point_is_in_target(contig, position, tss_regions): region = tss_regions[contig].get_region_containing_point(position) tss_relpos[region.get_relative_position(position)] += 1 if point_is_in_target(contig, position, ctcf_regions): region = ctcf_regions[contig].get_region_containing_point(position) ctcf_relpos[region.get_relative_position(position)] += 1 if point_is_in_target(contig, position, peak_regions): counts_by_barcode[barcode]["peak_region_cutsites"] += 1 # Count up fragment overlap targeting is_targeted = False if fragment_overlaps_target(contig, start, stop, tss_regions): counts_by_barcode[barcode]["TSS_fragments"] += 1 is_targeted = True if fragment_overlaps_target(contig, start, stop, dnase_regions): counts_by_barcode[barcode]["DNase_sensitive_region_fragments"] += 1 is_targeted = True if fragment_overlaps_target(contig, start, stop, enhancer_regions): counts_by_barcode[barcode]["enhancer_region_fragments"] += 1 is_targeted = True if fragment_overlaps_target(contig, start, stop, promoter_regions): counts_by_barcode[barcode]["promoter_region_fragments"] += 1 is_targeted = True if is_targeted: counts_by_barcode[barcode]["on_target_fragments"] += 1 if fragment_overlaps_target(contig, start, stop, blacklist_regions): counts_by_barcode[barcode]["blacklist_region_fragments"] += 1 if fragment_overlaps_target(contig, start, stop, peak_regions): counts_by_barcode[barcode]["peak_region_fragments"] += 1 return read_count, counts_by_barcode, tss_relpos, ctcf_relpos