Esempio n. 1
0
File: hmm.py Progetto: tcsvn/pyadlml
    def assign_states(self, true_z, true_y):
        """
        assigns the unordered hidden states of the trained model (on true_y)
        to the most probable state labels in alignment of true_z
        :param true_z
            the true state sequence of a labeled dataset
        :param true_y
            the true corresp. observation sequence of a labeled dataset

        assign
           z = true state seq [1,2,1,....,]
           tmp3 = pred. state seq [3,4,1,2,...,]
           match each row to different column in such a way that corresp
           sum is minimized
           select n el of C, so that there is exactly one el.  in each row
           and one in each col. with min corresp. costs


        match states [1,2,...,] of of the
        :return:
            None
        """
        # Plot the true and inferred states
        tmp1 = self._hmm.most_likely_states(true_y)
        # todo temporary cast to int64 remove for space saving solution
        # todo as the amount of states only in range [0, 30
        true_z = true_z.astype(np.int64)
        tmp2 = find_permutation(true_z, tmp1)
        self._hmm.permute(tmp2)
Esempio n. 2
0
# Mask off some data
mask = npr.rand(T, N) < 0.75
y_masked = y * mask

# Fit an SLDS with mean field posterior
print("Fitting SLDS with SVI using structured variational posterior")
slds = SLDS(N, K, D, emissions="gaussian")
slds.initialize(y_masked, masks=mask)

q_mf = SLDSMeanFieldVariationalPosterior(slds, y_masked, masks=mask)
q_mf_elbos = slds.fit(q_mf, y_masked, masks=mask, num_iters=1000, initialize=False)
q_mf_x = q_mf.mean[0]
q_mf_y = slds.smooth(q_mf_x, y)

# Find the permutation that matches the true and inferred states
slds.permute(find_permutation(z, slds.most_likely_states(q_mf_x, y)))
q_mf_z = slds.most_likely_states(q_mf_x, y)

# Do the same with the structured posterior
print("Fitting SLDS with SVI using structured variational posterior")
slds = SLDS(N, K, D, emissions="gaussian")
slds.initialize(y_masked, masks=mask)

q_struct = SLDSTriDiagVariationalPosterior(slds, y_masked, masks=mask)
q_struct_elbos = slds.fit(q_struct, y_masked, masks=mask, num_iters=1000, initialize=False)
q_struct_x = q_struct.mean[0]
q_struct_y = slds.smooth(q_struct_x, y)

# Find the permutation that matches the true and inferred states
slds.permute(find_permutation(z, slds.most_likely_states(q_struct_x, y)))
q_struct_z = slds.most_likely_states(q_struct_x, y)
Esempio n. 3
0
]

# Fit with both SGD and EM
methods = ["sgd", "em"]

results = {}
for obs in observations:
    for method in methods:
        print("Fitting {} HMM with {}".format(obs, method))
        model = ssm.HMM(K, D, observations=obs)
        train_lls = model.fit(y, method=method)
        test_ll = model.log_likelihood(y_test)
        smoothed_y = model.smooth(y)

        # Permute to match the true states
        model.permute(find_permutation(z, model.most_likely_states(y)))
        smoothed_z = model.most_likely_states(y)
        results[(obs, method)] = (model, train_lls, test_ll, smoothed_z,
                                  smoothed_y)

# Plot the inferred states
fig, axs = plt.subplots(len(observations) + 1, 1, figsize=(12, 8))

# Plot the true states
plt.sca(axs[0])
plt.imshow(z[None, :], aspect="auto", cmap="jet")
plt.title("true")
plt.xticks()

# Plot the inferred states
for i, obs in enumerate(observations):
Esempio n. 4
0
                 masks=mask,
                 num_iters=5000,
                 print_intvl=100,
                 initialize=False)
slds_x = q.mean[0]

# In[6]:

plt.plot(elbos)
plt.xlabel("SVI Iteration")
plt.ylabel("ELBO")

# In[7]:

