示例#1
0
def _create_tree (tree,fasta,out,color):
    seqs = SeqGroup(fasta, format="fasta")
    t = Tree(tree)
    colors = _parse_color_file(color)
    node_names = t.get_leaf_names()
    for name in node_names:
        seq = seqs.get_seq(name)
        seqFace = SeqMotifFace(seq, seq_format="()")
        node = t.get_leaves_by_name(name)
        for i in range(0,len(node)):
            if name in colors:
                ns = NodeStyle()
                ns['bgcolor'] = colors[name]
                node[i].set_style(ns)
            node[i].add_face(seqFace,0,'aligned')
    t.render(out)
def open_tsv_population_size(tree_file, tsv_file):
    t = Tree(tree_file, format=1)
    csv = pd.read_csv(tsv_file, header=None, sep='\t')
    for index, (leaf_1, leaf_2, _, ne, _) in csv.iterrows():
        if leaf_1 == leaf_2:
            leaves = t.get_leaves_by_name(leaf_1)
            assert (len(leaves) == 1)
            n = leaves[0]
        else:
            n = t.get_common_ancestor([leaf_1, leaf_2])
        n.pop_size = ne

    pop_size_dict = dict()
    root_pop_size = float(t.pop_size)
    pop_size_dict["LogPopulationSize"] = [
        np.log(float(n.pop_size) / root_pop_size) for n in t.traverse()
    ]
    return pop_size_dict, t
def get_scorpios_aore_tree(gene_list, treefile, outgroups, outgr_gene):
    """
    Loads the AORe gene tree built by SCORPiOs.

    Args:
        gene_list (dict): dict of gene_names (key) : species_names (value) to keep in the tree
        treefile (str): name of the input tree file
        outgroups (list of str): list of outgroup species to keep/add in tree
        outgr_gene (str): name of the outgroup gene

    Returns:
        ete3.Tree : the loaded tree
    """

    tree = Tree(treefile)
    tleaves = tree.get_leaves()

    #remove sp name
    for leaf in tleaves:
        leaf.name = '_'.join(leaf.name.split('_')[:-1])

    tree.prune([i for i in tleaves if i.name in gene_list])
    leaves = {i.name for i in tree.get_leaves()}
    if leaves != set(gene_list.keys()):

        diff = set(gene_list.keys()).difference(leaves)

        outgr_node = tree.get_leaves_by_name(outgr_gene)[0]
        outgr_t = Tree()
        for gened in diff:
            if gene_list[gened] in outgroups:
                outgr_t.add_child(name=gened)
            else:
                return None  #TODO: print the kind of cases covered here?
        outgr_t.add_child(name=outgr_gene)
        outgr_node.add_child(outgr_t)
    tree.prune(tree.get_leaves())

    return tree
示例#4
0
logger.info("Loaded {} bootstrap trees.".format(len(bootstrap_trees)))

# Calculate the new bootstrap scores
# For each node in main_tree, that is not a leaf, count how often you find the same clade in the bs trees
for main_node in main_tree.traverse(strategy="levelorder"):
    if main_node.is_leaf():
        continue
    new_support = 0
    # Get all leaf names from the main tree
    clade_leaf_names = main_tree.get_leaf_names()
    # Now check for each bs_tree if the common ancestor of these same leaves have more leaves
    for bs_tree in bootstrap_trees:

        # Get all node objects for all the leaves by name
        clade_leaf_nodes_in_bs_tree = [
            bs_tree.get_leaves_by_name(leaf_name)[0]
            for leaf_name in clade_leaf_names
        ]
        # Get common ancestor in bs_tree
        common_ancestor_node = bs_tree.get_common_ancestor(
            clade_leaf_nodes_in_bs_tree)
        # Get leafnames of the common ancestor node and check if they are the same
        bs_tree_clade_leaf_names = common_ancestor_node.get_leaf_names()
        if set(clade_leaf_names) == set(bs_tree_clade_leaf_names):
            # Clades match!
            new_support = new_support + 1
    new_support = new_support / len(bootstrap_trees) * 100
    logger.debug("Support for internal node was {}, now is {}".format(
        main_node.support, new_support))
    main_node.support = new_support
class CoveringTree:
    ############################################################################################
    # Constructor
    ############################################################################################
    __metaclass__ = abc.ABCMeta

    def __init__(self, l0, l1_bounds, l2_bounds, idelta=0):
        #Initialize Base
        self.l0 = l0
        #Initialize the Bounds
        self.__l1_bounds = l1_bounds
        self.__l2_bounds = l2_bounds

        #Define Initial Rectangle P
        left = -self.__l1_bounds[1]
        top = 0
        width = self.__l1_bounds[1] + self.l0 + self.__l2_bounds[1]
        height = min(self.__l1_bounds[1], self.__l2_bounds[1])

        #Initialize initial Space where the workspace lie
        self.__Xspace = Rect(left, top, width, height)
        #Initialize the Root
        self.__initTree(self.__Xspace)
        #Initialize the minimal size of the rectangle
        self.__delta = idelta

        #Initialize plotting facilities
        self.__fig = plt.figure()

        self.__ax = self.__fig.add_subplot(111)
        self.__ax.axis('scaled')
        self.__ax.axis([
            self.__Xspace.left, self.__Xspace.right, self.__Xspace.top,
            self.__Xspace.bottom
        ])

        self.__tleveltext = self.__ax.text(0.98,0.9, 'Tree Level = {}'.format(0),\
                       verticalalignment='center', \
                       horizontalalignment='right', \
                       fontsize=13,\
                       transform = self.__ax.transAxes)

        self.__curdiam = self.__ax.text(0.02,0.9, 'd(Rectangle) = {}'.format(round(self.__d(self.__Xspace),4)),\
                       verticalalignment='center', \
                       horizontalalignment='left', \
                       fontsize=13,\
                       transform = self.__ax.transAxes)

    @abc.abstractmethod
    def getMaxVal(self, xbounds, ybounds, diam):
        raise NotImplementedError

    @abc.abstractmethod
    def getMinVal(self, xbounds, ybounds, diam):
        raise NotImplementedError
############################################################################################
# Private Members
############################################################################################

    def __vSplitter(self, iRect):
        newleft1 = iRect.left
        newtop1 = iRect.top
        newwidth1 = iRect.width / 2.0
        newheight1 = iRect.height
        Rleft = Rect(newleft1, newtop1, newwidth1, newheight1)

        newleft2 = iRect.left + iRect.width / 2.0
        newtop2 = iRect.top
        newwidth2 = iRect.width / 2.0
        newheight2 = iRect.height
        Rright = Rect(newleft2, newtop2, newwidth2, newheight2)
        return Rleft, Rright

    def __hSplitter(self, iRect):
        newleft1 = iRect.left
        newtop1 = iRect.top
        newwidth1 = iRect.width
        newheight1 = iRect.height / 2.0
        Rleft = Rect(newleft1, newtop1, newwidth1, newheight1)

        newleft2 = iRect.left
        newtop2 = iRect.top + iRect.height / 2.0
        newwidth2 = iRect.width
        newheight2 = iRect.height / 2.0
        Rright = Rect(newleft2, newtop2, newwidth2, newheight2)
        return Rleft, Rright

    def __d(self, iRect):
        return sqrt(iRect.width**2.0 + iRect.height**2.0)

    def g1(self, x):
        return x[0]**2.0 + x[1]**2.0 - (self.__l1_bounds[1]**2.0)

    def g2(self, x):
        return self.__l1_bounds[0]**2.0 - (x[0]**2.0) - (x[1]**2.0)

    def g3(self, x):
        return (
            (x[0] - self.l0)**2.0) + (x[1]**2.0) - (self.__l2_bounds[1]**2.0)

    def g4(self, x):
        return self.__l2_bounds[0]**2.0 - ((x[0] - self.l0)**2.0) - (x[1]**2.0)

    def g3m(self, x):
        return np.array([(x[0]**2.0) + (x[1]**2.0) - (self.__l2_bounds[1]**2.0)
                         ])

    def g4m(self, x):
        return np.array([self.__l2_bounds[0]**2.0 - (x[0]**2.0) - (x[1]**2.0)])

    def phi(self, x):
        return max(self.g1(x), self.g2(x), self.g3(x), self.g4(x))

    def __analyseRect(self, iRect):

        xmin = iRect.left
        xmax = iRect.left + iRect.width
        ymin = iRect.top
        ymax = iRect.top + iRect.height

        maxval = self.getMaxVal((xmin, xmax), (ymin, ymax), self.__d(iRect))
        #The whole rectangle is a part of the solution -> save it
        if (maxval < 0):
            #mark it as in range
            inrange = True
            return False, inrange

        minval = self.getMinVal((xmin, xmax), (ymin, ymax), self.__d(iRect))
        #There is no solution for the rectangle -> get rid of it
        if (minval > 0):
            #mark it as out of range
            inrange = False
            return False, inrange

        #The rectangle should be processed further
        return True, False

    def __addToTree(self, motherNode, iRect1, iRect2, childNodeLevel):
        # and add the nodes as children.
        oNode2 = motherNode.add_child(name='{}'.format(childNodeLevel))
        oNode1 = motherNode.add_child(name='{}'.format(childNodeLevel))
        #add features
        oNode2.add_feature('Rect', iRect2)
        oNode1.add_feature('Rect', iRect1)

    def __getNewRect(self, iRect, level):
        (oRleft, oRright) = self.__vSplitter(iRect) if (
            level % 2 == 0) else self.__hSplitter(iRect)
        return (oRleft, oRright)

    def __initTree(self, Xspace):
        self.__sTree = Tree('0;')  #name here is the level of the tree
        motherNode = self.__sTree.search_nodes(name='0')[0]
        motherNode.add_feature('Rect', Xspace)

    def __drawRect(self, iRect, fillIt, PlotEdges=True, inQI=False, inQE=True):
        if (PlotEdges):
            #Internal
            if inQI and inQE:
                edgeColor = 'black'
                LineStyle = 'solid'
                LineWidth = 1
                Alpha = 0.3
            #External
            if inQE and (not inQI):
                edgeColor = 'red'
                LineStyle = 'solid'
                LineWidth = 1
                Alpha = None
            #Out of range
            if (not inQE) and (not inQI):
                edgeColor = 'green'
                LineStyle = 'solid'
                LineWidth = 1
                Alpha = None

            self.__ax.add_patch(
                patches.Rectangle(
                    (iRect.left, iRect.top),  # (x,y)
                    iRect.width,  # width    
                    iRect.height,  # height
                    fill=inQI,
                    alpha=Alpha,
                    linestyle=LineStyle,
                    edgecolor=edgeColor,
                    lw=LineWidth))
        else:
            self.__ax.add_patch(
                patches.Rectangle(
                    (iRect.left, iRect.top),  # (x,y)
                    iRect.width,  # width    
                    iRect.height,  # height
                    fill=fillIt,
                    edgecolor='none'))
        plt.draw()


