def top_ani_score(self, prev_rid, sp_cids, cur_genomes):
        """Identify genome in cluster with highest balanced ANI score to representative genome."""

        # calculate ANI between representative and genomes in species cluster
        gid_pairs = []
        for cid in sp_cids:
            gid_pairs.append((cid, prev_rid))
            gid_pairs.append((prev_rid, cid))

        ani_af = self.fastani.pairs(gid_pairs,
                                    cur_genomes.genomic_files,
                                    report_progress=False,
                                    check_cache=True)

        # find genome with top ANI score
        max_score = -1e6
        max_rid = None
        max_ani = None
        max_af = None
        for cid in sp_cids:
            ani, af = symmetric_ani(ani_af, prev_rid, cid)

            cur_score = cur_genomes[cid].score_ani(ani)
            if cur_score > max_score:
                max_score = cur_score
                max_rid = cid
                max_ani = ani
                max_af = af

        return max_rid, max_score, max_ani, max_af
    def _nonrep_radius(self, unclustered_gids, rep_gids, ani_af_rep_vs_nonrep):
        """Calculate circumscription radius for unclustered, nontype genomes."""
        
        # set radius for genomes to default values
        nonrep_radius = {}
        for gid in unclustered_gids:
            nonrep_radius[gid] = GenomeRadius(ani = self.ani_sp, 
                                                     af = None,
                                                     neighbour_gid = None)

        # determine closest type ANI neighbour and restrict ANI radius as necessary
        ani_af = pickle.load(open(ani_af_rep_vs_nonrep, 'rb'))
        for nonrep_gid in unclustered_gids:
            if nonrep_gid not in ani_af:
                continue
                    
            for rep_gid in rep_gids:
                if rep_gid not in ani_af[nonrep_gid]:
                    continue
                    
                ani, af = symmetric_ani(ani_af, nonrep_gid, rep_gid)

                if ani > nonrep_radius[nonrep_gid].ani and af >= self.af_sp:
                    nonrep_radius[nonrep_gid] = GenomeRadius(ani = ani, 
                                                             af = af,
                                                             neighbour_gid = rep_gid)
                    
        self.logger.info('ANI circumscription radius: min={:.2f}, mean={:.2f}, max={:.2f}'.format(
                                min([d.ani for d in nonrep_radius.values()]), 
                                np_mean([d.ani for d in nonrep_radius.values()]), 
                                max([d.ani for d in nonrep_radius.values()])))
                        
        return nonrep_radius
    def _cluster(self, ani_af, non_reps, rep_radius):
        """Cluster non-representative to representative genomes using species specific ANI thresholds."""

        clusters = {}
        for rep_id in rep_radius:
            clusters[rep_id] = []

        num_clustered = 0
        non_reps = ['G003634515']  #***
        for idx, non_rid in enumerate(non_reps):
            if idx % 100 == 0:
                sys.stdout.write(
                    '==> Processed {:,} of {:,} genomes [no. clustered = {:,}].\r'
                    .format(idx + 1, len(non_reps), num_clustered))
                sys.stdout.flush()

            if non_rid not in ani_af:
                continue

            closest_rid = None
            closest_ani = 0
            closest_af = 0
            for rid in rep_radius:
                if rid not in ani_af[non_rid]:
                    continue

                ani, af = symmetric_ani(ani_af, rid, non_rid)

                if af >= self.af_sp:
                    if ani > closest_ani or (ani == closest_ani
                                             and af > closest_af):
                        closest_rid = rid
                        closest_ani = ani
                        closest_af = af

            if closest_rid:
                if closest_ani > rep_radius[closest_rid].ani:
                    num_clustered += 1
                    clusters[closest_rid].append(
                        self.ClusteredGenome(gid=non_rid,
                                             ani=closest_ani,
                                             af=closest_af))

        sys.stdout.write(
            '==> Processed {:,} of {:,} genomes [no. clustered = {:,}].\r'.
            format(idx + 1, len(non_reps), num_clustered))
        sys.stdout.flush()
        sys.stdout.write('\n')

        num_unclustered = len(non_reps) - num_clustered
        self.logger.info(
            'Assigned {:,} genomes to {:,} representatives; {:,} genomes remain unclustered.'
            .format(sum([len(clusters[rid]) for rid in clusters]),
                    len(clusters), num_unclustered))

        return clusters
    def symmetric_ani_cached(self, gid1, gid2, genome_file1, genome_file2):
        """Calculate symmetric ANI and AF between two genomes."""

        ani_af12 = self.fastani(gid1, gid2, genome_file1, genome_file2)
        ani_af21 = self.fastani(gid2, gid1, genome_file2, genome_file1)

        self.ani_cache[gid1][gid2] = ani_af12[2:]
        self.ani_cache[gid2][gid1] = ani_af21[2:]

        return symmetric_ani(self.ani_cache, gid1, gid2)
Example #5
0
    def _write_rep_info(self, 
                        clusters, 
                        cluster_sp_names, 
                        quality_metadata, 
                        genome_quality,
                        excluded_from_refseq_note,
                        ani_af,
                        output_file):
        """Write out information about selected representative genomes."""
                                            
        fout = open(output_file, 'w')
        fout.write('Species\tType genome\tNCBI assembly level\tNCBI genome category')
        fout.write('\tGenome size (bp)\tQuality score\tCompleteness (%)\tContamination (%)\tNo. scaffolds\tNo. contigs\tN50 contigs\tAmbiguous bases\tSSU count\tSSU length (bp)')
        fout.write('\tNo. genomes in cluster\tMean ANI\tMean AF\tMin ANI\tMin AF\tNCBI exclude from RefSeq\n')
        
        for gid in clusters:
            fout.write('%s\t%s\t%s\t%s' % (
                        cluster_sp_names[gid], 
                        gid, 
                        quality_metadata[gid].ncbi_assembly_level,
                        quality_metadata[gid].ncbi_genome_category))

            fout.write('\t%d\t%.2f\t%.2f\t%.2f\t%d\t%d\t%.1f\t%d\t%d\t%d' % (
                            quality_metadata[gid].genome_size,
                            genome_quality[gid], 
                            quality_metadata[gid].checkm_completeness,
                            quality_metadata[gid].checkm_contamination,
                            quality_metadata[gid].scaffold_count,
                            quality_metadata[gid].contig_count,
                            quality_metadata[gid].n50_contigs,
                            quality_metadata[gid].ambiguous_bases,
                            quality_metadata[gid].ssu_count,
                            quality_metadata[gid].ssu_length if quality_metadata[gid].ssu_length else 0))
                            
            anis = []
            afs = []
            for cluster_id in clusters[gid]:
                ani, af = symmetric_ani(ani_af, gid, cluster_id)
                anis.append(ani)
                afs.append(af)
            
            if anis:
                fout.write('\t%d\t%.1f\t%.2f\t%.1f\t%.2f\t%s\n' % (len(clusters[gid]),
                                                                    np_mean(anis), np_mean(afs),
                                                                    min(anis), min(afs),
                                                                    excluded_from_refseq_note.get(gid, '')))
            else:
                fout.write('\t%d\t%s\t%s\t%s\t%s\t%s\n' % (len(clusters[gid]),
                                                            'n/a', 'n/a', 'n/a', 'n/a',
                                                            excluded_from_refseq_note.get(gid, '')))
        fout.close()
Example #6
0
    def _cluster(self, genome_files, sp_clusters, rep_radius, user_genomes,
                 mash_anis):
        """Cluster User genomes to existing species clusters."""

        # assign User genomes to closest species cluster

        for idx, cur_gid in enumerate(user_genomes):
            # determine species cluster to calculate ANI between
            ani_pairs = []
            if cur_gid in mash_anis:
                for rep_gid in sp_clusters:
                    if mash_anis[cur_gid].get(rep_gid, 0) >= self.min_mash_ani:
                        ani_pairs.append((cur_gid, rep_gid))
                        ani_pairs.append((rep_gid, cur_gid))

                # determine if genome clusters with representative
                clustered = False
                if ani_pairs:
                    ani_af = self.fastani.pairs(ani_pairs,
                                                genome_files,
                                                report_progress=False)

                    closest_rep_gid = None
                    closest_rep_ani = 0
                    closest_rep_af = 0
                    for rep_gid in sp_clusters:
                        ani, af = symmetric_ani(ani_af, cur_gid, rep_gid)

                        if af >= self.af_sp:
                            if ani > closest_rep_ani or (ani == closest_rep_ani
                                                         and
                                                         af > closest_rep_af):
                                closest_rep_gid = rep_gid
                                closest_rep_ani = ani
                                closest_rep_af = af

                    if closest_rep_gid and closest_rep_ani > rep_radius[
                            closest_rep_gid].ani:
                        sp_clusters[closest_rep_gid].append(cur_gid)
                    else:
                        self.logger.warning(
                            'Failed to assign genome %s to representative.' %
                            cur_gid)

            statusStr = '-> Assigned %d of %d (%.2f%%) genomes.'.ljust(86) % (
                idx + 1, len(user_genomes),
                float(idx + 1) * 100 / len(user_genomes))
            sys.stdout.write('%s\r' % statusStr)
            sys.stdout.flush()
        sys.stdout.write('\n')