# Find the permutation that matches the true and inferred states
slds.permute(find_permutation(z, slds.most_likely_states(slds_x, y)))
slds_z = slds.most_likely_states(slds_x, y)

# In[8]:

# Smooth the observations
slds_y = slds.smooth(slds_x, y)

# In[9]:

# Plot the true and inferred states
xlim = (0, 1000)

plt.figure(figsize=(8, 4))
plt.imshow(np.column_stack((z, slds_z)).T, aspect="auto")
plt.plot(xlim, [0.5, 0.5], '-k', lw=2)
Esempio n. 5
0
hsmm = HSMM(K, D, observations="gaussian")
hsmm_em_lls = hsmm.fit(y, method="em", num_em_iters=N_em_iters)

# Plot log likelihoods (fit model is typically better)
plt.figure()
plt.plot(hsmm_em_lls, ls=':', label="HSMM (EM)")
plt.plot(true_ll * np.ones(N_em_iters), ':', label="true")
plt.legend(loc="lower right")

# Print the test likelihoods (true model is typically better)
print("Test log likelihood")
print("True HSMM: ", true_hsmm.log_likelihood(y_test))
print("Fit HSMM:  ", hsmm.log_likelihood(y_test))

# Plot the true and inferred states
hsmm.permute(find_permutation(z, hsmm.most_likely_states(y)))
hsmm_z = hsmm.most_likely_states(y)

# Plot the true and inferred discrete states
plt.figure(figsize=(8, 4))
plt.subplot(211)
plt.imshow(z[None, :1000], aspect="auto", cmap="cubehelix", vmin=0, vmax=K - 1)
plt.xlim(0, 1000)
plt.ylabel("$z_{\\mathrm{true}}$")
plt.yticks([])

