Beispiel #1
0
    def nonrep_radius(self, unclustered_gids, rep_gids, ani_af_rep_vs_nonrep):
        """Calculate circumscription radius for unclustered, nontype genomes."""

        # set radius for genomes to default values
        nonrep_radius = {}
        for gid in unclustered_gids:
            nonrep_radius[gid] = GenomeRadius(ani=self.ani_sp,
                                              af=None,
                                              neighbour_gid=None)

        # determine closest type ANI neighbour and restrict ANI radius as necessary
        ani_af = pickle.load(open(ani_af_rep_vs_nonrep, 'rb'))
        for nonrep_gid in unclustered_gids:
            if nonrep_gid not in ani_af:
                continue

            for rep_gid in rep_gids:
                if rep_gid not in ani_af[nonrep_gid]:
                    continue

                ani, af = FastANI.symmetric_ani(ani_af, nonrep_gid, rep_gid)

                if ani > nonrep_radius[nonrep_gid].ani and af >= self.af_sp:
                    nonrep_radius[nonrep_gid] = GenomeRadius(ani=ani,
                                                             af=af,
                                                             neighbour_gid=rep_gid)

        self.logger.info('ANI circumscription radius: min={:.2f}, mean={:.2f}, max={:.2f}'.format(
            min([d.ani for d in nonrep_radius.values()]),
            np_mean([d.ani for d in nonrep_radius.values()]),
            max([d.ani for d in nonrep_radius.values()])))

        return nonrep_radius
