예제 #1
0
파일: CheckEnvs.py 프로젝트: HKU-BAL/Clair3
def check_contig_in_bam(bam_fn, sorted_contig_list, samtools):
    bai_process = subprocess_popen(shlex.split("{} idxstats {}".format(samtools, bam_fn)))
    contig_with_read_support_set = set()
    for row_id, row in enumerate(bai_process.stdout):
        row = row.split('\t')
        if len(row) != 4:
            continue
        contig_name, contig_length, mapped_reads, unmapped_reads = row
        if contig_name not in sorted_contig_list:
            continue
        if int(mapped_reads) > 0:
            contig_with_read_support_set.add(contig_name)
    for contig_name in sorted_contig_list:
        if contig_name not in contig_with_read_support_set:
            print(log_warning(
                "[WARNING] Contig name {} provided but no mapped reads in BAM, skip!".format(contig_name)))
    filtered_sorted_contig_list = [item for item in sorted_contig_list if item in contig_with_read_support_set]

    found_contig = True
    if len(filtered_sorted_contig_list) == 0:
        found_contig = False
        print(log_warning(
            "[WARNING] No mapped reads support in BAM for provided contigs set {}".format(
                ' '.join(sorted_contig_list))))
    return filtered_sorted_contig_list, found_contig
예제 #2
0
def reference_sequence_from(samtools_execute_command, fasta_file_path,
                            regions):
    refernce_sequences = []
    region_value_for_faidx = " ".join(regions)

    samtools_faidx_process = subprocess_popen(
        shlex.split("{} faidx {} {}".format(samtools_execute_command,
                                            fasta_file_path,
                                            region_value_for_faidx)))
    while True:
        row = samtools_faidx_process.stdout.readline()
        is_finish_reading_output = row == '' and samtools_faidx_process.poll(
        ) is not None
        if is_finish_reading_output:
            break
        if row:
            refernce_sequences.append(row.rstrip())

    # first line is reference name ">xxxx", need to be ignored
    reference_sequence = "".join(refernce_sequences[1:])

    # uppercase for masked sequences
    reference_sequence = reference_sequence.upper()

    samtools_faidx_process.stdout.close()
    samtools_faidx_process.wait()
    if samtools_faidx_process.returncode != 0:
        return None

    return reference_sequence
예제 #3
0
def bed_tree_from(bed_file_path):
    """
    0-based interval tree [start, end)
    """

    tree = {}
    if bed_file_path is None:
        return tree

    unzip_process = subprocess_popen(
        shlex.split("gzip -fdc %s" % (bed_file_path)))
    while True:
        row = unzip_process.stdout.readline()
        is_finish_reading_output = row == '' and unzip_process.poll(
        ) is not None
        if is_finish_reading_output:
            break

        if row:
            columns = row.strip().split()

            ctg_name = columns[0]
            if ctg_name not in tree:
                tree[ctg_name] = IntervalTree()

            ctg_start, ctg_end = int(columns[1]), int(columns[2])
            if ctg_start == ctg_end:
                ctg_end += 1

            tree[ctg_name].addi(ctg_start, ctg_end)

    unzip_process.stdout.close()
    unzip_process.wait()

    return tree
예제 #4
0
def variants_map_from(variant_file_path):
    """
    variants map with 1-based position as key
    """
    if variant_file_path == None:
        return {}

    variants_map = {}
    f = subprocess_popen(shlex.split("gzip -fdc %s" % (variant_file_path)))

    while True:
        row = f.stdout.readline()
        is_finish_reading_output = row == '' and f.poll() is not None
        if is_finish_reading_output:
            break

        if row:
            columns = row.split()
            ctg_name, position_str = columns[0], columns[1]
            key = ctg_name + ":" + position_str

            variants_map[key] = True

    f.stdout.close()
    f.wait()

    return variants_map
예제 #5
0
파일: CheckEnvs.py 프로젝트: HKU-BAL/Clair3
def split_extend_vcf(vcf_fn, output_fn):
    expand_region_size = param.no_of_positions
    output_ctg_dict = defaultdict(list)
    unzip_process = subprocess_popen(shlex.split("gzip -fdc %s" % (vcf_fn)))

    for row_id, row in enumerate(unzip_process.stdout):
        if row[0] == '#':
            continue
        columns = row.strip().split(maxsplit=3)
        ctg_name = columns[0]

        center_pos = int(columns[1])
        ctg_start, ctg_end = center_pos - 1, center_pos
        if ctg_start < 0:
            sys.exit(
                log_error("[ERROR] Invalid VCF input in {}-th row {} {} {}".format(row_id + 1, ctg_name, center_pos)))
        if ctg_start - expand_region_size < 0:
            continue
        expand_ctg_start = ctg_start - expand_region_size
        expand_ctg_end = ctg_end + expand_region_size

        output_ctg_dict[ctg_name].append(
            ' '.join([ctg_name, str(expand_ctg_start), str(expand_ctg_end)]))

    for key, value in output_ctg_dict.items():
        ctg_output_fn = os.path.join(output_fn, key)
        with open(ctg_output_fn, 'w') as output_file:
            output_file.write('\n'.join(value))

    unzip_process.stdout.close()
    unzip_process.wait()

    know_vcf_contig_set = set(list(output_ctg_dict.keys()))

    return know_vcf_contig_set
예제 #6
0
파일: CheckEnvs.py 프로젝트: HKU-BAL/Clair3
def split_extend_bed(bed_fn, output_fn, contig_set=None):
    expand_region_size = param.no_of_positions
    output_ctg_dict = defaultdict(list)
    unzip_process = subprocess_popen(shlex.split("gzip -fdc %s" % (bed_fn)))
    for row_id, row in enumerate(unzip_process.stdout):
        if row[0] == '#':
            continue
        columns = row.strip().split()
        ctg_name = columns[0]
        if contig_set and ctg_name not in contig_set:
            continue

        ctg_start, ctg_end = int(columns[1]), int(columns[2])

        if ctg_end < ctg_start or ctg_start < 0 or ctg_end < 0:
            sys.exit(log_error(
                "[ERROR] Invalid BED input in {}-th row {} {} {}".format(row_id + 1, ctg_name, ctg_start, ctg_end)))
        expand_ctg_start = max(0, ctg_start - expand_region_size)
        expand_ctg_end = max(0, ctg_end + expand_region_size)
        output_ctg_dict[ctg_name].append(
            ' '.join([ctg_name, str(expand_ctg_start), str(expand_ctg_end)]))

    for key, value in output_ctg_dict.items():
        ctg_output_fn = os.path.join(output_fn, key)
        with open(ctg_output_fn, 'w') as output_file:
            output_file.write('\n'.join(value))

    unzip_process.stdout.close()
    unzip_process.wait()
예제 #7
0
파일: utils.py 프로젝트: HKU-BAL/Clair3
def variant_map_from(var_fn, tree, is_tree_empty):
    Y = {}
    truth_alt_dict = {}
    miss_variant_set = set()
    if var_fn is None:
        return Y, miss_variant_set, truth_alt_dict

    f = subprocess_popen(shlex.split("gzip -fdc %s" % (var_fn)))
    for row in f.stdout:
        if row[0] == "#":
            continue
        columns = row.strip().split()
        ctg_name, position_str, ref_base, alt_base, genotype1, genotype2 = columns
        key = ctg_name + ":" + position_str
        if genotype1 == '-1' or genotype2 == '-1':
            miss_variant_set.add(key)
            continue
        if not (is_tree_empty or is_region_in(tree, ctg_name, int(position_str))):
            continue

        Y[key] = output_labels_from_vcf_columns(columns)
        ref_base_list, alt_base_list = decode_alt(ref_base, alt_base)
        truth_alt_dict[int(position_str)] = (ref_base_list, alt_base_list)
    f.stdout.close()
    f.wait()
    return Y, miss_variant_set, truth_alt_dict
예제 #8
0
 def read_input(self):
     if self.compress:
         self.read_proc = subprocess_popen(shlex.split("{} {}".format(
             LZ4_DECOMPRESS, self.input_path)),
                                           stderr=subprocess.DEVNULL)
         self.reader = self.read_proc.stdout
     else:
         self.reader = open(self.input_path, 'r')
     return self.reader
예제 #9
0
def FiterHeteSnpPhasing(args):

    """
    Filter heterozygous snp variant for phasing, currently, we only filter snp variant with low quality socore as low
    quality variant contains more false positive variant that would lead to a larger minimum error correction loss.
    """
    qual_fn = args.qual_fn if args.qual_fn is not None else 'phase_qual'
    vcf_fn = args.vcf_fn
    var_pct_full = args.var_pct_full
    contig_name = args.ctgName
    split_folder = args.split_folder
    variant_dict = defaultdict(str)
    qual_set = defaultdict(int)
    found_qual_cut_off = False
    header = []

    #try to find the global quality cut off:
    f_qual = os.path.join(split_folder, qual_fn)
    if os.path.exists(f_qual):
        phase_qual_cut_off = float(open(f_qual, 'r').read().rstrip())
        found_qual_cut_off = True

    unzip_process = subprocess_popen(shlex.split("gzip -fdc %s" % (vcf_fn)))
    for row in unzip_process.stdout:
        row = row.rstrip()
        if row[0] == '#':
            header.append(row + '\n')
            continue
        columns = row.strip().split()
        ctg_name = columns[0]
        if contig_name and contig_name != ctg_name:
            continue
        pos = int(columns[1])
        ref_base = columns[3]
        alt_base = columns[4]
        genotype = columns[9].split(':')[0].replace('|', '/')

        if len(ref_base) == 1 and len(alt_base) == 1:
            if genotype == '0/1' or genotype=='1/0':
                variant_dict[pos] = row
                qual = float(columns[5])
                qual_set[pos] = qual

    if found_qual_cut_off:
        remove_low_qual_list = [[k,v] for k,v in qual_set.items() if v < phase_qual_cut_off ]
    else:
        remove_low_qual_list = sorted(qual_set.items(), key=lambda x: x[1])[:int(var_pct_full * len(qual_set))]
    for pos, qual in remove_low_qual_list:
        del variant_dict[pos]

    print ('[INFO] Total heterozygous SNP positions selected: {}: {}'.format(contig_name, len(variant_dict)))

    f = open(os.path.join(split_folder, '{}.vcf'.format(contig_name)), 'w')
    f.write(''.join(header))
    for key,row in sorted(variant_dict.items(), key=lambda x: x[0]):
        f.write(row +'\n')
    f.close()
예제 #10
0
파일: GetTruth.py 프로젝트: pythseq/Clair
def GetBase(chromosome, position, ref_fn):
    fp = subprocess_popen(
        shlex.split("samtools faidx %s %s:%s-%s" %
                    (ref_fn, chromosome, position, position)))
    for line in fp.stdout:
        if line[0] == ">":
            continue
        else:
            return line.strip()
def samtools_view_process_from(ctg_name, ctg_start, ctg_end, samtools,
                               bam_file_path):
    have_start_and_end_position = ctg_start != None and ctg_end != None
    region_str = ("%s:%d-%d" % (ctg_name, ctg_start, ctg_end)
                  ) if have_start_and_end_position else ctg_name

    return subprocess_popen(
        shlex.split("%s view -F %d %s %s" %
                    (samtools, param.SAMTOOLS_VIEW_FILTER_FLAG, bam_file_path,
                     region_str)))
예제 #12
0
    def write_output(self):
        if self.compress:
            self.write_fpo = open(self.output_path, 'w')
            self.write_proc = subprocess_popen(shlex.split(LZ4_COMPRESS),
                                               stdin=subprocess.PIPE,
                                               stdout=self.write_fpo,
                                               stderr=subprocess.DEVNULL)
            self.writer = self.write_proc.stdin

        else:
            self.writer = open(self.output_path, 'w')
        return self.writer
예제 #13
0
def split_extend_bed(args):

    """
    Split bed file regions according to the contig name and extend bed region with no_of_positions =
    flankingBaseNum + 1 + flankingBaseNum, which allow samtools mpileup submodule to scan the flanking windows.
    """

    bed_fn = args.bed_fn
    output_fn = args.output_fn
    contig_name = args.ctgName
    region_start = args.ctgStart
    region_end = args.ctgEnd
    expand_region_size = args.expand_region_size
    if bed_fn is None:
        return
    output = []
    unzip_process = subprocess_popen(shlex.split("gzip -fdc %s" % (bed_fn)))
    pre_end, pre_start = -1, -1

    for row in unzip_process.stdout:

        if row[0] == '#':
            continue
        columns = row.strip().split()
        ctg_name = columns[0]
        if contig_name != None and ctg_name != contig_name:
            continue
        ctg_start, ctg_end = int(columns[1]), int(columns[2])
        if region_start and ctg_end < region_start:
            continue
        if region_end and ctg_start > region_end:
            break
        if pre_start == -1:
            pre_start = ctg_start - expand_region_size
            pre_end = ctg_end + expand_region_size
            continue
        if pre_end >= ctg_start - expand_region_size:
            pre_end = ctg_end + expand_region_size
            continue
        else:
            output.append(' '.join([contig_name, str(pre_start), str(pre_end)]))
            pre_start = ctg_start - expand_region_size
            pre_end = ctg_end + expand_region_size

    with open(output_fn, 'w') as output_file:
        output_file.write('\n'.join(output))

    unzip_process.stdout.close()
    unzip_process.wait()