plt.subplot(212)
plt.imshow(hsmm_z[None, :1000],
           aspect="auto",
           cmap="cubehelix",
           vmin=0,
Esempio n. 6
0
def fit_slds_and_return_errors(X,
                               A1,
                               A2,
                               Kmax=4,
                               r=6,
                               num_iters=200,
                               num_restarts=1,
                               laplace_em=True,
                               single_subspace=True,
                               use_ds=True):
    '''
    Fit an SLDS to test data and return errors.
    
    Parameters
    ==========
    
    X : array, T x N
    A1 : array, N x N
    A2 : array, N x N
    '''
    # hardcoded
    true_K = 2
    # params
    N = X.shape[1]
    T = X.shape[0]

    def _fit_once():
        # fit a model
        slds = ssm.SLDS(N,
                        Kmax,
                        r,
                        single_subspace=single_subspace,
                        emissions='gaussian')
        #slds.initialize(X)
        #q_mf = SLDSMeanFieldVariationalPosterior(slds, X)
        if laplace_em:
            elbos, posterior = slds.fit(
                X,
                num_iters=num_iters,
                initialize=True,
                method="laplace_em",
                variational_posterior="structured_meanfield")
            posterior_x = posterior.mean_continuous_states[0]
        else:
            # Use blackbox + meanfield
            elbos, posterior = slds.fit(X,
                                        num_iters=num_iters,
                                        initialize=True,
                                        method="bbvi",
                                        variational_posterior="mf")
            # predict states
        return slds, elbos, posterior

    # Fit num_restarts many models
    results = []
    for restart in range(num_restarts):
        print("restart ", restart + 1, " / ", num_restarts)
        results.append(_fit_once())
    sldss, elboss, posteriors = list(zip(*results))

    # Take the SLDS that achieved the highest training ELBO
    best = np.argmax([elbos[-1] for elbos in elboss])
    slds, elbos, posterior = sldss[best], elboss[best], posteriors[best]

    if laplace_em:
        posterior_x = posterior.mean_continuous_states[0]
    else:
        posterior_x = posterior.mean[0]

    # Align the labels between true and most likely
    true_states = np.array([0 if i < T / 2 else 1 for i in range(T)])
    slds.permute(
        find_permutation(true_states, slds.most_likely_states(posterior_x, X),
                         true_K, Kmax))
    pred_states = slds.most_likely_states(posterior_x, X)
    print("predicted states:")
    print(pred_states)
    # extract predicted A1, A2 matrices
    Ahats, bhats = convert_slds_to_tvart(slds)
    # A_r = slds.dynamics.As
    # b_r = slds.dynamics.bs
    # Cs = slds.emissions.Cs[0]
    # A1_pred = Cs @ A_r[0] @ np.linalg.pinv(Cs)
    # A2_pred = Cs @ A_r[1] @ np.linalg.pinv(Cs)
    # compare inferred and true
    #err_inf = 0.5 * (np.max(np.abs(A1_pred[:] - A1[:])) + \
    #                 np.max(np.abs(A2_pred[:] - A2[:])))
    #err_2 = 0.5 * (norm(A1_pred - A1, 2) + \
    #               norm(A2_pred - A2, 2))
    #err_fro = 0.5 * (norm(A1_pred - A1, 'fro') + \
    #                 norm(A2_pred - A2, 'fro'))
    err_mse, err_inf, err_2, err_fro = errors(Ahats, bhats, pred_states,
                                              true_states, A1, A2, X)
    return (err_inf, err_2, err_fro, err_mse, elbos)
Esempio n. 7
0
def fit_arhmm_and_return_errors(X,
                                A1,
                                A2,
                                Kmax=4,
                                num_restarts=1,
                                num_iters=100,
                                rank=None):
    '''
    Fit an ARHMM to test data and return errors.
    
    Parameters
    ==========
    
    X : array, T x N
    A1 : array, N x N
    A2 : array, N x N
    '''
    # hardcoded
    true_K = 2
    # params
    N = X.shape[1]
    T = X.shape[0]

    if rank is not None:
        # project data down
        u, s, vt = np.linalg.svd(X)
        Xp = u[:, 0:rank] * s[0:rank]  # T x rank matrix
    else:
        Xp = X

    def _fit_once():
        # fit a model
        if rank is not None:
            arhmm = ssm.HMM(Kmax, rank, observations="ar")
        else:
            arhmm = ssm.HMM(Kmax, N, observations="ar")
        lls = arhmm.fit(Xp, num_iters=num_iters)
        return arhmm, lls

    # Fit num_restarts many models
    results = []
    for restart in range(num_restarts):
        print("restart ", restart + 1, " / ", num_restarts)
        results.append(_fit_once())
    arhmms, llss = list(zip(*results))

    # Take the ARHMM that achieved the highest training ELBO
    best = np.argmax([lls[-1] for lls in llss])
    arhmm, lls = arhmms[best], llss[best]

    # xhat = arhmm.smooth(X)
    pred_states = arhmm.most_likely_states(Xp)

    # Align the labels between true and most likely
    true_states = np.array([0 if i < T / 2 else 1 for i in range(T)])
    arhmm.permute(find_permutation(true_states, pred_states, true_K, Kmax))
    print("predicted states:")
    print(pred_states)
    # extract predicted A1, A2 matrices
    Ahats, bhats = arhmm.observations.As, arhmm.observations.bs
    if rank is not None:
        # project back up
        Ahats = [vt[0:rank, :].T @ Ahat @ vt[0:rank, :] for Ahat in Ahats]
        bhats = [vt[0:rank, :].T @ bhat for bhat in bhats]

    # A_r = slds.dynamics.As
    # b_r = slds.dynamics.bs
    # Cs = slds.emissions.Cs[0]
    # A1_pred = Cs @ A_r[0] @ np.linalg.pinv(Cs)
    # A2_pred = Cs @ A_r[1] @ np.linalg.pinv(Cs)
    # compare inferred and true
    #err_inf = 0.5 * (np.max(np.abs(A1_pred[:] - A1[:])) + \
    #                 np.max(np.abs(A2_pred[:] - A2[:])))
    #err_2 = 0.5 * (norm(A1_pred - A1, 2) + \
    #               norm(A2_pred - A2, 2))
    #err_fro = 0.5 * (norm(A1_pred - A1, 'fro') + \
    #                 norm(A2_pred - A2, 'fro'))
    err_mse, err_inf, err_2, err_fro = errors(Ahats, bhats, pred_states,
                                              true_states, A1, A2, X)
    return (err_inf, err_2, err_fro, err_mse, lls)
Esempio n. 8
0
             emissions="gaussian_orthog",
             single_subspace=True)
