コード例 #1
0
ファイル: CheckEnvs.py プロジェクト: HKU-BAL/Clair3
def split_extend_vcf(vcf_fn, output_fn):
    expand_region_size = param.no_of_positions
    output_ctg_dict = defaultdict(list)
    unzip_process = subprocess_popen(shlex.split("gzip -fdc %s" % (vcf_fn)))

    for row_id, row in enumerate(unzip_process.stdout):
        if row[0] == '#':
            continue
        columns = row.strip().split(maxsplit=3)
        ctg_name = columns[0]

        center_pos = int(columns[1])
        ctg_start, ctg_end = center_pos - 1, center_pos
        if ctg_start < 0:
            sys.exit(
                log_error("[ERROR] Invalid VCF input in {}-th row {} {} {}".format(row_id + 1, ctg_name, center_pos)))
        if ctg_start - expand_region_size < 0:
            continue
        expand_ctg_start = ctg_start - expand_region_size
        expand_ctg_end = ctg_end + expand_region_size

        output_ctg_dict[ctg_name].append(
            ' '.join([ctg_name, str(expand_ctg_start), str(expand_ctg_end)]))

    for key, value in output_ctg_dict.items():
        ctg_output_fn = os.path.join(output_fn, key)
        with open(ctg_output_fn, 'w') as output_file:
            output_file.write('\n'.join(value))

    unzip_process.stdout.close()
    unzip_process.wait()

    know_vcf_contig_set = set(list(output_ctg_dict.keys()))

    return know_vcf_contig_set
コード例 #2
0
ファイル: CheckEnvs.py プロジェクト: HKU-BAL/Clair3
def check_tools_version(tool_version, required_tool_version):
    for tool, version in tool_version.items():
        required_version = required_tool_version[tool]
        # whatshap cannot be installed in Mac arm64 system
        if platform.system() == "Darwin" and tool == 'whatshap':
            continue
        if version is None:
            print(log_error("[ERROR] {} not found, please check you are in clair3 virtual environment".format(tool)))
            check_python_path()
        elif version < required_version:
            print(log_error("[ERROR] Tool version not match, please check you are in clair3 virtual environment"))
            print(' '.join([str(item).ljust(10) for item in ["Tool", "Version", "Required"]]))
            error_info = ' '.join([str(item).ljust(10) for item in [tool, version, '>=' + str(required_version)]])
            print(error_info)
            check_python_path()
    return
コード例 #3
0
ファイル: CheckEnvs.py プロジェクト: HKU-BAL/Clair3
def split_extend_bed(bed_fn, output_fn, contig_set=None):
    expand_region_size = param.no_of_positions
    output_ctg_dict = defaultdict(list)
    unzip_process = subprocess_popen(shlex.split("gzip -fdc %s" % (bed_fn)))
    for row_id, row in enumerate(unzip_process.stdout):
        if row[0] == '#':
            continue
        columns = row.strip().split()
        ctg_name = columns[0]
        if contig_set and ctg_name not in contig_set:
            continue

        ctg_start, ctg_end = int(columns[1]), int(columns[2])

        if ctg_end < ctg_start or ctg_start < 0 or ctg_end < 0:
            sys.exit(log_error(
                "[ERROR] Invalid BED input in {}-th row {} {} {}".format(row_id + 1, ctg_name, ctg_start, ctg_end)))
        expand_ctg_start = max(0, ctg_start - expand_region_size)
        expand_ctg_end = max(0, ctg_end + expand_region_size)
        output_ctg_dict[ctg_name].append(
            ' '.join([ctg_name, str(expand_ctg_start), str(expand_ctg_end)]))

    for key, value in output_ctg_dict.items():
        ctg_output_fn = os.path.join(output_fn, key)
        with open(ctg_output_fn, 'w') as output_file:
            output_file.write('\n'.join(value))

    unzip_process.stdout.close()
    unzip_process.wait()
