def _durs_withlabel(self,k):
     assert k != SAMPLING
     if len(self.stateseq) > 0:
         stateseq_norep, durations = rle(self.stateseq)
         return durations[stateseq_norep == k]
     else:
         return []
Example #2
0
    def plot_states(self, ax_data, ax_states, estimated_states_seq):
        # We plot a pcolormesh like the statesequence to show the difference between the prior statesq and the posterior statesq
        # Green is for same states and red is for different states
        diff = np.array(estimated_states_seq == self.true_states, dtype=int)
        state_colors = {0: 0, 1: 1}
        stateseq_norep, durations = rle(diff)
        datamin, datamax = 0., 1.
        x, y = np.hstack((0, durations.cumsum())), np.array([datamin, datamax])
        C = np.atleast_2d([state_colors[state] for state in stateseq_norep])

        ax_states.pcolormesh(x, y, C, vmin=0, vmax=1, alpha=0.3, cmap='RdYlGn')
        ax_states.set_ylim((datamin, datamax))
        ax_states.set_xlim((0, len(diff)))
        ax_states.set_yticks([])
        ax_states.set_title('differences')

        # We plot the data with the estimated one in red dashes
        mu = self.controlled_params['air']['mu']
        powers_hat = np.array(
            [mu[int(state)] for state in estimated_states_seq])

        ax_data.plot(self.true_powers, label='true powers')
        ax_data.plot(powers_hat, 'r--', label='estimated powers')
        ax_data.set_xlim((0, len(self.true_powers)))
        ax_data.legend(loc='best')

        return float(np.sum(diff)) / len(diff)
    def __init__(self,model,beta,alpha_0,obs,dur,data=None,T=None,stateseq=None):
        self.alpha_0 = alpha_0

        self.model = model
        self.beta = beta
        self.obs = obs
        self.dur = dur

        self.data = data

        if (data,stateseq) == (None,None):
            # generating
            assert T is not None, 'must pass in T when generating'
            self._generate(T)
        elif data is None:
            self.T = stateseq.shape[0]
            self.stateseq = stateseq
        elif stateseq is None:
            self.data = data
            # self._generate(data.shape[0]) # initialized from the prior
            # self.stateseq = self.stateseq[:self.T]
            self.stateseq = np.random.randint(25,size=data.shape[0])
            self.T = data.shape[0]
        else:
            assert data.shape[0] == stateseq.shape[0]
            self.stateseq = stateseq
            self.stateseq_norep, self.durations = rle(stateseq)
            self.data = data
            self.T = data.shape[0]
 def _counts_to(self,k):
     assert k != SAMPLING
     stateseq_norep, _ = rle(self.stateseq)
     temp = np.sum(stateseq_norep[1:] == k)
     if SAMPLING in stateseq_norep[:-1] and \
             stateseq_norep[np.where(stateseq_norep == SAMPLING)[0]+1] == k:
         temp -= 1
     return temp
 def _counts_fromto(self,k1,k2):
     assert k1 != SAMPLING and k2 != SAMPLING
     if k1 not in self.stateseq or k2 not in self.stateseq or k1 == k2:
         return 0
     else:
         stateseq_norep, _ = rle(self.stateseq)
         from_indices, = np.where(stateseq_norep[:-1] == k1) # EXCEPT last
         return np.sum(stateseq_norep[from_indices+1] == k2)
