コード例 #1
0
    return vmap(hmm_loglikelihood_jax, in_axes=(None, 0, 0))(params_jax,
                                                             batches, lens)


# state transition matrix
A = jnp.array([[0.95, 0.05], [0.10, 0.90]])

# observation matrix
B = jnp.array([
    [1 / 6, 1 / 6, 1 / 6, 1 / 6, 1 / 6, 1 / 6],  # fair die
    [1 / 10, 1 / 10, 1 / 10, 1 / 10, 1 / 10, 5 / 10]  # loaded die
])

pi = jnp.array([1, 1]) / 2

params_numpy = HMMNumpy(np.array(A), np.array(B), np.array(pi))
params_jax = HMMJax(A, B, pi)

seed = 0
rng_key = PRNGKey(seed)
rng_key, rng_sample = split(rng_key)

n_obs_seq, batch_size, max_len = 15, 5, 10

observations, lens = hmm_utils.hmm_sample_n(params_jax, hmm_sample_jax,
                                            n_obs_seq, max_len, rng_sample)

observations, lens = hmm_utils.pad_sequences(observations, lens)

rng_key, rng_batch = split(rng_key)
batches, lens = hmm_utils.hmm_sample_minibatches(observations, lens,
コード例 #2
0
# state transition matrix
n_hidden, n_obs = 100, 10
A = uniform(key_A, (n_hidden, n_hidden))
A = A / jnp.sum(A, axis=1)

# observation matrix
B = uniform(key_B, (n_hidden, n_obs))
B = B / jnp.sum(B, axis=1).reshape((-1, 1))

n_samples = 1000
init_state_dist = jnp.ones(n_hidden) / n_hidden

seed = 0
rng_key = PRNGKey(seed)

params_numpy = HMMNumpy(A, B, init_state_dist)
params_jax = HMMJax(A, B, init_state_dist)
hmm_distrax = HMM(trans_dist=distrax.Categorical(probs=A),
                  obs_dist=distrax.Categorical(probs=B),
                  init_dist=distrax.Categorical(probs=init_state_dist))

z_hist, x_hist = hmm_sample_jax(params_jax, n_samples, rng_key)

start = time.time()
alphas_np, _, gammas_np, loglikelihood_np = hmm_forwards_backwards_numpy(
    params_numpy, x_hist, len(x_hist))
print(
    f'Time taken by numpy version of forwards backwards : {time.time()-start}s'
)

start = time.time()
コード例 #3
0
seed = 100
rng_key = PRNGKey(seed)
rng_key, rng_sample, rng_batch, rng_init = split(rng_key, 4)

casino = HMMJax(A, B, pi)

n_obs_seq, batch_size, max_len = 5, 5, 3000

observations, lens = hmm_utils.hmm_sample_n(casino, hmm_sample_jax, n_obs_seq,
                                            max_len, rng_sample)
observations, lens = hmm_utils.pad_sequences(observations, lens)

n_hidden, n_obs = B.shape
params_jax = init_random_params_jax([n_hidden, n_obs], rng_key=rng_init)
params_numpy = HMMNumpy(np.array(params_jax.trans_mat),
                        np.array(params_jax.obs_mat),
                        np.array(params_jax.init_dist))

num_epochs = 20

start = time.time()
params_numpy, neg_ll_numpy = hmm_em_numpy(np.array(observations),
                                          np.array(lens),
                                          num_epochs=num_epochs,
                                          init_params=params_numpy)
print(f'Time taken by numpy version of EM : {time.time()-start}s')

start = time.time()
params_jax, neg_ll_jax = hmm_em_jax(observations,
                                    lens,
                                    num_epochs=num_epochs,