Esempio n. 1
0
def optimize_local_meanfield(global_natparam, node_potentials, tol=1e-3, max_iter=100):
    label_global, gaussian_global = global_natparam

    local_vlb = -np.inf
    label_stats = initialize_local_meanfield(label_global, node_potentials)
    for _ in xrange(max_iter):
        gaussian_natparam, gaussian_stats, gaussian_vlb = \
            gaussian_meanfield(gaussian_global, unbox(node_potentials), label_stats)
        label_natparam, label_stats, label_vlb = \
            label_meanfield(label_global, gaussian_global, gaussian_stats)

        local_vlb, prev_local_vlb = label_vlb + gaussian_vlb, local_vlb
        if abs(local_vlb - prev_local_vlb) < tol:
            break
    else:
        print 'iteration limit reached'

    # recompute values that depend on node_potentials at optimum
    gaussian_natparam, gaussian_stats, gaussian_vlb = \
        gaussian_meanfield(gaussian_global, node_potentials, label_stats)
    label_natparam, label_stats, label_vlb = \
        label_meanfield(label_global, gaussian_global, gaussian_stats)

    stats = label_stats, gaussian_stats
    local_natparams = label_natparam, gaussian_natparam
    vlbs = label_vlb, gaussian_vlb

    return stats, local_natparams, vlbs
Esempio n. 2
0
def run_inference(prior_natparam, global_natparam, nn_potentials, num_samples):
    hmm_global_natparam, lds_global_natparam = global_natparam

    # optimize local mean field (can use unboxed val for low-level code)
    (hmm_stats, _), (hmm_local_natparam, lds_local_natparam), _ = \
        optimize_local_meanfield(global_natparam, unbox(nn_potentials))

    # recompute terms that depend on nn_potentials at optimum (using boxed val)
    samples, lds_stats, lds_normalizer = natural_lds_inference_general(
        lds_local_natparam, nn_potentials, num_samples)
    hmm_vlb = get_hmm_vlb(lds_global_natparam, hmm_local_natparam, lds_stats)

    # get global statistics from the local expected stats
    global_lds_stats, local_lds_stats = lds_stats[:-1], lds_stats[-1]
    expected_stats = get_global_stats(hmm_stats, global_lds_stats)

    # compute global and local vlb terms
    global_vlb = slds_prior_vlb(global_natparam, prior_natparam)
    lds_vlb = lds_normalizer - contract(nn_potentials, local_lds_stats)
    local_vlb = hmm_vlb + lds_vlb

    return samples, expected_stats, global_vlb, local_vlb
Esempio n. 3
0
File: gmm.py Progetto: mattjj/svae
def run_inference(prior_natparam, global_natparam, nn_potentials, num_samples):
    _, stats, local_natparam, local_kl = local_meanfield(global_natparam, nn_potentials)
    samples = gaussian.natural_sample(local_natparam[1], num_samples)
    global_kl = prior_kl(global_natparam, prior_natparam)
    return samples, unbox(stats), global_kl, local_kl
Esempio n. 4
0
def run_inference(prior_natparam, global_natparam, nn_potentials, num_samples):
    _, stats, local_natparam, local_kl = local_meanfield(global_natparam, nn_potentials)
    samples = gaussian.natural_sample(local_natparam[1], num_samples)
    global_kl = prior_kl(global_natparam, prior_natparam)
    return samples, unbox(stats), global_kl, local_kl
Esempio n. 5
0
def run_inference(prior_natparam, global_natparam, nn_potentials):
    stats, local_natparam, local_kl = local_meanfield(global_natparam,
                                                      nn_potentials)
    global_kl = prior_kl(global_natparam, prior_natparam)
    return unbox(stats), global_kl, local_kl