Example #6
0
def plot_states_and_var(data,
                        hidden_states,
                        cmap=None,
                        columns=None,
                        by='Activity'):
    """
    Make  a plot of the data and the states

    Parameters
    ----------
    data : pandas DataFrame
        Data to plot
    hidden_states: iteretable
        the hidden states corresponding to the timesteps
    columns : list, optional
        Which columns to plot
    by : str
        The column to group on
    """
    fig, ax = plt.subplots(figsize=(15, 5))
    if columns is None:
        columns = data.columns
    df = data[columns].copy()
    stateseq = np.array(hidden_states)
    stateseq_norep, durations = rle(stateseq)
    datamin, datamax = np.array(df).min(), np.array(df).max()
    y = np.array([datamin, datamax])
    maxstate = stateseq.max() + 1
    x = np.hstack(([0], durations.cumsum()[:-1], [len(df.index) - 1]))
    C = np.array([[float(state) / maxstate]
                  for state in stateseq_norep]).transpose()
    ax.set_xlim((min(x), max(x)))

    if cmap is None:
        num_states = max(hidden_states) + 1
        colormap, cmap = get_color_map(num_states)
    pc = ax.pcolorfast(x, y, C, vmin=0, vmax=1, alpha=0.3, cmap=cmap)
    plt.plot(df.as_matrix())
    locator = AutoDateLocator()
    locator.create_dummy_axis()
    num_index = pd.Index(df.index.map(date2num))
    ticks_num = locator.tick_values(min(df.index), max(df.index))
    ticks = [num_index.get_loc(t) for t in ticks_num]
    plt.xticks(ticks, df.index.strftime('%H:%M')[ticks], rotation='vertical')
    cb = plt.colorbar(pc)
    cb.set_ticks(np.arange(1. / (2 * cmap.N), 1, 1. / cmap.N))
    cb.set_ticklabels(np.arange(0, cmap.N))
    # Plot the activities
    if by is not None:
        actseq = np.array(data[by])
        sca = ax.scatter(
            np.arange(len(hidden_states)),  #data.index,
            np.ones_like(hidden_states) * datamax,
            c=actseq,
            edgecolors='none')
    plt.show()
    return fig, ax
Example #7
0
def plot_states_and_var(data, hidden_states, cmap=None, columns=None, by='Activity'):
    """
    Make  a plot of the data and the states

    Parameters
    ----------
    data : pandas DataFrame
        Data to plot
    hidden_states: iteretable
        the hidden states corresponding to the timesteps
    columns : list, optional
        Which columns to plot
    by : str
        The column to group on
    """
    fig, ax = plt.subplots(figsize=(15, 5))
    if columns is None:
        columns = data.columns
    df = data[columns].copy()
    stateseq = np.array(hidden_states)
    stateseq_norep, durations = rle(stateseq)
    datamin, datamax = np.array(df).min(), np.array(df).max()
    y = np.array(
        [datamin, datamax])
    maxstate = stateseq.max() + 1
    x = np.hstack(([0], durations.cumsum()[:-1], [len(df.index) - 1]))
    C = np.array(
        [[float(state) / maxstate] for state in stateseq_norep]).transpose()
    ax.set_xlim((min(x), max(x)))

    if cmap is None:
        num_states = max(hidden_states) + 1
        colormap, cmap = get_color_map(num_states)
    pc = ax.pcolorfast(x, y, C, vmin=0, vmax=1, alpha=0.3, cmap=cmap)
    plt.plot(df.as_matrix())
    locator = AutoDateLocator()
    locator.create_dummy_axis()
    num_index = pd.Index(df.index.map(date2num))
    ticks_num = locator.tick_values(min(df.index), max(df.index))
    ticks = [num_index.get_loc(t) for t in ticks_num]
    plt.xticks(ticks, df.index.strftime('%H:%M')[ticks], rotation='vertical')
    cb = plt.colorbar(pc)
    cb.set_ticks(np.arange(1./(2*cmap.N), 1, 1./cmap.N))
    cb.set_ticklabels(np.arange(0, cmap.N))
    # Plot the activities
    if by is not None:
        actseq = np.array(data[by])
        sca = ax.scatter(
            np.arange(len(hidden_states)), #data.index,
            np.ones_like(hidden_states) * datamax,
            c=actseq,
            edgecolors='none'
        )
    plt.show()
    return fig, ax
