Exemple #1
0
    def run(self):
        super(JobTrestle, self).run()

        if not os.path.isdir(self.work_dir):
            os.mkdir(self.work_dir)

        summary_file = os.path.join(self.work_dir, "trestle_summary.txt")
        resolved_repeats_seqs = os.path.join(self.work_dir,
                                             "resolved_copies.fasta")
        repeat_graph = RepeatGraph(fp.read_sequence_dict(self.graph_edges))
        repeat_graph.load_from_file(self.repeat_graph)

        try:
            repeats_info = tres_graph \
                .get_simple_repeats(repeat_graph, self.reads_alignment_file,
                                    fp.read_sequence_dict(self.graph_edges))
            tres_graph.dump_repeats(
                repeats_info, os.path.join(self.work_dir, "repeats_dump"))

            tres.resolve_repeats(self.args, self.work_dir, repeats_info,
                                 summary_file, resolved_repeats_seqs)
            tres_graph.apply_changes(
                repeat_graph, summary_file,
                fp.read_sequence_dict(resolved_repeats_seqs))
        except KeyboardInterrupt as e:
            raise
        #except Exception as e:
        #    logger.warning("Caught unhandled exception: " + str(e))
        #    logger.warning("Continuing to the next pipeline stage. "
        #                   "Please submit a bug report along with the full log file")

        repeat_graph.dump_to_file(self.out_files["repeat_graph"])
        fp.write_fasta_dict(repeat_graph.edges_fasta,
                            self.out_files["repeat_graph_edges"])
Exemple #2
0
    def run(self):
        super(JobConsensus, self).run()
        if not os.path.isdir(self.consensus_dir):
            os.mkdir(self.consensus_dir)

        #split into 1Mb chunks to reduce RAM usage
        CHUNK_SIZE = 1000000
        chunks_file = os.path.join(self.consensus_dir, "chunks.fasta")
        chunks = aln.split_into_chunks(fp.read_sequence_dict(self.in_contigs),
                                       CHUNK_SIZE)
        fp.write_fasta_dict(chunks, chunks_file)

        logger.info("Running Minimap2")
        out_alignment = os.path.join(self.consensus_dir, "minimap.bam")
        aln.make_alignment(chunks_file,
                           self.args.reads,
                           self.args.threads,
                           self.consensus_dir,
                           self.args.platform,
                           out_alignment,
                           reference_mode=True,
                           sam_output=True)

        contigs_info = aln.get_contigs_info(chunks_file)
        logger.info("Computing consensus")
        consensus_fasta = cons.get_consensus(out_alignment, chunks_file,
                                             contigs_info, self.args.threads,
                                             self.args.platform)

        #merge chunks back into single sequences
        merged_fasta = aln.merge_chunks(consensus_fasta)
        fp.write_fasta_dict(merged_fasta, self.out_consensus)
        os.remove(chunks_file)
        os.remove(out_alignment)
Exemple #3
0
def generate_scaffolds(contigs_file, links_file, out_scaffolds):

    contigs_fasta = fp.read_sequence_dict(contigs_file)
    scaffolds_fasta = {}
    used_contigs = set()

    connections = {}
    with open(links_file, "r") as f:
        for line in f:
            line = line.strip()
            if not line: continue
            ctg_1, sign_1, ctg_2, sign_2 = line.split("\t")
            if ctg_1 in contigs_fasta and ctg_2 in contigs_fasta:
                connections[sign_1 + ctg_1] = sign_2 + ctg_2
                connections[rc(sign_2) + ctg_2] = rc(sign_1) + ctg_1

    scaffolds_fasta = {}
    scaffolds_seq = {}
    for ctg in contigs_fasta:
        if ctg in used_contigs: continue

        used_contigs.add(ctg)
        scf = ["-" + ctg]
        #extending right
        while (scf[-1] in connections
               and unsigned(connections[scf[-1]]) not in used_contigs):
            scf.append(connections[scf[-1]])
            used_contigs.add(unsigned(scf[-1]))

        for i, ctg in enumerate(scf):
            scf[i] = rc(ctg[0]) + unsigned(ctg)
        scf = scf[::-1]

        #extending left
        while (scf[-1] in connections
               and unsigned(connections[scf[-1]]) not in used_contigs):
            scf.append(connections[scf[-1]])
            used_contigs.add(unsigned(scf[-1]))

        #generating sequence interleaved by Ns
        if len(scf) == 1:
            scaffolds_fasta[unsigned(ctg)] = contigs_fasta[unsigned(ctg)]
            scaffolds_seq[unsigned(ctg)] = scf
        else:
            scf_name = "scaffold_" + unsigned(scf[0]).strip("contig_")
            scaffolds_seq[scf_name] = scf
            scf_seq = []
            for scf_ctg in scf:
                if scf_ctg[0] == "+":
                    scf_seq.append(contigs_fasta[unsigned(scf_ctg)])
                else:
                    scf_seq.append(
                        fp.reverse_complement(
                            contigs_fasta[unsigned(scf_ctg)]))
            gap = "N" * cfg.vals["scaffold_gap"]
            scaffolds_fasta[scf_name] = gap.join(scf_seq)

    fp.write_fasta_dict(scaffolds_fasta, out_scaffolds)
    return scaffolds_seq