예제 #14
0
def bed_tree_from(bed_file_path, expand_region=None, contig_name=None, bed_ctg_start=None, bed_ctg_end=None,
                  return_bed_region=False, padding=None):
    """
    0-based interval tree [start, end)
    """

    tree = {}
    if bed_file_path is None:
        if return_bed_region:
            return tree, None, None
        return tree

    bed_start, bed_end = float('inf'), 0
    unzip_process = subprocess_popen(shlex.split("gzip -fdc %s" % (bed_file_path)))
    for row_id, row in enumerate(unzip_process.stdout):
        if row[0] == '#':
            continue
        columns = row.strip().split()

        ctg_name = columns[0]
        if contig_name != None and ctg_name != contig_name:
            continue
        if ctg_name not in tree:
            tree[ctg_name] = IntervalTree()

        ctg_start, ctg_end = int(columns[1]), int(columns[2])

        if ctg_end < ctg_start or ctg_start < 0 or ctg_end < 0:
            sys.exit("[ERROR] Invalid bed input in {}-th row {} {} {}".format(row_id+1, ctg_name, ctg_start, ctg_end))

        if bed_ctg_start and bed_ctg_end:
            if ctg_end < bed_ctg_start or ctg_start > bed_ctg_end:
                continue
        if padding:
            ctg_start += padding
            ctg_end -= padding
        bed_start = min(ctg_start, bed_start)
        bed_end = max(ctg_end, bed_end)
        if ctg_start == ctg_end:
            ctg_end += 1

        tree[ctg_name].addi(ctg_start, ctg_end)

    unzip_process.stdout.close()
    unzip_process.wait()
    if return_bed_region:
        return tree, bed_start, bed_end
    return tree
예제 #15
0
def reference_result_from(
    ctg_name,
    ctg_start,
    ctg_end,
    samtools,
    reference_file_path,
    expand_reference_region
):
    region_str = ""
    reference_start, reference_end = None, None
    have_start_and_end_positions = ctg_start != None and ctg_end != None
    if have_start_and_end_positions:
        reference_start, reference_end = ctg_start - expand_reference_region, ctg_end + expand_reference_region
        reference_start = 1 if reference_start < 1 else reference_start
        region_str = "%s:%d-%d" % (ctg_name, reference_start, reference_end)
    else:
        region_str = ctg_name

    faidx_process = subprocess_popen(shlex.split("%s faidx %s %s" % (samtools, reference_file_path, region_str)),)
    if faidx_process is None:
        return None

    reference_name = None
    reference_sequences = []
    for row in faidx_process.stdout:
        if reference_name is None:
            reference_name = row.rstrip().lstrip(">") or ""
        else:
            reference_sequences.append(row.rstrip())
    reference_sequence = "".join(reference_sequences)

    # uppercase for masked sequences
    reference_sequence = reference_sequence.upper()

    faidx_process.stdout.close()
    faidx_process.wait()

    return ReferenceResult(
        name=reference_name,
        start=reference_start,
        end=reference_end,
        sequence=reference_sequence,
        is_faidx_process_have_error=faidx_process.returncode != 0,
    )
예제 #16
0
파일: utils.py 프로젝트: pythseq/Clair
def variant_map_from(var_fn, tree, is_tree_empty):
    Y = {}
    if var_fn is None:
        return Y

    f = subprocess_popen(shlex.split("gzip -fdc %s" % (var_fn)))
    for row in f.stdout:
        columns = row.split()
        ctg_name, position_str = columns[0], columns[1]

        if not (is_tree_empty or is_region_in(tree, ctg_name, int(position_str))):
            continue

        key = ctg_name + ":" + position_str
        Y[key] = output_labels_from_vcf_columns(columns)

    f.stdout.close()
    f.wait()
    return Y
예제 #17
0
파일: utils.py 프로젝트: Yufeng98/Clair
def tensor_generator_from(tensor_file_path, batch_size):
    if tensor_file_path != "PIPE":
        f = subprocess_popen(shlex.split("gzip -fdc %s" % (tensor_file_path)))
        fo = f.stdout
    else:
        fo = sys.stdin

    processed_tensors = 0

    def item_from(row):
        # print(row)
        columns = row.split()
        return columns[:-input_tensor_size], np.array(
            columns[-input_tensor_size:], dtype=np.float32)

    for batch in batches_from(fo, item_from=item_from, batch_size=batch_size):
        # tmp_time = time()
        tensors = np.empty((batch_size, input_tensor_size), dtype=np.float32)
        non_tensor_infos = []
        for non_tensor_info, tensor in batch:
            _, _, sequence = non_tensor_info
            if sequence[param.flankingBaseNum] not in BASE2NUM:
                continue
            tensors[len(non_tensor_infos)] = tensor
            non_tensor_infos.append(non_tensor_info)

        current_batch_size = len(non_tensor_infos)
        X = np.reshape(tensors,
                       (batch_size, no_of_positions, matrix_row, matrix_num))
        for i in range(1, matrix_num):
            X[:current_batch_size, :, :, i] -= X[:current_batch_size, :, :, 0]

        processed_tensors += current_batch_size
        # print("Processed %d tensors takes %.4f s" % (processed_tensors, time() - tmp_time), file=sys.stderr)
        print("Processed %d tensors" % processed_tensors, file=sys.stderr)
        if current_batch_size <= 0:
            continue
        yield X[:current_batch_size], non_tensor_infos[:current_batch_size]

    if tensor_file_path != "PIPE":
        fo.close()
        f.wait()
def candidate_position_generator_from(candidate_file_path, ctg_start, ctg_end,
                                      is_consider_left_edge, flanking_base_num,
                                      begin_to_end):
    is_read_file_from_standard_input = candidate_file_path == "PIPE"
    if is_read_file_from_standard_input:
        candidate_file_path_output = sys.stdin
    else:
        candidate_file_path_process = subprocess_popen(
            shlex.split("gzip -fdc %s" % (candidate_file_path)))
        candidate_file_path_output = candidate_file_path_process.stdout

    is_ctg_region_provided = ctg_start is not None and ctg_end is not None

    for row in candidate_file_path_output:
        row = row.split()
        position = int(row[1])  # 1-based position

        if is_ctg_region_provided and not (ctg_start <= position <= ctg_end):
            continue

        if is_consider_left_edge:
            # i is 0-based
            for i in range(position - (flanking_base_num + 1),
                           position + (flanking_base_num + 1)):
                if i not in begin_to_end:
                    begin_to_end[i] = [(position + (flanking_base_num + 1),
                                        position)]
                else:
                    begin_to_end[i].append(
                        (position + (flanking_base_num + 1), position))
        else:
            begin_to_end[position - (flanking_base_num + 1)] = [
                (position + (flanking_base_num + 1), position)
            ]

        yield position

    if not is_read_file_from_standard_input:
        candidate_file_path_output.close()
        candidate_file_path_process.wait()
    yield -1
예제 #19
0
def Run(args):
    tree = bed_tree_from(bed_file_path=args.bed_fn)

    logging.info("Counting the number of Truth Variants in %s ..." % args.tensor_var_fn)
    v = 0
    d = {}
    f = subprocess_popen(shlex.split("gzip -fdc %s" % (args.tensor_var_fn)))
    for row in f.stdout:
        row = row.strip().split()
        ctgName = row[0]
        pos = int(row[1])
        key = "-".join([ctgName, str(pos)])
        v += 1
        d[key] = 1
    f.stdout.close()
    f.wait()

    logging.info("%d Truth Variants" % v)
    t = v * args.amp
    logging.info("%d non-variants to be picked" % t)

    logging.info("Counting the number of usable non-variants in %s ..." % args.tensor_can_fn)
    c = 0
    f = subprocess_popen(shlex.split("gzip -fdc %s" % (args.tensor_can_fn)))
    for row in f.stdout:
        row = row.strip().split()
        ctgName = row[0]
        pos = int(row[1])
        if args.bed_fn != None:
            if not is_region_in(tree, ctgName, pos):
                continue
        key = "-".join([ctgName, str(pos)])
        if key in d:
            continue
        c += 1
    f.stdout.close()
    f.wait()
    logging.info("%d usable non-variant" % c)

    r = float(t) / c
    r = r if r <= 1 else 1
    logging.info("%.2f of all non-variants are selected" % r)

    o1 = 0
    o2 = 0
    output_fpo = open(args.output_fn, "wb")
    output_fh = subprocess_popen(shlex.split("gzip -c"), stdin=PIPE, stdout=output_fpo)
    f = subprocess_popen(shlex.split("gzip -fdc %s" % (args.tensor_var_fn)))
    for row in f.stdout:
        row = row.strip()
        output_fh.stdin.write(row)
        output_fh.stdin.write("\n")
        o1 += 1
    f.stdout.close()
    f.wait()
    f = subprocess_popen(shlex.split("gzip -fdc %s" % (args.tensor_can_fn)))
    for row in f.stdout:
        rawRow = row.strip()
        row = rawRow.split()
        ctgName = row[0]
        pos = int(row[1])
        if args.bed_fn != None:
            if not is_region_in(tree, ctgName, pos):
                continue
        key = "-".join([ctgName, str(pos)])
        if key in d:
            continue
        if random() < r:
            output_fh.stdin.write(rawRow)
            output_fh.stdin.write("\n")
            o2 += 1
    f.stdout.close()
    f.wait()
    output_fh.stdin.close()
    output_fh.wait()
    output_fpo.close()
    logging.info("%.2f/%.2f Truth Variants/Non-variants outputed" % (o1, o2))
예제 #20
0
파일: MergeVcf.py 프로젝트: HKU-BAL/Clair3
def MergeVcf_illumina(args):
    # region vcf merge for illumina, as read realignment will make candidate varaints shift and missing.
    bed_fn_prefix = args.bed_fn_prefix
    output_fn = args.output_fn
    full_alignment_vcf_fn = args.full_alignment_vcf_fn
    pileup_vcf_fn = args.pileup_vcf_fn  # true vcf var
    contig_name = args.ctgName
    QUAL = args.qual
    bed_fn = None
    if not os.path.exists(bed_fn_prefix):
        exit(
            log_error("[ERROR] Input directory: {} not exists!").format(
                bed_fn_prefix))

    all_files = os.listdir(bed_fn_prefix)
    all_files = [
        item for item in all_files if item.startswith(contig_name + '.')
    ]
    if len(all_files) != 0:
        bed_fn = os.path.join(bed_fn_prefix,
                              "full_aln_regions_{}".format(contig_name))
        with open(bed_fn, 'w') as output_file:
            for file in all_files:
                with open(os.path.join(bed_fn_prefix, file)) as f:
                    output_file.write(f.read())

    is_haploid_precise_mode_enabled = args.haploid_precise
    is_haploid_sensitive_mode_enabled = args.haploid_sensitive
    print_ref = args.print_ref_calls

    tree = bed_tree_from(bed_file_path=bed_fn,
                         padding=param.no_of_positions,
                         contig_name=contig_name)
    unzip_process = subprocess_popen(
        shlex.split("gzip -fdc %s" % (pileup_vcf_fn)))
    output_dict = {}
    header = []
    pileup_count = 0
    for row in unzip_process.stdout:
        if row[0] == '#':
            header.append(row)
            continue
        columns = row.strip().split()
        ctg_name = columns[0]
        if contig_name != None and ctg_name != contig_name:
            continue
        pos = int(columns[1])
        qual = float(columns[5])
        pass_bed = is_region_in(tree, ctg_name, pos)
        ref_base, alt_base = columns[3], columns[4]
        is_reference = (alt_base == "." or ref_base == alt_base)
        if is_haploid_precise_mode_enabled:
            row = update_haploid_precise_genotype(columns)
        if is_haploid_sensitive_mode_enabled:
            row = update_haploid_sensitive_genotype(columns)

        if not pass_bed:
            if not is_reference:
                row = MarkLowQual(row, QUAL, qual)
                output_dict[pos] = row
                pileup_count += 1
            elif print_ref:
                output_dict[pos] = row
                pileup_count += 1

    unzip_process.stdout.close()
    unzip_process.wait()

    realigned_vcf_unzip_process = subprocess_popen(
        shlex.split("gzip -fdc %s" % (full_alignment_vcf_fn)))
    realiged_read_num = 0
    for row in realigned_vcf_unzip_process.stdout:
        if row[0] == '#':
            continue
        columns = row.strip().split()
        ctg_name = columns[0]
        if contig_name != None and ctg_name != contig_name:
            continue

        pos = int(columns[1])
        qual = float(columns[5])
        ref_base, alt_base = columns[3], columns[4]
        is_reference = (alt_base == "." or ref_base == alt_base)

        if is_haploid_precise_mode_enabled:
            row = update_haploid_precise_genotype(columns)
        if is_haploid_sensitive_mode_enabled:
            row = update_haploid_sensitive_genotype(columns)

        if is_region_in(tree, ctg_name, pos):
            if not is_reference:
                row = MarkLowQual(row, QUAL, qual)
                output_dict[pos] = row
                realiged_read_num += 1
            elif print_ref:
                output_dict[pos] = row
                realiged_read_num += 1

    logging.info('[INFO] Pileup positions variants proceeded in {}: {}'.format(
        contig_name, pileup_count))
    logging.info(
        '[INFO] Realigned positions variants proceeded in {}: {}'.format(
            contig_name, realiged_read_num))
    realigned_vcf_unzip_process.stdout.close()
    realigned_vcf_unzip_process.wait()

    with open(output_fn, 'w') as output_file:
        output_list = header + [
            output_dict[pos] for pos in sorted(output_dict.keys())
        ]
        output_file.write(''.join(output_list))
