예제 #1
0
def loss_fn(params, batch, lens):
    '''
    Objective function of hidden markov models for discrete observations. It returns the mean of the negative
    loglikelihood of the sequence of observations

    Parameters
    ----------
    params : HMMJax
        Hidden Markov Model

    batch: array(N, max_len)
        Minibatch consisting of observation sequences

    lens : array(N, seq_len)
        Consists of the valid length of each observation sequence in the minibatch

    Returns
    -------
    * float
        The mean negative loglikelihood of the minibatch
    '''
    params_soft = HMMJax(softmax(params.trans_mat, axis=1),
                         softmax(params.obs_mat, axis=1),
                         softmax(params.init_dist))
    return -hmm_loglikelihood_jax(params_soft, batch, lens).mean()
예제 #2
0
def init_random_params(sizes, rng_key):
    '''
    Initializes the components of HMM from normal distibution

    Parameters
    ----------
    sizes: List
      Consists of number of hidden states and observable events, respectively

    rng_key : array
        Random key of shape (2,) and dtype uint32

    Returns
    -------
    * array(num_hidden, num_hidden)
      Transition probability matrix

    * array(num_hidden, num_obs)
      Emission probability matrix

    * array(1, num_hidden)
      Initial distribution probabilities
    '''
    num_hidden, num_obs = sizes
    rng_key, rng_a, rng_b, rng_pi = split(rng_key, 4)
    return HMMJax(normal(rng_a, (num_hidden, num_hidden)),
                  normal(rng_b, (num_hidden, num_obs)),
                  normal(rng_pi, (num_hidden, )))
예제 #3
0
n_hidden, n_obs = 100, 10
A = uniform(key_A, (n_hidden, n_hidden))
A = A / jnp.sum(A, axis=1)

# observation matrix
B = uniform(key_B, (n_hidden, n_obs))
B = B / jnp.sum(B, axis=1).reshape((-1, 1))

n_samples = 1000
init_state_dist = jnp.ones(n_hidden) / n_hidden

seed = 0
rng_key = PRNGKey(seed)

params_numpy = HMMNumpy(A, B, init_state_dist)
params_jax = HMMJax(A, B, init_state_dist)
hmm_distrax = HMM(trans_dist=distrax.Categorical(probs=A),
                  obs_dist=distrax.Categorical(probs=B),
                  init_dist=distrax.Categorical(probs=init_state_dist))

z_hist, x_hist = hmm_sample_jax(params_jax, n_samples, rng_key)

start = time.time()
alphas_np, _, gammas_np, loglikelihood_np = hmm_forwards_backwards_numpy(
    params_numpy, x_hist, len(x_hist))
print(
    f'Time taken by numpy version of forwards backwards : {time.time()-start}s'
)

start = time.time()
alphas_jax, _, gammas_jax, loglikelihood_jax = hmm_forwards_backwards_jax(
                                                             batches, lens)


# state transition matrix
A = jnp.array([[0.95, 0.05], [0.10, 0.90]])

# observation matrix
B = jnp.array([
    [1 / 6, 1 / 6, 1 / 6, 1 / 6, 1 / 6, 1 / 6],  # fair die
    [1 / 10, 1 / 10, 1 / 10, 1 / 10, 1 / 10, 5 / 10]  # loaded die
])

pi = jnp.array([1, 1]) / 2

params_numpy = HMMNumpy(np.array(A), np.array(B), np.array(pi))
params_jax = HMMJax(A, B, pi)

seed = 0
rng_key = PRNGKey(seed)
rng_key, rng_sample = split(rng_key)

n_obs_seq, batch_size, max_len = 15, 5, 10

observations, lens = hmm_utils.hmm_sample_n(params_jax, hmm_sample_jax,
                                            n_obs_seq, max_len, rng_sample)

observations, lens = hmm_utils.pad_sequences(observations, lens)

rng_key, rng_batch = split(rng_key)
batches, lens = hmm_utils.hmm_sample_minibatches(observations, lens,
                                                 batch_size, rng_batch)
예제 #5
0
from hmm_sgd_lib import fit
from hmm_utils import pad_sequences, hmm_sample_n
from jax.experimental import optimizers

# state transition matrix
A = jnp.array([[0.95, 0.05], [0.10, 0.90]])

# observation matrix
B = jnp.array([
    [1 / 6, 1 / 6, 1 / 6, 1 / 6, 1 / 6, 1 / 6],  # fair die
    [1 / 10, 1 / 10, 1 / 10, 1 / 10, 1 / 10, 5 / 10]  # loaded die
])

pi = jnp.array([1, 1]) / 2

casino = HMMJax(A, B, pi)
num_hidden, num_obs = 2, 6

seed = 0
rng_key = PRNGKey(seed)
rng_key, rng_sample = split(rng_key)

n_obs_seq, max_len = 4, 5000
num_epochs = 400

observations, lens = pad_sequences(
    *hmm_sample_n(casino, hmm_sample_jax, n_obs_seq, max_len, rng_sample))
optimizer = optimizers.momentum(step_size=1e-3, mass=0.95)

# Mini Batch Gradient Descent
batch_size = 2
예제 #6
0
def fit(observations,
        lens,
        num_hidden,
        num_obs,
        batch_size,
        optimizer,
        rng_key=None,
        num_epochs=1):
    '''
    Trains the HMM model with the given number of hidden states and observations via any optimizer.

    Parameters
    ----------
    observations: array(N, seq_len)
        All observation sequences

    lens : array(N, seq_len)
        Consists of the valid length of each observation sequence

    num_hidden : int
        The number of hidden state

    num_obs : int
        The number of observable events

    batch_size : int
        The number of observation sequences that will be included in each minibatch

    optimizer : jax.experimental.optimizers.Optimizer
        Optimizer that is used during training

    num_epochs : int
        The total number of iterations

    Returns
    -------
    * HMMJax
        Hidden Markov Model

    * array
      Consists of training losses
    '''
    global opt_init, opt_update, get_params

    if rng_key is None:
        rng_key = PRNGKey(0)

    rng_init, rng_iter = split(rng_key)
    params = init_random_params([num_hidden, num_obs], rng_init)
    opt_init, opt_update, get_params = optimizer
    opt_state = opt_init(params)
    itercount = itertools.count()

    def epoch_step(opt_state, key):
        def train_step(opt_state, params):
            batch, length = params
            opt_state, loss = update(next(itercount), opt_state, batch, length)
            return opt_state, loss

        batches, valid_lens = hmm_sample_minibatches(observations, lens,
                                                     batch_size, key)
        params = (batches, valid_lens)
        opt_state, losses = jax.lax.scan(train_step, opt_state, params)
        return opt_state, losses.mean()

    epochs = split(rng_iter, num_epochs)
    opt_state, losses = jax.lax.scan(epoch_step, opt_state, epochs)

    losses = losses.flatten()

    params = get_params(opt_state)
    params = HMMJax(softmax(params.trans_mat, axis=1),
                    softmax(params.obs_mat, axis=1), softmax(params.init_dist))
    return params, losses