Exemple #4
0
def get_contigs_info(contigs_file):
    contigs_info = {}
    contigs_fasta = fp.read_sequence_dict(contigs_file)
    for ctg_id, ctg_seq in contigs_fasta.iteritems():
        contig_type = ctg_id.split("_")[0]
        contigs_info[ctg_id] = ContigInfo(ctg_id, len(ctg_seq), contig_type)

    return contigs_info
Exemple #5
0
def make_bubbles(alignment_path, contigs_info, contigs_path, err_mode,
                 num_proc, bubbles_out):
    """
    The main function: takes an alignment and returns bubbles
    """
    CHUNK_SIZE = 1000000

    contigs_fasta = fp.read_sequence_dict(contigs_path)
    aln_reader = SynchronizedSamReader(alignment_path,
                                       contigs_fasta,
                                       cfg.vals["max_read_coverage"],
                                       use_secondary=True)
    chunk_feeder = SynchonizedChunkManager(contigs_fasta,
                                           chunk_size=CHUNK_SIZE)

    manager = multiprocessing.Manager()
    results_queue = manager.Queue()
    error_queue = manager.Queue()
    bubbles_out_lock = multiprocessing.Lock()
    bubbles_out_handle = open(bubbles_out, "w")

    process_in_parallel(
        _thread_worker,
        (aln_reader, chunk_feeder, contigs_info, err_mode, results_queue,
         error_queue, bubbles_out_handle, bubbles_out_lock), num_proc)
    if not error_queue.empty():
        raise error_queue.get()

    #logging
    total_bubbles = 0
    total_long_bubbles = 0
    total_long_branches = 0
    total_empty = 0
    total_aln_errors = []
    coverage_stats = defaultdict(list)

    while not results_queue.empty():
        (ctg_id, num_bubbles, num_long_bubbles, num_empty, num_long_branch,
         aln_errors, mean_coverage) = results_queue.get()
        total_long_bubbles += num_long_bubbles
        total_long_branches += num_long_branch
        total_empty += num_empty
        total_aln_errors.extend(aln_errors)
        total_bubbles += num_bubbles
        coverage_stats[ctg_id].append(mean_coverage)

    for ctg in coverage_stats:
        coverage_stats[ctg] = int(
            sum(coverage_stats[ctg]) / len(coverage_stats[ctg]))

    mean_aln_error = sum(total_aln_errors) / (len(total_aln_errors) + 1)
    logger.debug("Generated %d bubbles", total_bubbles)
    logger.debug("Split %d long bubbles", total_long_bubbles)
    logger.debug("Skipped %d empty bubbles", total_empty)
    logger.debug("Skipped %d bubbles with long branches", total_long_branches)
    ###

    return coverage_stats, mean_aln_error
Exemple #6
0
def get_consensus(alignment_path, contigs_path, contigs_info, num_proc,
                  platform):
    """
    Main function
    """
    aln_reader = SynchronizedSamReader(
        alignment_path,
        fp.read_sequence_dict(contigs_path),
        max_coverage=cfg.vals["max_read_coverage"],
        use_secondary=True)
    manager = multiprocessing.Manager()
    results_queue = manager.Queue()
    error_queue = manager.Queue()

    #making sure the main process catches SIGINT
    orig_sigint = signal.signal(signal.SIGINT, signal.SIG_IGN)
    threads = []
    for _ in range(num_proc):
        threads.append(
            multiprocessing.Process(target=_thread_worker,
                                    args=(aln_reader, contigs_info, platform,
                                          results_queue, error_queue)))
    signal.signal(signal.SIGINT, orig_sigint)

    for t in threads:
        t.start()
    try:
        for t in threads:
            t.join()
            if t.exitcode == -9:
                logger.error("Looks like the system ran out of memory")
            if t.exitcode != 0:
                raise Exception(
                    "One of the processes exited with code: {0}".format(
                        t.exitcode))
    except KeyboardInterrupt:
        for t in threads:
            t.terminate()
        raise

    if not error_queue.empty():
        raise error_queue.get()
    aln_reader.close()

    out_fasta = {}
    total_aln_errors = []
    while not results_queue.empty():
        ctg_id, ctg_seq, aln_errors = results_queue.get()
        total_aln_errors.extend(aln_errors)
        if len(ctg_seq) > 0:
            out_fasta[ctg_id] = ctg_seq

    mean_aln_error = sum(total_aln_errors) / (len(total_aln_errors) + 1)
    logger.info("Alignment error rate: %f", mean_aln_error)

    return out_fasta