예제 #21
0
파일: MergeVcf.py 프로젝트: HKU-BAL/Clair3
def MergeVcf(args):
    """
    Merge pileup and full alignment vcf output. We merge the low quality score pileup candidates
    recalled by full-alignment model with high quality score pileup output.
    """

    output_fn = args.output_fn
    full_alignment_vcf_fn = args.full_alignment_vcf_fn
    pileup_vcf_fn = args.pileup_vcf_fn  # true vcf var
    contig_name = args.ctgName
    QUAL = args.qual
    is_haploid_precise_mode_enabled = args.haploid_precise
    is_haploid_sensitive_mode_enabled = args.haploid_sensitive
    print_ref = args.print_ref_calls
    full_alignment_vcf_unzip_process = subprocess_popen(
        shlex.split("gzip -fdc %s" % (full_alignment_vcf_fn)))

    full_alignment_output = []
    full_alignment_output_set = set()
    header = []

    for row in full_alignment_vcf_unzip_process.stdout:
        if row[0] == '#':
            header.append(row)
            continue
        columns = row.strip().split()
        ctg_name = columns[0]
        if contig_name != None and ctg_name != contig_name:
            continue
        pos = int(columns[1])
        qual = float(columns[5])
        ref_base, alt_base = columns[3], columns[4]
        is_reference = (alt_base == "." or ref_base == alt_base)

        full_alignment_output_set.add((ctg_name, pos))

        if is_haploid_precise_mode_enabled:
            row = update_haploid_precise_genotype(columns)
        if is_haploid_sensitive_mode_enabled:
            row = update_haploid_sensitive_genotype(columns)

        if not is_reference:
            row = MarkLowQual(row, QUAL, qual)
            full_alignment_output.append((pos, row))

        elif print_ref:
            full_alignment_output.append((pos, row))

    full_alignment_vcf_unzip_process.stdout.close()
    full_alignment_vcf_unzip_process.wait()

    pileup_vcf_unzip_process = subprocess_popen(
        shlex.split("gzip -fdc %s" % (pileup_vcf_fn)))

    output_file = open(output_fn, 'w')
    output_file.write(''.join(header))

    def pileup_vcf_generator_from(pileup_vcf_unzip_process):
        pileup_row_count = 0
        for row in pileup_vcf_unzip_process.stdout:
            if row[0] == '#':
                continue

            columns = row.rstrip().split('\t')
            ctg_name = columns[0]
            if contig_name and contig_name != ctg_name:
                continue
            pos = int(columns[1])
            qual = float(columns[5])
            ref_base, alt_base = columns[3], columns[4]
            is_reference = (alt_base == "." or ref_base == alt_base)

            if (ctg_name, pos) in full_alignment_output_set:
                continue

            if is_haploid_precise_mode_enabled:
                row = update_haploid_precise_genotype(columns)
            if is_haploid_sensitive_mode_enabled:
                row = update_haploid_sensitive_genotype(columns)

            if not is_reference:
                row = MarkLowQual(row, QUAL, qual)
                pileup_row_count += 1
                yield (pos, row)
            elif print_ref:
                pileup_row_count += 1
                yield (pos, row)

        logging.info('[INFO] Pileup variants processed in {}: {}'.format(
            contig_name, pileup_row_count))

    pileup_vcf_generator = pileup_vcf_generator_from(
        pileup_vcf_unzip_process=pileup_vcf_unzip_process)
    full_alignment_vcf_generator = iter(full_alignment_output)
    for vcf_infos in heapq.merge(full_alignment_vcf_generator,
                                 pileup_vcf_generator):
        if len(vcf_infos) != 2:
            continue
        pos, row = vcf_infos
        output_file.write(row)

    logging.info('[INFO] Full-alignment variants processed in {}: {}'.format(
        contig_name, len(full_alignment_output)))

    pileup_vcf_unzip_process.stdout.close()
    pileup_vcf_unzip_process.wait()
    output_file.close()
