def shogun_utree_lca(input, output, utree_indx, threads):
    verify_make_dir(output)

    basenames = [os.path.basename(filename)[:-4] for filename in os.listdir(input) if filename.endswith('.fna')]

    for basename in basenames:
        fna_file = os.path.join(input, basename + '.fna')
        tsv_outf = os.path.join(output, basename + '.utree.tsv')
        if not os.path.isfile(tsv_outf):
            print(utree_search(utree_indx, fna_file, tsv_outf))
        else:
            print("Found the output file \"%s\". Skipping the alignment phase for this file." % tsv_outf)

    counts = []
    utree_outf = os.path.join(output, 'taxon_counts.txt')
    # Indexing for emblalmer
    if not os.path.isfile(utree_outf):
        for basename in basenames:
            lcas = []
            utree_tsv = os.path.join(output, basename + '.utree.tsv')
            with open(utree_tsv) as inf:
                tsv_parser = csv.reader(inf, delimiter='\t')
                for line in tsv_parser:
                    if line[1]:
                        lcas.append(';'.join(line[1].split('; ')))
            counts.append(Counter(filter(None, lcas)))

    df = pd.DataFrame(counts, index=basenames)
    df.T.to_csv(os.path.join(output, 'taxon_counts.csv'))
def shogun_bugbase(input, output, img_database_folder):
    verify_make_dir(output)
    utree_outf = os.path.join(output, 'taxa_counts.txt')
    # Indexing for emblalmer
    if not os.path.isfile(utree_outf):

        utree_indx = os.path.join(img_database_folder, 'img.genes.ctr')
        with open(os.path.join(img_database_folder, 'img_map.pkl'), 'rb') as inf:
            gg2img_oid = pickle.load(inf)

        basenames = [os.path.basename(filename)[:-4] for filename in os.listdir(input) if filename.endswith('.fna')]

        for basename in basenames:
            fna_file = os.path.join(input, basename + '.fna')
            tsv_outf = os.path.join(output, basename + '.utree.tsv')
            if not os.path.isfile(tsv_outf):
                print(utree_search(utree_indx, fna_file, tsv_outf))
            else:
                print("Found the output file \"%s\". Skipping the alignment phase for this file." % tsv_outf)

        counts = []

        for basename in basenames:
            lcas = []
            utree_tsv = os.path.join(output, basename + '.utree.tsv')
            with open(utree_tsv) as inf:
                tsv_parser = csv.reader(inf, delimiter='\t')
                for line in tsv_parser:
                    if line[1]:
                        taxon = line[1].replace('; ', ';')
                        if taxon in gg2img_oid:
                            lcas.append(gg2img_oid[taxon])
            counts.append(Counter(filter(None, lcas)))

        df = pd.DataFrame(counts, index=basenames).fillna(0).astype(int).T
        df.to_csv(utree_outf, sep='\t', index_label='#OTU ID')
    else:
        print("Found the output file \"%s\". Skipping all steps." % utree_outf)
def shogun_utree_capitalist(input, output, utree_indx, reference_fasta, reference_map, extract_ncbi_tid, threads):
    verify_make_dir(output)

    basenames = [os.path.basename(filename)[:-4] for filename in os.listdir(input) if filename.endswith('.fna')]

    for basename in basenames:
        fna_file = os.path.join(input, basename + '.fna')
        tsv_outf = os.path.join(output, basename + '.utree.tsv')
        if not os.path.isfile(tsv_outf):
            print(utree_search(utree_indx, fna_file, tsv_outf))
        else:
            print("Found the output file \"%s\". Skipping the alignment phase for this file." % tsv_outf)

    embalmer_outf = os.path.join(output, 'embalmer_out.txt')
    # Indexing for emblalmer
    if not os.path.isfile(embalmer_outf):
        lca_maps = defaultdict(lambda: defaultdict(list))
        for basename in basenames:
            utree_tsv = os.path.join(output, basename + '.utree.tsv')
            with open(utree_tsv) as inf:
                tsv_parser = csv.reader(inf, delimiter='\t')
                for line in tsv_parser:
                    if line[1]:
                        lca_maps[';'.join(line[1].split('; '))][basename].append(line[0])


        fna_faidx = {}
        for basename in basenames:
            fna_faidx[basename] = pyfaidx.Fasta(os.path.join(input, basename + '.fna'))

        dict_reference_map = defaultdict(list)

        with open(reference_map) as inf:
            tsv_in = csv.reader(inf, delimiter='\t')
            for line in tsv_in:
                dict_reference_map[';'.join(line[1].split('; '))].append(line[0])

        # reverse the dict to feed into embalmer
        references_faidx = pyfaidx.Fasta(reference_fasta)

        tmpdir = tempfile.mkdtemp()
        print(tmpdir)
        with open(embalmer_outf, 'w') as embalmer_cat:
            for species in lca_maps.keys():

                queries_fna_filename = os.path.join(tmpdir, 'queries.fna')
                references_fna_filename = os.path.join(tmpdir, 'reference.fna')
                output_filename = os.path.join(tmpdir, 'output.txt')

                with open(queries_fna_filename, 'w') as queries_fna:
                    for basename in lca_maps[species].keys():
                        for header in lca_maps[species][basename]:
                            record = fna_faidx[basename][header][:]
                            queries_fna.write('>filename|%s|%s\n%s\n' % (basename, record.name, record.seq))

                with open(references_fna_filename, 'w') as references_fna:
                    for i in dict_reference_map[species]:
                            record = references_faidx[i][:]
                            references_fna.write('>%s\n%s\n' % (record.name, record.seq))

                print(embalmer_align(queries_fna_filename, references_fna_filename, output_filename))

                with open(output_filename) as embalmer_out:
                    for line in embalmer_out:
                        embalmer_cat.write(line)

                os.remove(queries_fna_filename)
                os.remove(references_fna_filename)
                os.remove(output_filename)

        os.rmdir(tmpdir)
    else:
        print("Found the output file \"%s\". Skipping the strain alignment phase for this file." % embalmer_outf)


    # Convert the results from embalmer into CSV
    sparse_ncbi_dict = defaultdict(dict)

    begin, end = extract_ncbi_tid.split(',')
    # build query by NCBI_TID DataFrame
    with open(embalmer_outf) as embalmer_cat:
        embalmer_csv = csv.reader(embalmer_cat, delimiter='\t')
        for line in embalmer_csv:
            # line[0] = qname, line[1] = rname, line[2] = %match
            ncbi_tid = np.int(find_between(line[1], begin, end))
            sparse_ncbi_dict[line[0]][ncbi_tid] = np.float(line[2])

    df = pd.DataFrame.from_dict(sparse_ncbi_dict)
    df.to_csv(os.path.join(output, 'strain_alignments.csv'))