Esempio n. 1
0
        def _step(t, loop_state, tas):
            """One step in tf.while_loop."""
            prev_latent_state = loop_state.latent_encoded
            prev_rnn_state = loop_state.rnn_state
            current_input = inputs[:, t, :]

            # Duplicate current observation to sample multiple trajectories.
            current_input = tf.tile(current_input, [num_samples, 1])

            rnn_input = tf.concat(
                [current_input, prev_latent_state],
                axis=-1)  # num_samples * BS, latent_dim+input_dim
            rnn_out, rnn_state = self.posterior_rnn(inputs=rnn_input,
                                                    states=prev_rnn_state)
            dist = self.posterior_dist(rnn_out)
            latent_state = dist.sample(seed=random_seed)

            ## rnn_state is a list of [batch_size, rnn_hidden_dim],
            ## after TA.stack(), the dimension will be
            ## [num_steps, 1 for GRU/2 for LSTM, batch, rnn_dim]
            tas_updates = [
                rnn_state, latent_state,
                dist.entropy(),
                dist.log_prob(latent_state)
            ]
            tas = utils.write_updates_to_tas(tas, t, tas_updates)

            return (t + 1,
                    loopstate(rnn_state=rnn_state,
                              latent_encoded=latent_state), tas)
Esempio n. 2
0
    def _steps(t, prev_prob, fwd_tas):
        """One step forward in iterations."""
        bi_t = log_b[:, t, :]  # log p(x[t+1] | s[t+1])
        aij_t = log_a[:, t, :, :]  # log p(s[t+1] | s[t], x[t])

        current_updates = tf.math.reduce_logsumexp(
            bi_t[:, :, tf.newaxis] + aij_t + prev_prob[:, tf.newaxis, :],
            axis=-1)
        current_updates = utils.normalize_logprob(current_updates, axis=-1)

        prev_prob = current_updates[0]
        fwd_tas = utils.write_updates_to_tas(fwd_tas, t, current_updates)

        return (t + 1, prev_prob, fwd_tas)
Esempio n. 3
0
    def _steps(t, next_prob, bwd_tas):
        """One step backward."""
        bi_tp1 = log_b[:, t + 1, :]  # log p(x[t+1] | s[t+1])
        aij_tp1 = log_a[:, t + 1, :, :]  # log p(s[t+1] | s[t], x[t])
        current_updates = tf.math.reduce_logsumexp(
            next_prob[:, :, tf.newaxis] + aij_tp1 + bi_tp1[:, :, tf.newaxis],
            axis=-2)

        current_updates = utils.normalize_logprob(current_updates, axis=-1)

        next_prob = current_updates[0]
        bwd_tas = utils.write_updates_to_tas(bwd_tas, t, current_updates)

        return (t - 1, next_prob, bwd_tas)
Esempio n. 4
0
def backward_pass(log_a, log_b, logprob_s0):
    """Computing the backward pass of Baum-Welch Algorithm.

  Args:
    log_a: a float `Tensor` of size [batch, num_steps, num_categ, num_categ]
      stores time dependent transition matrices, `log p(s[t] | s[t-1], x[t-1])`.
      `A[:, t, i, j]` is the transition probability from `s[t-1]=j` to `s[t]=i`.
      Since `A[:, t, :, :]` is using the information from `t-1`, `A[:, 0, :, :]`
      is meaningless, i.e. set to zeros.
    log_b: a float `Tensor` of size [batch, num_steps, num_categ] stores time
      dependent emission matrices, `log p(x[t](, z[t])| s[t])`
    logprob_s0: a float `Tensor` of size [num_categ], initial discrete states
      probability, `log p(s[0])`.

  Returns:
    backward_pass: a float `Tensor` of size [batch_size, num_steps, num_categ]
      stores the backward-pass  probability log p(s_t | x_t+1:T(, z_t+1:T)).
    normalizer: a float `Tensor` of size [batch, num_steps, num_categ] stores
      the normalizing probability, log p(x_t | x_t:T).
  """
    batch_size = tf.shape(log_b)[0]
    num_steps = tf.shape(log_b)[1]
    num_categ = logprob_s0.shape[0]

    tas = [
        tf.TensorArray(tf.float32, num_steps, name=n)
        for n in ["backward_prob", "normalizer"]
    ]

    init_updates = [
        tf.zeros([batch_size, num_categ]),
        tf.zeros([batch_size, 1])
    ]

    tas = utils.write_updates_to_tas(tas, num_steps - 1, init_updates)

    next_prob = init_updates[0]
    init_state = (num_steps - 2, next_prob, tas)

    def _cond(t, *unused_args):
        return t > -1

    def _steps(t, next_prob, bwd_tas):
        """One step backward."""
        bi_tp1 = log_b[:, t + 1, :]  # log p(x[t+1] | s[t+1])
        aij_tp1 = log_a[:, t + 1, :, :]  # log p(s[t+1] | s[t], x[t])
        current_updates = tf.math.reduce_logsumexp(
            next_prob[:, :, tf.newaxis] + aij_tp1 + bi_tp1[:, :, tf.newaxis],
            axis=-2)

        current_updates = utils.normalize_logprob(current_updates, axis=-1)

        next_prob = current_updates[0]
        bwd_tas = utils.write_updates_to_tas(bwd_tas, t, current_updates)

        return (t - 1, next_prob, bwd_tas)

    _, _, tas_final = tf.while_loop(_cond, _steps, init_state)

    backward_prob = tf.transpose(tas_final[0].stack(), [1, 0, 2])
    normalizer = tf.transpose(tf.squeeze(tas_final[1].stack(), axis=[-1]),
                              [1, 0])

    return backward_prob, normalizer
