示例#1
0
def drawTree(MS_distDict, Methyl_distDict, filtered_samples, ratio, outgroup):
    '''
    Merge MS and Methyl distance matrices
    '''
    merged_distMatrix = []
    for sample1 in sorted(filtered_samples):
        sample1_dist = []
        for sample2 in sorted(filtered_samples):
            merged_dist = (MS_distDict[sample1][sample2] * ratio) + (
                Methyl_distDict[sample1][sample2] * (1 - ratio)
            ) / 100  #We want to scale methyl PD dist properly because PD is calculated from a 0-100 scale while MS dist is 0-1 scale
            sample1_dist.append(merged_dist)
        merged_distMatrix.append(sample1_dist)
    '''
    Run neighbor-joining phylogenetic tree building algorithm on pairwise cell distance (saved in distDict)
    '''
    distObj = DistanceMatrix(merged_distMatrix, sorted(filtered_samples))
    print(distObj.data)
    skbio_tree = nj(distObj, result_constructor=str)
    ete_tree = Tree(
        skbio_tree
    )  #We use skbio to first make a tree from distance matrix then convert to ete tree
    if outgroup is "NA":
        return ete_tree
    else:
        if outgroup == "Midpoint":
            tree_midpoint = ete_tree.get_midpoint_outgroup()
            ete_tree.set_outgroup(tree_midpoint)
        else:
            ete_tree.set_outgroup(outgroup)
    return ete_tree
示例#2
0
def generate_subfams_treestrat(fin_seqs, dio_tmp, threads):
    """Generates two subfamilies for a gene family.
    
    Generates subfamilies using a tree-based strategy. 
    
    Args:
        fin_seqs (chr): Path to fasta file with sequences of genes.
        dio_tmp (chr): Path to folder to store temporary files in.
        threads (int): Number of threads to use.
        
    Returns:
        A list of two elements: the gene ids for subfamily one and the gene ids
            for subfamily two. 
    """
    run_mafft(fin_seqs, f"{dio_tmp}/seqs.aln", threads)
    # infer tree with iqtree
    run_iqtree(f"{dio_tmp}/seqs.aln", f"{dio_tmp}/tree", threads, ["-m", "LG"])
    tree = Tree(f"{dio_tmp}/tree/tree.treefile")
    # midpoint root the tree and split in two
    midoutgr = tree.get_midpoint_outgroup()
    # if midoutgr is None:
    #     logging.info(f"{family}: failed to midpoint root - moving on")
    #     return(pangenome)
    genes_subfam1 = midoutgr.get_leaf_names()
    midoutgr.detach()
    genes_subfam2 = tree.get_leaf_names()
    return ([genes_subfam1, genes_subfam2])
示例#3
0
def suspicious_clades(tree):
    """
    Find suspicious clades (more than 70 bs and more than 2 tax groups)
    input: phylogenetic tree
    output: tuple of tree name and list of suspicious clades 
    """
    t = Tree(tree)
    # midpoint rooted tree
    R = t.get_midpoint_outgroup()
    t.set_outgroup(R)

    supported_clades = []
    for node in t.traverse('preorder'):
        if (node.is_root() is False) and (node.is_leaf() is False):
            # report only clades which encompass less than a half of all oranisms
            if node.support >= 70 and (len(node) < (len(t) - len(node))):
                clade = node.get_leaf_names()
                if len(clade) > 1:  # do we need this statement?
                    supported_clades.append(clade)
    suspicious = []
    for clade in supported_clades:
        groups = set()
        for org in clade:
            # get org name
            if '..' in org:
                org = org.split('..')[0]
            else:
                org = org.split('_')[0]
            groups.add(metadata[org]['Higher Taxonomy'])
        if len(groups) > 1:
            suspicious.append(clade)
    return tree, suspicious
示例#4
0
def collect_contaminants(tree_file, cont_dict):
    """
    Collect name of all sequences where position on tree corresponds to expected place 
    for a contamination.
    input: tree file, contamination dict from parse_contaminants fucntion
    result: set of proven contaminants, set of proven contamination (same names as in csv result tables)
    """
    t = Tree(tree_file)
    R = t.get_midpoint_outgroup()
    t.set_outgroup(R)
    cont_table_names = set()
    contaminants = set()
    n = 0
    for node in t.traverse('preorder'):
        if node.is_leaf() is True:
            if node.name.count('_') == 4:
                name = node.name
                org = name.split('_')[0]
                quality = f'{node.name.split("_")[-3]}_{node.name.split("_")[-2]}_{node.name.split("_")[-1]}'
                table_name = f'{metadata[org]["full"]}_{quality}@{org}'
                if org in cont_dict:
                    exp_hood = expected_neighborhood(node.up, cont_dict[org])
                    if exp_hood is True:
                        contaminants.add(name)
                        cont_table_names.add(table_name)
        n += 1
    return contaminants, cont_table_names
示例#5
0
    def RapidNJ(names,
                profiles,
                embeded,
                handle_missing='pair_delete',
                **params):
        dist = distance_matrix.get_distance('symmetric', profiles,
                                            handle_missing)

        dist_file = params['tempfix'] + 'dist.list'
        with open(dist_file, 'w') as fout:
            fout.write('    {0}\n'.format(dist.shape[0]))
            for n, d in enumerate(dist):
                fout.write('{0!s:10} {1}\n'.format(
                    n, ' '.join(['{:.6f}'.format(dd) for dd in d])))
        del dist, d
        Popen([
            params['RapidNJ_{0}'.format(platform.system())], '-n', '-x',
            dist_file + '_rapidnj.nwk', '-i', 'pd', dist_file
        ],
              stdout=PIPE,
              stderr=PIPE).communicate()
        tree = Tree(dist_file + '_rapidnj.nwk')
        for fname in glob(dist_file + '*'):
            os.unlink(fname)

        try:
            tree.set_outgroup(tree.get_midpoint_outgroup())
            tree.unroot()
        except:
            pass

        for leaf in tree.get_leaves():
            leaf.name = names[int(leaf.name.strip("'"))]
        return tree
示例#6
0
def generateRootedTree(sequenceFileName):
    remove(sequenceFileName, True, True, True, False)
    fasta = fastaToDictionary(sequenceFileName)

    entries = []
    for entry in fasta:
        entries.append(entry)
    ofile = open("data/OUTGROUP_" + sequenceFileName, "w")
    length = 0
    for entry in fasta:
        ofile.write(">" + entry + "\n" + fasta[entry] + "\n")
        length = len(fasta[entry])
    OUTGROUP = "A" * length
    ofile.write(">" + 'OUTGROUP' + "\n" + OUTGROUP + "\n")
    ofile.close()

    # Build unrooted tree containing fictional outgorup
    relativeURL = 'data/' + sequenceFileName
    outputFileName = sequenceFileName
    substitutionModel = "PROTGAMMABLOSUM62"
    outputDirectory = "/Users/williamlin/Desktop/IW/IW/phyloSim-master/trees_unrooted/"
    threads = "10"
    randomSeed = np.random.randint(0, 2000)
    command = "./raxml -s data/" + sequenceFileName + " -n " + outputFileName + " -m " + substitutionModel + " -w " + outputDirectory + " -T " + threads + " -p " + str(
        randomSeed) + " 1>>log_file"
    os.system(command)

    #root based on midpoint
    current = Tree('trees_unrooted/RAxML_bestTree.' + sequenceFileName)
    current.set_outgroup(current.get_midpoint_outgroup())
    current.prune(entries)

    ofile = open('trees_unrooted/RAxML_bestTree.R_' + sequenceFileName, "w")
    ofile.write(current.write(format=1))
    ofile.close()
示例#7
0
def get_root(prefix, tree_file) :
    tree = Tree(tree_file, format=1)
    try:
        tree.set_outgroup( tree.get_midpoint_outgroup() )
    except :
        pass
    tree.write(outfile='{0}.rooted.nwk'.format(prefix), format=1)
    return '{0}.rooted.nwk'.format(prefix)
示例#8
0
def tree_to_tsvg(tree_file, contaminants=None, backpropagation=None):
    if contaminants is None:
        contaminants = set()
    tree_base = str(os.path.basename(tree_file))

    # what if they will use somethig different than Raxml? We should make some if statement here maybe.
    name_ = tree_base.split('.')[1]

    if os.path.isfile(f'{args.input}/{name_}.trimmed') is True:
        build_len, len_dict, trimmed_len = get_build_len(name_)
        len_info = f'Final Align Len: {build_len}, Trimmed Align Len: {trimmed_len}'
        len_dict = {k: round(v / trimmed_len, 2) for k, v in len_dict.items()}
    else:
        build_len, len_dict = get_build_len(name_)
        len_info = f'Final Align Len: {build_len}'
        len_dict = {k: round(v / build_len, 2) for k, v in len_dict.items()}

    if not backpropagation:
        table = open(f"{output_folder}/{name_.split('_')[0]}.tsv", 'w')
    else:
        table = open(f"{output_folder}/{name_.split('_')[0]}.tsv", 'r')

    top_ranked = get_best_candidates(tree_file)
    t = Tree(tree_file)
    ts = TreeStyle()
    R = t.get_midpoint_outgroup()
    t.set_outgroup(R)
    sus_clades = 0

    for node in t.traverse('preorder'):
        node_style = NodeStyle()
        node_style['vt_line_width'] = 3
        node_style['hz_line_width'] = 3
        node_style['vt_line_type'] = 0
        node_style['hz_line_type'] = 0

        if node.is_root() is False:
            if node.is_leaf() is False:
                # All internal nodes
                supp, sus_clades = format_nodes(node, node_style, sus_clades,
                                                t)
                node.add_face(supp, column=0, position="branch-bottom")
                node.set_style(node_style)
            else:
                # All leaves
                format_leaves(backpropagation, contaminants, node, node_style,
                              table, top_ranked, len_dict)
                node.set_style(node_style)

    title_face = TextFace(
        f'<{name_}  {len_info}, {sus_clades} suspicious clades>', bold=True)
    ts.title.add_face(title_face, column=1)
    t.render(
        f'{output_folder}/{name_}_tree.svg',
        tree_style=ts,
    )
    if not backpropagation:  # what what what?
        table.close()