コード例 #4
0
ファイル: CreateTensorPileup.py プロジェクト: HKU-BAL/Clair3
def CreateTensorPileup(args):
    """
    Create pileup tensor for pileup model training or calling.
    Use slide window to scan the whole candidate regions, keep all candidates over specific minimum allelic frequency
    and minimum depth, use samtools mpileup to store pileup info for pileup tensor generation. Only scan candidate
    regions once, we could directly get all variant candidates directly.
    """
    ctg_start = args.ctgStart
    ctg_end = args.ctgEnd
    fasta_file_path = args.ref_fn
    ctg_name = args.ctgName
    samtools_execute_command = args.samtools
    bam_file_path = args.bam_fn
    chunk_id = args.chunk_id - 1 if args.chunk_id else None  # 1-base to 0-base
    chunk_num = args.chunk_num
    tensor_can_output_path = args.tensor_can_fn
    minimum_af_for_candidate = args.min_af
    minimum_snp_af_for_candidate = args.snp_min_af
    minimum_indel_af_for_candidate = args.indel_min_af
    min_coverage = args.minCoverage
    platform = args.platform
    confident_bed_fn = args.bed_fn
    is_confident_bed_file_given = confident_bed_fn is not None
    alt_fn = args.indel_fn
    extend_bed = args.extend_bed
    is_extend_bed_file_given = extend_bed is not None
    min_mapping_quality = args.minMQ
    min_base_quality = args.minBQ
    fast_mode = args.fast_mode
    vcf_fn = args.vcf_fn
    is_known_vcf_file_provided = vcf_fn is not None
    call_snp_only = args.call_snp_only

    global test_pos
    test_pos = None

    # 1-based regions [start, end] (start and end inclusive)
    ref_regions = []
    reads_regions = []
    known_variants_set = set()
    tree, bed_start, bed_end = bed_tree_from(bed_file_path=extend_bed,
                                             contig_name=ctg_name,
                                             return_bed_region=True)

    fai_fn = file_path_from(fasta_file_path,
                            suffix=".fai",
                            exit_on_not_found=True,
                            sep='.')
    if not is_confident_bed_file_given and chunk_id is not None:
        contig_length = 0
        with open(fai_fn, 'r') as fai_fp:
            for row in fai_fp:
                columns = row.strip().split("\t")

                contig_name = columns[0]
                if contig_name != ctg_name:
                    continue
                contig_length = int(columns[1])
        chunk_size = contig_length // chunk_num + 1 if contig_length % chunk_num else contig_length // chunk_num
        ctg_start = chunk_size * chunk_id  # 0-base to 1-base
        ctg_end = ctg_start + chunk_size

    if is_confident_bed_file_given and chunk_id is not None:
        chunk_size = (bed_end - bed_start) // chunk_num + 1 if (
            bed_end - bed_start) % chunk_num else (bed_end -
                                                   bed_start) // chunk_num
        ctg_start = bed_start + 1 + chunk_size * chunk_id  # 0-base to 1-base
        ctg_end = ctg_start + chunk_size

    if is_known_vcf_file_provided and chunk_id is not None:
        known_variants_list = vcf_candidates_from(vcf_fn=vcf_fn,
                                                  contig_name=ctg_name)
        total_variants_size = len(known_variants_list)
        chunk_variants_size = total_variants_size // chunk_num if total_variants_size % chunk_num == 0 else total_variants_size // chunk_num + 1
        chunk_start_pos = chunk_id * chunk_variants_size
        known_variants_set = set(
            known_variants_list[chunk_start_pos:chunk_start_pos +
                                chunk_variants_size])
        if len(known_variants_set) == 0:
            return
        ctg_start, ctg_end = min(known_variants_set), max(known_variants_set)

    is_ctg_name_given = ctg_name is not None
    is_ctg_range_given = is_ctg_name_given and ctg_start is not None and ctg_end is not None
    if is_ctg_range_given:
        extend_start = ctg_start - no_of_positions
        extend_end = ctg_end + no_of_positions
        reads_regions.append(
            region_from(ctg_name=ctg_name,
                        ctg_start=extend_start,
                        ctg_end=extend_end))
        reference_start, reference_end = ctg_start - param.expandReferenceRegion, ctg_end + param.expandReferenceRegion
        reference_start = 1 if reference_start < 1 else reference_start
        ref_regions.append(
            region_from(ctg_name=ctg_name,
                        ctg_start=reference_start,
                        ctg_end=reference_end))
    elif is_ctg_name_given:
        reads_regions.append(region_from(ctg_name=ctg_name))
        ref_regions.append(region_from(ctg_name=ctg_name))
        reference_start = 1

    reference_sequence = reference_sequence_from(
        samtools_execute_command=samtools_execute_command,
        fasta_file_path=fasta_file_path,
        regions=ref_regions)

    if reference_sequence is None or len(reference_sequence) == 0:
        sys.exit(
            log_error(
                "[ERROR] Failed to load reference sequence from file ({}).".
                format(fasta_file_path)))

    if is_confident_bed_file_given and ctg_name not in tree:
        sys.exit(
            log_error("[ERROR] ctg_name {} not exists in bed file({}).".format(
                ctg_name, confident_bed_fn)))

    # samtools mpileup options
    # reverse-del: deletion in forward/reverse strand were marked as '*'/'#'
    min_base_quality = 0 if args.gvcf else min_base_quality
    max_depth = param.max_depth_dict[
        args.platform] if args.platform else args.max_depth
    mq_option = ' --min-MQ {}'.format(min_mapping_quality)
    bq_option = ' --min-BQ {}'.format(min_base_quality)
    flags_option = ' --excl-flags {}'.format(param.SAMTOOLS_VIEW_FILTER_FLAG)
    max_depth_option = ' --max-depth {}'.format(max_depth)
    bed_option = ' -l {}'.format(
        extend_bed) if is_extend_bed_file_given else ""
    gvcf_option = ' -a' if args.gvcf else ""
    samtools_mpileup_process = subprocess_popen(
        shlex.split("{} mpileup  {} -r {} --reverse-del".format(
            samtools_execute_command,
            bam_file_path,
            " ".join(reads_regions),
        ) + mq_option + bq_option + bed_option + flags_option +
                    max_depth_option + gvcf_option))

    if tensor_can_output_path != "PIPE":
        tensor_can_fpo = open(tensor_can_output_path, "wb")
        tensor_can_fp = subprocess_popen(shlex.split("{} -c".format(
            param.zstd)),
                                         stdin=PIPE,
                                         stdout=tensor_can_fpo)
    else:
        tensor_can_fp = TensorStdout(sys.stdout)

    # whether save all alternative information, only for debug mode
    if alt_fn:
        alt_fp = open(alt_fn, 'w')

    pos_offset = 0
    pre_pos = -1
    tensor = [[]] * sliding_window_size
    candidate_position = []
    all_alt_dict = {}
    depth_dict = {}
    af_dict = {}

    # to generate gvcf, it is needed to record whole genome statistical information
    if args.gvcf:
        nonVariantCaller = variantInfoCalculator(
            gvcfWritePath=args.temp_file_dir,
            ref_path=args.ref_fn,
            bp_resolution=args.bp_resolution,
            ctgName=ctg_name,
            sample_name='.'.join(
                [args.sampleName, ctg_name,
                 str(ctg_start),
                 str(ctg_end)]),
            p_err=args.base_err,
            gq_bin_size=args.gq_bin_size)

    confident_bed_tree = bed_tree_from(bed_file_path=confident_bed_fn,
                                       contig_name=ctg_name,
                                       bed_ctg_start=extend_start,
                                       bed_ctg_end=extend_end)

    empty_pileup_flag = True
    for row in samtools_mpileup_process.stdout:
        empty_pileup_flag = False
        columns = row.strip().split('\t', maxsplit=5)
        pos = int(columns[1])
        pileup_bases = columns[4]
        reference_base = reference_sequence[pos - reference_start].upper()
        valid_reference_flag = True
        within_flag = True
        if args.gvcf:
            if not valid_reference_flag:
                nonVariantCaller.make_gvcf_online({}, push_current=True)
            if ctg_start != None and ctg_end != None:
                within_flag = pos >= ctg_start and pos < ctg_end
            elif ctg_start != None and ctg_end == None:
                within_flag = pos >= ctg_start
            elif ctg_start == None and ctg_end != None:
                within_flag = pos <= ctg_end
            else:
                within_flag = True
            if columns[3] == '0' and within_flag and valid_reference_flag:
                cur_site_info = {
                    'chr': columns[0],
                    'pos': pos,
                    'ref': reference_base,
                    'n_total': 0,
                    'n_ref': 0
                }
                nonVariantCaller.make_gvcf_online(cur_site_info)
                continue

        # start with a new region, clear all sliding windows cache, avoid memory occupation
        if pre_pos + 1 != pos:
            pos_offset = 0
            tensor = [[]] * sliding_window_size
            candidate_position = []
        pre_pos = pos

        # a condition to skip some positions creating tensor,but return allele summary
        # allele count function
        pileup_tensor, alt_dict, af, depth, pass_af, pileup_list, max_del_length = generate_tensor(
            pos=pos,
            pileup_bases=pileup_bases,
            reference_sequence=reference_sequence,
            reference_start=reference_start,
            reference_base=reference_base,
            minimum_af_for_candidate=minimum_af_for_candidate,
            minimum_snp_af_for_candidate=minimum_snp_af_for_candidate,
            minimum_indel_af_for_candidate=minimum_indel_af_for_candidate,
            platform=platform,
            fast_mode=fast_mode,
            call_snp_only=call_snp_only)
        if args.gvcf and within_flag and valid_reference_flag:
            cur_n_total = 0
            cur_n_ref = 0
            for _key, _value in pileup_list:
                if (_key == reference_base):
                    cur_n_ref = _value
                cur_n_total += _value

            cur_site_info = {
                'chr': columns[0],
                'pos': pos,
                'ref': reference_base,
                'n_total': cur_n_total,
                'n_ref': cur_n_ref
            }
            nonVariantCaller.make_gvcf_online(cur_site_info)

        pass_confident_bed = not is_confident_bed_file_given or is_region_in(
            tree=confident_bed_tree,
            contig_name=ctg_name,
            region_start=pos - 1,
            region_end=pos + max_del_length + 1)  # 0-based
        if (pass_confident_bed and reference_base in 'ACGT' and
            (pass_af and depth >= min_coverage)
                and not is_known_vcf_file_provided) or (
                    is_known_vcf_file_provided and pos in known_variants_set):
            candidate_position.append(pos)
            all_alt_dict[pos] = alt_dict
            depth_dict[pos] = depth
            af_dict[pos] = af
        tensor[pos_offset] = pileup_tensor

        # save pileup tensor for each candidate position with nearby flanking_base_num bp distance
        pos_offset = (pos_offset + 1) % sliding_window_size
        if len(candidate_position
               ) and pos - candidate_position[0] == flanking_base_num:
            center = candidate_position.pop(0)
            has_empty_tensor = sum([True for item in tensor if not len(item)])
            if not has_empty_tensor:
                depth = depth_dict[center]
                ref_seq = reference_sequence[center - (flanking_base_num) -
                                             reference_start:center +
                                             flanking_base_num + 1 -
                                             reference_start]
                concat_tensor = tensor[pos_offset:] + tensor[0:pos_offset]

                alt_info = str(depth) + '-' + ' '.join([
                    ' '.join([item[0], str(item[1])])
                    for item in list(all_alt_dict[center].items())
                ])
                l = "%s\t%d\t%s\t%s\t%s" % (
                    ctg_name, center, ref_seq, " ".join(
                        " ".join("%d" % x for x in innerlist)
                        for innerlist in concat_tensor), alt_info)
                tensor_can_fp.stdin.write(l)
                tensor_can_fp.stdin.write("\n")
                if alt_fn:
                    alt_info = ' '.join([
                        ' '.join([item[0], str(item[1])])
                        for item in list(all_alt_dict[center].items())
                    ])
                    alt_fp.write('\t'.join([
                        ctg_name + ' ' + str(center),
                        str(depth), alt_info,
                        str(af_dict[center])
                    ]) + '\n')
                del all_alt_dict[center], depth_dict[center], af_dict[center]

    if args.gvcf and len(nonVariantCaller.current_block) != 0:
        nonVariantCaller.write_to_gvcf_batch(nonVariantCaller.current_block,
                                             nonVariantCaller.cur_min_DP,
                                             nonVariantCaller.cur_raw_gq)

    if args.gvcf and empty_pileup_flag:
        nonVariantCaller.write_empty_pileup(ctg_name, ctg_start, ctg_end)
    if args.gvcf:
        nonVariantCaller.close_vcf_writer()

    samtools_mpileup_process.stdout.close()
    samtools_mpileup_process.wait()

    if tensor_can_output_path != "PIPE":
        tensor_can_fp.stdin.close()
        tensor_can_fp.wait()
        tensor_can_fpo.close()

    if alt_fn:
        alt_fp.close()