Example #7
0
    def _cluster(self, ani_af, nontype_gids, type_radius):
        """Cluster non-type genomes to type genomes using species specific ANI thresholds."""

        clusters = {}
        for rep_id in type_radius:
            clusters[rep_id] = []

        for idx, nontype_gid in enumerate(nontype_gids):
            if idx % 100 == 0:
                sys.stdout.write('==> Processed %d of %d genomes.\r' %
                                 (idx + 1, len(nontype_gids)))
                sys.stdout.flush()

            if nontype_gid not in ani_af:
                continue

            closest_type_gid = None
            closest_ani = 0
            closest_af = 0
            for type_gid in type_radius:
                if type_gid not in ani_af[nontype_gid]:
                    continue

                ani, af = symmetric_ani(ani_af, type_gid, nontype_gid)

                if af >= self.af_sp:
                    if ani > closest_ani or (ani == closest_ani
                                             and af > closest_af):
                        closest_type_gid = type_gid
                        closest_ani = ani
                        closest_af = af

            if closest_type_gid:
                if closest_ani > type_radius[closest_type_gid].ani:
                    clusters[closest_type_gid].append(
                        self.ClusteredGenome(gid=nontype_gid,
                                             ani=closest_ani,
                                             af=closest_af))

        sys.stdout.write('==> Processed %d of %d genomes.\r' %
                         (idx, len(nontype_gids)))
        sys.stdout.flush()
        sys.stdout.write('\n')

        self.logger.info(
            'Assigned %d genomes to representatives.' %
            sum([len(clusters[type_gid]) for type_gid in clusters]))

        return clusters
    def _rep_genome_stats(self, clusters, genome_files):
        """Calculate statistics relative to representative genome."""
        
        self.logger.info('Calculating statistics to cluster representatives:')
        stats = {}
        for idx, (rid, cids) in enumerate(clusters.items()):
            if len(cids) == 0:
                stats[rid] = self.RepStats(min_ani = -1,
                                            mean_ani = -1,
                                            std_ani = -1,
                                            median_ani = -1)
            else:
                # calculate ANI to representative genome
                gid_pairs = []
                for cid in cids:
                    gid_pairs.append((cid, rid))
                    gid_pairs.append((rid, cid))
                
                if False: #***
                    ani_af = self.fastani.pairs(gid_pairs, 
                                                            genome_files, 
                                                            report_progress=False)
                else:
                    ani_af = self.fastani.ani_cache
                
                # calculate statistics
                anis = [symmetric_ani(ani_af, cid, rid)[0] for cid in cids]

                stats[rid] = self.RepStats(min_ani = min(anis),
                                            mean_ani = np_mean(anis),
                                            std_ani = np_std(anis),
                                            median_ani = np_median(anis))
                                            
            statusStr = '-> Processing %d of %d (%.2f%%) clusters.'.ljust(86) % (
                                idx+1, 
                                len(clusters), 
                                float((idx+1)*100)/len(clusters))
            sys.stdout.write('%s\r' % statusStr)
            sys.stdout.flush()
                
        sys.stdout.write('\n')
            
        return stats
    def top_hits(self, species, rid, ani_af, genomes):
        """Report top 5 hits to species."""

        results = {}
        for qid in ani_af[rid]:
            ani, af = symmetric_ani(ani_af, rid, qid)
            results[qid] = (ani, af)

        self.logger.info(f'Closest 5 species to {species} ({rid}):')
        idx = 0
        for qid, (ani, af) in sorted(results.items(),
                                     key=lambda x: x[1],
                                     reverse=True):
            q_species = genomes[qid].gtdb_species
            self.logger.info(
                f'{q_species} ({qid}): ANI={ani:.1f}%, AF={af:.2f}')
            if idx == 5:
                break

            idx += 1
 def _find_multiple_reps(self, clusters, cluster_radius):
     """Determine number of non-rep genomes within ANI radius of multiple rep genomes.
     
     This method assumes the ANI cache contains all relevant ANI calculations between
     representative and non-representative genomes. This is the case once the de novo
     clustering has been performed.
     """
     
     self.logger.info('Determine number of non-rep genomes within ANI radius of multiple rep genomes.')
     
     # get clustered genomes IDs
     clustered_gids = []
     for rid in clusters:
         clustered_gids += clusters[rid]
         
     self.logger.info('Considering %d representatives and %d non-representative genomes.' % (len(clusters), len(clustered_gids)))
         
     nonrep_rep_count = defaultdict(set)
     for idx, gid in enumerate(clustered_gids):
         cur_ani_cache = self.fastani.ani_cache[gid]
         for rid in clusters:
             if rid not in cur_ani_cache:
                 continue
                 
             ani, af = symmetric_ani(self.fastani.ani_cache, gid, rid)
             if af >= self.af_sp and ani >= cluster_radius[rid].ani:
                 nonrep_rep_count[gid].add((rid, ani))
                 
         if (idx+1) % 100 == 0 or (idx+1) == len(clustered_gids):
             statusStr = '-> Processing %d of %d (%.2f%%) clusters genomes.'.ljust(86) % (
                                 idx+1, 
                                 len(clustered_gids), 
                                 float((idx+1)*100)/len(clustered_gids))
             sys.stdout.write('%s\r' % statusStr)
             sys.stdout.flush()
             
     sys.stdout.write('\n')
                 
     return nonrep_rep_count
    def merge_ani_radius(self, species, rid, merged_sp_cluster, genomic_files):
        """Determine ANI radius if species were merged."""

        self.logger.info(
            f'Calculating ANI from {species} to all genomes in merged species cluster.'
        )

        gid_pairs = []
        for gid in merged_sp_cluster:
            gid_pairs.append((rid, gid))
            gid_pairs.append((gid, rid))
        merged_ani_af1 = self.fastani.pairs(gid_pairs, genomic_files)

        ani_radius = 100
        for gid in merged_sp_cluster:
            ani, af = symmetric_ani(merged_ani_af1, rid, gid)
            if ani < ani_radius:
                ani_radius = ani
                af_radius = af
        self.logger.info(
            f'Merged cluster with {species} rep: ANI radius={ani_radius:.1f}%, AF={af_radius:.2f}'
        )
