Esempio n. 1
0
def single_linkage(G, distances_bwtn_centroids, centroid_to_index, neighbours):
    index = []
    neigh_array = []
    for neigh in neighbours:
        for sid in G.nodes[neigh]['centroid']:
            index.append(centroid_to_index[sid])
            neigh_array.append(neigh)
    index = np.array(index, dtype=int)
    neigh_array = np.array(neigh_array)

    n_components, labels = connected_components(
        csgraph=distances_bwtn_centroids[index][:, index],
        directed=False,
        return_labels=True)
    # labels = labels[index]
    for neigh in neighbours:
        l = list(set(labels[neigh_array == neigh]))
        if len(l) > 1:
            for i in l[1:]:
                labels[labels == i] = l[0]

    clusters = [
        del_dups(list(neigh_array[labels == i])) for i in np.unique(labels)
    ]

    return (clusters)
Esempio n. 2
0
def merge_nodes(G,
                nodeA,
                nodeB,
                newNode,
                multi_centroid=True,
                check_merge_mems=True):

    if check_merge_mems:
        if len(G.nodes[nodeA]['members'] & G.nodes[nodeB]['members']) > 0:
            raise ValueError("merging nodes with the same genome IDs!")

    # take node with most support as the 'consensus'
    if G.nodes[nodeA]['size'] < G.nodes[nodeB]['size']:
        nodeB, nodeA = nodeA, nodeB

    # First create a new node and combine the attributes
    dna = del_dups(G.nodes[nodeA]['dna'] + G.nodes[nodeB]['dna'])
    maxLenId = 0
    max_l = 0
    for i, s in enumerate(dna):
        if len(s) >= max_l:
            max_l = len(s)
            maxLenId = i

    if multi_centroid:
        G.add_node(
            newNode,
            size=len(G.nodes[nodeA]['members'] | G.nodes[nodeB]['members']),
            centroid=del_dups(G.nodes[nodeA]['centroid'] +
                              G.nodes[nodeB]['centroid']),
            maxLenId=maxLenId,
            members=G.nodes[nodeA]['members'] | G.nodes[nodeB]['members'],
            seqIDs=G.nodes[nodeA]['seqIDs'] | G.nodes[nodeB]['seqIDs'],
            hasEnd=(G.nodes[nodeA]['hasEnd'] or G.nodes[nodeB]['hasEnd']),
            protein=del_dups(G.nodes[nodeA]['protein'] +
                             G.nodes[nodeB]['protein']),
            dna=dna,
            annotation=";".join(
                del_dups(G.nodes[nodeA]['annotation'].split(";") +
                         G.nodes[nodeB]['annotation'].split(";"))),
            description=";".join(
                del_dups(G.nodes[nodeA]['description'].split(";") +
                         G.nodes[nodeB]['description'].split(";"))),
            lengths=G.nodes[nodeA]['lengths'] + G.nodes[nodeB]['lengths'],
            longCentroidID=max(G.nodes[nodeA]['longCentroidID'],
                               G.nodes[nodeB]['longCentroidID']),
            paralog=(G.nodes[nodeA]['paralog'] or G.nodes[nodeB]['paralog']),
            mergedDNA=(G.nodes[nodeA]['mergedDNA']
                       or G.nodes[nodeB]['mergedDNA']))
        if "prevCentroids" in G.nodes[nodeA]:
            G.nodes[newNode]['prevCentroids'] = ";".join(
                set(G.nodes[nodeA]['prevCentroids'].split(";") +
                    G.nodes[nodeB]['prevCentroids'].split(";")))
    else:
        G.add_node(
            newNode,
            size=len(G.nodes[nodeA]['members'] | G.nodes[nodeB]['members']),
            centroid=del_dups(G.nodes[nodeA]['centroid'] +
                              G.nodes[nodeB]['centroid']),
            maxLenId=maxLenId,
            members=G.nodes[nodeA]['members'] | G.nodes[nodeB]['members'],
            seqIDs=G.nodes[nodeA]['seqIDs'] | G.nodes[nodeB]['seqIDs'],
            hasEnd=(G.nodes[nodeA]['hasEnd'] or G.nodes[nodeB]['hasEnd']),
            protein=del_dups(G.nodes[nodeA]['protein'] +
                             G.nodes[nodeB]['protein']),
            dna=dna,
            annotation=G.nodes[nodeA]['annotation'],
            description=G.nodes[nodeA]['description'],
            paralog=(G.nodes[nodeA]['paralog'] or G.nodes[nodeB]['paralog']),
            lengths=G.nodes[nodeA]['lengths'] + G.nodes[nodeB]['lengths'],
            longCentroidID=max(G.nodes[nodeA]['longCentroidID'],
                               G.nodes[nodeB]['longCentroidID']),
            mergedDNA=True)
        if "prevCentroids" in G.nodes[nodeA]:
            G.nodes[newNode]['prevCentroids'] = ";".join(
                set(G.nodes[nodeA]['prevCentroids'].split(";") +
                    G.nodes[nodeB]['prevCentroids'].split(";")))

    # Now iterate through neighbours of each node and add them to the new node
    neigboursB = list(G.neighbors(nodeB))
    neigboursA = list(G.neighbors(nodeA))
    for neighbor in neigboursA:
        if neighbor in neigboursB:
            G.add_edge(newNode,
                       neighbor,
                       weight=G[nodeA][neighbor]['weight'] +
                       G[nodeB][neighbor]['weight'],
                       members=G[nodeA][neighbor]['members']
                       | G[nodeB][neighbor]['members'])
            neigboursB.remove(neighbor)
        else:
            G.add_edge(newNode,
                       neighbor,
                       weight=G[nodeA][neighbor]['weight'],
                       members=G[nodeA][neighbor]['members'])

    for neighbor in neigboursB:
        G.add_edge(newNode,
                   neighbor,
                   weight=G[nodeB][neighbor]['weight'],
                   members=G[nodeB][neighbor]['members'])

    # remove old nodes from Graph
    G.remove_nodes_from([nodeA, nodeB])

    if len(max(G.nodes[newNode]["dna"], key=len)) <= 0:
        print(G.nodes[newNode]["dna"])
        raise NameError("Problem!")

    return G
