Exemple #1
0
    def _fitch_anc(self, **kwargs):
        """
        Reconstruct ancestral states using Fitch's algorithm. The method requires
        sequences to be assigned to leaves. It implements the iteration from
        leaves to the root constructing the Fitch profiles for each character of
        the sequence, and then by propagating from the root to the leaves,
        reconstructs the sequences of the internal nodes.

        KWargs:
         -

        Returns:
         - Ndiff (int): number of the characters that changed since the previous
         reconstruction. These changes are determined from the pre-set sequence attributes
         of the nodes. If there are no sequences available (i.e., no reconstruction
         has been made before), returns the total number of characters in the tree.

        """
        # set fitch profiiles to each terminal node
        for l in self.tree.get_terminals():
            l.state = [[k] for k in l.sequence]

        print ("Walking up the tree, creating the Fitch profiles")
        for node in self.tree.get_nonterminals(order='postorder'):
            node.state = [self._fitch_state(node, k) for k in range(self.L)]

        ambs = [i for i in range(self.L) if len(self.tree.root.state[i])>1]
        if len(ambs) > 0:
            for amb in ambs:
                print ("Ambiguous state of the root sequence "
                                    "in the position %d: %s, "
                                    "choosing %s" % (amb, str(self.tree.root.state[amb]),
                                                     self.tree.root.state[amb][0]))
        self.tree.root.sequence = np.array([k[np.random.randint(len(k)) if len(k)>1 else 0]
                                           for k in self.tree.root.state])



        print ("Walking down the self.tree, generating sequences from the "
                         "Fitch profiles.")
        N_diff = 0
        for node in self.tree.get_nonterminals(order='preorder'):
            if node.up != None: # not root
                sequence =  np.array([node.up.sequence[i]
                        if node.up.sequence[i] in node.state[i]
                        else node.state[i][0] for i in range(self.L)])
                if hasattr(node, 'sequence'):
                    N_diff += (sequence!=node.sequence).sum()
                else:
                    N_diff += self.L
                node.sequence = sequence

            node.profile = seq_utils.seq2prof(node.sequence)
            del node.state # no need to store Fitch states
        print ("Done ancestral state reconstruction")
        for node in self.tree.get_terminals():
            node.profile = seq_utils.seq2prof(node.sequence)
        return N_diff
Exemple #2
0
    def _fitch_anc(self, **kwargs):
        """
        Reconstruct ancestral states using Fitch's algorithm. The method requires
        sequences to be assigned to leaves. It implements the iteration from
        leaves to the root constructing the Fitch profiles for each character of
        the sequence, and then by propagating from the root to the leaves,
        reconstructs the sequences of the internal nodes.

        KWargs:
         -

        Returns:
         - Ndiff (int): number of the characters that changed since the previous
         reconstruction. These changes are determined from the pre-set sequence attributes
         of the nodes. If there are no sequences available (i.e., no reconstruction
         has been made before), returns the total number of characters in the tree.

        """
        # set fitch profiiles to each terminal node
        for l in self.tree.get_terminals():
            l.state = [[k] for k in l.sequence]

        print ("Walking up the tree, creating the Fitch profiles")
        for node in self.tree.get_nonterminals(order='postorder'):
            node.state = [self._fitch_state(node, k) for k in range(self.L)]

        ambs = [i for i in range(self.L) if len(self.tree.root.state[i])>1]
        if len(ambs) > 0:
            for amb in ambs:
                print ("Ambiguous state of the root sequence "
                                    "in the position %d: %s, "
                                    "choosing %s" % (amb, str(self.tree.root.state[amb]),
                                                     self.tree.root.state[amb][0]))
        self.tree.root.sequence = np.array([k[np.random.randint(len(k)) if len(k)>1 else 0]
                                           for k in self.tree.root.state])



        print ("Walking down the self.tree, generating sequences from the "
                         "Fitch profiles.")
        N_diff = 0
        for node in self.tree.get_nonterminals(order='preorder'):
            if node.up != None: # not root
                sequence =  np.array([node.up.sequence[i]
                        if node.up.sequence[i] in node.state[i]
                        else node.state[i][0] for i in range(self.L)])
                if hasattr(node, 'sequence'):
                    N_diff += (sequence!=node.sequence).sum()
                else:
                    N_diff += self.L
                node.sequence = sequence

            node.profile = seq_utils.seq2prof(node.sequence, self.gtr.profile_map)
            del node.state # no need to store Fitch states
        print ("Done ancestral state reconstruction")
        for node in self.tree.get_terminals():
            node.profile = seq_utils.seq2prof(node.sequence, self.gtr.profile_map)
        return N_diff
