class PMC_ClusterStats(object):
    """Calculate statistics for species cluster."""

    def __init__(self, af_sp, max_genomes, ani_cache_file, cpus, output_dir):
        """Initialization."""

        check_dependencies(['fastANI', 'mash'])

        self.cpus = cpus
        self.output_dir = output_dir

        self.logger = logging.getLogger('timestamp')

        self.af_sp = af_sp

        self.fastani = FastANI(ani_cache_file, cpus)

        # maximum number of randomly selected genomes to
        self.max_genomes_for_stats = max_genomes
        # consider when calculating pairwise statistics

        self.RepStats = namedtuple(
            'RepStats', 'min_ani mean_ani std_ani median_ani')
        self.PairwiseStats = namedtuple('PairwiseStats', ('min_ani',
                                                          'mean_ani',
                                                          'std_ani',
                                                          'median_ani',
                                                          'ani_to_medoid',
                                                          'mean_ani_to_medoid',
                                                          'mean_ani_to_rep',
                                                          'ani_below_95'))

    def find_multiple_reps(self, clusters, cluster_radius):
        """Determine number of non-rep genomes within ANI radius of multiple rep genomes.

        This method assumes the ANI cache contains all relevant ANI calculations between
        representative and non-representative genomes. This is the case once the de novo
        clustering has been performed.
        """

        self.logger.info(
            'Determine number of non-rep genomes within ANI radius of multiple rep genomes.')

        # get clustered genomes IDs
        clustered_gids = []
        for rid in clusters:
            clustered_gids += clusters[rid]

        self.logger.info('Considering {:,} representatives and {:,} non-representative genomes.'.format(
            len(clusters),
            len(clustered_gids)))

        nonrep_rep_count = defaultdict(set)
        for idx, gid in enumerate(clustered_gids):
            cur_ani_cache = self.fastani.ani_cache[gid]
            for rid in clusters:
                if rid not in cur_ani_cache:
                    continue

                ani, af = FastANI.symmetric_ani(
                    self.fastani.ani_cache, gid, rid)
                if af >= self.af_sp and ani >= cluster_radius[rid].ani:
                    nonrep_rep_count[gid].add((rid, ani))

            if (idx+1) % 100 == 0 or (idx+1) == len(clustered_gids):
                statusStr = '-> Processing %d of %d (%.2f%%) clusters genomes.'.ljust(86) % (
                    idx+1,
                    len(clustered_gids),
                    float((idx+1)*100)/len(clustered_gids))
                sys.stdout.write('%s\r' % statusStr)
                sys.stdout.flush()

        sys.stdout.write('\n')

        return nonrep_rep_count

    def intragenus_pairwise_ani(self, clusters, species, genome_files, gtdb_taxonomy):
        """Determine pairwise intra-genus ANI between representative genomes."""

        self.logger.info(
            'Calculating pairwise intra-genus ANI values between GTDB representatives.')

        # get genus for each representative
        genus = {}
        for rid, sp in species.items():
            genus[rid] = sp.split()[0].replace('s__', '')
            assert genus[rid] == gtdb_taxonomy[rid][5].replace('g__', '')

        # get pairs above Mash threshold
        self.logger.info('Determining intra-genus genome pairs.')
        ani_pairs = []
        for qid in clusters:
            for rid in clusters:
                if qid == rid:
                    continue

                genusA = genus[qid]
                genusB = genus[rid]
                if genusA != genusB:
                    continue

                ani_pairs.append((qid, rid))
                ani_pairs.append((rid, qid))

        self.logger.info(
            'Identified {:,} intra-genus genome pairs.'.format(len(ani_pairs)))

        # calculate ANI between pairs
        self.logger.info(
            'Calculating ANI between {:,} genome pairs:'.format(len(ani_pairs)))
        if True:  # ***DEBUGGING
            ani_af = self.fastani.pairs(ani_pairs, genome_files)
            pickle.dump(ani_af, open(os.path.join(
                self.output_dir, 'type_genomes_ani_af.pkl'), 'wb'))
        else:
            ani_af = pickle.load(
                open(os.path.join(self.output_dir, 'type_genomes_ani_af.pkl'), 'rb'))

        # find closest intra-genus pair for each rep
        fout = open(os.path.join(self.output_dir,
                                 'intra_genus_pairwise_ani.tsv'), 'w')
        fout.write(
            'Genus\tSpecies 1\tGenome ID 1\tSpecies 2\tGenome ID2\tANI\tAF\n')
        closest_intragenus_rep = {}
        for qid in clusters:
            genusA = genus[qid]

            closest_ani = 0
            closest_af = 0
            closest_gid = None
            for rid in clusters:
                if qid == rid:
                    continue

                genusB = genus[rid]
                if genusA != genusB:
                    continue

                ani, af = ('n/a', 'n/a')
                if qid in ani_af and rid in ani_af[qid]:
                    ani, af = FastANI.symmetric_ani(ani_af, qid, rid)

                    if ani > closest_ani:
                        closest_ani = ani
                        closest_af = af
                        closest_gid = rid

                fout.write('{}\t{}\t{}\t{}\t{}\t{}\t{}\n'.format(
                    genusA,
                    species[qid],
                    qid,
                    species[rid],
                    rid,
                    ani,
                    af))

            if closest_gid:
                closest_intragenus_rep[qid] = (
                    closest_gid, closest_ani, closest_af)

        fout.close()

        # write out closest intra-genus species to each representative
        fout = open(os.path.join(self.output_dir,
                                 'closest_intragenus_rep.tsv'), 'w')
        fout.write(
            'Genome ID\tSpecies\tIntra-genus neighbour\tIntra-genus species\tANI\tAF\n')
        for qid in closest_intragenus_rep:
            rid, ani, af = closest_intragenus_rep[qid]

            fout.write('%s\t%s\t%s\t%s\t%.2f\t%.3f\n' % (
                qid,
                species[qid],
                rid,
                species[rid],
                ani,
                af))
        fout.close()

    def parse_clusters(self, cluster_file):
        """Parse species clustering information."""

        species = {}
        clusters = {}
        cluster_radius = {}
        with open(cluster_file) as f:
            headers = f.readline().strip().split('\t')

            type_sp_index = headers.index('NCBI species')
            type_genome_index = headers.index('Type genome')
            num_clustered_index = headers.index('No. clustered genomes')
            clustered_genomes_index = headers.index('Clustered genomes')
            closest_type_index = headers.index('Closest type genome')
            ani_radius_index = headers.index('ANI radius')
            af_index = headers.index('AF closest')

            for line in f:
                line_split = line.strip().split('\t')

                rid = line_split[type_genome_index]
                rid = canonical_gid(rid)

                species[rid] = line_split[type_sp_index]

                clusters[rid] = set()
                num_clustered = int(line_split[num_clustered_index])
                if num_clustered > 0:
                    for gid in [g.strip() for g in line_split[clustered_genomes_index].split(',')]:
                        gid = canonical_gid(gid)
                        clusters[rid].add(gid)

                cluster_radius[rid] = GenomeRadius(ani=float(line_split[ani_radius_index]),
                                                   af=float(
                                                       line_split[af_index]),
                                                   neighbour_gid=line_split[closest_type_index])

        return clusters, species, cluster_radius

    def rep_genome_stats(self, clusters, genome_files):
        """Calculate statistics relative to representative genome."""

        self.logger.info('Calculating statistics to cluster representatives:')
        stats = {}
        for idx, (rid, cids) in enumerate(clusters.items()):
            if len(cids) == 0:
                stats[rid] = self.RepStats(min_ani=-1,
                                           mean_ani=-1,
                                           std_ani=-1,
                                           median_ani=-1)
            else:
                # calculate ANI to representative genome
                gid_pairs = []
                for cid in cids:
                    gid_pairs.append((cid, rid))
                    gid_pairs.append((rid, cid))

                if True:  # *** DEBUGGING
                    ani_af = self.fastani.pairs(gid_pairs,
                                                genome_files,
                                                report_progress=False)
                else:
                    ani_af = self.fastani.ani_cache

                # calculate statistics
                anis = [FastANI.symmetric_ani(ani_af, cid, rid)[
                    0] for cid in cids]

                stats[rid] = self.RepStats(min_ani=min(anis),
                                           mean_ani=np_mean(anis),
                                           std_ani=np_std(anis),
                                           median_ani=np_median(anis))

            statusStr = '-> Processing %d of %d (%.2f%%) clusters.'.ljust(86) % (
                idx+1,
                len(clusters),
                float((idx+1)*100)/len(clusters))
            sys.stdout.write('%s\r' % statusStr)
            sys.stdout.flush()

        sys.stdout.write('\n')

        return stats

    def pairwise_stats(self, clusters, genome_files):
        """Calculate statistics for all pairwise comparisons in a species cluster."""

        self.logger.info(
            f'Restricting pairwise comparisons to {self.max_genomes_for_stats:,} randomly selected genomes.')
        self.logger.info(
            'Calculating statistics for all pairwise comparisons in a species cluster:')

        stats = {}
        for idx, (rid, cids) in enumerate(clusters.items()):
            statusStr = '-> Processing {:,} of {:,} ({:2f}%) clusters (size = {:,}).'.ljust(86).format(
                idx+1,
                len(clusters),
                float((idx+1)*100)/len(clusters),
                len(cids))
            sys.stdout.write('{}\r'.format(statusStr))
            sys.stdout.flush()

            if len(cids) == 0:
                stats[rid] = self.PairwiseStats(min_ani=-1,
                                                mean_ani=-1,
                                                std_ani=-1,
                                                median_ani=-1,
                                                ani_to_medoid=-1,
                                                mean_ani_to_medoid=-1,
                                                mean_ani_to_rep=-1,
                                                ani_below_95=-1)
            else:
                if len(cids) > self.max_genomes_for_stats:
                    cids = set(random.sample(cids, self.max_genomes_for_stats))

                # calculate ANI to representative genome
                gid_pairs = []
                gids = list(cids.union([rid]))
                for gid1, gid2 in combinations(gids, 2):
                    gid_pairs.append((gid1, gid2))
                    gid_pairs.append((gid2, gid1))

                if True:  # ***DEBUGGING
                    ani_af = self.fastani.pairs(gid_pairs,
                                                genome_files,
                                                report_progress=False)
                else:
                    ani_af = self.fastani.ani_cache

                # calculate medoid point
                if len(gids) > 2:
                    dist_mat = np_zeros((len(gids), len(gids)))
                    for i, gid1 in enumerate(gids):
                        for j, gid2 in enumerate(gids):
                            if i < j:
                                ani, _af = FastANI.symmetric_ani(
                                    ani_af, gid1, gid2)
                                dist_mat[i, j] = 100 - ani
                                dist_mat[j, i] = 100 - ani

                    medoid_idx = np_argmin(dist_mat.sum(axis=0))
                    medoid_gid = gids[medoid_idx]
                else:
                    # with only 2 genomes in a cluster, the representative is the
                    # natural medoid at least for reporting statistics for the
                    # individual species cluster
                    medoid_gid = rid

                mean_ani_to_medoid = np_mean([FastANI.symmetric_ani(ani_af, gid, medoid_gid)[0]
                                              for gid in gids if gid != medoid_gid])

                mean_ani_to_rep = np_mean([FastANI.symmetric_ani(ani_af, gid, rid)[0]
                                           for gid in gids if gid != rid])

                if mean_ani_to_medoid < mean_ani_to_rep:
                    self.logger.error('mean_ani_to_medoid < mean_ani_to_rep')
                    sys.exit(-1)

                # calculate statistics
                anis = []
                for gid1, gid2 in combinations(gids, 2):
                    ani, _af = FastANI.symmetric_ani(ani_af, gid1, gid2)
                    anis.append(ani)

                stats[rid] = self.PairwiseStats(
                    min_ani=min(anis),
                    mean_ani=np_mean(anis),
                    std_ani=np_std(anis),
                    median_ani=np_median(anis),
                    ani_to_medoid=FastANI.symmetric_ani(
                        ani_af, rid, medoid_gid)[0],
                    mean_ani_to_medoid=mean_ani_to_medoid,
                    mean_ani_to_rep=mean_ani_to_rep,
                    ani_below_95=sum([1 for ani in anis if ani < 95]))

        sys.stdout.write('\n')

        return stats

    def write_cluster_stats(self,
                            stats_file,
                            clusters,
                            species,
                            cluster_radius,
                            rep_stats,
                            pairwise_stats):
        """Write file with cluster statistics."""

        fout = open(stats_file, 'w')
        fout.write('Species\tRep genome\tNo. clustered genomes')
        fout.write(
            '\tMin ANI to rep\tMean ANI to rep\tStd ANI to rep\tMedian ANI to rep')
        fout.write(
            '\tMin pairwise ANI\tMean pairwise ANI\tStd pairwise ANI\tMedian pairwise ANI')
        fout.write(
            '\tANI to medoid\tMean ANI to medoid\tMean ANI to rep (w/ subsampling)\tANI pairs <95%')
        fout.write(
            '\tClosest species\tClosest rep genome\tANI radius\tAF closest')
        fout.write('\tClustered genomes\n')

        for rid in clusters:
            fout.write('%s\t%s\t%d' % (species[rid], rid, len(clusters[rid])))
            fout.write('\t%.2f\t%.2f\t%.3f\t%.2f' % (
                rep_stats[rid].min_ani,
                rep_stats[rid].mean_ani,
                rep_stats[rid].std_ani,
                rep_stats[rid].median_ani))

            fout.write('\t%.2f\t%.2f\t%.3f\t%.2f\t%.2f\t%.2f\t%.2f\t%d' % (
                pairwise_stats[rid].min_ani,
                pairwise_stats[rid].mean_ani,
                pairwise_stats[rid].std_ani,
                pairwise_stats[rid].median_ani,
                pairwise_stats[rid].ani_to_medoid,
                pairwise_stats[rid].mean_ani_to_medoid,
                pairwise_stats[rid].mean_ani_to_rep,
                pairwise_stats[rid].ani_below_95))

            if cluster_radius[rid].neighbour_gid != 'N/A':
                fout.write('\t%s\t%s\t%.2f\t%.2f' % (
                    species[cluster_radius[rid].neighbour_gid],
                    cluster_radius[rid].neighbour_gid,
                    cluster_radius[rid].ani,
                    cluster_radius[rid].af))
            else:
                fout.write('\t%s\t%s\t%.2f\t%.2f' % ('N/A', 'N/A', 95, 0))

            fout.write('\t%s\n' % ','.join(clusters[rid]))

        fout.close()

    def run(self, cluster_file, genome_path_file, metadata_file):
        """Calculate statistics for species cluster."""

        # read the GTDB taxonomy
        self.logger.info('Reading GTDB taxonomy from metadata file.')
        gtdb_taxonomy = read_gtdb_taxonomy(metadata_file)

        # get path to genome FASTA files
        self.logger.info('Reading path to genome FASTA files.')
        genome_files = read_genome_path(genome_path_file)
        self.logger.info(f'Read path for {len(genome_files):,} genomes.')

        # determine type genomes and genomes clustered to type genomes
        self.logger.info('Reading species clusters.')
        clusters, species, cluster_radius = self.parse_clusters(cluster_file)
        self.logger.info(f'Identified {len(clusters):,} species clusters.')

        # determine species assignment for clustered genomes
        clustered_species = {}
        for rid, cids in clusters.items():
            for cid in cids:
                clustered_species[cid] = species[rid]

        # determine number of non-rep genomes within ANI radius of multiple rep genomes
        nonrep_rep_count = self.find_multiple_reps(clusters, cluster_radius)

        fout = open(os.path.join(self.output_dir,
                                 'nonrep_rep_ani_radius_count.tsv'), 'w')
        fout.write('Genome ID\tSpecies\tNo. rep radii\tMean radii')
        fout.write('\t<0.25%\t<0.5%\t<0.75%\t<1%\t<1.5%\t<2%')
        fout.write('\tRep genomes IDs\n')
        for gid, rid_info in nonrep_rep_count.items():
            rids = [rid for rid, ani in rid_info]
            anis = [ani for rid, ani in rid_info]

            fout.write('%s\t%s\t%d\t%.2f' % (
                gid,
                clustered_species[gid],
                len(rids),
                np_mean([cluster_radius[rid].ani for rid in rids])))

            if len(anis) >= 2:
                max_ani = max(anis)
                ani_2nd = sorted(anis, reverse=True)[1]
                diff = max_ani - ani_2nd
                fout.write('\t%s' % (diff < 0.25))
                fout.write('\t%s' % (diff < 0.5))
                fout.write('\t%s' % (diff < 0.75))
                fout.write('\t%s' % (diff < 1.0))
                fout.write('\t%s' % (diff < 1.5))
                fout.write('\t%s' % (diff < 2.0))
            else:
                fout.write('\tFalse\tFalse\tFalse\tFalse\tFalse\tFalse')

            fout.write('\t%s\n' % ','.join(rids))
        fout.close()

        # find closest representative genome to each representative genome
        self.intragenus_pairwise_ani(
            clusters, species, genome_files, gtdb_taxonomy)

        # identify statistics relative to representative genome
        rep_stats = self.rep_genome_stats(clusters, genome_files)

        # identify pairwise statistics
        pairwise_stats = self.pairwise_stats(clusters, genome_files)

        # report statistics
        stats_file = os.path.join(self.output_dir, 'cluster_stats.tsv')
        self.write_cluster_stats(stats_file,
                                 clusters,
                                 species,
                                 cluster_radius,
                                 rep_stats,
                                 pairwise_stats)
