def resample_states(self,**kwargs):
        # NOTE: kwargs is just to absorb any multiprocessing stuff
        # TODO only use this when the number/size of sequences warrant it
        from messages import resample_arhmm
        assert self.obs_distns[0].D_out > 1
        if len(self.states_list) > 0:
            stateseqs = [np.empty(s.T,dtype='int32') for s in self.states_list]
            params, normalizers = map(np.array,zip(*[self._param_matrix(o) for o in self.obs_distns]))
            params, normalizers = params.repeat(s.rs,axis=0), normalizers.repeat(s.rs,axis=0)
            stats, _, loglikes = resample_arhmm(
                    [s.hmm_pi_0.astype(self.dtype) for s in self.states_list],
                    [s.hmm_trans_matrix.astype(self.dtype) for s in self.states_list],
                    params.astype(self.dtype), normalizers.astype(self.dtype),
                    [undo_AR_striding(s.data,self.nlags) for s in self.states_list],
                    stateseqs,
                    [np.random.uniform(size=s.T).astype(self.dtype) for s in self.states_list],
                    self.alphans)
            for s, stateseq, loglike in zip(self.states_list,stateseqs,loglikes):
                s.stateseq = stateseq
                s._map_states()
                s._normalizer = loglike

            starts, ends = cumsum(s.rs,strict=True), cumsum(s.rs,strict=False)
            stats = map(np.array,stats)
            stats = [sum(stats[start:end]) for start, end in zip(starts,ends)]

            self._obs_stats = stats
        else:
            self._obs_stats = None
Beispiel #2
0
    def resample_states(self, **kwargs):
        from messages import resample_arhmm
        if len(self.states_list) > 0:
            stateseqs = [
                np.empty(s.T, dtype='int32') for s in self.states_list
            ]
            params, normalizers = map(
                np.array,
                zip(*[self._param_matrix(o) for o in self.obs_distns]))
            stats, transcounts, loglikes = resample_arhmm(
                [s.pi_0.astype(self.dtype) for s in self.states_list],
                [s.trans_matrix.astype(self.dtype) for s in self.states_list],
                params.astype(self.dtype), normalizers.astype(self.dtype), [
                    undo_AR_striding(s.data, self.nlags)
                    for s in self.states_list
                ], stateseqs, [
                    np.random.uniform(size=s.T).astype(self.dtype)
                    for s in self.states_list
                ], self.alphans)
            for s, stateseq, loglike in zip(self.states_list, stateseqs,
                                            loglikes):
                s.stateseq = stateseq
                s._normalizer = loglike

            self._obs_stats = stats
            self._transcounts = transcounts
        else:
            self._obs_stats = None
            self._transcounts = []
    def _generate_obs(self,s):
        if s.data is None:
            # generating brand new data sequence
            data = np.zeros((s.T+self.nlags,self.D))

            if hasattr(self,'prefix'):
                data[:self.nlags] = self.prefix
            else:
                data[:self.nlags] = self.init_emission_distn\
                    .rvs().reshape(data[:self.nlags].shape)

            for idx, state in enumerate(s.stateseq):
                data[idx+self.nlags] = \
                    self.obs_distns[state].rvs(lagged_data=data[idx:idx+self.nlags])

            s.data = AR_striding(data,self.nlags)

        else:
            # filling in missing data
            data = undo_AR_striding(s.data,self.nlags)

            # TODO should sample from init_emission_distn if there are nans in
            # data[:self.nlags]
            assert not np.isnan(data[:self.nlags]).any(), "can't have missing data (nans) in prefix"

            nan_idx, = np.where(np.isnan(data[self.nlags:]).any(1))
            for idx, state in zip(nan_idx,s.stateseq[nan_idx]):
                data[idx+self.nlags] = \
                    self.obs_distns[state].rvs(lagged_data=data[idx:idx+self.nlags])

        return data
Beispiel #4
0
    def _generate_obs(self, s):
        if s.data is None:
            # generating brand new data sequence
            data = np.zeros((s.T + self.nlags, self.D))

            if hasattr(self, 'prefix'):
                data[:self.nlags] = self.prefix
            else:
                data[:self.nlags] = self.init_emission_distn\
                    .rvs().reshape(data[:self.nlags].shape)

            for idx, state in enumerate(s.stateseq):
                data[idx+self.nlags] = \
                    self.obs_distns[state].rvs(lagged_data=data[idx:idx+self.nlags])

            s.data = AR_striding(data, self.nlags)

        else:
            # filling in missing data
            data = undo_AR_striding(s.data, self.nlags)

            # TODO should sample from init_emission_distn if there are nans in
            # data[:self.nlags]
            assert not np.isnan(data[:self.nlags]).any(
            ), "can't have missing data (nans) in prefix"

            nan_idx, = np.where(np.isnan(data[self.nlags:]).any(1))
            for idx, state in zip(nan_idx, s.stateseq[nan_idx]):
                data[idx+self.nlags] = \
                    self.obs_distns[state].rvs(lagged_data=data[idx:idx+self.nlags])

        return data
