def extract_msa(base_dname, base_fname, locus_list, min_freq, verbose):
    # Download human genome and HISAT2 index
    HISAT2_fnames = ["grch38", "genome.fa", "genome.fa.fai"]
    if not typing_common.check_files(HISAT2_fnames):
        typing_common.download_genome_and_index(ex_path)

    # Load allele frequency information
    allele_freq = {}
    if min_freq > 0.0:
        excel = openpyxl.load_workbook(
            "hisatgenotype_db/CODIS/NIST-US1036-AlleleFrequencies.xlsx")
        sheet = excel.get_sheet_by_name(u'All data, n=1036')
        for col in range(2, 100):
            locus_name = sheet.cell(row=3, column=col).value
            if not locus_name:
                break
            locus_name = locus_name.encode('ascii', 'ignore')
            locus_name = locus_name.upper()
            assert locus_name not in allele_freq
            allele_freq[locus_name] = {}

            for row in range(4, 101):
                allele_id = sheet.cell(row=row, column=1).value
                allele_id = str(allele_id)
                freq = sheet.cell(row=row, column=col).value
                if not freq:
                    continue
                allele_freq[locus_name][allele_id] = float(freq)
        excel.close()

    CODIS_seq = orig_CODIS_seq
    if len(locus_list) > 0:
        new_CODIS_seq = {}
        for locus_name, fields in CODIS_seq.items():
            if locus_name in locus_list:
                new_CODIS_seq[locus_name] = fields
        CODIS_seq = new_CODIS_seq

    # Add some additional sequences to allele sequences to make them reasonably long for typing and assembly
    for locus_name, fields in CODIS_seq.items():
        _, left_seq, repeat_seq, right_seq = fields
        allele_seq = left_seq + repeat_seq + right_seq
        left_flank_seq, right_flank_seq = get_flanking_seqs(allele_seq)
        CODIS_seq[locus_name][1] = left_flank_seq + left_seq
        CODIS_seq[locus_name][3] = right_seq + right_flank_seq

        print >> sys.stderr, "%s is found on the reference genome (GRCh38)" % locus_name

    for locus_name in CODIS_seq.keys():
        alleles = []
        for line in open("hisatgenotype_db/CODIS/codis.dat"):
            locus_name2, allele_id, repeat_st = line.strip().split('\t')
            if locus_name != locus_name2:
                continue
            if min_freq > 0.0:
                assert locus_name in allele_freq
                if allele_id not in allele_freq[locus_name] or \
                   allele_freq[locus_name][allele_id] < min_freq:
                    continue

            alleles.append([allele_id, repeat_st])

        # From   [TTTC]3TTTTTTCT[CTTT]20CTCC[TTCC]2
        # To     [['TTTC', [3]], ['TTTTTTCT', [1]], ['CTTT', [20]], ['CTCC', [1]], ['TTCC', [2]]]
        def read_allele(repeat_st):
            allele = []
            s = 0
            while s < len(repeat_st):
                ch = repeat_st[s]
                if ch == ' ':
                    s += 1
                    continue
                assert ch in "[ACGT"
                if ch == '[':
                    s += 1
                    repeat = ""
                    while s < len(repeat_st):
                        nt = repeat_st[s]
                        if nt in "ACGT":
                            repeat += nt
                            s += 1
                        else:
                            assert nt == ']'
                            s += 1
                            break
                    assert s < len(repeat_st)
                    num = 0
                    while s < len(repeat_st):
                        digit = repeat_st[s]
                        if digit.isdigit():
                            num = num * 10 + int(digit)
                            s += 1
                        else:
                            break
                    assert num > 0
                    allele.append([set([repeat]), set([num])])
                else:
                    repeat = ""
                    while s < len(repeat_st):
                        nt = repeat_st[s]
                        if nt in "ACGT":
                            repeat += nt
                            s += 1
                        else:
                            assert nt == ' ' or nt == '['
                            break
                    allele.append([set([repeat]), set([1])])

            # Sanity check
            cmp_repeat_st = ""
            for repeats, repeat_nums in allele:
                repeat = list(repeats)[0]
                repeat_num = list(repeat_nums)[0]
                if repeat_num > 1 or locus_name == "D8S1179":
                    cmp_repeat_st += "["
                cmp_repeat_st += repeat
                if repeat_num > 1 or locus_name == "D8S1179":
                    cmp_repeat_st += "]%d" % repeat_num

            assert repeat_st.replace(' ', '') == cmp_repeat_st.replace(' ', '')
            return allele

        alleles = [[allele_id, read_allele(repeat_st)]
                   for allele_id, repeat_st in alleles]

        def to_sequence(repeat_st):
            sequence = ""
            for repeats, repeat_nums in repeat_st:
                repeat = list(repeats)[0]
                repeat_num = list(repeat_nums)[0]
                sequence += (repeat * repeat_num)
            return sequence

        def remove_redundant_alleles(alleles):
            seq_to_ids = {}
            new_alleles = []
            for allele_id, repeat_st in alleles:
                allele_seq = to_sequence(repeat_st)
                if allele_seq in seq_to_ids:
                    print >> sys.stderr, "Warning) %s: %s has the same sequence as %s" % \
                        (locus_name, allele_id, seq_to_ids[allele_seq])
                    continue
                if allele_seq not in seq_to_ids:
                    seq_to_ids[allele_seq] = [allele_id]
                else:
                    seq_to_ids[allele_seq].append(allele_id)
                new_alleles.append([allele_id, repeat_st])

            return new_alleles

        alleles = remove_redundant_alleles(alleles)

        allele_seqs = [[allele_id, to_sequence(repeat_st)]
                       for allele_id, repeat_st in alleles]

        ref_allele_st, ref_allele_left, ref_allele, ref_allele_right = CODIS_seq[
            locus_name]
        ref_allele_st = read_allele(ref_allele_st)
        for allele_id, allele_seq in allele_seqs:
            if ref_allele == allele_seq:
                CODIS_ref_name[locus_name] = allele_id
                break

        # Add GRCh38 allele
        if locus_name not in CODIS_ref_name:
            allele_id = "GRCh38"
            CODIS_ref_name[locus_name] = allele_id
            allele_seqs = [[allele_id, ref_allele]] + allele_seqs
            alleles = [[allele_id, ref_allele_st]] + alleles

        print >> sys.stderr, "%s: %d alleles with reference allele as %s" % (
            locus_name, len(alleles), CODIS_ref_name[locus_name])
        if verbose:
            print >> sys.stderr, "\t", ref_allele_left, ref_allele, ref_allele_right
            for allele_id, allele in alleles:
                print >> sys.stderr, allele_id, "\t", allele

        # Create a backbone sequence
        assert len(alleles) > 0
        backbone_allele = deepcopy(alleles[-1][1])
        for allele_id, allele_st in reversed(alleles[:-1]):
            if verbose:
                print >> sys.stderr
                print >> sys.stderr, allele_id
                print >> sys.stderr, "backbone         :", backbone_allele
                print >> sys.stderr, "allele           :", allele_st
            backbone_allele = combine_alleles(backbone_allele, allele_st)
            msf_allele_seq, msf_backbone_seq = msf_alignment(
                backbone_allele, allele_st)
            if verbose:
                print >> sys.stderr, "combined backbone:", backbone_allele
                print >> sys.stderr, "msf_allele_seq  :", msf_allele_seq
                print >> sys.stderr, "msf_backbone_seq:", msf_backbone_seq
                print >> sys.stderr

        allele_dic = {}
        for allele_id, allele_seq in allele_seqs:
            allele_dic[allele_id] = allele_seq

        allele_repeat_msf = {}
        for allele_id, allele_st in alleles:
            msf_allele_seq, msf_backbone_seq = msf_alignment(
                backbone_allele, allele_st)
            allele_repeat_msf[allele_id] = msf_allele_seq

        # Sanity check
        assert len(allele_dic) == len(allele_repeat_msf)
        repeat_len = None
        for allele_id, repeat_msf in allele_repeat_msf.items():
            if not repeat_len:
                repeat_len = len(repeat_msf)
            else:
                assert repeat_len == len(repeat_msf)

        # Creat full multiple sequence alignment
        ref_allele_id = CODIS_ref_name[locus_name]
        allele_msf = {}
        for allele_id, repeat_msf in allele_repeat_msf.items():
            allele_msf[
                allele_id] = ref_allele_left + repeat_msf + ref_allele_right

        # Make sure the length of allele ID is short, less than 20 characters
        max_allele_id_len = max(
            [len(allele_id) for allele_id in allele_dic.keys()])
        assert max_allele_id_len < 20

        # Write MSF (multiple sequence alignment file)
        msf_len = len(ref_allele_left) + len(ref_allele_right) + repeat_len
        msf_fname = "%s_gen.msf" % locus_name
        msf_file = open(msf_fname, 'w')
        for s in range(0, msf_len, 50):
            for allele_id, msf in allele_msf.items():
                assert len(msf) == msf_len
                allele_name = "%s*%s" % (locus_name, allele_id)
                print >> msf_file, "%20s" % allele_name,
                for s2 in range(s, min(msf_len, s + 50), 10):
                    print >> msf_file, " %s" % msf[s2:s2 + 10],
                print >> msf_file

            if s + 50 >= msf_len:
                break
            print >> msf_file
        msf_file.close()

        # Write FASTA file
        fasta_fname = "%s_gen.fasta" % locus_name
        fasta_file = open(fasta_fname, 'w')
        for allele_id, allele_seq in allele_seqs:
            gen_seq = ref_allele_left + allele_seq + ref_allele_right
            print >> fasta_file, ">%s*%s %d bp" % (locus_name, allele_id,
                                                   len(gen_seq))
            for s in range(0, len(gen_seq), 60):
                print >> fasta_file, gen_seq[s:s + 60]
        fasta_file.close()
