def collapse_nonbranching_paths(self): def node_on_nonbranching_path(node): if self.nx_graph.in_degree(node) != 1 or \ self.nx_graph.out_degree(node) != 1: return False in_edge = fst_iterable(self.nx_graph.in_edges(node, keys=True)) out_edge = fst_iterable(self.nx_graph.out_edges(node, keys=True)) in_edge_color = self.nx_graph.get_edge_data(*in_edge)[self.color] out_edge_color = self.nx_graph.get_edge_data(*out_edge)[self.color] # if in_edge == out_edge this is a loop and should not be removed return in_edge != out_edge and in_edge_color == out_edge_color nodes = list(self.nx_graph) removed_edge_indexes = [] for node in nodes: if node_on_nonbranching_path(node): in_edge = fst_iterable(self.nx_graph.in_edges(node, keys=True)) out_edge = fst_iterable( self.nx_graph.out_edges(node, keys=True)) in_node, out_node = in_edge[0], out_edge[1] in_data = self.nx_graph.get_edge_data(*in_edge) out_data = self.nx_graph.get_edge_data(*out_edge) in_color = in_data[self.color] out_color = out_data[self.color] assert in_color == out_color color = in_color in_string = in_data[self.string] out_string = out_data[self.string] len_node = len(self.nodeindex2label[node]) string = in_string + out_string[len_node:] string = tuple(string) edge_len = len(string) - len_node edge_index = in_data[self.edge_index] self._add_edge(color=color, string=string, in_node=in_node, out_node=out_node, in_data=in_data, out_data=out_data, edge_len=edge_len, edge_index=edge_index) removed_edge_indexes.append(out_data[self.edge_index]) self.nx_graph.remove_node(node) label = self.nodeindex2label[node] del self.nodeindex2label[node] del self.nodelabel2index[label] del self.edge_index2edge[out_data[self.edge_index]] self._assert_nx_graph_validity() return removed_edge_indexes
def node_on_nonbranching_path(node): if self.nx_graph.in_degree(node) != 1 or \ self.nx_graph.out_degree(node) != 1: return False in_edge = fst_iterable(self.nx_graph.in_edges(node, keys=True)) out_edge = fst_iterable(self.nx_graph.out_edges(node, keys=True)) in_edge_color = self.nx_graph.get_edge_data(*in_edge)[self.color] out_edge_color = self.nx_graph.get_edge_data(*out_edge)[self.color] # if in_edge == out_edge this is a loop and should not be removed return in_edge != out_edge and in_edge_color == out_edge_color
def def_get_frequent_kmers(kmer_index, string_set, min_mult, min_mult_rescue=4): min_mult_rescue = min(min_mult, min_mult_rescue) if min_mult > 1: assert min_mult_rescue > 1 # otherwise need to handle '?' symbs assert len(kmer_index) > 0 k = len(fst_iterable(kmer_index)) frequent_kmers = {kmer: len(pos) for kmer, pos in kmer_index.items() if len(pos) >= min_mult} # s_id -> pos left_freq = defaultdict(lambda: 2**32) right_freq = defaultdict(int) for kmer in frequent_kmers: for s_id, p in kmer_index[kmer]: left_freq[s_id] = min(left_freq[s_id], p) right_freq[s_id] = max(right_freq[s_id], p) ext_frequent_kmers = {} for kmer, pos in kmer_index.items(): if len(pos) < min_mult_rescue: continue for s_id, p in pos: if left_freq[s_id] <= p <= right_freq[s_id]: ext_frequent_kmers[kmer] = len(pos) break return ext_frequent_kmers def get_first_reliable(kmers): for i, kmer in enumerate(kmers): if kmer not in kmer_index: continue kmer_mult = kmer_index[kmer] if kmer_mult >= min_mult: return i return None for s_id, string in string_set.items(): kmers = [tuple(string[i:i+k]) for i in range(len(string)-k+1)] left = get_first_reliable(kmers) if left is None: continue right = get_first_reliable(kmers[::-1]) assert right is not None right = len(string) - k - right assert left <= right for i in range(left, right+1): kmer = string[i:i+k] kmer = tuple(kmer) if kmer not in kmer_index: continue kmer_mult = kmer_index[kmer] if kmer_mult >= min_mult_rescue: frequent_kmers[kmer] = kmer_mult return frequent_kmers
def from_mono_db(cls, db, monostring_set, mappings=None): monostring = fst_iterable(monostring_set.values()) neutral_symbs = set([monostring.gap_symb]) return cls.fromDB(db=db, string_set=monostring_set, neutral_symbs=neutral_symbs, raw_mappings=mappings)
def correct_kmers(kmer_index_w_pos, string_set, max_ident_diff=0.02): k = len(fst_iterable(kmer_index_w_pos)) assert k % 2 == 1 k2 = k//2 km1mer2central = defaultdict(Counter) for kmer, pos in kmer_index_w_pos.items(): km1mer = kmer[:k2] + kmer[k2+1:] central = kmer[k2] for s_id, p in pos: string = string_set[s_id] mi = string.monoinstances[p] if abs(mi.identity - mi.sec_identity) < max_ident_diff: km1mer2central[km1mer][central] += 1 top_gain, top_subst = 0, None for km1mer, central_counter in km1mer2central.items(): if len(central_counter) != 2: continue mono_ind1, mono_ind2 = central_counter cnt1, cnt2 = central_counter.values() kmer1 = km1mer[:k2] + (mono_ind1,) + km1mer[k2:] kmer2 = km1mer[:k2] + (mono_ind2,) + km1mer[k2:] assert kmer1 in kmer_index_w_pos assert kmer2 in kmer_index_w_pos gain12, _ = get_gain(kmer=kmer1, subst_monomer=mono_ind2, kmer_index=kmer_index_w_pos, string_set=string_set) gain21, _ = get_gain(kmer=kmer2, subst_monomer=mono_ind1, kmer_index=kmer_index_w_pos, string_set=string_set) if gain12 > top_gain: top_gain = gain12 top_subst = (kmer1, mono_ind2) if gain21 > top_gain: top_gain = gain21 top_subst = (kmer2, mono_ind1) print(f'Top gain = {top_gain}') return top_subst
def process_complex(): # complex vertex for i in in_indexes: old_edge = self.index2edge[i] new_edge = (old_edge[0], self.get_new_vertex_index(), 0) self.move_edge(*old_edge, *new_edge) for j in out_indexes: old_edge = self.index2edge[j] new_edge = (self.get_new_vertex_index(), old_edge[1], 0) self.move_edge(*old_edge, *new_edge) ac_s2e = defaultdict(set) ac_e2s = defaultdict(set) paired_in = set() paired_out = set() for e_in in in_indexes: for e_out in out_indexes: if (e_in, e_out) in self.idb_mappings.pairindex2pos: ac_s2e[e_in].add(e_out) ac_e2s[e_out].add(e_in) paired_in.add(e_in) paired_out.add(e_out) loops = set(in_indexes) & set(out_indexes) if len(loops) == 1: loop = fst_iterable(loops) if loop in self.unique_edges: rest_in = set(in_indexes) - loops rest_out = set(out_indexes) - loops if len(rest_in) == 1: in_index = fst_iterable(rest_in) ac_s2e[in_index].add(loop) ac_e2s[loop].add(in_index) if len(rest_out) == 1: out_index = fst_iterable(rest_out) ac_s2e[loop].add(out_index) ac_e2s[out_index].add(loop) unpaired_in = set(in_indexes) - paired_in unpaired_out = set(out_indexes) - paired_out if len(unpaired_in) == 1 and len(unpaired_out) == 1: if len(set(in_indexes) - self.unique_edges) == 0 or \ len(set(out_indexes) - self.unique_edges) == 0: unpaired_in_single = list(unpaired_in)[0] unpaired_out_single = list(unpaired_out)[0] ac_s2e[unpaired_in_single].add(unpaired_out_single) ac_e2s[unpaired_out_single].add(unpaired_in_single) # print(u, ac_s2e, ac_e2s) merged = {} for i in ac_s2e: for j in ac_s2e[i]: # print(u, i, j, ac_s2e[i], ac_e2s[j]) if i in merged: i = merged[i] if j in merged: j = merged[j] e_i = self.index2edge[i] e_j = self.index2edge[j] in_seq = self.edge2seq[i] out_seq = self.edge2seq[j] assert in_seq[-nlen:] == out_seq[:nlen] if len(ac_s2e[i]) == len(ac_e2s[j]) == 1: if e_i != e_j: self.merge_edges(e_i, e_j) merged[j] = i else: # isolated loop self.move_edge(*e_i, e_i[0], e_i[0]) if in_seq[-nlen-1:] != in_seq[:nlen+1]: self.edge2seq[i].append(in_seq[nlen]) elif len(ac_s2e[i]) >= 2 and len(ac_e2s[j]) >= 2: seq = in_seq[-nlen-1:] + [out_seq[nlen]] assert len(seq) == nlen + 2 self.add_edge(i, j, seq) elif len(ac_s2e[i]) == 1 and len(ac_e2s[j]) >= 2: # extend left edge to the right self.move_edge(*e_i, e_i[0], e_j[0]) seq = in_seq + [out_seq[nlen]] self.edge2seq[i] = seq elif len(ac_e2s[j]) == 1 and len(ac_s2e[i]) >= 2: # extend right edge to the left self.move_edge(*e_j, e_i[1], e_j[1]) seq = [in_seq[-nlen-1]] + out_seq self.edge2seq[j] = seq else: assert False assert self.nx_graph.in_degree(u) == 0 assert self.nx_graph.out_degree(u) == 0 self.nx_graph.remove_node(u) del self.node2len[u]
def _update_unresolved_vertices(self): for u in self.nx_graph.nodes: if u in self.unresolved: continue in_indexes = set( [self.edge2index[e_in] for e_in in self.nx_graph.in_edges(u, keys=True)]) out_indexes = set( [self.edge2index[e_out] for e_out in self.nx_graph.out_edges(u, keys=True)]) indegree = self.nx_graph.in_degree(u) outdegree = self.nx_graph.out_degree(u) if indegree == 1 and outdegree == 1: self_loop = in_indexes == out_indexes # assert self_loop self.unresolved.add(u) elif indegree >= 2 and outdegree >= 2: # do not process anything at all # self.unresolved.add(u) # process only fully resolved vertices # all_ac = self.idb_mappings.get_active_connections() # pairs = set() # for e_in in in_indexes: # for e_out in out_indexes: # if (e_in, e_out) in all_ac: # pairs.add((e_in, e_out)) # all_pairs = set((e_in, e_out) # for e_in in in_indexes # for e_out in out_indexes) # if len(all_pairs - pairs): # self.unresolved.add(u) # initial heuristic paired_in = set() paired_out = set() loops = in_indexes & out_indexes if len(loops) == 1: loop = fst_iterable(loops) if loop in self.unique_edges: rest_in = in_indexes - loops rest_out = out_indexes - loops if len(rest_in) == 1: in_index = fst_iterable(rest_in) paired_in.add(in_index) paired_out.add(loop) if len(rest_out) == 1: out_index = fst_iterable(rest_out) paired_in.add(loop) paired_out.add(out_index) for e_in in in_indexes: for e_out in out_indexes: if (e_in, e_out) in self.idb_mappings.pairindex2pos: paired_in.add(e_in) paired_out.add(e_out) unpaired_in = set(in_indexes) - paired_in unpaired_out = set(out_indexes) - paired_out if len(unpaired_in) == 1 and len(unpaired_out) == 1: if len(set(in_indexes) - self.unique_edges) == 0 or \ len(set(out_indexes) - self.unique_edges) == 0: unpaired_in_single = list(unpaired_in)[0] unpaired_out_single = list(unpaired_out)[0] paired_in.add(unpaired_in_single) paired_out.add(unpaired_out_single) tips = (in_indexes - paired_in) | (out_indexes - paired_out) if len(tips): self.unresolved.add(u) prev, new = self.unresolved, set() while len(prev): for u in prev: for edge in self.nx_graph.in_edges(u, keys=True): index = self.edge2index[edge] seq = self.edge2seq[index] v = edge[0] if v in self.unresolved: continue if self.node2len[v] + 1 == len(seq): new.add(v) for edge in self.nx_graph.out_edges(u, keys=True): index = self.edge2index[edge] seq = self.edge2seq[index] v = edge[1] if v in self.unresolved: continue if self.node2len[v] + 1 == len(seq): new.add(v) self.unresolved |= new prev, new = new, set()
def _find_string_overlaps(string, neutral_symbs, overlap_penalty, edges, addEq): monolen = len(string) neutral_symb = fst_iterable(neutral_symbs) neutral_run_len = max(0, monolen - overlap_penalty) neutral_run = neutral_symb * neutral_run_len neutral_run = tuple(neutral_run) overlaps = [] for edge, edge_string in edges.items(): ext_edge_string = neutral_run + edge_string + neutral_run align = edlib.align(string, ext_edge_string, mode='HW', k=0, task='locations', additionalEqualities=addEq) locations = align['locations'] for e_st, e_en in locations: e_en += 1 # [e_s; e_e) assert e_en - e_st == monolen # 0123456789 # --read---- # ???edge??? => e_s, e_e = [2, 6) # left_hanging = max(0, 3-2) = 1 # right_hanging = max(0, 6-3-4) = 0 # s_s = left_hanging = 1 # s_e = 4-right_hanging = 4-0 = 4 # e_s = max(0, 2-3) = 0 # e_e = 6-3-0=3 # 0123456789 # ----read-- # ???edge??? => e_s, e_e = [4, 8) # left_hanging = max(0, 3-4) = 0 # right_hanging = max(0, 8-3-4) = 1 # s_s = 0 # s_e = 4-1 = 3 # e_s = max(0, 4-3) = 1 # e_e = 8-3-1=4 left_hanging = max(0, neutral_run_len-e_st) right_hanging = \ max(0, e_en-neutral_run_len-len(edge_string)) s_st, s_en = left_hanging, monolen - right_hanging e_st = max(0, e_st-neutral_run_len) e_en = e_en - neutral_run_len - right_hanging assert e_en == e_st + s_en - s_st assert 0 <= e_st < e_en <= len(edge_string) assert 0 <= s_st < s_en <= monolen for c1, c2 in zip(string[s_st:s_en], edge_string[e_st:e_en]): if c1 != c2: assert c1 in neutral_symbs or \ c2 in neutral_symbs overlap = Overlap(edge=edge, s_st=s_st, s_en=s_en, e_st=e_st, e_en=e_en) overlaps.append(overlap) overlaps.sort(key=lambda overlap: overlap.s_en) return overlaps