Beispiel #1
0
def test_active_state_indices(oom_msm_scenario):
    for msm in oom_msm_scenario.msms:
        dtrajs_proj = msm.count_model.transform_discrete_trajectories_to_submodel(oom_msm_scenario.dtrajs)
        indices = compute_index_states(dtrajs_proj)
        np.testing.assert_equal(len(indices), msm.n_states)
        hist = count_states(oom_msm_scenario.dtrajs)
        for state in range(msm.n_states):
            np.testing.assert_equal(indices[state].shape[0], hist[msm.count_model.state_symbols[state]])
            np.testing.assert_equal(indices[state].shape[1], 2)
Beispiel #2
0
def test_simulate_stats(msm):
    # test statistics of starting state
    N = 5000
    trajs = [msm.simulate(1, seed=i + 1) for i in range(N)]
    ss = np.concatenate(trajs).astype(int)
    pi = deeptime.markov.tools.analysis.stationary_distribution(
        msm.transition_matrix)
    piest = count_states(ss) / float(N)
    np.testing.assert_allclose(piest, pi, atol=0.025)
 def test_active_state_indices(self, setting):
     scenario = make_double_well(setting)
     from deeptime.markov.sample import compute_index_states
     I = compute_index_states(scenario.data.dtraj, subset=scenario.msm.count_model.state_symbols)
     assert (len(I) == scenario.msm.n_states)
     # compare to histogram
     from deeptime.markov.util import count_states
     hist = count_states(scenario.data.dtraj)
     # number of frames should match on active subset
     A = scenario.msm.count_model.state_symbols
     for i in range(A.shape[0]):
         assert I[i].shape[0] == hist[A[i]]
         assert I[i].shape[1] == 2
Beispiel #4
0
    def test_observable_state_indices(self):
        from deeptime.markov.sample import compute_index_states

        hmsm = self.hmm_lag10_largest
        I = compute_index_states(self.dtrajs, subset=hmsm.observation_symbols)
        # I = hmsm.observable_state_indexes
        np.testing.assert_equal(len(I), hmsm.n_observation_states)
        # compare to histogram
        hist = count_states(self.dtrajs)
        # number of frames should match on active subset
        A = hmsm.observation_symbols
        for i in range(A.shape[0]):
            np.testing.assert_equal(I[i].shape[0], hist[A[i]])
            np.testing.assert_equal(I[i].shape[1], 2)
Beispiel #5
0
    def fit(self, data, *args, **kw):
        r""" Counts transitions at given lag time according to configuration of the estimator.

        Parameters
        ----------
        data : array_like or list of array_like
            discretized trajectories
        """
        dtrajs = ensure_dtraj_list(data)

        # basic count statistics
        histogram = count_states(dtrajs, ignore_negative=True)

        # Compute count matrix
        count_mode = self.count_mode
        lagtime = self.lagtime
        count_matrix = TransitionCountEstimator.count(count_mode,
                                                      dtrajs,
                                                      lagtime,
                                                      sparse=self.sparse)
        if self.n_states is not None and self.n_states > count_matrix.shape[0]:
            histogram = np.pad(histogram,
                               pad_width=[
                                   (0, self.n_states - count_matrix.shape[0])
                               ])
            if issparse(count_matrix):
                count_matrix = scipy.sparse.csr_matrix(
                    (count_matrix.data, count_matrix.indices,
                     count_matrix.indptr),
                    shape=(self.n_states, self.n_states))
            else:
                n_pad = self.n_states - count_matrix.shape[0]
                count_matrix = np.pad(count_matrix,
                                      pad_width=[(0, n_pad), (0, n_pad)])

        # initially state symbols, full count matrix, and full histogram can be left None because they coincide
        # with the input arguments
        self._model = TransitionCountModel(count_matrix=count_matrix,
                                           counting_mode=count_mode,
                                           lagtime=lagtime,
                                           state_histogram=histogram)
        return self
Beispiel #6
0
    def nonempty_obs(self, dtrajs) -> np.ndarray:
        r"""
        Computes the set of visited observable states given a set of discrete trajectories.

        Parameters
        ----------
        dtrajs : array_like
            observable trajectory

        Returns
        -------
        symbols : np.ndarray
            The observation symbols which are visited.
        """
        from deeptime.markov.util import compute_dtrajs_effective, count_states
        if dtrajs is None:
            raise ValueError("Needs nonempty dtrajs to evaluate nonempty obs.")
        dtrajs = ensure_dtraj_list(dtrajs)
        dtrajs_lagged_strided = compute_dtrajs_effective(
            dtrajs, self.transition_model.lagtime,
            self.transition_model.count_model.n_states_full, self.stride)
        obs = np.where(count_states(dtrajs_lagged_strided) > 0)[0]
        return obs