示例#9
0
文件: phy.py 项目: zhaokai2014/wgd
def phylogenetic_tree_to_cluster_format(tree, pairwise_estimates):
    """
    Convert a phylogenetic tree to a 'cluster' data structure as in
    ``fastcluster``. The first two columns indicate the nodes that are joined by
    the relevant node, the third indicates the distance (calculated from branch
    lengths in the case of a phylogenetic tree) and the fourth the number of
    leaves underneath the node. Note that the trees are rooted using
    midpoint-rooting.

    Example of the data structure (output from ``fastcluster``)::

        [[   3.            7.            4.26269776    2.        ]
         [   0.            5.           26.75703595    2.        ]
         [   2.            8.           56.16007598    2.        ]
         [   9.           12.           78.91813609    3.        ]
         [   1.           11.           87.91756528    3.        ]
         [   4.            6.           93.04790855    2.        ]
         [  14.           15.          114.71302639    5.        ]
         [  13.           16.          137.94616373    8.        ]
         [  10.           17.          157.29055403   10.        ]]

    :param tree: newick tree file
    :param pairwise_estimates: pairwise Ks estimates data frame (pandas)
        (only the index is used)
    :return: clustering data structure, pairwise distances dictionary
    """
    id_map = {
        pairwise_estimates.index[i]: i for i in range(len(pairwise_estimates))}
    t = Tree(tree)

    # midpoint rooting
    midpoint = t.get_midpoint_outgroup()
    if not midpoint:  # midpoint = None when their are only two leaves
        midpoint = list(t.get_leaves())[0]
    t.set_outgroup(midpoint)
    logging.debug('Tree after rooting:\n{}'.format(t.get_ascii()))

    # algorithm for getting cluster data structure
    n = len(id_map)
    out = []
    pairwise_distances = {}
    for node in t.traverse('postorder'):
        if node.is_leaf():
            node.name = id_map[node.name]
            id_map[node.name] = node.name  # add identity map for renamed nodes
            # to id_map for line below
            pairwise_distances[node.name] = {
                id_map[x.name]: node.get_distance(x) for x in t.get_leaves()
            }
        else:
            node.name = n
            n += 1
            children = node.get_children()
            out.append(
                [children[0].name, children[1].name,
                 children[0].get_distance(children[1]),
                 len(node.get_leaves())])
    return np.array(out), pairwise_distances
示例#10
0
def parse_tree(tree_file):
    with open(tree_file, "rU") as tree_file:
        lines = ''
        for line in tree_file:
            lines += line.rstrip("\n")
        tree = Tree(lines)
        outgroup = tree.get_midpoint_outgroup()
        tree.set_outgroup(outgroup)
        return tree
示例#11
0
def parse_tree(tree_file):
    # Load a tree structure from a newick file.
    t = Tree(tree_file)

    #Need this otherwise groups derived from the tree are inaccurate
    t.resolve_polytomy()
    #Need to force root for consistency but should modify this behavior to support defined root

    root = t.get_midpoint_outgroup()
    #t.set_outgroup(root)
    return t
示例#12
0
def CreatePhyloGeneticTree(inputfile, outputfile, size):
    f = open(inputfile, "r")
    data = f.readlines()[0]
    f.close()
    tree = Tree(data)
    tree.set_outgroup(tree.get_midpoint_outgroup())
    ts = TreeStyle()
    ts.show_leaf_name = True
    ts.show_branch_length = False
    ts.show_branch_support = False
    ts.optimal_scale_level = "mid"
    t = tree.render(str(outputfile), w=size, units="px", tree_style=None)
示例#13
0
def midpoint_root(tree):
    '''
    Function to root trees produced from
    tree_build, using ete3 for midpoint
    rooting
    '''
    t = Tree(tree, format=1)
    # Calculate the midpoint node
    R = t.get_midpoint_outgroup()
    # and set it as tree outgroup
    t.set_outgroup(R)
    #Write rooted tree to file
    t.write(format=1, outfile=tree + ".rooted")
示例#14
0
    def ninja(names,
              profiles,
              embeded,
              handle_missing='pair_delete',
              **params):
        dist = distance_matrix.get_distance('symmetric', profiles,
                                            handle_missing)
        dist = dist / profiles.shape[1]
        dist_file = params['tempfix'] + 'dist.list'
        with open(dist_file, 'w') as fout:
            fout.write('    {0}\n'.format(dist.shape[0]))
            for n, d in enumerate(dist):
                fout.write('{0!s:10} {1}\n'.format(
                    n, ' '.join(['{:.6f}'.format(dd) for dd in d])))
        del dist, d
        free_memory = int(0.9 * psutil.virtual_memory().total / (1024.**2))
        ninja_out = Popen([
            'java', '-d64', '-Xmx' + str(free_memory) + 'M', '-jar',
            params['ninja_{0}'.format(
                platform.system())], '--in_type', 'd', dist_file
        ],
                          stdout=PIPE,
                          stderr=PIPE,
                          universal_newlines=True).communicate()
        if ninja_out[1].find('64-bit JVM') >= 0:
            ninja_out = Popen([
                'java', '-Xmx1200M', '-jar', params['ninja_{0}'.format(
                    platform.system())], '--in_type', 'd', dist_file
            ],
                              stdout=PIPE,
                              stderr=PIPE,
                              universal_newlines=True).communicate()
        with open(dist_file + '.nwk', 'wt') as fout:
            fout.write(ninja_out[0])
        tree = Tree(dist_file + '.nwk')
        for fname in glob(dist_file + '*'):
            os.unlink(fname)

        for node in tree.traverse():
            node.dist *= profiles.shape[1]

        try:
            tree.set_outgroup(tree.get_midpoint_outgroup())
            tree.unroot()
        except:
            pass

        for leaf in tree.get_leaves():
            leaf.name = names[int(leaf.name.strip("'"))]
        return tree
示例#15
0
def get_root(prefix, tree_file):
    tree = Tree(tree_file, format=1)
    for node in tree.traverse():
        if node.dist == 0 and node.up and not node.is_leaf():
            for c in node.get_children():
                node.up.add_child(c)
                c.up = node.up
            node.up.remove_child(node)
    try:
        tree.set_outgroup(tree.get_midpoint_outgroup())
    except:
        pass
    tree.write(outfile='{0}.rooted.nwk'.format(prefix), format=1)
    return '{0}.rooted.nwk'.format(prefix)
示例#16
0
def read_tree(file_name):
	if file_name == "-":
		file_name = "/dev/stdin"

	t = Tree(file_name);
	
	# ignore any errors with midpoint rooting
	R = t.get_midpoint_outgroup()
	try:
		t.set_outgroup(R)
	except Exception as e:
		pass

	t.ladderize(direction=1)
	return t
示例#17
0
def drawTree(distDict, alleleDict, sample_list, outgroup, prefix, bootstrap):
    '''
    Run neighbor-joining phylogenetic tree building algorithm on pairwise cell distance (saved in distDict)
    '''
    distMatrix = []
    targetMatrix = []
    pairwise_numTargets = []
    sample_numTargets = []
    for sample1 in sorted(sample_list):
        sample1_dist = []
        sample1_targets = []
        for sample2 in sorted(sample_list):
            sample_pair = tuple(sorted([sample1, sample2]))
            sample1_dist.append(distDict["sampleComp"][sample_pair]["dist"])
            sample1_targets.append(distDict["sampleComp"][sample_pair]["num_targets"])
            if sample1 != sample2:
                pairwise_numTargets.append(distDict["sampleComp"][sample_pair]["num_targets"])
            else:
                sample_numTargets.append(distDict["sampleComp"][sample_pair]["num_targets"])
        distMatrix.append(sample1_dist)
        targetMatrix.append(sample1_targets)
    if bootstrap is False: #Only output statistics for distance and number targets shared if for original tree (don't output for bootstrap resampling)
        statsOutput = open(prefix + ".buildPhylo.stats.txt", 'w')
        statsOutput.write("Number of Samples Analyzed:\t" + str(len(sample_list)) + "\n" + ','.join(sample_list) + "\n")
        statsOutput.write("Avg targets shared per pair of cells:\t" + str(float(sum(pairwise_numTargets) / len(pairwise_numTargets))) + "\t[" + str(min(pairwise_numTargets)) + "," + str(max(pairwise_numTargets)) + "]\n")
        statsOutput.write("Avg targets captured per single cell:\t" + str(float(sum(sample_numTargets) / len(sample_numTargets))) + "\t[" + str(min(sample_numTargets)) + "," + str(max(sample_numTargets)) + "]\n")
        for dist_indx,dist_list in enumerate(distMatrix): #Print matrix containing distances
            statsOutput.write(sorted(sample_list)[dist_indx] + "," + ",".join(str(round(i,3)) for i in dist_list) + "\n")
        for target_indx,target_list in enumerate(targetMatrix): #Print matrix containing number targets shared between each pair
            statsOutput.write(sorted(sample_list)[target_indx] + "," + ",".join(str(j) for j in target_list) + "\n")
        statsOutput.close()
        pickle.dump(distDict, open(prefix + ".buildPhylo.distDict.pkl", "wb")) #We want to print out the distance information for each single cell pair that was used to buildPhylo (this will be useful for downstream statistics)
    distObj = DistanceMatrix(distMatrix,sorted(sample_list))
    skbio_tree = nj(distObj, result_constructor=str)
    ete_tree = Tree(skbio_tree) #We use skbio to first make a tree from distance matrix then convert to ete tree
    if outgroup is "NA":
        return ete_tree
    else:
        if outgroup == "Midpoint":
            tree_midpoint = ete_tree.get_midpoint_outgroup()
            if tree_midpoint is not None:
                ete_tree.set_outgroup(tree_midpoint)
            else:
                print(ete_tree.write(format = 0))
                return None #We want to throw out tree if midpoint was not found
        else:
            ete_tree.set_outgroup(outgroup)
    return ete_tree
