Ejemplo n.º 1
0
    def _ml_t_init(self,ancestral_inference=True, **kwarks):
        """
        Initialize the attributes in all tree nodes that are required
        by the ML algorithm to compute the probablility distribution of the node
        locations. These attributes include the distance from the node postions
        to the present (in branch length units), branch length interpolation
        objects, and the probability distributions for the nodes which have the
        date-time information (these are going to be delta-functions), and
        set the sequence profiles in the eigenspace of the GTR matrix.

        """
        tree = self.tree

        if ttconf.BRANCH_LEN_PENALTY is None:
            ttconf.BRANCH_LEN_PENALTY = 0.0

        if ancestral_inference:
            self.optimize_seq_and_branch_len(**kwarks)

        print('Initializing branch length interpolation objects')
        if self.date2dist is None:
            print ("error - no date to dist conversion set. "
                "Run init_date_constraints and try once more.")
            return

        for node in tree.find_clades():

            if not hasattr(node, 'merger_rate'):
                node.merger_rate=ttconf.BRANCH_LEN_PENALTY

            # make interpolation object for branch lengths
            self._make_branch_len_interpolator(node, n=ttconf.BRANCH_GRID_SIZE)
            # set the profiles in the eigenspace of the GTR matrix
            # in the following, we only use the prf_l and prf_r (left and right
            # profiles in the matrix eigenspace)
            self._set_rotated_profiles(node)

            # node is constrained
            if hasattr(node, 'numdate_given') and node.numdate_given is not None:
                if hasattr(node, 'bad_branch') and node.bad_branch==True:
                    print ("Branch is marked as bad, excluding it from the optimization process"
                        " Will be optimized freely")
                    node.numdate_given = None
                    node.abs_t = None
                    #    if there are no constraints - log_prob will be set on-the-fly
                    node.msg_to_parent = None
                else:

                    # set the absolute time in branch length units
                    # the abs_t zero is today, and the direction is to the past

                    # this is the conversion between the branch-len and the years
                    node.abs_t = (utils.numeric_date() - node.numdate_given) * abs(self.date2dist.slope)
                    node.msg_to_parent = utils.delta_fun(node.abs_t, return_log=True, normalized=False)
            # unconstrained node
            else:
                node.numdate_given = None
                node.abs_t = None
                # if there are no constraints - log_prob will be set on-the-fly
                node.msg_to_parent = None
Ejemplo n.º 2
0
    def convert_dates(self):
        from datetime import datetime, timedelta
        now = utils.numeric_date()
        for node in self.tree.find_clades():
            years_bp = self.date2dist.to_years(node.time_before_present)
            if years_bp < 0:
                if not hasattr(node, "bad_branch") or node.bad_branch == False:
                    self.logger(
                        "ClockTree.convert_dates -- WARNING: The node is later than today, but it is not"
                        "marked as \"BAD\", which indicates the error in the "
                        "likelihood optimization.",
                        4,
                        warn=True)
                else:
                    self.logger(
                        "ClockTree.convert_dates -- WARNING: node which is marked as \"BAD\" optimized "
                        "later than present day",
                        4,
                        warn=True)

            node.numdate = now - years_bp

            # set the human-readable date
            days = 365.25 * (node.numdate - int(node.numdate))
            year = int(node.numdate)
            try:
                n_date = datetime(year, 1, 1) + timedelta(days=days)
                node.date = datetime.strftime(n_date, "%Y-%m-%d")
            except:
                # this is the approximation
                n_date = datetime(1900, 1, 1) + timedelta(days=days)
                node.date = str(year) + "-" + str(n_date.month) + "-" + str(
                    n_date.day)