Beispiel #2
0
    def _cluster(self, ani_af, non_reps, rep_radius):
        """Cluster non-representative to representative genomes using species specific ANI thresholds."""

        clusters = {}
        for rep_id in rep_radius:
            clusters[rep_id] = []

        num_clustered = 0
        for idx, non_rid in enumerate(non_reps):
            if idx % 100 == 0:
                sys.stdout.write('==> Processed {:,} of {:,} genomes [no. clustered = {:,}].\r'.format(
                    idx+1,
                    len(non_reps),
                    num_clustered))
                sys.stdout.flush()

            if non_rid not in ani_af:
                continue

            closest_rid = None
            closest_ani = 0
            closest_af = 0
            for rid in rep_radius:
                if rid not in ani_af[non_rid]:
                    continue

                ani, af = FastANI.symmetric_ani(ani_af, rid, non_rid)

                if af >= self.af_sp:
                    if ani > closest_ani or (ani == closest_ani and af > closest_af):
                        closest_rid = rid
                        closest_ani = ani
                        closest_af = af

            if closest_rid:
                if closest_ani > rep_radius[closest_rid].ani:
                    num_clustered += 1
                    clusters[closest_rid].append(self.ClusteredGenome(gid=non_rid,
                                                                      ani=closest_ani,
                                                                      af=closest_af))

        sys.stdout.write('==> Processed {:,} of {:,} genomes [no. clustered = {:,}].\r'.format(
            len(non_reps),
            len(non_reps),
            num_clustered))
        sys.stdout.flush()
        sys.stdout.write('\n')

        num_unclustered = len(non_reps) - num_clustered
        self.logger.info('Assigned {:,} genomes to {:,} representatives; {:,} genomes remain unclustered.'.format(
            sum([len(clusters[rid]) for rid in clusters]),
            len(clusters),
            num_unclustered))

        return clusters
    def run(self, target_genus, gtdb_metadata_file, genomic_path_file):
        """Dereplicate GTDB species clusters using ANI/AF criteria."""

        # create GTDB genome sets
        self.logger.info('Creating GTDB genome set.')
        genomes = Genomes()
        genomes.load_from_metadata_file(gtdb_metadata_file)
        genomes.load_genomic_file_paths(genomic_path_file)
        self.logger.info(
            ' - genome set has {:,} species clusters spanning {:,} genomes.'.
            format(len(genomes.sp_clusters),
                   genomes.sp_clusters.total_num_genomes()))

        # identify GTDB representatives from target genus
        self.logger.info('Identifying GTDB representatives from target genus.')
        target_gids = set()
        for gid in genomes:
            if genomes[gid].is_gtdb_sp_rep(
            ) and genomes[gid].gtdb_taxa.genus == target_genus:
                target_gids.add(gid)
        self.logger.info(' - identified {:,} genomes.'.format(
            len(target_gids)))

        # calculate FastANI ANI/AF between target genomes
        self.logger.info('Calculating pairwise ANI between target genomes.')
        ani_af = self.fastani.pairwise(target_gids,
                                       genomes.genomic_files,
                                       check_cache=True)
        self.fastani.write_cache(silence=True)

        # write out results
        genus_label = target_genus.replace('g__', '').lower()
        fout = open(
            os.path.join(self.output_dir,
                         '{}_rep_ani.tsv'.format(genus_label)), 'w')
        fout.write(
            'Query ID\tQuery species\tTarget ID\tTarget species\tANI\tAF\n')
        for qid in target_gids:
            for rid in target_gids:
                ani, af = FastANI.symmetric_ani(ani_af, qid, rid)

                fout.write('{}\t{}\t{}\t{}\t{:.3f}\t{:.3f}\n'.format(
                    qid, genomes[qid].gtdb_taxa.species, rid,
                    genomes[rid].gtdb_taxa.species, ani, af))
        fout.close()
    def rep_genome_stats(self, clusters, genome_files):
        """Calculate statistics relative to representative genome."""

        self.logger.info('Calculating statistics to cluster representatives:')
        stats = {}
        for idx, (rid, cids) in enumerate(clusters.items()):
            if len(cids) == 0:
                stats[rid] = self.RepStats(min_ani=-1,
                                           mean_ani=-1,
                                           std_ani=-1,
                                           median_ani=-1)
            else:
                # calculate ANI to representative genome
                gid_pairs = []
                for cid in cids:
                    gid_pairs.append((cid, rid))
                    gid_pairs.append((rid, cid))

                if True:  # *** DEBUGGING
                    ani_af = self.fastani.pairs(gid_pairs,
                                                genome_files,
                                                report_progress=False)
                else:
                    ani_af = self.fastani.ani_cache

                # calculate statistics
                anis = [FastANI.symmetric_ani(ani_af, cid, rid)[
                    0] for cid in cids]

                stats[rid] = self.RepStats(min_ani=min(anis),
                                           mean_ani=np_mean(anis),
                                           std_ani=np_std(anis),
                                           median_ani=np_median(anis))

            statusStr = '-> Processing %d of %d (%.2f%%) clusters.'.ljust(86) % (
                idx+1,
                len(clusters),
                float((idx+1)*100)/len(clusters))
            sys.stdout.write('%s\r' % statusStr)
            sys.stdout.flush()

        sys.stdout.write('\n')

        return stats
    def find_multiple_reps(self, clusters, cluster_radius):
        """Determine number of non-rep genomes within ANI radius of multiple rep genomes.

        This method assumes the ANI cache contains all relevant ANI calculations between
        representative and non-representative genomes. This is the case once the de novo
        clustering has been performed.
        """

        self.logger.info(
            'Determine number of non-rep genomes within ANI radius of multiple rep genomes.')

        # get clustered genomes IDs
        clustered_gids = []
        for rid in clusters:
            clustered_gids += clusters[rid]

        self.logger.info('Considering {:,} representatives and {:,} non-representative genomes.'.format(
            len(clusters),
            len(clustered_gids)))

        nonrep_rep_count = defaultdict(set)
        for idx, gid in enumerate(clustered_gids):
            cur_ani_cache = self.fastani.ani_cache[gid]
            for rid in clusters:
                if rid not in cur_ani_cache:
                    continue

                ani, af = FastANI.symmetric_ani(
                    self.fastani.ani_cache, gid, rid)
                if af >= self.af_sp and ani >= cluster_radius[rid].ani:
                    nonrep_rep_count[gid].add((rid, ani))

            if (idx+1) % 100 == 0 or (idx+1) == len(clustered_gids):
                statusStr = '-> Processing %d of %d (%.2f%%) clusters genomes.'.ljust(86) % (
                    idx+1,
                    len(clustered_gids),
                    float((idx+1)*100)/len(clustered_gids))
                sys.stdout.write('%s\r' % statusStr)
                sys.stdout.flush()

        sys.stdout.write('\n')

        return nonrep_rep_count
    def top_hits(self, species, rid, ani_af, genomes):
        """Report top 5 hits to species."""

        results = {}
        for qid in ani_af[rid]:
            ani, af = FastANI.symmetric_ani(ani_af, rid, qid)
            results[qid] = (ani, af)

        self.logger.info(f'Closest 5 species to {species} ({rid}):')
        idx = 0
        for qid, (ani, af) in sorted(results.items(),
                                     key=lambda x: x[1],
                                     reverse=True):
            q_species = genomes[qid].gtdb_species
            self.logger.info(
                f'{q_species} ({qid}): ANI={ani:.1f}%, AF={af:.2f}')
            if idx == 5:
                break

            idx += 1
    def merge_ani_radius(self, species, rid, merged_sp_cluster, genomic_files):
        """Determine ANI radius if species were merged."""

        self.logger.info(
            f'Calculating ANI from {species} to all genomes in merged species cluster.'
        )

        gid_pairs = []
        for gid in merged_sp_cluster:
            gid_pairs.append((rid, gid))
            gid_pairs.append((gid, rid))
        merged_ani_af1 = self.fastani.pairs(gid_pairs, genomic_files)

        ani_radius = 100
        for gid in merged_sp_cluster:
            ani, af = FastANI.symmetric_ani(merged_ani_af1, rid, gid)
            if ani < ani_radius:
                ani_radius = ani
                af_radius = af
        self.logger.info(
            f'Merged cluster with {species} rep: ANI radius={ani_radius:.1f}%, AF={af_radius:.2f}'
        )
