def _mash_ani_unclustered(self, cur_genomes, gids):
        """Calculate pairwise Mash ANI estimates between genomes."""
        
        mash = Mash(self.cpus)
        
        # create Mash sketch for potential representative genomes
        mash_nontype_sketch_file = os.path.join(self.output_dir, 'gtdb_unclustered_genomes.msh')
        genome_list_file = os.path.join(self.output_dir, 'gtdb_unclustered_genomes.lst')
        mash.sketch(gids, cur_genomes.genomic_files, genome_list_file, mash_nontype_sketch_file)

        # get Mash distances
        mash_dist_file = os.path.join(self.output_dir, 'gtdb_unclustered_genomes.dst')
        mash.dist_pairwise( float(100 - self.min_mash_ani)/100, mash_nontype_sketch_file, mash_dist_file)

        # read Mash distances
        mash_ani = mash.read_ani(mash_dist_file)
        
        # report pairs above Mash threshold
        mash_ani_pairs = []
        for qid in mash_ani:
            for rid in mash_ani[qid]:
                if mash_ani[qid][rid] >= self.min_mash_ani:
                    n_qid = cur_genomes.user_uba_id_map.get(qid, qid)
                    n_rid = cur_genomes.user_uba_id_map.get(rid, rid)
                    if n_qid != n_rid:
                        mash_ani_pairs.append((n_qid, n_rid))
                        mash_ani_pairs.append((n_rid, n_qid))
                
        self.logger.info('Identified {:,} genome pairs with a Mash ANI >= {:.1f}%.'.format(
                            len(mash_ani_pairs), 
                            self.min_mash_ani))

        return mash_ani
    def __init__(self, derep_ani, derep_af, max_genomes_per_sp, ani_cache_file,
                 cpus, output_dir):
        """Initialization."""

        check_dependencies(['fastANI', 'mash'])

        self.cpus = cpus
        self.output_dir = output_dir

        self.logger = logging.getLogger('timestamp')

        self.max_genomes_per_sp = max_genomes_per_sp
        self.derep_ani = derep_ani
        self.derep_af = derep_af

        # minimum MASH ANI value for dereplicating within a species
        self.min_mash_intra_sp_ani = derep_ani - 1.0

        self.mash = Mash(self.cpus)
        self.fastani = FastANI(ani_cache_file, cpus)
Example #3
0
    def _mash_ani(self, genome_files, user_genomes, sp_clusters):
        """Calculate Mash ANI estimates between User genomes and species clusters."""

        mash = Mash(self.cpus)

        # create Mash sketch for User genomes
        mash_user_sketch_file = os.path.join(self.output_dir,
                                             'gtdb_user_genomes.msh')
        genome_list_file = os.path.join(self.output_dir,
                                        'gtdb_user_genomes.lst')
        mash.sketch(user_genomes, genome_files, genome_list_file,
                    mash_user_sketch_file)

        # create Mash sketch for species clusters
        mash_sp_sketch_file = os.path.join(self.output_dir,
                                           'gtdb_sp_genomes.msh')
        genome_list_file = os.path.join(self.output_dir, 'gtdb_sp_genomes.lst')
        mash.sketch(sp_clusters, genome_files, genome_list_file,
                    mash_sp_sketch_file)

        # get Mash distances
        mash_dist_file = os.path.join(self.output_dir, 'gtdb_user_vs_sp.dst')
        mash.dist(
            float(100 - self.min_mash_ani) / 100, mash_sp_sketch_file,
            mash_user_sketch_file, mash_dist_file)

        # read Mash distances
        mash_ani = mash.read_ani(mash_dist_file)

        # report pairs above Mash threshold
        mash_ani_pairs = []
        for qid in mash_ani:
            for rid in mash_ani[qid]:
                if mash_ani[qid][rid] >= self.min_mash_ani:
                    if qid != rid:
                        mash_ani_pairs.append((qid, rid))
                        mash_ani_pairs.append((rid, qid))

        self.logger.info(
            'Identified %d genome pairs with a Mash ANI >= %.1f%%.' %
            (len(mash_ani_pairs), self.min_mash_ani))

        return mash_ani
