def loss_fn(params, batch, lens): ''' Objective function of hidden markov models for discrete observations. It returns the mean of the negative loglikelihood of the sequence of observations Parameters ---------- params : HMMJax Hidden Markov Model batch: array(N, max_len) Minibatch consisting of observation sequences lens : array(N, seq_len) Consists of the valid length of each observation sequence in the minibatch Returns ------- * float The mean negative loglikelihood of the minibatch ''' params_soft = HMMJax(softmax(params.trans_mat, axis=1), softmax(params.obs_mat, axis=1), softmax(params.init_dist)) return -hmm_loglikelihood_jax(params_soft, batch, lens).mean()
def init_random_params(sizes, rng_key): ''' Initializes the components of HMM from normal distibution Parameters ---------- sizes: List Consists of number of hidden states and observable events, respectively rng_key : array Random key of shape (2,) and dtype uint32 Returns ------- * array(num_hidden, num_hidden) Transition probability matrix * array(num_hidden, num_obs) Emission probability matrix * array(1, num_hidden) Initial distribution probabilities ''' num_hidden, num_obs = sizes rng_key, rng_a, rng_b, rng_pi = split(rng_key, 4) return HMMJax(normal(rng_a, (num_hidden, num_hidden)), normal(rng_b, (num_hidden, num_obs)), normal(rng_pi, (num_hidden, )))
n_hidden, n_obs = 100, 10 A = uniform(key_A, (n_hidden, n_hidden)) A = A / jnp.sum(A, axis=1) # observation matrix B = uniform(key_B, (n_hidden, n_obs)) B = B / jnp.sum(B, axis=1).reshape((-1, 1)) n_samples = 1000 init_state_dist = jnp.ones(n_hidden) / n_hidden seed = 0 rng_key = PRNGKey(seed) params_numpy = HMMNumpy(A, B, init_state_dist) params_jax = HMMJax(A, B, init_state_dist) hmm_distrax = HMM(trans_dist=distrax.Categorical(probs=A), obs_dist=distrax.Categorical(probs=B), init_dist=distrax.Categorical(probs=init_state_dist)) z_hist, x_hist = hmm_sample_jax(params_jax, n_samples, rng_key) start = time.time() alphas_np, _, gammas_np, loglikelihood_np = hmm_forwards_backwards_numpy( params_numpy, x_hist, len(x_hist)) print( f'Time taken by numpy version of forwards backwards : {time.time()-start}s' ) start = time.time() alphas_jax, _, gammas_jax, loglikelihood_jax = hmm_forwards_backwards_jax(
batches, lens) # state transition matrix A = jnp.array([[0.95, 0.05], [0.10, 0.90]]) # observation matrix B = jnp.array([ [1 / 6, 1 / 6, 1 / 6, 1 / 6, 1 / 6, 1 / 6], # fair die [1 / 10, 1 / 10, 1 / 10, 1 / 10, 1 / 10, 5 / 10] # loaded die ]) pi = jnp.array([1, 1]) / 2 params_numpy = HMMNumpy(np.array(A), np.array(B), np.array(pi)) params_jax = HMMJax(A, B, pi) seed = 0 rng_key = PRNGKey(seed) rng_key, rng_sample = split(rng_key) n_obs_seq, batch_size, max_len = 15, 5, 10 observations, lens = hmm_utils.hmm_sample_n(params_jax, hmm_sample_jax, n_obs_seq, max_len, rng_sample) observations, lens = hmm_utils.pad_sequences(observations, lens) rng_key, rng_batch = split(rng_key) batches, lens = hmm_utils.hmm_sample_minibatches(observations, lens, batch_size, rng_batch)
from hmm_sgd_lib import fit from hmm_utils import pad_sequences, hmm_sample_n from jax.experimental import optimizers # state transition matrix A = jnp.array([[0.95, 0.05], [0.10, 0.90]]) # observation matrix B = jnp.array([ [1 / 6, 1 / 6, 1 / 6, 1 / 6, 1 / 6, 1 / 6], # fair die [1 / 10, 1 / 10, 1 / 10, 1 / 10, 1 / 10, 5 / 10] # loaded die ]) pi = jnp.array([1, 1]) / 2 casino = HMMJax(A, B, pi) num_hidden, num_obs = 2, 6 seed = 0 rng_key = PRNGKey(seed) rng_key, rng_sample = split(rng_key) n_obs_seq, max_len = 4, 5000 num_epochs = 400 observations, lens = pad_sequences( *hmm_sample_n(casino, hmm_sample_jax, n_obs_seq, max_len, rng_sample)) optimizer = optimizers.momentum(step_size=1e-3, mass=0.95) # Mini Batch Gradient Descent batch_size = 2
def fit(observations, lens, num_hidden, num_obs, batch_size, optimizer, rng_key=None, num_epochs=1): ''' Trains the HMM model with the given number of hidden states and observations via any optimizer. Parameters ---------- observations: array(N, seq_len) All observation sequences lens : array(N, seq_len) Consists of the valid length of each observation sequence num_hidden : int The number of hidden state num_obs : int The number of observable events batch_size : int The number of observation sequences that will be included in each minibatch optimizer : jax.experimental.optimizers.Optimizer Optimizer that is used during training num_epochs : int The total number of iterations Returns ------- * HMMJax Hidden Markov Model * array Consists of training losses ''' global opt_init, opt_update, get_params if rng_key is None: rng_key = PRNGKey(0) rng_init, rng_iter = split(rng_key) params = init_random_params([num_hidden, num_obs], rng_init) opt_init, opt_update, get_params = optimizer opt_state = opt_init(params) itercount = itertools.count() def epoch_step(opt_state, key): def train_step(opt_state, params): batch, length = params opt_state, loss = update(next(itercount), opt_state, batch, length) return opt_state, loss batches, valid_lens = hmm_sample_minibatches(observations, lens, batch_size, key) params = (batches, valid_lens) opt_state, losses = jax.lax.scan(train_step, opt_state, params) return opt_state, losses.mean() epochs = split(rng_iter, num_epochs) opt_state, losses = jax.lax.scan(epoch_step, opt_state, epochs) losses = losses.flatten() params = get_params(opt_state) params = HMMJax(softmax(params.trans_mat, axis=1), softmax(params.obs_mat, axis=1), softmax(params.init_dist)) return params, losses