Ejemplo n.º 3
0
def set_node_dates_from_dic(tree, dates_dic):
    """
    Read names of the leaves of the tree, mathc them with the provided dictionary
    and set the raw_date attribute to the nodes. If the dictionary has no entry
    for  a node, the node gets raw_date = None attribute.
    Args:
     - tree (TreeTime): instance of the tree time object with phylogenetic tree
     loaded.
     - dates_dic (dic): dictionary storing dates of the nodes as datetime.datetime
     object.
    Returns:
     - None, tree is being modified in-place
    """

    err_ = 0
    num_ = 0
    now = utils.numeric_date(datetime.datetime.now())
    for node in tree.tree.find_clades():

        if node.name is None or not node.name in dates_dic:
            node.numdate_given = None
            continue

        n_date = dates_dic[
            node.name]  # assume the dictionary contains the numdate
        if not isinstance(n_date, float) and not isinstance(
                n_date, int):  #  sanity check
            print(
                "Cannot set the numeric date tot the node. Float or int expected"
            )
            continue

        try:

            if n_date > now:
                print(
                    "Cannot set the date! the specified date is later "
                    " than today! cannot assign node date, skipping")
                node.numdate_given = None
                err_ += 1
                continue
            else:
                node.numdate_given = n_date
                num_ += 1

        except:
            print("Cannot assign date to the node: exception caught")
            node.numdate_given = None
            err_ += 1

    tu = (num_, err_)

    print("Assigned dates to {0} nodes, {1} errors".format(*tu))
Ejemplo n.º 4
0
def set_node_dates_from_dic(tree, dates_dic):
    """
    Read names of the leaves of the tree, mathc them with the provided dictionary
    and set the raw_date attribute to the nodes. If the dictionary has no entry
    for  a node, the node gets raw_date = None attribute.
    Args:
     - tree (TreeTime): instance of the tree time object with phylogenetic tree
     loaded.
     - dates_dic (dic): dictionary storing dates of the nodes as datetime.datetime
     object.
    Returns:
     - None, tree is being modified in-place
    """

    err_ = 0
    num_ = 0
    now = utils.numeric_date(datetime.datetime.now())
    for node in tree.tree.find_clades():

        if node.name is None or not node.name in dates_dic:
            node.numdate_given = None
            continue

        n_date = dates_dic[node.name] # assume the dictionary contains the numdate
        if not isinstance(n_date, float) and not isinstance(n_date, int): #  sanity check
            print ("Cannot set the numeric date tot the node. Float or int expected")
            continue

        try:

            if n_date > now:
                print ("Cannot set the date! the specified date is later "
                    " than today! cannot assign node date, skipping")
                node.numdate_given = None
                err_+=1
                continue
            else:
                node.numdate_given = n_date
                num_ += 1

        except:
            print ("Cannot assign date to the node: exception caught")
            node.numdate_given = None
            err_ += 1

    tu = (num_, err_)

    print ("Assigned dates to {0} nodes, {1} errors".format(*tu))
Ejemplo n.º 5
0
    def convert_dates(self):
        '''
        this fucntion converts the estimated "time_before_present" properties of all nodes
        to numerical dates stored in the "numdate" attribute. This date is further converted
        into a human readable date string in format %Y-%m-%d assuming the usual calendar


        Returns
        -------
         None
            All manipulations are done in place on the tree

        '''
        from datetime import datetime, timedelta
        now = utils.numeric_date()
        for node in self.tree.find_clades():
            years_bp = self.date2dist.to_years(node.time_before_present)
            if years_bp < 0 and self.real_dates:
                if not hasattr(node, "bad_branch") or node.bad_branch == False:
                    self.logger(
                        "ClockTree.convert_dates -- WARNING: The node is later than today, but it is not "
                        "marked as \"BAD\", which indicates the error in the "
                        "likelihood optimization.",
                        4,
                        warn=True)
                else:
                    self.logger(
                        "ClockTree.convert_dates -- WARNING: node which is marked as \"BAD\" optimized "
                        "later than present day",
                        4,
                        warn=True)

            node.numdate = now - years_bp

            # set the human-readable date
            days = 365.25 * (node.numdate - int(node.numdate))
            year = int(node.numdate)
            try:  # datetime will only operate on dates after 1900
                n_date = datetime(year, 1, 1) + timedelta(days=days)
                node.date = datetime.strftime(n_date, "%Y-%m-%d")
            except:
                # this is the approximation not accounting for gap years etc
                n_date = datetime(1900, 1, 1) + timedelta(days=days)
                node.date = str(year) + "-" + str(n_date.month) + "-" + str(
                    n_date.day)