Exemple #3
0
    def ancestral_likelihood(self):
        """
        Calculate the likelihood of the given realization of the sequences in
        the tree
        """
        log_lh = np.zeros(self.tree.root.sequence.shape[0])
        for node in self.tree.find_clades(order='postorder'):

            if node.up is None:  #  root node
                # 0-1 profile
                profile = seq_utils.seq2prof(node.sequence,
                                             self.gtr.profile_map)
                # get the probabilities to observe each nucleotide
                profile *= self.gtr.Pi
                profile = profile.sum(axis=1)
                log_lh += np.log(profile)  # product over all characters
                continue

            t = node.branch_length

            indices = np.array([
                (np.argmax(self.gtr.alphabet == a),
                 np.argmax(self.gtr.alphabet == b))
                for a, b in izip(node.up.sequence, node.sequence)
            ])

            logQt = np.log(self.gtr.expQt(t))
            lh = logQt[indices[:, 1], indices[:, 0]]
            log_lh += lh

        return log_lh
Exemple #4
0
    def _ml_anc(self, marginal=False, verbose=0, **kwargs):
        """
        Perform ML reconstruction of the ancestral states
        KWargs:
         - store_lh (bool): if True, all likelihoods will be stored for all nodes.
           Useful for testing, diagnostics and if special post-processing is required.
         - verbose (int): how verbose the output should be
        """
        tree = self.tree
        # number of nucleotides changed from prev reconstruction
        N_diff = 0
        if 'store_lh' in kwargs:
            store_lh = kwargs['store_lh'] == True

        L = tree.get_terminals()[0].sequence.shape[0]
        n_states = self.gtr.alphabet.shape[0]
        if verbose > 2:
            print ("Walking up the tree, computing likelihoods... type of reconstruction:", 'marginal' if marginal else "joint")
        for leaf in tree.get_terminals():
            # in any case, set the profile
            leaf.profile = seq_utils.seq2prof(leaf.sequence, self.gtr.profile_map)
            leaf.lh_prefactor = np.zeros(L)
        for node in tree.get_nonterminals(order='postorder'): #leaves -> root
            # regardless of what was before, set the profile to ones
            node.lh_prefactor = np.zeros(L)
            node.profile = np.ones((L, n_states)) # we will multiply it
            for ch in node.clades:
                ch.seq_msg_to_parent = self.gtr.propagate_profile(ch.profile,
                    ch.branch_length,
                    rotated=False, # use unrotated
                    return_log=False) # raw prob to transfer prob up
                node.profile *= ch.seq_msg_to_parent
                node.lh_prefactor += ch.lh_prefactor
            pre = node.profile.sum(axis=1) #sum over nucleotide states

            node.profile = (node.profile.T/pre).T # normalize so that the sum is 1
            node.lh_prefactor += np.log(pre) # and store log-prefactor
        if (verbose > 2):
            print ("Walking down the tree, computing maximum likelihood sequences...")

        # extract the likelihood from the profile
        tree.root.lh_prefactor += np.log(tree.root.profile.max(axis=1))
        tree.anc_LH = tree.root.lh_prefactor.sum()
        tree.sequence_LH = 0
        # reset profile to 0-1 and set the sequence
        tree.root.profile *= np.diag(self.gtr.Pi) # Msg to the root from the distant part (equ frequencies)
        tree.root.sequence, tree.root.profile = \
            seq_utils.prof2seq(tree.root.profile, self.gtr, correct_prof=not marginal)
        tree.root.seq_msg_from_parent = np.repeat([self.gtr.Pi.diagonal()], len(tree.root.sequence), axis=0)

        for node in tree.find_clades(order='preorder'):
            if node.up is None: # skip if node is root
                continue
            # integrate the information coming from parents with the information
            # of all children my multiplying it to the prev computed profile
            if marginal:
                tmp_msg = np.copy(node.up.seq_msg_from_parent)
                for c in node.up.clades:
                    if c != node:
                        tmp_msg*=c.seq_msg_to_parent
                node.seq_msg_from_parent = self.gtr.propagate_profile(tmp_msg,
                            node.branch_length,
                            rotated=False, # use unrotated
                            return_log=False)
                node.profile *= node.seq_msg_from_parent
            else:
                node.seq_msg_from_parent = self.gtr.propagate_profile(node.up.profile,
                            node.branch_length,
                            rotated=False, # use unrotated
                            return_log=False)
                node.profile *= node.seq_msg_from_parent

            # reset the profile to 0-1 and  set the sequence
            sequence, profile = seq_utils.prof2seq(node.profile, self.gtr, correct_prof=not marginal)
            node.mutations = [(anc, pos, der) for pos, (anc, der) in
                            enumerate(izip(node.up.sequence, sequence)) if anc!=der]

            # this needs fixing for marginal reconstruction
            if not marginal:
                tree.sequence_LH += np.sum(np.log(node.seq_msg_from_parent[profile>0.9]))
            if hasattr(node, 'sequence') and node.sequence is not None:
                try:
                    N_diff += (sequence!=node.sequence).sum()
                except:
                    import ipdb; ipdb.set_trace()
            else:
                N_diff += L
            node.sequence = sequence
            node.profile = profile
        return N_diff
