Example #1
0
    def _msa_filter_by_taxa(self, concatenated_file, gtdb_taxonomy,
                            taxa_filter, outgroup_taxon):
        """Filter GTDB MSA filtered to specified taxa."""

        msa = read_fasta(concatenated_file)
        msa_len = len(msa)
        self.logger.info('Read concatenated alignment for %d GTDB genomes.' %
                         msa_len)

        if taxa_filter is not None:
            taxa_to_keep = set(taxa_filter.split(','))

            if outgroup_taxon not in taxa_to_keep and outgroup_taxon is not None:
                taxa_to_keep.add(outgroup_taxon)

            filtered_genomes = 0
            for genome_id, taxa in gtdb_taxonomy.iteritems():
                common_taxa = taxa_to_keep.intersection(taxa)
                if len(common_taxa) == 0:
                    if genome_id in msa:
                        del msa[genome_id]
                        filtered_genomes += 1

            msg = 'Filtered %.2f%% (%d/%d) taxa based on assigned taxonomy, ' \
                  '%d taxa remain.' % (
                      (float(filtered_genomes) / float(msa_len)) * 100.0,
                      filtered_genomes, msa_len, msa_len - filtered_genomes)
            self.logger.info(msg) if len(msa) > 0 else self.logger.warning(msg)

        return msa
Example #2
0
    def test_write_fasta_wrap(self):
        """ Test that the sequences are wrapped as specified """
        path_fasta = os.path.join(self.dir_tmp, 'fasta.fna')
        seqs = {
            'genome_1': 'CAGTTCAGTT',
            'genome_2': 'TTAGTCA',
            'genome_3': 'CAG'
        }
        write_fasta(seqs, path_fasta, wrap=4)

        self.assertDictEqual(seqs, read_fasta(path_fasta))

        file_content = defaultdict(list)
        with open(path_fasta, 'r') as f:
            cur_gid = None
            for line in f.readlines():
                line = line.strip()
                if line.startswith('>'):
                    cur_gid = line[1:]
                else:
                    file_content[cur_gid].append(line)

        for gid, seq_list in file_content.items():
            for cur_seq in seq_list:
                self.assertLessEqual(len(cur_seq), 4)
Example #3
0
    def add_genome(self, genome_id: str, path_faa: str, pfam_th: TopHitPfamFile, tigr_th: TopHitTigrFile):
        """Process the top hit files for a genome and store the copy info."""
        if genome_id in self.genomes:
            self.logger.warning(f'Genome already exists in copy number file: {genome_id}')
        self.genomes[genome_id] = {'unq': dict(), 'mul': dict(), 'muq': dict(), 'mis': dict()}

        # Pointers to unique, multiple hit, multiple-unique, missing markers.
        cur_unq = self.genomes[genome_id]['unq']
        cur_mul = self.genomes[genome_id]['mul']
        cur_muq = self.genomes[genome_id]['muq']
        cur_mis = self.genomes[genome_id]['mis']

        # Load genes from the prodigal faa file.
        d_genes = read_fasta(path_faa, False)
        for seq_id, seq in d_genes.items():
            if seq.endswith('*'):
                d_genes[seq_id] = seq[:-1]

        # Create a dictionary of marker names -> Hits
        d_hmm_hits = self._merge_hit_files(pfam_th, tigr_th)

        # Foreach expected marker determine which category it falls into.
        for marker_id in self.marker_names:

            # Marker is missing.
            if marker_id not in d_hmm_hits:
                cur_mis[marker_id] = None

            # Multiple hits to to the same marker.
            elif len(d_hmm_hits[marker_id]) > 1:

                # If sequences are the same, take the most significant hit
                unq_seqs = {d_genes[x.gene_id] for x in d_hmm_hits[marker_id]}
                if len(unq_seqs) == 1:
                    cur_top_hit = sorted(d_hmm_hits[marker_id], reverse=True)[0]
                    cur_muq[marker_id] = {'hit': cur_top_hit, 'seq': d_genes[cur_top_hit.gene_id]}

                # Marker maps to multiple genes.
                else:
                    cur_mul[marker_id] = None

            # This was a unique hit.
            else:
                cur_hit = d_hmm_hits[marker_id][0]
                cur_unq[marker_id] = {'hit': cur_hit, 'seq': d_genes[cur_hit.gene_id]}

        # Sanity check - confirm that the total number of markers matches.
        if len(self.marker_names) != len(cur_unq) + len(cur_mul) + len(cur_muq) + len(cur_mis):
            raise GTDBTkExit('The marker set is inconsistent, please report this issue.')
Example #4
0
    def run(self, msa, mask, outf):
        with open(outf, 'w') as outfwriter:
            dict_genomes = read_fasta(msa, False)
            with open(mask, 'r') as f:
                maskstr = f.readline()
            print(maskstr)
            print(len(maskstr))

            for k, v in dict_genomes.items():
                aligned_seq = ''.join([
                    v[i] for i in range(0, len(maskstr)) if maskstr[i] == '1'
                ])
                fasta_outstr = ">%s\n%s\n" % (k, aligned_seq)
                outfwriter.write(fasta_outstr)
            outfwriter.close()
Example #5
0
    def _msa_filter_by_taxa(self, concatenated_file: str,
                            gtdb_taxonomy: Dict[str, Tuple[str, str, str, str,
                                                           str, str, str]],
                            taxa_filter: Optional[str],
                            outgroup_taxon: Optional[str]) -> Dict[str, str]:
        """Filter GTDB MSA to a subset of specified taxa.

        Parameters
        ----------
        concatenated_file
            The path to the MSA.
        gtdb_taxonomy
            A dictionary mapping the accession to the 7 rank taxonomy.
        taxa_filter
            A comma separated list of taxa to include.
        outgroup_taxon
            If using an outgroup (de novo workflow), ensure this is retained.

        Returns
        -------
        Dict[str, str]
            The genome id to msa of those genomes specified in the filter.
        """

        msa = read_fasta(concatenated_file)
        msa_len = len(msa)
        self.logger.info(
            f'Read concatenated alignment for {msa_len:,} GTDB genomes.')

        if taxa_filter is not None:
            taxa_to_keep = set(taxa_filter.split(','))

            if outgroup_taxon not in taxa_to_keep and outgroup_taxon is not None:
                taxa_to_keep.add(outgroup_taxon)

            filtered_genomes = 0
            for genome_id, taxa in gtdb_taxonomy.items():
                common_taxa = taxa_to_keep.intersection(taxa)
                if len(common_taxa) == 0:
                    if genome_id in msa:
                        del msa[genome_id]
                        filtered_genomes += 1

            msg = f'Filtered {filtered_genomes / msa_len:.2%} ({filtered_genomes:,}/{msa_len:,}) ' \
                  f'taxa based on assigned taxonomy, {msa_len - filtered_genomes:,} taxa remain.'
            self.logger.info(msg) if len(msa) > 0 else self.logger.warning(msg)

        return msa