class MergeTest():
    """Produce information relevant to merging two sister species."""
    def __init__(self, ani_cache_file, cpus, output_dir):
        """Initialization."""

        check_dependencies(['fastANI'])

        self.cpus = cpus
        self.output_dir = output_dir

        self.logger = logging.getLogger('timestamp')

        self.fastani = FastANI(ani_cache_file, cpus)

    def top_hits(self, species, rid, ani_af, genomes):
        """Report top 5 hits to species."""

        results = {}
        for qid in ani_af[rid]:
            ani, af = FastANI.symmetric_ani(ani_af, rid, qid)
            results[qid] = (ani, af)

        self.logger.info(f'Closest 5 species to {species} ({rid}):')
        idx = 0
        for qid, (ani, af) in sorted(results.items(),
                                     key=lambda x: x[1],
                                     reverse=True):
            q_species = genomes[qid].gtdb_species
            self.logger.info(
                f'{q_species} ({qid}): ANI={ani:.1f}%, AF={af:.2f}')
            if idx == 5:
                break

            idx += 1

    def merge_ani_radius(self, species, rid, merged_sp_cluster, genomic_files):
        """Determine ANI radius if species were merged."""

        self.logger.info(
            f'Calculating ANI from {species} to all genomes in merged species cluster.'
        )

        gid_pairs = []
        for gid in merged_sp_cluster:
            gid_pairs.append((rid, gid))
            gid_pairs.append((gid, rid))
        merged_ani_af1 = self.fastani.pairs(gid_pairs, genomic_files)

        ani_radius = 100
        for gid in merged_sp_cluster:
            ani, af = FastANI.symmetric_ani(merged_ani_af1, rid, gid)
            if ani < ani_radius:
                ani_radius = ani
                af_radius = af
        self.logger.info(
            f'Merged cluster with {species} rep: ANI radius={ani_radius:.1f}%, AF={af_radius:.2f}'
        )

    def run(self, gtdb_metadata_file, genome_path_file, species1, species2):
        """Produce information relevant to merging two sister species."""

        # read GTDB species clusters
        self.logger.info('Reading GTDB species clusters.')
        genomes = Genomes()
        genomes.load_from_metadata_file(gtdb_metadata_file)
        genomes.load_genomic_file_paths(genome_path_file)
        self.logger.info(
            ' - identified {:,} species clusters spanning {:,} genomes.'.
            format(len(genomes.sp_clusters),
                   genomes.sp_clusters.total_num_genomes()))

        # find species of interest
        gid1 = None
        gid2 = None
        for gid, species in genomes.sp_clusters.species():
            if species == species1:
                gid1 = gid
            elif species == species2:
                gid2 = gid

        if gid1 is None:
            self.logger.error(
                f'Unable to find representative genome for {species1}.')
            sys.exit(-1)

        if gid2 is None:
            self.logger.error(
                f'Unable to find representative genome for {species2}.')
            sys.exit(-1)

        self.logger.info(' - identified {:,} genomes in {}.'.format(
            len(genomes.sp_clusters[gid1]), species1))
        self.logger.info(' - identified {:,} genomes in {}.'.format(
            len(genomes.sp_clusters[gid2]), species2))

        # calculate ANI between all genome in genus
        genus1 = genomes[gid1].gtdb_genus
        genus2 = genomes[gid2].gtdb_genus
        if genus1 != genus2:
            self.logger.error(
                f'Genomes must be from same genus: {genus1} {genus2}')
            sys.exit(-1)

        self.logger.info(f'Identifying {genus1} species representatives.')
        reps_in_genera = set()
        for rid in genomes.sp_clusters:
            if genomes[rid].gtdb_genus == genus1:
                reps_in_genera.add(rid)

        self.logger.info(
            f' - identified {len(reps_in_genera):,} representatives.')

        # calculate ANI between genomes
        self.logger.info(f'Calculating ANI to {species1}.')
        gid_pairs = []
        for gid in reps_in_genera:
            if gid != gid1:
                gid_pairs.append((gid1, gid))
                gid_pairs.append((gid, gid1))
        ani_af1 = self.fastani.pairs(gid_pairs, genomes.genomic_files)

        self.logger.info(f'Calculating ANI to {species2}.')
        gid_pairs = []
        for gid in reps_in_genera:
            if gid != gid2:
                gid_pairs.append((gid2, gid))
                gid_pairs.append((gid, gid2))
        ani_af2 = self.fastani.pairs(gid_pairs, genomes.genomic_files)

        # report results
        ani12, af12 = ani_af1[gid1][gid2]
        ani21, af21 = ani_af2[gid2][gid1]
        ani, af = FastANI.symmetric_ani(ani_af1, gid1, gid2)

        self.logger.info(
            f'{species1} ({gid1}) -> {species2} ({gid2}): ANI={ani12:.1f}%, AF={af12:.2f}'
        )
        self.logger.info(
            f'{species2} ({gid2}) -> {species1} ({gid1}): ANI={ani21:.1f}%, AF={af21:.2f}'
        )
        self.logger.info(f'Max. ANI={ani:.1f}%, Max. AF={af:.2f}')

        # report top hits
        self.top_hits(species1, gid1, ani_af1, genomes)
        self.top_hits(species2, gid2, ani_af2, genomes)

        # calculate ANI from species to all genomes in merged species cluster
        merged_sp_cluster = genomes.sp_clusters[gid1].union(
            genomes.sp_clusters[gid2])
        self.merge_ani_radius(species1, gid1, merged_sp_cluster,
                              genomes.genomic_files)
        self.merge_ani_radius(species2, gid2, merged_sp_cluster,
                              genomes.genomic_files)
class UpdateClusterNamedReps(object):
    """Cluster genomes to selected GTDB representatives."""
    def __init__(self, ani_sp, af_sp, ani_cache_file, cpus, output_dir):
        """Initialization."""

        check_dependencies(['fastANI', 'mash'])

        self.cpus = cpus
        self.output_dir = output_dir

        self.logger = logging.getLogger('timestamp')

        self.ani_sp = ani_sp
        self.af_sp = af_sp

        self.max_ani_neighbour = 97.0
        self.max_af_neighbour = 0.65
        self.min_mash_ani = 90.0

        self.ClusteredGenome = namedtuple('ClusteredGenome', 'ani af gid')

        self.fastani = FastANI(ani_cache_file, cpus)

    def _rep_radius(self, rep_gids, rep_ani_file):
        """Calculate circumscription radius for representative genomes."""

        # set radius for all representative genomes to default values
        rep_radius = {}
        for gid in rep_gids:
            rep_radius[gid] = GenomeRadius(ani=self.ani_sp,
                                           af=None,
                                           neighbour_gid=None)

        # determine closest ANI neighbour and restrict ANI radius as necessary
        af_warning_count = 0
        with open(rep_ani_file) as f:
            header = f.readline().strip().split('\t')

            rep_gid1_index = header.index('Representative 1')
            rep_gid2_index = header.index('Representative 2')
            ani_index = header.index('ANI')
            af_index = header.index('AF')

            for line in f:
                line_split = line.strip().split('\t')

                rep_gid1 = line_split[rep_gid1_index]
                rep_gid2 = line_split[rep_gid2_index]

                if rep_gid1 not in rep_gids or rep_gid2 not in rep_gids:
                    continue

                ani = float(line_split[ani_index])
                af = float(line_split[af_index])

                if ani >= self.max_ani_neighbour and af >= self.max_af_neighbour:
                    # typically, representative genomes should not exceed this ANI and AF
                    # criteria as they should have been declared synonyms in
                    # the u_sel_reps step if they are this similar to each other.
                    # However, a 'fudge factor' is used to allow previous GTDB clusters
                    # to remain as seperate clusters if they exceed these thresholds by
                    # a small margin as this can simply be due to differences in the
                    # version of FastANI used to calculate ANI and AF.
                    self.logger.warning(
                        'ANI neighbours {} and {} have ANI={:.2f} and AF={:.2f}.'
                        .format(rep_gid1, rep_gid2, ani, af))

                if ani > rep_radius[rep_gid1].ani:
                    if af < self.af_sp:
                        af_warning_count += 1
                        #self.logger.warning('ANI for {} and {} is >{:.2f}, but AF <{:.2f} [pair skipped].'.format(
                        #                        rep_gid1,
                        #                        rep_gid2,
                        #                        ani, af))
                        continue

                    rep_radius[rep_gid1] = GenomeRadius(ani=ani,
                                                        af=af,
                                                        neighbour_gid=rep_gid2)

        self.logger.info(
            'ANI circumscription radius: min={:.2f}, mean={:.2f}, max={:.2f}'.
            format(min([d.ani for d in rep_radius.values()]),
                   np_mean([d.ani for d in rep_radius.values()]),
                   max([d.ani for d in rep_radius.values()])))

        self.logger.warning(
            'Identified {:,} genome pairs meeting ANI radius criteria, but with an AF <{:.2f}'
            .format(af_warning_count, self.af_sp))

        return rep_radius

    def _calculate_ani(self, cur_genomes, rep_gids, rep_mash_sketch_file):
        """Calculate ANI between representative and non-representative genomes."""

        if True:  #***
            mash = Mash(self.cpus)

            # create Mash sketch for representative genomes
            if not rep_mash_sketch_file or not os.path.exists(
                    rep_mash_sketch_file):
                rep_genome_list_file = os.path.join(self.output_dir,
                                                    'gtdb_reps.lst')
                rep_mash_sketch_file = os.path.join(self.output_dir,
                                                    'gtdb_reps.msh')
                mash.sketch(rep_gids, cur_genomes.genomic_files,
                            rep_genome_list_file, rep_mash_sketch_file)

            # create Mash sketch for non-representative genomes
            nonrep_gids = set()
            for gid in cur_genomes:
                if gid not in rep_gids:
                    nonrep_gids.add(gid)

            nonrep_genome_list_file = os.path.join(self.output_dir,
                                                   'gtdb_nonreps.lst')
            nonrep_genome_sketch_file = os.path.join(self.output_dir,
                                                     'gtdb_nonreps.msh')
            mash.sketch(nonrep_gids, cur_genomes.genomic_files,
                        nonrep_genome_list_file, nonrep_genome_sketch_file)

            # get Mash distances
            mash_dist_file = os.path.join(self.output_dir,
                                          'gtdb_reps_vs_nonreps.dst')
            mash.dist(
                float(100 - self.min_mash_ani) / 100, rep_mash_sketch_file,
                nonrep_genome_sketch_file, mash_dist_file)

            # read Mash distances
            mash_ani = mash.read_ani(mash_dist_file)

            # get pairs above Mash threshold
            mash_ani_pairs = []
            for qid in mash_ani:
                for rid in mash_ani[qid]:
                    if mash_ani[qid][rid] >= self.min_mash_ani:
                        n_qid = cur_genomes.user_uba_id_map.get(qid, qid)
                        n_rid = cur_genomes.user_uba_id_map.get(rid, rid)
                        if n_qid != n_rid:
                            mash_ani_pairs.append((n_qid, n_rid))
                            mash_ani_pairs.append((n_rid, n_qid))

            self.logger.info(
                'Identified {:,} genome pairs with a Mash ANI >= {:.1f}%.'.
                format(len(mash_ani_pairs), self.min_mash_ani))

            # calculate ANI between pairs
            self.logger.info(
                'Calculating ANI between {:,} genome pairs:'.format(
                    len(mash_ani_pairs)))
            ani_af = self.fastani.pairs(mash_ani_pairs,
                                        cur_genomes.genomic_files)
            pickle.dump(
                ani_af,
                open(os.path.join(self.output_dir, 'ani_af_rep_vs_nonrep.pkl'),
                     'wb'))
        else:
            self.logger.warning(
                'Using previously calculated results in: {}'.format(
                    'ani_af_rep_vs_nonrep.pkl'))
            ani_af = pickle.load(
                open(os.path.join(self.output_dir, 'ani_af_rep_vs_nonrep.pkl'),
                     'rb'))

        return ani_af

    def _cluster(self, ani_af, non_reps, rep_radius):
        """Cluster non-representative to representative genomes using species specific ANI thresholds."""

        clusters = {}
        for rep_id in rep_radius:
            clusters[rep_id] = []

        num_clustered = 0
        for idx, non_rid in enumerate(non_reps):
            if idx % 100 == 0:
                sys.stdout.write(
                    '==> Processed {:,} of {:,} genomes [no. clustered = {:,}].\r'
                    .format(idx + 1, len(non_reps), num_clustered))
                sys.stdout.flush()

            if non_rid not in ani_af:
                continue

            closest_rid = None
            closest_ani = 0
            closest_af = 0
            for rid in rep_radius:
                if rid not in ani_af[non_rid]:
                    continue

                ani, af = symmetric_ani(ani_af, rid, non_rid)

                if af >= self.af_sp:
                    if ani > closest_ani or (ani == closest_ani
                                             and af > closest_af):
                        closest_rid = rid
                        closest_ani = ani
                        closest_af = af

            if closest_rid:
                if closest_ani > rep_radius[closest_rid].ani:
                    num_clustered += 1
                    clusters[closest_rid].append(
                        self.ClusteredGenome(gid=non_rid,
                                             ani=closest_ani,
                                             af=closest_af))

        sys.stdout.write(
            '==> Processed {:,} of {:,} genomes [no. clustered = {:,}].\r'.
            format(idx + 1, len(non_reps), num_clustered))
        sys.stdout.flush()
        sys.stdout.write('\n')

        num_unclustered = len(non_reps) - num_clustered
        self.logger.info(
            'Assigned {:,} genomes to {:,} representatives; {:,} genomes remain unclustered.'
            .format(sum([len(clusters[rid]) for rid in clusters]),
                    len(clusters), num_unclustered))

        return clusters

    def run(self, named_rep_file, cur_gtdb_metadata_file,
            cur_genomic_path_file, uba_genome_paths, qc_passed_file,
            ncbi_genbank_assembly_file, untrustworthy_type_file,
            rep_mash_sketch_file, rep_ani_file, gtdb_type_strains_ledger):
        """Cluster genomes to selected GTDB representatives."""

        # create current GTDB genome sets
        self.logger.info('Creating current GTDB genome set.')
        cur_genomes = Genomes()
        cur_genomes.load_from_metadata_file(
            cur_gtdb_metadata_file,
            gtdb_type_strains_ledger=gtdb_type_strains_ledger,
            create_sp_clusters=False,
            uba_genome_file=uba_genome_paths,
            qc_passed_file=qc_passed_file,
            ncbi_genbank_assembly_file=ncbi_genbank_assembly_file,
            untrustworthy_type_ledger=untrustworthy_type_file)
        self.logger.info(
            f' ... current genome set contains {len(cur_genomes):,} genomes.')

        # get path to previous and current genomic FASTA files
        self.logger.info('Reading path to current genomic FASTA files.')
        cur_genomes.load_genomic_file_paths(cur_genomic_path_file)
        cur_genomes.load_genomic_file_paths(uba_genome_paths)

        # get representative genomes
        rep_gids = set()
        with open(named_rep_file) as f:
            header = f.readline().strip().split('\t')
            rep_index = header.index('Representative')
            sp_index = header.index('Proposed species')

            for line in f:
                line_split = line.strip().split('\t')
                gid = line_split[rep_index]
                assert gid in cur_genomes
                rep_gids.add(gid)

        self.logger.info(
            'Identified representative genomes for {:,} species.'.format(
                len(rep_gids)))

        # calculate circumscription radius for representative genomes
        self.logger.info(
            'Determining ANI species circumscription for {:,} representative genomes.'
            .format(len(rep_gids)))
        rep_radius = self._rep_radius(rep_gids, rep_ani_file)
        write_rep_radius(
            rep_radius, cur_genomes,
            os.path.join(self.output_dir, 'gtdb_rep_ani_radius.tsv'))

        # calculate ANI between representative and non-representative genomes
        self.logger.info(
            'Calculating ANI between representative and non-representative genomes.'
        )
        ani_af = self._calculate_ani(cur_genomes, rep_gids,
                                     rep_mash_sketch_file)
        self.logger.info(
            ' ... ANI values determined for {:,} query genomes.'.format(
                len(ani_af)))
        self.logger.info(
            ' ... ANI values determined for {:,} genome pairs.'.format(
                sum([len(ani_af[qid]) for qid in ani_af])))

        # cluster remaining genomes to representatives
        non_reps = set(cur_genomes.genomes) - set(rep_radius)
        self.logger.info(
            'Clustering {:,} non-representatives to {:,} representatives using species-specific ANI radii.'
            .format(len(non_reps), len(rep_radius)))
        clusters = self._cluster(ani_af, non_reps, rep_radius)

        # write out clusters
        write_clusters(
            clusters, rep_radius, cur_genomes,
            os.path.join(self.output_dir, 'gtdb_named_rep_clusters.tsv'))
