Пример #1
0
 def test_connected_sets(self):
     """Directed"""
     cc = connected_sets(self.C)
     for i in range(len(cc)):
         self.assertTrue(np.all(self.cc_directed[i] == np.sort(cc[i])))
     """Undirected"""
     cc = connected_sets(self.C, directed=False)
     for i in range(len(cc)):
         self.assertTrue(np.all(self.cc_undirected[i] == np.sort(cc[i])))
Пример #2
0
def compute_connected_sets(count_matrix,
                           connectivity_threshold: float = 0,
                           directed=True):
    """ Computes the connected sets of a count matrix C.

    C : (N, N) np.ndarray
        count matrix
    mincount_connectivity : float
        Minimum count required to be included in the connected set computation.
    directed : boolean
        True: Seek connected sets in the directed graph. False: Seek connected sets in the undirected graph.
    Returns
    -------
    A list of arrays, each array representing a connected set by enumerating the respective states. The list is in
    descending order by size of connected set.
    """
    import deeptime.markov.tools.estimation as msmest
    import scipy.sparse as scs
    if connectivity_threshold > 0:
        if scs.issparse(count_matrix):
            Cconn = count_matrix.tocsr(copy=True)
            Cconn.data[Cconn.data < connectivity_threshold] = 0
            Cconn.eliminate_zeros()
        else:
            Cconn = count_matrix.copy()
            Cconn[np.where(Cconn < connectivity_threshold)] = 0
    else:
        Cconn = count_matrix
    # treat each connected set separately
    S = msmest.connected_sets(Cconn, directed=directed)
    return S
Пример #3
0
    def test_multiple_components(self):
        L = np.array([[0, 1, 1, 0, 0, 0, 0], [1, 0, 1, 0, 0, 0, 0],
                      [1, 1, 0, 0, 0, 0, 0], [0, 0, 0, 0, 1, 1, 1],
                      [0, 0, 0, 1, 0, 1, 0], [0, 0, 0, 1, 1, 0, 1],
                      [0, 0, 0, 1, 0, 1, 0]])

        transition_matrix = L / np.sum(L, 1).reshape(-1, 1)
        pi = np.zeros((transition_matrix.shape[0], ))
        for cs in connected_sets(transition_matrix):
            P_sub = transition_matrix[cs, :][:, cs]
            pi[cs] = stationary_distribution(P_sub)
        chi = pcca(transition_matrix, 2, pi=pi)
        expected = np.array([[0., 1.], [0., 1.], [0., 1.], [1., 0.], [1., 0.],
                             [1., 0.], [1., 0.]])
        np.testing.assert_equal(chi, expected)
Пример #4
0
    def _compute_connected_sets(C, mincount_connectivity, strong=True):
        """ Computes the connected sets of C.

        C : count matrix
        mincount_connectivity : float
            Minimum count which counts as a connection.
        strong : boolean
            True: Seek strongly connected sets. False: Seek weakly connected sets.
        Returns
        -------
        Cconn, S
        """
        import scipy.sparse as scs
        if scs.issparse(C):
            Cconn = C.tocsr(copy=True)
            Cconn.data[Cconn.data < mincount_connectivity] = 0
            Cconn.eliminate_zeros()
        else:
            Cconn = C.copy()
            Cconn[np.where(Cconn < mincount_connectivity)] = 0

        # treat each connected set separately
        S = connected_sets(Cconn, directed=strong)
        return S
