def __init__(self, chromosome_name, bam_file_path, draft_file_path, truth_bam, train_mode): """ Initialize a manager object :param chromosome_name: Name of the chromosome :param bam_file_path: Path to the BAM file :param draft_file_path: Path to the reference FASTA file :param truth_bam: Path to the truth sequence to reference mapping file """ # --- initialize handlers --- # create objects to handle different files and query self.bam_path = bam_file_path self.fasta_path = draft_file_path self.bam_handler = PEPPER.BAM_handler(bam_file_path) self.fasta_handler = PEPPER.FASTA_handler(draft_file_path) self.train_mode = train_mode self.downsample_rate = 1.0 self.truth_bam_handler = None if self.train_mode: self.truth_bam_handler = PEPPER.BAM_handler(truth_bam) # --- initialize names --- # name of the chromosome self.chromosome_name = chromosome_name
def make_train_images(bam_filepath, truth_bam_filepath, fasta_filepath, region, region_bed, output_dir, threads): """ GENERATE IMAGES WITHOUT ANY LABELS. THIS IS USED BY pepper.py :param bam_filepath: Path to the input bam file. :param truth_bam_filepath: Path to the bam where truth is aligned to the assembly. :param fasta_filepath: Path to the input fasta file. :param region: Specific region of interest. :param output_dir: Path to the output directory. :param threads: Number of threads to use. :return: """ # check the bam file if not os.path.isfile(bam_filepath) or not PEPPER.BAM_handler( bam_filepath): sys.stderr.write("[" + str(datetime.now().strftime('%m-%d-%Y %H:%M:%S')) + "] ERROR: CAN NOT LOCATE BAM FILE.\n") exit(1) # check the truth bam file if not os.path.isfile(truth_bam_filepath) or not PEPPER.BAM_handler( truth_bam_filepath): sys.stderr.write("[" + str(datetime.now().strftime('%m-%d-%Y %H:%M:%S')) + "] ERROR: CAN NOT LOCATE TRUTH BAM FILE.\n") exit(1) # check the fasta file if not os.path.isfile(fasta_filepath): sys.stderr.write("[" + str(datetime.now().strftime('%m-%d-%Y %H:%M:%S')) + "] ERROR: CAN NOT LOCATE FASTA FILE.\n") exit(1) # check the output directory output_dir = UserInterfaceSupport.handle_output_directory( os.path.abspath(output_dir)) # check number of threads if threads <= 0: sys.stderr.write("[" + str(datetime.now().strftime('%m-%d-%Y %H:%M:%S')) + "] ERROR: THREAD NEEDS TO BE >=0.\n") exit(1) # get the list of contigs contig_list = UserInterfaceSupport.get_chromosome_list( region, fasta_filepath, bam_filepath, region_bed) # call the parallelization method to generate images in parallel UserInterfaceSupport.chromosome_level_parallelization( contig_list, bam_filepath, fasta_filepath, truth_bam=truth_bam_filepath, output_path=output_dir, total_threads=threads, train_mode=True)
def reads_to_reference_realignment(self, region_start, region_end, reads): # PERFORMS LOCAL REALIGNMENT OF READS TO THE REFERENCE if not reads: return [] ref_start = region_start ref_end = region_end + AlingerOptions.ALIGNMENT_SAFE_BASES ref_sequence = self.fasta_handler.get_reference_sequence( self.chromosome_name, ref_start, ref_end) aligner = PEPPER.ReadAligner(ref_start, ref_end, ref_sequence) realigned_reads = aligner.align_reads_to_reference(reads) # generate_pileup_from_reads.pileup_from_reads(ref_sequence, ref_start, ref_end, realigned_reads) return realigned_reads
def alignment_stitch(sequence_chunks): sequence_chunks = sorted(sequence_chunks, key=lambda element: (element[1], element[2])) contig, running_start, running_end, running_sequence = sequence_chunks[0] # if len(running_sequence) < 500: # sys.stderr.write("ERROR: CURRENT SEQUENCE LENGTH TOO SHORT: " + sequence_chunk_keys[0] + "\n") # exit() aligner = PEPPER.Aligner(MATCH_PENALTY, MISMATCH_PENALTY, GAP_PENALTY, GAP_EXTEND_PENALTY) filter = PEPPER.Filter() for i in range(1, len(sequence_chunks)): _, this_start, this_end, this_sequence = sequence_chunks[i] if this_start < running_end: # overlap overlap_bases = running_end - this_start overlap_bases = overlap_bases + int( overlap_bases * BASE_ERROR_RATE) reference_sequence = running_sequence[-overlap_bases:] read_sequence = this_sequence[:overlap_bases] alignment = PEPPER.Alignment() aligner.SetReferenceSequence(reference_sequence, len(reference_sequence)) aligner.Align_cpp(read_sequence, filter, alignment, 0) if alignment.best_score == 0: # we are going to say that the left sequence is right left_sequence = running_sequence # we are going to say right sequence is also right right_sequence = this_sequence # but there are 10 'N's as overlaps overlap_sequence = 10 * 'N' # now append all three parts and we have a contiguous sequence running_sequence = left_sequence + overlap_sequence + right_sequence running_end = this_end else: pos_a, pos_b = get_confident_positions(alignment) if pos_a == -1 or pos_b == -1: # we are going to say that the left sequence is right left_sequence = running_sequence # we are going to say right sequence is also right right_sequence = this_sequence # but there are 10 'N's as overlaps overlap_sequence = 10 * 'N' # now append all three parts and we have a contiguous sequence running_sequence = left_sequence + overlap_sequence + right_sequence running_end = this_end else: # this is a perfect match so we can simply stitch them # take all of the sequence from the left left_sequence = running_sequence[:-overlap_bases] # get the bases that overlapped overlap_sequence = reference_sequence[:pos_a] # get sequences from current sequence right_sequence = this_sequence[pos_b:] # now append all three parts and we have a contiguous sequence running_sequence = left_sequence + overlap_sequence + right_sequence running_end = this_end else: # this means there was a gap before this chunk, which could be low read coverage in a small contig. running_sequence = running_sequence + this_sequence running_end = this_end return contig, running_start, running_end, running_sequence
def polish(bam_filepath, fasta_filepath, output_path, threads, region, model_path, batch_size, gpu_mode, device_ids, num_workers): """ Run all the sub-modules to polish an input assembly. """ # check the bam file if not os.path.isfile(bam_filepath) or not PEPPER.BAM_handler( bam_filepath): sys.stderr.write("[" + str(datetime.now().strftime('%m-%d-%Y %H:%M:%S')) + "] ERROR: CAN NOT LOCATE BAM FILE.\n") exit(1) # check the fasta file if not os.path.isfile(fasta_filepath): sys.stderr.write("[" + str(datetime.now().strftime('%m-%d-%Y %H:%M:%S')) + "] ERROR: CAN NOT LOCATE FASTA FILE.\n") exit(1) # check the model file if not os.path.isfile(model_path): sys.stderr.write("[" + str(datetime.now().strftime('%m-%d-%Y %H:%M:%S')) + "] ERROR: CAN NOT LOCATE MODEL FILE.\n") exit(1) # check number of threads if threads <= 0: sys.stderr.write("[" + str(datetime.now().strftime('%m-%d-%Y %H:%M:%S')) + "] ERROR: THREAD NEEDS TO BE >=0.\n") exit(1) # check batch_size if batch_size <= 0: sys.stderr.write("[" + str(datetime.now().strftime('%m-%d-%Y %H:%M:%S')) + "] ERROR: batch_size NEEDS TO BE >0.\n") exit(1) # check num_workers if num_workers < 0: sys.stderr.write("[" + str(datetime.now().strftime('%m-%d-%Y %H:%M:%S')) + "] ERROR: num_workers NEEDS TO BE >=0.\n") exit(1) callers = threads threads_per_caller = int(threads / max(1, callers)) # check number of threads if threads_per_caller <= 0: sys.stderr.write("[" + str(datetime.now().strftime('%m-%d-%Y %H:%M:%S')) + "] ERROR: THREAD PER CALLER NEEDS TO BE >=0.\n") exit(1) # check if gpu inference can be done if gpu_mode: if not torch.cuda.is_available(): sys.stderr.write( "[" + str(datetime.now().strftime('%m-%d-%Y %H:%M:%S')) + "] ERROR: TORCH IS NOT BUILT WITH CUDA.\n") sys.stderr.write( "SEE TORCH CAPABILITY:\n$ python3\n" ">>> import torch \n" ">>> torch.cuda.is_available()\n If true then cuda is avilable" ) exit(1) # check if all devices are available if device_ids is not None: device_ids = [int(i) for i in device_ids.split(',')] for device_id in device_ids: major_capable, minor_capable = torch.cuda.get_device_capability( device=device_id) if major_capable < 0: sys.stderr.write( "[" + str(datetime.now().strftime('%m-%d-%Y %H:%M:%S')) + "] ERROR: GPU DEVICE: " + str(device_id) + " IS NOT CUDA CAPABLE.\n") sys.stderr.write( "Try running: $ python3\n" ">>> import torch \n" ">>> torch.cuda.get_device_capability(device=" + str(device_id) + ")\n") exit(1) else: sys.stderr.write( "[" + str(datetime.now().strftime('%m-%d-%Y %H:%M:%S')) + "] INFO: CAPABILITY OF GPU#" + str(device_id) + ":\t" + str(major_capable) + "-" + str(minor_capable) + "\n") timestr = time.strftime("%m%d%Y_%H%M%S") # run directories output_dir = UserInterfaceSupport.handle_output_directory(output_path) image_output_directory = output_dir + "images_" + str(timestr) + "/" prediction_output_directory = output_dir + "predictions_" + str( timestr) + "/" sys.stderr.write("[" + str(datetime.now().strftime('%m-%d-%Y %H:%M:%S')) + "] INFO: RUN-ID: " + str(timestr) + "\n") sys.stderr.write("[" + str(datetime.now().strftime('%m-%d-%Y %H:%M:%S')) + "] INFO: IMAGE OUTPUT: " + str(image_output_directory) + "\n") sys.stderr.write("[" + str(datetime.now().strftime('%m-%d-%Y %H:%M:%S')) + "] STEP 1: GENERATING IMAGES\n") sys.stderr.flush() # call the parallelization method to generate images in parallel make_images(bam_filepath, fasta_filepath, region, image_output_directory, threads) sys.stderr.write("[" + str(datetime.now().strftime('%m-%d-%Y %H:%M:%S')) + "] STEP 2: RUNNING INFERENCE\n") sys.stderr.write("[" + str(datetime.now().strftime('%m-%d-%Y %H:%M:%S')) + "] INFO: PREDICTION OUTPUT: " + str(prediction_output_directory) + "\n") sys.stderr.flush() call_consensus(image_output_directory, model_path, batch_size, num_workers, prediction_output_directory, device_ids, gpu_mode, threads) sys.stderr.write("[" + str(datetime.now().strftime('%m-%d-%Y %H:%M:%S')) + "] STEP 3: RUNNING STITCH\n") sys.stderr.write("[" + str(datetime.now().strftime('%m-%d-%Y %H:%M:%S')) + "] INFO: STITCH OUTPUT: " + str(output_dir) + "\n") sys.stderr.flush() perform_stitch(prediction_output_directory, output_dir, threads)
def get_chromosome_list(chromosome_names, ref_file, bam_file, region_bed): """ PARSES THROUGH THE CHROMOSOME PARAMETER TO FIND OUT WHICH REGIONS TO PROCESS :param chromosome_names: NAME OF CHROMOSOME :param ref_file: PATH TO THE REFERENCE FILE :param bam_file: PATH TO BAM FILE :return: LIST OF CHROMOSOME IN REGION SPECIFIC FORMAT """ if not chromosome_names and not region_bed: fasta_handler = PEPPER.FASTA_handler(ref_file) bam_handler = PEPPER.BAM_handler(bam_file) bam_contigs = bam_handler.get_chromosome_sequence_names() fasta_contigs = fasta_handler.get_chromosome_names() common_contigs = list(set(fasta_contigs) & set(bam_contigs)) if len(common_contigs) == 0: sys.stderr.write( "[" + datetime.now().strftime('%m-%d-%Y %H:%M:%S') + "] " + "ERROR: NO COMMON CONTIGS FOUND BETWEEN THE BAM FILE AND THE FASTA FILE." ) sys.stderr.flush() exit(1) common_contigs = sorted(common_contigs, key=UserInterfaceSupport.natural_key) sys.stderr.write("[" + datetime.now().strftime('%m-%d-%Y %H:%M:%S') + "] INFO: COMMON CONTIGS FOUND: " + str(common_contigs) + "\n") sys.stderr.flush() chromosome_name_list = [] for contig_name in common_contigs: chromosome_name_list.append((contig_name, None)) return chromosome_name_list if region_bed: chromosome_name_list = [] with open(region_bed) as fp: line = fp.readline() cnt = 1 while line: line_to_list = line.rstrip().split('\t') chr_name, start_pos, end_pos = line_to_list[0], int( line_to_list[1]), int(line_to_list[2]) region = sorted([start_pos, end_pos]) chromosome_name_list.append((chr_name, region)) line = fp.readline() cnt += 1 return chromosome_name_list split_names = chromosome_names.strip().split(',') split_names = [name.strip() for name in split_names] chromosome_name_list = [] for name in split_names: # split on region region = None if ':' in name: name_region = name.strip().split(':') if len(name_region) != 2: sys.stderr.write("ERROR: --region INVALID value.\n") exit(0) name, region = tuple(name_region) region = region.strip().split('-') region = [int(pos) for pos in region] if len(region) != 2 or not region[0] <= region[1]: sys.stderr.write("ERROR: --region INVALID value.\n") exit(0) range_split = name.split('-') if len(range_split) > 1: chr_prefix = '' for p in name: if p.isdigit(): break else: chr_prefix = chr_prefix + p int_ranges = [] for item in range_split: s = ''.join(i for i in item if i.isdigit()) int_ranges.append(int(s)) int_ranges = sorted(int_ranges) for chr_seq in range(int_ranges[0], int_ranges[-1] + 1): chromosome_name_list.append( (chr_prefix + str(chr_seq), region)) else: chromosome_name_list.append((name, region)) return chromosome_name_list
def chromosome_level_parallelization(chr_list, bam_file, draft_file, truth_bam, output_path, total_threads, train_mode): if train_mode: max_size = 1000 else: max_size = 1000 start_time = time.time() fasta_handler = PEPPER.FASTA_handler(draft_file) contigs = set() all_intervals = [] # first calculate all the intervals that we need to process for chr_name, region in chr_list: # contig update message contigs.add(str(chr_name)) if not region: interval_start, interval_end = ( 0, fasta_handler.get_chromosome_sequence_length( str(chr_name)) - 1) else: interval_start, interval_end = tuple(region) interval_start = max(0, interval_start) interval_end = min( interval_end, fasta_handler.get_chromosome_sequence_length(str(chr_name)) - 1) # this is the interval size each of the process is going to get which is 10^6 # I will split this into 10^4 size inside the worker process for pos in range(interval_start, interval_end, max_size): pos_start = max(interval_start, pos - ImageSizeOptions.MIN_IMAGE_OVERLAP) pos_end = min( interval_end, pos + max_size + ImageSizeOptions.MIN_IMAGE_OVERLAP) all_intervals.append((chr_name, pos_start, pos_end)) # all intervals calculated now # contig update message sys.stderr.write("[" + datetime.now().strftime('%m-%d-%Y %H:%M:%S') + "] " + "INFO: TOTAL CONTIGS: " + str(len(contigs)) + " TOTAL INTERVALS: " + str(len(all_intervals)) + "\n") sys.stderr.flush() args = (output_path, bam_file, draft_file, truth_bam, train_mode) with concurrent.futures.ProcessPoolExecutor( max_workers=total_threads) as executor: futures = [ executor.submit(UserInterfaceSupport.image_generator, args, all_intervals, total_threads, thread_id) for thread_id in range(0, total_threads) ] for fut in concurrent.futures.as_completed(futures): if fut.exception() is None: # get the results thread_id = fut.result() sys.stderr.write( "[" + datetime.now().strftime('%m-%d-%Y %H:%M:%S') + "] " + "INFO: THREAD " + str(thread_id) + " FINISHED SUCCESSFULLY.\n") else: sys.stderr.write( "[" + str(datetime.now().strftime('%m-%d-%Y %H:%M:%S')) + "] ERROR: " + str(fut.exception()) + "\n") fut._result = None # python issue 27144 end_time = time.time() mins = int((end_time - start_time) / 60) secs = int((end_time - start_time)) % 60 sys.stderr.write("[" + str(datetime.now().strftime('%m-%d-%Y %H:%M:%S')) + "] INFO: FINISHED IMAGE GENERATION\n") sys.stderr.write("[" + str(datetime.now().strftime('%m-%d-%Y %H:%M:%S')) + "] INFO: ELAPSED TIME: " + str(mins) + " Min " + str(secs) + " Sec\n")
def create_summary(self, truth_bam_handler, train_mode, realignment_flag=True): log_prefix = "[" + self.chromosome_name + ":" + str(self.region_start_position) + "-" \ + str(self.region_end_position) + "]" all_images = [] all_labels = [] all_positions = [] all_image_chunk_ids = [] if train_mode: # get the reads from the bam file include_supplementary = True min_mapq = 60 min_baseq = 0 truth_reads = truth_bam_handler.get_reads( self.chromosome_name, self.region_start_position, self.region_end_position, include_supplementary, min_mapq, min_baseq) # do a local realignment of truth reads to reference if realignment_flag: truth_reads = self.reads_to_reference_realignment( self.region_start_position, self.region_end_position, truth_reads) truth_regions = [] for read in truth_reads: # start, end, read, is_kept, is_h1 truth_regions.append([read.pos, read.pos_end - 1, read, True]) # these are all the regions we will use to generate summaries from. # It's important to notice that we need to realign the reads to the reference before we do that. truth_regions = self.remove_conflicting_regions(truth_regions) if not truth_regions: # sys.stderr.write(TextColor.GREEN + "INFO: " + log_prefix + " NO TRAINING REGION FOUND.\n" # + TextColor.END) return [], [], [], [] for region in truth_regions: region_start, region_end, truth_read, is_kept = tuple(region) if not is_kept: continue ref_start = region_start ref_end = region_end + 1 # ref_seq should contain region_end_position base ref_seq = self.fasta_handler.get_reference_sequence( self.chromosome_name, ref_start, ref_end) read_start = max(0, region_start) read_end = region_end include_supplementary = False all_reads = self.bam_handler.get_reads(self.chromosome_name, read_start, read_end, include_supplementary, 0, 0) total_reads = len(all_reads) if total_reads == 0: continue if total_reads > AlingerOptions.MAX_READS_IN_REGION: # https://github.com/google/nucleus/blob/master/nucleus/util/utils.py # reservoir_sample method utilized here random = np.random.RandomState(AlingerOptions.RANDOM_SEED) sample = [] for i, read in enumerate(all_reads): if len(sample) < AlingerOptions.MAX_READS_IN_REGION: sample.append(read) else: j = random.randint(0, i + 1) if j < AlingerOptions.MAX_READS_IN_REGION: sample[j] = read all_reads = sample # sys.stderr.write(TextColor.GREEN + "INFO: " + log_prefix + " TOTAL " + str(total_reads) # + " READS FOUND.\n" + TextColor.END) start_time = time.time() if realignment_flag: all_reads = self.reads_to_reference_realignment( read_start, read_end, all_reads) # sys.stderr.write(TextColor.GREEN + "INFO: " + log_prefix + " REALIGNMENT OF TOTAL " # + str(total_reads) + " READS TOOK: " + str(round(time.time()-start_time, 5)) # + " secs\n" + TextColor.END) summary_generator = PEPPER.SummaryGenerator( ref_seq, self.chromosome_name, ref_start, ref_end) summary_generator.generate_train_summary( all_reads, region_start, region_end, truth_read) images, labels, positions, chunk_ids = self.chunk_images_train( summary_generator, chunk_size=ImageSizeOptions.SEQ_LENGTH, chunk_overlap=ImageSizeOptions.SEQ_OVERLAP) all_images.extend(images) all_labels.extend(labels) all_positions.extend(positions) all_image_chunk_ids.extend(chunk_ids) else: # HERE REALIGN THE READS TO THE REFERENCE THEN GENERATE THE SUMMARY TO GET A POLISHED HAPLOTYPE read_start = max(0, self.region_start_position) read_end = self.region_end_position include_supplementary = False all_reads = self.bam_handler.get_reads(self.chromosome_name, read_start, read_end, include_supplementary, 0, 0) total_reads = len(all_reads) if total_reads == 0: return [], [], [], [] if total_reads > AlingerOptions.MAX_READS_IN_REGION: # https://github.com/google/nucleus/blob/master/nucleus/util/utils.py # reservoir_sample method utilized here random = np.random.RandomState(AlingerOptions.RANDOM_SEED) sample = [] for i, read in enumerate(all_reads): if len(sample) < AlingerOptions.MAX_READS_IN_REGION: sample.append(read) else: j = random.randint(0, i + 1) if j < AlingerOptions.MAX_READS_IN_REGION: sample[j] = read all_reads = sample # sys.stderr.write(TextColor.PURPLE + "INFO: " + log_prefix + " TOTAL " + str(total_reads) + " READS FOUND\n" # + TextColor.END) if realignment_flag: start_time = time.time() all_reads = self.reads_to_reference_realignment( self.region_start_position, self.region_end_position, all_reads) # sys.stderr.write(TextColor.GREEN + "INFO: " + log_prefix + " REALIGNMENT OF TOTAL " + str(total_reads) # + " READS TOOK: " + str(round(time.time()-start_time, 5)) + " secs\n" + TextColor.END) # ref_seq should contain region_end_position base ref_seq = self.fasta_handler.get_reference_sequence( self.chromosome_name, self.region_start_position, self.region_end_position + 1) summary_generator = PEPPER.SummaryGenerator( ref_seq, self.chromosome_name, self.region_start_position, self.region_end_position) summary_generator.generate_summary(all_reads, self.region_start_position, self.region_end_position) images, labels, positions, chunk_ids = \ self.chunk_images(summary_generator, chunk_size=ImageSizeOptions.SEQ_LENGTH, chunk_overlap=ImageSizeOptions.SEQ_OVERLAP) all_images.extend(images) all_labels.extend(labels) all_positions.extend(positions) all_image_chunk_ids.extend(chunk_ids) assert (len(all_images) == len(all_labels) == len(all_image_chunk_ids)) return all_images, all_labels, all_positions, all_image_chunk_ids