class IntraSpeciesDereplication(object):
    """Dereplicate GTDB species clusters using ANI/AF criteria."""
    def __init__(self, derep_ani, derep_af, max_genomes_per_sp, ani_cache_file,
                 cpus, output_dir):
        """Initialization."""

        check_dependencies(['fastANI', 'mash'])

        self.cpus = cpus
        self.output_dir = output_dir

        self.logger = logging.getLogger('timestamp')

        self.max_genomes_per_sp = max_genomes_per_sp
        self.derep_ani = derep_ani
        self.derep_af = derep_af

        # minimum MASH ANI value for dereplicating within a species
        self.min_mash_intra_sp_ani = derep_ani - 1.0

        self.mash = Mash(self.cpus)
        self.fastani = FastANI(ani_cache_file, cpus)

    def mash_sp_ani(self, gids, genomes, output_prefix):
        """Calculate pairwise Mash ANI estimates between genomes."""

        INIT_MASH_ANI_FILTER = 95.0

        # create Mash sketch for all genomes
        mash_sketch_file = f'{output_prefix}.msh'
        genome_list_file = f'{output_prefix}.lst'
        self.mash.sketch(gids,
                         genomes.genomic_files,
                         genome_list_file,
                         mash_sketch_file,
                         silence=True)

        # get Mash distances
        mash_dist_file = f'{output_prefix}.dst'
        self.mash.dist_pairwise(float(100 - INIT_MASH_ANI_FILTER) / 100,
                                mash_sketch_file,
                                mash_dist_file,
                                silence=True)

        # read Mash distances
        mash_ani = self.mash.read_ani(mash_dist_file)

        count = 0
        for qid in mash_ani:
            for rid in mash_ani[qid]:
                if qid != rid:
                    count += 1

        self.logger.info(
            ' - identified {:,} pairs passing Mash filtering of ANI >= {:.1f}%.'
            .format(count, INIT_MASH_ANI_FILTER))

        return mash_ani

    def priority_score(self, gid, genomes):
        """Get priority score of genome."""

        score = genomes[gid].score_assembly()
        if genomes[gid].is_gtdb_type_subspecies():
            score += 1e4

        return score

    def order_genomes_by_priority(self, gids, genomes):
        """Order genomes by overall priority. """

        genome_priority = {}
        for gid in gids:
            genome_priority[gid] = self.priority_score(gid, genomes)

        sorted_by_priority = sorted(genome_priority.items(),
                                    key=operator.itemgetter(1),
                                    reverse=True)

        return [d[0] for d in sorted_by_priority]

    def mash_sp_dereplicate(self, mash_ani, sorted_gids, ani_threshold):
        """Dereplicate genomes in species using Mash distances."""

        # perform greedy selection of new representatives
        sp_reps = []
        for gid in sorted_gids:
            clustered = False
            for rep_id in sp_reps:
                if gid in mash_ani:
                    ani = mash_ani[gid].get(rep_id, 0)
                else:
                    ani = 0

                if ani >= ani_threshold:
                    clustered = True
                    break

            if not clustered:
                # genome was not assigned to an existing representative,
                # so make it a new representative genome
                sp_reps.append(gid)

        return sp_reps

    def dereplicate_species(self, species, rid, cids, genomes, mash_out_dir):
        """Dereplicate genomes within a GTDB species."""

        # greedily dereplicate genomes based on genome priority
        sorted_gids = self.order_genomes_by_priority(cids.difference([rid]),
                                                     genomes)
        sorted_gids = [rid] + sorted_gids

        # calculate Mash ANI between genomes
        mash_ani = []
        if len(sorted_gids) > 1:
            # calculate MASH distances between genomes
            out_prefix = os.path.join(mash_out_dir,
                                      species[3:].lower().replace(' ', '_'))
            mash_ani = self.mash_sp_ani(sorted_gids, genomes, out_prefix)

        # perform initial dereplication using Mash for species with excessive
        # numbers of genomes
        if len(sorted_gids) > self.max_genomes_per_sp:
            self.logger.info(
                ' - limiting species to <={:,} genomes based on priority and Mash dereplication.'
                .format(self.max_genomes_per_sp))

            prev_mash_rep_gids = None
            for ani_threshold in [
                    99.75, 99.5, 99.25, 99.0, 98.75, 98.5, 98.25, 98.0, 97.75,
                    97.5, 97.0, 96.5, 96.0, 95.0, None
            ]:
                if ani_threshold is None:
                    self.logger.warning(
                        ' - delected {:,} highest priority genomes from final Mash dereplication.'
                        % self.max_genomes_per_sp)
                    sorted_gids = mash_rep_gids[0:self.max_genomes_per_sp]
                    break

                mash_rep_gids = self.mash_sp_dereplicate(
                    mash_ani, sorted_gids, ani_threshold)

                self.logger.info(
                    ' - dereplicated {} from {:,} to {:,} genomes at {:.2f}% ANI using Mash.'
                    .format(species, len(cids), len(mash_rep_gids),
                            ani_threshold))

                if len(mash_rep_gids) <= self.max_genomes_per_sp:
                    if not prev_mash_rep_gids:
                        # corner case where dereplication is occurring at 99.75%
                        prev_mash_rep_gids = sorted_gids

                    # select maximum allowed number of genomes by taking all genomes in the
                    # current Mash dereplicated set and then the highest priority genomes in the
                    # previous Mash dereplicated set which have not been selected
                    cur_sel_gids = set(mash_rep_gids)
                    prev_sel_gids = set(prev_mash_rep_gids)
                    num_prev_to_sel = self.max_genomes_per_sp - len(
                        cur_sel_gids)
                    num_prev_selected = 0
                    sel_sorted_gids = []
                    for gid in sorted_gids:
                        if gid in cur_sel_gids:
                            sel_sorted_gids.append(gid)
                        elif (gid in prev_sel_gids
                              and num_prev_selected < num_prev_to_sel):
                            num_prev_selected += 1
                            sel_sorted_gids.append(gid)

                        if len(sel_sorted_gids) == self.max_genomes_per_sp:
                            break

                    assert len(cur_sel_gids - set(sel_sorted_gids)) == 0
                    assert num_prev_to_sel == num_prev_selected
                    assert len(sel_sorted_gids) == self.max_genomes_per_sp

                    sorted_gids = sel_sorted_gids
                    self.logger.info(
                        ' - selected {:,} highest priority genomes from Mash dereplication at an ANI = {:.2f}%.'
                        .format(len(sorted_gids), ani_threshold))
                    break

                prev_mash_rep_gids = mash_rep_gids
                prev_ani_threshold = ani_threshold

        # calculate FastANI ANI/AF between genomes passing Mash filtering
        ani_pairs = set()
        for gid1, gid2 in permutations(sorted_gids, 2):
            if gid1 in mash_ani and gid2 in mash_ani[gid1]:
                if mash_ani[gid1][gid2] >= self.min_mash_intra_sp_ani:
                    ani_pairs.add((gid1, gid2))
                    ani_pairs.add((gid2, gid1))

        self.logger.info(
            ' - calculating FastANI between {:,} pairs with Mash ANI >= {:.1f}%.'
            .format(len(ani_pairs), self.min_mash_intra_sp_ani))
        ani_af = self.fastani.pairs(ani_pairs,
                                    genomes.genomic_files,
                                    report_progress=False,
                                    check_cache=True)
        self.fastani.write_cache(silence=True)

        # perform greedy dereplication
        sp_reps = []
        for idx, gid in enumerate(sorted_gids):
            # determine if genome clusters with existing representative
            clustered = False
            for rid in sp_reps:
                ani, af = FastANI.symmetric_ani(ani_af, gid, rid)

                if ani >= self.derep_ani and af >= self.derep_af:
                    clustered = True
                    break

            if not clustered:
                sp_reps.append(gid)

        self.logger.info(
            ' - dereplicated {} from {:,} to {:,} genomes.'.format(
                species, len(sorted_gids), len(sp_reps)))

        # assign clustered genomes to most similar representative
        subsp_clusters = {}
        for rid in sp_reps:
            subsp_clusters[rid] = [rid]

        non_rep_gids = set(sorted_gids) - set(sp_reps)
        for gid in non_rep_gids:
            closest_rid = None
            max_ani = 0
            max_af = 0
            for rid in sp_reps:
                ani, af = FastANI.symmetric_ani(ani_af, gid, rid)
                if ((ani > max_ani and af >= self.derep_af) or
                    (ani == max_ani and af >= max_af and af >= self.derep_af)):
                    max_ani = ani
                    max_af = af
                    closest_rid = rid

            assert closest_rid is not None
            subsp_clusters[closest_rid].append(gid)

        return subsp_clusters

    def derep_sp_clusters(self, genomes):
        """Dereplicate each GTDB species cluster."""

        mash_out_dir = os.path.join(self.output_dir, 'mash')
        if not os.path.exists(mash_out_dir):
            os.makedirs(mash_out_dir)

        derep_genomes = {}
        for rid, cids in genomes.sp_clusters.items():
            species = genomes[rid].gtdb_taxa.species

            self.logger.info(
                'Dereplicating {} with {:,} genomes [{:,} of {:,} ({:.2f}%) species].'
                .format(species, len(cids), len(derep_genomes),
                        len(genomes.sp_clusters),
                        len(derep_genomes) * 100.0 / len(genomes.sp_clusters)))

            subsp_clusters = self.dereplicate_species(species, rid, cids,
                                                      genomes, mash_out_dir)

            derep_genomes[species] = subsp_clusters

        return derep_genomes

    def run(self, gtdb_metadata_file, genomic_path_file):
        """Dereplicate GTDB species clusters using ANI/AF criteria."""

        # create GTDB genome sets
        self.logger.info('Creating GTDB genome set.')
        genomes = Genomes()
        genomes.load_from_metadata_file(gtdb_metadata_file)
        genomes.load_genomic_file_paths(genomic_path_file)
        self.logger.info(
            ' - genome set has {:,} species clusters spanning {:,} genomes.'.
            format(len(genomes.sp_clusters),
                   genomes.sp_clusters.total_num_genomes()))

        # dereplicate each species cluster
        self.logger.info(
            'Performing dereplication with ANI={:.1f}, AF={:.2f}, Mash ANI={:.2f}, max genomes={:,}.'
            .format(self.derep_ani, self.derep_af, self.min_mash_intra_sp_ani,
                    self.max_genomes_per_sp))
        derep_genomes = self.derep_sp_clusters(genomes)

        # write out `subspecies` clusters
        out_file = os.path.join(self.output_dir, 'subsp_clusters.tsv')
        fout = open(out_file, 'w')
        fout.write(
            'Genome ID\tGTDB Species\tGTDB Taxonomy\tPriority score\tNo. clustered genomes\tNo. clustered genomes\tClustered genomes\n'
        )
        for species, subsp_clusters in derep_genomes.items():
            for rid, cids in subsp_clusters.items():
                assert species == genomes[rid].gtdb_taxa.species
                fout.write('{}\t{}\t{}\t{:.3f}\t{}\t{}\n'.format(
                    rid, genomes[rid].gtdb_taxa.species,
                    genomes[rid].gtdb_taxa, self.priority_score(rid, genomes),
                    len(cids), ','.join(cids)))
Esempio n. 5
0
class ClusterNamedTypes(object):
    """Cluster genomes to selected GTDB type genomes."""
    def __init__(self, ani_sp, af_sp, ani_cache_file, cpus, output_dir):
        """Initialization."""

        check_dependencies(['fastANI', 'mash'])

        self.cpus = cpus
        self.output_dir = output_dir

        self.logger = logging.getLogger('timestamp')

        self.ani_sp = ani_sp
        self.af_sp = af_sp

        self.max_ani_neighbour = 97.0
        self.min_mash_ani = 90.0

        self.ClusteredGenome = namedtuple('ClusteredGenome', 'ani af gid')

        self.fastani = FastANI(ani_cache_file, cpus)

    def _type_genome_radius(self, type_gids, type_genome_ani_file):
        """Calculate circumscription radius for type genomes."""

        # set type radius for all type genomes to default values
        type_radius = {}
        for gid in type_gids:
            type_radius[gid] = GenomeRadius(ani=self.ani_sp,
                                            af=None,
                                            neighbour_gid=None)

        # determine closest ANI neighbour and restrict ANI radius as necessary
        with open(type_genome_ani_file) as f:
            header = f.readline().strip().split('\t')

            type_gid1_index = header.index('Type genome 1')
            type_gid2_index = header.index('Type genome 2')
            ani_index = header.index('ANI')
            af_index = header.index('AF')

            for line in f:
                line_split = line.strip().split('\t')

                type_gid1 = line_split[type_gid1_index]
                type_gid2 = line_split[type_gid2_index]

                if type_gid1 not in type_gids or type_gid2 not in type_gids:
                    continue

                ani = float(line_split[ani_index])
                af = float(line_split[af_index])

                if ani > type_radius[type_gid1].ani:
                    if af < self.af_sp:
                        if ani >= self.ani_sp:
                            self.logger.warning(
                                'ANI for %s and %s is >%.2f, but AF <%.2f [pair skipped].'
                                % (type_gid1, type_gid2, ani, af))
                        continue

                    if ani > self.max_ani_neighbour:
                        self.logger.error('ANI neighbour %s is >%.2f for %s.' %
                                          (type_gid2, ani, type_gid1))

                    type_radius[type_gid1] = GenomeRadius(
                        ani=ani, af=af, neighbour_gid=type_gid2)

        self.logger.info(
            'ANI circumscription radius: min=%.2f, mean=%.2f, max=%.2f' %
            (min([d.ani for d in type_radius.values()
                  ]), np_mean([d.ani for d in type_radius.values()
                               ]), max([d.ani for d in type_radius.values()])))

        return type_radius

    def _calculate_ani(self, type_gids, genome_files, ncbi_taxonomy,
                       type_genome_sketch_file):
        """Calculate ANI between type and non-type genomes."""

        mash = Mash(self.cpus)

        # create Mash sketch for type genomes
        if not type_genome_sketch_file or not os.path.exists(
                type_genome_sketch_file):
            type_genome_list_file = os.path.join(self.output_dir,
                                                 'gtdb_type_genomes.lst')
            type_genome_sketch_file = os.path.join(self.output_dir,
                                                   'gtdb_type_genomes.msh')
            mash.sketch(type_gids, genome_files, type_genome_list_file,
                        type_genome_sketch_file)

        # create Mash sketch for non-type genomes
        nontype_gids = set()
        for gid in genome_files:
            if gid not in type_gids:
                nontype_gids.add(gid)

        nontype_genome_list_file = os.path.join(self.output_dir,
                                                'gtdb_nontype_genomes.lst')
        nontype_genome_sketch_file = os.path.join(self.output_dir,
                                                  'gtdb_nontype_genomes.msh')
        mash.sketch(nontype_gids, genome_files, nontype_genome_list_file,
                    nontype_genome_sketch_file)

        # get Mash distances
        mash_dist_file = os.path.join(self.output_dir,
                                      'gtdb_type_vs_nontype_genomes.dst')
        mash.dist(
            float(100 - self.min_mash_ani) / 100, type_genome_sketch_file,
            nontype_genome_sketch_file, mash_dist_file)

        # read Mash distances
        mash_ani = mash.read_ani(mash_dist_file)

        # get pairs above Mash threshold
        mash_ani_pairs = []
        for qid in mash_ani:
            for rid in mash_ani[qid]:
                if mash_ani[qid][rid] >= self.min_mash_ani:
                    if qid != rid:
                        mash_ani_pairs.append((qid, rid))
                        mash_ani_pairs.append((rid, qid))

        self.logger.info(
            'Identified %d genome pairs with a Mash ANI >= %.1f%%.' %
            (len(mash_ani_pairs), self.min_mash_ani))

        # calculate ANI between pairs
        self.logger.info('Calculating ANI between %d genome pairs:' %
                         len(mash_ani_pairs))
        if True:  #***
            ani_af = self.fastani.pairs(mash_ani_pairs, genome_files)
            pickle.dump(
                ani_af,
                open(
                    os.path.join(self.output_dir,
                                 'ani_af_type_vs_nontype.pkl'), 'wb'))
        else:
            ani_af = pickle.load(
                open(
                    os.path.join(self.output_dir,
                                 'ani_af_type_vs_nontype.pkl'), 'rb'))

        return ani_af

    def _cluster(self, ani_af, nontype_gids, type_radius):
        """Cluster non-type genomes to type genomes using species specific ANI thresholds."""

        clusters = {}
        for rep_id in type_radius:
            clusters[rep_id] = []

        for idx, nontype_gid in enumerate(nontype_gids):
            if idx % 100 == 0:
                sys.stdout.write('==> Processed %d of %d genomes.\r' %
                                 (idx + 1, len(nontype_gids)))
                sys.stdout.flush()

            if nontype_gid not in ani_af:
                continue

            closest_type_gid = None
            closest_ani = 0
            closest_af = 0
            for type_gid in type_radius:
                if type_gid not in ani_af[nontype_gid]:
                    continue

                ani, af = symmetric_ani(ani_af, type_gid, nontype_gid)

                if af >= self.af_sp:
                    if ani > closest_ani or (ani == closest_ani
                                             and af > closest_af):
                        closest_type_gid = type_gid
                        closest_ani = ani
                        closest_af = af

            if closest_type_gid:
                if closest_ani > type_radius[closest_type_gid].ani:
                    clusters[closest_type_gid].append(
                        self.ClusteredGenome(gid=nontype_gid,
                                             ani=closest_ani,
                                             af=closest_af))

        sys.stdout.write('==> Processed %d of %d genomes.\r' %
                         (idx, len(nontype_gids)))
        sys.stdout.flush()
        sys.stdout.write('\n')

        self.logger.info(
            'Assigned %d genomes to representatives.' %
            sum([len(clusters[type_gid]) for type_gid in clusters]))

        return clusters

    def run(self, qc_file, metadata_file, genome_path_file,
            named_type_genome_file, type_genome_ani_file, mash_sketch_file,
            species_exception_file):
        """Cluster genomes to selected GTDB type genomes."""

        # identify genomes failing quality criteria
        self.logger.info('Reading QC file.')
        passed_qc = read_qc_file(qc_file)
        self.logger.info('Identified %d genomes passing QC.' % len(passed_qc))

        # get type genomes
        type_gids = set()
        species_type_gid = {}
        with open(named_type_genome_file) as f:
            header = f.readline().strip().split('\t')
            type_gid_index = header.index('Type genome')
            sp_index = header.index('NCBI species')

            for line in f:
                line_split = line.strip().split('\t')
                type_gids.add(line_split[type_gid_index])
                species_type_gid[
                    line_split[type_gid_index]] = line_split[sp_index]
        self.logger.info('Identified type genomes for %d species.' %
                         len(species_type_gid))

        # calculate circumscription radius for type genomes
        self.logger.info(
            'Determining ANI species circumscription for %d type genomes.' %
            len(type_gids))
        type_radius = self._type_genome_radius(type_gids, type_genome_ani_file)
        assert (len(type_radius) == len(species_type_gid))

        write_rep_radius(
            type_radius, species_type_gid,
            os.path.join(self.output_dir, 'gtdb_type_genome_ani_radius.tsv'))

        # get path to genome FASTA files
        self.logger.info('Reading path to genome FASTA files.')
        genome_files = read_genome_path(genome_path_file)
        self.logger.info('Read path for %d genomes.' % len(genome_files))
        for gid in set(genome_files):
            if gid not in passed_qc:
                genome_files.pop(gid)
        self.logger.info(
            'Considering %d genomes after removing unwanted User genomes.' %
            len(genome_files))
        assert (len(genome_files) == len(passed_qc))

        # get GTDB and NCBI taxonomy strings for each genome
        self.logger.info('Reading NCBI taxonomy from GTDB metadata file.')
        ncbi_taxonomy, ncbi_update_count = read_gtdb_ncbi_taxonomy(
            metadata_file, species_exception_file)
        self.logger.info(
            'Read NCBI taxonomy for %d genomes with %d manually defined updates.'
            % (len(ncbi_taxonomy), ncbi_update_count))

        # calculate ANI between type and non-type genomes
        self.logger.info('Calculating ANI between type and non-type genomes.')
        ani_af = self._calculate_ani(type_gids, genome_files, ncbi_taxonomy,
                                     mash_sketch_file)

        # cluster remaining genomes to type genomes
        nontype_gids = set(genome_files) - set(type_radius)
        self.logger.info(
            'Clustering %d non-type genomes to type genomes using species specific ANI radii.'
            % len(nontype_gids))
        clusters = self._cluster(ani_af, nontype_gids, type_radius)

        # write out clusters
        write_clusters(
            clusters, type_radius, species_type_gid,
            os.path.join(self.output_dir, 'gtdb_type_genome_clusters.tsv'))