Пример #5
0
    def _estimate(self, dtrajs):
        if self.E is None or self.w is None or self.m is None:
            raise ValueError("E, w or m was not specified. Stopping.")

        # get trajectory counts. This sets _C_full and _nstates_full
        dtrajstats = self._get_dtraj_stats(dtrajs)
        self._C_full = dtrajstats.count_matrix()  # full count matrix
        self._nstates_full = self._C_full.shape[0]  # number of states

        # set active set. This is at the same time a mapping from active to full
        if self.connectivity == 'largest':
            # statdist not given - full connectivity on all states
            self.active_set = dtrajstats.largest_connected_set
        else:
            # for 'None' and 'all' all visited states are active
            self.active_set = dtrajstats.visited_set

        # FIXME: setting is_estimated before so that we can start using the parameters just set, but this is not clean!
        # is estimated
        self._is_estimated = True

        # if active set is empty, we can't do anything.
        if _np.size(self.active_set) == 0:
            raise RuntimeError('Active set is empty. Cannot estimate AMM.')

        # active count matrix and number of states
        self._C_active = dtrajstats.count_matrix(subset=self.active_set)
        self._nstates = self._C_active.shape[0]

        # computed derived quantities
        # back-mapping from full to lcs
        self._full2active = -1 * _np.ones(dtrajstats.nstates, dtype=int)
        self._full2active[self.active_set] = _np.arange(len(self.active_set))

        # slice out active states from E matrix

        _dset = list(set(_np.concatenate(self._dtrajs_full)))
        _rras = [_dset.index(s) for s in self.active_set]
        self.E_active = self.E[_rras]

        if not self.sparse:
            self._C_active = self._C_active.toarray()
            self._C_full = self._C_full.toarray()

        # reversibly counted
        self._C2 = 0.5 * (self._C_active + self._C_active.T)
        self._nz = _np.nonzero(self._C2)
        self._csum = _np.sum(self._C_active, axis=1)  # row sums C

        # get ranges of Markov model expectation values
        if self.support_ci == 1:
            self.E_min = _np.min(self.E_active, axis=0)
            self.E_max = _np.max(self.E_active, axis=0)
        else:
            # PyEMMA confidence interval calculation fails sometimes with conf=1.0
            self.E_min, self.E_max = _ci(self.E_active, conf=self.support_ci)

        # dimensions of E matrix
        self.n_mstates_active, self.n_exp_active = _np.shape(self.E_active)

        assert self.n_exp_active == len(self.w)
        assert self.n_exp_active == len(self.m)

        self.count_outside = []
        self.count_inside = []
        self._lls = []

        i = 0
        # Determine which experimental values are outside the support as defined by the Confidence interval
        for emi, ema, mm, mw in zip(self.E_min, self.E_max, self.m, self.w):
            if mm < emi or ema < mm:
                self.logger.info(
                    "Experimental value %f is outside the support (%f,%f)" %
                    (mm, emi, ema))
                self.count_outside.append(i)
            else:
                self.count_inside.append(i)
            i = i + 1

        self.logger.info(
            "Total experimental constraints outside support %d of %d" %
            (len(self.count_outside), len(self.E_min)))

        # A number of initializations
        self.P, self.pi = transition_matrix(self._C_active,
                                            reversible=True,
                                            return_statdist=True)
        self.lagrange = _np.zeros(self.m.shape)
        self._pihat = self.pi.copy()
        self._update_mhat()
        self._dmhat = 1e-1 * _np.ones(_np.shape(self.mhat))

        # Determine number of slices of R-tensors computable at once with the given cache size
        self._slicesz = _np.floor(self.max_cache /
                                  (self.P.nbytes / 1.e6)).astype(int)
        # compute first bundle of slices
        self._update_Rslices(0)

        self._ll_old = self._log_likelihood_biased(self._C_active, self.P,
                                                   self.m, self.mhat, self.w)

        self._lls = [self._ll_old]

        # make sure everything is initialized

        self._update_pihat()
        self._update_mhat()

        self._update_Q()
        self._update_X_and_pi()

        self._ll_old = self._log_likelihood_biased(self._C_active, self.P,
                                                   self.m, self.mhat, self.w)
        self._update_G()

        #
        # Main estimation algorithm
        # 2-step algorithm, lagrange multipliers and pihat have different convergence criteria
        # when the lagrange multipliers have converged, pihat is updated until the log-likelihood has converged (changes are smaller than 1e-3).
        # These do not always converge together, but usually within a few steps of each other.
        # A better heuristic for the latter may be necessary. For realistic cases (the two ubiquitin examples in [1])
        # this yielded results very similar to those with more stringent convergence criteria (changes smaller than 1e-9) with convergence times
        # which are seconds instead of tens of minutes.
        #

        converged = False  # Convergence flag for lagrange multipliers
        i = 0
        die = False
        while i <= self.maxiter:
            pihat_old = self._pihat.copy()
            self._update_pihat()
            if not _np.all(self._pihat > 0):
                self._pihat = pihat_old.copy()
                die = True
                self.logger.warning(
                    "pihat does not have a finite probability for all states, terminating"
                )
            self._update_mhat()
            self._update_Q()
            if i > 1:
                X_old = self.X.copy()
                self._update_X_and_pi()
                if _np.any(self.X[self._nz] < 0) and i > 0:
                    die = True
                    self.logger.warning(
                        "Warning: new X is not proportional to C... reverting to previous step and terminating"
                    )
                    self.X = X_old.copy()

            if not converged:
                self._newton_lagrange()
            else:  # once Lagrange multipliers are converged compute likelihood here
                P = self.X / self.pi[:, None]
                _ll_new = self._log_likelihood_biased(self._C_active, P,
                                                      self.m, self.mhat,
                                                      self.w)
                self._lls.append(_ll_new)

            # General case fixed-point iteration
            if len(self.count_outside) > 0:
                if i > 1 and _np.all(
                    (_np.abs(self._dmhat) /
                     self.sigmas) < self.eps) and not converged:
                    self.logger.info(
                        "Converged Lagrange multipliers after %i steps..." % i)
                    converged = True
            # Special case
            else:
                if _np.abs(self._lls[-2] - self._lls[-1]) < 1e-8:
                    self.logger.info(
                        "Converged Lagrange multipliers after %i steps..." % i)
                    converged = True
            # if Lagrange multipliers are converged, check whether log-likelihood has converged
            if converged and _np.abs(self._lls[-2] - self._lls[-1]) < 1e-8:
                self.logger.info("Converged pihat after %i steps..." % i)
                die = True
            if die:
                break
            if i == self.maxiter:
                self.logger.info("Failed to converge within %i iterations. "
                                 "Consider increasing max_iter(now=%i)" %
                                 (i, self.max_iter))
            i += 1

        _P = transition_matrix(self._C_active, reversible=True, mu=self._pihat)

        self._connected_sets = connected_sets(self._C_full)
        self.set_model_params(P=_P,
                              pi=self._pihat,
                              reversible=True,
                              dt_model=self.timestep_traj.get_scaled(self.lag))

        return self
