def single_linkage(G, distances_bwtn_centroids, centroid_to_index, neighbours): index = [] neigh_array = [] for neigh in neighbours: for sid in G.nodes[neigh]['centroid']: index.append(centroid_to_index[sid]) neigh_array.append(neigh) index = np.array(index, dtype=int) neigh_array = np.array(neigh_array) n_components, labels = connected_components( csgraph=distances_bwtn_centroids[index][:, index], directed=False, return_labels=True) # labels = labels[index] for neigh in neighbours: l = list(set(labels[neigh_array == neigh])) if len(l) > 1: for i in l[1:]: labels[labels == i] = l[0] clusters = [ del_dups(list(neigh_array[labels == i])) for i in np.unique(labels) ] return (clusters)
def merge_nodes(G, nodeA, nodeB, newNode, multi_centroid=True, check_merge_mems=True): if check_merge_mems: if len(G.nodes[nodeA]['members'] & G.nodes[nodeB]['members']) > 0: raise ValueError("merging nodes with the same genome IDs!") # take node with most support as the 'consensus' if G.nodes[nodeA]['size'] < G.nodes[nodeB]['size']: nodeB, nodeA = nodeA, nodeB # First create a new node and combine the attributes dna = del_dups(G.nodes[nodeA]['dna'] + G.nodes[nodeB]['dna']) maxLenId = 0 max_l = 0 for i, s in enumerate(dna): if len(s) >= max_l: max_l = len(s) maxLenId = i if multi_centroid: G.add_node( newNode, size=len(G.nodes[nodeA]['members'] | G.nodes[nodeB]['members']), centroid=del_dups(G.nodes[nodeA]['centroid'] + G.nodes[nodeB]['centroid']), maxLenId=maxLenId, members=G.nodes[nodeA]['members'] | G.nodes[nodeB]['members'], seqIDs=G.nodes[nodeA]['seqIDs'] | G.nodes[nodeB]['seqIDs'], hasEnd=(G.nodes[nodeA]['hasEnd'] or G.nodes[nodeB]['hasEnd']), protein=del_dups(G.nodes[nodeA]['protein'] + G.nodes[nodeB]['protein']), dna=dna, annotation=";".join( del_dups(G.nodes[nodeA]['annotation'].split(";") + G.nodes[nodeB]['annotation'].split(";"))), description=";".join( del_dups(G.nodes[nodeA]['description'].split(";") + G.nodes[nodeB]['description'].split(";"))), lengths=G.nodes[nodeA]['lengths'] + G.nodes[nodeB]['lengths'], longCentroidID=max(G.nodes[nodeA]['longCentroidID'], G.nodes[nodeB]['longCentroidID']), paralog=(G.nodes[nodeA]['paralog'] or G.nodes[nodeB]['paralog']), mergedDNA=(G.nodes[nodeA]['mergedDNA'] or G.nodes[nodeB]['mergedDNA'])) if "prevCentroids" in G.nodes[nodeA]: G.nodes[newNode]['prevCentroids'] = ";".join( set(G.nodes[nodeA]['prevCentroids'].split(";") + G.nodes[nodeB]['prevCentroids'].split(";"))) else: G.add_node( newNode, size=len(G.nodes[nodeA]['members'] | G.nodes[nodeB]['members']), centroid=del_dups(G.nodes[nodeA]['centroid'] + G.nodes[nodeB]['centroid']), maxLenId=maxLenId, members=G.nodes[nodeA]['members'] | G.nodes[nodeB]['members'], seqIDs=G.nodes[nodeA]['seqIDs'] | G.nodes[nodeB]['seqIDs'], hasEnd=(G.nodes[nodeA]['hasEnd'] or G.nodes[nodeB]['hasEnd']), protein=del_dups(G.nodes[nodeA]['protein'] + G.nodes[nodeB]['protein']), dna=dna, annotation=G.nodes[nodeA]['annotation'], description=G.nodes[nodeA]['description'], paralog=(G.nodes[nodeA]['paralog'] or G.nodes[nodeB]['paralog']), lengths=G.nodes[nodeA]['lengths'] + G.nodes[nodeB]['lengths'], longCentroidID=max(G.nodes[nodeA]['longCentroidID'], G.nodes[nodeB]['longCentroidID']), mergedDNA=True) if "prevCentroids" in G.nodes[nodeA]: G.nodes[newNode]['prevCentroids'] = ";".join( set(G.nodes[nodeA]['prevCentroids'].split(";") + G.nodes[nodeB]['prevCentroids'].split(";"))) # Now iterate through neighbours of each node and add them to the new node neigboursB = list(G.neighbors(nodeB)) neigboursA = list(G.neighbors(nodeA)) for neighbor in neigboursA: if neighbor in neigboursB: G.add_edge(newNode, neighbor, weight=G[nodeA][neighbor]['weight'] + G[nodeB][neighbor]['weight'], members=G[nodeA][neighbor]['members'] | G[nodeB][neighbor]['members']) neigboursB.remove(neighbor) else: G.add_edge(newNode, neighbor, weight=G[nodeA][neighbor]['weight'], members=G[nodeA][neighbor]['members']) for neighbor in neigboursB: G.add_edge(newNode, neighbor, weight=G[nodeB][neighbor]['weight'], members=G[nodeB][neighbor]['members']) # remove old nodes from Graph G.remove_nodes_from([nodeA, nodeB]) if len(max(G.nodes[newNode]["dna"], key=len)) <= 0: print(G.nodes[newNode]["dna"]) raise NameError("Problem!") return G
def find_missing(G, gff_file_handles, dna_seq_file, prot_seq_file, gene_data_file, merge_id_thresh, search_radius, prop_match, pairwise_id_thresh, n_cpu, remove_by_consensus=False, verbose=True): # Iterate over each genome file checking to see if any missing accessory genes # can be found. # generate mapping between internal nodes and gff ids id_to_gff = {} with open(gene_data_file, 'r') as infile: next(infile) for line in infile: line = line.split(",") if line[2] in id_to_gff: raise NameError("Duplicate internal ids!") id_to_gff[line[2]] = line[3] # identify nodes that have been merged at the protein level merged_ids = {} for node in G.nodes(): if (len(G.nodes[node]['centroid']) > 1) or (G.nodes[node]['mergedDNA']): for sid in sorted(G.nodes[node]['seqIDs']): merged_ids[sid] = node merged_nodes = defaultdict(dict) with open(gene_data_file, 'r') as infile: next(infile) for line in infile: line = line.split(",") if line[2] in merged_ids: mem = int(sid.split("_")[0]) if merged_ids[line[2]] in merged_nodes[mem]: merged_nodes[mem][merged_ids[line[2]]] = G.nodes[ merged_ids[line[2]]]["dna"][G.nodes[merged_ids[ line[2]]]['maxLenId']] else: merged_nodes[mem][merged_ids[line[2]]] = line[5] # iterate through nodes to identify accessory genes for searching # these are nodes missing a member with at least one neighbour that has that member n_searches = 0 search_list = defaultdict(lambda: defaultdict(set)) conflicts = defaultdict(set) for node in G.nodes(): for neigh in G.neighbors(node): # seen_mems = set() for sid in sorted(G.nodes[neigh]['seqIDs']): member = sid.split("_")[0] # if member in seen_mems: continue # seen_mems.add(member) conflicts[int(member)].add((neigh, id_to_gff[sid])) if member not in G.nodes[node]['members']: if len(G.nodes[node]["dna"][G.nodes[node] ['maxLenId']]) <= 0: print(G.nodes[node]["dna"]) raise NameError("Problem!") search_list[int(member)][node].add( (G.nodes[node]["dna"][G.nodes[node]['maxLenId']], id_to_gff[sid])) n_searches += 1 if verbose: print("Number of searches to perform: ", n_searches) print("Searching...") all_hits, all_node_locs, max_seq_lengths = zip(*Parallel(n_jobs=n_cpu)( delayed(search_gff)(search_list[member], conflicts[member], gff_handle, merged_nodes=merged_nodes[member], search_radius=search_radius, prop_match=prop_match, pairwise_id_thresh=pairwise_id_thresh, merge_id_thresh=merge_id_thresh) for member, gff_handle in tqdm(enumerate(gff_file_handles), disable=(not verbose)))) if verbose: print("translating hits...") hits_trans_dict = {} for member, hits in enumerate(all_hits): hits_trans_dict[member] = Parallel(n_jobs=n_cpu)( delayed(translate_to_match)(hit[1], G.nodes[hit[0]]["protein"][0]) for hit in hits) # remove nodes that conflict (overlap) nodes_by_size = sorted([(G.nodes[node]['size'], node) for node in G.nodes()], reverse=True) nodes_by_size = [n[1] for n in nodes_by_size] member = 0 bad_node_mem_pairs = set() bad_nodes = set() for node_locs, max_seq_length in zip(all_node_locs, max_seq_lengths): seq_coverage = defaultdict( lambda: np.zeros(max_seq_length + 2, dtype=bool)) for node in nodes_by_size: if node in bad_nodes: continue if node not in node_locs: continue contig_id = node_locs[node][0] loc = node_locs[node][1] if np.sum(seq_coverage[contig_id][loc[0]:loc[1]]) >= ( 0.5 * (max(G.nodes[node]['lengths']))): if str(member) in G.nodes[node]['members']: remove_member_from_node(G, node, member) # G.nodes[node]['members'].remove(str(member)) # G.nodes[node]['size'] -= 1 bad_node_mem_pairs.add((node, member)) else: seq_coverage[contig_id][loc[0]:loc[1]] = True member += 1 for node in G.nodes(): if len(G.nodes[node]['members']) <= 0: bad_nodes.add(node) for node in bad_nodes: if node in G.nodes(): delete_node(G, node) # remove by consensus if remove_by_consensus: if verbose: print("removing by consensus...") node_hit_counter = Counter() for member, hits in enumerate(all_hits): for node, dna_hit in hits: if dna_hit == "": continue if node in bad_nodes: continue if (node, member) in bad_node_mem_pairs: continue node_hit_counter[node] += 1 for node in G: if node_hit_counter[node] > G.nodes[node]['size']: bad_nodes.add(node) for node in bad_nodes: if node in G.nodes(): delete_node(G, node) if verbose: print("Updating output...") n_found = 0 with open(dna_seq_file, 'a') as dna_out: with open(prot_seq_file, 'a') as prot_out: with open(gene_data_file, 'a') as data_out: for member, hits in enumerate(all_hits): i = -1 for node, dna_hit in hits: i += 1 if dna_hit == "": continue if node in bad_nodes: continue if (node, member) in bad_node_mem_pairs: continue hit_protein = hits_trans_dict[member][i] G.nodes[node]['members'].add(str(member)) G.nodes[node]['size'] += 1 G.nodes[node]['dna'] = del_dups(G.nodes[node]['dna'] + [dna_hit]) dna_out.write(">" + str(member) + "_refound_" + str(n_found) + "\n" + dna_hit + "\n") G.nodes[node]['protein'] = del_dups( G.nodes[node]['protein'] + [hit_protein]) prot_out.write(">" + str(member) + "_refound_" + str(n_found) + "\n" + hit_protein + "\n") data_out.write(",".join([ os.path.splitext( os.path.basename( gff_file_handles[member]))[0], "", str(member) + "_refound_" + str(n_found), str(member) + "_refound_" + str(n_found), hit_protein, dna_hit, "", "" ]) + "\n") G.nodes[node]['seqIDs'] |= set( [str(member) + "_refound_" + str(n_found)]) n_found += 1 if verbose: print("Number of refound genes: ", n_found) return (G)
def collapse_families(G, outdir, family_threshold=0.7, dna_error_threshold=0.99, correct_mistranslations=False, n_cpu=1, quiet=False, distances_bwtn_centroids=None, centroid_to_index=None): node_count = max(list(G.nodes())) + 10 if correct_mistranslations: depths = [1, 2, 3] threshold = [0.99, 0.98, 0.95, 0.9] else: depths = [1, 2, 3] threshold = [0.99, 0.95, 0.9, 0.8, 0.7, 0.6, 0.5] # precluster for speed if correct_mistranslations: cdhit_clusters = iterative_cdhit(G, outdir, thresholds=threshold, n_cpu=n_cpu, quiet=True, dna=True, word_length=7, accurate=False) distances_bwtn_centroids, centroid_to_index = pwdist_edlib( G, cdhit_clusters, dna_error_threshold, dna=True, n_cpu=n_cpu) # keep track of centroids for each sequence. Need this to resolve clashes seqid_to_index = {} for node in G.nodes(): for sid in G.node[node]['seqIDs']: seqid_to_index[sid] = centroid_to_index[G.node[node] ["longCentroidID"][1]] elif distances_bwtn_centroids is None: cdhit_clusters = iterative_cdhit(G, outdir, thresholds=threshold, n_cpu=n_cpu, quiet=True, dna=False) distances_bwtn_centroids, centroid_to_index = pwdist_edlib( G, cdhit_clusters, family_threshold, dna=False, n_cpu=n_cpu) for depth in depths: search_space = set(G.nodes()) while len(search_space) > 0: # look for nodes to merge temp_node_list = list(search_space) removed_nodes = set() for node in temp_node_list: if node in removed_nodes: continue if G.degree[node] <= 2: search_space.remove(node) removed_nodes.add(node) continue # find neighbouring nodes and cluster their centroid with cdhit neighbours = [ v for u, v in nx.bfs_edges(G, source=node, depth_limit=depth) ] if correct_mistranslations: neighbours += [node] # find clusters index = [] neigh_array = [] for neigh in neighbours: for sid in G.node[neigh]['centroid'].split(";"): index.append(centroid_to_index[sid]) neigh_array.append(neigh) index = np.array(index, dtype=int) neigh_array = np.array(neigh_array) n_components, labels = connected_components( csgraph=distances_bwtn_centroids[index][:, index], directed=False, return_labels=True) # labels = labels[index] for neigh in neighbours: l = list(set(labels[neigh_array == neigh])) if len(l) > 1: for i in l[1:]: labels[labels == i] = l[0] clusters = [ del_dups(list(neigh_array[labels == i])) for i in np.unique(labels) ] for cluster in clusters: # check if there are any to collapse if len(cluster) <= 1: continue # check for conflicts members = [] for n in cluster: for m in set(G.node[n]['members']): members.append(m) if (len(members) == len(set(members))): # no conflicts so merge node_count += 1 for neig in cluster: removed_nodes.add(neig) if neig in search_space: search_space.remove(neig) temp_c = cluster.copy() G = merge_nodes( G, temp_c.pop(), temp_c.pop(), node_count, multi_centroid=(not correct_mistranslations)) while (len(temp_c) > 0): G = merge_nodes( G, node_count, temp_c.pop(), node_count + 1, multi_centroid=(not correct_mistranslations)) node_count += 1 search_space.add(node_count) else: # if correct_mistranslations: # merge if the centroids don't conflict and the nodes are adjacent in the conflicting genome # this corresponds to a mistranslation/frame shift/premature stop where one gene has been split # into two in a subset of genomes # build a mini graph of allowed pairwise merges tempG = nx.Graph() for nA, nB in itertools.combinations(cluster, 2): mem_inter = set( G.node[nA]['members']).intersection( G.node[nB]['members']) if len(mem_inter) > 0: if distances_bwtn_centroids[centroid_to_index[ G.node[nA]["longCentroidID"] [1]], centroid_to_index[ G.node[nB]["longCentroidID"][1]]] == 0: tempG.add_edge(nA, nB) else: for imem in mem_inter: tempids = [] for sid in G.node[nA][ 'seqIDs'] + G.node[nB][ 'seqIDs']: if int(sid.split("_")[0]) == imem: tempids.append(sid) shouldmerge = True for sidA, sidB in itertools.combinations( tempids, 2): if abs( int(sidA.split("_")[2]) - int(sidB.split("_")[2]) ) >= len(tempids): shouldmerge = False if distances_bwtn_centroids[ seqid_to_index[sidA], seqid_to_index[sidB]] == 1: shouldmerge = False if shouldmerge: tempG.add_edge(nA, nB) else: tempG.add_edge(nA, nB) # merge from largest clique to smallest clique = max_clique(tempG) while len(clique) > 1: node_count += 1 for neig in clique: removed_nodes.add(neig) if neig in search_space: search_space.remove(neig) temp_c = clique.copy() G = merge_nodes( G, temp_c.pop(), temp_c.pop(), node_count, multi_centroid=(not correct_mistranslations), check_merge_mems=False) while (len(temp_c) > 0): G = merge_nodes( G, node_count, temp_c.pop(), node_count + 1, multi_centroid=( not correct_mistranslations), check_merge_mems=False) node_count += 1 search_space.add(node_count) tempG.remove_nodes_from(clique) clique = max_clique(tempG) if node in search_space: search_space.remove(node) return G, distances_bwtn_centroids, centroid_to_index