############################################################################################
# Public Members
############################################################################################

    def getCovering(self, maxLevels, saveasmovie=True):

        cdRect = self.__d(self.__Xspace)
        print 'The diameter of the initial rectangle is {}\n'.format(cdRect)

        bExit = False
        for curLevel in range(0, maxLevels):
            print 'Processing level {}'.format(curLevel)
            #pause(0.000001)
            #Get all the rectangles that are on some level of the tree
            curLevelNodes = self.__sTree.get_leaves_by_name(
                name='{}'.format(curLevel))
            #Loop over the rectangles
            for curLevelNode in curLevelNodes:
                #Get a rectangle from the tree level
                oRect = curLevelNode.Rect
                #Save current rectangle diameter
                if self.__d(oRect) < cdRect:
                    cdRect = self.__d(oRect)
                    print 'Current level diameter of the rectangle is {}\n'.format(
                        cdRect)

                inQE = False
                inQI = False
                #The diameter of the rectangle is less than or equal to the predefined delta value
                #see eq. 2.6: d(P^(i)) <= \delta
                if self.__d(oRect) <= self.__delta:
                    #It is too small to decide upon -> save it as if it was in range
                    cont = False
                    inrange = True
                    inQE = True
                    inQI = False
                    #Return the result on the next iteration
                    bExit = True
                #Otherwise
                else:
                    #Analyze it
                    #see eq. 2.4 and 2.5
                    (cont, inrange) = self.__analyseRect(oRect)
                    if inrange:
                        inQI = True
                        inQE = True
                #Save the obtained results
                if cont and (curLevel < maxLevels - 1):
                    (oRleft, oRright) = self.__getNewRect(oRect, curLevel)
                    self.__addToTree(curLevelNode, oRleft, oRright,
                                     curLevel + 1)
                else:
                    #save results to the analyzed node
                    curLevelNode.add_feature('Inrange', inrange)
                    curLevelNode.add_feature('inQI', inQI)
                    curLevelNode.add_feature('inQE', inQE)

            #All of the rectangles could be obtained on the next iterations are too small
            #so break it
            if bExit:
                print 'The result is obtained for {} levels'.format(curLevel)
                break

        #plt.show()

    def saveCoveringAsImage(self, fileName='./Images/{0}__{1:02d}_{2:02d}_{3:02d}_covering.jpeg'.format(datetime.date.today(), \
                                                           datetime.datetime.now().hour,\
                                                           datetime.datetime.now().minute,\
                                                           datetime.datetime.now().second),\
                                                           ResOnly = False, Grayscale = False):
        plt.cla()

        for leaf in self.__sTree.iter_leaves():
            if (ResOnly):
                #Draw the rectangle without edges
                self.__drawRect(leaf.Rect, leaf.Inrange, False)
            else:
                #Draw the rectangle with edges
                self.__drawRect(leaf.Rect, leaf.Inrange, True, leaf.inQI,
                                leaf.inQE)

        plt.draw()
        plt.pause(1)

        if (Grayscale):
            self.__fig.savefig('./Images/temp.png', dpi=600)
            Image.open('./Images/temp.png').convert("L").save(fileName)
        else:
            self.__fig.savefig(fileName, dpi=600)