rslds.initialize(y)

q = SLDSTriDiagVariationalPosterior(rslds, y)
elbos = rslds.fit(q, y, num_iters=1000, initialize=False)


# In[5]:


# Get the posterior mean of the continuous states
xhat = q.mean[0]

# Find the permutation that matches the true and inferred states
rslds.permute(find_permutation(z, rslds.most_likely_states(xhat, y)))
zhat = rslds.most_likely_states(xhat, y)


# In[6]:


# Plot some results
plt.figure()
plt.plot(elbos)
plt.xlabel("Iteration")
plt.ylabel("ELBO")


# In[7]:
Esempio n. 9
0
# Now create a new HMM and fit it to the data with EM
N_iters = 50
hmm = ssm.HMM(K,
              D,
              M,
              observations="categorical",
              observation_kwargs=dict(C=C),
              transitions="inputdriven")

# Fit
hmm_lps = hmm.fit(y, inputs=inpt, method="em", num_em_iters=N_iters)

# In[5]:

# Find a permutation of the states that best matches the true and inferred states
hmm.permute(find_permutation(z, hmm.most_likely_states(y, input=inpt)))
z_inf = hmm.most_likely_states(y, input=inpt)

# In[6]:

# Plot the log probabilities of the true and fit models
plt.plot(hmm_lps, label="EM")
plt.plot([0, N_iters], true_lp * np.ones(2), ':k', label="True")
plt.legend(loc="lower right")
plt.xlabel("EM Iteration")
plt.xlim(0, N_iters)
plt.ylabel("Log Probability")

# In[7]:

# Plot the true and inferred states
Esempio n. 10
0
                     observation_kwargs=dict(C=num_categories),
                     transitions="inputdriven")
#new_glmhmm.observations = GLM_PoissonObservations(num_states, obs_dim, input_dim) ##obs:"input_driven"
new_glmhmm.observations = InputVonMisesObservations(num_states, obs_dim,
                                                    input_dim)

N_iters = 100  # maximum number of EM iterations. Fitting with stop earlier if increase in LL is below tolerance specified by tolerance parameter
fit_ll = new_glmhmm.fit(true_choices,
                        inputs=inpts,
                        method="em",
                        num_iters=N_iters)  #, tolerance=10**-4)

# %%
new_glmhmm.permute(
    find_permutation(
        true_latents[0],
        new_glmhmm.most_likely_states(true_choices[0], input=inpts[0])))

# %%
true_obs_ws = true_glmhmm.observations.mus  #Wk
inferred_obs_ws = new_glmhmm.observations.mus

cols = ['r', 'g', 'b']
plt.figure()
for ii in range(num_states):
    plt.plot(true_obs_ws[ii][0], linewidth=5, label='ture', color=cols[ii])
    plt.plot(inferred_obs_ws[ii][0],
             '--',
             linewidth=5,
             label='inferred',
             color=cols[ii])
Esempio n. 11
0
plt.figure()
plt.imshow(true_hmm.transitions.transition_matrix, vmin=0.0, vmax=1.0, aspect="auto")

# run for multiple models
N = 10
max_ll = -np.inf
for n in range(N):
    test_hmm_temp = HMM(K, D, observations="poisson") 
    poiss_lls_temp = test_hmm_temp.fit(y, num_iters=20)
    if poiss_lls_temp[-1] > max_ll:
        max_ll = poiss_lls_temp[-1]
        poiss_lls = poiss_lls_temp 
        test_hmm = test_hmm_temp
