Пример #1
0
class ResolveTypes():
    """Resolve cases where a species has multiple genomes assembled from the type strain."""
    def __init__(self, ani_cache_file, cpus, output_dir):
        """Initialization."""

        self.ltp_dir = 'rna_ltp_132'
        self.ltp_results_file = 'ssu.taxonomy.tsv'
        self.LTP_METADATA = namedtuple(
            'LTP_METADATA',
            'taxonomy taxa species ssu_len evalue bitscore aln_len perc_iden perc_aln'
        )

        self.ltp_pi_threshold = 99.0
        self.ltp_pa_threshold = 90.0
        self.ltp_ssu_len_threshold = 900
        self.ltp_evalue_threshold = 1e-10

        self.output_dir = output_dir
        self.logger = logging.getLogger('timestamp')
        self.cpus = cpus

        self.fastani = FastANI(ani_cache_file, cpus)

        self.ani_pickle_dir = os.path.join(self.output_dir, 'ani_pickles')
        if not os.path.exists(self.ani_pickle_dir):
            os.makedirs(self.ani_pickle_dir)

    def _parse_ltp_taxonomy_str(self, ltp_taxonomy_str):
        """Parse taxa and species from LTP taxonomy string."""

        if ';type sp.|' in ltp_taxonomy_str:
            taxa = ltp_taxonomy_str.split(';type sp.|')[0].split(';')
        elif ';|' in ltp_taxonomy_str:
            taxa = ltp_taxonomy_str.split(';|')[0].split(';')
        elif '|' in ltp_taxonomy_str:
            taxa = ltp_taxonomy_str.split('|')[0].split(';')
        elif ltp_taxonomy_str[-1] == ';':
            taxa = ltp_taxonomy_str[0:-1].split(';')
        else:
            taxa = ltp_taxonomy_str.split(';')

        sp = taxa[-1]
        if ' subsp. ' in sp:
            sp = ' '.join(sp.split()[0:2])

        # validate that terminal taxon appears to be a
        # valid binomial species name
        if (sp[0].islower() or any(c.isdigit() for c in sp)
                or any(c.isupper() for c in sp[1:])):
            print(ltp_taxonomy_str, taxa)
            assert False

        return taxa, 's__' + sp

    def parse_ltp_metadata(self, type_gids, cur_genomes):
        """Parse Living Tree Project 16S rRNA metadata."""

        metadata = defaultdict(list)
        for gid in type_gids:
            genome_path = os.path.dirname(
                os.path.abspath(cur_genomes[gid].genomic_file))
            ltp_file = os.path.join(genome_path, self.ltp_dir,
                                    self.ltp_results_file)
            if os.path.exists(ltp_file):
                with open(ltp_file) as f:
                    header = f.readline().strip().split('\t')

                    taxonomy_index = header.index('taxonomy')
                    ssu_len_index = header.index('length')
                    evalue_index = header.index('blast_evalue')
                    bitscore_index = header.index('blast_bitscore')
                    aln_len_index = header.index('blast_align_len')
                    pi_index = header.index('blast_perc_identity')

                    for line in f:
                        tokens = line.strip().split('\t')

                        taxonomy = tokens[taxonomy_index]
                        ssu_len = int(tokens[ssu_len_index])
                        evalue = float(tokens[evalue_index])
                        bitscore = float(tokens[bitscore_index])
                        aln_len = int(tokens[aln_len_index])
                        pi = float(tokens[pi_index])

                        taxa, sp = self._parse_ltp_taxonomy_str(taxonomy)

                        metadata[gid].append(
                            self.LTP_METADATA(taxonomy=taxonomy,
                                              taxa=taxa,
                                              species=sp,
                                              ssu_len=ssu_len,
                                              evalue=evalue,
                                              bitscore=bitscore,
                                              aln_len=aln_len,
                                              perc_iden=pi,
                                              perc_aln=aln_len * 100.0 /
                                              ssu_len))

        return metadata

    def ltp_defined_species(self, ltp_taxonomy_file):
        """Get all species present in the LTP database."""

        ltp_species = set()
        with open(ltp_taxonomy_file, encoding='utf-8') as f:
            for line in f:
                tokens = line.strip().split('\t')

                taxonomy = tokens[1]
                _taxa, sp = self._parse_ltp_taxonomy_str(taxonomy)
                ltp_species.add(sp)

        return ltp_species

    def ltp_species(self, gid, ltp_metadata):
        """Get high confident species assignments."""

        sp = set()
        for hit in ltp_metadata[gid]:
            # check if hit should be trusted
            if (hit.perc_iden >= self.ltp_pi_threshold
                    and hit.perc_aln >= self.ltp_pa_threshold
                    and hit.ssu_len >= self.ltp_ssu_len_threshold
                    and hit.evalue < self.ltp_evalue_threshold):
                sp.add(hit.species)

        return sp

    def check_strain_ani(self, gid_anis, untrustworthy_gids):
        """Check if genomes meet strain ANI criteria."""

        for gid1, gid2 in combinations(gid_anis, 2):
            if gid1 in untrustworthy_gids or gid2 in untrustworthy_gids:
                continue

            if gid_anis[gid1][gid2] < 99:
                return False

        return True

    def resolve_by_intra_specific_ani(self, gid_anis):
        """Resolve by removing intra-specific genomes with divergent ANI values."""

        if len(gid_anis) <= 2:
            return False, {}

        # consider most divergent genome as untrustworthy
        untrustworthy_gids = {}
        while True:
            # find most divergent genome
            min_ani = 100
            untrustworthy_gid = None
            for gid in gid_anis:
                if gid in untrustworthy_gids:
                    continue

                anis = [
                    ani for cur_gid, ani in gid_anis[gid].items()
                    if cur_gid not in untrustworthy_gids
                ]
                if np_mean(anis) < min_ani:
                    min_ani = np_mean(anis)
                    untrustworthy_gid = gid

            untrustworthy_gids[
                untrustworthy_gid] = f'{min_ani:.2f}% ANI to other type strain genomes'

            all_similar = self.check_strain_ani(gid_anis, untrustworthy_gids)

            if all_similar:
                return True, untrustworthy_gids

            remaining_genomes = len(gid_anis) - len(untrustworthy_gids)
            if remaining_genomes <= 2 or len(untrustworthy_gids) >= len(
                    gid_anis):
                return False, {}

    def resolve_by_ncbi_types(self, gid_anis, type_gids, cur_genomes):
        """Resolve by consulting NCBI type material metadata."""

        untrustworthy_gids = {}
        ncbi_type_count = 0
        for gid in type_gids:
            if not cur_genomes[gid].is_ncbi_type_strain():
                untrustworthy_gids[
                    gid] = 'Not classified as assembled from type material at NCBI'
            else:
                ncbi_type_count += 1

        all_similar = self.check_strain_ani(gid_anis, untrustworthy_gids)

        if all_similar and len(untrustworthy_gids) > 0 and ncbi_type_count > 0:
            return True, untrustworthy_gids

        return False, {}

    def resolve_by_ncbi_reps(self, gid_anis, type_gids, cur_genomes):
        """Resovle by considering genomes annotated as representative genomes at NCBI."""

        untrustworthy_gids = {}
        ncbi_rep_count = 0
        for gid in type_gids:
            if not cur_genomes[gid].is_ncbi_representative():
                untrustworthy_gids[
                    gid] = 'Excluded in favour of RefSeq representative or reference genome'
            else:
                ncbi_rep_count += 1

        all_similar = self.check_strain_ani(gid_anis, untrustworthy_gids)

        if all_similar and ncbi_rep_count >= 1:
            return True, untrustworthy_gids

        return False, {}

    def resolve_gtdb_family(self, gid_anis, ncbi_sp, type_gids, cur_genomes):
        """Resolve by identifying genomes with a conflicting GTDB family assignment."""

        genus = 'g__' + generic_name(ncbi_sp)
        gtdb_genus_rep = cur_genomes.gtdb_type_species_of_genus(genus)
        if not gtdb_genus_rep:
            return False, {}

        expected_gtdb_family = cur_genomes[gtdb_genus_rep].gtdb_taxa.family

        untrustworthy_gids = {}
        matched_family = 0
        for gid in type_gids:
            if cur_genomes[gid].gtdb_taxa.family == expected_gtdb_family:
                matched_family += 1
            else:
                # genome is classified to a different GTDB family than
                # expected for this species
                untrustworthy_gids[
                    gid] = f'Conflicting GTDB family assignment of {cur_genomes[gid].gtdb_taxa.family}, expected {expected_gtdb_family}'

        all_similar = self.check_strain_ani(gid_anis, untrustworthy_gids)

        # conflict is resolved if remaining genomes pass ANI similarity test,
        if all_similar and len(untrustworthy_gids) > 0 and matched_family > 0:
            return True, untrustworthy_gids

        return False, {}

    def resolve_gtdb_genus(self, gid_anis, ncbi_sp, type_gids, cur_genomes):
        """Resolve by identifying genomes with a conflicting GTDB genus assignments."""

        ncbi_genus = 'g__' + generic_name(ncbi_sp)

        untrustworthy_gids = {}
        matched_genus = 0
        for gid in type_gids:
            canonical_gtdb_genus = canonical_taxon(
                cur_genomes[gid].gtdb_taxa.genus)

            if ncbi_genus == canonical_gtdb_genus:
                matched_genus += 1
            else:
                untrustworthy_gids[
                    gid] = f'Conflicting GTDB genus assignment of {cur_genomes[gid].gtdb_taxa.genus}, expected {ncbi_genus}'

        all_similar = self.check_strain_ani(gid_anis, untrustworthy_gids)

        if all_similar and len(untrustworthy_gids) > 0 and matched_genus > 0:
            return True, untrustworthy_gids

        return False, {}

    def resolve_gtdb_species(self, gid_anis, ncbi_sp, type_gids, cur_genomes):
        """Resolve by identifying genomes with a conflicting GTDB species assignments to different type material."""

        ncbi_sp_epithet = specific_epithet(ncbi_sp)

        untrustworthy_gids = {}
        matched_sp_epithet = 0
        for gid in type_gids:
            if ncbi_sp_epithet == cur_genomes[gid].gtdb_taxa.specific_epithet:
                matched_sp_epithet += 1
            else:
                # check if genome is classified to a GTDB species cluster supported
                # by a type strain genome in which case we should consider this
                # genome untrustworthy
                gtdb_sp = cur_genomes[gid].gtdb_taxa.species
                if gtdb_sp != 's__':
                    gtdb_sp_rid = cur_genomes.gtdb_sp_rep(gtdb_sp)
                    if cur_genomes[gtdb_sp_rid].is_effective_type_strain():
                        # genome has been assigned to another species
                        # defined by a type strain genome
                        ani, af = self.fastani.symmetric_ani_cached(
                            gid, gtdb_sp_rid, cur_genomes[gid].genomic_file,
                            cur_genomes[gtdb_sp_rid].genomic_file)
                        untrustworthy_gids[
                            gid] = f'Conflicting GTDB species assignment of {cur_genomes[gid].gtdb_taxa.species} [ANI={ani:.2f}%; AF={af:.2f}%]'

        all_similar = self.check_strain_ani(gid_anis, untrustworthy_gids)

        # conflict is resolved if remaining genomes pass ANI similarity test,
        if all_similar and len(
                untrustworthy_gids) > 0 and matched_sp_epithet > 0:
            return True, untrustworthy_gids

        return False, {}

    def resolve_validated_untrustworthy_ncbi_genomes(self, gid_anis, ncbi_sp,
                                                     type_gids, ltp_metadata,
                                                     ltp_defined_species,
                                                     cur_genomes):
        """Resolve by identifying genomes marked as `untrustworthy as type` at NCBI and with conflicting LTP assignments."""

        if ncbi_sp not in ltp_defined_species:
            return False, {}

        untrustworthy_gids = {}
        for gid in type_gids:
            if 'untrustworthy as type' in cur_genomes[
                    gid].excluded_from_refseq_note.lower():
                ltp_species = self.ltp_species(gid, ltp_metadata)

                if ncbi_sp not in ltp_species and len(ltp_species) > 0:
                    untrustworthy_gids[
                        gid] = f"Conflicting 16S rRNA hits to LTP database of {' / '.join(set(ltp_species))}"

        all_similar = self.check_strain_ani(gid_anis, untrustworthy_gids)

        # conflict is resolved if remaining genomes pass ANI similarity test,
        if all_similar and len(untrustworthy_gids) > 0:
            return True, untrustworthy_gids

        return False, {}

    def resolve_ltp_conflict(self, gid_anis, ncbi_sp, type_gids, ltp_metadata,
                             require_conflict_sp):
        """Resolve by considering BLAST hits of 16S rRNA genes to LTP database."""

        untrustworthy_gids = {}
        genomes_matching_expected_sp = 0
        for gid in type_gids:
            expected_sp_count = 0
            match_unexpected_sp = []
            for hit in ltp_metadata[gid]:
                # check if hit should be trusted
                if (hit.perc_iden >= self.ltp_pi_threshold
                        and hit.perc_aln >= self.ltp_pa_threshold
                        and hit.ssu_len >= self.ltp_ssu_len_threshold
                        and hit.evalue < self.ltp_evalue_threshold):
                    ltp_sp = hit.species
                    if ltp_sp == ncbi_sp:
                        expected_sp_count += 1
                    else:
                        match_unexpected_sp.append(ltp_sp)

            if expected_sp_count == 0 and len(
                    match_unexpected_sp) >= require_conflict_sp:
                if len(match_unexpected_sp) > 0:
                    untrustworthy_gids[
                        gid] = f"Conflicting 16S rRNA hits to LTP database of {' / '.join(set(match_unexpected_sp))}"
                else:
                    untrustworthy_gids[
                        gid] = "Lack of 16S rRNA hits to LTP database"
            elif expected_sp_count > len(match_unexpected_sp):
                genomes_matching_expected_sp += 1

        all_similar = self.check_strain_ani(gid_anis, untrustworthy_gids)

        if all_similar and len(
                untrustworthy_gids) > 0 and genomes_matching_expected_sp > 0:
            return True, untrustworthy_gids

        return False, {}

    def parse_untrustworthy_type_ledger(self, untrustworthy_type_ledger):
        """Parse file indicating genomes considered to be untrustworthy as type material."""

        manual_untrustworthy_types = {}
        with open(untrustworthy_type_ledger) as f:
            header = f.readline().strip().split('\t')

            ncbi_sp_index = header.index('NCBI species')
            reason_index = header.index('Reason for declaring untrustworthy')

            for line in f:
                tokens = line.strip().split('\t')

                gid = canonical_gid(tokens[0])
                manual_untrustworthy_types[gid] = (tokens[ncbi_sp_index],
                                                   tokens[reason_index])

        return manual_untrustworthy_types

    def sp_with_mult_type_strains(self, cur_genomes):
        """Identify NCBI species with multiple type strain of species genomes."""

        sp_type_strain_genomes = defaultdict(set)
        for gid in cur_genomes:
            if cur_genomes[gid].is_effective_type_strain():
                ncbi_sp = cur_genomes[gid].ncbi_taxa.species
                if ncbi_sp != 's__':
                    # yes, NCBI has genomes marked as assembled from type material
                    # that do not actually have a binomial species name
                    sp_type_strain_genomes[ncbi_sp].add(gid)

        multi_type_strains_sp = {
            ncbi_sp: gids
            for ncbi_sp, gids in sp_type_strain_genomes.items()
            if len(gids) > 1
        }

        return multi_type_strains_sp

    def calculate_type_strain_ani(self, ncbi_sp, type_gids, cur_genomes,
                                  use_pickled_results):
        """Calculate pairwise ANI between type strain genomes."""

        ncbi_sp_str = ncbi_sp[3:].lower().replace(' ', '_')
        if not use_pickled_results:  # ***
            ani_af = self.fastani.pairwise(type_gids,
                                           cur_genomes.genomic_files)
            pickle.dump(
                ani_af,
                open(os.path.join(self.ani_pickle_dir, f'{ncbi_sp_str}.pkl'),
                     'wb'))
        else:
            ani_af = pickle.load(
                open(os.path.join(self.ani_pickle_dir, f'{ncbi_sp_str}.pkl'),
                     'rb'))

        anis = []
        afs = []
        gid_anis = defaultdict(lambda: {})
        gid_afs = defaultdict(lambda: {})
        all_similar = True
        for gid1, gid2 in combinations(type_gids, 2):
            ani, af = FastANI.symmetric_ani(ani_af, gid1, gid2)
            if ani < 99 or af < 0.65:
                all_similar = False

            anis.append(ani)
            afs.append(af)

            gid_anis[gid1][gid2] = ani
            gid_anis[gid2][gid1] = ani

            gid_afs[gid1][gid2] = af
            gid_afs[gid2][gid1] = af

        return all_similar, anis, afs, gid_anis, gid_afs

    def run(self, cur_gtdb_metadata_file, cur_genomic_path_file,
            qc_passed_file, ncbi_genbank_assembly_file, ltp_taxonomy_file,
            gtdb_type_strains_ledger, untrustworthy_type_ledger,
            ncbi_env_bioproject_ledger):
        """Resolve cases where a species has multiple genomes assembled from the type strain."""

        # get species in LTP reference database
        self.logger.info(
            'Determining species defined in LTP reference database.')
        ltp_defined_species = self.ltp_defined_species(ltp_taxonomy_file)
        self.logger.info(
            f' - identified {len(ltp_defined_species):,} species.')

        # create current GTDB genome sets
        self.logger.info('Creating current GTDB genome set.')
        cur_genomes = Genomes()
        cur_genomes.load_from_metadata_file(
            cur_gtdb_metadata_file,
            gtdb_type_strains_ledger=gtdb_type_strains_ledger,
            create_sp_clusters=False,
            qc_passed_file=qc_passed_file,
            ncbi_genbank_assembly_file=ncbi_genbank_assembly_file,
            untrustworthy_type_ledger=untrustworthy_type_ledger,
            ncbi_env_bioproject_ledger=ncbi_env_bioproject_ledger)
        cur_genomes.load_genomic_file_paths(cur_genomic_path_file)

        # parsing genomes manually established to be untrustworthy as type
        self.logger.info(
            'Determining genomes manually annotated as untrustworthy as type.')
        manual_untrustworthy_types = self.parse_untrustworthy_type_ledger(
            untrustworthy_type_ledger)
        self.logger.info(
            f' - identified {len(manual_untrustworthy_types):,} genomes manually annotated as untrustworthy as type.'
        )

        # Identify NCBI species with multiple genomes assembled from type strain of species. This
        # is done using a series of heuristics that aim to ensure that the selected type strain
        # genome is reliable. More formal evaluation and a manuscript descirbing this selection
        # process is ultimately required. Ideally, the community will eventually adopt a
        # database that indicates a single `type genome assembly` for each species instead
        # of just indicating a type strain from which many (sometimes dissimilar) assemblies exist.
        self.logger.info(
            'Determining number of type strain genomes in each NCBI species.')
        multi_type_strains_sp = self.sp_with_mult_type_strains(cur_genomes)
        self.logger.info(
            f' - identified {len(multi_type_strains_sp):,} NCBI species with multiple assemblies indicated as being type strain genomes.'
        )

        # resolve species with multiple type strain genomes
        fout = open(
            os.path.join(self.output_dir, 'multi_type_strain_species.tsv'),
            'w')
        fout.write(
            'NCBI species\tNo. type strain genomes\t>=99% ANI\tMean ANI\tStd ANI\tMean AF\tStd AF\tResolution\tGenome IDs\n'
        )

        fout_genomes = open(
            os.path.join(self.output_dir, 'type_strain_genomes.tsv'), 'w')
        fout_genomes.write(
            'Genome ID\tUntrustworthy\tNCBI species\tGTDB genus\tGTDB species\tLTP species\tConflict with prior GTDB assignment'
        )
        fout_genomes.write(
            '\tMean ANI\tStd ANI\tMean AF\tStd AF\tExclude from RefSeq\tNCBI taxonomy\tGTDB taxonomy\tReason for GTDB untrustworthy as type\n'
        )

        fout_unresolved = open(
            os.path.join(self.output_dir,
                         'unresolved_type_strain_genomes.tsv'), 'w')
        fout_unresolved.write(
            'Genome ID\tNCBI species\tGTDB genus\tGTDB species\tLTP species')
        fout_unresolved.write(
            '\tMean ANI\tStd ANI\tMean AF\tStd AF\tExclude from RefSeq\tNCBI taxonomy\tGTDB taxonomy\n'
        )

        fout_high_divergence = open(
            os.path.join(self.output_dir,
                         'highly_divergent_type_strain_genomes.tsv'), 'w')
        fout_high_divergence.write(
            'Genome ID\tNCBI species\tGTDB genus\tGTDB species\tLTP species\tMean ANI\tStd ANI\tMean AF\tStd AF\tExclude from RefSeq\tNCBI taxonomy\tGTDB taxonomy\n'
        )

        fout_untrustworthy = open(
            os.path.join(self.output_dir, 'untrustworthy_type_material.tsv'),
            'w')
        fout_untrustworthy.write(
            'Genome ID\tNCBI species\tGTDB species\tLTP species\tReason for declaring untrustworthy\n'
        )

        for gid in manual_untrustworthy_types:
            ncbi_sp, reason = manual_untrustworthy_types[gid]
            fout_untrustworthy.write('{}\t{}\t{}\t{}\t{}\t{}\n'.format(
                gid, ncbi_sp, cur_genomes[gid].gtdb_taxa.species,
                '<not tested>', 'n/a', 'Manual curation: ' + reason))

        processed = 0
        num_divergent = 0
        unresolved_sp_count = 0

        ncbi_ltp_resolved = 0
        intra_ani_resolved = 0
        ncbi_type_resolved = 0
        ncbi_rep_resolved = 0
        gtdb_family_resolved = 0
        gtdb_genus_resolved = 0
        gtdb_sp_resolved = 0
        ltp_resolved = 0

        # *** Perhaps should be an external flag, but used right now to speed up debugging
        use_pickled_results = False
        if use_pickled_results:
            self.logger.warning(
                'Using previously calculated ANI results in: {}'.format(
                    self.ani_pickle_dir))

        prev_gtdb_sp_conflicts = 0

        self.logger.info(
            'Resolving species with multiple type strain genomes:')
        for ncbi_sp, type_gids in sorted(multi_type_strains_sp.items(),
                                         key=lambda kv: len(kv[1])):
            assert len(type_gids) > 1

            status_str = '-> Processing {} with {:,} type strain genomes [{:,} of {:,} ({:.2f}%)].'.format(
                ncbi_sp, len(type_gids), processed + 1,
                len(multi_type_strains_sp), (processed + 1) * 100.0 /
                len(multi_type_strains_sp)).ljust(128)
            sys.stdout.write('{}\r'.format(status_str))
            sys.stdout.flush()
            processed += 1

            # calculate ANI between type strain genomes
            all_similar, anis, afs, gid_anis, gid_afs = self.calculate_type_strain_ani(
                ncbi_sp, type_gids, cur_genomes, use_pickled_results)

            # read LTP metadata for genomes
            ltp_metadata = self.parse_ltp_metadata(type_gids, cur_genomes)

            untrustworthy_gids = {}
            gtdb_resolved_sp_conflict = False
            unresolved_species = False
            note = 'All type strain genomes have ANI >99% and AF >65%.'
            if not all_similar:
                note = ''

                # need to establish which genomes are untrustworthy as type
                num_divergent += 1
                unresolved_species = True

                # write out highly divergent cases for manual inspection;
                # these should be compared to the automated selection
                if np_mean(anis) < 95:
                    for gid in type_gids:
                        ltp_species = self.ltp_species(gid, ltp_metadata)

                        fout_high_divergence.write(
                            '{}\t{}\t{}\t{}\t{}\t{:.2f}\t{:.3f}\t{:.3f}\t{:.4f}\t{}\t{}\t{}\n'
                            .format(gid, ncbi_sp,
                                    cur_genomes[gid].gtdb_taxa.genus,
                                    cur_genomes[gid].gtdb_taxa.species,
                                    ' / '.join(ltp_species),
                                    np_mean(list(gid_anis[gid].values())),
                                    np_std(list(gid_anis[gid].values())),
                                    np_mean(list(gid_afs[gid].values())),
                                    np_std(list(gid_afs[gid].values())),
                                    cur_genomes[gid].excluded_from_refseq_note,
                                    cur_genomes[gid].ncbi_taxa,
                                    cur_genomes[gid].gtdb_taxa))

                # filter genomes marked as `untrustworthy as type` at NCBI and where the LTP
                # assignment also suggest the asserted type material is incorrect
                resolved, untrustworthy_gids = self.resolve_validated_untrustworthy_ncbi_genomes(
                    gid_anis, ncbi_sp, type_gids, ltp_metadata,
                    ltp_defined_species, cur_genomes)
                if resolved:
                    note = "Species resolved by removing genomes considered `untrustworthy as type` and with a LTP BLAST hit confirming the assembly is likely untrustworthy"
                    ncbi_ltp_resolved += 1

                # try to resolve by LTP 16S BLAST results
                if not resolved:
                    resolved, untrustworthy_gids = self.resolve_ltp_conflict(
                        gid_anis, ncbi_sp, type_gids, ltp_metadata, 0)
                    if resolved:
                        note = 'Species resolved by identifying conflicting or lack of LTP BLAST results'
                        ltp_resolved += 1

                # try to resolve species using intra-specific ANI test
                if not resolved:
                    resolved, untrustworthy_gids = self.resolve_by_intra_specific_ani(
                        gid_anis)
                    if resolved:
                        note = 'Species resolved by intra-specific ANI test'
                        intra_ani_resolved += 1

                # try to resolve by GTDB family assignment
                if not resolved:
                    resolved, untrustworthy_gids = self.resolve_gtdb_family(
                        gid_anis, ncbi_sp, type_gids, cur_genomes)
                    if resolved:
                        note = 'Species resolved by consulting GTDB family classifications'
                        gtdb_family_resolved += 1

                # try to resolve by GTDB genus assignment
                if not resolved:
                    resolved, untrustworthy_gids = self.resolve_gtdb_genus(
                        gid_anis, ncbi_sp, type_gids, cur_genomes)
                    if resolved:
                        note = 'Species resolved by consulting GTDB genus classifications'
                        gtdb_genus_resolved += 1

                # try to resolve by GTDB species assignment
                if not resolved:
                    resolved, untrustworthy_gids = self.resolve_gtdb_species(
                        gid_anis, ncbi_sp, type_gids, cur_genomes)
                    if resolved:
                        note = 'Species resolved by consulting GTDB species classifications'
                        gtdb_sp_resolved += 1

                # try to resolve by considering genomes annotated as type material at NCBI,
                # which includes considering if genomes are marked as untrustworthy as type
                if not resolved:
                    resolved, untrustworthy_gids = self.resolve_by_ncbi_types(
                        gid_anis, type_gids, cur_genomes)
                    if resolved:
                        note = 'Species resolved by consulting NCBI assembled from type metadata'
                        ncbi_type_resolved += 1

                # try to resovle by considering genomes annotated as representative genomes at NCBI
                if not resolved:
                    resolved, untrustworthy_gids = self.resolve_by_ncbi_reps(
                        gid_anis, type_gids, cur_genomes)
                    if resolved:
                        note = 'Species resolved by considering NCBI representative genomes'
                        ncbi_rep_resolved += 1

                if resolved:
                    unresolved_species = False

                    # check if type strain genomes marked as trusted or untrusted conflict
                    # with current GTDB species assignment
                    untrustworthy_gtdb_sp_match = False
                    trusted_gtdb_sp_match = False
                    for gid in type_gids:
                        gtdb_canonical_epithet = canonical_taxon(
                            specific_epithet(
                                cur_genomes[gid].gtdb_taxa.species))
                        if gtdb_canonical_epithet == specific_epithet(ncbi_sp):
                            if gid in untrustworthy_gids:
                                untrustworthy_gtdb_sp_match = True
                            else:
                                trusted_gtdb_sp_match = True

                    if untrustworthy_gtdb_sp_match and not trusted_gtdb_sp_match:
                        prev_gtdb_sp_conflicts += 1
                        gtdb_resolved_sp_conflict = True
                else:
                    note = 'Species is unresolved; manual curation is required!'
                    unresolved_sp_count += 1

                if unresolved_species:
                    for gid in type_gids:
                        ltp_species = self.ltp_species(gid, ltp_metadata)

                        fout_unresolved.write(
                            '{}\t{}\t{}\t{}\t{}\t{:.2f}\t{:.3f}\t{:.3f}\t{:.4f}\t{}\t{}\t{}\n'
                            .format(gid, ncbi_sp,
                                    cur_genomes[gid].gtdb_taxa.genus,
                                    cur_genomes[gid].gtdb_taxa.species,
                                    ' / '.join(ltp_species),
                                    np_mean(list(gid_anis[gid].values())),
                                    np_std(list(gid_anis[gid].values())),
                                    np_mean(list(gid_afs[gid].values())),
                                    np_std(list(gid_afs[gid].values())),
                                    cur_genomes[gid].excluded_from_refseq_note,
                                    cur_genomes[gid].ncbi_taxa,
                                    cur_genomes[gid].gtdb_taxa))

            # remove genomes marked as untrustworthy as type at NCBI if one or more potential type strain genomes remaining
            ncbi_untrustworthy_gids = set([
                gid for gid in type_gids if 'untrustworthy as type' in
                cur_genomes[gid].excluded_from_refseq_note
            ])
            if len(type_gids - set(untrustworthy_gids) -
                   ncbi_untrustworthy_gids) >= 1:
                for gid in ncbi_untrustworthy_gids:
                    untrustworthy_gids[
                        gid] = "Genome annotated as `untrustworthy as type` at NCBI and there are other potential type strain genomes available"

            # report cases where genomes marked as untrustworthy as type at NCBI are being retained as potential type strain genomes
            num_ncbi_untrustworthy = len(ncbi_untrustworthy_gids)
            for gid in type_gids:
                if (gid not in untrustworthy_gids and 'untrustworthy as type'
                        in cur_genomes[gid].excluded_from_refseq_note):
                    self.logger.warning(
                        "Retaining genome {} from {} despite being marked as `untrustworthy as type` at NCBI [{:,} of {:,} considered untrustworthy]."
                        .format(gid, ncbi_sp, num_ncbi_untrustworthy,
                                len(type_gids)))

            # write out genomes identified as being untrustworthy
            for gid, reason in untrustworthy_gids.items():
                ltp_species = self.ltp_species(gid, ltp_metadata)

                if 'untrustworthy as type' in cur_genomes[
                        gid].excluded_from_refseq_note:
                    reason += "; considered `untrustworthy as type` at NCBI"
                fout_untrustworthy.write('{}\t{}\t{}\t{}\t{}\n'.format(
                    gid, ncbi_sp, cur_genomes[gid].gtdb_taxa.species,
                    ' / '.join(ltp_species), reason))

                # Sanity check that if the untrustworthy genome has an LTP to only the
                # expected species, that all other genomes also have a hit to the
                # expected species (or potentially no hit). Otherwise, more consideration
                # should be given to the genome with the conflicting LTP hit.
                if len(ltp_species) == 1 and ncbi_sp in ltp_species:
                    other_sp = set()
                    for test_gid in type_gids:
                        ltp_species = self.ltp_species(test_gid, ltp_metadata)
                        if ltp_species and ncbi_sp not in ltp_species:
                            other_sp.update(ltp_species)

                    if other_sp:
                        self.logger.warning(
                            f'Genome {gid} marked as untrustworthy, but this conflicts with high confidence LTP 16S rRNA assignment.'
                        )

            # write out information about all type genomes
            for gid in type_gids:
                ltp_species = self.ltp_species(gid, ltp_metadata)

                fout_genomes.write(
                    '{}\t{}\t{}\t{}\t{}\t{}\t{}\t{:.2f}\t{:.3f}\t{:.3f}\t{:.4f}\t{}\t{}\t{}\t{}\n'
                    .format(gid, gid in untrustworthy_gids, ncbi_sp,
                            cur_genomes[gid].gtdb_taxa.genus,
                            cur_genomes[gid].gtdb_taxa.species,
                            ' / '.join(ltp_species), gtdb_resolved_sp_conflict,
                            np_mean(list(gid_anis[gid].values())),
                            np_std(list(gid_anis[gid].values())),
                            np_mean(list(gid_afs[gid].values())),
                            np_std(list(gid_afs[gid].values())),
                            cur_genomes[gid].excluded_from_refseq_note,
                            cur_genomes[gid].ncbi_taxa,
                            cur_genomes[gid].gtdb_taxa,
                            untrustworthy_gids.get(gid, '')))

            fout.write(
                '{}\t{}\t{}\t{:.2f}\t{:.3f}\t{:.3f}\t{:.4f}\t{}\t{}\n'.format(
                    ncbi_sp, len(type_gids), all_similar, np_mean(anis),
                    np_std(anis), np_mean(afs), np_std(afs), note,
                    ', '.join(type_gids)))

        sys.stdout.write('\n')
        fout.close()
        fout_unresolved.close()
        fout_high_divergence.close()
        fout_genomes.close()
        fout_untrustworthy.close()

        self.logger.info(
            f'Identified {num_divergent:,} species with 1 or more divergent type strain genomes.'
        )
        self.logger.info(
            f' - resolved {ncbi_ltp_resolved:,} species by removing NCBI `untrustworthy as type` genomes with a conflicting LTP 16S rRNA classifications.'
        )
        self.logger.info(
            f' - resolved {ltp_resolved:,} species by considering conflicting LTP 16S rRNA classifications.'
        )
        self.logger.info(
            f' - resolved {intra_ani_resolved:,} species by considering intra-specific ANI values.'
        )
        self.logger.info(
            f' - resolved {gtdb_family_resolved:,} species by considering conflicting GTDB family classifications.'
        )
        self.logger.info(
            f' - resolved {gtdb_genus_resolved:,} species by considering conflicting GTDB genus classifications.'
        )
        self.logger.info(
            f' - resolved {gtdb_sp_resolved:,} species by considering conflicting GTDB species classifications.'
        )
        self.logger.info(
            f' - resolved {ncbi_type_resolved:,} species by considering type material designations at NCBI.'
        )
        self.logger.info(
            f' - resolved {ncbi_rep_resolved:,} species by considering RefSeq reference and representative designations at NCBI.'
        )

        if unresolved_sp_count > 0:
            self.logger.warning(
                f'There are {unresolved_sp_count:,} unresolved species with multiple type strain genomes.'
            )
            self.logger.warning(
                'These should be handled before proceeding with the next step of GTDB species updating.'
            )
            self.logger.warning(
                "This can be done by manual curation and adding genomes to 'untrustworthy_type_ledger'."
            )

        self.logger.info(
            f'Identified {prev_gtdb_sp_conflicts:,} cases where resolved type strain conflicts with prior GTDB assignment.'
        )