def build_genotype_genome(base_fname, inter_gap, intra_gap, threads,
                          database_list, use_clinvar, use_commonvar, aligner,
                          graph_index, verbose):
    # Download HISAT2 index
    typing_common.download_genome_and_index()

    # Load genomic sequences
    chr_dic, chr_names, chr_full_names = typing_common.read_genome("genome.fa")

    genotype_vars = {}
    genotype_haplotypes = {}
    genotype_clnsig = {}
    if use_clinvar:
        # Extract variants from the ClinVar database
        CLINVAR_fnames = [
            "clinvar.vcf.gz", "clinvar.snp", "clinvar.haplotype",
            "clinvar.clnsig"
        ]

        if not typing_common.check_files(CLINVAR_fnames):
            if not os.path.exists("clinvar.vcf.gz"):
                os.system("wget ftp://ftp.ncbi.nlm.nih.gov/pub/clinvar/"\
                            "vcf_GRCh38/archive/2017/clinvar_20170404.vcf.gz")
            assert os.path.exists("clinvar.vcf.gz")

            extract_cmd = ["hisat2_extract_snps_haplotypes_VCF.py"]
            extract_cmd += [
                "--inter-gap",
                str(inter_gap), "--intra-gap",
                str(intra_gap), "--genotype-vcf", "clinvar.vcf.gz",
                "genome.fa", "/dev/null", "clinvar"
            ]
            if verbose:
                print("\tRunning:", ' '.join(extract_cmd), file=sys.stderr)
            proc = subprocess.Popen(extract_cmd,
                                    stdout=open("/dev/null", 'w'),
                                    stderr=open("/dev/null", 'w'))
            proc.communicate()
            if not typing_common.check_files(CLINVAR_fnames):
                print("Error: extract variants from clinvar failed!",
                      file=sys.stderr)
                sys.exit(1)

        # Read variants to be genotyped
        genotype_vars = typing_common.read_variants("clinvar.snp")

        # Read haplotypes
        genotype_haplotypes = typing_common.read_haplotypes(
            "clinvar.haplotype")

        # Read information about clinical significance
        genotype_clnsig = read_clnsig("clinvar.clnsig")

    if use_commonvar:
        # Extract variants from dbSNP database
        # TODO: CB Write script to make local uptodate SNP database from dbSNP
        # ftp://ftp.ncbi.nlm.nih.gov/snp/database/README.create_local_dbSNP.txt
        commonvar_fbase = "snp144Common"
        commonvar_fnames = [
            "%s.snp" % commonvar_fbase,
            "%s.haplotype" % commonvar_fbase
        ]
        if not typing_common.check_files(commonvar_fnames):
            if not os.path.exists("%s.txt.gz" % commonvar_fbase):
                os.system("wget http://hgdownload.cse.ucsc.edu/goldenPath/hg38/"\
                               "database/%s.txt.gz" % commonvar_fbase)
            assert os.path.exists("%s.txt.gz" % commonvar_fbase)
            os.system("gzip -cd %s.txt.gz "\
                         "| awk 'BEGIN{OFS=\"\t\"} "\
                             "{if($2 ~ /^chr/) {$2 = substr($2, 4)}; "\
                              "if($2 == \"M\") {$2 = \"MT\"} print}' > %s.txt" \
                                  % (commonvar_fbase, commonvar_fbase))
            extract_cmd = [
                "hisat2_extract_snps_haplotypes_UCSC.py", "--inter-gap",
                str(inter_gap), "--intra-gap",
                str(intra_gap), "genome.fa",
                "%s.txt" % commonvar_fbase, commonvar_fbase
            ]
            if verbose:
                print("\tRunning:", ' '.join(extract_cmd), file=sys.stderr)
            proc = subprocess.Popen(extract_cmd,
                                    stdout=open("/dev/null", 'w'),
                                    stderr=open("/dev/null", 'w'))
            proc.communicate()
            if not typing_common.check_files(commonvar_fnames):
                print("Error: extract variants from clinvar failed!",
                      file=sys.stderr)
                sys.exit(1)

        # Read variants to be genotyped
        genotype_vars = typing_common.read_variants(commonvar_fnames[0])

        # Read haplotypes
        genotype_haplotypes = typing_common.read_haplotypes(
            commonvar_fnames[1])

    # Genes to be genotyped
    genotype_genes = {}

    # Read genes or genomics regions
    for database_name in database_list:
        # Extract HLA variants, backbone sequence, and other sequeces
        typing_common.extract_database_if_not_exists(
            database_name,
            [],  # locus_list
            inter_gap,
            intra_gap,
            True,  # partial?
            verbose)
        locus_fname = "%s.locus" % database_name
        assert os.path.exists(locus_fname)
        for line in open(locus_fname):
            locus_name, \
              chr, \
              left, \
              right, \
              length, \
              exon_str, \
              strand \
                   = line.strip().split()
            left = int(left)
            right = int(right)
            length = int(length)
            if chr not in chr_names:
                continue
            if chr not in genotype_genes:
                genotype_genes[chr] = []
            genotype_genes[chr].append([
                left, right, length, locus_name, database_name, exon_str,
                strand
            ])

    # Write genotype genome
    var_num = 0
    haplotype_num = 0
    genome_out_file = open("%s.fa" % base_fname, 'w')
    locus_out_file = open("%s.locus" % base_fname, 'w')
    var_out_file = open("%s.snp" % base_fname, 'w')
    index_var_out_file = open("%s.index.snp" % base_fname, 'w')
    haplotype_out_file = open("%s.haplotype" % base_fname, 'w')
    link_out_file = open("%s.link" % base_fname, 'w')
    coord_out_file = open("%s.coord" % base_fname, 'w')
    clnsig_out_file = open("%s.clnsig" % base_fname, 'w')
    for c in range(len(chr_names)):
        chr = chr_names[c]
        chr_full_name = chr_full_names[c]
        assert chr in chr_dic
        chr_seq = chr_dic[chr]
        chr_len = len(chr_seq)
        if chr in genotype_genes:
            chr_genes = genotype_genes[chr]
            chr_genes = sorted(chr_genes, key=lambda x: (x[1], x[2], x[3]))
        else:
            chr_genes = []

        chr_genotype_vars = []
        chr_genotype_vari = 0
        if graph_index:
            if chr in genotype_vars:
                chr_genotype_vars = genotype_vars[chr]
            chr_genotype_haplotypes = []
            chr_genotype_hti = 0
            if chr in genotype_haplotypes:
                chr_genotype_haplotypes = genotype_haplotypes[chr]

        def add_vars(left, right, chr_genotype_vari, chr_genotype_hti,
                     haplotype_num):
            # Output variants with clinical significance
            while chr_genotype_vari < len(chr_genotype_vars):
                var_left, \
                  var_type, \
                  var_data, \
                  var_id \
                    = chr_genotype_vars[chr_genotype_vari]
                var_right = var_left
                if var_type == "deletion":
                    var_right += var_data
                if var_right > right:
                    break
                if var_right >= left:
                    chr_genotype_vari += 1
                    continue

                out_str = "%s\t%s\t%s\t%d\t%s" % (var_id, var_type, chr,
                                                  var_left + off, var_data)
                print(out_str, file=var_out_file)
                print(out_str, file=index_var_out_file)

                if var_id in genotype_clnsig:
                    var_gene, clnsig = genotype_clnsig[var_id]
                    print("%s\t%s\t%s" \
                             % (var_id, var_gene, clnsig), file=clnsig_out_file)

                chr_genotype_vari += 1

            # Output haplotypes
            while chr_genotype_hti < len(chr_genotype_haplotypes):
                ht_left, ht_right, ht_vars = chr_genotype_haplotypes[
                    chr_genotype_hti]
                if ht_right > right:
                    break
                if ht_right >= left:
                    chr_genotype_hti += 1
                    continue

                print("ht%d\t%s\t%d\t%d\t%s" \
                        % (haplotype_num,
                           chr,
                           ht_left + off,
                           ht_right + off,
                           ','.join(ht_vars)),
                      file=haplotype_out_file)
                chr_genotype_hti += 1
                haplotype_num += 1

            return chr_genotype_vari, chr_genotype_hti, haplotype_num

        out_chr_seq = ""
        off = 0
        prev_right = 0
        for gene in chr_genes:
            left, right, length, name, family, exon_str, strand = gene

            if not graph_index:
                # Output gene (genotype_genome.gene)
                print("%s\t%s\t%s\t%d\t%d\t%s\t%s" \
                        % (family.upper(),
                           name,
                           chr,
                           left,
                           right,
                           exon_str,
                           strand),
                      file=locus_out_file)
                continue

            chr_genotype_vari, \
              chr_genotype_hti, \
              haplotype_num \
                = add_vars(left,
                           right,
                           chr_genotype_vari,
                           chr_genotype_hti,
                           haplotype_num)

            # Read gene family sequences and information
            allele_seqs = typing_common.read_allele_seq("%s_backbone.fa" %
                                                        family)
            allele_vars = typing_common.read_variants("%s.snp" % family)
            allele_index_vars = typing_common.read_variants("%s.index.snp" %
                                                            family)
            allele_haplotypes = typing_common.read_haplotypes("%s.haplotype" %
                                                              family)
            links = typing_common.read_links("%s.link" % family, True)

            if name not in allele_seqs:
                continue
            if name not in allele_vars or name not in allele_index_vars:
                vars = []
                index_vars = []
            else:
                vars = allele_vars[name]
                index_vars = allele_index_vars[name]

            allele_seq = allele_seqs[name]
            index_var_ids = set()
            for _, _, _, var_id in index_vars:
                index_var_ids.add(var_id)

            if name not in allele_haplotypes:
                haplotypes = []
            else:
                haplotypes = allele_haplotypes[name]
            assert length == len(allele_seq)
            assert left < chr_len and right < chr_len
            # Skipping overlapping genes
            if left < prev_right:
                print("Warning: skipping %s ..." % (name), file=sys.stderr)
                continue

            varID2htID = {}
            assert left < right
            prev_length = right - left + 1
            assert prev_length <= length

            if prev_right < left:
                out_chr_seq += chr_seq[prev_right:left]

            # Output gene (genotype_genome.locus)
            print("%s\t%s\t%s\t%d\t%d\t%s\t%s" \
                    % (family.upper(),
                       name,
                       chr,
                       len(out_chr_seq),
                       len(out_chr_seq) + length - 1,
                       exon_str,
                       strand),
                  file=locus_out_file)

            # Output coord (genotype_genome.coord)
            print("%s\t%d\t%d\t%d" \
                    % (chr,
                       len(out_chr_seq),
                       left,
                       right - left + 1),
                  file=coord_out_file)
            out_chr_seq += allele_seq

            # Output variants (genotype_genome.snp and genotype_genome.index.snp)
            for var in vars:
                var_left, var_type, var_data, var_id = var
                new_var_id = "hv%d" % var_num
                varID2htID[var_id] = new_var_id
                new_var_left = var_left + left + off
                assert var_type in ["single", "deletion", "insertion"]
                assert new_var_left < len(out_chr_seq)
                if var_type == "single":
                    assert out_chr_seq[new_var_left] != var_data
                elif var_type == "deletion":
                    assert new_var_left + var_data <= len(out_chr_seq)
                else:
                    assert var_type == "insertion"

                out_str = "%s\t%s\t%s\t%d\t%s" \
                            % (new_var_id, var_type, chr, new_var_left, var_data)
                print(out_str, file=var_out_file)
                if var_id in index_var_ids:
                    print(out_str, file=index_var_out_file)
                var_num += 1

            # Output haplotypes (genotype_genome.haplotype)
            for haplotype in haplotypes:
                ht_left, ht_right, ht_vars = haplotype
                new_ht_left = ht_left + left + off
                assert new_ht_left < len(out_chr_seq)
                new_ht_right = ht_right + left + off
                assert new_ht_left <= new_ht_right
                assert new_ht_right <= len(out_chr_seq)
                new_ht_vars = []
                for var_id in ht_vars:
                    assert var_id in varID2htID
                    new_ht_vars.append(varID2htID[var_id])
                print("ht%d\t%s\t%d\t%d\t%s" \
                        % (haplotype_num,
                           chr,
                           new_ht_left,
                           new_ht_right,
                           ','.join(new_ht_vars)),
                      file=haplotype_out_file)
                haplotype_num += 1

            # Output link information between alleles and variants (genotype_genome.link)
            for link in links:
                var_id, allele_names = link
                if var_id not in varID2htID:
                    continue
                new_var_id = varID2htID[var_id]
                print("%s\t%s" % (new_var_id, " ".join(allele_names)),
                      file=link_out_file)

            off += (length - prev_length)
            prev_right = right + 1

        if not graph_index:
            continue

        # Write the rest of the Vars
        chr_genotype_vari, \
          chr_genotype_hti, \
          haplotype_num \
            = add_vars(sys.maxsize,
                       sys.maxsize,
                       chr_genotype_vari,
                       chr_genotype_hti,
                       haplotype_num)

        print("%s\t%d\t%d\t%d" \
                % (chr,
                   len(out_chr_seq),
                   prev_right,
                   len(chr_seq) - prev_right),
              file=coord_out_file)
        out_chr_seq += chr_seq[prev_right:]

        assert len(out_chr_seq) == len(chr_seq) + off

        # Output chromosome sequence
        print(">%s" % (chr_full_name), file=genome_out_file)
        line_width = 60
        for s in range(0, len(out_chr_seq), line_width):
            print(out_chr_seq[s:s + line_width], file=genome_out_file)

    genome_out_file.close()
    locus_out_file.close()
    var_out_file.close()
    index_var_out_file.close()
    haplotype_out_file.close()
    link_out_file.close()
    coord_out_file.close()
    clnsig_out_file.close()

    allele_out_file = open("%s.allele" % base_fname, 'w')
    if graph_index:
        for database in database_list:
            for line in open("%s.allele" % database):
                allele_name = line.strip()
                print("%s\t%s" % (database.upper(), allele_name),
                      file=allele_out_file)
    allele_out_file.close()

    partial_out_file = open("%s.partial" % base_fname, 'w')
    if graph_index:
        for database in database_list:
            for line in open("%s.partial" % database):
                allele_name = line.strip()
                print("%s\t%s" % (database.upper(), allele_name),
                      file=partial_out_file)
    partial_out_file.close()

    if not graph_index:
        shutil.copyfile("genome.fa", "%s.fa" % base_fname)

    # Index genotype_genome.fa
    index_cmd = ["samtools", "faidx", "%s.fa" % base_fname]
    subprocess.call(index_cmd)

    # Build indexes based on the above information
    if graph_index:
        assert aligner == "hisat2"
        build_cmd = [
            "hisat2-build", "-p",
            str(threads), "--snp",
            "%s.index.snp" % base_fname, "--haplotype",
            "%s.haplotype" % base_fname,
            "%s.fa" % base_fname,
            "%s" % base_fname
        ]
    else:
        assert aligner in ["hisat2", "bowtie2"]
        build_cmd = [
            "%s-build" % aligner, "-p" if aligner == "hisat2" else "--threads",
            str(threads),
            "%s.fa" % base_fname,
            "%s" % base_fname
        ]
    if verbose:
        print("\tRunning:", ' '.join(build_cmd), file=sys.stderr)

    subprocess.call(build_cmd,
                    stdout=open("/dev/null", 'w'),
                    stderr=open("/dev/null", 'w'))

    if aligner == "hisat2":
        index_fnames = ["%s.%d.ht2" % (base_fname, i + 1) for i in range(8)]
    else:
        index_fnames = ["%s.%d.bt2" % (base_fname, i + 1) for i in range(4)]
        index_fnames += [
            "%s.rev.%d.bt2" % (base_fname, i + 1) for i in range(2)
        ]
    if not typing_common.check_files(index_fnames):
        print("Error: indexing failed! "\
               "Perhaps, you may have forgotten to build %s executables?" \
                    % aligner,
              file=sys.stderr)
        sys.exit(1)
