Beispiel #1
0
def polish_genome(image_dir, model_path, batch_size, num_workers, threads, output_dir, output_prefix, gpu_mode,
                  device_ids, callers):
    """
    This method provides an interface too call the predict method that generates the prediction hdf5 file.
    :param image_dir: Path to directory where all MarginPolish images are saved.
    :param model_path: Path to a trained model.
    :param batch_size: Batch size for minibatch processing.
    :param num_workers: Number of workers for minibatch processing.
    :param threads: Number of threads for pytorch.
    :param output_dir: Path to the output directory.
    :param output_prefix: Prefix of the output HDF5 file.
    :param gpu_mode: If true, predict method will use GPU.
    :param device_ids: List of GPU devices.
    :param callers: Total number of callers to use.
    :return:
    """
    output_dir = FileManager.handle_output_directory(output_dir)
    timestr = time.strftime("%m%d%Y_%H%M%S")

    prediction_output_directory = output_dir + "/predictions_" + str(timestr) + "/"
    prediction_output_directory = FileManager.handle_output_directory(prediction_output_directory)

    sys.stderr.write(TextColor.GREEN + "INFO: RUN-ID: " + str(timestr) + "\n" + TextColor.END)
    sys.stderr.write(TextColor.GREEN + "INFO: PREDICTION OUTPUT DIRECTORY: "
                     + str(prediction_output_directory) + "\n" + TextColor.END)

    call_consensus_start_time = time.time()
    sys.stderr.write(TextColor.GREEN + "INFO: CALL CONSENSUS STARTING\n" + TextColor.END)
    call_consensus(image_dir,
                   model_path,
                   batch_size,
                   num_workers,
                   threads,
                   prediction_output_directory,
                   output_prefix,
                   gpu_mode,
                   device_ids,
                   callers)
    call_consensus_end_time = time.time()

    stitch_start_time = time.time()
    sys.stderr.write(TextColor.GREEN + "INFO: STITCH STARTING\n" + TextColor.END)
    print(prediction_output_directory)
    perform_stitch(prediction_output_directory,
                   output_dir,
                   output_prefix,
                   threads)
    stitch_end_time = time.time()

    call_consensus_time = get_elapsed_time_string(call_consensus_start_time, call_consensus_end_time)
    stitch_time = get_elapsed_time_string(stitch_start_time, stitch_end_time)
    overall_time = get_elapsed_time_string(call_consensus_start_time, stitch_end_time)

    sys.stderr.write(TextColor.GREEN + "INFO: FINISHED PROCESSING.\n" + TextColor.END)
    sys.stderr.write(TextColor.GREEN + "INFO: TOTAL TIME ELAPSED: " + str(overall_time) + "\n" + TextColor.END)
    sys.stderr.write(TextColor.GREEN + "INFO: PREDICTION TIME: " + str(call_consensus_time) + "\n" + TextColor.END)
    sys.stderr.write(TextColor.GREEN + "INFO: STITCH TIME: " + str(stitch_time) + "\n" + TextColor.END)
Beispiel #2
0
def test_interface(test_file, batch_size, gpu_mode, num_workers, model_path,
                   output_directory, print_details):
    """
    Test a trained model
    :param test_file: Path to directory containing test images
    :param batch_size: Batch size for training
    :param gpu_mode: If true the model will be trained on GPU
    :param num_workers: Number of workers for data loading
    :param model_path: Path to a saved model
    :param output_directory: Path to output_directory
    :param print_details: If true then logs will be printed
    :return:
    """
    sys.stderr.write(TextColor.PURPLE + 'Loading data\n' + TextColor.END)

    output_directory = FileManager.handle_output_directory(output_directory)

    if os.path.isfile(model_path) is False:
        sys.stderr.write(TextColor.RED + "ERROR: INVALID PATH TO MODEL\n")
        exit(1)

    sys.stderr.write(TextColor.GREEN + "INFO: MODEL LOADING\n" + TextColor.END)

    transducer_model, hidden_size, gru_layers, prev_ite = \
        ModelHandler.load_simple_model(model_path,
                                       input_channels=ImageSizeOptions.IMAGE_CHANNELS,
                                       image_features=ImageSizeOptions.IMAGE_HEIGHT,
                                       seq_len=ImageSizeOptions.SEQ_LENGTH,
                                       num_base_classes=ImageSizeOptions.TOTAL_BASE_LABELS,
                                       num_rle_classes=ImageSizeOptions.TOTAL_RLE_LABELS)

    sys.stderr.write(TextColor.GREEN + "INFO: MODEL LOADED\n" + TextColor.END)

    if print_details and gpu_mode:
        sys.stderr.write(TextColor.GREEN +
                         "INFO: GPU MODE NOT AVAILABLE WHEN PRINTING DETAILS. "
                         "SETTING GPU MODE TO FALSE.\n" + TextColor.END)
        gpu_mode = False

    if gpu_mode:
        transducer_model = transducer_model.cuda()

    stats_dictionary = test(
        test_file,
        batch_size,
        gpu_mode,
        transducer_model,
        num_workers,
        gru_layers,
        hidden_size,
        num_base_classes=ImageSizeOptions.TOTAL_BASE_LABELS,
        num_rle_classes=ImageSizeOptions.TOTAL_RLE_LABELS,
        output_directory=output_directory,
        print_details=print_details)

    save_rle_confusion_matrix(stats_dictionary, output_directory)
    save_base_confusion_matrix(stats_dictionary, output_directory)

    sys.stderr.write(TextColor.PURPLE + 'DONE\n' + TextColor.END)