예제 #22
0
def reads_realignment(args):
    bed_file_path = args.full_aln_regions
    extend_bed = args.extend_bed
    fasta_file_path = args.ref_fn
    ctg_name = args.ctgName
    ctg_start = args.ctgStart
    ctg_end = args.ctgEnd
    chunk_id = args.chunk_id - 1 if args.chunk_id else None  # 1-base to 0-base
    chunk_num = args.chunk_num
    samtools_execute_command = args.samtools
    bam_file_path = args.bam_fn
    minMQ = args.minMQ
    min_coverage = args.minCoverage
    is_bed_file_given = bed_file_path is not None
    is_ctg_name_given = ctg_name is not None
    read_fn = args.read_fn

    global test_pos
    test_pos = None
    if is_bed_file_given:
        candidate_file_path_process = subprocess_popen(
            shlex.split("gzip -fdc %s" % (bed_file_path)))
        candidate_file_path_output = candidate_file_path_process.stdout

        ctg_start, ctg_end = float('inf'), 0
        for row in candidate_file_path_output:
            row = row.rstrip().split('\t')
            if row[0] != ctg_name: continue
            position = int(row[1]) + 1
            end = int(row[2]) + 1
            ctg_start = min(position, ctg_start)
            ctg_end = max(end, ctg_end)

        candidate_file_path_output.close()
        candidate_file_path_process.wait()

    if chunk_id is not None:
        fai_fn = file_path_from(fasta_file_path,
                                suffix=".fai",
                                exit_on_not_found=True,
                                sep='.')
        contig_length = 0
        with open(fai_fn, 'r') as fai_fp:
            for row in fai_fp:
                columns = row.strip().split("\t")

                contig_name = columns[0]
                if contig_name != ctg_name:
                    continue
                contig_length = int(columns[1])
        chunk_size = contig_length // chunk_num + 1 if contig_length % chunk_num else contig_length // chunk_num
        ctg_start = chunk_size * chunk_id  # 0-base to 1-base
        ctg_end = ctg_start + chunk_size

    is_ctg_range_given = is_ctg_name_given and ctg_start is not None and ctg_end is not None

    # 1-based regions [start, end] (start and end inclusive)
    ref_regions = []
    reads_regions = []
    reference_start, reference_end = None, None

    if is_ctg_range_given:
        extend_start = ctg_start - max_window_size
        extend_end = ctg_end + max_window_size
        reads_regions.append(
            region_from(ctg_name=ctg_name,
                        ctg_start=extend_start,
                        ctg_end=extend_end))
        reference_start, reference_end = ctg_start - param.expandReferenceRegion, ctg_end + param.expandReferenceRegion
        reference_start = 1 if reference_start < 1 else reference_start
        ref_regions.append(
            region_from(ctg_name=ctg_name,
                        ctg_start=reference_start,
                        ctg_end=reference_end))
    elif is_ctg_name_given:
        reads_regions.append(region_from(ctg_name=ctg_name))
        ref_regions.append(region_from(ctg_name=ctg_name))
        reference_start = 1

    reference_sequence = reference_sequence_from(
        samtools_execute_command=samtools_execute_command,
        fasta_file_path=fasta_file_path,
        regions=ref_regions)
    if reference_sequence is None or len(reference_sequence) == 0:
        sys.exit(
            "[ERROR] Failed to load reference sequence from file ({}).".format(
                fasta_file_path))

    tree = bed_tree_from(bed_file_path=bed_file_path)
    if is_bed_file_given and ctg_name not in tree:
        sys.exit("[ERROR] ctg_name({}) not exists in bed file({}).".format(
            ctg_name, bed_file_path))

    bed_option = ' -L {}'.format(extend_bed) if extend_bed else ""
    bed_option = ' -L {}'.format(
        bed_file_path) if is_bed_file_given else bed_option
    mq_option = ' -q {}'.format(minMQ) if minMQ > 0 else ""
    samtools_view_command = "{} view -h {} {}".format(
        samtools_execute_command, bam_file_path,
        " ".join(reads_regions)) + mq_option + bed_option
    samtools_view_process = subprocess_popen(
        shlex.split(samtools_view_command))

    if read_fn and read_fn == 'PIPE':
        save_file_fp = TensorStdout(sys.stdout)
    elif read_fn:
        save_file_fp = subprocess_popen(shlex.split(
            "{} view -bh - -o {}".format(
                samtools_execute_command,
                read_fn + ('.{}_{}'.format(ctg_start, ctg_end)
                           if is_ctg_range_given and not test_pos else ""))),
                                        stdin=PIPE,
                                        stdout=PIPE)

    reference_start_0_based = 0 if reference_start is None else (
        reference_start - 1)

    header = []
    add_header = False
    aligned_reads = defaultdict()
    pileup = defaultdict(lambda: {"X": 0})
    samtools_view_generator = samtools_view_generator_from(
        samtools_view_process=samtools_view_process,
        aligned_reads=aligned_reads,
        pileup=pileup,
        ctg_name=ctg_name,
        reference_sequence=reference_sequence,
        reference_start_0_based=reference_start_0_based,
        header=header)
    pre_aligned_reads = defaultdict()

    while True:
        chunk_start, chunk_end = next(samtools_view_generator)
        if chunk_start is None:
            break
        if not add_header:
            save_file_fp.stdin.write(''.join(header))
            add_header = True

        variant_allele_list = [[position, pileup[position]["X"]]
                               for position in list(pileup.keys())]
        candidate_position_list = [
            (position, support_allele_count)
            for position, support_allele_count in variant_allele_list
            if support_allele_count >= min_coverage
            and position >= chunk_start - region_expansion_in_bp -
            1 and position <= chunk_end + region_expansion_in_bp - 1
        ]
        candidate_position_list.sort(key=(lambda x: x[0]))

        if not len(aligned_reads) or not len(candidate_position_list):
            continue
        if len(pre_aligned_reads):  # update the read in previous chunk
            for read_name, read in pre_aligned_reads.items():
                aligned_reads[read_name] = read

        region_dict = {}
        split_region_size = max_window_size
        region_tree = IntervalTree()
        for split_idx in range((chunk_end - chunk_start) // split_region_size):
            split_start = chunk_start + split_idx * split_region_size - region_expansion_in_bp - 1
            split_end = split_start + split_region_size + region_expansion_in_bp * 2 + 1
            region_dict[(split_start, split_end)] = []
            region_tree.addi(split_start, split_end)
        for candidate_position in candidate_position_list:
            for region in region_tree.at(candidate_position[0]):
                region_dict[(region.begin,
                             region.end)].append(candidate_position[0])

        for key, split_candidate_position_list in region_dict.items():
            start_pos, end_pos = None, None
            windows = []
            read_windows_dict = {}
            for pos in split_candidate_position_list:
                if start_pos is None:
                    start_pos = pos
                    end_pos = pos

                elif pos > end_pos + 2 * min_windows_distance:
                    temp_window = (start_pos - min_windows_distance,
                                   end_pos + min_windows_distance)
                    windows.append(temp_window)
                    read_windows_dict[temp_window] = []

                    start_pos = pos
                    end_pos = pos
                else:
                    end_pos = pos

            if start_pos is not None:
                temp_window = (start_pos - min_windows_distance,
                               end_pos + min_windows_distance)
                windows.append(temp_window)
                read_windows_dict[temp_window] = []
            if not len(windows): continue
            windows = sorted(windows, key=lambda x: x[0])
            max_window_end = max([item[1] for item in windows])
            # #find read windows overlap_pair
            for read_name, read in aligned_reads.items():
                if read.read_start > max_window_end: continue
                argmax_window_idx = find_max_overlap_index(
                    (read.read_start, read.read_end), windows)
                if argmax_window_idx is not None:
                    read_windows_dict[windows[argmax_window_idx]].append(
                        read_name)

            # realignment
            for window in windows:
                start_pos, end_pos = window
                if end_pos - start_pos > max_window_size:  # or (window not in need_align_windows_set):
                    continue

                ref_start = start_pos - reference_start_0_based
                ref_end = end_pos - reference_start_0_based
                ref = reference_sequence[ref_start:ref_end]
                reads = []
                low_base_quality_pos_list = []
                # pypy binding with ctypes for DBG building
                for read_name in read_windows_dict[window]:
                    read = aligned_reads[read_name]
                    if (
                            not read.graph_mq
                    ) or read.read_start > end_pos or read.read_end < start_pos:
                        continue
                    reads.append(read.seq)
                    low_base_quality_pos_list.append(' '.join([
                        str(bq_idx)
                        for bq_idx, item in enumerate(read.base_quality)
                        if int(item) < 15
                    ]))
                totoal_read_num = len(reads)
                c_ref = byte(ref)
                read_list1 = ctypes.c_char_p(byte(','.join(reads)))
                low_base_quality_pos_array = ctypes.c_char_p(
                    byte(','.join(low_base_quality_pos_list)))

                dbg.get_consensus.restype = ctypes.POINTER(DBGPointer)
                dbg.get_consensus.argtypes = [
                    ctypes.c_char_p, ctypes.c_char_p, ctypes.c_char_p,
                    ctypes.c_int
                ]

                dbg_p = dbg.get_consensus(ctypes.c_char_p(c_ref), read_list1,
                                          low_base_quality_pos_array,
                                          totoal_read_num)

                c_consensus, consensus_size = dbg_p.contents.consensus, dbg_p.contents.consensus_size
                consensus = [
                    item.decode() for item in c_consensus[:consensus_size]
                ]

                if len(consensus) == 0 or len(
                        consensus) == 1 and consensus[0] == ref or len(
                            read_windows_dict[window]) == 0:
                    continue
                min_read_start = min([
                    aligned_reads[item].read_start
                    for item in read_windows_dict[window]
                ])
                max_read_end = max([
                    aligned_reads[item].read_end
                    for item in read_windows_dict[window]
                ])
                tmp_ref_start = max(
                    0,
                    min(min_read_start, start_pos) - expand_align_ref_region)
                tmp_ref_end = max(max_read_end,
                                  end_pos) + expand_align_ref_region

                ref_prefix = get_reference_seq(reference_sequence,
                                               tmp_ref_start, start_pos,
                                               reference_start_0_based)
                ref_center = get_reference_seq(reference_sequence, start_pos,
                                               end_pos,
                                               reference_start_0_based)
                if tmp_ref_end < end_pos:
                    continue
                ref_suffix = get_reference_seq(reference_sequence, end_pos,
                                               tmp_ref_end,
                                               reference_start_0_based)
                ref_seq = ref_prefix + ref_center + ref_suffix

                # pypy binding with ctypes for realignment
                read_name_list = []
                totoal_read_num = min(max_region_reads_num,
                                      len(read_windows_dict[window]))
                seq_list = (ctypes.c_char_p * totoal_read_num)()
                position_list = (ctypes.c_int * totoal_read_num)()
                cigars_list = (ctypes.c_char_p * totoal_read_num)()

                for read_idx, read_name in enumerate(
                        read_windows_dict[window]):
                    read = aligned_reads[read_name]
                    if read_idx >= totoal_read_num: break
                    seq_list[read_idx] = byte(read.seq.upper())
                    position_list[read_idx] = read.read_start
                    cigars_list[read_idx] = byte(read.cigar)
                    read_name_list.append(read_name)
                haplotypes_list = [
                    ref_prefix + cons + ref_suffix for cons in consensus
                ]
                haplotypes = ' '.join(haplotypes_list)

                realigner.realign_reads.restype = ctypes.POINTER(StructPointer)
                realigner.realign_reads.argtypes = [
                    ctypes.c_char_p * totoal_read_num,
                    ctypes.c_int * totoal_read_num,
                    ctypes.c_char_p * totoal_read_num, ctypes.c_char_p,
                    ctypes.c_char_p, ctypes.c_int, ctypes.c_int, ctypes.c_int,
                    ctypes.c_int
                ]

                realigner_p = realigner.realign_reads(
                    seq_list, position_list, cigars_list,
                    ctypes.c_char_p(byte(ref_seq)),
                    ctypes.c_char_p(byte(haplotypes)), tmp_ref_start,
                    len(ref_prefix), len(ref_suffix), totoal_read_num)

                realign_positions, realign_cigars = realigner_p.contents.position, realigner_p.contents.cigar_string
                read_position_list = realign_positions[:totoal_read_num]
                read_cigar_list = [
                    item.decode() for item in realign_cigars[:totoal_read_num]
                ]

                if len(read_name_list):
                    for read_id, read_name in enumerate(read_name_list):
                        if read_cigar_list[read_id] == "" or (
                                aligned_reads[read_name].cigar
                                == read_cigar_list[read_id]
                                and aligned_reads[read_name].read_start
                                == read_position_list[read_id]):
                            continue
                        # update cigar and read start position
                        aligned_reads[read_name].test_pos = test_pos
                        realignment_start = read_position_list[read_id]
                        realignment_cigar = read_cigar_list[read_id].replace(
                            'X', 'M')
                        if realignment_cigar == aligned_reads[
                                read_name].cigar and realignment_start == aligned_reads[
                                    read_name].read_start:
                            continue
                        aligned_reads[read_name].set_realignment_info(
                            split_start, read_cigar_list[read_id],
                            read_position_list[read_id])

                realigner.free_memory.restype = ctypes.POINTER(ctypes.c_void_p)
                realigner.free_memory.argtypes = [
                    ctypes.POINTER(StructPointer), ctypes.c_int
                ]
                realigner.free_memory(realigner_p, totoal_read_num)
        # # realignment end

        if read_fn:
            sorted_key = sorted([(key, item.best_pos)
                                 for key, item in aligned_reads.items()],
                                key=lambda x: x[1])
            for read_name, read_start in sorted_key:
                read = aligned_reads[read_name]
                if read_start < chunk_start - region_expansion_in_bp - max_window_size:  # safe distance for save reads
                    phasing_info = 'HP:i:{}'.format(
                        read.phasing) if read.phasing else ""
                    pass
                    read_str = '\t'.join([
                        read_name, read.flag, ctg_name,
                        str(read_start + 1),
                        str(read.mapping_quality), read.best_cigar, read.RNEXT,
                        read.PNEXT, read.TLEN, read.seq, read.raw_base_quality,
                        phasing_info
                    ])
                    save_file_fp.stdin.write(read_str + '\n')
                    del aligned_reads[read_name]
                for pile_pos in list(pileup.keys()):
                    if pile_pos < chunk_start - region_expansion_in_bp - max_window_size:
                        del pileup[pile_pos]

    if read_fn and aligned_reads:
        sorted_key = sorted([(key, item.best_pos)
                             for key, item in aligned_reads.items()],
                            key=lambda x: x[1])
        for read_name, read_start in sorted_key:
            read = aligned_reads[read_name]
            phasing_info = 'HP:i:{}'.format(
                read.phasing) if read.phasing else ""
            read_str = '\t'.join([
                read_name, read.flag, ctg_name,
                str(read_start + 1),
                str(read.mapping_quality), read.best_cigar, read.RNEXT,
                read.PNEXT, read.TLEN, read.seq, read.raw_base_quality,
                phasing_info
            ])
            save_file_fp.stdin.write(read_str + '\n')
            del aligned_reads[read_name]
        if read_fn != 'PIPE':
            save_file_fp.stdin.close()
            save_file_fp.wait()
    samtools_view_process.stdout.close()
    samtools_view_process.wait()

    if test_pos:
        save_file_fp = subprocess_popen(shlex.split("samtools index {}".format(
            read_fn + ('.{}_{}'.format(ctg_start, ctg_end)
                       if is_ctg_range_given and not test_pos else ""))),
                                        stdin=PIPE,
                                        stdout=PIPE)
        save_file_fp.stdin.close()
        save_file_fp.wait()
예제 #23
0
def CreateTensorPileup(args):
    """
    Create pileup tensor for pileup model training or calling.
    Use slide window to scan the whole candidate regions, keep all candidates over specific minimum allelic frequency
    and minimum depth, use samtools mpileup to store pileup info for pileup tensor generation. Only scan candidate
    regions once, we could directly get all variant candidates directly.
    """
    ctg_start = args.ctgStart
    ctg_end = args.ctgEnd
    fasta_file_path = args.ref_fn
    ctg_name = args.ctgName
    samtools_execute_command = args.samtools
    bam_file_path = args.bam_fn
    chunk_id = args.chunk_id - 1 if args.chunk_id else None  # 1-base to 0-base
    chunk_num = args.chunk_num
    tensor_can_output_path = args.tensor_can_fn
    minimum_af_for_candidate = args.min_af
    minimum_snp_af_for_candidate = args.snp_min_af
    minimum_indel_af_for_candidate = args.indel_min_af
    min_coverage = args.minCoverage
    platform = args.platform
    confident_bed_fn = args.bed_fn
    is_confident_bed_file_given = confident_bed_fn is not None
    alt_fn = args.indel_fn
    extend_bed = args.extend_bed
    is_extend_bed_file_given = extend_bed is not None
    min_mapping_quality = args.minMQ
    min_base_quality = args.minBQ
    fast_mode = args.fast_mode
    vcf_fn = args.vcf_fn
    is_known_vcf_file_provided = vcf_fn is not None
    call_snp_only = args.call_snp_only

    global test_pos
    test_pos = None

    # 1-based regions [start, end] (start and end inclusive)
    ref_regions = []
    reads_regions = []
    known_variants_set = set()
    tree, bed_start, bed_end = bed_tree_from(bed_file_path=extend_bed,
                                             contig_name=ctg_name,
                                             return_bed_region=True)

    fai_fn = file_path_from(fasta_file_path,
                            suffix=".fai",
                            exit_on_not_found=True,
                            sep='.')
    if not is_confident_bed_file_given and chunk_id is not None:
        contig_length = 0
        with open(fai_fn, 'r') as fai_fp:
            for row in fai_fp:
                columns = row.strip().split("\t")

                contig_name = columns[0]
                if contig_name != ctg_name:
                    continue
                contig_length = int(columns[1])
        chunk_size = contig_length // chunk_num + 1 if contig_length % chunk_num else contig_length // chunk_num
        ctg_start = chunk_size * chunk_id  # 0-base to 1-base
        ctg_end = ctg_start + chunk_size

    if is_confident_bed_file_given and chunk_id is not None:
        chunk_size = (bed_end - bed_start) // chunk_num + 1 if (
            bed_end - bed_start) % chunk_num else (bed_end -
                                                   bed_start) // chunk_num
        ctg_start = bed_start + 1 + chunk_size * chunk_id  # 0-base to 1-base
        ctg_end = ctg_start + chunk_size

    if is_known_vcf_file_provided and chunk_id is not None:
        known_variants_list = vcf_candidates_from(vcf_fn=vcf_fn,
                                                  contig_name=ctg_name)
        total_variants_size = len(known_variants_list)
        chunk_variants_size = total_variants_size // chunk_num if total_variants_size % chunk_num == 0 else total_variants_size // chunk_num + 1
        chunk_start_pos = chunk_id * chunk_variants_size
        known_variants_set = set(
            known_variants_list[chunk_start_pos:chunk_start_pos +
                                chunk_variants_size])
        if len(known_variants_set) == 0:
            return
        ctg_start, ctg_end = min(known_variants_set), max(known_variants_set)

    is_ctg_name_given = ctg_name is not None
    is_ctg_range_given = is_ctg_name_given and ctg_start is not None and ctg_end is not None
    if is_ctg_range_given:
        extend_start = ctg_start - no_of_positions
        extend_end = ctg_end + no_of_positions
        reads_regions.append(
            region_from(ctg_name=ctg_name,
                        ctg_start=extend_start,
                        ctg_end=extend_end))
        reference_start, reference_end = ctg_start - param.expandReferenceRegion, ctg_end + param.expandReferenceRegion
        reference_start = 1 if reference_start < 1 else reference_start
        ref_regions.append(
            region_from(ctg_name=ctg_name,
                        ctg_start=reference_start,
                        ctg_end=reference_end))
    elif is_ctg_name_given:
        reads_regions.append(region_from(ctg_name=ctg_name))
        ref_regions.append(region_from(ctg_name=ctg_name))
        reference_start = 1

    reference_sequence = reference_sequence_from(
        samtools_execute_command=samtools_execute_command,
        fasta_file_path=fasta_file_path,
        regions=ref_regions)

    if reference_sequence is None or len(reference_sequence) == 0:
        sys.exit(
            log_error(
                "[ERROR] Failed to load reference sequence from file ({}).".
                format(fasta_file_path)))

    if is_confident_bed_file_given and ctg_name not in tree:
        sys.exit(
            log_error("[ERROR] ctg_name {} not exists in bed file({}).".format(
                ctg_name, confident_bed_fn)))

    # samtools mpileup options
    # reverse-del: deletion in forward/reverse strand were marked as '*'/'#'
    min_base_quality = 0 if args.gvcf else min_base_quality
    max_depth = param.max_depth_dict[
        args.platform] if args.platform else args.max_depth
    mq_option = ' --min-MQ {}'.format(min_mapping_quality)
    bq_option = ' --min-BQ {}'.format(min_base_quality)
    flags_option = ' --excl-flags {}'.format(param.SAMTOOLS_VIEW_FILTER_FLAG)
    max_depth_option = ' --max-depth {}'.format(max_depth)
    bed_option = ' -l {}'.format(
        extend_bed) if is_extend_bed_file_given else ""
    gvcf_option = ' -a' if args.gvcf else ""
    samtools_mpileup_process = subprocess_popen(
        shlex.split("{} mpileup  {} -r {} --reverse-del".format(
            samtools_execute_command,
            bam_file_path,
            " ".join(reads_regions),
        ) + mq_option + bq_option + bed_option + flags_option +
                    max_depth_option + gvcf_option))

    if tensor_can_output_path != "PIPE":
        tensor_can_fpo = open(tensor_can_output_path, "wb")
        tensor_can_fp = subprocess_popen(shlex.split("{} -c".format(
            param.zstd)),
                                         stdin=PIPE,
                                         stdout=tensor_can_fpo)
    else:
        tensor_can_fp = TensorStdout(sys.stdout)

    # whether save all alternative information, only for debug mode
    if alt_fn:
        alt_fp = open(alt_fn, 'w')

    pos_offset = 0
    pre_pos = -1
    tensor = [[]] * sliding_window_size
    candidate_position = []
    all_alt_dict = {}
    depth_dict = {}
    af_dict = {}

    # to generate gvcf, it is needed to record whole genome statistical information
    if args.gvcf:
        nonVariantCaller = variantInfoCalculator(
            gvcfWritePath=args.temp_file_dir,
            ref_path=args.ref_fn,
            bp_resolution=args.bp_resolution,
            ctgName=ctg_name,
            sample_name='.'.join(
                [args.sampleName, ctg_name,
                 str(ctg_start),
                 str(ctg_end)]),
            p_err=args.base_err,
            gq_bin_size=args.gq_bin_size)

    confident_bed_tree = bed_tree_from(bed_file_path=confident_bed_fn,
                                       contig_name=ctg_name,
                                       bed_ctg_start=extend_start,
                                       bed_ctg_end=extend_end)

    empty_pileup_flag = True
    for row in samtools_mpileup_process.stdout:
        empty_pileup_flag = False
        columns = row.strip().split('\t', maxsplit=5)
        pos = int(columns[1])
        pileup_bases = columns[4]
        reference_base = reference_sequence[pos - reference_start].upper()
        valid_reference_flag = True
        within_flag = True
        if args.gvcf:
            if not valid_reference_flag:
                nonVariantCaller.make_gvcf_online({}, push_current=True)
            if ctg_start != None and ctg_end != None:
                within_flag = pos >= ctg_start and pos < ctg_end
            elif ctg_start != None and ctg_end == None:
                within_flag = pos >= ctg_start
            elif ctg_start == None and ctg_end != None:
                within_flag = pos <= ctg_end
            else:
                within_flag = True
            if columns[3] == '0' and within_flag and valid_reference_flag:
                cur_site_info = {
                    'chr': columns[0],
                    'pos': pos,
                    'ref': reference_base,
                    'n_total': 0,
                    'n_ref': 0
                }
                nonVariantCaller.make_gvcf_online(cur_site_info)
                continue

        # start with a new region, clear all sliding windows cache, avoid memory occupation
        if pre_pos + 1 != pos:
            pos_offset = 0
            tensor = [[]] * sliding_window_size
            candidate_position = []
        pre_pos = pos

        # a condition to skip some positions creating tensor,but return allele summary
        # allele count function
        pileup_tensor, alt_dict, af, depth, pass_af, pileup_list, max_del_length = generate_tensor(
            pos=pos,
            pileup_bases=pileup_bases,
            reference_sequence=reference_sequence,
            reference_start=reference_start,
            reference_base=reference_base,
            minimum_af_for_candidate=minimum_af_for_candidate,
            minimum_snp_af_for_candidate=minimum_snp_af_for_candidate,
            minimum_indel_af_for_candidate=minimum_indel_af_for_candidate,
            platform=platform,
            fast_mode=fast_mode,
            call_snp_only=call_snp_only)
        if args.gvcf and within_flag and valid_reference_flag:
            cur_n_total = 0
            cur_n_ref = 0
            for _key, _value in pileup_list:
                if (_key == reference_base):
                    cur_n_ref = _value
                cur_n_total += _value

            cur_site_info = {
                'chr': columns[0],
                'pos': pos,
                'ref': reference_base,
                'n_total': cur_n_total,
                'n_ref': cur_n_ref
            }
            nonVariantCaller.make_gvcf_online(cur_site_info)

        pass_confident_bed = not is_confident_bed_file_given or is_region_in(
            tree=confident_bed_tree,
            contig_name=ctg_name,
            region_start=pos - 1,
            region_end=pos + max_del_length + 1)  # 0-based
        if (pass_confident_bed and reference_base in 'ACGT' and
            (pass_af and depth >= min_coverage)
                and not is_known_vcf_file_provided) or (
                    is_known_vcf_file_provided and pos in known_variants_set):
            candidate_position.append(pos)
            all_alt_dict[pos] = alt_dict
            depth_dict[pos] = depth
            af_dict[pos] = af
        tensor[pos_offset] = pileup_tensor

        # save pileup tensor for each candidate position with nearby flanking_base_num bp distance
        pos_offset = (pos_offset + 1) % sliding_window_size
        if len(candidate_position
               ) and pos - candidate_position[0] == flanking_base_num:
            center = candidate_position.pop(0)
            has_empty_tensor = sum([True for item in tensor if not len(item)])
            if not has_empty_tensor:
                depth = depth_dict[center]
                ref_seq = reference_sequence[center - (flanking_base_num) -
                                             reference_start:center +
                                             flanking_base_num + 1 -
                                             reference_start]
                concat_tensor = tensor[pos_offset:] + tensor[0:pos_offset]

                alt_info = str(depth) + '-' + ' '.join([
                    ' '.join([item[0], str(item[1])])
                    for item in list(all_alt_dict[center].items())
                ])
                l = "%s\t%d\t%s\t%s\t%s" % (
                    ctg_name, center, ref_seq, " ".join(
                        " ".join("%d" % x for x in innerlist)
                        for innerlist in concat_tensor), alt_info)
                tensor_can_fp.stdin.write(l)
                tensor_can_fp.stdin.write("\n")
                if alt_fn:
                    alt_info = ' '.join([
                        ' '.join([item[0], str(item[1])])
                        for item in list(all_alt_dict[center].items())
                    ])
                    alt_fp.write('\t'.join([
                        ctg_name + ' ' + str(center),
                        str(depth), alt_info,
                        str(af_dict[center])
                    ]) + '\n')
                del all_alt_dict[center], depth_dict[center], af_dict[center]

    if args.gvcf and len(nonVariantCaller.current_block) != 0:
        nonVariantCaller.write_to_gvcf_batch(nonVariantCaller.current_block,
                                             nonVariantCaller.cur_min_DP,
                                             nonVariantCaller.cur_raw_gq)

    if args.gvcf and empty_pileup_flag:
        nonVariantCaller.write_empty_pileup(ctg_name, ctg_start, ctg_end)
    if args.gvcf:
        nonVariantCaller.close_vcf_writer()

    samtools_mpileup_process.stdout.close()
    samtools_mpileup_process.wait()

    if tensor_can_output_path != "PIPE":
        tensor_can_fp.stdin.close()
        tensor_can_fp.wait()
        tensor_can_fpo.close()

    if alt_fn:
        alt_fp.close()
