示例#1
0
def test_lds_sample(T=25, D=4):
    """
    Test lds_sample correctness
    """
    As, bs, Qi_sqrts, ms, Ri_sqrts = make_lds_parameters(T, D)
    J_diag, J_lower_diag, h = convert_lds_to_block_tridiag(As, bs, Qi_sqrts, ms, Ri_sqrts)

    # Convert to dense matrix
    J_full = np.zeros((T*D, T*D))
    for t in range(T):
        J_full[t*D:(t+1)*D, t*D:(t+1)*D] = J_diag[t]

    for t in range(T-1):
        J_full[t*D:(t+1)*D, (t+1)*D:(t+2)*D] = J_lower_diag[t].T
        J_full[(t+1)*D:(t+2)*D, t*D:(t+1)*D] = J_lower_diag[t]
    
    z = npr.randn(T*D,)

    # Sample directly
    L = np.linalg.cholesky(J_full)
    xtrue = np.linalg.solve(L.T, z).reshape(T, D) 
    xtrue += np.linalg.solve(J_full, h.reshape(T*D)).reshape(T, D)

    # Solve with the banded solver
    xtest = lds_sample(As, bs, Qi_sqrts, ms, Ri_sqrts, z=z)

    assert np.allclose(xtrue, xtest)
示例#2
0
def _gibbs_sample_continuous_states(arhmm,
                                    discrete_states,
                                    indicators,
                                    mog,
                                    beta=1):
    """
    Gibbs sample the continuous states given the discrete states, 
    the indicators, and the mixture of Gaussian model.
    
    Beta scales the natural parameters of the MoG potential
    J = Sigma^{-1}
    h = Sigma^{-1} \mu = J \mu
    
    When beta = 0, the MoG potential is ignored.
    When beta = 1, the MoG potential is a normalized probability on x
    
    To get the mean parameters, set \mu = J^{-1} h. This is unchanged by beta. 
    
    """
    # Extract the dynamics parameters
    As = arhmm.observations.As[discrete_states[1:]]
    bs = arhmm.observations.bs[discrete_states[1:]]
    Qi_sqrts = np.linalg.cholesky(np.linalg.inv(arhmm.observations.Sigmas))
    Qi_sqrts = Qi_sqrts[discrete_states[1:]]

    # Extract the observation potentials
    ms = mog.observations.mus[indicators]
    Ri_sqrts = np.linalg.cholesky(np.linalg.inv(mog.observations.Sigmas))
    Ri_sqrts = Ri_sqrts[indicators]

    # Call the forward filter backward sample code
    return lds_sample(As, bs, Qi_sqrts, ms,
                      1e-4 * np.eye(arhmm.D) + np.sqrt(beta) * Ri_sqrts)
示例#3
0
 def sample(self):
     return [lds_sample(*prms) for prms in self.params]