Ejemplo n.º 6
0
def root_lh_to_json(tt, outf):

    cutoff = 1e-3

    mtp = tt.tree.root.msg_to_parent
    mtp_min = mtp.y.min()

    mtpy = np.array([np.exp(-k + mtp_min) for k in mtp.y])
    mtpx = mtp.x

    # cut and center
    maxy_idx = mtpy.argmax()
    val_right = (mtpy[maxy_idx:] > cutoff)
    if (val_right.sum() == 0):
        right_dist = 0
    else:
        # left, actually (time is in the opposite direction)
        right_dist = -mtpx[maxy_idx] + mtpx[maxy_idx + val_right.argmin()]

    val_left = mtpy[:maxy_idx] > cutoff
    if (val_left.sum() == 0):
        left_dist = 0.0
    else:
        left_dist = mtpx[maxy_idx] - mtpx[maxy_idx -
                                          (maxy_idx - val_left.argmax())]

    dist = np.max((left_dist, right_dist))
    center = mtpx[maxy_idx]

    # final x-y scatter
    #import ipdb; ipdb.set_trace()
    raw_x = np.unique(
        np.concatenate(
            ([center - dist], [center], [center + dist],
             mtpx[(mtpx < dist + center) & (mtpx > center - dist)])))

    x = utils.numeric_date() - np.array(map(tt.date2dist.get_date, raw_x))
    y = np.exp(-(mtp(raw_x) - mtp_min))
    arr = [{"x": f, "y": b} for f, b in zip(x, y)]

    with open(outf, 'w') as of:
        json.dump(arr, of, indent=True)

    print(', '.join([str(k) for k in x]))
    print(', '.join([str(k) for k in y]))
Ejemplo n.º 7
0
def set_node_dates_from_names(tree, date_func):
    """
    Read names of the leaves of the tree, extract the dates of the leaves from the
    names and asign the date to the nodes.
    Assumes that the dates are given in some human-readable format
    and are converted into the numericaldate (YYYY.F).
    After this function call, each node of
    the tree gets the numdate_given attribute. If the date was extracted from name
    successfully, the 'numdate_given' will be the days-before-present (int) value.
    Otherwise (either no node name, or date reading failed), the 'numdate_given' will be
    set to None.
    Args:
     - tree (TreeTime): instance of the tree time object with phylogenetic tree
     loaded.
     - date_func (callable): function to extract date and time from node name,
     should return float
    Returns:
     - None, tree is being modified in-place
    """
    #now = datetime.datetime.now()
    ## NOTE we do not rely on the datetime objects
    now = utils.numeric_date(datetime.datetime.now())
    for node in tree.tree.get_terminals():
        try:
            node_date = date_func(node.name)
        except:
            print(
                "Cannot extract numdate from the node name. Exception caugth.")
            node.numdate_given = None
            continue
        if node_date is None:
            #print ("Cannot parse the date from name: " + str(node.name) +
            #    " Setting node raw date to None")
            node.numdate_given = None  # cannot extract the date from name - set None

        elif node_date > now:
            print(
                "Cannot set the date! the specified date is later "
                " than today")
            node.numdate_given = None
        else:
            node.numdate_given = node_date

    return
