def test_lds_mean(T=25, D=4): """ Test lds_mean correctness """ As, bs, Qi_sqrts, ms, Ri_sqrts = make_lds_parameters(T, D) J_diag, J_lower_diag, h = convert_lds_to_block_tridiag(As, bs, Qi_sqrts, ms, Ri_sqrts) # Convert to dense matrix J_full = np.zeros((T*D, T*D)) for t in range(T): J_full[t*D:(t+1)*D, t*D:(t+1)*D] = J_diag[t] for t in range(T-1): J_full[t*D:(t+1)*D, (t+1)*D:(t+2)*D] = J_lower_diag[t].T J_full[(t+1)*D:(t+2)*D, t*D:(t+1)*D] = J_lower_diag[t] Sigma = np.linalg.inv(J_full) mu_true = Sigma.dot(h.ravel()).reshape((T, D)) # Solve with the banded solver mu_test = lds_mean(As, bs, Qi_sqrts, ms, Ri_sqrts) assert np.allclose(mu_true, mu_test)
def mean(self): return [lds_mean(*prms) for prms in self.params]