def local_kl_z(z_nat_param, dir_stats, z_stats):
    # z_nat_param: [M, K]
    # dir_stats: [K]
    # z_stats: [M, K]
    z_nat_param = z_nat_param - tf.reduce_logsumexp(z_nat_param, axis=-1,
                                                    keepdims=True)
    nat_param_diff = z_nat_param - dir_stats
    # ret: [M]
    return exp_family_kl(nat_param_diff, z_stats)
def local_kl_x(x_nat_param, niw_stats, z_stats, x_stats, d):
    # x_nat_param: [M, d + d^2]
    # niw_stats: [K, d + d^2 + 2]
    # z_stats: [M, K]
    # x_stats: [M, d + d^2]
    # x_prior_term: [M, d + d^2 + 2]
    x_prior_term = tf.matmul(z_stats, niw_stats)
    # nat_param_diff: [M, d + d^2]
    nat_param_diff = x_nat_param - x_prior_term[:, :-2]
    # log_partition_diff: [M]
    log_z_diff = mvn.log_partition(x_nat_param, d) + tf.reduce_sum(
        x_prior_term[:, -2:], axis=-1)
    # ret: [M]
    return exp_family_kl(nat_param_diff, x_stats, log_z_diff=log_z_diff)
 def _kl_helper(log_partition, param, prior_param, stats):
     nat_diff = param - prior_param
     log_z_diff = log_partition(param) - log_partition(prior_param)
     return exp_family_kl(nat_diff, stats, log_z_diff=log_z_diff)