コード例 #5
0
def sort_vcf_from(args):
    """
    Sort vcf file from providing vcf filename prefix.
    """
    output_fn = args.output_fn
    input_dir = args.input_dir
    vcf_fn_prefix = args.vcf_fn_prefix
    vcf_fn_suffix = args.vcf_fn_suffix
    sample_name = args.sampleName
    ref_fn = args.ref_fn
    contigs_fn = args.contigs_fn

    if not os.path.exists(input_dir):
        exit(
            log_error("[ERROR] Input directory: {} not exists!").format(
                input_dir))
    all_files = os.listdir(input_dir)

    if vcf_fn_prefix is not None:
        all_files = [
            item for item in all_files if item.startswith(vcf_fn_prefix)
        ]
        if len(all_files) == 0:
            output_header(output_fn=output_fn,
                          reference_file_path=ref_fn,
                          sample_name=sample_name)
            print(
                log_warning(
                    "[WARNING] No vcf file found with prefix:{}/{}, output empty vcf file"
                    .format(input_dir, vcf_fn_prefix)))
            compress_index_vcf(output_fn)
            print_calling_step(output_fn=output_fn)
            return

    if vcf_fn_suffix is not None:
        all_files = [
            item for item in all_files if item.endswith(vcf_fn_suffix)
        ]
        if len(all_files) == 0:
            output_header(output_fn=output_fn,
                          reference_file_path=ref_fn,
                          sample_name=sample_name)
            print(
                log_warning(
                    "[WARNING] No vcf file found with suffix:{}/{}, output empty vcf file"
                    .format(input_dir, vcf_fn_prefix)))
            compress_index_vcf(output_fn)
            print_calling_step(output_fn=output_fn)
            return

    all_contigs_list = []
    if contigs_fn and os.path.exists(contigs_fn):
        with open(contigs_fn) as f:
            all_contigs_list = [item.rstrip() for item in f.readlines()]
    else:
        exit(
            log_error("[ERROR] Cannot find contig file {}. Exit!").format(
                contigs_fn))

    contigs_order = major_contigs_order + all_contigs_list
    contigs_order_list = sorted(all_contigs_list,
                                key=lambda x: contigs_order.index(x))

    row_count = 0
    header = []
    no_vcf_output = True
    need_write_header = True

    # only compress intermediate gvcf using lz4 output and keep final gvcf in bgzip format
    output_bgzip_gvcf = vcf_fn_suffix == '.gvcf'
    compress_gvcf = 'gvcf' in vcf_fn_suffix
    if compress_gvcf:
        lz4_path = subprocess.run("which lz4",
                                  stdout=subprocess.PIPE,
                                  shell=True).stdout.decode().rstrip()
        compress_gvcf = True if lz4_path != "" else False
    is_lz4_format = compress_gvcf
    compress_gvcf_output = compress_gvcf and not output_bgzip_gvcf
    if compress_gvcf_output:
        write_fpo = open(output_fn, 'w')
        write_proc = subprocess_popen(shlex.split("lz4 -c"),
                                      stdin=subprocess.PIPE,
                                      stdout=write_fpo,
                                      stderr=subprocess.DEVNULL)
        output = write_proc.stdin
    else:
        output = open(output_fn, 'w')

    for contig in contigs_order_list:
        contig_dict = defaultdict(str)
        contig_vcf_fns = [fn for fn in all_files if contig in fn]
        for vcf_fn in contig_vcf_fns:
            file = os.path.join(input_dir, vcf_fn)
            if is_lz4_format:
                read_proc = subprocess_popen(shlex.split("{} {}".format(
                    "lz4 -fdc", file)),
                                             stderr=subprocess.DEVNULL)
                fn = read_proc.stdout
            else:
                fn = open(file, 'r')
            for row in fn:
                row_count += 1
                if row[0] == '#':
                    # skip phasing command line only occur with --enable_phasing, otherwise would lead to hap.py evaluation failure
                    if row.startswith('##commandline='):
                        continue
                    if row not in header:
                        header.append(row)
                    continue
                # use the first vcf header
                columns = row.strip().split(maxsplit=3)
                ctg_name, pos = columns[0], columns[1]
                # skip vcf file sharing same contig prefix, ie, chr1 and chr11
                if ctg_name != contig:
                    break
                contig_dict[int(pos)] = row
                no_vcf_output = False
            fn.close()
            if is_lz4_format:
                read_proc.wait()
        if need_write_header and len(header):
            if output_bgzip_gvcf:
                header = check_header_in_gvcf(header=header,
                                              contigs_list=all_contigs_list)
            output.write(''.join(header))
            need_write_header = False
        all_pos = sorted(contig_dict.keys())
        for pos in all_pos:
            output.write(contig_dict[pos])

    if compress_gvcf_output:
        write_proc.stdin.close()
        write_proc.wait()
        write_fpo.close()
        return
    else:
        output.close()

    if row_count == 0:
        print(
            log_warning("[WARNING] No vcf file found, output empty vcf file"))
        output_header(output_fn=output_fn,
                      reference_file_path=ref_fn,
                      sample_name=sample_name)
        compress_index_vcf(output_fn)
        print_calling_step(output_fn=output_fn)
        return
    if no_vcf_output:
        output_header(output_fn=output_fn,
                      reference_file_path=ref_fn,
                      sample_name=sample_name)
        print(log_warning("[WARNING] No variant found, output empty vcf file"))
        compress_index_vcf(output_fn)
        print_calling_step(output_fn=output_fn)
        return

    if vcf_fn_suffix == ".tmp.gvcf":
        return
    if vcf_fn_suffix == ".gvcf":
        print("[INFO] Need some time to compress and index GVCF file...")
    compress_index_vcf(output_fn)