示例#6
0
class Species:
    def __init__(self,
                 path,
                 max_unknowns=200,
                 contigs=3.0,
                 assembly_size=3.0,
                 mash=3.0,
                 assembly_summary=None,
                 processes=1):
        """Represents a collection of genomes in `path`

        :param path: Path to the directory of related genomes you wish to analyze.
        :param max_unknowns: Number of allowable unknown bases, i.e. not [ATCG]
        :param contigs: Acceptable deviations from median number of contigs
        :param assembly_size: Acceptable deviations from median assembly size
        :param mash: Acceptable deviations from median MASH distances
        :param assembly_summary: a pandas DataFrame with assembly summary information
        """
        self.max_unknowns = max_unknowns
        self.contigs = contigs
        self.assembly_size = assembly_size
        self.mash = mash
        self.assembly_summary = assembly_summary
        self.deviation_values = [max_unknowns, contigs, assembly_size, mash]
        self.ncpus = processes
        self.path = os.path.abspath(path)
        self.name = os.path.basename(os.path.normpath(path))
        self.log = logbook.Logger(self.name)
        self.qc_dir = os.path.join(self.path, "qc")
        self.label = '-'.join(map(str, self.deviation_values))
        self.qc_results_dir = os.path.join(self.qc_dir, self.label)
        self.passed_dir = os.path.join(self.qc_results_dir, "passed")
        self.stats_path = os.path.join(self.qc_dir, 'stats.csv')
        self.nw_path = os.path.join(self.qc_dir, 'tree.nw')
        self.dmx_path = os.path.join(self.qc_dir, 'dmx.csv')
        self.failed_path = os.path.join(self.qc_results_dir, "failed.csv")
        self.tree_img = os.path.join(self.qc_results_dir, "tree.svg")
        self.summary_path = os.path.join(self.qc_results_dir, "summary.txt")
        self.allowed_path = os.path.join(self.qc_results_dir, "allowed.p")
        self.paste_file = os.path.join(self.qc_dir, 'all.msh')
        # Figure out if defining these as None is necessary
        self.tree = None
        self.stats = None
        self.dmx = None
        if os.path.isfile(self.stats_path):
            self.stats = pd.read_csv(self.stats_path, index_col=0)
        if os.path.isfile(self.nw_path):
            self.tree = Tree(self.nw_path, 1)
        if os.path.isfile(self.failed_path):
            self.failed_report = pd.read_csv(self.failed_path, index_col=0)
        if os.path.isfile(self.dmx_path):
            try:
                self.dmx = pd.read_csv(self.dmx_path, index_col=0, sep="\t")
                self.log.info("Distance matrix read succesfully")
            except pd.errors.EmptyDataError:
                self.log.exception()
        self.metadata_path = os.path.join(self.qc_dir,
                                          "{}_metadata.csv".format(self.name))
        try:
            self.metadata_df = pd.read_csv(self.metadata_path,
                                           index_col="accession")
        except FileNotFoundError:
            self.metadata_df = pd.DataFrame(columns=["accession"])
        self.criteria = ["unknowns", "contigs", "assembly_size", "distance"]
        self.tolerance = {
            "unknowns": max_unknowns,
            "contigs": contigs,
            "assembly_size": assembly_size,
            "distance": mash
        }
        self.passed = self.stats
        self.failed = {}
        self.med_abs_devs = {}
        self.dev_refs = {}
        self.allowed = {"unknowns": max_unknowns}
        self.colors = {
            "unknowns": "red",
            "contigs": "green",
            "distance": "purple",
            "assembly_size": "orange"
        }
        self.genomes = [
            Genome.Genome(genome, self.assembly_summary)
            for genome in self.genome_paths
        ]
        self.assess_tree()

    def __str__(self):
        self.message = [
            "Species: {}".format(self.name),
            "Maximum Unknown Bases:  {}".format(self.max_unknowns),
            "Acceptable Deviations,", "Contigs, {}".format(self.contigs),
            "Assembly Size, {}".format(self.assembly_size),
            "MASH: {}".format(self.mash)
        ]
        return '\n'.join(self.message)

    def assess(f):
        # TODO: This can have a more general application if the pickling
        # functionality is implemented elsewhere
        @functools.wraps(f)
        def wrapper(self):
            try:
                assert self.stats is not None
                assert os.path.isfile(self.allowed_path)
                assert (sorted(self.genome_ids().tolist()) == sorted(
                    self.stats.index.tolist()))
                self.complete = True
                with open(self.allowed_path, 'rb') as p:
                    self.allowed = pickle.load(p)
                self.log.info('Already complete')
            except AssertionError:
                self.complete = False
                f(self)

        return wrapper

    def assess_tree(self):
        try:
            assert self.tree is not None
            assert self.stats is not None
            leaf_names = [
                re.sub(".fasta", "", i) for i in self.tree.get_leaf_names()
            ]
            assert (sorted(leaf_names) == sorted(self.stats.index.tolist()) ==
                    sorted(self.genome_ids().tolist()))
            self.tree_complete = True
            self.log.info("Tree already complete")
        except AssertionError:
            self.tree_complete = False

    @property
    def genome_paths(self, ext="fasta"):
        # Why doesn't this work when importing at top of file?
        """Returns a generator for every file ending with `ext`

        :param ext: File extension of genomes in species directory
        :returns: Generator of Genome objects for all genomes in species dir
        :rtype: generator
        """
        return [
            os.path.join(self.path, genome) for genome in os.listdir(self.path)
            if genome.endswith(ext)
        ]

    # @property
    # def genomes(self):
    #     """Returns a generator for every file ending with `ext`

    #     :param ext: File extension of genomes in species directory
    #     :returns: Generator of Genome objects for all genomes in species dir
    #     :rtype: generator
    #     """
    #     return (Genome.Genome(genome, self.assembly_summary) for genome in self.genome_paths)

    @property
    def total_genomes(self):
        return len(list(self.genomes))

    def sketches(self):
        return (i.msh for i in self.genomes)

    def genome_ids(self):
        ids = [i.name for i in self.genomes]
        return pd.Index(ids)

    # may be redundant. see genome_ids attrib
    @property
    def accession_ids(self):
        ids = [
            i.accession_id for i in self.genomes if i.accession_id is not None
        ]
        return ids

    def mash_paste(self):
        if os.path.isfile(self.paste_file):
            os.remove(self.paste_file)
        sketches = os.path.join(self.qc_dir, "*msh")
        cmd = "mash paste {} {}".format(self.paste_file, sketches)
        Popen(cmd, shell="True", stderr=DEVNULL).wait()
        self.log.info("MASH paste completed")
        if not os.path.isfile(self.paste_file):
            self.log.error("MASH paste failed")
            self.paste_file = None

    def mash_dist(self):
        cmd = "mash dist -p {} -t '{}' '{}' > '{}'".format(
            self.ncpus, self.paste_file, self.paste_file, self.dmx_path)
        Popen(cmd, shell="True", stderr=DEVNULL).wait()
        self.log.info("MASH distance completed")
        self.dmx = pd.read_csv(self.dmx_path, index_col=0, sep="\t")
        # Make distance matrix more readable
        names = [os.path.splitext(i)[0].split('/')[-1] for i in self.dmx.index]
        self.dmx.index = names
        self.dmx.columns = names
        self.dmx.to_csv(self.dmx_path, sep="\t")
        self.log.info("dmx.csv created")

    def sketch_genomes(self):
        """Sketch all genomes"""
        with Pool(ncpus=self.ncpus) as pool:
            self.log.info("{} cpus in pool".format(pool.ncpus))
            pool.map(Genome.sketch_genome, self.genome_paths)
        self.log.info("All genomes sketched")

    def get_tree(self):
        # Use decorator instead of if statement
        if self.tree_complete is False:
            from ete3.coretype.tree import TreeError
            import numpy as np
            # import matplotlib as mpl
            # mpl.use('TkAgg')
            from skbio.tree import TreeNode
            from scipy.cluster.hierarchy import weighted
            ids = ['{}.fasta'.format(i) for i in self.dmx.index.tolist()]
            triu = np.triu(self.dmx.as_matrix())
            hclust = weighted(triu)
            t = TreeNode.from_linkage_matrix(hclust, ids)
            nw = t.__str__().replace("'", "")
            self.tree = Tree(nw)
            # midpoint root tree
            try:
                self.tree.set_outgroup(self.tree.get_midpoint_outgroup())
            except TreeError as e:
                self.log.exception()
            self.tree.write(outfile=self.nw_path)

    def get_stats(self):
        """Get stats for all genomes. Concat the results into a DataFrame"""
        dmx_mean = [self.dmx.mean()] * len(self.genome_paths)
        with Pool(ncpus=self.ncpus) as pool:
            results = pool.map(Genome.mp_stats, self.genome_paths, dmx_mean)
        self.stats = pd.concat(results)
        self.stats.to_csv(self.stats_path)
        self.log.info("Generated stats and wrote to disk")

    def MAD(self, df, col):
        """Get the median absolute deviation for col"""
        MAD = abs(df[col] - df[col].median()).mean()
        return MAD

    def MAD_ref(MAD, tolerance):
        """Get the reference value for median absolute deviation"""
        dev_ref = MAD * tolerance
        return dev_ref

    def bound(df, col, dev_ref):
        lower = df[col].median() - dev_ref
        upper = df[col].median() + dev_ref
        return lower, upper

    def filter_unknown_bases(self):
        """Filter out genomes with too many unknown bases."""
        self.failed["unknowns"] = self.stats.index[
            self.stats["unknowns"] > self.tolerance["unknowns"]]
        self.passed = self.stats.drop(self.failed["unknowns"])
        self.log.info("Analyzed unknowns")

    def check_passed_count(f):
        """
        Count the number of genomes in self.passed.
        Commence with filtering only if self.passed has more than five genomes.
        """
        @functools.wraps(f)
        def wrapper(self, *args):
            if len(self.passed) > 5:
                f(self, *args)
            else:
                self.allowed[args[0]] = ''
                self.failed[args[0]] = ''
                self.log.info("Stopped filtering after {}".format(f.__name__))

        return wrapper

    @check_passed_count
    def filter_contigs(self, criteria):
        """
        Only look at genomes with > 10 contigs to avoid throwing off the
        median absolute deviation.
        Median absolute deviation - Average absolute difference between
        number of contigs and the median for all genomes
        Extract genomes with < 10 contigs to add them back in later.
        Add genomes with < 10 contigs back in
        """
        eligible_contigs = self.passed.contigs[self.passed.contigs > 10]
        not_enough_contigs = self.passed.contigs[self.passed.contigs <= 10]
        # TODO Define separate function for this
        med_abs_dev = abs(eligible_contigs - eligible_contigs.median()).mean()
        self.med_abs_devs["contigs"] = med_abs_dev
        # Define separate function for this
        # The "deviation reference"
        dev_ref = med_abs_dev * self.contigs
        self.dev_refs["contigs"] = dev_ref
        self.allowed["contigs"] = eligible_contigs.median() + dev_ref
        self.failed["contigs"] = eligible_contigs[
            abs(eligible_contigs - eligible_contigs.median()) > dev_ref].index
        eligible_contigs = eligible_contigs[
            abs(eligible_contigs - eligible_contigs.median()) <= dev_ref]
        eligible_contigs = pd.concat([eligible_contigs, not_enough_contigs])
        eligible_contigs = eligible_contigs.index
        self.passed = self.passed.loc[eligible_contigs]
        self.log.info("Analyzed contigs")

    @check_passed_count
    def filter_MAD_range(self, criteria):
        """
        Filter based on median absolute deviation.
        Passing values fall within a lower and upper bound.
        """
        # Get the median absolute deviation
        med_abs_dev = abs(self.passed[criteria] -
                          self.passed[criteria].median()).mean()
        dev_ref = med_abs_dev * self.tolerance[criteria]
        lower = self.passed[criteria].median() - dev_ref
        upper = self.passed[criteria].median() + dev_ref
        allowed_range = (str(int(x)) for x in [lower, upper])
        allowed_range = '-'.join(allowed_range)
        self.allowed[criteria] = allowed_range
        self.failed[criteria] = self.passed[
            abs(self.passed[criteria] -
                self.passed[criteria].median()) > dev_ref].index
        self.passed = self.passed[abs(
            self.passed[criteria] - self.passed[criteria].median()) <= dev_ref]
        self.log.info("Filtered based on median absolute deviation range")

    @check_passed_count
    def filter_MAD_upper(self, criteria):
        """
        Filter based on median absolute deviation.
        Passing values fall under the upper bound.
        """
        # Get the median absolute deviation
        med_abs_dev = abs(self.passed[criteria] -
                          self.passed[criteria].median()).mean()
        dev_ref = med_abs_dev * self.tolerance[criteria]
        upper = self.passed[criteria].median() + dev_ref
        self.failed[criteria] = self.passed[
            self.passed[criteria] > upper].index
        self.passed = self.passed[self.passed[criteria] <= upper]
        upper = "{:.4f}".format(upper)
        self.allowed[criteria] = upper
        self.log.info("Filtered based on MAD upper bound")

    def base_node_style(self):
        from ete3 import NodeStyle, AttrFace
        nstyle = NodeStyle()
        nstyle["shape"] = "sphere"
        nstyle["size"] = 2
        nstyle["fgcolor"] = "black"
        for n in self.tree.traverse():
            n.set_style(nstyle)
            if re.match('.*fasta', n.name):
                nf = AttrFace('name', fsize=8)
                nf.margin_right = 150
                nf.margin_left = 3
                n.add_face(nf, column=0)
        self.log.info("Applied base node style")

    # Might be better in a layout function
    def style_and_render_tree(self, file_types=["svg"]):
        from ete3 import TreeStyle, TextFace, CircleFace
        ts = TreeStyle()
        title_face = TextFace(self.name.replace('_', ' '), fsize=20)
        title_face.margin_bottom = 10
        ts.title.add_face(title_face, column=0)
        ts.branch_vertical_margin = 10
        ts.show_leaf_name = False
        # Legend
        ts.legend.add_face(TextFace(""), column=1)
        for category in ["Allowed", "Tolerance", "Filtered", "Color"]:
            category = TextFace(category, fsize=8, bold=True)
            category.margin_bottom = 2
            category.margin_right = 40
            ts.legend.add_face(category, column=1)
        for i, criteria in enumerate(self.criteria, 2):
            title = criteria.replace("_", " ").title()
            title = TextFace(title, fsize=8, bold=True)
            title.margin_bottom = 2
            title.margin_right = 40
            cf = CircleFace(4, self.colors[criteria], style="sphere")
            cf.margin_bottom = 5
            filtered_count = len(
                list(filter(None, self.failed_report.criteria == criteria)))
            filtered = TextFace(filtered_count, fsize=8)
            filtered.margin_bottom = 5
            allowed = TextFace(self.allowed[criteria], fsize=8)
            allowed.margin_bottom = 5
            allowed.margin_right = 25
            tolerance = TextFace(self.tolerance[criteria], fsize=8)
            tolerance.margin_bottom = 5
            ts.legend.add_face(title, column=i)
            ts.legend.add_face(allowed, column=i)
            ts.legend.add_face(tolerance, column=i)
            ts.legend.add_face(filtered, column=i)
            ts.legend.add_face(cf, column=i)
        for f in file_types:
            out_tree = os.path.join(self.qc_results_dir, 'tree.{}'.format(f))
            self.tree.render(out_tree, tree_style=ts)
            self.log.info("tree.{} generated".format(f))

    def color_tree(self):
        from ete3 import NodeStyle
        self.base_node_style()
        for genome in self.failed_report.index:
            n = self.tree.get_leaves_by_name(genome + ".fasta").pop()
            nstyle = NodeStyle()
            nstyle["fgcolor"] = self.colors[self.failed_report.loc[genome,
                                                                   'criteria']]
            nstyle["size"] = 9
            n.set_style(nstyle)
        self.style_and_render_tree()

    def filter(self):
        self.filter_unknown_bases()
        self.filter_contigs("contigs")
        self.filter_MAD_range("assembly_size")
        self.filter_MAD_upper("distance")
        with open(self.allowed_path, 'wb') as p:
            pickle.dump(self.allowed, p)
            self.log.info("Pickled results of filtering")
        self.summary()
        self.write_failed_report()

    def write_failed_report(self):
        from itertools import chain
        if os.path.isfile(self.failed_path):
            os.remove(self.failed_path)
        ixs = chain.from_iterable([i for i in self.failed.values()])
        self.failed_report = pd.DataFrame(index=ixs, columns=["criteria"])
        for criteria in self.failed.keys():
            if type(self.failed[criteria]) == pd.Index:
                self.failed_report.loc[self.failed[criteria],
                                       'criteria'] = criteria
        self.failed_report.to_csv(self.failed_path)
        self.log.info("Wrote failed report")

    def summary(self):
        summary = [
            self.name, "Unknown Bases",
            "Allowed: {}".format(self.allowed["unknowns"]),
            "Tolerance: {}".format(self.tolerance["unknowns"]),
            "Filtered: {}".format(len(self.failed["unknowns"])), "\n",
            "Contigs", "Allowed: {}".format(self.allowed["contigs"]),
            "Tolerance: {}".format(
                self.tolerance["contigs"]), "Filtered: {}".format(
                    len(self.failed["contigs"])), "\n", "Assembly Size",
            "Allowed: {}".format(self.allowed["assembly_size"]),
            "Tolerance: {}".format(self.tolerance["assembly_size"]),
            "Filtered: {}".format(len(self.failed["assembly_size"])), "\n",
            "MASH", "Allowed: {}".format(self.allowed["distance"]),
            "Tolerance: {}".format(self.tolerance["distance"]),
            "Filtered: {}".format(len(self.failed["distance"])), "\n"
        ]
        summary = '\n'.join(summary)
        with open(os.path.join(self.summary_path), "w") as f:
            f.write(summary)
            self.log.info("Wrote QC summary")
        return summary

    def link_genomes(self):
        if not os.path.exists(self.passed_dir):
            os.mkdir(self.passed_dir)
        for genome in self.passed.index:
            fname = "{}.fasta".format(genome)
            src = os.path.join(self.path, fname)
            dst = os.path.join(self.passed_dir, fname)
            try:
                os.link(src, dst)
            except FileExistsError:
                pass
        self.log.info("Links created for genomes that passed QC")

    @assess
    def qc(self):
        if not os.path.isdir(self.qc_dir):
            os.mkdir(self.qc_dir)
        if not os.path.isdir(self.qc_results_dir):
            os.mkdir(self.qc_results_dir)
        self.sketch_genomes()
        self.mash_paste()
        self.mash_dist()
        self.get_stats()
        self.filter()
        self.link_genomes()
        self.get_tree()
        self.color_tree()
        self.log.info("qc command completed")

    def metadata(self):
        metadata = []
        for genome in self.genomes:
            if genome.accession_id in self.metadata_df.index:
                continue
            genome.get_metadata()
            metadata.append(genome.metadata)
        self.metadata_df = pd.concat(
            [self.metadata_df,
             pd.DataFrame(metadata).set_index("accession")])
        self.metadata_df.to_csv(self.metadata_path)
        self.log.info("Completed metadata command")