Esempio n. 3
0
def find_missing(G,
                 gff_file_handles,
                 dna_seq_file,
                 prot_seq_file,
                 gene_data_file,
                 merge_id_thresh,
                 search_radius,
                 prop_match,
                 pairwise_id_thresh,
                 n_cpu,
                 remove_by_consensus=False,
                 verbose=True):

    # Iterate over each genome file checking to see if any missing accessory genes
    #  can be found.

    # generate mapping between internal nodes and gff ids
    id_to_gff = {}
    with open(gene_data_file, 'r') as infile:
        next(infile)
        for line in infile:
            line = line.split(",")
            if line[2] in id_to_gff:
                raise NameError("Duplicate internal ids!")
            id_to_gff[line[2]] = line[3]

    # identify nodes that have been merged at the protein level
    merged_ids = {}
    for node in G.nodes():
        if (len(G.nodes[node]['centroid']) >
                1) or (G.nodes[node]['mergedDNA']):
            for sid in sorted(G.nodes[node]['seqIDs']):
                merged_ids[sid] = node

    merged_nodes = defaultdict(dict)
    with open(gene_data_file, 'r') as infile:
        next(infile)
        for line in infile:
            line = line.split(",")
            if line[2] in merged_ids:
                mem = int(sid.split("_")[0])
                if merged_ids[line[2]] in merged_nodes[mem]:
                    merged_nodes[mem][merged_ids[line[2]]] = G.nodes[
                        merged_ids[line[2]]]["dna"][G.nodes[merged_ids[
                            line[2]]]['maxLenId']]
                else:
                    merged_nodes[mem][merged_ids[line[2]]] = line[5]

    # iterate through nodes to identify accessory genes for searching
    # these are nodes missing a member with at least one neighbour that has that member
    n_searches = 0
    search_list = defaultdict(lambda: defaultdict(set))
    conflicts = defaultdict(set)
    for node in G.nodes():
        for neigh in G.neighbors(node):
            # seen_mems = set()
            for sid in sorted(G.nodes[neigh]['seqIDs']):
                member = sid.split("_")[0]
                # if member in seen_mems: continue
                # seen_mems.add(member)
                conflicts[int(member)].add((neigh, id_to_gff[sid]))
                if member not in G.nodes[node]['members']:
                    if len(G.nodes[node]["dna"][G.nodes[node]
                                                ['maxLenId']]) <= 0:
                        print(G.nodes[node]["dna"])
                        raise NameError("Problem!")
                    search_list[int(member)][node].add(
                        (G.nodes[node]["dna"][G.nodes[node]['maxLenId']],
                         id_to_gff[sid]))

                    n_searches += 1

    if verbose:
        print("Number of searches to perform: ", n_searches)
        print("Searching...")

    all_hits, all_node_locs, max_seq_lengths = zip(*Parallel(n_jobs=n_cpu)(
        delayed(search_gff)(search_list[member],
                            conflicts[member],
                            gff_handle,
                            merged_nodes=merged_nodes[member],
                            search_radius=search_radius,
                            prop_match=prop_match,
                            pairwise_id_thresh=pairwise_id_thresh,
                            merge_id_thresh=merge_id_thresh)
        for member, gff_handle in tqdm(enumerate(gff_file_handles),
                                       disable=(not verbose))))

    if verbose:
        print("translating hits...")

    hits_trans_dict = {}
    for member, hits in enumerate(all_hits):
        hits_trans_dict[member] = Parallel(n_jobs=n_cpu)(
            delayed(translate_to_match)(hit[1], G.nodes[hit[0]]["protein"][0])
            for hit in hits)

    # remove nodes that conflict (overlap)
    nodes_by_size = sorted([(G.nodes[node]['size'], node)
                            for node in G.nodes()],
                           reverse=True)
    nodes_by_size = [n[1] for n in nodes_by_size]
    member = 0
    bad_node_mem_pairs = set()
    bad_nodes = set()
    for node_locs, max_seq_length in zip(all_node_locs, max_seq_lengths):
        seq_coverage = defaultdict(
            lambda: np.zeros(max_seq_length + 2, dtype=bool))

        for node in nodes_by_size:
            if node in bad_nodes: continue
            if node not in node_locs: continue
            contig_id = node_locs[node][0]
            loc = node_locs[node][1]

            if np.sum(seq_coverage[contig_id][loc[0]:loc[1]]) >= (
                    0.5 * (max(G.nodes[node]['lengths']))):
                if str(member) in G.nodes[node]['members']:
                    remove_member_from_node(G, node, member)
                # G.nodes[node]['members'].remove(str(member))
                # G.nodes[node]['size'] -= 1
                bad_node_mem_pairs.add((node, member))
            else:
                seq_coverage[contig_id][loc[0]:loc[1]] = True
        member += 1

    for node in G.nodes():
        if len(G.nodes[node]['members']) <= 0:
            bad_nodes.add(node)
    for node in bad_nodes:
        if node in G.nodes():
            delete_node(G, node)

    # remove by consensus
    if remove_by_consensus:
        if verbose:
            print("removing by consensus...")
        node_hit_counter = Counter()
        for member, hits in enumerate(all_hits):
            for node, dna_hit in hits:
                if dna_hit == "": continue
                if node in bad_nodes: continue
                if (node, member) in bad_node_mem_pairs: continue
                node_hit_counter[node] += 1
        for node in G:
            if node_hit_counter[node] > G.nodes[node]['size']:
                bad_nodes.add(node)
        for node in bad_nodes:
            if node in G.nodes():
                delete_node(G, node)

    if verbose:
        print("Updating output...")

    n_found = 0
    with open(dna_seq_file, 'a') as dna_out:
        with open(prot_seq_file, 'a') as prot_out:
            with open(gene_data_file, 'a') as data_out:
                for member, hits in enumerate(all_hits):
                    i = -1
                    for node, dna_hit in hits:
                        i += 1
                        if dna_hit == "": continue
                        if node in bad_nodes: continue
                        if (node, member) in bad_node_mem_pairs: continue
                        hit_protein = hits_trans_dict[member][i]
                        G.nodes[node]['members'].add(str(member))
                        G.nodes[node]['size'] += 1
                        G.nodes[node]['dna'] = del_dups(G.nodes[node]['dna'] +
                                                        [dna_hit])
                        dna_out.write(">" + str(member) + "_refound_" +
                                      str(n_found) + "\n" + dna_hit + "\n")
                        G.nodes[node]['protein'] = del_dups(
                            G.nodes[node]['protein'] + [hit_protein])
                        prot_out.write(">" + str(member) + "_refound_" +
                                       str(n_found) + "\n" + hit_protein +
                                       "\n")
                        data_out.write(",".join([
                            os.path.splitext(
                                os.path.basename(
                                    gff_file_handles[member]))[0], "",
                            str(member) + "_refound_" + str(n_found),
                            str(member) + "_refound_" +
                            str(n_found), hit_protein, dna_hit, "", ""
                        ]) + "\n")
                        G.nodes[node]['seqIDs'] |= set(
                            [str(member) + "_refound_" + str(n_found)])
                        n_found += 1

    if verbose:
        print("Number of refound genes: ", n_found)

    return (G)