Пример #6
0
def pcca(P: np.ndarray,
         m: int,
         pi: np.ndarray = None,
         transition_matrix_tol: float = 1e-12):
    """
    PCCA+ spectral clustering method with optimized memberships [1]_

    Clusters the first m eigenvectors of a transition matrix in order to cluster the states.
    This function does not assume that the transition matrix is fully connected. Disconnected sets
    will automatically define the first metastable states, with perfect membership assignments.

    Parameters
    ----------
    P : ndarray (n,n)
        Transition matrix.
    m : int
        Number of clusters to group to.
    pi : ndarray(n,), optional, default=None
        Stationary distribution if available. Should be defined piecewise over the connected sets.
    transition_matrix_tol : float, optional, default=1e-12
        Tolerance under which P is checked to be a transition matrix.

    Returns
    -------
    chi : ndarray (n x m)
        A matrix containing the probability or membership of each state to be assigned to each cluster.
        The rows sum to 1.

    References
    ----------
    [1] S. Roeblitz and M. Weber, Fuzzy spectral clustering by PCCA+:
        application to Markov state models and data classification.
        Adv Data Anal Classif 7, 147-179 (2013).
    [2] F. Noe, multiset PCCA and HMMs, in preparation.
    """
    # imports
    from deeptime.markov.tools.estimation import connected_sets
    from deeptime.markov.tools.analysis import eigenvalues, is_transition_matrix, hitting_probability

    # validate input
    n = np.shape(P)[0]
    if m > n:
        raise ValueError(
            f"Number of metastable states m={m} exceeds number of states of transition matrix n={n}"
        )
    if not is_transition_matrix(P, tol=transition_matrix_tol):
        raise ValueError("Input matrix is not a transition matrix.")
    if pi is not None and pi.ndim > 1:
        raise ValueError(
            "Stationary distribution must be given as one-dimensional array or left None."
        )
    if pi is not None and pi.shape[0] != n:
        raise ValueError(
            f"Stationary distribution must be defined on entire space, piecewise if the transition matrix "
            f"has multiple connected components. It covered {pi.shape[0]} != {n} states."
        )

    # prepare output
    chi = np.zeros((n, m))

    # test connectivity
    components = connected_sets(P)
    n_components = len(
        components
    )  # (n_components, labels) = connected_components(P, connection='strong')

    # store components as closed (with positive equilibrium distribution)
    # or as transition states (with vanishing equilibrium distribution)
    closed_components = []
    transition_states = []
    for i in range(n_components):
        component = components[i]  # np.argwhere(labels==i).flatten()
        rest = list(set(range(n)) - set(component))
        # is component closed?
        if np.sum(P[component, :][:, rest]) == 0:
            closed_components.append(component)
        else:
            transition_states.append(component)
    n_closed_components = len(closed_components)
    closed_states = np.concatenate(closed_components)
    if len(transition_states) == 0:
        transition_states = np.array([], dtype=int)
    else:
        transition_states = np.concatenate(transition_states)

    # check if we have enough clusters to support the disconnected sets
    if m < len(closed_components):
        raise ValueError(
            f"Number of metastable states m={m} is too small. Transition matrix "
            f"has {len(closed_components)} disconnected components.")

    # We collect eigenvalues in order to decide which
    closed_components_Psub = []
    closed_components_ev = []
    closed_components_enum = []
    for i in range(n_closed_components):
        component = closed_components[i]

        # compute eigenvalues in submatrix
        Psub = P[component, :][:, component]
        closed_components_Psub.append(Psub)
        closed_components_ev.append(eigenvalues(Psub))
        closed_components_enum.append(i * np.ones(
            (component.size, ), dtype=int))

    # flatten
    closed_components_ev_flat = np.hstack(closed_components_ev)
    closed_components_enum_flat = np.hstack(closed_components_enum)
    # which components should be clustered?
    component_indexes = closed_components_enum_flat[np.argsort(
        closed_components_ev_flat)][0:m]
    # cluster each component
    ipcca = 0
    for i in range(n_closed_components):
        component = closed_components[i]
        # how many PCCA states in this component?
        m_by_component = np.shape(np.argwhere(component_indexes == i))[0]

        # if 1, then the result is trivial
        if m_by_component == 1:
            chi[component, ipcca] = 1.0
            ipcca += 1
        elif m_by_component > 1:
            # print "submatrix: ",closed_components_Psub[i]
            chi[component, ipcca:ipcca + m_by_component] = _pcca_connected(
                closed_components_Psub[i],
                m_by_component,
                pi=None if pi is None else pi[closed_components[i]])
            ipcca += m_by_component
        else:
            raise RuntimeError(
                f"Component {i} spuriously has {m_by_component} pcca sets")

    # finally assign all transition states
    if transition_states.size > 0:
        # make all closed states absorbing, so we can see which closed state we hit first
        Pabs = P.copy()
        Pabs[closed_states, :] = 0.0
        Pabs[closed_states, closed_states] = 1.0
        for i in range(closed_states.size):
            # hitting probability to each closed state
            h = hitting_probability(Pabs, closed_states[i])
            for j in range(transition_states.size):
                # transition states belong to closed states with the hitting probability, and inherit their chi
                chi[transition_states[j]] += h[transition_states[j]] * chi[
                    closed_states[i]]

    # check if we have m metastable sets. If less than m, we must raise
    nmeta = np.count_nonzero(chi.sum(axis=0))
    if nmeta < m:
        raise RuntimeError(
            f"{m} metastable states requested, but transition matrix only has {nmeta}. "
            f"Consider using a prior or request less metastable states.")

    return chi