Exemplo n.º 3
0
def genotype(base_fname,
             fastq,
             read_fnames,
             threads,
             num_mismatch,
             verbose,
             daehwan_debug):
    # Load genomic sequences
    chr_dic, chr_names, chr_full_names = typing_common.read_genome(open("%s.fa" % base_fname))

    # variants, backbone sequence, and other sequeces
    genotype_fnames = ["%s.fa" % base_fname,
                       "%s.gene" % base_fname,
                       "%s.snp" % base_fname,
                       "%s.index.snp" % base_fname,
                       "%s.haplotype" % base_fname,
                       "%s.link" % base_fname,
                       "%s.coord" % base_fname,
                       "%s.clnsig" % base_fname]
    # hisat2 graph index files
    genotype_fnames += ["%s.%d.ht2" % (base_fname, i+1) for i in range(8)]
    if not typing_common.check_files(genotype_fnames):
        print >> sys.stderr, "Error: some of the following files are missing!"
        for fname in genotype_fnames:
            print >> sys.stderr, "\t%s" % fname
        sys.exit(1)

    # Align reads, and sort the alignments into a BAM file
    align_reads(base_fname,
                read_fnames,
                fastq,
                threads,
                verbose)

    # Read HLA alleles (names and sequences)
    genes, gene_loci, gene_seqs = {}, {}, {}
    for line in open("%s.gene" % base_fname):
        family, allele_name, chr, left, right = line.strip().split()
        gene_name = "%s-%s" % (family, allele_name.split('*')[0])
        assert gene_name not in genes
        genes[gene_name] = allele_name
        left, right = int(left), int(right)
        """
        exons = []
        for exon in exon_str.split(','):
            exon_left, exon_right = exon.split('-')
            exons.append([int(exon_left), int(exon_right)])
        """
        gene_loci[gene_name] = [allele_name, chr, left, right]
        assert chr in chr_dic
        chr_seq = chr_dic[chr]
        assert left < right
        assert right < len(chr_seq)
        gene_seqs[gene_name] = chr_dic[chr][left:right+1]

    # Read link information
    Links, var_genes, allele_vars = {}, {}, {}
    for line in open("%s.link" % base_fname):
        var_id, alleles = line.strip().split('\t')
        alleles = alleles.split()
        assert not var_id in Links
        Links[var_id] = alleles
        for allele in alleles:
            if allele not in allele_vars:
                allele_vars[allele] = set()
            allele_vars[allele].add(var_id)
            gene_name = "HLA-%s" % (allele.split('*')[0])
            var_genes[var_id] = gene_name

    # gene alleles
    allele_names = {}
    for gene_name in genes.keys():
        if gene_name not in allele_names:
            allele_names[gene_name] = []
        gene_name2 = gene_name.split('-')[1]
        for allele_name in allele_vars.keys():
            allele_name1 = allele_name.split('*')[0]
            if gene_name2 == allele_name1:
                allele_names[gene_name].append(allele_name)


    # Read HLA variants, and link information
    Vars, Var_list = {}, {}
    for line in open("%s.snp" % base_fname):
        var_id, var_type, chr, pos, data = line.strip().split('\t')
        pos = int(pos)

        # daehwan - for debugging purposes
        if var_id not in var_genes:
            continue
        
        assert var_id in var_genes
        gene_name = var_genes[var_id]
        if not gene_name in Vars:
            Vars[gene_name] = {}
            assert not gene_name in Var_list
            Var_list[gene_name] = []
            
        assert not var_id in Vars[gene_name]
        Vars[gene_name][var_id] = [var_type, pos, data]
        Var_list[gene_name].append([pos, var_id])

    for gene_name, in_var_list in Var_list.items():
        Var_list[gene_name] = sorted(in_var_list)
    def lower_bound(Var_list, pos):
        low, high = 0, len(Var_list)
        while low < high:
            m = (low + high) / 2
            m_pos = Var_list[m][0]
            if m_pos < pos:
                low = m + 1
            elif m_pos > pos:
                high = m
            else:
                assert m_pos == pos
                while m > 0:
                    if Var_list[m-1][0] < pos:
                        break
                    m -= 1
                return m
        return low       
           
    
    # HLA gene allele lengths
    """
    HLA_lengths = {}
    for HLA_gene, HLA_alleles in HLAs.items():
        HLA_lengths[HLA_gene] = {}
        for allele_name, seq in HLA_alleles.items():
            HLA_lengths[HLA_gene][allele_name] = len(seq)
    """

    # Cigar regular expression
    cigar_re = re.compile('\d+\w')

    test_list = [[sorted(genes.keys())]]
    for test_i in range(len(test_list)):
        test_HLA_list = test_list[test_i]
        for test_HLA_names in test_HLA_list:
            print >> sys.stderr, "\t%s" % (test_HLA_names)
            for gene in test_HLA_names:
                ref_allele = genes[gene]
                ref_seq = gene_seqs[gene]
                # ref_exons = refHLA_loci[gene][-1]

                # Read alignments
                alignview_cmd = ["samtools",
                                 "view"]
                alignview_cmd += ["hla_output.bam"]
                base_locus = 0
                _, chr, left, right = gene_loci[gene]
                base_locus = left
                alignview_cmd += ["%s:%d-%d" % (chr, left + 1, right + 1)]

                bamview_proc = subprocess.Popen(alignview_cmd,
                                                stdout=subprocess.PIPE,
                                                stderr=open("/dev/null", 'w'))

                sort_read_cmd = ["sort", "-k", "1", "-n"]
                alignview_proc = subprocess.Popen(sort_read_cmd,
                                                  stdin=bamview_proc.stdout,
                                                  stdout=subprocess.PIPE,
                                                  stderr=open("/dev/null", 'w'))

                # Count alleles
                HLA_counts, HLA_cmpt = {}, {}
                coverage = [0 for i in range(len(ref_seq) + 1)]
                num_reads, total_read_len = 0, 0
                prev_read_id = None
                prev_exon = False
                for line in alignview_proc.stdout:
                    cols = line.strip().split()
                    read_id, flag, chr, pos, mapQ, cigar_str = cols[:6]
                    origin_read_id = read_id
                    if read_id.find('|') != -1:
                       tmp_read_id = read_id.split('|')[0]
                       try:
                           read_id = int(tmp_read_id)
                       except ValueError:
                           None
                       
                    read_seq, qual = cols[9], cols[10]
                    num_reads += 1
                    total_read_len += len(read_seq)
                    flag, pos = int(flag), int(pos)
                    pos -= 1
                    if pos < 0:
                        continue

                    if flag & 0x4 != 0:
                        continue

                    NM, Zs, MD = "", "", ""
                    for i in range(11, len(cols)):
                        col = cols[i]
                        if col.startswith("Zs"):
                            Zs = col[5:]
                        elif col.startswith("MD"):
                            MD = col[5:]
                        elif col.startswith("NM"):
                            NM = int(col[5:])

                    if NM > num_mismatch:
                        continue

                    # daehwan - for debugging purposes
                    debug = False
                    if read_id in ["2339"] and False:
                        debug = True
                        print "read_id: %s)" % read_id, pos, cigar_str, "NM:", NM, MD, Zs
                        print "            ", read_seq

                    vars = []
                    if Zs:
                        vars = Zs.split(',')

                    assert MD != ""
                    MD_str_pos, MD_len = 0, 0
                    read_pos, left_pos = 0, pos
                    right_pos = left_pos
                    cigars = cigar_re.findall(cigar_str)
                    cigars = [[cigar[-1], int(cigar[:-1])] for cigar in cigars]
                    cmp_list = []
                    for i in range(len(cigars)):
                        cigar_op, length = cigars[i]
                        if cigar_op == 'M':
                            first = True
                            MD_len_used = 0
                            while True:
                                if not first or MD_len == 0:
                                    if MD[MD_str_pos].isdigit():
                                        num = int(MD[MD_str_pos])
                                        MD_str_pos += 1
                                        while MD_str_pos < len(MD):
                                            if MD[MD_str_pos].isdigit():
                                                num = num * 10 + int(MD[MD_str_pos])
                                                MD_str_pos += 1
                                            else:
                                                break
                                        MD_len += num
                                # Insertion or full match followed
                                if MD_len >= length:
                                    MD_len -= length
                                    cmp_list.append(["match", right_pos + MD_len_used, length - MD_len_used])
                                    break
                                first = False
                                read_base = read_seq[read_pos + MD_len]
                                MD_ref_base = MD[MD_str_pos]
                                MD_str_pos += 1
                                assert MD_ref_base in "ACGT"
                                cmp_list.append(["match", right_pos + MD_len_used, MD_len - MD_len_used])
                                cmp_list.append(["mismatch", right_pos + MD_len, 1])
                                MD_len_used = MD_len + 1
                                MD_len += 1
                                # Full match
                                if MD_len == length:
                                    MD_len = 0
                                    break
                        elif cigar_op == 'I':
                            cmp_list.append(["insertion", right_pos, length])
                        elif cigar_op == 'D':
                            if MD[MD_str_pos] == '0':
                                MD_str_pos += 1
                            assert MD[MD_str_pos] == '^'
                            MD_str_pos += 1
                            while MD_str_pos < len(MD):
                                if not MD[MD_str_pos] in "ACGT":
                                    break
                                MD_str_pos += 1
                            cmp_list.append(["deletion", right_pos, length])
                        elif cigar_op == 'S':
                            cmp_list.append(["soft", right_pos, length])
                        else:                    
                            assert cigar_op == 'N'
                            cmp_list.append(["intron", right_pos, length])

                        if cigar_op in "MND":
                            right_pos += length

                        if cigar_op in "MIS":
                            read_pos += length

                    """
                    exon = False
                    for exon in ref_exons:
                        exon_left, exon_right = exon
                        if right_pos <= exon_left or pos > exon_right:
                            continue
                        else:
                            exon = True
                            break
                    """

                    if left_pos < base_locus or \
                            right_pos - base_locus > len(ref_seq):
                        continue
                
                    def add_stat(HLA_cmpt, HLA_counts, HLA_count_per_read, exon = True):
                        max_count = max(HLA_count_per_read.values())
                        cur_cmpt = set()
                        for allele, count in HLA_count_per_read.items():
                            if count < max_count:
                                continue
                            """
                            if allele in exclude_allele_list:
                                continue
                            """
                            cur_cmpt.add(allele)                    
                            if not allele in HLA_counts:
                                HLA_counts[allele] = 1
                            else:
                                HLA_counts[allele] += 1

                        if len(cur_cmpt) == 0:
                            return

                        # daehwan - for debugging purposes                            
                        alleles = ["", ""]
                        # alleles = ["B*40:304", "B*40:02:01"]
                        allele1_found, allele2_found = False, False
                        for allele, count in HLA_count_per_read.items():
                            if count < max_count:
                                continue
                            if allele == alleles[0]:
                                allele1_found = True
                            elif allele == alleles[1]:
                                allele2_found = True
                        if allele1_found != allele2_found:
                            print alleles[0], HLA_count_per_read[alleles[0]]
                            print alleles[1], HLA_count_per_read[alleles[1]]
                            if allele1_found:
                                print ("%s\tread_id %s - %d vs. %d]" % (alleles[0], prev_read_id, max_count, HLA_count_per_read[alleles[1]]))
                            else:
                                print ("%s\tread_id %s - %d vs. %d]" % (alleles[1], prev_read_id, max_count, HLA_count_per_read[alleles[0]]))
                            print read_seq

                        cur_cmpt = sorted(list(cur_cmpt))
                        cur_cmpt = '-'.join(cur_cmpt)
                        add = 1
                        """
                        if partial and not exon:
                            add *= 0.2
                        """
                        if not cur_cmpt in HLA_cmpt:
                            HLA_cmpt[cur_cmpt] = add
                        else:
                            HLA_cmpt[cur_cmpt] += add

                    if read_id != prev_read_id:
                        if prev_read_id != None:
                            add_stat(HLA_cmpt, HLA_counts, HLA_count_per_read, prev_exon)

                        HLA_count_per_read = {}
                        for HLA_name in allele_names[gene]:
                            if HLA_name.find("BACKBONE") != -1:
                                continue
                            HLA_count_per_read[HLA_name] = 0

                    def add_count(var_id, add):
                        assert var_id in Links
                        alleles = Links[var_id]
                        for allele in alleles:
                            if allele.find("BACKBONE") != -1:
                                continue
                            HLA_count_per_read[allele] += add
                            # daehwan - for debugging purposes
                            if debug:
                                if allele in ["DQA1*05:05:01:01", "DQA1*05:05:01:02"]:
                                    print allele, add, var_id

                    # Decide which allele(s) a read most likely came from
                    # also sanity check - read length, cigar string, and MD string
                    for var_id, data in Vars[gene].items():
                        var_type, var_pos, var_data = data
                        if var_type != "deletion":
                            continue
                        if left_pos >= var_pos and right_pos <= var_pos + int(var_data):
                            add_count(var_id, -1)                            
                    ref_pos, read_pos, cmp_cigar_str, cmp_MD = left_pos, 0, "", ""
                    cigar_match_len, MD_match_len = 0, 0            
                    for cmp in cmp_list:
                        type = cmp[0]
                        length = cmp[2]
                        if type == "match":
                            var_idx = lower_bound(Var_list[gene], ref_pos)
                            while var_idx < len(Var_list[gene]):
                                var_pos, var_id = Var_list[gene][var_idx]
                                if ref_pos + length <= var_pos:
                                    break
                                if ref_pos <= var_pos:
                                    var_type, _, var_data = Vars[gene][var_id]
                                    if var_type == "insertion":
                                        if ref_pos < var_pos and ref_pos + length > var_pos + len(var_data):
                                            add_count(var_id, -1)
                                            # daehwan - for debugging purposes
                                            if debug:
                                                print cmp, var_id, Links[var_id]
                                    elif var_type == "deletion":
                                        del_len = int(var_data)
                                        if ref_pos < var_pos and ref_pos + length > var_pos + del_len:
                                            # daehwan - for debugging purposes
                                            if debug:
                                                print cmp, var_id, Links[var_id], -1, Vars[gene][var_id]
                                            # Check if this might be one of the two tandem repeats (the same left coordinate)
                                            cmp_left, cmp_right = cmp[1], cmp[1] + cmp[2]
                                            test1_seq1 = ref_seq[cmp_left-base_locus:cmp_right-base_locus]
                                            test1_seq2 = ref_seq[cmp_left-base_locus:var_pos-base_locus] + ref_seq[var_pos + del_len - base_locus:cmp_right + del_len - base_locus]
                                            # Check if this happens due to small repeats (the same right coordinate - e.g. 19 times of TTTC in DQA1*05:05:01:02)
                                            cmp_left -= read_pos
                                            cmp_right += (len(read_seq) - read_pos - cmp[2])
                                            test2_seq1 = ref_seq[cmp_left+int(var_data)-base_locus:cmp_right-base_locus]
                                            test2_seq2 = ref_seq[cmp_left-base_locus:var_pos-base_locus] + ref_seq[var_pos+int(var_data)-base_locus:cmp_right-base_locus]
                                            if test1_seq1 != test1_seq2 and test2_seq1 != test2_seq2:
                                                add_count(var_id, -1)
                                    else:
                                        if debug:
                                            print cmp, var_id, Links[var_id], -1
                                        add_count(var_id, -1)
                                var_idx += 1

                            read_pos += length
                            ref_pos += length
                            cigar_match_len += length
                            MD_match_len += length
                        elif type == "mismatch":
                            read_base = read_seq[read_pos]
                            var_idx = lower_bound(Var_list[gene], ref_pos)
                            while var_idx < len(Var_list[gene]):
                                var_pos, var_id = Var_list[gene][var_idx]
                                if ref_pos < var_pos:
                                    break
                                if ref_pos == var_pos:
                                    var_type, _, var_data = Vars[gene][var_id]
                                    if var_type == "single":
                                        if var_data == read_base:
                                            # daehwan - for debugging purposes
                                            if debug:
                                                print cmp, var_id, 1, var_data, read_base, Links[var_id]

                                            # daehwan - for debugging purposes
                                            if False:
                                                read_qual = ord(qual[read_pos])
                                                add_count(var_id, (read_qual - 60) / 60.0)
                                            else:
                                                add_count(var_id, 1)
                                        # daehwan - check out if this routine is appropriate
                                        # else:
                                        #    add_count(var_id, -1)
                                var_idx += 1
                            cmp_MD += ("%d%s" % (MD_match_len, ref_seq[ref_pos-base_locus]))
                            MD_match_len = 0
                            cigar_match_len += 1
                            read_pos += 1
                            ref_pos += 1
                        elif type == "insertion":
                            ins_seq = read_seq[read_pos:read_pos+length]
                            var_idx = lower_bound(Var_list[gene], ref_pos)
                            # daehwan - for debugging purposes
                            if debug:
                                print left_pos, cigar_str, MD, vars
                                print ref_pos, ins_seq, Var_list[gene][var_idx], Vars[gene][Var_list[gene][var_idx][1]]
                                # sys.exit(1)
                            while var_idx < len(Var_list[gene]):
                                var_pos, var_id = Var_list[gene][var_idx]
                                if ref_pos < var_pos:
                                    break
                                if ref_pos == var_pos:
                                    var_type, _, var_data = Vars[gene][var_id]
                                    if var_type == "insertion":                                
                                        if var_data == ins_seq:
                                            # daehwan - for debugging purposes
                                            if debug:
                                                print cmp, var_id, 1, Links[var_id]
                                            add_count(var_id, 1)
                                var_idx += 1

                            if cigar_match_len > 0:
                                cmp_cigar_str += ("%dM" % cigar_match_len)
                                cigar_match_len = 0
                            read_pos += length
                            cmp_cigar_str += ("%dI" % length)
                        elif type == "deletion":
                            del_len = length
                            # Deletions can be shifted bidirectionally
                            temp_ref_pos = ref_pos
                            while temp_ref_pos > 0:
                                last_bp = ref_seq[temp_ref_pos + del_len - 1 - base_locus]
                                prev_bp = ref_seq[temp_ref_pos - 1 - base_locus]
                                if last_bp != prev_bp:
                                    break
                                temp_ref_pos -= 1
                            var_idx = lower_bound(Var_list[gene], temp_ref_pos)
                            while var_idx < len(Var_list[gene]):
                                var_pos, var_id = Var_list[gene][var_idx]
                                if temp_ref_pos < var_pos:
                                    first_bp = ref_seq[temp_ref_pos - base_locus]
                                    next_bp = ref_seq[temp_ref_pos + del_len - base_locus]
                                    if first_bp == next_bp:
                                        temp_ref_pos += 1
                                        continue
                                    else:
                                        break
                                if temp_ref_pos == var_pos:
                                    var_type, _, var_data = Vars[gene][var_id]
                                    if var_type == "deletion":
                                        var_len = int(var_data)
                                        if var_len == length:
                                            if debug:
                                                print cmp, var_id, 1, Links[var_id]
                                                print ref_seq[var_pos - 10-base_locus:var_pos-base_locus], ref_seq[var_pos-base_locus:var_pos+int(var_data)-base_locus], ref_seq[var_pos+int(var_data)-base_locus:var_pos+int(var_data)+10-base_locus]
                                            add_count(var_id, 1)
                                var_idx += 1

                            if cigar_match_len > 0:
                                cmp_cigar_str += ("%dM" % cigar_match_len)
                                cigar_match_len = 0
                            cmp_MD += ("%d" % MD_match_len)
                            MD_match_len = 0
                            cmp_cigar_str += ("%dD" % length)
                            cmp_MD += ("^%s" % ref_seq[ref_pos-base_locus:ref_pos+length-base_locus])
                            ref_pos += length
                        elif type == "soft":
                            if cigar_match_len > 0:
                                cmp_cigar_str += ("%dM" % cigar_match_len)
                                cigar_match_len = 0
                            read_pos += length
                            cmp_cigar_str += ("%dS" % length)
                        else:
                            assert type == "intron"
                            if cigar_match_len > 0:
                                cmp_cigar_str += ("%dM" % cigar_match_len)
                                cigar_match_len = 0
                            cmp_cigar_str += ("%dN" % length)
                            ref_pos += length                    
                    if cigar_match_len > 0:
                        cmp_cigar_str += ("%dM" % cigar_match_len)
                    cmp_MD += ("%d" % MD_match_len)
                    if read_pos != len(read_seq) or \
                            cmp_cigar_str != cigar_str or \
                            cmp_MD != MD:
                        print >> sys.stderr, "Error:", cigar_str, MD
                        print >> sys.stderr, "\tcomputed:", cmp_cigar_str, cmp_MD
                        print >> sys.stderr, "\tcmp list:", cmp_list
                        assert False            

                    prev_read_id = read_id
                    # prev_exon = exon

                if num_reads <= 0:
                    continue

                if prev_read_id != None:
                    add_stat(HLA_cmpt, HLA_counts, HLA_count_per_read)

                HLA_counts = [[allele, count] for allele, count in HLA_counts.items()]
                def HLA_count_cmp(a, b):
                    if a[1] != b[1]:
                        return b[1] - a[1]
                    assert a[0] != b[0]
                    if a[0] < b[0]:
                        return -1
                    else:
                        return 1
                HLA_counts = sorted(HLA_counts, cmp=HLA_count_cmp)
                for count_i in range(len(HLA_counts)):
                    count = HLA_counts[count_i]
                    print >> sys.stderr, "\t\t\t\t%d %s (count: %d)" % (count_i + 1, count[0], count[1])
                    if count_i >= 9:
                        break
                print >> sys.stderr

                def normalize(prob):
                    total = sum(prob.values())
                    for allele, mass in prob.items():
                        prob[allele] = mass / total

                def normalize2(prob, length):
                    total = 0
                    for allele, mass in prob.items():
                        assert allele in length
                        total += (mass / length[allele])
                    for allele, mass in prob.items():
                        assert allele in length
                        prob[allele] = mass / length[allele] / total

                def prob_diff(prob1, prob2):
                    diff = 0.0
                    for allele in prob1.keys():
                        if allele in prob2:
                            diff += abs(prob1[allele] - prob2[allele])
                        else:
                            diff += prob1[allele]
                    return diff

                def HLA_prob_cmp(a, b):
                    if a[1] != b[1]:
                        if a[1] < b[1]:
                            return 1
                        else:
                            return -1
                    assert a[0] != b[0]
                    if a[0] < b[0]:
                        return -1
                    else:
                        return 1

                HLA_prob, HLA_prob_next = {}, {}
                for cmpt, count in HLA_cmpt.items():
                    alleles = cmpt.split('-')
                    for allele in alleles:
                        if allele not in HLA_prob:
                            HLA_prob[allele] = 0.0
                        HLA_prob[allele] += (float(count) / len(alleles))

                """
                assert gene in HLA_lengths
                HLA_length = HLA_lengths[gene]
                """
                HLA_length = {}
                
                # normalize2(HLA_prob, HLA_length)
                normalize(HLA_prob)
                def next_prob(HLA_cmpt, HLA_prob, HLA_length):
                    HLA_prob_next = {}
                    for cmpt, count in HLA_cmpt.items():
                        alleles = cmpt.split('-')
                        alleles_prob = 0.0
                        for allele in alleles:
                            assert allele in HLA_prob
                            alleles_prob += HLA_prob[allele]
                        for allele in alleles:
                            if allele not in HLA_prob_next:
                                HLA_prob_next[allele] = 0.0
                            HLA_prob_next[allele] += (float(count) * HLA_prob[allele] / alleles_prob)
                    # normalize2(HLA_prob_next, HLA_length)
                    normalize(HLA_prob_next)
                    return HLA_prob_next

                diff, iter = 1.0, 0
                while diff > 0.0001 and iter < 1000:
                    HLA_prob_next = next_prob(HLA_cmpt, HLA_prob, HLA_length)
                    diff = prob_diff(HLA_prob, HLA_prob_next)
                    HLA_prob = HLA_prob_next
                    iter += 1

                """
                for allele, prob in HLA_prob.items():
                    allele_len = len(HLAs[gene][allele])
                    HLA_prob[allele] /= float(allele_len)
                normalize(HLA_prob)
                """
                HLA_prob = [[allele, prob] for allele, prob in HLA_prob.items()]

                HLA_prob = sorted(HLA_prob, cmp=HLA_prob_cmp)
                success = [False for i in range(len(test_HLA_names))]
                found_list = [False for i in range(len(test_HLA_names))]
                for prob_i in range(len(HLA_prob)):
                    prob = HLA_prob[prob_i]
                    print >> sys.stderr, "\t\t\t\t%d ranked %s (abundance: %.2f%%)" % (prob_i + 1, prob[0], prob[1] * 100.0)
                    if prob_i >= 9:
                        break
                print >> sys.stderr

                """
                if len(test_HLA_names) == 2:
                    HLA_prob, HLA_prob_next = {}, {}
                    for cmpt, count in HLA_cmpt.items():
                        alleles = cmpt.split('-')
                        for allele1 in alleles:
                            for allele2 in HLA_names[gene]:
                                if allele1 < allele2:
                                    allele_pair = "%s-%s" % (allele1, allele2)
                                else:
                                    allele_pair = "%s-%s" % (allele2, allele1)
                                if not allele_pair in HLA_prob:
                                    HLA_prob[allele_pair] = 0.0
                                HLA_prob[allele_pair] += (float(count) / len(alleles))

                    if len(HLA_prob) <= 0:
                        continue

                    # Choose top allele pairs
                    def choose_top_alleles(HLA_prob):
                        HLA_prob_list = [[allele_pair, prob] for allele_pair, prob in HLA_prob.items()]
                        HLA_prob_list = sorted(HLA_prob_list, cmp=HLA_prob_cmp)
                        HLA_prob = {}
                        best_prob = HLA_prob_list[0][1]
                        for i in range(len(HLA_prob_list)):
                            allele_pair, prob = HLA_prob_list[i]
                            if prob * 2 <= best_prob:
                                break                        
                            HLA_prob[allele_pair] = prob
                        normalize(HLA_prob)
                        return HLA_prob
                    HLA_prob = choose_top_alleles(HLA_prob)

                    def next_prob(HLA_cmpt, HLA_prob):
                        HLA_prob_next = {}
                        for cmpt, count in HLA_cmpt.items():
                            alleles = cmpt.split('-')
                            prob = 0.0
                            for allele in alleles:
                                for allele_pair in HLA_prob.keys():
                                    if allele in allele_pair:
                                        prob += HLA_prob[allele_pair]
                            for allele in alleles:
                                for allele_pair in HLA_prob.keys():
                                    if not allele in allele_pair:
                                        continue
                                    if allele_pair not in HLA_prob_next:
                                        HLA_prob_next[allele_pair] = 0.0
                                    HLA_prob_next[allele_pair] += (float(count) * HLA_prob[allele_pair] / prob)
                        normalize(HLA_prob_next)
                        return HLA_prob_next

                    diff, iter = 1.0, 0
                    while diff > 0.0001 and iter < 1000:
                        HLA_prob_next = next_prob(HLA_cmpt, HLA_prob)
                        diff = prob_diff(HLA_prob, HLA_prob_next)
                        HLA_prob = HLA_prob_next
                        HLA_prob = choose_top_alleles(HLA_prob)
                        iter += 1

                    HLA_prob = [[allele_pair, prob] for allele_pair, prob in HLA_prob.items()]
                    HLA_prob = sorted(HLA_prob, cmp=HLA_prob_cmp)

                    success = [False]
                    for prob_i in range(len(HLA_prob)):
                        allele_pair, prob = HLA_prob[prob_i]
                        allele1, allele2 = allele_pair.split('-')
                        if best_alleles and prob_i < 1:
                            print >> sys.stdout, "PairModel %s (abundance: %.2f%%)" % (allele_pair, prob * 100.0)
                        if simulation:
                            if allele1 in test_HLA_names and allele2 in test_HLA_names:
                                rank_i = prob_i
                                while rank_i > 0:
                                    if HLA_prob[rank_i-1][1] == prob:                                        
                                        rank_i -= 1
                                    else:
                                        break
                                print >> sys.stderr, "\t\t\t*** %d ranked %s (abundance: %.2f%%)" % (rank_i + 1, allele_pair, prob * 100.0)
                                if rank_i == 0:
                                    success[0] = True
                                break
                        print >> sys.stderr, "\t\t\t\t%d ranked %s (abundance: %.2f%%)" % (prob_i + 1, allele_pair, prob * 100.0)
                        if not simulation and prob_i >= 9:
                            break
                    print >> sys.stderr
                """

    # Read variants with clinical significance
    clnsigs = {}
    for line in open("%s.clnsig" % base_fname):
        var_id, var_gene, var_clnsig = line.strip().split('\t')
        clnsigs[var_id] = [var_gene, var_clnsig]

    vars, Var_list = {}, {}
    for line in open("%s.snp" % base_fname):
        var_id, type, chr, left, data = line.strip().split()
        if var_id not in clnsigs:
            continue
        left = int(left)
        if type == "deletion":
            data = int(data)
        vars[var_id] = [chr, left, type, data]
        if chr not in Var_list:
            Var_list[chr] = []
        Var_list[chr].append([left, var_id])

    var_counts = {}

    # Read alignments
    alignview_cmd = ["samtools",
                     "view",
                     "hla_output.bam"]
    bamview_proc = subprocess.Popen(alignview_cmd,
                                    stdout=subprocess.PIPE,
                                    stderr=open("/dev/null", 'w'))

    for line in bamview_proc.stdout:
        cols = line.strip().split()
        read_id, flag, chr, pos, mapQ, cigar_str = cols[:6]
        read_seq, qual = cols[9], cols[10]
        flag, pos = int(flag), int(pos)
        pos -= 1
        if pos < 0:
            continue

        if flag & 0x4 != 0:
            continue

        if chr not in Var_list:
            continue

        assert chr in chr_dic
        chr_seq = chr_dic[chr]

        NM, Zs, MD, NH = "", "", "", ""
        for i in range(11, len(cols)):
            col = cols[i]
            if col.startswith("Zs"):
                Zs = col[5:]
            elif col.startswith("MD"):
                MD = col[5:]
            elif col.startswith("NM"):
                NM = int(col[5:])
            elif col.startswith("NH"):
                NH = int(col[5:])

        assert NH != ""
        NH = int(NH)
        if NH > 1:
            continue

        if NM > num_mismatch:
            continue

        read_vars = []
        if Zs:
            read_vars = Zs.split(',')
        for read_var in read_vars:
            _, _, var_id = read_var.split('|')
            if var_id not in clnsigs:
                continue
            if var_id not in var_counts:
                var_counts[var_id] = [1, 0]
            else:
                var_counts[var_id][0] += 1

        assert MD != ""
        MD_str_pos, MD_len = 0, 0
        read_pos, left_pos = 0, pos
        right_pos = left_pos
        cigars = cigar_re.findall(cigar_str)
        cigars = [[cigar[-1], int(cigar[:-1])] for cigar in cigars]
        cmp_list = []
        for i in range(len(cigars)):
            cigar_op, length = cigars[i]
            if cigar_op == 'M':
                chr_var_list = Var_list[chr]
                var_idx = lower_bound(chr_var_list, right_pos)
                while var_idx < len(chr_var_list):
                    var_pos, var_id = chr_var_list[var_idx]
                    if var_pos >= right_pos + length:
                        break
                    if var_pos >= right_pos:
                        assert var_id in vars
                        _, _, var_type, var_data = vars[var_id]
                        contradict = False
                        if var_type == "single":
                            contradict = (read_seq[read_pos + var_pos - right_pos] == chr_seq[var_pos])
                        elif var_type == "insertion":
                            contradict = (right_pos < var_pos)
                        else:
                            contradict = True
                        if contradict:
                            if var_id not in var_counts:
                                var_counts[var_id] = [0, 1]
                            else:
                                var_counts[var_id][1] += 1
                    
                    var_idx += 1
                    
            if cigar_op in "MND":
                right_pos += length

            if cigar_op in "MIS":
                read_pos += length

    for var_id, counts in var_counts.items():
        if counts[0] < 2: # or counts[0] * 3 < counts[1]:
            continue
        assert var_id in vars
        var_chr, var_left, var_type, var_data = vars[var_id]
        assert var_id in clnsigs
        var_gene, var_clnsig = clnsigs[var_id]
        print >> sys.stderr, "\t\t\t%s %s: %s:%d %s %s (%s): %d-%d" % \
                (var_gene, var_id, var_chr, var_left, var_type, var_data, var_clnsig, counts[0], counts[1])