Exemple #5
0
    def _ml_anc(self, **kwargs):
        """
        Perform ML reconstruction of the ancestral states
        KWargs:
         - store_lh (bool): if True, all likelihoods will be stored for all nodes.
           Useful for testing, diagnostics and if special post-processing is required.
         - verbose (int): how verbose the output should be
        """
        tree = self.tree
        # number of nucleotides changed from prev reconstruction
        N_diff = 0
        verbose = 0 # how verbose to be at the output
        if 'store_lh' in kwargs:
            store_lh = kwargs['store_lh'] == True
        if 'verbose' in kwargs:
            try:
                verbose = int(kwargs['verbose'])
            except:
                print ("ML ERROR in input: verbose param must be int")
        L = tree.get_terminals()[0].sequence.shape[0]
        a = self.gtr.alphabet.shape[0]
        if verbose > 2:
            print ("Walking up the tree, computing joint likelihoods...")
        for leaf in tree.get_terminals():
            # in any case, set the profile
            leaf.profile = seq_utils.seq2prof(leaf.sequence, self.gtr.alphabet_type)
            leaf.lh_prefactor = np.zeros(L)
        for node in tree.get_nonterminals(order='postorder'): #leaves -> root
            # regardless of what was before, set the profile to ones
            node.lh_prefactor = np.zeros(L)
            node.profile = np.ones((L, a)) # we will multiply it
            for ch in node.clades:
                ch.seq_msg_to_parent = self.gtr.propagate_profile(ch.profile,
                    ch.branch_length,
                    rotated=False, # use unrotated
                    return_log=False) # raw prob to transfer prob up
                node.profile *= ch.seq_msg_to_parent
                node.lh_prefactor += ch.lh_prefactor
            pre = node.profile.sum(axis=1) #sum over nucleotide states

            node.profile = (node.profile.T/pre).T # normalize so that the sum is 1
            node.lh_prefactor += np.log(pre) # and store log-prefactor
        if (verbose > 2):
            print ("Walking down the tree, computing maximum likelihood sequences...")
        tree.root.profile *= np.diag(self.gtr.Pi) # Msg to the root from the distant part (equ frequencies)

        # extract the likelihood from the profile
        tree.root.lh_prefactor += np.log(tree.root.profile.max(axis=1))
        tree.anc_LH = tree.root.lh_prefactor.sum()
        tree.sequence_LH = 0
        # reset profile to 0-1 and set the sequence
        tree.root.sequence, tree.root.profile = \
            seq_utils.prof2seq(tree.root.profile, self.gtr, True)


        for node in tree.find_clades(order='preorder'):
            if node.up is None: # skip if node is root
                continue
            # integrate the information coming from parents with the information
            # of all children my multiplying it to the prev computed profile
            node.seq_msg_from_parent = self.gtr.propagate_profile(node.up.profile,
                            node.branch_length,
                            rotated=False, # use unrotated
                            return_log=False)
            node.profile *= node.seq_msg_from_parent

            # reset the profile to 0-1 and  set the sequence
            sequence, profile = seq_utils.prof2seq(node.profile, self.gtr, True)
            node.mutations = [(anc, pos, der) for pos, (anc, der) in
                            enumerate(izip(node.up.sequence, sequence)) if anc!=der]

            tree.sequence_LH += np.sum(np.log(node.seq_msg_from_parent[profile>0.9]))

            if hasattr(node, 'sequence'):
                N_diff += (sequence!=node.sequence).sum()
            else:
                N_diff += self.L
            node.sequence = sequence
            node.profile = profile
        return N_diff