示例#7
0
                        newchar = tchar
            #            return(mindist, newchar) #nope, out!
            #        else:
            #            return([])    #nope, out!
                else:
                    out = myTraverse(ch, chdist, mindist, newchar, ind)
                    if out[0] < mindist and out[1] in "atgcATGC":
                        mindist = out[0]
                        newchar = out[1]
        # print (" ".join(["node in question: ", ch.name, "chdist", str(chdist), "mindist", str(mindist),"newchar", newchar, "ind", str(ind)]))
    return (mindist, newchar)


for seqid in ambiDict.keys():
    print("seqid " + seqid)
    node = tree.get_leaves_by_name(name=seqid)[0]
    for ind in ambiDict[seqid]:
        tnode = node
        mindist = 1000.0
        newchar = "X"
        fixed = False
        while (not tnode.is_root()
               and (not fixed or node.get_distance(tnode) < mindist)):
            siss = tnode.get_sisters()
            for s in siss:
                #  print ("sis "+s.name)
                curdist = node.get_distance(s)
                sismindist, sisnewchar = myTraverse(s, curdist, mindist,
                                                    newchar, ind)
                if sismindist < mindist:
                    mindist = sismindist
示例#8
0
class FamilyOrthologies():
    """
    FamilyOrthologies object containing the outgroup gene, the genes in each orthogroup, all genes
    in the family in the orthologytable, and the corresponding constrained gene tree topology.
    """
    def __init__(self):
        """
        Class builder, initialized with empty objects
        """

        self.outgroup_gene = ''
        self.orthogroup_a, self.orthogroup_b = [], []
        self.genes_in_orthotable = []
        self.ctree = Tree()

    def update_orthologies(self, outgroup_gene, orthogroup):
        """
        Adds genes of one orthogroup. a's and b's are arbitrary.

        Args:
            outgroup_gene (str): name of the outgroup gene
            orthogroup (list): list of names of genes in one orthogroup
        """

        self.outgroup_gene = outgroup_gene

        if self.orthogroup_a == []:
            self.orthogroup_a = orthogroup

        else:
            self.orthogroup_b = orthogroup

    def to_constrained_tree(self):
        """
        Transforms the orthogroups + outgroup into a constrained topology, represented by an ete3
        Tree object.
        """

        #put outgroup gene as outgroup
        self.ctree.add_child(name=self.outgroup_gene)

        #add 3R duplication node
        dup_3r = self.ctree.add_child(name="dup_3r")

        #if only one orthogroup
        if self.orthogroup_a and not self.orthogroup_b:

            #add genes in group orhogroup a
            for i in self.orthogroup_a:
                i = dup_3r.add_child(name=i)

        #if two orthogroups
        elif self.orthogroup_a and self.orthogroup_b:

            #add internal node for group a
            ortho_a = dup_3r.add_child(name="orthoA")

            #add genes in group a
            for i in self.orthogroup_a:
                i = ortho_a.add_child(name=i)

            #add internal node for group b
            ortho_b = dup_3r.add_child(name="orthoB")

            #add genes in group b
            for i in self.orthogroup_b:
                i = ortho_b.add_child(name=i)

    def update_constrained_tree(self, leaves_to_place, ensembl_tree):
        """
        Adds, to the constrained tree, leaves that are under the lca in the original subtree and
        were predicted to be in the family (orthotable). These can be, for instance, genes of
        lowcov species that were discarded from the synteny analysis. They will be placed in the
        same orthogroup as its closest neighbour in the original ensembl tree.

        Args:
            leaves_to_place (list): list of the name of genes to add to the ctree.
            ensembl_tree (ete3 Tree): original gene tree.

        """

        place_in_tree = {}
        genes = [
            i.name for i in self.ctree.get_leaves()
            if i.name != self.outgroup_gene
        ]

        #find closest neighbours in original gene tree
        while leaves_to_place:
            gene_to_place = leaves_to_place.pop()
            node = ensembl_tree.get_leaves_by_name(name=gene_to_place)[0]
            place_in_tree[gene_to_place] = gt.closest_gene_in_tree(
                ensembl_tree, node, genes)

        # we place all genes once all neighbours are found, to not impact the position of others
        for gene_to_place in place_in_tree:

            min_g = place_in_tree[gene_to_place].name

            #add gene in the same orthogroup as the closest gene
            sis = self.ctree.get_leaves_by_name(name=min_g)[0]
            sis.add_sister(name=gene_to_place)

            #update orthogroups
            if min_g in self.orthogroup_a:
                self.orthogroup_a.append(gene_to_place)

            elif min_g in self.orthogroup_b:
                self.orthogroup_b.append(gene_to_place)

    def is_multigenic(self):
        """
        Filters multigenic subtrees, where more duplications than just the 3R duplication is
        involved. These families are often full of errors in original gene trees and difficult to
        solve.

        Returns:
            bool: Is the subtree multigenic (True) or not (False)
        """

        is_multigenic = False

        subtrees = [self.orthogroup_a, self.orthogroup_b]
        for subtree in subtrees:

            if subtree:
                species = [i.split('_')[-1] for i in subtree]
                nb_genes_in_teleost = Counter(species)
                nbg = 0
                for key in nb_genes_in_teleost:
                    nbg += nb_genes_in_teleost[key]

                #average gene per species
                mean_gene_in_teleost = nbg / float(len(nb_genes_in_teleost))

                #do not correct if more than 1.5 gene per species and more than 2 species
                if mean_gene_in_teleost > 1.5 and len(nb_genes_in_teleost) > 2:
                    is_multigenic = True
                    break

        return is_multigenic
