Пример #1
0
def logdet_symm_block_tridiag(H_diag, H_upper_diag):
    """
    compute the log determinant of a positive definite,
    symmetric block tridiag matrix.  Use the Kalman
    info filter to do so.  Specifically, the KF computes
    the normalizer:

        log Z = 1/2 h^T J^{-1} h -1/2 log |J| +n/2 log 2 \pi

    We set h=0 to get -1/2 log |J| + n/2 log 2 \pi and from
    this we solve for log |J|.
    """
    T, D, _ = H_diag.shape
    assert H_diag.ndim == 3 and H_diag.shape[2] == D
    assert H_upper_diag.shape == (T - 1, D, D)

    J_init = J_11 = J_22 = np.zeros((D, D))
    h_init = h_1 = h_2 = np.zeros((D, ))
    log_Z_init = 0

    J_21 = np.swapaxes(H_upper_diag, -1, -2)
    log_Z_pair = 0

    J_node = H_diag
    h_node = np.zeros((T, D))
    log_Z_node = 0

    logZ, _, _ = kalman_info_filter(J_init, h_init, log_Z_init, J_11, J_21,
                                    J_22, h_1, h_2, log_Z_pair, J_node, h_node,
                                    log_Z_node)

    # logZ = -1/2 log |J| + n/2 log 2 \pi
    logdetJ = -2 * (logZ - (T * D) / 2 * np.log(2 * np.pi))
    return logdetJ
Пример #2
0
def check_filters(A, B, C, D, mu_init, sigma_init, data):
    def info_normalizer(J, h):
        out = 0.
        out += 1 / 2. * h.dot(np.linalg.solve(J, h))
        out -= 1 / 2. * np.linalg.slogdet(J)[1]
        out += h.shape[0] / 2. * np.log(2 * np.pi)
        return out

    ll, filtered_mus, filtered_sigmas = kalman_filter(mu_init, sigma_init, A,
                                                      B.dot(B.T), C,
                                                      D.dot(D.T), data)
    py_partial_ll = info_normalizer(
        *dense_infoparams(A, B, C, D, mu_init, sigma_init, data))
    partial_ll, filtered_Js, filtered_hs = kalman_info_filter(
        *info_params(A, B, C, D, mu_init, sigma_init, data))

    ll2 = partial_ll + LDSStates._extra_loglike_terms(A, B.dot(
        B.T), C, D.dot(D.T), mu_init, sigma_init, data)
    filtered_mus2 = [
        np.linalg.solve(J, h) for J, h in zip(filtered_Js, filtered_hs)
    ]
    filtered_sigmas2 = [np.linalg.inv(J) for J in filtered_Js]

    assert all(
        np.allclose(mu1, mu2) for mu1, mu2 in zip(filtered_mus, filtered_mus2))
    assert all(
        np.allclose(s1, s2)
        for s1, s2 in zip(filtered_sigmas, filtered_sigmas2))

    assert np.isclose(partial_ll, py_partial_ll)
    assert np.isclose(ll, ll2)
Пример #3
0
def check_filters(A, B, C, D, mu_init, sigma_init, data):
    def info_normalizer(J,h):
        out = 0.
        out += 1/2. * h.dot(np.linalg.solve(J,h))
        out -= 1/2. * np.linalg.slogdet(J)[1] 
        out += h.shape[0]/2. * np.log(2*np.pi)
        return out

    ll, filtered_mus, filtered_sigmas = kalman_filter(
        mu_init, sigma_init, A, B.dot(B.T), C, D.dot(D.T), data)
    py_partial_ll = info_normalizer(*dense_infoparams(
        A, B, C, D, mu_init, sigma_init, data))
    partial_ll, filtered_Js, filtered_hs = kalman_info_filter(
        *info_params(A, B, C, D, mu_init, sigma_init, data))

    ll2 = partial_ll + extra_loglike_terms(
        A, B, C, D, mu_init, sigma_init, data)
    filtered_mus2 = [np.linalg.solve(J,h) for J, h in zip(filtered_Js, filtered_hs)]
    filtered_sigmas2 = [np.linalg.inv(J) for J in filtered_Js]

    assert all(np.allclose(mu1, mu2)
               for mu1, mu2 in zip(filtered_mus, filtered_mus2))
    assert all(np.allclose(s1, s2)
               for s1, s2 in zip(filtered_sigmas, filtered_sigmas2))

    assert np.isclose(partial_ll, py_partial_ll)
    assert np.isclose(ll, ll2)
