Example #1
0
def forward_backward(log_a, log_b, log_init):
    """Forward backward algorithm."""
    fwd, _ = forward_pass(log_a, log_b, log_init)
    bwd, _ = backward_pass(log_a, log_b, log_init)

    m_fwd = fwd[:, 0:-1, tf.newaxis, :]
    m_bwd = bwd[:, 1:, :, tf.newaxis]
    m_a = log_a[:, 1:, :, :]
    m_b = log_b[:, 1:, :, tf.newaxis]

    # posterior score
    posterior = fwd + bwd
    gamma_ij = m_fwd + m_a + m_bwd + m_b

    # normalize the probability matrices
    posterior, _ = utils.normalize_logprob(posterior, axis=-1)
    gamma_ij, _ = utils.normalize_logprob(gamma_ij, axis=[-2, -1])

    # padding the matrix to the same shape of inputs
    gamma_ij = tf.concat([
        tf.zeros(
            [tf.shape(log_a)[0], 1,
             tf.shape(log_a)[2],
             tf.shape(log_a)[3]]), gamma_ij
    ],
                         axis=1,
                         name="concat_f_b")

    return fwd, bwd, posterior, gamma_ij
  def test_normalize_logprob(self):
    input_prob = np.random.uniform(low=1e-6, high=1., size=[2, 3, 6])
    log_normalizer = np.log(np.sum(input_prob, axis=-1, keepdims=True))
    input_logprob = np.log(input_prob)

    target_logprob = input_logprob - log_normalizer
    self.assertAllClose(
        self.evaluate(utils.normalize_logprob(input_logprob)[0]),
        target_logprob)

    input_tensor = np.log([0.1, 0.3, 0.5])
    target_logprob = np.log([1./3., 1./3., 1./3.])
    temperature = 1e5
    self.assertAllClose(
        self.evaluate(utils.normalize_logprob(
            input_tensor, temperature=temperature)[0]),
        target_logprob,
        rtol=1e-4,
        atol=1e-4,
        )
Example #3
0
    def _steps(t, prev_prob, fwd_tas):
        """One step forward in iterations."""
        bi_t = log_b[:, t, :]  # log p(x[t+1] | s[t+1])
        aij_t = log_a[:, t, :, :]  # log p(s[t+1] | s[t], x[t])

        current_updates = tf.math.reduce_logsumexp(
            bi_t[:, :, tf.newaxis] + aij_t + prev_prob[:, tf.newaxis, :],
            axis=-1)
        current_updates = utils.normalize_logprob(current_updates, axis=-1)

        prev_prob = current_updates[0]
        fwd_tas = utils.write_updates_to_tas(fwd_tas, t, current_updates)

        return (t + 1, prev_prob, fwd_tas)