Example #6
0
    def run(self, msa_file, marker_list):
        """Randomly select a subset of columns from the MSA of each marker."""

        # read multiple sequence alignment
        self.logger.info('Reading multiple sequence alignment.')
        msa = read_fasta(msa_file, False)
        self.logger.info('Read MSA for %d genomes.' % len(msa))

        filtered_seqs, pruned_seqs = self.trim(msa, marker_list)

        self.logger.info(
            'Removed %d taxa have amino acids in <%.1f%% of columns in filtered MSA.'
            % (len(pruned_seqs), self.min_perc_aa))

        # write out trimmed sequences
        with open(os.path.join(self.output_dir, "filtered_msa.faa"),
                  'w') as filter_file:
            for gid, seq in filtered_seqs.items():
                fasta_outstr = ">%s\n%s\n" % (gid, seq)
                filter_file.write(fasta_outstr)

        self.logger.info('Done.')
Example #7
0
    def trim_msa(self, untrimmed_msa, mask_type, maskid, output_file):
        """Trim the multiple sequence alignment using a mask.

        Parameters
        ----------
        untrimmed_msa : str
            The path to the untrimmed MSA.
        mask_type : str
            Which mask should be used, reference or user specified.
        maskid : str
            The path to the mask used for trimming.
        output_file : str
            The path to the output trimmed MSA.
        """
        if maskid == 'bac' and mask_type == 'reference':
            mask = os.path.join(Config.MASK_DIR, Config.MASK_BAC120)
        elif maskid == 'arc' and mask_type == 'reference':
            mask = os.path.join(Config.MASK_DIR, Config.MASK_AR122)
        elif mask_type == 'file':
            mask = maskid
        else:
            self.logger.error('Command not understood.')
            raise GTDBTkException('Command not understood.')

        with open(mask, 'r') as f:
            maskstr = f.readline()

        with open(output_file, 'w') as outfwriter:
            dict_genomes = read_fasta(untrimmed_msa, False)

            for k, v in dict_genomes.items():
                aligned_seq = ''.join([
                    v[i] for i in range(0, len(maskstr)) if maskstr[i] == '1'
                ])
                fasta_outstr = ">%s\n%s\n" % (k, aligned_seq)
                outfwriter.write(fasta_outstr)
Example #8
0
    def _producer(self, genome_file):
        """Apply prodigal to genome with most suitable translation table.

        Parameters
        ----------
        genome_file : str
            Fasta file for genome.
        """

        genome_id = remove_extension(genome_file)

        aa_gene_file = os.path.join(self.output_dir, genome_id + '_genes.faa')
        nt_gene_file = os.path.join(self.output_dir, genome_id + '_genes.fna')
        gff_file = os.path.join(self.output_dir, genome_id + '.gff')

        best_translation_table = -1
        table_coding_density = {4: -1, 11: -1}
        table_prob = {4: -1, 11: -1}
        if self.called_genes:
            os.system('cp %s %s' %
                      (os.path.abspath(genome_file), aa_gene_file))
        else:
            seqs = read_fasta(genome_file)

            if len(seqs) == 0:
                self.logger.warning(
                    'Cannot call Prodigal on an empty genome. Skipped: {}'.
                    format(genome_file))
                return None

            tmp_dir = tempfile.mkdtemp()

            # determine number of bases
            total_bases = 0
            for seq in seqs.values():
                total_bases += len(seq)

            # call genes under different translation tables
            if self.translation_table:
                translation_tables = [self.translation_table]
            else:
                translation_tables = [4, 11]

            translation_table_gffs = dict()
            tln_table_stats = dict()
            for translation_table in translation_tables:
                os.makedirs(os.path.join(tmp_dir, str(translation_table)))
                aa_gene_file_tmp = os.path.join(tmp_dir,
                                                str(translation_table),
                                                genome_id + '_genes.faa')
                nt_gene_file_tmp = os.path.join(tmp_dir,
                                                str(translation_table),
                                                genome_id + '_genes.fna')

                # check if there are sufficient bases to calculate prodigal parameters
                if total_bases < 100000 or self.meta:
                    proc_str = 'meta'  # use best precalculated parameters
                else:
                    proc_str = 'single'  # estimate parameters from data

                # If this is a gzipped genome, re-write the uncompressed genome file to disk
                prodigal_input = genome_file
                if genome_file.endswith('.gz'):
                    prodigal_input = os.path.join(
                        tmp_dir,
                        os.path.basename(genome_file[0:-3]) + '.fna')
                    write_fasta(seqs, prodigal_input)

                args = [
                    'prodigal', '-m', '-p', proc_str, '-q', '-f', 'gff', '-g',
                    str(translation_table), '-a', aa_gene_file_tmp, '-d',
                    nt_gene_file_tmp, '-i', prodigal_input
                ]
                if self.closed_ends:
                    args.append('-c')

                self.logger.debug('{}: {}'.format(genome_id, ' '.join(args)))

                proc = subprocess.Popen(args,
                                        stdout=subprocess.PIPE,
                                        stderr=subprocess.STDOUT)
                proc_out, proc_err = proc.communicate()
                gff_stdout = proc_out

                translation_table_gffs[translation_table] = gff_stdout

                if proc.returncode != 0:
                    self.logger.warning(
                        'Prodigal returned a non-zero exit code while processing: {}'
                        .format(genome_file))
                    return None

                # determine coding density
                prodigal_parser = ProdigalGeneFeatureParser(gff_stdout)

                # Skip if no genes were called.
                if prodigal_parser.n_sequences_processed() == 0:
                    shutil.rmtree(tmp_dir)
                    self.logger.warning(
                        'No genes were called! Check the quality of your genome. Skipped: {}'
                        .format(genome_file))
                    return None

                # Save the statistics for this translation table
                prodigal_stats = prodigal_parser.generate_statistics()
                tln_table_stats[translation_table] = prodigal_stats
                table_coding_density[
                    translation_table] = prodigal_stats.coding_density

            # determine best translation table
            if not self.translation_table:

                # Logistic classifier coefficients
                b0 = 12.363017423768538
                bi = np.array([
                    0.01212327382066545, -0.9250857181041326,
                    -0.10176647009345675, 0.7733711446656522,
                    0.6355731038236031, -0.1631355971443377,
                    -0.14713264317198863, -0.10320909026025472,
                    0.09621494439016824, 0.4992209080695785, 1.159933669041023,
                    -0.0507139271834123, 1.2619603455217179,
                    0.24392226222721214, -0.08567859197118802,
                    -0.18759562346413916, 0.13136209122186523,
                    -0.1399459561138417, 2.08086235029142, 0.6917662070950119
                ])

                # Scale x
                scaler_mean = np.array([
                    0.0027036907781622732, -1.8082140490218692,
                    -8.511942254988097e-08, 19.413811775420918,
                    12.08719100126732, 249.89521467118365,
                    0.0011868456444391487, -0.0007358432829349235,
                    0.004750880986023392, -0.04096159411654551,
                    -0.12505492579693805, -0.03749033894554058,
                    0.13053986993752234, -0.15914556336256136,
                    -0.6075506034967058, 0.06704648371665446,
                    0.04316693333324335, 0.26905236546875266,
                    0.010326462563249823, 333.3320678912514
                ])
                scaler_scale = np.array([
                    0.08442772272873166, 2.043313786484819,
                    2.917510891467501e-05, 22.577812640992242,
                    12.246767248868036, 368.87834547339907,
                    0.0014166252200216657, 0.0014582164250905056,
                    0.025127203671053467, 0.5095427815162036,
                    0.2813128128116135, 0.2559877920464989, 1.274371529860827,
                    0.7314782174742842, 1.6885750374356985,
                    0.17019369029012987, 0.15376309021975043,
                    0.583965556283342, 0.025076680822882474, 544.3648797867784
                ])
                xi = np.array(tln_table_stats[11]) - np.array(
                    tln_table_stats[4])
                xi -= scaler_mean
                xi /= scaler_scale

                # If xi are all 0, then P(11) = 1.
                prob_tbl_11 = 1 / (1 + np.exp(-1 * (b0 + (bi * xi).sum())))
                best_translation_table = 11 if prob_tbl_11 >= 0.5 else 4
                table_prob[4] = 1.0 - prob_tbl_11
                table_prob[11] = prob_tbl_11

            else:
                best_translation_table = self.translation_table

            shutil.copyfile(
                os.path.join(tmp_dir, str(best_translation_table),
                             genome_id + '_genes.faa'), aa_gene_file)
            shutil.copyfile(
                os.path.join(tmp_dir, str(best_translation_table),
                             genome_id + '_genes.fna'), nt_gene_file)
            with open(gff_file, 'w') as f:
                f.write(translation_table_gffs[best_translation_table])

            # clean up temporary files
            shutil.rmtree(tmp_dir)
        return genome_id, aa_gene_file, nt_gene_file, gff_file, best_translation_table, table_coding_density[
            4], table_coding_density[11], table_prob[4], table_prob[11]