示例#9
0
class Species:
    def __init__(
        self,
        path,
        max_unknowns=200,
        contigs=3.0,
        assembly_size=3.0,
        mash=3.0,
        assembly_summary=None,
        metadata=None,
    ):
        """Represents a collection of genomes in `path`

        :param path: Path to the directory of related genomes you wish to analyze.
        :param max_unknowns: Number of allowable unknown bases, i.e. not [ATCG]
        :param contigs: Acceptable deviations from median number of contigs
        :param assembly_size: Acceptable deviations from median assembly size
        :param mash: Acceptable deviations from median MASH distances
        :param assembly_summary: a pandas DataFrame with assembly summary information
        """
        self.path = os.path.abspath(path)
        self.deviation_values = [max_unknowns, contigs, assembly_size, mash]
        self.label = "-".join(map(str, self.deviation_values))
        self.paths = config.Paths(root=Path(self.path),
                                  subdirs=["metadata", ".logs", "qc"])
        self.qc_results_dir = os.path.join(self.paths.qc, self.label)
        if not os.path.isdir(self.qc_results_dir):
            os.mkdir(self.qc_results_dir)
        self.name = os.path.basename(os.path.normpath(path))
        self.log = logbook.Logger(self.name)
        self.max_unknowns = max_unknowns
        self.contigs = contigs
        self.assembly_size = assembly_size
        self.mash = mash
        self.assembly_summary = assembly_summary
        self.qc_dir = os.path.join(self.path, "qc")
        self.passed_dir = os.path.join(self.qc_results_dir, "passed")
        self.stats_path = os.path.join(self.qc_dir, "stats.csv")
        self.nw_path = os.path.join(self.qc_dir, "tree.nw")
        self.dmx_path = os.path.join(self.qc_dir, "dmx.csv")
        self.failed_path = os.path.join(self.qc_results_dir, "failed.csv")
        self.tree_img = os.path.join(self.qc_results_dir, "tree.svg")
        self.summary_path = os.path.join(self.qc_results_dir, "qc_summary.txt")
        self.allowed_path = os.path.join(self.qc_results_dir, "allowed.p")
        self.paste_file = os.path.join(self.qc_dir, "all.msh")
        # Figure out if defining these as None is necessary
        self.tree = None
        self.stats = None
        self.dmx = None
        if os.path.isfile(self.stats_path):
            self.stats = pd.read_csv(self.stats_path, index_col=0)
        if os.path.isfile(self.nw_path):
            self.tree = Tree(self.nw_path, 1)
        if os.path.isfile(self.failed_path):
            self.failed_report = pd.read_csv(self.failed_path, index_col=0)
        if os.path.isfile(self.dmx_path):
            try:
                self.dmx = pd.read_csv(self.dmx_path, index_col=0, sep="\t")
            except pd.errors.EmptyDataError:
                self.log.exception("Failed to read distance matrix")
        self.metadata_path = os.path.join(self.qc_dir,
                                          "{}_metadata.csv".format(self.name))
        self.criteria = ["unknowns", "contigs", "assembly_size", "distance"]
        self.tolerance = {
            "unknowns": max_unknowns,
            "contigs": contigs,
            "assembly_size": assembly_size,
            "distance": mash,
        }
        self.passed = self.stats
        self.failed = {}
        self.med_abs_devs = {}
        self.dev_refs = {}
        self.allowed = {"unknowns": max_unknowns}
        self.colors = {
            "unknowns": "red",
            "contigs": "green",
            "distance": "purple",
            "assembly_size": "orange",
        }
        self.genomes = [
            genome.Genome(path, self.assembly_summary)
            for path in self.genome_paths
        ]

    def __str__(self):
        self.message = [
            "Species: {}".format(self.name),
            "Maximum Unknown Bases:  {}".format(self.max_unknowns),
            "Acceptable Deviations:",
            "Contigs, {}".format(self.contigs),
            "Assembly Size, {}".format(self.assembly_size),
            "MASH: {}".format(self.mash),
        ]
        return "\n".join(self.message)

    def assess(f):
        @functools.wraps(f)
        def wrapper(self):
            try:
                assert sorted(self.genome_names.tolist()) == sorted(
                    self.stats.index.tolist())
                assert os.path.isfile(self.allowed_path)
                self.log.info("Already complete")
            except (AttributeError, AssertionError):
                f(self)

        return wrapper

    def tree_complete(self):
        try:
            leaf_names = [
                re.sub(".fasta", "", i) for i in self.tree.get_leaf_names()
            ]
            assert (sorted(leaf_names) == sorted(self.stats.index.tolist()) ==
                    sorted(self.genome_names.tolist()))
            return True
        except (AssertionError, AttributeError):
            return False

    @property
    def genome_paths(self, ext="fasta"):
        """Returns a generator for every file ending with `ext`

        :param ext: File extension of genomes in species directory
        :returns: Generator of Genome objects for all genomes in species dir
        :rtype: generator
        """
        return [
            os.path.join(self.path, genome) for genome in os.listdir(self.path)
            if genome.endswith(ext)
        ]

    @property
    def total_genomes(self):
        return len(list(self.genomes))

    @property
    def sketches(self):
        return Path(self.qc_dir).glob("GCA*msh")

    @property
    def total_sketches(self):
        return len(list(self.sketches))

    @property
    def genome_names(self):
        ids = [i.name for i in self.genomes]
        return pd.Index(ids)

    @property
    def biosample_ids(self):
        ids = self.assembly_summary.df.loc[
            self.accession_ids].biosample.tolist()
        return ids

    # may be redundant. see genome_names attrib
    @property
    def accession_ids(self):
        ids = [
            i.accession_id for i in self.genomes if i.accession_id is not None
        ]
        return ids

    def mash_paste(self):
        if os.path.isfile(self.paste_file):
            os.remove(self.paste_file)
        sketches = os.path.join(self.qc_dir, "*msh")
        cmd = "mash paste {} {}".format(self.paste_file, sketches)
        Popen(cmd, shell="True", stderr=DEVNULL).wait()
        if not os.path.isfile(self.paste_file):
            self.log.error("MASH paste failed")
            self.paste_file = None

    def mash_dist(self):
        from multiprocessing import cpu_count

        ncpus = cpu_count() - 2
        cmd = "mash dist -p {} -t '{}' '{}' > '{}'".format(
            ncpus, self.paste_file, self.paste_file, self.dmx_path)
        Popen(cmd, shell="True", stderr=DEVNULL).wait()
        self.dmx = pd.read_csv(self.dmx_path, index_col=0, sep="\t")
        # Make distance matrix more readable
        names = [os.path.splitext(i)[0].split("/")[-1] for i in self.dmx.index]
        self.dmx.index = names
        self.dmx.columns = names
        self.dmx.to_csv(self.dmx_path, sep="\t")

    def mash_sketch(self):
        """Sketch all genomes"""
        with ProcessingPool() as pool:
            pool.map(genome.sketch_genome, self.genome_paths)

    def run_mash(self):
        try:
            self.mash_sketch()
        except Exception:
            self.log.exception("mash sketch failed")
        try:
            self.mash_paste()
        except Exception:
            self.log.exception("mash paste failed")
        try:
            self.mash_dist()
        except Exception:
            self.log.exception("mash dist failed")

    def get_tree(self):
        if not self.tree_complete():
            from ete3.coretype.tree import TreeError
            import numpy as np
            from skbio.tree import TreeNode
            from scipy.cluster.hierarchy import weighted

            ids = ["{}.fasta".format(i) for i in self.dmx.index.tolist()]
            triu = np.triu(self.dmx.as_matrix())
            hclust = weighted(triu)
            t = TreeNode.from_linkage_matrix(hclust, ids)
            nw = t.__str__().replace("'", "")
            self.tree = Tree(nw)
            try:
                # midpoint root tree
                self.tree.set_outgroup(self.tree.get_midpoint_outgroup())
            except TreeError:
                self.log.error("Unable to midpoint root tree")
            self.tree.write(outfile=self.nw_path)

    @property
    def stats_files(self):
        return Path(self.qc_dir).glob("GCA*csv")

    def get_stats(self):
        """Get stats for all genomes. Concat the results into a DataFrame"""
        # pool.map needs an arg for each function that will be run
        dmx_mean = [self.dmx.mean()] * len(self.genome_paths)
        with ProcessingPool() as pool:
            results = pool.map(genome.mp_stats, self.genome_paths, dmx_mean)
        self.stats = pd.concat(results)
        self.stats.to_csv(self.stats_path)

    def MAD(self, df, col):
        """Get the median absolute deviation for col"""
        MAD = abs(df[col] - df[col].median()).mean()
        return MAD

    def MAD_ref(MAD, tolerance):
        """Get the reference value for median absolute deviation"""
        dev_ref = MAD * tolerance
        return dev_ref

    def bound(df, col, dev_ref):
        lower = df[col].median() - dev_ref
        upper = df[col].median() + dev_ref
        return lower, upper

    def filter_unknown_bases(self):
        """Filter out genomes with too many unknown bases."""
        self.failed["unknowns"] = self.stats.index[
            self.stats["unknowns"] > self.tolerance["unknowns"]]
        self.passed = self.stats.drop(self.failed["unknowns"])

    # Perform this logic in self.filter
    # Don't use decorator
    def check_passed_count(f):
        """
        Count the number of genomes in self.passed.
        Commence with filtering only if self.passed has more than five genomes.
        """
        @functools.wraps(f)
        def wrapper(self, *args):
            if len(self.passed) > 5:
                f(self, *args)
            else:
                self.allowed[args[0]] = ""
                self.failed[args[0]] = ""
                self.log.info("Not filtering based on {}".format(f.__name__))

        return wrapper

    @check_passed_count
    def filter_contigs(self, criteria):
        """
        Only look at genomes with > 10 contigs to avoid throwing off the
        median absolute deviation.
        Median absolute deviation - Average absolute difference between
        number of contigs and the median for all genomes
        Extract genomes with < 10 contigs to add them back in later.
        Add genomes with < 10 contigs back in
        """
        eligible_contigs = self.passed.contigs[self.passed.contigs > 10]
        not_enough_contigs = self.passed.contigs[self.passed.contigs <= 10]
        # TODO Define separate function for this
        med_abs_dev = abs(eligible_contigs - eligible_contigs.median()).mean()
        self.med_abs_devs["contigs"] = med_abs_dev
        # Define separate function for this
        # The "deviation reference"
        dev_ref = med_abs_dev * self.contigs
        self.dev_refs["contigs"] = dev_ref
        self.allowed["contigs"] = eligible_contigs.median() + dev_ref
        self.failed["contigs"] = eligible_contigs[
            abs(eligible_contigs - eligible_contigs.median()) > dev_ref].index
        eligible_contigs = eligible_contigs[
            abs(eligible_contigs - eligible_contigs.median()) <= dev_ref]
        eligible_contigs = pd.concat([eligible_contigs, not_enough_contigs])
        eligible_contigs = eligible_contigs.index
        self.passed = self.passed.loc[eligible_contigs]

    @check_passed_count
    def filter_MAD_range(self, criteria):
        """
        Filter based on median absolute deviation.
        Passing values fall within a lower and upper bound.
        """
        # Get the median absolute deviation
        med_abs_dev = abs(self.passed[criteria] -
                          self.passed[criteria].median()).mean()
        dev_ref = med_abs_dev * self.tolerance[criteria]
        lower = self.passed[criteria].median() - dev_ref
        upper = self.passed[criteria].median() + dev_ref
        allowed_range = (str(int(x)) for x in [lower, upper])
        allowed_range = "-".join(allowed_range)
        self.allowed[criteria] = allowed_range
        self.failed[criteria] = self.passed[
            abs(self.passed[criteria] -
                self.passed[criteria].median()) > dev_ref].index
        self.passed = self.passed[abs(
            self.passed[criteria] - self.passed[criteria].median()) <= dev_ref]

    @check_passed_count
    def filter_MAD_upper(self, criteria):
        """
        Filter based on median absolute deviation.
        Passing values fall under the upper bound.
        """
        # Get the median absolute deviation
        med_abs_dev = abs(self.passed[criteria] -
                          self.passed[criteria].median()).mean()
        dev_ref = med_abs_dev * self.tolerance[criteria]
        upper = self.passed[criteria].median() + dev_ref
        self.failed[criteria] = self.passed[
            self.passed[criteria] > upper].index
        self.passed = self.passed[self.passed[criteria] <= upper]
        upper = "{:.4f}".format(upper)
        self.allowed[criteria] = upper

    def base_node_style(self):
        from ete3 import NodeStyle, AttrFace

        nstyle = NodeStyle()
        nstyle["shape"] = "sphere"
        nstyle["size"] = 2
        nstyle["fgcolor"] = "black"
        for n in self.tree.traverse():
            n.set_style(nstyle)
            if re.match(".*fasta", n.name):
                nf = AttrFace("name", fsize=8)
                nf.margin_right = 150
                nf.margin_left = 3
                n.add_face(nf, column=0)

    # Might be better in a layout function
    def style_and_render_tree(self, file_types=["svg"]):
        from ete3 import TreeStyle, TextFace, CircleFace

        ts = TreeStyle()
        title_face = TextFace(self.name.replace("_", " "), fsize=20)
        title_face.margin_bottom = 10
        ts.title.add_face(title_face, column=0)
        ts.branch_vertical_margin = 10
        ts.show_leaf_name = False
        # Legend
        ts.legend.add_face(TextFace(""), column=1)
        for category in ["Allowed", "Tolerance", "Filtered", "Color"]:
            category = TextFace(category, fsize=8, bold=True)
            category.margin_bottom = 2
            category.margin_right = 40
            ts.legend.add_face(category, column=1)
        for i, criteria in enumerate(self.criteria, 2):
            title = criteria.replace("_", " ").title()
            title = TextFace(title, fsize=8, bold=True)
            title.margin_bottom = 2
            title.margin_right = 40
            cf = CircleFace(4, self.colors[criteria], style="sphere")
            cf.margin_bottom = 5
            filtered_count = len(
                list(filter(None, self.failed_report.criteria == criteria)))
            filtered = TextFace(filtered_count, fsize=8)
            filtered.margin_bottom = 5
            allowed = TextFace(self.allowed[criteria], fsize=8)
            allowed.margin_bottom = 5
            allowed.margin_right = 25
            tolerance = TextFace(self.tolerance[criteria], fsize=8)
            tolerance.margin_bottom = 5
            ts.legend.add_face(title, column=i)
            ts.legend.add_face(allowed, column=i)
            ts.legend.add_face(tolerance, column=i)
            ts.legend.add_face(filtered, column=i)
            ts.legend.add_face(cf, column=i)
        for f in file_types:
            out_tree = os.path.join(self.qc_results_dir, "tree.{}".format(f))
            self.tree.render(out_tree, tree_style=ts)

    def color_tree(self):
        from ete3 import NodeStyle

        self.base_node_style()
        for failed_genome in self.failed_report.index:
            n = self.tree.get_leaves_by_name(failed_genome + ".fasta").pop()
            nstyle = NodeStyle()
            nstyle["fgcolor"] = self.colors[self.failed_report.loc[
                failed_genome, "criteria"]]
            nstyle["size"] = 9
            n.set_style(nstyle)
        self.style_and_render_tree()

    def filter(self):
        self.filter_unknown_bases()
        self.filter_contigs("contigs")
        self.filter_MAD_range("assembly_size")
        self.filter_MAD_upper("distance")
        with open(self.allowed_path, "wb") as p:
            pickle.dump(self.allowed, p)
        self.summary()
        self.write_failed_report()

    def write_failed_report(self):
        from itertools import chain

        if os.path.isfile(self.failed_path):
            os.remove(self.failed_path)
        ixs = chain.from_iterable([i for i in self.failed.values()])
        self.failed_report = pd.DataFrame(index=ixs, columns=["criteria"])
        for criteria in self.failed.keys():
            if type(self.failed[criteria]) == pd.Index:
                self.failed_report.loc[self.failed[criteria],
                                       "criteria"] = criteria
        self.failed_report.to_csv(self.failed_path)

    def summary(self):
        summary = [
            self.name,
            "Unknown Bases",
            "Allowed: {}".format(self.allowed["unknowns"]),
            "Tolerance: {}".format(self.tolerance["unknowns"]),
            "Filtered: {}".format(len(self.failed["unknowns"])),
            "\n",
            "Contigs",
            "Allowed: {}".format(self.allowed["contigs"]),
            "Tolerance: {}".format(self.tolerance["contigs"]),
            "Filtered: {}".format(len(self.failed["contigs"])),
            "\n",
            "Assembly Size",
            "Allowed: {}".format(self.allowed["assembly_size"]),
            "Tolerance: {}".format(self.tolerance["assembly_size"]),
            "Filtered: {}".format(len(self.failed["assembly_size"])),
            "\n",
            "MASH",
            "Allowed: {}".format(self.allowed["distance"]),
            "Tolerance: {}".format(self.tolerance["distance"]),
            "Filtered: {}".format(len(self.failed["distance"])),
            "\n",
        ]
        summary = "\n".join(summary)
        with open(os.path.join(self.summary_path), "w") as f:
            f.write(summary)
        return summary

    def link_genomes(self):
        if not os.path.exists(self.passed_dir):
            os.mkdir(self.passed_dir)
        for passed_genome in self.passed.index:
            fname = "{}.fasta".format(passed_genome)
            src = os.path.join(self.path, fname)
            dst = os.path.join(self.passed_dir, fname)
            try:
                os.link(src, dst)
            except FileExistsError:
                continue

    @assess
    def qc(self):
        if self.total_genomes > 10:
            self.run_mash()
            self.get_stats()
            self.filter()
            self.link_genomes()
            self.get_tree()
            self.color_tree()
            self.log.info("QC finished")
            self.report()

    def report(self):
        try:
            assert (self.total_genomes == self.total_sketches == len(
                list(self.stats_files)))
        except AssertionError:
            from itertools import combinations

            self.log.error("File counts do not match up.")
            self.log.error(f"{self.total_genomes} total .fasta files")
            self.log.error(f"{self.total_sketches} total sketch .msh files")
            self.log.error(
                f"{len(list(self.stats_files))} total stats .csv files")
            sketches = [genome.Genome.id_(i.as_posix()) for i in self.sketches]
            stats = [genome.Genome.id_(i.as_posix()) for i in self.stats_files]
            genome_ids = [i.accession_id for i in self.genomes]
            ids = [genome_ids, sketches, stats]
            for a, b in combinations(ids, 2):
                diff = set(a) - set(b)
                if bool(diff):
                    for i in diff:
                        self.log.error(i)
        try:
            assert Path(self.dmx_path).stat().st_size  # Check if dmx is empty
        except AssertionError:
            self.log.error("Distance matrix is empty")
        try:
            assert Path(self.passed_dir).iterdir()
        except AssertionError:
            self.log.error("Passed directory is empty")

    def select_metadata(self, metadata):
        try:
            self.metadata = metadata.joined.loc[self.biosample_ids]
            self.metadata.to_csv(self.metadata_path)
        except KeyError:
            self.log.exception("Metadata failed")