Ejemplo n.º 8
0
    def _set_final_date(self, node):
        """
        Given the location of the node in branch length units, convert it to the
        date-time information.

        Args:
         - node(Phylo.Clade): tree node. NOTE the node should have the abs_t attribute
         to have a valid value. This is automatically taken care of in the
         procedure to get the node location probability distribution.

        """
        node.abs_t = utils.min_interp(node.total_prob)
        if node.up is not None:
            node.branch_length = node.up.abs_t - node.abs_t
            node.dist2root = node.up.dist2root + node.branch_length
        else:
            node.branch_length = self.one_mutation
            node.dist2root = 0.0

        node.years_bp = self.date2dist.get_date(node.abs_t)
        if node.years_bp < 0:
            if not hasattr(node, "bad_branch") or node.bad_branch == False:
                raise ArithmeticError(
                    "The node is later than today, but it is not"
                    "marked as \"BAD\", which indicates the error in the "
                    "likelihood optimization.")
            else:
                print("Warning! node, which is marked as \"BAD\" optimized "
                      "later than present day")

        now = utils.numeric_date()
        node.numdate = now - node.years_bp

        # set the human-readable date
        days = 365.25 * (node.numdate - int(node.numdate))
        year = int(node.numdate)
        try:
            n_date = datetime.datetime(year, 1,
                                       1) + datetime.timedelta(days=days)
            node.date = datetime.datetime.strftime(n_date, "%Y-%m-%d")
        except:
            # this is the approximation
            node.date = str(year) + "-" + str(int(days / 30)) + "-" + str(
                int(days % 30))
Ejemplo n.º 9
0
def root_lh_to_json(tt, outf):

    cutoff = 1e-3

    mtp = tt.tree.root.msg_to_parent
    mtp_min = mtp.y.min()

    mtpy = np.array([np.exp(-k+mtp_min) for k in mtp.y])
    mtpx = mtp.x

    # cut and center
    maxy_idx = mtpy.argmax()
    val_right = (mtpy[maxy_idx:] > cutoff)
    if (val_right.sum() == 0):
        right_dist = 0
    else:
        # left, actually (time is in the opposite direction)
        right_dist = - mtpx[maxy_idx] + mtpx[maxy_idx + val_right.argmin()]

    val_left = mtpy[:maxy_idx] > cutoff
    if (val_left.sum() == 0):
        left_dist = 0.0
    else:
        left_dist =  mtpx[maxy_idx] - mtpx[maxy_idx - (maxy_idx - val_left.argmax())]


    dist = np.max((left_dist, right_dist))
    center = mtpx[maxy_idx]

    # final x-y scatter
    #import ipdb; ipdb.set_trace()
    raw_x = np.unique(np.concatenate(([center-dist], [center], [center+dist], mtpx[(mtpx < dist + center) & (mtpx > center-dist)])))


    x = utils.numeric_date() -  np.array(map(tt.date2dist.get_date, raw_x))
    y = np.exp(-(mtp(raw_x) - mtp_min))
    arr = [{"x":f, "y":b} for f, b in zip(x, y)]

    with open (outf,'w') as of:
        json.dump(arr, of, indent=True)

    print (', '.join([str(k) for k in x]))
    print (', '.join([str(k) for k in y]))
Ejemplo n.º 10
0
def set_node_dates_from_names(tree, date_func):
    """
    Read names of the leaves of the tree, extract the dates of the leaves from the
    names and asign the date to the nodes.
    Assumes that the dates are given in some human-readable format
    and are converted into the numericaldate (YYYY.F).
    After this function call, each node of
    the tree gets the numdate_given attribute. If the date was extracted from name
    successfully, the 'numdate_given' will be the days-before-present (int) value.
    Otherwise (either no node name, or date reading failed), the 'numdate_given' will be
    set to None.
    Args:
     - tree (TreeTime): instance of the tree time object with phylogenetic tree
     loaded.
     - date_func (callable): function to extract date and time from node name,
     should return float
    Returns:
     - None, tree is being modified in-place
    """
    #now = datetime.datetime.now()
    ## NOTE we do not rely on the datetime objects
    now = utils.numeric_date(datetime.datetime.now())
    for node in tree.tree.get_terminals():
        try:
            node_date = date_func(node.name)
        except:
            print ("Cannot extract numdate from the node name. Exception caugth.")
            node.numdate_given = None
            continue
        if node_date is None:
            #print ("Cannot parse the date from name: " + str(node.name) +
            #    " Setting node raw date to None")
            node.numdate_given = None # cannot extract the date from name - set None

        elif node_date > now:
            print ("Cannot set the date! the specified date is later "
                " than today")
            node.numdate_given = None
        else:
            node.numdate_given = node_date

    return
