def _extract_log_probs(num_states, dist):
    """Tabulate log probabilities from a batch of distributions."""

    states = tf.reshape(
        tf.range(num_states),
        tf.concat([[num_states],
                   tf.ones_like(dist.batch_shape_tensor())],
                  axis=0))
    return distribution_util.move_dimension(dist.log_prob(states), 0, -1)
                    def _reduce_multiple_steps():
                        """Perform `reduce_max` operation when `num_steps` > 1."""
                        def forward_step(previous_step_pair,
                                         log_prob_observation):
                            log_prob_previous = previous_step_pair[0]
                            log_prob = (
                                log_prob_previous[..., tf.newaxis] +
                                self._log_trans +
                                log_prob_observation[..., tf.newaxis, :])
                            most_likely_given_successor = tf.argmax(log_prob,
                                                                    axis=-2)
                            max_log_p_given_successor = tf.reduce_max(
                                input_tensor=log_prob, axis=-2)
                            return (max_log_p_given_successor,
                                    most_likely_given_successor)

                        forward_log_probs, all_most_likely_given_successor = tf.scan(
                            forward_step,
                            observation_log_probs[1:],
                            initializer=(log_prob,
                                         tf.zeros(tf.shape(log_prob),
                                                  dtype=tf.int64)),
                            name="forward_log_probs")

                        most_likely_end = tf.argmax(forward_log_probs[-1],
                                                    axis=-1)

                        # We require the operation that gives C from A and B where
                        # C[i...j] = A[i...j, B[i...j]]
                        # and A = most_likely_given_successor
                        #     B = most_likely_successor.
                        # tf.gather requires indices of known shape so instead we use
                        # reduction with tf.one_hot(B) to pick out elements from B
                        def backward_step(most_likely_successor,
                                          most_likely_given_successor):
                            return tf.reduce_sum(
                                input_tensor=(most_likely_given_successor *
                                              tf.one_hot(most_likely_successor,
                                                         self._num_states,
                                                         dtype=tf.int64)),
                                axis=-1)

                        backward_scan = tf.scan(
                            backward_step,
                            all_most_likely_given_successor,
                            most_likely_end,
                            reverse=True)
                        most_likely_sequences = tf.concat(
                            [backward_scan, [most_likely_end]], axis=0)
                        return distribution_util.move_dimension(
                            most_likely_sequences, 0, -1)
    def _log_prob(self, value):
        with tf.control_dependencies(self._runtime_assertions):
            # The argument `value` is a tensor of sequences of observations.
            # `observation_batch_shape` is the shape of that tensor with the
            # sequence part removed.
            # `observation_batch_shape` is then broadcast to the full batch shape
            # to give the `batch_shape` that defines the shape of the result.

            observation_tensor_shape = tf.shape(value)
            observation_batch_shape = observation_tensor_shape[:-1 - self.
                                                               _underlying_event_rank]
            # value :: observation_batch_shape num_steps observation_event_shape
            batch_shape = tf.broadcast_dynamic_shape(observation_batch_shape,
                                                     self.batch_shape_tensor())
            log_init = tf.broadcast_to(
                self._log_init,
                tf.concat([batch_shape, [self._num_states]], axis=0))
            # log_init :: batch_shape num_states
            log_transition = self._log_trans

            # `observation_event_shape` is the shape of each sequence of observations
            # emitted by the model.
            observation_event_shape = observation_tensor_shape[
                -1 - self._underlying_event_rank:]
            working_obs = tf.broadcast_to(
                value, tf.concat([batch_shape, observation_event_shape],
                                 axis=0))
            # working_obs :: batch_shape observation_event_shape
            r = self._underlying_event_rank

            # Move index into sequence of observations to front so we can apply
            # tf.foldl
            working_obs = distribution_util.move_dimension(
                working_obs, -1 - r, 0)[..., tf.newaxis]
            # working_obs :: num_steps batch_shape underlying_event_shape
            observation_probs = (
                self._observation_distribution.log_prob(working_obs))

            def forward_step(log_prev_step, log_prob_observation):
                return _log_vector_matrix(
                    log_prev_step, log_transition) + log_prob_observation

            fwd_prob = tf.foldl(forward_step,
                                observation_probs,
                                initializer=log_init)
            # fwd_prob :: batch_shape num_states

            log_prob = tf.reduce_logsumexp(fwd_prob, axis=-1)
            # log_prob :: batch_shape

            return log_prob
    def posterior_marginals(self, observations, mask=None, name=None):
        """Compute marginal posterior distribution for each state.

    This function computes, for each time step, the marginal
    conditional probability that the hidden Markov model was in
    each possible state given the observations that were made
    at each time step.
    So if the hidden states are `z[0],...,z[num_steps - 1]` and
    the observations are `x[0], ..., x[num_steps - 1]`, then
    this function computes `P(z[i] | x[0], ..., x[num_steps - 1])`
    for all `i` from `0` to `num_steps - 1`.

    This operation is sometimes called smoothing. It uses a form
    of the forward-backward algorithm.

    Note: the behavior of this function is undefined if the
    `observations` argument represents impossible observations
    from the model.

    Args:
      observations: A tensor representing a batch of observations
        made on the hidden Markov model.  The rightmost dimension of this tensor
        gives the steps in a sequence of observations from a single sample from
        the hidden Markov model. The size of this dimension should match the
        `num_steps` parameter of the hidden Markov model object. The other
        dimensions are the dimensions of the batch and these are broadcast with
        the hidden Markov model's parameters.
      mask: optional bool-type `tensor` with rightmost dimension matching
        `num_steps` indicating which observations the result of this
        function should be conditioned on. When the mask has value
        `True` the corresponding observations aren't used.
        if `mask` is `None` then all of the observations are used.
        the `mask` dimensions left of the last are broadcast with the
        hmm batch as well as with the observations.
      name: Python `str` name prefixed to Ops created by this class.
        Default value: "HiddenMarkovModel".

    Returns:
      posterior_marginal: A `Categorical` distribution object representing the
        marginal probability of the hidden Markov model being in each state at
        each step. The rightmost dimension of the `Categorical` distributions
        batch will equal the `num_steps` parameter providing one marginal
        distribution for each step. The other dimensions are the dimensions
        corresponding to the batch of observations.

    Raises:
      ValueError: if rightmost dimension of `observations` does not
      have size `num_steps`.
    """

        with tf.name_scope(name or "posterior_marginals"):
            with tf.control_dependencies(self._runtime_assertions):
                observation_tensor_shape = tf.shape(observations)
                mask_tensor_shape = tf.shape(
                    mask) if mask is not None else None

                with self._observation_mask_shape_preconditions(
                        observation_tensor_shape, mask_tensor_shape):
                    observation_log_probs = self._observation_log_probs(
                        observations, mask)
                    log_prob = self._log_init + observation_log_probs[0]
                    log_transition = self._log_trans
                    log_adjoint_prob = tf.zeros_like(log_prob)

                    def _scan_multiple_steps_forwards():
                        def forward_step(log_previous_step,
                                         log_prob_observation):
                            return _log_vector_matrix(
                                log_previous_step,
                                log_transition) + log_prob_observation

                        forward_log_probs = tf.scan(forward_step,
                                                    observation_log_probs[1:],
                                                    initializer=log_prob,
                                                    name="forward_log_probs")
                        return tf.concat([[log_prob], forward_log_probs],
                                         axis=0)

                    forward_log_probs = prefer_static.cond(
                        self._num_steps > 1, _scan_multiple_steps_forwards,
                        lambda: tf.convert_to_tensor([log_prob]))

                    total_log_prob = tf.reduce_logsumexp(forward_log_probs[-1],
                                                         axis=-1)

                    def _scan_multiple_steps_backwards():
                        """Perform `scan` operation when `num_steps` > 1."""
                        def backward_step(log_previous_step,
                                          log_prob_observation):
                            return _log_matrix_vector(
                                log_transition,
                                log_prob_observation + log_previous_step)

                        backward_log_adjoint_probs = tf.scan(
                            backward_step,
                            observation_log_probs[1:],
                            initializer=log_adjoint_prob,
                            reverse=True,
                            name="backward_log_adjoint_probs")

                        return tf.concat(
                            [backward_log_adjoint_probs, [log_adjoint_prob]],
                            axis=0)

                    backward_log_adjoint_probs = prefer_static.cond(
                        self._num_steps > 1, _scan_multiple_steps_backwards,
                        lambda: tf.convert_to_tensor([log_adjoint_prob]))

                    log_likelihoods = forward_log_probs + backward_log_adjoint_probs

                    marginal_log_probs = distribution_util.move_dimension(
                        log_likelihoods - total_log_prob[..., tf.newaxis], 0,
                        -2)

                    return categorical.Categorical(logits=marginal_log_probs)
    def _observation_log_probs(self, observations, mask):
        """Compute and shape tensor of log probs associated with observations.."""

        # Let E be the underlying event shape
        #     M the number of steps in the HMM
        #     N the number of states of the HMM
        #
        # Then the incoming observations have shape
        #
        # observations : batch_o [M] E
        #
        # and the mask (if present) has shape
        #
        # mask : batch_m [M]
        #
        # Let this HMM distribution have batch shape batch_d
        # We need to broadcast all three of these batch shapes together
        # into the shape batch.
        #
        # We need to move the step dimension to the first dimension to make
        # them suitable for folding or scanning over.
        #
        # When we call `log_prob` for our observations we need to
        # do this for each state the observation could correspond to.
        # We do this by expanding the dimensions by 1 so we end up with:
        #
        # observations : [M] batch [1] [E]
        #
        # After calling `log_prob` we get
        #
        # observation_log_probs : [M] batch [N]
        #
        # We wish to use `mask` to select from this so we also
        # reshape and broadcast it up to shape
        #
        # mask : [M] batch [N]

        observation_tensor_shape = tf.shape(observations)
        observation_batch_shape = observation_tensor_shape[:-1 - self.
                                                           _underlying_event_rank]
        observation_event_shape = observation_tensor_shape[
            -1 - self._underlying_event_rank:]

        if mask is not None:
            mask_tensor_shape = tf.shape(mask)
            mask_batch_shape = mask_tensor_shape[:-1]

        batch_shape = tf.broadcast_dynamic_shape(observation_batch_shape,
                                                 self.batch_shape_tensor())

        if mask is not None:
            batch_shape = tf.broadcast_dynamic_shape(batch_shape,
                                                     mask_batch_shape)
        observations = tf.broadcast_to(
            observations,
            tf.concat([batch_shape, observation_event_shape], axis=0))
        observation_rank = tf.rank(observations)
        underlying_event_rank = self._underlying_event_rank
        observations = distribution_util.move_dimension(
            observations, observation_rank - underlying_event_rank - 1, 0)
        observations = tf.expand_dims(observations,
                                      observation_rank - underlying_event_rank)
        observation_log_probs = self._observation_distribution.log_prob(
            observations)

        if mask is not None:
            mask = tf.broadcast_to(
                mask, tf.concat([batch_shape, [self._num_steps]], axis=0))
            mask = distribution_util.move_dimension(mask, -1, 0)
            observation_log_probs = tf.where(
                mask[..., tf.newaxis], tf.zeros_like(observation_log_probs),
                observation_log_probs)

        return observation_log_probs
    def _sample_n(self, n, seed=None):
        with tf.control_dependencies(self._runtime_assertions):
            strm = SeedStream(seed, salt="HiddenMarkovModel")

            num_states = self._num_states

            batch_shape = self.batch_shape_tensor()
            batch_size = tf.reduce_prod(batch_shape)

            # The batch sizes of the underlying initial distributions and
            # transition distributions might not match the batch size of
            # the HMM distribution.
            # As a result we need to ask for more samples from the
            # underlying distributions and then reshape the results into
            # the correct batch size for the HMM.
            init_repeat = (
                tf.reduce_prod(self.batch_shape_tensor()) // tf.reduce_prod(
                    self._initial_distribution.batch_shape_tensor()))
            init_state = self._initial_distribution.sample(n * init_repeat,
                                                           seed=strm())
            init_state = tf.reshape(init_state, [n, batch_size])
            # init_state :: n batch_size

            transition_repeat = (
                tf.reduce_prod(self.batch_shape_tensor()) // tf.reduce_prod(
                    self._transition_distribution.batch_shape_tensor()[:-1]))

            def generate_step(state, _):
                """Take a single step in Markov chain."""

                gen = self._transition_distribution.sample(n *
                                                           transition_repeat,
                                                           seed=strm())
                # gen :: (n * transition_repeat) transition_batch

                new_states = tf.reshape(gen, [n, batch_size, num_states])

                # new_states :: n batch_size num_states

                old_states_one_hot = tf.one_hot(state,
                                                num_states,
                                                dtype=tf.int32)

                # old_states :: n batch_size num_states

                return tf.reduce_sum(old_states_one_hot * new_states, axis=-1)

            def _scan_multiple_steps():
                """Take multiple steps with tf.scan."""
                dummy_index = tf.zeros(self._num_steps - 1, dtype=tf.float32)
                if seed is not None:
                    # Force parallel_iterations to 1 to ensure reproducibility
                    # b/139210489
                    hidden_states = tf.scan(generate_step,
                                            dummy_index,
                                            initializer=init_state,
                                            parallel_iterations=1)
                else:
                    # Invoke default parallel_iterations behavior
                    hidden_states = tf.scan(generate_step,
                                            dummy_index,
                                            initializer=init_state)

                # TODO(b/115618503): add/use prepend_initializer to tf.scan
                return tf.concat([[init_state], hidden_states], axis=0)

            hidden_states = prefer_static.cond(
                self._num_steps > 1, _scan_multiple_steps,
                lambda: init_state[tf.newaxis, ...])

            hidden_one_hot = tf.one_hot(
                hidden_states,
                num_states,
                dtype=self._observation_distribution.dtype)
            # hidden_one_hot :: num_steps n batch_size num_states

            # The observation distribution batch size might not match
            # the required batch size so as with the initial and
            # transition distributions we generate more samples and
            # reshape.
            observation_repeat = (batch_size // tf.reduce_prod(
                self._observation_distribution.batch_shape_tensor()[:-1]))

            possible_observations = self._observation_distribution.sample(
                [self._num_steps, observation_repeat * n], seed=strm())

            inner_shape = self._observation_distribution.event_shape

            # possible_observations :: num_steps (observation_repeat * n)
            #                          observation_batch[:-1] num_states inner_shape

            possible_observations = tf.reshape(
                possible_observations,
                tf.concat([[self._num_steps, n], batch_shape, [num_states],
                           inner_shape],
                          axis=0))

            # possible_observations :: steps n batch_size num_states inner_shape

            hidden_one_hot = tf.reshape(
                hidden_one_hot,
                tf.concat([[self._num_steps, n], batch_shape, [num_states],
                           tf.ones_like(inner_shape)],
                          axis=0))

            # hidden_one_hot :: steps n batch_size num_states "inner_shape"

            observations = tf.reduce_sum(hidden_one_hot *
                                         possible_observations,
                                         axis=-1 - tf.size(inner_shape))

            # observations :: steps n batch_size inner_shape

            observations = distribution_util.move_dimension(
                observations, 0, 1 + tf.size(batch_shape))

            # returned :: n batch_shape steps inner_shape

            return observations