def __init__(self, aligner, knotter, disjointigs, dot_plot): # type: (Aligner, LineMerger, DisjointigCollection, LineDotPlot) -> None self.aligner = aligner self.polisher = Polisher(aligner, aligner.dir_distributor) self.knotter = knotter self.disjointigs = disjointigs self.dot_plot = dot_plot self.scorer = Scorer()
def __init__(self, aligner, knotter, disjointigs, dot_plot, reads, recruiter): # type: (Aligner, LineMerger, DisjointigCollection, LineDotPlot, ReadCollection, PairwiseReadRecruiter) -> None self.aligner = aligner self.polisher = Polisher(aligner, aligner.dir_distributor) self.knotter = knotter self.disjointigs = disjointigs self.dot_plot = dot_plot self.scorer = Scorer() self.reads = reads self.recruiter = recruiter
def splitSegKmeans(aligner, seg, mult, all_reads_list): polisher = Polisher(aligner, aligner.dir_distributor) all_reads = ContigStorage() base = seg.asContig() tmp = [] rtv = readsToVectors(aligner, all_reads_list, base) kmeans = KMeans(n_clusters=mult, precompute_distances=True) recs = list(rtv.values()) result = kmeans.fit_predict(X=[rec.v for rec in recs]) print result clusters = dict() for i, c in enumerate(result): if c not in clusters: clusters[c] = [] clusters[c].append(recs[i].al) for c in clusters.values(): print str(c), ":", len(c) split_contigs = [] split_reads = [] for c in clusters.values(): split_contigs.append( Contig( polisher.polishSmallSegment(base.asSegment(), c).seg_from.Seq(), str(len(split_contigs)))) split_reads.append([al.seg_from.contig for al in c]) maxpi = 1 for i in range(mult): for j in range(mult): if i == j: sys.stdout.write("1.0 ") continue al = aligner.overlapAlign([split_contigs[i]], ContigStorage([split_contigs[j] ])).next() sys.stdout.write(str(al.percentIdentity()) + " ") maxpi = max(maxpi, al.percentIdentity()) print "" print "Maxpi:", maxpi if maxpi < 0.985: return zip(split_contigs, split_reads) else: return None
def readsToVectors(aligner, reads_list, base): als = [] rtv = dict() polisher = Polisher(aligner, aligner.dir_distributor) for al in fixAlDir(aligner.overlapAlign(reads_list, ContigStorage([base])), base): if len(al.seg_to) < len(base) - 100: continue else: als.append(al) rtv[al.seg_from.contig.id] = ReadRecord(al).extend(toVector(al)) reads_list = [al.seg_from.contig for al in als] bases = [base] for base_al1, base_al2, base_al3 in zip(als[0::3], als[1::3], als[2::3]): base_candidate = Contig( polisher.polishSmallSegment( base.asSegment(), [base_al1, base_al2, base_al3]).seg_from.Seq(), str(len(bases))) rtr_als = [] read_ids = set() # base_candidate = base_al.seg_from.asContig() for al in fixAlDir( aligner.overlapAlign(reads_list, ContigStorage([base_candidate])), base_candidate): if len(al.seg_to) < len(base_candidate) - 100: continue else: rtr_als.append(al) read_ids.add(al.seg_from.contig.id) if len(read_ids) == len(als): bases.append(base_candidate) for al in rtr_als: rtv[al.seg_from.contig.id].extend(toVector(al)) if len(bases) > 10: break for rec in rtv.values(): print rec.read.id, len(rec.v), rec.v return rtv
def __init__(self, aligner): # type: (Aligner) -> None # params.scores = ComplexScores() # params.scores.load(open("flye/config/bin_cfg/pacbio_substitutions.mat", "r")) self.aligner = aligner self.polisher = Polisher(aligner, aligner.dir_distributor) testList = [] for name, obj in inspect.getmembers(sys.modules[__name__]): if inspect.isclass(obj) and name.endswith("Test"): testList.append(obj) self.tests = dict([(c.__name__, c) for c in testList]) params.redo_alignments = True params.k = 500 params.l = 1500 params.min_k_mer_cov = 5 sys.stdout.level = common.log_params.LogPriority.alignment_files - 1
def testManual(self): dataset = TestDataset("abcdefgabhiDEFjkl") dname = dataset.addDisjointig( "abcdefgabhiCDEjklabcdefgabhiCDEjkl".upper()) dataset.generateReads(5, 15, True) read1 = dataset.addRead("cdefg") read2 = dataset.addRead("cdefg") name1 = dataset.addContig("abcde") name2 = dataset.addContig("efgabhi") lines, dp, reads = dataset.genAll(self.aligner) read1 = reads[read1] read2 = reads[read2] line1 = lines[name1] UniqueMarker(self.aligner).markAllUnique(lines, reads) knotter = LineMerger( lines, Polisher(self.aligner, self.aligner.dir_distributor), dp) dp.printAll(sys.stdout) res = knotter.tryMergeRight(line1) assert res is not None assert str(list(dp.allInter(res.asSegment()))) == \ "[((C0_abcde,C1_efgabhi)[0:1100]->(C0_abcde,C1_efgabhi)[3850:4950]:1.000!!!), ((C0_abcde,C1_efgabhi)[3850:4950]->(C0_abcde,C1_efgabhi)[0:1100]:1.000!!!), ((C0_abcde,C1_efgabhi)[0:6050-0]->(C0_abcde,C1_efgabhi)[0:6050-0]:1.000)]", str(list(dp.allInter(res.asSegment())))
def testCase(self, instance): # type: (list[str]) -> None dataset = TestDataset(instance[0], mutation_rate=0.01) dname = dataset.addDisjointig(instance[0] + instance[0].upper()) dataset.generateReads(int(instance[1]), 25, True) ethalon = int(instance[2]) for s in instance[3:]: dataset.addContig(s) lines, dp, reads = dataset.genAll(self.aligner) UniqueMarker(self.aligner).markAllUnique(lines, reads) knotter = LineMerger( lines, Polisher(self.aligner, self.aligner.dir_distributor), dp) extender = LineExtender(self.aligner, knotter, lines.disjointigs, dp) extender.updateAllStructures( itertools.chain.from_iterable(line.completely_resolved for line in lines)) while True: stop = True for line_id in list(lines.items.keys()): if line_id not in lines.items: continue line = lines[line_id] dp.printAll(sys.stdout) extended = extender.processLine(line) if extended: stop = False if stop: break print " ".join([ str(dataset.translateBack(line, self.aligner)) for line in lines.unique() ]) print[line.circular for line in lines.unique()] breaks = 0 for line in lines.unique(): if not line.circular: breaks += 1 assert breaks == ethalon, str(breaks) + " " + str(ethalon)
def test1(self): dataset = TestDataset("abcdefghijklmCDEFGHInopqr", mutation_rate=0.01) dname = dataset.addDisjointig("abcdefghijklmCDEFGHInopqrabcd".upper()) name1 = dataset.addContig("abcde") name2 = dataset.addContig("klmCDE") dataset.generateReads(5, 25, True) lines, dp, reads = dataset.genAll(self.aligner) UniqueMarker(self.aligner).markAllUnique(lines, reads) line1 = lines[name1] line2 = lines[name2] knotter = LineMerger( lines, Polisher(self.aligner, self.aligner.dir_distributor), dp) extender = LineExtender(self.aligner, knotter, lines.disjointigs, dp) print "New iteration results" print dataset.translateBack(line1, self.aligner), dataset.translateBack( line2, self.aligner) extender.updateAllStructures( itertools.chain.from_iterable(line.completely_resolved for line in lines)) while True: stop = True for line_id in list(lines.items.keys()): if line_id not in lines.items: continue line = lines[line_id] dp.printAll(sys.stdout) extended = extender.processLine(line) if extended: stop = False if stop: break print " ".join([ str(dataset.translateBack(line, self.aligner)) for line in lines.unique() ]) print[line.circular for line in lines.unique()]
class LineExtender: def __init__(self, aligner, knotter, disjointigs, dot_plot, reads, recruiter): # type: (Aligner, LineMerger, DisjointigCollection, LineDotPlot, ReadCollection, PairwiseReadRecruiter) -> None self.aligner = aligner self.polisher = Polisher(aligner, aligner.dir_distributor) self.knotter = knotter self.disjointigs = disjointigs self.dot_plot = dot_plot self.scorer = Scorer() self.reads = reads self.recruiter = recruiter def checkAlignments(self, seg, als): # type: (Segment,List[AlignmentPiece]) -> None rids = set([al.seg_from.contig.id for al in als]) for al in self.aligner.localAlign(self.reads, ContigStorage([seg.contig])): if al.seg_to.interSize( seg) > params.k and al.seg_from.contig.id not in rids: print "Missing alignment", al def processLine(self, line): # type: (NewLine) -> int line.completely_resolved.mergeSegments(params.k) bound = LinePosition(line, line.left()) new_recruits = 0 new_line = self.knotter.tryMergeRight(line) if new_line is not None: self.updateAllStructures(list(new_line.completely_resolved)) return 1 self.updateAllStructures(line.completely_resolved) while True: seg_to_resolve = line.completely_resolved.find( bound.suffix(), params.k) if seg_to_resolve is None: break if line.knot is not None and seg_to_resolve.right == len(line): break if seg_to_resolve.right <= line.initial[0].seg_to.left + params.k: bound = LinePosition(line, seg_to_resolve.right - params.k + 1) continue result = self.attemptCleanResolution(seg_to_resolve) total = sum([len(arr) for seg, arr in result]) new_recruits += total if total == 0: bound = LinePosition(line, seg_to_resolve.right - params.k + 1) continue self.updateAllStructures([seg for seg, arr in result]) new_line = self.knotter.tryMergeRight(line) if debugger.debugger is not None: debugger.debugger.dump() if new_line is not None: self.updateAllStructures(list(new_line.completely_resolved)) return new_recruits + 1 return new_recruits # input: a collection of segments that had reads recruited to. def updateAllStructures(self, interesting_segments): # type: (Iterable[Segment]) -> None interesting_segments = list(interesting_segments) sys.stdout.trace("Updating structures:", interesting_segments) # Correct contig sequences, update correct segment storages. Return segments that were corrected. corrected = self.correctSequences(interesting_segments) # Collect all relevant contig segments, collect all reads that align to relevant segments. # Mark resolved bound for each read. sys.stdout.trace("Expanding resolved segments:") records = self.collectRecords( corrected) # type: List[LineExtender.Record] for rec in records: sys.stdout.trace("Record:", rec.line, rec.correct, rec.resolved) sys.stdout.trace("Reads from record:") for al in rec: sys.stdout.trace(al, al.seg_from.contig.alignments) sys.stdout.trace(rec.reads) sys.stdout.trace(rec.potential_good) # Update resolved segments on all relevant contig positions self.updateResolved(records) def updateResolved(self, records): # type: (List[LineExtender.Record]) -> None ok = True while ok: sys.stdout.trace("Good reads:") rec = records[0] # type: LineExtender.Record for read_name in rec.good_reads: sys.stdout.trace(read_name, rec.read_bounds[read_name]) ok = False for rec in records: if self.attemptProlongResolved(rec): sys.stdout.trace("Successfully prolonged resolved:", rec.line, rec.line.initial, rec.resolved, rec.line.completely_resolved) ok = True for rec in records: line = rec.resolved.contig # type: NewLine line.completely_resolved.add(rec.resolved) for seg in rec.old_resolved: line.completely_resolved.add(seg) line.completely_resolved.mergeSegments(params.k - 1) def collectRecords(self, corrected): # type: (List[Segment]) -> List[LineExtender.Record] sys.stdout.trace("Collecting records", corrected) read_bounds = dict() records = dict() # type: Dict[Segment, LineExtender.Record] good_reads = set() for seg in corrected: sys.stdout.trace("Oppa initial:", seg) seg = seg.expandLeft(params.k) sys.stdout.trace("Alignments relevant for", seg, list(self.dot_plot.allInter(seg))) for al in self.dot_plot.allInter(seg): seg1 = al.matchingSequence().mapSegUp(al.seg_from.contig, seg) line = al.seg_from.contig # type:NewLine for seg_correct in line.correct_segments.allInter(al.seg_from): for seg_resolved in line.completely_resolved.allInter( seg_correct): if seg_resolved in records: continue if seg_resolved.right == len(line): next_start = len(line) else: next = line.completely_resolved.find( line.asSegment().suffix( pos=seg_resolved.right), 1) if next is None: next_start = len(line) else: next_start = next.left next_start = min(next_start, len(line) - 200) focus = line.segment( max(seg_resolved.left, min(seg_resolved.right - params.k, seg1.left)), min(seg_correct.right, next_start + params.k)) if self.recruiter is None: als = list(line.getRelevantAlignmentsFor(focus)) else: als = list( self.recruiter.getRelevantAlignments( focus, params.k)) if params.check_alignments: self.checkAlignments(focus, als) reads = ContigStorage() for al in als: reads.add(al.seg_from.contig) als = list( self.aligner.localAlign(reads.unique(), ContigStorage([line]))) final_als = [] sys.stdout.trace("Focus:", focus, seg_resolved) sys.stdout.trace(als) for al in als: if al.seg_to.contig == line.rc: al = al.rc if al.seg_to.interSize(focus) >= params.k - 100: final_als.append(al) sys.stdout.trace(final_als) sys.stdout.trace("Finished realignment of reads") records[seg_resolved] = self.createRecord( seg_resolved, next_start, seg_correct, final_als, good_reads, read_bounds) records = list(records.values()) # type: List[LineExtender.Record] return records def correctSequences(self, interesting_segments): # type: (Iterable[Segment]) -> List[Segment] interesting_segments = list(interesting_segments) to_correct = [] for seg in interesting_segments: line = seg.contig # type: NewLine correct = line.correct_segments.find(seg) next = line.correct_segments.find(line.suffix(correct.right), 1) if next is None: right = len(line) else: right = min(len(line), next.left + params.k / 2) to_correct.append(line.segment(correct.right - params.k / 2, right)) to_correct = sorted(to_correct, key=lambda seg: (basic.Normalize(seg.contig.id), seg.left)) corrected = [] for line_id, it in itertools.groupby( to_correct, key=lambda seg: basic.Normalize( seg.contig.id)): # type: NewLine, Iterable[Segment] it = list(it) line = None # type: NewLine forward = SegmentStorage() backward = SegmentStorage() for seg in it: if seg.contig.id != line_id: backward.add(seg) line = seg.contig.rc else: forward.add(seg) line = seg.contig to_polysh = SegmentStorage() to_polysh.addAll(forward).addAll(backward.rc) to_polysh.mergeSegments() line.addListener(to_polysh) line.addListener(forward) line.rc.addListener(backward) sys.stdout.trace("Polishing:", to_polysh) if (not line.max_extension) and to_polysh[-1].RC().left < 200: l = to_polysh[-1].right if self.attemptExtend(line): to_polysh.add(line.asSegment().suffix(pos=l)) forward.add(line.asSegment().suffix(pos=l)) if (not line.rc.max_extension) and to_polysh[0].left < 200: l = to_polysh[0].RC().right if self.attemptExtend(line.rc): to_polysh.rc.add(line.rc.asSegment().suffix(pos=l)) backward.add(line.rc.asSegment().suffix(pos=l)) to_polysh.mergeSegments() forward.mergeSegments() backward.mergeSegments() line.removeListener(to_polysh) new_segments = self.polyshSegments(line, to_polysh) line.removeListener(forward) line.rc.removeListener(backward) corrected.extend(forward) corrected.extend(backward) line.updateCorrectSegments(line.asSegment()) return corrected def attemptCleanResolution(self, resolved): # type: (Segment) -> List[Tuple[Segment, List[AlignmentPiece]]] # Find all lines that align to at least k nucls of resolved segment. Since this segment is resolve we get all sys.stdout.trace("Attempting recruitment:", resolved, resolved.contig, resolved.contig.correct_segments) resolved = resolved.suffix(length=min(len(resolved), params.k * 2)) sys.stdout.trace("Considering resolved subsegment:", resolved) line_alignments = filter( lambda al: len(al.seg_to) >= params.k and resolved. interSize(al.seg_to) > params.k - 30, self.dot_plot.allInter(resolved)) # type: List[AlignmentPiece] line_alignments = [ al for al in line_alignments if (al.seg_from.right >= al.seg_from.contig.initial[0].seg_to.right + params.k + 20 and al.seg_to.right >= al.seg_to.contig.initial[0].seg_to.right + params.k + 20) or al.isIdentical() ] sys.stdout.trace("Alternative lines:", map(str, line_alignments)) for al in line_alignments: if not al.isIdentical(): sys.stdout.trace(al) sys.stdout.trace("\n".join(al.asMatchingStrings())) line_alignments = [ al.reduce(target=resolved) for al in line_alignments ] read_alignments = [] # type: List[Tuple[AlignmentPiece, Segment]] correct_segments = [] active_segments = set() for ltl in line_alignments: line = ltl.seg_from.contig # type: NewLine new_copy = line.correct_segments.find(ltl.seg_from) # assert new_copy is not None and new_copy.interSize(ltl.seg_from) >= max(len(ltl.seg_from) - 20, params.k), str([ltl, new_copy, str(line.correct_segments)]) # assert new_copy is not None, str([ltl, line.correct_segments]) if new_copy is None: return [] if not new_copy.contains(ltl.seg_from): sys.stdout.trace( "Warning: alignment of resolved segment to uncorrected segment" ) sys.stdout.trace(ltl, new_copy, line.correct_segments) correct_segments.append(new_copy) if ltl.percentIdentity() > 0.95: active_segments.add(new_copy) if self.recruiter is None: relevant_alignments = list( line.getRelevantAlignmentsFor(ltl.seg_from)) else: relevant_alignments = list( self.recruiter.getRelevantAlignments( ltl.seg_from, params.k)) if params.check_alignments: self.checkAlignments(ltl.seg_from, relevant_alignments) read_alignments.extend( zip(relevant_alignments, itertools.cycle([correct_segments[-1]]))) read_alignments = sorted(read_alignments, key=lambda al: al[0].seg_from.contig.id) alignments_by_read = itertools.groupby( read_alignments, lambda al: al[0].seg_from.contig.id) new_recruits = [] sys.stdout.trace("Starting read recruitment to", map(str, line_alignments)) for name, it in alignments_by_read: als = list(it) # type: List[Tuple[AlignmentPiece, Segment]] read = als[0][0].seg_from.contig # type: AlignedRead sys.stdout.trace("Recruiting read:", read, als) ok = False for al in als: if al[0].seg_to.interSize(resolved) >= params.k: ok = True break if not ok: sys.stdout.trace("Read does not overlap with resolved", resolved) continue skip = False for al1 in als: for al2 in read.alignments: if al1[0].seg_to.inter(al2.seg_to): sys.stdout.trace("Read already recruited", al1, al2) skip = True break if skip: break if skip: continue new_als = [] for al in als: if not al[0].contradictingRTC(tail_size=params.bad_end_length): new_als.append((self.scorer.polyshAlignment( al[0], params.alignment_correction_radius), al[1])) if len(new_als) == 0: sys.stdout.warn("No noncontradicting alignments of a read") winner = None seg = None else: winner, seg = self.tournament( new_als) #type: AlignmentPiece, Segment if winner is None: sys.stdout.trace("No winner") else: sys.stdout.trace("Winner for", winner.seg_from.contig.id, ":", winner, seg) if winner is not None: if seg not in active_segments: sys.stdout.trace( "Winner ignored since winning segment is too different from investigated segment" ) elif winner.percentIdentity() < 0.85: sys.stdout.trace( "Winner ignored since it is too different from winning line" ) else: line = winner.seg_to.contig # type: NewLine line.addReadAlignment(winner) new_recruits.append((seg, winner)) new_recruits = sorted(new_recruits, key=lambda rec: (rec[0].contig.id, rec[0].left, rec[0].right)) sys.stdout.info("Recruited " + str(len(new_recruits)) + " new reads") return [(seg, [al for seg, al in it]) for seg, it in itertools.groupby(new_recruits, key=lambda rec: rec[0])] def fight(self, c1, c2): # type: (Tuple[AlignmentPiece, Segment], Tuple[AlignmentPiece, Segment]) -> Optional[Tuple[AlignmentPiece, Segment]] assert c1[0].seg_from.contig == c2[0].seg_from.contig s1, s2, s12 = self.scorer.scoreInCorrectSegments( c1[0], c1[1], c2[0], c2[1]) if s1 is not None and s2 is not None: diff = abs(s1 - s2) else: diff = None if s12 is None: if s1 is None: winner = c2 else: winner = c1 else: if s12 < 25 or (s12 < 40 and abs(s1 - s2) < s12 * 0.8) or ( s12 < 100 and abs(s1 - s2) < s12 * 0.5) or abs(s1 - s2) < s12 * 0.3: winner = None elif s1 > s2: winner = c2 else: winner = c1 if winner is None: sys.stdout.trace("Fight:", c1, c2, "Comparison results:", diff, s12, s1, s2, "No winner") else: sys.stdout.trace("Fight:", c1, c2, "Comparison results:", diff, s12, s1, s2, "Winner:", winner) return winner def tournament(self, candidates): # type: (List[Tuple[AlignmentPiece, Segment]]) -> Tuple[Optional[AlignmentPiece], Optional[Segment]] best = None best_id = None wins = [] for i, candidate in enumerate(candidates): if best is None: best = candidate best_id = i else: best = self.fight(candidate, best) if best is None: best_id = None wins = [] elif best == candidates[best_id]: wins.append(i) else: best_id = i wins = [] if best is None: return None, None if len(candidates) > 2: for i, candidate in enumerate(candidates): if i == best_id or i in wins: continue fight_results = self.fight(candidate, best) if fight_results is None or fight_results != best: return None, None return best def attemptExtend(self, line): # type: (NewLine) -> bool sys.stdout.trace("Attempting to extend:", line) if line.knot is not None: sys.stdout.trace("Blocked by knot") return False relevant_reads = list( line.read_alignments.allInter( line.asSegment().suffix(length=min(params.k, len(line) - 20)))) sys.stdout.trace("Relevant reads for extending", relevant_reads) if len(relevant_reads) == 0: return False new_contig, relevant_als = self.polisher.polishEnd(relevant_reads) if len(new_contig) == len(line): return False assert line.seq == new_contig.prefix(len=len(line)).Seq() tmp = len(new_contig) - len(line) sys.stdout.trace("Extending", line, "for", tmp) line.extendRight(new_contig.suffix(pos=len(line)).Seq(), relevant_als) sys.stdout.info("Extended contig", line, "for", tmp) sys.stdout.trace("Correct:", line.correct_segments) sys.stdout.trace("Reads:") sys.stdout.trace( list( line.read_alignments.allInter( line.asSegment().suffix(length=min(len(line), 2000))))) sys.stdout.trace("Sequence:") sys.stdout.trace(line.seq) return True def polyshSegments(self, line, to_polysh): # type: (NewLine, Iterable[Segment]) -> List[Segment] segs = SegmentStorage() corrections = AlignmentStorage() line.addListener(segs) segs.addAll(to_polysh) segs.mergeSegments() segs.sort() for seg in segs: corrections.add( self.polisher.polishSegment( seg, list(line.read_alignments.allInter(seg)))) line.correctSequence(list(corrections)) line.removeListener(segs) return list(segs) def updateCorrectSegments(self, line): # type: (NewLine) -> None line.updateCorrectSegments(line.asSegment()) class Record: def __init__(self, resolved, next, correct, good_reads, read_bounds): # type: (Segment, int, Segment, Set[str], Dict[str, int]) -> None self.line = resolved.contig # type: NewLine self.resolved = resolved self.old_resolved = [] self.next_resolved_start = next self.correct = correct self.good_reads = good_reads self.read_bounds = read_bounds self.reads = [] # type: List[AlignmentPiece] self.sorted = True self.potential_good = [] def setResolved(self, seg): # type: (Segment) -> None if seg.interSize(self.resolved) >= params.k - 1: self.resolved = self.resolved.cup(seg) else: self.old_resolved.append(self.resolved) self.resolved = seg self.updateGood() def add(self, al): # type: (AlignmentPiece) -> None tmp = list(al.split(100)) if len(tmp) > 1: al = Scorer().polyshAlignment( al, params.alignment_correction_radius) for al1 in al.split(100): self.innerAdd(al) else: self.innerAdd(al) def innerAdd(self, al): # type: (AlignmentPiece) -> None if al.seg_from.left < params.bad_end_length: self.potential_good.append(al) else: self.reads.append(al) read = al.seg_from.contig # type: AlignedRead if read.id not in self.read_bounds: self.read_bounds[read.id] = len(read) if al.rc.seg_to.left < 50: self.read_bounds[read.id] = min(self.read_bounds[read.id], al.seg_from.right) self.sorted = False def addAll(self, als): # type: (Iterator[AlignmentPiece]) -> None for al in als: self.add(al) def sort(self): if not self.sorted: self.reads = sorted(self.reads, key=lambda al: -al.seg_to.left) self.potential_good = sorted(self.potential_good, key=lambda al: -al.seg_to.left) self.sorted = True def get(self, num=None, right=None, min_inter=0): # type: (int, Segment, int) -> List[AlignmentPiece] self.sort() if num is None: num = len(self.reads) if right is None: right = self.resolved.right popped = [] res = [] while len(res) < num and len( self.reads) > 0 and self.reads[-1].seg_to.left < right: al = self.reads.pop() necessary_contig_support = min( len(al.seg_from.contig), al.seg_from.left + params.k + 100) if al.seg_from.contig.id not in self.good_reads or necessary_contig_support > self.read_bounds[ al.seg_from.contig.id]: popped.append(al) if len(al.seg_to) >= min_inter: res.append(al) self.reads.extend(popped[::-1]) return res def __iter__(self): for al in self.reads[::-1]: necessary_contig_support = min( len(al.seg_from.contig), al.seg_from.left + params.k + 100) if al.seg_from.contig.id not in self.good_reads or necessary_contig_support > self.read_bounds[ al.seg_from.contig.id]: yield al def unsupportedAlignments(self, inter_size): for al in self.reads[::-1]: necessary_contig_support = min( len(al.seg_from.contig), al.seg_from.left + inter_size + 100) if al.seg_from.contig.id not in self.good_reads or necessary_contig_support > self.read_bounds[ al.seg_from.contig.id]: yield al def updateGood(self): self.sort() while len(self.reads) > 0 and self.reads[ -1].seg_to.left <= self.resolved.right - params.k: al = self.reads.pop() if al.seg_to.interSize(self.resolved) >= params.k: if al.seg_from.contig.id not in self.good_reads: sys.stdout.trace("New good read:", al) self.good_reads.add(al.seg_from.contig.id) else: sys.stdout.trace("Read does not overlap resolved", al, self.resolved) while len(self.potential_good) > 0 and self.potential_good[ -1].seg_to.left <= self.resolved.right - params.k: al = self.potential_good.pop() if al.seg_to.interSize(self.resolved) >= params.k: if al.seg_from.contig.id not in self.good_reads: sys.stdout.trace("New good read from potential:", al) self.good_reads.add(al.seg_from.contig.id) else: sys.stdout.trace("Read does not overlap resolved", al, self.resolved) def pop(self): return self.reads.pop() def __str__(self): return str([ self.resolved, self.correct, self.next_resolved_start, self.reads ]) def createRecord(self, resolved, next_start, correct, als, good_reads, read_bounds): # type: (Segment, int, Segment, List[AlignmentPiece], Set[str], Dict[str, int]) -> Record line = resolved.contig # type: NewLine focus = line.segment(resolved.right - params.k, min(correct.right, next_start + params.k)) res = self.Record(resolved, next_start, correct, good_reads, read_bounds) res.addAll(als) res.updateGood() return res def findResolvedBound(self, rec, inter_size): # type: (Record, int) -> int bad_reads = [] for read in rec.unsupportedAlignments(inter_size): if len(read.seg_to) >= inter_size: bad_reads.append(read) if len(bad_reads) >= params.min_contra_for_break: if bad_reads[-1].seg_to.left - bad_reads[0].seg_to.left > 50: bad_reads = bad_reads[1:] else: break if len(bad_reads) < params.min_contra_for_break: sys.stdout.trace("No resolved bound for", rec.resolved) return len(rec.line) else: sys.stdout.trace("Resolved bound for", rec.resolved, ":", bad_reads[0].seg_to.left) sys.stdout.trace("Bound caused by read alignments:", map(str, bad_reads)) return bad_reads[0].seg_to.left def attemptProlongResolved(self, rec): # type: (Record) -> bool sys.stdout.trace("Working on prolonging", rec.resolved) res = self.findAndFilterResolvedBound(rec, params.k) if res <= rec.resolved.right: sys.stdout.trace("No luck with", rec.resolved, rec.line.correct_segments) return False sys.stdout.trace("Prolonged", rec.resolved, "to", res) rec.setResolved(rec.resolved.contig.segment(rec.resolved.left, res)) return True def findAndFilterResolvedBound(self, rec, sz): bound0 = self.findResolvedBound(rec, sz) + params.k * 9 / 10 bound = min(rec.correct.right, rec.next_resolved_start + sz - 1, bound0) res = rec.resolved.right if bound > rec.resolved.right: sys.stdout.trace("Checking resolved bound against known copies") candidates = self.segmentsWithGoodCopies( rec.resolved, rec.line.segment(max(0, rec.resolved.right - sz), bound), sz) sys.stdout.trace("Candidates:", candidates) for candidate in candidates: if candidate.left == max( 0, rec.resolved.right - sz) and candidate.right > rec.resolved.right: res = candidate.right sys.stdout.trace("Final resolved bound for", rec.resolved, " and k =", sz, ":", res) return res def attemptJump(self, rec): # type: (Record) -> bool bound = self.findAndFilterResolvedBound(rec, params.l) bad_segments = SegmentStorage() for al in rec: if al.seg_to.left > bound: break if al.seg_from.left > min(params.bad_end_length, params.k / 2) and \ al.rc.seg_from.left > min(params.bad_end_length, params.k / 2): bad_segments.add(al.seg_to) for al in self.dot_plot.allInter( rec.line.segment(rec.resolved.right - params.k, bound)): if al.seg_from.left > min(params.bad_end_length, params.k / 2): if al.rc.seg_from.left > min(params.bad_end_length, params.k / 2): bad_segments.add(al.seg_to) bad_segments.mergeSegments(params.k - 200) sys.stdout.trace("Bad segments:", bad_segments) good_segments = bad_segments.reverse(rec.line, params.k - 100).reduce( rec.line.segment(rec.resolved.right - params.k, bound)) for seg in good_segments: seg = Segment(seg.contig, max(0, seg.left), seg.right) for seg1 in self.segmentsWithGoodCopies(rec.resolved, seg, params.k): if len(seg1) >= params.k and seg1.right > rec.resolved.right: rec.setResolved(seg1) return True return False def segmentsWithGoodCopies(self, resolved, seg, inter_size): # type: (Segment, Segment, int) -> List[Segment] als = [ al for al in self.dot_plot.allInter(seg) if al.seg_from.left > 20 or al.rc.seg_to.left > 20 or al.isIdentical() ] segs = SegmentStorage() for al in als: line = al.seg_from.contig # type: NewLine if len(al.seg_to ) >= inter_size and al.seg_from.right > line.initial[ 0].seg_to.left: cap = al.seg_from.cap( line.suffix(pos=line.initial[0].seg_to.left)) incorrect = line.correct_segments.reverse( line, inter_size - 1).reduce(cap) matching = al.matchingSequence() sys.stdout.trace("Incorrect: ", line, cap, incorrect) for seg1 in incorrect: seg2 = matching.mapSegDown(seg.contig, seg1, mapIn=False) sys.stdout.trace( "Relevant unpolished k-mer segment alignment:", seg1, seg2) segs.add(seg2) if al.rc.seg_from.left < 50 and al.seg_to.right >= resolved.right - 100: segs.add( al.seg_to.contig.suffix( pos=al.seg_to.right).expand(inter_size + 100)) sys.stdout.trace( "Incoming line:", resolved, seg, al, al.seg_to.contig.suffix( pos=al.seg_to.right).expand(inter_size + 100)) segs.mergeSegments(inter_size - 1) return list( segs.reverse(seg.contig, inter_size - 1 - min(100, inter_size / 10)).reduce(seg))
def assemble(args, bin_path): params.bin_path = bin_path start = time.time() cl_params = Params().parse(args) ref = ContigStorage() if cl_params.test: cl_params.reads_file = os.path.dirname(__file__) + "/../../test_dataset/reads.fasta" cl_params.genome_size = 30000 cl_params.dir = os.path.dirname(__file__) + "/../../test_results" ref.loadFromFile(os.path.dirname(__file__) + "/../../test_dataset/axbctbdy.fasta", False) if cl_params.debug: params.save_alignments = True cl_params.check() CreateLog(cl_params.dir) sys.stdout.info("Command line:", " ".join(cl_params.args)) sys.stdout.info("Started") if cl_params.debug: sys.stdout.info("Version:", subprocess.check_output(["git", "rev-parse", "HEAD"])) sys.stdout.info("Modifications:") print subprocess.check_output(["git", "diff"]) sys.stdout.info("Preparing initial state") if cl_params.debug: save_handler = SaveHandler(os.path.join(cl_params.dir, "saves")) else: save_handler = None if cl_params.load_from is not None: # tmp = cl_params.focus sys.stdout.info("Loading initial state from saves") cl_params, aligner, contigs, reads, disjointigs, lines, dot_plot = loadAll(TokenReader(open(cl_params.load_from, "r"))) cl_params.parse(args) # cl_params.focus = tmp knotter = LineMerger(lines, Polisher(aligner, aligner.dir_distributor), dot_plot) extender = LineExtender(aligner, knotter, disjointigs, dot_plot) dot_plot.printAll(sys.stdout) printState(lines) else: aligner = Aligner(DirDistributor(cl_params.alignmentDir())) polisher = Polisher(aligner, aligner.dir_distributor) reads = CreateReadCollection(cl_params.reads_file, cl_params.cut_reads, cl_params.downsample) if cl_params.contigs_file is None: sys.stdout.info("Running Flye") assembly_dir = os.path.join(cl_params.dir, "assembly_initial") reads_file = os.path.join(cl_params.dir, "actual_reads.fasta") reads.print_fasta(open(reads_file, "w")) subprocess.check_call([os.path.join(params.bin_path, "flye"), "--meta", "-o", assembly_dir, "-t", str(cl_params.threads), "--" + params.technology + "-raw", reads_file, "--genome-size", str(cl_params.genome_size), "--min-overlap", str(params.k)]) cl_params.set_flye_dir(assembly_dir, cl_params.mode) elif len(cl_params.disjointigs_file_list) == 0: assembly_dir = os.path.join(cl_params.dir, "assembly_initial") reads_file = os.path.join(cl_params.dir, "actual_reads.fasta") reads.print_fasta(open(reads_file, "w")) disjointigs_file = constructDisjointigs(reads, params.expected_size, assembly_dir) # graph_file, contigs_file, disjointigs_file, rep_dir, graph_file_after, contigs_file_after = parseFlyeDir(assembly_dir) cl_params.disjointigs_file_list.append(disjointigs_file) params.min_contra_for_break = 8 disjointigs = CreateDisjointigCollection(cl_params.disjointigs_file_list, cl_params.dir, aligner, reads) all_unique = cl_params.init_file is not None contigs = CreateContigCollection(cl_params.graph_file, cl_params.contigs_file, cl_params.min_cov, aligner, polisher, reads, cl_params.force_unique, all_unique) if cl_params.autoKL: adjustKL(aligner, reads, contigs) if cl_params.init_file is None: ExtendShortContigs(contigs, reads, aligner, polisher, cl_params.read_dump) lines = CreateLineCollection(cl_params.dir, aligner, contigs, disjointigs, reads, cl_params.split) else: lines = LoadLineCollection(cl_params.dir, cl_params.init_file, aligner, contigs, disjointigs, reads, polisher) sys.stdout.info("Constructing dot plot") dot_plot = LineDotPlot(lines, aligner) dot_plot.construct(aligner) # dot_plot.printAll(sys.stdout) sys.stdout.info("Updating sequences and resolved segments.") knotter = LineMerger(lines, Polisher(aligner, aligner.dir_distributor), dot_plot) extender = LineExtender(aligner, knotter, disjointigs, dot_plot) extender.updateAllStructures(itertools.chain.from_iterable(line.completely_resolved for line in lines)) for line in list(lines.unique()): # type: NewLine line.completely_resolved.mergeSegments() if len(line.completely_resolved) == 0: lines.removeLine(line) if cl_params.debug: sys.stdout.info( "Saving initial state") try: writer = save_handler.getWriter() sys.stdout.info("Save details:", writer.info) saveAll(writer, cl_params, aligner, contigs, reads, disjointigs, lines, dot_plot) except Exception as e: _, _, tb = sys.exc_info() sys.stdout.warn("Could not write save") traceback.print_tb(tb) sys.stdout.INFO( "Message:", e.message) sys.stdout.trace( "Disjointig alignments") for line in lines: sys.stdout.trace( line.disjointig_alignments) sys.stdout.info("Starting expanding alignment-consensus loop") EACL(aligner, cl_params, contigs, disjointigs, dot_plot, extender, lines, reads, save_handler) dot_plot.printAll(sys.stdout) sys.stdout.trace( "Final result:") lines.printToFasta(open(os.path.join(cl_params.dir, "lines.fasta"), "w")) lines.printKnottedToFasta(open(os.path.join(cl_params.dir, "assembly.fasta"), "w")) printState(lines) sys.stdout.info("Finished") secs = int(time.time() - start) days = secs / 60 / 60 / 24 hours = secs / 60 / 60 % 24 mins = secs / 60 % 60 sys.stdout.info("Finished in %d days, %d hours, %d minutes" % (days, hours, mins)) if cl_params.test: passed = False for al in aligner.dotplotAlign(lines, ref): if len(al) > len(al.seg_to.contig) - 3000: passed = True break if passed: sys.stdout.info("Test passed") else: sys.stdout.info("Test failed")
def main(contigs_file, contig_name, reads_file, dir, k, initial_reads1, initial_reads2): basic.ensure_dir_existance(dir) basic.CreateLog(dir) dd = DirDistributor(os.path.join(dir, "alignments")) aligner = Aligner(dd) contigs = ContigStorage().loadFromFasta(open(contigs_file, "r"), False) # contig = contigs[contig_name].asSegment().prefix(length=2000).asContig() contig = contigs[contig_name] reads = ContigStorage().loadFromFasta(open(reads_file, "r"), False) reads1 = ContigStorage() reads2 = ContigStorage() cnt = 0 for read in reads.unique(): cnt += 1 # if cnt % 2 == 0: if read.id in initial_reads1: reads1.add(read) elif read.id in initial_reads2: reads2.add(read) polisher = Polisher(aligner, dd) contig1 = contig contig2 = contig scorer = Scorer() for i in range(3): diff = 0 print "Iteration", i als1 = fixAlDir(aligner.overlapAlign(reads1.unique(), ContigStorage([contig])), contig) als2 = fixAlDir(aligner.overlapAlign(reads2.unique(), ContigStorage([contig])), contig) contig1 = Contig(polisher.polishSmallSegment(contig.asSegment(), als1).seg_from.Seq(), "1") contig2 = Contig(polisher.polishSmallSegment(contig.asSegment(), als2).seg_from.Seq(), "2") al = aligner.overlapAlign([contig1], ContigStorage([contig2])).next() als1 = fixAlDir(aligner.overlapAlign(reads.unique(), ContigStorage([contig1])), contig1) als1 = filter(lambda al: len(al.seg_to) > len(al.seg_to.contig) - 100, als1) als2 = fixAlDir(aligner.overlapAlign(reads.unique(), ContigStorage([contig2])), contig2) als2 = filter(lambda al: len(al.seg_to) > len(al.seg_to.contig) - 100, als2) als1 = sorted(als1, key = lambda al: al.seg_from.contig.id) als2 = sorted(als2, key = lambda al: al.seg_from.contig.id) reads1 = ContigStorage() reads2 = ContigStorage() dp = scorer.accurateScore(al.matchingSequence(), 10) #1 - al.percentIdentity() als_map = dict() for al in als1: als_map[al.seg_from.contig.id] = [al] for al in als2: if al.seg_from.contig.id in als_map: als_map[al.seg_from.contig.id].append(al) com_res = [] diffs = [] for tmp_als in als_map.values(): if len(tmp_als) != 2: continue al1 = tmp_als[0] al2 = tmp_als[1] print al1, al2 assert al1.seg_from.contig == al2.seg_from.contig pi1 = scorer.accurateScore(al1.matchingSequence(), 10) # al1.percentIdentity() pi2 = scorer.accurateScore(al2.matchingSequence(), 10) # al2.percentIdentity() com_res.append((al1, al2, pi1 - pi2)) diffs.append(pi1 - pi2) diffs = sorted(diffs) th1 = diffs[len(diffs) / 4] th2 = diffs[len(diffs) * 3 / 4] print "Thresholds:", th1, th2 for al1, al2, diff in com_res: if diff < th1: reads1.add(al1.seg_from.contig) elif diff > th2: reads2.add(al2.seg_from.contig) # if pi1 > pi2 + dp / 4: # reads1.add(al1.seg_from.contig) # elif pi2 > pi1 + dp / 4: # reads2.add(al2.seg_from.contig) # diff += abs(pi1 - pi2) print float(diff) / len(als1), len(reads1) / 2, len(reads2) / 2 al = aligner.overlapAlign([contig1], ContigStorage([contig2])).next() print al print "\n".join(al.asMatchingStrings2()) for read in reads1: if read.id in initial_reads1: sys.stdout.write(read.id + " ") print "" for read in reads2: if read.id in initial_reads2: sys.stdout.write(read.id + " ") print "" contig1 = prolong(aligner, polisher, contig1, reads1) contig2 = prolong(aligner, polisher, contig2, reads2) contig1.id = "1" contig2.id = "2" out = open(os.path.join(dir, "copies.fasta"), "w") SeqIO.write(contig1, out, "fasta") SeqIO.write(contig2, out, "fasta") out.close() out = open(os.path.join(dir, "reads1.fasta"), "w") for read in reads1.unique(): SeqIO.write(read, out, "fasta") out.close() out = open(os.path.join(dir, "reads2.fasta"), "w") for read in reads2.unique(): SeqIO.write(read, out, "fasta") out.close() print "Finished"
def splitSeg(aligner, seg, mult, all_reads_list): all_reads = ContigStorage() base = seg.asContig() tmp = [] for al in fixAlDir( aligner.overlapAlign(all_reads_list, ContigStorage([base])), base): if len(al.seg_to) < len(base) - 100: continue all_reads.add(al.seg_from.contig) tmp.append(al.seg_from.contig) all_reads_list = tmp split_reads = [] split_contigs = [] for i in range(mult): split_reads.append([]) split_contigs.append(base) cnt = 0 for read in all_reads_list: split_reads[cnt % mult].append(read) polisher = Polisher(aligner, aligner.dir_distributor) for i in range(10): print "Iteration", i split_contigs = [] for reads in split_reads: tmp_als = fixAlDir( aligner.overlapAlign(reads, ContigStorage([base])), base) split_contigs.append( Contig( polisher.polishSmallSegment(base.asSegment(), tmp_als).seg_from.Seq(), str(len(split_contigs)))) bestals = dict() for read in all_reads_list: bestals[read.id] = None for contig in split_contigs: for al in fixAlDir( aligner.overlapAlign(all_reads_list, ContigStorage([contig])), contig): if len(al.seg_to) < len(base) - 100: continue if al.seg_from.contig.id not in bestals: print bestals.keys() print al if bestals[al.seg_from.contig. id] is None or al.percentIdentity() > bestals[ al.seg_from.contig.id].percentIdentity(): bestals[al.seg_from.contig.id] = al # als.append(fixAlDir(aligner.overlapAlign(all_reads_list, ContigStorage([contig])), contig)) # als[-1] = sorted(als[-1], key = lambda al: al.seg_from.contig.id) for i in range(mult): split_reads[i] = [] for rid in bestals: al = bestals[rid] if al is None: print "Warning: no alignment for read", rid else: split_reads[int(al.seg_to.contig.id)].append( al.seg_from.contig) print " ".join(map(str, map(len, split_reads))) maxpi = 0 print "pi matrix:" for i in range(mult): for j in range(mult): al = aligner.overlapAlign([split_contigs[i]], ContigStorage([split_contigs[j] ])).next() sys.stdout.write(str(al.percentIdentity()) + " ") maxpi = max(maxpi, al.percentIdentity()) print "" print "Maxpi:", maxpi if maxpi < 0.985: return zip(split_contigs, split_reads) else: return None