예제 #24
0
def Run(args):
    basedir = dirname(__file__)
    EVCBin = basedir + "/../clair.py ExtractVariantCandidates"
    GTBin = basedir + "/../clair.py GetTruth"
    CTBin = basedir + "/../clair.py CreateTensor"
    CVBin = basedir + "/../clair.py call_var"

    pypyBin = executable_command_string_from(args.pypy, exit_on_not_found=True)
    samtoolsBin = executable_command_string_from(args.samtools,
                                                 exit_on_not_found=True)

    chkpnt_fn = file_path_from(args.chkpnt_fn,
                               suffix=".meta",
                               exit_on_not_found=True)
    bam_fn = file_path_from(args.bam_fn, exit_on_not_found=True)
    ref_fn = file_path_from(args.ref_fn, exit_on_not_found=True)
    vcf_fn = file_path_from(args.vcf_fn)
    bed_fn = file_path_from(args.bed_fn)

    dcov = args.dcov
    call_fn = args.call_fn
    af_threshold = args.threshold
    minCoverage = int(args.minCoverage)
    sampleName = args.sampleName
    ctgName = args.ctgName
    if ctgName is None:
        sys.exit(
            "--ctgName must be specified. You can call variants on multiple chromosomes simultaneously."
        )

    stop_consider_left_edge = command_option_from(args.stop_consider_left_edge,
                                                  'stop_consider_left_edge')
    log_path = command_option_from(args.log_path,
                                   'log_path',
                                   option_value=args.log_path)
    pysam_for_all_indel_bases = command_option_from(
        args.pysam_for_all_indel_bases, 'pysam_for_all_indel_bases')
    haploid_precision_mode = command_option_from(args.haploid_precision,
                                                 'haploid_precision')
    haploid_sensitive_mode = command_option_from(args.haploid_sensitive,
                                                 'haploid_sensitive')
    output_for_ensemble = command_option_from(args.output_for_ensemble,
                                              'output_for_ensemble')
    pipe_line = command_option_from(args.pipe_line, 'pipe_line')
    store_loaded_mini_match = command_option_from(args.store_loaded_mini_match,
                                                  'store_loaded_mini_match')
    only_prediction = command_option_from(args.only_prediction,
                                          'only_prediction')
    debug = command_option_from(args.debug, 'debug')
    qual = command_option_from(args.qual, 'qual', option_value=args.qual)
    fast_plotting = command_option_from(args.fast_plotting, 'fast_plotting')

    ctgStart = None
    ctgEnd = None
    if args.ctgStart is not None and args.ctgEnd is not None and int(
            args.ctgStart) <= int(args.ctgEnd):
        ctgStart = CommandOption('ctgStart', args.ctgStart)
        ctgEnd = CommandOption('ctgEnd', args.ctgEnd)

    if args.threads is None:
        numCpus = multiprocessing.cpu_count()
    else:
        numCpus = args.threads if args.threads < multiprocessing.cpu_count(
        ) else multiprocessing.cpu_count()

    maxCpus = multiprocessing.cpu_count()
    _cpuSet = ",".join(
        str(x) for x in random.sample(range(0, maxCpus), numCpus))

    taskSet = "taskset -c %s" % (_cpuSet)
    try:
        subprocess.check_output("which %s" % ("taskset"), shell=True)
    except:
        taskSet = ""

    if args.delay > 0:
        delay = random.randrange(0, args.delay)
        print("Delay %d seconds before starting variant calling ..." % (delay),
              file=sys.stderr)
        sleep(delay)

    extract_variant_candidate_command_options = [
        pypyBin, EVCBin,
        CommandOption('bam_fn', bam_fn),
        CommandOption('ref_fn', ref_fn),
        CommandOption('bed_fn', bed_fn),
        CommandOption('ctgName', ctgName), ctgStart, ctgEnd,
        CommandOption('threshold', af_threshold),
        CommandOption('minCoverage', minCoverage),
        CommandOption('samtools', samtoolsBin)
    ]
    get_truth_command_options = [
        pypyBin, GTBin,
        CommandOption('vcf_fn', vcf_fn),
        CommandOption('ref_fn', ref_fn),
        CommandOption('ctgName', ctgName), ctgStart, ctgEnd
    ]

    create_tensor_command_options = [
        pypyBin, CTBin,
        CommandOption('bam_fn', bam_fn),
        CommandOption('ref_fn', ref_fn),
        CommandOption('ctgName', ctgName), ctgStart, ctgEnd,
        stop_consider_left_edge,
        CommandOption('samtools', samtoolsBin),
        CommandOption('dcov', dcov)
    ]

    call_variant_command_options = [
        taskSet,
        ExecuteCommand('python', CVBin),
        CommandOption('chkpnt_fn', chkpnt_fn),
        CommandOption('call_fn', call_fn),
        CommandOption('bam_fn', bam_fn),
        CommandOption('sampleName', sampleName),
        CommandOption('time_counter_file_name', args.time_counter_file_name),
        CommandOption('threads', numCpus),
        CommandOption('ref_fn', ref_fn), pysam_for_all_indel_bases,
        haploid_precision_mode, haploid_sensitive_mode, output_for_ensemble,
        pipe_line, store_loaded_mini_match, only_prediction, qual, debug
    ]
    call_variant_with_activation_command_options = [
        CommandOptionWithNoValue('activation_only'),
        log_path,
        CommandOption('max_plot', args.max_plot),
        CommandOption('parallel_level', args.parallel_level),
        CommandOption('workers', args.workers),
        fast_plotting,
    ] if args.activation_only else []

    is_true_variant_call = vcf_fn is not None
    try:
        c.extract_variant_candidate = subprocess_popen(
            shlex.split(
                command_string_from(
                    get_truth_command_options if is_true_variant_call else
                    extract_variant_candidate_command_options)))

        c.create_tensor = subprocess_popen(
            shlex.split(command_string_from(create_tensor_command_options)),
            stdin=c.extract_variant_candidate.stdout)

        c.call_variant = subprocess_popen(shlex.split(
            command_string_from(call_variant_command_options +
                                call_variant_with_activation_command_options)),
                                          stdin=c.create_tensor.stdout,
                                          stdout=sys.stderr)
    except Exception as e:
        print(e, file=sys.stderr)
        sys.exit("Failed to start required processes. Exiting...")

    signal.signal(signal.SIGALRM, check_return_code)
    signal.alarm(2)

    try:
        c.call_variant.wait()
        c.create_tensor.stdout.close()
        c.create_tensor.wait()
        c.extract_variant_candidate.stdout.close()
        c.extract_variant_candidate.wait()
    except KeyboardInterrupt as e:
        print(
            "KeyboardInterrupt received when waiting at CallVarBam, terminating all scripts."
        )
        try:
            c.call_variant.terminate()
            c.create_tensor.terminate()
            c.extract_variant_candidate.terminate()
        except Exception as e:
            print(e)

        raise KeyboardInterrupt
    except Exception as e:
        print(
            "Exception received when waiting at CallVarBam, terminating all scripts."
        )
        print(e)
        try:
            c.call_variant.terminate()
            c.create_tensor.terminate()
            c.extract_variant_candidate.terminate()
        except Exception as e:
            print(e)

        raise e