Exemple #7
0
def get_consensus(alignment_path, contigs_path, contigs_info, num_proc,
                  platform):
    """
    Main function
    """

    CHUNK_SIZE = 1000000
    contigs_fasta = fp.read_sequence_dict(contigs_path)
    mp_manager = multiprocessing.Manager()
    aln_reader = SynchronizedSamReader(
        alignment_path,
        contigs_fasta,
        mp_manager,
        max_coverage=cfg.vals["max_read_coverage"],
        use_secondary=True)
    chunk_feeder = SynchonizedChunkManager(contigs_fasta, mp_manager,
                                           CHUNK_SIZE)

    #manager = multiprocessing.Manager()
    results_queue = mp_manager.Queue()
    error_queue = mp_manager.Queue()

    process_in_parallel(
        _thread_worker,
        (aln_reader, chunk_feeder, platform, results_queue, error_queue),
        num_proc)

    if not error_queue.empty():
        raise error_queue.get()

    chunk_consensus = defaultdict(list)
    total_aln_errors = []
    while not results_queue.empty():
        ctg_id, region_start, ctg_seq, aln_errors = results_queue.get()
        total_aln_errors.extend(aln_errors)
        if len(ctg_seq) > 0:
            chunk_consensus[ctg_id].append((region_start, ctg_seq))

    out_fasta = {}
    for ctg in chunk_consensus:
        sorted_chunks = [
            x[1] for x in sorted(chunk_consensus[ctg], key=lambda p: p[0])
        ]
        out_fasta[ctg] = "".join(sorted_chunks)

    mean_aln_error = sum(total_aln_errors) / (len(total_aln_errors) + 1)
    logger.info("Alignment error rate: %f", mean_aln_error)

    return out_fasta
Exemple #8
0
    def run(self):
        super(JobShortPlasmidsAssembly, self).run()
        logger.info("Recovering short unassembled sequences")
        if not os.path.isdir(self.work_dir):
            os.mkdir(self.work_dir)
        plasmids = plas.assemble_short_plasmids(self.args, self.work_dir,
                                                self.contigs_path)

        #updating repeat graph
        repeat_graph = RepeatGraph(fp.read_sequence_dict(self.graph_edges))
        repeat_graph.load_from_file(self.repeat_graph)
        plas.update_graph(repeat_graph, plasmids)
        repeat_graph.dump_to_file(self.out_files["repeat_graph"])
        fp.write_fasta_dict(repeat_graph.edges_fasta,
                            self.out_files["repeat_graph_edges"])
Exemple #9
0
def get_consensus(alignment_path, contigs_path, contigs_info, num_proc,
                  platform):
    """
    Main function
    """
    aln_reader = SynchronizedSamReader(alignment_path,
                                       fp.read_sequence_dict(contigs_path),
                                       cfg.vals["max_read_coverage"])
    manager = multiprocessing.Manager()
    results_queue = manager.Queue()
    error_queue = manager.Queue()

    #making sure the main process catches SIGINT
    orig_sigint = signal.signal(signal.SIGINT, signal.SIG_IGN)
    threads = []
    for _ in xrange(num_proc):
        threads.append(
            multiprocessing.Process(target=_thread_worker,
                                    args=(aln_reader, contigs_info, platform,
                                          results_queue, error_queue)))
    signal.signal(signal.SIGINT, orig_sigint)

    for t in threads:
        t.start()
    try:
        for t in threads:
            t.join()
    except KeyboardInterrupt:
        for t in threads:
            t.terminate()

    if not error_queue.empty():
        raise error_queue.get()

    out_fasta = {}
    total_aln_errors = []
    while not results_queue.empty():
        ctg_id, ctg_seq, aln_errors = results_queue.get()
        total_aln_errors.extend(aln_errors)
        if len(ctg_seq) > 0:
            out_fasta[ctg_id] = ctg_seq

    mean_aln_error = float(sum(total_aln_errors)) / (len(total_aln_errors) + 1)
    logger.info("Alignment error rate: {0}".format(mean_aln_error))

    return out_fasta
Exemple #10
0
def extract_unmapped_reads(args, reads2contigs_mapping,
                           mapping_rate_threshold):
    mapping_rates = calc_mapping_rates(reads2contigs_mapping)
    unmapped_reads = dict()
    n_processed_reads = 0

    for file in args.reads:
        fasta_dict = fp.read_sequence_dict(file)
        for read, sequence in fasta_dict.items():
            contigs = mapping_rates.get(read)
            if contigs is None:
                unmapped_reads[read] = sequence
            else:
                is_unmapped = True
                for contig, mapping_rate in contigs.items():
                    if mapping_rate >= mapping_rate_threshold:
                        is_unmapped = False

                if is_unmapped:
                    unmapped_reads[read] = sequence

        n_processed_reads += len(fasta_dict)

    return unmapped_reads, n_processed_reads