class RepActions(object):
    """Perform initial actions required for changed representatives."""
    def __init__(self, ani_cache_file, cpus, output_dir):
        """Initialization."""

        self.output_dir = output_dir
        self.logger = logging.getLogger('timestamp')

        self.fastani = FastANI(ani_cache_file, cpus)

        # action parameters
        self.genomic_update_ani = 99.0
        self.genomic_update_af = 0.80

        self.new_rep_ani = 99.0
        self.new_rep_af = 0.80
        self.new_rep_qs_threshold = 10  # increase in ANI score require to select
        # new representative

        self.action_log = open(os.path.join(self.output_dir, 'action_log.tsv'),
                               'w')
        self.action_log.write(
            'Genome ID\tPrevious GTDB species\tAction\tParameters\n')

        self.new_reps = {}

    def rep_change_gids(self, rep_change_summary_file, field, value):
        """Get genomes with a specific change."""

        gids = {}
        with open(rep_change_summary_file) as f:
            header = f.readline().strip().split('\t')

            field_index = header.index(field)
            prev_sp_index = header.index('Previous GTDB species')

            for line in f:
                line_split = line.strip().split('\t')

                v = line_split[field_index]
                if v == value:
                    prev_sp = line_split[prev_sp_index]
                    gids[line_split[0]] = prev_sp

        return gids

    def top_ani_score_prev_rep(self, prev_rid, sp_cids, prev_genomes,
                               cur_genomes):
        """Identify genome in cluster with highest balanced ANI score to genomic file of representative in previous GTDB release."""

        max_score = -1e6
        max_rid = None
        max_ani = None
        max_af = None
        for cid in sp_cids:
            ani, af = self.fastani.symmetric_ani_cached(
                f'{prev_rid}-P', f'{cid}-C',
                prev_genomes[prev_rid].genomic_file,
                cur_genomes[cid].genomic_file)

            cur_score = cur_genomes[cid].score_ani(ani)
            if (cur_score > max_score
                    or (cur_score == max_score and ani > max_ani)):
                max_score = cur_score
                max_rid = cid
                max_ani = ani
                max_af = af

        return max_rid, max_score, max_ani, max_af

    def top_ani_score(self, prev_rid, sp_cids, cur_genomes):
        """Identify genome in cluster with highest balanced ANI score to representative genome."""

        # calculate ANI between representative and genomes in species cluster
        gid_pairs = []
        for cid in sp_cids:
            gid_pairs.append((cid, prev_rid))
            gid_pairs.append((prev_rid, cid))

        ani_af = self.fastani.pairs(gid_pairs,
                                    cur_genomes.genomic_files,
                                    report_progress=False,
                                    check_cache=True)

        # find genome with top ANI score
        max_score = -1e6
        max_rid = None
        max_ani = None
        max_af = None
        for cid in sp_cids:
            ani, af = symmetric_ani(ani_af, prev_rid, cid)

            cur_score = cur_genomes[cid].score_ani(ani)
            if cur_score > max_score:
                max_score = cur_score
                max_rid = cid
                max_ani = ani
                max_af = af

        return max_rid, max_score, max_ani, max_af

    def get_updated_rid(self, prev_rid):
        """Get updated representative."""

        if prev_rid in self.new_reps:
            gid, action = self.new_reps[prev_rid]
            return gid

        return prev_rid

    def update_rep(self, prev_rid, new_rid, action):
        """Update representative genome for GTDB species cluster."""

        if prev_rid in self.new_reps and self.new_reps[prev_rid][0] != new_rid:
            self.logger.warning(
                'Representative {} was reassigned multiple times: {} {}.'.
                format(prev_rid, self.new_reps[prev_rid], (new_rid, action)))
            self.logger.warning(
                'Assuming last reassignment of {}: {} has priority.'.format(
                    new_rid, action))

        self.new_reps[prev_rid] = (new_rid, action)

    def genomes_in_current_sp_cluster(self, prev_rid, prev_genomes,
                                      new_updated_sp_clusters, cur_genomes):
        """Get genomes in current species cluster."""

        assert prev_rid in prev_genomes.sp_clusters

        sp_cids = prev_genomes.sp_clusters[prev_rid]
        if prev_rid in new_updated_sp_clusters:
            sp_cids = sp_cids.union(new_updated_sp_clusters[prev_rid])
        sp_cids = sp_cids.intersection(cur_genomes)

        return sp_cids

    def action_genomic_lost(self, rep_change_summary_file, prev_genomes,
                            cur_genomes, new_updated_sp_clusters):
        """Handle species with lost representative genome."""

        # get genomes with specific changes
        self.logger.info(
            'Identifying species with lost representative genome.')
        genomic_lost_rids = self.rep_change_gids(rep_change_summary_file,
                                                 'GENOMIC_CHANGE', 'LOST')
        self.logger.info(
            f' ... identified {len(genomic_lost_rids):,} genomes.')

        # calculate ANI between previous and current genomes
        for prev_rid, prev_gtdb_sp in genomic_lost_rids.items():
            sp_cids = self.genomes_in_current_sp_cluster(
                prev_rid, prev_genomes, new_updated_sp_clusters, cur_genomes)

            params = {}
            if sp_cids:
                action = 'GENOMIC_CHANGE:LOST:REPLACED'

                new_rid, top_score, ani, af = self.top_ani_score_prev_rep(
                    prev_rid, sp_cids, prev_genomes, cur_genomes)
                assert (new_rid != prev_rid)

                params['new_rid'] = new_rid
                params['ani'] = ani
                params['af'] = af
                params['new_assembly_quality'] = cur_genomes[
                    new_rid].score_assembly()
                params['prev_assembly_quality'] = prev_genomes[
                    prev_rid].score_assembly()

                self.update_rep(prev_rid, new_rid, action)
            else:
                action = 'GENOMIC_CHANGE:LOST:SPECIES_RETIRED'
                self.update_rep(prev_rid, None, action)

            self.action_log.write('{}\t{}\t{}\t{}\n'.format(
                prev_rid, prev_gtdb_sp, action, params))

    def action_genomic_update(self, rep_change_summary_file, prev_genomes,
                              cur_genomes, new_updated_sp_clusters):
        """Handle representatives with updated genomes."""

        # get genomes with specific changes
        self.logger.info(
            'Identifying representatives with updated genomic files.')
        genomic_update_gids = self.rep_change_gids(rep_change_summary_file,
                                                   'GENOMIC_CHANGE', 'UPDATED')
        self.logger.info(
            f' ... identified {len(genomic_update_gids):,} genomes.')

        # calculate ANI between previous and current genomes
        assembly_score_change = []
        for prev_rid, prev_gtdb_sp in genomic_update_gids.items():
            # check that genome hasn't been lost which should
            # be handled differently
            assert prev_rid in cur_genomes

            ani, af = self.fastani.symmetric_ani_cached(
                f'{prev_rid}-P', f'{prev_rid}-C',
                prev_genomes[prev_rid].genomic_file,
                cur_genomes[prev_rid].genomic_file)

            params = {}
            params['ani'] = ani
            params['af'] = af
            params['prev_ncbi_accession'] = prev_genomes[prev_rid].ncbi_accn
            params['cur_ncbi_accession'] = cur_genomes[prev_rid].ncbi_accn
            assert prev_genomes[prev_rid].ncbi_accn != cur_genomes[
                prev_rid].ncbi_accn

            if ani >= self.genomic_update_ani and af >= self.genomic_update_af:
                params['prev_assembly_quality'] = prev_genomes[
                    prev_rid].score_assembly()
                params['new_assembly_quality'] = cur_genomes[
                    prev_rid].score_assembly()
                action = 'GENOMIC_CHANGE:UPDATED:MINOR_CHANGE'

                d = cur_genomes[prev_rid].score_assembly(
                ) - prev_genomes[prev_rid].score_assembly()
                assembly_score_change.append(d)
            else:
                sp_cids = self.genomes_in_current_sp_cluster(
                    prev_rid, prev_genomes, new_updated_sp_clusters,
                    cur_genomes)

                if sp_cids:
                    new_rid, top_score, ani, af = self.top_ani_score_prev_rep(
                        prev_rid, sp_cids, prev_genomes, cur_genomes)

                    if new_rid == prev_rid:
                        params['prev_assembly_quality'] = prev_genomes[
                            prev_rid].score_assembly()
                        params['new_assembly_quality'] = cur_genomes[
                            prev_rid].score_assembly()
                        action = 'GENOMIC_CHANGE:UPDATED:RETAINED'
                    else:
                        action = 'GENOMIC_CHANGE:UPDATED:REPLACED'
                        params['new_rid'] = new_rid
                        params['ani'] = ani
                        params['af'] = af
                        params['new_assembly_quality'] = cur_genomes[
                            new_rid].score_assembly()
                        params['prev_assembly_quality'] = prev_genomes[
                            prev_rid].score_assembly()

                        self.update_rep(prev_rid, new_rid, action)
                else:
                    action = 'GENOMIC_CHANGE:UPDATED:SPECIES_RETIRED'
                    self.update_rep(prev_rid, None, action)

            self.action_log.write('{}\t{}\t{}\t{}\n'.format(
                prev_rid, prev_gtdb_sp, action, params))

        self.logger.info(
            ' ... change in assembly score for updated genomes: {:.2f} +/- {:.2f}'
            .format(np_mean(assembly_score_change),
                    np_std(assembly_score_change)))

    def action_type_strain_lost(self, rep_change_summary_file, prev_genomes,
                                cur_genomes, new_updated_sp_clusters):
        """Handle representatives which have lost type strain genome status."""

        # get genomes with new NCBI species assignments
        self.logger.info(
            'Identifying representative that lost type strain genome status.')
        ncbi_type_species_lost = self.rep_change_gids(rep_change_summary_file,
                                                      'TYPE_STRAIN_CHANGE',
                                                      'LOST')
        self.logger.info(
            f' ... identified {len(ncbi_type_species_lost):,} genomes.')

        for prev_rid, prev_gtdb_sp in ncbi_type_species_lost.items():
            # check that genome hasn't been lost which should
            # be handled differently
            assert prev_rid in cur_genomes

            sp_cids = self.genomes_in_current_sp_cluster(
                prev_rid, prev_genomes, new_updated_sp_clusters, cur_genomes)

            prev_rep_score = cur_genomes[prev_rid].score_ani(100)
            new_rid, top_score, ani, af = self.top_ani_score(
                prev_rid, sp_cids, cur_genomes)

            params = {}
            params['prev_rid_prev_strain_ids'] = prev_genomes[
                prev_rid].ncbi_strain_identifiers
            params['prev_rid_cur_strain_ids'] = cur_genomes[
                prev_rid].ncbi_strain_identifiers
            params['prev_rid_prev_gtdb_type_designation'] = prev_genomes[
                prev_rid].gtdb_type_designation
            params['prev_rid_cur_gtdb_type_designation'] = cur_genomes[
                prev_rid].gtdb_type_designation
            params[
                'prev_rid_prev_gtdb_type_designation_sources'] = prev_genomes[
                    prev_rid].gtdb_type_designation_sources
            params['prev_rid_cur_gtdb_type_designation_sources'] = cur_genomes[
                prev_rid].gtdb_type_designation_sources

            if top_score > prev_rep_score:
                action = 'TYPE_STRAIN_CHANGE:LOST:REPLACED'
                assert (prev_rid != new_rid)

                params['new_rid'] = new_rid
                params['ani'] = ani
                params['af'] = af
                params['new_assembly_quality'] = cur_genomes[
                    new_rid].score_assembly()
                params['prev_assembly_quality'] = prev_genomes[
                    prev_rid].score_assembly()

                params['new_rid_strain_ids'] = prev_genomes[
                    new_rid].ncbi_strain_identifiers
                params['new_rid_gtdb_type_designation'] = prev_genomes[
                    new_rid].gtdb_type_designation
                params['new_rid_gtdb_type_designation_sources'] = prev_genomes[
                    new_rid].gtdb_type_designation_sources

                self.update_rep(prev_rid, new_rid, action)
            else:
                action = 'TYPE_STRAIN_CHANGE:LOST:RETAINED'

            self.action_log.write('{}\t{}\t{}\t{}\n'.format(
                prev_rid, prev_gtdb_sp, action, params))

    def action_domain_change(self, rep_change_summary_file, prev_genomes,
                             cur_genomes):
        """Handle representatives which have new domain assignments."""

        # get genomes with new NCBI species assignments
        self.logger.info(
            'Identifying representative with new domain assignments.')
        domain_changed = self.rep_change_gids(rep_change_summary_file,
                                              'DOMAIN_CHECK', 'REASSIGNED')
        self.logger.info(f' ... identified {len(domain_changed):,} genomes.')

        for prev_rid, prev_gtdb_sp in domain_changed.items():
            action = 'DOMAIN_CHECK:REASSIGNED'
            params = {}
            params['prev_gtdb_domain'] = prev_genomes[
                prev_rid].gtdb_taxa.domain
            params['cur_gtdb_domain'] = cur_genomes[prev_rid].gtdb_taxa.domain

            self.update_rep(prev_rid, None, action)
            self.action_log.write('{}\t{}\t{}\t{}\n'.format(
                prev_rid, prev_gtdb_sp, action, params))

    def action_improved_rep(self, prev_genomes, cur_genomes,
                            new_updated_sp_clusters):
        """Check if representative should be replace with higher quality genome."""

        self.logger.info(
            'Identifying improved representatives for GTDB species clusters.')
        num_gtdb_ncbi_type_sp = 0
        num_gtdb_type_sp = 0
        num_ncbi_type_sp = 0
        num_complete = 0
        num_isolate = 0
        anis = []
        afs = []
        improved_reps = {}
        for idx, (prev_rid,
                  cids) in enumerate(new_updated_sp_clusters.clusters()):
            if prev_rid not in cur_genomes:
                # indicates genome has been lost
                continue

            prev_gtdb_sp = new_updated_sp_clusters.get_species(prev_rid)
            statusStr = '-> Processing {:,} of {:,} ({:.2f}%) species [{}: {:,} new/updated genomes].'.format(
                idx + 1, len(new_updated_sp_clusters),
                float(idx + 1) * 100 / len(new_updated_sp_clusters),
                prev_gtdb_sp, len(cids)).ljust(86)
            sys.stdout.write('%s\r' % statusStr)
            sys.stdout.flush()

            # get latest representative of GTDB species clusters as it may
            # have been updated by a previous update rule
            prev_updated_rid = self.get_updated_rid(prev_rid)

            prev_rep_score = cur_genomes[prev_updated_rid].score_ani(100)
            new_rid, top_score, ani, af = self.top_ani_score(
                prev_updated_rid, cids, cur_genomes)

            params = {}
            action = None

            if top_score > prev_rep_score + self.new_rep_qs_threshold:
                assert (prev_updated_rid != new_rid)

                if (cur_genomes[prev_updated_rid].is_gtdb_type_strain(
                ) and cur_genomes[prev_updated_rid].ncbi_taxa.specific_epithet
                        != cur_genomes[new_rid].ncbi_taxa.specific_epithet
                        and self.sp_priority_mngr.has_priority(
                            cur_genomes, prev_updated_rid, new_rid)):
                    # GTDB species cluster should not be moved to a different type strain genome
                    # that has lower naming priority
                    self.logger.warning(
                        'Reassignments to type strain genome with lower naming priority is not allowed: {}/{}/{}, {}/{}/{}'
                        .format(
                            prev_updated_rid,
                            cur_genomes[prev_updated_rid].ncbi_taxa.species,
                            cur_genomes[prev_updated_rid].year_of_priority(),
                            new_rid, cur_genomes[new_rid].ncbi_taxa.species,
                            cur_genomes[new_rid].year_of_priority()))
                    continue

                action = 'IMPROVED_REP:REPLACED:HIGHER_QS'

                params['new_rid'] = new_rid
                params['ani'] = ani
                params['af'] = af
                params['new_assembly_quality'] = cur_genomes[
                    new_rid].score_assembly()
                params['prev_assembly_quality'] = cur_genomes[
                    prev_updated_rid].score_assembly()
                params['new_gtdb_type_strain'] = cur_genomes[
                    new_rid].is_gtdb_type_strain()
                params['prev_gtdb_type_strain'] = cur_genomes[
                    prev_updated_rid].is_gtdb_type_strain()
                params['new_ncbi_type_strain'] = cur_genomes[
                    new_rid].is_ncbi_type_strain()
                params['prev_ncbi_type_strain'] = cur_genomes[
                    prev_updated_rid].is_ncbi_type_strain()

                anis.append(ani)
                afs.append(af)

                improvement_list = []
                gtdb_type_improv = cur_genomes[new_rid].is_gtdb_type_strain(
                ) and not cur_genomes[prev_updated_rid].is_gtdb_type_strain()
                ncbi_type_improv = cur_genomes[new_rid].is_ncbi_type_strain(
                ) and not cur_genomes[prev_updated_rid].is_ncbi_type_strain()

                if gtdb_type_improv and ncbi_type_improv:
                    num_gtdb_ncbi_type_sp += 1
                    improvement_list.append(
                        'replaced with genome from type strain according to GTDB and NCBI'
                    )
                elif gtdb_type_improv:
                    num_gtdb_type_sp += 1
                    improvement_list.append(
                        'replaced with genome from type strain according to GTDB'
                    )
                elif ncbi_type_improv:
                    num_ncbi_type_sp += 1
                    improvement_list.append(
                        'replaced with genome from type strain according to NCBI'
                    )

                if cur_genomes[new_rid].is_isolate(
                ) and not cur_genomes[prev_updated_rid].is_isolate():
                    num_isolate += 1
                    improvement_list.append('MAG/SAG replaced with isolate')

                if cur_genomes[new_rid].is_complete_genome(
                ) and not cur_genomes[prev_updated_rid].is_complete_genome():
                    num_complete += 1
                    improvement_list.append('replaced with complete genome')

                if len(improvement_list) == 0:
                    improvement_list.append(
                        'replaced with higher quality genome')

                params['improvements'] = '; '.join(improvement_list)

                self.action_log.write('{}\t{}\t{}\t{}\n'.format(
                    prev_rid, prev_gtdb_sp, action, params))

                improved_reps[prev_rid] = (new_rid, action)

        sys.stdout.write('\n')
        self.logger.info(
            f' ... identified {len(improved_reps):,} species with improved representatives.'
        )
        self.logger.info(
            f'   ... {num_gtdb_ncbi_type_sp:,} replaced with GTDB/NCBI genome from type strain.'
        )
        self.logger.info(
            f'   ... {num_gtdb_type_sp:,} replaced with GTDB genome from type strain.'
        )
        self.logger.info(
            f'   ... {num_ncbi_type_sp:,} replaced with NCBI genome from type strain.'
        )
        self.logger.info(
            f'   ... {num_isolate:,} replaced MAG/SAG with isolate.')
        self.logger.info(
            f'   ... {num_complete:,} replaced with complete genome assembly.')
        self.logger.info(
            f' ... ANI = {np_mean(anis):.2f} +/- {np_std(anis):.2f}%; AF = {np_mean(afs)*100:.2f} +/- {np_std(afs)*100:.2f}%.'
        )

        return improved_reps

    def action_naming_priority(self, prev_genomes, cur_genomes,
                               new_updated_sp_clusters):
        """Check if representative should be replace with genome with higher nomenclatural priority."""

        self.logger.info(
            'Identifying genomes with naming priority in GTDB species clusters.'
        )

        out_file = os.path.join(self.output_dir, 'update_priority.tsv')
        fout = open(out_file, 'w')
        fout.write(
            'NCBI species\tGTDB species\tRepresentative\tStrain IDs\tRepresentative type sources\tPriority year\tGTDB type species\tGTDB type strain\tNCBI assembly type'
        )
        fout.write(
            '\tNCBI synonym\tGTDB synonym\tSynonym genome\tSynonym strain IDs\tSynonym type sources\tPriority year\tGTDB type species\tGTDB type strain\tSynonym NCBI assembly type'
        )
        fout.write('\tANI\tAF\tPriority note\n')

        num_higher_priority = 0
        assembly_score_change = []
        anis = []
        afs = []
        for idx, prev_rid in enumerate(prev_genomes.sp_clusters):
            # get type strain genomes in GTDB species cluster, including genomes new to this release
            type_strain_gids = [
                gid for gid in prev_genomes.sp_clusters[prev_rid]
                if gid in cur_genomes
                and cur_genomes[gid].is_effective_type_strain()
            ]
            if prev_rid in new_updated_sp_clusters:
                new_type_strain_gids = [
                    gid for gid in new_updated_sp_clusters[prev_rid]
                    if cur_genomes[gid].is_effective_type_strain()
                ]
                type_strain_gids.extend(new_type_strain_gids)

            if len(type_strain_gids) == 0:
                continue

            # check if representative has already been updated
            updated_rid = self.get_updated_rid(prev_rid)

            type_strain_sp = set([
                cur_genomes[gid].ncbi_taxa.species for gid in type_strain_gids
            ])
            if len(type_strain_sp) == 1 and updated_rid in type_strain_gids:
                continue

            updated_sp = cur_genomes[updated_rid].ncbi_taxa.species
            highest_priority_gid = updated_rid

            if updated_rid not in type_strain_gids:
                highest_priority_gid = None
                if updated_sp in type_strain_sp:
                    sp_gids = [
                        gid for gid in type_strain_gids
                        if cur_genomes[gid].ncbi_taxa.species == updated_sp
                    ]
                    hq_gid = select_highest_quality(sp_gids, cur_genomes)
                    highest_priority_gid = hq_gid

                #self.logger.warning('Representative is a non-type strain genome even though type strain genomes exist in species cluster: {}: {}, {}: {}'.format(
                #                    prev_rid, cur_genomes[prev_rid].is_effective_type_strain(), updated_rid, cur_genomes[updated_rid].is_effective_type_strain()))
                #self.logger.warning('Type strain genomes: {}'.format(','.join(type_strain_gids)))

            # find highest priority genome
            for sp in type_strain_sp:
                if sp == updated_sp:
                    continue

                # get highest quality genome from species
                sp_gids = [
                    gid for gid in type_strain_gids
                    if cur_genomes[gid].ncbi_taxa.species == sp
                ]
                hq_gid = select_highest_quality(sp_gids, cur_genomes)

                if highest_priority_gid is None:
                    highest_priority_gid = hq_gid
                else:
                    highest_priority_gid, note = self.sp_priority_mngr.priority(
                        cur_genomes, highest_priority_gid, hq_gid)

            # check if representative should be updated
            if highest_priority_gid != updated_rid:
                num_higher_priority += 1

                ani, af = self.fastani.symmetric_ani_cached(
                    updated_rid, highest_priority_gid,
                    cur_genomes[updated_rid].genomic_file,
                    cur_genomes[highest_priority_gid].genomic_file)

                anis.append(ani)
                afs.append(af)

                d = cur_genomes[highest_priority_gid].score_assembly(
                ) - cur_genomes[updated_rid].score_assembly()
                assembly_score_change.append(d)

                action = 'NOMENCLATURE_PRIORITY:REPLACED'
                params = {}
                params['prev_ncbi_species'] = cur_genomes[
                    updated_rid].ncbi_taxa.species
                params['prev_year_of_priority'] = cur_genomes[
                    updated_rid].year_of_priority()
                params['new_ncbi_species'] = cur_genomes[
                    highest_priority_gid].ncbi_taxa.species
                params['new_year_of_priority'] = cur_genomes[
                    highest_priority_gid].year_of_priority()
                params['new_rid'] = highest_priority_gid
                params['ani'] = ani
                params['af'] = af
                params['priority_note'] = note

                self.update_rep(prev_rid, highest_priority_gid, action)
                self.action_log.write('{}\t{}\t{}\t{}\n'.format(
                    prev_rid, cur_genomes[updated_rid].gtdb_taxa.species,
                    action, params))

                fout.write('{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}'.format(
                    cur_genomes[highest_priority_gid].ncbi_taxa.species,
                    cur_genomes[highest_priority_gid].gtdb_taxa.species,
                    highest_priority_gid, ','.join(
                        sorted(
                            cur_genomes[highest_priority_gid].strain_ids())),
                    ','.join(
                        sorted(cur_genomes[highest_priority_gid].
                               gtdb_type_sources())).upper().replace(
                                   'STRAININFO', 'StrainInfo'),
                    cur_genomes[highest_priority_gid].year_of_priority(),
                    cur_genomes[highest_priority_gid].is_gtdb_type_species(),
                    cur_genomes[highest_priority_gid].is_gtdb_type_strain(),
                    cur_genomes[highest_priority_gid].ncbi_type_material))
                fout.write('\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}'.format(
                    cur_genomes[updated_rid].ncbi_taxa.species,
                    cur_genomes[updated_rid].gtdb_taxa.species, updated_rid,
                    ','.join(sorted(cur_genomes[updated_rid].strain_ids())),
                    ','.join(
                        sorted(cur_genomes[updated_rid].gtdb_type_sources())
                    ).upper().replace('STRAININFO', 'StrainInfo'),
                    cur_genomes[updated_rid].year_of_priority(),
                    cur_genomes[updated_rid].is_gtdb_type_species(),
                    cur_genomes[updated_rid].is_gtdb_type_strain(),
                    cur_genomes[updated_rid].ncbi_type_material))
                fout.write('\t{:.3f}\t{:.4f}\t{}\n'.format(ani, af, note))

        fout.close()

        self.logger.info(
            f' ... identified {num_higher_priority:,} species with representative changed to genome with higher nomenclatural priority.'
        )
        self.logger.info(
            ' ... change in assembly score for new representatives: {:.2f} +/- {:.2f}'
            .format(np_mean(assembly_score_change),
                    np_std(assembly_score_change)))
        self.logger.info(' ... ANI: {:.2f} +/- {:.2f}'.format(
            np_mean(anis), np_std(anis)))
        self.logger.info(' ... AF: {:.2f} +/- {:.2f}'.format(
            np_mean(afs), np_std(afs)))

    def write_updated_clusters(self, prev_genomes, cur_genomes, new_reps,
                               new_updated_sp_clusters, out_file):
        """Write out updated GTDB species clusters."""

        self.logger.info(
            'Writing updated GTDB species clusters to file: {}'.format(
                out_file))

        fout = open(out_file, 'w')
        fout.write(
            'Representative genome\tGTDB species\tNo. clustered genomes\tClustered genomes\n'
        )

        cur_genome_set = set(cur_genomes)

        num_clusters = 0
        for idx, prev_rid in enumerate(prev_genomes.sp_clusters):

            new_rid, action = new_reps.get(prev_rid, [prev_rid, None])
            if new_rid is None:
                continue

            sp_cids = self.genomes_in_current_sp_cluster(
                prev_rid, prev_genomes, new_updated_sp_clusters,
                cur_genome_set)

            fout.write('{}\t{}\t{}\t{}\n'.format(
                new_rid, prev_genomes.sp_clusters.get_species(prev_rid),
                len(sp_cids), ','.join(sp_cids)))
            num_clusters += 1

        fout.close()

        self.logger.info(f' ... wrote {num_clusters:,} clusters.')

    def run(self, rep_change_summary_file, prev_gtdb_metadata_file,
            prev_genomic_path_file, cur_gtdb_metadata_file,
            cur_genomic_path_file, uba_genome_paths, genomes_new_updated_file,
            qc_passed_file, gtdbtk_classify_file, ncbi_genbank_assembly_file,
            untrustworthy_type_file, gtdb_type_strains_ledger,
            sp_priority_ledger):
        """Perform initial actions required for changed representatives."""

        # create previous and current GTDB genome sets
        self.logger.info('Creating previous GTDB genome set.')
        prev_genomes = Genomes()
        prev_genomes.load_from_metadata_file(
            prev_gtdb_metadata_file,
            gtdb_type_strains_ledger=gtdb_type_strains_ledger,
            uba_genome_file=uba_genome_paths,
            ncbi_genbank_assembly_file=ncbi_genbank_assembly_file,
            untrustworthy_type_ledger=untrustworthy_type_file)
        self.logger.info(
            ' ... previous genome set has {:,} species clusters spanning {:,} genomes.'
            .format(len(prev_genomes.sp_clusters),
                    prev_genomes.sp_clusters.total_num_genomes()))

        self.logger.info('Creating current GTDB genome set.')
        cur_genomes = Genomes()
        cur_genomes.load_from_metadata_file(
            cur_gtdb_metadata_file,
            gtdb_type_strains_ledger=gtdb_type_strains_ledger,
            create_sp_clusters=False,
            uba_genome_file=uba_genome_paths,
            qc_passed_file=qc_passed_file,
            ncbi_genbank_assembly_file=ncbi_genbank_assembly_file,
            untrustworthy_type_ledger=untrustworthy_type_file)
        self.logger.info(
            f' ... current genome set contains {len(cur_genomes):,} genomes.')

        # get path to previous and current genomic FASTA files
        self.logger.info(
            'Reading path to previous and current genomic FASTA files.')
        prev_genomes.load_genomic_file_paths(prev_genomic_path_file)
        prev_genomes.load_genomic_file_paths(uba_genome_paths)
        cur_genomes.load_genomic_file_paths(cur_genomic_path_file)
        cur_genomes.load_genomic_file_paths(uba_genome_paths)

        # created expanded previous GTDB species clusters
        new_updated_sp_clusters = SpeciesClusters()

        self.logger.info(
            'Creating species clusters of new and updated genomes based on GTDB-Tk classifications.'
        )
        new_updated_sp_clusters.create_expanded_clusters(
            prev_genomes.sp_clusters, genomes_new_updated_file, qc_passed_file,
            gtdbtk_classify_file)

        self.logger.info(
            'Identified {:,} expanded species clusters spanning {:,} genomes.'.
            format(len(new_updated_sp_clusters),
                   new_updated_sp_clusters.total_num_genomes()))

        # initialize species priority manager
        self.sp_priority_mngr = SpeciesPriorityManager(sp_priority_ledger)

        # take required action for each changed representatives
        self.action_genomic_lost(rep_change_summary_file, prev_genomes,
                                 cur_genomes, new_updated_sp_clusters)

        self.action_genomic_update(rep_change_summary_file, prev_genomes,
                                   cur_genomes, new_updated_sp_clusters)

        self.action_type_strain_lost(rep_change_summary_file, prev_genomes,
                                     cur_genomes, new_updated_sp_clusters)

        self.action_domain_change(rep_change_summary_file, prev_genomes,
                                  cur_genomes)

        if True:  #***
            improved_reps = self.action_improved_rep(prev_genomes, cur_genomes,
                                                     new_updated_sp_clusters)

            pickle.dump(
                improved_reps,
                open(os.path.join(self.output_dir, 'improved_reps.pkl'), 'wb'))
        else:
            self.logger.warning(
                'Reading improved_reps for pre-cached file. Generally used only for debugging.'
            )
            improved_reps = pickle.load(
                open(os.path.join(self.output_dir, 'improved_reps.pkl'), 'rb'))

        for prev_rid, (new_rid, action) in improved_reps.items():
            self.update_rep(prev_rid, new_rid, action)

        self.action_naming_priority(prev_genomes, cur_genomes,
                                    new_updated_sp_clusters)

        # report basic statistics
        num_retired_sp = sum(
            [1 for v in self.new_reps.values() if v[0] is None])
        num_replaced_rids = sum(
            [1 for v in self.new_reps.values() if v[0] is not None])
        self.logger.info(f'Identified {num_retired_sp:,} retired species.')
        self.logger.info(
            f'Identified {num_replaced_rids:,} species with a modified representative genome.'
        )

        self.action_log.close()

        # write out representatives for existing species clusters
        fout = open(os.path.join(self.output_dir, 'updated_species_reps.tsv'),
                    'w')
        fout.write(
            'Previous representative ID\tNew representative ID\tAction\tRepresentative status\n'
        )
        for rid in prev_genomes.sp_clusters:
            if rid in self.new_reps:
                new_rid, action = self.new_reps[rid]
                if new_rid is not None:
                    fout.write(f'{rid}\t{new_rid}\t{action}\tREPLACED\n')
                else:
                    fout.write(f'{rid}\t{new_rid}\t{action}\tLOST\n')
            else:
                fout.write(f'{rid}\t{rid}\tNONE\tUNCHANGED\n')

        fout.close()

        # write out updated species clusters
        out_file = os.path.join(self.output_dir, 'updated_sp_clusters.tsv')
        self.write_updated_clusters(prev_genomes, cur_genomes, self.new_reps,
                                    new_updated_sp_clusters, out_file)