コード例 #6
0
ファイル: MergeVcf.py プロジェクト: HKU-BAL/Clair3
def MergeVcf_illumina(args):
    # region vcf merge for illumina, as read realignment will make candidate varaints shift and missing.
    bed_fn_prefix = args.bed_fn_prefix
    output_fn = args.output_fn
    full_alignment_vcf_fn = args.full_alignment_vcf_fn
    pileup_vcf_fn = args.pileup_vcf_fn  # true vcf var
    contig_name = args.ctgName
    QUAL = args.qual
    bed_fn = None
    if not os.path.exists(bed_fn_prefix):
        exit(
            log_error("[ERROR] Input directory: {} not exists!").format(
                bed_fn_prefix))

    all_files = os.listdir(bed_fn_prefix)
    all_files = [
        item for item in all_files if item.startswith(contig_name + '.')
    ]
    if len(all_files) != 0:
        bed_fn = os.path.join(bed_fn_prefix,
                              "full_aln_regions_{}".format(contig_name))
        with open(bed_fn, 'w') as output_file:
            for file in all_files:
                with open(os.path.join(bed_fn_prefix, file)) as f:
                    output_file.write(f.read())

    is_haploid_precise_mode_enabled = args.haploid_precise
    is_haploid_sensitive_mode_enabled = args.haploid_sensitive
    print_ref = args.print_ref_calls

    tree = bed_tree_from(bed_file_path=bed_fn,
                         padding=param.no_of_positions,
                         contig_name=contig_name)
    unzip_process = subprocess_popen(
        shlex.split("gzip -fdc %s" % (pileup_vcf_fn)))
    output_dict = {}
    header = []
    pileup_count = 0
    for row in unzip_process.stdout:
        if row[0] == '#':
            header.append(row)
            continue
        columns = row.strip().split()
        ctg_name = columns[0]
        if contig_name != None and ctg_name != contig_name:
            continue
        pos = int(columns[1])
        qual = float(columns[5])
        pass_bed = is_region_in(tree, ctg_name, pos)
        ref_base, alt_base = columns[3], columns[4]
        is_reference = (alt_base == "." or ref_base == alt_base)
        if is_haploid_precise_mode_enabled:
            row = update_haploid_precise_genotype(columns)
        if is_haploid_sensitive_mode_enabled:
            row = update_haploid_sensitive_genotype(columns)

        if not pass_bed:
            if not is_reference:
                row = MarkLowQual(row, QUAL, qual)
                output_dict[pos] = row
                pileup_count += 1
            elif print_ref:
                output_dict[pos] = row
                pileup_count += 1

    unzip_process.stdout.close()
    unzip_process.wait()

    realigned_vcf_unzip_process = subprocess_popen(
        shlex.split("gzip -fdc %s" % (full_alignment_vcf_fn)))
    realiged_read_num = 0
    for row in realigned_vcf_unzip_process.stdout:
        if row[0] == '#':
            continue
        columns = row.strip().split()
        ctg_name = columns[0]
        if contig_name != None and ctg_name != contig_name:
            continue

        pos = int(columns[1])
        qual = float(columns[5])
        ref_base, alt_base = columns[3], columns[4]
        is_reference = (alt_base == "." or ref_base == alt_base)

        if is_haploid_precise_mode_enabled:
            row = update_haploid_precise_genotype(columns)
        if is_haploid_sensitive_mode_enabled:
            row = update_haploid_sensitive_genotype(columns)

        if is_region_in(tree, ctg_name, pos):
            if not is_reference:
                row = MarkLowQual(row, QUAL, qual)
                output_dict[pos] = row
                realiged_read_num += 1
            elif print_ref:
                output_dict[pos] = row
                realiged_read_num += 1

    logging.info('[INFO] Pileup positions variants proceeded in {}: {}'.format(
        contig_name, pileup_count))
    logging.info(
        '[INFO] Realigned positions variants proceeded in {}: {}'.format(
            contig_name, realiged_read_num))
    realigned_vcf_unzip_process.stdout.close()
    realigned_vcf_unzip_process.wait()

    with open(output_fn, 'w') as output_file:
        output_list = header + [
            output_dict[pos] for pos in sorted(output_dict.keys())
        ]
        output_file.write(''.join(output_list))
