Exemplo n.º 1
0
    def __init__(self, chromosome_name, bam_file_path, draft_file_path,
                 truth_bam, train_mode):
        """
        Initialize a manager object
        :param chromosome_name: Name of the chromosome
        :param bam_file_path: Path to the BAM file
        :param draft_file_path: Path to the reference FASTA file
        :param truth_bam: Path to the truth sequence to reference mapping file
        """
        # --- initialize handlers ---
        # create objects to handle different files and query
        self.bam_path = bam_file_path
        self.fasta_path = draft_file_path
        self.bam_handler = PEPPER.BAM_handler(bam_file_path)
        self.fasta_handler = PEPPER.FASTA_handler(draft_file_path)
        self.train_mode = train_mode
        self.downsample_rate = 1.0
        self.truth_bam_handler = None

        if self.train_mode:
            self.truth_bam_handler = PEPPER.BAM_handler(truth_bam)

        # --- initialize names ---
        # name of the chromosome
        self.chromosome_name = chromosome_name
Exemplo n.º 2
0
def make_train_images(bam_filepath, truth_bam_filepath, fasta_filepath, region,
                      region_bed, output_dir, threads):
    """
    GENERATE IMAGES WITHOUT ANY LABELS. THIS IS USED BY pepper.py
    :param bam_filepath: Path to the input bam file.
    :param truth_bam_filepath: Path to the bam where truth is aligned to the assembly.
    :param fasta_filepath: Path to the input fasta file.
    :param region: Specific region of interest.
    :param output_dir: Path to the output directory.
    :param threads: Number of threads to use.
    :return:
    """
    # check the bam file
    if not os.path.isfile(bam_filepath) or not PEPPER.BAM_handler(
            bam_filepath):
        sys.stderr.write("[" +
                         str(datetime.now().strftime('%m-%d-%Y %H:%M:%S')) +
                         "] ERROR: CAN NOT LOCATE BAM FILE.\n")
        exit(1)

    # check the truth bam file
    if not os.path.isfile(truth_bam_filepath) or not PEPPER.BAM_handler(
            truth_bam_filepath):
        sys.stderr.write("[" +
                         str(datetime.now().strftime('%m-%d-%Y %H:%M:%S')) +
                         "] ERROR: CAN NOT LOCATE TRUTH BAM FILE.\n")
        exit(1)

    # check the fasta file
    if not os.path.isfile(fasta_filepath):
        sys.stderr.write("[" +
                         str(datetime.now().strftime('%m-%d-%Y %H:%M:%S')) +
                         "] ERROR: CAN NOT LOCATE FASTA FILE.\n")
        exit(1)
    # check the output directory
    output_dir = UserInterfaceSupport.handle_output_directory(
        os.path.abspath(output_dir))

    # check number of threads
    if threads <= 0:
        sys.stderr.write("[" +
                         str(datetime.now().strftime('%m-%d-%Y %H:%M:%S')) +
                         "] ERROR: THREAD NEEDS TO BE >=0.\n")
        exit(1)

    # get the list of contigs
    contig_list = UserInterfaceSupport.get_chromosome_list(
        region, fasta_filepath, bam_filepath, region_bed)

    # call the parallelization method to generate images in parallel
    UserInterfaceSupport.chromosome_level_parallelization(
        contig_list,
        bam_filepath,
        fasta_filepath,
        truth_bam=truth_bam_filepath,
        output_path=output_dir,
        total_threads=threads,
        train_mode=True)