Exemple #11
0
def extract_unique_plasmids(trimmed_reads_mapping,
                            trimmed_reads_path,
                            mapping_rate_threshold=0.8,
                            max_length_difference=500,
                            min_sequence_length=1000):
    hits = read_paf(trimmed_reads_mapping)
    trimmed_reads = set()

    for hit in hits:
        trimmed_reads.add(hit.query)
        trimmed_reads.add(hit.target)

    trimmed_reads = list(trimmed_reads)
    n_trimmed_reads = len(trimmed_reads)
    read2int = dict()
    int2read = dict()

    for i in xrange(n_trimmed_reads):
        read2int[trimmed_reads[i]] = i
        int2read[i] = trimmed_reads[i]

    similarity_graph = [[] for _ in xrange(n_trimmed_reads)]
    hits.sort(key=lambda hit: (hit.query, hit.target))

    current_hit = None
    query_mapping_segments = []
    target_mapping_segments = []
    seq_lengths = {}

    for hit in hits:
        seq_lengths[hit.query] = hit.query_length
        seq_lengths[hit.target] = hit.target_length

        if hit.query == hit.target:
            continue

        if (current_hit is None or hit.query != current_hit.query
                or hit.target != current_hit.target):
            if current_hit is not None:
                query_length = current_hit.query_length
                target_length = current_hit.target_length
                query_mapping_rate = \
                    unmapped.calc_mapping_rate(query_length,
                                               query_mapping_segments)
                target_mapping_rate = \
                    unmapped.calc_mapping_rate(target_length,
                                               target_mapping_segments)

                if (query_mapping_rate > mapping_rate_threshold
                        or target_mapping_rate > mapping_rate_threshold):
                    #abs(query_length - target_length) < max_length_difference:
                    vertex1 = read2int[current_hit.query]
                    vertex2 = read2int[current_hit.target]
                    similarity_graph[vertex1].append(vertex2)
                    similarity_graph[vertex2].append(vertex1)

            query_mapping_segments = []
            target_mapping_segments = []
            current_hit = hit

        query_mapping_segments.append(
            unmapped.MappingSegment(hit.query_start, hit.query_end))
        target_mapping_segments.append(
            unmapped.MappingSegment(hit.target_start, hit.target_end))

    connected_components, n_components = \
        utils.find_connected_components(similarity_graph)

    groups = [[] for _ in xrange(n_components)]
    for i in xrange(len(connected_components)):
        groups[connected_components[i]].append(int2read[i])

    #for g in groups:
    #    logger.debug("Group {0}".format(len(g)))
    #    for s in g:
    #        logger.debug("\t{0}".format(seq_lengths[s]))

    groups = [group for group in groups if len(group) > 1]
    trimmed_reads_dict = fp.read_sequence_dict(trimmed_reads_path)
    unique_plasmids = dict()

    for group in groups:
        sequence = trimmed_reads_dict[group[0]]
        if len(sequence) >= min_sequence_length:
            unique_plasmids[group[0]] = sequence

    return unique_plasmids
Exemple #12
0
def find_divergence(alignment_path, contigs_path, contigs_info, frequency_path,
                    positions_path, div_sum_path, min_aln_rate, platform,
                    num_proc, sub_thresh, del_thresh, ins_thresh):
    """
    Main function: takes in an alignment and finds the divergent positions
    """
    if not os.path.isfile(alignment_path) or not os.path.isfile(contigs_path):
        ctg_profile = []
        positions = _write_frequency_path(frequency_path, ctg_profile,
                                          sub_thresh, del_thresh, ins_thresh)
        total_header = "".join([
            "Total_positions_{0}_".format(len(positions["total"])),
            "with_thresholds_sub_{0}".format(sub_thresh),
            "_del_{0}_ins_{1}".format(del_thresh, ins_thresh)
        ])
        sub_header = "".join([
            "Sub_positions_{0}_".format(len(positions["sub"])),
            "with_threshold_sub_{0}".format(sub_thresh)
        ])
        del_header = "".join([
            "Del_positions_{0}_".format(len(positions["del"])),
            "with_threshold_del_{0}".format(del_thresh)
        ])
        ins_header = "".join([
            "Ins_positions_{0}_".format(len(positions["ins"])),
            "with_threshold_ins_{0}".format(ins_thresh)
        ])
        _write_positions(positions_path, positions, total_header, sub_header,
                         del_header, ins_header)

        window_len = 1000
        sum_header = "Tentative Divergent Position Summary"
        _write_div_summary(div_sum_path, sum_header, positions,
                           len(ctg_profile), window_len)
        return

    contigs_fasta = fp.read_sequence_dict(contigs_path)
    aln_reader = SynchronizedSamReader(alignment_path, contigs_fasta,
                                       config.vals["max_read_coverage"])
    chunk_feeder = SynchonizedChunkManager(contigs_fasta)

    manager = multiprocessing.Manager()
    results_queue = manager.Queue()
    error_queue = manager.Queue()

    process_in_parallel(_thread_worker,
                        (aln_reader, chunk_feeder, contigs_info, platform,
                         results_queue, error_queue), num_proc)

    if not error_queue.empty():
        raise error_queue.get()

    total_aln_errors = []
    while not results_queue.empty():
        _, ctg_profile, aln_errors = results_queue.get()

        positions = _write_frequency_path(frequency_path, ctg_profile,
                                          sub_thresh, del_thresh, ins_thresh)
        total_header = "".join([
            "Total_positions_{0}_".format(len(positions["total"])),
            "with_thresholds_sub_{0}".format(sub_thresh),
            "_del_{0}_ins_{1}".format(del_thresh, ins_thresh)
        ])
        sub_header = "".join([
            "Sub_positions_{0}_".format(len(positions["sub"])),
            "with_threshold_sub_{0}".format(sub_thresh)
        ])
        del_header = "".join([
            "Del_positions_{0}_".format(len(positions["del"])),
            "with_threshold_del_{0}".format(del_thresh)
        ])
        ins_header = "".join([
            "Ins_positions_{0}_".format(len(positions["ins"])),
            "with_threshold_ins_{0}".format(ins_thresh)
        ])
        _write_positions(positions_path, positions, total_header, sub_header,
                         del_header, ins_header)

        window_len = 1000
        sum_header = "Tentative Divergent Position Summary"
        _write_div_summary(div_sum_path, sum_header, positions,
                           len(ctg_profile), window_len)

        logger.debug("Total positions: %d", len(positions["total"]))
        total_aln_errors.extend(aln_errors)

    mean_aln_error = sum(total_aln_errors) / (len(total_aln_errors) + 1)
    logger.debug("Alignment error rate: %f", mean_aln_error)