コード例 #7
0
ファイル: CheckEnvs.py プロジェクト: HKU-BAL/Clair3
def check_python_path():
    python_path = subprocess.run("which python", stdout=subprocess.PIPE, shell=True).stdout.decode().rstrip()
    sys.exit(log_error("[ERROR] Current python execution path: {}".format(python_path)))
コード例 #8
0
ファイル: CheckEnvs.py プロジェクト: HKU-BAL/Clair3
def CheckEnvs(args):
    basedir = os.path.dirname(__file__)
    bam_fn = file_path_from(args.bam_fn, exit_on_not_found=True)
    ref_fn = file_path_from(args.ref_fn, exit_on_not_found=True)
    fai_fn = file_path_from(args.ref_fn, suffix=".fai", exit_on_not_found=True, sep='.')
    bai_fn = file_path_from(args.bam_fn, suffix=".bai", sep='.')
    csi_fn = file_path_from(args.bam_fn, suffix=".csi", sep='.')
    if bai_fn is None and csi_fn is None:
        sys.exit(log_error("[ERROR] Neither Bam index file {} or {} not found".format(file_name + '.bai', file_name + '.csi')))
    bed_fn = file_path_from(args.bed_fn)
    vcf_fn = file_path_from(args.vcf_fn)
    tree = bed_tree_from(bed_file_path=bed_fn)

    # create temp file folder
    output_fn_prefix = args.output_fn_prefix
    output_fn_prefix = folder_path_from(output_fn_prefix, create_not_found=True)
    log_path = folder_path_from(os.path.join(output_fn_prefix, 'log'), create_not_found=True)
    tmp_file_path = folder_path_from(os.path.join(output_fn_prefix, 'tmp'), create_not_found=True)
    split_bed_path = folder_path_from(os.path.join(tmp_file_path, 'split_beds'),
                                      create_not_found=True) if bed_fn or vcf_fn else None
    pileup_vcf_path = folder_path_from(os.path.join(tmp_file_path, 'pileup_output'), create_not_found=True)
    merge_vcf_path = folder_path_from(os.path.join(tmp_file_path, 'merge_output'), create_not_found=True)
    phase_output_path = folder_path_from(os.path.join(tmp_file_path, 'phase_output'), create_not_found=True)
    gvcf_temp_output_path = folder_path_from(os.path.join(tmp_file_path, 'gvcf_tmp_output'), create_not_found=True)
    full_alignment_output_path = folder_path_from(os.path.join(tmp_file_path, 'full_alignment_output'),
                                                  create_not_found=True)
    phase_vcf_path = folder_path_from(os.path.join(phase_output_path, 'phase_vcf'), create_not_found=True)
    phase_bam_path = folder_path_from(os.path.join(phase_output_path, 'phase_bam'), create_not_found=True)
    candidate_bed_path = folder_path_from(os.path.join(full_alignment_output_path, 'candidate_bed'),
                                          create_not_found=True)

    # environment parameters
    pypy = args.pypy
    samtools = args.samtools
    whatshap = args.whatshap
    parallel = args.parallel
    qual = args.qual
    var_pct_full = args.var_pct_full
    ref_pct_full = args.ref_pct_full
    snp_min_af = args.snp_min_af
    indel_min_af = args.indel_min_af
    min_contig_size = args.min_contig_size
    sample_name = args.sampleName
    contig_name_list = os.path.join(tmp_file_path, 'CONTIGS')
    chunk_list = os.path.join(tmp_file_path, 'CHUNK_LIST')

    legal_range_from(param_name="qual", x=qual, min_num=0, exit_out_of_range=True)
    legal_range_from(param_name="var_pct_full", x=var_pct_full, min_num=0, max_num=1, exit_out_of_range=True)
    legal_range_from(param_name="ref_pct_full", x=ref_pct_full, min_num=0, max_num=1, exit_out_of_range=True)
    legal_range_from(param_name="snp_min_af", x=snp_min_af, min_num=0, max_num=1, exit_out_of_range=True)
    legal_range_from(param_name="indel_min_af", x=indel_min_af, min_num=0, max_num=1, exit_out_of_range=True)
    if ref_pct_full > 0.3:
        print(log_warning(
            "[WARNING] For efficiency, we use a maximum 30% reference candidates for full-alignment calling"))
    tool_version = {
        'python': LooseVersion(sys.version.split()[0]),
        'pypy': check_version(tool=pypy, pos=0, is_pypy=True),
        'samtools': check_version(tool=samtools, pos=1),
        'whatshap': check_version(tool=whatshap, pos=1),
        'parallel': check_version(tool=parallel, pos=2),
    }
    check_tools_version(tool_version, required_tool_version)

    is_include_all_contigs = args.include_all_ctgs
    is_bed_file_provided = bed_fn is not None
    is_known_vcf_file_provided = vcf_fn is not None

    if is_known_vcf_file_provided and is_bed_file_provided:
        sys.exit(log_error("[ERROR] Please provide either --vcf_fn or --bed_fn only"))

    if is_known_vcf_file_provided:
        know_vcf_contig_set = split_extend_vcf(vcf_fn=vcf_fn, output_fn=split_bed_path)

    ctg_name_list = args.ctg_name
    is_ctg_name_list_provided = ctg_name_list is not None and ctg_name_list != "EMPTY"
    contig_set = set(ctg_name_list.split(',')) if is_ctg_name_list_provided else set()

    if is_ctg_name_list_provided and is_bed_file_provided:
        print(log_warning("[WARNING] both --ctg_name and --bed_fn provided, will only proceed contigs in intersection"))

    if is_ctg_name_list_provided and is_known_vcf_file_provided:
        print(log_warning("[WARNING] both --ctg_name and --vcf_fn provided, will only proceed contigs in intersection"))

    if is_ctg_name_list_provided:

        contig_set = contig_set.intersection(
            set(tree.keys())) if is_bed_file_provided else contig_set

        contig_set = contig_set.intersection(
            know_vcf_contig_set) if is_known_vcf_file_provided else contig_set
    else:
        contig_set = contig_set.union(
            set(tree.keys())) if is_bed_file_provided else contig_set

        contig_set = contig_set.union(
            know_vcf_contig_set) if is_known_vcf_file_provided else contig_set

    # if each split region is too small(long) for given default chunk num, will increase(decrease) the total chunk num
    default_chunk_num = args.chunk_num
    DEFAULT_CHUNK_SIZE = args.chunk_size
    contig_length_list = []
    contig_chunk_num = {}

    with open(fai_fn, 'r') as fai_fp:
        for row in fai_fp:
            columns = row.strip().split("\t")
            contig_name, contig_length = columns[0], int(columns[1])
            if not is_include_all_contigs and (
            not (is_bed_file_provided or is_ctg_name_list_provided or is_known_vcf_file_provided)) and str(
                    contig_name) not in major_contigs:
                continue

            if is_bed_file_provided and contig_name not in tree:
                continue
            if is_ctg_name_list_provided and contig_name not in contig_set:
                continue
            if is_known_vcf_file_provided and contig_name not in contig_set:
                continue

            if min_contig_size > 0 and contig_length < min_contig_size:
                print(log_warning(
                    "[WARNING] {} contig length {} is smaller than minimum contig size {}, will skip it!".format(contig_name, contig_length, min_contig_size)))
                if contig_name in contig_set:
                    contig_set.remove(contig_name)
                continue
            contig_set.add(contig_name)
            contig_length_list.append(contig_length)
            chunk_num = int(
                contig_length / float(DEFAULT_CHUNK_SIZE)) + 1 if contig_length % DEFAULT_CHUNK_SIZE else int(
                contig_length / float(DEFAULT_CHUNK_SIZE))
            contig_chunk_num[contig_name] = max(chunk_num, 1)

    if default_chunk_num > 0:
        min_chunk_length = min(contig_length_list) / float(default_chunk_num)
        max_chunk_length = max(contig_length_list) / float(default_chunk_num)

    contigs_order = major_contigs_order + list(contig_set)

    sorted_contig_list = sorted(list(contig_set), key=lambda x: contigs_order.index(x))

    found_contig = True
    if not len(contig_set):
        if is_bed_file_provided:
            all_contig_in_bed = ' '.join(list(tree.keys()))
            print(log_warning("[WARNING] No contig intersection found by --bed_fn, contigs in BED {}: {}".format(bed_fn, all_contig_in_bed)))
        if is_known_vcf_file_provided:
            all_contig_in_vcf = ' '.join(list(know_vcf_contig_set))
            print(log_warning("[WARNING] No contig intersection found by --vcf_fn, contigs in VCF {}: {}".format(vcf_fn, all_contig_in_vcf)))
        if is_ctg_name_list_provided:
            all_contig_in_ctg_name = ' '.join(ctg_name_list.split(','))
            print(log_warning("[WARNING] No contig intersection found by --ctg_name, contigs in contigs list: {}".format(all_contig_in_ctg_name)))
        found_contig = False
    else:
        for c in sorted_contig_list:
            if c not in contig_chunk_num:
                print(log_warning(("[WARNING] Contig {} given but not found in reference fai file".format(c))))

        # check contig in bam have support reads
        sorted_contig_list, found_contig = check_contig_in_bam(bam_fn=bam_fn, sorted_contig_list=sorted_contig_list,
                                                               samtools=samtools)

    if not found_contig:
        # output header only to merge_output.vcf.gz
        output_fn = os.path.join(output_fn_prefix, "merge_output.vcf")
        output_header(output_fn=output_fn, reference_file_path=ref_fn, sample_name=sample_name)
        compress_index_vcf(output_fn)
        print(log_warning(
            ("[WARNING] No contig intersection found, output header only in {}").format(output_fn + ".gz")))
        with open(contig_name_list, 'w') as output_file:
            output_file.write("")
        return

    print('[INFO] Call variant in contigs: {}'.format(' '.join(sorted_contig_list)))
    print('[INFO] Chunk number for each contig: {}'.format(
        ' '.join([str(contig_chunk_num[c]) for c in sorted_contig_list])))

    if default_chunk_num > 0 and max_chunk_length > MAX_CHUNK_LENGTH:
        print(log_warning(
            '[WARNING] Current maximum chunk size {} is larger than default maximum chunk size {}, You may set a larger chunk_num by setting --chunk_num=$ for better parallelism.'.format(
                min_chunk_length, MAX_CHUNK_LENGTH)))

    elif default_chunk_num > 0 and min_chunk_length < MIN_CHUNK_LENGTH:
        print(log_warning(
            '[WARNING] Current minimum chunk size {} is smaller than default minimum chunk size {}, You may set a smaller chunk_num by setting --chunk_num=$.'.format(
                min_chunk_length, MIN_CHUNK_LENGTH)))

    if default_chunk_num == 0 and max(contig_length_list) < DEFAULT_CHUNK_SIZE / 5:
        print(log_warning(
            '[WARNING] Current maximum contig length {} is much smaller than default chunk size {}, You may set a smaller chunk size by setting --chunk_size=$ for better parallelism.'.format(
                max(contig_length_list), DEFAULT_CHUNK_SIZE)))

    if is_bed_file_provided:
        split_extend_bed(bed_fn=bed_fn, output_fn=split_bed_path, contig_set=contig_set)

    with open(contig_name_list, 'w') as output_file:
        output_file.write('\n'.join(sorted_contig_list))

    with open(chunk_list, 'w') as output_file:
        for contig_name in sorted_contig_list:
            chunk_num = contig_chunk_num[contig_name]
            for chunk_id in range(1, chunk_num + 1):
                output_file.write(contig_name + ' ' + str(chunk_id) + ' ' + str(chunk_num) + '\n')
