return vmap(hmm_loglikelihood_jax, in_axes=(None, 0, 0))(params_jax, batches, lens) # state transition matrix A = jnp.array([[0.95, 0.05], [0.10, 0.90]]) # observation matrix B = jnp.array([ [1 / 6, 1 / 6, 1 / 6, 1 / 6, 1 / 6, 1 / 6], # fair die [1 / 10, 1 / 10, 1 / 10, 1 / 10, 1 / 10, 5 / 10] # loaded die ]) pi = jnp.array([1, 1]) / 2 params_numpy = HMMNumpy(np.array(A), np.array(B), np.array(pi)) params_jax = HMMJax(A, B, pi) seed = 0 rng_key = PRNGKey(seed) rng_key, rng_sample = split(rng_key) n_obs_seq, batch_size, max_len = 15, 5, 10 observations, lens = hmm_utils.hmm_sample_n(params_jax, hmm_sample_jax, n_obs_seq, max_len, rng_sample) observations, lens = hmm_utils.pad_sequences(observations, lens) rng_key, rng_batch = split(rng_key) batches, lens = hmm_utils.hmm_sample_minibatches(observations, lens,
# state transition matrix n_hidden, n_obs = 100, 10 A = uniform(key_A, (n_hidden, n_hidden)) A = A / jnp.sum(A, axis=1) # observation matrix B = uniform(key_B, (n_hidden, n_obs)) B = B / jnp.sum(B, axis=1).reshape((-1, 1)) n_samples = 1000 init_state_dist = jnp.ones(n_hidden) / n_hidden seed = 0 rng_key = PRNGKey(seed) params_numpy = HMMNumpy(A, B, init_state_dist) params_jax = HMMJax(A, B, init_state_dist) hmm_distrax = HMM(trans_dist=distrax.Categorical(probs=A), obs_dist=distrax.Categorical(probs=B), init_dist=distrax.Categorical(probs=init_state_dist)) z_hist, x_hist = hmm_sample_jax(params_jax, n_samples, rng_key) start = time.time() alphas_np, _, gammas_np, loglikelihood_np = hmm_forwards_backwards_numpy( params_numpy, x_hist, len(x_hist)) print( f'Time taken by numpy version of forwards backwards : {time.time()-start}s' ) start = time.time()
seed = 100 rng_key = PRNGKey(seed) rng_key, rng_sample, rng_batch, rng_init = split(rng_key, 4) casino = HMMJax(A, B, pi) n_obs_seq, batch_size, max_len = 5, 5, 3000 observations, lens = hmm_utils.hmm_sample_n(casino, hmm_sample_jax, n_obs_seq, max_len, rng_sample) observations, lens = hmm_utils.pad_sequences(observations, lens) n_hidden, n_obs = B.shape params_jax = init_random_params_jax([n_hidden, n_obs], rng_key=rng_init) params_numpy = HMMNumpy(np.array(params_jax.trans_mat), np.array(params_jax.obs_mat), np.array(params_jax.init_dist)) num_epochs = 20 start = time.time() params_numpy, neg_ll_numpy = hmm_em_numpy(np.array(observations), np.array(lens), num_epochs=num_epochs, init_params=params_numpy) print(f'Time taken by numpy version of EM : {time.time()-start}s') start = time.time() params_jax, neg_ll_jax = hmm_em_jax(observations, lens, num_epochs=num_epochs,