예제 #1
0
    def collapse_nonbranching_paths(self):
        def node_on_nonbranching_path(node):
            if self.nx_graph.in_degree(node) != 1 or \
               self.nx_graph.out_degree(node) != 1:
                return False

            in_edge = fst_iterable(self.nx_graph.in_edges(node, keys=True))
            out_edge = fst_iterable(self.nx_graph.out_edges(node, keys=True))

            in_edge_color = self.nx_graph.get_edge_data(*in_edge)[self.color]
            out_edge_color = self.nx_graph.get_edge_data(*out_edge)[self.color]

            # if in_edge == out_edge this is a loop and should not be removed
            return in_edge != out_edge and in_edge_color == out_edge_color

        nodes = list(self.nx_graph)
        removed_edge_indexes = []
        for node in nodes:
            if node_on_nonbranching_path(node):
                in_edge = fst_iterable(self.nx_graph.in_edges(node, keys=True))
                out_edge = fst_iterable(
                    self.nx_graph.out_edges(node, keys=True))
                in_node, out_node = in_edge[0], out_edge[1]

                in_data = self.nx_graph.get_edge_data(*in_edge)
                out_data = self.nx_graph.get_edge_data(*out_edge)

                in_color = in_data[self.color]
                out_color = out_data[self.color]
                assert in_color == out_color
                color = in_color

                in_string = in_data[self.string]
                out_string = out_data[self.string]
                len_node = len(self.nodeindex2label[node])
                string = in_string + out_string[len_node:]
                string = tuple(string)
                edge_len = len(string) - len_node

                edge_index = in_data[self.edge_index]

                self._add_edge(color=color,
                               string=string,
                               in_node=in_node,
                               out_node=out_node,
                               in_data=in_data,
                               out_data=out_data,
                               edge_len=edge_len,
                               edge_index=edge_index)
                removed_edge_indexes.append(out_data[self.edge_index])

                self.nx_graph.remove_node(node)
                label = self.nodeindex2label[node]
                del self.nodeindex2label[node]
                del self.nodelabel2index[label]
                del self.edge_index2edge[out_data[self.edge_index]]

        self._assert_nx_graph_validity()
        return removed_edge_indexes
예제 #2
0
        def node_on_nonbranching_path(node):
            if self.nx_graph.in_degree(node) != 1 or \
               self.nx_graph.out_degree(node) != 1:
                return False

            in_edge = fst_iterable(self.nx_graph.in_edges(node, keys=True))
            out_edge = fst_iterable(self.nx_graph.out_edges(node, keys=True))

            in_edge_color = self.nx_graph.get_edge_data(*in_edge)[self.color]
            out_edge_color = self.nx_graph.get_edge_data(*out_edge)[self.color]

            # if in_edge == out_edge this is a loop and should not be removed
            return in_edge != out_edge and in_edge_color == out_edge_color
예제 #3
0
def def_get_frequent_kmers(kmer_index, string_set,
                           min_mult, min_mult_rescue=4):
    min_mult_rescue = min(min_mult, min_mult_rescue)
    if min_mult > 1:
        assert min_mult_rescue > 1  # otherwise need to handle '?' symbs
    assert len(kmer_index) > 0
    k = len(fst_iterable(kmer_index))

    frequent_kmers = {kmer: len(pos) for kmer, pos in kmer_index.items()
                      if len(pos) >= min_mult}

    # s_id -> pos
    left_freq = defaultdict(lambda: 2**32)
    right_freq = defaultdict(int)

    for kmer in frequent_kmers:
        for s_id, p in kmer_index[kmer]:
            left_freq[s_id] = min(left_freq[s_id], p)
            right_freq[s_id] = max(right_freq[s_id], p)

    ext_frequent_kmers = {}
    for kmer, pos in kmer_index.items():
        if len(pos) < min_mult_rescue:
            continue
        for s_id, p in pos:
            if left_freq[s_id] <= p <= right_freq[s_id]:
                ext_frequent_kmers[kmer] = len(pos)
                break

    return ext_frequent_kmers


    def get_first_reliable(kmers):
        for i, kmer in enumerate(kmers):
            if kmer not in kmer_index:
                continue
            kmer_mult = kmer_index[kmer]
            if kmer_mult >= min_mult:
                return i
        return None

    for s_id, string in string_set.items():
        kmers = [tuple(string[i:i+k]) for i in range(len(string)-k+1)]
        left = get_first_reliable(kmers)
        if left is None:
            continue
        right = get_first_reliable(kmers[::-1])
        assert right is not None
        right = len(string) - k - right
        assert left <= right

        for i in range(left, right+1):
            kmer = string[i:i+k]
            kmer = tuple(kmer)
            if kmer not in kmer_index:
                continue
            kmer_mult = kmer_index[kmer]
            if kmer_mult >= min_mult_rescue:
                frequent_kmers[kmer] = kmer_mult
    return frequent_kmers