Example #12
0
    def _cluster_genomes(self, 
                            genome_files,
                            rep_genomes,
                            type_gids, 
                            passed_qc,
                            final_cluster_radius):
        """Cluster all non-type/representative genomes to selected type/representatives genomes."""

        all_reps = rep_genomes.union(type_gids)
        
        # calculate MASH distance between non-type/representative genomes and selected type/representatives genomes
        mash = Mash(self.cpus)
        
        mash_type_rep_sketch_file = os.path.join(self.output_dir, 'gtdb_rep_genomes.msh')
        type_rep_genome_list_file = os.path.join(self.output_dir, 'gtdb_rep_genomes.lst')
        mash.sketch(all_reps, genome_files, type_rep_genome_list_file, mash_type_rep_sketch_file)
        
        mash_none_rep_sketch_file = os.path.join(self.output_dir, 'gtdb_nonrep_genomes.msh')
        type_none_rep_file = os.path.join(self.output_dir, 'gtdb_nonrep_genomes.lst')
        mash.sketch(passed_qc - all_reps, genome_files, type_none_rep_file, mash_none_rep_sketch_file)

        # get Mash distances
        mash_dist_file = os.path.join(self.output_dir, 'gtdb_rep_vs_nonrep_genomes.dst')
        mash.dist(float(100 - self.min_mash_ani)/100, mash_type_rep_sketch_file, mash_none_rep_sketch_file, mash_dist_file)

        # read Mash distances
        mash_ani = mash.read_ani(mash_dist_file)
        
        # calculate ANI between non-type/representative genomes and selected type/representatives genomes
        clusters = {}
        for gid in all_reps:
            clusters[gid] = []
        
        genomes_to_cluster = passed_qc - set(clusters)
        ani_pairs = []
        for gid in genomes_to_cluster:
            if gid in mash_ani:
                for rep_gid in clusters:
                    if mash_ani[gid].get(rep_gid, 0) >= self.min_mash_ani:
                        ani_pairs.append((gid, rep_gid))
                        ani_pairs.append((rep_gid, gid))
                        
        self.logger.info('Calculating ANI between %d species clusters and %d unclustered genomes (%d pairs):' % (
                            len(clusters), 
                            len(genomes_to_cluster),
                            len(ani_pairs)))
        ani_af = self.fastani.pairs(ani_pairs, genome_files)

        # assign genomes to closest representatives 
        # that is within the representatives ANI radius
        self.logger.info('Assigning genomes to closest representative.')
        for idx, cur_gid in enumerate(genomes_to_cluster):
            closest_rep_gid = None
            closest_rep_ani = 0
            closest_rep_af = 0
            for rep_gid in clusters:
                ani, af = symmetric_ani(ani_af, cur_gid, rep_gid)
                
                if ani >= final_cluster_radius[rep_gid].ani and af >= self.af_sp:
                    if ani > closest_rep_ani or (ani == closest_rep_ani and af > closest_rep_af):
                        closest_rep_gid = rep_gid
                        closest_rep_ani = ani
                        closest_rep_af = af
                
            if closest_rep_gid:
                clusters[closest_rep_gid].append(self.ClusteredGenome(gid=cur_gid, 
                                                                        ani=closest_rep_ani, 
                                                                        af=closest_rep_af))
            else:
                self.logger.warning('Failed to assign genome %s to representative.' % cur_gid)
                if closest_rep_gid:
                    self.logger.warning(' ...closest_rep_gid = %s' % closest_rep_gid)
                    self.logger.warning(' ...closest_rep_ani = %.2f' % closest_rep_ani)
                    self.logger.warning(' ...closest_rep_af = %.2f' % closest_rep_af)
                    self.logger.warning(' ...closest rep radius = %.2f' % final_cluster_radius[closest_rep_gid].ani)
                else:
                    self.logger.warning(' ...no representative with an AF >%.2f identified.' % self.af_sp)
             
            statusStr = '-> Assigned %d of %d (%.2f%%) genomes.'.ljust(86) % (idx+1, 
                                                                                len(genomes_to_cluster), 
                                                                                float(idx+1)*100/len(genomes_to_cluster))
            sys.stdout.write('%s\r' % statusStr)
            sys.stdout.flush()
        sys.stdout.write('\n')

        return clusters, ani_af
    def run(self, 
                cur_gtdb_metadata_file,
                cur_genomic_path_file,
                qc_passed_file,
                ncbi_genbank_assembly_file,
                ltp_taxonomy_file,
                gtdb_type_strains_ledger,
                untrustworthy_type_ledger):
        """Resolve cases where a species has multiple genomes assembled from the type strain."""
        
        # get species in LTP reference database
        self.logger.info('Determining species defined in LTP reference database.')
        ltp_defined_species = self.ltp_defined_species(ltp_taxonomy_file)
        self.logger.info(f' ... identified {len(ltp_defined_species):,} species.')
        
        # create current GTDB genome sets
        self.logger.info('Creating current GTDB genome set.')
        cur_genomes = Genomes()
        cur_genomes.load_from_metadata_file(cur_gtdb_metadata_file,
                                                gtdb_type_strains_ledger=gtdb_type_strains_ledger,
                                                create_sp_clusters=False,
                                                uba_genome_file=None,
                                                qc_passed_file=qc_passed_file,
                                                ncbi_genbank_assembly_file=ncbi_genbank_assembly_file,
                                                untrustworthy_type_ledger=untrustworthy_type_ledger)
        cur_genomes.load_genomic_file_paths(cur_genomic_path_file)
        self.logger.info(f' ... current genome set contains {len(cur_genomes):,} genomes.')
        
        # update current genomes with GTDB-Tk classifications
        self.logger.info('Updating current genomes with GTDB-Tk classifications.')
        num_updated, num_ncbi_sp = cur_genomes.set_gtdbtk_classification(gtdbtk_classify_file, prev_genomes)
        self.logger.info(f' ... set GTDB taxa for {num_updated:,} genomes with {num_ncbi_sp:,} genomes using NCBI genus and species name.')
        
        # parsing genomes manually established to be untrustworthy as type
        self.logger.info('Determining genomes manually annotated as untrustworthy as type.')
        manual_untrustworthy_types = {}
        with open(untrustworthy_type_ledger) as f:
            header = f.readline().strip().split('\t')
            
            ncbi_sp_index = header.index('NCBI species')
            reason_index = header.index('Reason for declaring untrustworthy')
            
            for line in f:
                tokens = line.strip().split('\t')
                
                gid = canonical_gid(tokens[0])
                manual_untrustworthy_types[gid] = (tokens[ncbi_sp_index], tokens[reason_index])
        self.logger.info(f' ... identified {len(manual_untrustworthy_types):,} genomes manually annotated as untrustworthy as type.')

        # identify NCBI species with multiple genomes assembled from type strain of species
        self.logger.info('Determining number of type strain genomes in each NCBI species.')
        sp_type_strain_genomes = defaultdict(set)
        for gid in cur_genomes:
            if cur_genomes[gid].is_effective_type_strain():
                ncbi_sp = cur_genomes[gid].ncbi_taxa.species
                if ncbi_sp != 's__':
                    # yes, NCBI has genomes marked as assembled from type material
                    # that do not actually have a binomial species name
                    sp_type_strain_genomes[ncbi_sp].add(gid)

        multi_type_strains_sp = [ncbi_sp for ncbi_sp, gids in sp_type_strain_genomes.items() if len(gids) > 1]
        self.logger.info(f' ... identified {len(multi_type_strains_sp):,} NCBI species with multiple assemblies indicated as being type strain genomes.')
        
        # sort by number of genome assemblies
        self.logger.info('Calculating ANI between type strain genomes in each species.')
        
        fout = open(os.path.join(self.output_dir, 'multi_type_strain_species.tsv'), 'w')
        fout.write('NCBI species\tNo. type strain genomes\t>=99% ANI\tMean ANI\tStd ANI\tMean AF\tStd AF\tResolution\tGenome IDs\n')
        
        fout_genomes = open(os.path.join(self.output_dir, 'type_strain_genomes.tsv'), 'w')
        fout_genomes.write('Genome ID\tUntrustworthy\tNCBI species\tGTDB genus\tGTDB species\tLTP species\tConflict with prior GTDB assignment')
        fout_genomes.write('\tMean ANI\tStd ANI\tMean AF\tStd AF\tExclude from RefSeq\tNCBI taxonomy\tGTDB taxonomy\n')
        
        fout_unresolved = open(os.path.join(self.output_dir, 'unresolved_type_strain_genomes.tsv'), 'w')
        fout_unresolved.write('Genome ID\tNCBI species\tGTDB genus\tGTDB species\tLTP species')
        fout_unresolved.write('\tMean ANI\tStd ANI\tMean AF\tStd AF\tExclude from RefSeq\tNCBI taxonomy\tGTDB taxonomy\n')
        
        fout_high_divergence = open(os.path.join(self.output_dir, 'highly_divergent_type_strain_genomes.tsv'), 'w')
        fout_high_divergence.write('Genome ID\tNCBI species\tGTDB genus\tGTDB species\tLTP species\tMean ANI\tStd ANI\tMean AF\tStd AF\tExclude from RefSeq\tNCBI taxonomy\tGTDB taxonomy\n')
        
        fout_untrustworthy = open(os.path.join(self.output_dir, 'untrustworthy_type_material.tsv'), 'w')
        fout_untrustworthy.write('Genome ID\tNCBI species\tGTDB species\tLTP species\tReason for declaring untrustworthy\n')
        for gid in manual_untrustworthy_types:
            ncbi_sp, reason = manual_untrustworthy_types[gid]
            fout_untrustworthy.write('{}\t{}\t{}\t{}\t{}\n'.format(
                                        gid, 
                                        ncbi_sp, 
                                        cur_genomes[gid].gtdb_taxa.species,
                                        '<not tested>',
                                        'n/a',
                                        'Manual curation: ' + reason))
        
        processed = 0
        num_divergent = 0
        unresolved_sp_count = 0
        
        ncbi_ltp_resolved = 0
        intra_ani_resolved = 0
        ncbi_type_resolved = 0
        gtdb_family_resolved = 0
        gtdb_genus_resolved = 0
        gtdb_sp_resolved = 0
        ltp_resolved = 0
        
        use_pickled_results = False #***
        if use_pickled_results:
            self.logger.warning('Using previously calculated ANI results in: {}'.format(self.ani_pickle_dir))
        
        prev_gtdb_sp_conflicts = 0
        for ncbi_sp, type_gids in sorted(sp_type_strain_genomes.items(), key=lambda kv: len(kv[1])):
            if len(type_gids) == 1:
                continue
                
            status_str = '-> Processing {} with {:,} type strain genomes [{:,} of {:,} ({:.2f}%)].'.format(
                                ncbi_sp, 
                                len(type_gids),
                                processed+1, 
                                len(multi_type_strains_sp),
                                (processed+1)*100.0/len(multi_type_strains_sp)).ljust(128)
            sys.stdout.write('{}\r'.format(status_str))
            sys.stdout.flush()
            processed += 1

            # calculate ANI between type strain genomes
            ncbi_sp_str = ncbi_sp[3:].lower().replace(' ', '_')
            if not use_pickled_results: #***
                ani_af = self.fastani.pairwise(type_gids, cur_genomes.genomic_files)
                pickle.dump(ani_af, open(os.path.join(self.ani_pickle_dir, f'{ncbi_sp_str}.pkl'), 'wb'))
            else:
                ani_af = pickle.load(open(os.path.join(self.ani_pickle_dir, f'{ncbi_sp_str}.pkl'), 'rb'))
            
            anis = []
            afs = []
            gid_anis = defaultdict(lambda: {})
            gid_afs = defaultdict(lambda: {})
            all_similar = True
            for gid1, gid2 in combinations(type_gids, 2):
                ani, af = symmetric_ani(ani_af, gid1, gid2)
                if ani < 99 or af < 0.65:
                    all_similar = False
                    
                anis.append(ani)
                afs.append(af)
                
                gid_anis[gid1][gid2] = ani
                gid_anis[gid2][gid1] = ani
                
                gid_afs[gid1][gid2] = af
                gid_afs[gid2][gid1] = af
                
            note = 'All type strain genomes have ANI >99% and AF >65%.'
            unresolved_species = False
            
            # read LTP metadata for genomes
            ltp_metadata = self.parse_ltp_metadata(type_gids, cur_genomes)

            untrustworthy_gids = {}
            gtdb_resolved_sp_conflict = False
            if not all_similar:
                # need to establish which genomes are untrustworthy as type
                num_divergent += 1
                unresolved_species = True
                
                # write out highly divergent cases for manual inspection; 
                # these should be compared to the automated selection
                if np_mean(anis) < 95:
                    for gid in type_gids:
                        ltp_species = self.ltp_species(gid, ltp_metadata)
                            
                        fout_high_divergence.write('{}\t{}\t{}\t{}\t{}\t{:.2f}\t{:.3f}\t{:.3f}\t{:.4f}\t{}\t{}\t{}\n'.format(
                                                        gid,
                                                        ncbi_sp,
                                                        cur_genomes[gid].gtdb_taxa.genus,
                                                        cur_genomes[gid].gtdb_taxa.species,
                                                        ' / '.join(ltp_species),
                                                        np_mean(list(gid_anis[gid].values())),
                                                        np_std(list(gid_anis[gid].values())),
                                                        np_mean(list(gid_afs[gid].values())),
                                                        np_std(list(gid_afs[gid].values())),
                                                        cur_genomes[gid].excluded_from_refseq_note,
                                                        cur_genomes[gid].ncbi_taxa,
                                                        cur_genomes[gid].gtdb_taxa))
                
                # filter genomes marked as `untrustworthy as type` at NCBI and where the LTP
                # assignment also suggest the asserted type material is incorrect
                resolved, untrustworthy_gids = self.resolve_validated_untrustworthy_ncbi_genomes(gid_anis, 
                                                                                                    ncbi_sp, 
                                                                                                    type_gids, 
                                                                                                    ltp_metadata, 
                                                                                                    ltp_defined_species,
                                                                                                    cur_genomes)
                if resolved:
                    note = "Species resolved by removing genomes considered `untrustworthy as type` and with a LTP BLAST hit confirming the assembly is likely untrustworthy"
                    ncbi_ltp_resolved += 1

                # try to resolve by LTP 16S BLAST results
                if not resolved:
                    resolved, untrustworthy_gids = self.resolve_ltp_conflict(gid_anis, ncbi_sp, type_gids, ltp_metadata, 0)
                    if resolved:
                        note = 'Species resolved by identifying conflicting or lack of LTP BLAST results'
                        ltp_resolved += 1

                # try to resolve species using intra-specific ANI test
                if not resolved:
                    resolved, untrustworthy_gids = self.resolve_by_intra_specific_ani(gid_anis)
                    if resolved:
                        note = 'Species resolved by intra-specific ANI test'
                        intra_ani_resolved += 1

                # try to resolve by GTDB family assignment
                if not resolved:
                    resolved, untrustworthy_gids = self.resolve_gtdb_family(gid_anis, ncbi_sp, type_gids, cur_genomes)
                    if resolved:
                        note = 'Species resolved by consulting GTDB family classifications'
                        gtdb_family_resolved += 1
                
                # try to resolve by GTDB genus assignment
                if not resolved:
                    resolved, untrustworthy_gids = self.resolve_gtdb_genus(gid_anis, ncbi_sp, type_gids, cur_genomes)
                    if resolved:
                        note = 'Species resolved by consulting GTDB genus classifications'
                        gtdb_genus_resolved += 1
                           
                # try to resolve by GTDB species assignment
                if not resolved:
                    resolved, untrustworthy_gids = self.resolve_gtdb_species(gid_anis, ncbi_sp, type_gids, cur_genomes)
                    if resolved:
                        note = 'Species resolved by consulting GTDB species classifications'
                        gtdb_sp_resolved += 1
                        
                # try to resolve by considering genomes annotated as type material at NCBI,
                # which includes considering if genomes are marked as untrustworthy as type
                if not resolved:
                    resolved, untrustworthy_gids = self.resolve_by_ncbi_types(gid_anis, type_gids, cur_genomes)
                    if resolved:
                        note = 'Species resolved by consulting NCBI assembled from type metadata'
                        ncbi_type_resolved += 1

                if resolved:
                    unresolved_species = False
                    
                    # check if type strain genomes marked as trusted or untrusted conflict
                    # with current GTDB species assignment
                    untrustworthy_gtdb_sp_match = False
                    trusted_gtdb_sp_match = False
                    for gid in type_gids:
                        gtdb_canonical_epithet = canonical_taxon(specific_epithet(cur_genomes[gid].gtdb_taxa.species))
                        if gtdb_canonical_epithet == specific_epithet(ncbi_sp):
                            if gid in untrustworthy_gids:
                                untrustworthy_gtdb_sp_match = True
                            else:
                                trusted_gtdb_sp_match = True

                    if untrustworthy_gtdb_sp_match and not trusted_gtdb_sp_match:
                        prev_gtdb_sp_conflicts += 1
                        gtdb_resolved_sp_conflict = True

                    # write results to file
                    for gid, reason in untrustworthy_gids.items():
                        ltp_species = self.ltp_species(gid, ltp_metadata)
                        
                        if 'untrustworthy as type' in cur_genomes[gid].excluded_from_refseq_note:
                            reason += "; considered `untrustworthy as type` at NCBI"
                        fout_untrustworthy.write('{}\t{}\t{}\t{}\t{}\n'.format(gid,
                                                                                ncbi_sp,
                                                                                cur_genomes[gid].gtdb_taxa.species,
                                                                                ' / '.join(ltp_species),
                                                                                reason))
                                                                                
                        # Sanity check that if the untrustworthy genome has an LTP to only the
                        # expected species, that all other genomes also have a hit to the 
                        # expected species (or potentially no hit). Otherwise, more consideration
                        # should be given to the genome with the conflicting LTP hit.
                        if len(ltp_species) == 1 and ncbi_sp in ltp_species:
                            other_sp = set()
                            for test_gid in type_gids:
                                ltp_species = self.ltp_species(test_gid, ltp_metadata)
                                if ltp_species and ncbi_sp not in ltp_species:
                                    other_sp.update(ltp_species)
                                
                            if other_sp:
                                self.logger.warning(f'Genome {gid} marked as untrustworthy, but this conflicts with high confidence LTP 16S rRNA assignment.')
                                
                    num_ncbi_untrustworthy = sum([1 for gid in type_gids if 'untrustworthy as type' in cur_genomes[gid].excluded_from_refseq_note])
                    if num_ncbi_untrustworthy != len(type_gids):
                        for gid in type_gids:
                            if (gid not in untrustworthy_gids 
                                and 'untrustworthy as type' in cur_genomes[gid].excluded_from_refseq_note):
                                self.logger.warning("Retaining genome {} from {} despite being marked as `untrustworthy as type` at NCBI [{:,} of {:,} considered untrustworthy].".format(
                                                        gid, 
                                                        ncbi_sp,
                                                        num_ncbi_untrustworthy,
                                                        len(type_gids)))
                else:
                    note = 'Species is unresolved; manual curation is required!'
                    unresolved_sp_count += 1
                    
                if unresolved_species:
                    for gid in type_gids:
                        ltp_species = self.ltp_species(gid, ltp_metadata)
                            
                        fout_unresolved.write('{}\t{}\t{}\t{}\t{}\t{:.2f}\t{:.3f}\t{:.3f}\t{:.4f}\t{}\t{}\t{}\n'.format(
                                    gid,
                                    ncbi_sp,
                                    cur_genomes[gid].gtdb_taxa.genus,
                                    cur_genomes[gid].gtdb_taxa.species,
                                    ' / '.join(ltp_species),
                                    np_mean(list(gid_anis[gid].values())),
                                    np_std(list(gid_anis[gid].values())),
                                    np_mean(list(gid_afs[gid].values())),
                                    np_std(list(gid_afs[gid].values())),
                                    cur_genomes[gid].excluded_from_refseq_note,
                                    cur_genomes[gid].ncbi_taxa,
                                    cur_genomes[gid].gtdb_taxa))

            for gid in type_gids:
                ltp_species = self.ltp_species(gid, ltp_metadata)
                    
                fout_genomes.write('{}\t{}\t{}\t{}\t{}\t{}\t{}\t{:.2f}\t{:.3f}\t{:.3f}\t{:.4f}\t{}\t{}\t{}\n'.format(
                            gid,
                            gid in untrustworthy_gids,
                            ncbi_sp,
                            cur_genomes[gid].gtdb_taxa.genus,
                            cur_genomes[gid].gtdb_taxa.species,
                            ' / '.join(ltp_species),
                            gtdb_resolved_sp_conflict,
                            np_mean(list(gid_anis[gid].values())),
                            np_std(list(gid_anis[gid].values())),
                            np_mean(list(gid_afs[gid].values())),
                            np_std(list(gid_afs[gid].values())),
                            cur_genomes[gid].excluded_from_refseq_note,
                            cur_genomes[gid].ncbi_taxa,
                            cur_genomes[gid].gtdb_taxa))

            fout.write('{}\t{}\t{}\t{:.2f}\t{:.3f}\t{:.3f}\t{:.4f}\t{}\t{}\n'.format(
                        ncbi_sp,
                        len(type_gids),
                        all_similar,
                        np_mean(anis),
                        np_std(anis),
                        np_mean(afs),
                        np_std(afs),
                        note,
                        ', '.join(type_gids)))

        sys.stdout.write('\n')
        fout.close()
        fout_unresolved.close()
        fout_high_divergence.close()
        fout_genomes.close()
        fout_untrustworthy.close()
        
        self.logger.info(f'Identified {num_divergent:,} species with 1 or more divergent type strain genomes.')
        self.logger.info(f' ... resolved {ncbi_ltp_resolved:,} species by removing NCBI `untrustworthy as type` genomes with a conflicting LTP 16S rRNA classifications.')
        self.logger.info(f' ... resolved {ltp_resolved:,} species by considering conflicting LTP 16S rRNA classifications.')
        self.logger.info(f' ... resolved {intra_ani_resolved:,} species by considering intra-specific ANI values.')
        self.logger.info(f' ... resolved {gtdb_family_resolved:,} species by considering conflicting GTDB family classifications.')
        self.logger.info(f' ... resolved {gtdb_genus_resolved:,} species by considering conflicting GTDB genus classifications.')
        self.logger.info(f' ... resolved {gtdb_sp_resolved:,} species by considering conflicting GTDB species classifications.')
        self.logger.info(f' ... resolved {ncbi_type_resolved:,} species by considering type material designations at NCBI.')

        if unresolved_sp_count > 0:
            self.logger.warning(f'There are {unresolved_sp_count:,} unresolved species with multiple type strain genomes.')
            self.logger.warning('These should be handled before proceeding with the next step of GTDB species updating.')
            self.logger.warning("This can be done by manual curation and adding genomes to 'untrustworthy_type_ledger'.")
        
        self.logger.info(f'Identified {prev_gtdb_sp_conflicts:,} cases where resolved type strain conflicts with prior GTDB assignment.')
    def run(self, gtdb_metadata_file, genome_path_file, species1, species2):
        """Produce information relevant to merging two sister species."""

        # read GTDB species clusters
        self.logger.info('Reading GTDB species clusters.')
        genomes = Genomes()
        genomes.load_from_metadata_file(gtdb_metadata_file)
        genomes.load_genomic_file_paths(genome_path_file)
        self.logger.info(
            ' ... identified {:,} species clusters spanning {:,} genomes.'.
            format(len(genomes.sp_clusters),
                   genomes.sp_clusters.total_num_genomes()))

        # find species of interest
        gid1 = None
        gid2 = None
        for gid, species in genomes.sp_clusters.species():
            if species == species1:
                gid1 = gid
            elif species == species2:
                gid2 = gid

        if gid1 is None:
            self.logger.error(
                f'Unable to find representative genome for {species1}.')
            sys.exit(-1)

        if gid2 is None:
            self.logger.error(
                f'Unable to find representative genome for {species2}.')
            sys.exit(-1)

        self.logger.info(' ... identified {:,} genomes in {}.'.format(
            len(genomes.sp_clusters[gid1]), species1))
        self.logger.info(' ... identified {:,} genomes in {}.'.format(
            len(genomes.sp_clusters[gid2]), species2))

        # calculate ANI between all genome in genus
        genus1 = genomes[gid1].gtdb_genus
        genus2 = genomes[gid2].gtdb_genus
        if genus1 != genus2:
            self.logger.error(
                f'Genomes must be from same genus: {genus1} {genus2}')
            sys.exit(-1)

        self.logger.info(f'Identifying {genus1} species representatives.')
        reps_in_genera = set()
        for rid in genomes.sp_clusters:
            if genomes[rid].gtdb_genus == genus1:
                reps_in_genera.add(rid)

        self.logger.info(
            f' ... identified {len(reps_in_genera):,} representatives.')

        # calculate ANI between genomes
        self.logger.info(f'Calculating ANI to {species1}.')
        gid_pairs = []
        for gid in reps_in_genera:
            if gid != gid1:
                gid_pairs.append((gid1, gid))
                gid_pairs.append((gid, gid1))
        ani_af1 = self.fastani.pairs(gid_pairs, genomes.genomic_files)

        self.logger.info(f'Calculating ANI to {species2}.')
        gid_pairs = []
        for gid in reps_in_genera:
            if gid != gid2:
                gid_pairs.append((gid2, gid))
                gid_pairs.append((gid, gid2))
        ani_af2 = self.fastani.pairs(gid_pairs, genomes.genomic_files)

        # report results
        ani12, af12 = ani_af1[gid1][gid2]
        ani21, af21 = ani_af2[gid2][gid1]
        ani, af = symmetric_ani(ani_af1, gid1, gid2)

        self.logger.info(
            f'{species1} ({gid1}) -> {species2} ({gid2}): ANI={ani12:.1f}%, AF={af12:.2f}'
        )
        self.logger.info(
            f'{species2} ({gid2}) -> {species1} ({gid1}): ANI={ani21:.1f}%, AF={af21:.2f}'
        )
        self.logger.info(f'Max. ANI={ani:.1f}%, Max. AF={af:.2f}')

        # report top hits
        self.top_hits(species1, gid1, ani_af1, genomes)
        self.top_hits(species2, gid2, ani_af2, genomes)

        # calculate ANI from species to all genomes in merged species cluster
        merged_sp_cluster = genomes.sp_clusters[gid1].union(
            genomes.sp_clusters[gid2])
        self.merge_ani_radius(species1, gid1, merged_sp_cluster,
                              genomes.genomic_files)
        self.merge_ani_radius(species2, gid2, merged_sp_cluster,
                              genomes.genomic_files)
    def _selected_rep_genomes(self,
                                cur_genomes,
                                nonrep_radius, 
                                unclustered_qc_gids, 
                                mash_ani):
        """Select de novo representatives for species clusters in a greedy fashion using species-specific ANI thresholds."""

        # sort genomes by quality score
        self.logger.info('Selecting de novo representatives in a greedy manner based on quality.')
        q = {gid:cur_genomes[gid].score_type_strain() for gid in unclustered_qc_gids}
        q_sorted = sorted(q.items(), key=lambda kv: (kv[1], kv[0]), reverse=True)

        # greedily determine representatives for new species clusters
        cluster_rep_file = os.path.join(self.output_dir, 'cluster_reps.tsv')
        clusters = set()
        if not os.path.exists(cluster_rep_file):
            clustered_genomes = 0
            max_ani_pairs = 0
            for idx, (cur_gid, _score) in enumerate(q_sorted):

                # determine reference genomes to calculate ANI between
                ani_pairs = []
                if cur_gid in mash_ani:
                    for rep_gid in clusters:
                        if mash_ani[cur_gid].get(rep_gid, 0) >= self.min_mash_ani:
                            ani_pairs.append((cur_gid, rep_gid))
                            ani_pairs.append((rep_gid, cur_gid))

                # determine if genome clusters with representative
                clustered = False
                if ani_pairs:
                    if len(ani_pairs) > max_ani_pairs:
                        max_ani_pairs = len(ani_pairs)
                    
                    ani_af = self.fastani.pairs(ani_pairs, cur_genomes.genomic_files, report_progress=False)

                    closest_rep_gid = None
                    closest_rep_ani = 0
                    closest_rep_af = 0
                    for rep_gid in clusters:
                        ani, af = symmetric_ani(ani_af, cur_gid, rep_gid)

                        if af >= self.af_sp:
                            if ani > closest_rep_ani or (ani == closest_rep_ani and af > closest_rep_af):
                                closest_rep_gid = rep_gid
                                closest_rep_ani = ani
                                closest_rep_af = af

                        if ani > nonrep_radius[cur_gid].ani and af >= self.af_sp:
                            nonrep_radius[cur_gid] = GenomeRadius(ani = ani, 
                                                                         af = af,
                                                                         neighbour_gid = rep_gid)
                                                                         
                    if closest_rep_gid and closest_rep_ani > nonrep_radius[closest_rep_gid].ani:
                        clustered = True
                    
                if not clustered:
                    # genome is a new species cluster representative
                    clusters.add(cur_gid)
                else:
                    clustered_genomes += 1
                
                if (idx+1) % 10 == 0 or idx+1 == len(q_sorted):
                    statusStr = '-> Clustered {:,} of {:,} ({:.2f}%) genomes [ANI pairs: {:,}; clustered genomes: {:,}; clusters: {:,}].'.format(
                                    idx+1, 
                                    len(q_sorted), 
                                    float(idx+1)*100/len(q_sorted),
                                    max_ani_pairs,
                                    clustered_genomes,
                                    len(clusters)).ljust(96)
                    sys.stdout.write('{}\r'.format(statusStr))
                    sys.stdout.flush()
                    max_ani_pairs = 0
            sys.stdout.write('\n')
            
            # write out selected cluster representative
            fout = open(cluster_rep_file, 'w')
            for gid in clusters:
                fout.write('{}\n'.format(gid))
            fout.close()
        else:
            # read cluster reps from file
            self.logger.warning('Using previously determined cluster representatives.')
            for line in open(cluster_rep_file):
                gid = line.strip()
                clusters.add(gid)
                
        self.logger.info('Selected {:,} representative genomes for de novo species clusters.'.format(len(clusters)))
        
        return clusters
    def _cluster_genomes(self,
                            cur_genomes,
                            de_novo_rep_gids,
                            named_rep_gids, 
                            final_cluster_radius):
        """Cluster new representatives to representatives of named GTDB species clusters."""
        
        all_reps = de_novo_rep_gids.union(named_rep_gids)
        nonrep_gids = set(cur_genomes.genomes.keys()) - all_reps
        self.logger.info('Clustering {:,} genomes to {:,} named and de novo representatives.'.format(
                            len(nonrep_gids), len(all_reps)))

        if True: #***
            # calculate MASH distance between non-representatives and representatives genomes
            mash = Mash(self.cpus)
            
            mash_rep_sketch_file = os.path.join(self.output_dir, 'gtdb_rep_genomes.msh')
            rep_genome_list_file = os.path.join(self.output_dir, 'gtdb_rep_genomes.lst')
            mash.sketch(all_reps, cur_genomes.genomic_files, rep_genome_list_file, mash_rep_sketch_file)

            mash_none_rep_sketch_file = os.path.join(self.output_dir, 'gtdb_nonrep_genomes.msh')
            non_rep_file = os.path.join(self.output_dir, 'gtdb_nonrep_genomes.lst')
            mash.sketch(nonrep_gids, cur_genomes.genomic_files, non_rep_file, mash_none_rep_sketch_file)

            # get Mash distances
            mash_dist_file = os.path.join(self.output_dir, 'gtdb_rep_vs_nonrep_genomes.dst')
            mash.dist(float(100 - self.min_mash_ani)/100, 
                        mash_rep_sketch_file, 
                        mash_none_rep_sketch_file, 
                        mash_dist_file)

            # read Mash distances
            mash_ani = mash.read_ani(mash_dist_file)
            
            # calculate ANI between non-representatives and representatives genomes
            clusters = {}
            for gid in all_reps:
                clusters[gid] = []

            if False: #***
                mash_ani_pairs = []
                for gid in nonrep_gids:
                    if gid in mash_ani:
                        for rid in clusters:
                            if mash_ani[gid].get(rid, 0) >= self.min_mash_ani:
                                n_gid = cur_genomes.user_uba_id_map.get(gid, gid)
                                n_rid = cur_genomes.user_uba_id_map.get(rid, rid)
                                if n_gid != n_rid:
                                    mash_ani_pairs.append((n_gid, n_rid))
                                    mash_ani_pairs.append((n_rid, n_gid))
                                    
            mash_ani_pairs = []
            for qid in mash_ani:
                n_qid = cur_genomes.user_uba_id_map.get(qid, qid)
                assert n_qid in nonrep_gids
                
                for rid in mash_ani[qid]:
                    n_rid = cur_genomes.user_uba_id_map.get(rid, rid)
                    assert n_rid in all_reps
                    
                    if (mash_ani[qid][rid] >= self.min_mash_ani
                        and n_qid != n_rid):
                        mash_ani_pairs.append((n_qid, n_rid))
                        mash_ani_pairs.append((n_rid, n_qid))
                            
            self.logger.info('Calculating ANI between {:,} species clusters and {:,} unclustered genomes ({:,} pairs):'.format(
                                len(clusters), 
                                len(nonrep_gids),
                                len(mash_ani_pairs)))
            ani_af = self.fastani.pairs(mash_ani_pairs, cur_genomes.genomic_files)

            # assign genomes to closest representatives 
            # that is within the representatives ANI radius
            self.logger.info('Assigning genomes to closest representative.')
            for idx, cur_gid in enumerate(nonrep_gids):
                closest_rep_gid = None
                closest_rep_ani = 0
                closest_rep_af = 0
                for rep_gid in clusters:
                    ani, af = symmetric_ani(ani_af, cur_gid, rep_gid)
                    
                    if ani >= final_cluster_radius[rep_gid].ani and af >= self.af_sp:
                        if ani > closest_rep_ani or (ani == closest_rep_ani and af > closest_rep_af):
                            closest_rep_gid = rep_gid
                            closest_rep_ani = ani
                            closest_rep_af = af
                    
                if closest_rep_gid:
                    clusters[closest_rep_gid].append(ClusteredGenome(gid=cur_gid, 
                                                                            ani=closest_rep_ani, 
                                                                            af=closest_rep_af))
                else:
                    self.logger.warning('Failed to assign genome {} to representative.'.format(cur_gid))
                    if closest_rep_gid:
                        self.logger.warning(' ...closest_rep_gid = {}'.format(closest_rep_gid))
                        self.logger.warning(' ...closest_rep_ani = {:.2f}'.format(closest_rep_ani))
                        self.logger.warning(' ...closest_rep_af = {:.2f}'.format(closest_rep_af))
                        self.logger.warning(' ...closest rep radius = {:.2f}'.format(final_cluster_radius[closest_rep_gid].ani))
                    else:
                        self.logger.warning(' ...no representative with an AF >{:.2f} identified.'.format(self.af_sp))
                 
                statusStr = '-> Assigned {:,} of {:,} ({:.2f}%) genomes.'.format(idx+1, 
                                                                                    len(nonrep_gids), 
                                                                                    float(idx+1)*100/len(nonrep_gids)).ljust(86)
                sys.stdout.write('{}\r'.format(statusStr))
                sys.stdout.flush()
            sys.stdout.write('\n')
            
            pickle.dump(clusters, open(os.path.join(self.output_dir, 'clusters.pkl'), 'wb'))
            pickle.dump(ani_af, open(os.path.join(self.output_dir, 'ani_af_rep_vs_nonrep.de_novo.pkl'), 'wb'))
        else:
            self.logger.warning('Using previously calculated results in: {}'.format('clusters.pkl'))
            clusters = pickle.load(open(os.path.join(self.output_dir, 'clusters.pkl'), 'rb'))
            
            self.logger.warning('Using previously calculated results in: {}'.format('ani_af_rep_vs_nonrep.de_novo.pkl'))
            ani_af = pickle.load(open(os.path.join(self.output_dir, 'ani_af_rep_vs_nonrep.de_novo.pkl'), 'rb'))

        return clusters, ani_af
    def write_synonym_table(self, synonyms, cur_genomes, ani_af,
                            sp_priority_ledger):
        """Create table indicating species names that should be considered synonyms."""

        sp_priority_mngr = SpeciesPriorityManager(sp_priority_ledger)

        out_file = os.path.join(self.output_dir, 'synonyms.tsv')
        fout = open(out_file, 'w')
        fout.write(
            'NCBI species\tGTDB species\tRepresentative\tStrain IDs\tRepresentative type sources\tPriority year\tGTDB type species\tGTDB type strain\tNCBI assembly type'
        )
        fout.write(
            '\tNCBI synonym\tGTDB synonym\tSynonym genome\tSynonym strain IDs\tSynonym type sources\tPriority year\tGTDB type species\tGTDB type strain\tSynonym NCBI assembly type'
        )
        fout.write('\tANI\tAF\tWarnings\n')

        incorrect_priority = 0
        failed_type_strain_priority = 0
        for rid, synonym_ids in synonyms.items():
            for gid in synonym_ids:
                ani, af = symmetric_ani(ani_af, rid, gid)

                fout.write('{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}'.format(
                    cur_genomes[rid].ncbi_taxa.species,
                    cur_genomes[rid].gtdb_taxa.species, rid,
                    ','.join(sorted(cur_genomes[rid].strain_ids())),
                    ','.join(sorted(
                        cur_genomes[rid].gtdb_type_sources())).upper().replace(
                            'STRAININFO',
                            'StrainInfo'), cur_genomes[rid].year_of_priority(),
                    cur_genomes[rid].is_gtdb_type_species(),
                    cur_genomes[rid].is_gtdb_type_strain(),
                    cur_genomes[rid].ncbi_type_material))
                fout.write('\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}'.format(
                    cur_genomes[gid].ncbi_taxa.species,
                    cur_genomes[gid].gtdb_taxa.species, gid,
                    ','.join(sorted(cur_genomes[gid].strain_ids())),
                    ','.join(sorted(
                        cur_genomes[gid].gtdb_type_sources())).upper().replace(
                            'STRAININFO',
                            'StrainInfo'), cur_genomes[gid].year_of_priority(),
                    cur_genomes[gid].is_gtdb_type_species(),
                    cur_genomes[gid].is_gtdb_type_strain(),
                    cur_genomes[gid].ncbi_type_material))
                fout.write('\t{:.3f}\t{:.4f}'.format(ani, af))

                if cur_genomes[rid].is_gtdb_type_strain(
                ) and cur_genomes[gid].is_gtdb_type_strain():
                    priority_gid, note = sp_priority_mngr.priority(
                        cur_genomes, rid, gid)
                    if priority_gid != rid:
                        incorrect_priority += 1
                        fout.write('\tIncorrect priority: {}'.format(note))
                elif not cur_genomes[rid].is_gtdb_type_strain(
                ) and cur_genomes[gid].is_gtdb_type_strain():
                    failed_type_strain_priority += 1
                    fout.write('\tFailed to prioritize type strain of species')

                fout.write('\n')

        if incorrect_priority:
            self.logger.warning(
                f'Identified {incorrect_priority:,} synonyms with incorrect priority.'
            )

        if failed_type_strain_priority:
            self.logger.warning(
                f'Identified {failed_type_strain_priority:,} synonyms that failed to priotize the type strain of the species.'
            )
    def _intragenus_pairwise_ani(self, clusters, species, genome_files, gtdb_taxonomy):
        """Determine pairwise intra-genus ANI between representative genomes."""
        
        self.logger.info('Calculating pairwise intra-genus ANI values between GTDB representatives.')
        
        min_mash_ani = 50
        
        # get genus for each representative
        genus = {}
        for rid, sp in species.items():
            genus[rid] = sp.split()[0].replace('s__', '')
            assert genus[rid] == gtdb_taxonomy[rid][5].replace('g__', '')

        # get pairs above Mash threshold
        self.logger.info('Determining intra-genus genome pairs.')
        ani_pairs = []
        for qid in clusters:
            for rid in clusters:
                if qid == rid:
                    continue
                
                genusA = genus[qid]
                genusB = genus[rid]
                if genusA != genusB:
                    continue
            
                ani_pairs.append((qid, rid))
                ani_pairs.append((rid, qid))

        self.logger.info('Identified {:,} intra-genus genome pairs.'.format(len(ani_pairs)))
        
        # calculate ANI between pairs
        self.logger.info('Calculating ANI between {:,} genome pairs:'.format(len(ani_pairs)))
        if True: #***
            ani_af = self.fastani.pairs(ani_pairs, genome_files)
            pickle.dump(ani_af, open(os.path.join(self.output_dir, 'type_genomes_ani_af.pkl'), 'wb'))
        else:
            ani_af = pickle.load(open(os.path.join(self.output_dir, 'type_genomes_ani_af.pkl'), 'rb'))
            
        # find closest intra-genus pair for each rep
        fout = open(os.path.join(self.output_dir, 'intra_genus_pairwise_ani.tsv'), 'w')
        fout.write('Genus\tSpecies 1\tGenome ID 1\tSpecies 2\tGenome ID2\tANI\tAF\n')
        closest_intragenus_rep = {}
        for qid in clusters:
            genusA = genus[qid]
            
            closest_ani = 0
            closest_af = 0
            closest_gid = None
            for rid in clusters:
                if qid == rid:
                    continue
                    
                genusB = genus[rid]
                if genusA != genusB:
                    continue
                    
                ani, af = ('n/a', 'n/a')
                if qid in ani_af and rid in ani_af[qid]:
                    ani, af = symmetric_ani(ani_af, qid, rid)

                    if ani > closest_ani:
                        closest_ani = ani
                        closest_af = af
                        closest_gid = rid
                        
                fout.write('{}\t{}\t{}\t{}\t{}\t{}\t{}\n'.format(
                            genusA,
                            species[qid],
                            qid,
                            species[rid],
                            rid,
                            ani,
                            af))
                            
            if closest_gid:
                closest_intragenus_rep[qid] = (closest_gid, closest_ani, closest_af)

        fout.close()
        
        # write out closest intra-genus species to each representative
        fout = open(os.path.join(self.output_dir, 'closest_intragenus_rep.tsv'), 'w')
        fout.write('Genome ID\tSpecies\tIntra-genus neighbour\tIntra-genus species\tANI\tAF\n')
        for qid in closest_intragenus_rep:
            rid, ani, af = closest_intragenus_rep[qid]
            
            fout.write('%s\t%s\t%s\t%s\t%.2f\t%.3f\n' % (
                            qid,
                            species[qid],
                            rid,
                            species[rid],
                            ani,
                            af))
        fout.close()
    def identify_misclassified_genomes_ani(self, cur_genomes, cur_clusters):
        """Identify genomes with erroneous NCBI species assignments, based on ANI to type strain genomes."""

        forbidden_names = set(['cyanobacterium'])

        # get mapping from genomes to their representatives
        gid_to_rid = {}
        for rid, cids in cur_clusters.items():
            for cid in cids:
                gid_to_rid[cid] = rid

        # get genomes with NCBI species assignment
        ncbi_sp_gids = defaultdict(list)
        for gid in cur_genomes:
            ncbi_species = cur_genomes[gid].ncbi_taxa.species
            ncbi_specific = specific_epithet(ncbi_species)

            if ncbi_species != 's__' and ncbi_specific not in forbidden_names:
                ncbi_sp_gids[ncbi_species].append(gid)

        # get NCBI species anchored by a type strain genome
        ncbi_type_anchored_species = {}
        for rid, cids in cur_clusters.items():
            if cur_genomes[rid].is_effective_type_strain():
                ncbi_type_species = cur_genomes[rid].ncbi_taxa.species
                if ncbi_type_species != 's__':
                    ncbi_type_anchored_species[ncbi_type_species] = rid
        self.logger.info(
            ' - identified {:,} NCBI species anchored by a type strain genome.'
            .format(len(ncbi_type_anchored_species)))

        # identify genomes with erroneous NCBI species assignments
        fout = open(
            os.path.join(
                self.output_dir, 'ncbi_misclassified_sp.ani_{}.tsv'.format(
                    self.ani_ncbi_erroneous)), 'w')
        fout.write(
            'Genome ID\tNCBI species\tGenome cluster\tType species cluster\tANI to type strain\tAF to type strain\n'
        )

        misclassified_gids = set()
        for idx, (ncbi_species,
                  species_gids) in enumerate(ncbi_sp_gids.items()):
            if ncbi_species not in ncbi_type_anchored_species:
                continue

            type_rid = ncbi_type_anchored_species[ncbi_species]
            gids_to_check = []
            for gid in species_gids:
                cur_rid = gid_to_rid[gid]
                if type_rid != cur_rid:
                    # need to check genome as it has the same NCBI species name
                    # as a type strain genome, but resides in a different GTDB
                    # species cluster
                    gids_to_check.append(gid)

            if len(gids_to_check) > 0:
                gid_pairs = []
                for gid in gids_to_check:
                    gid_pairs.append((type_rid, gid))
                    gid_pairs.append((gid, type_rid))

                statusStr = '-> Establishing erroneous assignments for {} [ANI pairs: {:,}; {:,} of {:,} species].'.format(
                    ncbi_species, len(gid_pairs), idx + 1,
                    len(ncbi_sp_gids)).ljust(96)
                sys.stdout.write('{}\r'.format(statusStr))
                sys.stdout.flush()

                ani_af = self.fastani.pairs(gid_pairs,
                                            cur_genomes.genomic_files,
                                            report_progress=False,
                                            check_cache=True)

                for gid in gids_to_check:
                    ani, af = symmetric_ani(ani_af, type_rid, gid)
                    if ani < self.ani_ncbi_erroneous:
                        misclassified_gids.add(gid)
                        fout.write('{}\t{}\t{}\t{}\t{:.2f}\t{:.3f}\n'.format(
                            gid, ncbi_species, gid_to_rid[gid], type_rid, ani,
                            af))

        sys.stdout.write('\n')
        fout.close()

        misclassified_species = set(
            [cur_genomes[gid].ncbi_taxa.species for gid in misclassified_gids])
        self.logger.info(
            ' - identified {:,} genomes from {:,} species as having misclassified NCBI species assignments.'
            .format(len(misclassified_gids), len(misclassified_species)))

        return misclassified_gids
    def _pairwise_stats(self, clusters, genome_files, rep_stats):
        """Calculate statistics for all pairwise comparisons in a species cluster."""
        
        self.logger.info('Restricting pairwise comparisons to %d randomly selected genomes.' % self.max_genomes_for_stats)
        self.logger.info('Calculating statistics for all pairwise comparisons in a species cluster:')
        
        stats = {}
        for idx, (rid, cids) in enumerate(clusters.items()):
            statusStr = '-> Processing %d of %d (%.2f%%) clusters (size = %d).'.ljust(86) % (
                                idx+1, 
                                len(clusters), 
                                float((idx+1)*100)/len(clusters),
                                len(cids))
            sys.stdout.write('%s\r' % statusStr)
            sys.stdout.flush()
                                
            if len(cids) == 0:
                stats[rid] = self.PairwiseStats(min_ani = -1,
                                                mean_ani = -1,
                                                std_ani = -1,
                                                median_ani = -1,
                                                ani_to_medoid = -1,
                                                mean_ani_to_medoid = -1,
                                                mean_ani_to_rep = -1,
                                                ani_below_95 = -1)
            else:
                if len(cids) > self.max_genomes_for_stats:
                    cids = set(random.sample(cids, self.max_genomes_for_stats))
                
                # calculate ANI to representative genome
                gid_pairs = []
                gids = list(cids.union([rid]))
                for gid1, gid2 in combinations(gids, 2):
                    gid_pairs.append((gid1, gid2))
                    gid_pairs.append((gid2, gid1))
                    
                if False: #***
                    ani_af = self.fastani.pairs(gid_pairs, 
                                                        genome_files, 
                                                        report_progress=False)
                else:
                    ani_af = self.fastani.ani_cache
                                                        
                # calculate medoid point
                if len(gids) > 2:
                    dist_mat = np_zeros((len(gids), len(gids)))
                    for i, gid1 in enumerate(gids):
                        for j, gid2 in enumerate(gids):
                            if i < j:
                                ani, af = symmetric_ani(ani_af, gid1, gid2)
                                dist_mat[i, j] = 100 - ani
                                dist_mat[j, i] = 100 - ani

                    medoid_idx = np_argmin(dist_mat.sum(axis=0))
                    medoid_gid = gids[medoid_idx]
                else:
                    # with only 2 genomes in a cluster, the representative is the
                    # natural medoid at least for reporting statistics for the
                    # individual species cluster
                    medoid_gid = rid
                    
                mean_ani_to_medoid = np_mean([symmetric_ani(ani_af, gid, medoid_gid)[0] 
                                                for gid in gids if gid != medoid_gid])
                                                
                mean_ani_to_rep = np_mean([symmetric_ani(ani_af, gid, rid)[0] 
                                                for gid in gids if gid != rid])
                                                
                if mean_ani_to_medoid < mean_ani_to_rep:
                    self.logger.error('mean_ani_to_medoid < mean_ani_to_rep')
                    sys.exit(-1)

                # calculate statistics
                anis = []
                for gid1, gid2 in combinations(gids, 2):
                    ani, af = symmetric_ani(ani_af, gid1, gid2)
                    anis.append(ani)
                    
                stats[rid] = self.PairwiseStats(min_ani = min(anis),
                                                mean_ani = np_mean(anis),
                                                std_ani = np_std(anis),
                                                median_ani = np_median(anis),
                                                ani_to_medoid = symmetric_ani(ani_af, rid, medoid_gid)[0],
                                                mean_ani_to_medoid = mean_ani_to_medoid,
                                                mean_ani_to_rep = mean_ani_to_rep,
                                                ani_below_95 = sum([1 for ani in anis if ani < 95]))

        sys.stdout.write('\n')
            
        return stats
