def prior_kl(global_natparam, prior_natparam): expected_stats = flat(prior_expectedstats(global_natparam)) natparam_difference = flat(global_natparam) - flat(prior_natparam) logZ_difference = prior_logZ(global_natparam) - prior_logZ(prior_natparam) return np.dot(natparam_difference, expected_stats) - logZ_difference