Пример #7
0
def _pcca_connected(P, n, pi=None):
    r"""PCCA+ spectral clustering method with optimized memberships [1]_

    Clusters the first n_cluster eigenvectors of a transition matrix in order to cluster the states.
    This function assumes that the transition matrix is fully connected.

    Parameters
    ----------
    P : ndarray (n,n)
        Transition matrix.
    n : int
        Number of clusters to group to.
    pi: ndarray(n,), optional, default=None
        Stationary distribution if available.

    Returns
    -------
    chi : ndarray (n x m)
        A matrix containing the probability or membership of each state to be assigned to each cluster.
        The rows sum to 1.

    References
    ----------
    [1] S. Roeblitz and M. Weber, Fuzzy spectral clustering by PCCA+:
        application to Markov state models and data classification.
        Adv Data Anal Classif 7, 147-179 (2013).
    """

    # test connectivity
    from deeptime.markov.tools.estimation import connected_sets

    labels = connected_sets(P)
    n_components = len(
        labels
    )  # (n_components, labels) = connected_components(P, connection='strong')
    if n_components > 1:
        raise ValueError(
            "Transition matrix is disconnected. Cannot use pcca_connected.")

    if pi is None:
        from deeptime.markov.tools.analysis import stationary_distribution
        pi = stationary_distribution(P)
    else:
        if pi.shape[0] != P.shape[0]:
            raise ValueError(
                f"Stationary distribution must span entire state space but got {pi.shape[0]} states "
                f"instead of {P.shape[0]}.")
        pi /= pi.sum()  # make sure it is normalized

    from deeptime.markov.tools.analysis import is_reversible

    if not is_reversible(P, mu=pi):
        raise ValueError(
            "Transition matrix does not fulfill detailed balance. "
            "Make sure to call pcca with a reversible transition matrix estimate"
        )
    # TODO: Susanna mentioned that she has a potential fix for nonreversible matrices by replacing each complex conjugate
    #      pair by the real and imaginary components of one of the two vectors. We could use this but would then need to
    #      orthonormalize all eigenvectors e.g. using Gram-Schmidt orthonormalization. Currently there is no theoretical
    #      foundation for this, so I'll skip it for now.

    # right eigenvectors, ordered
    from deeptime.markov.tools.analysis import eigenvectors

    evecs = eigenvectors(P, n)

    # orthonormalize
    for i in range(n):
        evecs[:, i] /= math.sqrt(np.dot(evecs[:, i] * pi, evecs[:, i]))
    # make first eigenvector positive
    evecs[:, 0] = np.abs(evecs[:, 0])

    # Is there a significant complex component?
    if not np.alltrue(np.isreal(evecs)):
        warnings.warn(
            "The given transition matrix has complex eigenvectors, so it doesn't exactly fulfill detailed balance. "
            "Forcing eigenvectors to be real and continuing. Be aware that this is not theoretically solid."
        )
    evecs = np.real(evecs)

    # create initial solution using PCCA+. This could have negative memberships
    chi, rot_matrix = _pcca_connected_isa(evecs, n)

    # optimize the rotation matrix with PCCA++.
    rot_matrix = _opt_soft(evecs, rot_matrix, n)

    # These memberships should be nonnegative
    memberships = np.dot(evecs[:, :], rot_matrix)

    # We might still have numerical errors. Force memberships to be in [0,1]
    memberships = np.clip(memberships, 0., 1.)

    for i in range(0, np.shape(memberships)[0]):
        memberships[i] /= np.sum(memberships[i])

    return memberships
