Ejemplo n.º 1
0
def merge_paralogs(G):

    node_count = max(list(G.nodes())) + 10

    # group paralog nodes by centroid
    paralog_centroids = defaultdict(list)
    for node in G.nodes():
        if G.nodes[node]['paralog']:
            for centroid in G.nodes[node]['centroid']:
                paralog_centroids[centroid].append(node)

    # find nodes that share common centroids
    paralog_centroids = paralog_centroids.values()
    merge_clusters = []
    while len(paralog_centroids) > 0:
        first, *rest = paralog_centroids
        first = set(first)
        lf = -1
        while len(first) > lf:
            lf = len(first)
            rest2 = []
            for r in rest:
                if len(first.intersection(set(r))) > 0:
                    first |= set(r)
                else:
                    rest2.append(r)
            rest = rest2
        merge_clusters.append(first)
        paralog_centroids = rest

    # merge paralog nodes that share the same centroid
    for temp_c in merge_clusters:
        if len(temp_c) > 1:
            node_count += 1
            G = merge_nodes(G,
                            temp_c.pop(),
                            temp_c.pop(),
                            node_count,
                            check_merge_mems=False)
            while (len(temp_c) > 0):
                G = merge_nodes(G,
                                node_count,
                                temp_c.pop(),
                                node_count + 1,
                                check_merge_mems=False)
                node_count += 1

    return (G)
Ejemplo n.º 2
0
def merge_paralogs(G):

    node_count = max(list(G.nodes())) + 10

    # group paralog nodes by centroid
    paralog_centroid_dict = defaultdict(list)
    for node in G.nodes():
        if G.nodes[node]['paralog']:
            paralog_centroid_dict[G.nodes[node]['centroid'][0]].append(node)

    # merge paralog nodes that share the same centroid
    for centroid in paralog_centroid_dict:
        node_count += 1
        temp_c = paralog_centroid_dict[centroid]
        G = merge_nodes(G, temp_c.pop(), temp_c.pop(), node_count)
        while (len(temp_c) > 0):
            G = merge_nodes(G, node_count, temp_c.pop(), node_count + 1)
            node_count += 1

    return (G)