Beispiel #3
0
def download_models(output_dir):
    output_dir = FileManager.handle_output_directory(output_dir)
    sys.stderr.write(TextColor.YELLOW + "DOWNLOADING MODEL DESCRIPTION FILE" +
                     TextColor.END + "\n")
    description_file = "https://storage.googleapis.com/kishwar-helen/models_helen/mp_helen_model_description.csv"
    wget.download(description_file, output_dir)
    sys.stderr.write("\n")
    sys.stderr.flush()

    with open(output_dir + '/mp_helen_model_description.csv') as f:
        models = [line.rstrip() for line in f]

    os.remove(output_dir + '/mp_helen_model_description.csv')

    for model in models:
        model_name, model_url = model.split(',')
        sys.stderr.write("INFO: DOWNLOADING FILE: " + str(model_name) +
                         ".pkl\n")
        sys.stderr.write("INFO: DOWNLOADING LINK: " + str(model_url) + "\n")
        wget.download(model_url, output_dir)
        sys.stderr.write("\n")
        sys.stderr.flush()
Beispiel #4
0
def perform_stitch(input_directory, output_path, output_prefix, threads):
    """
    This method gathers all contigs and calls the stitch module for each contig.
    :param input_directory: Path to the directory containing input files.
    :param output_path: Path to the output_consensus_sequence
    :param output_prefix: Output file's prefix
    :param threads: Number of threads to use
    :return:
    """
    # get all the files
    all_prediction_files = get_file_paths_from_directory(input_directory)

    # we gather all the contigs
    all_contigs = set()

    # get contigs from all of the files
    for prediction_file in sorted(all_prediction_files):
        with h5py.File(prediction_file, 'r') as hdf5_file:
            if 'predictions' in hdf5_file:
                contigs = list(hdf5_file['predictions'].keys())
                all_contigs.update(contigs)
            else:
                raise ValueError(
                    TextColor.RED +
                    "ERROR: INVALID HDF5 FILE, FILE DOES NOT CONTAIN predictions KEY.\n"
                    + TextColor.END)
    # convert set to a list
    all_contigs = list(all_contigs)

    # get output directory
    output_dir = FileManager.handle_output_directory(output_path)

    # open an output fasta file
    # we should really use a fasta handler for this, I don't like this.
    output_filename = os.path.join(output_dir, output_prefix + '.fa')
    consensus_fasta_file = open(output_filename, 'w')
    sys.stderr.write(TextColor.GREEN + "INFO: OUTPUT FILE: " +
                     output_filename + "\n" + TextColor.END)

    # for each contig
    for i, contig in enumerate(sorted(all_contigs)):
        log_prefix = "{:04d}".format(i) + "/" + "{:04d}".format(
            len(contigs)) + ":"
        sys.stderr.write(TextColor.GREEN + "INFO: " + str(log_prefix) +
                         " PROCESSING CONTIG: " + contig + "\n" +
                         TextColor.END)

        # get all the chunk keys
        chunk_name_tuple = list()
        for prediction_file in all_prediction_files:
            with h5py.File(prediction_file, 'r') as hdf5_file:
                # check if the contig is contained in this file
                if contig not in list(hdf5_file['predictions'].keys()):
                    continue

                # if contained then get the chunks
                chunk_keys = sorted(hdf5_file['predictions'][contig].keys())
                for chunk_key in chunk_keys:
                    chunk_contig_start = hdf5_file['predictions'][contig][
                        chunk_key]['contig_start'][()]
                    chunk_contig_end = hdf5_file['predictions'][contig][
                        chunk_key]['contig_end'][()]
                    chunk_name_tuple.append(
                        (prediction_file, chunk_key, chunk_contig_start,
                         chunk_contig_end))

        # call stitch to generate a sequence for this contig
        stich_object = Stitch()
        consensus_sequence = stich_object.create_consensus_sequence(
            contig, chunk_name_tuple, threads)
        sys.stderr.write(TextColor.BLUE + "INFO: " + str(log_prefix) +
                         " FINISHED PROCESSING " + contig +
                         ", POLISHED SEQUENCE LENGTH: " +
                         str(len(consensus_sequence)) + ".\n" + TextColor.END)

        # if theres a sequence then write it to the file
        if consensus_sequence is not None and len(consensus_sequence) > 0:
            consensus_fasta_file.write('>' + contig + "\n")
            consensus_fasta_file.write(consensus_sequence + "\n")