示例#18
0
 def root_iqtree(self):
     """Midpoint or user-defined root setting of iqtree.
     """
     from ete3 import Tree
     tree = Tree(self.outfiles["treefile"], format=0)
     root_ = self.root
     root = None
     if root_ == 'midpoint':
         root = tree.get_midpoint_outgroup()
     else:
         root = root_
     tree.set_outgroup(root)
     tree.ladderize(direction=1)
     # dist_formatter is to prevent scientific notation.
     # with branch lengths in scientific notation, ClusterPicker dies.
     tree.write(outfile=self.outfiles["rooted_treefile"],
                dist_formatter="%0.16f")
示例#19
0
def reroot_trees(tree_computation, species_tree_polytomies):
    
    nog_id,tree_time,tree_nw  = tree_computation
    assert tree_nw, "Tree newick is non existant for %d %d"%(nog_id,tree_time)
    t = Tree(tree_nw)
    
    if species_tree_polytomies:
        # reconciliation algorithm can only have one input with multifurcations/polytomies
        t.resolve_polytomy(recursive=True)
    
    node = t.get_midpoint_outgroup()
    if node:
        t.set_outgroup(node)
        rerooted_job = (nog_id,tree_time,t.write())
        return rerooted_job
    else:
        sys.stderr.write('Problems in rerooting %s %s %s'%(nog_id,tree_time,tree_nw))
        return (nog_id,tree_time,tree_nw)
示例#20
0
def njWithRoot(dis_matrix, muestraPmid):
    # no culcula la distancia, solo le da un formato mas adecuado a las distancias con los ids
    muestraPmidStr = [str(i) for i in muestraPmid]
    ver = dis_matrix.tolist()
    dm = DistanceMatrix(ver, muestraPmidStr)
    treeOrig = nj(dm, result_constructor=str)
    # ponerle raiz
    t = TreeEte(treeOrig)
    R = t.get_midpoint_outgroup()
    t.set_outgroup(R)
    # imprime el arbol
    #print(t)
    # imprime el newick
    tree = t.write(format=3)
    tree = TreeEte(tree, format=1)
    #print(tree)
    #a = newick_to_pairwise_nodes(tree)
    #print(a)
    return tree
示例#21
0
def analyze_tree(tree_filename, full_name_studied_gene, node_support):
    global result_nohgt
    global result_hgt
    global result_complex
    global result_unknown

    # Load a tree structure from a newick file
    gene_tree = Tree(tree_filename,format=0)


    if node_support != 0:
        node_supports = []
        for node in gene_tree.traverse("preorder"):
            node_supports.append(node.support)
        if all(i <= 1 for i in node_supports):
            node_support = node_support/100
        for node in gene_tree.traverse("preorder"):
            if "@" not in node.name:
                if node.support < node_support:
                    node.delete()

    no_TOI = True
    only_TOI = True

    # Check if no_TOI or only_TOI to speed up calculations
    for node in gene_tree:
        if "@TOI" in str(node):
            no_TOI = False
        elif "EGP" not in str(node) or "StudiedOrganism" not in str(node):
            only_TOI = False


    if only_TOI:
        return "only_TOI"

    # Root the tree using the midpoint
    R = gene_tree.get_midpoint_outgroup()
    if(R != None):
        gene_tree.set_outgroup(R)

    return analysis(gene_tree, full_name_studied_gene)
示例#22
0
def generateTestCases(n=500):

    hostCases = 0
    while hostCases < n / 10:

        try:
            host = withHost(8, .3)[0]
        except:
            continue

        guestCases = 0
        while guestCases < 10:

            printProgressBar(hostCases * 10 + guestCases, n)

            try:
                guest = withHost(8, .3, host)[1]
            except:
                 continue

            writeMapping(genMap(host, guest), 'guest.map')
            folder_name = '60_examples/' + str(hostCases*10 + guestCases) + '/'
            system('mkdir ' + folder_name)
            system('mv host.nwk guest.nwk sequences.fa guest.map ' + folder_name)
            system('mv ' + folder_name + 'guest.nwk ' + folder_name + 'guest_full.nwk')
            guest.write(format=1, outfile=folder_name + 'guest.nwk')

            #Run RAxML
            system('rm RAxML_*')
            raxml(folder_name + 'sequences.fa', 'nwk')
            rax = Tree('RAxML_bestTree.nwk')
            rax.set_outgroup(rax.get_midpoint_outgroup())
            name(rax)
            writeTree(rax, folder_name + 'RAxML_bestTree.nwk')

            guestCases += 1

        hostCases += 1
def get_tree(infile):
    tree = Tree(infile)

    for x in tree.traverse():
        if not x.is_leaf():
            continue
        x.name = x.name.replace("'", '').split('.')[0]
        if x.name == 'genome':
            x.name = 'NT12001_189'
    strains = {x.name.split('_')[0] for x in tree.traverse() if x.is_leaf()}
    for s in strains:
        nodes = sorted([x for x in tree.traverse() if x.name.startswith(s)],
                       key=lambda x: x.name)
        if len(nodes) == 1:
            continue
        for node in nodes[1:]:
            node.delete()
    for x in tree.traverse():
        if not x.is_leaf():
            continue
        x.name = x.name.split('_')[0]
    tree.set_outgroup(tree.get_midpoint_outgroup())

    return tree
示例#24
0
def get_species_tree(biodb):
    
    from ete3 import Tree,TreeStyle
    
    server, db = manipulate_biosqldb.load_db(biodb)
    
    species2n_complete_genomes, species2n_draft_genomes, species2completeness = get_species_data(server,
                                                                                                 biodb)

    
    sql_tree = 'select tree from reference_phylogeny t1 inner join biodatabase t2 on t1.biodatabase_id=t2.biodatabase_id ' \
               ' where t2.name="%s";' % biodb
               
    server, db = manipulate_biosqldb.load_db(biodb)
    complete_tree = Tree(server.adaptor.execute_and_fetchall(sql_tree,)[0][0])
    R = complete_tree.get_midpoint_outgroup()
    complete_tree.set_outgroup(R)

    sql = 'select distinct taxon_id,species from taxid2species_%s t1 ' \
          ' inner join species_curated_taxonomy_%s t2 on t1.species_id=t2.species_id;' % (biodb, 
                                                                                          biodb)
          
    taxon_id2species_id = manipulate_biosqldb.to_dict(server.adaptor.execute_and_fetchall(sql,))
    
    
    # changing taxon id to species id
    for leaf in complete_tree.iter_leaves():
        #print '%s --> %s' % (leaf.name, str(taxon_id2species_id[str(leaf.name)]))
        leaf.name = "%s" % str(taxon_id2species_id[str(leaf.name)])

    # attributing unique id to each node
    # if all node descendant have the same name, use that name as node name
    n = 0
    for node in complete_tree.traverse():
        if node.name=='':
            desc_list = list(set([i.name for i in node.iter_descendants()]))
            try:
                desc_list.remove('')
            except ValueError:
                pass
            if len(desc_list) != 1:
                node.name = '%sbb' % n
            else:
                node.name = desc_list[0]
            n+=1
 
    # Collapsing nodes while traversing
    # http://etetoolkit.org/docs/latest/tutorial/tutorial_trees.html#collapsing-nodes-while-traversing-custom-is-leaf-definition
    node2labels = complete_tree.get_cached_content(store_attr="name")
    
    def collapsed_leaf(node):
        if len(node2labels[node]) == 1:
            return True
        else:
            return False

    species_tree = Tree(complete_tree.write(is_leaf_fn=collapsed_leaf))
    
    
    for lf_count, lf in enumerate(species_tree.iter_leaves()):
        
        try:
            n_complete_genomes = species2n_complete_genomes[lf.name]
        except:
            n_complete_genomes = False
        try:
            n_draft_genomes = species2n_draft_genomes[lf.name]
        except:
            n_draft_genomes = False   

        if n_draft_genomes:
            c1 = round(species2completeness[lf.name][0])
            c2 = round(species2completeness[lf.name][1])
            if c1 == c2:
                completeness = "%s%%" % c1
            else:
                completeness = "%s-%s%%" % (c1, c2)
        if n_complete_genomes and n_draft_genomes:

            lf.name = "%s (%sc/%sd, %s)" % (lf.name,
                                       n_complete_genomes,
                                       n_draft_genomes,
                                       completeness)

        if n_complete_genomes and not n_draft_genomes:
            lf.name = "%s (%sc)" % (lf.name,
                                    n_complete_genomes)
        if not n_complete_genomes and n_draft_genomes:
            lf.name = "%s (%sd, %s)" % (lf.name,
                                    n_draft_genomes,
                                    completeness)
    
    return complete_tree, species_tree