class UpdateClusterDeNovo(object):
    """Infer de novo species clusters and representatives for remaining genomes."""

    def __init__(self, ani_sp, af_sp, ani_cache_file, cpus, output_dir):
        """Initialization."""
        
        check_dependencies(['fastANI', 'mash'])
        
        self.cpus = cpus
        self.output_dir = output_dir

        self.logger = logging.getLogger('timestamp')
        
        self.true_str = ['t', 'T', 'true', 'True']
        
        self.ani_sp = ani_sp
        self.af_sp = af_sp

        self.min_mash_ani = 90.0

        self.fastani = FastANI(ani_cache_file, cpus)
        
    def _parse_named_clusters(self, named_cluster_file):
        """Parse named GTDB species clusters."""
        
        rep_gids = set()
        rep_clustered_gids = set()
        rep_radius = {}
        with open(named_cluster_file) as f:
            headers = f.readline().strip().split('\t')
            
            rep_index = headers.index('Representative')
            num_clustered_index = headers.index('No. clustered genomes')
            clustered_genomes_index = headers.index('Clustered genomes')
            closest_type_index = headers.index('Closest representative')
            ani_radius_index = headers.index('ANI radius')
            af_index = headers.index('AF closest')

            for line in f:
                line_split = line.strip().split('\t')

                rep_gid = line_split[rep_index]
                rep_gids.add(rep_gid)
                
                num_clustered = int(line_split[num_clustered_index])
                if num_clustered > 0:
                    for gid in [g.strip() for g in line_split[clustered_genomes_index].split(',')]:
                        rep_clustered_gids.add(gid)
                        
                rep_radius[rep_gid] = GenomeRadius(ani = float(line_split[ani_radius_index]), 
                                                     af = float(line_split[af_index]),
                                                     neighbour_gid = line_split[closest_type_index])
                        
        return rep_gids, rep_clustered_gids, rep_radius

    def _nonrep_radius(self, unclustered_gids, rep_gids, ani_af_rep_vs_nonrep):
        """Calculate circumscription radius for unclustered, nontype genomes."""
        
        # set radius for genomes to default values
        nonrep_radius = {}
        for gid in unclustered_gids:
            nonrep_radius[gid] = GenomeRadius(ani = self.ani_sp, 
                                                     af = None,
                                                     neighbour_gid = None)

        # determine closest type ANI neighbour and restrict ANI radius as necessary
        ani_af = pickle.load(open(ani_af_rep_vs_nonrep, 'rb'))
        for nonrep_gid in unclustered_gids:
            if nonrep_gid not in ani_af:
                continue
                    
            for rep_gid in rep_gids:
                if rep_gid not in ani_af[nonrep_gid]:
                    continue
                    
                ani, af = symmetric_ani(ani_af, nonrep_gid, rep_gid)

                if ani > nonrep_radius[nonrep_gid].ani and af >= self.af_sp:
                    nonrep_radius[nonrep_gid] = GenomeRadius(ani = ani, 
                                                             af = af,
                                                             neighbour_gid = rep_gid)
                    
        self.logger.info('ANI circumscription radius: min={:.2f}, mean={:.2f}, max={:.2f}'.format(
                                min([d.ani for d in nonrep_radius.values()]), 
                                np_mean([d.ani for d in nonrep_radius.values()]), 
                                max([d.ani for d in nonrep_radius.values()])))
                        
        return nonrep_radius
        
    def _mash_ani_unclustered(self, cur_genomes, gids):
        """Calculate pairwise Mash ANI estimates between genomes."""
        
        mash = Mash(self.cpus)
        
        # create Mash sketch for potential representative genomes
        mash_nontype_sketch_file = os.path.join(self.output_dir, 'gtdb_unclustered_genomes.msh')
        genome_list_file = os.path.join(self.output_dir, 'gtdb_unclustered_genomes.lst')
        mash.sketch(gids, cur_genomes.genomic_files, genome_list_file, mash_nontype_sketch_file)

        # get Mash distances
        mash_dist_file = os.path.join(self.output_dir, 'gtdb_unclustered_genomes.dst')
        mash.dist_pairwise( float(100 - self.min_mash_ani)/100, mash_nontype_sketch_file, mash_dist_file)

        # read Mash distances
        mash_ani = mash.read_ani(mash_dist_file)
        
        # report pairs above Mash threshold
        mash_ani_pairs = []
        for qid in mash_ani:
            for rid in mash_ani[qid]:
                if mash_ani[qid][rid] >= self.min_mash_ani:
                    n_qid = cur_genomes.user_uba_id_map.get(qid, qid)
                    n_rid = cur_genomes.user_uba_id_map.get(rid, rid)
                    if n_qid != n_rid:
                        mash_ani_pairs.append((n_qid, n_rid))
                        mash_ani_pairs.append((n_rid, n_qid))
                
        self.logger.info('Identified {:,} genome pairs with a Mash ANI >= {:.1f}%.'.format(
                            len(mash_ani_pairs), 
                            self.min_mash_ani))

        return mash_ani
        
    def _selected_rep_genomes(self,
                                cur_genomes,
                                nonrep_radius, 
                                unclustered_qc_gids, 
                                mash_ani):
        """Select de novo representatives for species clusters in a greedy fashion using species-specific ANI thresholds."""

        # sort genomes by quality score
        self.logger.info('Selecting de novo representatives in a greedy manner based on quality.')
        q = {gid:cur_genomes[gid].score_type_strain() for gid in unclustered_qc_gids}
        q_sorted = sorted(q.items(), key=lambda kv: (kv[1], kv[0]), reverse=True)

        # greedily determine representatives for new species clusters
        cluster_rep_file = os.path.join(self.output_dir, 'cluster_reps.tsv')
        clusters = set()
        if not os.path.exists(cluster_rep_file):
            clustered_genomes = 0
            max_ani_pairs = 0
            for idx, (cur_gid, _score) in enumerate(q_sorted):

                # determine reference genomes to calculate ANI between
                ani_pairs = []
                if cur_gid in mash_ani:
                    for rep_gid in clusters:
                        if mash_ani[cur_gid].get(rep_gid, 0) >= self.min_mash_ani:
                            ani_pairs.append((cur_gid, rep_gid))
                            ani_pairs.append((rep_gid, cur_gid))

                # determine if genome clusters with representative
                clustered = False
                if ani_pairs:
                    if len(ani_pairs) > max_ani_pairs:
                        max_ani_pairs = len(ani_pairs)
                    
                    ani_af = self.fastani.pairs(ani_pairs, cur_genomes.genomic_files, report_progress=False)

                    closest_rep_gid = None
                    closest_rep_ani = 0
                    closest_rep_af = 0
                    for rep_gid in clusters:
                        ani, af = symmetric_ani(ani_af, cur_gid, rep_gid)

                        if af >= self.af_sp:
                            if ani > closest_rep_ani or (ani == closest_rep_ani and af > closest_rep_af):
                                closest_rep_gid = rep_gid
                                closest_rep_ani = ani
                                closest_rep_af = af

                        if ani > nonrep_radius[cur_gid].ani and af >= self.af_sp:
                            nonrep_radius[cur_gid] = GenomeRadius(ani = ani, 
                                                                         af = af,
                                                                         neighbour_gid = rep_gid)
                                                                         
                    if closest_rep_gid and closest_rep_ani > nonrep_radius[closest_rep_gid].ani:
                        clustered = True
                    
                if not clustered:
                    # genome is a new species cluster representative
                    clusters.add(cur_gid)
                else:
                    clustered_genomes += 1
                
                if (idx+1) % 10 == 0 or idx+1 == len(q_sorted):
                    statusStr = '-> Clustered {:,} of {:,} ({:.2f}%) genomes [ANI pairs: {:,}; clustered genomes: {:,}; clusters: {:,}].'.format(
                                    idx+1, 
                                    len(q_sorted), 
                                    float(idx+1)*100/len(q_sorted),
                                    max_ani_pairs,
                                    clustered_genomes,
                                    len(clusters)).ljust(96)
                    sys.stdout.write('{}\r'.format(statusStr))
                    sys.stdout.flush()
                    max_ani_pairs = 0
            sys.stdout.write('\n')
            
            # write out selected cluster representative
            fout = open(cluster_rep_file, 'w')
            for gid in clusters:
                fout.write('{}\n'.format(gid))
            fout.close()
        else:
            # read cluster reps from file
            self.logger.warning('Using previously determined cluster representatives.')
            for line in open(cluster_rep_file):
                gid = line.strip()
                clusters.add(gid)
                
        self.logger.info('Selected {:,} representative genomes for de novo species clusters.'.format(len(clusters)))
        
        return clusters
        
    def _cluster_genomes(self,
                            cur_genomes,
                            de_novo_rep_gids,
                            named_rep_gids, 
                            final_cluster_radius):
        """Cluster new representatives to representatives of named GTDB species clusters."""
        
        all_reps = de_novo_rep_gids.union(named_rep_gids)
        nonrep_gids = set(cur_genomes.genomes.keys()) - all_reps
        self.logger.info('Clustering {:,} genomes to {:,} named and de novo representatives.'.format(
                            len(nonrep_gids), len(all_reps)))

        if True: #***
            # calculate MASH distance between non-representatives and representatives genomes
            mash = Mash(self.cpus)
            
            mash_rep_sketch_file = os.path.join(self.output_dir, 'gtdb_rep_genomes.msh')
            rep_genome_list_file = os.path.join(self.output_dir, 'gtdb_rep_genomes.lst')
            mash.sketch(all_reps, cur_genomes.genomic_files, rep_genome_list_file, mash_rep_sketch_file)

            mash_none_rep_sketch_file = os.path.join(self.output_dir, 'gtdb_nonrep_genomes.msh')
            non_rep_file = os.path.join(self.output_dir, 'gtdb_nonrep_genomes.lst')
            mash.sketch(nonrep_gids, cur_genomes.genomic_files, non_rep_file, mash_none_rep_sketch_file)

            # get Mash distances
            mash_dist_file = os.path.join(self.output_dir, 'gtdb_rep_vs_nonrep_genomes.dst')
            mash.dist(float(100 - self.min_mash_ani)/100, 
                        mash_rep_sketch_file, 
                        mash_none_rep_sketch_file, 
                        mash_dist_file)

            # read Mash distances
            mash_ani = mash.read_ani(mash_dist_file)
            
            # calculate ANI between non-representatives and representatives genomes
            clusters = {}
            for gid in all_reps:
                clusters[gid] = []

            if False: #***
                mash_ani_pairs = []
                for gid in nonrep_gids:
                    if gid in mash_ani:
                        for rid in clusters:
                            if mash_ani[gid].get(rid, 0) >= self.min_mash_ani:
                                n_gid = cur_genomes.user_uba_id_map.get(gid, gid)
                                n_rid = cur_genomes.user_uba_id_map.get(rid, rid)
                                if n_gid != n_rid:
                                    mash_ani_pairs.append((n_gid, n_rid))
                                    mash_ani_pairs.append((n_rid, n_gid))
                                    
            mash_ani_pairs = []
            for qid in mash_ani:
                n_qid = cur_genomes.user_uba_id_map.get(qid, qid)
                assert n_qid in nonrep_gids
                
                for rid in mash_ani[qid]:
                    n_rid = cur_genomes.user_uba_id_map.get(rid, rid)
                    assert n_rid in all_reps
                    
                    if (mash_ani[qid][rid] >= self.min_mash_ani
                        and n_qid != n_rid):
                        mash_ani_pairs.append((n_qid, n_rid))
                        mash_ani_pairs.append((n_rid, n_qid))
                            
            self.logger.info('Calculating ANI between {:,} species clusters and {:,} unclustered genomes ({:,} pairs):'.format(
                                len(clusters), 
                                len(nonrep_gids),
                                len(mash_ani_pairs)))
            ani_af = self.fastani.pairs(mash_ani_pairs, cur_genomes.genomic_files)

            # assign genomes to closest representatives 
            # that is within the representatives ANI radius
            self.logger.info('Assigning genomes to closest representative.')
            for idx, cur_gid in enumerate(nonrep_gids):
                closest_rep_gid = None
                closest_rep_ani = 0
                closest_rep_af = 0
                for rep_gid in clusters:
                    ani, af = symmetric_ani(ani_af, cur_gid, rep_gid)
                    
                    if ani >= final_cluster_radius[rep_gid].ani and af >= self.af_sp:
                        if ani > closest_rep_ani or (ani == closest_rep_ani and af > closest_rep_af):
                            closest_rep_gid = rep_gid
                            closest_rep_ani = ani
                            closest_rep_af = af
                    
                if closest_rep_gid:
                    clusters[closest_rep_gid].append(ClusteredGenome(gid=cur_gid, 
                                                                            ani=closest_rep_ani, 
                                                                            af=closest_rep_af))
                else:
                    self.logger.warning('Failed to assign genome {} to representative.'.format(cur_gid))
                    if closest_rep_gid:
                        self.logger.warning(' ...closest_rep_gid = {}'.format(closest_rep_gid))
                        self.logger.warning(' ...closest_rep_ani = {:.2f}'.format(closest_rep_ani))
                        self.logger.warning(' ...closest_rep_af = {:.2f}'.format(closest_rep_af))
                        self.logger.warning(' ...closest rep radius = {:.2f}'.format(final_cluster_radius[closest_rep_gid].ani))
                    else:
                        self.logger.warning(' ...no representative with an AF >{:.2f} identified.'.format(self.af_sp))
                 
                statusStr = '-> Assigned {:,} of {:,} ({:.2f}%) genomes.'.format(idx+1, 
                                                                                    len(nonrep_gids), 
                                                                                    float(idx+1)*100/len(nonrep_gids)).ljust(86)
                sys.stdout.write('{}\r'.format(statusStr))
                sys.stdout.flush()
            sys.stdout.write('\n')
            
            pickle.dump(clusters, open(os.path.join(self.output_dir, 'clusters.pkl'), 'wb'))
            pickle.dump(ani_af, open(os.path.join(self.output_dir, 'ani_af_rep_vs_nonrep.de_novo.pkl'), 'wb'))
        else:
            self.logger.warning('Using previously calculated results in: {}'.format('clusters.pkl'))
            clusters = pickle.load(open(os.path.join(self.output_dir, 'clusters.pkl'), 'rb'))
            
            self.logger.warning('Using previously calculated results in: {}'.format('ani_af_rep_vs_nonrep.de_novo.pkl'))
            ani_af = pickle.load(open(os.path.join(self.output_dir, 'ani_af_rep_vs_nonrep.de_novo.pkl'), 'rb'))

        return clusters, ani_af

    def run(self, named_cluster_file,
                    cur_gtdb_metadata_file,
                    cur_genomic_path_file,
                    uba_genome_paths,
                    qc_passed_file,
                    ncbi_genbank_assembly_file,
                    untrustworthy_type_file,
                    ani_af_rep_vs_nonrep,
                    gtdb_type_strains_ledger):
        """Infer de novo species clusters and representatives for remaining genomes."""
        
        # create current GTDB genome sets
        self.logger.info('Creating current GTDB genome set.')
        cur_genomes = Genomes()
        cur_genomes.load_from_metadata_file(cur_gtdb_metadata_file,
                                                gtdb_type_strains_ledger=gtdb_type_strains_ledger,
                                                create_sp_clusters=False,
                                                uba_genome_file=uba_genome_paths,
                                                qc_passed_file=qc_passed_file,
                                                ncbi_genbank_assembly_file=ncbi_genbank_assembly_file,
                                                untrustworthy_type_ledger=untrustworthy_type_file)
        self.logger.info(f' ... current genome set contains {len(cur_genomes):,} genomes.')

        # get path to previous and current genomic FASTA files
        self.logger.info('Reading path to current genomic FASTA files.')
        cur_genomes.load_genomic_file_paths(cur_genomic_path_file)
        cur_genomes.load_genomic_file_paths(uba_genome_paths)

        # determine representatives and genomes clustered to each representative
        self.logger.info('Reading named GTDB species clusters.')
        named_rep_gids, rep_clustered_gids, rep_radius = self._parse_named_clusters(named_cluster_file)
        self.logger.info(' ... identified {:,} representative genomes.'.format(len(named_rep_gids)))
        self.logger.info(' ... identified {:,} clustered genomes.'.format(len(rep_clustered_gids)))
        
        # determine genomes left to be clustered
        unclustered_gids = set(cur_genomes.genomes.keys()) - named_rep_gids - rep_clustered_gids
        self.logger.info('Identified {:,} unclustered genomes passing QC.'.format(len(unclustered_gids)))

        # establish closest representative for each unclustered genome
        self.logger.info('Determining ANI circumscription for {:,} unclustered genomes.'.format(len(unclustered_gids)))
        nonrep_radius = self._nonrep_radius(unclustered_gids, named_rep_gids, ani_af_rep_vs_nonrep)

        # calculate Mash ANI estimates between unclustered genomes
        self.logger.info('Calculating Mash ANI estimates between unclustered genomes.')
        mash_anis = self._mash_ani_unclustered(cur_genomes, unclustered_gids)

        # select de novo species representatives in a greedy fashion based on genome quality
        de_novo_rep_gids = self._selected_rep_genomes(cur_genomes,
                                                        nonrep_radius, 
                                                        unclustered_gids, 
                                                        mash_anis)

        # cluster all non-representative genomes to representative genomes
        final_cluster_radius = rep_radius.copy()
        final_cluster_radius.update(nonrep_radius)
        
        final_clusters, ani_af = self._cluster_genomes(cur_genomes,
                                                        de_novo_rep_gids,
                                                        named_rep_gids, 
                                                        final_cluster_radius)

        # remove genomes that are not representatives of a species cluster and then write out representative ANI radius
        for gid in set(final_cluster_radius) - set(final_clusters):
            del final_cluster_radius[gid]

        self.logger.info('Writing {:,} species clusters to file.'.format(len(final_clusters)))
        self.logger.info('Writing {:,} cluster radius information to file.'.format(len(final_cluster_radius)))
        
        write_clusters(final_clusters, 
                        final_cluster_radius, 
                        cur_genomes,
                        os.path.join(self.output_dir, 'gtdb_clusters_de_novo.tsv'))

        write_rep_radius(final_cluster_radius, 
                            cur_genomes,
                            os.path.join(self.output_dir, 'gtdb_ani_radius_de_novo.tsv'))