예제 #25
0
def SelectCandidates(args):
    """
    Select low quality and low sequence entropy candidate variants for full aligement. False positive pileup variants
    and true variants missed by pileup calling would mostly have low quality score (reference quality score for missing
    variants), so only use a proportion of low quality variants for full alignment while maintain high quality pileup
    output, as full alignment calling is substantially slower than pileup calling.
    """

    phased_vcf_fn = args.phased_vcf_fn
    pileup_vcf_fn = args.pileup_vcf_fn
    var_pct_full = args.var_pct_full
    ref_pct_full = args.ref_pct_full
    seq_entropy_pro = args.seq_entropy_pro
    contig_name = args.ctgName
    phasing_window_size = param.phasing_window_size
    platform = args.platform
    split_bed_size = args.split_bed_size
    split_folder = args.split_folder
    extend_bp = param.extend_bp
    call_low_seq_entropy = args.call_low_seq_entropy
    phasing_info_in_bam = args.phasing_info_in_bam
    need_phasing_list = []
    need_phasing_set = set()
    ref_call_pos_list = []
    variant_dict = defaultdict(str)
    flankingBaseNum = param.flankingBaseNum
    qual_fn = args.qual_fn if args.qual_fn is not None else 'qual'
    fasta_file_path = args.ref_fn
    samtools_execute_command = args.samtools

    found_qual_cut_off = False
    low_sequence_entropy_list = []
    # try to find the global quality cut off:
    f_qual = os.path.join(split_folder, qual_fn)
    if os.path.exists(f_qual):
        with open(f_qual, 'r') as f:
            line = f.read().rstrip().split(' ')
        var_qual, ref_qual = float(line[0]), float(line[1])
        found_qual_cut_off = True

    all_full_aln_regions = []
    if phased_vcf_fn and os.path.exists(phased_vcf_fn):
        unzip_process = subprocess_popen(
            shlex.split("gzip -fdc %s" % (phased_vcf_fn)))
        for row in unzip_process.stdout:
            row = row.rstrip()
            if row[0] == '#':
                continue
            columns = row.strip().split('\t')

            ctg_name = columns[0]
            if contig_name and contig_name != ctg_name:
                continue
            pos = int(columns[1])
            ref_base = columns[3]
            alt_base = columns[4]
            genotype_info = columns[9].split(':')
            genotype, phase_set = genotype_info[0], genotype_info[-1]
            if '|' not in genotype:  # unphasable
                continue
            variant_dict[pos] = '-'.join([
                ref_base, alt_base, ('1' if genotype == '0|1' else '2'),
                phase_set
            ])

    if pileup_vcf_fn and os.path.exists(pileup_vcf_fn):
        # vcf format
        unzip_process = subprocess_popen(
            shlex.split("gzip -fdc %s" % (pileup_vcf_fn)))
        for row in unzip_process.stdout:
            if row[0] == '#':
                continue
            columns = row.rstrip().split('\t')
            ctg_name = columns[0]
            if contig_name and contig_name != ctg_name:
                continue
            pos = int(columns[1])
            ref_base = columns[3]
            alt_base = columns[4]
            qual = float(columns[5])

            # reference calling
            if alt_base == "." or ref_base == alt_base:
                ref_call_pos_list.append((pos, qual))
            else:
                need_phasing_list.append((pos, qual))
                need_phasing_set.add(pos)

        if found_qual_cut_off:
            low_qual_ref_list = [[k, v] for k, v in ref_call_pos_list
                                 if v < ref_qual]
            low_qual_variant_list = [[k, v] for k, v in need_phasing_list
                                     if v < var_qual]
        else:
            low_qual_ref_list = sorted(
                ref_call_pos_list,
                key=lambda x: x[1])[:int(ref_pct_full *
                                         len(ref_call_pos_list))]
            low_qual_variant_list = sorted(
                need_phasing_list,
                key=lambda x: x[1])[:int(var_pct_full *
                                         len(need_phasing_list))]

        if call_low_seq_entropy:
            candidate_positions = sorted(
                ref_call_pos_list, key=lambda x: x[1])[:int(
                    (var_pct_full + seq_entropy_pro) * len(ref_call_pos_list)
                )] + sorted(need_phasing_list, key=lambda x: x[1])[:int(
                    (var_pct_full + seq_entropy_pro) * len(need_phasing_list))]
            candidate_positions = set(
                [item[0] for item in candidate_positions])

            candidate_positions_entropy_list = sqeuence_entropy_from(
                samtools_execute_command=samtools_execute_command,
                fasta_file_path=fasta_file_path,
                contig_name=contig_name,
                candidate_positions=candidate_positions)

            low_sequence_entropy_list = sorted(
                candidate_positions_entropy_list, key=lambda x: x[1]
            )[:int(seq_entropy_pro * len(candidate_positions_entropy_list))]

        # calling with phasing_info_in_bam: select low qual ref and low qual vairant for phasing calling
        if phasing_info_in_bam:
            logging.info(
                '[INFO] Low quality reference calls to be processed in {}: {}'.
                format(contig_name, len(low_qual_ref_list)))
            logging.info(
                '[INFO] Low quality variants to be processed in {}: {}'.format(
                    contig_name, len(low_qual_variant_list)))
            if call_low_seq_entropy:
                logging.info(
                    '[INFO] Total low sequence entropy variants to be processed in {}: {}'
                    .format(contig_name, len(low_sequence_entropy_list)))

            need_phasing_row_list = set(
                [item[0] for item in low_qual_ref_list] +
                [item[0] for item in low_qual_variant_list] +
                [item[0] for item in low_sequence_entropy_list])
            need_phasing_row_list = sorted(list(need_phasing_row_list))

            if len(need_phasing_row_list) == 0:
                print(
                    log_warning(
                        "[WARNING] Cannot find any low-quality 0/0, 0/1 or 1/1 variant in pileup output in contig {}"
                        .format(contig_name)))

            region_num = len(
                need_phasing_row_list) // split_bed_size + 1 if len(
                    need_phasing_row_list) % split_bed_size else len(
                        need_phasing_row_list) // split_bed_size

            for idx in range(region_num):
                # a windows region for create tensor # samtools mpileup not include last position
                split_output = need_phasing_row_list[idx *
                                                     split_bed_size:(idx + 1) *
                                                     split_bed_size]

                if platform == 'ilmn':
                    region_size = param.split_region_size
                    split_output = [(item // region_size * region_size -
                                     param.no_of_positions,
                                     item // region_size * region_size +
                                     region_size + param.no_of_positions)
                                    for item in split_output]
                else:
                    split_output = [(item - flankingBaseNum,
                                     item + flankingBaseNum + 2)
                                    for item in split_output]

                split_output = sorted(split_output, key=lambda x: x[0])

                # currently deprecate using ctgName.start_end as file name, which will run similar regions for several times when start and end has slight difference
                # output_path = os.path.join(split_folder, '{}.{}_{}'.format(contig_name, split_output[0][0], split_output[-1][1]))
                output_path = os.path.join(
                    split_folder, '{}.{}_{}'.format(contig_name, idx,
                                                    region_num))
                all_full_aln_regions.append(output_path)
                with open(output_path, 'w') as output_file:
                    output_file.write('\n'.join([
                        '\t'.join([
                            contig_name,
                            str(x[0] - 1),
                            str(x[1] - 1),
                        ]) for x in split_output
                    ]) + '\n')  # bed format

            if len(all_full_aln_regions) > 0:
                all_full_aln_regions_path = os.path.join(
                    split_folder, 'FULL_ALN_FILE_{}'.format(contig_name))
                with open(all_full_aln_regions_path, 'w') as output_file:
                    output_file.write('\n'.join(all_full_aln_regions) + '\n')
            return

        for pos, qual in low_qual_ref_list:
            need_phasing_set.add(pos)

    # Call variant in all candidate position
    elif args.all_alt_fn is not None:
        unzip_process = subprocess_popen(
            shlex.split("gzip -fdc %s" % (args.all_alt_fn)))
        for row in unzip_process.stdout:
            if row[0] == '#':
                continue
            columns = row.rstrip().split('\t')
            ctg_name, pos = columns[0].split()
            pos = int(pos)
            if contig_name and contig_name != ctg_name:
                continue
            need_phasing_set.add(pos)

    need_phasing_row_list = sorted(list(set(need_phasing_set)))
    snp_tree = IntervalTree()
    hete_snp_row_list = sorted(
        list(
            set(variant_dict.keys()).intersection(set(need_phasing_row_list))))
    print(
        '[INFO] Total hete snp with reads support in {}: '.format(contig_name),
        len(hete_snp_row_list))
    print(
        '[INFO] Total candidates need to be processed in {}: '.format(
            contig_name), len(need_phasing_row_list))

    for item in hete_snp_row_list:
        snp_tree.addi(item, item + 1)

    region_num = len(need_phasing_row_list) // split_bed_size + 1 if len(
        need_phasing_row_list) % split_bed_size else len(
            need_phasing_row_list) // split_bed_size
    for idx in range(region_num):
        split_output = need_phasing_row_list[idx * split_bed_size:(idx + 1) *
                                             split_bed_size]

        start = split_output[0]
        end = split_output[-1]
        extend_start, extend_end = start - phasing_window_size, end + phasing_window_size
        overlaps = snp_tree.overlap(extend_start, extend_end)
        snp_split_out = []
        for overlap in overlaps:
            snp_split_out.append((contig_name, overlap[0] - extend_bp - 1 - 1,
                                  overlap[0] + 1 + extend_bp - 1,
                                  variant_dict[overlap[0]]))  # bed format
        split_output = [(contig_name, item - flankingBaseNum - 1,
                         item + flankingBaseNum + 1 - 1)
                        for item in split_output
                        ]  # a windows region for create tensor # bed format

        split_output += snp_split_out
        split_output = sorted(split_output, key=lambda x: x[1])

        with open(
                os.path.join(split_folder,
                             '{}.{}_{}'.format(contig_name, start, end)),
                'w') as output_file:
            output_file.write(
                '\n'.join(['\t'.join(map(str, x))
                           for x in split_output]) + '\n')  # bed format