Example #8
0
    def resample(self,stateseqs=[]):
        if type(stateseqs) != type([]):
            stateseqs = [stateseqs]

        states_noreps = map(lambda x: rle(x)[0], stateseqs)

        if not any(len(states_norep) >= 2 for states_norep in states_noreps):
            # if there is no data we just sample from the prior
            self.beta = stats.gamma.rvs(self.alpha / self.state_dim, size=self.state_dim)
            self.beta /= np.sum(self.beta)

            self.fullA = stats.gamma.rvs(self.beta * self.gamma * np.ones((self.state_dim,1)))
            self.A = (1.-np.eye(self.state_dim)) * self.fullA
            self.fullA /= np.sum(self.fullA,axis=1)[:,na]
            self.A /= np.sum(self.A,axis=1)[:,na]

            assert not np.isnan(self.beta).any()
            assert not np.isnan(self.fullA).any()
            assert (self.A.diagonal() == 0).all()
        else:
            # make 2d array of transition counts
            data = np.zeros((self.state_dim,self.state_dim))
            for states_norep in states_noreps:
                for idx in xrange(len(states_norep)-1):
                    data[states_norep[idx],states_norep[idx+1]] += 1
            # we resample the children (A) then the mother (beta)
            # first, we complete the data using the current parameters
            # every time we transferred from a state, we had geometrically many
            # self-transitions thrown away that we want to sample
            assert (data.diagonal() == 0).all()
            froms = np.sum(data,axis=1)
            self_transitions = np.array([np.sum(stats.geom.rvs(1.-self.fullA.diagonal()[idx],size=from_num)) if from_num > 0 else 0 for idx, from_num in enumerate(froms)])
            self_transitions[froms == 0] = 0 # really emphasized here!
            assert (self_transitions < 1e7).all(), 'maybe alpha is too low... code is not happy about that at the moment'
            augmented_data = data + np.diag(self_transitions)
            # then, compute m's and stuff
            m = np.zeros((self.state_dim,self.state_dim))
            for rowidx in xrange(self.state_dim):
                for colidx in xrange(self.state_dim):
                    n = 0.
                    for i in xrange(int(augmented_data[rowidx,colidx])):
                        m[rowidx,colidx] += random() < self.alpha * self.beta[colidx] / (n + self.alpha * self.beta[colidx])
                        n += 1.
            self.m = m # save it for possible use in any child classes

            # resample mother (beta)
            self.beta = stats.gamma.rvs(self.alpha / self.state_dim  + np.sum(m,axis=0))
            self.beta /= np.sum(self.beta)
            assert not np.isnan(self.beta).any()
            # resample children (fullA and A)
            self.fullA = stats.gamma.rvs(self.gamma * self.beta + augmented_data)
            self.fullA /= np.sum(self.fullA,axis=1)[:,na]
            self.A = self.fullA * (1.-np.eye(self.state_dim))
            self.A /= np.sum(self.A,axis=1)[:,na]
            assert not np.isnan(self.A).any()
Example #9
0
def slices_from_indicators(indseq):
    indseq = np.asarray(indseq)
    if not indseq.any():
        return []
    else:
        vals, durs = rle(indseq)
        starts, ends = cumsum(durs, strict=True), cumsum(durs, strict=False)
        return [
            slice(start, end) for val, start, end in zip(vals, starts, ends)
            if val
        ]
Example #10
0
    def plot(self,colors_dict):
        from matplotlib import pyplot as plt
        stateseq_norep, durations = rle(self.stateseq)
        X,Y = np.meshgrid(np.hstack((0,durations.cumsum())),(0,1))

        if colors_dict is not None:
            C = np.array([[colors_dict[state] for state in stateseq_norep]])
        else:
            C = stateseq_norep[na,:]

        plt.pcolor(X,Y,C,vmin=0,vmax=1)
        plt.ylim((0,1))
        plt.xlim((0,len(self.stateseq)))
        plt.yticks([])
Example #11
0
    def plot(self,colors_dict=None):
        from matplotlib import pyplot as plt
        from pyhsmm.util.general import rle
        states,durations = rle(self.stateseq)
        X,Y = np.meshgrid(np.hstack((0,durations.cumsum())),(0,1))

        if colors_dict is not None:
            C = np.array([[colors_dict[state] for state in states]])
        else:
            C = states[na,:]

        plt.pcolor(X,Y,C,vmin=0,vmax=1)
        plt.ylim((0,1))
        plt.xlim((0,durations.sum()))
        plt.yticks([])