Beispiel #5
0
    def _plot_stateseq_data_values(self,
                                   s,
                                   ax,
                                   state_colors,
                                   plot_slice,
                                   update,
                                   data=None):
        data = undo_AR_striding(s.data, self.nlags)[plot_slice]
        stateseq = np.concatenate(
            (np.repeat(s.stateseq[0],
                       self.nlags), s.stateseq[:-1]))[plot_slice]
        colorseq = np.tile(
            np.array([state_colors[state] for state in stateseq]),
            data.shape[1])

        if update and hasattr(s, '_data_lc'):
            s._data_lc.set_array(colorseq)
        else:
            ts = np.arange(data.shape[0])
            segments = np.vstack([
                AR_striding(np.hstack((ts[:, None], scalarseq[:, None])),
                            1).reshape(-1, 2, 2) for scalarseq in data.T
            ])
            lc = s._data_lc = LineCollection(segments)
            lc.set_array(colorseq)
            lc.set_linewidth(0.5)
            ax.add_collection(lc)

        return s._data_lc
Beispiel #6
0
 def _reshape_data(self,data):
     assert isinstance(data,np.ndarray)
     if len(data) > 0:
         data = AR_striding(
                 scipy.linalg.solve_triangular(
                     self.Sigma_chol,
                     undo_AR_striding(data,self.nlags).T,
                     lower=True).T,
                 nlags=self.nlags)
     return data
Beispiel #7
0
    def plot_observations(self,colors=None,states_objs=None):
        # TODO makethis pcolor background to keep track colors
        if colors is None:
            colors = self._get_colors()
        if states_objs is None:
            states_objs = self.states_list

        cmap = cm.get_cmap()

        for s in states_objs:
            data = undo_AR_striding(s.data,self.nlags)

            for state,start,dur in zip(s.stateseq_norep,np.concatenate(((0,),s.durations.cumsum()))[:-1],s.durations):
                plt.plot(np.arange(start,start+data[start:start+dur].shape[0]),data[start:start+dur],color=cmap(colors[state]))
    def _plot_stateseq_data_values(self,s,ax,state_colors,plot_slice,update,data=None):
        data = undo_AR_striding(s.data,self.nlags)[plot_slice]
        stateseq = np.concatenate((np.repeat(s.stateseq[0],self.nlags),s.stateseq[:-1]))[plot_slice]
        colorseq = np.tile(np.array([state_colors[state] for state in stateseq]),data.shape[1])

        if update and hasattr(s,'_data_lc'):
            s._data_lc.set_array(colorseq)
        else:
            ts = np.arange(data.shape[0])
            segments = np.vstack(
                [AR_striding(np.hstack((ts[:,None], scalarseq[:,None])),1).reshape(-1,2,2)
                    for scalarseq in data.T])
            lc = s._data_lc = LineCollection(segments)
            lc.set_array(colorseq)
            lc.set_linewidth(0.5)
            ax.add_collection(lc)

        return s._data_lc
Beispiel #9
0
    def plot_observations(self,colors=None,states_objs=None):
        if colors is None:
            colors = self._get_colors()
        if states_objs is None:
            states_objs = self.states_list

        cmap = cm.get_cmap()

        for s in states_objs:
            data = undo_AR_striding(s.data,self.nlags)

            stateseq_norep, durs = rle(s.stateseq)
            starts = np.concatenate(((0,),durs.cumsum()))
            for state,start,dur in zip(stateseq_norep,starts,durs):
                plt.plot(
                        np.arange(start,start+data[start:start+dur].shape[0]),
                        data[start:start+dur],
                        color=cmap(colors[state]))
            plt.xlim(0,s.T-1)
Beispiel #10
0
    def resample_states(self,**kwargs):
        from messages import resample_arhmm
        if len(self.states_list) > 0:
            stateseqs = [np.empty(s.T,dtype='int32') for s in self.states_list]
            params, normalizers = map(np.array,zip(*[self._param_matrix(o) for o in self.obs_distns]))
            stats, transcounts, loglikes = resample_arhmm(
                    [s.pi_0.astype(self.dtype) for s in self.states_list],
                    [s.trans_matrix.astype(self.dtype) for s in self.states_list],
                    params.astype(self.dtype), normalizers.astype(self.dtype),
                    [undo_AR_striding(s.data,self.nlags) for s in self.states_list],
                    stateseqs,
                    [np.random.uniform(size=s.T).astype(self.dtype) for s in self.states_list],
                    self.alphans)
            for s, stateseq, loglike in zip(self.states_list,stateseqs,loglikes):
                s.stateseq = stateseq
                s._normalizer = loglike

            self._obs_stats = stats
            self._transcounts = transcounts
        else:
            self._obs_stats = None
            self._transcounts = []
Beispiel #11
0
 def datas(self):
     return [undo_AR_striding(s.data,self.nlags) for s in self.states_list]
Beispiel #12
0
 def _get_joblib_pair(self,s):
     return (undo_AR_striding(s.data,self.nlags),s._kwargs)
Beispiel #13
0
 def datas(self):
     return [undo_AR_striding(s.data, self.nlags) for s in self.states_list]
Beispiel #14
0
 def _get_joblib_pair(self, s):
     return (undo_AR_striding(s.data, self.nlags), s._kwargs)