def extract_reads(base_fname, # base file name for index, variants, haplotypes, etc. /Users/katiecampbell/hisat_test/tm/hisat2-hisat2_v2.2.0_beta/evaluation/hla-analysis/genotype_genome
                  database_list, # comma-separated, use hla
                  read_dir,
                  out_dir,
                  suffix,
                  read_fname,
                  fastq,
                  paired,
                  simulation,
                  threads,
                  threads_aprocess,
                  max_sample,
                  job_range,
                  aligner,
                  block_size,
                  is_rna,
                  rna_strandness,
                  verbose):
    if block_size > 0:
        resource.setrlimit(resource.RLIMIT_NOFILE, (1000, 1000))
        resource.setrlimit(resource.RLIMIT_NPROC, (1000, 1000))

    genotype_fnames = ["%s.fa" % base_fname, # Filenames for genotypes in /Users/katiecampbell/hisat_test/tm/hisat2-hisat2_v2.2.0_beta/evaluation/hla-analysis/genotype_genome
                       "%s.locus" % base_fname,
                       "%s.snp" % base_fname,
                       "%s.haplotype" % base_fname,
                       "%s.link" % base_fname,
                       "%s.coord" % base_fname,
                       "%s.clnsig" % base_fname]
    # graph index files
    if aligner == "hisat2":
        genotype_fnames += ["%s.%d.ht2" % (base_fname, i+1) for i in range(8)] # gets the hisat index files
    else:
        assert aligner == "bowtie2"
        genotype_fnames = ["%s.%d.bt2" % (base_fname, i+1) for i in range(4)]
        genotype_fnames += ["%s.rev.%d.bt2" % (base_fname, i+1) for i in range(2)]

    if not typing_common.check_files(genotype_fnames): # Checks to see if file exists
        print >> sys.stderr, "Error: %s related files do not exist as follows:" % base_fname
        for fname in genotype_fnames:
            print >> sys.stderr, "\t%s" % fname
        sys.exit(1)

    filter_region = len(database_list) > 0
    ranges = []
    regions, region_loci = {}, {}
    for line in open("%s.locus" % base_fname):
        family, allele_name, chr, left, right = line.strip().split()[:5] # pulls HLA families in the
        if filter_region and family.lower() not in database_list:
            continue
        region_name = "%s-%s" % (family, allele_name.split('*')[0]) # specifies which HLA gene
        assert region_name not in regions
        regions[region_name] = allele_name
        left, right = int(left), int(right)
        """
        exons = []
        for exon in exon_str.split(','):
            exon_left, exon_right = exon.split('-')
            exons.append([int(exon_left), int(exon_right)])
        """
        if chr not in region_loci:
            region_loci[chr] = {}
        region_loci[chr][region_name] = [allele_name, chr, left, right] # Returns chromosomal region of chosen databases
        database_list.add(family.lower())

    if out_dir != "" and not os.path.exists(out_dir):
        os.mkdir(out_dir)

    # Extract reads
    if len(read_fname) > 0:
        if paired:
            fq_fnames = [read_fname[0]]
            fq_fnames2 = [read_fname[1]]
        else:
            fq_fnames = read_fname
    else:
        if paired:
            fq_fnames = glob.glob("%s/*.1.%s" % (read_dir, suffix))
        else:
            fq_fnames = glob.glob("%s/*.%s" % (read_dir, suffix))
    count = 0
    pids = [0 for i in range(threads)]
    for file_i in range(len(fq_fnames)):
        if file_i >= max_sample:
            break
        fq_fname = fq_fnames[file_i]
        if job_range[1] > 1:
            if job_range[0] != (file_i % job_range[1]):
                continue

        fq_fname_base = fq_fname.split('/')[-1]
        one_suffix = ".1." + suffix
        if fq_fname_base.find(one_suffix) != -1:
            fq_fname_base = fq_fname_base[:fq_fname_base.find(one_suffix)]
        else:
            fq_fname_base = fq_fname_base.split('.')[0]

        if paired:
            if read_dir == "":
                fq_fname2 = fq_fnames2[file_i]
            else:
                fq_fname2 = "%s/%s.2.%s" % (read_dir, fq_fname_base, suffix)
            if not os.path.exists(fq_fname2):
                print >> sys.stderr, "%s does not exist." % fq_fname2
                continue
        else:
            fq_fname2 = ""

        if paired:
            if out_dir != "":
                if os.path.exists("%s/%s.extracted.1.fq.gz" % (out_dir, fq_fname_base)):
                    continue
        else:
            if out_dir != "":
                if os.path.exists("%s/%s.extracted.fq.gz" % (out_dir, fq_fname_base)):
                    continue
        count += 1

        print >> sys.stderr, "\t%d: Extracting reads from %s" % (count, fq_fname_base)
        def work(fq_fname_base,
                 fq_fname,
                 fq_fname2,
                 ranges,
                 simulation,
                 verbose):
            aligner_cmd = [aligner]
            if threads_aprocess > 1:
                aligner_cmd += ["-p", "%d" % threads_aprocess]
            if not fastq:
                aligner_cmd += ["-f"]
            aligner_cmd += ["-x", base_fname]
            if aligner == "hisat2":
                if not is_rna:
                    aligner_cmd += ["--no-spliced-alignment"]
                else:
                    if rna_strandness != "":
                        aligner_cmd += ["--rna-strandness", rna_strandness]
                    else:
                        pass
                # aligner_cmd += ["--max-altstried", "64"]
            if not is_rna:
                aligner_cmd += ["-X", "1000"]
            if paired:
                aligner_cmd += ["-1", fq_fname,
                                "-2", fq_fname2]
            else:
                aligner_cmd += ["-U", fq_fname]
            if verbose:
                print >> sys.stderr, "\t\trunning", ' '.join(aligner_cmd)
            align_proc = subprocess.Popen(aligner_cmd,
                                          stdout=subprocess.PIPE,
                                          stderr=open("/dev/null", 'w'))

            gzip_dic = {}
            out_dir_slash = out_dir
            if out_dir != "":
                out_dir_slash += "/"
            for database in database_list:
                if paired:
                    # LP6005041-DNA_A01.extracted.1.fq.gz
                    gzip1_proc = subprocess.Popen(["gzip"],
                                                  stdin=subprocess.PIPE,
                                                  stdout=open("%s%s.%s.extracted.1.fq.gz" % (out_dir_slash, fq_fname_base, database), 'w'),
                                                  stderr=open("/dev/null", 'w'))

                    # LP6005041-DNA_A01.extracted.2.fq.gz
                    gzip2_proc = subprocess.Popen(["gzip"],
                                                  stdin=subprocess.PIPE,
                                                  stdout=open("%s%s.%s.extracted.2.fq.gz" % (out_dir_slash, fq_fname_base, database), 'w'),
                                                  stderr=open("/dev/null", 'w'))
                else:
                    # LP6005041-DNA_A01.extracted.fq.gz
                    gzip1_proc = subprocess.Popen(["gzip"],
                                                  stdin=subprocess.PIPE,
                                                  stdout=open("%s%s.%s.extracted.fq.gz" % (out_dir_slash, fq_fname_base, database), 'w'),
                                                  stderr=open("/dev/null", 'w'))
                gzip_dic[database] = [gzip1_proc, gzip2_proc if paired else None]

            whole_gzip_dic = {}
            if block_size > 0:
                mult = block_size / 1000000
                for chr_line in open("%s.fa.fai" % base_fname):
                    chr, length = chr_line.strip().split('\t')[:2]
                    length = int(length)
                    if chr not in [str(i+1) for i in range(22)] + ['X', 'Y', 'MT']:
                        continue
                    length = (length + block_size - 1) / block_size
                    assert chr not in whole_gzip_dic
                    whole_gzip_dic[chr] = []
                    for region_i in range(length):
                        if paired:
                            # LP6005041-DNA_A01.extracted.1.fq.gz
                            gzip1_proc = subprocess.Popen(["gzip"],
                                                          stdin=subprocess.PIPE,
                                                          stdout=open("%s%s.%s.%d_%dM.extracted.1.fq.gz" % (out_dir_slash, fq_fname_base, chr, region_i * mult, (region_i + 1) * mult), 'w'),
                                                          stderr=open("/dev/null", 'w'))

                            # LP6005041-DNA_A01.extracted.2.fq.gz
                            gzip2_proc = subprocess.Popen(["gzip"],
                                                          stdin=subprocess.PIPE,
                                                          stdout=open("%s%s.%s.%d_%dM.extracted.2.fq.gz" % (out_dir_slash, fq_fname_base, chr, region_i * mult, (region_i + 1) * mult), 'w'),
                                                          stderr=open("/dev/null", 'w'))
                        else:
                            # LP6005041-DNA_A01.extracted.fq.gz
                            gzip1_proc = subprocess.Popen(["gzip"],
                                                          stdin=subprocess.PIPE,
                                                          stdout=open("%s%s.%s.%d_%dM.extracted.fq.gz" % (out_dir_slash, fq_fname_base, chr, region_i * mult, (region_i + 1) * mult), 'w'),
                                                          stderr=open("/dev/null", 'w'))
                        whole_gzip_dic[chr].append([gzip1_proc, gzip2_proc if paired else None])


            def write_read(gzip_proc, read_name, seq, qual):
                if fastq:
                    gzip_proc.stdin.write("@%s\n" % read_name)
                    gzip_proc.stdin.write("%s\n" % seq)
                    gzip_proc.stdin.write("+\n")
                    gzip_proc.stdin.write("%s\n" % qual)
                else:
                    gzip_proc.stdin.write(">%s\n" % prev_read_name)
                    gzip_proc.stdin.write("%s\n" % seq)

            prev_read_name, extract_read, whole_extract_read, read1, read2, read1_first, read2_first = "", set(), set(), [], [], True, True
            for line in align_proc.stdout:
                if line.startswith('@'):
                    continue
                line = line.strip()
                cols = line.split()
                read_name, flag, chr, pos, mapQ, cigar, _, _, _, read, qual = cols[:11]
                flag, pos = int(flag), int(pos) - 1
                strand = '-' if flag & 0x10 else '+'
                AS, XS, NH = "", "", ""
                for i in range(11, len(cols)):
                    col = cols[i]
                    if col.startswith("AS"):
                        AS = int(col[5:])
                    elif col.startswith("XS"):
                        XS = col[5:]
                    elif col.startswith("NH"):
                        NH = int(col[5:])

                if (not simulation and read_name != prev_read_name) or \
                   (simulation and read_name.split('|')[0] != prev_read_name.split('|')[0]):
                    for region in extract_read:
                        print region
                        print prev_read_name
                        write_read(gzip_dic[region][0], prev_read_name, read1[0], read1[1])
                        if paired:
                            write_read(gzip_dic[region][1], prev_read_name, read2[0], read2[1])

                    for chr_region_num in whole_extract_read:
                        region_chr, region_num = chr_region_num.split('-')
                        region_num = int(region_num)
                        if region_chr not in whole_gzip_dic:
                            continue

                        assert region_num < len(whole_gzip_dic[region_chr])
                        write_read(whole_gzip_dic[region_chr][region_num][0], prev_read_name, read1[0], read1[1])
                        if paired:
                            write_read(whole_gzip_dic[region_chr][region_num][1], prev_read_name, read2[0], read2[1])

                    prev_read_name, extract_read, whole_extract_read, read1, read2, read1_first, read2_first = read_name, set(), set(), [], [], True, True

                if flag & 0x4 == 0 and \
                   ((aligner == "hisat2" and NH == 1) or (aligner == "bowtie2" and AS > XS and read1_first if flag & 0x40 or not paired else read2_first)):
                    if chr in region_loci:
                        for region, loci in region_loci[chr].items():
                            region = region.split('-')[0].lower()
                            _, _, loci_left, loci_right = loci
                            # there might be a different candidate region for each of left and right reads
                            if pos >= loci_left and pos < loci_right:
                                extract_read.add(region)
                                break
                    if block_size > 0:
                        chr_region_num = "%s-%d" % (chr, pos / block_size)
                        whole_extract_read.add(chr_region_num)

                if flag & 0x40 or not paired: # left read
                    read1_first = False
                    if not read1:
                        if flag & 0x10: # reverse complement
                            read1 = [typing_common.reverse_complement(read), qual[::-1]]
                        else:
                            read1 = [read, qual]
                else:
                    assert flag & 0x80 # right read
                    read2_first = False
                    if flag & 0x10: # reverse complement
                        read2 = [typing_common.reverse_complement(read), qual[::-1]]
                    else:
                        read2 = [read, qual]

            for region in extract_read:
                write_read(gzip_dic[region][0], prev_read_name, read1[0], read1[1])
                if paired:
                    write_read(gzip_dic[region][1], prev_read_name, read2[0], read2[1])

            for chr_region_num in whole_extract_read:
                region_chr, region_num = chr_region_num.split('-')
                region_num = int(region_num)
                if region_chr not in whole_gzip_dic:
                    continue
                assert region_num < len(whole_gzip_dic[region_chr])
                write_read(whole_gzip_dic[region_chr][region_num][0], prev_read_name, read1[0], read1[1])
                if paired:
                    write_read(whole_gzip_dic[region_chr][region_num][1], prev_read_name, read2[0], read2[1])

            for gzip1_proc, gzip2_proc in gzip_dic.values():
                gzip1_proc.stdin.close()
                if paired:
                    gzip2_proc.stdin.close()

            for gzip_list in whole_gzip_dic.values():
                for gzip1_proc, gzip2_proc in gzip_list:
                    gzip1_proc.stdin.close()
                    if paired:
                        gzip2_proc.stdin.close()


        if threads <= 1:
            work(fq_fname_base,
                 fq_fname,
                 fq_fname2,
                 ranges,
                 simulation,
                 verbose)
        else:
            parallel_work(pids,
                          work,
                          fq_fname_base,
                          fq_fname,
                          fq_fname2,
                          ranges,
                          simulation,
                          verbose)

    if threads > 1:
        wait_pids(pids)