Example #12
0
    def sample_forwards(self,betal,betastarl):
        aBl = self.aBl
        # stateseq = np.array(self.stateseq,dtype=np.int32)
        stateseq = np.zeros(betal.shape[0],dtype=np.int32)
        A = self.transition_distn.A
        pi0 = self.initial_distn.pi_0

        apmf = np.zeros((self.state_dim,self.T))
        arg = np.arange(1,self.T+1)
        for state_idx, dur_distn in enumerate(self.dur_distns):
            apmf[state_idx] = dur_distn.pmf(arg)

        scipy.weave.inline(self.sample_forwards_codestr,['betal','betastarl','aBl','stateseq','A','pi0','apmf'],headers=['<Eigen/Core>'],include_dirs=['/usr/local/include/eigen3'],extra_compile_args=['-O3'])#,'-march=native'])

        self.stateseq_norep, self.durations = util.rle(stateseq)
        self.stateseq = stateseq
Example #13
0
    def plot_observations(self,colors=None,states_objs=None):
        if colors is None:
            colors = self._get_colors()
        if states_objs is None:
            states_objs = self.states_list

        cmap = cm.get_cmap()

        for s in states_objs:
            data = undo_AR_striding(s.data,self.nlags)

            stateseq_norep, durs = rle(s.stateseq)
            starts = np.concatenate(((0,),durs.cumsum()))
            for state,start,dur in zip(stateseq_norep,starts,durs):
                plt.plot(
                        np.arange(start,start+data[start:start+dur].shape[0]),
                        data[start:start+dur],
                        color=cmap(colors[state]))
            plt.xlim(0,s.T-1)
def plot_stateseq(s, ax, Interval):
    num_states = s.state_dim
    data = s.data
    stateseq = s.stateseq
    # Colors
    cmap = cm.get_cmap()
    state_usages = np.bincount(stateseq, minlength=num_states)
    freqs = state_usages / state_usages.sum()
    #used_states = sorted(set(stateseq), key=lambda x: freqs[x], reverse=True)
    used_states = sorted(set(stateseq), key=lambda x: s.obs_distns[x].mu)
    unused_states = [
        idx for idx in range(num_states) if idx not in used_states
    ]

    colorseq = np.linspace(0, 1, num_states)
    state_colors = dict((idx, v) for idx, v in zip(used_states, colorseq))

    for state in unused_states:
        state_colors[state] = cmap(1.)

    # State sequence colors
    from pyhsmm.util.general import rle

    stateseq_norep, durations = rle(stateseq)
    datamin, datamax = data.min(), data.max()

    x, y = np.hstack((0, durations.cumsum())), np.array([datamin, datamax])
    C = np.atleast_2d([state_colors[state] for state in stateseq_norep])

    ax.pcolormesh(x,
                  y,
                  C,
                  cmap=cm.get_cmap('summer'),
                  vmin=0,
                  vmax=1,
                  alpha=0.8)
    ax.set_ylim((datamin, datamax))
    ax.set_xlim((Interval[0], Interval[1]))
    ax.set_yticks([])
Example #15
0
    def __init__(self,T,state_dim,obs_distns,dur_distns,transition_distn,initial_distn,
            stateseq=None,trunc=None,data=None,**kwargs):
        # TODO T parameter only makes sense with censoring. it should be
        # removed.
        self.T = T
        self.state_dim = state_dim
        self.obs_distns = obs_distns
        self.dur_distns = dur_distns
        self.transition_distn = transition_distn
        self.initial_distn = initial_distn
        self.trunc = T if trunc is None else trunc
        self.data = data

        # this arg is for initialization heuristics which may pre-determine the
        # state sequence
        if stateseq is not None:
            self.stateseq = stateseq
            # gather durations and stateseq_norep
            self.stateseq_norep, self.durations = util.rle(stateseq)
        else:
            if data is not None:
                self.resample()
            else:
                self.generate_states()
