Esempio n. 1
0
    def open(self):
        """Open the HDMF file and set up chunks and taxonomy label"""
        if self.comm is not None:
            self.io = get_hdf5io(self.path, 'r', comm=self.comm, driver='mpio')
        else:
            self.io = get_hdf5io(self.path, 'r')
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            self.orig_difile = self.io.read()

        if self._world_size > 1:
            self.orig_difile.set_sequence_subset(balsplit(self.orig_difile.get_seq_lengths(), self._world_size, self._global_rank))

        self.difile = self.orig_difile

        self.load(sequence=self.load_data)

        self.difile.set_label_key(self.hparams.tgt_tax_lvl)

        if self.window is not None:
            self.set_chunks(self.window, self.step)

        if self.revcomp:
            self.set_revcomp()

        self._set_subset(train=self._train_subset, validate=self._validate_subset, test=self._test_subset)
Esempio n. 2
0
def prepare_data(argv=None):
    '''Aggregate sequence data GTDB using a file-of-files'''
    import argparse
    import io
    import sys
    import logging
    import h5py
    import pandas as pd

    from skbio import TreeNode

    from hdmf.common import get_hdf5io
    from hdmf.data_utils import DataChunkIterator

    from ..utils import get_faa_path, get_fna_path, get_genomic_path
    from exabiome.sequence.convert import AASeqIterator, DNASeqIterator, DNAVocabIterator, DNAVocabGeneIterator
    from exabiome.sequence.dna_table import AATable, DNATable, SequenceTable, TaxaTable, DeepIndexFile, NewickString, CondensedDistanceMatrix

    parser = argparse.ArgumentParser()
    parser.add_argument(
        'accessions',
        type=str,
        help='file of the NCBI accessions of the genomes to convert')
    parser.add_argument('fadir',
                        type=str,
                        help='directory with NCBI sequence files')
    parser.add_argument('metadata', type=str, help='metadata file from GTDB')
    parser.add_argument('tree', type=str, help='the distances file')
    parser.add_argument('out', type=str, help='output HDF5')
    grp = parser.add_mutually_exclusive_group()
    parser.add_argument('-e',
                        '--emb',
                        type=str,
                        help='embedding file',
                        default=None)
    grp.add_argument('-p',
                     '--protein',
                     action='store_true',
                     default=False,
                     help='get paths for protein files')
    grp.add_argument('-c',
                     '--cds',
                     action='store_true',
                     default=False,
                     help='get paths for CDS files')
    grp.add_argument('-g',
                     '--genomic',
                     action='store_true',
                     default=False,
                     help='get paths for genomic files (default)')
    parser.add_argument('-D',
                        '--dist_h5',
                        type=str,
                        help='the distances file',
                        default=None)
    parser.add_argument(
        '-d',
        '--max_deg',
        type=float,
        default=None,
        help='max number of degenerate characters in protein sequences')
    parser.add_argument('-l',
                        '--min_len',
                        type=float,
                        default=None,
                        help='min length of sequences')
    parser.add_argument('-V',
                        '--vocab',
                        action='store_true',
                        default=False,
                        help='store sequences as vocabulary data')

    if len(sys.argv) == 1:
        parser.print_help()
        sys.exit(1)

    args = parser.parse_args(args=argv)

    if not any([args.protein, args.cds, args.genomic]):
        args.genomic = True

    logging.basicConfig(stream=sys.stdout,
                        level=logging.INFO,
                        format='%(asctime)s - %(message)s')
    logger = logging.getLogger()

    # read accessions
    logger.info('reading accessions %s' % args.accessions)
    with open(args.accessions, 'r') as f:
        taxa_ids = [l[:-1] for l in f.readlines()]

    # get paths to Fasta Files
    fa_path_func = get_genomic_path
    if args.cds:
        fa_path_func = get_fna_path
    elif args.protein:
        fa_path_func = get_faa_path
    fapaths = [fa_path_func(acc, args.fadir) for acc in taxa_ids]

    di_kwargs = dict()
    # if a distance matrix file has been given, read and select relevant distances
    if args.dist_h5:
        #############################
        # read and filter distances
        #############################
        logger.info('reading distances from %s' % args.dist_h5)
        with h5py.File(args.dist_h5, 'r') as f:
            dist = f['distances'][:]
            dist_taxa = f['leaf_names'][:].astype('U')
        logger.info('selecting distances for taxa found in %s' %
                    args.accessions)
        dist = select_distances(taxa_ids, dist_taxa, dist)
        dist = CondensedDistanceMatrix('distances', data=dist)
        di_kwargs['distances'] = dist

    #############################
    # read and filter taxonomies
    #############################
    logger.info('reading taxonomies from %s' % args.metadata)
    taxlevels = [
        'domain', 'phylum', 'class', 'order', 'family', 'genus', 'species'
    ]

    def func(row):
        dat = dict(zip(taxlevels, row['gtdb_taxonomy'].split(';')))
        dat['species'] = dat['species'].split(' ')[1]
        dat['gtdb_genome_representative'] = row['gtdb_genome_representative'][
            3:]
        dat['accession'] = row['accession'][3:]
        return pd.Series(data=dat)

    logger.info('selecting GTDB taxonomy for taxa found in %s' %
                args.accessions)
    taxdf = pd.read_csv(args.metadata, header=0, sep='\t')[['accession', 'gtdb_taxonomy', 'gtdb_genome_representative']]\
                        .apply(func, axis=1)\
                        .set_index('accession')\
                        .filter(items=taxa_ids, axis=0)

    #############################
    # read and filter embeddings
    #############################
    emb = None
    if args.emb is not None:
        logger.info('reading embeddings from %s' % args.emb)
        with h5py.File(args.emb, 'r') as f:
            emb = f['embedding'][:]
            emb_taxa = f['leaf_names'][:]
        logger.info('selecting embeddings for taxa found in %s' %
                    args.accessions)
        emb = select_embeddings(taxa_ids, emb_taxa, emb)

    #############################
    # read and trim tree
    #############################
    logger.info('reading tree from %s' % args.tree)
    root = TreeNode.read(args.tree, format='newick')

    logger.info('transforming leaf names for shearing')
    for tip in root.tips():
        tip.name = tip.name[3:].replace(' ', '_')

    logger.info('shearing taxa not found in %s' % args.accessions)
    rep_ids = taxdf['gtdb_genome_representative'].values
    root = root.shear(rep_ids)

    logger.info('converting tree to Newick string')
    bytes_io = io.BytesIO()
    root.write(bytes_io, format='newick')
    tree_str = bytes_io.getvalue()
    tree = NewickString('tree', data=tree_str)

    if di_kwargs.get('distances') is None:
        from scipy.spatial.distance import squareform
        tt_dmat = root.tip_tip_distances()
        if (rep_ids != taxa_ids).any():
            tt_dmat = get_nonrep_matrix(taxa_ids, rep_ids, tt_dmat)
        dmat = tt_dmat.data
        di_kwargs['distances'] = CondensedDistanceMatrix('distances',
                                                         data=dmat)

    h5path = args.out

    logger.info("reading %d Fasta files" % len(fapaths))
    logger.info("Total size: %d", sum(os.path.getsize(f) for f in fapaths))

    if args.vocab:
        if args.protein:
            SeqTable = SequenceTable
            seqit = AAVocabIterator(fapaths,
                                    logger=logger,
                                    min_seq_len=args.min_len)
        else:
            SeqTable = DNATable
            if args.cds:
                logger.info("reading and writing CDS sequences")
                seqit = DNAVocabGeneIterator(fapaths,
                                             logger=logger,
                                             min_seq_len=args.min_len)
            else:
                seqit = DNAVocabIterator(fapaths,
                                         logger=logger,
                                         min_seq_len=args.min_len)
    else:
        if args.protein:
            logger.info("reading and writing protein sequences")
            seqit = AASeqIterator(fapaths,
                                  logger=logger,
                                  max_degenerate=args.max_deg,
                                  min_seq_len=args.min_len)
            SeqTable = AATable
        else:
            logger.info("reading and writing DNA sequences")
            seqit = DNASeqIterator(fapaths,
                                   logger=logger,
                                   min_seq_len=args.min_len)
            SeqTable = DNATable

    seqit_bsize = 2**25
    if args.protein:
        seqit_bsize = 2**15
    elif args.cds:
        seqit_bsize = 2**18

    # set up DataChunkIterators
    packed = DataChunkIterator.from_iterable(iter(seqit),
                                             maxshape=(None, ),
                                             buffer_size=seqit_bsize,
                                             dtype=np.dtype('uint8'))
    seqindex = DataChunkIterator.from_iterable(seqit.index_iter,
                                               maxshape=(None, ),
                                               buffer_size=2**0,
                                               dtype=np.dtype('int'))
    names = DataChunkIterator.from_iterable(seqit.names_iter,
                                            maxshape=(None, ),
                                            buffer_size=2**0,
                                            dtype=np.dtype('U'))
    ids = DataChunkIterator.from_iterable(seqit.id_iter,
                                          maxshape=(None, ),
                                          buffer_size=2**0,
                                          dtype=np.dtype('int'))
    taxa = DataChunkIterator.from_iterable(seqit.taxon_iter,
                                           maxshape=(None, ),
                                           buffer_size=2**0,
                                           dtype=np.dtype('uint16'))
    seqlens = DataChunkIterator.from_iterable(seqit.seqlens_iter,
                                              maxshape=(None, ),
                                              buffer_size=2**0,
                                              dtype=np.dtype('uint32'))

    io = get_hdf5io(h5path, 'w')

    tt_args = ['taxa_table', 'a table for storing taxa data', taxa_ids]
    tt_kwargs = dict()
    for t in taxlevels[1:]:
        tt_args.append(taxdf[t].values)
    if emb is not None:
        tt_kwargs['embedding'] = emb
    tt_kwargs['rep_taxon_id'] = rep_ids

    taxa_table = TaxaTable(*tt_args, **tt_kwargs)

    seq_table = SeqTable(
        'seq_table',
        'a table storing sequences for computing sequence embedding',
        io.set_dataio(names, compression='gzip', chunks=(2**15, )),
        io.set_dataio(packed,
                      compression='gzip',
                      maxshape=(None, ),
                      chunks=(2**15, )),
        io.set_dataio(seqindex,
                      compression='gzip',
                      maxshape=(None, ),
                      chunks=(2**15, )),
        io.set_dataio(seqlens,
                      compression='gzip',
                      maxshape=(None, ),
                      chunks=(2**15, )),
        io.set_dataio(taxa,
                      compression='gzip',
                      maxshape=(None, ),
                      chunks=(2**15, )),
        taxon_table=taxa_table,
        id=io.set_dataio(ids,
                         compression='gzip',
                         maxshape=(None, ),
                         chunks=(2**15, )))

    difile = DeepIndexFile(seq_table, taxa_table, tree, **di_kwargs)

    io.write(difile, exhaust_dci=False)
    io.close()

    logger.info("reading %s" % (h5path))
    h5size = os.path.getsize(h5path)
    logger.info("HDF5 size: %d", h5size)