# test_hmm = HMM(K, D, observations="poisson") 
# poiss_lls = test_hmm.fit(y, num_iters=20)
test_hmm.permute(find_permutation(z, test_hmm.most_likely_states(y)))
smoothed_z = test_hmm.most_likely_states(y)

plt.figure()
plt.subplot(211)
plt.imshow(np.row_stack((z, smoothed_z)), aspect="auto")
plt.xlim([0,T_plot])
plt.subplot(212)
# plt.plot(y)
for n in range(D):
    plt.eventplot(np.where(y[:,n]>0)[0]+1, linelengths=0.5, lineoffsets=D-n,color='k')
plt.xlim([0,T_plot])


As = np.clip(0.8 + 0.1 * npr.randn(D), 0.6, 0.95)
betas = 1.0 * np.ones(D)
Esempio n. 12
0
N_iters = 200 # maximum number of EM iterations. Fitting with stop earlier if increase in LL is below tolerance specified by tolerance parameter
fit_ll = new_glmhmm.fit(true_choices, inputs=inpt, method="em", num_iters=N_iters, tolerance=10**-4)
# Plot the log probabilities of the true and fit models. Fit model final LL should be greater
# than or equal to true LL.
fig = plt.figure(figsize=(4, 3), dpi=80, facecolor='w', edgecolor='k')
plt.plot(fit_ll, label="EM")
plt.plot([0, len(fit_ll)], true_ll * np.ones(2), ':k', label="True")
plt.legend(loc="lower right")
plt.xlabel("EM Iteration")
plt.xlim(0, len(fit_ll))
plt.ylabel("Log Probability")
plt.show()

new_glmhmm.permute(find_permutation(true_latents[0].ravel(),
                                    new_glmhmm.most_likely_states(true_choices[0],
                                    input=inpt[0])))
fig = plt.figure( dpi=80, facecolor='w', edgecolor='k')
cols = ['#ff7f00', '#4daf4a', '#377eb8']
recovered_weights = new_glmhmm.observations.params
covariates =[
    'choice*reward t-1',
    'choice*reward t-2',
    'choice*reward t-3',
    'choice*reward t-4',
    'choice*reward t-5',
    'choice*reward t-6',
    'choice*reward t-7',
    'choice*reward t-8',
    'choice*reward t-9',
    'choice*reward t-10',
Esempio n. 13
0
                emissions="poisson_orthog",
                emission_kwargs=dict(link="softplus"))
slds.initialize(y_masked, masks=mask)

q_svi_elbos, q_svi = slds.fit(y_masked,
                              masks=mask,
                              method="bbvi",
                              variational_posterior="tridiag",
                              initial_variance=1,
                              num_iters=1000,
                              print_intvl=100,
                              initialize=False)
q_svi_x = q_svi.mean[0]

# Find the permutation that matches the true and inferred states
slds.permute(find_permutation(z, slds.most_likely_states(q_svi_x, y)))
q_svi_z = slds.most_likely_states(q_svi_x, y)

# Smooth the observations
q_svi_y = slds.smooth(q_svi_x, y)

# In[6]:

print("Fitting SLDS with Laplace-EM")
slds = ssm.SLDS(N,
                K,
                D,
                emissions="poisson_orthog",
                emission_kwargs=dict(link="softplus"))
