def local_kl_z(z_nat_param, dir_stats, z_stats): # z_nat_param: [M, K] # dir_stats: [K] # z_stats: [M, K] z_nat_param = z_nat_param - tf.reduce_logsumexp(z_nat_param, axis=-1, keepdims=True) nat_param_diff = z_nat_param - dir_stats # ret: [M] return exp_family_kl(nat_param_diff, z_stats)
def local_kl_x(x_nat_param, niw_stats, z_stats, x_stats, d): # x_nat_param: [M, d + d^2] # niw_stats: [K, d + d^2 + 2] # z_stats: [M, K] # x_stats: [M, d + d^2] # x_prior_term: [M, d + d^2 + 2] x_prior_term = tf.matmul(z_stats, niw_stats) # nat_param_diff: [M, d + d^2] nat_param_diff = x_nat_param - x_prior_term[:, :-2] # log_partition_diff: [M] log_z_diff = mvn.log_partition(x_nat_param, d) + tf.reduce_sum( x_prior_term[:, -2:], axis=-1) # ret: [M] return exp_family_kl(nat_param_diff, x_stats, log_z_diff=log_z_diff)
def _kl_helper(log_partition, param, prior_param, stats): nat_diff = param - prior_param log_z_diff = log_partition(param) - log_partition(prior_param) return exp_family_kl(nat_diff, stats, log_z_diff=log_z_diff)