Example #9
0
    def _report_identified_marker_genes(self, gene_dict, outdir, prefix):
        """Report statistics for identified marker genes."""

        translation_table_file = open(
            os.path.join(outdir, PATH_TLN_TABLE_SUMMARY.format(prefix=prefix)),
            "w")
        bac_outfile = open(
            os.path.join(outdir,
                         PATH_BAC120_MARKER_SUMMARY.format(prefix=prefix)),
            "w")
        arc_outfile = open(
            os.path.join(outdir,
                         PATH_AR122_MARKER_SUMMARY.format(prefix=prefix)), "w")

        header = "Name\tnumber_unique_genes\tnumber_multiple_genes\tnumber_missing_genes\tlist_unique_genes\tlist_multiple_genes\tlist_missing_genes\n"

        bac_outfile.write(header)
        arc_outfile.write(header)

        # gather information for all marker genes
        marker_dbs = {
            "PFAM": PFAM_TOP_HIT_SUFFIX,
            "TIGR": TIGRFAM_TOP_HIT_SUFFIX
        }

        marker_bac_list_original = []
        for db_marker in Config.BAC120_MARKERS.keys():
            marker_bac_list_original.extend([
                marker.replace(".HMM", "").replace(".hmm", "")
                for marker in Config.BAC120_MARKERS[db_marker]
            ])

        marker_arc_list_original = []
        for db_marker in Config.AR122_MARKERS.keys():
            marker_arc_list_original.extend([
                marker.replace(".HMM", "").replace(".hmm", "")
                for marker in Config.AR122_MARKERS[db_marker]
            ])

        for db_genome_id, info in gene_dict.items():

            unique_genes_bac, multi_hits_bac, missing_genes_bac = [], [], []
            unique_genes_arc, multi_hits_arc, missing_genes_arc = [], [], []

            gene_bac_dict, gene_arc_dict = {}, {}

            path = info.get("aa_gene_path")
            for _marker_db, marker_suffix in marker_dbs.items():
                # get all gene sequences
                protein_file = str(path)
                tophit_path = os.path.join(
                    outdir, DIR_MARKER_GENE, db_genome_id,
                    '{}{}'.format(db_genome_id, marker_suffix))

                # we load the list of all the genes detected in the genome
                all_genes_dict = read_fasta(protein_file, False)

                # Prodigal adds an asterisks at the end of each called genes.
                # These asterisks sometimes appear in the MSA, which can be
                # an issue for some downstream software
                for seq_id, seq in all_genes_dict.items():
                    if seq[-1] == '*':
                        all_genes_dict[seq_id] = seq[:-1]

                # we store the tophit file line by line and store the
                # information in a dictionary
                with open(tophit_path) as tp:
                    # first line is header line
                    tp.readline()

                    for line_tp in tp:
                        linelist = line_tp.split("\t")
                        genename = linelist[0]
                        sublist = linelist[1]
                        if ";" in sublist:
                            diff_markers = sublist.split(";")
                        else:
                            diff_markers = [sublist]

                        for each_mark in diff_markers:
                            sublist = each_mark.split(",")
                            markerid = sublist[0]

                            if (markerid not in marker_bac_list_original and
                                    markerid not in marker_arc_list_original):
                                continue

                            if markerid in marker_bac_list_original:
                                if markerid in gene_bac_dict:
                                    gene_bac_dict.get(
                                        markerid)["multihit"] = True
                                else:
                                    gene_bac_dict[markerid] = {
                                        "gene": genename,
                                        "multihit": False
                                    }

                            if markerid in marker_arc_list_original:
                                if markerid in gene_arc_dict:
                                    gene_arc_dict.get(
                                        markerid)["multihit"] = True
                                else:
                                    gene_arc_dict[markerid] = {
                                        "gene": genename,
                                        "multihit": False
                                    }

            for mid in marker_bac_list_original:
                if mid not in gene_bac_dict:
                    missing_genes_bac.append(mid)
                elif gene_bac_dict[mid]["multihit"]:
                    multi_hits_bac.append(mid)
                else:
                    unique_genes_bac.append(mid)

            for mid in marker_arc_list_original:
                if mid not in gene_arc_dict:
                    missing_genes_arc.append(mid)
                elif gene_arc_dict[mid]["multihit"]:
                    multi_hits_arc.append(mid)
                else:
                    unique_genes_arc.append(mid)

            bac_outfile.write("{0}\t{1}\t{2}\t{3}\t{4}\t{5}\t{6}\n".format(
                db_genome_id, len(unique_genes_bac), len(multi_hits_bac),
                len(missing_genes_bac), ','.join(unique_genes_bac),
                ','.join(multi_hits_bac), ','.join(missing_genes_bac)))

            arc_outfile.write("{0}\t{1}\t{2}\t{3}\t{4}\t{5}\t{6}\n".format(
                db_genome_id, len(unique_genes_arc), len(multi_hits_arc),
                len(missing_genes_arc), ','.join(unique_genes_arc),
                ','.join(multi_hits_arc), ','.join(missing_genes_arc)))

            translation_table_file.write('{}\t{}\n'.format(
                db_genome_id, info.get("best_translation_table")))

        bac_outfile.close()
        arc_outfile.close()
        translation_table_file.close()

        # Create a symlink to store the summary files in the root.
        symlink_f(
            PATH_BAC120_MARKER_SUMMARY.format(prefix=prefix),
            os.path.join(
                outdir,
                os.path.basename(
                    PATH_BAC120_MARKER_SUMMARY.format(prefix=prefix))))
        symlink_f(
            PATH_AR122_MARKER_SUMMARY.format(prefix=prefix),
            os.path.join(
                outdir,
                os.path.basename(
                    PATH_AR122_MARKER_SUMMARY.format(prefix=prefix))))
        symlink_f(
            PATH_TLN_TABLE_SUMMARY.format(prefix=prefix),
            os.path.join(
                outdir,
                os.path.basename(
                    PATH_TLN_TABLE_SUMMARY.format(prefix=prefix))))