def genotype(base_fname,
             target_region_list,
             fastq,
             read_fnames,
             alignment_fname,
             threads,
             num_editdist,
             assembly,
             local_database,
             verbose,
             debug):
    # variants, backbone sequence, and other sequeces
    genotype_fnames = ["%s.fa" % base_fname,
                       "%s.locus" % base_fname,
                       "%s.snp" % base_fname,
                       "%s.index.snp" % base_fname,
                       "%s.haplotype" % base_fname,
                       "%s.link" % base_fname,
                       "%s.coord" % base_fname,
                       "%s.clnsig" % base_fname]
    # hisat2 graph index files
    genotype_fnames += ["%s.%d.ht2" % (base_fname, i+1) for i in range(8)]
    if not typing_common.check_files(genotype_fnames):
        print("Error: some of the following files are missing!", 
              file=sys.stderr)
        for fname in genotype_fnames:
            print("\t%s" % fname, 
                  file=sys.stderr)
        sys.exit(1)

    # Read region alleles (names and sequences)
    regions     = {}
    region_loci = {}
    for line in open("%s.locus" % base_fname):
        family, allele_name, chr, left, right = line.strip().split()[:5]
        family = family.lower()
        if len(target_region_list) > 0 \
                and family not in target_region_list:
            continue
        
        locus_name = allele_name.split('*')[0]
        if family in target_region_list \
                and len(target_region_list[family]) > 0 \
                and locus_name not in target_region_list[family]:
            continue
        
        left  = int(left)
        right = int(right)
        if family not in region_loci:
            region_loci[family] = []
        region_loci[family].append([locus_name, allele_name, chr, left, right])

    if len(region_loci) <= 0:
        print("Warning: no region exists!", 
              file=sys.stderr)
        sys.exit(1)

    # Align reads, and sort the alignments into a BAM file
    if len(read_fnames) > 0:
        alignment_fname = align_reads(base_fname,
                                      read_fnames,
                                      fastq,
                                      threads,
                                      verbose)
    assert alignment_fname != "" and os.path.exists(alignment_fname)
    if not os.path.exists(alignment_fname + ".bai"):
        index_bam(alignment_fname,
                  verbose)
    assert os.path.exists(alignment_fname + ".bai")

    # Extract reads and perform genotyping
    for family, loci in region_loci.items():
        print("Analyzing %s ..." % family.upper(), file=sys.stderr)
        for locus_name, allele_name, chr, left, right in loci:
            out_read_fname = "%s.%s" % (family, locus_name)
            if verbose:
                print("\tExtracting reads beloning to %s-%s ..." \
                        % (family, locus_name), 
                      file=sys.stderr)

            extracted_read_fnames = extract_reads(alignment_fname,
                                                  chr,
                                                  left,
                                                  right,
                                                  out_read_fname,
                                                  len(read_fnames) != 1, # paired?
                                                  fastq,
                                                  verbose)

            perform_genotyping(base_fname,
                               family,
                               [locus_name],
                               extracted_read_fnames,
                               fastq,
                               num_editdist,
                               assembly,
                               local_database,
                               threads,
                               verbose,
                               debug)
        print("\n", file=sys.stderr)