コード例 #9
0
def call_variants_from_cffi(args, output_config, output_utilities):
    use_gpu = args.use_gpu
    if use_gpu:
        import tritonclient.grpc as tritongrpcclient
        server_url = 'localhost:8001'
        try:
            triton_client = tritongrpcclient.InferenceServerClient(
                url=server_url, verbose=False)
        except Exception as e:
            print("channel creation failed: " + str(e))
            sys.exit()
    else:
        os.environ["CUDA_VISIBLE_DEVICES"] = ""

    global param
    if args.pileup:
        import shared.param_p as param
        if use_gpu:
            model_name = 'pileup'
            input_dtype = 'INT32'
        else:
            from clair3.model import Clair3_P
            m = Clair3_P(add_indel_length=args.add_indel_length, predict=True)
    else:
        import shared.param_f as param
        if use_gpu:
            model_name = 'alignment'
            input_dtype = 'INT8'
        else:
            from clair3.model import Clair3_F
            m = Clair3_F(add_indel_length=args.add_indel_length, predict=True)

    if not use_gpu:
        m.load_weights(args.chkpnt_fn)
    output_utilities.gen_output_file()
    output_utilities.output_header()
    chunk_id = args.chunk_id - 1 if args.chunk_id else None  # 1-base to 0-base
    chunk_num = args.chunk_num
    full_alignment_mode = not args.pileup

    logging.info("Calling variants ...")
    variant_call_start_time = time()

    batch_output_method = batch_output
    total = 0

    if args.pileup:
        from preprocess.CreateTensorPileupFromCffi import CreateTensorPileup as CT
    else:
        from preprocess.CreateTensorFullAlignmentFromCffi import CreateTensorFullAlignment as CT

    tensor, all_position, all_alt_info = CT(args)

    def tensor_generator_from(tensor, all_position, all_alt_info):
        total_data = len(tensor)
        assert total_data == len(all_alt_info)
        assert total_data == len(all_position)
        batch_size = param.predictBatchSize
        total_chunk = total_data // batch_size if total_data % batch_size == 0 else total_data // batch_size + 1
        for chunk_id in range(total_chunk):
            chunk_start = chunk_id * batch_size
            chunk_end = (
                chunk_id +
                1) * batch_size if chunk_id < total_chunk - 1 else total_data
            yield (tensor[chunk_start:chunk_end],
                   all_position[chunk_start:chunk_end],
                   all_alt_info[chunk_start:chunk_end])

    tensor_generator = tensor_generator_from(tensor, all_position,
                                             all_alt_info)

    for (X, position, alt_info_list) in tensor_generator:
        total += len(X)

        if use_gpu:
            inputs = []
            outputs = []

            inputs.append(
                tritongrpcclient.InferInput('input_1', X.shape, input_dtype))
            outputs.append(tritongrpcclient.InferRequestedOutput('output_1'))

            inputs[0].set_data_from_numpy(X)
            results = triton_client.infer(model_name=model_name,
                                          inputs=inputs,
                                          outputs=outputs)
            Y = results.as_numpy('output_1')
        else:
            Y = m.predict_on_batch(X)

        batch_output_method(position, alt_info_list, Y, output_config,
                            output_utilities)

    if chunk_id is not None:
        logging.info(
            "Total processed positions in {} (chunk {}/{}) : {}".format(
                args.ctgName, chunk_id + 1, chunk_num, total))
    elif full_alignment_mode:
        try:
            chunk_infos = args.call_fn.split('.')[-2]
            c_id, c_num = chunk_infos.split('_')
            c_id = int(c_id) + 1  # 0-index to 1-index
            logging.info(
                "Total processed positions in {} (chunk {}/{}) : {}".format(
                    args.ctgName, c_id, c_num, total))
        except:
            logging.info("Total processed positions in {} : {}".format(
                args.ctgName, total))
    else:
        logging.info("Total processed positions in {} : {}".format(
            args.ctgName, total))

    if full_alignment_mode and total == 0:
        logging.info(
            log_error("[ERROR] No full-alignment output for file {}/{}".format(
                args.ctgName, args.call_fn)))

    logging.info("Total time elapsed: %.2f s" %
                 (time() - variant_call_start_time))

    output_utilities.close_opened_files()
    # remove file if on variant in output
    if os.path.exists(args.call_fn):
        for row in open(args.call_fn, 'r'):
            if row[0] != '#':
                return
        logging.info(
            "[INFO] No vcf output for file {}, remove empty file".format(
                args.call_fn))
        os.remove(args.call_fn)