Exemple #13
0
def extract_unique_plasmids(trimmed_reads_mapping,
                            trimmed_reads_path,
                            mapping_rate_threshold=0.8,
                            max_length_difference=500,
                            min_sequence_length=1000):
    trimmed_reads = set()
    for hit in read_paf(trimmed_reads_mapping):
        trimmed_reads.add(hit.query)
        trimmed_reads.add(hit.target)

    trimmed_reads = list(trimmed_reads)
    n_trimmed_reads = len(trimmed_reads)
    read2int = dict()
    int2read = dict()

    for i in xrange(n_trimmed_reads):
        read2int[trimmed_reads[i]] = i
        int2read[i] = trimmed_reads[i]

    similarity_graph = [[] for _ in xrange(n_trimmed_reads)]

    #each hit group stores alginmemnts for each (query, target) pair
    for hit_group in read_paf_grouped(trimmed_reads_mapping):
        if hit_group[0].query == hit_group[0].target:
            continue

        query_mapping_segments = []
        target_mapping_segments = []
        for hit in hit_group:
            query_mapping_segments.append(
                unmapped.MappingSegment(hit.query_start, hit.query_end))
            target_mapping_segments.append(
                unmapped.MappingSegment(hit.target_start, hit.target_end))

        query_length = hit_group[0].query_length
        target_length = hit_group[0].target_length
        query_mapping_rate = unmapped.calc_mapping_rate(
            query_length, query_mapping_segments)
        target_mapping_rate = unmapped.calc_mapping_rate(
            target_length, target_mapping_segments)

        if (query_mapping_rate > mapping_rate_threshold
                or target_mapping_rate > mapping_rate_threshold):
            #abs(query_length - target_length) < max_length_difference:
            vertex1 = read2int[hit_group[0].query]
            vertex2 = read2int[hit_group[0].target]
            similarity_graph[vertex1].append(vertex2)
            similarity_graph[vertex2].append(vertex1)

    connected_components, n_components = \
        utils.find_connected_components(similarity_graph)

    groups = [[] for _ in xrange(n_components)]
    for i in xrange(len(connected_components)):
        groups[connected_components[i]].append(int2read[i])

    #for g in groups:
    #    logger.debug("Group {0}".format(len(g)))
    #    for s in g:
    #        logger.debug("\t{0}".format(seq_lengths[s]))

    groups = [group for group in groups if len(group) > 1]
    trimmed_reads_dict = fp.read_sequence_dict(trimmed_reads_path)
    unique_plasmids = dict()

    for group in groups:
        sequence = trimmed_reads_dict[group[0]]
        if len(sequence) >= min_sequence_length:
            unique_plasmids[group[0]] = sequence

    return unique_plasmids