Esempio n. 7
0
class ClusterDeNovo(object):
    """Infer de novo species clusters and type genomes for remaining genomes."""

    def __init__(self, ani_sp, af_sp, ani_cache_file, cpus, output_dir):
        """Initialization."""
        
        check_dependencies(['fastANI', 'mash'])
        
        self.cpus = cpus
        self.output_dir = output_dir

        self.logger = logging.getLogger('timestamp')
        
        self.true_str = ['t', 'T', 'true', 'True']
        
        self.ani_sp = ani_sp
        self.af_sp = af_sp

        self.min_mash_ani = 90.0
        
        self.ClusteredGenome = namedtuple('ClusteredGenome', 'ani af gid')
        
        self.fastani = FastANI(ani_cache_file, cpus)
        
    def _parse_type_clusters(self, type_genome_cluster_file):
        """Parse type genomes clustering information."""
        
        type_species = set()
        species_type_gid = {}
        type_gids = set()
        type_clustered_gids = set()
        type_radius = {}
        with open(type_genome_cluster_file) as f:
            headers = f.readline().strip().split('\t')
            
            type_sp_index = headers.index('NCBI species')
            type_genome_index = headers.index('Type genome')
            num_clustered_index = headers.index('No. clustered genomes')
            clustered_genomes_index = headers.index('Clustered genomes')
            closest_type_index = headers.index('Closest type genome')
            ani_radius_index = headers.index('ANI radius')
            af_index = headers.index('AF closest')

            for line in f:
                line_split = line.strip().split('\t')
                
                type_sp = line_split[type_sp_index]
                type_species.add(type_sp)
                
                type_gid = line_split[type_genome_index]
                type_gids.add(type_gid)
                
                species_type_gid[type_gid] = type_sp
                
                num_clustered = int(line_split[num_clustered_index])
                if num_clustered > 0:
                    for gid in [g.strip() for g in line_split[clustered_genomes_index].split(',')]:
                        type_clustered_gids.add(gid)
                        
                type_radius[type_gid] = GenomeRadius(ani = float(line_split[ani_radius_index]), 
                                                     af = float(line_split[af_index]),
                                                     neighbour_gid = line_split[closest_type_index])
                        
        return type_species, species_type_gid, type_gids, type_clustered_gids, type_radius
        
    def _parse_synonyms(self, type_genome_synonym_file):
        """Parse synonyms."""
        
        synonyms = set()
        with open(type_genome_synonym_file) as f:
            headers = f.readline().strip().split('\t')
            
            synonym_index = headers.index('Synonym')
            
            for line in f:
                line_split = line.strip().split('\t')
                
                synonym = line_split[synonym_index]
                synonyms.add(synonym)
                
        return synonyms
        
    def _nontype_radius(self, unclustered_gids, type_gids, ani_af_nontype_vs_type):
        """Calculate circumscription radius for unclustered, nontype genomes."""
        
        # set type radius for all type genomes to default values
        nontype_radius = {}
        for gid in unclustered_gids:
            nontype_radius[gid] = GenomeRadius(ani = self.ani_sp, 
                                                     af = None,
                                                     neighbour_gid = None)

        # determine closest type ANI neighbour and restrict ANI radius as necessary
        ani_af = pickle.load(open(ani_af_nontype_vs_type, 'rb'))
        for nontype_gid in unclustered_gids:
            if nontype_gid not in ani_af:
                continue
                    
            for type_gid in type_gids:
                if type_gid not in ani_af[nontype_gid]:
                    continue
                    
                ani, af = symmetric_ani(ani_af, nontype_gid, type_gid)

                if ani > nontype_radius[nontype_gid].ani and af >= self.af_sp:
                    nontype_radius[nontype_gid] = GenomeRadius(ani = ani, 
                                                                 af = af,
                                                                 neighbour_gid = type_gid)
                    
        self.logger.info('ANI circumscription radius: min=%.2f, mean=%.2f, max=%.2f' % (
                                min([d.ani for d in nontype_radius.values()]), 
                                np_mean([d.ani for d in nontype_radius.values()]), 
                                max([d.ani for d in nontype_radius.values()])))
                        
        return nontype_radius
        
    def _mash_ani_unclustered(self, genome_files, gids):
        """Calculate pairwise Mash ANI estimates between genomes."""
        
        mash = Mash(self.cpus)
        
        # create Mash sketch for potential representative genomes
        mash_nontype_sketch_file = os.path.join(self.output_dir, 'gtdb_unclustered_genomes.msh')
        genome_list_file = os.path.join(self.output_dir, 'gtdb_unclustered_genomes.lst')
        mash.sketch(gids, genome_files, genome_list_file, mash_nontype_sketch_file)

        # get Mash distances
        mash_dist_file = os.path.join(self.output_dir, 'gtdb_unclustered_genomes.dst')
        mash.dist_pairwise( float(100 - self.min_mash_ani)/100, mash_nontype_sketch_file, mash_dist_file)

        # read Mash distances
        mash_ani = mash.read_ani(mash_dist_file)
        
        # report pairs above Mash threshold
        mash_ani_pairs = []
        for qid in mash_ani:
            for rid in mash_ani[qid]:
                if mash_ani[qid][rid] >= self.min_mash_ani:
                    if qid != rid:
                        mash_ani_pairs.append((qid, rid))
                        mash_ani_pairs.append((rid, qid))
                
        self.logger.info('Identified %d genome pairs with a Mash ANI >= %.1f%%.' % (len(mash_ani_pairs), self.min_mash_ani))

        return mash_ani
        
    def _selected_rep_genomes(self,
                                genome_files,
                                nontype_radius, 
                                unclustered_qc_gids, 
                                mash_ani,
                                quality_metadata,
                                rnd_type_genome):
        """Select representative genomes for species clusters in a  greedy fashion using species-specific ANI thresholds."""

        # sort genomes by quality score
        if rnd_type_genome:
            self.logger.info('Selecting random de novo type genomes.')
            sorted_gids = []
            for gid in random.sample(unclustered_qc_gids, len(unclustered_qc_gids)):
                sorted_gids.append((gid, 0))
        else:
            self.logger.info('Selecting de novo type genomes in a greedy manner based on quality.')
            qscore = quality_score(unclustered_qc_gids, quality_metadata)
            sorted_gids = sorted(qscore.items(), key=operator.itemgetter(1), reverse=True)

        # greedily determine representatives for new species clusters
        cluster_rep_file = os.path.join(self.output_dir, 'cluster_reps.tsv')
        clusters = set()
        if not os.path.exists(cluster_rep_file):
            self.logger.info('Clustering genomes to identify representatives.')
            clustered_genomes = 0
            max_ani_pairs = 0
            for idx, (cur_gid, _score) in enumerate(sorted_gids):

                # determine reference genomes to calculate ANI between
                ani_pairs = []
                if cur_gid in mash_ani:
                    for rep_gid in clusters:
                        if mash_ani[cur_gid].get(rep_gid, 0) >= self.min_mash_ani:
                            ani_pairs.append((cur_gid, rep_gid))
                            ani_pairs.append((rep_gid, cur_gid))

                # determine if genome clusters with representative
                clustered = False
                if ani_pairs:
                    if len(ani_pairs) > max_ani_pairs:
                        max_ani_pairs = len(ani_pairs)
                    
                    ani_af = self.fastani.pairs(ani_pairs, genome_files, report_progress=False)

                    closest_rep_gid = None
                    closest_rep_ani = 0
                    closest_rep_af = 0
                    for rep_gid in clusters:
                        ani, af = symmetric_ani(ani_af, cur_gid, rep_gid)

                        if af >= self.af_sp:
                            if ani > closest_rep_ani or (ani == closest_rep_ani and af > closest_rep_af):
                                closest_rep_gid = rep_gid
                                closest_rep_ani = ani
                                closest_rep_af = af

                        if ani > nontype_radius[cur_gid].ani and af >= self.af_sp:
                            nontype_radius[cur_gid] = GenomeRadius(ani = ani, 
                                                                         af = af,
                                                                         neighbour_gid = rep_gid)
                                                                         
                    if closest_rep_gid and closest_rep_ani > nontype_radius[closest_rep_gid].ani:
                        clustered = True
                    
                if not clustered:
                    # genome is a new species cluster representative
                    clusters.add(cur_gid)
                else:
                    clustered_genomes += 1
                
                if (idx+1) % 10 == 0 or idx+1 == len(sorted_gids):
                    statusStr = '-> Clustered %d of %d (%.2f%%) genomes [ANI pairs: %d; clustered genomes: %d; clusters: %d].'.ljust(96) % (
                                    idx+1, 
                                    len(sorted_gids), 
                                    float(idx+1)*100/len(sorted_gids),
                                    max_ani_pairs,
                                    clustered_genomes,
                                    len(clusters))
                    sys.stdout.write('%s\r' % statusStr)
                    sys.stdout.flush()
                    max_ani_pairs = 0
            sys.stdout.write('\n')
            
            # write out selected cluster representative
            fout = open(cluster_rep_file, 'w')
            for gid in clusters:
                fout.write('%s\n' % gid)
            fout.close()
        else:
            # read cluster reps from file
            self.logger.warning('Using previously determined cluster representatives.')
            for line in open(cluster_rep_file):
                gid = line.strip()
                clusters.add(gid)
                
        self.logger.info('Selected %d representative genomes for de novo species clusters.' % len(clusters))
        
        return clusters
        
    def _cluster_genomes(self, 
                            genome_files,
                            rep_genomes,
                            type_gids, 
                            passed_qc,
                            final_cluster_radius):
        """Cluster all non-type/representative genomes to selected type/representatives genomes."""

        all_reps = rep_genomes.union(type_gids)
        
        # calculate MASH distance between non-type/representative genomes and selected type/representatives genomes
        mash = Mash(self.cpus)
        
        mash_type_rep_sketch_file = os.path.join(self.output_dir, 'gtdb_rep_genomes.msh')
        type_rep_genome_list_file = os.path.join(self.output_dir, 'gtdb_rep_genomes.lst')
        mash.sketch(all_reps, genome_files, type_rep_genome_list_file, mash_type_rep_sketch_file)
        
        mash_none_rep_sketch_file = os.path.join(self.output_dir, 'gtdb_nonrep_genomes.msh')
        type_none_rep_file = os.path.join(self.output_dir, 'gtdb_nonrep_genomes.lst')
        mash.sketch(passed_qc - all_reps, genome_files, type_none_rep_file, mash_none_rep_sketch_file)

        # get Mash distances
        mash_dist_file = os.path.join(self.output_dir, 'gtdb_rep_vs_nonrep_genomes.dst')
        mash.dist(float(100 - self.min_mash_ani)/100, mash_type_rep_sketch_file, mash_none_rep_sketch_file, mash_dist_file)

        # read Mash distances
        mash_ani = mash.read_ani(mash_dist_file)
        
        # calculate ANI between non-type/representative genomes and selected type/representatives genomes
        clusters = {}
        for gid in all_reps:
            clusters[gid] = []
        
        genomes_to_cluster = passed_qc - set(clusters)
        ani_pairs = []
        for gid in genomes_to_cluster:
            if gid in mash_ani:
                for rep_gid in clusters:
                    if mash_ani[gid].get(rep_gid, 0) >= self.min_mash_ani:
                        ani_pairs.append((gid, rep_gid))
                        ani_pairs.append((rep_gid, gid))
                        
        self.logger.info('Calculating ANI between %d species clusters and %d unclustered genomes (%d pairs):' % (
                            len(clusters), 
                            len(genomes_to_cluster),
                            len(ani_pairs)))
        ani_af = self.fastani.pairs(ani_pairs, genome_files)

        # assign genomes to closest representatives 
        # that is within the representatives ANI radius
        self.logger.info('Assigning genomes to closest representative.')
        for idx, cur_gid in enumerate(genomes_to_cluster):
            closest_rep_gid = None
            closest_rep_ani = 0
            closest_rep_af = 0
            for rep_gid in clusters:
                ani, af = symmetric_ani(ani_af, cur_gid, rep_gid)
                
                if ani >= final_cluster_radius[rep_gid].ani and af >= self.af_sp:
                    if ani > closest_rep_ani or (ani == closest_rep_ani and af > closest_rep_af):
                        closest_rep_gid = rep_gid
                        closest_rep_ani = ani
                        closest_rep_af = af
                
            if closest_rep_gid:
                clusters[closest_rep_gid].append(self.ClusteredGenome(gid=cur_gid, 
                                                                        ani=closest_rep_ani, 
                                                                        af=closest_rep_af))
            else:
                self.logger.warning('Failed to assign genome %s to representative.' % cur_gid)
                if closest_rep_gid:
                    self.logger.warning(' ...closest_rep_gid = %s' % closest_rep_gid)
                    self.logger.warning(' ...closest_rep_ani = %.2f' % closest_rep_ani)
                    self.logger.warning(' ...closest_rep_af = %.2f' % closest_rep_af)
                    self.logger.warning(' ...closest rep radius = %.2f' % final_cluster_radius[closest_rep_gid].ani)
                else:
                    self.logger.warning(' ...no representative with an AF >%.2f identified.' % self.af_sp)
             
            statusStr = '-> Assigned %d of %d (%.2f%%) genomes.'.ljust(86) % (idx+1, 
                                                                                len(genomes_to_cluster), 
                                                                                float(idx+1)*100/len(genomes_to_cluster))
            sys.stdout.write('%s\r' % statusStr)
            sys.stdout.flush()
        sys.stdout.write('\n')

        return clusters, ani_af
        
    def _assign_species_names(self, clusters, names_in_use, gtdb_taxonomy, gtdb_user_to_genbank):
        """Assign a species name to each species cluster."""
        
        orig_names_in_use = set(names_in_use)

        fout = open(os.path.join(self.output_dir, 'gtdb_assigned_sp.tsv'), 'w')
        fout.write('Representative genome\tAssigned species\tGTDB taxonomy\tNo. clustered genomes\tClustered GTDB genera\tClustered GTDB species\tSpecies name in use\tMost common name in use\tClustered genomes\n')
        cluster_sp_names = {}
        for rid in sorted(clusters, key=lambda x: len(clusters[x]), reverse=True):
            clustered_gids = [c.gid for c in clusters[rid]]
            
            # find most common genus name in cluster
            gtdb_genera = [gtdb_taxonomy[gid][5] for gid in clustered_gids] + [gtdb_taxonomy[rid][5]]
            gtdb_genus_counter = Counter(gtdb_genera)
            gtdb_common_genus = None 
            gtdb_common_genus_count = 0
            for genus, count in gtdb_genus_counter.most_common(): 
                if genus != 'g__':
                    gtdb_common_genus = genus
                    gtdb_common_genus_count = count
                    break
                    
            # in case of ties involving genus of representative genome, 
            # defer to classification of representative
            rep_genus = gtdb_taxonomy[rid][5]
            if gtdb_genus_counter[rep_genus] == gtdb_common_genus_count and rep_genus != 'g__':
                gtdb_common_genus = rep_genus
            
            # get most common GTDB species name 
            gtdb_sp = [gtdb_taxonomy[gid][6] for gid in clustered_gids] + [gtdb_taxonomy[rid][6]]
            gtdb_sp_counter = Counter(gtdb_sp)
            gtdb_common_sp = None
            gtdb_common_sp_count = 0
            for sp, count in gtdb_sp_counter.most_common(): 
                if sp != 's__':
                    gtdb_common_sp = sp
                    gtdb_common_sp_count = count
                    break
                    
            most_common_in_use = gtdb_common_sp in names_in_use

            min_req_genomes = 0.5*(sum(gtdb_sp_counter.values()) - gtdb_sp_counter.get('s__', 0))
            if gtdb_common_sp_count >= min_req_genomes and not most_common_in_use:
                # assign common species if it occurs in >=50% of the clustered genomes,
                # excluding genomes with no species assignment
                names_in_use.add(gtdb_common_sp)
                cluster_sp_names[rid] = gtdb_common_sp
            else:
                # derive new species name from genus, if possible, 
                # and accession number of representative genome
                genus = '{unresolved}'
                if gtdb_common_genus and gtdb_common_genus != 'g__':
                    genus = gtdb_common_genus[3:]
                
                acc = rid
                if rid.startswith('U_'):
                    if rid in gtdb_user_to_genbank:
                        acc = gtdb_user_to_genbank[rid]
                    else:
                        # create accession from GTDB User ID of the form:
                        # U_<number>u.0 which will give 'sp<number>u'
                        acc = 'U_' + rid.replace('U_', '') + 'u.0'

                derived_sp = 's__' + '%s sp%s' % (genus, acc[acc.rfind('_')+1:acc.rfind('.')])
                if derived_sp in names_in_use:
                    self.logger.error('Derived species name already in use: %s, %s' % (derived_sp, acc))
                    sys.exit(-1)

                names_in_use.add(derived_sp)
                cluster_sp_names[rid] = derived_sp
                
            fout.write('%s\t%s\t%s\t%d\t%s\t%s\t%s\t%s\t%s\n' % (
                        rid, 
                        cluster_sp_names[rid],
                        '; '.join(gtdb_taxonomy[rid]),
                        len(clustered_gids),
                        ', '.join("%s=%r" % (genus, count) for (genus, count) in gtdb_genus_counter.most_common()),
                        ', '.join("%s=%r" % (sp, count) for (sp, count) in gtdb_sp_counter.most_common()),
                        ', '.join("%s=%s" % (sp, sp in names_in_use) for sp, _count in gtdb_sp_counter.most_common()),
                        '%s=%d' % (gtdb_common_sp, gtdb_common_sp_count) if most_common_in_use else 'n/a',
                        ', '.join(clustered_gids)))
                
        fout.close()
        
        return cluster_sp_names
        
    def _write_rep_info(self, 
                        clusters, 
                        cluster_sp_names, 
                        quality_metadata, 
                        genome_quality,
                        excluded_from_refseq_note,
                        ani_af,
                        output_file):
        """Write out information about selected representative genomes."""
                                            
        fout = open(output_file, 'w')
        fout.write('Species\tType genome\tNCBI assembly level\tNCBI genome category')
        fout.write('\tGenome size (bp)\tQuality score\tCompleteness (%)\tContamination (%)\tNo. scaffolds\tNo. contigs\tN50 contigs\tAmbiguous bases\tSSU count\tSSU length (bp)')
        fout.write('\tNo. genomes in cluster\tMean ANI\tMean AF\tMin ANI\tMin AF\tNCBI exclude from RefSeq\n')
        
        for gid in clusters:
            fout.write('%s\t%s\t%s\t%s' % (
                        cluster_sp_names[gid], 
                        gid, 
                        quality_metadata[gid].ncbi_assembly_level,
                        quality_metadata[gid].ncbi_genome_category))

            fout.write('\t%d\t%.2f\t%.2f\t%.2f\t%d\t%d\t%.1f\t%d\t%d\t%d' % (
                            quality_metadata[gid].genome_size,
                            genome_quality[gid], 
                            quality_metadata[gid].checkm_completeness,
                            quality_metadata[gid].checkm_contamination,
                            quality_metadata[gid].scaffold_count,
                            quality_metadata[gid].contig_count,
                            quality_metadata[gid].n50_contigs,
                            quality_metadata[gid].ambiguous_bases,
                            quality_metadata[gid].ssu_count,
                            quality_metadata[gid].ssu_length if quality_metadata[gid].ssu_length else 0))
                            
            anis = []
            afs = []
            for cluster_id in clusters[gid]:
                ani, af = symmetric_ani(ani_af, gid, cluster_id)
                anis.append(ani)
                afs.append(af)
            
            if anis:
                fout.write('\t%d\t%.1f\t%.2f\t%.1f\t%.2f\t%s\n' % (len(clusters[gid]),
                                                                    np_mean(anis), np_mean(afs),
                                                                    min(anis), min(afs),
                                                                    excluded_from_refseq_note.get(gid, '')))
            else:
                fout.write('\t%d\t%s\t%s\t%s\t%s\t%s\n' % (len(clusters[gid]),
                                                            'n/a', 'n/a', 'n/a', 'n/a',
                                                            excluded_from_refseq_note.get(gid, '')))
        fout.close()
        
    def _gtdb_user_genomes(self, gtdb_user_genomes_file, metadata_file):
        """Get map between GTDB User genomes and GenBank accessions."""
        
        uba_to_genbank = {}
        for line in open(gtdb_user_genomes_file):
            line_split = line.strip().split('\t')
            gb_acc = line_split[0]
            uba_id = line_split[4]
            uba_to_genbank[uba_id] = gb_acc
        
        user_to_genbank = {}
        m = read_gtdb_metadata(metadata_file, ['organism_name'])
        for gid, metadata in m.items():
            if '(UBA' in str(metadata.organism_name):
                uba_id = metadata.organism_name[metadata.organism_name.find('(')+1:-1]
                if uba_id in uba_to_genbank:
                    user_to_genbank[gid] = uba_to_genbank[uba_id]

        return user_to_genbank

    def run(self, qc_file,
                metadata_file,
                gtdb_user_genomes_file,
                genome_path_file,
                type_genome_cluster_file,
                type_genome_synonym_file,
                ncbi_refseq_assembly_file,
                ncbi_genbank_assembly_file,
                ani_af_nontype_vs_type,
                species_exception_file,
                rnd_type_genome):
        """Infer de novo species clusters and type genomes for remaining genomes."""
        
        # identify genomes failing quality criteria
        self.logger.info('Reading QC file.')
        passed_qc = read_qc_file(qc_file)
        self.logger.info('Identified %d genomes passing QC.' % len(passed_qc))
        
        # get NCBI taxonomy strings for each genome
        self.logger.info('Reading NCBI taxonomy from GTDB metadata file.')
        ncbi_taxonomy, ncbi_update_count = read_gtdb_ncbi_taxonomy(metadata_file, species_exception_file)
        gtdb_taxonomy = read_gtdb_taxonomy(metadata_file)
        self.logger.info('Read NCBI taxonomy for %d genomes with %d manually defined updates.' % (len(ncbi_taxonomy), ncbi_update_count))
        self.logger.info('Read GTDB taxonomy for %d genomes.' % len(gtdb_taxonomy))
        
        # parse NCBI assembly files
        self.logger.info('Parsing NCBI assembly files.')
        excluded_from_refseq_note = exclude_from_refseq(ncbi_refseq_assembly_file, ncbi_genbank_assembly_file)

        # get path to genome FASTA files
        self.logger.info('Reading path to genome FASTA files.')
        genome_files = read_genome_path(genome_path_file)
        self.logger.info('Read path for %d genomes.' % len(genome_files))
        for gid in set(genome_files):
            if gid not in passed_qc:
                genome_files.pop(gid)
        self.logger.info('Considering %d genomes as potential representatives after removing unwanted User genomes.' % len(genome_files))
        assert(len(genome_files) == len(passed_qc))
        
        # determine type genomes and genomes clustered to type genomes
        type_species, species_type_gid, type_gids, type_clustered_gids, type_radius = self._parse_type_clusters(type_genome_cluster_file)
        assert(len(type_species) == len(type_gids))
        self.logger.info('Identified %d type genomes.' % len(type_gids))
        self.logger.info('Identified %d clustered genomes.' % len(type_clustered_gids))
        
        # calculate quality score for genomes
        self.logger.info('Parse quality statistics for all genomes.')
        quality_metadata = read_quality_metadata(metadata_file)
        
        # calculate genome quality score
        self.logger.info('Calculating genome quality score.')
        genome_quality = quality_score(quality_metadata.keys(), quality_metadata)

        # determine genomes left to be clustered
        unclustered_gids = passed_qc - type_gids - type_clustered_gids
        self.logger.info('Identified %d unclustered genomes passing QC.' % len(unclustered_gids))

        # establish closest type genome for each unclustered genome
        self.logger.info('Determining ANI circumscription for %d unclustered genomes.' % len(unclustered_gids))
        nontype_radius = self._nontype_radius(unclustered_gids, type_gids, ani_af_nontype_vs_type)
        
        # calculate Mash ANI estimates between unclustered genomes
        self.logger.info('Calculating Mash ANI estimates between unclustered genomes.')
        mash_anis = self._mash_ani_unclustered(genome_files, unclustered_gids)

        # select species representatives genomes in a greedy fashion based on genome quality
        rep_genomes = self._selected_rep_genomes(genome_files,
                                                    nontype_radius, 
                                                    unclustered_gids, 
                                                    mash_anis,
                                                    quality_metadata,
                                                    rnd_type_genome)
        
        # cluster all non-type/non-rep genomes to species type/rep genomes
        final_cluster_radius = type_radius.copy()
        final_cluster_radius.update(nontype_radius)
        
        final_clusters, ani_af = self._cluster_genomes(genome_files,
                                                        rep_genomes,
                                                        type_gids, 
                                                        passed_qc,
                                                        final_cluster_radius)
        rep_clusters = {}
        for gid in rep_genomes:
            rep_clusters[gid] = final_clusters[gid]

        # get list of synonyms in order to restrict usage of species names
        synonyms = self._parse_synonyms(type_genome_synonym_file)
        self.logger.info('Identified %d synonyms.' % len(synonyms))
        
        # determine User genomes with NCBI accession number that may form species names
        gtdb_user_to_genbank = self._gtdb_user_genomes(gtdb_user_genomes_file, metadata_file)
        self.logger.info('Identified %d GTDB User genomes with NCBI accessions.' % len(gtdb_user_to_genbank))
        
        # assign species names to de novo species clusters
        names_in_use = synonyms.union(type_species)
        self.logger.info('Identified %d species names already in use.' % len(names_in_use))
        self.logger.info('Assigning species name to each de novo species cluster.')
        cluster_sp_names = self._assign_species_names(rep_clusters, 
                                                        names_in_use, 
                                                        gtdb_taxonomy,
                                                        gtdb_user_to_genbank)
        
         # write out file with details about selected representative genomes
        self._write_rep_info(rep_clusters, 
                                cluster_sp_names,
                                quality_metadata,
                                genome_quality,
                                excluded_from_refseq_note,
                                ani_af,
                                os.path.join(self.output_dir, 'gtdb_rep_genome_info.tsv'))
                                             
        # remove genomes that are not representatives of a species cluster and then write out representative ANI radius
        for gid in set(final_cluster_radius) - set(final_clusters):
            del final_cluster_radius[gid]
            
        all_species = cluster_sp_names
        all_species.update(species_type_gid)

        self.logger.info('Writing %d species clusters to file.' % len(all_species))
        self.logger.info('Writing %d cluster radius information to file.' % len(final_cluster_radius))
        
        write_clusters(final_clusters, 
                        final_cluster_radius, 
                        all_species, 
                        os.path.join(self.output_dir, 'gtdb_clusters_final.tsv'))

        write_rep_radius(final_cluster_radius, 
                            all_species, 
                            os.path.join(self.output_dir, 'gtdb_ani_radius_final.tsv'))
        
