예제 #1
0
def metastable_from_msm(msm,
                        n_hidden_states: int,
                        reversible: bool = True,
                        stationary: bool = False,
                        separate_symbols=None,
                        regularize: bool = True):
    r""" Makes an initial guess for an :class:`HMM <deeptime.markov.hmm.HiddenMarkovModel>` with
    discrete output model from an already existing MSM over observable states. The procedure is described in
    :footcite:`noe2013projected` and uses PCCA+ :footcite:`roblitz2013fuzzy` for
    coarse-graining the transition matrix and obtaining membership assignments.

    Parameters
    ----------
    msm : MarkovStateModel
        The markov state model over observable state space.
    n_hidden_states : int
        The desired number of hidden states.
    reversible : bool, optional, default=True
        Whether the HMM transition matrix is estimated so that it is reversibe.
    stationary : bool, optional, default=False
        If True, the initial distribution of hidden states is self-consistently computed as the stationary
        distribution of the transition matrix. If False, it will be estimated from the starting states.
        Only set this to true if you're sure that the observation trajectories are initiated from a global
        equilibrium distribution.
    separate_symbols : array_like, optional, default=None
        Force the given set of observed states to stay in a separate hidden state.
        The remaining nstates-1 states will be assigned by a metastable decomposition.
    regularize : bool, optional, default=True
        If set to True, makes sure that the hidden initial distribution and transition matrix have nonzero probabilities
        by setting them to eps and then renormalizing. Avoids zeros that would cause estimation algorithms to crash or
        get stuck in suboptimal states.

    Returns
    -------
    hmm_init : HiddenMarkovModel
        An initial guess for the HMM

    See Also
    --------
    deeptime.markov.hmm.DiscreteOutputModel
        The type of output model this heuristic uses.

    :func:`metastable_from_data`
        Initial guess from data if no MSM is available yet.

    :func:`deeptime.markov.hmm.init.gaussian.from_data`
        Initial guess with :class:`Gaussian output model <deeptime.markov.hmm.GaussianOutputModel>`.

    References
    ----------
    .. footbibliography::
    """
    from deeptime.markov._transition_matrix import stationary_distribution
    from deeptime.markov._transition_matrix import estimate_P
    from deeptime.markov.msm import MarkovStateModel
    from deeptime.markov import PCCAModel

    count_matrix = msm.count_model.count_matrix
    nonseparate_symbols = np.arange(msm.count_model.n_states_full)
    nonseparate_states = msm.count_model.symbols_to_states(nonseparate_symbols)
    nonseparate_msm = msm
    if separate_symbols is not None:
        separate_symbols = np.asanyarray(separate_symbols)
        if np.max(separate_symbols) >= msm.count_model.n_states_full:
            raise ValueError(f'Separate set has indices that do not exist in '
                             f'full state space: {np.max(separate_symbols)}')
        nonseparate_symbols = np.setdiff1d(nonseparate_symbols,
                                           separate_symbols)
        nonseparate_states = msm.count_model.symbols_to_states(
            nonseparate_symbols)
        nonseparate_count_model = msm.count_model.submodel(nonseparate_states)
        # make reversible
        nonseparate_count_matrix = nonseparate_count_model.count_matrix
        if issparse(nonseparate_count_matrix):
            nonseparate_count_matrix = nonseparate_count_matrix.toarray()
        P_nonseparate = estimate_P(nonseparate_count_matrix, reversible=True)
        pi = stationary_distribution(P_nonseparate, C=nonseparate_count_matrix)
        nonseparate_msm = MarkovStateModel(P_nonseparate,
                                           stationary_distribution=pi)
    if issparse(count_matrix):
        count_matrix = count_matrix.toarray()

    # if #metastable sets == #states, we can stop here
    n_meta = n_hidden_states if separate_symbols is None else n_hidden_states - 1
    if n_meta == nonseparate_msm.n_states:
        pcca = PCCAModel(nonseparate_msm.transition_matrix,
                         nonseparate_msm.stationary_distribution,
                         np.eye(n_meta), np.eye(n_meta))
    else:
        pcca = nonseparate_msm.pcca(n_meta)
    if separate_symbols is not None:
        separate_states = msm.count_model.symbols_to_states(separate_symbols)
        memberships = np.zeros((msm.n_states, n_hidden_states))
        memberships[nonseparate_states, :n_hidden_states -
                    1] = pcca.memberships
        memberships[separate_states, -1] = 1
    else:
        memberships = pcca.memberships
        separate_states = None

    hidden_transition_matrix = _coarse_grain_transition_matrix(
        msm.transition_matrix, memberships)
    if reversible:
        from deeptime.markov._transition_matrix import enforce_reversible_on_closed
        hidden_transition_matrix = enforce_reversible_on_closed(
            hidden_transition_matrix)

    hidden_counts = memberships.T.dot(count_matrix).dot(memberships)
    hidden_pi = stationary_distribution(hidden_transition_matrix,
                                        C=hidden_counts)

    output_probabilities = np.zeros(
        (n_hidden_states, msm.count_model.n_states_full))
    # we might have lost a few symbols, reduce nonsep symbols to the ones actually represented
    nonseparate_symbols = msm.count_model.state_symbols[nonseparate_states]
    if separate_symbols is not None:
        separate_symbols = msm.count_model.state_symbols[separate_states]
        output_probabilities[:n_hidden_states - 1,
                             nonseparate_symbols] = pcca.metastable_distributions
        output_probabilities[
            -1,
            separate_symbols] = msm.stationary_distribution[separate_states]
    else:
        output_probabilities[:,
                             nonseparate_symbols] = pcca.metastable_distributions

    # regularize
    eps_a = 0.01 / n_hidden_states if regularize else 0.
    hidden_pi, hidden_transition_matrix = _regularize_hidden(
        hidden_pi,
        hidden_transition_matrix,
        reversible=reversible,
        stationary=stationary,
        count_matrix=hidden_counts,
        eps=eps_a)
    eps_b = 0.01 / msm.n_states if regularize else 0.
    output_probabilities = _regularize_pobs(output_probabilities,
                                            nonempty=None,
                                            separate=separate_symbols,
                                            eps=eps_b)
    from deeptime.markov.hmm import HiddenMarkovModel
    return HiddenMarkovModel(transition_model=hidden_transition_matrix,
                             output_model=output_probabilities,
                             initial_distribution=hidden_pi)
