Exemple #1
0
def test_sufficient_statistics():
    # test all of the sufficient statistics against sklearn and pure python

    model = MetastableSwitchingLDS(n_states=N_STATES,
                                   n_features=refmodel.n_features)
    model._impl._sequences = data
    model.means_ = refmodel.means_
    model.covars_ = refmodel.covars_
    model.transmat_ = refmodel.transmat_
    model.populations_ = refmodel.startprob_

    logprob, stats = model._impl.do_estep()
    rlogprob, rstats = _sklearn_estep()

    yield lambda: np.testing.assert_array_almost_equal(
        stats['post'], rstats['post'], decimal=3)
    yield lambda: np.testing.assert_array_almost_equal(
        stats['post[1:]'], rstats['post[1:]'], decimal=3)
    yield lambda: np.testing.assert_array_almost_equal(
        stats['post[:-1]'], rstats['post[:-1]'], decimal=3)
    yield lambda: np.testing.assert_array_almost_equal(
        stats['obs'], rstats['obs'], decimal=3)
    yield lambda: np.testing.assert_array_almost_equal(
        stats['obs[1:]'], rstats['obs[1:]'], decimal=3)
    yield lambda: np.testing.assert_array_almost_equal(
        stats['obs[:-1]'], rstats['obs[:-1]'], decimal=3)
    yield lambda: np.testing.assert_array_almost_equal(
        stats['obs*obs.T'], rstats['obs*obs.T'], decimal=3)
    yield lambda: np.testing.assert_array_almost_equal(
        stats['obs*obs[t-1].T'], rstats['obs*obs[t-1].T'], decimal=3)
    yield lambda: np.testing.assert_array_almost_equal(
        stats['obs[1:]*obs[1:].T'], rstats['obs[1:]*obs[1:].T'], decimal=3)
    yield lambda: np.testing.assert_array_almost_equal(
        stats['obs[:-1]*obs[:-1].T'], rstats['obs[:-1]*obs[:-1].T'], decimal=3)
    yield lambda: np.testing.assert_array_almost_equal(
        stats['trans'], rstats['trans'], decimal=3)