Beispiel #8
0
    def calculate_type_strain_ani(self, ncbi_sp, type_gids, cur_genomes,
                                  use_pickled_results):
        """Calculate pairwise ANI between type strain genomes."""

        ncbi_sp_str = ncbi_sp[3:].lower().replace(' ', '_')
        if not use_pickled_results:  # ***
            ani_af = self.fastani.pairwise(type_gids,
                                           cur_genomes.genomic_files)
            pickle.dump(
                ani_af,
                open(os.path.join(self.ani_pickle_dir, f'{ncbi_sp_str}.pkl'),
                     'wb'))
        else:
            ani_af = pickle.load(
                open(os.path.join(self.ani_pickle_dir, f'{ncbi_sp_str}.pkl'),
                     'rb'))

        anis = []
        afs = []
        gid_anis = defaultdict(lambda: {})
        gid_afs = defaultdict(lambda: {})
        all_similar = True
        for gid1, gid2 in combinations(type_gids, 2):
            ani, af = FastANI.symmetric_ani(ani_af, gid1, gid2)
            if ani < 99 or af < 0.65:
                all_similar = False

            anis.append(ani)
            afs.append(af)

            gid_anis[gid1][gid2] = ani
            gid_anis[gid2][gid1] = ani

            gid_afs[gid1][gid2] = af
            gid_afs[gid2][gid1] = af

        return all_similar, anis, afs, gid_anis, gid_afs
    def run(self, gtdb_metadata_file, genome_path_file, species1, species2):
        """Produce information relevant to merging two sister species."""

        # read GTDB species clusters
        self.logger.info('Reading GTDB species clusters.')
        genomes = Genomes()
        genomes.load_from_metadata_file(gtdb_metadata_file)
        genomes.load_genomic_file_paths(genome_path_file)
        self.logger.info(
            ' - identified {:,} species clusters spanning {:,} genomes.'.
            format(len(genomes.sp_clusters),
                   genomes.sp_clusters.total_num_genomes()))

        # find species of interest
        gid1 = None
        gid2 = None
        for gid, species in genomes.sp_clusters.species():
            if species == species1:
                gid1 = gid
            elif species == species2:
                gid2 = gid

        if gid1 is None:
            self.logger.error(
                f'Unable to find representative genome for {species1}.')
            sys.exit(-1)

        if gid2 is None:
            self.logger.error(
                f'Unable to find representative genome for {species2}.')
            sys.exit(-1)

        self.logger.info(' - identified {:,} genomes in {}.'.format(
            len(genomes.sp_clusters[gid1]), species1))
        self.logger.info(' - identified {:,} genomes in {}.'.format(
            len(genomes.sp_clusters[gid2]), species2))

        # calculate ANI between all genome in genus
        genus1 = genomes[gid1].gtdb_genus
        genus2 = genomes[gid2].gtdb_genus
        if genus1 != genus2:
            self.logger.error(
                f'Genomes must be from same genus: {genus1} {genus2}')
            sys.exit(-1)

        self.logger.info(f'Identifying {genus1} species representatives.')
        reps_in_genera = set()
        for rid in genomes.sp_clusters:
            if genomes[rid].gtdb_genus == genus1:
                reps_in_genera.add(rid)

        self.logger.info(
            f' - identified {len(reps_in_genera):,} representatives.')

        # calculate ANI between genomes
        self.logger.info(f'Calculating ANI to {species1}.')
        gid_pairs = []
        for gid in reps_in_genera:
            if gid != gid1:
                gid_pairs.append((gid1, gid))
                gid_pairs.append((gid, gid1))
        ani_af1 = self.fastani.pairs(gid_pairs, genomes.genomic_files)

        self.logger.info(f'Calculating ANI to {species2}.')
        gid_pairs = []
        for gid in reps_in_genera:
            if gid != gid2:
                gid_pairs.append((gid2, gid))
                gid_pairs.append((gid, gid2))
        ani_af2 = self.fastani.pairs(gid_pairs, genomes.genomic_files)

        # report results
        ani12, af12 = ani_af1[gid1][gid2]
        ani21, af21 = ani_af2[gid2][gid1]
        ani, af = FastANI.symmetric_ani(ani_af1, gid1, gid2)

        self.logger.info(
            f'{species1} ({gid1}) -> {species2} ({gid2}): ANI={ani12:.1f}%, AF={af12:.2f}'
        )
        self.logger.info(
            f'{species2} ({gid2}) -> {species1} ({gid1}): ANI={ani21:.1f}%, AF={af21:.2f}'
        )
        self.logger.info(f'Max. ANI={ani:.1f}%, Max. AF={af:.2f}')

        # report top hits
        self.top_hits(species1, gid1, ani_af1, genomes)
        self.top_hits(species2, gid2, ani_af2, genomes)

        # calculate ANI from species to all genomes in merged species cluster
        merged_sp_cluster = genomes.sp_clusters[gid1].union(
            genomes.sp_clusters[gid2])
        self.merge_ani_radius(species1, gid1, merged_sp_cluster,
                              genomes.genomic_files)
        self.merge_ani_radius(species2, gid2, merged_sp_cluster,
                              genomes.genomic_files)
    def pairwise_stats(self, clusters, genome_files):
        """Calculate statistics for all pairwise comparisons in a species cluster."""

        self.logger.info(
            f'Restricting pairwise comparisons to {self.max_genomes_for_stats:,} randomly selected genomes.')
        self.logger.info(
            'Calculating statistics for all pairwise comparisons in a species cluster:')

        stats = {}
        for idx, (rid, cids) in enumerate(clusters.items()):
            statusStr = '-> Processing {:,} of {:,} ({:2f}%) clusters (size = {:,}).'.ljust(86).format(
                idx+1,
                len(clusters),
                float((idx+1)*100)/len(clusters),
                len(cids))
            sys.stdout.write('{}\r'.format(statusStr))
            sys.stdout.flush()

            if len(cids) == 0:
                stats[rid] = self.PairwiseStats(min_ani=-1,
                                                mean_ani=-1,
                                                std_ani=-1,
                                                median_ani=-1,
                                                ani_to_medoid=-1,
                                                mean_ani_to_medoid=-1,
                                                mean_ani_to_rep=-1,
                                                ani_below_95=-1)
            else:
                if len(cids) > self.max_genomes_for_stats:
                    cids = set(random.sample(cids, self.max_genomes_for_stats))

                # calculate ANI to representative genome
                gid_pairs = []
                gids = list(cids.union([rid]))
                for gid1, gid2 in combinations(gids, 2):
                    gid_pairs.append((gid1, gid2))
                    gid_pairs.append((gid2, gid1))

                if True:  # ***DEBUGGING
                    ani_af = self.fastani.pairs(gid_pairs,
                                                genome_files,
                                                report_progress=False)
                else:
                    ani_af = self.fastani.ani_cache

                # calculate medoid point
                if len(gids) > 2:
                    dist_mat = np_zeros((len(gids), len(gids)))
                    for i, gid1 in enumerate(gids):
                        for j, gid2 in enumerate(gids):
                            if i < j:
                                ani, _af = FastANI.symmetric_ani(
                                    ani_af, gid1, gid2)
                                dist_mat[i, j] = 100 - ani
                                dist_mat[j, i] = 100 - ani

                    medoid_idx = np_argmin(dist_mat.sum(axis=0))
                    medoid_gid = gids[medoid_idx]
                else:
                    # with only 2 genomes in a cluster, the representative is the
                    # natural medoid at least for reporting statistics for the
                    # individual species cluster
                    medoid_gid = rid

                mean_ani_to_medoid = np_mean([FastANI.symmetric_ani(ani_af, gid, medoid_gid)[0]
                                              for gid in gids if gid != medoid_gid])

                mean_ani_to_rep = np_mean([FastANI.symmetric_ani(ani_af, gid, rid)[0]
                                           for gid in gids if gid != rid])

                if mean_ani_to_medoid < mean_ani_to_rep:
                    self.logger.error('mean_ani_to_medoid < mean_ani_to_rep')
                    sys.exit(-1)

                # calculate statistics
                anis = []
                for gid1, gid2 in combinations(gids, 2):
                    ani, _af = FastANI.symmetric_ani(ani_af, gid1, gid2)
                    anis.append(ani)

                stats[rid] = self.PairwiseStats(
                    min_ani=min(anis),
                    mean_ani=np_mean(anis),
                    std_ani=np_std(anis),
                    median_ani=np_median(anis),
                    ani_to_medoid=FastANI.symmetric_ani(
                        ani_af, rid, medoid_gid)[0],
                    mean_ani_to_medoid=mean_ani_to_medoid,
                    mean_ani_to_rep=mean_ani_to_rep,
                    ani_below_95=sum([1 for ani in anis if ani < 95]))

        sys.stdout.write('\n')

        return stats
    def intragenus_pairwise_ani(self, clusters, species, genome_files, gtdb_taxonomy):
        """Determine pairwise intra-genus ANI between representative genomes."""

        self.logger.info(
            'Calculating pairwise intra-genus ANI values between GTDB representatives.')

        # get genus for each representative
        genus = {}
        for rid, sp in species.items():
            genus[rid] = sp.split()[0].replace('s__', '')
            assert genus[rid] == gtdb_taxonomy[rid][5].replace('g__', '')

        # get pairs above Mash threshold
        self.logger.info('Determining intra-genus genome pairs.')
        ani_pairs = []
        for qid in clusters:
            for rid in clusters:
                if qid == rid:
                    continue

                genusA = genus[qid]
                genusB = genus[rid]
                if genusA != genusB:
                    continue

                ani_pairs.append((qid, rid))
                ani_pairs.append((rid, qid))

        self.logger.info(
            'Identified {:,} intra-genus genome pairs.'.format(len(ani_pairs)))

        # calculate ANI between pairs
        self.logger.info(
            'Calculating ANI between {:,} genome pairs:'.format(len(ani_pairs)))
        if True:  # ***DEBUGGING
            ani_af = self.fastani.pairs(ani_pairs, genome_files)
            pickle.dump(ani_af, open(os.path.join(
                self.output_dir, 'type_genomes_ani_af.pkl'), 'wb'))
        else:
            ani_af = pickle.load(
                open(os.path.join(self.output_dir, 'type_genomes_ani_af.pkl'), 'rb'))

        # find closest intra-genus pair for each rep
        fout = open(os.path.join(self.output_dir,
                                 'intra_genus_pairwise_ani.tsv'), 'w')
        fout.write(
            'Genus\tSpecies 1\tGenome ID 1\tSpecies 2\tGenome ID2\tANI\tAF\n')
        closest_intragenus_rep = {}
        for qid in clusters:
            genusA = genus[qid]

            closest_ani = 0
            closest_af = 0
            closest_gid = None
            for rid in clusters:
                if qid == rid:
                    continue

                genusB = genus[rid]
                if genusA != genusB:
                    continue

                ani, af = ('n/a', 'n/a')
                if qid in ani_af and rid in ani_af[qid]:
                    ani, af = FastANI.symmetric_ani(ani_af, qid, rid)

                    if ani > closest_ani:
                        closest_ani = ani
                        closest_af = af
                        closest_gid = rid

                fout.write('{}\t{}\t{}\t{}\t{}\t{}\t{}\n'.format(
                    genusA,
                    species[qid],
                    qid,
                    species[rid],
                    rid,
                    ani,
                    af))

            if closest_gid:
                closest_intragenus_rep[qid] = (
                    closest_gid, closest_ani, closest_af)

        fout.close()

        # write out closest intra-genus species to each representative
        fout = open(os.path.join(self.output_dir,
                                 'closest_intragenus_rep.tsv'), 'w')
        fout.write(
            'Genome ID\tSpecies\tIntra-genus neighbour\tIntra-genus species\tANI\tAF\n')
        for qid in closest_intragenus_rep:
            rid, ani, af = closest_intragenus_rep[qid]

            fout.write('%s\t%s\t%s\t%s\t%.2f\t%.3f\n' % (
                qid,
                species[qid],
                rid,
                species[rid],
                ani,
                af))
        fout.close()
    def dereplicate_species(self, species, rid, cids, genomes, mash_out_dir):
        """Dereplicate genomes within a GTDB species."""

        # greedily dereplicate genomes based on genome priority
        sorted_gids = self.order_genomes_by_priority(cids.difference([rid]),
                                                     genomes)
        sorted_gids = [rid] + sorted_gids

        # calculate Mash ANI between genomes
        mash_ani = []
        if len(sorted_gids) > 1:
            # calculate MASH distances between genomes
            out_prefix = os.path.join(mash_out_dir,
                                      species[3:].lower().replace(' ', '_'))
            mash_ani = self.mash_sp_ani(sorted_gids, genomes, out_prefix)

        # perform initial dereplication using Mash for species with excessive
        # numbers of genomes
        if len(sorted_gids) > self.max_genomes_per_sp:
            self.logger.info(
                ' - limiting species to <={:,} genomes based on priority and Mash dereplication.'
                .format(self.max_genomes_per_sp))

            prev_mash_rep_gids = None
            for ani_threshold in [
                    99.75, 99.5, 99.25, 99.0, 98.75, 98.5, 98.25, 98.0, 97.75,
                    97.5, 97.0, 96.5, 96.0, 95.0, None
            ]:
                if ani_threshold is None:
                    self.logger.warning(
                        ' - delected {:,} highest priority genomes from final Mash dereplication.'
                        % self.max_genomes_per_sp)
                    sorted_gids = mash_rep_gids[0:self.max_genomes_per_sp]
                    break

                mash_rep_gids = self.mash_sp_dereplicate(
                    mash_ani, sorted_gids, ani_threshold)

                self.logger.info(
                    ' - dereplicated {} from {:,} to {:,} genomes at {:.2f}% ANI using Mash.'
                    .format(species, len(cids), len(mash_rep_gids),
                            ani_threshold))

                if len(mash_rep_gids) <= self.max_genomes_per_sp:
                    if not prev_mash_rep_gids:
                        # corner case where dereplication is occurring at 99.75%
                        prev_mash_rep_gids = sorted_gids

                    # select maximum allowed number of genomes by taking all genomes in the
                    # current Mash dereplicated set and then the highest priority genomes in the
                    # previous Mash dereplicated set which have not been selected
                    cur_sel_gids = set(mash_rep_gids)
                    prev_sel_gids = set(prev_mash_rep_gids)
                    num_prev_to_sel = self.max_genomes_per_sp - len(
                        cur_sel_gids)
                    num_prev_selected = 0
                    sel_sorted_gids = []
                    for gid in sorted_gids:
                        if gid in cur_sel_gids:
                            sel_sorted_gids.append(gid)
                        elif (gid in prev_sel_gids
                              and num_prev_selected < num_prev_to_sel):
                            num_prev_selected += 1
                            sel_sorted_gids.append(gid)

                        if len(sel_sorted_gids) == self.max_genomes_per_sp:
                            break

                    assert len(cur_sel_gids - set(sel_sorted_gids)) == 0
                    assert num_prev_to_sel == num_prev_selected
                    assert len(sel_sorted_gids) == self.max_genomes_per_sp

                    sorted_gids = sel_sorted_gids
                    self.logger.info(
                        ' - selected {:,} highest priority genomes from Mash dereplication at an ANI = {:.2f}%.'
                        .format(len(sorted_gids), ani_threshold))
                    break

                prev_mash_rep_gids = mash_rep_gids
                prev_ani_threshold = ani_threshold

        # calculate FastANI ANI/AF between genomes passing Mash filtering
        ani_pairs = set()
        for gid1, gid2 in permutations(sorted_gids, 2):
            if gid1 in mash_ani and gid2 in mash_ani[gid1]:
                if mash_ani[gid1][gid2] >= self.min_mash_intra_sp_ani:
                    ani_pairs.add((gid1, gid2))
                    ani_pairs.add((gid2, gid1))

        self.logger.info(
            ' - calculating FastANI between {:,} pairs with Mash ANI >= {:.1f}%.'
            .format(len(ani_pairs), self.min_mash_intra_sp_ani))
        ani_af = self.fastani.pairs(ani_pairs,
                                    genomes.genomic_files,
                                    report_progress=False,
                                    check_cache=True)
        self.fastani.write_cache(silence=True)

        # perform greedy dereplication
        sp_reps = []
        for idx, gid in enumerate(sorted_gids):
            # determine if genome clusters with existing representative
            clustered = False
            for rid in sp_reps:
                ani, af = FastANI.symmetric_ani(ani_af, gid, rid)

                if ani >= self.derep_ani and af >= self.derep_af:
                    clustered = True
                    break

            if not clustered:
                sp_reps.append(gid)

        self.logger.info(
            ' - dereplicated {} from {:,} to {:,} genomes.'.format(
                species, len(sorted_gids), len(sp_reps)))

        # assign clustered genomes to most similar representative
        subsp_clusters = {}
        for rid in sp_reps:
            subsp_clusters[rid] = [rid]

        non_rep_gids = set(sorted_gids) - set(sp_reps)
        for gid in non_rep_gids:
            closest_rid = None
            max_ani = 0
            max_af = 0
            for rid in sp_reps:
                ani, af = FastANI.symmetric_ani(ani_af, gid, rid)
                if ((ani > max_ani and af >= self.derep_af) or
                    (ani == max_ani and af >= max_af and af >= self.derep_af)):
                    max_ani = ani
                    max_af = af
                    closest_rid = rid

            assert closest_rid is not None
            subsp_clusters[closest_rid].append(gid)

        return subsp_clusters
    def write_synonym_table(self, type_strain_synonyms, consensus_synonyms,
                            ani_af, sp_priority_ledger, genus_priority_ledger,
                            lpsn_gss_file):
        """Create table indicating species names that should be considered synonyms."""

        sp_priority_mngr = SpeciesPriorityManager(sp_priority_ledger,
                                                  genus_priority_ledger,
                                                  lpsn_gss_file,
                                                  self.output_dir)

        out_file = os.path.join(self.output_dir, 'synonyms.tsv')
        fout = open(out_file, 'w')
        fout.write(
            'Synonym type\tNCBI species\tGTDB representative\tStrain IDs\tType sources\tPriority year'
        )
        fout.write('\tGTDB type species\tGTDB type strain\tNCBI assembly type')
        fout.write(
            '\tNCBI synonym\tHighest-quality synonym genome\tSynonym strain IDs\tSynonym type sources\tSynonym priority year'
        )
        fout.write(
            '\tSynonym GTDB type species\tSynonym GTDB type strain\tSynonym NCBI assembly type'
        )
        fout.write('\tANI\tAF\tWarnings\n')

        incorrect_priority = 0
        failed_type_strain_priority = 0
        for synonyms, synonym_type in [
            (type_strain_synonyms, 'TYPE_STRAIN_SYNONYM'),
            (consensus_synonyms, 'MAJORITY_VOTE_SYNONYM')
        ]:
            for rid, synonym_ids in synonyms.items():
                for gid in synonym_ids:
                    ani, af = FastANI.symmetric_ani(ani_af, rid, gid)

                    fout.write(synonym_type)
                    fout.write('\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}'.format(
                        self.cur_genomes[rid].ncbi_taxa.species, rid,
                        ','.join(sorted(self.cur_genomes[rid].strain_ids())),
                        ','.join(
                            sorted(self.cur_genomes[rid].gtdb_type_sources())
                        ).upper().replace('STRAININFO', 'StrainInfo'),
                        sp_priority_mngr.species_priority_year(
                            self.cur_genomes,
                            rid), self.cur_genomes[rid].is_gtdb_type_species(),
                        self.cur_genomes[rid].is_gtdb_type_strain(),
                        self.cur_genomes[rid].ncbi_type_material))

                    synonym_priority_year = sp_priority_mngr.species_priority_year(
                        self.cur_genomes, gid)
                    if synonym_priority_year == Genome.NO_PRIORITY_YEAR:
                        synonym_priority_year = 'n/a'

                    fout.write('\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}'.format(
                        self.cur_genomes[gid].ncbi_taxa.species, gid,
                        ','.join(sorted(self.cur_genomes[gid].strain_ids())),
                        ','.join(
                            sorted(self.cur_genomes[gid].gtdb_type_sources())
                        ).upper().replace('STRAININFO',
                                          'StrainInfo'), synonym_priority_year,
                        self.cur_genomes[gid].is_gtdb_type_species(),
                        self.cur_genomes[gid].is_gtdb_type_strain(),
                        self.cur_genomes[gid].ncbi_type_material))
                    fout.write('\t{:.3f}\t{:.4f}'.format(ani, af))

                    if self.cur_genomes[rid].is_effective_type_strain(
                    ) and self.cur_genomes[gid].is_effective_type_strain():
                        priority_gid, note = sp_priority_mngr.species_priority(
                            self.cur_genomes, rid, gid)
                        if priority_gid != rid:
                            incorrect_priority += 1
                            fout.write('\tIncorrect priority: {}'.format(note))
                    elif not self.cur_genomes[rid].is_gtdb_type_strain(
                    ) and self.cur_genomes[gid].is_gtdb_type_strain():
                        failed_type_strain_priority += 1
                        fout.write(
                            '\tFailed to prioritize type strain of species')

                    fout.write('\n')

        if incorrect_priority:
            self.logger.warning(
                f' - identified {incorrect_priority:,} synonyms with incorrect priority.'
            )

        if failed_type_strain_priority:
            self.logger.warning(
                f' - identified {failed_type_strain_priority:,} synonyms that failed to priotize the type strain of the species.'
            )