Example #21
0
    def _selected_rep_genomes(self,
                                genome_files,
                                nontype_radius, 
                                unclustered_qc_gids, 
                                mash_ani,
                                quality_metadata,
                                rnd_type_genome):
        """Select representative genomes for species clusters in a  greedy fashion using species-specific ANI thresholds."""

        # sort genomes by quality score
        if rnd_type_genome:
            self.logger.info('Selecting random de novo type genomes.')
            sorted_gids = []
            for gid in random.sample(unclustered_qc_gids, len(unclustered_qc_gids)):
                sorted_gids.append((gid, 0))
        else:
            self.logger.info('Selecting de novo type genomes in a greedy manner based on quality.')
            qscore = quality_score(unclustered_qc_gids, quality_metadata)
            sorted_gids = sorted(qscore.items(), key=operator.itemgetter(1), reverse=True)

        # greedily determine representatives for new species clusters
        cluster_rep_file = os.path.join(self.output_dir, 'cluster_reps.tsv')
        clusters = set()
        if not os.path.exists(cluster_rep_file):
            self.logger.info('Clustering genomes to identify representatives.')
            clustered_genomes = 0
            max_ani_pairs = 0
            for idx, (cur_gid, _score) in enumerate(sorted_gids):

                # determine reference genomes to calculate ANI between
                ani_pairs = []
                if cur_gid in mash_ani:
                    for rep_gid in clusters:
                        if mash_ani[cur_gid].get(rep_gid, 0) >= self.min_mash_ani:
                            ani_pairs.append((cur_gid, rep_gid))
                            ani_pairs.append((rep_gid, cur_gid))

                # determine if genome clusters with representative
                clustered = False
                if ani_pairs:
                    if len(ani_pairs) > max_ani_pairs:
                        max_ani_pairs = len(ani_pairs)
                    
                    ani_af = self.fastani.pairs(ani_pairs, genome_files, report_progress=False)

                    closest_rep_gid = None
                    closest_rep_ani = 0
                    closest_rep_af = 0
                    for rep_gid in clusters:
                        ani, af = symmetric_ani(ani_af, cur_gid, rep_gid)

                        if af >= self.af_sp:
                            if ani > closest_rep_ani or (ani == closest_rep_ani and af > closest_rep_af):
                                closest_rep_gid = rep_gid
                                closest_rep_ani = ani
                                closest_rep_af = af

                        if ani > nontype_radius[cur_gid].ani and af >= self.af_sp:
                            nontype_radius[cur_gid] = GenomeRadius(ani = ani, 
                                                                         af = af,
                                                                         neighbour_gid = rep_gid)
                                                                         
                    if closest_rep_gid and closest_rep_ani > nontype_radius[closest_rep_gid].ani:
                        clustered = True
                    
                if not clustered:
                    # genome is a new species cluster representative
                    clusters.add(cur_gid)
                else:
                    clustered_genomes += 1
                
                if (idx+1) % 10 == 0 or idx+1 == len(sorted_gids):
                    statusStr = '-> Clustered %d of %d (%.2f%%) genomes [ANI pairs: %d; clustered genomes: %d; clusters: %d].'.ljust(96) % (
                                    idx+1, 
                                    len(sorted_gids), 
                                    float(idx+1)*100/len(sorted_gids),
                                    max_ani_pairs,
                                    clustered_genomes,
                                    len(clusters))
                    sys.stdout.write('%s\r' % statusStr)
                    sys.stdout.flush()
                    max_ani_pairs = 0
            sys.stdout.write('\n')
            
            # write out selected cluster representative
            fout = open(cluster_rep_file, 'w')
            for gid in clusters:
                fout.write('%s\n' % gid)
            fout.close()
        else:
            # read cluster reps from file
            self.logger.warning('Using previously determined cluster representatives.')
            for line in open(cluster_rep_file):
                gid = line.strip()
                clusters.add(gid)
                
        self.logger.info('Selected %d representative genomes for de novo species clusters.' % len(clusters))
        
        return clusters
