def test_lds_sample(T=25, D=4): """ Test lds_sample correctness """ As, bs, Qi_sqrts, ms, Ri_sqrts = make_lds_parameters(T, D) J_diag, J_lower_diag, h = convert_lds_to_block_tridiag(As, bs, Qi_sqrts, ms, Ri_sqrts) # Convert to dense matrix J_full = np.zeros((T*D, T*D)) for t in range(T): J_full[t*D:(t+1)*D, t*D:(t+1)*D] = J_diag[t] for t in range(T-1): J_full[t*D:(t+1)*D, (t+1)*D:(t+2)*D] = J_lower_diag[t].T J_full[(t+1)*D:(t+2)*D, t*D:(t+1)*D] = J_lower_diag[t] z = npr.randn(T*D,) # Sample directly L = np.linalg.cholesky(J_full) xtrue = np.linalg.solve(L.T, z).reshape(T, D) xtrue += np.linalg.solve(J_full, h.reshape(T*D)).reshape(T, D) # Solve with the banded solver xtest = lds_sample(As, bs, Qi_sqrts, ms, Ri_sqrts, z=z) assert np.allclose(xtrue, xtest)
def _gibbs_sample_continuous_states(arhmm, discrete_states, indicators, mog, beta=1): """ Gibbs sample the continuous states given the discrete states, the indicators, and the mixture of Gaussian model. Beta scales the natural parameters of the MoG potential J = Sigma^{-1} h = Sigma^{-1} \mu = J \mu When beta = 0, the MoG potential is ignored. When beta = 1, the MoG potential is a normalized probability on x To get the mean parameters, set \mu = J^{-1} h. This is unchanged by beta. """ # Extract the dynamics parameters As = arhmm.observations.As[discrete_states[1:]] bs = arhmm.observations.bs[discrete_states[1:]] Qi_sqrts = np.linalg.cholesky(np.linalg.inv(arhmm.observations.Sigmas)) Qi_sqrts = Qi_sqrts[discrete_states[1:]] # Extract the observation potentials ms = mog.observations.mus[indicators] Ri_sqrts = np.linalg.cholesky(np.linalg.inv(mog.observations.Sigmas)) Ri_sqrts = Ri_sqrts[indicators] # Call the forward filter backward sample code return lds_sample(As, bs, Qi_sqrts, ms, 1e-4 * np.eye(arhmm.D) + np.sqrt(beta) * Ri_sqrts)
def sample(self): return [lds_sample(*prms) for prms in self.params]