Пример #8
0
def _compute_csets(
    connectivity, state_counts, count_matrices, ttrajs, dtrajs, bias_trajs, nn,
    equilibrium_state_counts=None, factor=1.0, callback=None):
    n_therm_states, n_conf_states = state_counts.shape

    if equilibrium_state_counts is not None:
        all_state_counts = state_counts + equilibrium_state_counts
    else:
        all_state_counts = state_counts

    if connectivity is None:
        cset_projected = _np.where(all_state_counts.sum(axis=0) > 0)[0]
        csets = [ _np.where(all_state_counts[k, :] > 0)[0] for k in range(n_therm_states) ]
        return csets, cset_projected
    elif connectivity == 'summed_count_matrix':
        # assume that two thermodynamic states overlap when there are samples from both
        # ensembles in some Markov state
        C_sum = count_matrices.sum(axis=0)
        if equilibrium_state_counts is not None:
            eq_states = _np.where(equilibrium_state_counts.sum(axis=0) > 0)[0]
            C_sum[eq_states, eq_states[:, _np.newaxis]] = 1
        cset_projected = largest_connected_set(C_sum, directed=True)
        csets = []
        for k in range(n_therm_states):
            cset = _np.intersect1d(_np.where(all_state_counts[k, :] > 0), cset_projected)
            csets.append(cset)
        return csets, cset_projected
    elif connectivity == 'reversible_pathways' or connectivity == 'largest':
        C_proxy = _np.zeros((n_conf_states, n_conf_states), dtype=int)
        for C in count_matrices:
            for comp in connected_sets(C, directed=True):
                C_proxy[comp[0:-1], comp[1:]] = 1 # add chain of states
        if equilibrium_state_counts is not None:
            eq_states = _np.where(equilibrium_state_counts.sum(axis=0) > 0)[0]
            C_proxy[eq_states, eq_states[:, _np.newaxis]] = 1
        cset_projected = largest_connected_set(C_proxy, directed=False)
        csets = []
        for k in range(n_therm_states):
            cset = _np.intersect1d(_np.where(all_state_counts[k, :] > 0), cset_projected)
            csets.append(cset)
        return csets, cset_projected
    elif connectivity in ['neighbors', 'post_hoc_RE', 'BAR_variance']:
        dim = n_therm_states * n_conf_states
        if connectivity == 'post_hoc_RE' or connectivity == 'BAR_variance':
            if connectivity == 'post_hoc_RE':
                overlap = _util._overlap_post_hoc_RE
            else:
                overlap = _overlap_BAR_variance
            i_s = []
            j_s = []
            for i in range(n_conf_states):
                # can take a very long time, allow to report progress via callback
                if callback is not None:
                    callback(maxiter=n_conf_states, iteration_step=i)
                therm_states = _np.where(all_state_counts[:, i] > 0)[0] # therm states that have samples
                # prepare list of indices for all thermodynamic states
                traj_indices = {}
                frame_indices = {}
                for k in therm_states:
                    frame_indices[k] = [_np.where(
                        _np.logical_and(d == i, t == k))[0] for t, d in zip(ttrajs, dtrajs)]
                    traj_indices[k] = [j for j, fi in enumerate(frame_indices[k]) if len(fi) > 0]
                for k in therm_states:
                    for l in therm_states:
                        if k!=l:
                            kl = _np.array([k, l])
                            a = _np.concatenate([
                                bias_trajs[j][:, kl][frame_indices[k][j], :] for j in traj_indices[k]])
                            b = _np.concatenate([
                                bias_trajs[j][:, kl][frame_indices[l][j], :] for j in traj_indices[l]])
                            if overlap(a, b, factor=factor):
                                x = i + k * n_conf_states
                                y = i + l * n_conf_states
                                i_s.append(x)
                                j_s.append(y)
        else: # assume overlap between nn neighboring umbrellas
            assert nn is not None, 'With connectivity="neighbors", nn can\'t be None.'
            assert nn >= 1 and nn <= n_therm_states - 1
            i_s = []
            j_s = []
            # connectivity between thermodynamic states
            for l in range(1, nn + 1):
                if callback is not None:
                    callback(maxiter=nn, iteration_step=l)
                for k in range(n_therm_states - l):
                    w = _np.where(_np.logical_and(
                        all_state_counts[k, :] > 0, all_state_counts[k + l, :] > 0))[0]
                    a = w + k * n_conf_states
                    b = w + (k + l) * n_conf_states
                    i_s += list(a)
                    j_s += list(b)

        # connectivity between conformational states:
        # just copy it from the count matrices
        for k in range(n_therm_states):
            for comp in connected_sets(count_matrices[k, :, :], directed=True):
                # add chain that links all states in the component
                i_s += list(comp[0:-1] + k * n_conf_states)
                j_s += list(comp[1:]   + k * n_conf_states)

        # If there is global equilibrium data, assume full connectivity
        # between all visited conformational states within the same thermodynamic state.
        if equilibrium_state_counts is not None:
            for k in range(n_therm_states):
                vertices = _np.where(equilibrium_state_counts[k, :]>0)[0]
                # add bidirectional chain that links all states
                chain = (vertices[0:-1], vertices[1:])
                i_s += list(chain[0] + k * n_conf_states)
                j_s += list(chain[1] + k * n_conf_states)

        data = _np.ones(len(i_s), dtype=int)
        A = _sp.sparse.coo_matrix((data, (i_s, j_s)), shape=(dim, dim))
        cset = largest_connected_set(A, directed=False)
        # group by thermodynamic state
        cset = _np.unravel_index(cset, (n_therm_states, n_conf_states), order='C')
        csets = [[] for k in range(n_therm_states)]
        for k,i in zip(*cset):
            csets[k].append(i)
        csets = [_np.array(c,dtype=int) for c in csets]
        projected_cset = _np.unique(_np.concatenate(csets))
        return csets, projected_cset
    else:
        raise Exception(
            'Unknown value "%s" of connectivity. Should be one of: \
            summed_count_matrix, strong_in_every_ensemble, neighbors, \
            post_hoc_RE or BAR_variance.' % connectivity)
