def test_get_posterior_crossentropy(self):
    input_logprob = np.log(np.random.uniform(low=1e-6, high=1., size=[2, 3, 6]))
    prior_prob = np.random.uniform(size=[6])

    result_entropy = utils.get_posterior_crossentropy(input_logprob, prior_prob)
    numpy_result = np.sum(input_logprob*prior_prob[None, None, :],
                          axis=(1, 2))

    self.assertAllClose(
        self.evaluate(result_entropy), numpy_result)
    def call(self, inputs, temperature=1.0, num_samples=1, dtype=tf.float32):
        """Inference call of SNLDS.

    Args:
      inputs: a `float` Tensor of shape `[batch_size, num_steps, event_size]`,
        containing the observation time series of the model.
      temperature: a `float` Scalar for temperature used to estimate discrete
        state transition `p(s[t] | s[t-1], x[t-1])` as described in Dong et al.
        (2019). Increasing temperature increase the uncertainty about each
        discrete states.
        Default to 1. For ''temperature annealing'', the temperature is set
        to large value initially, and decay to a smaller one. A temperature
        should be positive, but could be smaller than `1.`.
      num_samples: an `int` scalar for number of samples per time-step, for
        posterior inference networks, `z[i] ~ q(z[1:T] | x[1:T])`.
      dtype: data type for calculation. Default to `tf.float32`.

    Returns:
      return_dict: a python `dict` contains all the `Tensor`s for inference
        results. Including the following keys:
        elbo: Evidence Lower Bound, returned by `get_objective_values` function.
        iwae: IWAE Bound, returned by `get_objective_values` function.
        initial_likelihood: the likelihood of `p(s[0], z[0], x[0])`, returned
          by `get_objective_values` function.
        sequence_likelihood: the likelihood of `p(s[1:T], z[1:T], x[0:T])`,
          returned by `get_objective_values` function.
        zt_entropy: the entropy of posterior distribution `H(q(z[t] | x[1:T])`,
          returned by `get_objective_values` function.
        reconstruction: the reconstructed inputs, returned by
          `get_reconstruction` function.
        posterior_llk: the posterior likelihood, `p(s[t] | x[1:T], z[1:T])`,
          returned by `forward_backward_algo.forward_backward` function.
        sampled_z: the sampled z[1:T] from the approximate posterior.
        cross_entropy: batched cross entropy between discrete state posterior
          likelihood and its prior distribution.
    """
        inputs = tf.convert_to_tensor(inputs,
                                      dtype_hint=dtype,
                                      name="SNLDS_Input_Tensor")
        # Sample continuous hidden variable from `q(z[1:T] | x[1:T])'
        z_sampled, z_entropy, log_prob_q = self.inference_network(
            inputs, num_samples=num_samples)

        _, batch_size, num_steps, z_dim = tf.unstack(tf.shape(z_sampled))

        # Merge batch_size and num_samples dimensions.
        z_sampled = tf.reshape(z_sampled,
                               [num_samples * batch_size, num_steps, z_dim])
        z_entropy = tf.reshape(z_entropy,
                               [num_samples * batch_size, num_steps])
        log_prob_q = tf.reshape(log_prob_q,
                                [num_samples * batch_size, num_steps])

        inputs = tf.tile(inputs, [num_samples, 1, 1])

        # Base on observation inputs `x', sampled continuous dynamical states
        # `z_sampled', get `log_a(j, k) = p(s[t]=j | s[t-1]=k, x[t-1])', and
        # `log_b(k) = p(x[t] | z[t])p(z[t] | z[t-1], s[t]=k)'.
        log_b, log_a = self.calculate_likelihoods(inputs,
                                                  z_sampled,
                                                  temperature=temperature)

        # Forward-backward algorithm will return the posterior marginal of
        # discrete states `log_gamma2 = p(s[t]=k, s[t-1]=j | x[1:T], z[1:T])'
        # and `log_gamma1 = p(s[t]=k | x[1:T], z[1:T])'.
        _, _, log_gamma1, log_gamma2 = forward_backward_algo.forward_backward(
            log_a, log_b, self.log_init)

        recon_inputs = self.get_reconstruction(
            z_sampled,
            observation_shape=tf.shape(inputs),
            sample_for_reconstruction=False)

        # Calculate Evidence Lower Bound with components.
        # The return_dict currently support the following items:
        #   elbo: Evidence Lower Bound.
        #   iwae: IWAE Lower Bound.
        #   initial_likelihood: likelihood of p(s[0], z[0], x[0]).
        #   sequence_likelihood: likelihood of p(s[1:T], z[1:T], x[0:T]).
        #   zt_entropy: the entropy of posterior distribution.
        return_dict = self.get_objective_values(log_a, log_b, self.log_init,
                                                log_gamma1, log_gamma2,
                                                log_prob_q, z_entropy,
                                                num_samples)

        # Estimate the cross entropy between state prior and posterior likelihoods.
        state_crossentropy = utils.get_posterior_crossentropy(
            log_gamma1, prior_probs=self.discrete_prior)
        state_crossentropy = tf.reduce_mean(state_crossentropy, axis=0)

        recon_inputs = tf.reshape(recon_inputs,
                                  [num_samples, batch_size, num_steps, -1])
        log_gamma1 = tf.reshape(log_gamma1,
                                [num_samples, batch_size, num_steps, -1])
        z_sampled = tf.reshape(z_sampled,
                               [num_samples, batch_size, num_steps, z_dim])

        return_dict["inputs"] = inputs
        return_dict["reconstructions"] = recon_inputs[0]
        return_dict["posterior_llk"] = log_gamma1[0]
        return_dict["sampled_z"] = z_sampled[0]
        return_dict["cross_entropy"] = state_crossentropy

        return return_dict