Exemple #14
0
def polish(contig_seqs, read_seqs, work_dir, num_iters, num_threads,
           error_mode, output_progress):
    """
    High-level polisher interface
    """
    logger_state = logger.disabled
    if not output_progress:
        logger.disabled = True

    subs_matrix = os.path.join(
        cfg.vals["pkg_root"], cfg.vals["err_modes"][error_mode]["subs_matrix"])
    hopo_matrix = os.path.join(
        cfg.vals["pkg_root"], cfg.vals["err_modes"][error_mode]["hopo_matrix"])
    stats_file = os.path.join(work_dir, "contigs_stats.txt")

    prev_assembly = contig_seqs
    contig_lengths = None
    coverage_stats = None
    for i in xrange(num_iters):
        logger.info("Polishing genome ({0}/{1})".format(i + 1, num_iters))

        #split into 1Mb chunks to reduce RAM usage
        #slightly vary chunk size between iterations
        CHUNK_SIZE = 1000000 - (i % 2) * 100000
        chunks_file = os.path.join(work_dir, "chunks_{0}.fasta".format(i + 1))
        chunks = split_into_chunks(fp.read_sequence_dict(prev_assembly),
                                   CHUNK_SIZE)
        fp.write_fasta_dict(chunks, chunks_file)

        ####
        logger.info("Running minimap2")
        alignment_file = os.path.join(work_dir,
                                      "minimap_{0}.sam".format(i + 1))
        make_alignment(chunks_file,
                       read_seqs,
                       num_threads,
                       work_dir,
                       error_mode,
                       alignment_file,
                       reference_mode=True,
                       sam_output=True)

        #####
        logger.info("Separating alignment into bubbles")
        contigs_info = get_contigs_info(chunks_file)
        bubbles_file = os.path.join(work_dir,
                                    "bubbles_{0}.fasta".format(i + 1))
        coverage_stats, mean_aln_error = \
            make_bubbles(alignment_file, contigs_info, chunks_file,
                         error_mode, num_threads,
                         bubbles_file)

        logger.info("Alignment error rate: {0}".format(mean_aln_error))
        consensus_out = os.path.join(work_dir,
                                     "consensus_{0}.fasta".format(i + 1))
        polished_file = os.path.join(work_dir,
                                     "polished_{0}.fasta".format(i + 1))
        if os.path.getsize(bubbles_file) == 0:
            logger.info("No reads were aligned during polishing")
            if not output_progress:
                logger.disabled = logger_state
            open(stats_file, "w").write("#seq_name\tlength\tcoverage\n")
            open(polished_file, "w")
            return polished_file, stats_file

        #####
        logger.info("Correcting bubbles")
        _run_polish_bin(bubbles_file, subs_matrix, hopo_matrix, consensus_out,
                        num_threads, output_progress)
        polished_fasta, polished_lengths = _compose_sequence(consensus_out)
        merged_chunks = merge_chunks(polished_fasta)
        fp.write_fasta_dict(merged_chunks, polished_file)

        #Cleanup
        os.remove(chunks_file)
        os.remove(bubbles_file)
        os.remove(consensus_out)
        os.remove(alignment_file)

        contig_lengths = polished_lengths
        prev_assembly = polished_file

    #merge information from chunks
    contig_lengths = merge_chunks(contig_lengths, fold_function=sum)
    coverage_stats = merge_chunks(coverage_stats,
                                  fold_function=lambda l: sum(l) / len(l))

    with open(stats_file, "w") as f:
        f.write("#seq_name\tlength\tcoverage\n")
        for ctg_id in contig_lengths:
            f.write("{0}\t{1}\t{2}\n".format(ctg_id, contig_lengths[ctg_id],
                                             coverage_stats[ctg_id]))

    if not output_progress:
        logger.disabled = logger_state

    return prev_assembly, stats_file
Exemple #15
0
def generate_polished_edges(edges_file, gfa_file, polished_contigs, work_dir,
                            error_mode, num_threads):
    """
    Generate polished graph edges sequences by extracting them from
    polished contigs
    """
    logger.debug("Generating polished GFA")

    alignment_file = os.path.join(work_dir, "edges_aln.sam")
    polished_dict = fp.read_sequence_dict(polished_contigs)
    make_alignment(polished_contigs, [edges_file],
                   num_threads,
                   work_dir,
                   error_mode,
                   alignment_file,
                   reference_mode=True,
                   sam_output=True)
    aln_reader = SynchronizedSamReader(alignment_file, polished_dict,
                                       cfg.vals["max_read_coverage"])
    aln_reader.init_reading()
    aln_by_edge = defaultdict(list)

    #getting one best alignment for each contig
    while not aln_reader.is_eof():
        _, ctg_aln = aln_reader.get_chunk()
        for aln in ctg_aln:
            aln_by_edge[aln.qry_id].append(aln)
    aln_reader.stop_reading()

    MIN_CONTAINMENT = 0.9
    updated_seqs = 0
    edges_dict = fp.read_sequence_dict(edges_file)
    for edge in edges_dict:
        if edge in aln_by_edge:
            main_aln = aln_by_edge[edge][0]
            map_start = main_aln.trg_start
            map_end = main_aln.trg_end
            for aln in aln_by_edge[edge]:
                if aln.trg_id == main_aln.trg_id and aln.trg_sign == main_aln.trg_sign:
                    map_start = min(map_start, aln.trg_start)
                    map_end = max(map_end, aln.trg_end)

            new_seq = polished_dict[main_aln.trg_id][map_start:map_end]
            if main_aln.qry_sign == "-":
                new_seq = fp.reverse_complement(new_seq)

            #print edge, main_aln.qry_len, len(new_seq), main_aln.qry_start, main_aln.qry_end
            if float(len(new_seq)) / aln.qry_len > MIN_CONTAINMENT:
                edges_dict[edge] = new_seq
                updated_seqs += 1

    #writes fasta file with polished egdes
    #edges_polished = os.path.join(work_dir, "polished_edges.fasta")
    #fp.write_fasta_dict(edges_dict, edges_polished)

    #writes gfa file with polished edges
    with open(os.path.join(work_dir, "polished_edges.gfa"), "w") as gfa_polished, \
         open(gfa_file, "r") as gfa_in:
        for line in gfa_in:
            if line.startswith("S"):
                seq_id = line.split()[1]
                coverage_tag = line.split()[3]
                gfa_polished.write("S\t{0}\t{1}\t{2}\n".format(
                    seq_id, edges_dict[seq_id], coverage_tag))
            else:
                gfa_polished.write(line)

    logger.debug("{0} sequences remained unpolished".format(
        len(edges_dict) - updated_seqs))
    os.remove(alignment_file)