Example #10
0
    def run(self, outf):

        # Check if all directories are here
        actual_dirs = os.listdir(self.pack_dir)
        if len(actual_dirs) != len(self.list_dirsinpackage):
            print('ERROR:')
        if len(set(actual_dirs) & set(self.list_dirsinpackage)) != len(
                self.list_dirsinpackage):
            print('ERROR:')

        with open(os.path.join(self.pack_dir, 'metadata',
                               'metadata.txt')) as metafile:
            for line in metafile:
                if line.startswith('VERSION_DATA'):
                    version = line.strip().split('=')[1]

        # List genomes in fastani folder
        list_genomes = [
            os.path.basename(x) for x in glob.glob(
                os.path.join(self.pack_dir, 'fastani', 'database/*.gz'))
        ]
        list_genomes = [
            x.replace('_genomic.fna.gz',
                      '').replace('GCA_',
                                  'GB_GCA_').replace('GCF_', 'RS_GCF_')
            for x in list_genomes
        ]

        # Archaeal genome MSA is untrimmed
        ar_msa_file = glob.glob(os.path.join(self.pack_dir,
                                             'msa/*ar53.faa'))[0]
        ar_msa = read_fasta(ar_msa_file)
        first_seq = ar_msa.get(list(ar_msa.keys())[0])
        if len(first_seq) != 32675:
            print('ERROR: len(first_seq) != 32675')

        # Bacterial genome MSA is untrimmed
        bac_msa_file = glob.glob(os.path.join(self.pack_dir,
                                              'msa/*bac120.faa'))[0]
        bac_msa = read_fasta(bac_msa_file)
        first_seq = bac_msa.get(list(bac_msa.keys())[0])
        if len(first_seq) != 41155:
            print('ERROR: len(first_seq) != 41155')

        # Bacterial MASK is same length as the untrimmed bacterial genomes
        bac_mask_file = glob.glob(
            os.path.join(self.pack_dir, 'masks/*bac120.mask'))[0]
        bac_mask = ''
        with open(bac_mask_file) as bmf:
            bac_mask = bmf.readline()
        if len(bac_mask) != 41155:
            print('ERROR: len(bac_mask) != 41155')

        # Archaeal MASK is same length as the untrimmed archaeal genomes
        ar_mask_file = glob.glob(
            os.path.join(self.pack_dir, 'masks/*ar53.mask'))[0]
        ar_mask = ''
        with open(ar_mask_file) as amf:
            ar_mask = amf.readline()
        if len(ar_mask) != 32675:
            print('ERROR: len(ar_mask) != 32675')

        # Archaeal Pplacer MSA should have the same number of genomes as the
        # Archaeal untrimmed MSA
        ar_pplacer_msa_file = glob.glob(
            os.path.join(self.pack_dir, 'pplacer',
                         'gtdb_' + version + '_ar53.refpkg',
                         'ar53_msa_r95.faa'))[0]
        ar_pplacer_msa = read_fasta(ar_pplacer_msa_file)
        if len(ar_pplacer_msa) != len(ar_msa):
            print('ERROR: len(ar_pplacer_msa) != len(ar_msa)')
            print('len(ar_pplacer_msa): {}'.format(len(ar_pplacer_msa)))
            print('len(ar_msa): {}'.format(len(ar_msa)))
            print('difference genomes: {}'.format(
                list(
                    set(ar_msa.keys()).difference(set(
                        ar_pplacer_msa.keys())))))
        first_seq = ar_pplacer_msa.get(list(ar_pplacer_msa.keys())[0])
        # Archaeal Pplacer MSA should have the same length as the Archaeal mask
        if len(first_seq) != len([a for a in ar_mask if a == '1']):
            print(
                'ERROR: len(first_seq) != len([a for a in ar_mask if a ==1])')
            print('len(first_seq): {}'.format(len(first_seq)))
            print('len([a for a in ar_mask if a ==1]): {}'.format(
                len([a for a in ar_mask if a == '1'])))

        # Bacterial Pplacer MSA should have the same number of genomes as the
        # Bacterial untrimmed MSA
        bac_pplacer_msa_file = os.path.join(
            self.pack_dir, 'pplacer', 'gtdb_' + version + '_bac120.refpkg',
            'bac120_msa_r95.faa')
        bac_pplacer_msa = read_fasta(bac_pplacer_msa_file)
        if len(bac_pplacer_msa) != len(bac_msa):
            print('ERROR: len(bac_pplacer_msa) != len(bac_msa)')
            print('len(bac_pplacer_msa): {}'.format(len(bac_pplacer_msa)))
            print('len(bac_msa): {}'.format(len(bac_msa)))
            print('difference genomes: {}'.format(
                list(
                    set(bac_msa.keys()).difference(set(
                        bac_pplacer_msa.keys())))))
        first_seq = bac_pplacer_msa.get(list(bac_pplacer_msa.keys())[0])
        # Bacterial Pplacer MSA should have the same length as the Bacterial
        # mask
        if len(first_seq) != len([a for a in bac_mask if a == '1']):
            print(
                'ERROR: len(first_seq) != len([a for a in bac_mask if a ==1])')
            print('len(first_seq): {}'.format(len(first_seq)))
            print('len([a for a in bac_mask if a ==1]): {}'.format(
                len([a for a in bac_mask if a == '1'])))

        # Archaeal Tree should have the same number of leaves than nomber of
        # genomes in the MSA
        arc_tree = dendropy.Tree.get_from_path(os.path.join(
            self.pack_dir, 'pplacer', 'gtdb_' + version + '_ar53.refpkg',
            'ar53_' + version + '_unroot.pplacer.tree'),
                                               schema='newick',
                                               rooting='force-rooted',
                                               preserve_underscores=True)
        list_leaves = arc_tree.leaf_nodes()
        if len(list_leaves) != len(ar_pplacer_msa):
            print('ERROR: len(list_leaves) != len(ar_pplacer_msa)')
            print('len(list_leaves): {}'.format(len(list_leaves)))
            print('len(ar_pplacer_msa): {}'.format(len(ar_pplacer_msa)))

        # Bacterial Tree should have the same number of leaves than nomber of
        # genomes in the MSA
        bac_tree = dendropy.Tree.get_from_path(os.path.join(
            self.pack_dir, 'pplacer', 'gtdb_' + version + '_bac120.refpkg',
            'bac120_' + version + '_unroot.pplacer.tree'),
                                               schema='newick',
                                               rooting='force-rooted',
                                               preserve_underscores=True)
        list_leaves = bac_tree.leaf_nodes()
        if len(list_leaves) != len(bac_pplacer_msa):
            print('ERROR: len(list_leaves) != len(bac_pplacer_msa)')
            print('len(list_leaves): {}'.format(len(list_leaves)))
            print('len(bac_pplacer_msa): {}'.format(len(bac_pplacer_msa)))

        # Taxonomy file should have as many genomes as bac120 and ar53 MSA
        # combined
        tax_file = os.path.join(self.pack_dir, 'taxonomy', 'gtdb_taxonomy.tsv')
        tax_dict = {}
        with open(tax_file) as tf:
            for line in tf:
                infos = line.strip().split('\t')
                tax_dict[infos[0]] = infos[1]
        if len(tax_dict) != (len(ar_msa) + len(bac_msa)):
            print('ERROR: len(tax_dict) != (len(ar_msa) + len(bac_msa))')
            print('len(tax_dict): {}'.format(len(tax_dict)))
            print('len(ar_msa) + len(bac_msa): {}'.format(
                len(ar_msa) + len(bac_msa)))

        # Radii file should have as many genomes as bac120 and ar53 MSA
        # combined
        radii_file = os.path.join(self.pack_dir, 'radii', 'gtdb_radii.tsv')
        radii_dict = {}
        with open(radii_file) as rf:
            for line in rf:
                infos = line.strip().split('\t')
                radii_dict[infos[1]] = infos[2]
        if len(radii_dict) != (len(ar_msa) + len(bac_msa)):
            print('ERROR: len(radii_dict) != (len(ar_msa) + len(bac_msa))')
            print('len(radii_dict): {}'.format(len(radii_dict)))
            print('len(ar_msa) + len(bac_msa): {}'.format(
                len(ar_msa) + len(bac_msa)))
        if len(
                set(radii_dict.keys()).symmetric_difference(
                    set(tax_dict.keys()))) != 0:
            print(
                'ERROR: len(set(radii_dict.keys()).symmetric_difference(tax_dict.keys()))'
            )
            print(
                'set(radii_dict.keys()).symmetric_difference(tax_dict.keys()): {}'
                .format(
                    set(radii_dict.keys()).symmetric_difference(
                        set(tax_dict.keys()))))

        if len(list_genomes) != len(radii_dict):
            print('ERROR: len(list_genomes) != len(radii_dict)')
            print('Missing genomes {}'.format(
                set(list_genomes) ^ set(radii_dict.keys())))
            print('len(list_genomes): {}'.format(len(list_genomes)))
            print('len(radii_dict): {}'.format(len(radii_dict)))

        print('\n\nVERSION: {}'.format(version))
        print('Length trimmed bac120 MSA: {}'.format(
            len(bac_pplacer_msa.get(list(bac_pplacer_msa.keys())[0]))))
        print('Length trimmed ar53 MSA: {}'.format(
            len(ar_pplacer_msa.get(list(ar_pplacer_msa.keys())[0]))))
        print('')
        print('Number of genomes in fastani/database: {}'.format(
            len(list_genomes)))
        print('Number of genomes in radii file: {}'.format(len(radii_dict)))
        print('Number of genomes in taxonomy file: {}'.format(len(tax_dict)))

        print('Would you like to archive the folder? ')
        # raw_input returns the empty string for "enter"

        yes = {'yes', 'y', 'yep', ''}
        no = {'no', 'n'}

        final_choice = False
        choice = input().lower()
        if choice in yes:
            with tarfile.open(outf, "w:gz") as tar:
                packdir = copy.copy(self.pack_dir)
                if packdir.endswith('/'):
                    packdir = packdir[:-1]
                tar.add(self.pack_dir, arcname=os.path.basename(packdir))
        elif choice in no:
            return False
        else:
            sys.stdout.write("Please respond with 'yes' or 'no'")