Example #16
0
    def __init__(self,T,state_dim,obs_distns,dur_distns,transition_distn,initial_distn,stateseq=None,trunc=None,data=None,**kwargs):
        self.T = T
        self.state_dim = state_dim
        self.obs_distns = obs_distns
        self.dur_distns = dur_distns
        self.transition_distn = transition_distn
        self.initial_distn = initial_distn
        self.trunc = T if trunc is None else trunc
        self.data = data

        self.sample_forwards_codestr = hsmm_sample_forwards_codestr % {'M':state_dim,'T':T}

        # self.messages_backwards_codestr = hsmm_messages_backwards_codestr % {'M':state_dim,'T':T}

        # this arg is for initialization heuristics which may pre-determine the state sequence
        if stateseq is not None:
            self.stateseq = stateseq
            # gather durations and stateseq_norep
            self.stateseq_norep, self.durations = util.rle(stateseq)
        else:
            if data is not None:
                self.resample()
            else:
                self.generate_states()
Example #17
0
 def durations(self):
     return rle(self.stateseq)[1]
Example #18
0
 def stateseq_norep(self):
     return rle(self.stateseq)[0]
Example #19
0
 def stateseqs_norep(self):
     return rle(self.z)[0]
Example #20
0
 def durations(self):
     return rle(self.z)[1]
Example #21
0
 def _count_transitions(self,stateseqs):
     stateseq_noreps = [rle(stateseq)[0] for stateseq in stateseqs]
     return super(_HSMMTransitionsBase,self)._count_transitions(stateseq_noreps)
Example #22
0
 def stateseq_norep(self):
     return rle(self.stateseq)[0]
Example #23
0
          aspect="auto")
ax.set_xticklabels([])
ax.set_yticks([])
ax.set_title("Inferred Discrete States")

ax = fig.add_subplot(gs[2, 0])
plt.plot(y[:, 0], color='k', lw=2, label="observed")
plt.plot(smoothed_data[:, 0], color=colors[0], lw=1, label="smoothed")
plt.xlabel("Time")
plt.xlim(0, min(T, 500))
plt.ylabel("Observations")
plt.legend(loc="upper center", ncol=2)
plt.tight_layout()
plt.savefig("aux/demo_smooth.png")

plt.figure()
from pyhsmm.util.general import rle

z_rle = rle(z)
offset = 0
for k, dur in zip(*z_rle):
    plt.plot(x[offset:offset + dur, 0],
             x[offset:offset + dur, 1],
             color=colors[k])
    offset += dur

plt.xlabel("$x_1$")
plt.ylabel("$x_2$")
plt.title("Continuous Latent States")
plt.show()
Example #24
0
    def durations(self):
        if self._durations is None:
            self._letterseq_norep, self._durations = rle(self.letterseq)

        return self._durations
Example #25
0
 def generate_obs(self):
     obs = []
     for state, dur in zip(*rle(self.stateseq)):
         obs.append(self.obs_distns[state].rvs(int(dur)))
     return np.concatenate(obs)
Example #26
0
 def durations(self):
     return rle(self.z)[1]
        ) > bestposteriormodel.log_likelihood():
            print("We have a winner")
            bestposteriormodel = posteriormodel

    plot(bestposteriormodel)
    plt.title(mode)

    plt.show()

    thestateseq = bestposteriormodel.states_list[-1].stateseq
    # NOTE: here we fix the state-ambiguity by flipping the inferred state seq if it's become inverted
    if np.mean(thestateseq) > 0.5:
        thestateseq = 1 - thestateseq
    resultstore[mode] = {
        'stateseq': thestateseq,
        'rle': rle(thestateseq),
    }

# add 'true' to resultstore as if it were a result!
resultstore['true'] = {
    'stateseq': trueseq,
    'rle': rle(trueseq),
}

##############################################################################

#print resultstore
#print trueseq

