def open(self): """Open the HDMF file and set up chunks and taxonomy label""" if self.comm is not None: self.io = get_hdf5io(self.path, 'r', comm=self.comm, driver='mpio') else: self.io = get_hdf5io(self.path, 'r') with warnings.catch_warnings(): warnings.simplefilter("ignore") self.orig_difile = self.io.read() if self._world_size > 1: self.orig_difile.set_sequence_subset(balsplit(self.orig_difile.get_seq_lengths(), self._world_size, self._global_rank)) self.difile = self.orig_difile self.load(sequence=self.load_data) self.difile.set_label_key(self.hparams.tgt_tax_lvl) if self.window is not None: self.set_chunks(self.window, self.step) if self.revcomp: self.set_revcomp() self._set_subset(train=self._train_subset, validate=self._validate_subset, test=self._test_subset)
def prepare_data(argv=None): '''Aggregate sequence data GTDB using a file-of-files''' import argparse import io import sys import logging import h5py import pandas as pd from skbio import TreeNode from hdmf.common import get_hdf5io from hdmf.data_utils import DataChunkIterator from ..utils import get_faa_path, get_fna_path, get_genomic_path from exabiome.sequence.convert import AASeqIterator, DNASeqIterator, DNAVocabIterator, DNAVocabGeneIterator from exabiome.sequence.dna_table import AATable, DNATable, SequenceTable, TaxaTable, DeepIndexFile, NewickString, CondensedDistanceMatrix parser = argparse.ArgumentParser() parser.add_argument( 'accessions', type=str, help='file of the NCBI accessions of the genomes to convert') parser.add_argument('fadir', type=str, help='directory with NCBI sequence files') parser.add_argument('metadata', type=str, help='metadata file from GTDB') parser.add_argument('tree', type=str, help='the distances file') parser.add_argument('out', type=str, help='output HDF5') grp = parser.add_mutually_exclusive_group() parser.add_argument('-e', '--emb', type=str, help='embedding file', default=None) grp.add_argument('-p', '--protein', action='store_true', default=False, help='get paths for protein files') grp.add_argument('-c', '--cds', action='store_true', default=False, help='get paths for CDS files') grp.add_argument('-g', '--genomic', action='store_true', default=False, help='get paths for genomic files (default)') parser.add_argument('-D', '--dist_h5', type=str, help='the distances file', default=None) parser.add_argument( '-d', '--max_deg', type=float, default=None, help='max number of degenerate characters in protein sequences') parser.add_argument('-l', '--min_len', type=float, default=None, help='min length of sequences') parser.add_argument('-V', '--vocab', action='store_true', default=False, help='store sequences as vocabulary data') if len(sys.argv) == 1: parser.print_help() sys.exit(1) args = parser.parse_args(args=argv) if not any([args.protein, args.cds, args.genomic]): args.genomic = True logging.basicConfig(stream=sys.stdout, level=logging.INFO, format='%(asctime)s - %(message)s') logger = logging.getLogger() # read accessions logger.info('reading accessions %s' % args.accessions) with open(args.accessions, 'r') as f: taxa_ids = [l[:-1] for l in f.readlines()] # get paths to Fasta Files fa_path_func = get_genomic_path if args.cds: fa_path_func = get_fna_path elif args.protein: fa_path_func = get_faa_path fapaths = [fa_path_func(acc, args.fadir) for acc in taxa_ids] di_kwargs = dict() # if a distance matrix file has been given, read and select relevant distances if args.dist_h5: ############################# # read and filter distances ############################# logger.info('reading distances from %s' % args.dist_h5) with h5py.File(args.dist_h5, 'r') as f: dist = f['distances'][:] dist_taxa = f['leaf_names'][:].astype('U') logger.info('selecting distances for taxa found in %s' % args.accessions) dist = select_distances(taxa_ids, dist_taxa, dist) dist = CondensedDistanceMatrix('distances', data=dist) di_kwargs['distances'] = dist ############################# # read and filter taxonomies ############################# logger.info('reading taxonomies from %s' % args.metadata) taxlevels = [ 'domain', 'phylum', 'class', 'order', 'family', 'genus', 'species' ] def func(row): dat = dict(zip(taxlevels, row['gtdb_taxonomy'].split(';'))) dat['species'] = dat['species'].split(' ')[1] dat['gtdb_genome_representative'] = row['gtdb_genome_representative'][ 3:] dat['accession'] = row['accession'][3:] return pd.Series(data=dat) logger.info('selecting GTDB taxonomy for taxa found in %s' % args.accessions) taxdf = pd.read_csv(args.metadata, header=0, sep='\t')[['accession', 'gtdb_taxonomy', 'gtdb_genome_representative']]\ .apply(func, axis=1)\ .set_index('accession')\ .filter(items=taxa_ids, axis=0) ############################# # read and filter embeddings ############################# emb = None if args.emb is not None: logger.info('reading embeddings from %s' % args.emb) with h5py.File(args.emb, 'r') as f: emb = f['embedding'][:] emb_taxa = f['leaf_names'][:] logger.info('selecting embeddings for taxa found in %s' % args.accessions) emb = select_embeddings(taxa_ids, emb_taxa, emb) ############################# # read and trim tree ############################# logger.info('reading tree from %s' % args.tree) root = TreeNode.read(args.tree, format='newick') logger.info('transforming leaf names for shearing') for tip in root.tips(): tip.name = tip.name[3:].replace(' ', '_') logger.info('shearing taxa not found in %s' % args.accessions) rep_ids = taxdf['gtdb_genome_representative'].values root = root.shear(rep_ids) logger.info('converting tree to Newick string') bytes_io = io.BytesIO() root.write(bytes_io, format='newick') tree_str = bytes_io.getvalue() tree = NewickString('tree', data=tree_str) if di_kwargs.get('distances') is None: from scipy.spatial.distance import squareform tt_dmat = root.tip_tip_distances() if (rep_ids != taxa_ids).any(): tt_dmat = get_nonrep_matrix(taxa_ids, rep_ids, tt_dmat) dmat = tt_dmat.data di_kwargs['distances'] = CondensedDistanceMatrix('distances', data=dmat) h5path = args.out logger.info("reading %d Fasta files" % len(fapaths)) logger.info("Total size: %d", sum(os.path.getsize(f) for f in fapaths)) if args.vocab: if args.protein: SeqTable = SequenceTable seqit = AAVocabIterator(fapaths, logger=logger, min_seq_len=args.min_len) else: SeqTable = DNATable if args.cds: logger.info("reading and writing CDS sequences") seqit = DNAVocabGeneIterator(fapaths, logger=logger, min_seq_len=args.min_len) else: seqit = DNAVocabIterator(fapaths, logger=logger, min_seq_len=args.min_len) else: if args.protein: logger.info("reading and writing protein sequences") seqit = AASeqIterator(fapaths, logger=logger, max_degenerate=args.max_deg, min_seq_len=args.min_len) SeqTable = AATable else: logger.info("reading and writing DNA sequences") seqit = DNASeqIterator(fapaths, logger=logger, min_seq_len=args.min_len) SeqTable = DNATable seqit_bsize = 2**25 if args.protein: seqit_bsize = 2**15 elif args.cds: seqit_bsize = 2**18 # set up DataChunkIterators packed = DataChunkIterator.from_iterable(iter(seqit), maxshape=(None, ), buffer_size=seqit_bsize, dtype=np.dtype('uint8')) seqindex = DataChunkIterator.from_iterable(seqit.index_iter, maxshape=(None, ), buffer_size=2**0, dtype=np.dtype('int')) names = DataChunkIterator.from_iterable(seqit.names_iter, maxshape=(None, ), buffer_size=2**0, dtype=np.dtype('U')) ids = DataChunkIterator.from_iterable(seqit.id_iter, maxshape=(None, ), buffer_size=2**0, dtype=np.dtype('int')) taxa = DataChunkIterator.from_iterable(seqit.taxon_iter, maxshape=(None, ), buffer_size=2**0, dtype=np.dtype('uint16')) seqlens = DataChunkIterator.from_iterable(seqit.seqlens_iter, maxshape=(None, ), buffer_size=2**0, dtype=np.dtype('uint32')) io = get_hdf5io(h5path, 'w') tt_args = ['taxa_table', 'a table for storing taxa data', taxa_ids] tt_kwargs = dict() for t in taxlevels[1:]: tt_args.append(taxdf[t].values) if emb is not None: tt_kwargs['embedding'] = emb tt_kwargs['rep_taxon_id'] = rep_ids taxa_table = TaxaTable(*tt_args, **tt_kwargs) seq_table = SeqTable( 'seq_table', 'a table storing sequences for computing sequence embedding', io.set_dataio(names, compression='gzip', chunks=(2**15, )), io.set_dataio(packed, compression='gzip', maxshape=(None, ), chunks=(2**15, )), io.set_dataio(seqindex, compression='gzip', maxshape=(None, ), chunks=(2**15, )), io.set_dataio(seqlens, compression='gzip', maxshape=(None, ), chunks=(2**15, )), io.set_dataio(taxa, compression='gzip', maxshape=(None, ), chunks=(2**15, )), taxon_table=taxa_table, id=io.set_dataio(ids, compression='gzip', maxshape=(None, ), chunks=(2**15, ))) difile = DeepIndexFile(seq_table, taxa_table, tree, **di_kwargs) io.write(difile, exhaust_dci=False) io.close() logger.info("reading %s" % (h5path)) h5size = os.path.getsize(h5path) logger.info("HDF5 size: %d", h5size)
parser.add_argument('-i', '--index', type=int, nargs='+', default=None, help='specific indices to check') args = parser.parse_args() logging.basicConfig(stream=sys.stderr, level=logging.DEBUG, format='%(asctime)s - %(message)s') logger = logging.getLogger() logger.info('opening %s' % args.h5) hdmfio = get_hdf5io(args.h5, 'r') difile = hdmfio.read() n_total_seqs = len(difile.seq_table) logger.info('found %d sequences' % n_total_seqs) if args.index is not None: idx = set(args.index) logger.info('checking sequences %s' % ", ".join(map(str, args.index))) else: logger.info('using seed %d' % args.seed) random = np.random.RandomState(args.seed) n_seqs = math.round( args.n_seqs * n_total_seqs) if args.n_seqs < 1.0 else int(args.n_seqs) idx = set(random.permutation(n_total_seqs)[:n_seqs]) logger.info('sampling %d sequences' % n_seqs)
from hdmf.common import get_hdf5io from exabiome.response.embedding import read_embedding import pandas as pd if __name__ == '__main__': import argparse parser = argparse.ArgumentParser(description='substitute new embeddings in a DeepIndex input file') parser.add_argument('new_embeddings', type=str, help='the embeddings to add to the DeepIndex input file') parser.add_argument('deep_index_input', type=str, help='the DeepIndex file with embeddings to overwwrite') args = parser.parse_args() emb, leaf_names = read_embedding(args.new_embeddings) leaf_names = [_[3:] for _ in leaf_names] emb_df = pd.DataFrame(data={'embedding1': emb[:,0], 'embedding2': emb[:,1]}, index=leaf_names) hdmfio = get_hdf5io(args.deep_index_input, mode='a') difile = hdmfio.read() di_taxa = difile.taxa_table['taxon_id'][:] di_emb = difile.taxa_table['embedding'].data # h5py.Dataset emb_df = emb_df.filter(items=di_taxa, axis=0) di_emb[:] = emb_df.values hdmfio.close()
from hdmf.common import get_hdf5io from exabiome.sequence.convert import SeqConcat, pack_ohe_dna from exabiome.sequence.dna_table import DNATable fnapath = "../deep_index/gtdb/test_data/genomes/all/GCA/000/989/525/GCA_000989525.1_ASM98952v1/GCA_000989525.1_ASM98952v1_cds_from_genomic.fna.gz" h5path = "seq.h5" # ## Read Fasta sequence print("reading %s" % (fnapath)) fasize = os.path.getsize(fnapath) print("Fasta size:", fasize) sc = SeqConcat() data, seqindex, ltags = sc._read_path(fnapath) # ## Pack sequence and write to HDF5 file packed, padded = pack_ohe_dna(data) with get_hdf5io(h5path, 'w') as io: table = DNATable('root', 'a test table', io.set_dataio(ltags, compression='gzip'), io.set_dataio(packed, compression='gzip'), io.set_dataio(seqindex, compression='gzip')) io.write(table) print("reading %s" % (h5path)) h5size = os.path.getsize(h5path) print("HDF5 size:", h5size)
def taxonomic_accuracy(argv=None): #import ..sequence as seq from ..sequence import DeepIndexFile from ..utils import get_logger from hdmf.common import get_hdf5io import h5py import numpy as np import pandas as pd from sklearn.preprocessing import LabelEncoder levels = DeepIndexFile.taxonomic_levels parser = argparse.ArgumentParser() parser.add_argument("summary", type=str, help='the summarized sequence NN outputs') parser.add_argument("input", type=str, help='the training input data') parser.add_argument("output", type=str, help="the path to save resutls to") parser.add_argument("-l", "--level", type=str, choices=levels, help='the taxonomic level') args = parser.parse_args(argv) logger = get_logger() logger.info(f'reading {args.input}') io = get_hdf5io(args.input, 'r') difile = io.read() with h5py.File(args.summary, 'r') as f: logger.info(f'loading summary results from {args.summary}') n_classes = None if 'outputs' in f: n_classes = f['outputs'].shape[1] elif 'n_classes' in f['labels'].attrs: n_classes = f['labels'].attrs['n_classes'] level = None classes = None if args.level is None: if n_classes is None: print(f"Could not find number of classes in {args.summary}. Without this, I cannot guess what the taxonomic level is") exit(1) for lvl in levels[:-1]: n_classes_lvl = difile.taxa_table[lvl].elements.data.shape[0] if n_classes == n_classes_lvl: classes = difile.taxa_table[lvl].elements.data level = lvl if level is None: n_classes_lvl = difile.taxa_table['species'].data.shape[0] if n_classes == n_classes_lvl: level = 'species' classes = difile.taxa_table['species'].data[:] else: print("Cannot determine which level to use. Please specify with --level option", file=sys.stderr) exit(1) else: level = args.level logger.info(f'computing accuracy for {level}{" and higher" if level else ""}') seq_preds = f['preds'][:].astype(int) seq_labels = f['labels'][:].astype(int) seq_lens = f['lengths'][:].astype(int) mask = seq_labels != -1 seq_preds = seq_preds[mask] seq_labels = seq_labels[mask] seq_lens = seq_lens[mask] logger.info(f'Keeping {mask.sum()} of {mask.shape[0]} ({mask.mean()*100:.1f}%) sequences after discarding uninitialized sequences') ## I used this code to double check that genus elements were correct # seq_ids = f['seq_ids'][:] # genome_ids = difile.seq_table['genome'].data[:][seq_ids] # taxon_ids = difile.genome_table['taxon_id'].data[:][genome_ids] # classes = difile.taxa_table['genus'].elements.data[:] logger.info('loading taxonomy table') # do this because h5py.Datasets cannot point-index with non-unique indices for col in difile.taxa_table.columns: col.transform(lambda x: x[:]) to_drop = ['taxon_id'] for lvl in levels[::-1]: if lvl == level: break to_drop.append(lvl) # orient table to index it by the taxonomic level and remove columns we cannot get predictions for taxdf = difile.taxa_table.to_dataframe() n_orig_classes = {col: np.unique(taxdf[col]).shape[0] for col in taxdf} taxdf = taxdf.drop(to_drop, axis=1).\ set_index(level).\ groupby(level).\ nth(0).\ filter(classes, axis=0) logger.info('encoding taxonomy for quicker comparisons') # encode into integers for faster comparisons encoders = dict() new_dat = dict() for col in taxdf.columns: enc = LabelEncoder().fit(taxdf[col]) encoders[col] = enc new_dat[col] = enc.transform(taxdf[col]) enc_df = pd.DataFrame(data=new_dat, index=taxdf.index) # a helper function to transform results into a DataFrame def get_results(true, pred, lens, n_classes): mask = true == pred n_classes = "%s / %s" % ((np.unique(true).shape[0]), n_classes) return {'seq-level': "%0.1f" % (100*mask.mean()), 'base-level': "%0.1f" % (100*lens[mask].sum()/lens.sum()), 'n_classes': n_classes} results = dict() for colname in enc_df.columns: logger.info(f'computing results for {colname}') col = enc_df[colname].values results[colname] = get_results(col[seq_labels], col[seq_preds], seq_lens, n_orig_classes[colname]) logger.info(f'computing results for {level}') results[level] = get_results(seq_labels, seq_preds, seq_lens, n_orig_classes[level]) results['n'] = {'seq-level': len(seq_lens), 'base-level': seq_lens.sum(), 'n_classes': '-1'} results = pd.DataFrame(data=results) results.to_csv(args.output, sep=',') print(results)
def prepare_data(argv=None): '''Aggregate sequence data GTDB using a file-of-files''' from io import BytesIO import tempfile import h5py from datetime import datetime from tqdm import tqdm from skbio import TreeNode from skbio.sequence import DNA, Protein from hdmf.common import get_hdf5io from hdmf.data_utils import DataChunkIterator from ..utils import get_faa_path, get_fna_path, get_genomic_path from deep_taxon.sequence.convert import AASeqIterator, DNASeqIterator, DNAVocabIterator, DNAVocabGeneIterator from deep_taxon.sequence.dna_table import AATable, DNATable, SequenceTable, TaxaTable, DeepIndexFile, NewickString, CondensedDistanceMatrix, GenomeTable, TreeGraph parser = argparse.ArgumentParser() parser.add_argument('fadir', type=str, help='directory with NCBI sequence files') parser.add_argument('metadata', type=str, help='metadata file from GTDB') parser.add_argument('out', type=str, help='output HDF5') parser.add_argument( '-T', '--tree', type=str, help='a Newick file with a tree of representative taxa', default=None) parser.add_argument( '-A', '--accessions', type=str, default=None, help='file of the NCBI accessions of the genomes to convert') parser.add_argument( '-d', '--max_deg', type=float, default=None, help='max number of degenerate characters in protein sequences') parser.add_argument('-l', '--min_len', type=float, default=None, help='min length of sequences') parser.add_argument('--iter', action='store_true', default=False, help='convert using iterators') parser.add_argument( '-p', '--num_procs', type=int, default=1, help='the number of processes to use for counting total sequence size') parser.add_argument('-L', '--total_seq_len', type=int, default=None, help='the total sequence length') parser.add_argument('-t', '--tmpdir', type=str, default=None, help='a temporary directory to store sequences') parser.add_argument('-N', '--n_seqs', type=int, default=None, help='the total number of sequences') rep_grp = parser.add_mutually_exclusive_group() rep_grp.add_argument( '-n', '--nonrep', action='store_true', default=False, help='keep non-representative genomes only. keep both by default') rep_grp.add_argument( '-r', '--rep', action='store_true', default=False, help='keep representative genomes only. keep both by default') parser.add_argument( '-a', '--all', action='store_true', default=False, help= 'keep all non-representative genomes. By default, only non-reps with the highest and lowest contig count are kept' ) grp = parser.add_mutually_exclusive_group() grp.add_argument('-P', '--protein', action='store_true', default=False, help='get paths for protein files') grp.add_argument('-C', '--cds', action='store_true', default=False, help='get paths for CDS files') grp.add_argument('-G', '--genomic', action='store_true', default=False, help='get paths for genomic files (default)') parser.add_argument('-z', '--gzip', action='store_true', default=False, help='GZip sequence table') dep_grp = parser.add_argument_group( title="Legacy options you probably do not need") dep_grp.add_argument('-e', '--emb', type=str, help='embedding file', default=None) if len(sys.argv) == 1: parser.print_help() sys.exit(1) args = parser.parse_args(args=argv) if args.total_seq_len is not None: if args.n_seqs is None: sys.stderr.write( "If using --total_seq_len, you must also use --n_seqs\n") if args.n_seqs is not None: if args.total_seq_len is None: sys.stderr.write( "If using --n_seqs, you must also use --total_seq_len\n") if not any([args.protein, args.cds, args.genomic]): args.genomic = True logging.basicConfig(stream=sys.stderr, level=logging.INFO, format='%(asctime)s - %(message)s') logger = logging.getLogger() ############################# # read and filter taxonomies ############################# logger.info('Reading taxonomies from %s' % args.metadata) taxlevels = [ 'domain', 'phylum', 'class', 'order', 'family', 'genus', 'species' ] extra_cols = ['contig_count', 'checkm_completeness'] def func(row): dat = dict(zip(taxlevels, row['gtdb_taxonomy'].split(';'))) dat['species'] = dat['species'] # .split(' ')[1] dat['gtdb_genome_representative'] = row['gtdb_genome_representative'][ 3:] dat['accession'] = row['accession'][3:] for k in extra_cols: dat[k] = row[k] return pd.Series(data=dat) taxdf = pd.read_csv(args.metadata, header=0, sep='\t')[['accession', 'gtdb_taxonomy', 'gtdb_genome_representative', 'contig_count', 'checkm_completeness']]\ .apply(func, axis=1) taxdf = taxdf.set_index('accession') dflen = len(taxdf) logger.info('Found %d total genomes' % dflen) taxdf = taxdf[taxdf['gtdb_genome_representative'].str.contains( 'GC[A,F]_', regex=True)] # get rid of genomes that are not at NCBI taxdf = taxdf[taxdf.index.str.contains( 'GC[A,F]_', regex=True)] # get rid of genomes that are not at NCBI logger.info('Discarded %d non-NCBI genomes' % (dflen - len(taxdf))) rep_taxdf = taxdf[taxdf.index == taxdf['gtdb_genome_representative']] if args.accessions is not None: logger.info('reading accessions %s' % args.accessions) with open(args.accessions, 'r') as f: accessions = [l[:-1] for l in f.readlines()] dflen = len(taxdf) taxdf = taxdf[taxdf.index.isin(accessions)] logger.info('Discarded %d genomes not found in %s' % (dflen - len(taxdf), args.accessions)) dflen = len(taxdf) if args.nonrep: taxdf = taxdf[taxdf.index != taxdf['gtdb_genome_representative']] logger.info('Discarded %d representative genomes' % (dflen - len(taxdf))) dflen = len(taxdf) if not args.all: groups = taxdf[['gtdb_genome_representative', 'contig_count' ]].groupby('gtdb_genome_representative') min_ctgs = groups.idxmin()['contig_count'] max_ctgs = groups.idxmax()['contig_count'] accessions = np.unique(np.concatenate([min_ctgs, max_ctgs])) taxdf = taxdf.filter(accessions, axis=0) logger.info('Discarded %d extra non-representative genomes' % (dflen - len(taxdf))) elif args.rep: taxdf = taxdf[taxdf.index == taxdf['gtdb_genome_representative']] logger.info('Discarded %d non-representative genomes' % (dflen - len(taxdf))) dflen = len(taxdf) logger.info('%d remaining genomes' % dflen) ############################### # Arguments for constructing the DeepIndexFile object ############################### di_kwargs = dict() taxa_ids = taxdf.index.values # get paths to Fasta Files fa_path_func = partial(get_genomic_path, directory=args.fadir) if args.cds: fa_path_func = partial(get_fna_path, directory=args.fadir) elif args.protein: fa_path_func = partial(get_faa_path, directory=args.fadir) map_func = map if args.num_procs > 1: logger.info(f'using {args.num_procs} processes to locate Fasta files') import multiprocessing as mp map_func = mp.Pool(processes=args.num_procs).imap logger.info('Locating Fasta files for each taxa') fapaths = list(tqdm(map_func(fa_path_func, taxa_ids), total=len(taxa_ids))) logger.info('Found Fasta files for all accessions') ############################# # read and filter embeddings ############################# emb = None if args.emb is not None: logger.info('reading embeddings from %s' % args.emb) with h5py.File(args.emb, 'r') as f: emb = f['embedding'][:] emb_taxa = f['leaf_names'][:] logger.info('selecting embeddings for taxa found in %s' % args.accessions) emb = select_embeddings(taxa_ids, emb_taxa, emb) logger.info(f'Writing {len(rep_taxdf)} taxa to taxa table') tt_args = [ 'taxa_table', 'a table for storing taxa data', rep_taxdf.index.values ] tt_kwargs = dict() for t in taxlevels[:-1]: enc = LabelEncoder().fit(rep_taxdf[t].values) _data = enc.transform(rep_taxdf[t].values).astype(np.uint32) _vocab = enc.classes_.astype('U') logger.info(f'{t} - {len(_vocab)} classes') tt_args.append( EnumData(name=t, description=f'label encoded {t}', data=_data, elements=_vocab)) # we have too many species to store this as VocabData, nor does it save any spaces tt_args.append( VectorData(name='species', description=f'Microbial species in the form Genus species', data=rep_taxdf['species'].values)) if emb is not None: tt_kwargs['embedding'] = emb #tt_kwargs['rep_taxon_id'] = rep_taxdf['gtdb_genome_representative'].values taxa_table = TaxaTable(*tt_args, **tt_kwargs) h5path = args.out logger.info("reading %d Fasta files" % len(fapaths)) logger.info("Total size: %d", sum(list(map_func(os.path.getsize, fapaths)))) tmp_h5_file = None if args.protein: vocab_it = AAVocabIterator SeqTable = SequenceTable skbio_cls = Protein else: vocab_it = DNAVocabIterator SeqTable = DNATable skbio_cls = DNA vocab = np.array(list(vocab_it.characters())) if not args.protein: np.testing.assert_array_equal(vocab, list('ACYWSKDVNTGRMHB')) if args.total_seq_len is None: logger.info('counting total number of sqeuences') n_seqs, total_seq_len = np.array( list(zip( *tqdm(map_func(seqlen, fapaths), total=len(fapaths))))).sum( axis=1) logger.info(f'found {total_seq_len} bases across {n_seqs} sequences') else: n_seqs, total_seq_len = args.n_seqs, args.total_seq_len logger.info( f'As specified, there are {total_seq_len} bases across {n_seqs} sequences' ) logger.info( f'allocating uint8 array of length {total_seq_len} for sequences') if args.tmpdir is not None: if not os.path.exists(args.tmpdir): os.mkdir(args.tmpdir) tmpdir = tempfile.mkdtemp(dir=args.tmpdir) else: tmpdir = tempfile.mkdtemp() comp = 'gzip' if args.gzip else None tmp_h5_filename = os.path.join(tmpdir, 'sequences.h5') logger.info(f'writing temporary sequence data to {tmp_h5_filename}') tmp_h5_file = h5py.File(tmp_h5_filename, 'w') sequence = tmp_h5_file.create_dataset('sequences', shape=(total_seq_len, ), dtype=np.uint8, compression=comp) seqindex = tmp_h5_file.create_dataset('sequences_index', shape=(n_seqs, ), dtype=np.uint64, compression=comp) genomes = tmp_h5_file.create_dataset('genomes', shape=(n_seqs, ), dtype=np.uint64, compression=comp) seqlens = tmp_h5_file.create_dataset('seqlens', shape=(n_seqs, ), dtype=np.uint64, compression=comp) names = tmp_h5_file.create_dataset('seqnames', shape=(n_seqs, ), dtype=h5py.special_dtype(vlen=str), compression=comp) taxa = np.zeros(len(fapaths), dtype=int) seq_i = 0 b = 0 for genome_i, fa in tqdm(enumerate(fapaths), total=len(fapaths)): kwargs = { 'format': 'fasta', 'constructor': skbio_cls, 'validate': False } taxid = taxa_ids[genome_i] rep_taxid = taxdf['gtdb_genome_representative'][genome_i] taxa[genome_i] = np.where(rep_taxdf.index == rep_taxid)[0][0] for seq in skbio.io.read(fa, **kwargs): enc_seq = vocab_it.encode(seq) e = b + len(enc_seq) sequence[b:e] = enc_seq seqindex[seq_i] = e genomes[seq_i] = genome_i seqlens[seq_i] = len(enc_seq) names[seq_i] = vocab_it.get_seqname(seq) b = e seq_i += 1 ids = tmp_h5_file.create_dataset('ids', data=np.arange(n_seqs), dtype=int) tmp_h5_file.flush() io = get_hdf5io(h5path, 'w') print([a['name'] for a in GenomeTable.__init__.__docval__['args']]) genome_table = GenomeTable( 'genome_table', 'information about the genome each sequence comes from', taxa_ids, taxa, taxa_table=taxa_table) ############################# # read and trim tree ############################# if args.tree: logger.info('Reading tree from %s' % args.tree) root = TreeNode.read(args.tree, format='newick') logger.info('Found %d tips' % len(list(root.tips()))) logger.info('Transforming leaf names for shearing') for tip in root.tips(): tip.name = tip.name[3:].replace(' ', '_') logger.info('converting tree to Newick string') bytes_io = BytesIO() root.write(bytes_io, format='newick') tree_str = bytes_io.getvalue() di_kwargs['tree'] = NewickString('tree', data=tree_str) # get distances from tree if they are not provided tt_dmat = root.tip_tip_distances().filter(rep_taxdf.index) di_kwargs['distances'] = CondensedDistanceMatrix('distances', data=tt_dmat.data) adj, gt_indices = get_tree_graph(root, rep_taxdf) di_kwargs['tree_graph'] = TreeGraph(data=adj, leaves=gt_indices, table=genome_table, name='tree_graph') if args.gzip: names = io.set_dataio(names, compression='gzip', chunks=True) sequence = io.set_dataio(sequence, compression='gzip', maxshape=(None, ), chunks=True) seqindex = io.set_dataio(seqindex, compression='gzip', maxshape=(None, ), chunks=True) seqlens = io.set_dataio(seqlens, compression='gzip', maxshape=(None, ), chunks=True) genomes = io.set_dataio(genomes, compression='gzip', maxshape=(None, ), chunks=True) ids = io.set_dataio(ids, compression='gzip', maxshape=(None, ), chunks=True) seq_table = SeqTable( 'seq_table', 'a table storing sequences for computing sequence embedding', names, sequence, seqindex, seqlens, genomes, genome_table=genome_table, id=ids, vocab=vocab) difile = DeepIndexFile(seq_table, taxa_table, genome_table, **di_kwargs) before = datetime.now() io.write(difile, exhaust_dci=False, link_data=False) io.close() after = datetime.now() delta = (after - before).total_seconds() logger.info( f'Sequence totals {sequence.dtype.itemsize * sequence.size} bytes') logger.info(f'Took {delta} seconds to write after read') if tmp_h5_file is not None: tmp_h5_file.close() logger.info("reading %s" % (h5path)) h5size = os.path.getsize(h5path) logger.info("HDF5 size: %d", h5size)
def build_deployment_pkg(argv=None): """ Convert a Torch model checkpoint to ONNX format """ import json import os import shutil import tempfile import zipfile from hdmf.common import get_hdf5io desc = "Convert a Torch model checkpoint to ONNX format" epi = ("By default, the ONNX file will be written to same directory " "as checkpoint") parser = argparse.ArgumentParser(description=desc, epilog=epi) parser.add_argument('input', type=str, help='the input file to run inference on') parser.add_argument('config', type=str, help='the config file used for training') parser.add_argument('nn_model', type=str, help='the NN model for doing predictions') parser.add_argument( 'conf_model', type=str, help='the checkpoint file to use for running inference') parser.add_argument('output_dir', type=str, help='the directory to copy to before zipping') parser.add_argument('-f', '--force', action='store_true', default=False, help='overwrite output if it exists') if len(argv) == 0: parser.print_help() sys.exit(1) args = parser.parse_args(argv) logger = get_logger() if os.path.exists(args.output_dir): if args.force: logger.info(f"{args.output_dir} exists, removing tree") shutil.rmtree(args.output_dir) else: logger.error(f"{args.output_dir} exists, exiting") exit(1) os.mkdir(args.output_dir) tmpdir = args.output_dir logger.info(f'Using temporary directory {tmpdir}') logger.info(f'loading sample input from {args.input}') io = get_hdf5io(args.input, 'r') difile = io.read() tt = difile.taxa_table vocab = difile.seq_table.sequence.elements.data[:].tolist() _load = lambda x: x[:] for col in tt.columns: col.transform(_load) tt_df = tt.to_dataframe().set_index('taxon_id') io.close() path = lambda x: os.path.join(tmpdir, os.path.basename(x)) manifest = { 'taxa_table': os.path.join(tmpdir, "taxa_table.csv"), 'nn_model': path(args.nn_model), 'conf_model': path(args.conf_model), 'training_config': path(args.config), 'vocabulary': vocab, } logger.info(f"exporting taxa table CSV to {manifest['taxa_table']}") tt_df.to_csv(manifest['taxa_table']) logger.info(f"copying {args.nn_model} to {manifest['nn_model']}") shutil.copyfile(args.nn_model, manifest['nn_model']) logger.info(f"copying {args.conf_model} to {manifest['conf_model']}") shutil.copyfile(args.conf_model, manifest['conf_model']) logger.info(f"copying {args.config} to {manifest['training_config']}") shutil.copyfile(args.config, manifest['training_config']) with open(os.path.join(tmpdir, 'manifest.json'), 'w') as f: json.dump(manifest, f, indent=4) zip_path = args.output_dir + ".zip" zipf = zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) for root, dirs, files in os.walk(tmpdir): for file in files: path = os.path.join(root, file) logger.info(f'adding {path} to {zip_path}') zipf.write(path) zipf.close() logger.info(f'removing {tmpdir}') shutil.rmtree(tmpdir)