class IntraSpeciesDereplication(object):
    """Dereplicate GTDB species clusters using ANI/AF criteria."""
    def __init__(self, derep_ani, derep_af, max_genomes_per_sp, ani_cache_file,
                 cpus, output_dir):
        """Initialization."""

        check_dependencies(['fastANI', 'mash'])

        self.cpus = cpus
        self.output_dir = output_dir

        self.logger = logging.getLogger('timestamp')

        self.max_genomes_per_sp = max_genomes_per_sp
        self.derep_ani = derep_ani
        self.derep_af = derep_af

        # minimum MASH ANI value for dereplicating within a species
        self.min_mash_intra_sp_ani = derep_ani - 1.0

        self.mash = Mash(self.cpus)
        self.fastani = FastANI(ani_cache_file, cpus)

    def mash_sp_ani(self, gids, genomes, output_prefix):
        """Calculate pairwise Mash ANI estimates between genomes."""

        INIT_MASH_ANI_FILTER = 95.0

        # create Mash sketch for all genomes
        mash_sketch_file = f'{output_prefix}.msh'
        genome_list_file = f'{output_prefix}.lst'
        self.mash.sketch(gids,
                         genomes.genomic_files,
                         genome_list_file,
                         mash_sketch_file,
                         silence=True)

        # get Mash distances
        mash_dist_file = f'{output_prefix}.dst'
        self.mash.dist_pairwise(float(100 - INIT_MASH_ANI_FILTER) / 100,
                                mash_sketch_file,
                                mash_dist_file,
                                silence=True)

        # read Mash distances
        mash_ani = self.mash.read_ani(mash_dist_file)

        count = 0
        for qid in mash_ani:
            for rid in mash_ani[qid]:
                if qid != rid:
                    count += 1

        self.logger.info(
            ' - identified {:,} pairs passing Mash filtering of ANI >= {:.1f}%.'
            .format(count, INIT_MASH_ANI_FILTER))

        return mash_ani

    def priority_score(self, gid, genomes):
        """Get priority score of genome."""

        score = genomes[gid].score_assembly()
        if genomes[gid].is_gtdb_type_subspecies():
            score += 1e4

        return score

    def order_genomes_by_priority(self, gids, genomes):
        """Order genomes by overall priority. """

        genome_priority = {}
        for gid in gids:
            genome_priority[gid] = self.priority_score(gid, genomes)

        sorted_by_priority = sorted(genome_priority.items(),
                                    key=operator.itemgetter(1),
                                    reverse=True)

        return [d[0] for d in sorted_by_priority]

    def mash_sp_dereplicate(self, mash_ani, sorted_gids, ani_threshold):
        """Dereplicate genomes in species using Mash distances."""

        # perform greedy selection of new representatives
        sp_reps = []
        for gid in sorted_gids:
            clustered = False
            for rep_id in sp_reps:
                if gid in mash_ani:
                    ani = mash_ani[gid].get(rep_id, 0)
                else:
                    ani = 0

                if ani >= ani_threshold:
                    clustered = True
                    break

            if not clustered:
                # genome was not assigned to an existing representative,
                # so make it a new representative genome
                sp_reps.append(gid)

        return sp_reps

    def dereplicate_species(self, species, rid, cids, genomes, mash_out_dir):
        """Dereplicate genomes within a GTDB species."""

        # greedily dereplicate genomes based on genome priority
        sorted_gids = self.order_genomes_by_priority(cids.difference([rid]),
                                                     genomes)
        sorted_gids = [rid] + sorted_gids

        # calculate Mash ANI between genomes
        mash_ani = []
        if len(sorted_gids) > 1:
            # calculate MASH distances between genomes
            out_prefix = os.path.join(mash_out_dir,
                                      species[3:].lower().replace(' ', '_'))
            mash_ani = self.mash_sp_ani(sorted_gids, genomes, out_prefix)

        # perform initial dereplication using Mash for species with excessive
        # numbers of genomes
        if len(sorted_gids) > self.max_genomes_per_sp:
            self.logger.info(
                ' - limiting species to <={:,} genomes based on priority and Mash dereplication.'
                .format(self.max_genomes_per_sp))

            prev_mash_rep_gids = None
            for ani_threshold in [
                    99.75, 99.5, 99.25, 99.0, 98.75, 98.5, 98.25, 98.0, 97.75,
                    97.5, 97.0, 96.5, 96.0, 95.0, None
            ]:
                if ani_threshold is None:
                    self.logger.warning(
                        ' - delected {:,} highest priority genomes from final Mash dereplication.'
                        % self.max_genomes_per_sp)
                    sorted_gids = mash_rep_gids[0:self.max_genomes_per_sp]
                    break

                mash_rep_gids = self.mash_sp_dereplicate(
                    mash_ani, sorted_gids, ani_threshold)

                self.logger.info(
                    ' - dereplicated {} from {:,} to {:,} genomes at {:.2f}% ANI using Mash.'
                    .format(species, len(cids), len(mash_rep_gids),
                            ani_threshold))

                if len(mash_rep_gids) <= self.max_genomes_per_sp:
                    if not prev_mash_rep_gids:
                        # corner case where dereplication is occurring at 99.75%
                        prev_mash_rep_gids = sorted_gids

                    # select maximum allowed number of genomes by taking all genomes in the
                    # current Mash dereplicated set and then the highest priority genomes in the
                    # previous Mash dereplicated set which have not been selected
                    cur_sel_gids = set(mash_rep_gids)
                    prev_sel_gids = set(prev_mash_rep_gids)
                    num_prev_to_sel = self.max_genomes_per_sp - len(
                        cur_sel_gids)
                    num_prev_selected = 0
                    sel_sorted_gids = []
                    for gid in sorted_gids:
                        if gid in cur_sel_gids:
                            sel_sorted_gids.append(gid)
                        elif (gid in prev_sel_gids
                              and num_prev_selected < num_prev_to_sel):
                            num_prev_selected += 1
                            sel_sorted_gids.append(gid)

                        if len(sel_sorted_gids) == self.max_genomes_per_sp:
                            break

                    assert len(cur_sel_gids - set(sel_sorted_gids)) == 0
                    assert num_prev_to_sel == num_prev_selected
                    assert len(sel_sorted_gids) == self.max_genomes_per_sp

                    sorted_gids = sel_sorted_gids
                    self.logger.info(
                        ' - selected {:,} highest priority genomes from Mash dereplication at an ANI = {:.2f}%.'
                        .format(len(sorted_gids), ani_threshold))
                    break

                prev_mash_rep_gids = mash_rep_gids
                prev_ani_threshold = ani_threshold

        # calculate FastANI ANI/AF between genomes passing Mash filtering
        ani_pairs = set()
        for gid1, gid2 in permutations(sorted_gids, 2):
            if gid1 in mash_ani and gid2 in mash_ani[gid1]:
                if mash_ani[gid1][gid2] >= self.min_mash_intra_sp_ani:
                    ani_pairs.add((gid1, gid2))
                    ani_pairs.add((gid2, gid1))

        self.logger.info(
            ' - calculating FastANI between {:,} pairs with Mash ANI >= {:.1f}%.'
            .format(len(ani_pairs), self.min_mash_intra_sp_ani))
        ani_af = self.fastani.pairs(ani_pairs,
                                    genomes.genomic_files,
                                    report_progress=False,
                                    check_cache=True)
        self.fastani.write_cache(silence=True)

        # perform greedy dereplication
        sp_reps = []
        for idx, gid in enumerate(sorted_gids):
            # determine if genome clusters with existing representative
            clustered = False
            for rid in sp_reps:
                ani, af = FastANI.symmetric_ani(ani_af, gid, rid)

                if ani >= self.derep_ani and af >= self.derep_af:
                    clustered = True
                    break

            if not clustered:
                sp_reps.append(gid)

        self.logger.info(
            ' - dereplicated {} from {:,} to {:,} genomes.'.format(
                species, len(sorted_gids), len(sp_reps)))

        # assign clustered genomes to most similar representative
        subsp_clusters = {}
        for rid in sp_reps:
            subsp_clusters[rid] = [rid]

        non_rep_gids = set(sorted_gids) - set(sp_reps)
        for gid in non_rep_gids:
            closest_rid = None
            max_ani = 0
            max_af = 0
            for rid in sp_reps:
                ani, af = FastANI.symmetric_ani(ani_af, gid, rid)
                if ((ani > max_ani and af >= self.derep_af) or
                    (ani == max_ani and af >= max_af and af >= self.derep_af)):
                    max_ani = ani
                    max_af = af
                    closest_rid = rid

            assert closest_rid is not None
            subsp_clusters[closest_rid].append(gid)

        return subsp_clusters

    def derep_sp_clusters(self, genomes):
        """Dereplicate each GTDB species cluster."""

        mash_out_dir = os.path.join(self.output_dir, 'mash')
        if not os.path.exists(mash_out_dir):
            os.makedirs(mash_out_dir)

        derep_genomes = {}
        for rid, cids in genomes.sp_clusters.items():
            species = genomes[rid].gtdb_taxa.species

            self.logger.info(
                'Dereplicating {} with {:,} genomes [{:,} of {:,} ({:.2f}%) species].'
                .format(species, len(cids), len(derep_genomes),
                        len(genomes.sp_clusters),
                        len(derep_genomes) * 100.0 / len(genomes.sp_clusters)))

            subsp_clusters = self.dereplicate_species(species, rid, cids,
                                                      genomes, mash_out_dir)

            derep_genomes[species] = subsp_clusters

        return derep_genomes

    def run(self, gtdb_metadata_file, genomic_path_file):
        """Dereplicate GTDB species clusters using ANI/AF criteria."""

        # create GTDB genome sets
        self.logger.info('Creating GTDB genome set.')
        genomes = Genomes()
        genomes.load_from_metadata_file(gtdb_metadata_file)
        genomes.load_genomic_file_paths(genomic_path_file)
        self.logger.info(
            ' - genome set has {:,} species clusters spanning {:,} genomes.'.
            format(len(genomes.sp_clusters),
                   genomes.sp_clusters.total_num_genomes()))

        # dereplicate each species cluster
        self.logger.info(
            'Performing dereplication with ANI={:.1f}, AF={:.2f}, Mash ANI={:.2f}, max genomes={:,}.'
            .format(self.derep_ani, self.derep_af, self.min_mash_intra_sp_ani,
                    self.max_genomes_per_sp))
        derep_genomes = self.derep_sp_clusters(genomes)

        # write out `subspecies` clusters
        out_file = os.path.join(self.output_dir, 'subsp_clusters.tsv')
        fout = open(out_file, 'w')
        fout.write(
            'Genome ID\tGTDB Species\tGTDB Taxonomy\tPriority score\tNo. clustered genomes\tNo. clustered genomes\tClustered genomes\n'
        )
        for species, subsp_clusters in derep_genomes.items():
            for rid, cids in subsp_clusters.items():
                assert species == genomes[rid].gtdb_taxa.species
                fout.write('{}\t{}\t{}\t{:.3f}\t{}\t{}\n'.format(
                    rid, genomes[rid].gtdb_taxa.species,
                    genomes[rid].gtdb_taxa, self.priority_score(rid, genomes),
                    len(cids), ','.join(cids)))
    def _calculate_ani(self, cur_genomes, rep_gids, rep_mash_sketch_file):
        """Calculate ANI between representative and non-representative genomes."""

        if True:  #***
            mash = Mash(self.cpus)

            # create Mash sketch for representative genomes
            if not rep_mash_sketch_file or not os.path.exists(
                    rep_mash_sketch_file):
                rep_genome_list_file = os.path.join(self.output_dir,
                                                    'gtdb_reps.lst')
                rep_mash_sketch_file = os.path.join(self.output_dir,
                                                    'gtdb_reps.msh')
                mash.sketch(rep_gids, cur_genomes.genomic_files,
                            rep_genome_list_file, rep_mash_sketch_file)

            # create Mash sketch for non-representative genomes
            nonrep_gids = set()
            for gid in cur_genomes:
                if gid not in rep_gids:
                    nonrep_gids.add(gid)

            nonrep_genome_list_file = os.path.join(self.output_dir,
                                                   'gtdb_nonreps.lst')
            nonrep_genome_sketch_file = os.path.join(self.output_dir,
                                                     'gtdb_nonreps.msh')
            mash.sketch(nonrep_gids, cur_genomes.genomic_files,
                        nonrep_genome_list_file, nonrep_genome_sketch_file)

            # get Mash distances
            mash_dist_file = os.path.join(self.output_dir,
                                          'gtdb_reps_vs_nonreps.dst')
            mash.dist(
                float(100 - self.min_mash_ani) / 100, rep_mash_sketch_file,
                nonrep_genome_sketch_file, mash_dist_file)

            # read Mash distances
            mash_ani = mash.read_ani(mash_dist_file)

            # get pairs above Mash threshold
            mash_ani_pairs = []
            for qid in mash_ani:
                for rid in mash_ani[qid]:
                    if mash_ani[qid][rid] >= self.min_mash_ani:
                        n_qid = cur_genomes.user_uba_id_map.get(qid, qid)
                        n_rid = cur_genomes.user_uba_id_map.get(rid, rid)
                        if n_qid != n_rid:
                            mash_ani_pairs.append((n_qid, n_rid))
                            mash_ani_pairs.append((n_rid, n_qid))

            self.logger.info(
                'Identified {:,} genome pairs with a Mash ANI >= {:.1f}%.'.
                format(len(mash_ani_pairs), self.min_mash_ani))

            # calculate ANI between pairs
            self.logger.info(
                'Calculating ANI between {:,} genome pairs:'.format(
                    len(mash_ani_pairs)))
            ani_af = self.fastani.pairs(mash_ani_pairs,
                                        cur_genomes.genomic_files)
            pickle.dump(
                ani_af,
                open(os.path.join(self.output_dir, 'ani_af_rep_vs_nonrep.pkl'),
                     'wb'))
        else:
            self.logger.warning(
                'Using previously calculated results in: {}'.format(
                    'ani_af_rep_vs_nonrep.pkl'))
            ani_af = pickle.load(
                open(os.path.join(self.output_dir, 'ani_af_rep_vs_nonrep.pkl'),
                     'rb'))

        return ani_af
    def _cluster_genomes(self,
                            cur_genomes,
                            de_novo_rep_gids,
                            named_rep_gids, 
                            final_cluster_radius):
        """Cluster new representatives to representatives of named GTDB species clusters."""
        
        all_reps = de_novo_rep_gids.union(named_rep_gids)
        nonrep_gids = set(cur_genomes.genomes.keys()) - all_reps
        self.logger.info('Clustering {:,} genomes to {:,} named and de novo representatives.'.format(
                            len(nonrep_gids), len(all_reps)))

        if True: #***
            # calculate MASH distance between non-representatives and representatives genomes
            mash = Mash(self.cpus)
            
            mash_rep_sketch_file = os.path.join(self.output_dir, 'gtdb_rep_genomes.msh')
            rep_genome_list_file = os.path.join(self.output_dir, 'gtdb_rep_genomes.lst')
            mash.sketch(all_reps, cur_genomes.genomic_files, rep_genome_list_file, mash_rep_sketch_file)

            mash_none_rep_sketch_file = os.path.join(self.output_dir, 'gtdb_nonrep_genomes.msh')
            non_rep_file = os.path.join(self.output_dir, 'gtdb_nonrep_genomes.lst')
            mash.sketch(nonrep_gids, cur_genomes.genomic_files, non_rep_file, mash_none_rep_sketch_file)

            # get Mash distances
            mash_dist_file = os.path.join(self.output_dir, 'gtdb_rep_vs_nonrep_genomes.dst')
            mash.dist(float(100 - self.min_mash_ani)/100, 
                        mash_rep_sketch_file, 
                        mash_none_rep_sketch_file, 
                        mash_dist_file)

            # read Mash distances
            mash_ani = mash.read_ani(mash_dist_file)
            
            # calculate ANI between non-representatives and representatives genomes
            clusters = {}
            for gid in all_reps:
                clusters[gid] = []

            if False: #***
                mash_ani_pairs = []
                for gid in nonrep_gids:
                    if gid in mash_ani:
                        for rid in clusters:
                            if mash_ani[gid].get(rid, 0) >= self.min_mash_ani:
                                n_gid = cur_genomes.user_uba_id_map.get(gid, gid)
                                n_rid = cur_genomes.user_uba_id_map.get(rid, rid)
                                if n_gid != n_rid:
                                    mash_ani_pairs.append((n_gid, n_rid))
                                    mash_ani_pairs.append((n_rid, n_gid))
                                    
            mash_ani_pairs = []
            for qid in mash_ani:
                n_qid = cur_genomes.user_uba_id_map.get(qid, qid)
                assert n_qid in nonrep_gids
                
                for rid in mash_ani[qid]:
                    n_rid = cur_genomes.user_uba_id_map.get(rid, rid)
                    assert n_rid in all_reps
                    
                    if (mash_ani[qid][rid] >= self.min_mash_ani
                        and n_qid != n_rid):
                        mash_ani_pairs.append((n_qid, n_rid))
                        mash_ani_pairs.append((n_rid, n_qid))
                            
            self.logger.info('Calculating ANI between {:,} species clusters and {:,} unclustered genomes ({:,} pairs):'.format(
                                len(clusters), 
                                len(nonrep_gids),
                                len(mash_ani_pairs)))
            ani_af = self.fastani.pairs(mash_ani_pairs, cur_genomes.genomic_files)

            # assign genomes to closest representatives 
            # that is within the representatives ANI radius
            self.logger.info('Assigning genomes to closest representative.')
            for idx, cur_gid in enumerate(nonrep_gids):
                closest_rep_gid = None
                closest_rep_ani = 0
                closest_rep_af = 0
                for rep_gid in clusters:
                    ani, af = symmetric_ani(ani_af, cur_gid, rep_gid)
                    
                    if ani >= final_cluster_radius[rep_gid].ani and af >= self.af_sp:
                        if ani > closest_rep_ani or (ani == closest_rep_ani and af > closest_rep_af):
                            closest_rep_gid = rep_gid
                            closest_rep_ani = ani
                            closest_rep_af = af
                    
                if closest_rep_gid:
                    clusters[closest_rep_gid].append(ClusteredGenome(gid=cur_gid, 
                                                                            ani=closest_rep_ani, 
                                                                            af=closest_rep_af))
                else:
                    self.logger.warning('Failed to assign genome {} to representative.'.format(cur_gid))
                    if closest_rep_gid:
                        self.logger.warning(' ...closest_rep_gid = {}'.format(closest_rep_gid))
                        self.logger.warning(' ...closest_rep_ani = {:.2f}'.format(closest_rep_ani))
                        self.logger.warning(' ...closest_rep_af = {:.2f}'.format(closest_rep_af))
                        self.logger.warning(' ...closest rep radius = {:.2f}'.format(final_cluster_radius[closest_rep_gid].ani))
                    else:
                        self.logger.warning(' ...no representative with an AF >{:.2f} identified.'.format(self.af_sp))
                 
                statusStr = '-> Assigned {:,} of {:,} ({:.2f}%) genomes.'.format(idx+1, 
                                                                                    len(nonrep_gids), 
                                                                                    float(idx+1)*100/len(nonrep_gids)).ljust(86)
                sys.stdout.write('{}\r'.format(statusStr))
                sys.stdout.flush()
            sys.stdout.write('\n')
            
            pickle.dump(clusters, open(os.path.join(self.output_dir, 'clusters.pkl'), 'wb'))
            pickle.dump(ani_af, open(os.path.join(self.output_dir, 'ani_af_rep_vs_nonrep.de_novo.pkl'), 'wb'))
        else:
            self.logger.warning('Using previously calculated results in: {}'.format('clusters.pkl'))
            clusters = pickle.load(open(os.path.join(self.output_dir, 'clusters.pkl'), 'rb'))
            
            self.logger.warning('Using previously calculated results in: {}'.format('ani_af_rep_vs_nonrep.de_novo.pkl'))
            ani_af = pickle.load(open(os.path.join(self.output_dir, 'ani_af_rep_vs_nonrep.de_novo.pkl'), 'rb'))

        return clusters, ani_af
