Example #1
0
def test_lds_mean(T=25, D=4):
    """
    Test lds_mean correctness
    """
    As, bs, Qi_sqrts, ms, Ri_sqrts = make_lds_parameters(T, D)
    J_diag, J_lower_diag, h = convert_lds_to_block_tridiag(As, bs, Qi_sqrts, ms, Ri_sqrts)

    # Convert to dense matrix
    J_full = np.zeros((T*D, T*D))
    for t in range(T):
        J_full[t*D:(t+1)*D, t*D:(t+1)*D] = J_diag[t]

    for t in range(T-1):
        J_full[t*D:(t+1)*D, (t+1)*D:(t+2)*D] = J_lower_diag[t].T
        J_full[(t+1)*D:(t+2)*D, t*D:(t+1)*D] = J_lower_diag[t]
    
    Sigma = np.linalg.inv(J_full)
    mu_true = Sigma.dot(h.ravel()).reshape((T, D))
    
    
    # Solve with the banded solver
    mu_test = lds_mean(As, bs, Qi_sqrts, ms, Ri_sqrts)

    assert np.allclose(mu_true, mu_test)
Example #2
0
 def mean(self):
     return [lds_mean(*prms) for prms in self.params]