def __init__(self, model, T=None, data=None, stateseq=None, generate=True, initialize_from_prior=False, initialize_to_noise=True): self.model = model self.data = data self.T = T if T else data.shape[0] self.data = data self._obs_scheme = ObservationScheme(p=self.p, T=self.T) self._normalizer = None if stateseq is not None: self.stateseq = stateseq elif generate: if data is not None and not (initialize_from_prior or initialize_to_noise): self.resample() else: if initialize_from_prior: self.generate_states() else: self.stateseq = np.random.normal(size=(self.T, self.n))
def __init__(self,model,T=None,data=None,stateseq=None, generate=True,initialize_from_prior=False,initialize_to_noise=True): self.model = model self.data = data self.T = T if T else data.shape[0] self.data = data self._obs_scheme = ObservationScheme(p=self.p, T=self.T) self._normalizer = None if stateseq is not None: self.stateseq = stateseq elif generate: if data is not None and not (initialize_from_prior or initialize_to_noise): self.resample() else: if initialize_from_prior: self.generate_states() else: self.stateseq = np.random.normal(size=(self.T,self.n))
class LDSStates(object): def __init__(self,model,T=None,data=None,stateseq=None, generate=True,initialize_from_prior=False,initialize_to_noise=True): self.model = model self.data = data self.T = T if T else data.shape[0] self.data = data self._obs_scheme = ObservationScheme(p=self.p, T=self.T) self._normalizer = None if stateseq is not None: self.stateseq = stateseq elif generate: if data is not None and not (initialize_from_prior or initialize_to_noise): self.resample() else: if initialize_from_prior: self.generate_states() else: self.stateseq = np.random.normal(size=(self.T,self.n)) ### basics def log_likelihood(self): return self._ll_diag() if self.diag_sigma_obs else self._ll() def _ll(self): if self._normalizer is None: self._normalizer, _, _ = kalman_filter_diagonal( self.mu_init, self.sigma_init, self.A, self.sigma_states, self.C, self.d, self.sigma_obs, self.data) return self._normalizer def _ll_diag(self): if self._normalizer is None: self._normalizer, _, _ = kalman_filter( self.mu_init, self.sigma_init, self.A, self.sigma_states, self.C, self.d, self.sigma_obs, self.data) return self._normalizer ### generation def generate_states(self): T, n = self.T, self.n stateseq = self.stateseq = np.empty((T,n),dtype='double') stateseq[0] = np.random.multivariate_normal(self.mu_init, self.sigma_init) chol = np.linalg.cholesky(self.sigma_states) randseq = np.random.randn(T-1,n).dot(chol.T) for t in xrange(1,T): stateseq[t] = self.A.dot(stateseq[t-1]) + randseq[t-1] return stateseq def sample_predictions(self, Tpred, states_noise, obs_noise): _, filtered_mus, filtered_sigmas = kalman_filter( self.mu_init, self.sigma_init, self.A, self.sigma_states, self.C, self.d, self.sigma_obs, self.data) init_mu = self.A.dot(filtered_mus[-1]) init_sigma = self.sigma_states + self.A.dot( filtered_sigmas[-1]).dot(self.A.T) randseq = np.zeros((Tpred-1, self.n)) if states_noise: L = np.linalg.cholesky(self.sigma_states) randseq += np.random.randn(Tpred-1, self.n).dot(L.T) states = np.empty((Tpred, self.n)) states[0] = np.random.multivariate_normal(init_mu, init_sigma) for t in xrange(1,Tpred): states[t] = self.A.dot(states[t-1]) + randseq[t-1] obs = states.dot(self.C.T) if obs_noise: L = np.linalg.cholesky(self.sigma_obs) obs += np.random.randn(Tpred, self.p).dot(L.T) return obs ### filtering def filter(self): self._filter_diag() if self.diag_sigma_obs else self._filter() def _filter(self): self._normalizer, self.filtered_mus, self.filtered_sigmas = \ kalman_filter( self.mu_init, self.sigma_init, self.A, self.sigma_states, self.C, self.d, self.sigma_obs, self.data) def _filter_diag(self): self._normalizer, self.filtered_mus, self.filtered_sigmas = \ kalman_filter_diagonal( self.mu_init, self.sigma_init, self.A, self.sigma_states, self.C, self.d, self.sigma_obs, self.data) ### resampling def resample(self): self._resample_diag() if self.diag_sigma_obs else self._resample() def _resample(self): self._normalizer, self.stateseq = filter_and_sample( self.mu_init, self.sigma_init, self.A, self.sigma_states, self.C, self.d, self.sigma_obs, self.data) def _resample_diag(self): self._normalizer, self.stateseq = filter_and_sample_diagonal( self.mu_init, self.sigma_init, self.A, self.sigma_states, self.C, self.d, self.sigma_obs, self.data) ### EM def E_step(self): E_xtp1_xtT = self._E_step_stitch() if self.diag_sigma_obs \ else self._E_step() self._set_expected_stats( self.smoothed_mus,self.smoothed_sigmas,E_xtp1_xtT) def _E_step(self): self._normalizer, self.smoothed_mus, self.smoothed_sigmas, \ E_xtp1_xtT = E_step( self.mu_init, self.sigma_init, self.A, self.sigma_states, self.C, self.d, self.sigma_obs, self.data) return E_xtp1_xtT def _E_step_diag(self): self._normalizer, self.smoothed_mus, self.smoothed_sigmas, \ E_xtp1_xtT = E_step_diagonal( self.mu_init, self.sigma_init, self.A, self.sigma_states, self.C, self.d, self.dsigma_obs, self.data) return E_xtp1_xtT def _E_step_stitch(self): sub_pops = self.obs_scheme.sub_pops obs_time = self.obs_scheme.obs_time obs_pops = self.obs_scheme.obs_pops smoothed_mus = np.zeros((self.T,self.n)) smoothed_sigmas = np.zeros((self.T,self.n,self.n)) self._normalizer = 0 mu_predicts = np.zeros((self.T,self.n)) sigma_predicts = np.zeros((self.T,self.n,self.n)) mu_init = self.mu_init sigma_init = self.sigma_init for i in range(self.obs_scheme.num_obstime): ts = range(obs_time[0]) if i==0 else range(obs_time[i-1],obs_time[i]) idx = sub_pops[obs_pops[0]] if i==0 else sub_pops[obs_pops[i]] normalizer, mu_predicts[ts,:], sigma_predicts[ts,:,:], \ smoothed_mus[ts,:], smoothed_sigmas[ts,:,:], \ mu_init, sigma_init = E_step_forward( mu_init, sigma_init, self.A, self.sigma_states, self.C[idx,:].copy(), self.d[idx].copy(), self.dsigma_obs[idx].copy(), self.data[np.ix_(ts,idx)].copy()) if np.any(np.isnan(smoothed_mus)): self.smoothed_mus = smoothed_mus self.mu_predicts = mu_predicts print(smoothed_mus) #tmp = smoothed_sigmas[ts,:].sum(0) #assert np.allclose(tmp, tmp.T) #tmp = sigma_predicts[ts,:].sum(0) #assert np.allclose(tmp, tmp.T) self._normalizer += normalizer self.smoothed_mus, self.smoothed_sigmas, E_xtp1_xtT =\ E_step_backward(self.A, self.sigma_states, mu_predicts, sigma_predicts, smoothed_mus, smoothed_sigmas) #tmp = self.smoothed_sigmas.sum(0) #assert np.allclose(tmp, tmp.T) return E_xtp1_xtT def _set_expected_stats(self,smoothed_mus,smoothed_sigmas,E_xtp1_xtT): assert not np.isnan(E_xtp1_xtT).any() assert not np.isnan(smoothed_mus).any() assert not np.isnan(smoothed_sigmas).any() T, EyyT, EyxT = self._set_expected_stats_data(smoothed_mus) ExxT, E_xt_xtT, E_xtp1_xtp1T, ExxTe = self._set_expected_stats_latents(smoothed_mus, smoothed_sigmas) # MN: make copy for debugging purposes self.E_addition_stats = E_xtp1_xtT.copy() E_xtp1_xtT = E_xtp1_xtT.sum(0) def is_symmetric(A): return np.allclose(A,A.T) assert is_symmetric(ExxT) assert is_symmetric(E_xt_xtT) assert is_symmetric(E_xtp1_xtp1T) Ex0, ExxT0 = self._set_expected_stats_initial(smoothed_mus, smoothed_sigmas) assert is_symmetric(ExxT0) self.E_emission_stats = np.array([EyyT, EyxT, ExxTe, T]) self.E_dynamics_stats = np.array([E_xtp1_xtp1T, E_xtp1_xtT, E_xt_xtT, self.T-1]) self.E_initial_stats = np.array([Ex0, ExxT0, 1]) def _set_expected_stats_data(self, smoothed_mus): sub_pops = self.obs_scheme.sub_pops obs_pops = self.obs_scheme.obs_pops obs_time = self.obs_scheme.obs_time idx_grp = self.obs_scheme.idx_grp obs_idx = self.obs_scheme.obs_idx data = self.data if obs_time.size > 1: T = np.zeros(len(idx_grp)) Ey = np.zeros(self.p) EyyT = np.zeros(self.p) EyxT = np.zeros((self.p, self.n)) for i in range(obs_time.size): T[obs_idx[i]] += obs_time[0] if i==0 else obs_time[i] - obs_time[i-1] idx = sub_pops[obs_pops[i]] if idx.size>0: ts = range(obs_time[0]) if i==0 else range(obs_time[i-1], obs_time[i]) Ey[idx] += np.sum(data[np.ix_(ts,idx)],0) ytmp = data[np.ix_(ts,idx)] EyyT[idx] += np.sum(ytmp*ytmp,0) EyxT[idx,:] += np.einsum('ni,nj->ij', self.data[np.ix_(ts,idx)], smoothed_mus[ts,:]) else: T = self.T Ey = self.data.sum(0) EyyT = np.sum(data*data,0) if self.diag_sigma_obs else data.T.dot(data) EyxT = data.T.dot(smoothed_mus) if self.diag_sigma_obs: EyxT = np.hstack((EyxT,np.atleast_2d(Ey).T)) return T, EyyT, EyxT def _set_expected_stats_latents(self, smoothed_mus, smoothed_sigmas): sub_pops = self.obs_scheme.sub_pops obs_pops = self.obs_scheme.obs_pops obs_time = self.obs_scheme.obs_time idx_grp = self.obs_scheme.idx_grp obs_idx = self.obs_scheme.obs_idx if obs_time.size > 1: T = np.zeros(len(idx_grp)) Ex = np.zeros(self.n) ExxT = np.zeros((self.n,self.n)) Exe = np.zeros((self.n, len(idx_grp))) ExxTj = np.zeros((self.n, self.n, len(idx_grp))) ExxTe = np.zeros((self.n+1, self.n+1, len(idx_grp))) for i in range(obs_time.size): ts = range(obs_time[0]) if i==0 else range(obs_time[i-1], obs_time[i]) x = smoothed_mus[ts,:] sx = np.sum(x,0) sxxT = smoothed_sigmas[ts,:,:].sum(0) + x.T.dot(x) for j in obs_idx[i]: Exe[:,j] += sx ExxTj[:,:,j] += sxxT T[j] += len(ts) Ex += sx ExxT += sxxT for j in range(len(idx_grp)): ExxTe[:,:,j] = blockarray([[ExxTj[:,:,j],np.atleast_2d(Exe[:,j]).T], [np.atleast_2d(Exe[:,j]),np.atleast_2d(T[j])]]) else: Ex = smoothed_mus.sum(0) ExxT = smoothed_sigmas.sum(0) + smoothed_mus.T.dot(smoothed_mus) E_xt_xtT = \ ExxT - (smoothed_sigmas[-1] + np.outer(smoothed_mus[-1],smoothed_mus[-1])) E_xtp1_xtp1T = \ ExxT - (smoothed_sigmas[0] + np.outer(smoothed_mus[0], smoothed_mus[0])) if self.model.emission_distn.affine: ExxT = blockarray([[ExxT,np.atleast_2d(Ex).T], [np.atleast_2d(Ex),np.atleast_2d(self.T)]]) if not obs_time.size > 1: ExxTe = ExxT return ExxT, E_xt_xtT, E_xtp1_xtp1T, ExxTe def _set_expected_stats_initial(self, smoothed_mus, smoothed_sigmas): return smoothed_mus[0], smoothed_sigmas[0] + np.outer(smoothed_mus[0],smoothed_mus[0]) # next two methods are for testing def info_E_step(self): data = self.data A, sigma_states, C, sigma_obs = \ self.A, self.sigma_states, self.C, self.sigma_obs J_init = np.linalg.inv(self.sigma_init) h_init = np.linalg.solve(self.sigma_init, self.mu_init) J_pair_11 = A.T.dot(np.linalg.solve(sigma_states, A)) J_pair_21 = -np.linalg.solve(sigma_states, A) J_pair_22 = np.linalg.inv(sigma_states) J_node = C.T.dot(np.linalg.solve(sigma_obs, C)) h_node = np.einsum('ik,ij,tj->tk', C, np.linalg.inv(sigma_obs), data) self._normalizer, self.smoothed_mus, self.smoothed_sigmas, \ E_xtp1_xtT = info_E_step( J_init,h_init,J_pair_11,J_pair_21,J_pair_22,J_node,h_node) self._normalizer += self._extra_loglike_terms( self.A, self.sigma_states, self.C, self.sigma_obs, self.mu_init, self.sigma_init, self.data) self._set_expected_stats( self.smoothed_mus,self.smoothed_sigmas,E_xtp1_xtT) @staticmethod def _extra_loglike_terms(A, BBT, C, DDT, mu_init, sigma_init, data): p, n = C.shape T = data.shape[0] out = 0. out -= 1./2 * mu_init.dot(np.linalg.solve(sigma_init,mu_init)) out -= 1./2 * np.linalg.slogdet(sigma_init)[1] out -= n/2. * np.log(2*np.pi) out -= (T-1)/2. * np.linalg.slogdet(BBT)[1] out -= (T-1)*n/2. * np.log(2*np.pi) out -= 1./2 * np.einsum('ij,ti,tj->',np.linalg.inv(DDT),data,data) out -= T/2. * np.linalg.slogdet(DDT)[1] out -= T*p/2 * np.log(2*np.pi) return out ### mean field def meanfieldupdate(self): J_init = np.linalg.inv(self.sigma_init) h_init = np.linalg.solve(self.sigma_init, self.mu_init) def get_params(distn): return mniw_expectedstats( *distn._natural_to_standard(distn.mf_natural_hypparam)) J_pair_22, J_pair_21, J_pair_11, logdet_pair = \ get_params(self.dynamics_distn) J_yy, J_yx, J_node, logdet_node = get_params(self.emission_distn) h_node = self.data.dot(J_yx) self._normalizer, self.smoothed_mus, self.smoothed_sigmas, \ E_xtp1_xtT = info_E_step( J_init,h_init,J_pair_11,-J_pair_21,J_pair_22,J_node,h_node) self._normalizer += self._info_extra_loglike_terms( J_init, h_init, logdet_pair, J_yy, logdet_node, self.data) self._set_expected_stats( self.smoothed_mus,self.smoothed_sigmas,E_xtp1_xtT) def get_vlb(self): if not hasattr(self,'_normalizer'): self.meanfieldupdate() # NOTE: sets self._normalizer return self._normalizer @staticmethod def _info_extra_loglike_terms( J_init, h_init, logdet_pair, J_yy, logdet_node, data): p, n, T = J_yy.shape[0], h_init.shape[0], data.shape[0] out = 0. out -= 1./2 * h_init.dot(np.linalg.solve(J_init, h_init)) out += 1./2 * np.linalg.slogdet(J_init)[1] out -= n/2. * np.log(2*np.pi) out += 1./2 * logdet_pair.sum() if isinstance(logdet_pair, np.ndarray) \ else (T-1)/2. * logdet_pair out -= (T-1)*n/2. * np.log(2*np.pi) contract = 'ij,ti,tj->' if J_yy.ndim == 2 else 'tij,ti,tj->' out -= 1./2 * np.einsum(contract, J_yy, data, data) out += 1./2 * logdet_node.sum() if isinstance(logdet_node, np.ndarray) \ else T/2. * logdet_node out -= T*p/2. * np.log(2*np.pi) return out # model properties @property def emission_distn(self): return self.model.emission_distn @property def dynamics_distn(self): return self.model.dynamics_distn @property def mu_init(self): return self.model.mu_init @property def sigma_init(self): return self.model.sigma_init @property def n(self): return self.model.n @property def p(self): return self.model.p @property def A(self): return self.model.A @property def sigma_states(self): return self.model.sigma_states @property def C(self): return self.model.C @property def d(self): return self.model.d @property def sigma_obs(self): return self.model.sigma_obs @property def diag_sigma_obs(self): return self.model.diag_sigma_obs @property def dsigma_obs(self): return self.model.dsigma_obs @property def strided_stateseq(self): return AR_striding(self.stateseq,1) @property def obs_scheme(self): return self._obs_scheme @obs_scheme.setter def obs_scheme(self, obs_scheme): self._obs_scheme = obs_scheme try: self._obs_scheme.check_obs_scheme() except: raise TypeError('observation scheme does not meet requirements')
class LDSStates(object): def __init__(self, model, T=None, data=None, stateseq=None, generate=True, initialize_from_prior=False, initialize_to_noise=True): self.model = model self.data = data self.T = T if T else data.shape[0] self.data = data self._obs_scheme = ObservationScheme(p=self.p, T=self.T) self._normalizer = None if stateseq is not None: self.stateseq = stateseq elif generate: if data is not None and not (initialize_from_prior or initialize_to_noise): self.resample() else: if initialize_from_prior: self.generate_states() else: self.stateseq = np.random.normal(size=(self.T, self.n)) ### basics def log_likelihood(self): return self._ll_diag() if self.diag_sigma_obs else self._ll() def _ll(self): if self._normalizer is None: self._normalizer, _, _ = kalman_filter_diagonal( self.mu_init, self.sigma_init, self.A, self.sigma_states, self.C, self.d, self.sigma_obs, self.data) return self._normalizer def _ll_diag(self): if self._normalizer is None: self._normalizer, _, _ = kalman_filter(self.mu_init, self.sigma_init, self.A, self.sigma_states, self.C, self.d, self.sigma_obs, self.data) return self._normalizer ### generation def generate_states(self): T, n = self.T, self.n stateseq = self.stateseq = np.empty((T, n), dtype='double') stateseq[0] = np.random.multivariate_normal(self.mu_init, self.sigma_init) chol = np.linalg.cholesky(self.sigma_states) randseq = np.random.randn(T - 1, n).dot(chol.T) for t in xrange(1, T): stateseq[t] = self.A.dot(stateseq[t - 1]) + randseq[t - 1] return stateseq def sample_predictions(self, Tpred, states_noise, obs_noise): _, filtered_mus, filtered_sigmas = kalman_filter( self.mu_init, self.sigma_init, self.A, self.sigma_states, self.C, self.d, self.sigma_obs, self.data) init_mu = self.A.dot(filtered_mus[-1]) init_sigma = self.sigma_states + self.A.dot(filtered_sigmas[-1]).dot( self.A.T) randseq = np.zeros((Tpred - 1, self.n)) if states_noise: L = np.linalg.cholesky(self.sigma_states) randseq += np.random.randn(Tpred - 1, self.n).dot(L.T) states = np.empty((Tpred, self.n)) states[0] = np.random.multivariate_normal(init_mu, init_sigma) for t in xrange(1, Tpred): states[t] = self.A.dot(states[t - 1]) + randseq[t - 1] obs = states.dot(self.C.T) if obs_noise: L = np.linalg.cholesky(self.sigma_obs) obs += np.random.randn(Tpred, self.p).dot(L.T) return obs ### filtering def filter(self): self._filter_diag() if self.diag_sigma_obs else self._filter() def _filter(self): self._normalizer, self.filtered_mus, self.filtered_sigmas = \ kalman_filter( self.mu_init, self.sigma_init, self.A, self.sigma_states, self.C, self.d, self.sigma_obs, self.data) def _filter_diag(self): self._normalizer, self.filtered_mus, self.filtered_sigmas = \ kalman_filter_diagonal( self.mu_init, self.sigma_init, self.A, self.sigma_states, self.C, self.d, self.sigma_obs, self.data) ### resampling def resample(self): self._resample_diag() if self.diag_sigma_obs else self._resample() def _resample(self): self._normalizer, self.stateseq = filter_and_sample( self.mu_init, self.sigma_init, self.A, self.sigma_states, self.C, self.d, self.sigma_obs, self.data) def _resample_diag(self): self._normalizer, self.stateseq = filter_and_sample_diagonal( self.mu_init, self.sigma_init, self.A, self.sigma_states, self.C, self.d, self.sigma_obs, self.data) ### EM def E_step(self): E_xtp1_xtT = self._E_step_stitch() if self.diag_sigma_obs \ else self._E_step() self._set_expected_stats(self.smoothed_mus, self.smoothed_sigmas, E_xtp1_xtT) def _E_step(self): self._normalizer, self.smoothed_mus, self.smoothed_sigmas, \ E_xtp1_xtT = E_step( self.mu_init, self.sigma_init, self.A, self.sigma_states, self.C, self.d, self.sigma_obs, self.data) return E_xtp1_xtT def _E_step_diag(self): self._normalizer, self.smoothed_mus, self.smoothed_sigmas, \ E_xtp1_xtT = E_step_diagonal( self.mu_init, self.sigma_init, self.A, self.sigma_states, self.C, self.d, self.dsigma_obs, self.data) return E_xtp1_xtT def _E_step_stitch(self): sub_pops = self.obs_scheme.sub_pops obs_time = self.obs_scheme.obs_time obs_pops = self.obs_scheme.obs_pops smoothed_mus = np.zeros((self.T, self.n)) smoothed_sigmas = np.zeros((self.T, self.n, self.n)) self._normalizer = 0 mu_predicts = np.zeros((self.T, self.n)) sigma_predicts = np.zeros((self.T, self.n, self.n)) mu_init = self.mu_init sigma_init = self.sigma_init for i in range(self.obs_scheme.num_obstime): ts = range(obs_time[0]) if i == 0 else range( obs_time[i - 1], obs_time[i]) idx = sub_pops[obs_pops[0]] if i == 0 else sub_pops[obs_pops[i]] normalizer, mu_predicts[ts,:], sigma_predicts[ts,:,:], \ smoothed_mus[ts,:], smoothed_sigmas[ts,:,:], \ mu_init, sigma_init = E_step_forward( mu_init, sigma_init, self.A, self.sigma_states, self.C[idx,:].copy(), self.d[idx].copy(), self.dsigma_obs[idx].copy(), self.data[np.ix_(ts,idx)].copy()) if np.any(np.isnan(smoothed_mus)): self.smoothed_mus = smoothed_mus self.mu_predicts = mu_predicts print(smoothed_mus) #tmp = smoothed_sigmas[ts,:].sum(0) #assert np.allclose(tmp, tmp.T) #tmp = sigma_predicts[ts,:].sum(0) #assert np.allclose(tmp, tmp.T) self._normalizer += normalizer self.smoothed_mus, self.smoothed_sigmas, E_xtp1_xtT =\ E_step_backward(self.A, self.sigma_states, mu_predicts, sigma_predicts, smoothed_mus, smoothed_sigmas) #tmp = self.smoothed_sigmas.sum(0) #assert np.allclose(tmp, tmp.T) return E_xtp1_xtT def _set_expected_stats(self, smoothed_mus, smoothed_sigmas, E_xtp1_xtT): assert not np.isnan(E_xtp1_xtT).any() assert not np.isnan(smoothed_mus).any() assert not np.isnan(smoothed_sigmas).any() T, EyyT, EyxT = self._set_expected_stats_data(smoothed_mus) ExxT, E_xt_xtT, E_xtp1_xtp1T, ExxTe = self._set_expected_stats_latents( smoothed_mus, smoothed_sigmas) # MN: make copy for debugging purposes self.E_addition_stats = E_xtp1_xtT.copy() E_xtp1_xtT = E_xtp1_xtT.sum(0) def is_symmetric(A): return np.allclose(A, A.T) assert is_symmetric(ExxT) assert is_symmetric(E_xt_xtT) assert is_symmetric(E_xtp1_xtp1T) Ex0, ExxT0 = self._set_expected_stats_initial(smoothed_mus, smoothed_sigmas) assert is_symmetric(ExxT0) self.E_emission_stats = np.array([EyyT, EyxT, ExxTe, T]) self.E_dynamics_stats = np.array( [E_xtp1_xtp1T, E_xtp1_xtT, E_xt_xtT, self.T - 1]) self.E_initial_stats = np.array([Ex0, ExxT0, 1]) def _set_expected_stats_data(self, smoothed_mus): sub_pops = self.obs_scheme.sub_pops obs_pops = self.obs_scheme.obs_pops obs_time = self.obs_scheme.obs_time idx_grp = self.obs_scheme.idx_grp obs_idx = self.obs_scheme.obs_idx data = self.data if obs_time.size > 1: T = np.zeros(len(idx_grp)) Ey = np.zeros(self.p) EyyT = np.zeros(self.p) EyxT = np.zeros((self.p, self.n)) for i in range(obs_time.size): T[obs_idx[i]] += obs_time[ 0] if i == 0 else obs_time[i] - obs_time[i - 1] idx = sub_pops[obs_pops[i]] if idx.size > 0: ts = range(obs_time[0]) if i == 0 else range( obs_time[i - 1], obs_time[i]) Ey[idx] += np.sum(data[np.ix_(ts, idx)], 0) ytmp = data[np.ix_(ts, idx)] EyyT[idx] += np.sum(ytmp * ytmp, 0) EyxT[idx, :] += np.einsum('ni,nj->ij', self.data[np.ix_(ts, idx)], smoothed_mus[ts, :]) else: T = self.T Ey = self.data.sum(0) EyyT = np.sum(data * data, 0) if self.diag_sigma_obs else data.T.dot(data) EyxT = data.T.dot(smoothed_mus) if self.diag_sigma_obs: EyxT = np.hstack((EyxT, np.atleast_2d(Ey).T)) return T, EyyT, EyxT def _set_expected_stats_latents(self, smoothed_mus, smoothed_sigmas): sub_pops = self.obs_scheme.sub_pops obs_pops = self.obs_scheme.obs_pops obs_time = self.obs_scheme.obs_time idx_grp = self.obs_scheme.idx_grp obs_idx = self.obs_scheme.obs_idx if obs_time.size > 1: T = np.zeros(len(idx_grp)) Ex = np.zeros(self.n) ExxT = np.zeros((self.n, self.n)) Exe = np.zeros((self.n, len(idx_grp))) ExxTj = np.zeros((self.n, self.n, len(idx_grp))) ExxTe = np.zeros((self.n + 1, self.n + 1, len(idx_grp))) for i in range(obs_time.size): ts = range(obs_time[0]) if i == 0 else range( obs_time[i - 1], obs_time[i]) x = smoothed_mus[ts, :] sx = np.sum(x, 0) sxxT = smoothed_sigmas[ts, :, :].sum(0) + x.T.dot(x) for j in obs_idx[i]: Exe[:, j] += sx ExxTj[:, :, j] += sxxT T[j] += len(ts) Ex += sx ExxT += sxxT for j in range(len(idx_grp)): ExxTe[:, :, j] = blockarray( [[ExxTj[:, :, j], np.atleast_2d(Exe[:, j]).T], [np.atleast_2d(Exe[:, j]), np.atleast_2d(T[j])]]) else: Ex = smoothed_mus.sum(0) ExxT = smoothed_sigmas.sum(0) + smoothed_mus.T.dot(smoothed_mus) E_xt_xtT = \ ExxT - (smoothed_sigmas[-1] + np.outer(smoothed_mus[-1],smoothed_mus[-1])) E_xtp1_xtp1T = \ ExxT - (smoothed_sigmas[0] + np.outer(smoothed_mus[0], smoothed_mus[0])) if self.model.emission_distn.affine: ExxT = blockarray([[ExxT, np.atleast_2d(Ex).T], [np.atleast_2d(Ex), np.atleast_2d(self.T)]]) if not obs_time.size > 1: ExxTe = ExxT return ExxT, E_xt_xtT, E_xtp1_xtp1T, ExxTe def _set_expected_stats_initial(self, smoothed_mus, smoothed_sigmas): return smoothed_mus[0], smoothed_sigmas[0] + np.outer( smoothed_mus[0], smoothed_mus[0]) # next two methods are for testing def info_E_step(self): data = self.data A, sigma_states, C, sigma_obs = \ self.A, self.sigma_states, self.C, self.sigma_obs J_init = np.linalg.inv(self.sigma_init) h_init = np.linalg.solve(self.sigma_init, self.mu_init) J_pair_11 = A.T.dot(np.linalg.solve(sigma_states, A)) J_pair_21 = -np.linalg.solve(sigma_states, A) J_pair_22 = np.linalg.inv(sigma_states) J_node = C.T.dot(np.linalg.solve(sigma_obs, C)) h_node = np.einsum('ik,ij,tj->tk', C, np.linalg.inv(sigma_obs), data) self._normalizer, self.smoothed_mus, self.smoothed_sigmas, \ E_xtp1_xtT = info_E_step( J_init,h_init,J_pair_11,J_pair_21,J_pair_22,J_node,h_node) self._normalizer += self._extra_loglike_terms( self.A, self.sigma_states, self.C, self.sigma_obs, self.mu_init, self.sigma_init, self.data) self._set_expected_stats(self.smoothed_mus, self.smoothed_sigmas, E_xtp1_xtT) @staticmethod def _extra_loglike_terms(A, BBT, C, DDT, mu_init, sigma_init, data): p, n = C.shape T = data.shape[0] out = 0. out -= 1. / 2 * mu_init.dot(np.linalg.solve(sigma_init, mu_init)) out -= 1. / 2 * np.linalg.slogdet(sigma_init)[1] out -= n / 2. * np.log(2 * np.pi) out -= (T - 1) / 2. * np.linalg.slogdet(BBT)[1] out -= (T - 1) * n / 2. * np.log(2 * np.pi) out -= 1. / 2 * np.einsum('ij,ti,tj->', np.linalg.inv(DDT), data, data) out -= T / 2. * np.linalg.slogdet(DDT)[1] out -= T * p / 2 * np.log(2 * np.pi) return out ### mean field def meanfieldupdate(self): J_init = np.linalg.inv(self.sigma_init) h_init = np.linalg.solve(self.sigma_init, self.mu_init) def get_params(distn): return mniw_expectedstats( *distn._natural_to_standard(distn.mf_natural_hypparam)) J_pair_22, J_pair_21, J_pair_11, logdet_pair = \ get_params(self.dynamics_distn) J_yy, J_yx, J_node, logdet_node = get_params(self.emission_distn) h_node = self.data.dot(J_yx) self._normalizer, self.smoothed_mus, self.smoothed_sigmas, \ E_xtp1_xtT = info_E_step( J_init,h_init,J_pair_11,-J_pair_21,J_pair_22,J_node,h_node) self._normalizer += self._info_extra_loglike_terms( J_init, h_init, logdet_pair, J_yy, logdet_node, self.data) self._set_expected_stats(self.smoothed_mus, self.smoothed_sigmas, E_xtp1_xtT) def get_vlb(self): if not hasattr(self, '_normalizer'): self.meanfieldupdate() # NOTE: sets self._normalizer return self._normalizer @staticmethod def _info_extra_loglike_terms(J_init, h_init, logdet_pair, J_yy, logdet_node, data): p, n, T = J_yy.shape[0], h_init.shape[0], data.shape[0] out = 0. out -= 1. / 2 * h_init.dot(np.linalg.solve(J_init, h_init)) out += 1. / 2 * np.linalg.slogdet(J_init)[1] out -= n / 2. * np.log(2 * np.pi) out += 1./2 * logdet_pair.sum() if isinstance(logdet_pair, np.ndarray) \ else (T-1)/2. * logdet_pair out -= (T - 1) * n / 2. * np.log(2 * np.pi) contract = 'ij,ti,tj->' if J_yy.ndim == 2 else 'tij,ti,tj->' out -= 1. / 2 * np.einsum(contract, J_yy, data, data) out += 1./2 * logdet_node.sum() if isinstance(logdet_node, np.ndarray) \ else T/2. * logdet_node out -= T * p / 2. * np.log(2 * np.pi) return out # model properties @property def emission_distn(self): return self.model.emission_distn @property def dynamics_distn(self): return self.model.dynamics_distn @property def mu_init(self): return self.model.mu_init @property def sigma_init(self): return self.model.sigma_init @property def n(self): return self.model.n @property def p(self): return self.model.p @property def A(self): return self.model.A @property def sigma_states(self): return self.model.sigma_states @property def C(self): return self.model.C @property def d(self): return self.model.d @property def sigma_obs(self): return self.model.sigma_obs @property def diag_sigma_obs(self): return self.model.diag_sigma_obs @property def dsigma_obs(self): return self.model.dsigma_obs @property def strided_stateseq(self): return AR_striding(self.stateseq, 1) @property def obs_scheme(self): return self._obs_scheme @obs_scheme.setter def obs_scheme(self, obs_scheme): self._obs_scheme = obs_scheme try: self._obs_scheme.check_obs_scheme() except: raise TypeError('observation scheme does not meet requirements')