Example #4
0
    def _steps(t, next_prob, bwd_tas):
        """One step backward."""
        bi_tp1 = log_b[:, t + 1, :]  # log p(x[t+1] | s[t+1])
        aij_tp1 = log_a[:, t + 1, :, :]  # log p(s[t+1] | s[t], x[t])
        current_updates = tf.math.reduce_logsumexp(
            next_prob[:, :, tf.newaxis] + aij_tp1 + bi_tp1[:, :, tf.newaxis],
            axis=-2)

        current_updates = utils.normalize_logprob(current_updates, axis=-1)

        next_prob = current_updates[0]
        bwd_tas = utils.write_updates_to_tas(bwd_tas, t, current_updates)

        return (t - 1, next_prob, bwd_tas)
    def __init__(self,
                 continuous_transition_network,
                 discrete_transition_network,
                 emission_network,
                 inference_network,
                 initial_distribution,
                 continuous_state_dim=None,
                 num_categories=None,
                 discrete_state_prior=None):
        """Constructor of Switching Non-Linear Dynamical System.

    The model framework, as described in Dong et al. (2019)[1].

    Args:
      continuous_transition_network: a `callable` with its `call` function
        taking batched sequences of continuous hidden states, `z[t-1]`, with
        shape [batch_size, num_steps, hidden_states], and returning a
        distribution with its `log_prob` function implemented. The `log_prob`
        function takes continuous hidden states, `z[t]`, and returns their
        likelihood, `p(z[t] | z[t-1], s[t])`.
      discrete_transition_network: a `callable` with its `call` function
        taking batch conditional inputs, `x[t-1]`, and returning the discrete
        state transition matrices, `log p(s[t] |s[t-1], x[t-1])`.
      emission_network: a `callable` with its `call` function taking
        continuous hidden states, `z[t]`, and returning a distribution,
        `p(x[t] | z[t])`. The distribution should have `mean` and `sample`
        function, similar as the classes in `tfp.distributions`.
      inference_network: inference network should be a class that has
        `sample` function, which takes input observations, `x[1:T]`,
        and outputs the sampled hidden states sequence of `q(z[1:T] | x[1:T])`
        and the entropy of the distribution.
      initial_distribution: a initial state distribution for continuous
        variables, `p(z[0])`.
      continuous_state_dim: number of continuous hidden units, `z[t]`.
      num_categories: number of discrete hidden states, `s[t]`.
      discrete_state_prior: a `float` Tensor, indicating the prior
        of discrete state distribution, `p[k] = p(s[t]=k)`. This is used by
        cross entropy regularizer, which tries to minize the difference between
        discrete_state_prior and the smoothed likelihood of the discrete states,
        `p(s[t] | x[1:T], z[1:T])`.

    Reference:
      [1] Dong, Zhe and Seybold, Bryan A. and Murphy, Kevin P., and Bui,
          Hung H.. Collapsed Amortized Variational Inference for Switching
          Nonlinear Dynamical Systems. 2019. https://arxiv.org/abs/1910.09588.
    """
        super(SwitchingNLDS, self).__init__()

        self.z_tran = continuous_transition_network
        self.s_tran = discrete_transition_network
        self.x_emit = emission_network
        self.inference_network = inference_network
        self.z0_dist = initial_distribution

        if num_categories is None:
            self.num_categ = self.s_tran.output_event_dims
        else:
            self.num_categ = num_categories

        if continuous_state_dim is None:
            self.z_dim = self.z_tran.output_event_dims
        else:
            self.z_dim = continuous_state_dim

        if discrete_state_prior is None:
            self.discrete_prior = tf.ones(shape=[self.num_categ],
                                          dtype=tf.float32) / self.num_categ
        else:
            self.discrete_prior = discrete_state_prior

        self.log_init = tf.Variable(utils.normalize_logprob(tf.ones(
            shape=[self.num_categ], dtype=tf.float32),
                                                            axis=-1)[0],
                                    name="snlds_logprob_s0")
    def calculate_likelihoods(self,
                              inputs,
                              sampled_z,
                              switching_conditional_inputs=None,
                              temperature=1.0):
        """Calculate the probability by p network, `p_theta(x,z,s)`.

    Args:
      inputs: a float 3-D `Tensor` of shape [batch_size, num_steps, obs_dim],
        containing the observation time series of the model.
      sampled_z: a float 3-D `Tensor` of shape [batch_size, num_steps,
        latent_dim] for continuous hidden states, which are sampled from
        inference networks, `q(z[1:T] | x[1:T])`.
      switching_conditional_inputs: a float 3-D `Tensor` of shape [batch_size,
        num_steps, encoded_dim], which is the conditional input for discrete
        state transition probability, `p(s[t] | s[t-1], x[t-1])`.
        Default to `None`, when `inputs` will be used.
      temperature: a float scalar `Tensor`, indicates the temperature for
        transition probability, `p(s[t] | s[t-1], x[t-1])`.

    Returns:
      log_xt_zt: a float `Tensor` of size [batch_size, num_steps, num_categ]
        indicates the distribution, `log(p(x_t | z_t) p(z_t | z_t-1, s_t))`.
      prob_st_stm1: a float `Tensor` of size [batch_size, num_steps, num_categ,
        num_categ] indicates the transition probablity, `p(s_t | s_t-1, x_t-1)`.
      reconstruced_inputs: a float `Tensor` of size [batch_size, num_steps,
        obs_dim] for reconstructed inputs.
    """
        batch_size, num_steps = tf.unstack(tf.shape(inputs)[:2])
        num_steps = inputs.get_shape().with_rank_at_least(3).dims[1].value

        ########################################
        ## getting log p(z[t] | z[t-1], s[t])
        ########################################

        # Broadcasting rules of TFP dictate that: if the samples_z0 of dimension
        # [batch_size, 1, event_size], z0_dist is of [num_categ, event_size].
        # z0_dist.log_prob(samples_z0[:, None, :]) is of [batch_size, num_categ].
        sampled_z0 = sampled_z[:, 0, :]
        log_prob_z0 = self.z0_dist.log_prob(sampled_z0[:, tf.newaxis, :])
        log_prob_z0 = log_prob_z0[:, tf.newaxis, :]

        # `log_prob_zt` should be of the shape [batch_size, num_steps, self.z_dim]
        log_prob_zt = self.get_z_prior(sampled_z, log_prob_z0)

        ########################################
        ## getting log p(x[t] | z[t])
        ########################################

        emission_dist = self.x_emit(sampled_z)

        # `emission_dist' should have the same event shape as `inputs',
        # by broadcasting rule, the `log_prob_xt' should be of the shape
        # [batch_size, num_steps],
        log_prob_xt = emission_dist.log_prob(
            tf.reshape(inputs, [batch_size, num_steps, -1]))

        ########################################
        ## getting log p(s[t] |s[t-1], x[t-1])
        ########################################

        if switching_conditional_inputs is None:
            switching_conditional_inputs = inputs
        log_prob_st_stm1 = tf.reshape(
            self.s_tran(switching_conditional_inputs[:, :-1, :]),
            [batch_size, num_steps - 1, self.num_categ, self.num_categ])
        # by normalizing the 3rd axis (axis=-2), we restrict A[:,:,i,j] to be
        # transiting from s[t-1]=j -> s[t]=i
        log_prob_st_stm1 = utils.normalize_logprob(log_prob_st_stm1,
                                                   axis=-2,
                                                   temperature=temperature)[0]

        log_prob_st_stm1 = tf.concat([
            tf.eye(self.num_categ,
                   self.num_categ,
                   batch_shape=[batch_size, 1],
                   dtype=tf.float32,
                   name="concat_likelihoods"), log_prob_st_stm1
        ],
                                     axis=1)

        # log ( p(x_t | z_t) p(z_t | z_t-1, s_t) )
        log_xt_zt = log_prob_xt[:, :, tf.newaxis] + log_prob_zt
        return log_xt_zt, log_prob_st_stm1