Example #7
0
    def _calculate_ani(self, type_gids, genome_files, ncbi_taxonomy,
                       type_genome_sketch_file):
        """Calculate ANI between type and non-type genomes."""

        mash = Mash(self.cpus)

        # create Mash sketch for type genomes
        if not type_genome_sketch_file or not os.path.exists(
                type_genome_sketch_file):
            type_genome_list_file = os.path.join(self.output_dir,
                                                 'gtdb_type_genomes.lst')
            type_genome_sketch_file = os.path.join(self.output_dir,
                                                   'gtdb_type_genomes.msh')
            mash.sketch(type_gids, genome_files, type_genome_list_file,
                        type_genome_sketch_file)

        # create Mash sketch for non-type genomes
        nontype_gids = set()
        for gid in genome_files:
            if gid not in type_gids:
                nontype_gids.add(gid)

        nontype_genome_list_file = os.path.join(self.output_dir,
                                                'gtdb_nontype_genomes.lst')
        nontype_genome_sketch_file = os.path.join(self.output_dir,
                                                  'gtdb_nontype_genomes.msh')
        mash.sketch(nontype_gids, genome_files, nontype_genome_list_file,
                    nontype_genome_sketch_file)

        # get Mash distances
        mash_dist_file = os.path.join(self.output_dir,
                                      'gtdb_type_vs_nontype_genomes.dst')
        mash.dist(
            float(100 - self.min_mash_ani) / 100, type_genome_sketch_file,
            nontype_genome_sketch_file, mash_dist_file)

        # read Mash distances
        mash_ani = mash.read_ani(mash_dist_file)

        # get pairs above Mash threshold
        mash_ani_pairs = []
        for qid in mash_ani:
            for rid in mash_ani[qid]:
                if mash_ani[qid][rid] >= self.min_mash_ani:
                    if qid != rid:
                        mash_ani_pairs.append((qid, rid))
                        mash_ani_pairs.append((rid, qid))

        self.logger.info(
            'Identified %d genome pairs with a Mash ANI >= %.1f%%.' %
            (len(mash_ani_pairs), self.min_mash_ani))

        # calculate ANI between pairs
        self.logger.info('Calculating ANI between %d genome pairs:' %
                         len(mash_ani_pairs))
        if True:  #***
            ani_af = self.fastani.pairs(mash_ani_pairs, genome_files)
            pickle.dump(
                ani_af,
                open(
                    os.path.join(self.output_dir,
                                 'ani_af_type_vs_nontype.pkl'), 'wb'))
        else:
            ani_af = pickle.load(
                open(
                    os.path.join(self.output_dir,
                                 'ani_af_type_vs_nontype.pkl'), 'rb'))

        return ani_af