Exemplo n.º 3
0
def polish(bam_filepath, fasta_filepath, output_path, threads, region,
           model_path, batch_size, gpu_mode, device_ids, num_workers):
    """
    Run all the sub-modules to polish an input assembly.
    """
    # check the bam file
    if not os.path.isfile(bam_filepath) or not PEPPER.BAM_handler(
            bam_filepath):
        sys.stderr.write("[" +
                         str(datetime.now().strftime('%m-%d-%Y %H:%M:%S')) +
                         "] ERROR: CAN NOT LOCATE BAM FILE.\n")
        exit(1)

    # check the fasta file
    if not os.path.isfile(fasta_filepath):
        sys.stderr.write("[" +
                         str(datetime.now().strftime('%m-%d-%Y %H:%M:%S')) +
                         "] ERROR: CAN NOT LOCATE FASTA FILE.\n")
        exit(1)

    # check the model file
    if not os.path.isfile(model_path):
        sys.stderr.write("[" +
                         str(datetime.now().strftime('%m-%d-%Y %H:%M:%S')) +
                         "] ERROR: CAN NOT LOCATE MODEL FILE.\n")
        exit(1)

    # check number of threads
    if threads <= 0:
        sys.stderr.write("[" +
                         str(datetime.now().strftime('%m-%d-%Y %H:%M:%S')) +
                         "] ERROR: THREAD NEEDS TO BE >=0.\n")
        exit(1)

    # check batch_size
    if batch_size <= 0:
        sys.stderr.write("[" +
                         str(datetime.now().strftime('%m-%d-%Y %H:%M:%S')) +
                         "] ERROR: batch_size NEEDS TO BE >0.\n")
        exit(1)

    # check num_workers
    if num_workers < 0:
        sys.stderr.write("[" +
                         str(datetime.now().strftime('%m-%d-%Y %H:%M:%S')) +
                         "] ERROR: num_workers NEEDS TO BE >=0.\n")
        exit(1)

    callers = threads
    threads_per_caller = int(threads / max(1, callers))
    # check number of threads
    if threads_per_caller <= 0:
        sys.stderr.write("[" +
                         str(datetime.now().strftime('%m-%d-%Y %H:%M:%S')) +
                         "] ERROR: THREAD PER CALLER NEEDS TO BE >=0.\n")
        exit(1)

    # check if gpu inference can be done
    if gpu_mode:
        if not torch.cuda.is_available():
            sys.stderr.write(
                "[" + str(datetime.now().strftime('%m-%d-%Y %H:%M:%S')) +
                "] ERROR: TORCH IS NOT BUILT WITH CUDA.\n")
            sys.stderr.write(
                "SEE TORCH CAPABILITY:\n$ python3\n"
                ">>> import torch \n"
                ">>> torch.cuda.is_available()\n If true then cuda is avilable"
            )
            exit(1)

    # check if all devices are available
    if device_ids is not None:
        device_ids = [int(i) for i in device_ids.split(',')]
        for device_id in device_ids:
            major_capable, minor_capable = torch.cuda.get_device_capability(
                device=device_id)
            if major_capable < 0:
                sys.stderr.write(
                    "[" + str(datetime.now().strftime('%m-%d-%Y %H:%M:%S')) +
                    "] ERROR: GPU DEVICE: " + str(device_id) +
                    " IS NOT CUDA CAPABLE.\n")
                sys.stderr.write(
                    "Try running: $ python3\n"
                    ">>> import torch \n"
                    ">>> torch.cuda.get_device_capability(device=" +
                    str(device_id) + ")\n")
                exit(1)
            else:
                sys.stderr.write(
                    "[" + str(datetime.now().strftime('%m-%d-%Y %H:%M:%S')) +
                    "] INFO: CAPABILITY OF GPU#" + str(device_id) + ":\t" +
                    str(major_capable) + "-" + str(minor_capable) + "\n")

    timestr = time.strftime("%m%d%Y_%H%M%S")

    # run directories
    output_dir = UserInterfaceSupport.handle_output_directory(output_path)

    image_output_directory = output_dir + "images_" + str(timestr) + "/"
    prediction_output_directory = output_dir + "predictions_" + str(
        timestr) + "/"

    sys.stderr.write("[" + str(datetime.now().strftime('%m-%d-%Y %H:%M:%S')) +
                     "] INFO: RUN-ID: " + str(timestr) + "\n")
    sys.stderr.write("[" + str(datetime.now().strftime('%m-%d-%Y %H:%M:%S')) +
                     "] INFO: IMAGE OUTPUT: " + str(image_output_directory) +
                     "\n")

    sys.stderr.write("[" + str(datetime.now().strftime('%m-%d-%Y %H:%M:%S')) +
                     "] STEP 1: GENERATING IMAGES\n")
    sys.stderr.flush()
    # call the parallelization method to generate images in parallel
    make_images(bam_filepath, fasta_filepath, region, image_output_directory,
                threads)

    sys.stderr.write("[" + str(datetime.now().strftime('%m-%d-%Y %H:%M:%S')) +
                     "] STEP 2: RUNNING INFERENCE\n")
    sys.stderr.write("[" + str(datetime.now().strftime('%m-%d-%Y %H:%M:%S')) +
                     "] INFO: PREDICTION OUTPUT: " +
                     str(prediction_output_directory) + "\n")
    sys.stderr.flush()
    call_consensus(image_output_directory, model_path, batch_size, num_workers,
                   prediction_output_directory, device_ids, gpu_mode, threads)

    sys.stderr.write("[" + str(datetime.now().strftime('%m-%d-%Y %H:%M:%S')) +
                     "] STEP 3: RUNNING STITCH\n")
    sys.stderr.write("[" + str(datetime.now().strftime('%m-%d-%Y %H:%M:%S')) +
                     "] INFO: STITCH OUTPUT: " + str(output_dir) + "\n")
    sys.stderr.flush()
    perform_stitch(prediction_output_directory, output_dir, threads)