Esempio n. 3
0
parser.add_argument('-i',
                    '--index',
                    type=int,
                    nargs='+',
                    default=None,
                    help='specific indices to check')

args = parser.parse_args()

logging.basicConfig(stream=sys.stderr,
                    level=logging.DEBUG,
                    format='%(asctime)s - %(message)s')
logger = logging.getLogger()

logger.info('opening %s' % args.h5)
hdmfio = get_hdf5io(args.h5, 'r')
difile = hdmfio.read()
n_total_seqs = len(difile.seq_table)
logger.info('found %d sequences' % n_total_seqs)

if args.index is not None:
    idx = set(args.index)
    logger.info('checking sequences %s' % ", ".join(map(str, args.index)))
else:
    logger.info('using seed %d' % args.seed)
    random = np.random.RandomState(args.seed)
    n_seqs = math.round(
        args.n_seqs * n_total_seqs) if args.n_seqs < 1.0 else int(args.n_seqs)
    idx = set(random.permutation(n_total_seqs)[:n_seqs])
    logger.info('sampling %d sequences' % n_seqs)
from hdmf.common import get_hdf5io
from exabiome.response.embedding import read_embedding

import pandas as pd

if __name__ == '__main__':

    import argparse


    parser = argparse.ArgumentParser(description='substitute new embeddings in a DeepIndex input file')
    parser.add_argument('new_embeddings', type=str, help='the embeddings to add to the DeepIndex input file')
    parser.add_argument('deep_index_input', type=str, help='the DeepIndex file with embeddings to overwwrite')
    args = parser.parse_args()

    emb, leaf_names = read_embedding(args.new_embeddings)
    leaf_names = [_[3:] for _ in leaf_names]

    emb_df = pd.DataFrame(data={'embedding1': emb[:,0], 'embedding2': emb[:,1]}, index=leaf_names)

    hdmfio = get_hdf5io(args.deep_index_input, mode='a')
    difile = hdmfio.read()
    di_taxa = difile.taxa_table['taxon_id'][:]
    di_emb = difile.taxa_table['embedding'].data   # h5py.Dataset

    emb_df = emb_df.filter(items=di_taxa, axis=0)

    di_emb[:] = emb_df.values

    hdmfio.close()