Exemple #6
0
    def _ml_anc(self,
                marginal=False,
                verbose=0,
                store_compressed=True,
                sample_from_profile=False,
                **kwargs):
        """
        Perform ML reconstruction of the ancestral states
        KWargs:
         - store_lh (bool): if True, all likelihoods will be stored for all nodes.
           Useful for testing, diagnostics and if special post-processing is required.
         - verbose (int): how verbose the output should be
        """

        tree = self.tree
        # number of nucleotides changed from prev reconstruction
        N_diff = 0
        if 'store_lh' in kwargs:
            store_lh = kwargs['store_lh'] == True

        L = tree.get_terminals()[0].sequence.shape[0]
        n_states = self.gtr.alphabet.shape[0]
        self.logger(
            "TreeAnc._ml_anc: type of reconstruction:" +
            ('marginal' if marginal else "joint"), 2)
        self.logger("Walking up the tree, computing likelihoods... ", 3)
        for leaf in tree.get_terminals():
            # in any case, set the profile
            leaf.profile = seq_utils.seq2prof(leaf.sequence,
                                              self.gtr.profile_map)
            leaf.lh_prefactor = np.zeros(L)
        for node in tree.get_nonterminals(order='postorder'):  #leaves -> root
            # regardless of what was before, set the profile to ones
            node.lh_prefactor = np.zeros(L)
            node.profile = np.ones((L, n_states))  # we will multiply it
            for ch in node.clades:
                ch.seq_msg_to_parent = self.gtr.propagate_profile(
                    ch.profile,
                    self._branch_length_to_gtr(ch),
                    return_log=False)  # raw prob to transfer prob up
                node.profile *= ch.seq_msg_to_parent
                node.lh_prefactor += ch.lh_prefactor

            pre = node.profile.max(axis=1)  #sum over nucleotide states
            node.profile = (node.profile.T /
                            pre).T  # normalize so that the sum is 1
            node.lh_prefactor += np.log(pre)  # and store log-prefactor

        self.logger(
            "Walking down the tree, computing maximum likelihood sequences...",
            3)

        # extract the likelihood from the profile
        tree.root.profile *= self.gtr.Pi  # Msg to the root from the distant part (equ frequencies)
        pre = tree.root.profile.sum(axis=1)
        tree.root.profile = (tree.root.profile.T / pre).T
        tree.root.lh_prefactor += np.log(pre)

        tree.anc_LH = tree.root.lh_prefactor.sum()
        tree.sequence_LH = 0
        # reset profile to 0-1 and set the sequence
        tmp_sample = True if sample_from_profile == 'root' else sample_from_profile
        tree.root.sequence, tree.root.profile = \
            seq_utils.prof2seq(tree.root.profile, self.gtr, sample_from_prof=tmp_sample,
                               collapse_prof=not marginal)
        tree.root.seq_msg_from_parent = np.repeat([self.gtr.Pi],
                                                  len(tree.root.sequence),
                                                  axis=0)

        tmp_sample = False if sample_from_profile == 'root' else sample_from_profile

        for node in tree.find_clades(order='preorder'):
            if node.up is None:  # skip if node is root
                continue
            # integrate the information coming from parents with the information
            # of all children my multiplying it to the prev computed profile
            if marginal:
                tmp_msg = np.copy(node.up.seq_msg_from_parent)
                for c in node.up.clades:
                    if c != node:
                        tmp_msg *= c.seq_msg_to_parent
                node.seq_msg_from_parent = self.gtr.propagate_profile(
                    tmp_msg,
                    self._branch_length_to_gtr(node),
                    return_log=False)
                node.profile *= node.seq_msg_from_parent
            else:
                node.seq_msg_from_parent = self.gtr.propagate_profile(
                    node.up.profile,
                    self._branch_length_to_gtr(node),
                    return_log=False)
                node.profile *= node.seq_msg_from_parent

            # reset the profile to 0-1 and  set the sequence
            sequence, profile = seq_utils.prof2seq(node.profile,
                                                   self.gtr,
                                                   sample_from_prof=tmp_sample,
                                                   collapse_prof=not marginal)
            node.mutations = [
                (anc, pos, der)
                for pos, (anc,
                          der) in enumerate(izip(node.up.sequence, sequence))
                if anc != der
            ]

            # this needs fixing for marginal reconstruction
            if not marginal:
                tree.sequence_LH += np.sum(
                    np.log(node.seq_msg_from_parent[profile > 0.9]))

            if hasattr(node, 'sequence') and node.sequence is not None:
                N_diff += (sequence != node.sequence).sum()
            else:
                N_diff += L

            node.sequence = sequence
            node.profile = profile

        # note that the root doesn't contribute to N_diff (intended, since root sequence is often ambiguous)
        self.logger("TreeAnc._ml_anc: ...done", 3)
        if store_compressed:
            self.store_compressed_sequence_pairs()
        return N_diff
