def forward_backward(log_a, log_b, log_init): """Forward backward algorithm.""" fwd, _ = forward_pass(log_a, log_b, log_init) bwd, _ = backward_pass(log_a, log_b, log_init) m_fwd = fwd[:, 0:-1, tf.newaxis, :] m_bwd = bwd[:, 1:, :, tf.newaxis] m_a = log_a[:, 1:, :, :] m_b = log_b[:, 1:, :, tf.newaxis] # posterior score posterior = fwd + bwd gamma_ij = m_fwd + m_a + m_bwd + m_b # normalize the probability matrices posterior, _ = utils.normalize_logprob(posterior, axis=-1) gamma_ij, _ = utils.normalize_logprob(gamma_ij, axis=[-2, -1]) # padding the matrix to the same shape of inputs gamma_ij = tf.concat([ tf.zeros( [tf.shape(log_a)[0], 1, tf.shape(log_a)[2], tf.shape(log_a)[3]]), gamma_ij ], axis=1, name="concat_f_b") return fwd, bwd, posterior, gamma_ij
def test_normalize_logprob(self): input_prob = np.random.uniform(low=1e-6, high=1., size=[2, 3, 6]) log_normalizer = np.log(np.sum(input_prob, axis=-1, keepdims=True)) input_logprob = np.log(input_prob) target_logprob = input_logprob - log_normalizer self.assertAllClose( self.evaluate(utils.normalize_logprob(input_logprob)[0]), target_logprob) input_tensor = np.log([0.1, 0.3, 0.5]) target_logprob = np.log([1./3., 1./3., 1./3.]) temperature = 1e5 self.assertAllClose( self.evaluate(utils.normalize_logprob( input_tensor, temperature=temperature)[0]), target_logprob, rtol=1e-4, atol=1e-4, )
def _steps(t, prev_prob, fwd_tas): """One step forward in iterations.""" bi_t = log_b[:, t, :] # log p(x[t+1] | s[t+1]) aij_t = log_a[:, t, :, :] # log p(s[t+1] | s[t], x[t]) current_updates = tf.math.reduce_logsumexp( bi_t[:, :, tf.newaxis] + aij_t + prev_prob[:, tf.newaxis, :], axis=-1) current_updates = utils.normalize_logprob(current_updates, axis=-1) prev_prob = current_updates[0] fwd_tas = utils.write_updates_to_tas(fwd_tas, t, current_updates) return (t + 1, prev_prob, fwd_tas)
def _steps(t, next_prob, bwd_tas): """One step backward.""" bi_tp1 = log_b[:, t + 1, :] # log p(x[t+1] | s[t+1]) aij_tp1 = log_a[:, t + 1, :, :] # log p(s[t+1] | s[t], x[t]) current_updates = tf.math.reduce_logsumexp( next_prob[:, :, tf.newaxis] + aij_tp1 + bi_tp1[:, :, tf.newaxis], axis=-2) current_updates = utils.normalize_logprob(current_updates, axis=-1) next_prob = current_updates[0] bwd_tas = utils.write_updates_to_tas(bwd_tas, t, current_updates) return (t - 1, next_prob, bwd_tas)
def __init__(self, continuous_transition_network, discrete_transition_network, emission_network, inference_network, initial_distribution, continuous_state_dim=None, num_categories=None, discrete_state_prior=None): """Constructor of Switching Non-Linear Dynamical System. The model framework, as described in Dong et al. (2019)[1]. Args: continuous_transition_network: a `callable` with its `call` function taking batched sequences of continuous hidden states, `z[t-1]`, with shape [batch_size, num_steps, hidden_states], and returning a distribution with its `log_prob` function implemented. The `log_prob` function takes continuous hidden states, `z[t]`, and returns their likelihood, `p(z[t] | z[t-1], s[t])`. discrete_transition_network: a `callable` with its `call` function taking batch conditional inputs, `x[t-1]`, and returning the discrete state transition matrices, `log p(s[t] |s[t-1], x[t-1])`. emission_network: a `callable` with its `call` function taking continuous hidden states, `z[t]`, and returning a distribution, `p(x[t] | z[t])`. The distribution should have `mean` and `sample` function, similar as the classes in `tfp.distributions`. inference_network: inference network should be a class that has `sample` function, which takes input observations, `x[1:T]`, and outputs the sampled hidden states sequence of `q(z[1:T] | x[1:T])` and the entropy of the distribution. initial_distribution: a initial state distribution for continuous variables, `p(z[0])`. continuous_state_dim: number of continuous hidden units, `z[t]`. num_categories: number of discrete hidden states, `s[t]`. discrete_state_prior: a `float` Tensor, indicating the prior of discrete state distribution, `p[k] = p(s[t]=k)`. This is used by cross entropy regularizer, which tries to minize the difference between discrete_state_prior and the smoothed likelihood of the discrete states, `p(s[t] | x[1:T], z[1:T])`. Reference: [1] Dong, Zhe and Seybold, Bryan A. and Murphy, Kevin P., and Bui, Hung H.. Collapsed Amortized Variational Inference for Switching Nonlinear Dynamical Systems. 2019. https://arxiv.org/abs/1910.09588. """ super(SwitchingNLDS, self).__init__() self.z_tran = continuous_transition_network self.s_tran = discrete_transition_network self.x_emit = emission_network self.inference_network = inference_network self.z0_dist = initial_distribution if num_categories is None: self.num_categ = self.s_tran.output_event_dims else: self.num_categ = num_categories if continuous_state_dim is None: self.z_dim = self.z_tran.output_event_dims else: self.z_dim = continuous_state_dim if discrete_state_prior is None: self.discrete_prior = tf.ones(shape=[self.num_categ], dtype=tf.float32) / self.num_categ else: self.discrete_prior = discrete_state_prior self.log_init = tf.Variable(utils.normalize_logprob(tf.ones( shape=[self.num_categ], dtype=tf.float32), axis=-1)[0], name="snlds_logprob_s0")
def calculate_likelihoods(self, inputs, sampled_z, switching_conditional_inputs=None, temperature=1.0): """Calculate the probability by p network, `p_theta(x,z,s)`. Args: inputs: a float 3-D `Tensor` of shape [batch_size, num_steps, obs_dim], containing the observation time series of the model. sampled_z: a float 3-D `Tensor` of shape [batch_size, num_steps, latent_dim] for continuous hidden states, which are sampled from inference networks, `q(z[1:T] | x[1:T])`. switching_conditional_inputs: a float 3-D `Tensor` of shape [batch_size, num_steps, encoded_dim], which is the conditional input for discrete state transition probability, `p(s[t] | s[t-1], x[t-1])`. Default to `None`, when `inputs` will be used. temperature: a float scalar `Tensor`, indicates the temperature for transition probability, `p(s[t] | s[t-1], x[t-1])`. Returns: log_xt_zt: a float `Tensor` of size [batch_size, num_steps, num_categ] indicates the distribution, `log(p(x_t | z_t) p(z_t | z_t-1, s_t))`. prob_st_stm1: a float `Tensor` of size [batch_size, num_steps, num_categ, num_categ] indicates the transition probablity, `p(s_t | s_t-1, x_t-1)`. reconstruced_inputs: a float `Tensor` of size [batch_size, num_steps, obs_dim] for reconstructed inputs. """ batch_size, num_steps = tf.unstack(tf.shape(inputs)[:2]) num_steps = inputs.get_shape().with_rank_at_least(3).dims[1].value ######################################## ## getting log p(z[t] | z[t-1], s[t]) ######################################## # Broadcasting rules of TFP dictate that: if the samples_z0 of dimension # [batch_size, 1, event_size], z0_dist is of [num_categ, event_size]. # z0_dist.log_prob(samples_z0[:, None, :]) is of [batch_size, num_categ]. sampled_z0 = sampled_z[:, 0, :] log_prob_z0 = self.z0_dist.log_prob(sampled_z0[:, tf.newaxis, :]) log_prob_z0 = log_prob_z0[:, tf.newaxis, :] # `log_prob_zt` should be of the shape [batch_size, num_steps, self.z_dim] log_prob_zt = self.get_z_prior(sampled_z, log_prob_z0) ######################################## ## getting log p(x[t] | z[t]) ######################################## emission_dist = self.x_emit(sampled_z) # `emission_dist' should have the same event shape as `inputs', # by broadcasting rule, the `log_prob_xt' should be of the shape # [batch_size, num_steps], log_prob_xt = emission_dist.log_prob( tf.reshape(inputs, [batch_size, num_steps, -1])) ######################################## ## getting log p(s[t] |s[t-1], x[t-1]) ######################################## if switching_conditional_inputs is None: switching_conditional_inputs = inputs log_prob_st_stm1 = tf.reshape( self.s_tran(switching_conditional_inputs[:, :-1, :]), [batch_size, num_steps - 1, self.num_categ, self.num_categ]) # by normalizing the 3rd axis (axis=-2), we restrict A[:,:,i,j] to be # transiting from s[t-1]=j -> s[t]=i log_prob_st_stm1 = utils.normalize_logprob(log_prob_st_stm1, axis=-2, temperature=temperature)[0] log_prob_st_stm1 = tf.concat([ tf.eye(self.num_categ, self.num_categ, batch_shape=[batch_size, 1], dtype=tf.float32, name="concat_likelihoods"), log_prob_st_stm1 ], axis=1) # log ( p(x_t | z_t) p(z_t | z_t-1, s_t) ) log_xt_zt = log_prob_xt[:, :, tf.newaxis] + log_prob_zt return log_xt_zt, log_prob_st_stm1
def forward_pass(log_a, log_b, logprob_s0): """Computing the forward pass of Baum-Welch Algorithm. By employing log-exp-sum trick, values are computed in log space, including the output. Notation is adopted from https://arxiv.org/abs/1910.09588. `log_a` is the likelihood of discrete states, `log p(s[t] | s[t-1], x[t-1])`, `log_b` is the likelihood of observations, `log p(x[t], z[t] | s[t])`, and `logprob_s0` is the likelihood of initial discrete states, `log p(s[0])`. Forward pass calculates the filtering likelihood of `log p(s_t | x_1:t)`. Args: log_a: a float `Tensor` of size [batch, num_steps, num_categ, num_categ] stores time dependent transition matrices, `log p(s[t] | s[t-1], x[t-1])`. `A[i, j]` is the transition probability from `s[t-1]=j` to `s[t]=i`. log_b: a float `Tensor` of size [batch, num_steps, num_categ] stores time dependent emission matrices, 'log p(x[t](, z[t])| s[t])`. logprob_s0: a float `Tensor` of size [num_categ], initial discrete states probability, `log p(s[0])`. Returns: forward_pass: a float 3D `Tensor` of size [batch, num_steps, num_categ] stores the forward pass probability of `log p(s_t | x_1:t)`, which is normalized. normalizer: a float 2D `Tensor` of size [batch, num_steps] stores the normalizing probability, `log p(x_t | x_1:t-1)`. """ num_steps = log_a.get_shape().with_rank_at_least(3).dims[1].value tas = [ tf.TensorArray(tf.float32, num_steps, name=n) for n in ["forward_prob", "normalizer"] ] # The function will return normalized forward probability and # normalizing constant as a list, [forward_logprob, normalizer]. init_updates = utils.normalize_logprob(logprob_s0[tf.newaxis, :] + log_b[:, 0, :], axis=-1) tas = utils.write_updates_to_tas(tas, 0, init_updates) prev_prob = init_updates[0] init_state = (1, prev_prob, tas) def _cond(t, *unused_args): return t < num_steps def _steps(t, prev_prob, fwd_tas): """One step forward in iterations.""" bi_t = log_b[:, t, :] # log p(x[t+1] | s[t+1]) aij_t = log_a[:, t, :, :] # log p(s[t+1] | s[t], x[t]) current_updates = tf.math.reduce_logsumexp( bi_t[:, :, tf.newaxis] + aij_t + prev_prob[:, tf.newaxis, :], axis=-1) current_updates = utils.normalize_logprob(current_updates, axis=-1) prev_prob = current_updates[0] fwd_tas = utils.write_updates_to_tas(fwd_tas, t, current_updates) return (t + 1, prev_prob, fwd_tas) _, _, tas_final = tf.while_loop(_cond, _steps, init_state) # transpose to [batch, step, state] forward_prob = tf.transpose(tas_final[0].stack(), [1, 0, 2]) normalizer = tf.transpose(tf.squeeze(tas_final[1].stack(), axis=[-1]), [1, 0]) return forward_prob, normalizer