def metastable_from_msm(msm, n_hidden_states: int, reversible: bool = True, stationary: bool = False, separate_symbols=None, regularize: bool = True): r""" Makes an initial guess for an :class:`HMM <deeptime.markov.hmm.HiddenMarkovModel>` with discrete output model from an already existing MSM over observable states. The procedure is described in :footcite:`noe2013projected` and uses PCCA+ :footcite:`roblitz2013fuzzy` for coarse-graining the transition matrix and obtaining membership assignments. Parameters ---------- msm : MarkovStateModel The markov state model over observable state space. n_hidden_states : int The desired number of hidden states. reversible : bool, optional, default=True Whether the HMM transition matrix is estimated so that it is reversibe. stationary : bool, optional, default=False If True, the initial distribution of hidden states is self-consistently computed as the stationary distribution of the transition matrix. If False, it will be estimated from the starting states. Only set this to true if you're sure that the observation trajectories are initiated from a global equilibrium distribution. separate_symbols : array_like, optional, default=None Force the given set of observed states to stay in a separate hidden state. The remaining nstates-1 states will be assigned by a metastable decomposition. regularize : bool, optional, default=True If set to True, makes sure that the hidden initial distribution and transition matrix have nonzero probabilities by setting them to eps and then renormalizing. Avoids zeros that would cause estimation algorithms to crash or get stuck in suboptimal states. Returns ------- hmm_init : HiddenMarkovModel An initial guess for the HMM See Also -------- deeptime.markov.hmm.DiscreteOutputModel The type of output model this heuristic uses. :func:`metastable_from_data` Initial guess from data if no MSM is available yet. :func:`deeptime.markov.hmm.init.gaussian.from_data` Initial guess with :class:`Gaussian output model <deeptime.markov.hmm.GaussianOutputModel>`. References ---------- .. footbibliography:: """ from deeptime.markov._transition_matrix import stationary_distribution from deeptime.markov._transition_matrix import estimate_P from deeptime.markov.msm import MarkovStateModel from deeptime.markov import PCCAModel count_matrix = msm.count_model.count_matrix nonseparate_symbols = np.arange(msm.count_model.n_states_full) nonseparate_states = msm.count_model.symbols_to_states(nonseparate_symbols) nonseparate_msm = msm if separate_symbols is not None: separate_symbols = np.asanyarray(separate_symbols) if np.max(separate_symbols) >= msm.count_model.n_states_full: raise ValueError(f'Separate set has indices that do not exist in ' f'full state space: {np.max(separate_symbols)}') nonseparate_symbols = np.setdiff1d(nonseparate_symbols, separate_symbols) nonseparate_states = msm.count_model.symbols_to_states( nonseparate_symbols) nonseparate_count_model = msm.count_model.submodel(nonseparate_states) # make reversible nonseparate_count_matrix = nonseparate_count_model.count_matrix if issparse(nonseparate_count_matrix): nonseparate_count_matrix = nonseparate_count_matrix.toarray() P_nonseparate = estimate_P(nonseparate_count_matrix, reversible=True) pi = stationary_distribution(P_nonseparate, C=nonseparate_count_matrix) nonseparate_msm = MarkovStateModel(P_nonseparate, stationary_distribution=pi) if issparse(count_matrix): count_matrix = count_matrix.toarray() # if #metastable sets == #states, we can stop here n_meta = n_hidden_states if separate_symbols is None else n_hidden_states - 1 if n_meta == nonseparate_msm.n_states: pcca = PCCAModel(nonseparate_msm.transition_matrix, nonseparate_msm.stationary_distribution, np.eye(n_meta), np.eye(n_meta)) else: pcca = nonseparate_msm.pcca(n_meta) if separate_symbols is not None: separate_states = msm.count_model.symbols_to_states(separate_symbols) memberships = np.zeros((msm.n_states, n_hidden_states)) memberships[nonseparate_states, :n_hidden_states - 1] = pcca.memberships memberships[separate_states, -1] = 1 else: memberships = pcca.memberships separate_states = None hidden_transition_matrix = _coarse_grain_transition_matrix( msm.transition_matrix, memberships) if reversible: from deeptime.markov._transition_matrix import enforce_reversible_on_closed hidden_transition_matrix = enforce_reversible_on_closed( hidden_transition_matrix) hidden_counts = memberships.T.dot(count_matrix).dot(memberships) hidden_pi = stationary_distribution(hidden_transition_matrix, C=hidden_counts) output_probabilities = np.zeros( (n_hidden_states, msm.count_model.n_states_full)) # we might have lost a few symbols, reduce nonsep symbols to the ones actually represented nonseparate_symbols = msm.count_model.state_symbols[nonseparate_states] if separate_symbols is not None: separate_symbols = msm.count_model.state_symbols[separate_states] output_probabilities[:n_hidden_states - 1, nonseparate_symbols] = pcca.metastable_distributions output_probabilities[ -1, separate_symbols] = msm.stationary_distribution[separate_states] else: output_probabilities[:, nonseparate_symbols] = pcca.metastable_distributions # regularize eps_a = 0.01 / n_hidden_states if regularize else 0. hidden_pi, hidden_transition_matrix = _regularize_hidden( hidden_pi, hidden_transition_matrix, reversible=reversible, stationary=stationary, count_matrix=hidden_counts, eps=eps_a) eps_b = 0.01 / msm.n_states if regularize else 0. output_probabilities = _regularize_pobs(output_probabilities, nonempty=None, separate=separate_symbols, eps=eps_b) from deeptime.markov.hmm import HiddenMarkovModel return HiddenMarkovModel(transition_model=hidden_transition_matrix, output_model=output_probabilities, initial_distribution=hidden_pi)
class FiveStateSetup(object): def __init__(self, complete: bool = True): self.complete = complete data = np.load(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'resources', 'TestData_OOM_MSM.npz')) if complete: self.dtrajs = [data['arr_%d' % k] for k in range(1000)] else: excluded = [ 21, 25, 30, 40, 66, 72, 74, 91, 116, 158, 171, 175, 201, 239, 246, 280, 300, 301, 310, 318, 322, 323, 339, 352, 365, 368, 407, 412, 444, 475, 486, 494, 510, 529, 560, 617, 623, 637, 676, 689, 728, 731, 778, 780, 811, 828, 838, 845, 851, 859, 868, 874, 895, 933, 935, 938, 958, 961, 968, 974, 984, 990, 999 ] self.dtrajs = [data['arr_%d' % k] for k in np.setdiff1d(np.arange(1000), excluded)] # Number of states: self.N = 5 # Lag time: self.tau = 5 self.dtrajs_lag = [traj[:-self.tau] for traj in self.dtrajs] # Rank: if complete: self.rank = 3 else: self.rank = 2 # Build models: self.msmrev = OOMReweightedMSM(lagtime=self.tau, rank_mode='bootstrap_trajs').fit(self.dtrajs) self.msmrev_sparse = OOMReweightedMSM(lagtime=self.tau, sparse=True, rank_mode='bootstrap_trajs') \ .fit(self.dtrajs) self.msm = OOMReweightedMSM(lagtime=self.tau, reversible=False, rank_mode='bootstrap_trajs').fit(self.dtrajs) self.msm_sparse = OOMReweightedMSM(lagtime=self.tau, reversible=False, sparse=True, rank_mode='bootstrap_trajs').fit(self.dtrajs) self.estimators = [self.msmrev, self.msm, self.msmrev_sparse, self.msm_sparse] self.msms = [est.fetch_model() for est in self.estimators] # Reference count matrices at lag time tau and 2*tau: if complete: self.C2t = data['C2t'] else: self.C2t = data['C2t_s'] self.Ct = np.sum(self.C2t, axis=1) if complete: self.Ct_active = self.Ct self.C2t_active = self.C2t self.active_faction = 1. else: lcc = msmest.largest_connected_set(self.Ct) self.Ct_active = msmest.largest_connected_submatrix(self.Ct, lcc=lcc) self.C2t_active = self.C2t[:4, :4, :4] self.active_fraction = np.sum(self.Ct_active) / np.sum(self.Ct) # Compute OOM-components: self.Xi, self.omega, self.sigma, self.l = oom_transformations(self.Ct_active, self.C2t_active, self.rank) # Compute corrected transition matrix: Tt_rev = compute_transition_matrix(self.Xi, self.omega, self.sigma, reversible=True) Tt = compute_transition_matrix(self.Xi, self.omega, self.sigma, reversible=False) # Build reference models: self.rmsmrev = MarkovStateModel(Tt_rev) self.rmsm = MarkovStateModel(Tt) # Active count fraction: self.hist = count_states(self.dtrajs) self.active_hist = self.hist[:-1] if not complete else self.hist self.active_count_frac = float(np.sum(self.active_hist)) / np.sum(self.hist) if not complete else 1. self.active_state_frac = 0.8 if not complete else 1. # Commitor and MFPT: a = np.array([0, 1]) b = np.array([4]) if complete else np.array([3]) self.comm_forward = self.rmsm.committor_forward(a, b) self.comm_forward_rev = self.rmsmrev.committor_forward(a, b) self.comm_backward = self.rmsm.committor_backward(a, b) self.comm_backward_rev = self.rmsmrev.committor_backward(a, b) self.mfpt = self.tau * self.rmsm.mfpt(a, b) self.mfpt_rev = self.tau * self.rmsmrev.mfpt(a, b) # PCCA: pcca = self.rmsmrev.pcca(3 if complete else 2) self.pcca_ass = pcca.assignments self.pcca_dist = pcca.metastable_distributions self.pcca_mem = pcca.memberships self.pcca_sets = pcca.sets # Experimental quantities: a = np.array([1, 2, 3, 4, 5]) b = np.array([1, -1, 0, -2, 4]) p0 = np.array([0.5, 0.2, 0.2, 0.1, 0.0]) if not complete: a = a[:-1] b = b[:-1] p0 = p0[:-1] pi = self.rmsm.stationary_distribution pi_rev = self.rmsmrev.stationary_distribution _, _, L_rev = ma.rdl_decomposition(Tt_rev) self.exp = np.dot(self.rmsm.stationary_distribution, a) self.exp_rev = np.dot(self.rmsmrev.stationary_distribution, a) self.corr_rev = np.zeros(10) self.rel = np.zeros(10) self.rel_rev = np.zeros(10) for k in range(10): Ck_rev = np.dot(np.diag(pi_rev), np.linalg.matrix_power(Tt_rev, k)) self.corr_rev[k] = np.dot(a.T, np.dot(Ck_rev, b)) self.rel[k] = np.dot(p0.T, np.dot(np.linalg.matrix_power(Tt, k), a)) self.rel_rev[k] = np.dot(p0.T, np.dot(np.linalg.matrix_power(Tt_rev, k), a)) self.fing_cor = np.dot(a.T, L_rev.T) * np.dot(b.T, L_rev.T) self.fing_rel = np.dot(a.T, L_rev.T) * np.dot((p0 / pi_rev).T, L_rev.T)