Example #7
0
def forward_pass(log_a, log_b, logprob_s0):
    """Computing the forward pass of Baum-Welch Algorithm.

  By employing log-exp-sum trick, values are computed in log space, including
  the output. Notation is adopted from https://arxiv.org/abs/1910.09588.
  `log_a` is the likelihood of discrete states, `log p(s[t] | s[t-1], x[t-1])`,
  `log_b` is the likelihood of observations, `log p(x[t], z[t] | s[t])`,
  and `logprob_s0` is the likelihood of initial discrete states, `log p(s[0])`.
  Forward pass calculates the filtering likelihood of `log p(s_t | x_1:t)`.

  Args:
    log_a: a float `Tensor` of size [batch, num_steps, num_categ, num_categ]
      stores time dependent transition matrices, `log p(s[t] | s[t-1], x[t-1])`.
      `A[i, j]` is the transition probability from `s[t-1]=j` to `s[t]=i`.
    log_b: a float `Tensor` of size [batch, num_steps, num_categ] stores time
      dependent emission matrices, 'log p(x[t](, z[t])| s[t])`.
    logprob_s0: a float `Tensor` of size [num_categ], initial discrete states
      probability, `log p(s[0])`.

  Returns:
    forward_pass: a float 3D `Tensor` of size [batch, num_steps, num_categ]
      stores the forward pass probability of `log p(s_t | x_1:t)`, which is
      normalized.
    normalizer: a float 2D `Tensor` of size [batch, num_steps] stores the
      normalizing probability, `log p(x_t | x_1:t-1)`.
  """
    num_steps = log_a.get_shape().with_rank_at_least(3).dims[1].value

    tas = [
        tf.TensorArray(tf.float32, num_steps, name=n)
        for n in ["forward_prob", "normalizer"]
    ]

    # The function will return normalized forward probability and
    # normalizing constant as a list, [forward_logprob, normalizer].
    init_updates = utils.normalize_logprob(logprob_s0[tf.newaxis, :] +
                                           log_b[:, 0, :],
                                           axis=-1)

    tas = utils.write_updates_to_tas(tas, 0, init_updates)

    prev_prob = init_updates[0]
    init_state = (1, prev_prob, tas)

    def _cond(t, *unused_args):
        return t < num_steps

    def _steps(t, prev_prob, fwd_tas):
        """One step forward in iterations."""
        bi_t = log_b[:, t, :]  # log p(x[t+1] | s[t+1])
        aij_t = log_a[:, t, :, :]  # log p(s[t+1] | s[t], x[t])

        current_updates = tf.math.reduce_logsumexp(
            bi_t[:, :, tf.newaxis] + aij_t + prev_prob[:, tf.newaxis, :],
            axis=-1)
        current_updates = utils.normalize_logprob(current_updates, axis=-1)

        prev_prob = current_updates[0]
        fwd_tas = utils.write_updates_to_tas(fwd_tas, t, current_updates)

        return (t + 1, prev_prob, fwd_tas)

    _, _, tas_final = tf.while_loop(_cond, _steps, init_state)

    # transpose to [batch, step, state]
    forward_prob = tf.transpose(tas_final[0].stack(), [1, 0, 2])
    normalizer = tf.transpose(tf.squeeze(tas_final[1].stack(), axis=[-1]),
                              [1, 0])
    return forward_prob, normalizer