예제 #26
0
파일: GetTruth.py 프로젝트: HKU-BAL/Clair3
def OutputVariant(args):
    var_fn = args.var_fn
    vcf_fn = args.vcf_fn
    truth_vcf_fn = args.truth_vcf_fn
    ctg_name = args.ctgName
    ctg_start = args.ctgStart
    ctg_end = args.ctgEnd

    truth_vcf_set = set()
    variant_set = set()
    if args.truth_vcf_fn is not None:
        truth_vcf_set = set(
            vcf_candidates_from(vcf_fn=truth_vcf_fn, contig_name=ctg_name))
    if args.var_fn != "PIPE":
        var_fpo = open(var_fn, "wb")
        var_fp = subprocess_popen(shlex.split("gzip -c"),
                                  stdin=PIPE,
                                  stdout=var_fpo)
    else:
        var_fp = TruthStdout(sys.stdout)

    is_ctg_region_provided = ctg_start is not None and ctg_end is not None

    vcf_fp = subprocess_popen(shlex.split("gzip -fdc %s" % (vcf_fn)))

    for row in vcf_fp.stdout:
        columns = row.strip().split()
        if columns[0][0] == "#":
            continue

        # position in vcf is 1-based
        chromosome, position = columns[0], columns[1]
        if chromosome != ctg_name:
            continue
        if is_ctg_region_provided and not (ctg_start <= int(position) <=
                                           ctg_end):
            continue
        reference, alternate, last_column = columns[3], columns[4], columns[-1]
        # normal GetTruth
        genotype = last_column.split(":")[0].replace("/", "|").replace(
            ".", "0").split("|")
        genotype_1, genotype_2 = genotype

        # 1000 Genome GetTruth (format problem) (no genotype is given)
        if int(genotype_1) > int(genotype_2):
            genotype_1, genotype_2 = genotype_2, genotype_1

        #remove * to guarentee vcf match
        if '*' in alternate:
            alternate = alternate.split(',')
            if int(genotype_1) + int(genotype_2) != 3 or len(alternate) != 2:
                print('error with variant represatation')
                continue
            alternate = ''.join(
                [alt_base for alt_base in alternate if alt_base != '*'])
            # * always have a genotype 1/2

            genotype_1, genotype_2 = '0', '1'

        variant_set.add(int(position))
        var_fp.stdin.write(" ".join((chromosome, position, reference,
                                     alternate, genotype_1, genotype_2)))
        var_fp.stdin.write("\n")

    for position in truth_vcf_set:
        if position not in variant_set:
            # miss variant set used in Tensor2Bin
            var_fp.stdin.write(" ".join(
                (chromosome, str(position), "None", "None", "-1", "-1")))
            var_fp.stdin.write("\n")

    vcf_fp.stdout.close()
    vcf_fp.wait()

    if args.var_fn != "PIPE":
        var_fp.stdin.close()
        var_fp.wait()
        var_fpo.close()
예제 #27
0
def make_candidates(args):

    gen4Training = args.gen4Training
    variant_file_path = args.var_fn
    bed_file_path = args.bed_fn
    fasta_file_path = args.ref_fn
    ctg_name = args.ctgName
    ctg_start = args.ctgStart
    ctg_end = args.ctgEnd
    output_probability = args.outputProb
    samtools_execute_command = args.samtools
    minimum_depth_for_candidate = args.minCoverage
    minimum_af_for_candidate = args.threshold
    minimum_mapping_quality = args.minMQ
    bam_file_path = args.bam_fn
    candidate_output_path = args.can_fn
    is_using_stdout_for_output_candidate = candidate_output_path == "PIPE"

    is_building_training_dataset = gen4Training == True
    is_variant_file_given = variant_file_path is not None
    is_bed_file_given = bed_file_path is not None
    is_ctg_name_given = ctg_name is not None
    is_ctg_range_given = is_ctg_name_given and ctg_start is not None and ctg_end is not None

    if is_building_training_dataset:
        # minimum_depth_for_candidate = 0
        minimum_af_for_candidate = 0

    # preparation for candidates near variants
    need_consider_candidates_near_variant = is_building_training_dataset and is_variant_file_given
    variants_map = variants_map_from(
        variant_file_path) if need_consider_candidates_near_variant else {}
    non_variants_map = non_variants_map_near_variants_from(variants_map)
    no_of_candidates_near_variant = 0
    no_of_candidates_outside_variant = 0

    # update output probabilities for candidates near variants
    # original: (7000000.0 * 2.0 / 3000000000)
    ratio_of_candidates_near_variant_to_candidates_outside_variant = 1.0
    output_probability_near_variant = (
        3500000.0 *
        ratio_of_candidates_near_variant_to_candidates_outside_variant *
        RATIO_OF_NON_VARIANT_TO_VARIANT / 14000000)
    output_probability_outside_variant = 3500000.0 * RATIO_OF_NON_VARIANT_TO_VARIANT / (
        3000000000 - 14000000)

    if not isfile("{}.fai".format(fasta_file_path)):
        print("Fasta index {}.fai doesn't exist.".format(fasta_file_path),
              file=sys.stderr)
        sys.exit(1)

    # 1-based regions [start, end] (start and end inclusive)
    regions = []
    reference_start, reference_end = None, None
    if is_ctg_range_given:
        reference_start, reference_end = ctg_start - param.expandReferenceRegion, ctg_end + param.expandReferenceRegion
        reference_start = 1 if reference_start < 1 else reference_start
        regions.append(
            region_from(ctg_name=ctg_name,
                        ctg_start=reference_start,
                        ctg_end=reference_end))
    elif is_ctg_name_given:
        regions.append(region_from(ctg_name=ctg_name))

    reference_sequence = reference_sequence_from(
        samtools_execute_command=samtools_execute_command,
        fasta_file_path=fasta_file_path,
        regions=regions)
    if reference_sequence is None or len(reference_sequence) == 0:
        print(
            "[ERROR] Failed to load reference seqeunce from file ({}).".format(
                fasta_file_path),
            file=sys.stderr)
        sys.exit(1)

    tree = bed_tree_from(bed_file_path=bed_file_path)
    if is_bed_file_given and ctg_name not in tree:
        print("[ERROR] ctg_name({}) not exists in bed file({}).".format(
            ctg_name, bed_file_path),
              file=sys.stderr)
        sys.exit(1)

    samtools_view_process = subprocess_popen(
        shlex.split("{} view -F {} {} {}".format(
            samtools_execute_command, param.SAMTOOLS_VIEW_FILTER_FLAG,
            bam_file_path, " ".join(regions))))

    if is_using_stdout_for_output_candidate:
        can_fp = CandidateStdout(sys.stdout)
    else:
        can_fpo = open(candidate_output_path, "wb")
        can_fp = subprocess_popen(shlex.split("gzip -c"),
                                  stdin=PIPE,
                                  stdout=can_fpo)

    pileup = defaultdict(lambda: {
        "A": 0,
        "C": 0,
        "G": 0,
        "T": 0,
        "I": 0,
        "D": 0,
        "N": 0
    })
    POS = 0
    number_of_reads_processed = 0

    while True:
        row = samtools_view_process.stdout.readline()
        is_finish_reading_output = row == '' and samtools_view_process.poll(
        ) is not None

        if row:
            columns = row.strip().split()
            if columns[0][0] == "@":
                continue

            RNAME = columns[2]
            if RNAME != ctg_name:
                continue

            POS = int(
                columns[3]
            ) - 1  # switch from 1-base to 0-base to match sequence index
            MAPQ = int(columns[4])
            CIGAR = columns[5]
            SEQ = columns[9].upper(
            )  # uppercase for SEQ (regexp is \*|[A-Za-z=.]+)

            reference_position = POS
            query_position = 0

            if MAPQ < minimum_mapping_quality:
                continue
            if CIGAR == "*" or is_too_many_soft_clipped_bases_for_a_read_from(
                    CIGAR):
                continue

            number_of_reads_processed += 1

            advance = 0
            for c in str(CIGAR):
                if c.isdigit():
                    advance = advance * 10 + int(c)
                    continue

                if c == "S":
                    query_position += advance

                elif c == "M" or c == "=" or c == "X":
                    for _ in range(advance):
                        base = evc_base_from(SEQ[query_position])
                        pileup[reference_position][base] += 1

                        # those CIGAR operations consumes query and reference
                        reference_position += 1
                        query_position += 1

                elif c == "I":
                    pileup[reference_position - 1]["I"] += 1

                    # insertion consumes query
                    query_position += advance

                elif c == "D":
                    pileup[reference_position - 1]["D"] += 1

                    # deletion consumes reference
                    reference_position += advance

                # reset advance
                advance = 0

        positions = [x for x in pileup.keys() if x < POS
                     ] if not is_finish_reading_output else list(pileup.keys())
        positions.sort()
        for zero_based_position in positions:
            base_count = depth = reference_base = temp_key = None

            # ctg and bed checking (region [ctg_start, ctg_end] is 1-based, inclusive start and end positions)
            pass_ctg = not is_ctg_range_given or ctg_start <= zero_based_position + 1 <= ctg_end
            pass_bed = not is_bed_file_given or is_region_in(
                tree, ctg_name, zero_based_position)
            if not pass_bed or not pass_ctg:
                continue

            # output probability checking
            pass_output_probability = True
            if is_building_training_dataset and is_variant_file_given:
                temp_key = ctg_name + ":" + str(zero_based_position + 1)
                pass_output_probability = (temp_key not in variants_map and (
                    (temp_key in non_variants_map and
                     random.uniform(0, 1) <= output_probability_near_variant)
                    or (temp_key not in non_variants_map and random.uniform(
                        0, 1) <= output_probability_outside_variant)))
            elif is_building_training_dataset:
                pass_output_probability = random.uniform(
                    0, 1) <= output_probability
            if not pass_output_probability:
                continue

            # for depth checking and af checking
            try:
                reference_base = evc_base_from(reference_sequence[
                    zero_based_position -
                    (0 if reference_start is None else (reference_start - 1))])
                position_dict = pileup[zero_based_position]
            except:
                continue

            # depth checking
            base_count = list(position_dict.items())
            depth = sum(
                x[1]
                for x in base_count) - position_dict["I"] - position_dict["D"]
            if depth < minimum_depth_for_candidate:
                continue

            # af checking
            denominator = depth if depth > 0 else 1
            base_count.sort(
                key=lambda x: -x[1])  # sort base_count descendingly
            pass_af = (base_count[0][0] != reference_base
                       or (float(base_count[1][1]) / denominator) >=
                       minimum_af_for_candidate)
            if not pass_af:
                continue

            # output 1-based candidate
            if temp_key is not None and temp_key in non_variants_map:
                no_of_candidates_near_variant += 1
            elif temp_key is not None and temp_key not in non_variants_map:
                no_of_candidates_outside_variant += 1

            output = [ctg_name, zero_based_position + 1, reference_base, depth]
            output.extend(["%s %d" % x for x in base_count])
            output = " ".join([str(x) for x in output]) + "\n"

            can_fp.stdin.write(output)

        for zero_based_position in positions:
            del pileup[zero_based_position]

        if is_finish_reading_output:
            break

    if need_consider_candidates_near_variant:
        print("# of candidates near variant: ", no_of_candidates_near_variant)
        print("# of candidates outside variant: ",
              no_of_candidates_outside_variant)

    samtools_view_process.stdout.close()
    samtools_view_process.wait()

    if not is_using_stdout_for_output_candidate:
        can_fp.stdin.close()
        can_fp.wait()
        can_fpo.close()

    if number_of_reads_processed == 0:
        print(
            "No read has been process, either the genome region you specified has no read cover, or please check the correctness of your BAM input (%s)."
            % (bam_file_path),
            file=sys.stderr)
        sys.exit(0)