Пример #4
0
    def log_likelihood(self):
        if self._normalizer is None:
            self._normalizer, _, _ = kalman_info_filter(*self.info_params)

            # self._normalizer += self._info_extra_loglike_terms(
            #     *self.extra_info_params,
            #     isdiag=self.diagonal_noise)

        return self._normalizer
Пример #5
0
def check_filters(A, B, sigma_states, C, D, sigma_obs, mu_init, sigma_init,
                  data, inputs):
    ll, filtered_mus, filtered_sigmas = kalman_filter(mu_init, sigma_init, A,
                                                      B, sigma_states, C, D,
                                                      sigma_obs, inputs, data)

    ll2, filtered_Js, filtered_hs = kalman_info_filter(
        *info_params(A, B, sigma_states, C, D, sigma_obs, mu_init, sigma_init,
                     data, inputs))

    filtered_mus2 = [
        np.linalg.solve(J, h) for J, h in zip(filtered_Js, filtered_hs)
    ]

    filtered_sigmas2 = [np.linalg.inv(J) for J in filtered_Js]

    assert all(
        np.allclose(mu1, mu2) for mu1, mu2 in zip(filtered_mus, filtered_mus2))
    assert all(
        np.allclose(s1, s2)
        for s1, s2 in zip(filtered_sigmas, filtered_sigmas2))
    assert np.isclose(ll, ll2)
Пример #6
0
def test_lds_log_probability_perf(T=1000, D=10, N_iter=10):
    """
    Compare performance of banded method vs message passing in pylds.
    """
    print("Comparing methods for T={} D={}".format(T, D))

    from pylds.lds_messages_interface import kalman_info_filter, kalman_filter
    
    # Convert LDS parameters into info form for pylds
    As, bs, Qi_sqrts, ms, Ri_sqrts = make_lds_parameters(T, D)
    Qis = np.matmul(Qi_sqrts, np.swapaxes(Qi_sqrts, -1, -2))
    Ris = np.matmul(Ri_sqrts, np.swapaxes(Ri_sqrts, -1, -2))
    x = npr.randn(T, D)

    print("Timing banded method")
    start = time.time()
    for itr in range(N_iter):
        lds_log_probability(x, As, bs, Qi_sqrts, ms, Ri_sqrts)
    stop = time.time()
    print("Time per iter: {:.4f}".format((stop - start) / N_iter))

    # Compare to Kalman Filter
    mu_init = np.zeros(D)
    sigma_init = np.eye(D)
    Bs = np.ones((D, 1))
    sigma_states = np.linalg.inv(Qis)
    Cs = np.eye(D)
    Ds = np.zeros((D, 1))
    sigma_obs = np.linalg.inv(Ris)
    inputs = bs
    data = ms

    print("Timing PyLDS message passing (kalman_filter)")
    start = time.time()
    for itr in range(N_iter):
        kalman_filter(mu_init, sigma_init, 
            np.concatenate([As, np.eye(D)[None, :, :]]), Bs, np.concatenate([sigma_states, np.eye(D)[None, :, :]]), 
            Cs, Ds, sigma_obs, inputs, data)
    stop = time.time()
    print("Time per iter: {:.4f}".format((stop - start) / N_iter))

    # Info form comparison
    J_init = np.zeros((D, D))
    h_init = np.zeros(D)
    log_Z_init = 0

    J_diag, J_lower_diag, h = convert_lds_to_block_tridiag(As, bs, Qi_sqrts, ms, Ri_sqrts)
    J_pair_21 = J_lower_diag
    J_pair_22 = J_diag[1:]
    J_pair_11 = J_diag[:-1]
    J_pair_11[1:] = 0
    h_pair_2 = h[1:]
    h_pair_1 = h[:-1]
    h_pair_1[1:] = 0
    log_Z_pair = 0

    J_node = np.zeros((T, D, D))
    h_node = np.zeros((T, D))
    log_Z_node = 0

    print("Timing PyLDS message passing (kalman_info_filter)")
    start = time.time()
    for itr in range(N_iter):
        kalman_info_filter(J_init, h_init, log_Z_init, 
            J_pair_11, J_pair_21, J_pair_22, h_pair_1, h_pair_2, log_Z_pair,
            J_node, h_node, log_Z_node)
    stop = time.time()
    print("Time per iter: {:.4f}".format((stop - start) / N_iter))
Пример #7
0
    def info_filter(self):
        self._normalizer, filtered_Js, filtered_hs = \
            kalman_info_filter(*self.info_params)

        return filtered_Js, filtered_hs