def polish_genome(image_dir, model_path, batch_size, num_workers, threads, output_dir, output_prefix, gpu_mode, device_ids, callers): """ This method provides an interface too call the predict method that generates the prediction hdf5 file. :param image_dir: Path to directory where all MarginPolish images are saved. :param model_path: Path to a trained model. :param batch_size: Batch size for minibatch processing. :param num_workers: Number of workers for minibatch processing. :param threads: Number of threads for pytorch. :param output_dir: Path to the output directory. :param output_prefix: Prefix of the output HDF5 file. :param gpu_mode: If true, predict method will use GPU. :param device_ids: List of GPU devices. :param callers: Total number of callers to use. :return: """ output_dir = FileManager.handle_output_directory(output_dir) timestr = time.strftime("%m%d%Y_%H%M%S") prediction_output_directory = output_dir + "/predictions_" + str(timestr) + "/" prediction_output_directory = FileManager.handle_output_directory(prediction_output_directory) sys.stderr.write(TextColor.GREEN + "INFO: RUN-ID: " + str(timestr) + "\n" + TextColor.END) sys.stderr.write(TextColor.GREEN + "INFO: PREDICTION OUTPUT DIRECTORY: " + str(prediction_output_directory) + "\n" + TextColor.END) call_consensus_start_time = time.time() sys.stderr.write(TextColor.GREEN + "INFO: CALL CONSENSUS STARTING\n" + TextColor.END) call_consensus(image_dir, model_path, batch_size, num_workers, threads, prediction_output_directory, output_prefix, gpu_mode, device_ids, callers) call_consensus_end_time = time.time() stitch_start_time = time.time() sys.stderr.write(TextColor.GREEN + "INFO: STITCH STARTING\n" + TextColor.END) print(prediction_output_directory) perform_stitch(prediction_output_directory, output_dir, output_prefix, threads) stitch_end_time = time.time() call_consensus_time = get_elapsed_time_string(call_consensus_start_time, call_consensus_end_time) stitch_time = get_elapsed_time_string(stitch_start_time, stitch_end_time) overall_time = get_elapsed_time_string(call_consensus_start_time, stitch_end_time) sys.stderr.write(TextColor.GREEN + "INFO: FINISHED PROCESSING.\n" + TextColor.END) sys.stderr.write(TextColor.GREEN + "INFO: TOTAL TIME ELAPSED: " + str(overall_time) + "\n" + TextColor.END) sys.stderr.write(TextColor.GREEN + "INFO: PREDICTION TIME: " + str(call_consensus_time) + "\n" + TextColor.END) sys.stderr.write(TextColor.GREEN + "INFO: STITCH TIME: " + str(stitch_time) + "\n" + TextColor.END)
def train_interface(train_dir, test_dir, gpu_mode, device_ids, epoch_size, batch_size, num_workers, output_dir, retrain_model, retrain_model_path): """ Interface to perform training :param train_dir: Path to directory containing training images :param test_dir: Path to directory containing training images :param gpu_mode: GPU mode :param device_ids: Device IDs of devices to use for GPU inference :param epoch_size: Number of epochs to train on :param batch_size: Batch size :param num_workers: Number of workers for data loading :param output_dir: Path to directory to save model :param retrain_model: If you want to retrain an existing model :param retrain_model_path: Path to the model you want to retrain :return: """ model_out_dir, stats_dir = FileManager.handle_train_output_directory( output_dir) tm = TrainModule(train_dir, test_dir, gpu_mode, device_ids, epoch_size, batch_size, num_workers, retrain_model, retrain_model_path, model_out_dir, stats_dir) if gpu_mode: tm.train_model_gpu() else: tm.train_model()
def __init__(self, image_directory): """ This method initializes the dataset by loading all the image information. It creates a sequential list call all_images from which we can grab images iteratively through __getitem__. :param image_directory: Path to a directory where all the images are saved. """ # transformer to convert loaded objects to tensors self.transform = transforms.Compose([transforms.ToTensor()]) # a list of file-image pairs, where we have (file_name, image_name) as values so we can fetch images # from the list of files. file_image_pair = [] # get all the h5 files that we have in the directory hdf_files = FileManager.get_file_paths_from_directory(image_directory) for hdf5_file_path in hdf_files: # for each of the files get all the images with h5py.File(hdf5_file_path, 'r') as hdf5_file: # check if marginpolish somehow generated an empty file if 'images' in hdf5_file: image_names = list(hdf5_file['images'].keys()) # save the file-image pair to the list for image_name in image_names: file_image_pair.append((hdf5_file_path, image_name)) else: sys.stderr.write(TextColor.YELLOW + "WARN: NO IMAGES FOUND IN FILE: " + hdf5_file_path + "\n" + TextColor.END) # save the list to all_images so we can access the list inside other methods self.all_images = file_image_pair
def test_interface(test_file, batch_size, gpu_mode, num_workers, model_path, output_directory, print_details): """ Test a trained model :param test_file: Path to directory containing test images :param batch_size: Batch size for training :param gpu_mode: If true the model will be trained on GPU :param num_workers: Number of workers for data loading :param model_path: Path to a saved model :param output_directory: Path to output_directory :param print_details: If true then logs will be printed :return: """ sys.stderr.write(TextColor.PURPLE + 'Loading data\n' + TextColor.END) output_directory = FileManager.handle_output_directory(output_directory) if os.path.isfile(model_path) is False: sys.stderr.write(TextColor.RED + "ERROR: INVALID PATH TO MODEL\n") exit(1) sys.stderr.write(TextColor.GREEN + "INFO: MODEL LOADING\n" + TextColor.END) transducer_model, hidden_size, gru_layers, prev_ite = \ ModelHandler.load_simple_model(model_path, input_channels=ImageSizeOptions.IMAGE_CHANNELS, image_features=ImageSizeOptions.IMAGE_HEIGHT, seq_len=ImageSizeOptions.SEQ_LENGTH, num_base_classes=ImageSizeOptions.TOTAL_BASE_LABELS, num_rle_classes=ImageSizeOptions.TOTAL_RLE_LABELS) sys.stderr.write(TextColor.GREEN + "INFO: MODEL LOADED\n" + TextColor.END) if print_details and gpu_mode: sys.stderr.write(TextColor.GREEN + "INFO: GPU MODE NOT AVAILABLE WHEN PRINTING DETAILS. " "SETTING GPU MODE TO FALSE.\n" + TextColor.END) gpu_mode = False if gpu_mode: transducer_model = transducer_model.cuda() stats_dictionary = test( test_file, batch_size, gpu_mode, transducer_model, num_workers, gru_layers, hidden_size, num_base_classes=ImageSizeOptions.TOTAL_BASE_LABELS, num_rle_classes=ImageSizeOptions.TOTAL_RLE_LABELS, output_directory=output_directory, print_details=print_details) save_rle_confusion_matrix(stats_dictionary, output_directory) save_base_confusion_matrix(stats_dictionary, output_directory) sys.stderr.write(TextColor.PURPLE + 'DONE\n' + TextColor.END)
def create_consensus_sequence(self, contig, sequence_chunk_keys, threads): """ This is the consensus sequence create method that creates a sequence for a given contig. :param contig: Contig name :param sequence_chunk_keys: All the chunk keys in the contig :param threads: Number of available threads :return: A consensus sequence for a contig """ # first we sort the sequence chunks sequence_chunk_key_list = list() # then we split the chunks to so get contig name, start and end positions so we can sort them properly for hdf5_file, chunk_key, st, end in sequence_chunk_keys: sequence_chunk_key_list.append( (contig, hdf5_file, chunk_key, int(st), int(end))) # we sort based on positions sequence_chunk_key_list = sorted(sequence_chunk_key_list, key=lambda element: (element[3], element[4])) sequence_chunks = list() # we submit the chunks in process pool with concurrent.futures.ProcessPoolExecutor( max_workers=threads) as executor: # this chunks the keys into sucessive chunks file_chunks = FileManager.chunks( sequence_chunk_key_list, max(StitchOptions.MIN_SEQUENCE_REQUIRED_FOR_MULTITHREADING, int(len(sequence_chunk_key_list) / threads) + 1)) # we do the stitching per chunk of keys futures = [ executor.submit(self.small_chunk_stitch, contig, file_chunk) for file_chunk in file_chunks ] # as they complete we add them to a list for fut in concurrent.futures.as_completed(futures): if fut.exception() is None: contig, contig_start, contig_end, sequence = fut.result() sequence_chunks.append( (contig, contig_start, contig_end, sequence)) else: sys.stderr.write("ERROR: " + str(fut.exception()) + "\n") fut._result = None # python issue 27144 sequence_chunks = sorted(sequence_chunks, key=lambda element: (element[1], element[2])) # and do a final stitching on all the sequences we generated contig, contig_start, contig_end, sequence = self.alignment_stitch( sequence_chunks) return sequence
def download_models(output_dir): output_dir = FileManager.handle_output_directory(output_dir) sys.stderr.write(TextColor.YELLOW + "DOWNLOADING MODEL DESCRIPTION FILE" + TextColor.END + "\n") description_file = "https://storage.googleapis.com/kishwar-helen/models_helen/mp_helen_model_description.csv" wget.download(description_file, output_dir) sys.stderr.write("\n") sys.stderr.flush() with open(output_dir + '/mp_helen_model_description.csv') as f: models = [line.rstrip() for line in f] os.remove(output_dir + '/mp_helen_model_description.csv') for model in models: model_name, model_url = model.split(',') sys.stderr.write("INFO: DOWNLOADING FILE: " + str(model_name) + ".pkl\n") sys.stderr.write("INFO: DOWNLOADING LINK: " + str(model_url) + "\n") wget.download(model_url, output_dir) sys.stderr.write("\n") sys.stderr.flush()
def perform_stitch(input_directory, output_path, output_prefix, threads): """ This method gathers all contigs and calls the stitch module for each contig. :param input_directory: Path to the directory containing input files. :param output_path: Path to the output_consensus_sequence :param output_prefix: Output file's prefix :param threads: Number of threads to use :return: """ # get all the files all_prediction_files = get_file_paths_from_directory(input_directory) # we gather all the contigs all_contigs = set() # get contigs from all of the files for prediction_file in sorted(all_prediction_files): with h5py.File(prediction_file, 'r') as hdf5_file: if 'predictions' in hdf5_file: contigs = list(hdf5_file['predictions'].keys()) all_contigs.update(contigs) else: raise ValueError( TextColor.RED + "ERROR: INVALID HDF5 FILE, FILE DOES NOT CONTAIN predictions KEY.\n" + TextColor.END) # convert set to a list all_contigs = list(all_contigs) # get output directory output_dir = FileManager.handle_output_directory(output_path) # open an output fasta file # we should really use a fasta handler for this, I don't like this. output_filename = os.path.join(output_dir, output_prefix + '.fa') consensus_fasta_file = open(output_filename, 'w') sys.stderr.write(TextColor.GREEN + "INFO: OUTPUT FILE: " + output_filename + "\n" + TextColor.END) # for each contig for i, contig in enumerate(sorted(all_contigs)): log_prefix = "{:04d}".format(i) + "/" + "{:04d}".format( len(contigs)) + ":" sys.stderr.write(TextColor.GREEN + "INFO: " + str(log_prefix) + " PROCESSING CONTIG: " + contig + "\n" + TextColor.END) # get all the chunk keys chunk_name_tuple = list() for prediction_file in all_prediction_files: with h5py.File(prediction_file, 'r') as hdf5_file: # check if the contig is contained in this file if contig not in list(hdf5_file['predictions'].keys()): continue # if contained then get the chunks chunk_keys = sorted(hdf5_file['predictions'][contig].keys()) for chunk_key in chunk_keys: chunk_contig_start = hdf5_file['predictions'][contig][ chunk_key]['contig_start'][()] chunk_contig_end = hdf5_file['predictions'][contig][ chunk_key]['contig_end'][()] chunk_name_tuple.append( (prediction_file, chunk_key, chunk_contig_start, chunk_contig_end)) # call stitch to generate a sequence for this contig stich_object = Stitch() consensus_sequence = stich_object.create_consensus_sequence( contig, chunk_name_tuple, threads) sys.stderr.write(TextColor.BLUE + "INFO: " + str(log_prefix) + " FINISHED PROCESSING " + contig + ", POLISHED SEQUENCE LENGTH: " + str(len(consensus_sequence)) + ".\n" + TextColor.END) # if theres a sequence then write it to the file if consensus_sequence is not None and len(consensus_sequence) > 0: consensus_fasta_file.write('>' + contig + "\n") consensus_fasta_file.write(consensus_sequence + "\n")
def call_consensus(image_dir, model_path, batch_size, num_workers, threads, output_dir, output_prefix, gpu_mode, device_ids, callers): """ This method provides an interface too call the predict method that generates the prediction hdf5 file :param image_dir: Path to directory where all MarginPolish images are saved :param model_path: Path to a trained model :param batch_size: Batch size for minibatch processing :param num_workers: Number of workers for minibatch processing :param threads: Number of threads for pytorch :param output_dir: Path to the output directory :param output_prefix: Prefix of the output HDF5 file :param gpu_mode: If true, predict method will use GPU. :param device_ids: List of CUDA devices to use. :param callers: Total number of callers. :return: """ # check the model file if not os.path.isfile(model_path): sys.stderr.write(TextColor.RED + "ERROR: CAN NOT LOCATE MODEL FILE.\n" + TextColor.END) exit(1) # check the input directory if not os.path.isdir(image_dir): sys.stderr.write(TextColor.RED + "ERROR: CAN NOT LOCATE IMAGE DIRECTORY.\n" + TextColor.END) exit(1) # check batch_size if batch_size <= 0: sys.stderr.write(TextColor.RED + "ERROR: batch_size NEEDS TO BE >0.\n" + TextColor.END) exit(1) # check num_workers if num_workers < 0: sys.stderr.write(TextColor.RED + "ERROR: num_workers NEEDS TO BE >=0.\n" + TextColor.END) exit(1) # check number of threads if threads <= 0: sys.stderr.write(TextColor.RED + "ERROR: THREAD NEEDS TO BE >=0.\n" + TextColor.END) exit(1) output_dir = FileManager.handle_output_directory(output_dir) # create a filename for the output file output_filename = os.path.join(output_dir, output_prefix) # inform the output directory sys.stderr.write(TextColor.GREEN + "INFO: " + TextColor.END + "OUTPUT FILE: " + output_filename + "\n") if gpu_mode: # Make sure that GPU is if not torch.cuda.is_available(): sys.stderr.write(TextColor.RED + "ERROR: TORCH IS NOT BUILT WITH CUDA.\n" + TextColor.END) sys.stderr.write( TextColor.RED + "SEE TORCH CAPABILITY:\n$ python3\n" ">>> import torch \n" ">>> torch.cuda.is_available()\n If true then cuda is avilable" + TextColor.END) exit(1) # Now see which devices to use if device_ids is None: total_gpu_devices = torch.cuda.device_count() sys.stderr.write(TextColor.GREEN + "INFO: TOTAL GPU AVAILABLE: " + str(total_gpu_devices) + "\n" + TextColor.END) device_ids = [i for i in range(0, total_gpu_devices)] callers = total_gpu_devices else: device_ids = [int(i) for i in device_ids.split(',')] for device_id in device_ids: major_capable, minor_capable = torch.cuda.get_device_capability( device=device_id) if major_capable < 0: sys.stderr.write(TextColor.RED + "ERROR: GPU DEVICE: " + str(device_id) + " IS NOT CUDA CAPABLE.\n" + TextColor.END) sys.stderr.write( TextColor.GREEN + "Try running: $ python3\n" ">>> import torch \n" ">>> torch.cuda.get_device_capability(device=" + str(device_id) + ")\n" + TextColor.END) else: sys.stderr.write(TextColor.GREEN + "INFO: CAPABILITY OF GPU#" + str(device_id) + ":\t" + str(major_capable) + "-" + str(minor_capable) + "\n" + TextColor.END) callers = len(device_ids) sys.stderr.write(TextColor.GREEN + "INFO: AVAILABLE GPU DEVICES: " + str(device_ids) + "\n" + TextColor.END) threads_per_caller = 0 else: # calculate how many threads each caller can use threads_per_caller = int(threads / callers) device_ids = [] # chunk the inputs input_files = get_file_paths_from_directory(image_dir) # generate file chunks to process in parallel file_chunks = [[] for i in range(callers)] for i in range(0, len(input_files)): file_chunks[i % callers].append(input_files[i]) # get the file chunks file_chunks = [ file_chunks[i] for i in range(len(file_chunks)) if len(file_chunks[i]) > 0 ] callers = len(file_chunks) if gpu_mode: # Distributed GPU setup predict_gpu(file_chunks, output_filename, model_path, batch_size, callers, device_ids, num_workers) else: # distributed CPU setup, call the prediction function predict_cpu(file_chunks, output_filename, model_path, batch_size, callers, threads_per_caller, num_workers) # notify the user that process has completed successfully sys.stderr.write(TextColor.GREEN + "INFO: " + TextColor.END + "PREDICTION GENERATED SUCCESSFULLY.\n")