Esempio n. 4
0
def collapse_families(G,
                      outdir,
                      family_threshold=0.7,
                      dna_error_threshold=0.99,
                      correct_mistranslations=False,
                      n_cpu=1,
                      quiet=False,
                      distances_bwtn_centroids=None,
                      centroid_to_index=None):

    node_count = max(list(G.nodes())) + 10

    if correct_mistranslations:
        depths = [1, 2, 3]
        threshold = [0.99, 0.98, 0.95, 0.9]
    else:
        depths = [1, 2, 3]
        threshold = [0.99, 0.95, 0.9, 0.8, 0.7, 0.6, 0.5]

    # precluster for speed
    if correct_mistranslations:
        cdhit_clusters = iterative_cdhit(G,
                                         outdir,
                                         thresholds=threshold,
                                         n_cpu=n_cpu,
                                         quiet=True,
                                         dna=True,
                                         word_length=7,
                                         accurate=False)
        distances_bwtn_centroids, centroid_to_index = pwdist_edlib(
            G, cdhit_clusters, dna_error_threshold, dna=True, n_cpu=n_cpu)

        # keep track of centroids for each sequence. Need this to resolve clashes
        seqid_to_index = {}
        for node in G.nodes():
            for sid in G.node[node]['seqIDs']:
                seqid_to_index[sid] = centroid_to_index[G.node[node]
                                                        ["longCentroidID"][1]]

    elif distances_bwtn_centroids is None:
        cdhit_clusters = iterative_cdhit(G,
                                         outdir,
                                         thresholds=threshold,
                                         n_cpu=n_cpu,
                                         quiet=True,
                                         dna=False)
        distances_bwtn_centroids, centroid_to_index = pwdist_edlib(
            G, cdhit_clusters, family_threshold, dna=False, n_cpu=n_cpu)
    for depth in depths:
        search_space = set(G.nodes())
        while len(search_space) > 0:
            # look for nodes to merge
            temp_node_list = list(search_space)
            removed_nodes = set()
            for node in temp_node_list:
                if node in removed_nodes: continue

                if G.degree[node] <= 2:
                    search_space.remove(node)
                    removed_nodes.add(node)
                    continue

                # find neighbouring nodes and cluster their centroid with cdhit
                neighbours = [
                    v
                    for u, v in nx.bfs_edges(G, source=node, depth_limit=depth)
                ]
                if correct_mistranslations:
                    neighbours += [node]

                # find clusters
                index = []
                neigh_array = []
                for neigh in neighbours:
                    for sid in G.node[neigh]['centroid'].split(";"):
                        index.append(centroid_to_index[sid])
                        neigh_array.append(neigh)
                index = np.array(index, dtype=int)
                neigh_array = np.array(neigh_array)

                n_components, labels = connected_components(
                    csgraph=distances_bwtn_centroids[index][:, index],
                    directed=False,
                    return_labels=True)
                # labels = labels[index]
                for neigh in neighbours:
                    l = list(set(labels[neigh_array == neigh]))
                    if len(l) > 1:
                        for i in l[1:]:
                            labels[labels == i] = l[0]

                clusters = [
                    del_dups(list(neigh_array[labels == i]))
                    for i in np.unique(labels)
                ]

                for cluster in clusters:

                    # check if there are any to collapse
                    if len(cluster) <= 1: continue

                    # check for conflicts
                    members = []
                    for n in cluster:
                        for m in set(G.node[n]['members']):
                            members.append(m)

                    if (len(members) == len(set(members))):
                        # no conflicts so merge
                        node_count += 1
                        for neig in cluster:
                            removed_nodes.add(neig)
                            if neig in search_space: search_space.remove(neig)
                        temp_c = cluster.copy()
                        G = merge_nodes(
                            G,
                            temp_c.pop(),
                            temp_c.pop(),
                            node_count,
                            multi_centroid=(not correct_mistranslations))
                        while (len(temp_c) > 0):
                            G = merge_nodes(
                                G,
                                node_count,
                                temp_c.pop(),
                                node_count + 1,
                                multi_centroid=(not correct_mistranslations))
                            node_count += 1
                        search_space.add(node_count)
                    else:
                        # if correct_mistranslations:

                        # merge if the centroids don't conflict and the nodes are adjacent in the conflicting genome
                        # this corresponds to a mistranslation/frame shift/premature stop where one gene has been split
                        # into two in a subset of genomes

                        # build a mini graph of allowed pairwise merges
                        tempG = nx.Graph()
                        for nA, nB in itertools.combinations(cluster, 2):
                            mem_inter = set(
                                G.node[nA]['members']).intersection(
                                    G.node[nB]['members'])
                            if len(mem_inter) > 0:
                                if distances_bwtn_centroids[centroid_to_index[
                                        G.node[nA]["longCentroidID"]
                                    [1]], centroid_to_index[
                                        G.node[nB]["longCentroidID"][1]]] == 0:
                                    tempG.add_edge(nA, nB)
                                else:
                                    for imem in mem_inter:
                                        tempids = []
                                        for sid in G.node[nA][
                                                'seqIDs'] + G.node[nB][
                                                    'seqIDs']:
                                            if int(sid.split("_")[0]) == imem:
                                                tempids.append(sid)
                                        shouldmerge = True
                                        for sidA, sidB in itertools.combinations(
                                                tempids, 2):
                                            if abs(
                                                    int(sidA.split("_")[2]) -
                                                    int(sidB.split("_")[2])
                                            ) >= len(tempids):
                                                shouldmerge = False
                                            if distances_bwtn_centroids[
                                                    seqid_to_index[sidA],
                                                    seqid_to_index[sidB]] == 1:
                                                shouldmerge = False
                                        if shouldmerge:
                                            tempG.add_edge(nA, nB)
                            else:
                                tempG.add_edge(nA, nB)

                        # merge from largest clique to smallest
                        clique = max_clique(tempG)
                        while len(clique) > 1:
                            node_count += 1
                            for neig in clique:
                                removed_nodes.add(neig)
                                if neig in search_space:
                                    search_space.remove(neig)

                            temp_c = clique.copy()
                            G = merge_nodes(
                                G,
                                temp_c.pop(),
                                temp_c.pop(),
                                node_count,
                                multi_centroid=(not correct_mistranslations),
                                check_merge_mems=False)
                            while (len(temp_c) > 0):
                                G = merge_nodes(
                                    G,
                                    node_count,
                                    temp_c.pop(),
                                    node_count + 1,
                                    multi_centroid=(
                                        not correct_mistranslations),
                                    check_merge_mems=False)
                                node_count += 1
                            search_space.add(node_count)
                            tempG.remove_nodes_from(clique)
                            clique = max_clique(tempG)

                if node in search_space:
                    search_space.remove(node)

    return G, distances_bwtn_centroids, centroid_to_index