예제 #4
0
 def from_mono_db(cls, db, monostring_set, mappings=None):
     monostring = fst_iterable(monostring_set.values())
     neutral_symbs = set([monostring.gap_symb])
     return cls.fromDB(db=db,
                       string_set=monostring_set,
                       neutral_symbs=neutral_symbs,
                       raw_mappings=mappings)
예제 #5
0
def correct_kmers(kmer_index_w_pos, string_set,
                  max_ident_diff=0.02):
    k = len(fst_iterable(kmer_index_w_pos))
    assert k % 2 == 1
    k2 = k//2
    km1mer2central = defaultdict(Counter)
    for kmer, pos in kmer_index_w_pos.items():
        km1mer = kmer[:k2] + kmer[k2+1:]
        central = kmer[k2]
        for s_id, p in pos:
            string = string_set[s_id]
            mi = string.monoinstances[p]
            if abs(mi.identity - mi.sec_identity) < max_ident_diff:
                km1mer2central[km1mer][central] += 1

    top_gain, top_subst = 0, None
    for km1mer, central_counter in km1mer2central.items():
        if len(central_counter) != 2:
            continue
        mono_ind1, mono_ind2 = central_counter
        cnt1, cnt2 = central_counter.values()
        kmer1 = km1mer[:k2] + (mono_ind1,) + km1mer[k2:]
        kmer2 = km1mer[:k2] + (mono_ind2,) + km1mer[k2:]
        assert kmer1 in kmer_index_w_pos
        assert kmer2 in kmer_index_w_pos

        gain12, _ = get_gain(kmer=kmer1,
                             subst_monomer=mono_ind2,
                             kmer_index=kmer_index_w_pos,
                             string_set=string_set)

        gain21, _ = get_gain(kmer=kmer2,
                             subst_monomer=mono_ind1,
                             kmer_index=kmer_index_w_pos,
                             string_set=string_set)
        if gain12 > top_gain:
            top_gain = gain12
            top_subst = (kmer1, mono_ind2)
        if gain21 > top_gain:
            top_gain = gain21
            top_subst = (kmer2, mono_ind1)
    print(f'Top gain = {top_gain}')
    return top_subst