Example #8
0
    def _cluster_genomes(self, 
                            genome_files,
                            rep_genomes,
                            type_gids, 
                            passed_qc,
                            final_cluster_radius):
        """Cluster all non-type/representative genomes to selected type/representatives genomes."""

        all_reps = rep_genomes.union(type_gids)
        
        # calculate MASH distance between non-type/representative genomes and selected type/representatives genomes
        mash = Mash(self.cpus)
        
        mash_type_rep_sketch_file = os.path.join(self.output_dir, 'gtdb_rep_genomes.msh')
        type_rep_genome_list_file = os.path.join(self.output_dir, 'gtdb_rep_genomes.lst')
        mash.sketch(all_reps, genome_files, type_rep_genome_list_file, mash_type_rep_sketch_file)
        
        mash_none_rep_sketch_file = os.path.join(self.output_dir, 'gtdb_nonrep_genomes.msh')
        type_none_rep_file = os.path.join(self.output_dir, 'gtdb_nonrep_genomes.lst')
        mash.sketch(passed_qc - all_reps, genome_files, type_none_rep_file, mash_none_rep_sketch_file)

        # get Mash distances
        mash_dist_file = os.path.join(self.output_dir, 'gtdb_rep_vs_nonrep_genomes.dst')
        mash.dist(float(100 - self.min_mash_ani)/100, mash_type_rep_sketch_file, mash_none_rep_sketch_file, mash_dist_file)

        # read Mash distances
        mash_ani = mash.read_ani(mash_dist_file)
        
        # calculate ANI between non-type/representative genomes and selected type/representatives genomes
        clusters = {}
        for gid in all_reps:
            clusters[gid] = []
        
        genomes_to_cluster = passed_qc - set(clusters)
        ani_pairs = []
        for gid in genomes_to_cluster:
            if gid in mash_ani:
                for rep_gid in clusters:
                    if mash_ani[gid].get(rep_gid, 0) >= self.min_mash_ani:
                        ani_pairs.append((gid, rep_gid))
                        ani_pairs.append((rep_gid, gid))
                        
        self.logger.info('Calculating ANI between %d species clusters and %d unclustered genomes (%d pairs):' % (
                            len(clusters), 
                            len(genomes_to_cluster),
                            len(ani_pairs)))
        ani_af = self.fastani.pairs(ani_pairs, genome_files)

        # assign genomes to closest representatives 
        # that is within the representatives ANI radius
        self.logger.info('Assigning genomes to closest representative.')
        for idx, cur_gid in enumerate(genomes_to_cluster):
            closest_rep_gid = None
            closest_rep_ani = 0
            closest_rep_af = 0
            for rep_gid in clusters:
                ani, af = symmetric_ani(ani_af, cur_gid, rep_gid)
                
                if ani >= final_cluster_radius[rep_gid].ani and af >= self.af_sp:
                    if ani > closest_rep_ani or (ani == closest_rep_ani and af > closest_rep_af):
                        closest_rep_gid = rep_gid
                        closest_rep_ani = ani
                        closest_rep_af = af
                
            if closest_rep_gid:
                clusters[closest_rep_gid].append(self.ClusteredGenome(gid=cur_gid, 
                                                                        ani=closest_rep_ani, 
                                                                        af=closest_rep_af))
            else:
                self.logger.warning('Failed to assign genome %s to representative.' % cur_gid)
                if closest_rep_gid:
                    self.logger.warning(' ...closest_rep_gid = %s' % closest_rep_gid)
                    self.logger.warning(' ...closest_rep_ani = %.2f' % closest_rep_ani)
                    self.logger.warning(' ...closest_rep_af = %.2f' % closest_rep_af)
                    self.logger.warning(' ...closest rep radius = %.2f' % final_cluster_radius[closest_rep_gid].ani)
                else:
                    self.logger.warning(' ...no representative with an AF >%.2f identified.' % self.af_sp)
             
            statusStr = '-> Assigned %d of %d (%.2f%%) genomes.'.ljust(86) % (idx+1, 
                                                                                len(genomes_to_cluster), 
                                                                                float(idx+1)*100/len(genomes_to_cluster))
            sys.stdout.write('%s\r' % statusStr)
            sys.stdout.flush()
        sys.stdout.write('\n')

        return clusters, ani_af
