def optimal_hit_for_each_query_nr(blast_output, max_evalue):
        contigs_to_best_alignments = defaultdict(list)
        accession_counts = defaultdict(lambda: 0)

        # For each contig, get the alignments that have the best total score (may be multiple if there are ties).
        for alignment in m8.parse_tsv(blast_output, m8.BLAST_OUTPUT_SCHEMA):
            if alignment["evalue"] > max_evalue:
                continue
            query = alignment["qseqid"]
            best_alignments = contigs_to_best_alignments[query]

            if len(best_alignments) == 0 or best_alignments[0][
                    "bitscore"] < alignment["bitscore"]:
                contigs_to_best_alignments[query] = [alignment]
            # Add all ties to best_hits.
            elif len(best_alignments) > 0 and best_alignments[0][
                    "bitscore"] == alignment["bitscore"]:
                contigs_to_best_alignments[query].append(alignment)

        # Create a map of accession to best alignment count.
        for _contig_id, alignments in contigs_to_best_alignments.items():
            for alignment in alignments:
                accession_counts[alignment["sseqid"]] += 1

        # For each contig, pick the optimal alignment based on the accession that has the most best alignments.
        # If there is still a tie, arbitrarily pick the first one (later we could factor in which taxid has the most blast candidates)
        for contig_id, alignments in contigs_to_best_alignments.items():
            optimal_alignment = None
            for alignment in alignments:
                if not optimal_alignment or accession_counts[optimal_alignment[
                        "sseqid"]] < accession_counts[alignment["sseqid"]]:
                    optimal_alignment = alignment

            yield optimal_alignment
 def filter_and_group_hits_by_query(blast_output, min_alignment_length,
                                    min_pident, max_evalue):
     # Filter and group results by query, yielding one result group at a time.
     # A result group consists of all hits for a query, grouped by subject.
     # Please see comment in get_top_m8_nt for more context.
     current_query = None
     current_query_hits = None
     previously_seen_queries = set()
     # Please see comments explaining the definition of "hsp" elsewhere in this file.
     for hsp in m8.parse_tsv(blast_output, m8.BLAST_OUTPUT_NT_SCHEMA):
         # filter local alignment HSPs based on minimum length and sequence similarity
         if hsp["length"] < min_alignment_length:
             continue
         if hsp["pident"] < min_pident:
             continue
         if hsp["evalue"] > max_evalue:
             continue
         query, subject = hsp["qseqid"], hsp["sseqid"]
         if query != current_query:
             assert query not in previously_seen_queries, "blastn output appears out of order, please resort by (qseqid, sseqid, score)"
             previously_seen_queries.add(query)
             if current_query_hits:
                 yield current_query_hits
             current_query = query
             current_query_hits = defaultdict(list)
         current_query_hits[subject].append(hsp)
     if current_query_hits:
         yield current_query_hits
    def update_read_dict(read2contig, blast_top_m8, read_dict, accession_dict,
                         db_type):
        consolidated_dict = read_dict
        read2blastm8 = {}
        contig2accession = {}
        contig2lineage = {}
        added_reads = {}

        for row, raw_line in m8.parse_tsv(
                blast_top_m8,
                m8.RERANKED_BLAST_OUTPUT_SCHEMA[db_type]['contig_level'],
                raw_lines=True):
            contig_id = row["qseqid"]
            accession_id = row["sseqid"]
            contig2accession[contig_id] = (accession_id, raw_line)
            contig2lineage[contig_id] = accession_dict[accession_id]

        for read_id, contig_id in read2contig.items():
            (accession,
             m8_line) = contig2accession.get(contig_id, (None, None))
            if accession:
                (species_taxid, genus_taxid,
                 family_taxid) = accession_dict[accession]
                if consolidated_dict.get(read_id):
                    consolidated_dict[read_id] += [
                        contig_id, accession, species_taxid, genus_taxid,
                        family_taxid
                    ]
                    consolidated_dict[read_id][2] = species_taxid
                else:
                    added_reads[read_id] = [
                        read_id, "1", species_taxid, accession, species_taxid,
                        genus_taxid, family_taxid, contig_id, accession,
                        species_taxid, genus_taxid, family_taxid,
                        'from_assembly'
                    ]
            if m8_line:
                read2blastm8[read_id] = m8_line
        return (consolidated_dict, read2blastm8, contig2lineage, added_reads)