Ejemplo n.º 11
0
def save_timetree_results(tree, outfile_prefix):
    """
    First, it scans the tree and assigns the namesto every node with no name
    then, it saves the information as the csv table
    """
    import pandas
    df = pandas.DataFrame(
        columns=["Given_date", "Initial_root_dist", "Inferred_date"])
    aln = Align.MultipleSeqAlignment([])

    i = 0

    # save everything
    df.to_csv(outfile_prefix + ".meta.csv")
    #  TODO save variance to the metadata
    Phylo.write(tree.tree, outfile_prefix + ".tree.nwk", "newick")
    AlignIO.write(aln, outfile_prefix + ".aln.fasta", "fasta")

    # save root distibution
    mtp = tree.tree.root.msg_to_parent
    threshold = mtp.y.min() + 1000
    idxs = [mtp.y < threshold]
    mtpy = mtp.y[idxs]
    mtpx = utils.numeric_date() - np.array(
        map(tree.date2dist.get_date, mtp.x[idxs]))
    mtpy[0] = threshold
    mtpy[-1] = threshold

    np.savetxt(outfile_prefix + ".root_dist.csv",
               np.hstack((mtpx[:, None], mtpy[:, None])),
               header="Root date,-log(LH)",
               delimiter=',')

    # zip results to one file
    import zipfile
    outzip = outfile_prefix + ".zip"
    zipf = zipfile.ZipFile(outzip, 'w')
    zipf.write(outfile_prefix + ".meta.csv")
    zipf.write(outfile_prefix + ".aln.fasta")
    zipf.write(outfile_prefix + ".tree.nwk")
    zipf.write(outfile_prefix + ".root_dist.csv")
Ejemplo n.º 12
0
    def _set_final_date(self, node):
        """
        Given the location of the node in branch length units, convert it to the
        date-time information.

        Args:
         - node(Phylo.Clade): tree node. NOTE the node should have the abs_t attribute
         to have a valid value. This is automatically taken care of in the
         procedure to get the node location probability distribution.

        """
        node.abs_t = utils.min_interp(node.total_prob)
        if node.up is not None:
            node.branch_length = node.up.abs_t - node.abs_t
            node.dist2root = node.up.dist2root + node.branch_length
        else:
            node.branch_length = self.one_mutation
            node.dist2root = 0.0

        node.years_bp = self.date2dist.get_date(node.abs_t)
        if node.years_bp < 0:
            if not hasattr(node, "bad_branch") or node.bad_branch==False:
                raise ArithmeticError("The node is later than today, but it is not"
                    "marked as \"BAD\", which indicates the error in the "
                    "likelihood optimization.")
            else:
                print ("Warning! node, which is marked as \"BAD\" optimized "
                    "later than present day")

        now = utils.numeric_date()
        node.numdate = now - node.years_bp

        # set the human-readable date
        days = 365.25 * (node.numdate - int(node.numdate))
        year = int(node.numdate)
        try:
            n_date = datetime.datetime(year, 1, 1) + datetime.timedelta(days=days)
            node.date = datetime.datetime.strftime(n_date, "%Y-%m-%d")
        except:
            # this is the approximation
            node.date = str(year) + "-" + str( int(days / 30)) + "-" + str(int(days % 30))
Ejemplo n.º 13
0
def save_timetree_results(tree, outfile_prefix):
    """
    First, it scans the tree and assigns the namesto every node with no name
    then, it saves the information as the csv table
    """
    import pandas
    df = pandas.DataFrame(columns=["Given_date", "Initial_root_dist", "Inferred_date"])
    aln = Align.MultipleSeqAlignment([])

    i = 0

    # save everything
    df.to_csv(outfile_prefix + ".meta.csv")
    #  TODO save variance to the metadata
    Phylo.write(tree.tree, outfile_prefix + ".tree.nwk", "newick")
    AlignIO.write(aln, outfile_prefix + ".aln.fasta", "fasta")

    # save root distibution
    mtp = tree.tree.root.msg_to_parent
    threshold = mtp.y.min() + 1000
    idxs = [mtp.y < threshold]
    mtpy = mtp.y[idxs]
    mtpx = utils.numeric_date() -  np.array(map(tree.date2dist.get_date, mtp.x[idxs]))
    mtpy[0] = threshold
    mtpy[-1] = threshold

    np.savetxt(outfile_prefix + ".root_dist.csv",
        np.hstack((mtpx[:, None], mtpy[:, None])),
        header="Root date,-log(LH)", delimiter=',')

    # zip results to one file
    import zipfile
    outzip = outfile_prefix + ".zip"
    zipf = zipfile.ZipFile(outzip, 'w')
    zipf.write(outfile_prefix + ".meta.csv")
    zipf.write(outfile_prefix + ".aln.fasta")
    zipf.write(outfile_prefix + ".tree.nwk")
    zipf.write(outfile_prefix + ".root_dist.csv")