Exemplo n.º 4
0
    def get_chromosome_list(chromosome_names, ref_file, bam_file, region_bed):
        """
        PARSES THROUGH THE CHROMOSOME PARAMETER TO FIND OUT WHICH REGIONS TO PROCESS
        :param chromosome_names: NAME OF CHROMOSOME
        :param ref_file: PATH TO THE REFERENCE FILE
        :param bam_file: PATH TO BAM FILE
        :return: LIST OF CHROMOSOME IN REGION SPECIFIC FORMAT
        """
        if not chromosome_names and not region_bed:
            fasta_handler = PEPPER.FASTA_handler(ref_file)
            bam_handler = PEPPER.BAM_handler(bam_file)
            bam_contigs = bam_handler.get_chromosome_sequence_names()
            fasta_contigs = fasta_handler.get_chromosome_names()
            common_contigs = list(set(fasta_contigs) & set(bam_contigs))

            if len(common_contigs) == 0:
                sys.stderr.write(
                    "[" + datetime.now().strftime('%m-%d-%Y %H:%M:%S') + "] " +
                    "ERROR: NO COMMON CONTIGS FOUND BETWEEN THE BAM FILE AND THE FASTA FILE."
                )
                sys.stderr.flush()
                exit(1)

            common_contigs = sorted(common_contigs,
                                    key=UserInterfaceSupport.natural_key)
            sys.stderr.write("[" +
                             datetime.now().strftime('%m-%d-%Y %H:%M:%S') +
                             "] INFO: COMMON CONTIGS FOUND: " +
                             str(common_contigs) + "\n")
            sys.stderr.flush()

            chromosome_name_list = []
            for contig_name in common_contigs:
                chromosome_name_list.append((contig_name, None))

            return chromosome_name_list

        if region_bed:
            chromosome_name_list = []
            with open(region_bed) as fp:
                line = fp.readline()
                cnt = 1
                while line:
                    line_to_list = line.rstrip().split('\t')
                    chr_name, start_pos, end_pos = line_to_list[0], int(
                        line_to_list[1]), int(line_to_list[2])
                    region = sorted([start_pos, end_pos])
                    chromosome_name_list.append((chr_name, region))
                    line = fp.readline()
                cnt += 1
            return chromosome_name_list

        split_names = chromosome_names.strip().split(',')
        split_names = [name.strip() for name in split_names]

        chromosome_name_list = []
        for name in split_names:
            # split on region
            region = None
            if ':' in name:
                name_region = name.strip().split(':')

                if len(name_region) != 2:
                    sys.stderr.write("ERROR: --region INVALID value.\n")
                    exit(0)

                name, region = tuple(name_region)
                region = region.strip().split('-')
                region = [int(pos) for pos in region]

                if len(region) != 2 or not region[0] <= region[1]:
                    sys.stderr.write("ERROR: --region INVALID value.\n")
                    exit(0)

            range_split = name.split('-')
            if len(range_split) > 1:
                chr_prefix = ''
                for p in name:
                    if p.isdigit():
                        break
                    else:
                        chr_prefix = chr_prefix + p

                int_ranges = []
                for item in range_split:
                    s = ''.join(i for i in item if i.isdigit())
                    int_ranges.append(int(s))
                int_ranges = sorted(int_ranges)

                for chr_seq in range(int_ranges[0], int_ranges[-1] + 1):
                    chromosome_name_list.append(
                        (chr_prefix + str(chr_seq), region))
            else:
                chromosome_name_list.append((name, region))

        return chromosome_name_list