Esempio n. 4
0
    def generate_hit_data_from_m8(m8_file, valid_hits, assembly_level):
        """
        Generate hit data from an m8 file.
        Only include hits whose name appears in the valid_hits collection.
        """
        # M8 file should have at least a single line.
        # Anything less than this is considered an empty file.
        MIN_M8_FILE_SIZE = 25
        hits = {}

        # File is empty.
        if os.path.getsize(m8_file) < MIN_M8_FILE_SIZE:
            return hits

        # See m8.BLAST_OUTPUT_SCHEMA for the m8_file format.
        m8_schema = m8.RERANKED_BLAST_OUTPUT_SCHEMA['nt'][
            assembly_level]  # Only runs for NT
        for hit in m8.parse_tsv(m8_file, m8_schema):

            if hit["qseqid"] in valid_hits:
                # Blast output is per HSP, yet the hit represents a set of HSPs,
                # so these fields have been aggregated across that set by
                # function summary_row() in class CandidateHit.
                hits[hit["qseqid"]] = {
                    "accession": hit["sseqid"],
                    "percent_id": hit["pident"],
                    "alignment_length": hit["length"],
                    "num_mismatches": hit["mismatch"],
                    "num_gaps": hit["gapopen"],
                    "query_start": hit["qstart"],
                    "query_end": hit["qend"],
                    "subject_start": hit["sstart"],
                    "subject_end": hit["send"],
                    "prop_mismatch": hit["mismatch"] / max(1, hit["length"]),
                }

        return hits
    def merge_taxon_counts(self):
        # Create new merged m8 and hit summary files
        nr_alignment_per_read = {}
        # if this is a bottleneck, consider
        # (1) if processing time bottleneck, load all the data to memory
        # (2) if memory bottleneck, going through nt first, since that will save us from storing
        #     results in memory for all the reads that get their hit from NT contigs
        for nr_hit_dict in parse_tsv(self.inputs.nr_hitsummary2_tab,
                                     TAB_SCHEMA_MERGED,
                                     strict=False):
            nr_alignment_per_read[
                nr_hit_dict["read_id"]] = SpeciesAlignmentResults(
                    contig=nr_hit_dict.get("contig_species_taxid"),
                    read=nr_hit_dict.get("species_taxid"),
                )

        with open(self.outputs.merged_m8_filename,
                  'w') as output_m8, open(self.outputs.merged_hit_filename,
                                          'w') as output_hit:
            # first pass for NR and output to m8 files if assignment should come from NT
            for nt_hit_dict, [nt_m8_dict, nt_m8_row] in zip(
                    parse_tsv(self.inputs.nt_hitsummary2_tab,
                              TAB_SCHEMA_MERGED,
                              strict=False),
                    parse_tsv(self.inputs.nt_m8,
                              BLAST_OUTPUT_SCHEMA,
                              raw_lines=True,
                              strict=False)):
                # assert files match
                assert nt_hit_dict['read_id'] == nt_m8_dict[
                    "qseqid"], f"Mismatched m8 and hit summary files for nt [{nt_hit_dict['read_id']} != {nt_m8_dict['qseqid']}]"

                nr_alignment = nr_alignment_per_read.get(
                    nt_hit_dict["read_id"])
                nt_alignment = SpeciesAlignmentResults(
                    contig=nt_hit_dict.get("contig_species_taxid"),
                    read=nt_hit_dict.get("species_taxid"))
                has_nt_contig_hit = nt_alignment.contig
                has_nr_contig_hit = nr_alignment and nr_alignment.contig
                has_nt_read_hit = nt_alignment.read
                has_nr_read_hit = nr_alignment and nr_alignment.read
                if has_nt_contig_hit or (not has_nr_contig_hit
                                         and has_nt_read_hit):
                    output_m8.write(nt_m8_row)
                    nt_hit_dict["source_count_type"] = "NT"
                    self._write_tsv_row(nt_hit_dict, TAB_SCHEMA_MERGED,
                                        output_hit)
                    if nr_alignment:
                        del nr_alignment_per_read[nt_hit_dict["read_id"]]
                elif has_nr_contig_hit or has_nr_read_hit:
                    continue
                else:
                    raise Exception("NO ALIGNMENTS FOUND - Should not be here")

            # dump remaining reads from NR
            for nr_hit_dict, [nr_m8_dict, nr_m8_row] in zip(
                    parse_tsv(self.inputs.nr_hitsummary2_tab,
                              TAB_SCHEMA_MERGED,
                              strict=False),
                    parse_tsv(self.inputs.nr_m8,
                              BLAST_OUTPUT_SCHEMA,
                              raw_lines=True,
                              strict=False)):
                # assert files match
                assert nr_hit_dict['read_id'] == nr_m8_dict[
                    "qseqid"], f"Mismatched m8 and hit summary files for NR [{nr_hit_dict['read_id']} {nr_m8_dict['qseqid']}]."

                nr_alignment = nr_alignment_per_read.get(
                    nr_hit_dict["read_id"])

                if nr_alignment:
                    output_m8.write(nr_m8_row)
                    nr_hit_dict["source_count_type"] = "NR"
                    self._write_tsv_row(nr_hit_dict, TAB_SCHEMA_MERGED,
                                        output_hit)

        # Create new merged m8 and hit summary files
        self.create_taxon_count_file()