예제 #2
0
class FiveStateSetup(object):
    def __init__(self, complete: bool = True):
        self.complete = complete
        data = np.load(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'resources', 'TestData_OOM_MSM.npz'))
        if complete:
            self.dtrajs = [data['arr_%d' % k] for k in range(1000)]
        else:
            excluded = [
                21, 25, 30, 40, 66, 72, 74, 91, 116, 158, 171, 175, 201, 239, 246, 280, 300, 301, 310, 318,
                322, 323, 339, 352, 365, 368, 407, 412, 444, 475, 486, 494, 510, 529, 560, 617, 623, 637,
                676, 689, 728, 731, 778, 780, 811, 828, 838, 845, 851, 859, 868, 874, 895, 933, 935, 938,
                958, 961, 968, 974, 984, 990, 999
            ]
            self.dtrajs = [data['arr_%d' % k] for k in np.setdiff1d(np.arange(1000), excluded)]
        # Number of states:
        self.N = 5
        # Lag time:
        self.tau = 5
        self.dtrajs_lag = [traj[:-self.tau] for traj in self.dtrajs]
        # Rank:
        if complete:
            self.rank = 3
        else:
            self.rank = 2

        # Build models:
        self.msmrev = OOMReweightedMSM(lagtime=self.tau, rank_mode='bootstrap_trajs').fit(self.dtrajs)
        self.msmrev_sparse = OOMReweightedMSM(lagtime=self.tau, sparse=True, rank_mode='bootstrap_trajs') \
            .fit(self.dtrajs)
        self.msm = OOMReweightedMSM(lagtime=self.tau, reversible=False, rank_mode='bootstrap_trajs').fit(self.dtrajs)
        self.msm_sparse = OOMReweightedMSM(lagtime=self.tau, reversible=False, sparse=True,
                                           rank_mode='bootstrap_trajs').fit(self.dtrajs)
        self.estimators = [self.msmrev, self.msm, self.msmrev_sparse, self.msm_sparse]
        self.msms = [est.fetch_model() for est in self.estimators]

        # Reference count matrices at lag time tau and 2*tau:
        if complete:
            self.C2t = data['C2t']
        else:
            self.C2t = data['C2t_s']
        self.Ct = np.sum(self.C2t, axis=1)

        if complete:
            self.Ct_active = self.Ct
            self.C2t_active = self.C2t
            self.active_faction = 1.
        else:
            lcc = msmest.largest_connected_set(self.Ct)
            self.Ct_active = msmest.largest_connected_submatrix(self.Ct, lcc=lcc)
            self.C2t_active = self.C2t[:4, :4, :4]
            self.active_fraction = np.sum(self.Ct_active) / np.sum(self.Ct)

        # Compute OOM-components:
        self.Xi, self.omega, self.sigma, self.l = oom_transformations(self.Ct_active, self.C2t_active, self.rank)

        # Compute corrected transition matrix:
        Tt_rev = compute_transition_matrix(self.Xi, self.omega, self.sigma, reversible=True)
        Tt = compute_transition_matrix(self.Xi, self.omega, self.sigma, reversible=False)

        # Build reference models:
        self.rmsmrev = MarkovStateModel(Tt_rev)
        self.rmsm = MarkovStateModel(Tt)

        # Active count fraction:
        self.hist = count_states(self.dtrajs)
        self.active_hist = self.hist[:-1] if not complete else self.hist

        self.active_count_frac = float(np.sum(self.active_hist)) / np.sum(self.hist) if not complete else 1.
        self.active_state_frac = 0.8 if not complete else 1.

        # Commitor and MFPT:
        a = np.array([0, 1])
        b = np.array([4]) if complete else np.array([3])
        self.comm_forward = self.rmsm.committor_forward(a, b)
        self.comm_forward_rev = self.rmsmrev.committor_forward(a, b)
        self.comm_backward = self.rmsm.committor_backward(a, b)
        self.comm_backward_rev = self.rmsmrev.committor_backward(a, b)
        self.mfpt = self.tau * self.rmsm.mfpt(a, b)
        self.mfpt_rev = self.tau * self.rmsmrev.mfpt(a, b)
        # PCCA:
        pcca = self.rmsmrev.pcca(3 if complete else 2)
        self.pcca_ass = pcca.assignments
        self.pcca_dist = pcca.metastable_distributions
        self.pcca_mem = pcca.memberships
        self.pcca_sets = pcca.sets
        # Experimental quantities:
        a = np.array([1, 2, 3, 4, 5])
        b = np.array([1, -1, 0, -2, 4])
        p0 = np.array([0.5, 0.2, 0.2, 0.1, 0.0])
        if not complete:
            a = a[:-1]
            b = b[:-1]
            p0 = p0[:-1]
        pi = self.rmsm.stationary_distribution
        pi_rev = self.rmsmrev.stationary_distribution
        _, _, L_rev = ma.rdl_decomposition(Tt_rev)
        self.exp = np.dot(self.rmsm.stationary_distribution, a)
        self.exp_rev = np.dot(self.rmsmrev.stationary_distribution, a)
        self.corr_rev = np.zeros(10)
        self.rel = np.zeros(10)
        self.rel_rev = np.zeros(10)
        for k in range(10):
            Ck_rev = np.dot(np.diag(pi_rev), np.linalg.matrix_power(Tt_rev, k))
            self.corr_rev[k] = np.dot(a.T, np.dot(Ck_rev, b))
            self.rel[k] = np.dot(p0.T, np.dot(np.linalg.matrix_power(Tt, k), a))
            self.rel_rev[k] = np.dot(p0.T, np.dot(np.linalg.matrix_power(Tt_rev, k), a))

        self.fing_cor = np.dot(a.T, L_rev.T) * np.dot(b.T, L_rev.T)
        self.fing_rel = np.dot(a.T, L_rev.T) * np.dot((p0 / pi_rev).T, L_rev.T)