Beispiel #14
0
    def cluster_genomes(self,
                        cur_genomes,
                        de_novo_rep_gids,
                        named_rep_gids,
                        final_cluster_radius):
        """Cluster new representatives to representatives of named GTDB species clusters."""

        all_reps = de_novo_rep_gids.union(named_rep_gids)
        nonrep_gids = set(cur_genomes.genomes.keys()) - all_reps
        self.logger.info('Clustering {:,} genomes to {:,} named and de novo representatives.'.format(
            len(nonrep_gids), len(all_reps)))

        if True:  # ***
            # calculate MASH distance between non-representatives and representatives genomes
            mash = Mash(self.cpus)

            mash_rep_sketch_file = os.path.join(
                self.output_dir, 'gtdb_rep_genomes.msh')
            rep_genome_list_file = os.path.join(
                self.output_dir, 'gtdb_rep_genomes.lst')
            mash.sketch(all_reps, cur_genomes.genomic_files,
                        rep_genome_list_file, mash_rep_sketch_file)

            mash_none_rep_sketch_file = os.path.join(
                self.output_dir, 'gtdb_nonrep_genomes.msh')
            non_rep_file = os.path.join(
                self.output_dir, 'gtdb_nonrep_genomes.lst')
            mash.sketch(nonrep_gids, cur_genomes.genomic_files,
                        non_rep_file, mash_none_rep_sketch_file)

            # get Mash distances
            mash_dist_file = os.path.join(
                self.output_dir, 'gtdb_rep_vs_nonrep_genomes.dst')
            mash.dist(float(100 - self.min_mash_ani)/100,
                      mash_rep_sketch_file,
                      mash_none_rep_sketch_file,
                      mash_dist_file)

            # read Mash distances
            mash_ani = mash.read_ani(mash_dist_file)

            # calculate ANI between non-representatives and representatives genomes
            clusters = {}
            for gid in all_reps:
                clusters[gid] = []

            mash_ani_pairs = []
            for qid in mash_ani:
                assert qid in nonrep_gids

                for rid in mash_ani[qid]:
                    assert rid in all_reps

                    if (mash_ani[qid][rid] >= self.min_mash_ani
                            and qid != rid):
                        mash_ani_pairs.append((qid, rid))
                        mash_ani_pairs.append((rid, qid))

            self.logger.info('Calculating ANI between {:,} species clusters and {:,} unclustered genomes ({:,} pairs):'.format(
                len(clusters),
                len(nonrep_gids),
                len(mash_ani_pairs)))
            ani_af = self.fastani.pairs(
                mash_ani_pairs, cur_genomes.genomic_files)

            # assign genomes to closest representatives
            # that is within the representatives ANI radius
            self.logger.info('Assigning genomes to closest representative.')
            for idx, cur_gid in enumerate(nonrep_gids):
                closest_rep_gid = None
                closest_rep_ani = 0
                closest_rep_af = 0
                for rep_gid in clusters:
                    ani, af = FastANI.symmetric_ani(ani_af, cur_gid, rep_gid)

                    isclose_abs_tol = 1e-4
                    if (ani >= final_cluster_radius[rep_gid].ani - isclose_abs_tol
                            and af >= self.af_sp - isclose_abs_tol):
                        # the isclose_abs_tol factor is used in order to avoid missing genomes due to
                        # small rounding errors when comparing floating point values. In particular,
                        # the ANI radius for named GTDB representatives is read from file so small
                        # rounding errors could occur. This has only been observed once, but seems
                        # like good practice to use isclose here.
                        if ani > closest_rep_ani or (ani == closest_rep_ani and af > closest_rep_af):
                            closest_rep_gid = rep_gid
                            closest_rep_ani = ani
                            closest_rep_af = af

                if closest_rep_gid:
                    clusters[closest_rep_gid].append(ClusteredGenome(gid=cur_gid,
                                                                     ani=closest_rep_ani,
                                                                     af=closest_rep_af))
                else:
                    self.logger.warning(
                        'Failed to assign genome {} to representative.'.format(cur_gid))
                    if closest_rep_gid:
                        self.logger.warning(
                            ' - closest_rep_gid = {}'.format(closest_rep_gid))
                        self.logger.warning(
                            ' - closest_rep_ani = {:.2f}'.format(closest_rep_ani))
                        self.logger.warning(
                            ' - closest_rep_af = {:.2f}'.format(closest_rep_af))
                        self.logger.warning(
                            ' - closest rep radius = {:.2f}'.format(final_cluster_radius[closest_rep_gid].ani))
                    else:
                        self.logger.warning(
                            ' - no representative with an AF >{:.2f} identified.'.format(self.af_sp))

                statusStr = '-> Assigned {:,} of {:,} ({:.2f}%) genomes.'.format(idx+1,
                                                                                 len(nonrep_gids),
                                                                                 float(idx+1)*100/len(nonrep_gids)).ljust(86)
                sys.stdout.write('{}\r'.format(statusStr))
                sys.stdout.flush()
            sys.stdout.write('\n')

            pickle.dump(clusters, open(os.path.join(
                self.output_dir, 'clusters.pkl'), 'wb'))
            pickle.dump(ani_af, open(os.path.join(self.output_dir,
                                                  'ani_af_rep_vs_nonrep.de_novo.pkl'), 'wb'))
        else:
            self.logger.warning(
                'Using previously calculated results in: {}'.format('clusters.pkl'))
            clusters = pickle.load(
                open(os.path.join(self.output_dir, 'clusters.pkl'), 'rb'))

            self.logger.warning('Using previously calculated results in: {}'.format(
                'ani_af_rep_vs_nonrep.de_novo.pkl'))
            ani_af = pickle.load(
                open(os.path.join(self.output_dir, 'ani_af_rep_vs_nonrep.de_novo.pkl'), 'rb'))

        return clusters, ani_af