Ejemplo n.º 14
0
def root_pos_lh_to_human_readable(tt, node, cutoff=1e-4):

    mtp = mtp = node.marginal_lh
    mtp_min = mtp.y.min()

    mtpy = np.array([np.exp(-k + mtp_min) for k in mtp.y])
    mtpx = mtp.x

    # cut and center
    maxy_idx = mtpy.argmax()
    val_right = binary_dilation(mtpy[maxy_idx:] > cutoff)
    if (val_right.sum() == 0):
        right_dist = 0
    else:
        # left, actually (time is in the opposite direction)
        right_dist = -mtpx[maxy_idx] + mtpx[maxy_idx + val_right.argmin()]

    val_left = binary_dilation(mtpy[:maxy_idx] > cutoff)
    if (val_left.sum() == 0):
        left_dist = 0.0
    else:
        left_dist = mtpx[maxy_idx] - mtpx[maxy_idx -
                                          (maxy_idx - val_left.argmax())]

    dist = np.max((left_dist, right_dist))
    center = mtpx[maxy_idx]

    # final x-y scatter
    #import ipdb; ipdb.set_trace()
    raw_x = np.unique(
        np.concatenate(
            ([center - dist], [center], [center + dist],
             mtpx[(mtpx < dist + center) & (mtpx > center - dist)])))

    x = utils.numeric_date() - np.array(map(tt.date2dist.get_date, raw_x))
    y = np.exp(-(mtp(raw_x) - mtp_min))
    return x, y