示例#10
0
class Species(object):
    """Represents a collection of genomes in `path`

    :param path: Path to the directory of related genomes you wish to analyze.
    :param max_unknowns: Number of allowable unknown bases, i.e. not [ATCG]
    :param contigs: Acceptable deviations from median number of contigs
    :param assembly_size: Acceptable deviations from median assembly size
    :param mash: Acceptable deviations from median MASH distances
    :param assembly_summary: a pandas DataFrame with assembly summary information
    """

    path = attr.ib(default=Path(), converter=Path)
    max_unknowns = attr.ib(default=200)
    # TODO These are really about attrib names
    contigs = attr.ib(default=3.0)
    assembly_size = attr.ib(default=3.0)
    mash = attr.ib(default=3.0)
    assembly_summary = attr.ib(default=None)
    metadata = attr.ib(default=None)

    def __attrs_post_init__(self):
        self.log = logbook.Logger(self.path.name)
        self.label = "-".join(
            map(str, [self.max_unknowns, self.contigs, self.assembly_size, self.mash])
        )
        self.paths = config.Paths(
            root=self.path,
            subdirs=[
                "qc",
                ".logs",
            ],
        )
        self.stats_path = os.path.join(self.paths.qc, "stats.csv")
        self.nw_path = os.path.join(self.paths.qc, "tree.nw")
        self.dmx_path = os.path.join(self.paths.qc, "dmx.csv")
        self.failed_path = os.path.join(self.paths.qc, "failed.csv")
        self.summary_path = os.path.join(self.paths.qc, "qc_summary.txt")
        self.paste_file = os.path.join(self.paths.qc, "all.msh")
        # Figure out if defining these as None is necessary
        self.tree = None
        self.stats = None
        if os.path.isfile(self.stats_path):
            self.stats = pd.read_csv(self.stats_path, index_col=0)
        if os.path.isfile(self.nw_path):
            self.tree = Tree(self.nw_path, 1)
        if os.path.isfile(self.failed_path):
            self.failed_report = pd.read_csv(self.failed_path, index_col=0)
        self.tolerance = {
            "unknowns": self.max_unknowns,
            "contigs": self.contigs,
            "assembly_size": self.assembly_size,
            "distance": self.mash,
        }
        self.passed = self.stats
        self.failed = {}
        self.med_abs_devs = {}
        self.dev_refs = {}
        self.allowed = {"unknowns": self.max_unknowns}

    def __str__(self):
        self.message = [
            "Species: {}".format(self.path.name),
            "Maximum Unknown Bases:  {}".format(self.max_unknowns),
            "Acceptable Deviations:",
            "Contigs, {}".format(self.contigs),
            "Assembly Size, {}".format(self.assembly_size),
            "MASH: {}".format(self.mash),
        ]
        return "\n".join(self.message)

    @property
    def genome_paths(self, ext="fasta"):
        """Returns a generator for every file ending with `ext`

        :param ext: File extension of genomes in species directory
        :returns: Generator of Genome objects for all genomes in species dir
        :rtype: generator
        """
        return [
            os.path.join(self.path, genome)
            for genome in os.listdir(self.path)
            if genome.endswith(ext)
        ]

    @property
    def sketches(self):
        return Path(self.paths.qc).glob("GCA*msh")

    @property
    def total_sketches(self):
        return len(list(self.sketches))

    @property
    def genome_names(self):
        ids = [i.name for i in self.genomes]
        return pd.Index(ids)

    @property
    def biosample_ids(self):
        ids = self.assembly_summary.df.loc[self.accession_ids].biosample.tolist()
        return ids

    # may be redundant. see genome_names attrib
    @property
    def accession_ids(self):
        ids = [i.accession_id for i in self.genomes if i.accession_id is not None]
        return ids

    def get_tree(self):
        from ete3.coretype.tree import TreeError
        import numpy as np
        from skbio.tree import TreeNode
        from scipy.cluster.hierarchy import weighted

        ids = self.dmx.index.tolist()
        triu = np.triu(self.dmx.as_matrix())
        hclust = weighted(triu)
        t = TreeNode.from_linkage_matrix(hclust, ids)
        nw = t.__str__().replace("'", "")
        self.tree = Tree(nw)
        try:
            # midpoint root tree
            self.tree.set_outgroup(self.tree.get_midpoint_outgroup())
        except TreeError:
            self.log.error("Unable to midpoint root tree")
        self.tree.write(outfile=self.nw_path)

    @property
    def stats_files(self):
        return Path(self.paths.qc).glob("GCA*csv")

    def MAD(self, df, col):
        """Get the median absolute deviation for col"""
        MAD = abs(df[col] - df[col].median()).mean()
        return MAD

    def MAD_ref(MAD, tolerance):
        """Get the reference value for median absolute deviation"""
        dev_ref = MAD * tolerance
        return dev_ref

    def bound(df, col, dev_ref):
        lower = df[col].median() - dev_ref
        upper = df[col].median() + dev_ref
        return lower, upper

    def filter_unknown_bases(self):
        """Filter out genomes with too many unknown bases."""
        self.failed["unknowns"] = self.stats.index[
            self.stats["unknowns"] > self.tolerance["unknowns"]
        ]
        self.passed = self.stats.drop(self.failed["unknowns"])

    # TODO Don't use decorator; perform this logic in self.filter
    def check_passed_count(f):
        """
        Count the number of genomes in self.passed.
        Commence with filtering only if self.passed has more than five genomes.
        """

        @functools.wraps(f)
        def wrapper(self, *args):
            if len(self.passed) > 5:
                f(self, *args)
            else:
                self.allowed[args[0]] = ""
                self.failed[args[0]] = ""
                self.log.info("Not filtering based on {}".format(f.__name__))

        return wrapper

    # todo remove unnecessary criteria parameter
    @check_passed_count
    def filter_contigs(self, criteria):
        """
        Only look at genomes with > 10 contigs to avoid throwing off the median absolute deviation.
        Median absolute deviation - Average absolute difference between number of contigs and the
        median for all genomes. Extract genomes with < 10 contigs to add them back in later.
        """
        eligible_contigs = self.passed.contigs[self.passed.contigs > 10]
        not_enough_contigs = self.passed.contigs[self.passed.contigs <= 10]
        # TODO Define separate function for this
        med_abs_dev = abs(eligible_contigs - eligible_contigs.median()).mean()
        self.med_abs_devs["contigs"] = med_abs_dev
        # Define separate function for this
        # The "deviation reference"
        dev_ref = med_abs_dev * self.contigs
        self.dev_refs["contigs"] = dev_ref
        self.allowed["contigs"] = eligible_contigs.median() + dev_ref
        self.failed["contigs"] = eligible_contigs[
            abs(eligible_contigs - eligible_contigs.median()) > dev_ref
        ].index
        eligible_contigs = eligible_contigs[
            abs(eligible_contigs - eligible_contigs.median()) <= dev_ref
        ]
        eligible_contigs = pd.concat([eligible_contigs, not_enough_contigs])
        eligible_contigs = eligible_contigs.index
        self.passed = self.passed.loc[eligible_contigs]

    @check_passed_count
    def filter_MAD_range(self, criteria):
        """
        Filter based on median absolute deviation.
        Passing values fall within a lower and upper bound.
        """
        # Get the median absolute deviation

        med_abs_dev = abs(self.passed[criteria] - self.passed[criteria].median()).mean()
        dev_ref = med_abs_dev * self.tolerance[criteria]
        lower = self.passed[criteria].median() - dev_ref
        upper = self.passed[criteria].median() + dev_ref
        allowed_range = (str(int(x)) for x in [lower, upper])
        allowed_range = "-".join(allowed_range)
        self.allowed[criteria] = allowed_range
        self.failed[criteria] = self.passed[
            abs(self.passed[criteria] - self.passed[criteria].median()) > dev_ref
        ].index
        self.passed = self.passed[
            abs(self.passed[criteria] - self.passed[criteria].median()) <= dev_ref
        ]

    @check_passed_count
    def filter_MAD_upper(self, criteria):
        """
        Filter based on median absolute deviation.
        Passing values fall under the upper bound.
        """
        # Get the median absolute deviation
        med_abs_dev = abs(self.passed[criteria] - self.passed[criteria].median()).mean()
        dev_ref = med_abs_dev * self.tolerance[criteria]
        upper = self.passed[criteria].median() + dev_ref
        self.failed[criteria] = self.passed[self.passed[criteria] > upper].index
        self.passed = self.passed[self.passed[criteria] <= upper]
        upper = "{:.4f}".format(upper)
        self.allowed[criteria] = upper

    def base_node_style(self):
        from ete3 import NodeStyle, AttrFace

        nstyle = NodeStyle()
        nstyle["shape"] = "sphere"
        nstyle["size"] = 2
        nstyle["fgcolor"] = "black"
        for n in self.tree.traverse():
            n.set_style(nstyle)
            if re.match(".*fasta", n.name):
                nf = AttrFace("name", fsize=8)
                nf.margin_right = 150
                nf.margin_left = 3
                n.add_face(nf, column=0)

    # Might be better in a layout function
    def style_and_render_tree(self, file_types=["svg"]):
        from ete3 import TreeStyle, TextFace, CircleFace

        ts = TreeStyle()
        title_face = TextFace(f"{genus} {snakemake.config['species']}", fsize=20)
        title_face.margin_bottom = 10
        ts.title.add_face(title_face, column=0)
        ts.branch_vertical_margin = 10
        ts.show_leaf_name = True
        # Legend
        ts.legend.add_face(TextFace(""), column=1)
        for category in ["Allowed", "Deviations", "Filtered", "Color"]:
            category = TextFace(category, fsize=8, bold=True)
            category.margin_bottom = 2
            category.margin_right = 40
            ts.legend.add_face(category, column=1)
        for i, criteria in enumerate(CRITERIA, 2):
            title = criteria.replace("_", " ").title()
            title = TextFace(title, fsize=8, bold=True)
            title.margin_bottom = 2
            title.margin_right = 40
            cf = CircleFace(4, COLORS[criteria], style="sphere")
            cf.margin_bottom = 5
            filtered_count = len(
                list(filter(None, self.failed_report.criteria == criteria))
            )
            filtered = TextFace(filtered_count, fsize=8)
            filtered.margin_bottom = 5
            allowed = TextFace(self.allowed[criteria], fsize=8)
            allowed.margin_bottom = 5
            allowed.margin_right = 25
            # TODO Prevent tolerance from rendering as a float
            tolerance = TextFace(self.tolerance[criteria], fsize=8)
            tolerance.margin_bottom = 5
            ts.legend.add_face(title, column=i)
            ts.legend.add_face(allowed, column=i)
            ts.legend.add_face(tolerance, column=i)
            ts.legend.add_face(filtered, column=i)
            ts.legend.add_face(cf, column=i)
        for f in file_types:
            out_tree = os.path.join(self.paths.qc, "tree.{}".format(f))
            self.tree.render(out_tree, tree_style=ts)

    def color_tree(self):
        from ete3 import NodeStyle

        self.base_node_style()
        for failed_genome in self.failed_report.index:
            n = self.tree.get_leaves_by_name(failed_genome).pop()
            nstyle = NodeStyle()
            nstyle["fgcolor"] = COLORS[
                self.failed_report.loc[failed_genome, "criteria"]
            ]
            nstyle["size"] = 9
            n.set_style(nstyle)
        self.style_and_render_tree()

    def filter(self):
        self.filter_unknown_bases()
        self.filter_contigs("contigs")
        self.filter_MAD_range("assembly_size")
        self.filter_MAD_upper("distance")
        self.summary()
        self.write_failed_report()

    def write_failed_report(self):

        if os.path.isfile(self.failed_path):
            os.remove(self.failed_path)
        ixs = chain.from_iterable([i for i in self.failed.values()])
        self.failed_report = pd.DataFrame(index=ixs, columns=["criteria"])
        for criteria in self.failed.keys():
            if type(self.failed[criteria]) == pd.Index:
                self.failed_report.loc[self.failed[criteria], "criteria"] = criteria
        self.failed_report.to_csv(self.failed_path)

    def summary(self):
        summary = [
            self.path.name,
            "Unknown Bases",
            f"Allowed: {self.allowed['unknowns']}",
            f"Tolerance: {self.tolerance['unknowns']}",
            f"Filtered: {len(self.failed['unknowns'])}",
            "\n",
            "Contigs",
            f"Allowed: {self.allowed['contigs']}",
            f"Tolerance: {self.tolerance['contigs']}",
            f"Filtered: {len(self.failed['contigs'])}",
            "\n",
            "Assembly Size",
            f"Allowed: {self.allowed['assembly_size']}",
            f"Tolerance: {self.tolerance['assembly_size']}",
            f"Filtered: {len(self.failed['assembly_size'])}",
            "\n",
            "MASH",
            f"Allowed: {self.allowed['distance']}",
            f"Tolerance: {self.tolerance['distance']}",
            f"Filtered: {len(self.failed['distance'])}",
            "\n",
        ]
        summary = "\n".join(summary)
        with open(os.path.join(self.summary_path), "w") as f:
            f.write(summary)
        return summary

    #TODO Should probably use relative paths here
    def link_genomes(self):
        for passed_genome in self.passed.index:
            id = parse_genome_id(passed_genome)
            src = root / section / group
            src = (src / f"{id}/{passed_genome}").absolute()
            name = rename_genome(passed_genome, summary)
            dst = (self.paths.qc / name).absolute()
            try:
                dst.symlink_to(src)
            except FileExistsError:
                continue

    def qc(self):
        self.filter()
        self.link_genomes()
        self.get_tree()
        self.color_tree()
        self.log.info("QC finished")

    def select_metadata(self, metadata):
        try:
            self.metadata = metadata.joined.loc[self.biosample_ids]
            self.metadata.to_csv(self.metadata_path)
        except KeyError:
            self.log.exception("Metadata failed")