class UpdateErroneousNCBI(object):
    """Identify genomes with erroneous NCBI species assignments."""
    def __init__(self, ani_ncbi_erroneous, ani_cache_file, cpus, output_dir):
        """Initialization."""

        self.output_dir = output_dir
        self.logger = logging.getLogger('timestamp')

        self.ani_ncbi_erroneous = ani_ncbi_erroneous
        self.fastani = FastANI(ani_cache_file, cpus)

    def identify_misclassified_genomes_ani(self, cur_genomes, cur_clusters):
        """Identify genomes with erroneous NCBI species assignments, based on ANI to type strain genomes."""

        forbidden_names = set(['cyanobacterium'])

        # get mapping from genomes to their representatives
        gid_to_rid = {}
        for rid, cids in cur_clusters.items():
            for cid in cids:
                gid_to_rid[cid] = rid

        # get genomes with NCBI species assignment
        ncbi_sp_gids = defaultdict(list)
        for gid in cur_genomes:
            ncbi_species = cur_genomes[gid].ncbi_taxa.species
            ncbi_specific = specific_epithet(ncbi_species)

            if ncbi_species != 's__' and ncbi_specific not in forbidden_names:
                ncbi_sp_gids[ncbi_species].append(gid)

        # get NCBI species anchored by a type strain genome
        ncbi_type_anchored_species = {}
        for rid, cids in cur_clusters.items():
            if cur_genomes[rid].is_effective_type_strain():
                ncbi_type_species = cur_genomes[rid].ncbi_taxa.species
                if ncbi_type_species != 's__':
                    ncbi_type_anchored_species[ncbi_type_species] = rid
        self.logger.info(
            ' - identified {:,} NCBI species anchored by a type strain genome.'
            .format(len(ncbi_type_anchored_species)))

        # identify genomes with erroneous NCBI species assignments
        fout = open(
            os.path.join(
                self.output_dir, 'ncbi_misclassified_sp.ani_{}.tsv'.format(
                    self.ani_ncbi_erroneous)), 'w')
        fout.write(
            'Genome ID\tNCBI species\tGenome cluster\tType species cluster\tANI to type strain\tAF to type strain\n'
        )

        misclassified_gids = set()
        for idx, (ncbi_species,
                  species_gids) in enumerate(ncbi_sp_gids.items()):
            if ncbi_species not in ncbi_type_anchored_species:
                continue

            type_rid = ncbi_type_anchored_species[ncbi_species]
            gids_to_check = []
            for gid in species_gids:
                cur_rid = gid_to_rid[gid]
                if type_rid != cur_rid:
                    # need to check genome as it has the same NCBI species name
                    # as a type strain genome, but resides in a different GTDB
                    # species cluster
                    gids_to_check.append(gid)

            if len(gids_to_check) > 0:
                gid_pairs = []
                for gid in gids_to_check:
                    gid_pairs.append((type_rid, gid))
                    gid_pairs.append((gid, type_rid))

                statusStr = '-> Establishing erroneous assignments for {} [ANI pairs: {:,}; {:,} of {:,} species].'.format(
                    ncbi_species, len(gid_pairs), idx + 1,
                    len(ncbi_sp_gids)).ljust(96)
                sys.stdout.write('{}\r'.format(statusStr))
                sys.stdout.flush()

                ani_af = self.fastani.pairs(gid_pairs,
                                            cur_genomes.genomic_files,
                                            report_progress=False,
                                            check_cache=True)

                for gid in gids_to_check:
                    ani, af = symmetric_ani(ani_af, type_rid, gid)
                    if ani < self.ani_ncbi_erroneous:
                        misclassified_gids.add(gid)
                        fout.write('{}\t{}\t{}\t{}\t{:.2f}\t{:.3f}\n'.format(
                            gid, ncbi_species, gid_to_rid[gid], type_rid, ani,
                            af))

        sys.stdout.write('\n')
        fout.close()

        misclassified_species = set(
            [cur_genomes[gid].ncbi_taxa.species for gid in misclassified_gids])
        self.logger.info(
            ' - identified {:,} genomes from {:,} species as having misclassified NCBI species assignments.'
            .format(len(misclassified_gids), len(misclassified_species)))

        return misclassified_gids

    def identify_misclassified_genomes_cluster(self, cur_genomes,
                                               cur_clusters):
        """Identify genomes with erroneous NCBI species assignments, based on GTDB clustering of type strain genomes."""

        forbidden_names = set(['cyanobacterium'])

        # get mapping from genomes to their representatives
        gid_to_rid = {}
        for rid, cids in cur_clusters.items():
            for cid in cids:
                gid_to_rid[cid] = rid

        # get genomes with NCBI species assignment
        ncbi_sp_gids = defaultdict(list)
        for gid in cur_genomes:
            ncbi_species = cur_genomes[gid].ncbi_taxa.species
            ncbi_specific = specific_epithet(ncbi_species)

            if ncbi_species != 's__' and ncbi_specific not in forbidden_names:
                ncbi_sp_gids[ncbi_species].append(gid)

        # get NCBI species anchored by a type strain genome
        ncbi_type_anchored_species = {}
        for rid, cids in cur_clusters.items():
            for cid in cids:
                if cur_genomes[cid].is_effective_type_strain():
                    ncbi_type_species = cur_genomes[cid].ncbi_taxa.species
                    ncbi_specific = specific_epithet(ncbi_species)
                    if ncbi_type_species != 's__' and ncbi_specific not in forbidden_names:
                        if (ncbi_type_species in ncbi_type_anchored_species
                                and rid !=
                                ncbi_type_anchored_species[ncbi_type_species]):
                            self.logger.error(
                                'NCBI species {} has multiple effective type strain genomes in different clusters.'
                                .format(ncbi_type_species))
                            sys.exit(-1)

                        ncbi_type_anchored_species[ncbi_type_species] = rid
        self.logger.info(
            ' - identified {:,} NCBI species anchored by a type strain genome.'
            .format(len(ncbi_type_anchored_species)))

        # identify genomes with erroneous NCBI species assignments
        fout = open(
            os.path.join(self.output_dir,
                         'ncbi_misclassified_sp.gtdb_clustering.tsv'), 'w')
        fout.write(
            'Genome ID\tNCBI species\tGenome cluster\tType species cluster\n')

        misclassified_gids = set()
        for idx, (ncbi_species,
                  species_gids) in enumerate(ncbi_sp_gids.items()):
            if ncbi_species not in ncbi_type_anchored_species:
                continue

            # find genomes with NCBI species assignments that are in a
            # different cluster than the type strain genome
            type_rid = ncbi_type_anchored_species[ncbi_species]
            for gid in species_gids:
                cur_rid = gid_to_rid[gid]
                if type_rid != cur_rid:
                    misclassified_gids.add(gid)
                    fout.write('{}\t{}\t{}\t{}\t\n'.format(
                        gid, ncbi_species, cur_rid, type_rid))

        sys.stdout.write('\n')
        fout.close()

        misclassified_species = set(
            [cur_genomes[gid].ncbi_taxa.species for gid in misclassified_gids])
        self.logger.info(
            ' - identified {:,} genomes from {:,} species as having misclassified NCBI species assignments.'
            .format(len(misclassified_gids), len(misclassified_species)))

        return misclassified_gids

    def run(self, gtdb_clusters_file, cur_gtdb_metadata_file,
            cur_genomic_path_file, uba_genome_paths, qc_passed_file,
            ncbi_genbank_assembly_file, untrustworthy_type_file,
            gtdb_type_strains_ledger, sp_priority_ledger,
            genus_priority_ledger, dsmz_bacnames_file):
        """Cluster genomes to selected GTDB representatives."""

        # create current GTDB genome sets
        self.logger.info('Creating current GTDB genome set.')
        cur_genomes = Genomes()
        cur_genomes.load_from_metadata_file(
            cur_gtdb_metadata_file,
            gtdb_type_strains_ledger=gtdb_type_strains_ledger,
            create_sp_clusters=False,
            uba_genome_file=uba_genome_paths,
            qc_passed_file=qc_passed_file,
            ncbi_genbank_assembly_file=ncbi_genbank_assembly_file,
            untrustworthy_type_ledger=untrustworthy_type_file)
        self.logger.info(
            f' ... current genome set contains {len(cur_genomes):,} genomes.')

        # get path to previous and current genomic FASTA files
        self.logger.info('Reading path to current genomic FASTA files.')
        cur_genomes.load_genomic_file_paths(cur_genomic_path_file)
        cur_genomes.load_genomic_file_paths(uba_genome_paths)

        # read named GTDB species clusters
        self.logger.info(
            'Reading named and previous placeholder GTDB species clusters.')
        cur_clusters, rep_radius = read_clusters(gtdb_clusters_file)
        self.logger.info(
            ' ... identified {:,} clusters spanning {:,} genomes.'.format(
                len(cur_clusters),
                sum([len(gids) + 1 for gids in cur_clusters.values()])))

        # identify genomes with erroneous NCBI species assignments
        self.logger.info(
            'Identifying genomes with erroneous NCBI species assignments as established by ANI type strain genomes.'
        )
        self.identify_misclassified_genomes_ani(cur_genomes, cur_clusters)

        self.logger.info(
            'Identifying genomes with erroneous NCBI species assignments as established by GTDB cluster of type strain genomes.'
        )
        self.identify_misclassified_genomes_cluster(cur_genomes, cur_clusters)
Esempio n. 9
0
class ClusterUser(object):
    """Cluster User genomes to GTDB species clusters."""
    def __init__(self, ani_cache_file, cpus, output_dir):
        """Initialization."""

        check_dependencies(['fastANI', 'mash'])

        self.cpus = cpus
        self.output_dir = output_dir

        self.logger = logging.getLogger('timestamp')

        self.min_mash_ani = 90.0

        self.af_sp = 0.65

        self.fastani = FastANI(ani_cache_file, cpus)

    def _mash_ani(self, genome_files, user_genomes, sp_clusters):
        """Calculate Mash ANI estimates between User genomes and species clusters."""

        mash = Mash(self.cpus)

        # create Mash sketch for User genomes
        mash_user_sketch_file = os.path.join(self.output_dir,
                                             'gtdb_user_genomes.msh')
        genome_list_file = os.path.join(self.output_dir,
                                        'gtdb_user_genomes.lst')
        mash.sketch(user_genomes, genome_files, genome_list_file,
                    mash_user_sketch_file)

        # create Mash sketch for species clusters
        mash_sp_sketch_file = os.path.join(self.output_dir,
                                           'gtdb_sp_genomes.msh')
        genome_list_file = os.path.join(self.output_dir, 'gtdb_sp_genomes.lst')
        mash.sketch(sp_clusters, genome_files, genome_list_file,
                    mash_sp_sketch_file)

        # get Mash distances
        mash_dist_file = os.path.join(self.output_dir, 'gtdb_user_vs_sp.dst')
        mash.dist(
            float(100 - self.min_mash_ani) / 100, mash_sp_sketch_file,
            mash_user_sketch_file, mash_dist_file)

        # read Mash distances
        mash_ani = mash.read_ani(mash_dist_file)

        # report pairs above Mash threshold
        mash_ani_pairs = []
        for qid in mash_ani:
            for rid in mash_ani[qid]:
                if mash_ani[qid][rid] >= self.min_mash_ani:
                    if qid != rid:
                        mash_ani_pairs.append((qid, rid))
                        mash_ani_pairs.append((rid, qid))

        self.logger.info(
            'Identified %d genome pairs with a Mash ANI >= %.1f%%.' %
            (len(mash_ani_pairs), self.min_mash_ani))

        return mash_ani

    def _cluster(self, genome_files, sp_clusters, rep_radius, user_genomes,
                 mash_anis):
        """Cluster User genomes to existing species clusters."""

        # assign User genomes to closest species cluster

        for idx, cur_gid in enumerate(user_genomes):
            # determine species cluster to calculate ANI between
            ani_pairs = []
            if cur_gid in mash_anis:
                for rep_gid in sp_clusters:
                    if mash_anis[cur_gid].get(rep_gid, 0) >= self.min_mash_ani:
                        ani_pairs.append((cur_gid, rep_gid))
                        ani_pairs.append((rep_gid, cur_gid))

                # determine if genome clusters with representative
                clustered = False
                if ani_pairs:
                    ani_af = self.fastani.pairs(ani_pairs,
                                                genome_files,
                                                report_progress=False)

                    closest_rep_gid = None
                    closest_rep_ani = 0
                    closest_rep_af = 0
                    for rep_gid in sp_clusters:
                        ani, af = symmetric_ani(ani_af, cur_gid, rep_gid)

                        if af >= self.af_sp:
                            if ani > closest_rep_ani or (ani == closest_rep_ani
                                                         and
                                                         af > closest_rep_af):
                                closest_rep_gid = rep_gid
                                closest_rep_ani = ani
                                closest_rep_af = af

                    if closest_rep_gid and closest_rep_ani > rep_radius[
                            closest_rep_gid].ani:
                        sp_clusters[closest_rep_gid].append(cur_gid)
                    else:
                        self.logger.warning(
                            'Failed to assign genome %s to representative.' %
                            cur_gid)

            statusStr = '-> Assigned %d of %d (%.2f%%) genomes.'.ljust(86) % (
                idx + 1, len(user_genomes),
                float(idx + 1) * 100 / len(user_genomes))
            sys.stdout.write('%s\r' % statusStr)
            sys.stdout.flush()
        sys.stdout.write('\n')

    def run(self, metadata_file, genome_path_file, final_cluster_file):
        """Cluster User genomes to GTDB species clusters."""

        # get path to genome FASTA files
        self.logger.info('Reading path to genome FASTA files.')
        genome_files = read_genome_path(genome_path_file)

        # read existing cluster information
        self.logger.info('Reading already established species clusters.')
        sp_clusters, species, rep_radius = read_clusters(final_cluster_file)

        clustered_genomes = set()
        for rep_id in sp_clusters:
            clustered_genomes.add(rep_id)
            clustered_genomes.update(sp_clusters[rep_id])

        self.logger.info(
            'Identified %d species clusters spanning %d genomes.' %
            (len(sp_clusters), len(clustered_genomes)))

        # get User genomes to cluster
        self.logger.info('Parse quality statistics for all genomes.')
        quality_metadata = read_quality_metadata(metadata_file)

        user_genomes = set()
        for gid in quality_metadata:
            if gid in clustered_genomes:
                continue

            if (quality_metadata[gid].checkm_completeness > 50
                    and quality_metadata[gid].checkm_contamination < 10):
                user_genomes.add(gid)

        self.logger.info('Identified %d User genomes to cluster.' %
                         len(user_genomes))

        # calculate Mash ANI estimates between unclustered genomes
        self.logger.info(
            'Calculating Mash ANI estimates between User genomes and species clusters.'
        )
        mash_anis = self._mash_ani(genome_files, user_genomes, sp_clusters)

        # cluster User genomes to species clusters
        self.logger.info('Assigning User genomes to closest species cluster.')
        self._cluster(genome_files, sp_clusters, rep_radius, user_genomes,
                      mash_anis)

        clustered_genomes = 0
        for rep_id in sp_clusters:
            clustered_genomes += 1
            clustered_genomes += len(sp_clusters[rep_id])

        self.logger.info(
            'The %d species clusters span %d genomes, including User genomes.'
            % (len(sp_clusters), clustered_genomes))

        # report clustering
        user_cluster_file = os.path.join(self.output_dir,
                                         'gtdb_user_clusters.tsv')
        fout = open(user_cluster_file, 'w')
        fout.write('Type genome\tNo. clustered genomes\tClustered genomes\n')
        for rep_id in sp_clusters:
            fout.write('%s\t%d\t%s\n' % (rep_id, len(
                sp_clusters[rep_id]), ','.join(sp_clusters[rep_id])))
        fout.close()