Esempio n. 5
0
from hdmf.common import get_hdf5io
from exabiome.sequence.convert import SeqConcat, pack_ohe_dna
from exabiome.sequence.dna_table import DNATable

fnapath = "../deep_index/gtdb/test_data/genomes/all/GCA/000/989/525/GCA_000989525.1_ASM98952v1/GCA_000989525.1_ASM98952v1_cds_from_genomic.fna.gz"
h5path = "seq.h5"

# ## Read Fasta sequence

print("reading %s" % (fnapath))
fasize = os.path.getsize(fnapath)
print("Fasta size:", fasize)
sc = SeqConcat()
data, seqindex, ltags = sc._read_path(fnapath)

# ## Pack sequence and write to HDF5 file

packed, padded = pack_ohe_dna(data)

with get_hdf5io(h5path, 'w') as io:
    table = DNATable('root', 'a test table',
                     io.set_dataio(ltags, compression='gzip'),
                     io.set_dataio(packed, compression='gzip'),
                     io.set_dataio(seqindex, compression='gzip'))
    io.write(table)

print("reading %s" % (h5path))
h5size = os.path.getsize(h5path)
print("HDF5 size:", h5size)
Esempio n. 6
0
def taxonomic_accuracy(argv=None):
    #import ..sequence as seq
    from ..sequence import DeepIndexFile
    from ..utils import get_logger
    from hdmf.common import get_hdf5io
    import h5py
    import numpy as np
    import pandas as pd
    from sklearn.preprocessing import LabelEncoder


    levels = DeepIndexFile.taxonomic_levels

    parser = argparse.ArgumentParser()
    parser.add_argument("summary", type=str, help='the summarized sequence NN outputs')
    parser.add_argument("input", type=str, help='the training input data')
    parser.add_argument("output", type=str, help="the path to save resutls to")
    parser.add_argument("-l", "--level", type=str, choices=levels, help='the taxonomic level')

    args = parser.parse_args(argv)

    logger = get_logger()

    logger.info(f'reading {args.input}')
    io = get_hdf5io(args.input, 'r')
    difile = io.read()

    with h5py.File(args.summary, 'r') as f:

        logger.info(f'loading summary results from {args.summary}')

        n_classes = None
        if 'outputs' in f:
            n_classes = f['outputs'].shape[1]
        elif 'n_classes' in f['labels'].attrs:
            n_classes = f['labels'].attrs['n_classes']

        level = None
        classes = None
        if args.level is None:
            if n_classes is None:
                print(f"Could not find number of classes in {args.summary}. Without this, I cannot guess what the taxonomic level is")
                exit(1)
            for lvl in levels[:-1]:
                n_classes_lvl = difile.taxa_table[lvl].elements.data.shape[0]
                if n_classes == n_classes_lvl:
                    classes = difile.taxa_table[lvl].elements.data
                    level = lvl
            if level is None:
                n_classes_lvl = difile.taxa_table['species'].data.shape[0]
                if n_classes == n_classes_lvl:
                    level = 'species'
                    classes = difile.taxa_table['species'].data[:]
                else:
                    print("Cannot determine which level to use. Please specify with --level option", file=sys.stderr)
                    exit(1)
        else:
            level = args.level

        logger.info(f'computing accuracy for {level}{" and higher" if level else ""}')

        seq_preds = f['preds'][:].astype(int)
        seq_labels = f['labels'][:].astype(int)
        seq_lens = f['lengths'][:].astype(int)

    mask = seq_labels != -1
    seq_preds = seq_preds[mask]
    seq_labels = seq_labels[mask]
    seq_lens = seq_lens[mask]

    logger.info(f'Keeping {mask.sum()} of {mask.shape[0]} ({mask.mean()*100:.1f}%) sequences after discarding uninitialized sequences')

    ## I used this code to double check that genus elements were correct
    # seq_ids = f['seq_ids'][:]
    # genome_ids = difile.seq_table['genome'].data[:][seq_ids]
    # taxon_ids = difile.genome_table['taxon_id'].data[:][genome_ids]
    # classes = difile.taxa_table['genus'].elements.data[:]

    logger.info('loading taxonomy table')
    # do this because h5py.Datasets cannot point-index with non-unique indices
    for col in difile.taxa_table.columns:
        col.transform(lambda x: x[:])

    to_drop = ['taxon_id']
    for lvl in levels[::-1]:
        if lvl == level:
            break
        to_drop.append(lvl)

    # orient table to index it by the taxonomic level and remove columns we cannot get predictions for
    taxdf = difile.taxa_table.to_dataframe()

    n_orig_classes = {col: np.unique(taxdf[col]).shape[0] for col in taxdf}

    taxdf = taxdf.drop(to_drop, axis=1).\
                  set_index(level).\
                  groupby(level).\
                  nth(0).\
                  filter(classes, axis=0)

    logger.info('encoding taxonomy for quicker comparisons')
    # encode into integers for faster comparisons
    encoders = dict()
    new_dat = dict()
    for col in taxdf.columns:
        enc = LabelEncoder().fit(taxdf[col])
        encoders[col] = enc
        new_dat[col] = enc.transform(taxdf[col])
    enc_df = pd.DataFrame(data=new_dat, index=taxdf.index)

    # a helper function to transform results into a DataFrame
    def get_results(true, pred, lens, n_classes):
        mask = true == pred
        n_classes = "%s / %s" % ((np.unique(true).shape[0]), n_classes)
        return {'seq-level': "%0.1f" % (100*mask.mean()), 'base-level': "%0.1f" % (100*lens[mask].sum()/lens.sum()), 'n_classes': n_classes}

    results = dict()
    for colname in enc_df.columns:
        logger.info(f'computing results for {colname}')
        col = enc_df[colname].values
        results[colname] = get_results(col[seq_labels], col[seq_preds], seq_lens, n_orig_classes[colname])

    logger.info(f'computing results for {level}')
    results[level] = get_results(seq_labels, seq_preds, seq_lens, n_orig_classes[level])

    results['n'] = {'seq-level': len(seq_lens), 'base-level': seq_lens.sum(), 'n_classes': '-1'}

    results = pd.DataFrame(data=results)
    results.to_csv(args.output, sep=',')
    print(results)