Exemplo n.º 6
0
def build_genotype_genome(base_fname,                          
                          inter_gap,
                          intra_gap,
                          threads,
                          database_list,
                          use_clinvar,
                          use_commonvar,
                          verbose):    
    # Download HISAT2 index
    HISAT2_fnames = ["grch38",
                     "genome.fa",
                     "genome.fa.fai"]
    if not typing_common.check_files(HISAT2_fnames):
        typing_common.download_genome_and_index()

    # Load genomic sequences
    chr_dic, chr_names, chr_full_names = typing_common.read_genome(open("genome.fa"))

    genotype_vars, genotype_haplotypes, genotype_clnsig = {}, {}, {}
    if use_clinvar:
        # Extract variants from the ClinVar database
        CLINVAR_fnames = ["clinvar.vcf.gz",
                          "clinvar.snp",
                          "clinvar.haplotype",
                          "clinvar.clnsig"]

        if not typing_common.check_files(CLINVAR_fnames):
            if not os.path.exists("clinvar.vcf.gz"):
                os.system("wget ftp://ftp.ncbi.nlm.nih.gov/pub/clinvar/vcf_GRCh38/archive/2017/clinvar_20170404.vcf.gz")
            assert os.path.exists("clinvar.vcf.gz")

            extract_cmd = ["hisat2_extract_snps_haplotypes_VCF.py"]
            extract_cmd += ["--inter-gap", str(inter_gap),
                            "--intra-gap", str(intra_gap),
                            "--genotype-vcf", "clinvar.vcf.gz",
                            "genome.fa", "/dev/null", "clinvar"]
            if verbose:
                print >> sys.stderr, "\tRunning:", ' '.join(extract_cmd)
            proc = subprocess.Popen(extract_cmd, stdout=open("/dev/null", 'w'), stderr=open("/dev/null", 'w'))
            proc.communicate()
            if not typing_common.check_files(CLINVAR_fnames):
                print >> sys.stderr, "Error: extract variants from clinvar failed!"
                sys.exit(1)

        # Read variants to be genotyped
        genotype_vars = typing_common.read_variants("clinvar.snp")

        # Read haplotypes
        genotype_haplotypes = typing_common.read_haplotypes("clinvar.haplotype")

        # Read information about clinical significance
        genotype_clnsig = typing_common.read_clnsig("clinvar.clnsig")

    if use_commonvar:
        # Extract variants from dbSNP database
        commonvar_fbase = "snp144Common"
        commonvar_fnames = ["%s.snp" % commonvar_fbase,
                            "%s.haplotype" % commonvar_fbase]
        if not typing_common.check_files(commonvar_fnames):
            if not os.path.exists("%s.txt.gz" % commonvar_fbase):
                os.system("wget http://hgdownload.cse.ucsc.edu/goldenPath/hg38/database/%s.txt.gz" % commonvar_fbase)
            assert os.path.exists("%s.txt.gz" % commonvar_fbase)
            os.system("gzip -cd %s.txt.gz | awk 'BEGIN{OFS=\"\t\"} {if($2 ~ /^chr/) {$2 = substr($2, 4)}; if($2 == \"M\") {$2 = \"MT\"} print}' > %s.txt" % (commonvar_fbase, commonvar_fbase))
            extract_cmd = ["hisat2_extract_snps_haplotypes_UCSC.py",
                           "--inter-gap", str(inter_gap),
                           "--intra-gap", str(intra_gap),
                           "genome.fa", "%s.txt" % commonvar_fbase, commonvar_fbase]
            if verbose:
                print >> sys.stderr, "\tRunning:", ' '.join(extract_cmd)
            proc = subprocess.Popen(extract_cmd, stdout=open("/dev/null", 'w'), stderr=open("/dev/null", 'w'))
            proc.communicate()
            if not typing_common.check_files(commonvar_fnames):
                print >> sys.stderr, "Error: extract variants from clinvar failed!"
                sys.exit(1)

        # Read variants to be genotyped
        genotype_vars = typing_common.read_variants("%s.snp" % commonvar_fbase)

        # Read haplotypes
        genotype_haplotypes = typing_common.read_haplotypes("%s.haplotype" % commonvar_fbase)

    # Genes to be genotyped
    genotype_genes = {}

    # Read genes or genomics regions
    for database_name in database_list:
        # Extract HLA variants, backbone sequence, and other sequeces
        typing_common.extract_database_if_not_exists(database_name,
                                                     [],            # locus_list
                                                     inter_gap,
                                                     intra_gap,
                                                     True,          # partial?
                                                     verbose)
        locus_fname = "%s.locus" % database_name
        assert os.path.exists(locus_fname)
        for line in open(locus_fname):
            HLA_name, chr, left, right, length, exon_str, strand = line.strip().split()
            left, right = int(left), int(right)
            length = int(length)
            if chr not in chr_names:
                continue
            if chr not in genotype_genes:
                genotype_genes[chr] = []
            genotype_genes[chr].append([left, right, length, HLA_name, database_name, exon_str, strand])

    # Write genotype genome
    var_num, haplotype_num = 0, 0
    genome_out_file = open("%s.fa" % base_fname, 'w')
    locus_out_file = open("%s.locus" % base_fname, 'w')
    var_out_file = open("%s.snp" % base_fname, 'w')
    index_var_out_file = open("%s.index.snp" % base_fname, 'w')
    haplotype_out_file = open("%s.haplotype" % base_fname, 'w')
    link_out_file = open("%s.link" % base_fname, 'w')
    coord_out_file = open("%s.coord" % base_fname, 'w')
    clnsig_out_file = open("%s.clnsig" % base_fname, 'w')
    for c in range(len(chr_names)):
        chr = chr_names[c]
        chr_full_name = chr_full_names[c]
        assert chr in chr_dic
        chr_seq = chr_dic[chr]
        chr_len = len(chr_seq)
        if chr in genotype_genes:
            chr_genes = genotype_genes[chr]
            def gene_cmp(a, b):
                a_left, a_right, a_length = a[:3]
                b_left, b_right, b_length = b[:3]
                if a_left != b_left:
                    return a_left - b_left
                if a_right != b_right:
                    return a_right - b_right
                return a_lenght - b_length
            chr_genes = sorted(chr_genes, cmp=gene_cmp)
        else:
            chr_genes = []

        chr_genotype_vars, chr_genotype_vari = [], 0
        if chr in genotype_vars:
            chr_genotype_vars = genotype_vars[chr]
        chr_genotype_haplotypes, chr_genotype_hti = [], 0
        if chr in genotype_haplotypes:
            chr_genotype_haplotypes = genotype_haplotypes[chr]

        def add_vars(left, right, chr_genotype_vari, chr_genotype_hti, haplotype_num):
            # Output variants with clinical significance
            while chr_genotype_vari < len(chr_genotype_vars):
                var_left, var_type, var_data, var_id =  chr_genotype_vars[chr_genotype_vari]
                var_right = var_left
                if var_type == "deletion":
                    var_right += var_data
                if var_right > right:
                    break
                if var_right >= left:
                    chr_genotype_vari += 1
                    continue

                out_str = "%s\t%s\t%s\t%d\t%s" % (var_id, var_type, chr, var_left + off, var_data)
                print >> var_out_file, out_str
                print >> index_var_out_file, out_str

                if var_id in genotype_clnsig:
                    var_gene, clnsig = genotype_clnsig[var_id]
                    print >> clnsig_out_file, "%s\t%s\t%s" % \
                        (var_id, var_gene, clnsig)
                
                chr_genotype_vari += 1

            # Output haplotypes
            while chr_genotype_hti < len(chr_genotype_haplotypes):
                ht_left, ht_right, ht_vars =  chr_genotype_haplotypes[chr_genotype_hti]
                if ht_right > right:
                    break
                if ht_right >= left:
                    chr_genotype_hti += 1
                    continue

                print >> haplotype_out_file, "ht%d\t%s\t%d\t%d\t%s" % \
                    (haplotype_num, chr, ht_left + off, ht_right + off, ','.join(ht_vars))
                chr_genotype_hti += 1
                haplotype_num += 1

            return chr_genotype_vari, chr_genotype_hti, haplotype_num

        out_chr_seq = ""
        
        off = 0
        prev_right = 0
        for gene in chr_genes:
            left, right, length, name, family, exon_str, strand = gene

            chr_genotype_vari, chr_genotype_hti, haplotype_num = add_vars(left, right, chr_genotype_vari, chr_genotype_hti, haplotype_num)

            # Read HLA backbone sequences
            allele_seqs = typing_common.read_allele_sequences("%s_backbone.fa" % family)

            # Read HLA variants
            allele_vars = typing_common.read_variants("%s.snp" % family)
            allele_index_vars = typing_common.read_variants("%s.index.snp" % family)
                
            # Read HLA haplotypes
            allele_haplotypes = typing_common.read_haplotypes("%s.haplotype" % family)

            # Read HLA link information between haplotypes and variants
            links = typing_common.read_links("%s.link" % family)

            if name not in allele_seqs or \
                    name not in allele_vars or \
                    name not in allele_haplotypes:
                continue
            allele_seq = allele_seqs[name]
            vars, index_vars = allele_vars[name], allele_index_vars[name]
            index_var_ids = set()
            for _, _, _, var_id in index_vars:
                index_var_ids.add(var_id)

            haplotypes = allele_haplotypes[name]
            assert length == len(allele_seq)
            assert left < chr_len and right < chr_len
            # Skipping overlapping genes
            if left < prev_right:
                print >> sys.stderr, "Warning: skipping %s ..." % (name)
                continue

            varID2htID = {}

            assert left < right
            prev_length = right - left + 1
            assert prev_length <= length

            if prev_right < left:
                out_chr_seq += chr_seq[prev_right:left]

            # Output gene (genotype_genome.gene)
            print >> locus_out_file, "%s\t%s\t%s\t%d\t%d\t%s\t%s" % \
                (family.upper(), name, chr, len(out_chr_seq), len(out_chr_seq) + length - 1, exon_str, strand)

            # Output coord (genotype_genome.coord)
            print >> coord_out_file, "%s\t%d\t%d\t%d" % \
                (chr, len(out_chr_seq), left, right - left + 1)
            out_chr_seq += allele_seq

            # Output variants (genotype_genome.snp and genotype_genome.index.snp)
            for var in vars:
                var_left, var_type, var_data, var_id = var
                new_var_id = "hv%d" % var_num
                varID2htID[var_id] = new_var_id
                new_var_left = var_left + left + off
                assert var_type in ["single", "deletion", "insertion"]
                assert new_var_left < len(out_chr_seq)
                if var_type == "single":                    
                    assert out_chr_seq[new_var_left] != var_data
                elif var_type == "deletion":
                    assert new_var_left + var_data <= len(out_chr_seq)
                else:
                    assert var_type == "insertion"

                out_str = "%s\t%s\t%s\t%d\t%s" % (new_var_id, var_type, chr, new_var_left, var_data)
                print >> var_out_file, out_str
                if var_id in index_var_ids:
                    print >> index_var_out_file, out_str
                var_num += 1
                
            # Output haplotypes (genotype_genome.haplotype)
            for haplotype in haplotypes:
                ht_left, ht_right, ht_vars = haplotype
                new_ht_left = ht_left + left + off
                assert new_ht_left < len(out_chr_seq)
                new_ht_right = ht_right + left + off
                assert new_ht_left <= new_ht_right
                assert new_ht_right <= len(out_chr_seq)
                new_ht_vars = []
                for var_id in ht_vars:
                    assert var_id in varID2htID
                    new_ht_vars.append(varID2htID[var_id])
                print >> haplotype_out_file, "ht%d\t%s\t%d\t%d\t%s" % \
                    (haplotype_num, chr, new_ht_left, new_ht_right, ','.join(new_ht_vars))
                haplotype_num += 1

            # Output link information between alleles and variants (genotype_genome.link)
            for link in links:
                var_id, allele_names = link
                if var_id not in varID2htID:
                    continue
                new_var_id = varID2htID[var_id]
                print >> link_out_file, "%s\t%s" % (new_var_id, allele_names)
                
            off += (length - prev_length)

            prev_right = right + 1

        # Write the rest of the Vars
        chr_genotype_vari, chr_genotype_hti, haplotype_num = add_vars(sys.maxint, sys.maxint, chr_genotype_vari, chr_genotype_hti, haplotype_num)            
            
        print >> coord_out_file, "%s\t%d\t%d\t%d" % \
            (chr, len(out_chr_seq), prev_right, len(chr_seq) - prev_right)
        out_chr_seq += chr_seq[prev_right:]

        assert len(out_chr_seq) == len(chr_seq) + off

        # Output chromosome sequence
        print >> genome_out_file, ">%s" % (chr_full_name)
        line_width = 60
        for s in range(0, len(out_chr_seq), line_width):
            print >> genome_out_file, out_chr_seq[s:s+line_width]

    genome_out_file.close()
    locus_out_file.close()
    var_out_file.close()
    index_var_out_file.close()
    haplotype_out_file.close()
    link_out_file.close()
    coord_out_file.close()
    clnsig_out_file.close()

    partial_out_file = open("%s.partial" % base_fname, 'w')
    for database in database_list:
        for line in open("%s.partial" % database):
            allele_name = line.strip()
            print >> partial_out_file, "%s\t%s" % (database.upper(), allele_name)
    partial_out_file.close()

    # Index genotype_genome.fa
    index_cmd = ["samtools", "faidx", "%s.fa" % base_fname]
    subprocess.call(index_cmd)

    # Build HISAT-genotype graph indexes based on the above information
    hisat2_index_fnames = ["%s.%d.ht2" % (base_fname, i+1) for i in range(8)]
    build_cmd = ["hisat2-build",
                 "-p", str(threads),
                 "--snp", "%s.index.snp" % base_fname,
                 "--haplotype", "%s.haplotype" % base_fname,
                 "%s.fa" % base_fname,
                 "%s" % base_fname]
    if verbose:
        print >> sys.stderr, "\tRunning:", ' '.join(build_cmd)
        
    subprocess.call(build_cmd, stdout=open("/dev/null", 'w'), stderr=open("/dev/null", 'w'))
    if not typing_common.check_files(hisat2_index_fnames):
        print >> sys.stderr, "Error: indexing failed!  Perhaps, you may have forgotten to build hisat2 executables?"
        sys.exit(1)