Beispiel #5
0
def call_consensus(image_dir, model_path, batch_size, num_workers, threads,
                   output_dir, output_prefix, gpu_mode, device_ids, callers):
    """
    This method provides an interface too call the predict method that generates the prediction hdf5 file
    :param image_dir: Path to directory where all MarginPolish images are saved
    :param model_path: Path to a trained model
    :param batch_size: Batch size for minibatch processing
    :param num_workers: Number of workers for minibatch processing
    :param threads: Number of threads for pytorch
    :param output_dir: Path to the output directory
    :param output_prefix: Prefix of the output HDF5 file
    :param gpu_mode: If true, predict method will use GPU.
    :param device_ids: List of CUDA devices to use.
    :param callers: Total number of callers.
    :return:
    """
    # check the model file
    if not os.path.isfile(model_path):
        sys.stderr.write(TextColor.RED +
                         "ERROR: CAN NOT LOCATE MODEL FILE.\n" + TextColor.END)
        exit(1)

    # check the input directory
    if not os.path.isdir(image_dir):
        sys.stderr.write(TextColor.RED +
                         "ERROR: CAN NOT LOCATE IMAGE DIRECTORY.\n" +
                         TextColor.END)
        exit(1)

    # check batch_size
    if batch_size <= 0:
        sys.stderr.write(TextColor.RED +
                         "ERROR: batch_size NEEDS TO BE >0.\n" + TextColor.END)
        exit(1)

    # check num_workers
    if num_workers < 0:
        sys.stderr.write(TextColor.RED +
                         "ERROR: num_workers NEEDS TO BE >=0.\n" +
                         TextColor.END)
        exit(1)

    # check number of threads
    if threads <= 0:
        sys.stderr.write(TextColor.RED + "ERROR: THREAD NEEDS TO BE >=0.\n" +
                         TextColor.END)
        exit(1)

    output_dir = FileManager.handle_output_directory(output_dir)

    # create a filename for the output file
    output_filename = os.path.join(output_dir, output_prefix)

    # inform the output directory
    sys.stderr.write(TextColor.GREEN + "INFO: " + TextColor.END +
                     "OUTPUT FILE: " + output_filename + "\n")

    if gpu_mode:
        # Make sure that GPU is
        if not torch.cuda.is_available():
            sys.stderr.write(TextColor.RED +
                             "ERROR: TORCH IS NOT BUILT WITH CUDA.\n" +
                             TextColor.END)
            sys.stderr.write(
                TextColor.RED + "SEE TORCH CAPABILITY:\n$ python3\n"
                ">>> import torch \n"
                ">>> torch.cuda.is_available()\n If true then cuda is avilable"
                + TextColor.END)
            exit(1)

        # Now see which devices to use
        if device_ids is None:
            total_gpu_devices = torch.cuda.device_count()
            sys.stderr.write(TextColor.GREEN + "INFO: TOTAL GPU AVAILABLE: " +
                             str(total_gpu_devices) + "\n" + TextColor.END)
            device_ids = [i for i in range(0, total_gpu_devices)]
            callers = total_gpu_devices
        else:
            device_ids = [int(i) for i in device_ids.split(',')]
            for device_id in device_ids:
                major_capable, minor_capable = torch.cuda.get_device_capability(
                    device=device_id)
                if major_capable < 0:
                    sys.stderr.write(TextColor.RED + "ERROR: GPU DEVICE: " +
                                     str(device_id) +
                                     " IS NOT CUDA CAPABLE.\n" + TextColor.END)
                    sys.stderr.write(
                        TextColor.GREEN + "Try running: $ python3\n"
                        ">>> import torch \n"
                        ">>> torch.cuda.get_device_capability(device=" +
                        str(device_id) + ")\n" + TextColor.END)
                else:
                    sys.stderr.write(TextColor.GREEN +
                                     "INFO: CAPABILITY OF GPU#" +
                                     str(device_id) + ":\t" +
                                     str(major_capable) + "-" +
                                     str(minor_capable) + "\n" + TextColor.END)
            callers = len(device_ids)

        sys.stderr.write(TextColor.GREEN + "INFO: AVAILABLE GPU DEVICES: " +
                         str(device_ids) + "\n" + TextColor.END)
        threads_per_caller = 0
    else:
        # calculate how many threads each caller can use
        threads_per_caller = int(threads / callers)
        device_ids = []

    # chunk the inputs
    input_files = get_file_paths_from_directory(image_dir)

    # generate file chunks to process in parallel
    file_chunks = [[] for i in range(callers)]
    for i in range(0, len(input_files)):
        file_chunks[i % callers].append(input_files[i])

    # get the file chunks
    file_chunks = [
        file_chunks[i] for i in range(len(file_chunks))
        if len(file_chunks[i]) > 0
    ]

    callers = len(file_chunks)

    if gpu_mode:
        # Distributed GPU setup
        predict_gpu(file_chunks, output_filename, model_path, batch_size,
                    callers, device_ids, num_workers)
    else:
        # distributed CPU setup, call the prediction function
        predict_cpu(file_chunks, output_filename, model_path, batch_size,
                    callers, threads_per_caller, num_workers)

    # notify the user that process has completed successfully
    sys.stderr.write(TextColor.GREEN + "INFO: " + TextColor.END +
                     "PREDICTION GENERATED SUCCESSFULLY.\n")