def _step(t, loop_state, tas): """One step in tf.while_loop.""" prev_latent_state = loop_state.latent_encoded prev_rnn_state = loop_state.rnn_state current_input = inputs[:, t, :] # Duplicate current observation to sample multiple trajectories. current_input = tf.tile(current_input, [num_samples, 1]) rnn_input = tf.concat( [current_input, prev_latent_state], axis=-1) # num_samples * BS, latent_dim+input_dim rnn_out, rnn_state = self.posterior_rnn(inputs=rnn_input, states=prev_rnn_state) dist = self.posterior_dist(rnn_out) latent_state = dist.sample(seed=random_seed) ## rnn_state is a list of [batch_size, rnn_hidden_dim], ## after TA.stack(), the dimension will be ## [num_steps, 1 for GRU/2 for LSTM, batch, rnn_dim] tas_updates = [ rnn_state, latent_state, dist.entropy(), dist.log_prob(latent_state) ] tas = utils.write_updates_to_tas(tas, t, tas_updates) return (t + 1, loopstate(rnn_state=rnn_state, latent_encoded=latent_state), tas)
def _steps(t, prev_prob, fwd_tas): """One step forward in iterations.""" bi_t = log_b[:, t, :] # log p(x[t+1] | s[t+1]) aij_t = log_a[:, t, :, :] # log p(s[t+1] | s[t], x[t]) current_updates = tf.math.reduce_logsumexp( bi_t[:, :, tf.newaxis] + aij_t + prev_prob[:, tf.newaxis, :], axis=-1) current_updates = utils.normalize_logprob(current_updates, axis=-1) prev_prob = current_updates[0] fwd_tas = utils.write_updates_to_tas(fwd_tas, t, current_updates) return (t + 1, prev_prob, fwd_tas)
def _steps(t, next_prob, bwd_tas): """One step backward.""" bi_tp1 = log_b[:, t + 1, :] # log p(x[t+1] | s[t+1]) aij_tp1 = log_a[:, t + 1, :, :] # log p(s[t+1] | s[t], x[t]) current_updates = tf.math.reduce_logsumexp( next_prob[:, :, tf.newaxis] + aij_tp1 + bi_tp1[:, :, tf.newaxis], axis=-2) current_updates = utils.normalize_logprob(current_updates, axis=-1) next_prob = current_updates[0] bwd_tas = utils.write_updates_to_tas(bwd_tas, t, current_updates) return (t - 1, next_prob, bwd_tas)
def backward_pass(log_a, log_b, logprob_s0): """Computing the backward pass of Baum-Welch Algorithm. Args: log_a: a float `Tensor` of size [batch, num_steps, num_categ, num_categ] stores time dependent transition matrices, `log p(s[t] | s[t-1], x[t-1])`. `A[:, t, i, j]` is the transition probability from `s[t-1]=j` to `s[t]=i`. Since `A[:, t, :, :]` is using the information from `t-1`, `A[:, 0, :, :]` is meaningless, i.e. set to zeros. log_b: a float `Tensor` of size [batch, num_steps, num_categ] stores time dependent emission matrices, `log p(x[t](, z[t])| s[t])` logprob_s0: a float `Tensor` of size [num_categ], initial discrete states probability, `log p(s[0])`. Returns: backward_pass: a float `Tensor` of size [batch_size, num_steps, num_categ] stores the backward-pass probability log p(s_t | x_t+1:T(, z_t+1:T)). normalizer: a float `Tensor` of size [batch, num_steps, num_categ] stores the normalizing probability, log p(x_t | x_t:T). """ batch_size = tf.shape(log_b)[0] num_steps = tf.shape(log_b)[1] num_categ = logprob_s0.shape[0] tas = [ tf.TensorArray(tf.float32, num_steps, name=n) for n in ["backward_prob", "normalizer"] ] init_updates = [ tf.zeros([batch_size, num_categ]), tf.zeros([batch_size, 1]) ] tas = utils.write_updates_to_tas(tas, num_steps - 1, init_updates) next_prob = init_updates[0] init_state = (num_steps - 2, next_prob, tas) def _cond(t, *unused_args): return t > -1 def _steps(t, next_prob, bwd_tas): """One step backward.""" bi_tp1 = log_b[:, t + 1, :] # log p(x[t+1] | s[t+1]) aij_tp1 = log_a[:, t + 1, :, :] # log p(s[t+1] | s[t], x[t]) current_updates = tf.math.reduce_logsumexp( next_prob[:, :, tf.newaxis] + aij_tp1 + bi_tp1[:, :, tf.newaxis], axis=-2) current_updates = utils.normalize_logprob(current_updates, axis=-1) next_prob = current_updates[0] bwd_tas = utils.write_updates_to_tas(bwd_tas, t, current_updates) return (t - 1, next_prob, bwd_tas) _, _, tas_final = tf.while_loop(_cond, _steps, init_state) backward_prob = tf.transpose(tas_final[0].stack(), [1, 0, 2]) normalizer = tf.transpose(tf.squeeze(tas_final[1].stack(), axis=[-1]), [1, 0]) return backward_prob, normalizer
def forward_pass(log_a, log_b, logprob_s0): """Computing the forward pass of Baum-Welch Algorithm. By employing log-exp-sum trick, values are computed in log space, including the output. Notation is adopted from https://arxiv.org/abs/1910.09588. `log_a` is the likelihood of discrete states, `log p(s[t] | s[t-1], x[t-1])`, `log_b` is the likelihood of observations, `log p(x[t], z[t] | s[t])`, and `logprob_s0` is the likelihood of initial discrete states, `log p(s[0])`. Forward pass calculates the filtering likelihood of `log p(s_t | x_1:t)`. Args: log_a: a float `Tensor` of size [batch, num_steps, num_categ, num_categ] stores time dependent transition matrices, `log p(s[t] | s[t-1], x[t-1])`. `A[i, j]` is the transition probability from `s[t-1]=j` to `s[t]=i`. log_b: a float `Tensor` of size [batch, num_steps, num_categ] stores time dependent emission matrices, 'log p(x[t](, z[t])| s[t])`. logprob_s0: a float `Tensor` of size [num_categ], initial discrete states probability, `log p(s[0])`. Returns: forward_pass: a float 3D `Tensor` of size [batch, num_steps, num_categ] stores the forward pass probability of `log p(s_t | x_1:t)`, which is normalized. normalizer: a float 2D `Tensor` of size [batch, num_steps] stores the normalizing probability, `log p(x_t | x_1:t-1)`. """ num_steps = log_a.get_shape().with_rank_at_least(3).dims[1].value tas = [ tf.TensorArray(tf.float32, num_steps, name=n) for n in ["forward_prob", "normalizer"] ] # The function will return normalized forward probability and # normalizing constant as a list, [forward_logprob, normalizer]. init_updates = utils.normalize_logprob(logprob_s0[tf.newaxis, :] + log_b[:, 0, :], axis=-1) tas = utils.write_updates_to_tas(tas, 0, init_updates) prev_prob = init_updates[0] init_state = (1, prev_prob, tas) def _cond(t, *unused_args): return t < num_steps def _steps(t, prev_prob, fwd_tas): """One step forward in iterations.""" bi_t = log_b[:, t, :] # log p(x[t+1] | s[t+1]) aij_t = log_a[:, t, :, :] # log p(s[t+1] | s[t], x[t]) current_updates = tf.math.reduce_logsumexp( bi_t[:, :, tf.newaxis] + aij_t + prev_prob[:, tf.newaxis, :], axis=-1) current_updates = utils.normalize_logprob(current_updates, axis=-1) prev_prob = current_updates[0] fwd_tas = utils.write_updates_to_tas(fwd_tas, t, current_updates) return (t + 1, prev_prob, fwd_tas) _, _, tas_final = tf.while_loop(_cond, _steps, init_state) # transpose to [batch, step, state] forward_prob = tf.transpose(tas_final[0].stack(), [1, 0, 2]) normalizer = tf.transpose(tf.squeeze(tas_final[1].stack(), axis=[-1]), [1, 0]) return forward_prob, normalizer