modelist = ['true', 'geom', 'pois']
modedata = {
Example #28
0
    def stateseq_norep(self):
        if self._stateseq_norep is None:
            self._stateseq_norep, dur = rle(self.stateseq)

        return self._stateseq_norep
Example #29
0
Nmaxsuper = 2*Nsuper
Nmaxsub = 2*Nsub

obs_distnss = \
        [[pyhsmm.distributions.Gaussian(**obs_hypparams)
            for substate in range(Nmaxsub)] for superstate in range(Nmaxsuper)]

dur_distns = \
        [pyhsmm.distributions.NegativeBinomialIntegerR2Duration(
            **dur_hypparams) for superstate in range(Nmaxsuper)]

# !!! cheat to get the changepoints !!! #
changepointss = []
for data, labels in zip(datas,labelss):
    _, durations = rle(labels)
    temp = np.concatenate(((0,),durations.cumsum()))
    changepoints = zip(temp[:-1],temp[1:])
    changepoints[-1] = (changepoints[-1][0],T) # because last duration might be censored
    changepointss.append(changepoints)

# optionally split changepoints
# changepoints = [pair for (a,b) in changepoints for pair in [(a,a+(b-a)//2), (a+(b-a)//2,b)]]
print len(changepoints)


# # plot things to check!!
# plt.figure()
# for idx, (labels, changepoints) in enumerate(zip(labelss,changepointss)):
#     plt.subplot(len(changepointss),1,idx)
#     plt.plot(labels)
Example #30
0
 def durations(self):
     return rle(self.stateseq)[1]
Example #31
0
Nmaxsuper = 2 * Nsuper
Nmaxsub = 2 * Nsub

obs_distnss = \
        [[pyhsmm.distributions.Gaussian(**obs_hypparams)
            for substate in range(Nmaxsub)] for superstate in range(Nmaxsuper)]

dur_distns = \
        [pyhsmm.distributions.NegativeBinomialIntegerR2Duration(
            **dur_hypparams) for superstate in range(Nmaxsuper)]

# !!! cheat to get the changepoints !!! #
changepointss = []
for data, labels in zip(datas, labelss):
    _, durations = rle(labels)
    temp = np.concatenate(((0, ), durations.cumsum()))
    changepoints = zip(temp[:-1], temp[1:])
    changepoints[-1] = (changepoints[-1][0], T
                        )  # because last duration might be censored
    changepointss.append(changepoints)

# optionally split changepoints
# changepoints = [pair for (a,b) in changepoints for pair in [(a,a+(b-a)//2), (a+(b-a)//2,b)]]
print len(changepoints)

# # plot things to check!!
# plt.figure()
# for idx, (labels, changepoints) in enumerate(zip(labelss,changepointss)):
#     plt.subplot(len(changepointss),1,idx)
#     plt.plot(labels)
Example #32
0
 def _count_transitions(self,stateseqs):
     stateseq_noreps = [rle(stateseq)[0] for stateseq in stateseqs]
     return super(_HSMMTransitionsBase,self)._count_transitions(stateseq_noreps)
Example #33
0
ax = fig.add_subplot(gs[1,0])
ax.imshow(test_model.states_list[0].stateseq[None,:], vmin=0, vmax=max(len(colors), test_model.num_states)-1,
          cmap=cmap, interpolation="nearest", aspect="auto")
ax.set_xticklabels([])
ax.set_yticks([])
ax.set_title("Inferred Discrete States")

ax = fig.add_subplot(gs[2,0])
plt.plot(y[:,0], color='k', lw=2, label="observed")
plt.plot(smoothed_data[:,0], color=colors[0], lw=1, label="smoothed")
plt.xlabel("Time")
plt.xlim(0, min(T, 500))
plt.ylabel("Observations")
plt.legend(loc="upper center", ncol=2)
plt.tight_layout()
plt.savefig("aux/demo_smooth.png")

plt.figure()
from pyhsmm.util.general import rle
z_rle = rle(z)
offset = 0
for k, dur in zip(*z_rle):
    plt.plot(x[offset:offset+dur,0], x[offset:offset+dur,1], color=colors[k])
    offset += dur

plt.xlabel("$x_1$")
plt.ylabel("$x_2$")
plt.title("Continuous Latent States")
plt.show()
Example #34
0
 def stateseqs_norep(self):
     return rle(self.z)[0]
Example #35
0
 def durations_censored(self):
     if self._durations_censored is None:
         self._stateseq_norep, self._durations_censored = rle(self.stateseq)
     return self._durations_censored
Example #36
0
 def durations_censored(self):
     if self._durations_censored is None:
         self._stateseq_norep, self._durations_censored = rle(self.stateseq)
     return self._durations_censored
Example #37
0
def plot_states(posteriormodel,data=None,powers_hat=None,cmap='summer'):
    plt.rcParams['image.cmap'] = cmap
    Nb_obs = posteriormodel.datas[0].shape[0]
    Nb_segment = int(np.ceil(Nb_obs / 1000.))
    
    fig = plt.figure(figsize=(16,(5+2)*Nb_segment))
    height_ratios = ([5]+[2])*Nb_segment
    gs = GridSpec(2*Nb_segment,2,width_ratios=[15, 1],height_ratios=height_ratios,wspace=0.05) 
    
    for t in np.arange(0,Nb_obs,1000):
        
        # We plot the data with the estimated one in red dashes 
        data_ax = plt.subplot(gs[(t//1000)*2,0])
        data_ax.plot(data,color='red',label='true powers')
        data_ax.plot(powers_hat,'b--',label='estimated powers')
        data_ax.set_xlim((t,t+1000))
        data_ax.legend(loc='best')
        
        rect = data_ax.axis() # xmin xmax ymin ymax
        #data_ax.vlines([c[1] for c in posteriormodel.states_list[0].changepoints[:-1]],rect[2],rect[3],color='black',linestyles='dashed')

    stateseq = posteriormodel.stateseqs[0]
    num_states = len(set(stateseq))
    # Colors
    state_usages = np.bincount(stateseq,minlength=num_states)
    freqs = state_usages / state_usages.sum()
    used_states = sorted(set(stateseq), key=lambda x: posteriormodel.obs_distns[x].mu)
    #used_states = sorted(set(stateseq), key=lambda x: np.mean([c.mu for c in posteriormodel.obs_distns[x].components]))
    unused_states = [idx for idx in range(num_states) if idx not in used_states]

    colorseq = np.linspace(0,1,num_states)
    state_colors = dict((idx, v) for idx, v in zip(used_states,colorseq))

    for state in unused_states:
        state_colors[state] = 1.
                
    # Colorbar
    unique_states = np.sort(list(set(stateseq)))
    n = len(unique_states)
    C_bar = np.atleast_2d([state_colors[state] for state in unique_states])
    x_bar, y_bar = np.array([0.,1.]), np.arange(n+1)/n
        
    # State sequence 
    stateseq_norep, durations = rle(stateseq)
    x, y = np.hstack((0,durations.cumsum())), np.array([0.,1.])
    C = np.atleast_2d([state_colors[state] for state in stateseq_norep])
                
    for t in np.arange(0,Nb_obs,1000):
            
        stateseq_ax = plt.subplot(gs[(t//1000)*2 + 1,0])
        
        colorbar_ax = plt.subplot(gs[(t//1000)*2 + 1,1])
        colorbar_ax.set_ylim((0.,1.))
        colorbar_ax.set_xlim((0.,1.))
        colorbar_ax.set_xticks([])
        colorbar_ax.yaxis.set_label_position("right")
        colorbar_ax.get_yaxis().set_ticks([])
                
        # State sequence 
        stateseq_ax.pcolormesh(x,y,C,cmap=cmap,vmin=0,vmax=1,alpha=0.5)
        stateseq_ax.set_ylim((0.,1.))
        stateseq_ax.set_xlim((t,t+1000))
        stateseq_ax.set_yticks([])
        stateseq_ax.set_title('Hidden states sequence')
            
        # We plot a colorbar to indicate the color of each state
        colorbar_ax.pcolormesh(x_bar,y_bar,C_bar.T,vmin=0,vmax=1,alpha=0.5,cmap=cmap)
        for j in range(n):
            colorbar_ax.text(.5, (1+j*2)/(2*n), str(unique_states[j]), ha='center', va='center')
        colorbar_ax.get_yaxis().labelpad = 15
        colorbar_ax.set_ylabel('States', rotation=270)
        
    gs.tight_layout(fig, rect=[0, 0.03, 1, 0.97]) 
    return fig