Exemple #7
0
    def _ml_anc_joint(self,
                      verbose=0,
                      store_compressed=True,
                      sample_from_profile=False,
                      debug=False,
                      **kwargs):
        """
        Perform joint ML reconstruction of the ancestral states. In contrast to
        marginal reconstructions, this only needs to compare and multiply LH and
        can hence operate in log space.
        KWargs:
         - store_lh (bool): if True, all likelihoods will be stored for all nodes.
           Useful for testing, diagnostics and if special post-processing is required.
         - verbose (int): how verbose the output should be
        """
        N_diff = 0  # number of sites differ from perv reconstruction
        L = self.tree.get_terminals()[0].sequence.shape[0]
        n_states = self.gtr.alphabet.shape[0]

        self.logger("TreeAnc._ml_anc_joint: type of reconstruction: Joint", 2)

        self.logger(
            "TreeAnc._ml_anc_joint: Walking up the tree, computing likelihoods... ",
            3)
        # for the internal nodes, scan over all states j of this node, maximize the likelihood
        for node in self.tree.find_clades(order='postorder'):
            if node.up is None:
                node.joint_Cx = None  # not needed for root

            # preallocate storage
            node.joint_Lx = np.zeros((L, n_states))  # likelihood array
            node.joint_Cx = np.zeros((L, n_states),
                                     dtype=int)  # max LH indices
            branch_len = self._branch_length_to_gtr(node)
            # transition matrix from parent states to the current node states.
            # denoted as Pij(i), where j - parent state, i - node state
            log_transitions = np.log(self.gtr.expQt(branch_len))

            if node.is_terminal():
                msg_from_children = np.log(
                    np.maximum(
                        seq_utils.seq2prof(node.sequence,
                                           self.gtr.profile_map),
                        ttconf.TINY_NUMBER))
                msg_from_children[np.isnan(msg_from_children)
                                  | np.isinf(msg_from_children
                                             )] = -ttconf.BIG_NUMBER
            else:
                # Product (sum-Log) over all child subtree likelihoods.
                # this is prod_ch L_x(i)
                msg_from_children = np.sum(np.stack(
                    [c.joint_Lx for c in node.clades], axis=0),
                                           axis=0)

            # for every possible state of the parent node,
            # get the best state of the current node
            # and compute the likelihood of this state
            for char_i, char in enumerate(self.gtr.alphabet):
                # Pij(i) * L_ch(i) for given parent state j
                msg_to_parent = (log_transitions.T[char_i, :] +
                                 msg_from_children)
                # For this parent state, choose the best state of the current node:
                node.joint_Cx[:, char_i] = msg_to_parent.argmax(axis=1)
                # compute the likelihood of the best state of the current node
                # given the state of the parent (char_i)
                node.joint_Lx[:, char_i] = msg_to_parent.max(axis=1)

        # root node profile = likelihood of the total tree
        msg_from_children = np.sum(np.stack(
            [c.joint_Lx for c in self.tree.root.clades], axis=0),
                                   axis=0)
        # Pi(i) * Prod_ch Lch(i)
        self.tree.root.joint_Lx = msg_from_children + np.log(self.gtr.Pi)
        normalized_profile = (self.tree.root.joint_Lx.T -
                              self.tree.root.joint_Lx.max(axis=1)).T

        # choose sequence characters from this profile.
        # treat root node differently to avoid piling up mutations on the longer branch
        if sample_from_profile == 'root':
            root_sample_from_profile = True
        elif isinstance(sample_from_profile, bool):
            root_sample_from_profile = sample_from_profile

        seq, anc_lh_vals, idxs = seq_utils.prof2seq(
            np.exp(normalized_profile),
            self.gtr,
            sample_from_prof=root_sample_from_profile)

        # compute the likelihood of the most probable root sequence
        self.tree.sequence_LH = np.choose(idxs, self.tree.root.joint_Lx.T)
        self.tree.sequence_joint_LH = self.tree.sequence_LH.sum()
        self.tree.root.sequence = seq
        self.tree.root.seq_idx = idxs

        self.logger(
            "TreeAnc._ml_anc_joint: Walking down the tree, computing maximum likelihood sequences...",
            3)
        # for each node, resolve the conditioning on the parent node
        for node in self.tree.find_clades(order='preorder'):

            # root node has no mutations, everything else has been alread y set
            if node.up is None:
                node.mutations = []
                continue

            # choose the value of the Cx(i), corresponding to the state of the
            # parent node i. This is the state of the current node
            node.seq_idx = np.choose(node.up.seq_idx, node.joint_Cx.T)
            # reconstruct seq, etc
            tmp_sequence = np.choose(node.seq_idx, self.gtr.alphabet)
            if hasattr(node, 'sequence') and node.sequence is not None:
                N_diff += (tmp_sequence != node.sequence).sum()
            else:
                N_diff += L

            node.sequence = tmp_sequence
            node.mutations = [(anc, pos, der) for pos, (
                anc, der) in enumerate(izip(node.up.sequence, node.sequence))
                              if anc != der]

        self.logger("TreeAnc._ml_anc_joint: ...done", 3)
        if store_compressed:
            self.store_compressed_sequence_pairs()

        # do clean-up
        if not debug:
            for node in self.tree.find_clades(order='preorder'):
                del node.joint_Lx
                del node.joint_Cx
                del node.seq_idx

        return N_diff
