示例#1
0
def test_lds_sample(T=25, D=4):
    """
    Test lds_sample correctness
    """
    As, bs, Qi_sqrts, ms, Ri_sqrts = make_lds_parameters(T, D)
    J_diag, J_lower_diag, h = convert_lds_to_block_tridiag(As, bs, Qi_sqrts, ms, Ri_sqrts)

    # Convert to dense matrix
    J_full = np.zeros((T*D, T*D))
    for t in range(T):
        J_full[t*D:(t+1)*D, t*D:(t+1)*D] = J_diag[t]

    for t in range(T-1):
        J_full[t*D:(t+1)*D, (t+1)*D:(t+2)*D] = J_lower_diag[t].T
        J_full[(t+1)*D:(t+2)*D, t*D:(t+1)*D] = J_lower_diag[t]
    
    z = npr.randn(T*D,)

    # Sample directly
    L = np.linalg.cholesky(J_full)
    xtrue = np.linalg.solve(L.T, z).reshape(T, D) 
    xtrue += np.linalg.solve(J_full, h.reshape(T*D)).reshape(T, D)

    # Solve with the banded solver
    xtest = lds_sample(As, bs, Qi_sqrts, ms, Ri_sqrts, z=z)

    assert np.allclose(xtrue, xtest)
示例#2
0
def test_lds_log_probability(T=25, D=4):
    """
    Test lds_log_probability correctness
    """
    As, bs, Qi_sqrts, ms, Ri_sqrts = make_lds_parameters(T, D)
    J_diag, J_lower_diag, h = convert_lds_to_block_tridiag(As, bs, Qi_sqrts, ms, Ri_sqrts)

    # Convert to dense matrix
    J_full = np.zeros((T*D, T*D))
    for t in range(T):
        J_full[t*D:(t+1)*D, t*D:(t+1)*D] = J_diag[t]

    for t in range(T-1):
        J_full[t*D:(t+1)*D, (t+1)*D:(t+2)*D] = J_lower_diag[t].T
        J_full[(t+1)*D:(t+2)*D, t*D:(t+1)*D] = J_lower_diag[t]
    
    Sigma = np.linalg.inv(J_full)
    mu = Sigma.dot(h.ravel()).reshape((T, D))
    x = npr.randn(T, D)

    from scipy.stats import multivariate_normal
    ll_true = multivariate_normal.logpdf(x.ravel(), mu.ravel(), Sigma)

    # Solve with the banded solver
    ll_test = lds_log_probability(x, As, bs, Qi_sqrts, ms, Ri_sqrts)

    assert np.allclose(ll_true, ll_test), "True LL {} != Test LL {}".format(ll_true, ll_test)
示例#3
0
def test_lds_mean(T=25, D=4):
    """
    Test lds_mean correctness
    """
    As, bs, Qi_sqrts, ms, Ri_sqrts = make_lds_parameters(T, D)
    J_diag, J_lower_diag, h = convert_lds_to_block_tridiag(As, bs, Qi_sqrts, ms, Ri_sqrts)

    # Convert to dense matrix
    J_full = np.zeros((T*D, T*D))
    for t in range(T):
        J_full[t*D:(t+1)*D, t*D:(t+1)*D] = J_diag[t]

    for t in range(T-1):
        J_full[t*D:(t+1)*D, (t+1)*D:(t+2)*D] = J_lower_diag[t].T
        J_full[(t+1)*D:(t+2)*D, t*D:(t+1)*D] = J_lower_diag[t]
    
    Sigma = np.linalg.inv(J_full)
    mu_true = Sigma.dot(h.ravel()).reshape((T, D))
    
    
    # Solve with the banded solver
    mu_test = lds_mean(As, bs, Qi_sqrts, ms, Ri_sqrts)

    assert np.allclose(mu_true, mu_test)
示例#4
0
def test_lds_log_probability_perf(T=1000, D=10, N_iter=10):
    """
    Compare performance of banded method vs message passing in pylds.
    """
    print("Comparing methods for T={} D={}".format(T, D))

    from pylds.lds_messages_interface import kalman_info_filter, kalman_filter
    
    # Convert LDS parameters into info form for pylds
    As, bs, Qi_sqrts, ms, Ri_sqrts = make_lds_parameters(T, D)
    Qis = np.matmul(Qi_sqrts, np.swapaxes(Qi_sqrts, -1, -2))
    Ris = np.matmul(Ri_sqrts, np.swapaxes(Ri_sqrts, -1, -2))
    x = npr.randn(T, D)

    print("Timing banded method")
    start = time.time()
    for itr in range(N_iter):
        lds_log_probability(x, As, bs, Qi_sqrts, ms, Ri_sqrts)
    stop = time.time()
    print("Time per iter: {:.4f}".format((stop - start) / N_iter))

    # Compare to Kalman Filter
    mu_init = np.zeros(D)
    sigma_init = np.eye(D)
    Bs = np.ones((D, 1))
    sigma_states = np.linalg.inv(Qis)
    Cs = np.eye(D)
    Ds = np.zeros((D, 1))
    sigma_obs = np.linalg.inv(Ris)
    inputs = bs
    data = ms

    print("Timing PyLDS message passing (kalman_filter)")
    start = time.time()
    for itr in range(N_iter):
        kalman_filter(mu_init, sigma_init, 
            np.concatenate([As, np.eye(D)[None, :, :]]), Bs, np.concatenate([sigma_states, np.eye(D)[None, :, :]]), 
            Cs, Ds, sigma_obs, inputs, data)
    stop = time.time()
    print("Time per iter: {:.4f}".format((stop - start) / N_iter))

    # Info form comparison
    J_init = np.zeros((D, D))
    h_init = np.zeros(D)
    log_Z_init = 0

    J_diag, J_lower_diag, h = convert_lds_to_block_tridiag(As, bs, Qi_sqrts, ms, Ri_sqrts)
    J_pair_21 = J_lower_diag
    J_pair_22 = J_diag[1:]
    J_pair_11 = J_diag[:-1]
    J_pair_11[1:] = 0
    h_pair_2 = h[1:]
    h_pair_1 = h[:-1]
    h_pair_1[1:] = 0
    log_Z_pair = 0

    J_node = np.zeros((T, D, D))
    h_node = np.zeros((T, D))
    log_Z_node = 0

    print("Timing PyLDS message passing (kalman_info_filter)")
    start = time.time()
    for itr in range(N_iter):
        kalman_info_filter(J_init, h_init, log_Z_init, 
            J_pair_11, J_pair_21, J_pair_22, h_pair_1, h_pair_2, log_Z_pair,
            J_node, h_node, log_Z_node)
    stop = time.time()
    print("Time per iter: {:.4f}".format((stop - start) / N_iter))