Beispiel #15
0
    def selected_rep_genomes(self,
                             cur_genomes,
                             nonrep_radius,
                             unclustered_qc_gids,
                             mash_ani):
        """Select de novo representatives for species clusters in a greedy fashion using species-specific ANI thresholds."""

        # sort genomes by quality score
        self.logger.info(
            'Selecting de novo representatives in a greedy manner based on quality.')
        q = {gid: cur_genomes[gid].score_type_strain()
             for gid in unclustered_qc_gids}
        q_sorted = sorted(q.items(), key=lambda kv: (
            kv[1], kv[0]), reverse=True)

        # greedily determine representatives for new species clusters
        cluster_rep_file = os.path.join(self.output_dir, 'cluster_reps.tsv')
        clusters = set()
        if not os.path.exists(cluster_rep_file):
            clustered_genomes = 0
            max_ani_pairs = 0
            for idx, (cur_gid, _score) in enumerate(q_sorted):

                # determine reference genomes to calculate ANI between
                ani_pairs = []
                if cur_gid in mash_ani:
                    for rep_gid in clusters:
                        if mash_ani[cur_gid].get(rep_gid, 0) >= self.min_mash_ani:
                            ani_pairs.append((cur_gid, rep_gid))
                            ani_pairs.append((rep_gid, cur_gid))

                # determine if genome clusters with representative
                clustered = False
                if ani_pairs:
                    if len(ani_pairs) > max_ani_pairs:
                        max_ani_pairs = len(ani_pairs)

                    ani_af = self.fastani.pairs(
                        ani_pairs, cur_genomes.genomic_files, report_progress=False)

                    closest_rep_gid = None
                    closest_rep_ani = 0
                    closest_rep_af = 0
                    for rep_gid in clusters:
                        ani, af = FastANI.symmetric_ani(
                            ani_af, cur_gid, rep_gid)

                        if af >= self.af_sp:
                            if ani > closest_rep_ani or (ani == closest_rep_ani and af > closest_rep_af):
                                closest_rep_gid = rep_gid
                                closest_rep_ani = ani
                                closest_rep_af = af

                        if ani > nonrep_radius[cur_gid].ani and af >= self.af_sp:
                            nonrep_radius[cur_gid] = GenomeRadius(ani=ani,
                                                                  af=af,
                                                                  neighbour_gid=rep_gid)

                    if closest_rep_gid and closest_rep_ani > nonrep_radius[closest_rep_gid].ani:
                        clustered = True

                if not clustered:
                    # genome is a new species cluster representative
                    clusters.add(cur_gid)
                else:
                    clustered_genomes += 1

                if (idx+1) % 10 == 0 or idx+1 == len(q_sorted):
                    statusStr = '-> Clustered {:,} of {:,} ({:.2f}%) genomes [ANI pairs: {:,}; clustered genomes: {:,}; clusters: {:,}].'.format(
                        idx+1,
                        len(q_sorted),
                        float(idx+1)*100/len(q_sorted),
                        max_ani_pairs,
                        clustered_genomes,
                        len(clusters)).ljust(96)
                    sys.stdout.write('{}\r'.format(statusStr))
                    sys.stdout.flush()
                    max_ani_pairs = 0
            sys.stdout.write('\n')

            # write out selected cluster representative
            fout = open(cluster_rep_file, 'w')
            for gid in clusters:
                fout.write('{}\n'.format(gid))
            fout.close()
        else:
            # read cluster reps from file
            self.logger.warning(
                'Using previously determined cluster representatives.')
            for line in open(cluster_rep_file):
                gid = line.strip()
                clusters.add(gid)

        self.logger.info(
            'Selected {:,} representative genomes for de novo species clusters.'.format(len(clusters)))

        return clusters
    def run(self, gtdb_metadata_file, genomic_path_file):
        """Dereplicate GTDB species clusters using ANI/AF criteria."""

        # create GTDB genome sets
        self.logger.info('Creating GTDB genome set.')
        genomes = Genomes()
        genomes.load_from_metadata_file(gtdb_metadata_file)
        genomes.load_genomic_file_paths(genomic_path_file)
        self.logger.info(
            ' - genome set has {:,} species clusters spanning {:,} genomes.'.
            format(len(genomes.sp_clusters),
                   genomes.sp_clusters.total_num_genomes()))

        # get GTDB representatives from same genus
        self.logger.info('Identifying GTDB representatives in the same genus.')
        genus_gids = defaultdict(list)
        num_reps = 0
        for gid in genomes:
            if not genomes[gid].gtdb_is_rep:
                continue

            gtdb_genus = genomes[gid].gtdb_taxa.genus
            genus_gids[gtdb_genus].append(gid)
            num_reps += 1
        self.logger.info(
            f' - identified {len(genus_gids):,} genera spanning {num_reps:,} representatives'
        )

        # get all intragenus comparisons
        self.logger.info('Determining all intragenus comparisons.')
        gid_pairs = []
        for gids in genus_gids.values():
            if len(gids) < 2:
                continue

            for g1, g2 in permutations(gids, 2):
                gid_pairs.append((g1, g2))
        self.logger.info(
            f' - identified {len(gid_pairs):,} intragenus comparisons')

        # calculate FastANI ANI/AF between target genomes
        self.logger.info('Calculating ANI between intragenus pairs.')
        ani_af = self.fastani.pairs(gid_pairs,
                                    genomes.genomic_files,
                                    report_progress=True,
                                    check_cache=True)
        self.fastani.write_cache(silence=True)

        # write out results
        fout = open(
            os.path.join(self.output_dir, 'intragenus_ani_af_reps.tsv'), 'w')
        fout.write(
            'Query ID\tQuery species\tTarget ID\tTarget species\tANI\tAF\n')
        for qid in ani_af:
            for rid in ani_af:
                ani, af = FastANI.symmetric_ani(ani_af, qid, rid)

                fout.write('{}\t{}\t{}\t{}\t{:.3f}\t{:.3f}\n'.format(
                    qid, genomes[qid].gtdb_taxa.species, rid,
                    genomes[rid].gtdb_taxa.species, ani, af))
        fout.close()