def resample_states(self,**kwargs): # NOTE: kwargs is just to absorb any multiprocessing stuff # TODO only use this when the number/size of sequences warrant it from messages import resample_arhmm assert self.obs_distns[0].D_out > 1 if len(self.states_list) > 0: stateseqs = [np.empty(s.T,dtype='int32') for s in self.states_list] params, normalizers = map(np.array,zip(*[self._param_matrix(o) for o in self.obs_distns])) params, normalizers = params.repeat(s.rs,axis=0), normalizers.repeat(s.rs,axis=0) stats, _, loglikes = resample_arhmm( [s.hmm_pi_0.astype(self.dtype) for s in self.states_list], [s.hmm_trans_matrix.astype(self.dtype) for s in self.states_list], params.astype(self.dtype), normalizers.astype(self.dtype), [undo_AR_striding(s.data,self.nlags) for s in self.states_list], stateseqs, [np.random.uniform(size=s.T).astype(self.dtype) for s in self.states_list], self.alphans) for s, stateseq, loglike in zip(self.states_list,stateseqs,loglikes): s.stateseq = stateseq s._map_states() s._normalizer = loglike starts, ends = cumsum(s.rs,strict=True), cumsum(s.rs,strict=False) stats = map(np.array,stats) stats = [sum(stats[start:end]) for start, end in zip(starts,ends)] self._obs_stats = stats else: self._obs_stats = None
def resample_states(self, **kwargs): from messages import resample_arhmm if len(self.states_list) > 0: stateseqs = [ np.empty(s.T, dtype='int32') for s in self.states_list ] params, normalizers = map( np.array, zip(*[self._param_matrix(o) for o in self.obs_distns])) stats, transcounts, loglikes = resample_arhmm( [s.pi_0.astype(self.dtype) for s in self.states_list], [s.trans_matrix.astype(self.dtype) for s in self.states_list], params.astype(self.dtype), normalizers.astype(self.dtype), [ undo_AR_striding(s.data, self.nlags) for s in self.states_list ], stateseqs, [ np.random.uniform(size=s.T).astype(self.dtype) for s in self.states_list ], self.alphans) for s, stateseq, loglike in zip(self.states_list, stateseqs, loglikes): s.stateseq = stateseq s._normalizer = loglike self._obs_stats = stats self._transcounts = transcounts else: self._obs_stats = None self._transcounts = []
def resample_states(self,**kwargs): from messages import resample_arhmm if len(self.states_list) > 0: stateseqs = [np.empty(s.T,dtype='int32') for s in self.states_list] params, normalizers = map(np.array,zip(*[self._param_matrix(o) for o in self.obs_distns])) stats, transcounts, loglikes = resample_arhmm( [s.pi_0.astype(self.dtype) for s in self.states_list], [s.trans_matrix.astype(self.dtype) for s in self.states_list], params.astype(self.dtype), normalizers.astype(self.dtype), [undo_AR_striding(s.data,self.nlags) for s in self.states_list], stateseqs, [np.random.uniform(size=s.T).astype(self.dtype) for s in self.states_list], self.alphans) for s, stateseq, loglike in zip(self.states_list,stateseqs,loglikes): s.stateseq = stateseq s._normalizer = loglike self._obs_stats = stats self._transcounts = transcounts else: self._obs_stats = None self._transcounts = []