def drawTree(MS_distDict, Methyl_distDict, filtered_samples, ratio, outgroup): ''' Merge MS and Methyl distance matrices ''' merged_distMatrix = [] for sample1 in sorted(filtered_samples): sample1_dist = [] for sample2 in sorted(filtered_samples): merged_dist = (MS_distDict[sample1][sample2] * ratio) + ( Methyl_distDict[sample1][sample2] * (1 - ratio) ) / 100 #We want to scale methyl PD dist properly because PD is calculated from a 0-100 scale while MS dist is 0-1 scale sample1_dist.append(merged_dist) merged_distMatrix.append(sample1_dist) ''' Run neighbor-joining phylogenetic tree building algorithm on pairwise cell distance (saved in distDict) ''' distObj = DistanceMatrix(merged_distMatrix, sorted(filtered_samples)) print(distObj.data) skbio_tree = nj(distObj, result_constructor=str) ete_tree = Tree( skbio_tree ) #We use skbio to first make a tree from distance matrix then convert to ete tree if outgroup is "NA": return ete_tree else: if outgroup == "Midpoint": tree_midpoint = ete_tree.get_midpoint_outgroup() ete_tree.set_outgroup(tree_midpoint) else: ete_tree.set_outgroup(outgroup) return ete_tree
def generate_subfams_treestrat(fin_seqs, dio_tmp, threads): """Generates two subfamilies for a gene family. Generates subfamilies using a tree-based strategy. Args: fin_seqs (chr): Path to fasta file with sequences of genes. dio_tmp (chr): Path to folder to store temporary files in. threads (int): Number of threads to use. Returns: A list of two elements: the gene ids for subfamily one and the gene ids for subfamily two. """ run_mafft(fin_seqs, f"{dio_tmp}/seqs.aln", threads) # infer tree with iqtree run_iqtree(f"{dio_tmp}/seqs.aln", f"{dio_tmp}/tree", threads, ["-m", "LG"]) tree = Tree(f"{dio_tmp}/tree/tree.treefile") # midpoint root the tree and split in two midoutgr = tree.get_midpoint_outgroup() # if midoutgr is None: # logging.info(f"{family}: failed to midpoint root - moving on") # return(pangenome) genes_subfam1 = midoutgr.get_leaf_names() midoutgr.detach() genes_subfam2 = tree.get_leaf_names() return ([genes_subfam1, genes_subfam2])
def suspicious_clades(tree): """ Find suspicious clades (more than 70 bs and more than 2 tax groups) input: phylogenetic tree output: tuple of tree name and list of suspicious clades """ t = Tree(tree) # midpoint rooted tree R = t.get_midpoint_outgroup() t.set_outgroup(R) supported_clades = [] for node in t.traverse('preorder'): if (node.is_root() is False) and (node.is_leaf() is False): # report only clades which encompass less than a half of all oranisms if node.support >= 70 and (len(node) < (len(t) - len(node))): clade = node.get_leaf_names() if len(clade) > 1: # do we need this statement? supported_clades.append(clade) suspicious = [] for clade in supported_clades: groups = set() for org in clade: # get org name if '..' in org: org = org.split('..')[0] else: org = org.split('_')[0] groups.add(metadata[org]['Higher Taxonomy']) if len(groups) > 1: suspicious.append(clade) return tree, suspicious
def collect_contaminants(tree_file, cont_dict): """ Collect name of all sequences where position on tree corresponds to expected place for a contamination. input: tree file, contamination dict from parse_contaminants fucntion result: set of proven contaminants, set of proven contamination (same names as in csv result tables) """ t = Tree(tree_file) R = t.get_midpoint_outgroup() t.set_outgroup(R) cont_table_names = set() contaminants = set() n = 0 for node in t.traverse('preorder'): if node.is_leaf() is True: if node.name.count('_') == 4: name = node.name org = name.split('_')[0] quality = f'{node.name.split("_")[-3]}_{node.name.split("_")[-2]}_{node.name.split("_")[-1]}' table_name = f'{metadata[org]["full"]}_{quality}@{org}' if org in cont_dict: exp_hood = expected_neighborhood(node.up, cont_dict[org]) if exp_hood is True: contaminants.add(name) cont_table_names.add(table_name) n += 1 return contaminants, cont_table_names
def RapidNJ(names, profiles, embeded, handle_missing='pair_delete', **params): dist = distance_matrix.get_distance('symmetric', profiles, handle_missing) dist_file = params['tempfix'] + 'dist.list' with open(dist_file, 'w') as fout: fout.write(' {0}\n'.format(dist.shape[0])) for n, d in enumerate(dist): fout.write('{0!s:10} {1}\n'.format( n, ' '.join(['{:.6f}'.format(dd) for dd in d]))) del dist, d Popen([ params['RapidNJ_{0}'.format(platform.system())], '-n', '-x', dist_file + '_rapidnj.nwk', '-i', 'pd', dist_file ], stdout=PIPE, stderr=PIPE).communicate() tree = Tree(dist_file + '_rapidnj.nwk') for fname in glob(dist_file + '*'): os.unlink(fname) try: tree.set_outgroup(tree.get_midpoint_outgroup()) tree.unroot() except: pass for leaf in tree.get_leaves(): leaf.name = names[int(leaf.name.strip("'"))] return tree
def generateRootedTree(sequenceFileName): remove(sequenceFileName, True, True, True, False) fasta = fastaToDictionary(sequenceFileName) entries = [] for entry in fasta: entries.append(entry) ofile = open("data/OUTGROUP_" + sequenceFileName, "w") length = 0 for entry in fasta: ofile.write(">" + entry + "\n" + fasta[entry] + "\n") length = len(fasta[entry]) OUTGROUP = "A" * length ofile.write(">" + 'OUTGROUP' + "\n" + OUTGROUP + "\n") ofile.close() # Build unrooted tree containing fictional outgorup relativeURL = 'data/' + sequenceFileName outputFileName = sequenceFileName substitutionModel = "PROTGAMMABLOSUM62" outputDirectory = "/Users/williamlin/Desktop/IW/IW/phyloSim-master/trees_unrooted/" threads = "10" randomSeed = np.random.randint(0, 2000) command = "./raxml -s data/" + sequenceFileName + " -n " + outputFileName + " -m " + substitutionModel + " -w " + outputDirectory + " -T " + threads + " -p " + str( randomSeed) + " 1>>log_file" os.system(command) #root based on midpoint current = Tree('trees_unrooted/RAxML_bestTree.' + sequenceFileName) current.set_outgroup(current.get_midpoint_outgroup()) current.prune(entries) ofile = open('trees_unrooted/RAxML_bestTree.R_' + sequenceFileName, "w") ofile.write(current.write(format=1)) ofile.close()
def get_root(prefix, tree_file) : tree = Tree(tree_file, format=1) try: tree.set_outgroup( tree.get_midpoint_outgroup() ) except : pass tree.write(outfile='{0}.rooted.nwk'.format(prefix), format=1) return '{0}.rooted.nwk'.format(prefix)
def tree_to_tsvg(tree_file, contaminants=None, backpropagation=None): if contaminants is None: contaminants = set() tree_base = str(os.path.basename(tree_file)) # what if they will use somethig different than Raxml? We should make some if statement here maybe. name_ = tree_base.split('.')[1] if os.path.isfile(f'{args.input}/{name_}.trimmed') is True: build_len, len_dict, trimmed_len = get_build_len(name_) len_info = f'Final Align Len: {build_len}, Trimmed Align Len: {trimmed_len}' len_dict = {k: round(v / trimmed_len, 2) for k, v in len_dict.items()} else: build_len, len_dict = get_build_len(name_) len_info = f'Final Align Len: {build_len}' len_dict = {k: round(v / build_len, 2) for k, v in len_dict.items()} if not backpropagation: table = open(f"{output_folder}/{name_.split('_')[0]}.tsv", 'w') else: table = open(f"{output_folder}/{name_.split('_')[0]}.tsv", 'r') top_ranked = get_best_candidates(tree_file) t = Tree(tree_file) ts = TreeStyle() R = t.get_midpoint_outgroup() t.set_outgroup(R) sus_clades = 0 for node in t.traverse('preorder'): node_style = NodeStyle() node_style['vt_line_width'] = 3 node_style['hz_line_width'] = 3 node_style['vt_line_type'] = 0 node_style['hz_line_type'] = 0 if node.is_root() is False: if node.is_leaf() is False: # All internal nodes supp, sus_clades = format_nodes(node, node_style, sus_clades, t) node.add_face(supp, column=0, position="branch-bottom") node.set_style(node_style) else: # All leaves format_leaves(backpropagation, contaminants, node, node_style, table, top_ranked, len_dict) node.set_style(node_style) title_face = TextFace( f'<{name_} {len_info}, {sus_clades} suspicious clades>', bold=True) ts.title.add_face(title_face, column=1) t.render( f'{output_folder}/{name_}_tree.svg', tree_style=ts, ) if not backpropagation: # what what what? table.close()
def phylogenetic_tree_to_cluster_format(tree, pairwise_estimates): """ Convert a phylogenetic tree to a 'cluster' data structure as in ``fastcluster``. The first two columns indicate the nodes that are joined by the relevant node, the third indicates the distance (calculated from branch lengths in the case of a phylogenetic tree) and the fourth the number of leaves underneath the node. Note that the trees are rooted using midpoint-rooting. Example of the data structure (output from ``fastcluster``):: [[ 3. 7. 4.26269776 2. ] [ 0. 5. 26.75703595 2. ] [ 2. 8. 56.16007598 2. ] [ 9. 12. 78.91813609 3. ] [ 1. 11. 87.91756528 3. ] [ 4. 6. 93.04790855 2. ] [ 14. 15. 114.71302639 5. ] [ 13. 16. 137.94616373 8. ] [ 10. 17. 157.29055403 10. ]] :param tree: newick tree file :param pairwise_estimates: pairwise Ks estimates data frame (pandas) (only the index is used) :return: clustering data structure, pairwise distances dictionary """ id_map = { pairwise_estimates.index[i]: i for i in range(len(pairwise_estimates))} t = Tree(tree) # midpoint rooting midpoint = t.get_midpoint_outgroup() if not midpoint: # midpoint = None when their are only two leaves midpoint = list(t.get_leaves())[0] t.set_outgroup(midpoint) logging.debug('Tree after rooting:\n{}'.format(t.get_ascii())) # algorithm for getting cluster data structure n = len(id_map) out = [] pairwise_distances = {} for node in t.traverse('postorder'): if node.is_leaf(): node.name = id_map[node.name] id_map[node.name] = node.name # add identity map for renamed nodes # to id_map for line below pairwise_distances[node.name] = { id_map[x.name]: node.get_distance(x) for x in t.get_leaves() } else: node.name = n n += 1 children = node.get_children() out.append( [children[0].name, children[1].name, children[0].get_distance(children[1]), len(node.get_leaves())]) return np.array(out), pairwise_distances
def parse_tree(tree_file): with open(tree_file, "rU") as tree_file: lines = '' for line in tree_file: lines += line.rstrip("\n") tree = Tree(lines) outgroup = tree.get_midpoint_outgroup() tree.set_outgroup(outgroup) return tree
def parse_tree(tree_file): # Load a tree structure from a newick file. t = Tree(tree_file) #Need this otherwise groups derived from the tree are inaccurate t.resolve_polytomy() #Need to force root for consistency but should modify this behavior to support defined root root = t.get_midpoint_outgroup() #t.set_outgroup(root) return t
def CreatePhyloGeneticTree(inputfile, outputfile, size): f = open(inputfile, "r") data = f.readlines()[0] f.close() tree = Tree(data) tree.set_outgroup(tree.get_midpoint_outgroup()) ts = TreeStyle() ts.show_leaf_name = True ts.show_branch_length = False ts.show_branch_support = False ts.optimal_scale_level = "mid" t = tree.render(str(outputfile), w=size, units="px", tree_style=None)
def midpoint_root(tree): ''' Function to root trees produced from tree_build, using ete3 for midpoint rooting ''' t = Tree(tree, format=1) # Calculate the midpoint node R = t.get_midpoint_outgroup() # and set it as tree outgroup t.set_outgroup(R) #Write rooted tree to file t.write(format=1, outfile=tree + ".rooted")
def ninja(names, profiles, embeded, handle_missing='pair_delete', **params): dist = distance_matrix.get_distance('symmetric', profiles, handle_missing) dist = dist / profiles.shape[1] dist_file = params['tempfix'] + 'dist.list' with open(dist_file, 'w') as fout: fout.write(' {0}\n'.format(dist.shape[0])) for n, d in enumerate(dist): fout.write('{0!s:10} {1}\n'.format( n, ' '.join(['{:.6f}'.format(dd) for dd in d]))) del dist, d free_memory = int(0.9 * psutil.virtual_memory().total / (1024.**2)) ninja_out = Popen([ 'java', '-d64', '-Xmx' + str(free_memory) + 'M', '-jar', params['ninja_{0}'.format( platform.system())], '--in_type', 'd', dist_file ], stdout=PIPE, stderr=PIPE, universal_newlines=True).communicate() if ninja_out[1].find('64-bit JVM') >= 0: ninja_out = Popen([ 'java', '-Xmx1200M', '-jar', params['ninja_{0}'.format( platform.system())], '--in_type', 'd', dist_file ], stdout=PIPE, stderr=PIPE, universal_newlines=True).communicate() with open(dist_file + '.nwk', 'wt') as fout: fout.write(ninja_out[0]) tree = Tree(dist_file + '.nwk') for fname in glob(dist_file + '*'): os.unlink(fname) for node in tree.traverse(): node.dist *= profiles.shape[1] try: tree.set_outgroup(tree.get_midpoint_outgroup()) tree.unroot() except: pass for leaf in tree.get_leaves(): leaf.name = names[int(leaf.name.strip("'"))] return tree
def get_root(prefix, tree_file): tree = Tree(tree_file, format=1) for node in tree.traverse(): if node.dist == 0 and node.up and not node.is_leaf(): for c in node.get_children(): node.up.add_child(c) c.up = node.up node.up.remove_child(node) try: tree.set_outgroup(tree.get_midpoint_outgroup()) except: pass tree.write(outfile='{0}.rooted.nwk'.format(prefix), format=1) return '{0}.rooted.nwk'.format(prefix)
def read_tree(file_name): if file_name == "-": file_name = "/dev/stdin" t = Tree(file_name); # ignore any errors with midpoint rooting R = t.get_midpoint_outgroup() try: t.set_outgroup(R) except Exception as e: pass t.ladderize(direction=1) return t
def drawTree(distDict, alleleDict, sample_list, outgroup, prefix, bootstrap): ''' Run neighbor-joining phylogenetic tree building algorithm on pairwise cell distance (saved in distDict) ''' distMatrix = [] targetMatrix = [] pairwise_numTargets = [] sample_numTargets = [] for sample1 in sorted(sample_list): sample1_dist = [] sample1_targets = [] for sample2 in sorted(sample_list): sample_pair = tuple(sorted([sample1, sample2])) sample1_dist.append(distDict["sampleComp"][sample_pair]["dist"]) sample1_targets.append(distDict["sampleComp"][sample_pair]["num_targets"]) if sample1 != sample2: pairwise_numTargets.append(distDict["sampleComp"][sample_pair]["num_targets"]) else: sample_numTargets.append(distDict["sampleComp"][sample_pair]["num_targets"]) distMatrix.append(sample1_dist) targetMatrix.append(sample1_targets) if bootstrap is False: #Only output statistics for distance and number targets shared if for original tree (don't output for bootstrap resampling) statsOutput = open(prefix + ".buildPhylo.stats.txt", 'w') statsOutput.write("Number of Samples Analyzed:\t" + str(len(sample_list)) + "\n" + ','.join(sample_list) + "\n") statsOutput.write("Avg targets shared per pair of cells:\t" + str(float(sum(pairwise_numTargets) / len(pairwise_numTargets))) + "\t[" + str(min(pairwise_numTargets)) + "," + str(max(pairwise_numTargets)) + "]\n") statsOutput.write("Avg targets captured per single cell:\t" + str(float(sum(sample_numTargets) / len(sample_numTargets))) + "\t[" + str(min(sample_numTargets)) + "," + str(max(sample_numTargets)) + "]\n") for dist_indx,dist_list in enumerate(distMatrix): #Print matrix containing distances statsOutput.write(sorted(sample_list)[dist_indx] + "," + ",".join(str(round(i,3)) for i in dist_list) + "\n") for target_indx,target_list in enumerate(targetMatrix): #Print matrix containing number targets shared between each pair statsOutput.write(sorted(sample_list)[target_indx] + "," + ",".join(str(j) for j in target_list) + "\n") statsOutput.close() pickle.dump(distDict, open(prefix + ".buildPhylo.distDict.pkl", "wb")) #We want to print out the distance information for each single cell pair that was used to buildPhylo (this will be useful for downstream statistics) distObj = DistanceMatrix(distMatrix,sorted(sample_list)) skbio_tree = nj(distObj, result_constructor=str) ete_tree = Tree(skbio_tree) #We use skbio to first make a tree from distance matrix then convert to ete tree if outgroup is "NA": return ete_tree else: if outgroup == "Midpoint": tree_midpoint = ete_tree.get_midpoint_outgroup() if tree_midpoint is not None: ete_tree.set_outgroup(tree_midpoint) else: print(ete_tree.write(format = 0)) return None #We want to throw out tree if midpoint was not found else: ete_tree.set_outgroup(outgroup) return ete_tree
def root_iqtree(self): """Midpoint or user-defined root setting of iqtree. """ from ete3 import Tree tree = Tree(self.outfiles["treefile"], format=0) root_ = self.root root = None if root_ == 'midpoint': root = tree.get_midpoint_outgroup() else: root = root_ tree.set_outgroup(root) tree.ladderize(direction=1) # dist_formatter is to prevent scientific notation. # with branch lengths in scientific notation, ClusterPicker dies. tree.write(outfile=self.outfiles["rooted_treefile"], dist_formatter="%0.16f")
def reroot_trees(tree_computation, species_tree_polytomies): nog_id,tree_time,tree_nw = tree_computation assert tree_nw, "Tree newick is non existant for %d %d"%(nog_id,tree_time) t = Tree(tree_nw) if species_tree_polytomies: # reconciliation algorithm can only have one input with multifurcations/polytomies t.resolve_polytomy(recursive=True) node = t.get_midpoint_outgroup() if node: t.set_outgroup(node) rerooted_job = (nog_id,tree_time,t.write()) return rerooted_job else: sys.stderr.write('Problems in rerooting %s %s %s'%(nog_id,tree_time,tree_nw)) return (nog_id,tree_time,tree_nw)
def njWithRoot(dis_matrix, muestraPmid): # no culcula la distancia, solo le da un formato mas adecuado a las distancias con los ids muestraPmidStr = [str(i) for i in muestraPmid] ver = dis_matrix.tolist() dm = DistanceMatrix(ver, muestraPmidStr) treeOrig = nj(dm, result_constructor=str) # ponerle raiz t = TreeEte(treeOrig) R = t.get_midpoint_outgroup() t.set_outgroup(R) # imprime el arbol #print(t) # imprime el newick tree = t.write(format=3) tree = TreeEte(tree, format=1) #print(tree) #a = newick_to_pairwise_nodes(tree) #print(a) return tree
def analyze_tree(tree_filename, full_name_studied_gene, node_support): global result_nohgt global result_hgt global result_complex global result_unknown # Load a tree structure from a newick file gene_tree = Tree(tree_filename,format=0) if node_support != 0: node_supports = [] for node in gene_tree.traverse("preorder"): node_supports.append(node.support) if all(i <= 1 for i in node_supports): node_support = node_support/100 for node in gene_tree.traverse("preorder"): if "@" not in node.name: if node.support < node_support: node.delete() no_TOI = True only_TOI = True # Check if no_TOI or only_TOI to speed up calculations for node in gene_tree: if "@TOI" in str(node): no_TOI = False elif "EGP" not in str(node) or "StudiedOrganism" not in str(node): only_TOI = False if only_TOI: return "only_TOI" # Root the tree using the midpoint R = gene_tree.get_midpoint_outgroup() if(R != None): gene_tree.set_outgroup(R) return analysis(gene_tree, full_name_studied_gene)
def generateTestCases(n=500): hostCases = 0 while hostCases < n / 10: try: host = withHost(8, .3)[0] except: continue guestCases = 0 while guestCases < 10: printProgressBar(hostCases * 10 + guestCases, n) try: guest = withHost(8, .3, host)[1] except: continue writeMapping(genMap(host, guest), 'guest.map') folder_name = '60_examples/' + str(hostCases*10 + guestCases) + '/' system('mkdir ' + folder_name) system('mv host.nwk guest.nwk sequences.fa guest.map ' + folder_name) system('mv ' + folder_name + 'guest.nwk ' + folder_name + 'guest_full.nwk') guest.write(format=1, outfile=folder_name + 'guest.nwk') #Run RAxML system('rm RAxML_*') raxml(folder_name + 'sequences.fa', 'nwk') rax = Tree('RAxML_bestTree.nwk') rax.set_outgroup(rax.get_midpoint_outgroup()) name(rax) writeTree(rax, folder_name + 'RAxML_bestTree.nwk') guestCases += 1 hostCases += 1
def get_tree(infile): tree = Tree(infile) for x in tree.traverse(): if not x.is_leaf(): continue x.name = x.name.replace("'", '').split('.')[0] if x.name == 'genome': x.name = 'NT12001_189' strains = {x.name.split('_')[0] for x in tree.traverse() if x.is_leaf()} for s in strains: nodes = sorted([x for x in tree.traverse() if x.name.startswith(s)], key=lambda x: x.name) if len(nodes) == 1: continue for node in nodes[1:]: node.delete() for x in tree.traverse(): if not x.is_leaf(): continue x.name = x.name.split('_')[0] tree.set_outgroup(tree.get_midpoint_outgroup()) return tree
def get_species_tree(biodb): from ete3 import Tree,TreeStyle server, db = manipulate_biosqldb.load_db(biodb) species2n_complete_genomes, species2n_draft_genomes, species2completeness = get_species_data(server, biodb) sql_tree = 'select tree from reference_phylogeny t1 inner join biodatabase t2 on t1.biodatabase_id=t2.biodatabase_id ' \ ' where t2.name="%s";' % biodb server, db = manipulate_biosqldb.load_db(biodb) complete_tree = Tree(server.adaptor.execute_and_fetchall(sql_tree,)[0][0]) R = complete_tree.get_midpoint_outgroup() complete_tree.set_outgroup(R) sql = 'select distinct taxon_id,species from taxid2species_%s t1 ' \ ' inner join species_curated_taxonomy_%s t2 on t1.species_id=t2.species_id;' % (biodb, biodb) taxon_id2species_id = manipulate_biosqldb.to_dict(server.adaptor.execute_and_fetchall(sql,)) # changing taxon id to species id for leaf in complete_tree.iter_leaves(): #print '%s --> %s' % (leaf.name, str(taxon_id2species_id[str(leaf.name)])) leaf.name = "%s" % str(taxon_id2species_id[str(leaf.name)]) # attributing unique id to each node # if all node descendant have the same name, use that name as node name n = 0 for node in complete_tree.traverse(): if node.name=='': desc_list = list(set([i.name for i in node.iter_descendants()])) try: desc_list.remove('') except ValueError: pass if len(desc_list) != 1: node.name = '%sbb' % n else: node.name = desc_list[0] n+=1 # Collapsing nodes while traversing # http://etetoolkit.org/docs/latest/tutorial/tutorial_trees.html#collapsing-nodes-while-traversing-custom-is-leaf-definition node2labels = complete_tree.get_cached_content(store_attr="name") def collapsed_leaf(node): if len(node2labels[node]) == 1: return True else: return False species_tree = Tree(complete_tree.write(is_leaf_fn=collapsed_leaf)) for lf_count, lf in enumerate(species_tree.iter_leaves()): try: n_complete_genomes = species2n_complete_genomes[lf.name] except: n_complete_genomes = False try: n_draft_genomes = species2n_draft_genomes[lf.name] except: n_draft_genomes = False if n_draft_genomes: c1 = round(species2completeness[lf.name][0]) c2 = round(species2completeness[lf.name][1]) if c1 == c2: completeness = "%s%%" % c1 else: completeness = "%s-%s%%" % (c1, c2) if n_complete_genomes and n_draft_genomes: lf.name = "%s (%sc/%sd, %s)" % (lf.name, n_complete_genomes, n_draft_genomes, completeness) if n_complete_genomes and not n_draft_genomes: lf.name = "%s (%sc)" % (lf.name, n_complete_genomes) if not n_complete_genomes and n_draft_genomes: lf.name = "%s (%sd, %s)" % (lf.name, n_draft_genomes, completeness) return complete_tree, species_tree
def plot_phylo(nw_tree, out_name, parenthesis_classif=True, show_support=False, radial_mode=False, root=False): from ete3 import Tree, AttrFace, TreeStyle, NodeStyle, TextFace import orthogroup2phylogeny_best_refseq_uniprot_hity ete2_tree = Tree(nw_tree, format=0) if root: R = ete2_tree.get_midpoint_outgroup() # and set it as tree outgroup ete2_tree.set_outgroup(R) ete2_tree.set_outgroup('Bacillus subtilis') ete2_tree.ladderize() if parenthesis_classif: print('parenthesis_classif!') name2classif = {} for lf in ete2_tree.iter_leaves(): print(lf) try: classif = lf.name.split('_')[-2][0:-1] print('classif', classif) #lf.name = lf.name.split('(')[0] name2classif[lf.name] = classif except: pass classif_list = list(set(name2classif.values())) classif2col = dict( zip( classif_list, orthogroup2phylogeny_best_refseq_uniprot_hity. get_spaced_colors(len(classif_list)))) for lf in ete2_tree.iter_leaves(): #try: if parenthesis_classif: try: col = classif2col[name2classif[lf.name]] except: col = 'black' else: col = 'black' #print col #lf.name = '%s|%s-%s' % (lf.name, accession2name_and_phylum[lf.name][0],accession2name_and_phylum[lf.name][1]) if radial_mode: ff = AttrFace("name", fsize=12, fstyle='italic') else: ff = AttrFace("name", fsize=12, fstyle='italic') #ff.background.color = 'red' ff.fgcolor = col lf.add_face(ff, column=0) if not show_support: print('support') for n in ete2_tree.traverse(): print(n.support) nstyle = NodeStyle() if float(n.support) < 1: nstyle["fgcolor"] = "red" nstyle["size"] = 4 n.set_style(nstyle) else: nstyle["fgcolor"] = "red" nstyle["size"] = 0 n.set_style(nstyle) else: for n in ete2_tree.traverse(): nstyle = NodeStyle() nstyle["fgcolor"] = "red" nstyle["size"] = 0 n.set_style(nstyle) #nameFace = AttrFace(lf.name, fsize=30, fgcolor=phylum2col[accession2name_and_phylum[lf.name][1]]) #faces.add_face_to_node(nameFace, lf, 0, position="branch-right") # #nameFace.border.width = 1 ''' except: col = 'red' print col lf.name = '%s| %s' % (lf.name, locus2organism[lf.name]) ff = AttrFace("name", fsize=12) #ff.background.color = 'red' ff.fgcolor = col lf.add_face(ff, column=0) ''' #n = TextFace(lf.name, fgcolor = "black", fsize = 12, fstyle = 'italic') #lf.add_face(n, 0) ''' for n in ete2_tree.traverse(): nstyle = NodeStyle() if n.support < 90: nstyle["fgcolor"] = "black" nstyle["size"] = 4 n.set_style(nstyle) else: nstyle["fgcolor"] = "red" nstyle["size"] = 0 n.set_style(nstyle) ''' ts = TreeStyle() ts.show_leaf_name = False #ts.scale=2000 #ts.scale=20000 ts.show_branch_support = show_support if radial_mode: ts.mode = "c" ts.arc_start = -90 ts.arc_span = 360 ts.tree_width = 370 ts.complete_branch_lines_when_necessary = True ete2_tree.render(out_name, tree_style=ts, w=900)
def plot_tree_barplot(tree_file, taxon2value_list_barplot, header_list, taxon2set2value_heatmap=False, header_list2=False, column_scale=True, general_max=False, barplot2percentage=False, taxon2mlst=False): ''' display one or more barplot :param tree_file: :param taxon2value_list: :param exclude_outgroup: :param bw_scale: :param barplot2percentage: list of bool to indicates if the number are percentages and the range should be set to 0-100 :return: ''' import matplotlib.cm as cm from matplotlib.colors import rgb2hex import matplotlib as mpl if taxon2mlst: mlst_list = list(set(taxon2mlst.values())) mlst2color = dict(zip(mlst_list, get_spaced_colors(len(mlst_list)))) mlst2color['-'] = 'white' if isinstance(tree_file, Tree): t1 = tree_file else: t1 = Tree(tree_file) # Calculate the midpoint node R = t1.get_midpoint_outgroup() # and set it as tree outgroup t1.set_outgroup(R) tss = TreeStyle() value = 1 tss.draw_guiding_lines = True tss.guiding_lines_color = "gray" tss.show_leaf_name = False if column_scale and header_list2: import matplotlib.cm as cm from matplotlib.colors import rgb2hex import matplotlib as mpl column2scale = {} for column in header_list2: values = taxon2set2value_heatmap[column].values() norm = mpl.colors.Normalize(vmin=min(values), vmax=max(values)) cmap = cm.OrRd m = cm.ScalarMappable(norm=norm, cmap=cmap) column2scale[column] = m cmap = cm.YlGnBu #YlOrRd#OrRd values_lists = taxon2value_list_barplot.values() scale_list = [] max_value_list = [] for n, header in enumerate(header_list): #print 'scale', n, header data = [float(i[n]) for i in values_lists] if barplot2percentage is False: max_value = max(data) #3424182# min_value = min(data) #48.23 else: if barplot2percentage[n] is True: max_value = 100 min_value = 0 else: max_value = max(data) #3424182# min_value = min(data) #48.23 norm = mpl.colors.Normalize(vmin=min_value, vmax=max_value) m1 = cm.ScalarMappable(norm=norm, cmap=cmap) scale_list.append(m1) if not general_max: max_value_list.append(float(max_value)) else: max_value_list.append(general_max) for i, lf in enumerate(t1.iter_leaves()): #if taxon2description[lf.name] == 'Pirellula staleyi DSM 6068': # lf.name = 'Pirellula staleyi DSM 6068' # continue if i == 0: col_add = 0 if taxon2mlst: header_list = ['MLST'] + header_list for col, header in enumerate(header_list): #lf.add_face(n, column, position="aligned") n = TextFace(' ') n.margin_top = 1 n.margin_right = 2 n.margin_left = 2 n.margin_bottom = 1 n.rotation = 90 n.inner_background.color = "white" n.opacity = 1. n.hz_align = 2 n.vt_align = 2 tss.aligned_header.add_face(n, col_add + 1) n = TextFace('%s' % header) n.margin_top = 1 n.margin_right = 2 n.margin_left = 2 n.margin_bottom = 2 n.rotation = 270 n.inner_background.color = "white" n.opacity = 1. n.hz_align = 2 n.vt_align = 1 tss.aligned_header.add_face(n, col_add) col_add += 2 if header_list2: for col, header in enumerate(header_list2): n = TextFace('%s' % header) n.margin_top = 1 n.margin_right = 20 n.margin_left = 2 n.margin_bottom = 1 n.rotation = 270 n.hz_align = 2 n.vt_align = 2 n.inner_background.color = "white" n.opacity = 1. tss.aligned_header.add_face(n, col + col_add) if taxon2mlst: try: #if lf.name in leaf2mlst or int(lf.name) in leaf2mlst: n = TextFace(' %s ' % taxon2mlst[int(lf.name)]) n.inner_background.color = 'white' m = TextFace(' ') m.inner_background.color = mlst2color[taxon2mlst[int(lf.name)]] except: n = TextFace(' na ') n.inner_background.color = "grey" m = TextFace(' ') m.inner_background.color = "white" n.opacity = 1. n.margin_top = 2 n.margin_right = 2 n.margin_left = 0 n.margin_bottom = 2 m.margin_top = 2 m.margin_right = 0 m.margin_left = 2 m.margin_bottom = 2 lf.add_face(m, 0, position="aligned") lf.add_face(n, 1, position="aligned") col_add = 2 else: col_add = 0 try: val_list = taxon2value_list_barplot[lf.name] except: if not taxon2mlst: val_list = ['na'] * len(header_list) else: val_list = ['na'] * (len(header_list) - 1) for col, value in enumerate(val_list): # show value itself try: n = TextFace(' %s ' % str(value)) except: n = TextFace(' %s ' % str(value)) n.margin_top = 1 n.margin_right = 5 n.margin_left = 10 n.margin_bottom = 1 n.inner_background.color = "white" n.opacity = 1. lf.add_face(n, col_add, position="aligned") # show bar try: color = rgb2hex(scale_list[col].to_rgba(float(value))) except: color = 'white' try: percentage = (value / max_value_list[col]) * 100 #percentage = value except: percentage = 0 try: maximum_bar = ( (max_value_list[col] - value) / max_value_list[col]) * 100 except: maximum_bar = 0 #maximum_bar = 100-percentage b = StackedBarFace([percentage, maximum_bar], width=100, height=10, colors=[color, "white"]) b.rotation = 0 b.inner_border.color = "grey" b.inner_border.width = 0 b.margin_right = 15 b.margin_left = 0 lf.add_face(b, col_add + 1, position="aligned") col_add += 2 if taxon2set2value_heatmap: shift = col + col_add + 1 i = 0 for col, col_name in enumerate(header_list2): try: value = taxon2set2value_heatmap[col_name][lf.name] except: try: value = taxon2set2value_heatmap[col_name][int(lf.name)] except: value = 0 if int(value) > 0: if int(value) > 9: n = TextFace(' %i ' % int(value)) else: n = TextFace(' %i ' % int(value)) n.margin_top = 1 n.margin_right = 1 n.margin_left = 20 n.margin_bottom = 1 n.fgcolor = "white" n.inner_background.color = rgb2hex( column2scale[col_name].to_rgba( float(value))) #"orange" n.opacity = 1. lf.add_face(n, col + col_add, position="aligned") i += 1 else: n = TextFace(' ') #% str(value)) n.margin_top = 1 n.margin_right = 1 n.margin_left = 20 n.margin_bottom = 1 n.inner_background.color = "white" n.opacity = 1. lf.add_face(n, col + col_add, position="aligned") n = TextFace(lf.name, fgcolor="black", fsize=12, fstyle='italic') lf.add_face(n, 0) for n in t1.traverse(): nstyle = NodeStyle() if n.support < 1: nstyle["fgcolor"] = "black" nstyle["size"] = 6 n.set_style(nstyle) else: nstyle["fgcolor"] = "red" nstyle["size"] = 0 n.set_style(nstyle) return t1, tss
def plot_tree_barplot(tree_file, taxon2mlst, header_list): ''' display one or more barplot :param tree_file: :param taxon2value_list: :param exclude_outgroup: :param bw_scale: :param barplot2percentage: list of bool to indicates if the number are percentages and the range should be set to 0-100 :return: ''' import matplotlib.cm as cm from matplotlib.colors import rgb2hex import matplotlib as mpl mlst_list = list(set(taxon2mlst.values())) mlst2color = dict(zip(mlst_list, get_spaced_colors(len(mlst_list)))) mlst2color['-'] = 'white' if isinstance(tree_file, Tree): t1 = tree_file else: t1 = Tree(tree_file) # Calculate the midpoint node R = t1.get_midpoint_outgroup() # and set it as tree outgroup t1.set_outgroup(R) tss = TreeStyle() value = 1 tss.draw_guiding_lines = True tss.guiding_lines_color = "gray" tss.show_leaf_name = False cmap = cm.YlGnBu #YlOrRd#OrRd scale_list = [] max_value_list = [] for i, lf in enumerate(t1.iter_leaves()): #if taxon2description[lf.name] == 'Pirellula staleyi DSM 6068': # lf.name = 'Pirellula staleyi DSM 6068' # continue if i == 0: # header col_add = 0 #lf.add_face(n, column, position="aligned") n = TextFace('MLST') n.margin_top = 1 n.margin_right = 2 n.margin_left = 2 n.margin_bottom = 1 n.rotation = 90 n.inner_background.color = "white" n.opacity = 1. n.hz_align = 2 n.vt_align = 2 tss.aligned_header.add_face(n, col_add + 1) try: #if lf.name in leaf2mlst or int(lf.name) in leaf2mlst: n = TextFace(' %s ' % taxon2mlst[int(lf.name)]) n.inner_background.color = 'white' m = TextFace(' ') m.inner_background.color = mlst2color[taxon2mlst[int(lf.name)]] except: n = TextFace(' na ') n.inner_background.color = "grey" m = TextFace(' ') m.inner_background.color = "white" n.opacity = 1. n.margin_top = 2 n.margin_right = 2 n.margin_left = 0 n.margin_bottom = 2 m.margin_top = 2 m.margin_right = 0 m.margin_left = 2 m.margin_bottom = 2 lf.add_face(m, 0, position="aligned") lf.add_face(n, 1, position="aligned") n = TextFace(lf.name, fgcolor="black", fsize=12, fstyle='italic') lf.add_face(n, 0) for n in t1.traverse(): nstyle = NodeStyle() if n.support < 1: nstyle["fgcolor"] = "black" nstyle["size"] = 6 n.set_style(nstyle) else: nstyle["fgcolor"] = "red" nstyle["size"] = 0 n.set_style(nstyle) return t1, tss
# Parse command line arguments. cmdln = sys.argv pb_newick = cmdln[1] pb_newick_boots_only = cmdln[2] output_file_path = cmdln[3] # Initiate a tree style. ts = TreeStyle() ts.show_leaf_name = False # Parse trees. pb_newick_tree = Tree(pb_newick, format=0) pb_newick_boots_only_tree = Tree(pb_newick_boots_only, format=0) # Root trees on midpoint. pb_newick_tree.set_outgroup(pb_newick_tree.get_midpoint_outgroup()) pb_newick_boots_only_tree.set_outgroup( pb_newick_boots_only_tree.get_midpoint_outgroup()) # Add node support values as branch labels (modifies pb_newick_tree). add_combined_support_to_nodes_as_faces(pb_newick_tree, pb_newick_boots_only_tree) # Customize the node styles generally. customize_node_styles_for_visualization(pb_newick_tree) ##################################################### # Write tree to pdf. ## Use this for running on personal computer:
class EteTool(): ''' Plot ete3 phylogenetic profiles. - self.add_simple_barplot: add a barplot face from taxon2value dictionnary - self.add_text_face: add text face - self.add_heatmap: add column with cells with value + colored background - self.rename_leaves: rename tree leaves from a dictionnary (old_name2new_name) ''' def __init__(self, tree_file): self.column_count = 0 self.default_colors = ['#fc8d59', '#91bfdb', '#99d594', '#c51b7d', '#f1a340', '#999999'] self.color_index = 0 self.rotate = False # if not tree instance, considfer it as a path or a newick string print("TREE TYOE:", type(tree_file)) if isinstance(tree_file, Tree): self.tree = tree_file elif isinstance(tree_file, ete3.phylo.phylotree.PhyloNode): self.tree = tree_file else: self.tree = Tree(tree_file) # Calculate the midpoint node R = self.tree.get_midpoint_outgroup() # and set it as tree outgroup try: self.tree.set_outgroup(R) except: pass self.tree.ladderize() self.tss = TreeStyle() self.tss.draw_guiding_lines = True self.tss.guiding_lines_color = "gray" self.tss.show_leaf_name = False def add_stacked_barplot(self, taxon2value_list, header_name, color_list=False): pass def rename_leaves(self, taxon2new_taxon, keep_original=False, add_face=True): for i, lf in enumerate(self.tree.iter_leaves()): #print(dir(lf)) #print((lf.faces[0])) #lf.faces # = None #print("Iter leaf names") #for i in lf.features: # print("i", i) if not keep_original: if lf.name in taxon2new_taxon: label = taxon2new_taxon[lf.name] else: label = 'n/a' else: if lf.name in taxon2new_taxon: label = '%s (%s)' % (taxon2new_taxon[lf.name], lf.name) else: label = 'n/a' print ("add_face", add_face) if add_face: n = TextFace(label, fgcolor = "black", fsize = 12, fstyle = 'italic') lf.add_face(n, 0) lf.name = label #print(lf) def add_heatmap(self, taxon2value, header_name, continuous_scale=False, show_text=False): from metagenlab_libs.colors import get_continuous_scale self._add_header(header_name) if continuous_scale: color_scale = get_continuous_scale(taxon2value.values()) for i, lf in enumerate(self.tree.iter_leaves()): if not lf.name in taxon2value: n = TextFace('') else: value = taxon2value[lf.name] if show_text: n = TextFace('%s' % value) else: n = TextFace(' ') n.margin_top = 2 n.margin_right = 3 n.margin_left = 3 n.margin_bottom = 2 n.hz_align = 1 n.vt_align = 1 n.border.width = 3 n.border.color = "#ffffff" if continuous_scale: n.background.color = rgb2hex(color_scale[0].to_rgba(float(value))) n.opacity = 1. i+=1 if self.rotate: n.rotation = 270 lf.add_face(n, self.column_count, position="aligned") self.column_count += 1 def _add_header(self, header_name, column_add=0): n = TextFace(f'{header_name}') n.margin_top = 1 n.margin_right = 1 n.margin_left = 20 n.margin_bottom = 1 n.hz_align = 2 n.vt_align = 2 n.rotation = 270 n.inner_background.color = "white" n.opacity = 1. # add header self.tss.aligned_header.add_face(n, self.column_count-1+column_add) def _get_default_barplot_color(self,): col = self.default_colors[self.color_index] if self.color_index == 5: self.color_index = 0 else: self.color_index += 1 return col def add_simple_barplot(self, taxon2value, header_name, color=False, show_values=False, substract_min=False, highlight_cutoff=False, highlight_reverse=False, max_value=False): if not show_values: self._add_header(header_name, column_add=0) else: self._add_header(header_name, column_add=1) values_lists = [float(i) for i in taxon2value.values()] min_value = min(values_lists) if substract_min: values_lists = [i-min_value for i in values_lists] for taxon in list(taxon2value.keys()): taxon2value[taxon] = taxon2value[taxon]-min_value if not color: color = self._get_default_barplot_color() for i, lf in enumerate(self.tree.iter_leaves()): try: value = taxon2value[lf.name] except KeyError: value = 0 if show_values: barplot_column = 1 if substract_min: real_value = value + min_value else: real_value = value if isinstance(real_value, float): a = TextFace(" %s " % str(round(real_value,2))) else: a = TextFace(" %s " % str(real_value)) a.margin_top = 1 a.margin_right = 2 a.margin_left = 5 a.margin_bottom = 1 if self.rotate: a.rotation = 270 lf.add_face(a, self.column_count, position="aligned") else: barplot_column = 0 if not max_value: fraction_biggest = (float(value)/max(values_lists))*100 else: fraction_biggest = (float(value)/max_value)*100 fraction_rest = 100-fraction_biggest if highlight_cutoff: if substract_min: real_value = value + min_value else: real_value = value if highlight_reverse: if real_value > highlight_cutoff: lcolor = "grey" else: lcolor = color else: if real_value < highlight_cutoff: lcolor = "grey" else: lcolor = color else: lcolor = color b = StackedBarFace([fraction_biggest, fraction_rest], width=100, height=15,colors=[lcolor, 'white']) b.rotation= 0 b.inner_border.color = "grey" b.inner_border.width = 0 b.margin_right = 15 b.margin_left = 0 if self.rotate: b.rotation = 270 lf.add_face(b, self.column_count + barplot_column, position="aligned") self.column_count += (1 + barplot_column) def add_barplot_counts(self,): # todo pass def remove_dots(self,): nstyle = NodeStyle() nstyle["shape"] = "sphere" nstyle["size"] = 0 nstyle["fgcolor"] = "darkred" # Applies the same static style to all nodes in the tree. Note that, # if "nstyle" is modified, changes will affect to all nodes for n in self.tree.traverse(): n.set_style(nstyle) def add_text_face(self, taxon2text, header_name, color_scale=False): from metagenlab_libs.colors import get_categorical_color_scale if color_scale: value2color = get_categorical_color_scale(taxon2text.values()) self._add_header(header_name) # add column for i, lf in enumerate(self.tree.iter_leaves()): if lf.name in taxon2text: n = TextFace('%s' % taxon2text[lf.name]) if color_scale: n.background.color = value2color[taxon2text[lf.name]] else: print(lf.name, "not in", taxon2text) n = TextFace('-') n.margin_top = 1 n.margin_right = 10 n.margin_left = 10 n.margin_bottom = 1 n.opacity = 1. if self.rotate: n.rotation= 270 lf.add_face(n, self.column_count, position="aligned") self.column_count += 1
class EteToolCompact(): ''' Plot ete3 phylogenetic profiles. - self.add_simple_barplot: add a barplot face from taxon2value dictionnary - self.add_heatmap: add column with cells with value + colored background - self.rename_leaves: rename tree leaves from a dictionnary (old_name2new_name) - self.add_categorical_colorscale_legend: add legend - self.add_continuous_colorscale_legend: add legend ''' def __init__(self, tree_file): import math self.column_count = 0 self.rotate = False self.tree = Tree(tree_file) self.tree_length = len([i for i in self.tree.iter_leaves()]) self.text_scale = (self.tree_length)*0.01 # math.log2 self.default_colors = ['#fc8d59', '#91bfdb', '#99d594', '#c51b7d', '#f1a340', '#999999'] self.color_index = 0 # Calculate the midpoint node R = self.tree.get_midpoint_outgroup() # and set it as tree outgroup self.tree.set_outgroup(R) self.tss = TreeStyle() self.tss.draw_guiding_lines = True self.tss.guiding_lines_color = "gray" self.tss.show_leaf_name = False self.tss.branch_vertical_margin = 0 def _get_default_barplot_color(self,): col = self.default_colors[self.color_index] if self.color_index == 5: self.color_index = 0 else: self.color_index += 1 return col def _add_header(self, header_name, column_add=0): n = TextFace(f'{header_name}') n.margin_top = 1 n.margin_right = 1 n.margin_left = 20 n.margin_bottom = 1 n.hz_align = 2 n.vt_align = 2 n.rotation = 270 n.inner_background.color = "white" n.opacity = 1. # add header self.tss.aligned_header.add_face(n, self.column_count-1+column_add) def rename_leaves(self, taxon2new_taxon): for i, lf in enumerate(self.tree.iter_leaves()): n = TextFace(taxon2new_taxon[lf.name], fgcolor = "black", fsize = 12, fstyle = 'italic') lf.add_face(n, 0) def add_continuous_colorscale_legend(self, title, min_val, max_val, scale): self.tss.legend.add_face(TextFace(f"{title}", fsize = 4 * self.text_scale), column=0) if min_val != max_val: n = TextFace(" " * int(self.text_scale), fsize = 4 * self.text_scale) n.margin_top = 1 n.margin_right = 1 n.margin_left = 10 n.margin_bottom = 1 n.inner_background.color = rgb2hex(scale[0].to_rgba(float(max_val))) n2 = TextFace(" " * int(self.text_scale), fsize = 4 * self.text_scale) n2.margin_top = 1 n2.margin_right = 1 n2.margin_left = 10 n2.margin_bottom = 1 n2.inner_background.color = rgb2hex(scale[0].to_rgba(float(min_val))) self.tss.legend.add_face(n, column=1) self.tss.legend.add_face(TextFace(f"{max_val} % (max)", fsize = 4 * self.text_scale), column=2) self.tss.legend.add_face(n2, column=1) self.tss.legend.add_face(TextFace(f"{min_val} % (min)", fsize = 4 * self.text_scale), column=2) else: n2 = TextFace(" " * int(self.text_scale), fsize = 4 * self.text_scale) n2.margin_top = 1 n2.margin_right = 1 n2.margin_left = 10 n2.margin_bottom = 1 n2.inner_background.color = rgb2hex(scale[0].to_rgba(float(min_val))) self.tss.legend.add_face(n2, column=0) self.tss.legend.add_face(TextFace(f"{max_val} % Id", fsize = 4 * self.text_scale), column=1) def add_categorical_colorscale_legend(self, title, scale): self.tss.legend.add_face(TextFace(f"{title}", fsize = 4 * self.text_scale), column=0) col = 1 for n,value in enumerate(scale): n2 = TextFace(" " * int(self.text_scale), fsize = 4 * self.text_scale) n2.margin_top = 1 n2.margin_right = 1 n2.margin_left = 10 n2.margin_bottom = 1 n2.inner_background.color = scale[value] self.tss.legend.add_face(n2, column=col) self.tss.legend.add_face(TextFace(f"{value}", fsize = 4 * self.text_scale), column=col+1) col+=2 if col>16: self.tss.legend.add_face(TextFace(f" ", fsize = 4 * self.text_scale), column=0) col = 1 def add_simple_barplot(self, taxon2value, header_name, color=False, show_values=False, substract_min=False, max_value=False): print("scale factor", self.text_scale) if not show_values: self._add_header(header_name, column_add=0) else: self._add_header(header_name, column_add=1) values_lists = [float(i) for i in taxon2value.values()] min_value = min(values_lists) if substract_min: values_lists = [i-min_value for i in values_lists] for taxon in list(taxon2value.keys()): taxon2value[taxon] = taxon2value[taxon]-min_value if not color: color = self._get_default_barplot_color() for i, lf in enumerate(self.tree.iter_leaves()): try: value = taxon2value[lf.name] except: value = 0 if show_values: barplot_column = 1 if isinstance(value, float): a = TextFace(" %s " % str(round(value,2))) else: a = TextFace(" %s " % str(value)) a.margin_top = 1 a.margin_right = 2 a.margin_left = 5 a.margin_bottom = 1 if self.rotate: a.rotation = 270 lf.add_face(a, self.column_count, position="aligned") else: barplot_column = 0 if not max_value: fraction_biggest = (float(value)/max(values_lists))*100 else: fraction_biggest = (float(value)/max_value)*100 fraction_rest = 100-fraction_biggest b = StackedBarFace([fraction_biggest, fraction_rest], width=100 * (self.text_scale/3), height=18, colors=[color, 'white']) b.rotation= 0 #b.inner_border.color = "grey" #b.inner_border.width = 0 b.margin_right = 10 b.margin_left = 10 b.hz_align = 2 b.vt_align = 2 b.rotable = False if self.rotate: b.rotation = 270 lf.add_face(b, self.column_count + barplot_column, position="aligned") self.column_count += (1 + barplot_column) def add_heatmap(self, taxon2value, header_name, scale_type="continuous", palette=False): from metagenlab_libs.colors import get_categorical_color_scale from metagenlab_libs.colors import get_continuous_scale if scale_type == "continuous": scale = get_continuous_scale(taxon2value.values()) self.add_continuous_colorscale_legend("Closest hit identity", min(taxon2value.values()), max(taxon2value.values()), scale) elif scale_type == "categorical": scale = get_categorical_color_scale(taxon2value.values()) self.add_categorical_colorscale_legend("MLST", scale) else: raise IOError("unknown type") for i, lf in enumerate(self.tree.iter_leaves()): n = TextFace(" " * int(self.text_scale)) if lf.name in taxon2value: value = taxon2value[lf.name] n = TextFace(" " * int(self.text_scale)) if scale_type == "categorical": n.inner_background.color = scale[value] if scale_type == "continuous": n.inner_background.color = rgb2hex(scale[0].to_rgba(float(value))) n.margin_top = 0 n.margin_right = 0 n.margin_left = 10 n.margin_bottom = 0 n.opacity = 1. if self.rotate: n.rotation= 270 lf.add_face(n, self.column_count, position="aligned") self.column_count += 1 def remove_labels(self,): for i, lf in enumerate(self.tree.iter_leaves()): n = TextFace("") lf.add_face(n, 0)
parser.add_argument( '--verbose', action='store_true', help=('Print information about the outgroup (if any) taxa to standard ' 'error')) args = parser.parse_args() tree = Tree(args.treeFile.read()) if args.outgroupRegex: from re import compile regex = compile(args.outgroupRegex) taxa = [leaf.name for leaf in tree.iter_leaves() if regex.match(leaf.name)] if taxa: ca = tree.get_common_ancestor(taxa) if args.verbose: print('Taxa for outgroup:', taxa, file=sys.stderr) print('Common ancestor:', ca.name, file=sys.stderr) print('Common ancestor is tree:', tree == ca, file=sys.stderr) if len(taxa) == 1: tree.set_outgroup(tree & taxa[0]) else: if ca == tree: tree.set_outgroup(tree.get_midpoint_outgroup()) else: tree.set_outgroup(tree.get_common_ancestor(taxa)) print(tree.get_ascii())
def main(): global YIELD_FILE global MLST_FILE global FORCE_MLST_SCHEME #Set up the file names for Nullarbor folder structure YIELD_FILE = 'yield.tab' MLST_FILE = 'mlst.tab' #Add MLST schemes to force their usage if that species is encountered #Only force schemes if there are two (e.g., A baumannii and E coli) FORCE_MLST_SCHEME = {"Acinetobacter baumannii": "abaumannii_2", "Campylobacter jejuni": "campylobacter", #"Citrobacter freundii": "cfreundii", #"Cronobacter": "cronobacter", "Enterobacter cloacae": "ecloacae", "Escherichia coli": "ecoli", #"Klebsiella oxytoca": "koxytoca", #"Klebsiella pneumoniae": "kpneumoniae", #"Pseudomonas aeruginosa": "paeruginosa" "Shigella sonnei": "ecoli", "Salmonella enterica": "senterica", "Vibrio cholerae": "vcholerae" } ''' Read in the MDU-IDs from file. For each ID, instantiate an object of class Isolate. This class associates QC data with the ID tag. Move the contigs for all isolates into a tempdir, with a temp 9-character filename. Run andi phylogenomics on all the contig sets. Infer an NJ tree using Bio Phylo from the andi-calculated distance matrix. Correct the negative branch lengths in the NJ tree using ETE3. Export the tree to file. Gather and combine the metadata for each ID as a super-matrix. Optionally, add LIMS metadata to the super-matrix from a LIMS excel spreadsheet option (adds MALDI-ToF, Submitting Lab ID, Submitting Lab species guess) and/or use the flag-if-new to highlight 'new' isolates. Export the tree and metadata to .csv, .tsv/.tab file. Export the 'isolates not found' to text file too. ''' if not ARGS.subparser_name: PARSER.print_help() sys.exit() elif ARGS.subparser_name == 'version': from .utils.version import Version Version() sys.exit() else:# ARGS.subparser_name == "run": if ARGS.Nullarbor_folders: print('Nullarbor folder structure selected.') YIELD_FILE = 'yield.clean.tab' MLST_FILE = 'mlst2.tab' EXCEL_OUT = (f"{os.path.splitext(os.path.basename(ARGS.LIMS_request_sheet))[0]}" \ f"_results.xlsx") if ARGS.threads > cpu_count(): sys.exit(f'Number of requested threads must be less than {cpu_count()}.') print(str(ARGS.threads) +' CPU processors requested.') #Check if final slash in manually specified wgs_qc path if ARGS.wgs_qc[-1] != '/': print('\n-wgs_qc path is entered as '+ARGS.wgs_qc) print('You are missing a final \'/\' on this path.') print('Exiting now.\n') sys.exit() #i) read in the IDs from file xls_table = get_isolate_request_IDs(ARGS.LIMS_request_sheet) IDs = list(set(xls_table.index.values)) #base should be a global, given that it is used in other functions too. base = os.path.splitext(ARGS.LIMS_request_sheet)[0] #ii) Return a folder path to the QC data for each available ID # using a wildcard search of the ID in IDs in ARGS.wgs_qc path. iso_paths = isolates_available(IDs) #Drop the path and keep the folder name isos = [i.split('/')[-1] for i in iso_paths] #iii) make tempdir to store the temp_contigs there for 'andi' analysis. assembly_tempdir = make_tempdir() #vi) Copy contigs to become temp_contigs into tempdir, only if andi #requested. #Translation dict to store {random 9-character filename: original filename} iso_ID_trans = {} #Dict to store each isolate under each consensus species#####maybe delete from collections import defaultdict isos_grouped_by_cons_spp = defaultdict(list) for iso in isos: #Instantiate an Isolate class for each isolate in isolates sample = Isolate(iso) #Next, we could just use iso_path+/contigs.fa, but that would skip #the if os.path.exists() test in sample.assembly(iso). assembly_path = sample.assembly() short_id = shortened_ID() #Store key,value as original_name,short_id for later retrieval. iso_ID_trans[iso] = short_id if ARGS.andi_run: cmd = 'ln -s '+assembly_path+' '+assembly_tempdir+'/'+short_id+\ '_contigs.fa' os.system(cmd) print('Creating symlink:', cmd) if len(list(iso_ID_trans.items())) > 0: with open(base+'_temp_names.txt', 'w') as tmp_names: print('\nTranslated isolate IDs:\nShort\tOriginal') for key, value in list(iso_ID_trans.items()): print(value+'\t'+key) tmp_names.write(value+'\t'+key+'\n') if ARGS.metadata_run: #summary_frames will store all of the metaDataFrames herein summary_frames = [] n_isos = len(isos) if n_isos == 0: print('\nNo isolates detected in the path '+ARGS.wgs_qc+'.') print('Exiting now.\n') sys.exit() #Kraken set at 2 threads, so 36 processes can run on 72 CPUs #Create a pool 'p' of size based on number of isolates (n_isos) if n_isos <= ARGS.threads//2: p = Pool(n_isos) else: p = Pool(ARGS.threads//2) print(f'\nRunning kraken on the assemblies ({ARGS.assembly_name} files):') results_k_cntgs = p.map(kraken_contigs_multiprocessing, isos) print(results_k_cntgs) #concat the dataframe objects res_k_cntgs = pd.concat(results_k_cntgs, axis=0, sort=False) print('\nKraken_contigs results gathered from kraken on contigs...') #Multiprocessor retrieval of kraken results on reads. Single thread #per job. if n_isos <= ARGS.threads: p = Pool(n_isos) else: p = Pool(ARGS.threads) results_k_reads = p.map(kraken_reads_multiprocessing, isos) #concat the dataframe objects res_k_reads = pd.concat(results_k_reads, axis=0) print('Kraken_reads results gathered from kraken.tab files...') #Multiprocessor retrieval of contig metrics. Single process #per job. results_metrics_contigs = p.map(metricsContigs_multiprocessing, isos) res_m_cntgs = pd.concat(results_metrics_contigs, axis=0) print('Contig metrics gathered using \'fa -t\'...') #Multiprocessor retrieval of read metrics. Single process #per job. results_metrics_reads = p.map(metricsReads_multiprocessing, isos) res_m_reads = pd.concat(results_metrics_reads, axis=0) print('Read metrics gathered from '+YIELD_FILE+' files...') #Multiprocessor retrieval of abricate results. Single process #per job. results_abricate = p.map(abricate_multiprocessing, isos) res_all_abricate = pd.concat(results_abricate, axis=0, sort=False) res_all_abricate.fillna('', inplace=True) print('Resistome hits gathered from abricate.tab files...') #append the dfs to the summary list of dfs summary_frames.append(res_k_cntgs) summary_frames.append(res_k_reads) summary_frames.append(res_m_cntgs) summary_frames.append(res_m_reads) summary_frames.append(res_all_abricate) #These next steps build up the metadata not yet obtained #(via mulitprocesses above), also replace the dm-matrix short names #with original names #Let's store the metadata for each isolate in summary_isos summary_isos = [] #Let's populate summary_isos above, isolate by isolate (in series) c = 0 for iso in isos: iso_df = [] sample = Isolate(iso) short_id = iso_ID_trans[iso] species_cntgs = res_k_cntgs.loc[iso, 'sp_krkn1_cntgs'] species_reads = res_k_reads.loc[iso, 'sp_krkn1_reads'] if species_cntgs == species_reads: species = species_cntgs else: species = 'indet' mlst_df = sample.mlst(species, sample.assembly()) iso_df.append(mlst_df) species_consensus = {'sp_krkn_ReadAndContigConsensus':species} species_cons_df = pd.DataFrame([species_consensus], index=[iso]) iso_df.append(species_cons_df) iso_df_pd = pd.concat(iso_df, axis=1) summary_isos.append(iso_df_pd) #Glue the isolate by isolate metadata into a single df summary_isos_df = pd.concat(summary_isos) #Glue the dataframes built during multiprocessing processes summary_frames_df = pd.concat(summary_frames, axis=1) #Finish up with everything in one table! metadata_overall = pd.concat([xls_table, summary_isos_df, summary_frames_df], axis=1, sort=False) metadata_overall.fillna('', inplace=True) metadata_overall.index.name = 'ISOLATE' print('\nMetadata super-matrix:') #Write this supermatrix (metadata_overall) to csv and tab/tsv csv = os.path.abspath(base+'_metadataAll.csv') tsv = os.path.abspath(base+'_metadataAll.tab') json = os.path.abspath(base+'_metadataAll.json') metadata_overall.to_csv(sys.stdout) writer = pd.ExcelWriter(EXCEL_OUT) metadata_overall.to_excel(writer,'Sheet 1', freeze_panes=(1, 1)) writer.save() print(f"\nResults written to {os.path.abspath(EXCEL_OUT)}") for k, v in zip(metadata_overall['sp_krkn_ReadAndContigConsensus'], metadata_overall.index): isos_grouped_by_cons_spp[k.replace(' ', '_')].append(v) #Run andi? if ARGS.andi_run: #Run andi andi_mat = 'andi_'+ARGS.model_andi_distance+'dist_'+base+'.mat' andi_c = 'nice andi -j -m '+ARGS.model_andi_distance+' -t '+\ str(ARGS.threads)+' '+assembly_tempdir+'/*_contigs.fa > '+\ andi_mat print('\nRunning andi with: \''+andi_c+'\'') os.system(andi_c) #Read in the andi dist matrix, convert to lower triangle dm = read_file_lines(andi_mat)[1:] dm = lower_tri(dm) #Correct the names in the matrix for iso in isos: #Could do it this way, but this is slower than a nested loop #dm.names[dm.names.index(iso_ID_trans[iso])] = iso #real 0m9.417s #user 1m18.576s #sys 0m2.620s #Nested loop is faster for i in range(0, len(dm.names)): #iso_ID_trans[iso] is the short_id if dm.names[i] == iso_ID_trans[iso]: dm.names[i] = iso #real 0m8.789s #user 1m14.637s #sys 0m2.420s #From the distance matrix in dm, infer the NJ tree from Bio.Phylo.TreeConstruction import DistanceTreeConstructor constructor = DistanceTreeConstructor() njtree = constructor.nj(dm) njtree.rooted = True from Bio import Phylo Phylo.write(njtree, 'temp.tre', 'newick') from ete3 import Tree t = Tree('temp.tre', format=1) #Get rid of negative branch lengths (an artefact, not an error, of NJ) for node in t.traverse(): node.dist = abs(node.dist) t.set_outgroup(t.get_midpoint_outgroup()) t_out = base+'_andi_NJ_'+ARGS.model_andi_distance+'dist.nwk.tre' t.write(format=1, outfile=t_out) print('Final tree (midpoint-rooted, NJ under '+\ ARGS.model_andi_distance+' distance) looks like this:') #Print the ascii tree print(t) #Remove the temp.tre os.remove('temp.tre') print('Tree (NJ under '+ARGS.model_andi_distance+\ ' distance, midpoint-rooted) written to '+t_out+'.') #Run roary? if ARGS.roary_run: roary_keepers = [ "accessory.header.embl", "accessory.tab", "accessory_binary_genes.fa", "accessory_binary_genes.fa.newick", "accessory_binary_genes_midpoint.nwk.tre", "accessory_graph.dot", "blast_identity_frequency.Rtab", "clustered_proteins", "core_accessory.header.embl", "core_accessory.tab", "core_accessory_graph.dot", "core_gene_alignment.aln", "gene_presence_absence.Ltab.csv", "gene_presence_absence.Rtab", "gene_presence_absence.csv", "number_of_conserved_genes.Rtab", "number_of_genes_in_pan_genome.Rtab", "number_of_new_genes.Rtab", "number_of_unique_genes.Rtab", "pan_genome_reference.fa", "pan_genome_sequences", "summary_statistics.txt" ] params = [(i, 'prokka') for i in isos if not os.path.exists('prokka/'+i)] if len(params) > 0: print('\nRunning prokka:') if len(params) <= ARGS.threads//2: p = Pool(len(params)) else: p = Pool(ARGS.threads//2) p.map(prokka, params) else: print('\nProkka files already exist. Let\'s move on to '+\ 'the roary analysis...') #Run Roary on the species_consensus subsets. print('Now, let\'s run roary!') for k, v in list(isos_grouped_by_cons_spp.items()): print(k, v) n_isos = len(v) if n_isos > 1: shutil.rmtree(base+'_'+k+'_roary', ignore_errors=True) roary(base, k, ' '.join(['prokka/'+iso+'/*.gff' for iso in v])) roary_genes = pd.read_table(base+'_'+k+ '_roary/gene_presence_absence.' +\ 'Rtab', index_col=0, header=0) roary_genes = roary_genes.transpose() roary_genes.to_csv(base+'_'+k+ '_roary/gene_presence_absence.Ltab.csv', mode='w', index=True, index_label='name') if n_isos > 2: from ete3 import Tree t = Tree(base+'_'+k+ '_roary/accessory_binary_genes.fa.newick', format=1) #Get rid of negative branch lengths (an artefact, #not an error, of NJ) for node in t.traverse(): node.dist = abs(node.dist) t.set_outgroup(t.get_midpoint_outgroup()) t_out = base+'_'+k+\ '_roary/accessory_binary_genes_midpoint.nwk.tre' t.write(format=1, outfile=t_out) print('\nWritten midpoint-rooted roary tree.\n') wd = os.getcwd() os.chdir(base+'_'+k+'_roary') for f_name in glob.glob('*'): if f_name not in roary_keepers: shutil.rmtree(f_name, ignore_errors=True) os.remove(f_name) os.chdir(wd) if n_isos <= 2: print('Need more than two isolates to have a meaningful '+\ 'pangenome tree. No mid-point rooting of the ' +\ 'pangenome tree performed.') wd = os.getcwd() os.chdir(base+'_'+k+'_roary') os.system('python ../collapseSites.py -f core_gene_alignment.aln -i fasta -t '+str(ARGS.threads)) if os.path.exists('core_gene_alignment_collapsed.fasta'): os.system('FastTree -nt -gtr < core_gene_alignment_collapsed.fasta > core_gene_FastTree_SNVs.tre') #calc pairwise snp dist and write to file with open('core_gene_alignment_collapsed.fasta', 'r') as inf: from Bio import AlignIO aln = AlignIO.read(inf, 'fasta') pairs = [] for i in range(0,len(aln)): lst = [(aln, i, j) for j in range(0, i+1)] pairs.append(lst) if len(pairs) <= ARGS.threads: p = Pool(len(pairs)) else: p = Pool(ARGS.threads) print('Running pw comparisons in parallel...') result = p.map(pw_calc, pairs) summary = pd.concat(result, axis=0, sort=False) summary.fillna('', inplace=True) with open('core_gene_alignment_SNV_distances.tab', 'w') as distmat: summary.to_csv(distmat, mode='w', sep='\t', index=True, index_label='name') #convert roary output to fripan compatible os.system('python ../roary2fripan.py '+base+'_'+k) roary2fripan_strains_file = pd.read_table(base+'_'+k+ '.strains', index_col=0, header=0) info_list = [] info_list.append(roary2fripan_strains_file) info_list.append(metadata_overall.loc[v, :]) strains_info_out = pd.concat(info_list, axis=1, sort=False) strains_info_out.to_csv(base+'_'+k+'.strains', mode='w', sep='\t', index=True, index_label='ID') print('Updated '+base+'_'+k+'.strains with all metadata.') os.system('cp '+base+'_'+k+'* ~/public_html/fripan') os.chdir(wd) else: print('Only one isolate in '+k+'. Need at least 2 isolates '+\ 'to run roary. Moving on...') #Keep the tempdirs created during the run if not ARGS.keep_tempdirs: shutil.rmtree(assembly_tempdir, ignore_errors=True) print('\nDeleted tempdir '+assembly_tempdir+'.') else: print('\nTempdir '+assembly_tempdir+' not deleted.') print('\nRun finished.')