示例#5
0
def plot_trial(model, q, posterior, tr=0, legend=False):

    if posterior is "laplace_em":
        q_x = q.mean_continuous_states[tr]
        J_diag = q._params[tr]["J_diag"]
        J_lower_diag = q._params[tr]["J_lower_diag"]
        J = blocks_to_full(J_diag, J_lower_diag)
        Jinv = np.linalg.inv(J)
        q_lem_std = np.sqrt(np.diag(Jinv))
        q_lem_std = q_lem_std.reshape((T, D))
        q_std_1 = q_lem_std[:, 0]
        q_std_2 = q_lem_std[:, 1]
    elif posterior is "mf":
        q_x = q.mean[tr]
        q_std_1 = np.sqrt(np.exp(q.params[tr][1])[:, 0])
        q_std_2 = np.sqrt(np.exp(q.params[tr][1])[:, 1])
    elif posterior is "lds":
        q_x = q.mean[tr]
        J_diag, J_lower_diag, h = convert_lds_to_block_tridiag(
            *q_lds.params[tr])
        J = blocks_to_full(J_diag, J_lower_diag)
        Jinv = np.linalg.inv(J)
        q_lem_std = np.sqrt(np.diag(Jinv))
        q_lem_std = q_lem_std.reshape((T, D))
        q_std_1 = q_lem_std[:, 0]
        q_std_2 = q_lem_std[:, 1]

    yhat = model.smooth(q_x, ys[tr], input=us[tr])
    zhat = model.most_likely_states(q_x, ys[tr], input=us[tr])

    f, (a0, a1, a2,
        a3) = plt.subplots(4,
                           1,
                           gridspec_kw={'height_ratios': [0.5, 0.5, 3, 1]})
    a0.imshow(np.row_stack((zs[tr], zhat)), aspect="auto", vmin=0, vmax=2)
    a0.set_xticks([])
    a0.set_yticks([0, 1], ["$z_{\\mathrm{true}}$", "$z_{\\mathrm{inf}}$"])
    a0.axis("off")
    a2.plot(xs[tr][:, 0], color=[1.0, 0.0, 0.0], label="$x_1$", alpha=0.9)
    a2.plot(xs[tr][:, 1], color=[0.0, 0.0, 1.0], label="$x_2$", alpha=0.9)
    a2.plot(q_x[:, 0],
            color=[1.0, 0.3, 0.3],
            linestyle='--',
            label="$\hat{x}_1$",
            alpha=0.9)
    a2.plot(q_x[:, 1],
            color=[0.3, 0.3, 1.0],
            linestyle='--',
            label="$\hat{x}_2$",
            alpha=0.9)

    a2.fill_between(np.arange(T),
                    q_x[:, 0] - q_std_1 * 2.0,
                    q_x[:, 0] + q_std_1 * 2.0,
                    facecolor='r',
                    alpha=0.3)
    a2.fill_between(np.arange(T),
                    q_x[:, 1] - q_std_2 * 2.0,
                    q_x[:, 1] + q_std_2 * 2.0,
                    facecolor='b',
                    alpha=0.3)
    a2.plot(np.array([0, 100]),
            np.array([1, 1]),
            'k--',
            linewidth=1.0,
            label=None)
    a2.set_ylim([-0.4, 1.4])
    a2.set_xlim([-1, 101])
    a2.set_xticks([])
    a2.set_yticks([0, 1])
    a2.set_ylabel("x")
    if legend:
        a2.legend()
    sns.despine()
    for n in range(10):
        a3.eventplot(np.where(ys[tr][:, n] > 0)[0],
                     linelengths=0.5,
                     lineoffsets=1 + n,
                     color='k')
    sns.despine()
    a3.set_yticks([])
    a3.set_xlim([-1, 101])

    a1.plot(0.2 * us[tr][:, 0], color=[1.0, 0.5, 0.5], label=None, alpha=0.9)
    a1.plot(0.2 * us[tr][:, 1], color=[0.5, 0.5, 1.0], label=None, alpha=0.9)
    a1.set_yticks([])
    a1.set_xticks([])
    a1.axes.get_yaxis().set_visible(False)
    plt.tight_layout()

    return