示例#11
0
def deepbiome_draw_phylogenetic_tree(
        log,
        network_info,
        path_info,
        num_classes,
        file_name="%%inline",
        img_w=500,
        branch_vertical_margin=20,
        arc_start=0,
        arc_span=360,
        node_name_on=True,
        name_fsize=10,
        tree_weight_on=True,
        tree_weight=None,
        tree_level_list=['Genus', 'Family', 'Order', 'Class', 'Phylum'],
        weight_opacity=0.4,
        weight_max_radios=10,
        phylum_background_color_on=True,
        phylum_color=[],
        phylum_color_legend=False,
        show_covariates=True,
        verbose=True):
    """
    Draw phylogenetic tree

    Parameters
    ----------
    log (logging instance) :
        python logging instance for logging
    network_info (dictionary) :
        python dictionary with network_information
    path_info (dictionary):
        python dictionary with path_information
    num_classes (int):
        number of classes for the network. 0 for regression, 1 for binary classificatin.
    file_name (str):
        name of the figure for save.
        - "*.png", "*.jpg"
        - "%%inline" for notebook inline output.
        default="%%inline"
    img_w (int):
        image width (pt)
        default=500
    branch_vertical_margin (int):
        vertical margin for branch
        default=20
    arc_start (int):
        angle that arc start
        default=0
    arc_span (int):
        total amount of angle for the arc span
        default=360
    node_name_on (boolean):
        show the name of the last leaf node if True
        default=False
    name_fsize (int):
        font size for the name of the last leaf node
        default=10
    tree_weight_on (boolean):
        show the amount and the direction of the weight for each edge in the tree by circle size and color.
        default=True
    tree_weight (ndarray):
        reference tree weights
        default=None
    tree_level_list (list):
        name of each level of the given reference tree weights
        default=['Genus', 'Family', 'Order', 'Class', 'Phylum']
    weight_opacity  (float):
        opacity for weight circle
        default= 0.4
    weight_max_radios (int):
        maximum radios for weight circle
        default= 10
    phylum_background_color_on (boolean):
        show the background color for each phylum based on `phylumn_color`.
        default= True
    phylum_color (list):
        specify the list of background colors for phylum level. If `phylumn_color` is empty, it will arbitrarily assign the color for each phylum.
        default= []
    phylum_color_legend (boolean):
        show the legend for the background colors for phylum level
        default= False
    show_covariates (boolean):
        show the effect of the covariates
        default= True
    verbose (boolean):
        show the log if True
        default=True
    Returns
    -------
    
    Examples
    --------
    Draw phylogenetic tree
    
    deepbiome_draw_phylogenetic_tree(log, network_info, path_info, num_classes, file_name = "%%inline")
    """

    os.environ[
        'QT_QPA_PLATFORM'] = 'offscreen'  # for tree figure (https://github.com/etetoolkit/ete/issues/381)
    reader_class = getattr(readers,
                           network_info['model_info']['reader_class'].strip())
    reader = reader_class(log, path_info, verbose=verbose)
    data_path = path_info['data_info']['data_path']
    try:
        count_path = path_info['data_info']['count_path']
        x_list = np.array(
            pd.read_csv(path_info['data_info']['count_list_path'],
                        header=None).iloc[:, 0])
        x_path = np.array([
            '%s/%s' % (count_path, x_list[fold])
            for fold in range(x_list.shape[0]) if '.csv' in x_list[fold]
        ])
    except:
        x_path = np.array([
            '%s/%s' % (data_path, path_info['data_info']['x_path'])
            for fold in range(1)
        ])

    reader.read_dataset(x_path[0], None, 0)

    network_class = getattr(
        build_network, network_info['model_info']['network_class'].strip())
    network = network_class(network_info,
                            path_info,
                            log,
                            fold=0,
                            num_classes=num_classes,
                            tree_level_list=tree_level_list,
                            is_covariates=reader.is_covariates,
                            covariate_names=reader.covariate_names,
                            verbose=False)

    if len(phylum_color) == 0:
        colors = mcolors.CSS4_COLORS
        colors_name = list(colors.keys())
        if reader.is_covariates and show_covariates:
            phylum_color = np.random.choice(
                colors_name,
                network.phylogenetic_tree_info['Phylum_with_covariates'].
                unique().shape[0])
        else:
            phylum_color = np.random.choice(
                colors_name,
                network.phylogenetic_tree_info['Phylum'].unique().shape[0])

    basic_st = NodeStyle()
    basic_st['size'] = weight_max_radios * 0.5
    basic_st['shape'] = 'circle'
    basic_st['fgcolor'] = 'black'

    t = Tree()
    root_st = NodeStyle()
    root_st["size"] = 0
    t.set_style(root_st)

    tree_node_dict = {}
    tree_node_dict['root'] = t

    upper_class = 'root'
    lower_class = tree_level_list[-1]
    lower_layer_names = tree_weight[-1].columns.to_list()

    layer_tree_node_dict = {}
    phylum_color_dict = {}
    for j, val in enumerate(lower_layer_names):
        t.add_child(name=val)
        leaf_t = t.get_leaves_by_name(name=val)[0]
        leaf_t.set_style(basic_st)
        layer_tree_node_dict[val] = leaf_t
        if lower_class == 'Phylum' and phylum_background_color_on:
            phylum_st = copy.deepcopy(basic_st)
            phylum_st["bgcolor"] = phylum_color[j]
            phylum_color_dict[val] = phylum_color[j]
            leaf_t.set_style(phylum_st)
    tree_node_dict[lower_class] = layer_tree_node_dict
    upper_class = lower_class
    upper_layer_names = lower_layer_names

    for i in range(len(tree_level_list) - 1):
        lower_class = tree_level_list[-2 - i]
        if upper_class == 'Disease' and show_covariates == False:
            lower_layer_names = network.phylogenetic_tree_info[
                lower_class].unique()
        else:
            lower_layer_names = tree_weight[-i - 1].index.to_list()

        layer_tree_node_dict = {}
        for j, val in enumerate(upper_layer_names):
            parient_t = tree_node_dict[upper_class][val]
            if upper_class == 'Disease':
                child_class = lower_layer_names
            else:
                child_class = network.phylogenetic_tree_info[lower_class][
                    network.phylogenetic_tree_info[upper_class] ==
                    val].unique()

            for k, child_val in enumerate(child_class):
                parient_t.add_child(name=child_val)
                leaf_t = parient_t.get_leaves_by_name(name=child_val)[0]
                if lower_class == 'Phylum' and phylum_background_color_on:
                    phylum_st = copy.deepcopy(basic_st)
                    phylum_st["bgcolor"] = phylum_color[k]
                    phylum_color_dict[child_val] = phylum_color[k]
                    leaf_t.set_style(phylum_st)
                else:
                    leaf_t.set_style(basic_st)
                if tree_weight_on:
                    edge_weights = np.array(tree_weight[-1 - i])
                    edge_weights *= (weight_max_radios / np.max(edge_weights))
                    if upper_class == 'Disease':
                        upper_num = 0
                    else:
                        upper_num = network.phylogenetic_tree_dict[
                            upper_class][val]
                    if upper_class == 'Disease' and reader.is_covariates == True and show_covariates:
                        lower_num = network.phylogenetic_tree_dict[
                            '%s_with_covariates' % lower_class][child_val]
                    else:
                        lower_num = network.phylogenetic_tree_dict[
                            lower_class][child_val]
                    leaf_t.add_features(weight=edge_weights[lower_num,
                                                            upper_num])
                layer_tree_node_dict[child_val] = leaf_t
        tree_node_dict[lower_class] = layer_tree_node_dict
        upper_class = lower_class
        upper_layer_names = lower_layer_names

    def layout(node):
        if "weight" in node.features:
            # Creates a sphere face whose size is proportional to node's
            # feature "weight"
            color = {1: "RoyalBlue", 0: "Red"}[int(node.weight > 0)]
            C = CircleFace(radius=node.weight, color=color, style="circle")
            # Let's make the sphere transparent
            C.opacity = weight_opacity
            # And place as a float face over the tree
            faces.add_face_to_node(C, node, 0, position="float")

        if node_name_on & node.is_leaf():
            # Add node name to laef nodes
            N = AttrFace("name", fsize=name_fsize, fgcolor="black")
            faces.add_face_to_node(N, node, 0)

    ts = TreeStyle()

    ts.show_leaf_name = False
    ts.mode = "c"
    ts.arc_start = arc_start
    ts.arc_span = arc_span
    ts.layout_fn = layout
    ts.branch_vertical_margin = branch_vertical_margin
    ts.show_scale = False

    if phylum_color_legend:
        for phylum_name in np.sort(list(phylum_color_dict.keys())):
            color_name = phylum_color_dict[phylum_name]
            ts.legend.add_face(CircleFace(weight_max_radios * 1, color_name),
                               column=0)
            ts.legend.add_face(TextFace(" %s" % phylum_name, fsize=name_fsize),
                               column=1)

    return t.render(file_name=file_name, w=img_w, tree_style=ts)