Exemple #16
0
def generate_polished_edges(edges_file, gfa_file, polished_contigs, work_dir,
                            error_mode, polished_stats, num_threads):
    """
    Generate polished graph edges sequences by extracting them from
    polished contigs
    """
    logger.debug("Generating polished GFA")

    edges_new_coverage = {}
    with open(polished_stats, "r") as f:
        for line in f:
            if line.startswith("#"):
                continue
            ctg, _len, coverage = line.strip().split()
            ctg_id = ctg.split("_")[1]
            edges_new_coverage[ctg_id] = int(coverage)

    alignment_file = os.path.join(work_dir, "edges_aln.bam")
    polished_dict = fp.read_sequence_dict(polished_contigs)
    make_alignment(polished_contigs, [edges_file],
                   num_threads,
                   work_dir,
                   error_mode,
                   alignment_file,
                   reference_mode=True,
                   sam_output=True)
    aln_reader = SynchronizedSamReader(alignment_file, polished_dict,
                                       multiprocessing.Manager(),
                                       cfg.vals["max_read_coverage"])
    aln_by_edge = defaultdict(list)

    #getting one best alignment for each contig
    #for ctg in polished_dict:
    #    ctg_aln = aln_reader.get_alignments(ctg)
    for aln in aln_reader.get_all_alignments():
        aln_by_edge[aln.qry_id].append(aln)
    #logger.debug("Bam parsing done")

    MIN_CONTAINMENT = 0.9
    updated_seqs = 0
    edges_dict = fp.read_sequence_dict(edges_file)
    for edge in edges_dict:
        if edge in aln_by_edge:
            aln_by_edge[edge].sort(key=lambda a: a.qry_end - a.qry_start,
                                   reverse=True)
            main_aln = aln_by_edge[edge][0]
            map_start = main_aln.trg_start
            map_end = main_aln.trg_end
            for aln in aln_by_edge[edge]:
                if aln.trg_id == main_aln.trg_id and aln.trg_sign == main_aln.trg_sign:
                    map_start = min(map_start, aln.trg_start)
                    map_end = max(map_end, aln.trg_end)

            new_seq = polished_dict[main_aln.trg_id][map_start:map_end]
            if main_aln.qry_sign == "-":
                new_seq = fp.reverse_complement(new_seq)

            #print(edge, main_aln.qry_len, len(new_seq), main_aln.qry_start, main_aln.qry_end)
            if len(new_seq) / aln.qry_len > MIN_CONTAINMENT:
                edges_dict[edge] = new_seq
                updated_seqs += 1

    #writes fasta file with polished egdes
    #edges_polished = os.path.join(work_dir, "polished_edges.fasta")
    #fp.write_fasta_dict(edges_dict, edges_polished)

    #writes gfa file with polished edges
    with open(os.path.join(work_dir, "polished_edges.gfa"), "w") as gfa_polished, \
         open(gfa_file, "r") as gfa_in:
        for line in gfa_in:
            if line.startswith("S"):
                seq_id = line.split()[1]
                coverage_tag = line.split()[3]
                seq_num = seq_id.split("_")[1]
                if seq_num in edges_new_coverage:
                    #logger.info("from {0} to {1}".format(coverage_tag, edges_new_coverage[seq_num]))
                    coverage_tag = "dp:i:{0}".format(
                        edges_new_coverage[seq_num])
                gfa_polished.write("S\t{0}\t{1}\t{2}\n".format(
                    seq_id, edges_dict[seq_id], coverage_tag))
            else:
                gfa_polished.write(line)

    logger.debug("%d sequences remained unpolished",
                 len(edges_dict) - updated_seqs)
    os.remove(alignment_file)
