def nonrep_radius(self, unclustered_gids, rep_gids, ani_af_rep_vs_nonrep): """Calculate circumscription radius for unclustered, nontype genomes.""" # set radius for genomes to default values nonrep_radius = {} for gid in unclustered_gids: nonrep_radius[gid] = GenomeRadius(ani=self.ani_sp, af=None, neighbour_gid=None) # determine closest type ANI neighbour and restrict ANI radius as necessary ani_af = pickle.load(open(ani_af_rep_vs_nonrep, 'rb')) for nonrep_gid in unclustered_gids: if nonrep_gid not in ani_af: continue for rep_gid in rep_gids: if rep_gid not in ani_af[nonrep_gid]: continue ani, af = FastANI.symmetric_ani(ani_af, nonrep_gid, rep_gid) if ani > nonrep_radius[nonrep_gid].ani and af >= self.af_sp: nonrep_radius[nonrep_gid] = GenomeRadius(ani=ani, af=af, neighbour_gid=rep_gid) self.logger.info('ANI circumscription radius: min={:.2f}, mean={:.2f}, max={:.2f}'.format( min([d.ani for d in nonrep_radius.values()]), np_mean([d.ani for d in nonrep_radius.values()]), max([d.ani for d in nonrep_radius.values()]))) return nonrep_radius
def _cluster(self, ani_af, non_reps, rep_radius): """Cluster non-representative to representative genomes using species specific ANI thresholds.""" clusters = {} for rep_id in rep_radius: clusters[rep_id] = [] num_clustered = 0 for idx, non_rid in enumerate(non_reps): if idx % 100 == 0: sys.stdout.write('==> Processed {:,} of {:,} genomes [no. clustered = {:,}].\r'.format( idx+1, len(non_reps), num_clustered)) sys.stdout.flush() if non_rid not in ani_af: continue closest_rid = None closest_ani = 0 closest_af = 0 for rid in rep_radius: if rid not in ani_af[non_rid]: continue ani, af = FastANI.symmetric_ani(ani_af, rid, non_rid) if af >= self.af_sp: if ani > closest_ani or (ani == closest_ani and af > closest_af): closest_rid = rid closest_ani = ani closest_af = af if closest_rid: if closest_ani > rep_radius[closest_rid].ani: num_clustered += 1 clusters[closest_rid].append(self.ClusteredGenome(gid=non_rid, ani=closest_ani, af=closest_af)) sys.stdout.write('==> Processed {:,} of {:,} genomes [no. clustered = {:,}].\r'.format( len(non_reps), len(non_reps), num_clustered)) sys.stdout.flush() sys.stdout.write('\n') num_unclustered = len(non_reps) - num_clustered self.logger.info('Assigned {:,} genomes to {:,} representatives; {:,} genomes remain unclustered.'.format( sum([len(clusters[rid]) for rid in clusters]), len(clusters), num_unclustered)) return clusters
def run(self, target_genus, gtdb_metadata_file, genomic_path_file): """Dereplicate GTDB species clusters using ANI/AF criteria.""" # create GTDB genome sets self.logger.info('Creating GTDB genome set.') genomes = Genomes() genomes.load_from_metadata_file(gtdb_metadata_file) genomes.load_genomic_file_paths(genomic_path_file) self.logger.info( ' - genome set has {:,} species clusters spanning {:,} genomes.'. format(len(genomes.sp_clusters), genomes.sp_clusters.total_num_genomes())) # identify GTDB representatives from target genus self.logger.info('Identifying GTDB representatives from target genus.') target_gids = set() for gid in genomes: if genomes[gid].is_gtdb_sp_rep( ) and genomes[gid].gtdb_taxa.genus == target_genus: target_gids.add(gid) self.logger.info(' - identified {:,} genomes.'.format( len(target_gids))) # calculate FastANI ANI/AF between target genomes self.logger.info('Calculating pairwise ANI between target genomes.') ani_af = self.fastani.pairwise(target_gids, genomes.genomic_files, check_cache=True) self.fastani.write_cache(silence=True) # write out results genus_label = target_genus.replace('g__', '').lower() fout = open( os.path.join(self.output_dir, '{}_rep_ani.tsv'.format(genus_label)), 'w') fout.write( 'Query ID\tQuery species\tTarget ID\tTarget species\tANI\tAF\n') for qid in target_gids: for rid in target_gids: ani, af = FastANI.symmetric_ani(ani_af, qid, rid) fout.write('{}\t{}\t{}\t{}\t{:.3f}\t{:.3f}\n'.format( qid, genomes[qid].gtdb_taxa.species, rid, genomes[rid].gtdb_taxa.species, ani, af)) fout.close()
def rep_genome_stats(self, clusters, genome_files): """Calculate statistics relative to representative genome.""" self.logger.info('Calculating statistics to cluster representatives:') stats = {} for idx, (rid, cids) in enumerate(clusters.items()): if len(cids) == 0: stats[rid] = self.RepStats(min_ani=-1, mean_ani=-1, std_ani=-1, median_ani=-1) else: # calculate ANI to representative genome gid_pairs = [] for cid in cids: gid_pairs.append((cid, rid)) gid_pairs.append((rid, cid)) if True: # *** DEBUGGING ani_af = self.fastani.pairs(gid_pairs, genome_files, report_progress=False) else: ani_af = self.fastani.ani_cache # calculate statistics anis = [FastANI.symmetric_ani(ani_af, cid, rid)[ 0] for cid in cids] stats[rid] = self.RepStats(min_ani=min(anis), mean_ani=np_mean(anis), std_ani=np_std(anis), median_ani=np_median(anis)) statusStr = '-> Processing %d of %d (%.2f%%) clusters.'.ljust(86) % ( idx+1, len(clusters), float((idx+1)*100)/len(clusters)) sys.stdout.write('%s\r' % statusStr) sys.stdout.flush() sys.stdout.write('\n') return stats
def find_multiple_reps(self, clusters, cluster_radius): """Determine number of non-rep genomes within ANI radius of multiple rep genomes. This method assumes the ANI cache contains all relevant ANI calculations between representative and non-representative genomes. This is the case once the de novo clustering has been performed. """ self.logger.info( 'Determine number of non-rep genomes within ANI radius of multiple rep genomes.') # get clustered genomes IDs clustered_gids = [] for rid in clusters: clustered_gids += clusters[rid] self.logger.info('Considering {:,} representatives and {:,} non-representative genomes.'.format( len(clusters), len(clustered_gids))) nonrep_rep_count = defaultdict(set) for idx, gid in enumerate(clustered_gids): cur_ani_cache = self.fastani.ani_cache[gid] for rid in clusters: if rid not in cur_ani_cache: continue ani, af = FastANI.symmetric_ani( self.fastani.ani_cache, gid, rid) if af >= self.af_sp and ani >= cluster_radius[rid].ani: nonrep_rep_count[gid].add((rid, ani)) if (idx+1) % 100 == 0 or (idx+1) == len(clustered_gids): statusStr = '-> Processing %d of %d (%.2f%%) clusters genomes.'.ljust(86) % ( idx+1, len(clustered_gids), float((idx+1)*100)/len(clustered_gids)) sys.stdout.write('%s\r' % statusStr) sys.stdout.flush() sys.stdout.write('\n') return nonrep_rep_count
def top_hits(self, species, rid, ani_af, genomes): """Report top 5 hits to species.""" results = {} for qid in ani_af[rid]: ani, af = FastANI.symmetric_ani(ani_af, rid, qid) results[qid] = (ani, af) self.logger.info(f'Closest 5 species to {species} ({rid}):') idx = 0 for qid, (ani, af) in sorted(results.items(), key=lambda x: x[1], reverse=True): q_species = genomes[qid].gtdb_species self.logger.info( f'{q_species} ({qid}): ANI={ani:.1f}%, AF={af:.2f}') if idx == 5: break idx += 1
def merge_ani_radius(self, species, rid, merged_sp_cluster, genomic_files): """Determine ANI radius if species were merged.""" self.logger.info( f'Calculating ANI from {species} to all genomes in merged species cluster.' ) gid_pairs = [] for gid in merged_sp_cluster: gid_pairs.append((rid, gid)) gid_pairs.append((gid, rid)) merged_ani_af1 = self.fastani.pairs(gid_pairs, genomic_files) ani_radius = 100 for gid in merged_sp_cluster: ani, af = FastANI.symmetric_ani(merged_ani_af1, rid, gid) if ani < ani_radius: ani_radius = ani af_radius = af self.logger.info( f'Merged cluster with {species} rep: ANI radius={ani_radius:.1f}%, AF={af_radius:.2f}' )
def calculate_type_strain_ani(self, ncbi_sp, type_gids, cur_genomes, use_pickled_results): """Calculate pairwise ANI between type strain genomes.""" ncbi_sp_str = ncbi_sp[3:].lower().replace(' ', '_') if not use_pickled_results: # *** ani_af = self.fastani.pairwise(type_gids, cur_genomes.genomic_files) pickle.dump( ani_af, open(os.path.join(self.ani_pickle_dir, f'{ncbi_sp_str}.pkl'), 'wb')) else: ani_af = pickle.load( open(os.path.join(self.ani_pickle_dir, f'{ncbi_sp_str}.pkl'), 'rb')) anis = [] afs = [] gid_anis = defaultdict(lambda: {}) gid_afs = defaultdict(lambda: {}) all_similar = True for gid1, gid2 in combinations(type_gids, 2): ani, af = FastANI.symmetric_ani(ani_af, gid1, gid2) if ani < 99 or af < 0.65: all_similar = False anis.append(ani) afs.append(af) gid_anis[gid1][gid2] = ani gid_anis[gid2][gid1] = ani gid_afs[gid1][gid2] = af gid_afs[gid2][gid1] = af return all_similar, anis, afs, gid_anis, gid_afs
def run(self, gtdb_metadata_file, genome_path_file, species1, species2): """Produce information relevant to merging two sister species.""" # read GTDB species clusters self.logger.info('Reading GTDB species clusters.') genomes = Genomes() genomes.load_from_metadata_file(gtdb_metadata_file) genomes.load_genomic_file_paths(genome_path_file) self.logger.info( ' - identified {:,} species clusters spanning {:,} genomes.'. format(len(genomes.sp_clusters), genomes.sp_clusters.total_num_genomes())) # find species of interest gid1 = None gid2 = None for gid, species in genomes.sp_clusters.species(): if species == species1: gid1 = gid elif species == species2: gid2 = gid if gid1 is None: self.logger.error( f'Unable to find representative genome for {species1}.') sys.exit(-1) if gid2 is None: self.logger.error( f'Unable to find representative genome for {species2}.') sys.exit(-1) self.logger.info(' - identified {:,} genomes in {}.'.format( len(genomes.sp_clusters[gid1]), species1)) self.logger.info(' - identified {:,} genomes in {}.'.format( len(genomes.sp_clusters[gid2]), species2)) # calculate ANI between all genome in genus genus1 = genomes[gid1].gtdb_genus genus2 = genomes[gid2].gtdb_genus if genus1 != genus2: self.logger.error( f'Genomes must be from same genus: {genus1} {genus2}') sys.exit(-1) self.logger.info(f'Identifying {genus1} species representatives.') reps_in_genera = set() for rid in genomes.sp_clusters: if genomes[rid].gtdb_genus == genus1: reps_in_genera.add(rid) self.logger.info( f' - identified {len(reps_in_genera):,} representatives.') # calculate ANI between genomes self.logger.info(f'Calculating ANI to {species1}.') gid_pairs = [] for gid in reps_in_genera: if gid != gid1: gid_pairs.append((gid1, gid)) gid_pairs.append((gid, gid1)) ani_af1 = self.fastani.pairs(gid_pairs, genomes.genomic_files) self.logger.info(f'Calculating ANI to {species2}.') gid_pairs = [] for gid in reps_in_genera: if gid != gid2: gid_pairs.append((gid2, gid)) gid_pairs.append((gid, gid2)) ani_af2 = self.fastani.pairs(gid_pairs, genomes.genomic_files) # report results ani12, af12 = ani_af1[gid1][gid2] ani21, af21 = ani_af2[gid2][gid1] ani, af = FastANI.symmetric_ani(ani_af1, gid1, gid2) self.logger.info( f'{species1} ({gid1}) -> {species2} ({gid2}): ANI={ani12:.1f}%, AF={af12:.2f}' ) self.logger.info( f'{species2} ({gid2}) -> {species1} ({gid1}): ANI={ani21:.1f}%, AF={af21:.2f}' ) self.logger.info(f'Max. ANI={ani:.1f}%, Max. AF={af:.2f}') # report top hits self.top_hits(species1, gid1, ani_af1, genomes) self.top_hits(species2, gid2, ani_af2, genomes) # calculate ANI from species to all genomes in merged species cluster merged_sp_cluster = genomes.sp_clusters[gid1].union( genomes.sp_clusters[gid2]) self.merge_ani_radius(species1, gid1, merged_sp_cluster, genomes.genomic_files) self.merge_ani_radius(species2, gid2, merged_sp_cluster, genomes.genomic_files)
def pairwise_stats(self, clusters, genome_files): """Calculate statistics for all pairwise comparisons in a species cluster.""" self.logger.info( f'Restricting pairwise comparisons to {self.max_genomes_for_stats:,} randomly selected genomes.') self.logger.info( 'Calculating statistics for all pairwise comparisons in a species cluster:') stats = {} for idx, (rid, cids) in enumerate(clusters.items()): statusStr = '-> Processing {:,} of {:,} ({:2f}%) clusters (size = {:,}).'.ljust(86).format( idx+1, len(clusters), float((idx+1)*100)/len(clusters), len(cids)) sys.stdout.write('{}\r'.format(statusStr)) sys.stdout.flush() if len(cids) == 0: stats[rid] = self.PairwiseStats(min_ani=-1, mean_ani=-1, std_ani=-1, median_ani=-1, ani_to_medoid=-1, mean_ani_to_medoid=-1, mean_ani_to_rep=-1, ani_below_95=-1) else: if len(cids) > self.max_genomes_for_stats: cids = set(random.sample(cids, self.max_genomes_for_stats)) # calculate ANI to representative genome gid_pairs = [] gids = list(cids.union([rid])) for gid1, gid2 in combinations(gids, 2): gid_pairs.append((gid1, gid2)) gid_pairs.append((gid2, gid1)) if True: # ***DEBUGGING ani_af = self.fastani.pairs(gid_pairs, genome_files, report_progress=False) else: ani_af = self.fastani.ani_cache # calculate medoid point if len(gids) > 2: dist_mat = np_zeros((len(gids), len(gids))) for i, gid1 in enumerate(gids): for j, gid2 in enumerate(gids): if i < j: ani, _af = FastANI.symmetric_ani( ani_af, gid1, gid2) dist_mat[i, j] = 100 - ani dist_mat[j, i] = 100 - ani medoid_idx = np_argmin(dist_mat.sum(axis=0)) medoid_gid = gids[medoid_idx] else: # with only 2 genomes in a cluster, the representative is the # natural medoid at least for reporting statistics for the # individual species cluster medoid_gid = rid mean_ani_to_medoid = np_mean([FastANI.symmetric_ani(ani_af, gid, medoid_gid)[0] for gid in gids if gid != medoid_gid]) mean_ani_to_rep = np_mean([FastANI.symmetric_ani(ani_af, gid, rid)[0] for gid in gids if gid != rid]) if mean_ani_to_medoid < mean_ani_to_rep: self.logger.error('mean_ani_to_medoid < mean_ani_to_rep') sys.exit(-1) # calculate statistics anis = [] for gid1, gid2 in combinations(gids, 2): ani, _af = FastANI.symmetric_ani(ani_af, gid1, gid2) anis.append(ani) stats[rid] = self.PairwiseStats( min_ani=min(anis), mean_ani=np_mean(anis), std_ani=np_std(anis), median_ani=np_median(anis), ani_to_medoid=FastANI.symmetric_ani( ani_af, rid, medoid_gid)[0], mean_ani_to_medoid=mean_ani_to_medoid, mean_ani_to_rep=mean_ani_to_rep, ani_below_95=sum([1 for ani in anis if ani < 95])) sys.stdout.write('\n') return stats
def intragenus_pairwise_ani(self, clusters, species, genome_files, gtdb_taxonomy): """Determine pairwise intra-genus ANI between representative genomes.""" self.logger.info( 'Calculating pairwise intra-genus ANI values between GTDB representatives.') # get genus for each representative genus = {} for rid, sp in species.items(): genus[rid] = sp.split()[0].replace('s__', '') assert genus[rid] == gtdb_taxonomy[rid][5].replace('g__', '') # get pairs above Mash threshold self.logger.info('Determining intra-genus genome pairs.') ani_pairs = [] for qid in clusters: for rid in clusters: if qid == rid: continue genusA = genus[qid] genusB = genus[rid] if genusA != genusB: continue ani_pairs.append((qid, rid)) ani_pairs.append((rid, qid)) self.logger.info( 'Identified {:,} intra-genus genome pairs.'.format(len(ani_pairs))) # calculate ANI between pairs self.logger.info( 'Calculating ANI between {:,} genome pairs:'.format(len(ani_pairs))) if True: # ***DEBUGGING ani_af = self.fastani.pairs(ani_pairs, genome_files) pickle.dump(ani_af, open(os.path.join( self.output_dir, 'type_genomes_ani_af.pkl'), 'wb')) else: ani_af = pickle.load( open(os.path.join(self.output_dir, 'type_genomes_ani_af.pkl'), 'rb')) # find closest intra-genus pair for each rep fout = open(os.path.join(self.output_dir, 'intra_genus_pairwise_ani.tsv'), 'w') fout.write( 'Genus\tSpecies 1\tGenome ID 1\tSpecies 2\tGenome ID2\tANI\tAF\n') closest_intragenus_rep = {} for qid in clusters: genusA = genus[qid] closest_ani = 0 closest_af = 0 closest_gid = None for rid in clusters: if qid == rid: continue genusB = genus[rid] if genusA != genusB: continue ani, af = ('n/a', 'n/a') if qid in ani_af and rid in ani_af[qid]: ani, af = FastANI.symmetric_ani(ani_af, qid, rid) if ani > closest_ani: closest_ani = ani closest_af = af closest_gid = rid fout.write('{}\t{}\t{}\t{}\t{}\t{}\t{}\n'.format( genusA, species[qid], qid, species[rid], rid, ani, af)) if closest_gid: closest_intragenus_rep[qid] = ( closest_gid, closest_ani, closest_af) fout.close() # write out closest intra-genus species to each representative fout = open(os.path.join(self.output_dir, 'closest_intragenus_rep.tsv'), 'w') fout.write( 'Genome ID\tSpecies\tIntra-genus neighbour\tIntra-genus species\tANI\tAF\n') for qid in closest_intragenus_rep: rid, ani, af = closest_intragenus_rep[qid] fout.write('%s\t%s\t%s\t%s\t%.2f\t%.3f\n' % ( qid, species[qid], rid, species[rid], ani, af)) fout.close()
def dereplicate_species(self, species, rid, cids, genomes, mash_out_dir): """Dereplicate genomes within a GTDB species.""" # greedily dereplicate genomes based on genome priority sorted_gids = self.order_genomes_by_priority(cids.difference([rid]), genomes) sorted_gids = [rid] + sorted_gids # calculate Mash ANI between genomes mash_ani = [] if len(sorted_gids) > 1: # calculate MASH distances between genomes out_prefix = os.path.join(mash_out_dir, species[3:].lower().replace(' ', '_')) mash_ani = self.mash_sp_ani(sorted_gids, genomes, out_prefix) # perform initial dereplication using Mash for species with excessive # numbers of genomes if len(sorted_gids) > self.max_genomes_per_sp: self.logger.info( ' - limiting species to <={:,} genomes based on priority and Mash dereplication.' .format(self.max_genomes_per_sp)) prev_mash_rep_gids = None for ani_threshold in [ 99.75, 99.5, 99.25, 99.0, 98.75, 98.5, 98.25, 98.0, 97.75, 97.5, 97.0, 96.5, 96.0, 95.0, None ]: if ani_threshold is None: self.logger.warning( ' - delected {:,} highest priority genomes from final Mash dereplication.' % self.max_genomes_per_sp) sorted_gids = mash_rep_gids[0:self.max_genomes_per_sp] break mash_rep_gids = self.mash_sp_dereplicate( mash_ani, sorted_gids, ani_threshold) self.logger.info( ' - dereplicated {} from {:,} to {:,} genomes at {:.2f}% ANI using Mash.' .format(species, len(cids), len(mash_rep_gids), ani_threshold)) if len(mash_rep_gids) <= self.max_genomes_per_sp: if not prev_mash_rep_gids: # corner case where dereplication is occurring at 99.75% prev_mash_rep_gids = sorted_gids # select maximum allowed number of genomes by taking all genomes in the # current Mash dereplicated set and then the highest priority genomes in the # previous Mash dereplicated set which have not been selected cur_sel_gids = set(mash_rep_gids) prev_sel_gids = set(prev_mash_rep_gids) num_prev_to_sel = self.max_genomes_per_sp - len( cur_sel_gids) num_prev_selected = 0 sel_sorted_gids = [] for gid in sorted_gids: if gid in cur_sel_gids: sel_sorted_gids.append(gid) elif (gid in prev_sel_gids and num_prev_selected < num_prev_to_sel): num_prev_selected += 1 sel_sorted_gids.append(gid) if len(sel_sorted_gids) == self.max_genomes_per_sp: break assert len(cur_sel_gids - set(sel_sorted_gids)) == 0 assert num_prev_to_sel == num_prev_selected assert len(sel_sorted_gids) == self.max_genomes_per_sp sorted_gids = sel_sorted_gids self.logger.info( ' - selected {:,} highest priority genomes from Mash dereplication at an ANI = {:.2f}%.' .format(len(sorted_gids), ani_threshold)) break prev_mash_rep_gids = mash_rep_gids prev_ani_threshold = ani_threshold # calculate FastANI ANI/AF between genomes passing Mash filtering ani_pairs = set() for gid1, gid2 in permutations(sorted_gids, 2): if gid1 in mash_ani and gid2 in mash_ani[gid1]: if mash_ani[gid1][gid2] >= self.min_mash_intra_sp_ani: ani_pairs.add((gid1, gid2)) ani_pairs.add((gid2, gid1)) self.logger.info( ' - calculating FastANI between {:,} pairs with Mash ANI >= {:.1f}%.' .format(len(ani_pairs), self.min_mash_intra_sp_ani)) ani_af = self.fastani.pairs(ani_pairs, genomes.genomic_files, report_progress=False, check_cache=True) self.fastani.write_cache(silence=True) # perform greedy dereplication sp_reps = [] for idx, gid in enumerate(sorted_gids): # determine if genome clusters with existing representative clustered = False for rid in sp_reps: ani, af = FastANI.symmetric_ani(ani_af, gid, rid) if ani >= self.derep_ani and af >= self.derep_af: clustered = True break if not clustered: sp_reps.append(gid) self.logger.info( ' - dereplicated {} from {:,} to {:,} genomes.'.format( species, len(sorted_gids), len(sp_reps))) # assign clustered genomes to most similar representative subsp_clusters = {} for rid in sp_reps: subsp_clusters[rid] = [rid] non_rep_gids = set(sorted_gids) - set(sp_reps) for gid in non_rep_gids: closest_rid = None max_ani = 0 max_af = 0 for rid in sp_reps: ani, af = FastANI.symmetric_ani(ani_af, gid, rid) if ((ani > max_ani and af >= self.derep_af) or (ani == max_ani and af >= max_af and af >= self.derep_af)): max_ani = ani max_af = af closest_rid = rid assert closest_rid is not None subsp_clusters[closest_rid].append(gid) return subsp_clusters
def write_synonym_table(self, type_strain_synonyms, consensus_synonyms, ani_af, sp_priority_ledger, genus_priority_ledger, lpsn_gss_file): """Create table indicating species names that should be considered synonyms.""" sp_priority_mngr = SpeciesPriorityManager(sp_priority_ledger, genus_priority_ledger, lpsn_gss_file, self.output_dir) out_file = os.path.join(self.output_dir, 'synonyms.tsv') fout = open(out_file, 'w') fout.write( 'Synonym type\tNCBI species\tGTDB representative\tStrain IDs\tType sources\tPriority year' ) fout.write('\tGTDB type species\tGTDB type strain\tNCBI assembly type') fout.write( '\tNCBI synonym\tHighest-quality synonym genome\tSynonym strain IDs\tSynonym type sources\tSynonym priority year' ) fout.write( '\tSynonym GTDB type species\tSynonym GTDB type strain\tSynonym NCBI assembly type' ) fout.write('\tANI\tAF\tWarnings\n') incorrect_priority = 0 failed_type_strain_priority = 0 for synonyms, synonym_type in [ (type_strain_synonyms, 'TYPE_STRAIN_SYNONYM'), (consensus_synonyms, 'MAJORITY_VOTE_SYNONYM') ]: for rid, synonym_ids in synonyms.items(): for gid in synonym_ids: ani, af = FastANI.symmetric_ani(ani_af, rid, gid) fout.write(synonym_type) fout.write('\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}'.format( self.cur_genomes[rid].ncbi_taxa.species, rid, ','.join(sorted(self.cur_genomes[rid].strain_ids())), ','.join( sorted(self.cur_genomes[rid].gtdb_type_sources()) ).upper().replace('STRAININFO', 'StrainInfo'), sp_priority_mngr.species_priority_year( self.cur_genomes, rid), self.cur_genomes[rid].is_gtdb_type_species(), self.cur_genomes[rid].is_gtdb_type_strain(), self.cur_genomes[rid].ncbi_type_material)) synonym_priority_year = sp_priority_mngr.species_priority_year( self.cur_genomes, gid) if synonym_priority_year == Genome.NO_PRIORITY_YEAR: synonym_priority_year = 'n/a' fout.write('\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}'.format( self.cur_genomes[gid].ncbi_taxa.species, gid, ','.join(sorted(self.cur_genomes[gid].strain_ids())), ','.join( sorted(self.cur_genomes[gid].gtdb_type_sources()) ).upper().replace('STRAININFO', 'StrainInfo'), synonym_priority_year, self.cur_genomes[gid].is_gtdb_type_species(), self.cur_genomes[gid].is_gtdb_type_strain(), self.cur_genomes[gid].ncbi_type_material)) fout.write('\t{:.3f}\t{:.4f}'.format(ani, af)) if self.cur_genomes[rid].is_effective_type_strain( ) and self.cur_genomes[gid].is_effective_type_strain(): priority_gid, note = sp_priority_mngr.species_priority( self.cur_genomes, rid, gid) if priority_gid != rid: incorrect_priority += 1 fout.write('\tIncorrect priority: {}'.format(note)) elif not self.cur_genomes[rid].is_gtdb_type_strain( ) and self.cur_genomes[gid].is_gtdb_type_strain(): failed_type_strain_priority += 1 fout.write( '\tFailed to prioritize type strain of species') fout.write('\n') if incorrect_priority: self.logger.warning( f' - identified {incorrect_priority:,} synonyms with incorrect priority.' ) if failed_type_strain_priority: self.logger.warning( f' - identified {failed_type_strain_priority:,} synonyms that failed to priotize the type strain of the species.' )
def cluster_genomes(self, cur_genomes, de_novo_rep_gids, named_rep_gids, final_cluster_radius): """Cluster new representatives to representatives of named GTDB species clusters.""" all_reps = de_novo_rep_gids.union(named_rep_gids) nonrep_gids = set(cur_genomes.genomes.keys()) - all_reps self.logger.info('Clustering {:,} genomes to {:,} named and de novo representatives.'.format( len(nonrep_gids), len(all_reps))) if True: # *** # calculate MASH distance between non-representatives and representatives genomes mash = Mash(self.cpus) mash_rep_sketch_file = os.path.join( self.output_dir, 'gtdb_rep_genomes.msh') rep_genome_list_file = os.path.join( self.output_dir, 'gtdb_rep_genomes.lst') mash.sketch(all_reps, cur_genomes.genomic_files, rep_genome_list_file, mash_rep_sketch_file) mash_none_rep_sketch_file = os.path.join( self.output_dir, 'gtdb_nonrep_genomes.msh') non_rep_file = os.path.join( self.output_dir, 'gtdb_nonrep_genomes.lst') mash.sketch(nonrep_gids, cur_genomes.genomic_files, non_rep_file, mash_none_rep_sketch_file) # get Mash distances mash_dist_file = os.path.join( self.output_dir, 'gtdb_rep_vs_nonrep_genomes.dst') mash.dist(float(100 - self.min_mash_ani)/100, mash_rep_sketch_file, mash_none_rep_sketch_file, mash_dist_file) # read Mash distances mash_ani = mash.read_ani(mash_dist_file) # calculate ANI between non-representatives and representatives genomes clusters = {} for gid in all_reps: clusters[gid] = [] mash_ani_pairs = [] for qid in mash_ani: assert qid in nonrep_gids for rid in mash_ani[qid]: assert rid in all_reps if (mash_ani[qid][rid] >= self.min_mash_ani and qid != rid): mash_ani_pairs.append((qid, rid)) mash_ani_pairs.append((rid, qid)) self.logger.info('Calculating ANI between {:,} species clusters and {:,} unclustered genomes ({:,} pairs):'.format( len(clusters), len(nonrep_gids), len(mash_ani_pairs))) ani_af = self.fastani.pairs( mash_ani_pairs, cur_genomes.genomic_files) # assign genomes to closest representatives # that is within the representatives ANI radius self.logger.info('Assigning genomes to closest representative.') for idx, cur_gid in enumerate(nonrep_gids): closest_rep_gid = None closest_rep_ani = 0 closest_rep_af = 0 for rep_gid in clusters: ani, af = FastANI.symmetric_ani(ani_af, cur_gid, rep_gid) isclose_abs_tol = 1e-4 if (ani >= final_cluster_radius[rep_gid].ani - isclose_abs_tol and af >= self.af_sp - isclose_abs_tol): # the isclose_abs_tol factor is used in order to avoid missing genomes due to # small rounding errors when comparing floating point values. In particular, # the ANI radius for named GTDB representatives is read from file so small # rounding errors could occur. This has only been observed once, but seems # like good practice to use isclose here. if ani > closest_rep_ani or (ani == closest_rep_ani and af > closest_rep_af): closest_rep_gid = rep_gid closest_rep_ani = ani closest_rep_af = af if closest_rep_gid: clusters[closest_rep_gid].append(ClusteredGenome(gid=cur_gid, ani=closest_rep_ani, af=closest_rep_af)) else: self.logger.warning( 'Failed to assign genome {} to representative.'.format(cur_gid)) if closest_rep_gid: self.logger.warning( ' - closest_rep_gid = {}'.format(closest_rep_gid)) self.logger.warning( ' - closest_rep_ani = {:.2f}'.format(closest_rep_ani)) self.logger.warning( ' - closest_rep_af = {:.2f}'.format(closest_rep_af)) self.logger.warning( ' - closest rep radius = {:.2f}'.format(final_cluster_radius[closest_rep_gid].ani)) else: self.logger.warning( ' - no representative with an AF >{:.2f} identified.'.format(self.af_sp)) statusStr = '-> Assigned {:,} of {:,} ({:.2f}%) genomes.'.format(idx+1, len(nonrep_gids), float(idx+1)*100/len(nonrep_gids)).ljust(86) sys.stdout.write('{}\r'.format(statusStr)) sys.stdout.flush() sys.stdout.write('\n') pickle.dump(clusters, open(os.path.join( self.output_dir, 'clusters.pkl'), 'wb')) pickle.dump(ani_af, open(os.path.join(self.output_dir, 'ani_af_rep_vs_nonrep.de_novo.pkl'), 'wb')) else: self.logger.warning( 'Using previously calculated results in: {}'.format('clusters.pkl')) clusters = pickle.load( open(os.path.join(self.output_dir, 'clusters.pkl'), 'rb')) self.logger.warning('Using previously calculated results in: {}'.format( 'ani_af_rep_vs_nonrep.de_novo.pkl')) ani_af = pickle.load( open(os.path.join(self.output_dir, 'ani_af_rep_vs_nonrep.de_novo.pkl'), 'rb')) return clusters, ani_af
def selected_rep_genomes(self, cur_genomes, nonrep_radius, unclustered_qc_gids, mash_ani): """Select de novo representatives for species clusters in a greedy fashion using species-specific ANI thresholds.""" # sort genomes by quality score self.logger.info( 'Selecting de novo representatives in a greedy manner based on quality.') q = {gid: cur_genomes[gid].score_type_strain() for gid in unclustered_qc_gids} q_sorted = sorted(q.items(), key=lambda kv: ( kv[1], kv[0]), reverse=True) # greedily determine representatives for new species clusters cluster_rep_file = os.path.join(self.output_dir, 'cluster_reps.tsv') clusters = set() if not os.path.exists(cluster_rep_file): clustered_genomes = 0 max_ani_pairs = 0 for idx, (cur_gid, _score) in enumerate(q_sorted): # determine reference genomes to calculate ANI between ani_pairs = [] if cur_gid in mash_ani: for rep_gid in clusters: if mash_ani[cur_gid].get(rep_gid, 0) >= self.min_mash_ani: ani_pairs.append((cur_gid, rep_gid)) ani_pairs.append((rep_gid, cur_gid)) # determine if genome clusters with representative clustered = False if ani_pairs: if len(ani_pairs) > max_ani_pairs: max_ani_pairs = len(ani_pairs) ani_af = self.fastani.pairs( ani_pairs, cur_genomes.genomic_files, report_progress=False) closest_rep_gid = None closest_rep_ani = 0 closest_rep_af = 0 for rep_gid in clusters: ani, af = FastANI.symmetric_ani( ani_af, cur_gid, rep_gid) if af >= self.af_sp: if ani > closest_rep_ani or (ani == closest_rep_ani and af > closest_rep_af): closest_rep_gid = rep_gid closest_rep_ani = ani closest_rep_af = af if ani > nonrep_radius[cur_gid].ani and af >= self.af_sp: nonrep_radius[cur_gid] = GenomeRadius(ani=ani, af=af, neighbour_gid=rep_gid) if closest_rep_gid and closest_rep_ani > nonrep_radius[closest_rep_gid].ani: clustered = True if not clustered: # genome is a new species cluster representative clusters.add(cur_gid) else: clustered_genomes += 1 if (idx+1) % 10 == 0 or idx+1 == len(q_sorted): statusStr = '-> Clustered {:,} of {:,} ({:.2f}%) genomes [ANI pairs: {:,}; clustered genomes: {:,}; clusters: {:,}].'.format( idx+1, len(q_sorted), float(idx+1)*100/len(q_sorted), max_ani_pairs, clustered_genomes, len(clusters)).ljust(96) sys.stdout.write('{}\r'.format(statusStr)) sys.stdout.flush() max_ani_pairs = 0 sys.stdout.write('\n') # write out selected cluster representative fout = open(cluster_rep_file, 'w') for gid in clusters: fout.write('{}\n'.format(gid)) fout.close() else: # read cluster reps from file self.logger.warning( 'Using previously determined cluster representatives.') for line in open(cluster_rep_file): gid = line.strip() clusters.add(gid) self.logger.info( 'Selected {:,} representative genomes for de novo species clusters.'.format(len(clusters))) return clusters
def run(self, gtdb_metadata_file, genomic_path_file): """Dereplicate GTDB species clusters using ANI/AF criteria.""" # create GTDB genome sets self.logger.info('Creating GTDB genome set.') genomes = Genomes() genomes.load_from_metadata_file(gtdb_metadata_file) genomes.load_genomic_file_paths(genomic_path_file) self.logger.info( ' - genome set has {:,} species clusters spanning {:,} genomes.'. format(len(genomes.sp_clusters), genomes.sp_clusters.total_num_genomes())) # get GTDB representatives from same genus self.logger.info('Identifying GTDB representatives in the same genus.') genus_gids = defaultdict(list) num_reps = 0 for gid in genomes: if not genomes[gid].gtdb_is_rep: continue gtdb_genus = genomes[gid].gtdb_taxa.genus genus_gids[gtdb_genus].append(gid) num_reps += 1 self.logger.info( f' - identified {len(genus_gids):,} genera spanning {num_reps:,} representatives' ) # get all intragenus comparisons self.logger.info('Determining all intragenus comparisons.') gid_pairs = [] for gids in genus_gids.values(): if len(gids) < 2: continue for g1, g2 in permutations(gids, 2): gid_pairs.append((g1, g2)) self.logger.info( f' - identified {len(gid_pairs):,} intragenus comparisons') # calculate FastANI ANI/AF between target genomes self.logger.info('Calculating ANI between intragenus pairs.') ani_af = self.fastani.pairs(gid_pairs, genomes.genomic_files, report_progress=True, check_cache=True) self.fastani.write_cache(silence=True) # write out results fout = open( os.path.join(self.output_dir, 'intragenus_ani_af_reps.tsv'), 'w') fout.write( 'Query ID\tQuery species\tTarget ID\tTarget species\tANI\tAF\n') for qid in ani_af: for rid in ani_af: ani, af = FastANI.symmetric_ani(ani_af, qid, rid) fout.write('{}\t{}\t{}\t{}\t{:.3f}\t{:.3f}\n'.format( qid, genomes[qid].gtdb_taxa.species, rid, genomes[rid].gtdb_taxa.species, ani, af)) fout.close()