def extract_reads(base_fname,
                  database_list,
                  read_dir,
                  out_dir,
                  suffix,
                  read_fname,
                  fastq,
                  paired,
                  simulation,
                  threads,
                  max_sample,
                  job_range,
                  verbose):
    genotype_fnames = ["%s.fa" % base_fname,
                       "%s.locus" % base_fname,
                       "%s.snp" % base_fname,
                       "%s.haplotype" % base_fname,
                       "%s.link" % base_fname,
                       "%s.coord" % base_fname,
                       "%s.clnsig" % base_fname]
    # hisat2 graph index files
    genotype_fnames += ["%s.%d.ht2" % (base_fname, i+1) for i in range(8)]
    if not typing_common.check_files(genotype_fnames):        
        print >> sys.stderr, "Error: %s related files do not exist as follows:" % base_fname
        for fname in genotype_fnames:
            print >> sys.stderr, "\t%s" % fname
        sys.exit(1)

    filter_region = len(database_list) > 0
    ranges = []
    regions, region_loci = {}, {}
    for line in open("%s.locus" % base_fname):
        family, allele_name, chr, left, right = line.strip().split()[:5]
        if filter_region and family.lower() not in database_list:
            continue
        region_name = "%s-%s" % (family, allele_name.split('*')[0])
        assert region_name not in regions
        regions[region_name] = allele_name
        left, right = int(left), int(right)
        """
        exons = []
        for exon in exon_str.split(','):
            exon_left, exon_right = exon.split('-')
            exons.append([int(exon_left), int(exon_right)])
        """
        if chr not in region_loci:
            region_loci[chr] = {}
        region_loci[chr][region_name] = [allele_name, chr, left, right]
        database_list.add(family.lower())

    if out_dir != "" and not os.path.exists(out_dir):
        os.mkdir(out_dir)

    # Extract reads
    if len(read_fname) > 0:
        if paired:
            fq_fnames = [read_fname[0]]
            fq_fnames2 = [read_fname[1]]
        else:
            fq_fnames = read_fname
    else:
        if paired:
            fq_fnames = glob.glob("%s/*.1.%s" % (read_dir, suffix))
        else:
            fq_fnames = glob.glob("%s/*.%s" % (read_dir, suffix))
    count = 0
    pids = [0 for i in range(threads)]
    for file_i in range(len(fq_fnames)):
        if file_i >= max_sample:
            break
        fq_fname = fq_fnames[file_i]
        if job_range[1] > 1:
            if job_range[0] != (file_i % job_range[1]):
                continue

        fq_fname_base = fq_fname.split('/')[-1]
        one_suffix = ".1." + suffix
        if fq_fname_base.find(one_suffix) != -1:
            fq_fname_base = fq_fname_base[:fq_fname_base.find(one_suffix)]
        else:
            fq_fname_base = fq_fname_base.split('.')[0]
            
        if paired:
            if read_dir == "":
                fq_fname2 = fq_fnames2[file_i]
            else:
                fq_fname2 = "%s/%s.2.%s" % (read_dir, fq_fname_base, suffix)
            if not os.path.exists(fq_fname2):
                print >> sys.stderr, "%s does not exist." % fq_fname2
                continue
        else:
            fq_fname2 = ""

        if paired:
            if out_dir != "":
                if os.path.exists("%s/%s.extracted.1.fq.gz" % (out_dir, fq_fname_base)):
                    continue
        else:
            if out_dir != "":
                if os.path.exists("%s/%s.extracted.fq.gz" % (out_dir, fq_fname_base)):
                    continue
        count += 1

        print >> sys.stderr, "\t%d: Extracting reads from %s" % (count, fq_fname_base)
        def work(fq_fname_base,
                 fq_fname, 
                 fq_fname2, 
                 ranges,
                 simulation,
                 verbose):
            aligner_cmd = ["hisat2"]
            if not fastq:
                aligner_cmd += ["-f"]
            aligner_cmd += ["-x", base_fname]
            aligner_cmd += ["--no-spliced-alignment",
                            "--max-altstried", "64"]
            if paired:
                aligner_cmd += ["-1", fq_fname,
                                "-2", fq_fname2]
            else:
                aligner_cmd += ["-U", fq_fname]
            if verbose:
                print >> sys.stderr, "\t\trunning", ' '.join(aligner_cmd)
            align_proc = subprocess.Popen(aligner_cmd,
                                          stdout=subprocess.PIPE,
                                          stderr=open("/dev/null", 'w'))

            gzip_dic = {}
            out_dir_slash = out_dir
            if out_dir != "":
                out_dir_slash += "/"
            for database in database_list:
                if paired:
                    # LP6005041-DNA_A01.extracted.1.fq.gz
                    gzip1_proc = subprocess.Popen(["gzip"],
                                                  stdin=subprocess.PIPE,
                                                  stdout=open("%s%s.%s.extracted.1.fq.gz" % (out_dir_slash, fq_fname_base, database), 'w'),
                                                  stderr=open("/dev/null", 'w'))

                    # LP6005041-DNA_A01.extracted.2.fq.gz
                    gzip2_proc = subprocess.Popen(["gzip"],
                                                  stdin=subprocess.PIPE,
                                                  stdout=open("%s%s.%s.extracted.2.fq.gz" % (out_dir_slash, fq_fname_base, database), 'w'),
                                                  stderr=open("/dev/null", 'w'))
                else:
                    # LP6005041-DNA_A01.extracted.fq.gz
                    gzip1_proc = subprocess.Popen(["gzip"],
                                                  stdin=subprocess.PIPE,
                                                  stdout=open("%s%s.%s.extracted.fq.gz" % (out_dir_slash, fq_fname_base, database), 'w'),
                                                  stderr=open("/dev/null", 'w'))
                gzip_dic[database] = [gzip1_proc, gzip2_proc if paired else None]

            def write_read(gzip_proc, read_name, seq, qual):
                if fastq:
                    gzip_proc.stdin.write("@%s\n" % read_name)
                    gzip_proc.stdin.write("%s\n" % seq)
                    gzip_proc.stdin.write("+\n")
                    gzip_proc.stdin.write("%s\n" % qual)
                else:
                    gzip_proc.stdin.write(">%s\n" % prev_read_name)
                    gzip_proc.stdin.write("%s\n" % seq)                    

            prev_read_name, extract_read, read1, read2 = "", False, [], []
            for line in align_proc.stdout:
                if line.startswith('@'):
                    continue
                line = line.strip()
                cols = line.split()
                read_name, flag, chr, pos, mapQ, cigar, _, _, _, read, qual = cols[:11]
                flag, pos = int(flag), int(pos)
                strand = '-' if flag & 0x10 else '+'                   
                AS, NH = "", ""
                for i in range(11, len(cols)):
                    col = cols[i]
                    if col.startswith("AS"):
                        AS = int(col[5:])
                    elif col.startswith("NH"):
                        NH = int(col[5:])

                if (not simulation and read_name != prev_read_name) or \
                   (simulation and read_name.split('|')[0] != prev_read_name.split('|')[0]):
                    if extract_read:
                        write_read(gzip_dic[region][0], prev_read_name, read1[0], read1[1])
                        if paired:
                            write_read(gzip_dic[region][1], prev_read_name, read2[0], read2[1])
                    prev_read_name, extract_read, read1, read2 = read_name, False, [], []

                if flag & 0x4 == 0 and NH == 1 and chr in region_loci:                    
                    for region, loci in region_loci[chr].items():
                        region = region.split('-')[0].lower()
                        _, _, loci_left, loci_right = loci
                        if pos >= loci_left and pos < loci_right:
                            extract_read = True
                            break

                if flag & 0x40 or not paired: # left read
                    if not read1:
                        if flag & 0x10: # reverse complement
                            read1 = [typing_common.reverse_complement(read), qual[::-1]]
                        else:
                            read1 = [read, qual]
                else:
                    assert flag & 0x80 # right read
                    if flag & 0x10: # reverse complement
                        read2 = [typing_common.reverse_complement(read), qual[::-1]]
                    else:
                        read2 = [read, qual]

            if extract_read:
                write_read(gzip_dic[region][0], prev_read_name, read1[0], read1[1])
                if paired:
                    write_read(gzip_dic[region][1], prev_read_name, read2[0], read2[1])

            for gzip1_proc, gzip2_proc in gzip_dic.values():
                gzip1_proc.stdin.close()
                if paired:
                    gzip2_proc.stdin.close()                        

        if threads <= 1:
            work(fq_fname_base, 
                 fq_fname, 
                 fq_fname2,
                 ranges,
                 simulation,
                 verbose)
        else:
            parallel_work(pids, 
                          work, 
                          fq_fname_base, 
                          fq_fname, 
                          fq_fname2, 
                          ranges,
                          simulation,
                          verbose)

    if threads > 1:
        wait_pids(pids)
