Esempio n. 1
0
def test_reversible_disconnected(disconnected_states, lag, count_mode):
    r"""disconnected states: 2 <- 0 <-> 1 <-> 3 | 7 -> 4 <-> 5 | 6"""
    count_model = TransitionCountEstimator(lagtime=lag, count_mode=count_mode) \
        .fit(disconnected_states.dtrajs).fetch_model()

    msm = MaximumLikelihoodMSM(reversible=True).fit(count_model).fetch_model()
    assert_equal(msm.n_connected_msms, len(disconnected_states.connected_sets))
    for i, subset in enumerate(disconnected_states.connected_sets):
        # can do this because subsets are ordered in decreasing cardinality
        assert_equal(msm.state_symbols(i), subset)

    non_reversibly_connected_set = [0, 1, 2, 3]
    submodel = count_model.submodel(non_reversibly_connected_set)

    msm = MaximumLikelihoodMSM(reversible=True).fit(submodel).fetch_model()
    assert_equal(msm.n_connected_msms, 2)
    assert_equal(msm.state_symbols(0), [0, 1, 3])
    assert_equal(msm.state_symbols(1), [2])

    fully_disconnected_set = [6, 2]
    submodel = count_model.submodel(fully_disconnected_set)
    msm = MaximumLikelihoodMSM(reversible=True).fit(submodel).fetch_model()
    assert_equal(msm.n_connected_msms, 2)
    assert_equal(msm.state_symbols(0), [6])
    assert_equal(msm.state_symbols(1), [2])
Esempio n. 2
0
def test_nonreversible_disconnected():
    msm1 = MarkovStateModel([[.7, .3], [.3, .7]])
    msm2 = MarkovStateModel([[.9, .05, .05], [.3, .6, .1], [.1, .1, .8]])
    traj = np.concatenate([msm1.simulate(1000000), 2 + msm2.simulate(1000000)])
    counts = TransitionCountEstimator(lagtime=1, count_mode="sliding").fit(traj)
    msm = MaximumLikelihoodMSM(reversible=True).fit(counts).fetch_model()
    assert_equal(msm.transition_matrix.shape, (3, 3))
    assert_equal(msm.stationary_distribution.shape, (3,))
    assert_equal(msm.state_symbols(), [2, 3, 4])
    assert_equal(msm.state_symbols(1), [0, 1])
    msm.select(1)
    assert_equal(msm.transition_matrix.shape, (2, 2))
    assert_equal(msm.stationary_distribution.shape, (2,))
    assert_equal(msm.state_symbols(), [0, 1])
    assert_equal(msm.state_symbols(0), [2, 3, 4])
    with assert_raises(IndexError):
        msm.select(2)