Esempio n. 7
0
def prepare_data(argv=None):
    '''Aggregate sequence data GTDB using a file-of-files'''
    from io import BytesIO
    import tempfile
    import h5py

    from datetime import datetime

    from tqdm import tqdm

    from skbio import TreeNode
    from skbio.sequence import DNA, Protein

    from hdmf.common import get_hdf5io
    from hdmf.data_utils import DataChunkIterator

    from ..utils import get_faa_path, get_fna_path, get_genomic_path
    from deep_taxon.sequence.convert import AASeqIterator, DNASeqIterator, DNAVocabIterator, DNAVocabGeneIterator
    from deep_taxon.sequence.dna_table import AATable, DNATable, SequenceTable, TaxaTable, DeepIndexFile, NewickString, CondensedDistanceMatrix, GenomeTable, TreeGraph

    parser = argparse.ArgumentParser()
    parser.add_argument('fadir',
                        type=str,
                        help='directory with NCBI sequence files')
    parser.add_argument('metadata', type=str, help='metadata file from GTDB')
    parser.add_argument('out', type=str, help='output HDF5')
    parser.add_argument(
        '-T',
        '--tree',
        type=str,
        help='a Newick file with a tree of representative taxa',
        default=None)
    parser.add_argument(
        '-A',
        '--accessions',
        type=str,
        default=None,
        help='file of the NCBI accessions of the genomes to convert')
    parser.add_argument(
        '-d',
        '--max_deg',
        type=float,
        default=None,
        help='max number of degenerate characters in protein sequences')
    parser.add_argument('-l',
                        '--min_len',
                        type=float,
                        default=None,
                        help='min length of sequences')
    parser.add_argument('--iter',
                        action='store_true',
                        default=False,
                        help='convert using iterators')
    parser.add_argument(
        '-p',
        '--num_procs',
        type=int,
        default=1,
        help='the number of processes to use for counting total sequence size')
    parser.add_argument('-L',
                        '--total_seq_len',
                        type=int,
                        default=None,
                        help='the total sequence length')
    parser.add_argument('-t',
                        '--tmpdir',
                        type=str,
                        default=None,
                        help='a temporary directory to store sequences')
    parser.add_argument('-N',
                        '--n_seqs',
                        type=int,
                        default=None,
                        help='the total number of sequences')
    rep_grp = parser.add_mutually_exclusive_group()
    rep_grp.add_argument(
        '-n',
        '--nonrep',
        action='store_true',
        default=False,
        help='keep non-representative genomes only. keep both by default')
    rep_grp.add_argument(
        '-r',
        '--rep',
        action='store_true',
        default=False,
        help='keep representative genomes only. keep both by default')
    parser.add_argument(
        '-a',
        '--all',
        action='store_true',
        default=False,
        help=
        'keep all non-representative genomes. By default, only non-reps with the highest and lowest contig count are kept'
    )
    grp = parser.add_mutually_exclusive_group()
    grp.add_argument('-P',
                     '--protein',
                     action='store_true',
                     default=False,
                     help='get paths for protein files')
    grp.add_argument('-C',
                     '--cds',
                     action='store_true',
                     default=False,
                     help='get paths for CDS files')
    grp.add_argument('-G',
                     '--genomic',
                     action='store_true',
                     default=False,
                     help='get paths for genomic files (default)')
    parser.add_argument('-z',
                        '--gzip',
                        action='store_true',
                        default=False,
                        help='GZip sequence table')
    dep_grp = parser.add_argument_group(
        title="Legacy options you probably do not need")
    dep_grp.add_argument('-e',
                         '--emb',
                         type=str,
                         help='embedding file',
                         default=None)

    if len(sys.argv) == 1:
        parser.print_help()
        sys.exit(1)

    args = parser.parse_args(args=argv)

    if args.total_seq_len is not None:
        if args.n_seqs is None:
            sys.stderr.write(
                "If using --total_seq_len, you must also use --n_seqs\n")
    if args.n_seqs is not None:
        if args.total_seq_len is None:
            sys.stderr.write(
                "If using --n_seqs, you must also use --total_seq_len\n")

    if not any([args.protein, args.cds, args.genomic]):
        args.genomic = True

    logging.basicConfig(stream=sys.stderr,
                        level=logging.INFO,
                        format='%(asctime)s - %(message)s')
    logger = logging.getLogger()

    #############################
    # read and filter taxonomies
    #############################
    logger.info('Reading taxonomies from %s' % args.metadata)
    taxlevels = [
        'domain', 'phylum', 'class', 'order', 'family', 'genus', 'species'
    ]
    extra_cols = ['contig_count', 'checkm_completeness']

    def func(row):
        dat = dict(zip(taxlevels, row['gtdb_taxonomy'].split(';')))
        dat['species'] = dat['species']  # .split(' ')[1]
        dat['gtdb_genome_representative'] = row['gtdb_genome_representative'][
            3:]
        dat['accession'] = row['accession'][3:]
        for k in extra_cols:
            dat[k] = row[k]
        return pd.Series(data=dat)

    taxdf = pd.read_csv(args.metadata, header=0, sep='\t')[['accession', 'gtdb_taxonomy', 'gtdb_genome_representative', 'contig_count', 'checkm_completeness']]\
                        .apply(func, axis=1)

    taxdf = taxdf.set_index('accession')
    dflen = len(taxdf)
    logger.info('Found %d total genomes' % dflen)
    taxdf = taxdf[taxdf['gtdb_genome_representative'].str.contains(
        'GC[A,F]_', regex=True)]  # get rid of genomes that are not at NCBI
    taxdf = taxdf[taxdf.index.str.contains(
        'GC[A,F]_', regex=True)]  # get rid of genomes that are not at NCBI
    logger.info('Discarded %d non-NCBI genomes' % (dflen - len(taxdf)))

    rep_taxdf = taxdf[taxdf.index == taxdf['gtdb_genome_representative']]

    if args.accessions is not None:
        logger.info('reading accessions %s' % args.accessions)
        with open(args.accessions, 'r') as f:
            accessions = [l[:-1] for l in f.readlines()]
        dflen = len(taxdf)
        taxdf = taxdf[taxdf.index.isin(accessions)]
        logger.info('Discarded %d genomes not found in %s' %
                    (dflen - len(taxdf), args.accessions))

    dflen = len(taxdf)
    if args.nonrep:
        taxdf = taxdf[taxdf.index != taxdf['gtdb_genome_representative']]
        logger.info('Discarded %d representative genomes' %
                    (dflen - len(taxdf)))
        dflen = len(taxdf)
        if not args.all:
            groups = taxdf[['gtdb_genome_representative', 'contig_count'
                            ]].groupby('gtdb_genome_representative')
            min_ctgs = groups.idxmin()['contig_count']
            max_ctgs = groups.idxmax()['contig_count']
            accessions = np.unique(np.concatenate([min_ctgs, max_ctgs]))
            taxdf = taxdf.filter(accessions, axis=0)
            logger.info('Discarded %d extra non-representative genomes' %
                        (dflen - len(taxdf)))
    elif args.rep:
        taxdf = taxdf[taxdf.index == taxdf['gtdb_genome_representative']]
        logger.info('Discarded %d non-representative genomes' %
                    (dflen - len(taxdf)))

    dflen = len(taxdf)
    logger.info('%d remaining genomes' % dflen)

    ###############################
    # Arguments for constructing the DeepIndexFile object
    ###############################
    di_kwargs = dict()

    taxa_ids = taxdf.index.values

    # get paths to Fasta Files
    fa_path_func = partial(get_genomic_path, directory=args.fadir)
    if args.cds:
        fa_path_func = partial(get_fna_path, directory=args.fadir)
    elif args.protein:
        fa_path_func = partial(get_faa_path, directory=args.fadir)

    map_func = map
    if args.num_procs > 1:
        logger.info(f'using {args.num_procs} processes to locate Fasta files')
        import multiprocessing as mp
        map_func = mp.Pool(processes=args.num_procs).imap

    logger.info('Locating Fasta files for each taxa')
    fapaths = list(tqdm(map_func(fa_path_func, taxa_ids), total=len(taxa_ids)))

    logger.info('Found Fasta files for all accessions')

    #############################
    # read and filter embeddings
    #############################
    emb = None
    if args.emb is not None:
        logger.info('reading embeddings from %s' % args.emb)
        with h5py.File(args.emb, 'r') as f:
            emb = f['embedding'][:]
            emb_taxa = f['leaf_names'][:]
        logger.info('selecting embeddings for taxa found in %s' %
                    args.accessions)
        emb = select_embeddings(taxa_ids, emb_taxa, emb)

    logger.info(f'Writing {len(rep_taxdf)} taxa to taxa table')
    tt_args = [
        'taxa_table', 'a table for storing taxa data', rep_taxdf.index.values
    ]
    tt_kwargs = dict()
    for t in taxlevels[:-1]:
        enc = LabelEncoder().fit(rep_taxdf[t].values)
        _data = enc.transform(rep_taxdf[t].values).astype(np.uint32)
        _vocab = enc.classes_.astype('U')
        logger.info(f'{t} - {len(_vocab)} classes')
        tt_args.append(
            EnumData(name=t,
                     description=f'label encoded {t}',
                     data=_data,
                     elements=_vocab))
    # we have too many species to store this as VocabData, nor does it save any spaces
    tt_args.append(
        VectorData(name='species',
                   description=f'Microbial species in the form Genus species',
                   data=rep_taxdf['species'].values))

    if emb is not None:
        tt_kwargs['embedding'] = emb
    #tt_kwargs['rep_taxon_id'] = rep_taxdf['gtdb_genome_representative'].values

    taxa_table = TaxaTable(*tt_args, **tt_kwargs)

    h5path = args.out

    logger.info("reading %d Fasta files" % len(fapaths))
    logger.info("Total size: %d", sum(list(map_func(os.path.getsize,
                                                    fapaths))))

    tmp_h5_file = None
    if args.protein:
        vocab_it = AAVocabIterator
        SeqTable = SequenceTable
        skbio_cls = Protein
    else:
        vocab_it = DNAVocabIterator
        SeqTable = DNATable
        skbio_cls = DNA

    vocab = np.array(list(vocab_it.characters()))
    if not args.protein:
        np.testing.assert_array_equal(vocab, list('ACYWSKDVNTGRMHB'))

    if args.total_seq_len is None:
        logger.info('counting total number of sqeuences')
        n_seqs, total_seq_len = np.array(
            list(zip(
                *tqdm(map_func(seqlen, fapaths), total=len(fapaths))))).sum(
                    axis=1)
        logger.info(f'found {total_seq_len} bases across {n_seqs} sequences')
    else:
        n_seqs, total_seq_len = args.n_seqs, args.total_seq_len
        logger.info(
            f'As specified, there are {total_seq_len} bases across {n_seqs} sequences'
        )

    logger.info(
        f'allocating uint8 array of length {total_seq_len} for sequences')

    if args.tmpdir is not None:
        if not os.path.exists(args.tmpdir):
            os.mkdir(args.tmpdir)
        tmpdir = tempfile.mkdtemp(dir=args.tmpdir)
    else:
        tmpdir = tempfile.mkdtemp()

    comp = 'gzip' if args.gzip else None
    tmp_h5_filename = os.path.join(tmpdir, 'sequences.h5')
    logger.info(f'writing temporary sequence data to {tmp_h5_filename}')
    tmp_h5_file = h5py.File(tmp_h5_filename, 'w')
    sequence = tmp_h5_file.create_dataset('sequences',
                                          shape=(total_seq_len, ),
                                          dtype=np.uint8,
                                          compression=comp)
    seqindex = tmp_h5_file.create_dataset('sequences_index',
                                          shape=(n_seqs, ),
                                          dtype=np.uint64,
                                          compression=comp)
    genomes = tmp_h5_file.create_dataset('genomes',
                                         shape=(n_seqs, ),
                                         dtype=np.uint64,
                                         compression=comp)
    seqlens = tmp_h5_file.create_dataset('seqlens',
                                         shape=(n_seqs, ),
                                         dtype=np.uint64,
                                         compression=comp)
    names = tmp_h5_file.create_dataset('seqnames',
                                       shape=(n_seqs, ),
                                       dtype=h5py.special_dtype(vlen=str),
                                       compression=comp)

    taxa = np.zeros(len(fapaths), dtype=int)

    seq_i = 0
    b = 0
    for genome_i, fa in tqdm(enumerate(fapaths), total=len(fapaths)):
        kwargs = {
            'format': 'fasta',
            'constructor': skbio_cls,
            'validate': False
        }
        taxid = taxa_ids[genome_i]
        rep_taxid = taxdf['gtdb_genome_representative'][genome_i]
        taxa[genome_i] = np.where(rep_taxdf.index == rep_taxid)[0][0]
        for seq in skbio.io.read(fa, **kwargs):
            enc_seq = vocab_it.encode(seq)
            e = b + len(enc_seq)
            sequence[b:e] = enc_seq
            seqindex[seq_i] = e
            genomes[seq_i] = genome_i
            seqlens[seq_i] = len(enc_seq)
            names[seq_i] = vocab_it.get_seqname(seq)
            b = e
            seq_i += 1
    ids = tmp_h5_file.create_dataset('ids', data=np.arange(n_seqs), dtype=int)
    tmp_h5_file.flush()

    io = get_hdf5io(h5path, 'w')

    print([a['name'] for a in GenomeTable.__init__.__docval__['args']])

    genome_table = GenomeTable(
        'genome_table',
        'information about the genome each sequence comes from',
        taxa_ids,
        taxa,
        taxa_table=taxa_table)

    #############################
    # read and trim tree
    #############################
    if args.tree:
        logger.info('Reading tree from %s' % args.tree)
        root = TreeNode.read(args.tree, format='newick')

        logger.info('Found %d tips' % len(list(root.tips())))

        logger.info('Transforming leaf names for shearing')
        for tip in root.tips():
            tip.name = tip.name[3:].replace(' ', '_')

        logger.info('converting tree to Newick string')
        bytes_io = BytesIO()
        root.write(bytes_io, format='newick')
        tree_str = bytes_io.getvalue()
        di_kwargs['tree'] = NewickString('tree', data=tree_str)

        # get distances from tree if they are not provided
        tt_dmat = root.tip_tip_distances().filter(rep_taxdf.index)
        di_kwargs['distances'] = CondensedDistanceMatrix('distances',
                                                         data=tt_dmat.data)

        adj, gt_indices = get_tree_graph(root, rep_taxdf)
        di_kwargs['tree_graph'] = TreeGraph(data=adj,
                                            leaves=gt_indices,
                                            table=genome_table,
                                            name='tree_graph')

    if args.gzip:
        names = io.set_dataio(names, compression='gzip', chunks=True)
        sequence = io.set_dataio(sequence,
                                 compression='gzip',
                                 maxshape=(None, ),
                                 chunks=True)
        seqindex = io.set_dataio(seqindex,
                                 compression='gzip',
                                 maxshape=(None, ),
                                 chunks=True)
        seqlens = io.set_dataio(seqlens,
                                compression='gzip',
                                maxshape=(None, ),
                                chunks=True)
        genomes = io.set_dataio(genomes,
                                compression='gzip',
                                maxshape=(None, ),
                                chunks=True)
        ids = io.set_dataio(ids,
                            compression='gzip',
                            maxshape=(None, ),
                            chunks=True)

    seq_table = SeqTable(
        'seq_table',
        'a table storing sequences for computing sequence embedding',
        names,
        sequence,
        seqindex,
        seqlens,
        genomes,
        genome_table=genome_table,
        id=ids,
        vocab=vocab)

    difile = DeepIndexFile(seq_table, taxa_table, genome_table, **di_kwargs)

    before = datetime.now()
    io.write(difile, exhaust_dci=False, link_data=False)
    io.close()
    after = datetime.now()
    delta = (after - before).total_seconds()

    logger.info(
        f'Sequence totals {sequence.dtype.itemsize * sequence.size} bytes')
    logger.info(f'Took {delta} seconds to write after read')

    if tmp_h5_file is not None:
        tmp_h5_file.close()

    logger.info("reading %s" % (h5path))
    h5size = os.path.getsize(h5path)
    logger.info("HDF5 size: %d", h5size)
