def _extract_log_probs(num_states, dist): """Tabulate log probabilities from a batch of distributions.""" states = tf.reshape( tf.range(num_states), tf.concat([[num_states], tf.ones_like(dist.batch_shape_tensor())], axis=0)) return distribution_util.move_dimension(dist.log_prob(states), 0, -1)
def _reduce_multiple_steps(): """Perform `reduce_max` operation when `num_steps` > 1.""" def forward_step(previous_step_pair, log_prob_observation): log_prob_previous = previous_step_pair[0] log_prob = ( log_prob_previous[..., tf.newaxis] + self._log_trans + log_prob_observation[..., tf.newaxis, :]) most_likely_given_successor = tf.argmax(log_prob, axis=-2) max_log_p_given_successor = tf.reduce_max( input_tensor=log_prob, axis=-2) return (max_log_p_given_successor, most_likely_given_successor) forward_log_probs, all_most_likely_given_successor = tf.scan( forward_step, observation_log_probs[1:], initializer=(log_prob, tf.zeros(tf.shape(log_prob), dtype=tf.int64)), name="forward_log_probs") most_likely_end = tf.argmax(forward_log_probs[-1], axis=-1) # We require the operation that gives C from A and B where # C[i...j] = A[i...j, B[i...j]] # and A = most_likely_given_successor # B = most_likely_successor. # tf.gather requires indices of known shape so instead we use # reduction with tf.one_hot(B) to pick out elements from B def backward_step(most_likely_successor, most_likely_given_successor): return tf.reduce_sum( input_tensor=(most_likely_given_successor * tf.one_hot(most_likely_successor, self._num_states, dtype=tf.int64)), axis=-1) backward_scan = tf.scan( backward_step, all_most_likely_given_successor, most_likely_end, reverse=True) most_likely_sequences = tf.concat( [backward_scan, [most_likely_end]], axis=0) return distribution_util.move_dimension( most_likely_sequences, 0, -1)
def _log_prob(self, value): with tf.control_dependencies(self._runtime_assertions): # The argument `value` is a tensor of sequences of observations. # `observation_batch_shape` is the shape of that tensor with the # sequence part removed. # `observation_batch_shape` is then broadcast to the full batch shape # to give the `batch_shape` that defines the shape of the result. observation_tensor_shape = tf.shape(value) observation_batch_shape = observation_tensor_shape[:-1 - self. _underlying_event_rank] # value :: observation_batch_shape num_steps observation_event_shape batch_shape = tf.broadcast_dynamic_shape(observation_batch_shape, self.batch_shape_tensor()) log_init = tf.broadcast_to( self._log_init, tf.concat([batch_shape, [self._num_states]], axis=0)) # log_init :: batch_shape num_states log_transition = self._log_trans # `observation_event_shape` is the shape of each sequence of observations # emitted by the model. observation_event_shape = observation_tensor_shape[ -1 - self._underlying_event_rank:] working_obs = tf.broadcast_to( value, tf.concat([batch_shape, observation_event_shape], axis=0)) # working_obs :: batch_shape observation_event_shape r = self._underlying_event_rank # Move index into sequence of observations to front so we can apply # tf.foldl working_obs = distribution_util.move_dimension( working_obs, -1 - r, 0)[..., tf.newaxis] # working_obs :: num_steps batch_shape underlying_event_shape observation_probs = ( self._observation_distribution.log_prob(working_obs)) def forward_step(log_prev_step, log_prob_observation): return _log_vector_matrix( log_prev_step, log_transition) + log_prob_observation fwd_prob = tf.foldl(forward_step, observation_probs, initializer=log_init) # fwd_prob :: batch_shape num_states log_prob = tf.reduce_logsumexp(fwd_prob, axis=-1) # log_prob :: batch_shape return log_prob
def posterior_marginals(self, observations, mask=None, name=None): """Compute marginal posterior distribution for each state. This function computes, for each time step, the marginal conditional probability that the hidden Markov model was in each possible state given the observations that were made at each time step. So if the hidden states are `z[0],...,z[num_steps - 1]` and the observations are `x[0], ..., x[num_steps - 1]`, then this function computes `P(z[i] | x[0], ..., x[num_steps - 1])` for all `i` from `0` to `num_steps - 1`. This operation is sometimes called smoothing. It uses a form of the forward-backward algorithm. Note: the behavior of this function is undefined if the `observations` argument represents impossible observations from the model. Args: observations: A tensor representing a batch of observations made on the hidden Markov model. The rightmost dimension of this tensor gives the steps in a sequence of observations from a single sample from the hidden Markov model. The size of this dimension should match the `num_steps` parameter of the hidden Markov model object. The other dimensions are the dimensions of the batch and these are broadcast with the hidden Markov model's parameters. mask: optional bool-type `tensor` with rightmost dimension matching `num_steps` indicating which observations the result of this function should be conditioned on. When the mask has value `True` the corresponding observations aren't used. if `mask` is `None` then all of the observations are used. the `mask` dimensions left of the last are broadcast with the hmm batch as well as with the observations. name: Python `str` name prefixed to Ops created by this class. Default value: "HiddenMarkovModel". Returns: posterior_marginal: A `Categorical` distribution object representing the marginal probability of the hidden Markov model being in each state at each step. The rightmost dimension of the `Categorical` distributions batch will equal the `num_steps` parameter providing one marginal distribution for each step. The other dimensions are the dimensions corresponding to the batch of observations. Raises: ValueError: if rightmost dimension of `observations` does not have size `num_steps`. """ with tf.name_scope(name or "posterior_marginals"): with tf.control_dependencies(self._runtime_assertions): observation_tensor_shape = tf.shape(observations) mask_tensor_shape = tf.shape( mask) if mask is not None else None with self._observation_mask_shape_preconditions( observation_tensor_shape, mask_tensor_shape): observation_log_probs = self._observation_log_probs( observations, mask) log_prob = self._log_init + observation_log_probs[0] log_transition = self._log_trans log_adjoint_prob = tf.zeros_like(log_prob) def _scan_multiple_steps_forwards(): def forward_step(log_previous_step, log_prob_observation): return _log_vector_matrix( log_previous_step, log_transition) + log_prob_observation forward_log_probs = tf.scan(forward_step, observation_log_probs[1:], initializer=log_prob, name="forward_log_probs") return tf.concat([[log_prob], forward_log_probs], axis=0) forward_log_probs = prefer_static.cond( self._num_steps > 1, _scan_multiple_steps_forwards, lambda: tf.convert_to_tensor([log_prob])) total_log_prob = tf.reduce_logsumexp(forward_log_probs[-1], axis=-1) def _scan_multiple_steps_backwards(): """Perform `scan` operation when `num_steps` > 1.""" def backward_step(log_previous_step, log_prob_observation): return _log_matrix_vector( log_transition, log_prob_observation + log_previous_step) backward_log_adjoint_probs = tf.scan( backward_step, observation_log_probs[1:], initializer=log_adjoint_prob, reverse=True, name="backward_log_adjoint_probs") return tf.concat( [backward_log_adjoint_probs, [log_adjoint_prob]], axis=0) backward_log_adjoint_probs = prefer_static.cond( self._num_steps > 1, _scan_multiple_steps_backwards, lambda: tf.convert_to_tensor([log_adjoint_prob])) log_likelihoods = forward_log_probs + backward_log_adjoint_probs marginal_log_probs = distribution_util.move_dimension( log_likelihoods - total_log_prob[..., tf.newaxis], 0, -2) return categorical.Categorical(logits=marginal_log_probs)
def _observation_log_probs(self, observations, mask): """Compute and shape tensor of log probs associated with observations..""" # Let E be the underlying event shape # M the number of steps in the HMM # N the number of states of the HMM # # Then the incoming observations have shape # # observations : batch_o [M] E # # and the mask (if present) has shape # # mask : batch_m [M] # # Let this HMM distribution have batch shape batch_d # We need to broadcast all three of these batch shapes together # into the shape batch. # # We need to move the step dimension to the first dimension to make # them suitable for folding or scanning over. # # When we call `log_prob` for our observations we need to # do this for each state the observation could correspond to. # We do this by expanding the dimensions by 1 so we end up with: # # observations : [M] batch [1] [E] # # After calling `log_prob` we get # # observation_log_probs : [M] batch [N] # # We wish to use `mask` to select from this so we also # reshape and broadcast it up to shape # # mask : [M] batch [N] observation_tensor_shape = tf.shape(observations) observation_batch_shape = observation_tensor_shape[:-1 - self. _underlying_event_rank] observation_event_shape = observation_tensor_shape[ -1 - self._underlying_event_rank:] if mask is not None: mask_tensor_shape = tf.shape(mask) mask_batch_shape = mask_tensor_shape[:-1] batch_shape = tf.broadcast_dynamic_shape(observation_batch_shape, self.batch_shape_tensor()) if mask is not None: batch_shape = tf.broadcast_dynamic_shape(batch_shape, mask_batch_shape) observations = tf.broadcast_to( observations, tf.concat([batch_shape, observation_event_shape], axis=0)) observation_rank = tf.rank(observations) underlying_event_rank = self._underlying_event_rank observations = distribution_util.move_dimension( observations, observation_rank - underlying_event_rank - 1, 0) observations = tf.expand_dims(observations, observation_rank - underlying_event_rank) observation_log_probs = self._observation_distribution.log_prob( observations) if mask is not None: mask = tf.broadcast_to( mask, tf.concat([batch_shape, [self._num_steps]], axis=0)) mask = distribution_util.move_dimension(mask, -1, 0) observation_log_probs = tf.where( mask[..., tf.newaxis], tf.zeros_like(observation_log_probs), observation_log_probs) return observation_log_probs
def _sample_n(self, n, seed=None): with tf.control_dependencies(self._runtime_assertions): strm = SeedStream(seed, salt="HiddenMarkovModel") num_states = self._num_states batch_shape = self.batch_shape_tensor() batch_size = tf.reduce_prod(batch_shape) # The batch sizes of the underlying initial distributions and # transition distributions might not match the batch size of # the HMM distribution. # As a result we need to ask for more samples from the # underlying distributions and then reshape the results into # the correct batch size for the HMM. init_repeat = ( tf.reduce_prod(self.batch_shape_tensor()) // tf.reduce_prod( self._initial_distribution.batch_shape_tensor())) init_state = self._initial_distribution.sample(n * init_repeat, seed=strm()) init_state = tf.reshape(init_state, [n, batch_size]) # init_state :: n batch_size transition_repeat = ( tf.reduce_prod(self.batch_shape_tensor()) // tf.reduce_prod( self._transition_distribution.batch_shape_tensor()[:-1])) def generate_step(state, _): """Take a single step in Markov chain.""" gen = self._transition_distribution.sample(n * transition_repeat, seed=strm()) # gen :: (n * transition_repeat) transition_batch new_states = tf.reshape(gen, [n, batch_size, num_states]) # new_states :: n batch_size num_states old_states_one_hot = tf.one_hot(state, num_states, dtype=tf.int32) # old_states :: n batch_size num_states return tf.reduce_sum(old_states_one_hot * new_states, axis=-1) def _scan_multiple_steps(): """Take multiple steps with tf.scan.""" dummy_index = tf.zeros(self._num_steps - 1, dtype=tf.float32) if seed is not None: # Force parallel_iterations to 1 to ensure reproducibility # b/139210489 hidden_states = tf.scan(generate_step, dummy_index, initializer=init_state, parallel_iterations=1) else: # Invoke default parallel_iterations behavior hidden_states = tf.scan(generate_step, dummy_index, initializer=init_state) # TODO(b/115618503): add/use prepend_initializer to tf.scan return tf.concat([[init_state], hidden_states], axis=0) hidden_states = prefer_static.cond( self._num_steps > 1, _scan_multiple_steps, lambda: init_state[tf.newaxis, ...]) hidden_one_hot = tf.one_hot( hidden_states, num_states, dtype=self._observation_distribution.dtype) # hidden_one_hot :: num_steps n batch_size num_states # The observation distribution batch size might not match # the required batch size so as with the initial and # transition distributions we generate more samples and # reshape. observation_repeat = (batch_size // tf.reduce_prod( self._observation_distribution.batch_shape_tensor()[:-1])) possible_observations = self._observation_distribution.sample( [self._num_steps, observation_repeat * n], seed=strm()) inner_shape = self._observation_distribution.event_shape # possible_observations :: num_steps (observation_repeat * n) # observation_batch[:-1] num_states inner_shape possible_observations = tf.reshape( possible_observations, tf.concat([[self._num_steps, n], batch_shape, [num_states], inner_shape], axis=0)) # possible_observations :: steps n batch_size num_states inner_shape hidden_one_hot = tf.reshape( hidden_one_hot, tf.concat([[self._num_steps, n], batch_shape, [num_states], tf.ones_like(inner_shape)], axis=0)) # hidden_one_hot :: steps n batch_size num_states "inner_shape" observations = tf.reduce_sum(hidden_one_hot * possible_observations, axis=-1 - tf.size(inner_shape)) # observations :: steps n batch_size inner_shape observations = distribution_util.move_dimension( observations, 0, 1 + tf.size(batch_shape)) # returned :: n batch_shape steps inner_shape return observations