Example #22
0
    def write_synonym_table(self,
                            type_strain_synonyms,
                            consensus_synonyms,
                            ani_af,
                            sp_priority_ledger,
                            genus_priority_ledger,
                            dsmz_bacnames_file):
        """Create table indicating species names that should be considered synonyms."""
        
        sp_priority_mngr = SpeciesPriorityManager(sp_priority_ledger,
                                                    genus_priority_ledger,
                                                    dsmz_bacnames_file)

        out_file = os.path.join(self.output_dir, 'synonyms.tsv')
        fout = open(out_file, 'w')
        fout.write('Synonym type\tNCBI species\tGTDB representative\tStrain IDs\tType sources\tPriority year')
        fout.write('\tGTDB type species\tGTDB type strain\tNCBI assembly type')
        fout.write('\tNCBI synonym\tHighest-quality synonym genome\tSynonym strain IDs\tSynonym type sources\tSynonym priority year')
        fout.write('\tSynonym GTDB type species\tSynonym GTDB type strain\tSynonym NCBI assembly type')
        fout.write('\tANI\tAF\tWarnings\n')
        
        incorrect_priority = 0
        failed_type_strain_priority = 0
        for synonyms, synonym_type in [(type_strain_synonyms, 'TYPE_STRAIN_SYNONYM'), 
                                        (consensus_synonyms, 'CONSENSUS_SYNONYM')]:
            for rid, synonym_ids in synonyms.items():
                for gid in synonym_ids:
                    ani, af = symmetric_ani(ani_af, rid, gid)

                    fout.write(synonym_type)
                    fout.write('\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}'.format(
                                self.cur_genomes[rid].ncbi_taxa.species,
                                rid,
                                ','.join(sorted(self.cur_genomes[rid].strain_ids())),
                                ','.join(sorted(self.cur_genomes[rid].gtdb_type_sources())).upper().replace('STRAININFO', 'StrainInfo'),
                                sp_priority_mngr.species_priority_year(self.cur_genomes, rid),
                                self.cur_genomes[rid].is_gtdb_type_species(),
                                self.cur_genomes[rid].is_gtdb_type_strain(),
                                self.cur_genomes[rid].ncbi_type_material))
                    
                    synonym_priority_year = sp_priority_mngr.species_priority_year(self.cur_genomes, gid)
                    if synonym_priority_year == Genome.NO_PRIORITY_YEAR:
                        synonym_priority_year = 'n/a'
                    
                    fout.write('\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}'.format(
                                self.cur_genomes[gid].ncbi_taxa.species,
                                gid,
                                ','.join(sorted(self.cur_genomes[gid].strain_ids())),
                                ','.join(sorted(self.cur_genomes[gid].gtdb_type_sources())).upper().replace('STRAININFO', 'StrainInfo'),
                                synonym_priority_year,
                                self.cur_genomes[gid].is_gtdb_type_species(),
                                self.cur_genomes[gid].is_gtdb_type_strain(),
                                self.cur_genomes[gid].ncbi_type_material))
                    fout.write('\t{:.3f}\t{:.4f}'.format(ani, af))
                    
                    if self.cur_genomes[rid].is_effective_type_strain() and self.cur_genomes[gid].is_effective_type_strain():
                            priority_gid, note = sp_priority_mngr.species_priority(self.cur_genomes, rid, gid)
                            if priority_gid != rid:
                                incorrect_priority += 1
                                fout.write('\tIncorrect priority: {}'.format(note))
                    elif not self.cur_genomes[rid].is_gtdb_type_strain() and self.cur_genomes[gid].is_gtdb_type_strain():
                            failed_type_strain_priority += 1
                            fout.write('\tFailed to prioritize type strain of species')

                    fout.write('\n')
        
        if incorrect_priority:
            self.logger.warning(f' - identified {incorrect_priority:,} synonyms with incorrect priority.')
            
        if failed_type_strain_priority:
            self.logger.warning(f' - identified {failed_type_strain_priority:,} synonyms that failed to priotize the type strain of the species.')
    def dereplicate_species(self, species, rid, cids, genomes, mash_out_dir):
        """Dereplicate genomes within a GTDB species."""

        # greedily dereplicate genomes based on genome priority
        sorted_gids = self.order_genomes_by_priority(cids.difference([rid]),
                                                     genomes)
        sorted_gids = [rid] + sorted_gids

        # calculate Mash ANI between genomes
        mash_ani = []
        if len(sorted_gids) > 1:
            # calculate MASH distances between genomes
            out_prefix = os.path.join(mash_out_dir,
                                      species[3:].lower().replace(' ', '_'))
            mash_ani = self.mash_sp_ani(sorted_gids, genomes, out_prefix)

        # perform initial dereplication using Mash for species with excessive
        # numbers of genomes
        if len(sorted_gids) > self.max_genomes_per_sp:
            self.logger.info(
                ' - limiting species to <={:,} genomes based on priority and Mash dereplication.'
                .format(self.max_genomes_per_sp))

            prev_mash_rep_gids = None
            for ani_threshold in [
                    99.75, 99.5, 99.25, 99.0, 98.75, 98.5, 98.25, 98.0, 97.75,
                    97.5, 97.0, 96.5, 96.0, 95.0, None
            ]:
                if ani_threshold is None:
                    self.logger.warning(
                        ' - delected {:,} highest priority genomes from final Mash dereplication.'
                        % self.max_genomes_per_sp)
                    sorted_gids = mash_rep_gids[0:self.max_genomes_per_sp]
                    break

                mash_rep_gids = self.mash_sp_dereplicate(
                    mash_ani, sorted_gids, ani_threshold)

                self.logger.info(
                    ' - dereplicated {} from {:,} to {:,} genomes at {:.2f}% ANI using Mash.'
                    .format(species, len(cids), len(mash_rep_gids),
                            ani_threshold))

                if len(mash_rep_gids) <= self.max_genomes_per_sp:
                    if not prev_mash_rep_gids:
                        # corner case where dereplication is occurring at 99.75%
                        prev_mash_rep_gids = sorted_gids

                    # select maximum allowed number of genomes by taking all genomes in the
                    # current Mash dereplicated set and then the highest priority genomes in the
                    # previous Mash dereplicated set which have not been selected
                    cur_sel_gids = set(mash_rep_gids)
                    prev_sel_gids = set(prev_mash_rep_gids)
                    num_prev_to_sel = self.max_genomes_per_sp - len(
                        cur_sel_gids)
                    num_prev_selected = 0
                    sel_sorted_gids = []
                    for gid in sorted_gids:
                        if gid in cur_sel_gids:
                            sel_sorted_gids.append(gid)
                        elif (gid in prev_sel_gids
                              and num_prev_selected < num_prev_to_sel):
                            num_prev_selected += 1
                            sel_sorted_gids.append(gid)

                        if len(sel_sorted_gids) == self.max_genomes_per_sp:
                            break

                    assert len(cur_sel_gids - set(sel_sorted_gids)) == 0
                    assert num_prev_to_sel == num_prev_selected
                    assert len(sel_sorted_gids) == self.max_genomes_per_sp

                    sorted_gids = sel_sorted_gids
                    self.logger.info(
                        ' - selected {:,} highest priority genomes from Mash dereplication at an ANI = {:.2f}%.'
                        .format(len(sorted_gids), ani_threshold))
                    break

                prev_mash_rep_gids = mash_rep_gids
                prev_ani_threshold = ani_threshold

        # calculate FastANI ANI/AF between genomes passing Mash filtering
        ani_pairs = set()
        for gid1, gid2 in permutations(sorted_gids, 2):
            if gid1 in mash_ani and gid2 in mash_ani[gid1]:
                if mash_ani[gid1][gid2] >= self.min_mash_intra_sp_ani:
                    ani_pairs.add((gid1, gid2))
                    ani_pairs.add((gid2, gid1))

        self.logger.info(
            ' - calculating FastANI between {:,} pairs with Mash ANI >= {:.1f}%.'
            .format(len(ani_pairs), self.min_mash_intra_sp_ani))
        ani_af = self.fastani.pairs(ani_pairs,
                                    genomes.genomic_files,
                                    report_progress=False,
                                    check_cache=True)
        self.fastani.write_cache(silence=True)

        # perform greedy dereplication
        sp_reps = []
        for idx, gid in enumerate(sorted_gids):
            # determine if genome clusters with existing representative
            clustered = False
            for rid in sp_reps:
                ani, af = symmetric_ani(ani_af, gid, rid)

                if ani >= self.derep_ani and af >= self.derep_af:
                    clustered = True
                    break

            if not clustered:
                sp_reps.append(gid)

        self.logger.info(
            ' - dereplicated {} from {:,} to {:,} genomes.'.format(
                species, len(sorted_gids), len(sp_reps)))

        # assign clustered genomes to most similar representative
        subsp_clusters = {}
        for rid in sp_reps:
            subsp_clusters[rid] = [rid]

        non_rep_gids = set(sorted_gids) - set(sp_reps)
        for gid in non_rep_gids:
            closest_rid = None
            max_ani = 0
            max_af = 0
            for rid in sp_reps:
                ani, af = symmetric_ani(ani_af, gid, rid)
                if ((ani > max_ani and af >= self.derep_af) or
                    (ani == max_ani and af >= max_af and af >= self.derep_af)):
                    max_ani = ani
                    max_af = af
                    closest_rid = rid

            assert closest_rid is not None
            subsp_clusters[closest_rid].append(gid)

        return subsp_clusters