Example #11
0
    def _run_multi_align(self, db_genome_id, path, marker_set_id):
        """
        Returns the concatenated marker sequence for a specific genome
        :param db_genome_id: Selected genome
        :param path: Path to the genomic fasta file for the genome
        :param marker_set_id: Unique ID of marker set to use for alignment
        """

        # gather information for all marker genes
        marker_paths = {
            "PFAM":
            os.path.join(self.pfam_hmm_dir, 'individual_hmms'),
            "TIGRFAM":
            os.path.join(os.path.dirname(self.tigrfam_hmm_dir),
                         'individual_hmms')
        }

        marker_dict_original = {}
        if marker_set_id == "bac120":
            for db_marker in sorted(self.bac120_markers):
                marker_dict_original.update({
                    marker.replace(".HMM", "").replace(".hmm", ""):
                    os.path.join(marker_paths[db_marker], marker)
                    for marker in self.bac120_markers[db_marker]
                })
        elif marker_set_id == "ar122":
            for db_marker in sorted(self.ar122_markers):
                marker_dict_original.update({
                    marker.replace(".HMM", "").replace(".hmm", ""):
                    os.path.join(marker_paths[db_marker], marker)
                    for marker in self.ar122_markers[db_marker]
                })
        elif marker_set_id == "rps23":
            for db_marker in sorted(self.rps23_markers):
                marker_dict_original.update({
                    marker.replace(".HMM", "").replace(".hmm", ""):
                    os.path.join(marker_paths[db_marker], marker)
                    for marker in self.rps23_markers[db_marker]
                })

        result_aligns = {db_genome_id: {}}

        marker_dbs = {
            "PFAM": self.pfam_top_hit_suffix,
            "TIGRFAM": self.tigrfam_top_hit_suffix
        }
        for marker_db, marker_suffix in marker_dbs.iteritems():
            # get all gene sequences
            genome_path = str(path)
            tophit_path = genome_path.replace(self.protein_file_suffix,
                                              marker_suffix)

            # we load the list of all the genes detected in the genome
            protein_file = tophit_path.replace(marker_suffix,
                                               self.protein_file_suffix)
            all_genes_dict = read_fasta(protein_file, False)

            # Prodigal adds an asterisks at the end of each called genes,
            # These asterisks sometimes appear in the MSA, which can be an
            # issue for some softwares downstream
            for seq_id, seq in all_genes_dict.iteritems():
                if seq[-1] == '*':
                    all_genes_dict[seq_id] = seq[:-1]

            # we store the tophit file line by line and store the
            # information in a dictionary
            with open(tophit_path) as tp:
                # first line is header line
                tp.readline()
                gene_dict = {}
                for line_tp in tp:
                    linelist = line_tp.split("\t")
                    genename = linelist[0]
                    sublist = linelist[1]
                    if ";" in sublist:
                        diff_markers = sublist.split(";")
                    else:
                        diff_markers = [sublist]

                    for each_gene in diff_markers:
                        sublist = each_gene.split(",")
                        markerid = sublist[0]
                        if markerid not in marker_dict_original.keys():
                            continue
                        evalue = sublist[1]
                        bitscore = sublist[2].strip()

                        if markerid in gene_dict:
                            oldbitscore = gene_dict.get(markerid).get(
                                "bitscore")
                            if oldbitscore < bitscore:
                                gene_dict[markerid] = {
                                    "marker_path":
                                    marker_dict_original.get(markerid),
                                    "gene":
                                    genename,
                                    "gene_seq":
                                    all_genes_dict.get(genename),
                                    "bitscore":
                                    bitscore
                                }
                        else:
                            gene_dict[markerid] = {
                                "marker_path":
                                marker_dict_original.get(markerid),
                                "gene": genename,
                                "gene_seq": all_genes_dict.get(genename),
                                "bitscore": bitscore
                            }

            for mid, mpath in marker_dict_original.iteritems():
                if mid not in gene_dict and mid not in result_aligns.get(
                        db_genome_id):
                    size = self._get_hmm_size(mpath)
                    result_aligns.get(db_genome_id).update({mid: "-" * size})
                    # final_genome.append((db_genome_id, mid, "-" * size))

            result_aligns.get(db_genome_id).update(
                self._run_align(gene_dict, db_genome_id))

        # we concatenate the aligned markers together and associate them with
        # the genome.
        for gid, markids in result_aligns.iteritems():
            seq = ""
            for markid in sorted(markids.keys()):
                seq = seq + markids.get(markid)

        return seq
    def run(self, dirin, dirout, gtr, release):
        """ renaming genome files for fastani"""

        # get list of genomes to retain (based on genome list 1014)
        genomes_to_retain = set()
        with open(gtr) as f:
            # f.readline()

            for line in f:
                line_split = line.strip().split('\t')
                genomes_to_retain.add(line_split[0])

        print('Genome to retain: %d' % len(genomes_to_retain))
        # get mapping from published UBA genomes to NCBI accessions
        __location__ = os.path.realpath(os.path.join(
            os.getcwd(), os.path.dirname(__file__)))

        uba_acc = {}
        with open(os.path.join(__location__, 'uba_ncbi_accessions.tsv')) as ub:
            for line in ub:
                line_split = line.strip().split('\t')
                if line_split[2] != "None":
                    uba_acc[line_split[0]] = {
                        "uba": line_split[1], "gca": 'GB_' + line_split[2]}
                else:
                    uba_acc[line_split[0]] = {"uba": line_split[1]}

        # renaming taxonomy:
        taxout = open(os.path.join(dirout, 'gtdb_taxonomy.tsv'), 'w')
        with open(os.path.join(dirin, 'gtdb_taxonomy.tsv')) as gt:
            for line in gt:
                info = line.strip().split("\t")
                if info[0] in genomes_to_retain:
                    if info[0].startswith("U_"):
                        subdict = uba_acc.get(info[0])
                        if "gca" in subdict.keys():
                            taxout.write("{0}\t{1}\n".format(
                                subdict.get("gca"), info[1]))
                        else:
                            taxout.write("{0}\t{1}\n".format(
                                subdict.get("uba"), info[1]))
                    else:
                        taxout.write(line)
        taxout.close()

        # renaming genome files for fastani
        fastanis = glob.glob(os.path.join(dirin, 'fastani', "*"))
        fastani_dir = os.path.join(dirout, 'fastani')
        if not os.path.exists(fastani_dir):
            os.makedirs(fastani_dir)
        for genome in fastanis:
            filenamef = os.path.basename(genome)
            filenamef = filenamef.replace("_genomic.fna", "")
            if filenamef.startswith("U_"):
                subdict = uba_acc.get(filenamef)
                if filenamef == "U_74684":
                    print(subdict)
                    print(genome)
                    print(os.path.join(fastani_dir, subdict.get("gca")[3:] + "_genomic.fna"))
                if "gca" in subdict.keys():
                    copyfile(genome, os.path.join(
                        fastani_dir, subdict.get("gca")[3:] + "_genomic.fna"))
                else:
                    copyfile(genome, os.path.join(
                        fastani_dir, subdict.get("uba") + "_genomic.fna"))
            else:
                copyfile(genome, os.path.join(
                    fastani_dir, filenamef + "_genomic.fna"))

        for dom in ['bac120', 'ar122']:
            # MSA renaming
            msadir = os.path.join(dirout, dom, 'msa')
            if not os.path.exists(msadir):
                os.makedirs(msadir)
            msa_dict = read_fasta(os.path.join(
                dirin, dom, 'gtdb_concatenated.faa'))
            seqout = open(os.path.join(msadir, 'gtdb_r' +
                                       release + '_' + dom + '.faa'), 'w')
            for gid, seq in msa_dict.items():
                if gid in genomes_to_retain:
                    if gid.startswith("U_"):
                        subdict = uba_acc.get(gid)
                        if "gca" in subdict.keys():
                            seqout.write(">{0}\n{1}\n".format(
                                subdict.get("gca"), seq))
                        else:
                            seqout.write(">{0}\n{1}\n".format(
                                subdict.get("uba"), seq))
                    else:
                        seqout.write(">{0}\n{1}\n".format(gid, seq))
            seqout.close()

            # PPLACER renaming
            pplacerdir = os.path.join(dirout, dom, 'pplacer')
            if not os.path.exists(pplacerdir):
                os.makedirs(pplacerdir)

            trees = glob.glob(os.path.join(dirin, dom, 'pplacer', "*.tree"))
            if len(trees) != 1:
                print("Error")
                sys.exit()
            else:
                treef = trees[0]
            fastas = glob.glob(os.path.join(dirin, dom, 'pplacer', "*.fa"))
            if len(fastas) != 1:
                print("Error")
                sys.exit()
            else:
                seqfile = fastas[0]
            logs = glob.glob(os.path.join(dirin, dom, 'pplacer', "*.log"))
            if len(logs) != 1:
                print("Error")
                sys.exit()
            else:
                logfile = logs[0]

            # produce corrected tree
            tree = dendropy.Tree.get_from_path(os.path.join(treef),
                                               schema='newick',
                                               rooting='force-rooted',
                                               preserve_underscores=True)
            for n in tree.leaf_node_iter():
                if n.taxon.label.startswith("U_"):
                    subdict = uba_acc.get(n.taxon.label)
                    if "gca" in subdict.keys():
                        n.taxon.label = subdict.get("gca")
                    else:
                        n.taxon.label = subdict.get("uba")
            tree.write_to_path(os.path.join(dirout, dom, 'pplacer', dom + "_r" + release + ".tree"),
                               schema='newick',
                               suppress_rooting=True,
                               unquoted_underscores=True)

            trimmed_seqout = open(os.path.join(
                dirout, dom, 'pplacer', 'trimmed_msa_' + dom + '.faa'), 'w')
            trimmed_fasta = read_fasta(seqfile)
            for gid, seq in trimmed_fasta.items():
                if gid in genomes_to_retain:
                    if gid.startswith("U_"):
                        subdict = uba_acc.get(gid)
                        if "gca" in subdict.keys():
                            trimmed_seqout.write(
                                ">{0}\n{1}\n".format(subdict.get("gca"), seq))
                        else:
                            trimmed_seqout.write(
                                ">{0}\n{1}\n".format(subdict.get("uba"), seq))
                    else:
                        trimmed_seqout.write(">{0}\n{1}\n".format(gid, seq))
            trimmed_seqout.close()

            logoutf = open(os.path.join(dirout, dom, 'pplacer',
                                        'fitting_' + dom + '.log'), 'w')
            with open(logfile) as logfin:
                for line in logfin:
                    for k, subdict in uba_acc.items():
                        if "gca" in subdict.keys():
                            line = line.replace(
                                k + ":", subdict.get("gca") + ":")
                        else:
                            line = line.replace(
                                k + ":", subdict.get("uba") + ":")
                    logoutf.write(line)
            logoutf.close()