예제 #6
0
        def process_complex():
            # complex vertex
            for i in in_indexes:
                old_edge = self.index2edge[i]
                new_edge = (old_edge[0], self.get_new_vertex_index(), 0)
                self.move_edge(*old_edge, *new_edge)

            for j in out_indexes:
                old_edge = self.index2edge[j]
                new_edge = (self.get_new_vertex_index(), old_edge[1], 0)
                self.move_edge(*old_edge, *new_edge)

            ac_s2e = defaultdict(set)
            ac_e2s = defaultdict(set)
            paired_in = set()
            paired_out = set()
            for e_in in in_indexes:
                for e_out in out_indexes:
                    if (e_in, e_out) in self.idb_mappings.pairindex2pos:
                        ac_s2e[e_in].add(e_out)
                        ac_e2s[e_out].add(e_in)
                        paired_in.add(e_in)
                        paired_out.add(e_out)

            loops = set(in_indexes) & set(out_indexes)
            if len(loops) == 1:
                loop = fst_iterable(loops)
                if loop in self.unique_edges:
                    rest_in = set(in_indexes) - loops
                    rest_out = set(out_indexes) - loops
                    if len(rest_in) == 1:
                        in_index = fst_iterable(rest_in)
                        ac_s2e[in_index].add(loop)
                        ac_e2s[loop].add(in_index)
                    if len(rest_out) == 1:
                        out_index = fst_iterable(rest_out)
                        ac_s2e[loop].add(out_index)
                        ac_e2s[out_index].add(loop)

            unpaired_in = set(in_indexes) - paired_in
            unpaired_out = set(out_indexes) - paired_out
            if len(unpaired_in) == 1 and len(unpaired_out) == 1:
                if len(set(in_indexes) - self.unique_edges) == 0 or \
                        len(set(out_indexes) - self.unique_edges) == 0:
                    unpaired_in_single = list(unpaired_in)[0]
                    unpaired_out_single = list(unpaired_out)[0]
                    ac_s2e[unpaired_in_single].add(unpaired_out_single)
                    ac_e2s[unpaired_out_single].add(unpaired_in_single)

            # print(u, ac_s2e, ac_e2s)
            merged = {}
            for i in ac_s2e:
                for j in ac_s2e[i]:
                    # print(u, i, j, ac_s2e[i], ac_e2s[j])
                    if i in merged:
                        i = merged[i]
                    if j in merged:
                        j = merged[j]
                    e_i = self.index2edge[i]
                    e_j = self.index2edge[j]
                    in_seq = self.edge2seq[i]
                    out_seq = self.edge2seq[j]
                    assert in_seq[-nlen:] == out_seq[:nlen]
                    if len(ac_s2e[i]) == len(ac_e2s[j]) == 1:
                        if e_i != e_j:
                            self.merge_edges(e_i, e_j)
                            merged[j] = i
                        else:
                            # isolated loop
                            self.move_edge(*e_i, e_i[0], e_i[0])
                            if in_seq[-nlen-1:] != in_seq[:nlen+1]:
                                self.edge2seq[i].append(in_seq[nlen])
                    elif len(ac_s2e[i]) >= 2 and len(ac_e2s[j]) >= 2:
                        seq = in_seq[-nlen-1:] + [out_seq[nlen]]
                        assert len(seq) == nlen + 2
                        self.add_edge(i, j, seq)
                    elif len(ac_s2e[i]) == 1 and len(ac_e2s[j]) >= 2:
                        # extend left edge to the right
                        self.move_edge(*e_i, e_i[0], e_j[0])
                        seq = in_seq + [out_seq[nlen]]
                        self.edge2seq[i] = seq
                    elif len(ac_e2s[j]) == 1 and len(ac_s2e[i]) >= 2:
                        # extend right edge to the left
                        self.move_edge(*e_j, e_i[1], e_j[1])
                        seq = [in_seq[-nlen-1]] + out_seq
                        self.edge2seq[j] = seq
                    else:
                        assert False

            assert self.nx_graph.in_degree(u) == 0
            assert self.nx_graph.out_degree(u) == 0
            self.nx_graph.remove_node(u)
            del self.node2len[u]