Ejemplo n.º 3
0
def collapse_families(G,
                      seqid_to_centroid,
                      outdir,
                      family_threshold=0.7,
                      dna_error_threshold=0.99,
                      correct_mistranslations=False,
                      n_cpu=1,
                      quiet=False,
                      distances_bwtn_centroids=None,
                      centroid_to_index=None):

    node_count = max(list(G.nodes())) + 10

    if correct_mistranslations:
        depths = [1, 2, 3]
        threshold = [0.99, 0.98, 0.95, 0.9]
    else:
        depths = [1, 2, 3]
        threshold = [0.99, 0.95, 0.9, 0.8, 0.7, 0.6, 0.5]

    # precluster for speed
    if correct_mistranslations:
        cdhit_clusters = iterative_cdhit(G,
                                         outdir,
                                         thresholds=threshold,
                                         n_cpu=n_cpu,
                                         quiet=True,
                                         dna=True,
                                         word_length=7,
                                         accurate=False)
        distances_bwtn_centroids, centroid_to_index = pwdist_edlib(
            G, cdhit_clusters, dna_error_threshold, dna=True, n_cpu=n_cpu)
    elif distances_bwtn_centroids is None:
        cdhit_clusters = iterative_cdhit(G,
                                         outdir,
                                         thresholds=threshold,
                                         n_cpu=n_cpu,
                                         quiet=True,
                                         dna=False)
        distances_bwtn_centroids, centroid_to_index = pwdist_edlib(
            G, cdhit_clusters, family_threshold, dna=False, n_cpu=n_cpu)

    # keep track of centroids for each sequence. Need this to resolve clashes
    seqid_to_index = {}
    for node in G.nodes():
        for sid in G.nodes[node]['seqIDs']:
            if "refound" in sid:
                seqid_to_index[sid] = centroid_to_index[G.nodes[node]
                                                        ["longCentroidID"][1]]
            else:
                seqid_to_index[sid] = centroid_to_index[seqid_to_centroid[sid]]

    for depth in depths:
        search_space = set(G.nodes())
        while len(search_space) > 0:
            # look for nodes to merge
            temp_node_list = list(search_space)
            removed_nodes = set()
            for node in temp_node_list:
                if node in removed_nodes: continue

                if G.degree[node] <= 2:
                    search_space.remove(node)
                    removed_nodes.add(node)
                    continue

                # find neighbouring nodes and cluster their centroid with cdhit
                neighbours = [
                    v
                    for u, v in nx.bfs_edges(G, source=node, depth_limit=depth)
                ] + [node]

                # find clusters
                clusters = single_linkage(G, distances_bwtn_centroids,
                                          centroid_to_index, neighbours)

                for cluster in clusters:

                    # check if there are any to collapse
                    if len(cluster) <= 1: continue

                    # check for conflicts
                    members = []
                    for n in cluster:
                        for m in G.nodes[n]['members']:
                            members.append(m)

                    if (len(members) == len(set(members))):
                        # no conflicts so merge
                        node_count += 1
                        for neig in cluster:
                            removed_nodes.add(neig)
                            if neig in search_space: search_space.remove(neig)
                        temp_c = cluster.copy()
                        G = merge_nodes(
                            G,
                            temp_c.pop(),
                            temp_c.pop(),
                            node_count,
                            multi_centroid=(not correct_mistranslations))
                        while (len(temp_c) > 0):
                            G = merge_nodes(
                                G,
                                node_count,
                                temp_c.pop(),
                                node_count + 1,
                                multi_centroid=(not correct_mistranslations))
                            node_count += 1
                        search_space.add(node_count)
                    else:
                        # if correct_mistranslations:

                        # merge if the centroids don't conflict and the nodes are adjacent in the conflicting genome
                        # this corresponds to a mistranslation/frame shift/premature stop where one gene has been split
                        # into two in a subset of genomes

                        # build a mini graph of allowed pairwise merges
                        tempG = nx.Graph()
                        for nA, nB in itertools.combinations(cluster, 2):
                            mem_inter = sorted(
                                G.nodes[nA]['members'].intersection(
                                    G.nodes[nB]['members']))
                            if len(mem_inter) > 0:
                                # if distances_bwtn_centroids[centroid_to_index[G.nodes[nA]["longCentroidID"][1]],
                                #     centroid_to_index[G.nodes[nB]["longCentroidID"][1]]]==0:
                                #     tempG.add_edge(nA, nB)
                                # else:
                                shouldmerge = True
                                for imem in mem_inter:
                                    contig_ids = set()
                                    loc_ids = []
                                    index_ids = []
                                    for sid in G.nodes[nA]['seqIDs'] | G.nodes[
                                            nB]['seqIDs']:
                                        ssid = sid.split("_")
                                        if ssid[0] == imem:
                                            index_ids.append(
                                                seqid_to_index[sid])
                                            contig_ids.add(ssid[1])
                                            loc_ids.append(int(ssid[2]))

                                    # if len(contig_ids) > 1:
                                    #     shouldmerge = False
                                    #     break
                                    # loc_ids = np.array(loc_ids)
                                    # if np.max(np.abs(loc_ids - np.min(loc_ids))) > len(loc_ids):
                                    #     shouldmerge = False
                                    #     break

                                    if len(
                                            set(G.nodes[nA]
                                                ['centroid']).intersection(
                                                    set(G.nodes[nB]
                                                        ['centroid']))) > 0:
                                        shouldmerge = False
                                        break

                                    index_ids = np.array(index_ids)
                                    if np.sum(
                                            distances_bwtn_centroids[index_ids]
                                        [:, index_ids]) > 0:
                                        shouldmerge = False
                                        break

                                if shouldmerge:
                                    tempG.add_edge(nA, nB)
                            else:
                                tempG.add_edge(nA, nB)

                        # merge from largest clique to smallest
                        sys.setrecursionlimit(max(len(tempG.nodes), 10000))
                        clique = max_clique(tempG)
                        while len(clique) > 1:
                            clique_clusters = single_linkage(
                                G, distances_bwtn_centroids, centroid_to_index,
                                clique)
                            for clust in clique_clusters:
                                if len(clust) <= 1: continue
                                node_count += 1
                                for neig in clust:
                                    removed_nodes.add(neig)
                                    if neig in search_space:
                                        search_space.remove(neig)

                                temp_c = clust.copy()
                                G = merge_nodes(
                                    G,
                                    temp_c.pop(),
                                    temp_c.pop(),
                                    node_count,
                                    multi_centroid=(
                                        not correct_mistranslations),
                                    check_merge_mems=False)
                                while (len(temp_c) > 0):
                                    G = merge_nodes(
                                        G,
                                        node_count,
                                        temp_c.pop(),
                                        node_count + 1,
                                        multi_centroid=(
                                            not correct_mistranslations),
                                        check_merge_mems=False)
                                    node_count += 1
                                search_space.add(node_count)
                            tempG.remove_nodes_from(clique)
                            clique = max_clique(tempG)

                if node in search_space:
                    search_space.remove(node)

    return G, distances_bwtn_centroids, centroid_to_index