Beispiel #7
0
    def __init__(self, complete: bool = True):
        self.complete = complete
        data = np.load(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'resources', 'TestData_OOM_MSM.npz'))
        if complete:
            self.dtrajs = [data['arr_%d' % k] for k in range(1000)]
        else:
            excluded = [
                21, 25, 30, 40, 66, 72, 74, 91, 116, 158, 171, 175, 201, 239, 246, 280, 300, 301, 310, 318,
                322, 323, 339, 352, 365, 368, 407, 412, 444, 475, 486, 494, 510, 529, 560, 617, 623, 637,
                676, 689, 728, 731, 778, 780, 811, 828, 838, 845, 851, 859, 868, 874, 895, 933, 935, 938,
                958, 961, 968, 974, 984, 990, 999
            ]
            self.dtrajs = [data['arr_%d' % k] for k in np.setdiff1d(np.arange(1000), excluded)]
        # Number of states:
        self.N = 5
        # Lag time:
        self.tau = 5
        self.dtrajs_lag = [traj[:-self.tau] for traj in self.dtrajs]
        # Rank:
        if complete:
            self.rank = 3
        else:
            self.rank = 2

        # Build models:
        self.msmrev = OOMReweightedMSM(lagtime=self.tau, rank_mode='bootstrap_trajs').fit(self.dtrajs)
        self.msmrev_sparse = OOMReweightedMSM(lagtime=self.tau, sparse=True, rank_mode='bootstrap_trajs') \
            .fit(self.dtrajs)
        self.msm = OOMReweightedMSM(lagtime=self.tau, reversible=False, rank_mode='bootstrap_trajs').fit(self.dtrajs)
        self.msm_sparse = OOMReweightedMSM(lagtime=self.tau, reversible=False, sparse=True,
                                           rank_mode='bootstrap_trajs').fit(self.dtrajs)
        self.estimators = [self.msmrev, self.msm, self.msmrev_sparse, self.msm_sparse]
        self.msms = [est.fetch_model() for est in self.estimators]

        # Reference count matrices at lag time tau and 2*tau:
        if complete:
            self.C2t = data['C2t']
        else:
            self.C2t = data['C2t_s']
        self.Ct = np.sum(self.C2t, axis=1)

        if complete:
            self.Ct_active = self.Ct
            self.C2t_active = self.C2t
            self.active_faction = 1.
        else:
            lcc = msmest.largest_connected_set(self.Ct)
            self.Ct_active = msmest.largest_connected_submatrix(self.Ct, lcc=lcc)
            self.C2t_active = self.C2t[:4, :4, :4]
            self.active_fraction = np.sum(self.Ct_active) / np.sum(self.Ct)

        # Compute OOM-components:
        self.Xi, self.omega, self.sigma, self.l = oom_transformations(self.Ct_active, self.C2t_active, self.rank)

        # Compute corrected transition matrix:
        Tt_rev = compute_transition_matrix(self.Xi, self.omega, self.sigma, reversible=True)
        Tt = compute_transition_matrix(self.Xi, self.omega, self.sigma, reversible=False)

        # Build reference models:
        self.rmsmrev = MarkovStateModel(Tt_rev)
        self.rmsm = MarkovStateModel(Tt)

        # Active count fraction:
        self.hist = count_states(self.dtrajs)
        self.active_hist = self.hist[:-1] if not complete else self.hist

        self.active_count_frac = float(np.sum(self.active_hist)) / np.sum(self.hist) if not complete else 1.
        self.active_state_frac = 0.8 if not complete else 1.

        # Commitor and MFPT:
        a = np.array([0, 1])
        b = np.array([4]) if complete else np.array([3])
        self.comm_forward = self.rmsm.committor_forward(a, b)
        self.comm_forward_rev = self.rmsmrev.committor_forward(a, b)
        self.comm_backward = self.rmsm.committor_backward(a, b)
        self.comm_backward_rev = self.rmsmrev.committor_backward(a, b)
        self.mfpt = self.tau * self.rmsm.mfpt(a, b)
        self.mfpt_rev = self.tau * self.rmsmrev.mfpt(a, b)
        # PCCA:
        pcca = self.rmsmrev.pcca(3 if complete else 2)
        self.pcca_ass = pcca.assignments
        self.pcca_dist = pcca.metastable_distributions
        self.pcca_mem = pcca.memberships
        self.pcca_sets = pcca.sets
        # Experimental quantities:
        a = np.array([1, 2, 3, 4, 5])
        b = np.array([1, -1, 0, -2, 4])
        p0 = np.array([0.5, 0.2, 0.2, 0.1, 0.0])
        if not complete:
            a = a[:-1]
            b = b[:-1]
            p0 = p0[:-1]
        pi = self.rmsm.stationary_distribution
        pi_rev = self.rmsmrev.stationary_distribution
        _, _, L_rev = ma.rdl_decomposition(Tt_rev)
        self.exp = np.dot(self.rmsm.stationary_distribution, a)
        self.exp_rev = np.dot(self.rmsmrev.stationary_distribution, a)
        self.corr_rev = np.zeros(10)
        self.rel = np.zeros(10)
        self.rel_rev = np.zeros(10)
        for k in range(10):
            Ck_rev = np.dot(np.diag(pi_rev), np.linalg.matrix_power(Tt_rev, k))
            self.corr_rev[k] = np.dot(a.T, np.dot(Ck_rev, b))
            self.rel[k] = np.dot(p0.T, np.dot(np.linalg.matrix_power(Tt, k), a))
            self.rel_rev[k] = np.dot(p0.T, np.dot(np.linalg.matrix_power(Tt_rev, k), a))

        self.fing_cor = np.dot(a.T, L_rev.T) * np.dot(b.T, L_rev.T)
        self.fing_rel = np.dot(a.T, L_rev.T) * np.dot((p0 / pi_rev).T, L_rev.T)