def optimize_local_meanfield(global_natparam, node_potentials, tol=1e-3, max_iter=100): label_global, gaussian_global = global_natparam local_vlb = -np.inf label_stats = initialize_local_meanfield(label_global, node_potentials) for _ in xrange(max_iter): gaussian_natparam, gaussian_stats, gaussian_vlb = \ gaussian_meanfield(gaussian_global, unbox(node_potentials), label_stats) label_natparam, label_stats, label_vlb = \ label_meanfield(label_global, gaussian_global, gaussian_stats) local_vlb, prev_local_vlb = label_vlb + gaussian_vlb, local_vlb if abs(local_vlb - prev_local_vlb) < tol: break else: print 'iteration limit reached' # recompute values that depend on node_potentials at optimum gaussian_natparam, gaussian_stats, gaussian_vlb = \ gaussian_meanfield(gaussian_global, node_potentials, label_stats) label_natparam, label_stats, label_vlb = \ label_meanfield(label_global, gaussian_global, gaussian_stats) stats = label_stats, gaussian_stats local_natparams = label_natparam, gaussian_natparam vlbs = label_vlb, gaussian_vlb return stats, local_natparams, vlbs
def run_inference(prior_natparam, global_natparam, nn_potentials, num_samples): hmm_global_natparam, lds_global_natparam = global_natparam # optimize local mean field (can use unboxed val for low-level code) (hmm_stats, _), (hmm_local_natparam, lds_local_natparam), _ = \ optimize_local_meanfield(global_natparam, unbox(nn_potentials)) # recompute terms that depend on nn_potentials at optimum (using boxed val) samples, lds_stats, lds_normalizer = natural_lds_inference_general( lds_local_natparam, nn_potentials, num_samples) hmm_vlb = get_hmm_vlb(lds_global_natparam, hmm_local_natparam, lds_stats) # get global statistics from the local expected stats global_lds_stats, local_lds_stats = lds_stats[:-1], lds_stats[-1] expected_stats = get_global_stats(hmm_stats, global_lds_stats) # compute global and local vlb terms global_vlb = slds_prior_vlb(global_natparam, prior_natparam) lds_vlb = lds_normalizer - contract(nn_potentials, local_lds_stats) local_vlb = hmm_vlb + lds_vlb return samples, expected_stats, global_vlb, local_vlb
def run_inference(prior_natparam, global_natparam, nn_potentials, num_samples): _, stats, local_natparam, local_kl = local_meanfield(global_natparam, nn_potentials) samples = gaussian.natural_sample(local_natparam[1], num_samples) global_kl = prior_kl(global_natparam, prior_natparam) return samples, unbox(stats), global_kl, local_kl
def run_inference(prior_natparam, global_natparam, nn_potentials): stats, local_natparam, local_kl = local_meanfield(global_natparam, nn_potentials) global_kl = prior_kl(global_natparam, prior_natparam) return unbox(stats), global_kl, local_kl