# #########################################################################################################################
# if __name__ == "__main__":
#     argdict = argv_parse(sys.argv)
#     try: gpu_memory_fraction = float(argdict['gpu_memory_fraction'][0])
#     except: gpu_memory_fraction = None
#     try: max_queue_size=int(argdict['max_queue_size'][0])
#     except: max_queue_size=10
#     try: workers=int(argdict['workers'][0])
#     except: workers=1
#     try: use_multiprocessing=argdict['use_multiprocessing'][0]=='True'
#     except: use_multiprocessing=False

#     ### Logger ############################################################################################
#     logger = logging_daily.logging_daily(argdict['log_info'][0])
#     logger.reset_logging()
#     log = logger.get_logging()
#     log.setLevel(logging_daily.logging.INFO)

#     log.info('Argument input')
#     for argname, arg in argdict.items():
#         log.info('    {}:{}'.format(argname,arg))

#     ### Configuration #####################################################################################
#     config_data = configuration.Configurator(argdict['path_info'][0], log)
#     config_data.set_config_map(config_data.get_section_map())
#     config_data.print_config_map()

#     config_network = configuration.Configurator(argdict['network_info'][0], log)
#     config_network.set_config_map(config_network.get_section_map())
#     config_network.print_config_map()

#     path_info = config_data.get_config_map()
#     network_info = config_network.get_config_map()
#     test_evaluation, train_evaluation, network = deepbiome_train(log, network_info, path_info, number_of_fold=20)
示例#12
0
    parser.add_argument('--nosupport', action='store_true', default="", help="Hide branch support")
    parser.add_argument('-c',dest='circular', action='store_true', default="", help="Draw a circular tree ")
    
    args, unknown = parser.parse_known_args()
    
    t = Tree(args.i)
    ts = TreeStyle()
    
    if args.colorleaf:
        
        with open(args.colorleaf,'rU') as file_map:
            for line in file_map:
                if line:
                    leaf_color = list(map(str.strip,line.split('\t')))
                    print leaf_color
                    for leaf in t.get_leaves_by_name(leaf_color[0]):
                        leaf.set_style(NodeStyle())
                        
                        if leaf.name == leaf_color[0]:
                            leaf.img_style["bgcolor"] = leaf_color[1]
                        
    ts.show_leaf_name = not args.noleaf
    ts.show_branch_length = not args.nolength
    ts.show_branch_support = not args.nosupport

    if args.circular:
        ts.mode = "c"
        
    ext="svg"
    if args.ext:
       ext = args.ext