class RepActions(object):
    """Perform initial actions required for changed representatives."""
    def __init__(self, ani_cache_file, cpus, output_dir):
        """Initialization."""

        self.output_dir = output_dir
        self.logger = logging.getLogger('timestamp')

        self.fastani = FastANI(ani_cache_file, cpus)

        # action parameters
        self.genomic_update_ani = 99.0
        self.genomic_update_af = 0.80

        self.new_rep_ani = 99.0
        self.new_rep_af = 0.80
        self.new_rep_qs_threshold = 10  # increase in ANI score require to select
        # new representative

        self.action_log = open(os.path.join(self.output_dir, 'action_log.tsv'),
                               'w')
        self.action_log.write(
            'Genome ID\tPrevious GTDB species\tAction\tParameters\n')

        self.new_reps = {}

    def rep_change_gids(self, rep_change_summary_file, field, value):
        """Get genomes with a specific change."""

        gids = {}
        with open(rep_change_summary_file) as f:
            header = f.readline().strip().split('\t')

            field_index = header.index(field)
            prev_sp_index = header.index('Previous GTDB species')

            for line in f:
                line_split = line.strip().split('\t')

                v = line_split[field_index]
                if v == value:
                    prev_sp = line_split[prev_sp_index]
                    gids[line_split[0]] = prev_sp

        return gids

    def top_ani_score_prev_rep(self, prev_rid, sp_cids, prev_genomes,
                               cur_genomes):
        """Identify genome in cluster with highest balanced ANI score to genomic file of representative in previous GTDB release."""

        max_score = -1e6
        max_rid = None
        max_ani = None
        max_af = None
        for cid in sp_cids:
            ani, af = self.fastani.symmetric_ani_cached(
                f'{prev_rid}-P', f'{cid}-C',
                prev_genomes[prev_rid].genomic_file,
                cur_genomes[cid].genomic_file)

            cur_score = cur_genomes[cid].score_ani(ani)
            if (cur_score > max_score
                    or (cur_score == max_score and ani > max_ani)):
                max_score = cur_score
                max_rid = cid
                max_ani = ani
                max_af = af

        return max_rid, max_score, max_ani, max_af

    def top_ani_score(self, prev_rid, sp_cids, cur_genomes):
        """Identify genome in cluster with highest balanced ANI score to representative genome."""

        # calculate ANI between representative and genomes in species cluster
        gid_pairs = []
        for cid in sp_cids:
            gid_pairs.append((cid, prev_rid))
            gid_pairs.append((prev_rid, cid))

        ani_af = self.fastani.pairs(gid_pairs,
                                    cur_genomes.genomic_files,
                                    report_progress=False,
                                    check_cache=True)

        # find genome with top ANI score
        max_score = -1e6
        max_rid = None
        max_ani = None
        max_af = None
        for cid in sp_cids:
            ani, af = symmetric_ani(ani_af, prev_rid, cid)

            cur_score = cur_genomes[cid].score_ani(ani)
            if cur_score > max_score:
                max_score = cur_score
                max_rid = cid
                max_ani = ani
                max_af = af

        return max_rid, max_score, max_ani, max_af

    def get_updated_rid(self, prev_rid):
        """Get updated representative."""

        if prev_rid in self.new_reps:
            gid, action = self.new_reps[prev_rid]
            return gid

        return prev_rid

    def update_rep(self, prev_rid, new_rid, action):
        """Update representative genome for GTDB species cluster."""

        if prev_rid in self.new_reps and self.new_reps[prev_rid][0] != new_rid:
            self.logger.warning(
                'Representative {} was reassigned multiple times: {} {}.'.
                format(prev_rid, self.new_reps[prev_rid], (new_rid, action)))
            self.logger.warning(
                'Assuming last reassignment of {}: {} has priority.'.format(
                    new_rid, action))

        self.new_reps[prev_rid] = (new_rid, action)

    def genomes_in_current_sp_cluster(self, prev_rid, prev_genomes,
                                      new_updated_sp_clusters, cur_genomes):
        """Get genomes in current species cluster."""

        assert prev_rid in prev_genomes.sp_clusters

        sp_cids = prev_genomes.sp_clusters[prev_rid]
        if prev_rid in new_updated_sp_clusters:
            sp_cids = sp_cids.union(new_updated_sp_clusters[prev_rid])
        sp_cids = sp_cids.intersection(cur_genomes)

        return sp_cids

    def action_genomic_lost(self, rep_change_summary_file, prev_genomes,
                            cur_genomes, new_updated_sp_clusters):
        """Handle species with lost representative genome."""

        # get genomes with specific changes
        self.logger.info(
            'Identifying species with lost representative genome.')
        genomic_lost_rids = self.rep_change_gids(rep_change_summary_file,
                                                 'GENOMIC_CHANGE', 'LOST')
        self.logger.info(
            f' ... identified {len(genomic_lost_rids):,} genomes.')

        # calculate ANI between previous and current genomes
        for prev_rid, prev_gtdb_sp in genomic_lost_rids.items():
            sp_cids = self.genomes_in_current_sp_cluster(
                prev_rid, prev_genomes, new_updated_sp_clusters, cur_genomes)

            params = {}
            if sp_cids:
                action = 'GENOMIC_CHANGE:LOST:REPLACED'

                new_rid, top_score, ani, af = self.top_ani_score_prev_rep(
                    prev_rid, sp_cids, prev_genomes, cur_genomes)
                assert (new_rid != prev_rid)

                params['new_rid'] = new_rid
                params['ani'] = ani
                params['af'] = af
                params['new_assembly_quality'] = cur_genomes[
                    new_rid].score_assembly()
                params['prev_assembly_quality'] = prev_genomes[
                    prev_rid].score_assembly()

                self.update_rep(prev_rid, new_rid, action)
            else:
                action = 'GENOMIC_CHANGE:LOST:SPECIES_RETIRED'
                self.update_rep(prev_rid, None, action)

            self.action_log.write('{}\t{}\t{}\t{}\n'.format(
                prev_rid, prev_gtdb_sp, action, params))

    def action_genomic_update(self, rep_change_summary_file, prev_genomes,
                              cur_genomes, new_updated_sp_clusters):
        """Handle representatives with updated genomes."""

        # get genomes with specific changes
        self.logger.info(
            'Identifying representatives with updated genomic files.')
        genomic_update_gids = self.rep_change_gids(rep_change_summary_file,
                                                   'GENOMIC_CHANGE', 'UPDATED')
        self.logger.info(
            f' ... identified {len(genomic_update_gids):,} genomes.')

        # calculate ANI between previous and current genomes
        assembly_score_change = []
        for prev_rid, prev_gtdb_sp in genomic_update_gids.items():
            # check that genome hasn't been lost which should
            # be handled differently
            assert prev_rid in cur_genomes

            ani, af = self.fastani.symmetric_ani_cached(
                f'{prev_rid}-P', f'{prev_rid}-C',
                prev_genomes[prev_rid].genomic_file,
                cur_genomes[prev_rid].genomic_file)

            params = {}
            params['ani'] = ani
            params['af'] = af
            params['prev_ncbi_accession'] = prev_genomes[prev_rid].ncbi_accn
            params['cur_ncbi_accession'] = cur_genomes[prev_rid].ncbi_accn
            assert prev_genomes[prev_rid].ncbi_accn != cur_genomes[
                prev_rid].ncbi_accn

            if ani >= self.genomic_update_ani and af >= self.genomic_update_af:
                params['prev_assembly_quality'] = prev_genomes[
                    prev_rid].score_assembly()
                params['new_assembly_quality'] = cur_genomes[
                    prev_rid].score_assembly()
                action = 'GENOMIC_CHANGE:UPDATED:MINOR_CHANGE'

                d = cur_genomes[prev_rid].score_assembly(
                ) - prev_genomes[prev_rid].score_assembly()
                assembly_score_change.append(d)
            else:
                sp_cids = self.genomes_in_current_sp_cluster(
                    prev_rid, prev_genomes, new_updated_sp_clusters,
                    cur_genomes)

                if sp_cids:
                    new_rid, top_score, ani, af = self.top_ani_score_prev_rep(
                        prev_rid, sp_cids, prev_genomes, cur_genomes)

                    if new_rid == prev_rid:
                        params['prev_assembly_quality'] = prev_genomes[
                            prev_rid].score_assembly()
                        params['new_assembly_quality'] = cur_genomes[
                            prev_rid].score_assembly()
                        action = 'GENOMIC_CHANGE:UPDATED:RETAINED'
                    else:
                        action = 'GENOMIC_CHANGE:UPDATED:REPLACED'
                        params['new_rid'] = new_rid
                        params['ani'] = ani
                        params['af'] = af
                        params['new_assembly_quality'] = cur_genomes[
                            new_rid].score_assembly()
                        params['prev_assembly_quality'] = prev_genomes[
                            prev_rid].score_assembly()

                        self.update_rep(prev_rid, new_rid, action)
                else:
                    action = 'GENOMIC_CHANGE:UPDATED:SPECIES_RETIRED'
                    self.update_rep(prev_rid, None, action)

            self.action_log.write('{}\t{}\t{}\t{}\n'.format(
                prev_rid, prev_gtdb_sp, action, params))

        self.logger.info(
            ' ... change in assembly score for updated genomes: {:.2f} +/- {:.2f}'
            .format(np_mean(assembly_score_change),
                    np_std(assembly_score_change)))

    def action_type_strain_lost(self, rep_change_summary_file, prev_genomes,
                                cur_genomes, new_updated_sp_clusters):
        """Handle representatives which have lost type strain genome status."""

        # get genomes with new NCBI species assignments
        self.logger.info(
            'Identifying representative that lost type strain genome status.')
        ncbi_type_species_lost = self.rep_change_gids(rep_change_summary_file,
                                                      'TYPE_STRAIN_CHANGE',
                                                      'LOST')
        self.logger.info(
            f' ... identified {len(ncbi_type_species_lost):,} genomes.')

        for prev_rid, prev_gtdb_sp in ncbi_type_species_lost.items():
            # check that genome hasn't been lost which should
            # be handled differently
            assert prev_rid in cur_genomes

            sp_cids = self.genomes_in_current_sp_cluster(
                prev_rid, prev_genomes, new_updated_sp_clusters, cur_genomes)

            prev_rep_score = cur_genomes[prev_rid].score_ani(100)
            new_rid, top_score, ani, af = self.top_ani_score(
                prev_rid, sp_cids, cur_genomes)

            params = {}
            params['prev_rid_prev_strain_ids'] = prev_genomes[
                prev_rid].ncbi_strain_identifiers
            params['prev_rid_cur_strain_ids'] = cur_genomes[
                prev_rid].ncbi_strain_identifiers
            params['prev_rid_prev_gtdb_type_designation'] = prev_genomes[
                prev_rid].gtdb_type_designation
            params['prev_rid_cur_gtdb_type_designation'] = cur_genomes[
                prev_rid].gtdb_type_designation
            params[
                'prev_rid_prev_gtdb_type_designation_sources'] = prev_genomes[
                    prev_rid].gtdb_type_designation_sources
            params['prev_rid_cur_gtdb_type_designation_sources'] = cur_genomes[
                prev_rid].gtdb_type_designation_sources

            if top_score > prev_rep_score:
                action = 'TYPE_STRAIN_CHANGE:LOST:REPLACED'
                assert (prev_rid != new_rid)

                params['new_rid'] = new_rid
                params['ani'] = ani
                params['af'] = af
                params['new_assembly_quality'] = cur_genomes[
                    new_rid].score_assembly()
                params['prev_assembly_quality'] = prev_genomes[
                    prev_rid].score_assembly()

                params['new_rid_strain_ids'] = prev_genomes[
                    new_rid].ncbi_strain_identifiers
                params['new_rid_gtdb_type_designation'] = prev_genomes[
                    new_rid].gtdb_type_designation
                params['new_rid_gtdb_type_designation_sources'] = prev_genomes[
                    new_rid].gtdb_type_designation_sources

                self.update_rep(prev_rid, new_rid, action)
            else:
                action = 'TYPE_STRAIN_CHANGE:LOST:RETAINED'

            self.action_log.write('{}\t{}\t{}\t{}\n'.format(
                prev_rid, prev_gtdb_sp, action, params))

    def action_domain_change(self, rep_change_summary_file, prev_genomes,
                             cur_genomes):
        """Handle representatives which have new domain assignments."""

        # get genomes with new NCBI species assignments
        self.logger.info(
            'Identifying representative with new domain assignments.')
        domain_changed = self.rep_change_gids(rep_change_summary_file,
                                              'DOMAIN_CHECK', 'REASSIGNED')
        self.logger.info(f' ... identified {len(domain_changed):,} genomes.')

        for prev_rid, prev_gtdb_sp in domain_changed.items():
            action = 'DOMAIN_CHECK:REASSIGNED'
            params = {}
            params['prev_gtdb_domain'] = prev_genomes[
                prev_rid].gtdb_taxa.domain
            params['cur_gtdb_domain'] = cur_genomes[prev_rid].gtdb_taxa.domain

            self.update_rep(prev_rid, None, action)
            self.action_log.write('{}\t{}\t{}\t{}\n'.format(
                prev_rid, prev_gtdb_sp, action, params))

    def action_improved_rep(self, prev_genomes, cur_genomes,
                            new_updated_sp_clusters):
        """Check if representative should be replace with higher quality genome."""

        self.logger.info(
            'Identifying improved representatives for GTDB species clusters.')
        num_gtdb_ncbi_type_sp = 0
        num_gtdb_type_sp = 0
        num_ncbi_type_sp = 0
        num_complete = 0
        num_isolate = 0
        anis = []
        afs = []
        improved_reps = {}
        for idx, (prev_rid,
                  cids) in enumerate(new_updated_sp_clusters.clusters()):
            if prev_rid not in cur_genomes:
                # indicates genome has been lost
                continue

            prev_gtdb_sp = new_updated_sp_clusters.get_species(prev_rid)
            statusStr = '-> Processing {:,} of {:,} ({:.2f}%) species [{}: {:,} new/updated genomes].'.format(
                idx + 1, len(new_updated_sp_clusters),
                float(idx + 1) * 100 / len(new_updated_sp_clusters),
                prev_gtdb_sp, len(cids)).ljust(86)
            sys.stdout.write('%s\r' % statusStr)
            sys.stdout.flush()

            # get latest representative of GTDB species clusters as it may
            # have been updated by a previous update rule
            prev_updated_rid = self.get_updated_rid(prev_rid)

            prev_rep_score = cur_genomes[prev_updated_rid].score_ani(100)
            new_rid, top_score, ani, af = self.top_ani_score(
                prev_updated_rid, cids, cur_genomes)

            params = {}
            action = None

            if top_score > prev_rep_score + self.new_rep_qs_threshold:
                assert (prev_updated_rid != new_rid)

                if (cur_genomes[prev_updated_rid].is_gtdb_type_strain(
                ) and cur_genomes[prev_updated_rid].ncbi_taxa.specific_epithet
                        != cur_genomes[new_rid].ncbi_taxa.specific_epithet
                        and self.sp_priority_mngr.has_priority(
                            cur_genomes, prev_updated_rid, new_rid)):
                    # GTDB species cluster should not be moved to a different type strain genome
                    # that has lower naming priority
                    self.logger.warning(
                        'Reassignments to type strain genome with lower naming priority is not allowed: {}/{}/{}, {}/{}/{}'
                        .format(
                            prev_updated_rid,
                            cur_genomes[prev_updated_rid].ncbi_taxa.species,
                            cur_genomes[prev_updated_rid].year_of_priority(),
                            new_rid, cur_genomes[new_rid].ncbi_taxa.species,
                            cur_genomes[new_rid].year_of_priority()))
                    continue

                action = 'IMPROVED_REP:REPLACED:HIGHER_QS'

                params['new_rid'] = new_rid
                params['ani'] = ani
                params['af'] = af
                params['new_assembly_quality'] = cur_genomes[
                    new_rid].score_assembly()
                params['prev_assembly_quality'] = cur_genomes[
                    prev_updated_rid].score_assembly()
                params['new_gtdb_type_strain'] = cur_genomes[
                    new_rid].is_gtdb_type_strain()
                params['prev_gtdb_type_strain'] = cur_genomes[
                    prev_updated_rid].is_gtdb_type_strain()
                params['new_ncbi_type_strain'] = cur_genomes[
                    new_rid].is_ncbi_type_strain()
                params['prev_ncbi_type_strain'] = cur_genomes[
                    prev_updated_rid].is_ncbi_type_strain()

                anis.append(ani)
                afs.append(af)

                improvement_list = []
                gtdb_type_improv = cur_genomes[new_rid].is_gtdb_type_strain(
                ) and not cur_genomes[prev_updated_rid].is_gtdb_type_strain()
                ncbi_type_improv = cur_genomes[new_rid].is_ncbi_type_strain(
                ) and not cur_genomes[prev_updated_rid].is_ncbi_type_strain()

                if gtdb_type_improv and ncbi_type_improv:
                    num_gtdb_ncbi_type_sp += 1
                    improvement_list.append(
                        'replaced with genome from type strain according to GTDB and NCBI'
                    )
                elif gtdb_type_improv:
                    num_gtdb_type_sp += 1
                    improvement_list.append(
                        'replaced with genome from type strain according to GTDB'
                    )
                elif ncbi_type_improv:
                    num_ncbi_type_sp += 1
                    improvement_list.append(
                        'replaced with genome from type strain according to NCBI'
                    )

                if cur_genomes[new_rid].is_isolate(
                ) and not cur_genomes[prev_updated_rid].is_isolate():
                    num_isolate += 1
                    improvement_list.append('MAG/SAG replaced with isolate')

                if cur_genomes[new_rid].is_complete_genome(
                ) and not cur_genomes[prev_updated_rid].is_complete_genome():
                    num_complete += 1
                    improvement_list.append('replaced with complete genome')

                if len(improvement_list) == 0:
                    improvement_list.append(
                        'replaced with higher quality genome')

                params['improvements'] = '; '.join(improvement_list)

                self.action_log.write('{}\t{}\t{}\t{}\n'.format(
                    prev_rid, prev_gtdb_sp, action, params))

                improved_reps[prev_rid] = (new_rid, action)

        sys.stdout.write('\n')
        self.logger.info(
            f' ... identified {len(improved_reps):,} species with improved representatives.'
        )
        self.logger.info(
            f'   ... {num_gtdb_ncbi_type_sp:,} replaced with GTDB/NCBI genome from type strain.'
        )
        self.logger.info(
            f'   ... {num_gtdb_type_sp:,} replaced with GTDB genome from type strain.'
        )
        self.logger.info(
            f'   ... {num_ncbi_type_sp:,} replaced with NCBI genome from type strain.'
        )
        self.logger.info(
            f'   ... {num_isolate:,} replaced MAG/SAG with isolate.')
        self.logger.info(
            f'   ... {num_complete:,} replaced with complete genome assembly.')
        self.logger.info(
            f' ... ANI = {np_mean(anis):.2f} +/- {np_std(anis):.2f}%; AF = {np_mean(afs)*100:.2f} +/- {np_std(afs)*100:.2f}%.'
        )

        return improved_reps

    def action_naming_priority(self, prev_genomes, cur_genomes,
                               new_updated_sp_clusters):
        """Check if representative should be replace with genome with higher nomenclatural priority."""

        self.logger.info(
            'Identifying genomes with naming priority in GTDB species clusters.'
        )

        out_file = os.path.join(self.output_dir, 'update_priority.tsv')
        fout = open(out_file, 'w')
        fout.write(
            'NCBI species\tGTDB species\tRepresentative\tStrain IDs\tRepresentative type sources\tPriority year\tGTDB type species\tGTDB type strain\tNCBI assembly type'
        )
        fout.write(
            '\tNCBI synonym\tGTDB synonym\tSynonym genome\tSynonym strain IDs\tSynonym type sources\tPriority year\tGTDB type species\tGTDB type strain\tSynonym NCBI assembly type'
        )
        fout.write('\tANI\tAF\tPriority note\n')

        num_higher_priority = 0
        assembly_score_change = []
        anis = []
        afs = []
        for idx, prev_rid in enumerate(prev_genomes.sp_clusters):
            # get type strain genomes in GTDB species cluster, including genomes new to this release
            type_strain_gids = [
                gid for gid in prev_genomes.sp_clusters[prev_rid]
                if gid in cur_genomes
                and cur_genomes[gid].is_effective_type_strain()
            ]
            if prev_rid in new_updated_sp_clusters:
                new_type_strain_gids = [
                    gid for gid in new_updated_sp_clusters[prev_rid]
                    if cur_genomes[gid].is_effective_type_strain()
                ]
                type_strain_gids.extend(new_type_strain_gids)

            if len(type_strain_gids) == 0:
                continue

            # check if representative has already been updated
            updated_rid = self.get_updated_rid(prev_rid)

            type_strain_sp = set([
                cur_genomes[gid].ncbi_taxa.species for gid in type_strain_gids
            ])
            if len(type_strain_sp) == 1 and updated_rid in type_strain_gids:
                continue

            updated_sp = cur_genomes[updated_rid].ncbi_taxa.species
            highest_priority_gid = updated_rid

            if updated_rid not in type_strain_gids:
                highest_priority_gid = None
                if updated_sp in type_strain_sp:
                    sp_gids = [
                        gid for gid in type_strain_gids
                        if cur_genomes[gid].ncbi_taxa.species == updated_sp
                    ]
                    hq_gid = select_highest_quality(sp_gids, cur_genomes)
                    highest_priority_gid = hq_gid

                #self.logger.warning('Representative is a non-type strain genome even though type strain genomes exist in species cluster: {}: {}, {}: {}'.format(
                #                    prev_rid, cur_genomes[prev_rid].is_effective_type_strain(), updated_rid, cur_genomes[updated_rid].is_effective_type_strain()))
                #self.logger.warning('Type strain genomes: {}'.format(','.join(type_strain_gids)))

            # find highest priority genome
            for sp in type_strain_sp:
                if sp == updated_sp:
                    continue

                # get highest quality genome from species
                sp_gids = [
                    gid for gid in type_strain_gids
                    if cur_genomes[gid].ncbi_taxa.species == sp
                ]
                hq_gid = select_highest_quality(sp_gids, cur_genomes)

                if highest_priority_gid is None:
                    highest_priority_gid = hq_gid
                else:
                    highest_priority_gid, note = self.sp_priority_mngr.priority(
                        cur_genomes, highest_priority_gid, hq_gid)

            # check if representative should be updated
            if highest_priority_gid != updated_rid:
                num_higher_priority += 1

                ani, af = self.fastani.symmetric_ani_cached(
                    updated_rid, highest_priority_gid,
                    cur_genomes[updated_rid].genomic_file,
                    cur_genomes[highest_priority_gid].genomic_file)

                anis.append(ani)
                afs.append(af)

                d = cur_genomes[highest_priority_gid].score_assembly(
                ) - cur_genomes[updated_rid].score_assembly()
                assembly_score_change.append(d)

                action = 'NOMENCLATURE_PRIORITY:REPLACED'
                params = {}
                params['prev_ncbi_species'] = cur_genomes[
                    updated_rid].ncbi_taxa.species
                params['prev_year_of_priority'] = cur_genomes[
                    updated_rid].year_of_priority()
                params['new_ncbi_species'] = cur_genomes[
                    highest_priority_gid].ncbi_taxa.species
                params['new_year_of_priority'] = cur_genomes[
                    highest_priority_gid].year_of_priority()
                params['new_rid'] = highest_priority_gid
                params['ani'] = ani
                params['af'] = af
                params['priority_note'] = note

                self.update_rep(prev_rid, highest_priority_gid, action)
                self.action_log.write('{}\t{}\t{}\t{}\n'.format(
                    prev_rid, cur_genomes[updated_rid].gtdb_taxa.species,
                    action, params))

                fout.write('{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}'.format(
                    cur_genomes[highest_priority_gid].ncbi_taxa.species,
                    cur_genomes[highest_priority_gid].gtdb_taxa.species,
                    highest_priority_gid, ','.join(
                        sorted(
                            cur_genomes[highest_priority_gid].strain_ids())),
                    ','.join(
                        sorted(cur_genomes[highest_priority_gid].
                               gtdb_type_sources())).upper().replace(
                                   'STRAININFO', 'StrainInfo'),
                    cur_genomes[highest_priority_gid].year_of_priority(),
                    cur_genomes[highest_priority_gid].is_gtdb_type_species(),
                    cur_genomes[highest_priority_gid].is_gtdb_type_strain(),
                    cur_genomes[highest_priority_gid].ncbi_type_material))
                fout.write('\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}'.format(
                    cur_genomes[updated_rid].ncbi_taxa.species,
                    cur_genomes[updated_rid].gtdb_taxa.species, updated_rid,
                    ','.join(sorted(cur_genomes[updated_rid].strain_ids())),
                    ','.join(
                        sorted(cur_genomes[updated_rid].gtdb_type_sources())
                    ).upper().replace('STRAININFO', 'StrainInfo'),
                    cur_genomes[updated_rid].year_of_priority(),
                    cur_genomes[updated_rid].is_gtdb_type_species(),
                    cur_genomes[updated_rid].is_gtdb_type_strain(),
                    cur_genomes[updated_rid].ncbi_type_material))
                fout.write('\t{:.3f}\t{:.4f}\t{}\n'.format(ani, af, note))

        fout.close()

        self.logger.info(
            f' ... identified {num_higher_priority:,} species with representative changed to genome with higher nomenclatural priority.'
        )
        self.logger.info(
            ' ... change in assembly score for new representatives: {:.2f} +/- {:.2f}'
            .format(np_mean(assembly_score_change),
                    np_std(assembly_score_change)))
        self.logger.info(' ... ANI: {:.2f} +/- {:.2f}'.format(
            np_mean(anis), np_std(anis)))
        self.logger.info(' ... AF: {:.2f} +/- {:.2f}'.format(
            np_mean(afs), np_std(afs)))

    def write_updated_clusters(self, prev_genomes, cur_genomes, new_reps,
                               new_updated_sp_clusters, out_file):
        """Write out updated GTDB species clusters."""

        self.logger.info(
            'Writing updated GTDB species clusters to file: {}'.format(
                out_file))

        fout = open(out_file, 'w')
        fout.write(
            'Representative genome\tGTDB species\tNo. clustered genomes\tClustered genomes\n'
        )

        cur_genome_set = set(cur_genomes)

        num_clusters = 0
        for idx, prev_rid in enumerate(prev_genomes.sp_clusters):

            new_rid, action = new_reps.get(prev_rid, [prev_rid, None])
            if new_rid is None:
                continue

            sp_cids = self.genomes_in_current_sp_cluster(
                prev_rid, prev_genomes, new_updated_sp_clusters,
                cur_genome_set)

            fout.write('{}\t{}\t{}\t{}\n'.format(
                new_rid, prev_genomes.sp_clusters.get_species(prev_rid),
                len(sp_cids), ','.join(sp_cids)))
            num_clusters += 1

        fout.close()

        self.logger.info(f' ... wrote {num_clusters:,} clusters.')

    def run(self, rep_change_summary_file, prev_gtdb_metadata_file,
            prev_genomic_path_file, cur_gtdb_metadata_file,
            cur_genomic_path_file, uba_genome_paths, genomes_new_updated_file,
            qc_passed_file, gtdbtk_classify_file, ncbi_genbank_assembly_file,
            untrustworthy_type_file, gtdb_type_strains_ledger,
            sp_priority_ledger):
        """Perform initial actions required for changed representatives."""

        # create previous and current GTDB genome sets
        self.logger.info('Creating previous GTDB genome set.')
        prev_genomes = Genomes()
        prev_genomes.load_from_metadata_file(
            prev_gtdb_metadata_file,
            gtdb_type_strains_ledger=gtdb_type_strains_ledger,
            uba_genome_file=uba_genome_paths,
            ncbi_genbank_assembly_file=ncbi_genbank_assembly_file,
            untrustworthy_type_ledger=untrustworthy_type_file)
        self.logger.info(
            ' ... previous genome set has {:,} species clusters spanning {:,} genomes.'
            .format(len(prev_genomes.sp_clusters),
                    prev_genomes.sp_clusters.total_num_genomes()))

        self.logger.info('Creating current GTDB genome set.')
        cur_genomes = Genomes()
        cur_genomes.load_from_metadata_file(
            cur_gtdb_metadata_file,
            gtdb_type_strains_ledger=gtdb_type_strains_ledger,
            create_sp_clusters=False,
            uba_genome_file=uba_genome_paths,
            qc_passed_file=qc_passed_file,
            ncbi_genbank_assembly_file=ncbi_genbank_assembly_file,
            untrustworthy_type_ledger=untrustworthy_type_file)
        self.logger.info(
            f' ... current genome set contains {len(cur_genomes):,} genomes.')

        # get path to previous and current genomic FASTA files
        self.logger.info(
            'Reading path to previous and current genomic FASTA files.')
        prev_genomes.load_genomic_file_paths(prev_genomic_path_file)
        prev_genomes.load_genomic_file_paths(uba_genome_paths)
        cur_genomes.load_genomic_file_paths(cur_genomic_path_file)
        cur_genomes.load_genomic_file_paths(uba_genome_paths)

        # created expanded previous GTDB species clusters
        new_updated_sp_clusters = SpeciesClusters()

        self.logger.info(
            'Creating species clusters of new and updated genomes based on GTDB-Tk classifications.'
        )
        new_updated_sp_clusters.create_expanded_clusters(
            prev_genomes.sp_clusters, genomes_new_updated_file, qc_passed_file,
            gtdbtk_classify_file)

        self.logger.info(
            'Identified {:,} expanded species clusters spanning {:,} genomes.'.
            format(len(new_updated_sp_clusters),
                   new_updated_sp_clusters.total_num_genomes()))

        # initialize species priority manager
        self.sp_priority_mngr = SpeciesPriorityManager(sp_priority_ledger)

        # take required action for each changed representatives
        self.action_genomic_lost(rep_change_summary_file, prev_genomes,
                                 cur_genomes, new_updated_sp_clusters)

        self.action_genomic_update(rep_change_summary_file, prev_genomes,
                                   cur_genomes, new_updated_sp_clusters)

        self.action_type_strain_lost(rep_change_summary_file, prev_genomes,
                                     cur_genomes, new_updated_sp_clusters)

        self.action_domain_change(rep_change_summary_file, prev_genomes,
                                  cur_genomes)

        if True:  #***
            improved_reps = self.action_improved_rep(prev_genomes, cur_genomes,
                                                     new_updated_sp_clusters)

            pickle.dump(
                improved_reps,
                open(os.path.join(self.output_dir, 'improved_reps.pkl'), 'wb'))
        else:
            self.logger.warning(
                'Reading improved_reps for pre-cached file. Generally used only for debugging.'
            )
            improved_reps = pickle.load(
                open(os.path.join(self.output_dir, 'improved_reps.pkl'), 'rb'))

        for prev_rid, (new_rid, action) in improved_reps.items():
            self.update_rep(prev_rid, new_rid, action)

        self.action_naming_priority(prev_genomes, cur_genomes,
                                    new_updated_sp_clusters)

        # report basic statistics
        num_retired_sp = sum(
            [1 for v in self.new_reps.values() if v[0] is None])
        num_replaced_rids = sum(
            [1 for v in self.new_reps.values() if v[0] is not None])
        self.logger.info(f'Identified {num_retired_sp:,} retired species.')
        self.logger.info(
            f'Identified {num_replaced_rids:,} species with a modified representative genome.'
        )

        self.action_log.close()

        # write out representatives for existing species clusters
        fout = open(os.path.join(self.output_dir, 'updated_species_reps.tsv'),
                    'w')
        fout.write(
            'Previous representative ID\tNew representative ID\tAction\tRepresentative status\n'
        )
        for rid in prev_genomes.sp_clusters:
            if rid in self.new_reps:
                new_rid, action = self.new_reps[rid]
                if new_rid is not None:
                    fout.write(f'{rid}\t{new_rid}\t{action}\tREPLACED\n')
                else:
                    fout.write(f'{rid}\t{new_rid}\t{action}\tLOST\n')
            else:
                fout.write(f'{rid}\t{rid}\tNONE\tUNCHANGED\n')

        fout.close()

        # write out updated species clusters
        out_file = os.path.join(self.output_dir, 'updated_sp_clusters.tsv')
        self.write_updated_clusters(prev_genomes, cur_genomes, self.new_reps,
                                    new_updated_sp_clusters, out_file)