示例#25
0
def plot_phylo(nw_tree,
               out_name,
               parenthesis_classif=True,
               show_support=False,
               radial_mode=False,
               root=False):

    from ete3 import Tree, AttrFace, TreeStyle, NodeStyle, TextFace
    import orthogroup2phylogeny_best_refseq_uniprot_hity

    ete2_tree = Tree(nw_tree, format=0)
    if root:
        R = ete2_tree.get_midpoint_outgroup()
        # and set it as tree outgroup
        ete2_tree.set_outgroup(R)
    ete2_tree.set_outgroup('Bacillus subtilis')
    ete2_tree.ladderize()

    if parenthesis_classif:
        print('parenthesis_classif!')
        name2classif = {}
        for lf in ete2_tree.iter_leaves():
            print(lf)
            try:
                classif = lf.name.split('_')[-2][0:-1]
                print('classif', classif)
                #lf.name = lf.name.split('(')[0]
                name2classif[lf.name] = classif
            except:
                pass
        classif_list = list(set(name2classif.values()))
        classif2col = dict(
            zip(
                classif_list,
                orthogroup2phylogeny_best_refseq_uniprot_hity.
                get_spaced_colors(len(classif_list))))

    for lf in ete2_tree.iter_leaves():

        #try:
        if parenthesis_classif:
            try:
                col = classif2col[name2classif[lf.name]]
            except:
                col = 'black'
        else:
            col = 'black'
            #print col
            #lf.name = '%s|%s-%s' % (lf.name, accession2name_and_phylum[lf.name][0],accession2name_and_phylum[lf.name][1])

        if radial_mode:
            ff = AttrFace("name", fsize=12, fstyle='italic')
        else:
            ff = AttrFace("name", fsize=12, fstyle='italic')
        #ff.background.color = 'red'
        ff.fgcolor = col

        lf.add_face(ff, column=0)

        if not show_support:
            print('support')
            for n in ete2_tree.traverse():
                print(n.support)
                nstyle = NodeStyle()
                if float(n.support) < 1:
                    nstyle["fgcolor"] = "red"
                    nstyle["size"] = 4
                    n.set_style(nstyle)
                else:
                    nstyle["fgcolor"] = "red"
                    nstyle["size"] = 0
                    n.set_style(nstyle)
        else:
            for n in ete2_tree.traverse():
                nstyle = NodeStyle()
                nstyle["fgcolor"] = "red"
                nstyle["size"] = 0
                n.set_style(nstyle)

        #nameFace = AttrFace(lf.name, fsize=30, fgcolor=phylum2col[accession2name_and_phylum[lf.name][1]])
        #faces.add_face_to_node(nameFace, lf, 0, position="branch-right")
        #
        #nameFace.border.width = 1
        '''
        except:
            col = 'red'
            print col
            lf.name = '%s| %s' % (lf.name, locus2organism[lf.name])

            ff = AttrFace("name", fsize=12)
            #ff.background.color = 'red'
            ff.fgcolor = col

            lf.add_face(ff, column=0)
        '''
        #n = TextFace(lf.name, fgcolor = "black", fsize = 12, fstyle = 'italic')
        #lf.add_face(n, 0)
    '''
    for n in ete2_tree.traverse():
       nstyle = NodeStyle()
       if n.support < 90:
           nstyle["fgcolor"] = "black"
           nstyle["size"] = 4
           n.set_style(nstyle)
       else:
           nstyle["fgcolor"] = "red"
           nstyle["size"] = 0
           n.set_style(nstyle)
    '''
    ts = TreeStyle()
    ts.show_leaf_name = False
    #ts.scale=2000
    #ts.scale=20000
    ts.show_branch_support = show_support

    if radial_mode:
        ts.mode = "c"
        ts.arc_start = -90
        ts.arc_span = 360
    ts.tree_width = 370
    ts.complete_branch_lines_when_necessary = True
    ete2_tree.render(out_name, tree_style=ts, w=900)
示例#26
0
def plot_tree_barplot(tree_file,
                      taxon2value_list_barplot,
                      header_list,
                      taxon2set2value_heatmap=False,
                      header_list2=False,
                      column_scale=True,
                      general_max=False,
                      barplot2percentage=False,
                      taxon2mlst=False):
    '''

    display one or more barplot

    :param tree_file:
    :param taxon2value_list:
    :param exclude_outgroup:
    :param bw_scale:
    :param barplot2percentage: list of bool to indicates if the number are percentages and the range should be set to 0-100

    :return:
    '''

    import matplotlib.cm as cm
    from matplotlib.colors import rgb2hex
    import matplotlib as mpl

    if taxon2mlst:
        mlst_list = list(set(taxon2mlst.values()))
        mlst2color = dict(zip(mlst_list, get_spaced_colors(len(mlst_list))))
        mlst2color['-'] = 'white'

    if isinstance(tree_file, Tree):
        t1 = tree_file
    else:
        t1 = Tree(tree_file)

    # Calculate the midpoint node
    R = t1.get_midpoint_outgroup()
    # and set it as tree outgroup
    t1.set_outgroup(R)

    tss = TreeStyle()
    value = 1
    tss.draw_guiding_lines = True
    tss.guiding_lines_color = "gray"
    tss.show_leaf_name = False

    if column_scale and header_list2:
        import matplotlib.cm as cm
        from matplotlib.colors import rgb2hex
        import matplotlib as mpl
        column2scale = {}
        for column in header_list2:
            values = taxon2set2value_heatmap[column].values()

            norm = mpl.colors.Normalize(vmin=min(values), vmax=max(values))
            cmap = cm.OrRd
            m = cm.ScalarMappable(norm=norm, cmap=cmap)
            column2scale[column] = m

    cmap = cm.YlGnBu  #YlOrRd#OrRd

    values_lists = taxon2value_list_barplot.values()

    scale_list = []
    max_value_list = []

    for n, header in enumerate(header_list):
        #print 'scale', n, header
        data = [float(i[n]) for i in values_lists]

        if barplot2percentage is False:
            max_value = max(data)  #3424182#
            min_value = min(data)  #48.23
        else:
            if barplot2percentage[n] is True:
                max_value = 100
                min_value = 0
            else:
                max_value = max(data)  #3424182#
                min_value = min(data)  #48.23
        norm = mpl.colors.Normalize(vmin=min_value, vmax=max_value)
        m1 = cm.ScalarMappable(norm=norm, cmap=cmap)
        scale_list.append(m1)
        if not general_max:
            max_value_list.append(float(max_value))
        else:
            max_value_list.append(general_max)

    for i, lf in enumerate(t1.iter_leaves()):

        #if taxon2description[lf.name] == 'Pirellula staleyi DSM 6068':
        #    lf.name = 'Pirellula staleyi DSM 6068'
        #    continue
        if i == 0:

            col_add = 0

            if taxon2mlst:
                header_list = ['MLST'] + header_list

            for col, header in enumerate(header_list):

                #lf.add_face(n, column, position="aligned")
                n = TextFace(' ')
                n.margin_top = 1
                n.margin_right = 2
                n.margin_left = 2
                n.margin_bottom = 1
                n.rotation = 90
                n.inner_background.color = "white"
                n.opacity = 1.
                n.hz_align = 2
                n.vt_align = 2

                tss.aligned_header.add_face(n, col_add + 1)

                n = TextFace('%s' % header)
                n.margin_top = 1
                n.margin_right = 2
                n.margin_left = 2
                n.margin_bottom = 2
                n.rotation = 270
                n.inner_background.color = "white"
                n.opacity = 1.
                n.hz_align = 2
                n.vt_align = 1
                tss.aligned_header.add_face(n, col_add)
                col_add += 2

            if header_list2:
                for col, header in enumerate(header_list2):
                    n = TextFace('%s' % header)
                    n.margin_top = 1
                    n.margin_right = 20
                    n.margin_left = 2
                    n.margin_bottom = 1
                    n.rotation = 270
                    n.hz_align = 2
                    n.vt_align = 2
                    n.inner_background.color = "white"
                    n.opacity = 1.
                    tss.aligned_header.add_face(n, col + col_add)

        if taxon2mlst:

            try:
                #if lf.name in leaf2mlst or int(lf.name) in leaf2mlst:
                n = TextFace(' %s ' % taxon2mlst[int(lf.name)])
                n.inner_background.color = 'white'
                m = TextFace('  ')
                m.inner_background.color = mlst2color[taxon2mlst[int(lf.name)]]
            except:
                n = TextFace(' na ')
                n.inner_background.color = "grey"
                m = TextFace('    ')
                m.inner_background.color = "white"

            n.opacity = 1.
            n.margin_top = 2
            n.margin_right = 2
            n.margin_left = 0
            n.margin_bottom = 2

            m.margin_top = 2
            m.margin_right = 0
            m.margin_left = 2
            m.margin_bottom = 2

            lf.add_face(m, 0, position="aligned")
            lf.add_face(n, 1, position="aligned")
            col_add = 2
        else:
            col_add = 0

        try:
            val_list = taxon2value_list_barplot[lf.name]
        except:
            if not taxon2mlst:
                val_list = ['na'] * len(header_list)
            else:
                val_list = ['na'] * (len(header_list) - 1)

        for col, value in enumerate(val_list):

            # show value itself
            try:
                n = TextFace('  %s  ' % str(value))
            except:
                n = TextFace('  %s  ' % str(value))
            n.margin_top = 1
            n.margin_right = 5
            n.margin_left = 10
            n.margin_bottom = 1
            n.inner_background.color = "white"
            n.opacity = 1.

            lf.add_face(n, col_add, position="aligned")
            # show bar
            try:
                color = rgb2hex(scale_list[col].to_rgba(float(value)))
            except:
                color = 'white'
            try:
                percentage = (value / max_value_list[col]) * 100
                #percentage = value
            except:
                percentage = 0
            try:
                maximum_bar = (
                    (max_value_list[col] - value) / max_value_list[col]) * 100
            except:
                maximum_bar = 0
            #maximum_bar = 100-percentage
            b = StackedBarFace([percentage, maximum_bar],
                               width=100,
                               height=10,
                               colors=[color, "white"])
            b.rotation = 0
            b.inner_border.color = "grey"
            b.inner_border.width = 0
            b.margin_right = 15
            b.margin_left = 0
            lf.add_face(b, col_add + 1, position="aligned")
            col_add += 2

        if taxon2set2value_heatmap:
            shift = col + col_add + 1

            i = 0
            for col, col_name in enumerate(header_list2):
                try:
                    value = taxon2set2value_heatmap[col_name][lf.name]
                except:
                    try:
                        value = taxon2set2value_heatmap[col_name][int(lf.name)]
                    except:
                        value = 0

                if int(value) > 0:
                    if int(value) > 9:
                        n = TextFace(' %i ' % int(value))
                    else:
                        n = TextFace(' %i   ' % int(value))
                    n.margin_top = 1
                    n.margin_right = 1
                    n.margin_left = 20
                    n.margin_bottom = 1
                    n.fgcolor = "white"
                    n.inner_background.color = rgb2hex(
                        column2scale[col_name].to_rgba(
                            float(value)))  #"orange"
                    n.opacity = 1.
                    lf.add_face(n, col + col_add, position="aligned")
                    i += 1
                else:
                    n = TextFace('  ')  #% str(value))
                    n.margin_top = 1
                    n.margin_right = 1
                    n.margin_left = 20
                    n.margin_bottom = 1
                    n.inner_background.color = "white"
                    n.opacity = 1.

                    lf.add_face(n, col + col_add, position="aligned")

        n = TextFace(lf.name, fgcolor="black", fsize=12, fstyle='italic')
        lf.add_face(n, 0)

    for n in t1.traverse():
        nstyle = NodeStyle()
        if n.support < 1:
            nstyle["fgcolor"] = "black"
            nstyle["size"] = 6
            n.set_style(nstyle)
        else:
            nstyle["fgcolor"] = "red"
            nstyle["size"] = 0
            n.set_style(nstyle)

    return t1, tss