def extract_reads(base_fname, database_list, read_dir, out_dir, suffix,
                  read_fname, fastq, paired, simulation, threads, max_sample,
                  job_range, verbose):
    genotype_fnames = [
        "%s.fa" % base_fname,
        "%s.locus" % base_fname,
        "%s.snp" % base_fname,
        "%s.haplotype" % base_fname,
        "%s.link" % base_fname,
        "%s.coord" % base_fname,
        "%s.clnsig" % base_fname
    ]
    # hisat2 graph index files
    genotype_fnames += ["%s.%d.ht2" % (base_fname, i + 1) for i in range(8)]
    if not typing_common.check_files(genotype_fnames):
        print >> sys.stderr, "Error: %s related files do not exist as follows:" % base_fname
        for fname in genotype_fnames:
            print >> sys.stderr, "\t%s" % fname
        sys.exit(1)

    filter_region = len(database_list) > 0
    ranges = []
    regions, region_loci = {}, {}
    for line in open("%s.locus" % base_fname):
        family, allele_name, chr, left, right = line.strip().split()[:5]
        if filter_region and family.lower() not in database_list:
            continue
        region_name = "%s-%s" % (family, allele_name.split('*')[0])
        assert region_name not in regions
        regions[region_name] = allele_name
        left, right = int(left), int(right)
        """
        exons = []
        for exon in exon_str.split(','):
            exon_left, exon_right = exon.split('-')
            exons.append([int(exon_left), int(exon_right)])
        """
        if chr not in region_loci:
            region_loci[chr] = {}
        region_loci[chr][region_name] = [allele_name, chr, left, right]
        database_list.add(family.lower())

    if out_dir != "" and not os.path.exists(out_dir):
        os.mkdir(out_dir)

    # Extract reads
    if len(read_fname) > 0:
        if paired:
            fq_fnames = [read_fname[0]]
            fq_fnames2 = [read_fname[1]]
        else:
            fq_fnames = read_fname
    else:
        if paired:
            fq_fnames = glob.glob("%s/*.1.%s" % (read_dir, suffix))
        else:
            fq_fnames = glob.glob("%s/*.%s" % (read_dir, suffix))
    count = 0
    pids = [0 for i in range(threads)]
    for file_i in range(len(fq_fnames)):
        if file_i >= max_sample:
            break
        fq_fname = fq_fnames[file_i]
        if job_range[1] > 1:
            if job_range[0] != (file_i % job_range[1]):
                continue

        fq_fname_base = fq_fname.split('/')[-1]
        one_suffix = ".1." + suffix
        if fq_fname_base.find(one_suffix) != -1:
            fq_fname_base = fq_fname_base[:fq_fname_base.find(one_suffix)]
        else:
            fq_fname_base = fq_fname_base.split('.')[0]

        if paired:
            if read_dir == "":
                fq_fname2 = fq_fnames2[file_i]
            else:
                fq_fname2 = "%s/%s.2.%s" % (read_dir, fq_fname_base, suffix)
            if not os.path.exists(fq_fname2):
                print >> sys.stderr, "%s does not exist." % fq_fname2
                continue
        else:
            fq_fname2 = ""

        if paired:
            if out_dir != "":
                if os.path.exists("%s/%s.extracted.1.fq.gz" %
                                  (out_dir, fq_fname_base)):
                    continue
        else:
            if out_dir != "":
                if os.path.exists("%s/%s.extracted.fq.gz" %
                                  (out_dir, fq_fname_base)):
                    continue
        count += 1

        print >> sys.stderr, "\t%d: Extracting reads from %s" % (count,
                                                                 fq_fname_base)

        def work(fq_fname_base, fq_fname, fq_fname2, ranges, simulation,
                 verbose):
            aligner_cmd = ["hisat2"]
            if not fastq:
                aligner_cmd += ["-f"]
            aligner_cmd += ["-x", base_fname]
            aligner_cmd += ["--no-spliced-alignment", "--max-altstried", "64"]
            if paired:
                aligner_cmd += ["-1", fq_fname, "-2", fq_fname2]
            else:
                aligner_cmd += ["-U", fq_fname]
            if verbose:
                print >> sys.stderr, "\t\trunning", ' '.join(aligner_cmd)
            align_proc = subprocess.Popen(aligner_cmd,
                                          stdout=subprocess.PIPE,
                                          stderr=open("/dev/null", 'w'))

            gzip_dic = {}
            out_dir_slash = out_dir
            if out_dir != "":
                out_dir_slash += "/"
            for database in database_list:
                if paired:
                    # LP6005041-DNA_A01.extracted.1.fq.gz
                    gzip1_proc = subprocess.Popen(
                        ["gzip"],
                        stdin=subprocess.PIPE,
                        stdout=open(
                            "%s%s.%s.extracted.1.fq.gz" %
                            (out_dir_slash, fq_fname_base, database), 'w'),
                        stderr=open("/dev/null", 'w'))

                    # LP6005041-DNA_A01.extracted.2.fq.gz
                    gzip2_proc = subprocess.Popen(
                        ["gzip"],
                        stdin=subprocess.PIPE,
                        stdout=open(
                            "%s%s.%s.extracted.2.fq.gz" %
                            (out_dir_slash, fq_fname_base, database), 'w'),
                        stderr=open("/dev/null", 'w'))
                else:
                    # LP6005041-DNA_A01.extracted.fq.gz
                    gzip1_proc = subprocess.Popen(
                        ["gzip"],
                        stdin=subprocess.PIPE,
                        stdout=open(
                            "%s%s.%s.extracted.fq.gz" %
                            (out_dir_slash, fq_fname_base, database), 'w'),
                        stderr=open("/dev/null", 'w'))
                gzip_dic[database] = [
                    gzip1_proc, gzip2_proc if paired else None
                ]

            def write_read(gzip_proc, read_name, seq, qual):
                if fastq:
                    gzip_proc.stdin.write("@%s\n" % read_name)
                    gzip_proc.stdin.write("%s\n" % seq)
                    gzip_proc.stdin.write("+\n")
                    gzip_proc.stdin.write("%s\n" % qual)
                else:
                    gzip_proc.stdin.write(">%s\n" % prev_read_name)
                    gzip_proc.stdin.write("%s\n" % seq)

            prev_read_name, extract_read, read1, read2 = "", False, [], []
            for line in align_proc.stdout:
                if line.startswith('@'):
                    continue
                line = line.strip()
                cols = line.split()
                read_name, flag, chr, pos, mapQ, cigar, _, _, _, read, qual = cols[:
                                                                                   11]
                flag, pos = int(flag), int(pos)
                strand = '-' if flag & 0x10 else '+'
                AS, NH = "", ""
                for i in range(11, len(cols)):
                    col = cols[i]
                    if col.startswith("AS"):
                        AS = int(col[5:])
                    elif col.startswith("NH"):
                        NH = int(col[5:])

                if (not simulation and read_name != prev_read_name) or \
                   (simulation and read_name.split('|')[0] != prev_read_name.split('|')[0]):
                    if extract_read:
                        write_read(gzip_dic[region][0], prev_read_name,
                                   read1[0], read1[1])
                        if paired:
                            write_read(gzip_dic[region][1], prev_read_name,
                                       read2[0], read2[1])
                    prev_read_name, extract_read, read1, read2 = read_name, False, [], []

                if flag & 0x4 == 0 and NH == 1 and chr in region_loci:
                    for region, loci in region_loci[chr].items():
                        region = region.split('-')[0].lower()
                        _, _, loci_left, loci_right = loci
                        if pos >= loci_left and pos < loci_right:
                            extract_read = True
                            break

                if flag & 0x40 or not paired:  # left read
                    if not read1:
                        if flag & 0x10:  # reverse complement
                            read1 = [
                                typing_common.reverse_complement(read),
                                qual[::-1]
                            ]
                        else:
                            read1 = [read, qual]
                else:
                    assert flag & 0x80  # right read
                    if flag & 0x10:  # reverse complement
                        read2 = [
                            typing_common.reverse_complement(read), qual[::-1]
                        ]
                    else:
                        read2 = [read, qual]

            if extract_read:
                write_read(gzip_dic[region][0], prev_read_name, read1[0],
                           read1[1])
                if paired:
                    write_read(gzip_dic[region][1], prev_read_name, read2[0],
                               read2[1])

            for gzip1_proc, gzip2_proc in gzip_dic.values():
                gzip1_proc.stdin.close()
                if paired:
                    gzip2_proc.stdin.close()

        if threads <= 1:
            work(fq_fname_base, fq_fname, fq_fname2, ranges, simulation,
                 verbose)
        else:
            parallel_work(pids, work, fq_fname_base, fq_fname, fq_fname2,
                          ranges, simulation, verbose)

    if threads > 1:
        wait_pids(pids)
Exemplo n.º 9
0
def genotype(base_fname,
             target_region_list,
             fastq,
             read_fnames,
             alignment_fname,
             threads,
             num_editdist,
             assembly,
             local_database,
             verbose,
             debug):
    # variants, backbone sequence, and other sequeces
    genotype_fnames = ["%s.fa" % base_fname,
                       "%s.locus" % base_fname,
                       "%s.snp" % base_fname,
                       "%s.index.snp" % base_fname,
                       "%s.haplotype" % base_fname,
                       "%s.link" % base_fname,
                       "%s.coord" % base_fname,
                       "%s.clnsig" % base_fname]
    # hisat2 graph index files
    genotype_fnames += ["%s.%d.ht2" % (base_fname, i+1) for i in range(8)]
    if not typing_common.check_files(genotype_fnames):
        print >> sys.stderr, "Error: some of the following files are missing!"
        for fname in genotype_fnames:
            print >> sys.stderr, "\t%s" % fname
        sys.exit(1)

    # Read region alleles (names and sequences)
    regions, region_loci = {}, {}
    for line in open("%s.locus" % base_fname):
        family, allele_name, chr, left, right = line.strip().split()[:5]
        family = family.lower()
        if len(target_region_list) > 0 and \
           family not in target_region_list:
            continue
        
        locus_name = allele_name.split('*')[0]
        if family in target_region_list and \
           len(target_region_list[family]) > 0 and \
           locus_name not in target_region_list[family]:
            continue
        
        left, right = int(left), int(right)
        if family not in region_loci:
            region_loci[family] = []
        region_loci[family].append([locus_name, allele_name, chr, left, right])

    if len(region_loci) <= 0:
        print >> sys.stderr, "Warning: no region exists!"
        sys.exit(1)

    # Align reads, and sort the alignments into a BAM file
    if len(read_fnames) > 0:
        alignment_fname = align_reads(base_fname,
                                      read_fnames,
                                      fastq,
                                      threads,
                                      verbose)
    assert alignment_fname != "" and os.path.exists(alignment_fname)
    if not os.path.exists(alignment_fname + ".bai"):
        index_bam(alignment_fname,
                  verbose)
    assert os.path.exists(alignment_fname + ".bai")

    # Extract reads and perform genotyping
    for family, loci in region_loci.items():
        print >> sys.stderr, "Analyzing %s ..." % family.upper()
        for locus_name, allele_name, chr, left, right in loci:
            out_read_fname = "%s.%s" % (family, locus_name)
            if verbose:
                print >> sys.stderr, "\tExtracting reads beloning to %s-%s ..." % \
                    (family, locus_name)

            extracted_read_fnames = extract_reads(alignment_fname,
                                                  chr,
                                                  left,
                                                  right,
                                                  out_read_fname,
                                                  len(read_fnames) != 1, # paired?
                                                  fastq,
                                                  verbose)

            perform_genotyping(base_fname,
                               family,
                               [locus_name],
                               extracted_read_fnames,
                               fastq,
                               num_editdist,
                               assembly,
                               local_database,
                               threads,
                               verbose)
        print >> sys.stderr