class RepGenomicSimilarity(object):
    """Calculate ANI/AF betwenn GTDB representative genomes with the same genus."""
    def __init__(self, ani_cache_file, cpus, output_dir):
        """Initialization."""

        check_dependencies(['fastANI'])

        self.cpus = cpus
        self.output_dir = output_dir

        self.logger = logging.getLogger('timestamp')

        self.fastani = FastANI(ani_cache_file, cpus)

    def run(self, gtdb_metadata_file, genomic_path_file):
        """Dereplicate GTDB species clusters using ANI/AF criteria."""

        # create GTDB genome sets
        self.logger.info('Creating GTDB genome set.')
        genomes = Genomes()
        genomes.load_from_metadata_file(gtdb_metadata_file)
        genomes.load_genomic_file_paths(genomic_path_file)
        self.logger.info(
            ' - genome set has {:,} species clusters spanning {:,} genomes.'.
            format(len(genomes.sp_clusters),
                   genomes.sp_clusters.total_num_genomes()))

        # get GTDB representatives from same genus
        self.logger.info('Identifying GTDB representatives in the same genus.')
        genus_gids = defaultdict(list)
        num_reps = 0
        for gid in genomes:
            if not genomes[gid].gtdb_is_rep:
                continue

            gtdb_genus = genomes[gid].gtdb_taxa.genus
            genus_gids[gtdb_genus].append(gid)
            num_reps += 1
        self.logger.info(
            f' - identified {len(genus_gids):,} genera spanning {num_reps:,} representatives'
        )

        # get all intragenus comparisons
        self.logger.info('Determining all intragenus comparisons.')
        gid_pairs = []
        for gids in genus_gids.values():
            if len(gids) < 2:
                continue

            for g1, g2 in permutations(gids, 2):
                gid_pairs.append((g1, g2))
        self.logger.info(
            f' - identified {len(gid_pairs):,} intragenus comparisons')

        # calculate FastANI ANI/AF between target genomes
        self.logger.info('Calculating ANI between intragenus pairs.')
        ani_af = self.fastani.pairs(gid_pairs,
                                    genomes.genomic_files,
                                    report_progress=True,
                                    check_cache=True)
        self.fastani.write_cache(silence=True)

        # write out results
        fout = open(
            os.path.join(self.output_dir, 'intragenus_ani_af_reps.tsv'), 'w')
        fout.write(
            'Query ID\tQuery species\tTarget ID\tTarget species\tANI\tAF\n')
        for qid in ani_af:
            for rid in ani_af:
                ani, af = FastANI.symmetric_ani(ani_af, qid, rid)

                fout.write('{}\t{}\t{}\t{}\t{:.3f}\t{:.3f}\n'.format(
                    qid, genomes[qid].gtdb_taxa.species, rid,
                    genomes[rid].gtdb_taxa.species, ani, af))
        fout.close()