Exemple #8
0
    def _ml_anc_marginal(self,
                         verbose=0,
                         store_compressed=True,
                         sample_from_profile=False,
                         debug=False,
                         **kwargs):
        """
        Perform marginal ML reconstruction of the ancestral states. In contrast to
        joint reconstructions, this needs to access the probabilities rather than only
        log probabilities and is hence handled by a separate function.
        KWargs:
         - store_lh (bool): if True, all likelihoods will be stored for all nodes.
           Useful for testing, diagnostics and if special post-processing is required.
         - verbose (int): how verbose the output should be
        """

        tree = self.tree
        # number of nucleotides changed from prev reconstruction
        N_diff = 0

        L = self.tree.get_terminals()[0].sequence.shape[0]
        n_states = self.gtr.alphabet.shape[0]
        self.logger(
            "TreeAnc._ml_anc_marginal: type of reconstruction: Marginal", 2)

        self.logger("Walking up the tree, computing likelihoods... ", 3)
        #  set the leaves profiles
        for leaf in tree.get_terminals():
            # in any case, set the profile
            leaf.marginal_subtree_LH = seq_utils.seq2prof(
                leaf.sequence, self.gtr.profile_map)
            leaf.marginal_subtree_LH_prefactor = np.zeros(L)

        # propagate leaves -->> root, set the marginal-likelihood messages
        for node in tree.get_nonterminals(order='postorder'):  #leaves -> root
            # regardless of what was before, set the profile to ones
            node.marginal_subtree_LH_prefactor = np.zeros(L)
            node.marginal_subtree_LH = np.ones(
                (L, n_states))  # we will multiply it
            for ch in node.clades:
                ch.marginal_Lx = self.gtr.propagate_profile(
                    ch.marginal_subtree_LH,
                    self._branch_length_to_gtr(ch),
                    return_log=False)  # raw prob to transfer prob up
                node.marginal_subtree_LH *= ch.marginal_Lx
                node.marginal_subtree_LH_prefactor += ch.marginal_subtree_LH_prefactor

            pre = node.marginal_subtree_LH.sum(
                axis=1)  #sum over nucleotide states
            node.marginal_subtree_LH = (
                node.marginal_subtree_LH.T /
                pre).T  # normalize so that the sum is 1
            node.marginal_subtree_LH_prefactor += np.log(
                pre)  # and store log-prefactor

        self.logger(
            "Computing root node sequence and total tree likelihood...", 3)
        # reconstruct the root node sequence
        tree.root.marginal_subtree_LH *= self.gtr.Pi  # Msg to the root from the distant part (equ frequencies)
        pre = tree.root.marginal_subtree_LH.sum(axis=1)
        tree.root.marginal_profile = (tree.root.marginal_subtree_LH.T / pre).T
        tree.root.marginal_subtree_LH_prefactor += np.log(pre)

        # choose sequence characters from this profile.
        # treat root node differently to avoid piling up mutations on the longer branch
        if sample_from_profile == 'root':
            root_sample_from_profile = True
            other_sample_from_profile = False
        elif isinstance(sample_from_profile, bool):
            root_sample_from_profile = sample_from_profile
            other_sample_from_profile = sample_from_profile

        seq, prof_vals, idxs = seq_utils.prof2seq(
            tree.root.marginal_profile,
            self.gtr,
            sample_from_prof=root_sample_from_profile)

        self.tree.sequence_LH = np.log(
            prof_vals) + tree.root.marginal_subtree_LH_prefactor
        self.tree.sequence_marginal_LH = self.tree.sequence_LH.sum()
        self.tree.root.sequence = seq

        # need this fake msg to account for the complementary subtree when traversing tree back
        tree.root.seq_msg_from_parent = np.repeat([self.gtr.Pi],
                                                  len(tree.root.sequence),
                                                  axis=0)

        self.logger(
            "Walking down the tree, computing maximum likelihood sequences...",
            3)
        # propagate root -->> leaves, reconstruct the internal node sequences
        # provided the upstream message + the message from the complementary subtree
        for node in tree.find_clades(order='preorder'):
            if node.up is None:  # skip if node is root
                continue

            # integrate the information coming from parents with the information
            # of all children my multiplying it to the prev computed profile
            tmp_msg = np.copy(node.up.seq_msg_from_parent)
            for c in node.up.clades:
                if c != node:
                    tmp_msg *= c.marginal_Lx
            node.seq_msg_from_parent = self.gtr.propagate_profile(
                tmp_msg, self._branch_length_to_gtr(node), return_log=False)
            node.marginal_profile = node.marginal_subtree_LH * node.seq_msg_from_parent

            # choose sequence based maximal marginal LH. THIS NORMALIZES marginal_profile in place
            seq, prof_vals, idxs = seq_utils.prof2seq(
                node.marginal_profile,
                self.gtr,
                sample_from_prof=other_sample_from_profile)
            node.mutations = [
                (anc, pos, der)
                for pos, (anc, der) in enumerate(izip(node.up.sequence, seq))
                if anc != der
            ]

            if hasattr(node, 'sequence') and node.sequence is not None:
                N_diff += (seq != node.sequence).sum()
            else:
                N_diff += L
            #assign new sequence
            node.sequence = seq

        # note that the root doesn't contribute to N_diff (intended, since root sequence is often ambiguous)
        self.logger("TreeAnc._ml_anc_marginal: ...done", 3)
        if store_compressed:
            self.store_compressed_sequence_pairs()

        # do clean-up:
        if not debug:
            for node in self.tree.find_clades():
                del node.marginal_subtree_LH
                del node.marginal_subtree_LH_prefactor
                del node.seq_msg_from_parent

        return N_diff