def test_lds_sample(T=25, D=4): """ Test lds_sample correctness """ As, bs, Qi_sqrts, ms, Ri_sqrts = make_lds_parameters(T, D) J_diag, J_lower_diag, h = convert_lds_to_block_tridiag(As, bs, Qi_sqrts, ms, Ri_sqrts) # Convert to dense matrix J_full = np.zeros((T*D, T*D)) for t in range(T): J_full[t*D:(t+1)*D, t*D:(t+1)*D] = J_diag[t] for t in range(T-1): J_full[t*D:(t+1)*D, (t+1)*D:(t+2)*D] = J_lower_diag[t].T J_full[(t+1)*D:(t+2)*D, t*D:(t+1)*D] = J_lower_diag[t] z = npr.randn(T*D,) # Sample directly L = np.linalg.cholesky(J_full) xtrue = np.linalg.solve(L.T, z).reshape(T, D) xtrue += np.linalg.solve(J_full, h.reshape(T*D)).reshape(T, D) # Solve with the banded solver xtest = lds_sample(As, bs, Qi_sqrts, ms, Ri_sqrts, z=z) assert np.allclose(xtrue, xtest)
def test_lds_log_probability(T=25, D=4): """ Test lds_log_probability correctness """ As, bs, Qi_sqrts, ms, Ri_sqrts = make_lds_parameters(T, D) J_diag, J_lower_diag, h = convert_lds_to_block_tridiag(As, bs, Qi_sqrts, ms, Ri_sqrts) # Convert to dense matrix J_full = np.zeros((T*D, T*D)) for t in range(T): J_full[t*D:(t+1)*D, t*D:(t+1)*D] = J_diag[t] for t in range(T-1): J_full[t*D:(t+1)*D, (t+1)*D:(t+2)*D] = J_lower_diag[t].T J_full[(t+1)*D:(t+2)*D, t*D:(t+1)*D] = J_lower_diag[t] Sigma = np.linalg.inv(J_full) mu = Sigma.dot(h.ravel()).reshape((T, D)) x = npr.randn(T, D) from scipy.stats import multivariate_normal ll_true = multivariate_normal.logpdf(x.ravel(), mu.ravel(), Sigma) # Solve with the banded solver ll_test = lds_log_probability(x, As, bs, Qi_sqrts, ms, Ri_sqrts) assert np.allclose(ll_true, ll_test), "True LL {} != Test LL {}".format(ll_true, ll_test)
def test_lds_mean(T=25, D=4): """ Test lds_mean correctness """ As, bs, Qi_sqrts, ms, Ri_sqrts = make_lds_parameters(T, D) J_diag, J_lower_diag, h = convert_lds_to_block_tridiag(As, bs, Qi_sqrts, ms, Ri_sqrts) # Convert to dense matrix J_full = np.zeros((T*D, T*D)) for t in range(T): J_full[t*D:(t+1)*D, t*D:(t+1)*D] = J_diag[t] for t in range(T-1): J_full[t*D:(t+1)*D, (t+1)*D:(t+2)*D] = J_lower_diag[t].T J_full[(t+1)*D:(t+2)*D, t*D:(t+1)*D] = J_lower_diag[t] Sigma = np.linalg.inv(J_full) mu_true = Sigma.dot(h.ravel()).reshape((T, D)) # Solve with the banded solver mu_test = lds_mean(As, bs, Qi_sqrts, ms, Ri_sqrts) assert np.allclose(mu_true, mu_test)
def test_lds_log_probability_perf(T=1000, D=10, N_iter=10): """ Compare performance of banded method vs message passing in pylds. """ print("Comparing methods for T={} D={}".format(T, D)) from pylds.lds_messages_interface import kalman_info_filter, kalman_filter # Convert LDS parameters into info form for pylds As, bs, Qi_sqrts, ms, Ri_sqrts = make_lds_parameters(T, D) Qis = np.matmul(Qi_sqrts, np.swapaxes(Qi_sqrts, -1, -2)) Ris = np.matmul(Ri_sqrts, np.swapaxes(Ri_sqrts, -1, -2)) x = npr.randn(T, D) print("Timing banded method") start = time.time() for itr in range(N_iter): lds_log_probability(x, As, bs, Qi_sqrts, ms, Ri_sqrts) stop = time.time() print("Time per iter: {:.4f}".format((stop - start) / N_iter)) # Compare to Kalman Filter mu_init = np.zeros(D) sigma_init = np.eye(D) Bs = np.ones((D, 1)) sigma_states = np.linalg.inv(Qis) Cs = np.eye(D) Ds = np.zeros((D, 1)) sigma_obs = np.linalg.inv(Ris) inputs = bs data = ms print("Timing PyLDS message passing (kalman_filter)") start = time.time() for itr in range(N_iter): kalman_filter(mu_init, sigma_init, np.concatenate([As, np.eye(D)[None, :, :]]), Bs, np.concatenate([sigma_states, np.eye(D)[None, :, :]]), Cs, Ds, sigma_obs, inputs, data) stop = time.time() print("Time per iter: {:.4f}".format((stop - start) / N_iter)) # Info form comparison J_init = np.zeros((D, D)) h_init = np.zeros(D) log_Z_init = 0 J_diag, J_lower_diag, h = convert_lds_to_block_tridiag(As, bs, Qi_sqrts, ms, Ri_sqrts) J_pair_21 = J_lower_diag J_pair_22 = J_diag[1:] J_pair_11 = J_diag[:-1] J_pair_11[1:] = 0 h_pair_2 = h[1:] h_pair_1 = h[:-1] h_pair_1[1:] = 0 log_Z_pair = 0 J_node = np.zeros((T, D, D)) h_node = np.zeros((T, D)) log_Z_node = 0 print("Timing PyLDS message passing (kalman_info_filter)") start = time.time() for itr in range(N_iter): kalman_info_filter(J_init, h_init, log_Z_init, J_pair_11, J_pair_21, J_pair_22, h_pair_1, h_pair_2, log_Z_pair, J_node, h_node, log_Z_node) stop = time.time() print("Time per iter: {:.4f}".format((stop - start) / N_iter))
def plot_trial(model, q, posterior, tr=0, legend=False): if posterior is "laplace_em": q_x = q.mean_continuous_states[tr] J_diag = q._params[tr]["J_diag"] J_lower_diag = q._params[tr]["J_lower_diag"] J = blocks_to_full(J_diag, J_lower_diag) Jinv = np.linalg.inv(J) q_lem_std = np.sqrt(np.diag(Jinv)) q_lem_std = q_lem_std.reshape((T, D)) q_std_1 = q_lem_std[:, 0] q_std_2 = q_lem_std[:, 1] elif posterior is "mf": q_x = q.mean[tr] q_std_1 = np.sqrt(np.exp(q.params[tr][1])[:, 0]) q_std_2 = np.sqrt(np.exp(q.params[tr][1])[:, 1]) elif posterior is "lds": q_x = q.mean[tr] J_diag, J_lower_diag, h = convert_lds_to_block_tridiag( *q_lds.params[tr]) J = blocks_to_full(J_diag, J_lower_diag) Jinv = np.linalg.inv(J) q_lem_std = np.sqrt(np.diag(Jinv)) q_lem_std = q_lem_std.reshape((T, D)) q_std_1 = q_lem_std[:, 0] q_std_2 = q_lem_std[:, 1] yhat = model.smooth(q_x, ys[tr], input=us[tr]) zhat = model.most_likely_states(q_x, ys[tr], input=us[tr]) f, (a0, a1, a2, a3) = plt.subplots(4, 1, gridspec_kw={'height_ratios': [0.5, 0.5, 3, 1]}) a0.imshow(np.row_stack((zs[tr], zhat)), aspect="auto", vmin=0, vmax=2) a0.set_xticks([]) a0.set_yticks([0, 1], ["$z_{\\mathrm{true}}$", "$z_{\\mathrm{inf}}$"]) a0.axis("off") a2.plot(xs[tr][:, 0], color=[1.0, 0.0, 0.0], label="$x_1$", alpha=0.9) a2.plot(xs[tr][:, 1], color=[0.0, 0.0, 1.0], label="$x_2$", alpha=0.9) a2.plot(q_x[:, 0], color=[1.0, 0.3, 0.3], linestyle='--', label="$\hat{x}_1$", alpha=0.9) a2.plot(q_x[:, 1], color=[0.3, 0.3, 1.0], linestyle='--', label="$\hat{x}_2$", alpha=0.9) a2.fill_between(np.arange(T), q_x[:, 0] - q_std_1 * 2.0, q_x[:, 0] + q_std_1 * 2.0, facecolor='r', alpha=0.3) a2.fill_between(np.arange(T), q_x[:, 1] - q_std_2 * 2.0, q_x[:, 1] + q_std_2 * 2.0, facecolor='b', alpha=0.3) a2.plot(np.array([0, 100]), np.array([1, 1]), 'k--', linewidth=1.0, label=None) a2.set_ylim([-0.4, 1.4]) a2.set_xlim([-1, 101]) a2.set_xticks([]) a2.set_yticks([0, 1]) a2.set_ylabel("x") if legend: a2.legend() sns.despine() for n in range(10): a3.eventplot(np.where(ys[tr][:, n] > 0)[0], linelengths=0.5, lineoffsets=1 + n, color='k') sns.despine() a3.set_yticks([]) a3.set_xlim([-1, 101]) a1.plot(0.2 * us[tr][:, 0], color=[1.0, 0.5, 0.5], label=None, alpha=0.9) a1.plot(0.2 * us[tr][:, 1], color=[0.5, 0.5, 1.0], label=None, alpha=0.9) a1.set_yticks([]) a1.set_xticks([]) a1.axes.get_yaxis().set_visible(False) plt.tight_layout() return