def hmm_expected_states(log_pi0, log_Ps, ll): T, K = ll.shape # Make sure everything is C contiguous to_c = lambda arr: np.copy(arr, 'C') if not arr.flags['C_CONTIGUOUS'] else arr log_pi0 = to_c(getval(log_pi0)) log_Ps = to_c(getval(log_Ps)) ll = to_c(getval(ll)) alphas = np.zeros((T, K)) forward_pass(log_pi0, log_Ps, ll, alphas) normalizer = logsumexp(alphas[-1]) betas = np.zeros((T, K)) backward_pass(log_Ps, ll, betas) expected_states = alphas + betas expected_states -= logsumexp(expected_states, axis=1, keepdims=True) expected_states = np.exp(expected_states) expected_joints = alphas[:-1,:,None] + betas[1:,None,:] + ll[1:,None,:] + log_Ps expected_joints -= expected_joints.max((1,2))[:,None, None] expected_joints = np.exp(expected_joints) expected_joints /= expected_joints.sum((1,2))[:,None,None] return expected_states, expected_joints, normalizer
def hmm_expected_states(log_pi0, log_Ps, ll): T, K = ll.shape # Make sure everything is C contiguous log_pi0 = to_c(log_pi0) log_Ps = to_c(log_Ps) ll = to_c(ll) alphas = np.zeros((T, K)) forward_pass(log_pi0, log_Ps, ll, alphas) normalizer = logsumexp(alphas[-1]) betas = np.zeros((T, K)) backward_pass(log_Ps, ll, betas) expected_states = alphas + betas expected_states -= logsumexp(expected_states, axis=1, keepdims=True) expected_states = np.exp(expected_states) expected_joints = alphas[:-1,:,None] + betas[1:,None,:] + ll[1:,None,:] + log_Ps expected_joints -= expected_joints.max((1,2))[:,None, None] expected_joints = np.exp(expected_joints) expected_joints /= expected_joints.sum((1,2))[:,None,None] return expected_states, expected_joints, normalizer
def hmm_expected_states(log_pi0, log_Ps, ll, memlimit=2**31): T, K = ll.shape # Make sure everything is C contiguous log_pi0 = to_c(log_pi0) log_Ps = to_c(log_Ps) ll = to_c(ll) alphas = np.zeros((T, K)) forward_pass(log_pi0, log_Ps, ll, alphas) normalizer = logsumexp(alphas[-1]) betas = np.zeros((T, K)) backward_pass(log_Ps, ll, betas) # Compute E[z_t] for t = 1, ..., T expected_states = alphas + betas expected_states -= logsumexp(expected_states, axis=1, keepdims=True) expected_states = np.exp(expected_states) # Compute E[z_t, z_{t+1}] for t = 1, ..., T-1 # Note that this is an array of size T*K*K, which can be quite large. # To be a bit more frugal with memory, first check if the given log_Ps # are TxKxK. If so, instantiate the full expected joints as well, since # we will need them for the M-step. However, if log_Ps is 1xKxK then we # know that the transition matrix is stationary, and all we need for the # M-step is the sum of the expected joints. stationary = (log_Ps.shape[0] == 1) if not stationary: expected_joints = alphas[:-1, :, None] + betas[1:, None, :] + ll[ 1:, None, :] + log_Ps expected_joints -= expected_joints.max((1, 2))[:, None, None] expected_joints = np.exp(expected_joints) expected_joints /= expected_joints.sum((1, 2))[:, None, None] else: # Compute the sum over time axis of the expected joints # Limit ourselves to approximately 1GB of memory, assuming # the entries are float64's (8 bytes) batch_size = int(memlimit / (8 * K * K)) assert batch_size > 0 expected_joints = np.zeros((1, K, K)) for start in range(0, T - 1, batch_size): stop = min(T - 1, start + batch_size) # Compute expectations in this batch tmp = alphas[start:stop, :, None] + betas[start + 1:stop + 1, None, :] + ll[start + 1:stop + 1, None, :] + log_Ps tmp -= tmp.max((1, 2))[:, None, None] tmp = np.exp(tmp) tmp /= tmp.sum((1, 2))[:, None, None] expected_joints += tmp.sum(axis=0) return expected_states, expected_joints, normalizer
def test_backward_pass(T=1000, K=5, D=2): from pyhsmm.internals.hmm_messages_interface import messages_backwards_log # Make parameters As = npr.rand(K, K) As /= As.sum(axis=-1, keepdims=True) ll = npr.randn(T, K) # Use pyhsmm to compute true_betas = np.zeros((T, K)) messages_backwards_log(As, ll, true_betas) # Use ssm to compute test_betas = np.zeros((T, K)) backward_pass(As[None, :, :], ll, test_betas) assert np.allclose(true_betas, test_betas)
def hmm_expected_states(log_pi0, log_Ps, ll, memlimit=2**31): T, K = ll.shape # Make sure everything is C contiguous log_pi0 = to_c(log_pi0) log_Ps = to_c(log_Ps) ll = to_c(ll) alphas = np.zeros((T, K)) forward_pass(log_pi0, log_Ps, ll, alphas) normalizer = logsumexp(alphas[-1]) betas = np.zeros((T, K)) backward_pass(log_Ps, ll, betas) # Compute E[z_t] for t = 1, ..., T expected_states = alphas + betas expected_states -= logsumexp(expected_states, axis=1, keepdims=True) expected_states = np.exp(expected_states) # Compute E[z_t, z_{t+1}] for t = 1, ..., T-1 # Note that this is an array of size T*K*K, which can be quite large. # To be a bit more frugal with memory, first check if the given log_Ps # are TxKxK. If so, instantiate the full expected joints as well, since # we will need them for the M-step. However, if log_Ps is 1xKxK then we # know that the transition matrix is stationary, and all we need for the # M-step is the sum of the expected joints. stationary = (log_Ps.shape[0] == 1) if not stationary: expected_joints = alphas[:-1, :, None] + betas[1:, None, :] + ll[ 1:, None, :] + log_Ps expected_joints -= expected_joints.max((1, 2))[:, None, None] expected_joints = np.exp(expected_joints) expected_joints /= expected_joints.sum((1, 2))[:, None, None] else: # Compute the sum over time axis of the expected joints expected_joints = np.zeros((K, K)) compute_stationary_expected_joints(alphas, betas, ll, log_Ps[0], expected_joints) expected_joints = expected_joints[None, :, :] return expected_states, expected_joints, normalizer
def test_hmm_mp_perf(T=10000, K=100, D=20): # Make parameters pi0 = np.ones(K) / K Ps = npr.rand(T-1, K, K) Ps /= Ps.sum(axis=2, keepdims=True) ll = npr.randn(T, K) out1 = np.zeros((T, K)) out2 = np.zeros((T, K)) # Run the PyHSMM message passing code from pyhsmm.internals.hmm_messages_interface import messages_forwards_log, messages_backwards_log tic = time() messages_forwards_log(Ps, ll, pi0, out1) pyhsmm_dt = time() - tic print("PyHSMM Fwd: ", pyhsmm_dt, "sec") # Run the SSM message passing code from ssm.messages import forward_pass, backward_pass forward_pass(pi0, Ps, ll, out2) # Call once to compile, then time it tic = time() forward_pass(pi0, Ps, ll, out2) smm_dt = time() - tic print("SMM Fwd: ", smm_dt, "sec") assert np.allclose(out1, out2) # Backward pass tic = time() messages_backwards_log(Ps, ll, out1) pyhsmm_dt = time() - tic print("PyHSMM Bwd: ", pyhsmm_dt, "sec") backward_pass(Ps, ll, out2) # Call once to compile, then time it tic = time() backward_pass(Ps, ll, out2) smm_dt = time() - tic print("SMM (Numba) Bwd: ", smm_dt, "sec") assert np.allclose(out1, out2)