예제 #28
0
def OutputAlnTensor(args):
    available_slots = 5000000
    samtools = args.samtools
    tensor_file_path = args.tensor_fn
    bam_file_path = args.bam_fn
    reference_file_path = args.ref_fn
    candidate_file_path = args.can_fn
    dcov = args.dcov
    is_consider_left_edge = not args.stop_consider_left_edge
    min_coverage = args.minCoverage
    minimum_mapping_quality = args.minMQ
    ctg_name = args.ctgName
    ctg_start = args.ctgStart
    ctg_end = args.ctgEnd

    reference_result = reference_result_from(
        ctg_name=ctg_name,
        ctg_start=ctg_start,
        ctg_end=ctg_end,
        samtools=samtools,
        reference_file_path=reference_file_path,
        expand_reference_region=param.expandReferenceRegion,
    )

    reference_sequence = reference_result.sequence if reference_result is not None else ""
    is_faidx_process_have_error = reference_result is None or reference_result.is_faidx_process_have_error
    have_reference_sequence = reference_result is not None and len(reference_sequence) > 0

    if reference_result is None or is_faidx_process_have_error or not have_reference_sequence:
        print("Failed to load reference seqeunce. Please check if the provided reference fasta %s and the ctgName %s are correct." % (
            reference_file_path,
            ctg_name
        ), file=sys.stderr)
        sys.exit(1)

    reference_start = reference_result.start
    reference_start_0_based = 0 if reference_start is None else (reference_start - 1)
    begin_to_end = {}
    candidate_position = 0
    candidate_position_generator = candidate_position_generator_from(
        candidate_file_path=candidate_file_path,
        ctg_start=ctg_start,
        ctg_end=ctg_end,
        is_consider_left_edge=is_consider_left_edge,
        flanking_base_num=param.flankingBaseNum,
        begin_to_end=begin_to_end
    )

    samtools_view_process = samtools_view_process_from(
        ctg_name=ctg_name,
        ctg_start=ctg_start,
        ctg_end=ctg_end,
        samtools=samtools,
        bam_file_path=bam_file_path
    )

    center_to_alignment = {}

    if tensor_file_path != "PIPE":
        tensor_fpo = open(tensor_file_path, "wb")
        tensor_fp = subprocess_popen(shlex.split("gzip -c"), stdin=PIPE, stdout=tensor_fpo)
    else:
        tensor_fp = TensorStdout(sys.stdout)

    previous_position = 0
    depthCap = 0
    for l in samtools_view_process.stdout:
        l = l.split()
        if l[0][0] == "@":
            continue

        FLAG = int(l[1])
        POS = int(l[3]) - 1  # switch from 1-base to 0-base to match sequence index
        MQ = int(l[4])
        CIGAR = l[5]
        SEQ = l[9].upper()   # uppercase for SEQ (regexp is \*|[A-Za-z=.]+)
        reference_position = POS
        query_position = 0
        STRAND = (16 == (FLAG & 16))

        if MQ < minimum_mapping_quality:
            continue

        end_to_center = {}
        active_set = set()

        while candidate_position != -1 and candidate_position < (POS + len(SEQ) + 100000):
            candidate_position = next(candidate_position_generator)

        if previous_position != POS:
            previous_position = POS
            depthCap = 0
        else:
            depthCap += 1
            if depthCap >= dcov:
                #print >> sys.stderr, "Bypassing POS %d at depth %d\n" % (POS, depthCap)
                continue

        advance = 0
        for c in str(CIGAR):
            if available_slots <= 0:
                break

            if c.isdigit():
                advance = advance * 10 + int(c)
                continue

            # soft clip
            if c == "S":
                query_position += advance

            # match / mismatch
            if c == "M" or c == "=" or c == "X":
                for _ in range(advance):
                    if reference_position in begin_to_end:
                        for rEnd, rCenter in begin_to_end[reference_position]:
                            if rCenter in active_set:
                                continue
                            end_to_center[rEnd] = rCenter
                            active_set.add(rCenter)
                            center_to_alignment.setdefault(rCenter, [])
                            center_to_alignment[rCenter].append([])
                    for center in list(active_set):
                        if available_slots <= 0:
                            break
                        available_slots -= 1

                        center_to_alignment[center][-1].append((
                            reference_position,
                            0,
                            reference_sequence[reference_position - reference_start_0_based],
                            SEQ[query_position],
                            STRAND
                        ))
                    if reference_position in end_to_center:
                        center = end_to_center[reference_position]
                        active_set.remove(center)
                    reference_position += 1
                    query_position += 1

            # insertion
            if c == "I":
                for queryAdv in range(advance):
                    for center in list(active_set):
                        if available_slots <= 0:
                            break
                        available_slots -= 1

                        center_to_alignment[center][-1].append((
                            reference_position,
                            queryAdv,
                            "-",
                            SEQ[query_position],
                            STRAND
                        ))
                    query_position += 1

            # deletion
            if c == "D":
                for _ in range(advance):
                    for center in list(active_set):
                        if available_slots <= 0:
                            break
                        available_slots -= 1

                        center_to_alignment[center][-1].append((
                            reference_position,
                            0,
                            reference_sequence[reference_position - reference_start_0_based],
                            "-",
                            STRAND
                        ))
                    if reference_position in begin_to_end:
                        for rEnd, rCenter in begin_to_end[reference_position]:
                            if rCenter in active_set:
                                continue
                            end_to_center[rEnd] = rCenter
                            active_set.add(rCenter)
                            center_to_alignment.setdefault(rCenter, [])
                            center_to_alignment[rCenter].append([])
                    if reference_position in end_to_center:
                        center = end_to_center[reference_position]
                        active_set.remove(center)
                    reference_position += 1

            # reset advance
            advance = 0

        if depthCap == 0:
            for center in list(center_to_alignment.keys()):
                if center + (param.flankingBaseNum + 1) >= POS:
                    continue
                l = generate_tensor(
                    ctg_name, center_to_alignment[center], center, reference_sequence, reference_start_0_based, min_coverage
                )
                if l != None:
                    tensor_fp.stdin.write(l)
                    tensor_fp.stdin.write("\n")
                available_slots += sum(len(i) for i in center_to_alignment[center])
                #print >> sys.stderr, "POS %d: remaining slots %d" % (center, available_slots)
                del center_to_alignment[center]

    for center in center_to_alignment.keys():
        l = generate_tensor(
            ctg_name, center_to_alignment[center], center, reference_sequence, reference_start_0_based, min_coverage
        )
        if l != None:
            tensor_fp.stdin.write(l)
            tensor_fp.stdin.write("\n")

    samtools_view_process.stdout.close()
    samtools_view_process.wait()
    if tensor_file_path != "PIPE":
        tensor_fp.stdin.close()
        tensor_fp.wait()
        tensor_fpo.close()
예제 #29
0
def sort_vcf_from(args):
    """
    Sort vcf file from providing vcf filename prefix.
    """
    output_fn = args.output_fn
    input_dir = args.input_dir
    vcf_fn_prefix = args.vcf_fn_prefix
    vcf_fn_suffix = args.vcf_fn_suffix
    sample_name = args.sampleName
    ref_fn = args.ref_fn
    contigs_fn = args.contigs_fn

    if not os.path.exists(input_dir):
        exit(
            log_error("[ERROR] Input directory: {} not exists!").format(
                input_dir))
    all_files = os.listdir(input_dir)

    if vcf_fn_prefix is not None:
        all_files = [
            item for item in all_files if item.startswith(vcf_fn_prefix)
        ]
        if len(all_files) == 0:
            output_header(output_fn=output_fn,
                          reference_file_path=ref_fn,
                          sample_name=sample_name)
            print(
                log_warning(
                    "[WARNING] No vcf file found with prefix:{}/{}, output empty vcf file"
                    .format(input_dir, vcf_fn_prefix)))
            compress_index_vcf(output_fn)
            print_calling_step(output_fn=output_fn)
            return

    if vcf_fn_suffix is not None:
        all_files = [
            item for item in all_files if item.endswith(vcf_fn_suffix)
        ]
        if len(all_files) == 0:
            output_header(output_fn=output_fn,
                          reference_file_path=ref_fn,
                          sample_name=sample_name)
            print(
                log_warning(
                    "[WARNING] No vcf file found with suffix:{}/{}, output empty vcf file"
                    .format(input_dir, vcf_fn_prefix)))
            compress_index_vcf(output_fn)
            print_calling_step(output_fn=output_fn)
            return

    all_contigs_list = []
    if contigs_fn and os.path.exists(contigs_fn):
        with open(contigs_fn) as f:
            all_contigs_list = [item.rstrip() for item in f.readlines()]
    else:
        exit(
            log_error("[ERROR] Cannot find contig file {}. Exit!").format(
                contigs_fn))

    contigs_order = major_contigs_order + all_contigs_list
    contigs_order_list = sorted(all_contigs_list,
                                key=lambda x: contigs_order.index(x))

    row_count = 0
    header = []
    no_vcf_output = True
    need_write_header = True

    # only compress intermediate gvcf using lz4 output and keep final gvcf in bgzip format
    output_bgzip_gvcf = vcf_fn_suffix == '.gvcf'
    compress_gvcf = 'gvcf' in vcf_fn_suffix
    if compress_gvcf:
        lz4_path = subprocess.run("which lz4",
                                  stdout=subprocess.PIPE,
                                  shell=True).stdout.decode().rstrip()
        compress_gvcf = True if lz4_path != "" else False
    is_lz4_format = compress_gvcf
    compress_gvcf_output = compress_gvcf and not output_bgzip_gvcf
    if compress_gvcf_output:
        write_fpo = open(output_fn, 'w')
        write_proc = subprocess_popen(shlex.split("lz4 -c"),
                                      stdin=subprocess.PIPE,
                                      stdout=write_fpo,
                                      stderr=subprocess.DEVNULL)
        output = write_proc.stdin
    else:
        output = open(output_fn, 'w')

    for contig in contigs_order_list:
        contig_dict = defaultdict(str)
        contig_vcf_fns = [fn for fn in all_files if contig in fn]
        for vcf_fn in contig_vcf_fns:
            file = os.path.join(input_dir, vcf_fn)
            if is_lz4_format:
                read_proc = subprocess_popen(shlex.split("{} {}".format(
                    "lz4 -fdc", file)),
                                             stderr=subprocess.DEVNULL)
                fn = read_proc.stdout
            else:
                fn = open(file, 'r')
            for row in fn:
                row_count += 1
                if row[0] == '#':
                    # skip phasing command line only occur with --enable_phasing, otherwise would lead to hap.py evaluation failure
                    if row.startswith('##commandline='):
                        continue
                    if row not in header:
                        header.append(row)
                    continue
                # use the first vcf header
                columns = row.strip().split(maxsplit=3)
                ctg_name, pos = columns[0], columns[1]
                # skip vcf file sharing same contig prefix, ie, chr1 and chr11
                if ctg_name != contig:
                    break
                contig_dict[int(pos)] = row
                no_vcf_output = False
            fn.close()
            if is_lz4_format:
                read_proc.wait()
        if need_write_header and len(header):
            if output_bgzip_gvcf:
                header = check_header_in_gvcf(header=header,
                                              contigs_list=all_contigs_list)
            output.write(''.join(header))
            need_write_header = False
        all_pos = sorted(contig_dict.keys())
        for pos in all_pos:
            output.write(contig_dict[pos])

    if compress_gvcf_output:
        write_proc.stdin.close()
        write_proc.wait()
        write_fpo.close()
        return
    else:
        output.close()

    if row_count == 0:
        print(
            log_warning("[WARNING] No vcf file found, output empty vcf file"))
        output_header(output_fn=output_fn,
                      reference_file_path=ref_fn,
                      sample_name=sample_name)
        compress_index_vcf(output_fn)
        print_calling_step(output_fn=output_fn)
        return
    if no_vcf_output:
        output_header(output_fn=output_fn,
                      reference_file_path=ref_fn,
                      sample_name=sample_name)
        print(log_warning("[WARNING] No variant found, output empty vcf file"))
        compress_index_vcf(output_fn)
        print_calling_step(output_fn=output_fn)
        return

    if vcf_fn_suffix == ".tmp.gvcf":
        return
    if vcf_fn_suffix == ".gvcf":
        print("[INFO] Need some time to compress and index GVCF file...")
    compress_index_vcf(output_fn)
예제 #30
0
파일: utils.py 프로젝트: Yufeng98/Clair
def get_training_array(tensor_fn,
                       var_fn,
                       bed_fn,
                       shuffle=True,
                       is_allow_duplicate_chr_pos=False):
    tree = bed_tree_from(bed_file_path=bed_fn)
    is_tree_empty = len(tree.keys()) == 0

    Y = variant_map_from(var_fn, tree, is_tree_empty)

    X = {}
    f = subprocess_popen(shlex.split("gzip -fdc %s" % (tensor_fn)))
    total = 0
    mat = np.empty(input_tensor_size, dtype=np.float32)
    for row in f.stdout:
        chrom, coord, seq, mat = unpack_a_tensor_record(*(row.split()))
        if not (is_tree_empty or is_region_in(tree, chrom, int(coord))):
            continue
        seq = seq.upper()
        if seq[param.flankingBaseNum] not in BASIC_BASES:
            continue
        key = chrom + ":" + coord

        x = np.reshape(mat, (no_of_positions, matrix_row, matrix_num))
        for i in range(1, matrix_num):
            x[:, :, i] -= x[:, :, 0]

        if key not in X:
            X[key] = np.copy(x)
        elif is_allow_duplicate_chr_pos:
            new_key = ""
            for character in PREFIX_CHAR_STR:
                tmp_key = character + key
                if tmp_key not in X:
                    new_key = tmp_key
                    break
            if len(new_key) > 0:
                X[new_key] = np.copy(x)

        is_reference = key not in Y
        if is_reference:
            Y[key] = output_labels_from_reference(
                BASE2ACGT[seq[param.flankingBaseNum]])

        total += 1
        if total % 100000 == 0:
            print("Processed %d tensors" % total, file=sys.stderr)
    f.stdout.close()
    f.wait()

    # print "[INFO] size of X: {}, size of Y: {}".format(len(X), len(Y))

    all_chr_pos = sorted(X.keys())
    if shuffle == True:
        np.random.shuffle(all_chr_pos)

    X_compressed, Y_compressed, pos_compressed = [], [], []
    X_array, Y_array, pos_array = [], [], []
    count = 0
    total = 0
    for key in all_chr_pos:
        total += 1

        X_array.append(X[key])
        del X[key]

        if key in Y:
            Y_array.append(Y[key])
            pos_array.append(key)
            if not is_allow_duplicate_chr_pos:
                del Y[key]
        elif is_allow_duplicate_chr_pos:
            tmp_key = key[1:]
            Y_array.append(Y[tmp_key])
            pos_array.append(tmp_key)

        count += 1
        if count == param.bloscBlockSize:
            X_compressed.append(blosc_pack_array(np.array(X_array)))
            Y_compressed.append(blosc_pack_array(np.array(Y_array)))
            pos_compressed.append(blosc_pack_array(np.array(pos_array)))
            X_array, Y_array, pos_array = [], [], []
            count = 0

        if total % 50000 == 0:
            print("Compressed %d/%d tensor" % (total, len(all_chr_pos)),
                  file=sys.stderr)

    if count > 0:
        X_compressed.append(blosc_pack_array(np.array(X_array)))
        Y_compressed.append(blosc_pack_array(np.array(Y_array)))
        pos_compressed.append(blosc_pack_array(np.array(pos_array)))

    return total, X_compressed, Y_compressed, pos_compressed