示例#27
0
def plot_tree_barplot(tree_file, taxon2mlst, header_list):
    '''

    display one or more barplot

    :param tree_file:
    :param taxon2value_list:
    :param exclude_outgroup:
    :param bw_scale:
    :param barplot2percentage: list of bool to indicates if the number are percentages and the range should be set to 0-100

    :return:
    '''

    import matplotlib.cm as cm
    from matplotlib.colors import rgb2hex
    import matplotlib as mpl

    mlst_list = list(set(taxon2mlst.values()))
    mlst2color = dict(zip(mlst_list, get_spaced_colors(len(mlst_list))))
    mlst2color['-'] = 'white'

    if isinstance(tree_file, Tree):
        t1 = tree_file
    else:
        t1 = Tree(tree_file)

    # Calculate the midpoint node
    R = t1.get_midpoint_outgroup()
    # and set it as tree outgroup
    t1.set_outgroup(R)

    tss = TreeStyle()
    value = 1
    tss.draw_guiding_lines = True
    tss.guiding_lines_color = "gray"
    tss.show_leaf_name = False

    cmap = cm.YlGnBu  #YlOrRd#OrRd

    scale_list = []
    max_value_list = []

    for i, lf in enumerate(t1.iter_leaves()):

        #if taxon2description[lf.name] == 'Pirellula staleyi DSM 6068':
        #    lf.name = 'Pirellula staleyi DSM 6068'
        #    continue
        if i == 0:
            # header

            col_add = 0

            #lf.add_face(n, column, position="aligned")
            n = TextFace('MLST')
            n.margin_top = 1
            n.margin_right = 2
            n.margin_left = 2
            n.margin_bottom = 1
            n.rotation = 90
            n.inner_background.color = "white"
            n.opacity = 1.
            n.hz_align = 2
            n.vt_align = 2

            tss.aligned_header.add_face(n, col_add + 1)

        try:
            #if lf.name in leaf2mlst or int(lf.name) in leaf2mlst:
            n = TextFace(' %s ' % taxon2mlst[int(lf.name)])
            n.inner_background.color = 'white'
            m = TextFace('  ')
            m.inner_background.color = mlst2color[taxon2mlst[int(lf.name)]]
        except:
            n = TextFace(' na ')
            n.inner_background.color = "grey"
            m = TextFace('    ')
            m.inner_background.color = "white"

        n.opacity = 1.
        n.margin_top = 2
        n.margin_right = 2
        n.margin_left = 0
        n.margin_bottom = 2

        m.margin_top = 2
        m.margin_right = 0
        m.margin_left = 2
        m.margin_bottom = 2

        lf.add_face(m, 0, position="aligned")
        lf.add_face(n, 1, position="aligned")

        n = TextFace(lf.name, fgcolor="black", fsize=12, fstyle='italic')
        lf.add_face(n, 0)

    for n in t1.traverse():
        nstyle = NodeStyle()
        if n.support < 1:
            nstyle["fgcolor"] = "black"
            nstyle["size"] = 6
            n.set_style(nstyle)
        else:
            nstyle["fgcolor"] = "red"
            nstyle["size"] = 0
            n.set_style(nstyle)

    return t1, tss
示例#28
0
    # Parse command line arguments.
    cmdln = sys.argv
    pb_newick = cmdln[1]
    pb_newick_boots_only = cmdln[2]
    output_file_path = cmdln[3]

    # Initiate a tree style.
    ts = TreeStyle()
    ts.show_leaf_name = False

    # Parse trees.
    pb_newick_tree = Tree(pb_newick, format=0)
    pb_newick_boots_only_tree = Tree(pb_newick_boots_only, format=0)

    # Root trees on midpoint.
    pb_newick_tree.set_outgroup(pb_newick_tree.get_midpoint_outgroup())
    pb_newick_boots_only_tree.set_outgroup(
        pb_newick_boots_only_tree.get_midpoint_outgroup())

    # Add node support values as branch labels (modifies pb_newick_tree).
    add_combined_support_to_nodes_as_faces(pb_newick_tree,
                                           pb_newick_boots_only_tree)

    # Customize the node styles generally.
    customize_node_styles_for_visualization(pb_newick_tree)

    #####################################################

    # Write tree to pdf.

    ## Use this for running on personal computer:
示例#29
0
class EteTool():

    '''
    Plot ete3 phylogenetic profiles.
    
    - self.add_simple_barplot: add a barplot face from taxon2value dictionnary 
    - self.add_text_face: add text face
    - self.add_heatmap: add column with cells with value + colored background
    - self.rename_leaves: rename tree leaves from a dictionnary (old_name2new_name)
    '''

    def __init__(self,
                 tree_file):
               
        self.column_count = 0

        self.default_colors = ['#fc8d59', '#91bfdb', '#99d594', '#c51b7d', '#f1a340', '#999999']
        
        self.color_index = 0
         
        self.rotate = False
        
        # if not tree instance, considfer it as a path or a newick string
        print("TREE TYOE:", type(tree_file))
        if isinstance(tree_file, Tree):
            self.tree = tree_file
        elif isinstance(tree_file, ete3.phylo.phylotree.PhyloNode):
            self.tree = tree_file
        else:
            self.tree = Tree(tree_file)
        # Calculate the midpoint node
        R = self.tree.get_midpoint_outgroup()
        # and set it as tree outgroup
        try:
            self.tree.set_outgroup(R)
        except:
            pass
    
        self.tree.ladderize()
        
        self.tss = TreeStyle()
        self.tss.draw_guiding_lines = True
        self.tss.guiding_lines_color = "gray"
        self.tss.show_leaf_name = False     


    def add_stacked_barplot(self,
                            taxon2value_list,
                            header_name,
                            color_list=False):
        
        pass

    def rename_leaves(self,
                      taxon2new_taxon,
                      keep_original=False,
                      add_face=True):
        for i, lf in enumerate(self.tree.iter_leaves()):
            #print(dir(lf))
            #print((lf.faces[0]))
            #lf.faces
            # = None
            #print("Iter leaf names")
            #for i in lf.features:
            #    print("i", i)
            if not keep_original:
                if lf.name in taxon2new_taxon:
                    label = taxon2new_taxon[lf.name]
                else:
                    label = 'n/a'
            else:
                if lf.name in taxon2new_taxon:
                    label = '%s (%s)' % (taxon2new_taxon[lf.name], lf.name)
                else:
                    label = 'n/a'
            print ("add_face", add_face)
            if add_face:
                n = TextFace(label, fgcolor = "black", fsize = 12, fstyle = 'italic')
                lf.add_face(n, 0)
            lf.name = label
            #print(lf)
    
    def add_heatmap(self, 
                    taxon2value, 
                    header_name,
                    continuous_scale=False,
                    show_text=False):
        
        from metagenlab_libs.colors import get_continuous_scale
        
        self._add_header(header_name)
                
        if continuous_scale:
            color_scale = get_continuous_scale(taxon2value.values())
        
        for i, lf in enumerate(self.tree.iter_leaves()):
            
            if not lf.name in taxon2value:
                n = TextFace('')
            else:
                value = taxon2value[lf.name]

                if show_text:
                    n = TextFace('%s' % value)
                else:
                    n = TextFace('    ')

                n.margin_top = 2
                n.margin_right = 3
                n.margin_left = 3
                n.margin_bottom = 2
                n.hz_align = 1
                n.vt_align = 1
                n.border.width = 3
                n.border.color = "#ffffff"
                if continuous_scale:
                    n.background.color = rgb2hex(color_scale[0].to_rgba(float(value)))
                n.opacity = 1.
                i+=1

            if self.rotate:
                n.rotation = 270
            lf.add_face(n, self.column_count, position="aligned")
        
        self.column_count += 1


    def _add_header(self, 
                   header_name,
                   column_add=0):
        
        n = TextFace(f'{header_name}')
        n.margin_top = 1
        n.margin_right = 1
        n.margin_left = 20
        n.margin_bottom = 1
        n.hz_align = 2
        n.vt_align = 2
        n.rotation = 270
        n.inner_background.color = "white"
        n.opacity = 1.
        # add header
        self.tss.aligned_header.add_face(n, self.column_count-1+column_add)

    def _get_default_barplot_color(self,):
        
        col = self.default_colors[self.color_index]
        
        if self.color_index == 5:
            self.color_index = 0
        else:
            self.color_index += 1
        
        return col
        

    def add_simple_barplot(self, 
                           taxon2value, 
                           header_name,
                           color=False,
                           show_values=False,
                           substract_min=False,
                           highlight_cutoff=False,
                           highlight_reverse=False,
                           max_value=False):

        if not show_values:
            self._add_header(header_name, column_add=0)
        else:
            self._add_header(header_name, column_add=1)
        
        values_lists = [float(i) for i in taxon2value.values()]

        min_value = min(values_lists)
        
        if substract_min:
            values_lists = [i-min_value for i in values_lists]
            for taxon in list(taxon2value.keys()):
                taxon2value[taxon] = taxon2value[taxon]-min_value

        if not color:
            color = self._get_default_barplot_color()
                
        for i, lf in enumerate(self.tree.iter_leaves()):

            try:
                value = taxon2value[lf.name]
            except KeyError:
                value = 0

            if show_values:
                barplot_column = 1
                if substract_min:
                    real_value = value + min_value
                else:
                    real_value = value
                if isinstance(real_value, float):
                    a = TextFace(" %s " % str(round(real_value,2)))
                else:
                    a = TextFace(" %s " % str(real_value))
                a.margin_top = 1
                a.margin_right = 2
                a.margin_left = 5
                a.margin_bottom = 1
                if self.rotate:
                    a.rotation = 270
                lf.add_face(a, self.column_count, position="aligned")
            else:
                barplot_column = 0
            if not max_value:
                fraction_biggest = (float(value)/max(values_lists))*100
            else:
                fraction_biggest = (float(value)/max_value)*100
            fraction_rest = 100-fraction_biggest

            if highlight_cutoff:
                if substract_min:
                    real_value = value + min_value
                else:
                    real_value = value
                if highlight_reverse:
                    if real_value > highlight_cutoff:
                        lcolor = "grey"
                    else:
                        lcolor = color
                else:
                    if real_value < highlight_cutoff:
                        lcolor = "grey"
                    else:
                        lcolor = color
            else:
                lcolor = color
            
            b = StackedBarFace([fraction_biggest, fraction_rest], width=100, height=15,colors=[lcolor, 'white'])
            b.rotation= 0
            b.inner_border.color = "grey"
            b.inner_border.width = 0
            b.margin_right = 15
            b.margin_left = 0
            if self.rotate:
                b.rotation = 270
            lf.add_face(b, self.column_count + barplot_column, position="aligned")

        self.column_count += (1 + barplot_column)

    
    def add_barplot_counts(self,):
         # todo
        pass
    
    def remove_dots(self,):
        
        nstyle = NodeStyle()
        nstyle["shape"] = "sphere"
        nstyle["size"] = 0
        nstyle["fgcolor"] = "darkred"


        # Applies the same static style to all nodes in the tree. Note that,
        # if "nstyle" is modified, changes will affect to all nodes
        for n in self.tree.traverse():
            n.set_style(nstyle)
            
    def add_text_face(self,
                      taxon2text,
                      header_name,
                      color_scale=False):
        
        from metagenlab_libs.colors import get_categorical_color_scale
        
        if color_scale:
            value2color = get_categorical_color_scale(taxon2text.values())
        
        self._add_header(header_name)
       
        # add column
        for i, lf in enumerate(self.tree.iter_leaves()):
            if lf.name in taxon2text:
                n = TextFace('%s' % taxon2text[lf.name])
                if color_scale:
                    n.background.color = value2color[taxon2text[lf.name]]
            else:
                print(lf.name, "not in", taxon2text)
                n = TextFace('-')
            n.margin_top = 1
            n.margin_right = 10
            n.margin_left = 10
            n.margin_bottom = 1
            n.opacity = 1.
            if self.rotate:
                n.rotation= 270
            lf.add_face(n, self.column_count, position="aligned")
            
        self.column_count += 1