예제 #7
0
    def _update_unresolved_vertices(self):
        for u in self.nx_graph.nodes:
            if u in self.unresolved:
                continue
            in_indexes = set(
                [self.edge2index[e_in]
                 for e_in in self.nx_graph.in_edges(u, keys=True)])
            out_indexes = set(
                [self.edge2index[e_out]
                 for e_out in self.nx_graph.out_edges(u, keys=True)])
            indegree = self.nx_graph.in_degree(u)
            outdegree = self.nx_graph.out_degree(u)
            if indegree == 1 and outdegree == 1:
                self_loop = in_indexes == out_indexes
                # assert self_loop
                self.unresolved.add(u)
            elif indegree >= 2 and outdegree >= 2:
                # do not process anything at all
                # self.unresolved.add(u)

                # process only fully resolved vertices
                # all_ac = self.idb_mappings.get_active_connections()
                # pairs = set()
                # for e_in in in_indexes:
                #     for e_out in out_indexes:
                #         if (e_in, e_out) in all_ac:
                #             pairs.add((e_in, e_out))
                # all_pairs = set((e_in, e_out)
                #                 for e_in in in_indexes
                #                 for e_out in out_indexes)
                # if len(all_pairs - pairs):
                #     self.unresolved.add(u)

                # initial heuristic
                paired_in = set()
                paired_out = set()
                loops = in_indexes & out_indexes
                if len(loops) == 1:
                    loop = fst_iterable(loops)
                    if loop in self.unique_edges:
                        rest_in = in_indexes - loops
                        rest_out = out_indexes - loops
                        if len(rest_in) == 1:
                            in_index = fst_iterable(rest_in)
                            paired_in.add(in_index)
                            paired_out.add(loop)
                        if len(rest_out) == 1:
                            out_index = fst_iterable(rest_out)
                            paired_in.add(loop)
                            paired_out.add(out_index)

                for e_in in in_indexes:
                    for e_out in out_indexes:
                        if (e_in, e_out) in self.idb_mappings.pairindex2pos:
                            paired_in.add(e_in)
                            paired_out.add(e_out)

                unpaired_in = set(in_indexes) - paired_in
                unpaired_out = set(out_indexes) - paired_out
                if len(unpaired_in) == 1 and len(unpaired_out) == 1:
                    if len(set(in_indexes) - self.unique_edges) == 0 or \
                            len(set(out_indexes) - self.unique_edges) == 0:
                        unpaired_in_single = list(unpaired_in)[0]
                        unpaired_out_single = list(unpaired_out)[0]
                        paired_in.add(unpaired_in_single)
                        paired_out.add(unpaired_out_single)

                tips = (in_indexes - paired_in) | (out_indexes - paired_out)

                if len(tips):
                    self.unresolved.add(u)

        prev, new = self.unresolved, set()
        while len(prev):
            for u in prev:
                for edge in self.nx_graph.in_edges(u, keys=True):
                    index = self.edge2index[edge]
                    seq = self.edge2seq[index]
                    v = edge[0]
                    if v in self.unresolved:
                        continue
                    if self.node2len[v] + 1 == len(seq):
                        new.add(v)
                for edge in self.nx_graph.out_edges(u, keys=True):
                    index = self.edge2index[edge]
                    seq = self.edge2seq[index]
                    v = edge[1]
                    if v in self.unresolved:
                        continue
                    if self.node2len[v] + 1 == len(seq):
                        new.add(v)
            self.unresolved |= new
            prev, new = new, set()
예제 #8
0
def _find_string_overlaps(string,
                          neutral_symbs,
                          overlap_penalty,
                          edges,
                          addEq):
    monolen = len(string)
    neutral_symb = fst_iterable(neutral_symbs)
    neutral_run_len = max(0, monolen - overlap_penalty)
    neutral_run = neutral_symb * neutral_run_len
    neutral_run = tuple(neutral_run)
    overlaps = []
    for edge, edge_string in edges.items():
        ext_edge_string = neutral_run + edge_string + neutral_run
        align = edlib.align(string,
                            ext_edge_string,
                            mode='HW',
                            k=0,
                            task='locations',
                            additionalEqualities=addEq)
        locations = align['locations']
        for e_st, e_en in locations:
            e_en += 1  # [e_s; e_e)
            assert e_en - e_st == monolen

            # 0123456789
            # --read----
            # ???edge??? => e_s, e_e = [2, 6)
            # left_hanging = max(0, 3-2) = 1
            # right_hanging = max(0, 6-3-4) = 0
            # s_s = left_hanging = 1
            # s_e = 4-right_hanging = 4-0 = 4
            # e_s = max(0, 2-3) = 0
            # e_e = 6-3-0=3

            # 0123456789
            # ----read--
            # ???edge??? => e_s, e_e = [4, 8)
            # left_hanging = max(0, 3-4) = 0
            # right_hanging = max(0, 8-3-4) = 1
            # s_s = 0
            # s_e = 4-1 = 3
            # e_s = max(0, 4-3) = 1
            # e_e = 8-3-1=4

            left_hanging = max(0, neutral_run_len-e_st)
            right_hanging = \
                max(0, e_en-neutral_run_len-len(edge_string))
            s_st, s_en = left_hanging, monolen - right_hanging
            e_st = max(0, e_st-neutral_run_len)
            e_en = e_en - neutral_run_len - right_hanging
            assert e_en == e_st + s_en - s_st
            assert 0 <= e_st < e_en <= len(edge_string)
            assert 0 <= s_st < s_en <= monolen
            for c1, c2 in zip(string[s_st:s_en],
                              edge_string[e_st:e_en]):
                if c1 != c2:
                    assert c1 in neutral_symbs or \
                        c2 in neutral_symbs
            overlap = Overlap(edge=edge,
                              s_st=s_st, s_en=s_en,
                              e_st=e_st, e_en=e_en)
            overlaps.append(overlap)
    overlaps.sort(key=lambda overlap: overlap.s_en)
    return overlaps