Ejemplo n.º 4
0
def collapse_paralogs(G, centroid_contexts, max_context=5, quiet=False):

    # contexts [centroid] = [[node, member, contig, context], ...]
    node_count = max(list(G.nodes())) + 10

    # first sort by context length, context dist to ensure ties
    #  are broken the same way
    for centroid in centroid_contexts:
        centroid_contexts[centroid] = sorted(centroid_contexts[centroid])

    # set up for context search
    centroid_to_index = {}
    ncentroids = -1
    for node in G.nodes():
        centroid = G.nodes[node]['centroid'][0]
        if centroid not in centroid_to_index:
            ncentroids += 1
            centroid_to_index[centroid] = ncentroids
            centroid_to_index[G.nodes[node]['centroid'][0]] = ncentroids
        else:
            centroid_to_index[G.nodes[node]['centroid']
                              [0]] = centroid_to_index[centroid]
    ncentroids += 1

    for centroid in tqdm(centroid_contexts):
        # calculate distance
        # d = 1 - 1/(abs(contextA-contextB))
        member_paralogs = defaultdict(list)
        for para in centroid_contexts[centroid]:
            member_paralogs[para[1]].append(para)

        ref_paralogs = max(member_paralogs.items(), key=lambda x: len(x[1]))[1]
        # for each paralog find its closest reference paralog
        cluster_dict = defaultdict(set)
        cluster_mems = defaultdict(set)
        for c, ref in enumerate(ref_paralogs):
            cluster_dict[c].add(ref[0])
            cluster_mems[c].add(ref[1])

        for para in centroid_contexts[centroid]:
            d_max = np.inf
            s_max = -np.inf
            best_cluster = None

            if para[1] == ref_paralogs[0][1]:
                # this is the reference so skip
                continue

            # first attempt by shortest path
            for c, ref in enumerate(ref_paralogs):
                if para[1] in cluster_mems[c]:
                    #dont match paralogs of the same isolate
                    continue
                # d = spath[para[0], ref[0]]
                # d = gt.shortest_distance(Gt, para[0], ref[0])
                try:
                    d = nx.shortest_path_length(G, ref[0], para[0])
                except nx.NetworkXNoPath:
                    continue
                if d < d_max:
                    d_max = d
                    best_cluster = c

            # if this fails use context
            if d_max == np.inf:
                best_cluster = 0
                s_max = -np.inf
                para_context = np.zeros(ncentroids)
                for u, node, depth in mod_bfs_edges(G, para[0], max_context):
                    para_context[centroid_to_index[G.nodes[node]['centroid']
                                                   [0]]] = depth
                for c, ref in enumerate(ref_paralogs):
                    if para[1] in cluster_mems[c]:
                        #dont match paralogs of the same isolate
                        continue
                    ref_context = np.zeros(ncentroids)
                    for u, node, depth in mod_bfs_edges(
                            G, ref[0], max_context):
                        ref_context[centroid_to_index[G.nodes[node]['centroid']
                                                      [0]]] = depth
                    s = np.sum(1 / (1 + np.abs((para_context - ref_context)[
                        (para_context * ref_context) != 0])))
                    if s > s_max:
                        s_max = s
                        best_cluster = c

            cluster_dict[best_cluster].add(para[0])
            cluster_mems[best_cluster].add(para[1])

        # merge
        for cluster in cluster_dict:
            if len(cluster_dict[cluster]) < 2: continue
            temp_c = list(cluster_dict[cluster].copy())
            node_count += 1
            G = merge_nodes(G, temp_c.pop(), temp_c.pop(), node_count)
            while (len(temp_c) > 0):
                G = merge_nodes(G, node_count, temp_c.pop(), node_count + 1)
                node_count += 1

    return (G)