Esempio n. 5
0
def forward_pass(log_a, log_b, logprob_s0):
    """Computing the forward pass of Baum-Welch Algorithm.

  By employing log-exp-sum trick, values are computed in log space, including
  the output. Notation is adopted from https://arxiv.org/abs/1910.09588.
  `log_a` is the likelihood of discrete states, `log p(s[t] | s[t-1], x[t-1])`,
  `log_b` is the likelihood of observations, `log p(x[t], z[t] | s[t])`,
  and `logprob_s0` is the likelihood of initial discrete states, `log p(s[0])`.
  Forward pass calculates the filtering likelihood of `log p(s_t | x_1:t)`.

  Args:
    log_a: a float `Tensor` of size [batch, num_steps, num_categ, num_categ]
      stores time dependent transition matrices, `log p(s[t] | s[t-1], x[t-1])`.
      `A[i, j]` is the transition probability from `s[t-1]=j` to `s[t]=i`.
    log_b: a float `Tensor` of size [batch, num_steps, num_categ] stores time
      dependent emission matrices, 'log p(x[t](, z[t])| s[t])`.
    logprob_s0: a float `Tensor` of size [num_categ], initial discrete states
      probability, `log p(s[0])`.

  Returns:
    forward_pass: a float 3D `Tensor` of size [batch, num_steps, num_categ]
      stores the forward pass probability of `log p(s_t | x_1:t)`, which is
      normalized.
    normalizer: a float 2D `Tensor` of size [batch, num_steps] stores the
      normalizing probability, `log p(x_t | x_1:t-1)`.
  """
    num_steps = log_a.get_shape().with_rank_at_least(3).dims[1].value

    tas = [
        tf.TensorArray(tf.float32, num_steps, name=n)
        for n in ["forward_prob", "normalizer"]
    ]

    # The function will return normalized forward probability and
    # normalizing constant as a list, [forward_logprob, normalizer].
    init_updates = utils.normalize_logprob(logprob_s0[tf.newaxis, :] +
                                           log_b[:, 0, :],
                                           axis=-1)

    tas = utils.write_updates_to_tas(tas, 0, init_updates)

    prev_prob = init_updates[0]
    init_state = (1, prev_prob, tas)

    def _cond(t, *unused_args):
        return t < num_steps

    def _steps(t, prev_prob, fwd_tas):
        """One step forward in iterations."""
        bi_t = log_b[:, t, :]  # log p(x[t+1] | s[t+1])
        aij_t = log_a[:, t, :, :]  # log p(s[t+1] | s[t], x[t])

        current_updates = tf.math.reduce_logsumexp(
            bi_t[:, :, tf.newaxis] + aij_t + prev_prob[:, tf.newaxis, :],
            axis=-1)
        current_updates = utils.normalize_logprob(current_updates, axis=-1)

        prev_prob = current_updates[0]
        fwd_tas = utils.write_updates_to_tas(fwd_tas, t, current_updates)

        return (t + 1, prev_prob, fwd_tas)

    _, _, tas_final = tf.while_loop(_cond, _steps, init_state)

    # transpose to [batch, step, state]
    forward_prob = tf.transpose(tas_final[0].stack(), [1, 0, 2])
    normalizer = tf.transpose(tf.squeeze(tas_final[1].stack(), axis=[-1]),
                              [1, 0])
    return forward_prob, normalizer