Ejemplo n.º 15
0
    def _ml_t_marginal(self, assign_dates=False):
        """
        Compute the marginal probability distribution of the internal nodes positions by
        propagating from the tree leaves towards the root. The result of
        this operation are the probability distributions of each internal node,
        conditional on the constraints on all leaves of the tree, which have sampling dates.
        The probability distributions are set as marginal_pos_LH attributes to the nodes.

        Parameters
        ----------

         assign_dates : bool, default False
            If True, the inferred dates will be assigned to the nodes as
            :code:`time_before_present' attributes, and their branch lengths
            will be corrected accordingly.
            .. Note::
                Normally, the dates are assigned by running joint reconstruction.

        Returns
        -------

         None
            Every internal node is assigned the probability distribution in form
            of an interpolation object and sends this distribution further towards the
            root.

        """

        def _cleanup():
            for node in self.tree.find_clades():
                try:
                    del node.marginal_pos_Lx
                    del node.subtree_distribution
                    del node.msg_from_parent
                    #del node.marginal_pos_LH
                except:
                    pass


        self.logger("ClockTree - Marginal reconstruction:  Propagating leaves -> root...", 2)
        # go through the nodes from leaves towards the root:
        for node in self.tree.find_clades(order='postorder'):  # children first, msg to parents
            if node.bad_branch:
                # no information
                node.marginal_pos_Lx = None
            else: # all other nodes
                if node.date_constraint is not None and node.date_constraint.is_delta: # there is a time constraint
                    # initialize the Lx for nodes with precise date constraint:
                    # subtree probability given the position of the parent node
                    # position of the parent node is given by the branch length
                    # distribution attached to the child node position
                    node.subtree_distribution = node.date_constraint
                    bl = node.branch_length_interpolator.x
                    x = bl + node.date_constraint.peak_pos
                    node.marginal_pos_Lx = Distribution(x, node.branch_length_interpolator(bl),
                                                        min_width=self.min_width, is_log=True)

                else: # all nodes without precise constraint but positional information
                      # subtree likelihood given the node's constraint and child msg:
                    msgs_to_multiply = [node.date_constraint] if node.date_constraint is not None else []
                    msgs_to_multiply.extend([child.marginal_pos_Lx for child in node.clades
                                             if child.marginal_pos_Lx is not None])

                    # combine the different msgs and constraints
                    if len(msgs_to_multiply)==0:
                        # no information
                        node.marginal_pos_Lx = None
                        continue
                    elif len(msgs_to_multiply)==1:
                        node.subtree_distribution = msgs_to_multiply[0]
                    else: # combine the different msgs and constraints
                        node.subtree_distribution = Distribution.multiply(msgs_to_multiply)

                    if node.up is None: # this is the root, set dates
                        node.subtree_distribution._adjust_grid(rel_tol=self.rel_tol_prune)
                        node.marginal_pos_Lx = node.subtree_distribution
                        node.marginal_pos_LH = node.subtree_distribution
                        self.tree.positional_marginal_LH = -node.subtree_distribution.peak_val
                    else: # otherwise propagate to parent
                        res, res_t = NodeInterpolator.convolve(node.subtree_distribution,
                                        node.branch_length_interpolator,
                                        max_or_integral='integral',
                                        n_grid_points = self.node_grid_points,
                                        n_integral=self.n_integral,
                                        rel_tol=self.rel_tol_refine)
                        res._adjust_grid(rel_tol=self.rel_tol_prune)
                        node.marginal_pos_Lx = res

        self.logger("ClockTree - Marginal reconstruction:  Propagating root -> leaves...", 2)
        from scipy.interpolate import interp1d
        for node in self.tree.find_clades(order='preorder'):

            ## The root node
            if node.up is None:
                node.msg_from_parent = None # nothing beyond the root
            # all other cases (All internal nodes + unconstrained terminals)
            else:
                parent = node.up
                # messages from the complementary subtree (iterate over all sister nodes)
                complementary_msgs = [sister.marginal_pos_Lx for sister in parent.clades
                                            if (sister != node) and (sister.marginal_pos_Lx is not None)]

                # if parent itself got smth from the root node, include it
                if parent.msg_from_parent is not None:
                    complementary_msgs.append(parent.msg_from_parent)
                elif parent.marginal_pos_Lx is not None:
                    complementary_msgs.append(parent.marginal_pos_LH)

                if len(complementary_msgs):
                    msg_parent_to_node = NodeInterpolator.multiply(complementary_msgs)
                    msg_parent_to_node._adjust_grid(rel_tol=self.rel_tol_prune)
                else:
                    from utils import numeric_date
                    x = [parent.numdate, numeric_date()]
                    msg_parent_to_node = NodeInterpolator(x, [1.0, 1.0],min_width=self.min_width)

                # integral message, which delivers to the node the positional information
                # from the complementary subtree
                res, res_t = NodeInterpolator.convolve(msg_parent_to_node, node.branch_length_interpolator,
                                                    max_or_integral='integral',
                                                    inverse_time=False,
                                                    n_grid_points = self.node_grid_points,
                                                    n_integral=self.n_integral,
                                                    rel_tol=self.rel_tol_refine)

                node.msg_from_parent = res
                if node.marginal_pos_Lx is None:
                    node.marginal_pos_LH = node.msg_from_parent
                else:
                    node.marginal_pos_LH = NodeInterpolator.multiply((node.msg_from_parent, node.subtree_distribution))

                self.logger('ClockTree._ml_t_root_to_leaves: computed convolution'
                                ' with %d points at node %s'%(len(res.x),node.name),4)

                if self.debug:
                    tmp = np.diff(res.y-res.peak_val)
                    nsign_changed = np.sum((tmp[1:]*tmp[:-1]<0)&(res.y[1:-1]-res.peak_val<500))
                    if nsign_changed>1:
                        import matplotlib.pyplot as plt
                        plt.ion()
                        plt.plot(res.x, res.y-res.peak_val, '-o')
                        plt.plot(res.peak_pos - node.branch_length_interpolator.x,
                                 node.branch_length_interpolator(node.branch_length_interpolator.x)-node.branch_length_interpolator.peak_val, '-o')
                        plt.plot(msg_parent_to_node.x,msg_parent_to_node.y-msg_parent_to_node.peak_val, '-o')
                        plt.ylim(0,100)
                        plt.xlim(-0.05, 0.05)
                        import ipdb; ipdb.set_trace()

            # assign positions of nodes and branch length only when desired
            # since marginal reconstruction can result in negative branch length
            if assign_dates:
                node.time_before_present = node.marginal_pos_LH.peak_pos
                if node.up:
                    node.clock_length = node.up.time_before_present - node.time_before_present
                    node.branch_length = node.clock_length

            # construct the inverse cumulant distribution to evaluate confidence intervals
            if node.marginal_pos_LH.is_delta:
                node.marginal_inverse_cdf=interp1d([0,1], node.marginal_pos_LH.peak_pos*np.ones(2), kind="linear")
            else:
                dt = np.diff(node.marginal_pos_LH.x)
                y = node.marginal_pos_LH.prob_relative(node.marginal_pos_LH.x)
                int_y = np.concatenate(([0], np.cumsum(dt*(y[1:]+y[:-1])/2.0)))
                int_y/=int_y[-1]
                node.marginal_inverse_cdf = interp1d(int_y, node.marginal_pos_LH.x, kind="linear")
                node.marginal_cdf = interp1d(node.marginal_pos_LH.x, int_y, kind="linear")

        if not self.debug:
            _cleanup()

        return