slds.initialize(y_masked, masks=mask)
Esempio n. 14
0
    def test_hsmm_example(self):
        import autograd.numpy as np
        import autograd.numpy.random as npr
        from scipy.stats import nbinom
        import matplotlib.pyplot as plt
        import ssm
        from ssm.util import rle, find_permutation

        npr.seed(0)

        # Set the parameters of the HMM
        T = 5000  # number of time bins
        K = 5  # number of discrete states
        D = 2  # number of observed dimensions

        # Make an HMM with the true parameters
        true_hsmm = ssm.HSMM(K, D, observations="gaussian")
        print(true_hsmm.transitions.rs)
        z, y = true_hsmm.sample(T)
        z_test, y_test = true_hsmm.sample(T)
        true_ll = true_hsmm.log_probability(y)

        # Fit an HSMM
        N_em_iters = 500

        print("Fitting Gaussian HSMM with EM")
        hsmm = ssm.HSMM(K, D, observations="gaussian")
        hsmm_em_lls = hsmm.fit(y, method="em", num_em_iters=N_em_iters)

        print("Fitting Gaussian HMM with EM")
        hmm = ssm.HMM(K, D, observations="gaussian")
        hmm_em_lls = hmm.fit(y, method="em", num_em_iters=N_em_iters)

        # Plot log likelihoods (fit model is typically better)
        plt.figure()
        plt.plot(hsmm_em_lls, ls='-', label="HSMM (EM)")
        plt.plot(hmm_em_lls, ls='-', label="HMM (EM)")
        plt.plot(true_ll * np.ones(N_em_iters), ':', label="true")
        plt.legend(loc="lower right")

        # Print the test likelihoods (true model is typically better)
        print("Test log likelihood")
        print("True HSMM: ", true_hsmm.log_likelihood(y_test))
        print("Fit HSMM:  ", hsmm.log_likelihood(y_test))
        print("Fit HMM: ", hmm.log_likelihood(y_test))

        # Plot the true and inferred states
        hsmm.permute(find_permutation(z, hsmm.most_likely_states(y)))
        hsmm_z = hsmm.most_likely_states(y)
        hmm.permute(find_permutation(z, hmm.most_likely_states(y)))
        hmm_z = hsmm.most_likely_states(y)

        # Plot the true and inferred discrete states
        plt.figure(figsize=(8, 6))
        plt.subplot(311)
        plt.imshow(z[None, :1000],
                   aspect="auto",
                   cmap="cubehelix",
                   vmin=0,
                   vmax=K - 1)
        plt.xlim(0, 1000)
        plt.ylabel("True $z")
        plt.yticks([])

        plt.subplot(312)
        plt.imshow(hsmm_z[None, :1000],
                   aspect="auto",
                   cmap="cubehelix",
                   vmin=0,
                   vmax=K - 1)
        plt.xlim(0, 1000)
        plt.ylabel("HSMM Inferred $z$")
        plt.yticks([])

        plt.subplot(313)
        plt.imshow(hmm_z[None, :1000],
                   aspect="auto",
                   cmap="cubehelix",
                   vmin=0,
                   vmax=K - 1)
        plt.xlim(0, 1000)
        plt.ylabel("HMM Inferred $z$")
        plt.yticks([])
        plt.xlabel("time")

        plt.tight_layout()

        # Plot the true and inferred duration distributions
        states, durations = rle(z)
        inf_states, inf_durations = rle(hsmm_z)
        hmm_inf_states, hmm_inf_durations = rle(hmm_z)
        max_duration = max(np.max(durations), np.max(inf_durations),
                           np.max(hmm_inf_durations))
        dd = np.arange(max_duration, step=1)

        plt.figure(figsize=(3 * K, 9))
        for k in range(K):
            # Plot the durations of the true states
            plt.subplot(3, K, k + 1)
            plt.hist(durations[states == k] - 1, dd, density=True)
            plt.plot(dd,
                     nbinom.pmf(dd, true_hsmm.transitions.rs[k],
                                1 - true_hsmm.transitions.ps[k]),
                     '-k',
                     lw=2,
                     label='true')
            if k == K - 1:
                plt.legend(loc="lower right")
            plt.title("State {} (N={})".format(k + 1, np.sum(states == k)))

            # Plot the durations of the inferred states
            plt.subplot(3, K, K + k + 1)
            plt.hist(inf_durations[inf_states == k] - 1, dd, density=True)
            plt.plot(dd,
                     nbinom.pmf(dd, hsmm.transitions.rs[k],
                                1 - hsmm.transitions.ps[k]),
                     '-r',
                     lw=2,
                     label='hsmm inf.')
            if k == K - 1:
                plt.legend(loc="lower right")
            plt.title("State {} (N={})".format(k + 1, np.sum(inf_states == k)))

            # Plot the durations of the inferred states
            plt.subplot(3, K, 2 * K + k + 1)
            plt.hist(hmm_inf_durations[hmm_inf_states == k] - 1,
                     dd,
                     density=True)
            plt.plot(dd,
                     nbinom.pmf(dd, 1,
                                1 - hmm.transitions.transition_matrix[k, k]),
                     '-r',
                     lw=2,
                     label='hmm inf.')
            if k == K - 1:
                plt.legend(loc="lower right")
            plt.title("State {} (N={})".format(k + 1,
                                               np.sum(hmm_inf_states == k)))
        plt.tight_layout()

        plt.show()