Пример #9
0
    def _estimate(self, dtrajs):
        """ Estimate MSM """

        if self.core_set is not None:
            raise NotImplementedError(
                'Core set MSMs currently not compatible with {}.'.format(
                    self.__class__.__name__))

        # remove last lag steps from dtrajs:
        dtrajs_lag = [traj[:-self.lag] for traj in dtrajs]

        # get trajectory counts. This sets _C_full and _nstates_full
        dtrajstats = self._get_dtraj_stats(dtrajs_lag)
        self._C_full = dtrajstats.count_matrix()  # full count matrix
        self._nstates_full = self._C_full.shape[0]  # number of states

        # set active set. This is at the same time a mapping from active to full
        if self.connectivity == 'largest':
            self.active_set = dtrajstats.largest_connected_set
        else:
            raise NotImplementedError(
                'OOM based MSM estimation is only implemented for connectivity=\'largest\'.'
            )

        # FIXME: setting is_estimated before so that we can start using the parameters just set, but this is not clean!
        # is estimated
        self._is_estimated = True

        # if active set is empty, we can't do anything.
        if _np.size(self.active_set) == 0:
            raise RuntimeError('Active set is empty. Cannot estimate MSM.')

        # active count matrix and number of states
        self._C_active = dtrajstats.count_matrix(subset=self.active_set)
        self._nstates = self._C_active.shape[0]

        # computed derived quantities
        # back-mapping from full to lcs
        self._full2active = -1 * _np.ones(dtrajstats.nstates, dtype=int)
        self._full2active[self.active_set] = _np.arange(len(self.active_set))

        # Estimate transition matrix
        if self.connectivity == 'largest':
            # Re-sampling:
            if self.rank_Ct == 'bootstrap_counts':
                Ceff_full = effective_count_matrix(dtrajs_lag, self.lag)
                from pyemma.util.linalg import submatrix
                Ceff = submatrix(Ceff_full, self.active_set)
                smean, sdev = bootstrapping_count_matrix(Ceff, nbs=self.nbs)
            else:
                smean, sdev = bootstrapping_dtrajs(dtrajs_lag,
                                                   self.lag,
                                                   self._nstates_full,
                                                   nbs=self.nbs,
                                                   active_set=self._active_set)
            # Estimate two step count matrices:
            C2t = twostep_count_matrix(dtrajs, self.lag, self._nstates_full)
            # Rank decision:
            rank_ind = rank_decision(smean, sdev, tol=self.tol_rank)
            # Estimate OOM components:
            Xi, omega, sigma, l = oom_components(self._C_full.toarray(),
                                                 C2t,
                                                 rank_ind=rank_ind,
                                                 lcc=self.active_set)
            # Compute transition matrix:
            P, lcc_new = equilibrium_transition_matrix(
                Xi, omega, sigma, reversible=self.reversible)
        else:
            raise NotImplementedError(
                'OOM based MSM estimation is only implemented for connectivity=\'largest\'.'
            )

        # Update active set and derived quantities:
        if lcc_new.size < self._nstates:
            self._active_set = self._active_set[lcc_new]
            self._C_active = dtrajstats.count_matrix(subset=self.active_set)
            self._nstates = self._C_active.shape[0]
            self._full2active = -1 * _np.ones(dtrajstats.nstates, dtype=int)
            self._full2active[self.active_set] = _np.arange(
                len(self.active_set))
            warnings.warn(
                "Caution: Re-estimation of count matrix resulted in reduction of the active set."
            )

        # continue sparse or dense?
        if not self.sparse:
            # converting count matrices to arrays. As a result the
            # transition matrix and all subsequent properties will be
            # computed using dense arrays and dense matrix algebra.
            self._C_full = self._C_full.toarray()
            self._C_active = self._C_active.toarray()

        # Done. We set our own model parameters, so this estimator is
        # equal to the estimated model.
        self._dtrajs_full = dtrajs
        self._connected_sets = connected_sets(self._C_full)
        self._Xi = Xi
        self._omega = omega
        self._sigma = sigma
        self._eigenvalues_OOM = l
        self._rank_ind = rank_ind
        self._oom_rank = self._sigma.size
        self._C2t = C2t
        self.set_model_params(P=P,
                              pi=None,
                              reversible=self.reversible,
                              dt_model=self.timestep_traj.get_scaled(self.lag))

        return self