Example #9
0
    def cluster_genomes(self,
                        cur_genomes,
                        de_novo_rep_gids,
                        named_rep_gids,
                        final_cluster_radius):
        """Cluster new representatives to representatives of named GTDB species clusters."""

        all_reps = de_novo_rep_gids.union(named_rep_gids)
        nonrep_gids = set(cur_genomes.genomes.keys()) - all_reps
        self.logger.info('Clustering {:,} genomes to {:,} named and de novo representatives.'.format(
            len(nonrep_gids), len(all_reps)))

        if True:  # ***
            # calculate MASH distance between non-representatives and representatives genomes
            mash = Mash(self.cpus)

            mash_rep_sketch_file = os.path.join(
                self.output_dir, 'gtdb_rep_genomes.msh')
            rep_genome_list_file = os.path.join(
                self.output_dir, 'gtdb_rep_genomes.lst')
            mash.sketch(all_reps, cur_genomes.genomic_files,
                        rep_genome_list_file, mash_rep_sketch_file)

            mash_none_rep_sketch_file = os.path.join(
                self.output_dir, 'gtdb_nonrep_genomes.msh')
            non_rep_file = os.path.join(
                self.output_dir, 'gtdb_nonrep_genomes.lst')
            mash.sketch(nonrep_gids, cur_genomes.genomic_files,
                        non_rep_file, mash_none_rep_sketch_file)

            # get Mash distances
            mash_dist_file = os.path.join(
                self.output_dir, 'gtdb_rep_vs_nonrep_genomes.dst')
            mash.dist(float(100 - self.min_mash_ani)/100,
                      mash_rep_sketch_file,
                      mash_none_rep_sketch_file,
                      mash_dist_file)

            # read Mash distances
            mash_ani = mash.read_ani(mash_dist_file)

            # calculate ANI between non-representatives and representatives genomes
            clusters = {}
            for gid in all_reps:
                clusters[gid] = []

            mash_ani_pairs = []
            for qid in mash_ani:
                assert qid in nonrep_gids

                for rid in mash_ani[qid]:
                    assert rid in all_reps

                    if (mash_ani[qid][rid] >= self.min_mash_ani
                            and qid != rid):
                        mash_ani_pairs.append((qid, rid))
                        mash_ani_pairs.append((rid, qid))

            self.logger.info('Calculating ANI between {:,} species clusters and {:,} unclustered genomes ({:,} pairs):'.format(
                len(clusters),
                len(nonrep_gids),
                len(mash_ani_pairs)))
            ani_af = self.fastani.pairs(
                mash_ani_pairs, cur_genomes.genomic_files)

            # assign genomes to closest representatives
            # that is within the representatives ANI radius
            self.logger.info('Assigning genomes to closest representative.')
            for idx, cur_gid in enumerate(nonrep_gids):
                closest_rep_gid = None
                closest_rep_ani = 0
                closest_rep_af = 0
                for rep_gid in clusters:
                    ani, af = FastANI.symmetric_ani(ani_af, cur_gid, rep_gid)

                    isclose_abs_tol = 1e-4
                    if (ani >= final_cluster_radius[rep_gid].ani - isclose_abs_tol
                            and af >= self.af_sp - isclose_abs_tol):
                        # the isclose_abs_tol factor is used in order to avoid missing genomes due to
                        # small rounding errors when comparing floating point values. In particular,
                        # the ANI radius for named GTDB representatives is read from file so small
                        # rounding errors could occur. This has only been observed once, but seems
                        # like good practice to use isclose here.
                        if ani > closest_rep_ani or (ani == closest_rep_ani and af > closest_rep_af):
                            closest_rep_gid = rep_gid
                            closest_rep_ani = ani
                            closest_rep_af = af

                if closest_rep_gid:
                    clusters[closest_rep_gid].append(ClusteredGenome(gid=cur_gid,
                                                                     ani=closest_rep_ani,
                                                                     af=closest_rep_af))
                else:
                    self.logger.warning(
                        'Failed to assign genome {} to representative.'.format(cur_gid))
                    if closest_rep_gid:
                        self.logger.warning(
                            ' - closest_rep_gid = {}'.format(closest_rep_gid))
                        self.logger.warning(
                            ' - closest_rep_ani = {:.2f}'.format(closest_rep_ani))
                        self.logger.warning(
                            ' - closest_rep_af = {:.2f}'.format(closest_rep_af))
                        self.logger.warning(
                            ' - closest rep radius = {:.2f}'.format(final_cluster_radius[closest_rep_gid].ani))
                    else:
                        self.logger.warning(
                            ' - no representative with an AF >{:.2f} identified.'.format(self.af_sp))

                statusStr = '-> Assigned {:,} of {:,} ({:.2f}%) genomes.'.format(idx+1,
                                                                                 len(nonrep_gids),
                                                                                 float(idx+1)*100/len(nonrep_gids)).ljust(86)
                sys.stdout.write('{}\r'.format(statusStr))
                sys.stdout.flush()
            sys.stdout.write('\n')

            pickle.dump(clusters, open(os.path.join(
                self.output_dir, 'clusters.pkl'), 'wb'))
            pickle.dump(ani_af, open(os.path.join(self.output_dir,
                                                  'ani_af_rep_vs_nonrep.de_novo.pkl'), 'wb'))
        else:
            self.logger.warning(
                'Using previously calculated results in: {}'.format('clusters.pkl'))
            clusters = pickle.load(
                open(os.path.join(self.output_dir, 'clusters.pkl'), 'rb'))

            self.logger.warning('Using previously calculated results in: {}'.format(
                'ani_af_rep_vs_nonrep.de_novo.pkl'))
            ani_af = pickle.load(
                open(os.path.join(self.output_dir, 'ani_af_rep_vs_nonrep.de_novo.pkl'), 'rb'))

        return clusters, ani_af