Ejemplo n.º 16
0
    def _ml_t_init(self, ancestral_inference=True, **kwarks):
        """
        Initialize the attributes in all tree nodes that are required
        by the ML algorithm to compute the probablility distribution of the node
        locations. These attributes include the distance from the node postions
        to the present (in branch length units), branch length interpolation
        objects, and the probability distributions for the nodes which have the
        date-time information (these are going to be delta-functions), and
        set the sequence profiles in the eigenspace of the GTR matrix.

        """
        tree = self.tree

        if ttconf.BRANCH_LEN_PENALTY is None:
            ttconf.BRANCH_LEN_PENALTY = 0.0

        if ancestral_inference:
            self.optimize_seq_and_branch_len(**kwarks)

        print('Initializing branch length interpolation objects')
        if self.date2dist is None:
            print("error - no date to dist conversion set. "
                  "Run init_date_constraints and try once more.")
            return

        for node in tree.find_clades():

            if not hasattr(node, 'merger_rate'):
                node.merger_rate = ttconf.BRANCH_LEN_PENALTY

            # make interpolation object for branch lengths
            self._make_branch_len_interpolator(node, n=ttconf.BRANCH_GRID_SIZE)
            # set the profiles in the eigenspace of the GTR matrix
            # in the following, we only use the prf_l and prf_r (left and right
            # profiles in the matrix eigenspace)
            self._set_rotated_profiles(node)

            # node is constrained
            if hasattr(node,
                       'numdate_given') and node.numdate_given is not None:
                if hasattr(node, 'bad_branch') and node.bad_branch == True:
                    print(
                        "Branch is marked as bad, excluding it from the optimization process"
                        " Will be optimized freely")
                    node.numdate_given = None
                    node.abs_t = None
                    #    if there are no constraints - log_prob will be set on-the-fly
                    node.msg_to_parent = None
                else:

                    # set the absolute time in branch length units
                    # the abs_t zero is today, and the direction is to the past

                    # this is the conversion between the branch-len and the years
                    node.abs_t = (utils.numeric_date() - node.numdate_given
                                  ) * abs(self.date2dist.slope)
                    node.msg_to_parent = utils.delta_fun(node.abs_t,
                                                         return_log=True,
                                                         normalized=False)
            # unconstrained node
            else:
                node.numdate_given = None
                node.abs_t = None
                # if there are no constraints - log_prob will be set on-the-fly
                node.msg_to_parent = None