Exemple #17
0
def find_divergence(alignment_path, contigs_path, contigs_info, frequency_path,
                    positions_path, div_sum_path, min_aln_rate, platform,
                    num_proc, sub_thresh, del_thresh, ins_thresh):
    """
    Main function: takes in an alignment and finds the divergent positions
    """
    if not os.path.isfile(alignment_path) or not os.path.isfile(contigs_path):
        ctg_profile = []
        positions = _write_frequency_path(frequency_path, ctg_profile,
                                          sub_thresh, del_thresh, ins_thresh)
        total_header = "".join([
            "Total_positions_{0}_".format(len(positions["total"])),
            "with_thresholds_sub_{0}".format(sub_thresh),
            "_del_{0}_ins_{1}".format(del_thresh, ins_thresh)
        ])
        sub_header = "".join([
            "Sub_positions_{0}_".format(len(positions["sub"])),
            "with_threshold_sub_{0}".format(sub_thresh)
        ])
        del_header = "".join([
            "Del_positions_{0}_".format(len(positions["del"])),
            "with_threshold_del_{0}".format(del_thresh)
        ])
        ins_header = "".join([
            "Ins_positions_{0}_".format(len(positions["ins"])),
            "with_threshold_ins_{0}".format(ins_thresh)
        ])
        _write_positions(positions_path, positions, total_header, sub_header,
                         del_header, ins_header)

        window_len = 1000
        sum_header = "Tentative Divergent Position Summary"
        _write_div_summary(div_sum_path, sum_header, positions,
                           len(ctg_profile), window_len)
        return

    aln_reader = SynchronizedSamReader(alignment_path,
                                       fp.read_sequence_dict(contigs_path),
                                       config.vals["max_read_coverage"])
    manager = multiprocessing.Manager()
    results_queue = manager.Queue()
    error_queue = manager.Queue()

    #making sure the main process catches SIGINT
    orig_sigint = signal.signal(signal.SIGINT, signal.SIG_IGN)
    threads = []
    for _ in xrange(num_proc):
        threads.append(
            multiprocessing.Process(target=_thread_worker,
                                    args=(aln_reader, contigs_info, platform,
                                          results_queue, error_queue)))
    signal.signal(signal.SIGINT, orig_sigint)

    for t in threads:
        t.start()
    try:
        for t in threads:
            t.join()
    except KeyboardInterrupt:
        for t in threads:
            t.terminate()

    if not error_queue.empty():
        raise error_queue.get()

    total_aln_errors = []
    while not results_queue.empty():
        ctg_id, ctg_profile, aln_errors = results_queue.get()

        positions = _write_frequency_path(frequency_path, ctg_profile,
                                          sub_thresh, del_thresh, ins_thresh)
        total_header = "".join([
            "Total_positions_{0}_".format(len(positions["total"])),
            "with_thresholds_sub_{0}".format(sub_thresh),
            "_del_{0}_ins_{1}".format(del_thresh, ins_thresh)
        ])
        sub_header = "".join([
            "Sub_positions_{0}_".format(len(positions["sub"])),
            "with_threshold_sub_{0}".format(sub_thresh)
        ])
        del_header = "".join([
            "Del_positions_{0}_".format(len(positions["del"])),
            "with_threshold_del_{0}".format(del_thresh)
        ])
        ins_header = "".join([
            "Ins_positions_{0}_".format(len(positions["ins"])),
            "with_threshold_ins_{0}".format(ins_thresh)
        ])
        _write_positions(positions_path, positions, total_header, sub_header,
                         del_header, ins_header)

        window_len = 1000
        sum_header = "Tentative Divergent Position Summary"
        _write_div_summary(div_sum_path, sum_header, positions,
                           len(ctg_profile), window_len)

        logger.debug("Total positions: {0}".format(len(positions["total"])))
        total_aln_errors.extend(aln_errors)

    mean_aln_error = float(sum(total_aln_errors)) / (len(total_aln_errors) + 1)
    logger.debug("Alignment error rate: {0}".format(mean_aln_error))
Exemple #18
0
def make_bubbles(alignment_path, contigs_info, contigs_path, err_mode,
                 num_proc, bubbles_out):
    """
    The main function: takes an alignment and returns bubbles
    """
    aln_reader = SynchronizedSamReader(alignment_path,
                                       fp.read_sequence_dict(contigs_path),
                                       cfg.vals["max_read_coverage"],
                                       use_secondary=True)
    manager = multiprocessing.Manager()
    results_queue = manager.Queue()
    error_queue = manager.Queue()

    #making sure the main process catches SIGINT
    orig_sigint = signal.signal(signal.SIGINT, signal.SIG_IGN)
    threads = []
    bubbles_out_lock = multiprocessing.Lock()
    bubbles_out_handle = open(bubbles_out, "w")
    for _ in range(num_proc):
        threads.append(
            multiprocessing.Process(
                target=_thread_worker,
                args=(aln_reader, contigs_info, err_mode, results_queue,
                      error_queue, bubbles_out_handle, bubbles_out_lock)))
    signal.signal(signal.SIGINT, orig_sigint)

    for t in threads:
        t.start()
    try:
        for t in threads:
            t.join()
            if t.exitcode == -9:
                logger.error("Looks like the system ran out of memory")
            if t.exitcode != 0:
                raise Exception(
                    "One of the processes exited with code: {0}".format(
                        t.exitcode))
    except KeyboardInterrupt:
        for t in threads:
            t.terminate()
        raise

    if not error_queue.empty():
        raise error_queue.get()
    aln_reader.close()

    total_bubbles = 0
    total_long_bubbles = 0
    total_long_branches = 0
    total_empty = 0
    total_aln_errors = []
    coverage_stats = {}

    while not results_queue.empty():
        (ctg_id, num_bubbles, num_long_bubbles, num_empty, num_long_branch,
         aln_errors, mean_coverage) = results_queue.get()
        total_long_bubbles += num_long_bubbles
        total_long_branches += num_long_branch
        total_empty += num_empty
        total_aln_errors.extend(aln_errors)
        total_bubbles += num_bubbles
        coverage_stats[ctg_id] = mean_coverage

    mean_aln_error = sum(total_aln_errors) / (len(total_aln_errors) + 1)
    logger.debug("Generated %d bubbles", total_bubbles)
    logger.debug("Split %d long bubbles", total_long_bubbles)
    logger.debug("Skipped %d empty bubbles", total_empty)
    logger.debug("Skipped %d bubbles with long branches", total_long_branches)

    return coverage_stats, mean_aln_error