示例#30
0
class EteToolCompact():

    '''
    Plot ete3 phylogenetic profiles.
    
    - self.add_simple_barplot: add a barplot face from taxon2value dictionnary 
    - self.add_heatmap: add column with cells with value + colored background
    - self.rename_leaves: rename tree leaves from a dictionnary (old_name2new_name)
    - self.add_categorical_colorscale_legend: add legend
    - self.add_continuous_colorscale_legend: add legend
    '''

    def __init__(self,
                 tree_file):

        import math 
        
        self.column_count = 0
        
        self.rotate = False
             
        self.tree = Tree(tree_file)
        
        self.tree_length = len([i for i in self.tree.iter_leaves()])
        
        self.text_scale = (self.tree_length)*0.01 # math.log2
        
        self.default_colors = ['#fc8d59', '#91bfdb', '#99d594', '#c51b7d', '#f1a340', '#999999']
        
        self.color_index = 0
        
        # Calculate the midpoint node
        R = self.tree.get_midpoint_outgroup()
        # and set it as tree outgroup
        self.tree.set_outgroup(R)
    
        self.tss = TreeStyle()
        self.tss.draw_guiding_lines = True
        self.tss.guiding_lines_color = "gray"
        self.tss.show_leaf_name = False
        self.tss.branch_vertical_margin = 0


    def _get_default_barplot_color(self,):
        
        col = self.default_colors[self.color_index]
        
        if self.color_index == 5:
            self.color_index = 0
        else:
            self.color_index += 1
        
        return col

    def _add_header(self, 
                   header_name,
                   column_add=0):
        
        n = TextFace(f'{header_name}')
        n.margin_top = 1
        n.margin_right = 1
        n.margin_left = 20
        n.margin_bottom = 1
        n.hz_align = 2
        n.vt_align = 2
        n.rotation = 270
        n.inner_background.color = "white"
        n.opacity = 1.
        # add header
        self.tss.aligned_header.add_face(n, self.column_count-1+column_add)
 
 
    def rename_leaves(self,
                      taxon2new_taxon):
        for i, lf in enumerate(self.tree.iter_leaves()):
            n = TextFace(taxon2new_taxon[lf.name], fgcolor = "black", fsize = 12, fstyle = 'italic')
            lf.add_face(n, 0)
            
 
    def add_continuous_colorscale_legend(self,
                                         title,
                                         min_val, 
                                         max_val,
                                         scale):
        
        self.tss.legend.add_face(TextFace(f"{title}", fsize = 4 * self.text_scale), column=0)
        
        if min_val != max_val:
            n = TextFace(" " * int(self.text_scale), fsize = 4 * self.text_scale)
            n.margin_top = 1
            n.margin_right = 1
            n.margin_left = 10
            n.margin_bottom = 1
            n.inner_background.color = rgb2hex(scale[0].to_rgba(float(max_val)))
            
            n2 = TextFace(" " * int(self.text_scale), fsize = 4 * self.text_scale)
            n2.margin_top = 1
            n2.margin_right = 1
            n2.margin_left = 10
            n2.margin_bottom = 1
            n2.inner_background.color = rgb2hex(scale[0].to_rgba(float(min_val)))

            self.tss.legend.add_face(n, column=1)
            self.tss.legend.add_face(TextFace(f"{max_val} % (max)", fsize = 4 * self.text_scale), column=2)
            self.tss.legend.add_face(n2, column=1)
            self.tss.legend.add_face(TextFace(f"{min_val} % (min)", fsize = 4 * self.text_scale), column=2)   
        else:
            n2 = TextFace(" " * int(self.text_scale), fsize = 4 * self.text_scale)
            n2.margin_top = 1
            n2.margin_right = 1
            n2.margin_left = 10
            n2.margin_bottom = 1
            n2.inner_background.color = rgb2hex(scale[0].to_rgba(float(min_val)))

            self.tss.legend.add_face(n2, column=0)
            self.tss.legend.add_face(TextFace(f"{max_val} % Id", fsize = 4 * self.text_scale), column=1) 
    
    
    def add_categorical_colorscale_legend(self,
                                          title,
                                          scale):
        
        self.tss.legend.add_face(TextFace(f"{title}", fsize = 4 * self.text_scale), column=0)
        
        col = 1
        for n,value in enumerate(scale): 
            
            n2 = TextFace(" " * int(self.text_scale), fsize = 4 * self.text_scale)
            n2.margin_top = 1
            n2.margin_right = 1
            n2.margin_left = 10
            n2.margin_bottom = 1
            n2.inner_background.color = scale[value]

            self.tss.legend.add_face(n2, column=col)
            self.tss.legend.add_face(TextFace(f"{value}", fsize = 4 * self.text_scale), column=col+1)
            
            col+=2
            if col>16:
                self.tss.legend.add_face(TextFace(f"    ", fsize = 4 * self.text_scale), column=0)
                col = 1
    

    def add_simple_barplot(self, 
                           taxon2value, 
                           header_name,
                           color=False,
                           show_values=False,
                           substract_min=False,
                           max_value=False):

        print("scale factor", self.text_scale)

        if not show_values:
            self._add_header(header_name, column_add=0)
        else:
            self._add_header(header_name, column_add=1)
        
        
        values_lists = [float(i) for i in taxon2value.values()]
        
        min_value = min(values_lists)
        
        if substract_min:
            values_lists = [i-min_value for i in values_lists]
            for taxon in list(taxon2value.keys()):
                taxon2value[taxon] = taxon2value[taxon]-min_value
            
        if not color:
            color = self._get_default_barplot_color()
                
        for i, lf in enumerate(self.tree.iter_leaves()):

            try:
                value = taxon2value[lf.name]
            except:
                value = 0

            if show_values:
                barplot_column = 1
                if isinstance(value, float):
                    a = TextFace(" %s " % str(round(value,2)))
                else:
                    a = TextFace(" %s " % str(value))
                a.margin_top = 1
                a.margin_right = 2
                a.margin_left = 5
                a.margin_bottom = 1
                if self.rotate:
                    a.rotation = 270
                lf.add_face(a, self.column_count, position="aligned")
            else:
                barplot_column = 0
            if not max_value:
                fraction_biggest = (float(value)/max(values_lists))*100
            else:
                fraction_biggest = (float(value)/max_value)*100
            fraction_rest = 100-fraction_biggest

            b = StackedBarFace([fraction_biggest, fraction_rest], 
                               width=100 * (self.text_scale/3), 
                               height=18,
                               colors=[color, 'white'])
            b.rotation= 0
            #b.inner_border.color = "grey"
            #b.inner_border.width = 0
            b.margin_right = 10
            b.margin_left = 10
            b.hz_align = 2
            b.vt_align = 2
            b.rotable = False
            if self.rotate:
                b.rotation = 270
            lf.add_face(b, self.column_count + barplot_column, position="aligned")

        self.column_count += (1 + barplot_column)


    def add_heatmap(self,
                    taxon2value, 
                    header_name,
                    scale_type="continuous",
                    palette=False):
        
        from metagenlab_libs.colors import get_categorical_color_scale
        from metagenlab_libs.colors import get_continuous_scale
        
        if scale_type == "continuous":
            scale = get_continuous_scale(taxon2value.values())
            self.add_continuous_colorscale_legend("Closest hit identity", 
                                                  min(taxon2value.values()),
                                                  max(taxon2value.values()), 
                                                  scale)
        elif scale_type == "categorical":
            scale = get_categorical_color_scale(taxon2value.values())
            self.add_categorical_colorscale_legend("MLST",
                                                   scale)
        else:
            raise IOError("unknown type")
        
        for i, lf in enumerate(self.tree.iter_leaves()):
            n = TextFace("   " * int(self.text_scale))
            if lf.name in taxon2value:
                value = taxon2value[lf.name]
                n = TextFace("   " * int(self.text_scale))
                if scale_type == "categorical":
                    n.inner_background.color = scale[value]
                if scale_type == "continuous":
                    n.inner_background.color = rgb2hex(scale[0].to_rgba(float(value)))

            n.margin_top = 0
            n.margin_right = 0
            n.margin_left = 10
            n.margin_bottom = 0
            n.opacity = 1.
            if self.rotate:
                n.rotation= 270
            lf.add_face(n, self.column_count, position="aligned")
            
        self.column_count += 1
        
    
    def remove_labels(self,):
        for i, lf in enumerate(self.tree.iter_leaves()):
            n = TextFace("")
            lf.add_face(n, 0)