Example #13
0
    def get_high_pplacer_taxonomy(self, out_dir, marker_set_id, prefix,
                                  user_msa_file, tree):
        """Parse the pplacer tree and write the partial taxonomy for each user genome based on their placements

        Parameters
        ----------
        out_dir : output directory
        prefix : desired prefix for output files
        marker_set_id : bacterial or archaeal id (bac120 or ar53)
        user_msa_file : msa file listing all user genomes for a certain domain
        tree : pplacer tree including the user genomes

        Returns
        -------
        dictionary[genome_label]=pplacer_taxonomy

        """
        results = {}
        out_root = os.path.join(out_dir, 'classify', 'intermediate_results')
        make_sure_path_exists(out_root)

        if marker_set_id == 'bac120':
            out_pplacer = PplacerHighClassifyFile(out_dir, prefix)
        else:
            self.logger.error('There was an error determining the marker set.')
            raise GenomeMarkerSetUnknown

        red_bac_dict = Config.RED_DIST_BAC_DICT

        # We get the pplacer taxonomy for comparison
        user_genome_ids = set(read_fasta(user_msa_file).keys())
        for leaf in tree.leaf_node_iter():

            is_on_terminal_branch = False
            terminal_branch_test = False
            term_branch_taxonomy = ''
            if leaf.taxon.label in user_genome_ids:
                pplacer_row = PplacerHighClassifyRow()
                taxa = []
                cur_node = leaf
                current_rel_dist = 1.0
                # every user genomes has a RED value of one assigned to it
                while cur_node.parent_node:
                    # we go up the tree from the user genome
                    if hasattr(
                            cur_node, 'rel_dist'
                    ) and current_rel_dist == 1.0 and cur_node.rel_dist < 1.0:
                        # if the parent node of the current genome has a red distance,
                        # it means it is part of the reference tree
                        # we store the first RED value encountered in the
                        # tree
                        current_rel_dist = cur_node.rel_dist
                    if cur_node.is_internal():
                        # We check if the genome is place on a terminal
                        # branch

                        if not terminal_branch_test:
                            child_genomes = [
                                nd.taxon.label for nd in cur_node.leaf_nodes()
                                if nd.taxon.label not in user_genome_ids
                            ]
                            if len(child_genomes) == 1:
                                is_on_terminal_branch = True
                                term_branch_taxonomy = self.gtdb_taxonomy.get(
                                    child_genomes[0])
                                terminal_branch_test = True
                            if len(child_genomes) > 1:
                                terminal_branch_test = True
                    # While going up the tree we store of taxonomy
                    # information
                    _support, taxon, _aux_info = parse_label(cur_node.label)
                    if taxon:
                        for t in taxon.split(';')[::-1]:
                            taxa.append(t.strip())
                    cur_node = cur_node.parent_node

                taxa_str = ';'.join(taxa[::-1])

                pplacer_tax = str(taxa_str)

                taxa_str_terminal, taxa_str_red = '', ''

                if is_on_terminal_branch:
                    # some rank may be missing from going up the tree.
                    # if the genome is on a terminal branch,
                    # we can select the taxonomy from the reference leaf to get the low level of the taxonomy
                    # we select down to genus
                    if len(taxa) > 1:
                        tax_of_leaf = term_branch_taxonomy[
                            term_branch_taxonomy.
                            index(taxa_str.split(';')[-1]) + 1:-1]
                    else:
                        tax_of_leaf = term_branch_taxonomy[1:-1]
                        taxa_str = 'd__Bacteria'

                    taxa_str_terminal = self._classify_on_terminal_branch(
                        tax_of_leaf, current_rel_dist,
                        taxa_str.split(';')[-1][0:3], term_branch_taxonomy,
                        red_bac_dict)

                cur_node = leaf
                parent_taxon_node = cur_node.parent_node
                _support, parent_taxon, _aux_info = parse_label(
                    parent_taxon_node.label)

                while parent_taxon_node is not None and not parent_taxon:
                    parent_taxon_node = parent_taxon_node.parent_node
                    _support, parent_taxon, _aux_info = parse_label(
                        parent_taxon_node.label)

                # is the node represent multiple ranks, we select the lowest one
                # i.e. if node is p__A;c__B;o__C we pick o__
                parent_rank = parent_taxon.split(";")[-1]

                if parent_rank[0:3] != 'g__':
                    node_in_ref_tree = cur_node
                    while len([
                            childnd.taxon.label.replace("'", '')
                            for childnd in node_in_ref_tree.leaf_iter()
                            if childnd.taxon.label in self.reference_ids
                    ]) == 0:
                        node_in_ref_tree = node_in_ref_tree.parent_node
                    # we select a node of the reference tree

                    # we select the child rank (if parent_rank = 'c__'
                    # child rank will be 'o__)'
                    child_rk = self.order_rank[
                        self.order_rank.index(parent_rank[0:3]) + 1]

                    # get all reference genomes under the current node
                    list_subnode = [
                        childnd.taxon.label.replace("'", '')
                        for childnd in node_in_ref_tree.leaf_iter()
                        if childnd.taxon.label in self.reference_ids
                    ]

                    # get all names for the child rank
                    list_ranks = [
                        self.gtdb_taxonomy.get(name)[self.order_rank.index(
                            child_rk)] for name in list_subnode
                    ]

                    # if there is just one rank name
                    if len(set(list_ranks)) == 1:
                        child_taxons = []
                        child_rel_dist = None
                        for subranknd in node_in_ref_tree.preorder_iter():
                            _support, subranknd_taxon, _aux_info = parse_label(
                                subranknd.label)
                            if subranknd.is_internal(
                            ) and subranknd_taxon is not None and subranknd_taxon.startswith(
                                    child_rk):
                                child_taxons = subranknd_taxon.split(";")
                                child_taxon_node = subranknd
                                child_rel_dist = child_taxon_node.rel_dist
                                break

                        taxa_str_red, taxa_str_terminal = self._classify_on_internal_branch(
                            leaf.taxon.label, child_taxons, current_rel_dist,
                            child_rel_dist, node_in_ref_tree, parent_rank,
                            child_rk, taxa_str, taxa_str_terminal,
                            is_on_terminal_branch, red_bac_dict)
                    else:
                        taxa_str_red = taxa_str

                results[leaf.taxon.label] = {
                    "tk_tax_red":
                    standardise_taxonomy(taxa_str_red, 'bac120'),
                    "tk_tax_terminal":
                    standardise_taxonomy(taxa_str_terminal, 'bac120'),
                    "pplacer_tax":
                    standardise_taxonomy(pplacer_tax, 'bac120'),
                    'rel_dist':
                    current_rel_dist
                }

                pplacer_row.gid = leaf.taxon.label
                pplacer_row.gtdb_taxonomy_red = standardise_taxonomy(
                    taxa_str_red, 'bac120')
                pplacer_row.gtdb_taxonomy_terminal = standardise_taxonomy(
                    taxa_str_terminal, 'bac120')
                pplacer_row.pplacer_taxonomy = standardise_taxonomy(
                    pplacer_tax, 'bac120')
                pplacer_row.is_terminal = is_on_terminal_branch
                pplacer_row.red = current_rel_dist

                out_pplacer.add_row(pplacer_row)

        out_pplacer.write()
        return results