Esempio n. 8
0
def build_deployment_pkg(argv=None):
    """
    Convert a Torch model checkpoint to ONNX format
    """

    import json
    import os
    import shutil
    import tempfile
    import zipfile
    from hdmf.common import get_hdf5io

    desc = "Convert a Torch model checkpoint to ONNX format"
    epi = ("By default, the ONNX file will be written to same directory "
           "as checkpoint")

    parser = argparse.ArgumentParser(description=desc, epilog=epi)
    parser.add_argument('input',
                        type=str,
                        help='the input file to run inference on')
    parser.add_argument('config',
                        type=str,
                        help='the config file used for training')
    parser.add_argument('nn_model',
                        type=str,
                        help='the NN model for doing predictions')
    parser.add_argument(
        'conf_model',
        type=str,
        help='the checkpoint file to use for running inference')
    parser.add_argument('output_dir',
                        type=str,
                        help='the directory to copy to before zipping')
    parser.add_argument('-f',
                        '--force',
                        action='store_true',
                        default=False,
                        help='overwrite output if it exists')

    if len(argv) == 0:
        parser.print_help()
        sys.exit(1)

    args = parser.parse_args(argv)

    logger = get_logger()

    if os.path.exists(args.output_dir):
        if args.force:
            logger.info(f"{args.output_dir} exists, removing tree")
            shutil.rmtree(args.output_dir)
        else:
            logger.error(f"{args.output_dir} exists, exiting")
            exit(1)

    os.mkdir(args.output_dir)
    tmpdir = args.output_dir

    logger.info(f'Using temporary directory {tmpdir}')
    logger.info(f'loading sample input from {args.input}')

    io = get_hdf5io(args.input, 'r')
    difile = io.read()
    tt = difile.taxa_table
    vocab = difile.seq_table.sequence.elements.data[:].tolist()
    _load = lambda x: x[:]
    for col in tt.columns:
        col.transform(_load)
    tt_df = tt.to_dataframe().set_index('taxon_id')
    io.close()

    path = lambda x: os.path.join(tmpdir, os.path.basename(x))

    manifest = {
        'taxa_table': os.path.join(tmpdir, "taxa_table.csv"),
        'nn_model': path(args.nn_model),
        'conf_model': path(args.conf_model),
        'training_config': path(args.config),
        'vocabulary': vocab,
    }

    logger.info(f"exporting taxa table CSV to {manifest['taxa_table']}")
    tt_df.to_csv(manifest['taxa_table'])
    logger.info(f"copying {args.nn_model} to {manifest['nn_model']}")
    shutil.copyfile(args.nn_model, manifest['nn_model'])
    logger.info(f"copying {args.conf_model} to {manifest['conf_model']}")
    shutil.copyfile(args.conf_model, manifest['conf_model'])
    logger.info(f"copying {args.config} to {manifest['training_config']}")
    shutil.copyfile(args.config, manifest['training_config'])

    with open(os.path.join(tmpdir, 'manifest.json'), 'w') as f:
        json.dump(manifest, f, indent=4)

    zip_path = args.output_dir + ".zip"
    zipf = zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED)
    for root, dirs, files in os.walk(tmpdir):
        for file in files:
            path = os.path.join(root, file)
            logger.info(f'adding {path} to {zip_path}')
            zipf.write(path)

    zipf.close()

    logger.info(f'removing {tmpdir}')
    shutil.rmtree(tmpdir)