def logdet_symm_block_tridiag(H_diag, H_upper_diag): """ compute the log determinant of a positive definite, symmetric block tridiag matrix. Use the Kalman info filter to do so. Specifically, the KF computes the normalizer: log Z = 1/2 h^T J^{-1} h -1/2 log |J| +n/2 log 2 \pi We set h=0 to get -1/2 log |J| + n/2 log 2 \pi and from this we solve for log |J|. """ T, D, _ = H_diag.shape assert H_diag.ndim == 3 and H_diag.shape[2] == D assert H_upper_diag.shape == (T - 1, D, D) J_init = J_11 = J_22 = np.zeros((D, D)) h_init = h_1 = h_2 = np.zeros((D, )) log_Z_init = 0 J_21 = np.swapaxes(H_upper_diag, -1, -2) log_Z_pair = 0 J_node = H_diag h_node = np.zeros((T, D)) log_Z_node = 0 logZ, _, _ = kalman_info_filter(J_init, h_init, log_Z_init, J_11, J_21, J_22, h_1, h_2, log_Z_pair, J_node, h_node, log_Z_node) # logZ = -1/2 log |J| + n/2 log 2 \pi logdetJ = -2 * (logZ - (T * D) / 2 * np.log(2 * np.pi)) return logdetJ
def check_filters(A, B, C, D, mu_init, sigma_init, data): def info_normalizer(J, h): out = 0. out += 1 / 2. * h.dot(np.linalg.solve(J, h)) out -= 1 / 2. * np.linalg.slogdet(J)[1] out += h.shape[0] / 2. * np.log(2 * np.pi) return out ll, filtered_mus, filtered_sigmas = kalman_filter(mu_init, sigma_init, A, B.dot(B.T), C, D.dot(D.T), data) py_partial_ll = info_normalizer( *dense_infoparams(A, B, C, D, mu_init, sigma_init, data)) partial_ll, filtered_Js, filtered_hs = kalman_info_filter( *info_params(A, B, C, D, mu_init, sigma_init, data)) ll2 = partial_ll + LDSStates._extra_loglike_terms(A, B.dot( B.T), C, D.dot(D.T), mu_init, sigma_init, data) filtered_mus2 = [ np.linalg.solve(J, h) for J, h in zip(filtered_Js, filtered_hs) ] filtered_sigmas2 = [np.linalg.inv(J) for J in filtered_Js] assert all( np.allclose(mu1, mu2) for mu1, mu2 in zip(filtered_mus, filtered_mus2)) assert all( np.allclose(s1, s2) for s1, s2 in zip(filtered_sigmas, filtered_sigmas2)) assert np.isclose(partial_ll, py_partial_ll) assert np.isclose(ll, ll2)
def check_filters(A, B, C, D, mu_init, sigma_init, data): def info_normalizer(J,h): out = 0. out += 1/2. * h.dot(np.linalg.solve(J,h)) out -= 1/2. * np.linalg.slogdet(J)[1] out += h.shape[0]/2. * np.log(2*np.pi) return out ll, filtered_mus, filtered_sigmas = kalman_filter( mu_init, sigma_init, A, B.dot(B.T), C, D.dot(D.T), data) py_partial_ll = info_normalizer(*dense_infoparams( A, B, C, D, mu_init, sigma_init, data)) partial_ll, filtered_Js, filtered_hs = kalman_info_filter( *info_params(A, B, C, D, mu_init, sigma_init, data)) ll2 = partial_ll + extra_loglike_terms( A, B, C, D, mu_init, sigma_init, data) filtered_mus2 = [np.linalg.solve(J,h) for J, h in zip(filtered_Js, filtered_hs)] filtered_sigmas2 = [np.linalg.inv(J) for J in filtered_Js] assert all(np.allclose(mu1, mu2) for mu1, mu2 in zip(filtered_mus, filtered_mus2)) assert all(np.allclose(s1, s2) for s1, s2 in zip(filtered_sigmas, filtered_sigmas2)) assert np.isclose(partial_ll, py_partial_ll) assert np.isclose(ll, ll2)
def log_likelihood(self): if self._normalizer is None: self._normalizer, _, _ = kalman_info_filter(*self.info_params) # self._normalizer += self._info_extra_loglike_terms( # *self.extra_info_params, # isdiag=self.diagonal_noise) return self._normalizer
def check_filters(A, B, sigma_states, C, D, sigma_obs, mu_init, sigma_init, data, inputs): ll, filtered_mus, filtered_sigmas = kalman_filter(mu_init, sigma_init, A, B, sigma_states, C, D, sigma_obs, inputs, data) ll2, filtered_Js, filtered_hs = kalman_info_filter( *info_params(A, B, sigma_states, C, D, sigma_obs, mu_init, sigma_init, data, inputs)) filtered_mus2 = [ np.linalg.solve(J, h) for J, h in zip(filtered_Js, filtered_hs) ] filtered_sigmas2 = [np.linalg.inv(J) for J in filtered_Js] assert all( np.allclose(mu1, mu2) for mu1, mu2 in zip(filtered_mus, filtered_mus2)) assert all( np.allclose(s1, s2) for s1, s2 in zip(filtered_sigmas, filtered_sigmas2)) assert np.isclose(ll, ll2)
def test_lds_log_probability_perf(T=1000, D=10, N_iter=10): """ Compare performance of banded method vs message passing in pylds. """ print("Comparing methods for T={} D={}".format(T, D)) from pylds.lds_messages_interface import kalman_info_filter, kalman_filter # Convert LDS parameters into info form for pylds As, bs, Qi_sqrts, ms, Ri_sqrts = make_lds_parameters(T, D) Qis = np.matmul(Qi_sqrts, np.swapaxes(Qi_sqrts, -1, -2)) Ris = np.matmul(Ri_sqrts, np.swapaxes(Ri_sqrts, -1, -2)) x = npr.randn(T, D) print("Timing banded method") start = time.time() for itr in range(N_iter): lds_log_probability(x, As, bs, Qi_sqrts, ms, Ri_sqrts) stop = time.time() print("Time per iter: {:.4f}".format((stop - start) / N_iter)) # Compare to Kalman Filter mu_init = np.zeros(D) sigma_init = np.eye(D) Bs = np.ones((D, 1)) sigma_states = np.linalg.inv(Qis) Cs = np.eye(D) Ds = np.zeros((D, 1)) sigma_obs = np.linalg.inv(Ris) inputs = bs data = ms print("Timing PyLDS message passing (kalman_filter)") start = time.time() for itr in range(N_iter): kalman_filter(mu_init, sigma_init, np.concatenate([As, np.eye(D)[None, :, :]]), Bs, np.concatenate([sigma_states, np.eye(D)[None, :, :]]), Cs, Ds, sigma_obs, inputs, data) stop = time.time() print("Time per iter: {:.4f}".format((stop - start) / N_iter)) # Info form comparison J_init = np.zeros((D, D)) h_init = np.zeros(D) log_Z_init = 0 J_diag, J_lower_diag, h = convert_lds_to_block_tridiag(As, bs, Qi_sqrts, ms, Ri_sqrts) J_pair_21 = J_lower_diag J_pair_22 = J_diag[1:] J_pair_11 = J_diag[:-1] J_pair_11[1:] = 0 h_pair_2 = h[1:] h_pair_1 = h[:-1] h_pair_1[1:] = 0 log_Z_pair = 0 J_node = np.zeros((T, D, D)) h_node = np.zeros((T, D)) log_Z_node = 0 print("Timing PyLDS message passing (kalman_info_filter)") start = time.time() for itr in range(N_iter): kalman_info_filter(J_init, h_init, log_Z_init, J_pair_11, J_pair_21, J_pair_22, h_pair_1, h_pair_2, log_Z_pair, J_node, h_node, log_Z_node) stop = time.time() print("Time per iter: {:.4f}".format((stop - start) / N_iter))
def info_filter(self): self._normalizer, filtered_Js, filtered_hs = \ kalman_info_filter(*self.info_params) return filtered_Js, filtered_hs