Ejemplo n.º 5
0
def collapse_families(G,
                      outdir,
                      family_threshold=0.7,
                      dna_error_threshold=0.99,
                      correct_mistranslations=False,
                      n_cpu=1,
                      quiet=False,
                      distances_bwtn_centroids=None,
                      centroid_to_index=None):

    node_count = max(list(G.nodes())) + 10

    if correct_mistranslations:
        depths = [1, 2, 3]
        threshold = [0.99, 0.98, 0.95, 0.9]
    else:
        depths = [1, 2, 3]
        threshold = [0.99, 0.95, 0.9, 0.8, 0.7, 0.6, 0.5]

    # precluster for speed
    if correct_mistranslations:
        cdhit_clusters = iterative_cdhit(G,
                                         outdir,
                                         thresholds=threshold,
                                         n_cpu=n_cpu,
                                         quiet=True,
                                         dna=True,
                                         word_length=7,
                                         accurate=False)
        distances_bwtn_centroids, centroid_to_index = pwdist_edlib(
            G, cdhit_clusters, dna_error_threshold, dna=True, n_cpu=n_cpu)

        # keep track of centroids for each sequence. Need this to resolve clashes
        seqid_to_index = {}
        for node in G.nodes():
            for sid in G.node[node]['seqIDs']:
                seqid_to_index[sid] = centroid_to_index[G.node[node]
                                                        ["longCentroidID"][1]]

    elif distances_bwtn_centroids is None:
        cdhit_clusters = iterative_cdhit(G,
                                         outdir,
                                         thresholds=threshold,
                                         n_cpu=n_cpu,
                                         quiet=True,
                                         dna=False)
        distances_bwtn_centroids, centroid_to_index = pwdist_edlib(
            G, cdhit_clusters, family_threshold, dna=False, n_cpu=n_cpu)
    for depth in depths:
        search_space = set(G.nodes())
        while len(search_space) > 0:
            # look for nodes to merge
            temp_node_list = list(search_space)
            removed_nodes = set()
            for node in temp_node_list:
                if node in removed_nodes: continue

                if G.degree[node] <= 2:
                    search_space.remove(node)
                    removed_nodes.add(node)
                    continue

                # find neighbouring nodes and cluster their centroid with cdhit
                neighbours = [
                    v
                    for u, v in nx.bfs_edges(G, source=node, depth_limit=depth)
                ]
                if correct_mistranslations:
                    neighbours += [node]

                # find clusters
                index = []
                neigh_array = []
                for neigh in neighbours:
                    for sid in G.node[neigh]['centroid'].split(";"):
                        index.append(centroid_to_index[sid])
                        neigh_array.append(neigh)
                index = np.array(index, dtype=int)
                neigh_array = np.array(neigh_array)

                n_components, labels = connected_components(
                    csgraph=distances_bwtn_centroids[index][:, index],
                    directed=False,
                    return_labels=True)
                # labels = labels[index]
                for neigh in neighbours:
                    l = list(set(labels[neigh_array == neigh]))
                    if len(l) > 1:
                        for i in l[1:]:
                            labels[labels == i] = l[0]

                clusters = [
                    del_dups(list(neigh_array[labels == i]))
                    for i in np.unique(labels)
                ]

                for cluster in clusters:

                    # check if there are any to collapse
                    if len(cluster) <= 1: continue

                    # check for conflicts
                    members = []
                    for n in cluster:
                        for m in set(G.node[n]['members']):
                            members.append(m)

                    if (len(members) == len(set(members))):
                        # no conflicts so merge
                        node_count += 1
                        for neig in cluster:
                            removed_nodes.add(neig)
                            if neig in search_space: search_space.remove(neig)
                        temp_c = cluster.copy()
                        G = merge_nodes(
                            G,
                            temp_c.pop(),
                            temp_c.pop(),
                            node_count,
                            multi_centroid=(not correct_mistranslations))
                        while (len(temp_c) > 0):
                            G = merge_nodes(
                                G,
                                node_count,
                                temp_c.pop(),
                                node_count + 1,
                                multi_centroid=(not correct_mistranslations))
                            node_count += 1
                        search_space.add(node_count)
                    else:
                        # if correct_mistranslations:

                        # merge if the centroids don't conflict and the nodes are adjacent in the conflicting genome
                        # this corresponds to a mistranslation/frame shift/premature stop where one gene has been split
                        # into two in a subset of genomes

                        # build a mini graph of allowed pairwise merges
                        tempG = nx.Graph()
                        for nA, nB in itertools.combinations(cluster, 2):
                            mem_inter = set(
                                G.node[nA]['members']).intersection(
                                    G.node[nB]['members'])
                            if len(mem_inter) > 0:
                                if distances_bwtn_centroids[centroid_to_index[
                                        G.node[nA]["longCentroidID"]
                                    [1]], centroid_to_index[
                                        G.node[nB]["longCentroidID"][1]]] == 0:
                                    tempG.add_edge(nA, nB)
                                else:
                                    for imem in mem_inter:
                                        tempids = []
                                        for sid in G.node[nA][
                                                'seqIDs'] + G.node[nB][
                                                    'seqIDs']:
                                            if int(sid.split("_")[0]) == imem:
                                                tempids.append(sid)
                                        shouldmerge = True
                                        for sidA, sidB in itertools.combinations(
                                                tempids, 2):
                                            if abs(
                                                    int(sidA.split("_")[2]) -
                                                    int(sidB.split("_")[2])
                                            ) >= len(tempids):
                                                shouldmerge = False
                                            if distances_bwtn_centroids[
                                                    seqid_to_index[sidA],
                                                    seqid_to_index[sidB]] == 1:
                                                shouldmerge = False
                                        if shouldmerge:
                                            tempG.add_edge(nA, nB)
                            else:
                                tempG.add_edge(nA, nB)

                        # merge from largest clique to smallest
                        clique = max_clique(tempG)
                        while len(clique) > 1:
                            node_count += 1
                            for neig in clique:
                                removed_nodes.add(neig)
                                if neig in search_space:
                                    search_space.remove(neig)

                            temp_c = clique.copy()
                            G = merge_nodes(
                                G,
                                temp_c.pop(),
                                temp_c.pop(),
                                node_count,
                                multi_centroid=(not correct_mistranslations),
                                check_merge_mems=False)
                            while (len(temp_c) > 0):
                                G = merge_nodes(
                                    G,
                                    node_count,
                                    temp_c.pop(),
                                    node_count + 1,
                                    multi_centroid=(
                                        not correct_mistranslations),
                                    check_merge_mems=False)
                                node_count += 1
                            search_space.add(node_count)
                            tempG.remove_nodes_from(clique)
                            clique = max_clique(tempG)

                if node in search_space:
                    search_space.remove(node)

    return G, distances_bwtn_centroids, centroid_to_index