示例#31
0
parser.add_argument(
    '--verbose', action='store_true',
    help=('Print information about the outgroup (if any) taxa to standard '
          'error'))

args = parser.parse_args()

tree = Tree(args.treeFile.read())

if args.outgroupRegex:
    from re import compile
    regex = compile(args.outgroupRegex)
    taxa = [leaf.name for leaf in tree.iter_leaves() if regex.match(leaf.name)]

    if taxa:
        ca = tree.get_common_ancestor(taxa)
        if args.verbose:
            print('Taxa for outgroup:', taxa, file=sys.stderr)
            print('Common ancestor:', ca.name, file=sys.stderr)
            print('Common ancestor is tree:', tree == ca, file=sys.stderr)

        if len(taxa) == 1:
            tree.set_outgroup(tree & taxa[0])
        else:
            if ca == tree:
                tree.set_outgroup(tree.get_midpoint_outgroup())
            else:
                tree.set_outgroup(tree.get_common_ancestor(taxa))

print(tree.get_ascii())
示例#32
0
文件: __main__.py 项目: MDU-PHL/pando
def main():
    global YIELD_FILE
    global MLST_FILE
    global FORCE_MLST_SCHEME
    #Set up the file names for Nullarbor folder structure
    YIELD_FILE = 'yield.tab'
    MLST_FILE = 'mlst.tab'


    #Add MLST schemes to force their usage if that species is encountered
    #Only force schemes if there are two (e.g., A baumannii and E coli)
    FORCE_MLST_SCHEME = {"Acinetobacter baumannii": "abaumannii_2",
                         "Campylobacter jejuni": "campylobacter",
                         #"Citrobacter freundii": "cfreundii",
                         #"Cronobacter": "cronobacter",
                         "Enterobacter cloacae": "ecloacae",
                         "Escherichia coli": "ecoli",
                         #"Klebsiella oxytoca": "koxytoca",
                         #"Klebsiella pneumoniae": "kpneumoniae",
                         #"Pseudomonas aeruginosa": "paeruginosa"
                         "Shigella sonnei": "ecoli",
                         "Salmonella enterica": "senterica",
                         "Vibrio cholerae": "vcholerae"
                        }


    '''
    Read in the MDU-IDs from file. For each ID, instantiate an object of
    class Isolate.  This class associates QC data with the ID tag.
    Move the contigs for all isolates into a tempdir, with a temp 9-character
    filename.  Run andi phylogenomics on all the contig sets.  Infer an NJ tree
    using Bio Phylo from the andi-calculated distance matrix.  Correct the
    negative branch lengths in the NJ tree using ETE3.  Export the tree to
    file. Gather and combine the metadata for each ID as a super-matrix.
    Optionally, add LIMS metadata to the super-matrix from a LIMS excel
    spreadsheet option (adds MALDI-ToF, Submitting Lab ID, Submitting Lab
    species guess) and/or use the flag-if-new to highlight
    'new' isolates.  Export the tree and metadata to .csv, .tsv/.tab file.
    Export the 'isolates not found' to text file too.
    '''
    if not ARGS.subparser_name:
        PARSER.print_help()
        sys.exit()


    elif ARGS.subparser_name == 'version':
        from .utils.version import Version
        Version()
        sys.exit()

    else:# ARGS.subparser_name == "run":
        if ARGS.Nullarbor_folders:
            print('Nullarbor folder structure selected.')
            YIELD_FILE = 'yield.clean.tab'
            MLST_FILE = 'mlst2.tab'

        EXCEL_OUT = (f"{os.path.splitext(os.path.basename(ARGS.LIMS_request_sheet))[0]}" \
                     f"_results.xlsx")

        if ARGS.threads > cpu_count():
            sys.exit(f'Number of requested threads must be less than {cpu_count()}.')

        print(str(ARGS.threads) +' CPU processors requested.')


        #Check if final slash in manually specified wgs_qc path
        if ARGS.wgs_qc[-1] != '/':
            print('\n-wgs_qc path is entered as '+ARGS.wgs_qc)
            print('You are missing a final \'/\' on this path.')
            print('Exiting now.\n')
            sys.exit()



        #i) read in the IDs from file
        xls_table = get_isolate_request_IDs(ARGS.LIMS_request_sheet)
        IDs = list(set(xls_table.index.values))

        #base should be a global, given that it is used in other functions too.
        base = os.path.splitext(ARGS.LIMS_request_sheet)[0]

        #ii) Return a folder path to the QC data for each available ID
        #    using a wildcard search of the ID in IDs in ARGS.wgs_qc path.
        iso_paths = isolates_available(IDs)
        #Drop the path and keep the folder name
        isos = [i.split('/')[-1] for i in iso_paths]

        #iii) make tempdir to store the temp_contigs there for 'andi' analysis.
        assembly_tempdir = make_tempdir()

        #vi) Copy contigs to become temp_contigs into tempdir, only if andi
        #requested.
        #Translation dict to store {random 9-character filename: original filename}
        iso_ID_trans = {}
        #Dict to store each isolate under each consensus species#####maybe delete
        from collections import defaultdict
        isos_grouped_by_cons_spp = defaultdict(list)
        for iso in isos:
            #Instantiate an Isolate class for each isolate in isolates
            sample = Isolate(iso)
            #Next, we could just use iso_path+/contigs.fa, but that would skip
            #the if os.path.exists() test in sample.assembly(iso).
            assembly_path = sample.assembly()
            short_id = shortened_ID()
            #Store key,value as original_name,short_id for later retrieval.
            iso_ID_trans[iso] = short_id
            if ARGS.andi_run:
                cmd = 'ln -s '+assembly_path+' '+assembly_tempdir+'/'+short_id+\
                      '_contigs.fa'
                os.system(cmd)
                print('Creating symlink:', cmd)
        if len(list(iso_ID_trans.items())) > 0:
            with open(base+'_temp_names.txt', 'w') as tmp_names:
                print('\nTranslated isolate IDs:\nShort\tOriginal')
                for key, value in list(iso_ID_trans.items()):
                    print(value+'\t'+key)
                    tmp_names.write(value+'\t'+key+'\n')
        if ARGS.metadata_run:
           #summary_frames will store all of the metaDataFrames herein
            summary_frames = []
            n_isos = len(isos)
            if n_isos == 0:
                print('\nNo isolates detected in the path '+ARGS.wgs_qc+'.')
                print('Exiting now.\n')
                sys.exit()
            #Kraken set at 2 threads, so 36 processes can run on 72 CPUs
            #Create a pool 'p' of size based on number of isolates (n_isos)
            if n_isos <= ARGS.threads//2:
                p = Pool(n_isos)
            else:
                p = Pool(ARGS.threads//2)
            print(f'\nRunning kraken on the assemblies ({ARGS.assembly_name} files):')
            results_k_cntgs = p.map(kraken_contigs_multiprocessing, isos)
            print(results_k_cntgs)
            #concat the dataframe objects
            res_k_cntgs = pd.concat(results_k_cntgs, axis=0, sort=False)
            print('\nKraken_contigs results gathered from kraken on contigs...')

            #Multiprocessor retrieval of kraken results on reads.  Single thread
            #per job.
            if n_isos <= ARGS.threads:
                p = Pool(n_isos)
            else:
                p = Pool(ARGS.threads)
            results_k_reads = p.map(kraken_reads_multiprocessing, isos)
            #concat the dataframe objects
            res_k_reads = pd.concat(results_k_reads, axis=0)
            print('Kraken_reads results gathered from kraken.tab files...')

            #Multiprocessor retrieval of contig metrics.  Single process
            #per job.
            results_metrics_contigs = p.map(metricsContigs_multiprocessing, isos)
            res_m_cntgs = pd.concat(results_metrics_contigs, axis=0)
            print('Contig metrics gathered using \'fa -t\'...')

            #Multiprocessor retrieval of read metrics.  Single process
            #per job.
            results_metrics_reads = p.map(metricsReads_multiprocessing, isos)
            res_m_reads = pd.concat(results_metrics_reads, axis=0)
            print('Read metrics gathered from '+YIELD_FILE+' files...')

            #Multiprocessor retrieval of abricate results. Single process
            #per job.
            results_abricate = p.map(abricate_multiprocessing, isos)
            res_all_abricate = pd.concat(results_abricate, axis=0, sort=False)
            res_all_abricate.fillna('', inplace=True)
            print('Resistome hits gathered from abricate.tab files...')

            #append the dfs to the summary list of dfs
            summary_frames.append(res_k_cntgs)
            summary_frames.append(res_k_reads)
            summary_frames.append(res_m_cntgs)
            summary_frames.append(res_m_reads)
            summary_frames.append(res_all_abricate)

            #These next steps build up the metadata not yet obtained
            #(via mulitprocesses above), also replace the dm-matrix short names
            #with original names

            #Let's store the metadata for each isolate in summary_isos
            summary_isos = []

            #Let's populate summary_isos above, isolate by isolate (in series)
            c = 0
            for iso in isos:
                iso_df = []
                sample = Isolate(iso)
                short_id = iso_ID_trans[iso]
                species_cntgs = res_k_cntgs.loc[iso, 'sp_krkn1_cntgs']
                species_reads = res_k_reads.loc[iso, 'sp_krkn1_reads']
                if species_cntgs == species_reads:
                    species = species_cntgs
                else:
                    species = 'indet'
                mlst_df = sample.mlst(species, sample.assembly())
                iso_df.append(mlst_df)
                species_consensus = {'sp_krkn_ReadAndContigConsensus':species}
                species_cons_df = pd.DataFrame([species_consensus], index=[iso])
                iso_df.append(species_cons_df)
                iso_df_pd = pd.concat(iso_df, axis=1)
                summary_isos.append(iso_df_pd)

            #Glue the isolate by isolate metadata into a single df
            summary_isos_df = pd.concat(summary_isos)
            #Glue the dataframes built during multiprocessing processes
            summary_frames_df = pd.concat(summary_frames, axis=1)
            #Finish up with everything in one table!
            metadata_overall = pd.concat([xls_table, summary_isos_df, summary_frames_df],
                                         axis=1, sort=False)

            metadata_overall.fillna('', inplace=True)
            metadata_overall.index.name = 'ISOLATE'
            print('\nMetadata super-matrix:')
            #Write this supermatrix (metadata_overall) to csv and tab/tsv
            csv = os.path.abspath(base+'_metadataAll.csv')
            tsv = os.path.abspath(base+'_metadataAll.tab')
            json = os.path.abspath(base+'_metadataAll.json')
            metadata_overall.to_csv(sys.stdout)
            writer = pd.ExcelWriter(EXCEL_OUT)
            metadata_overall.to_excel(writer,'Sheet 1', freeze_panes=(1, 1))
            writer.save()
            print(f"\nResults written to {os.path.abspath(EXCEL_OUT)}")

            for k, v in zip(metadata_overall['sp_krkn_ReadAndContigConsensus'],
                            metadata_overall.index):
                isos_grouped_by_cons_spp[k.replace(' ', '_')].append(v)

        #Run andi?
        if ARGS.andi_run:
            #Run andi
            andi_mat = 'andi_'+ARGS.model_andi_distance+'dist_'+base+'.mat'
            andi_c = 'nice andi -j -m '+ARGS.model_andi_distance+' -t '+\
                      str(ARGS.threads)+' '+assembly_tempdir+'/*_contigs.fa > '+\
                      andi_mat
            print('\nRunning andi with: \''+andi_c+'\'')
            os.system(andi_c)

            #Read in the andi dist matrix, convert to lower triangle
            dm = read_file_lines(andi_mat)[1:]
            dm = lower_tri(dm)
            #Correct the names in the matrix
            for iso in isos:
                #Could do it this way, but this is slower than a nested loop
                #dm.names[dm.names.index(iso_ID_trans[iso])] = iso
                #real	0m9.417s
                #user	1m18.576s
                #sys	0m2.620s
                #Nested loop is faster
                for i in range(0, len(dm.names)):
                    #iso_ID_trans[iso] is the short_id
                    if dm.names[i] == iso_ID_trans[iso]:
                        dm.names[i] = iso
                #real	0m8.789s
                #user	1m14.637s
                #sys	0m2.420s

            #From the distance matrix in dm, infer the NJ tree
            from Bio.Phylo.TreeConstruction import DistanceTreeConstructor
            constructor = DistanceTreeConstructor()
            njtree = constructor.nj(dm)
            njtree.rooted = True
            from Bio import Phylo
            Phylo.write(njtree, 'temp.tre', 'newick')
            from ete3 import Tree
            t = Tree('temp.tre', format=1)
            #Get rid of negative branch lengths (an artefact, not an error, of NJ)
            for node in t.traverse():
                node.dist = abs(node.dist)
            t.set_outgroup(t.get_midpoint_outgroup())
            t_out = base+'_andi_NJ_'+ARGS.model_andi_distance+'dist.nwk.tre'
            t.write(format=1, outfile=t_out)
            print('Final tree (midpoint-rooted, NJ under '+\
                   ARGS.model_andi_distance+' distance) looks like this:')
            #Print the ascii tree
            print(t)
            #Remove the temp.tre
            os.remove('temp.tre')
            print('Tree (NJ under '+ARGS.model_andi_distance+\
                  ' distance, midpoint-rooted) written to '+t_out+'.')

        #Run roary?
        if ARGS.roary_run:
            roary_keepers = [
                            "accessory.header.embl",
                            "accessory.tab",
                            "accessory_binary_genes.fa",
                            "accessory_binary_genes.fa.newick",
                            "accessory_binary_genes_midpoint.nwk.tre",
                            "accessory_graph.dot",
                            "blast_identity_frequency.Rtab",
                            "clustered_proteins",
                            "core_accessory.header.embl",
                            "core_accessory.tab",
                            "core_accessory_graph.dot",
                            "core_gene_alignment.aln",
                            "gene_presence_absence.Ltab.csv",
                            "gene_presence_absence.Rtab",
                            "gene_presence_absence.csv",
                            "number_of_conserved_genes.Rtab",
                            "number_of_genes_in_pan_genome.Rtab",
                            "number_of_new_genes.Rtab",
                            "number_of_unique_genes.Rtab",
                            "pan_genome_reference.fa",
                            "pan_genome_sequences",
                            "summary_statistics.txt"
                            ]
            params = [(i, 'prokka') for i in isos if not
                      os.path.exists('prokka/'+i)]
            if len(params) > 0:
                print('\nRunning prokka:')
                if len(params) <= ARGS.threads//2:
                    p = Pool(len(params))
                else:
                    p = Pool(ARGS.threads//2)
                p.map(prokka, params)
            else:
                print('\nProkka files already exist. Let\'s move on to '+\
                      'the roary analysis...')

            #Run Roary on the species_consensus subsets.
            print('Now, let\'s run roary!')
            for k, v in list(isos_grouped_by_cons_spp.items()):
                print(k, v)
                n_isos = len(v)
                if n_isos > 1:
                    shutil.rmtree(base+'_'+k+'_roary', ignore_errors=True)
                    roary(base, k,
                          ' '.join(['prokka/'+iso+'/*.gff' for iso in v]))
                    roary_genes = pd.read_table(base+'_'+k+
                                                '_roary/gene_presence_absence.' +\
                                                'Rtab',
                                                index_col=0, header=0)
                    roary_genes = roary_genes.transpose()
                    roary_genes.to_csv(base+'_'+k+
                                       '_roary/gene_presence_absence.Ltab.csv',
                                       mode='w', index=True, index_label='name')
                    if n_isos > 2:
                        from ete3 import Tree
                        t = Tree(base+'_'+k+
                                 '_roary/accessory_binary_genes.fa.newick',
                                 format=1)
                        #Get rid of negative branch lengths (an artefact,
                        #not an error, of NJ)
                        for node in t.traverse():
                            node.dist = abs(node.dist)
                        t.set_outgroup(t.get_midpoint_outgroup())
                        t_out = base+'_'+k+\
                                '_roary/accessory_binary_genes_midpoint.nwk.tre'
                        t.write(format=1, outfile=t_out)
                        print('\nWritten midpoint-rooted roary tree.\n')
                        wd = os.getcwd()
                        os.chdir(base+'_'+k+'_roary')
                        for f_name in glob.glob('*'):
                            if f_name not in roary_keepers:
                                shutil.rmtree(f_name, ignore_errors=True)
                                os.remove(f_name)
                        os.chdir(wd)
                    if n_isos <= 2:
                        print('Need more than two isolates to have a meaningful '+\
                              'pangenome tree. No mid-point rooting of the ' +\
                              'pangenome tree performed.')
                    wd = os.getcwd()
                    os.chdir(base+'_'+k+'_roary')
                    os.system('python ../collapseSites.py -f core_gene_alignment.aln -i fasta -t '+str(ARGS.threads))
                    if os.path.exists('core_gene_alignment_collapsed.fasta'):
                        os.system('FastTree -nt -gtr < core_gene_alignment_collapsed.fasta > core_gene_FastTree_SNVs.tre')

                        #calc pairwise snp dist and write to file
                        with open('core_gene_alignment_collapsed.fasta', 'r') as inf:
                            from Bio import AlignIO
                            aln = AlignIO.read(inf, 'fasta')
                            pairs = []
                            for i in range(0,len(aln)):
                                lst = [(aln, i, j) for j in range(0, i+1)]
                                pairs.append(lst)
                            if len(pairs) <= ARGS.threads:
                                p = Pool(len(pairs))
                            else:
                                p = Pool(ARGS.threads)
                            print('Running pw comparisons in parallel...')
                            result = p.map(pw_calc, pairs)
                            summary = pd.concat(result, axis=0, sort=False)
                            summary.fillna('', inplace=True)
                            with open('core_gene_alignment_SNV_distances.tab', 'w') as distmat:
                                summary.to_csv(distmat, mode='w', sep='\t', index=True, index_label='name')

                    #convert roary output to fripan compatible
                    os.system('python ../roary2fripan.py '+base+'_'+k)
                    roary2fripan_strains_file = pd.read_table(base+'_'+k+
                                                              '.strains',
                                                              index_col=0,
                                                              header=0)
                    info_list = []
                    info_list.append(roary2fripan_strains_file)
                    info_list.append(metadata_overall.loc[v, :])
                    strains_info_out = pd.concat(info_list, axis=1, sort=False)
                    strains_info_out.to_csv(base+'_'+k+'.strains', mode='w',
                                            sep='\t', index=True,
                                            index_label='ID')
                    print('Updated '+base+'_'+k+'.strains with all metadata.')
                    os.system('cp '+base+'_'+k+'* ~/public_html/fripan')
                    os.chdir(wd)
                else:
                    print('Only one isolate in '+k+'. Need at least 2 isolates '+\
                          'to run roary.  Moving on...')

        #Keep the tempdirs created during the run
        if not ARGS.keep_tempdirs:
            shutil.rmtree(assembly_tempdir, ignore_errors=True)
            print('\nDeleted tempdir '+assembly_tempdir+'.')
        else:
            print('\nTempdir '+assembly_tempdir+' not deleted.')

        print('\nRun finished.')