Esempio n. 15
0
    def test_own_hsmm_example(self):
        import autograd.numpy as np
        import autograd.numpy.random as npr
        from scipy.stats import nbinom
        import matplotlib.pyplot as plt
        import ssm
        from ssm.util import rle, find_permutation

        print(npr.seed(0))

        # Set the parameters of the HMM
        T = 1000  # number of time bins todo why can't I set this < 500
        K = 8  # number of discrete states
        D = 5  # number of observed dimensions

        # Make an HMM with the true parameters
        true_hsmm = ssm.HSMM(K, D, observations="categorical")
        z, y = true_hsmm.sample(T)
        z_test, y_test = true_hsmm.sample(T)
        true_ll = true_hsmm.log_probability(y)

        # Fit an HSMM
        N_em_iters = 100

        print("Fitting Categorical HSMM with EM")
        hsmm = ssm.HSMM(K, D, observations="categorical")
        hsmm_em_lls = hsmm.fit(y, method="em", num_em_iters=N_em_iters)

        print("Fitting Categorical HMM with EM")
        hmm = ssm.HMM(K, D, observations="categorical")
        hmm_em_lls = hmm.fit(y, method="em", num_em_iters=N_em_iters)

        # Plot log likelihoods (fit model is typically better)
        plt.figure()
        plt.plot(hsmm_em_lls, ls='-', label="HSMM (EM)")
        plt.plot(hmm_em_lls, ls='-', label="HMM (EM)")
        plt.plot(true_ll * np.ones(N_em_iters), ':', label="true")
        plt.legend(loc="lower right")

        # Print the test likelihoods (true model is typically better)
        print("Test log likelihood")
        print("True HSMM: ", true_hsmm.log_likelihood(y_test))
        print("Fit HSMM:  ", hsmm.log_likelihood(y_test))
        print("Fit HMM: ", hmm.log_likelihood(y_test))

        # Plot the true and inferred states
        tmp1 = hsmm.most_likely_states(y)
        tmp2 = find_permutation(z, tmp1)
        hsmm.permute(tmp2)
        hsmm_z = hsmm.most_likely_states(y)

        # calculates viterbi sequence of states
        tmp3 = hmm.most_likely_states(y)
        #
        """
        z = true state seq [1,2,1,....,]
        tmp3 = pred. state seq [3,4,1,2,...,]
        match each row to different column in such a way that corresp
        sum is minimized
        select n el of C, so that there is exactly one el.  in each row 
        and one in each col. with min corresp. costs 
        
        
        match states [1,2,...,] of of the 
        """
        tmp4 = find_permutation(z, tmp3)
        hmm.permute(tmp4)
        hmm_z = hsmm.most_likely_states(y)

        # Plot the true and inferred discrete states
        plt.figure(figsize=(8, 6))
        plt.subplot(311)
        plt.imshow(z[None, :1000],
                   aspect="auto",
                   cmap="cubehelix",
                   vmin=0,
                   vmax=K - 1)
        plt.xlim(0, 1000)
        plt.ylabel("True $z")
        plt.yticks([])

        plt.subplot(312)
        plt.imshow(hsmm_z[None, :1000],
                   aspect="auto",
                   cmap="cubehelix",
                   vmin=0,
                   vmax=K - 1)
        plt.xlim(0, 1000)
        plt.ylabel("HSMM Inferred $z$")
        plt.yticks([])

        plt.subplot(313)
        plt.imshow(hmm_z[None, :1000],
                   aspect="auto",
                   cmap="cubehelix",
                   vmin=0,
                   vmax=K - 1)
        plt.xlim(0, 1000)
        plt.ylabel("HMM Inferred $z$")
        plt.yticks([])
        plt.xlabel("time")

        plt.tight_layout()

        # Plot the true and inferred duration distributions
        """
        N = the number of infered states 
            how often the state was inferred 
            blue bar is how often when one was in that state it endured x long
        x = maximal duration in a state
        
        
        red binomial plot
            for the hmm it is 1 trial and the self transitioning probability
            for the hsmm it is
            
        """
        """
        Negativ binomial distribution for state durations
        
            NB(r,p)
                r int, r>0
                p = [0,1] always .5 wk des eintretens von erfolgreicher transition
                r = anzahl erflogreiche selbst transitionen  befor man etwas anderes (trans in anderen
                zustand sieht)
                
                
        
        
        """

        true_states, true_durations = rle(z)
        hmm_inf_states, hmm_inf_durations = rle(hmm_z)
        hsmm_inf_states, hsmm_inf_durations = rle(hsmm_z)
        max_duration = max(np.max(true_durations), np.max(hsmm_inf_durations),
                           np.max(hmm_inf_durations))
        max_duration = 100
        dd = np.arange(max_duration, step=1)

        plt.figure(figsize=(3 * K, 9))
        for k in range(K):
            # Plot the durations of the true states
            plt.subplot(3, K, k + 1)
            """
            get the durations where it was gone into the state k =1
            state_seq: [0,1,2,3,1,1]
            dur_seq: [1,4,5,2,4,2]
                meaning one ts in state 0, than 4 in state 1, 5 in state 2, so on and so forth
            x = [4,4,2]
            """
            x = true_durations[true_states == k] - 1
            plt.hist(x, dd, density=True)
            n = true_hsmm.transitions.rs[k]
            p = 1 - true_hsmm.transitions.ps[k]
            plt.plot(dd, nbinom.pmf(dd, n, p), '-k', lw=2, label='true')
            if k == K - 1:
                plt.legend(loc="lower right")
            plt.title("State {} (N={})".format(k + 1,
                                               np.sum(true_states == k)))

            # Plot the durations of the inferred states of hmm
            plt.subplot(3, K, 2 * K + k + 1)
            plt.hist(hmm_inf_durations[hmm_inf_states == k] - 1,
                     dd,
                     density=True)
            plt.plot(dd,
                     nbinom.pmf(dd, 1,
                                1 - hmm.transitions.transition_matrix[k, k]),
                     '-r',
                     lw=2,
                     label='hmm inf.')
            if k == K - 1:
                plt.legend(loc="lower right")
            plt.title("State {} (N={})".format(k + 1,
                                               np.sum(hmm_inf_states == k)))

            # Plot the durations of the inferred states of hsmm
            plt.subplot(3, K, K + k + 1)
            plt.hist(hsmm_inf_durations[hsmm_inf_states == k] - 1,
                     dd,
                     density=True)
            plt.plot(dd,
                     nbinom.pmf(dd, hsmm.transitions.rs[k],
                                1 - hsmm.transitions.ps[k]),
                     '-r',
                     lw=2,
                     label='hsmm inf.')
            if k == K - 1:
                plt.legend(loc="lower right")
            plt.title("State {} (N={})".format(k + 1,
                                               np.sum(hsmm_inf_states == k)))

        plt.tight_layout()

        plt.show()