Ejemplo n.º 1
0
def get_var_lds_local_natparam(lds_global_natparams, hmm_expected_stats):
    _, _, expected_states = hmm_expected_stats
    all_init_params, all_pair_params = get_all_lds_local_natparams(lds_global_natparams)

    dense_init_params = map(np.stack, zip(*all_init_params))
    dense_pair_params = map(np.stack, zip(*all_pair_params))

    contract = lambda a: lambda b: np.tensordot(a, b, axes=1)
    init_param = map(contract(expected_states[0]), dense_init_params)
    pair_params = map(contract(expected_states[1:]), dense_pair_params)

    return init_param, pair_params
Ejemplo n.º 2
0
def test_compute_stats_grad():
    F = make_unop(lambda x: np.require(x, np.double, 'F'), tuple)

    dotter = F(randn_like(compute_stats(Ex, ExxT, ExnxT, True)))
    g1 = grad(lambda x: contract(dotter, compute_stats(*x)))((Ex, ExxT, ExnxT, 1.))
    g2 = _compute_stats_grad(dotter)
    assert allclose(g1[:3], g2)

    dotter = F(randn_like(compute_stats(Ex, ExxT, ExnxT, False)))
    g1 = grad(lambda x: contract(dotter, compute_stats(*x)))((Ex, ExxT, ExnxT, 0.))
    g2 = _compute_stats_grad(dotter)
    assert allclose(g1[:3], g2)
Ejemplo n.º 3
0
def cython_run_inference(prior_natparam, global_natparam, nn_potentials, num_samples):
    local_natparam = lds_prior_expectedstats(global_natparam)
    samples, expected_stats, local_normalizer = cython_natural_lds_inference_general(
        local_natparam, nn_potentials, num_samples)
    global_expected_stats, local_expected_stats = expected_stats[:-1], expected_stats[-1]
    local_vlb = local_normalizer - contract(nn_potentials, local_expected_stats)
    global_vlb = lds_prior_vlb(global_natparam, prior_natparam, local_natparam)
    return samples, global_expected_stats, global_vlb, local_vlb
Ejemplo n.º 4
0
def cython_run_inference(prior_natparam, global_natparam, nn_potentials, num_samples):
    local_natparam = lds_prior_expectedstats(global_natparam)
    samples, expected_stats, local_normalizer = cython_natural_lds_inference_general(
        local_natparam, nn_potentials, num_samples)
    global_expected_stats, local_expected_stats = expected_stats[:-1], expected_stats[-1]
    local_kl = contract(nn_potentials, local_expected_stats) - local_normalizer
    global_kl = lds_prior_kl(global_natparam, prior_natparam, local_natparam)
    return samples, global_expected_stats, global_kl, local_kl
Ejemplo n.º 5
0
    def get_lds_global_stats(hmm_stats, lds_stats):
        _, _, expected_states = hmm_stats
        init_stats, pair_stats = lds_stats

        contract = lambda w: lambda p: np.tensordot(w, p, axes=1)
        global_init_stats = tuple(scale(w, init_stats) for w in expected_states[0])
        global_pair_stats = tuple(map(contract(w), pair_stats) for w in expected_states[1:].T)

        return zip(global_init_stats, global_pair_stats)
Ejemplo n.º 6
0
def run_inference(prior_natparam, global_natparam, nn_potentials, num_samples):
    local_natparam = lds_prior_expectedstats(global_natparam)
    samples, expected_stats, local_normalizer = natural_lds_inference_general(
        local_natparam, nn_potentials, num_samples)
    global_expected_stats, local_expected_stats = expected_stats[:
                                                                 -1], expected_stats[
                                                                     -1]
    local_vlb = local_normalizer - contract(nn_potentials,
                                            local_expected_stats)
    global_vlb = lds_prior_vlb(global_natparam, prior_natparam, local_natparam)
    return samples, global_expected_stats, global_vlb, local_vlb
Ejemplo n.º 7
0
def gmm_global_vlb(global_natparam, prior_natparam):
    def gmm_prior_logZ(natparam):
        dir_natparam, niw_natparams = natparam
        return dirichlet.logZ(dir_natparam) + sum(map(niw.logZ, niw_natparams))

    def gmm_prior_expectedstats(natparam):
        dir_natparam, niw_natparams = natparam
        return dirichlet.expectedstats(dir_natparam), map(niw.expectedstats, niw_natparams)

    expected_stats = gmm_prior_expectedstats(global_natparam)
    return contract(sub(prior_natparam, global_natparam), expected_stats) \
        - (gmm_prior_logZ(prior_natparam) - gmm_prior_logZ(global_natparam))
Ejemplo n.º 8
0
def label_meanfield(label_global, gaussian_globals, gaussian_stats):
    partial_contract = lambda a, b: \
        sum(np.tensordot(x, y, axes=np.ndim(y)) for x, y, in zip(a, b))

    gaussian_local_natparams = map(niw.expectedstats, gaussian_globals)
    node_params = np.array([
        partial_contract(gaussian_stats, natparam) for natparam in gaussian_local_natparams]).T

    local_natparam = dirichlet.expectedstats(label_global) + node_params
    stats = normalize(np.exp(local_natparam  - logsumexp(local_natparam, axis=1, keepdims=True)))
    vlb = np.sum(logsumexp(local_natparam, axis=1)) - contract(stats, node_params)

    return local_natparam, stats, vlb
Ejemplo n.º 9
0
def grad_check(fun, gradfun, arg, eps=EPS, rtol=RTOL, atol=ATOL, rng=None):
    def scalar_nd(f, x, eps):
        return (f(x + eps/2) - f(x - eps/2)) / eps

    random_dir = rand_dir_like(arg)
    scalar_fun = lambda x: fun(add(arg, scale(x, random_dir)))

    numeric_grad  = scalar_nd(scalar_fun, 0.0, eps=eps)
    numeric_grad2 = scalar_nd(scalar_fun, 0.0, eps=eps)
    analytic_grad = contract(gradfun(arg), random_dir)

    assert np.isclose(numeric_grad, numeric_grad2, rtol=rtol, atol=atol)
    assert np.isclose(numeric_grad, analytic_grad, rtol=rtol, atol=atol)
Ejemplo n.º 10
0
def get_arhmm_local_nodeparams(lds_global_natparam, lds_expected_stats):
    init_stats, pair_stats = lds_expected_stats[:2]
    all_init_params, all_pair_params = get_all_lds_local_natparams(lds_global_natparam)

    dense_init_params = map(np.stack, zip(*all_init_params))
    dense_pair_params = map(np.stack, zip(*all_pair_params))

    partial_contract = lambda a: lambda b: contract(a, b)
    init_node_potential = np.array(map(partial_contract(init_stats), all_init_params))

    partial_contract = lambda a: lambda b: \
        sum(np.tensordot(x, y, axes=np.ndim(y)) for x, y in zip(a,b))
    remaining_node_potentials = np.vstack(map(partial_contract(pair_stats), all_pair_params)).T

    node_potentials = np.vstack((init_node_potential, remaining_node_potentials))

    return node_potentials
Ejemplo n.º 11
0
def gaussian_meanfield(gaussian_globals, node_potentials, label_stats):
    def make_full_potentials(node_potentials):
        Jdiag, h = node_potentials[:2]
        T, N = h.shape
        return Jdiag[...,None] * np.eye(N)[None,...], h, np.zeros(T), np.zeros(T)

    def get_local_natparam(gaussian_globals, node_potentials, label_stats):
        local_natparams = [np.tensordot(label_stats, param, axes=1)
                           for param in zip(*map(niw.expectedstats, gaussian_globals))]
        return add(local_natparams, make_full_potentials(node_potentials))

    def get_node_stats(gaussian_stats):
        ExxT, Ex, En, En = gaussian_stats
        return np.diagonal(ExxT, axis1=-1, axis2=-2), Ex, En

    local_natparam = get_local_natparam(gaussian_globals, node_potentials, label_stats)
    stats = gaussian.expectedstats(local_natparam)
    vlb = gaussian.logZ(local_natparam) - contract(node_potentials, get_node_stats(stats))

    return local_natparam, stats, vlb
Ejemplo n.º 12
0
def run_inference(prior_natparam, global_natparam, nn_potentials, num_samples):
    hmm_global_natparam, lds_global_natparam = global_natparam

    # optimize local mean field (can use unboxed val for low-level code)
    (hmm_stats, _), (hmm_local_natparam, lds_local_natparam), _ = \
        optimize_local_meanfield(global_natparam, unbox(nn_potentials))

    # recompute terms that depend on nn_potentials at optimum (using boxed val)
    samples, lds_stats, lds_normalizer = natural_lds_inference_general(
        lds_local_natparam, nn_potentials, num_samples)
    hmm_vlb = get_hmm_vlb(lds_global_natparam, hmm_local_natparam, lds_stats)

    # get global statistics from the local expected stats
    global_lds_stats, local_lds_stats = lds_stats[:-1], lds_stats[-1]
    expected_stats = get_global_stats(hmm_stats, global_lds_stats)

    # compute global and local vlb terms
    global_vlb = slds_prior_vlb(global_natparam, prior_natparam)
    lds_vlb = lds_normalizer - contract(nn_potentials, local_lds_stats)
    local_vlb = hmm_vlb + lds_vlb

    return samples, expected_stats, global_vlb, local_vlb
Ejemplo n.º 13
0
 def cy_fun(messages):
     result = _natural_smoother_general(messages, pair_params)
     result = result[0][:3], result[1], result[2]
     assert shape(result) == shape(dotter)
     return contract(dotter, result)
Ejemplo n.º 14
0
def get_global_stats(label_stats, gaussian_stats):
    contract = lambda w: lambda p: np.tensordot(w, p, axes=1)
    global_label_stats = np.sum(label_stats, axis=0)
    global_gaussian_stats = tuple(map(contract(w), gaussian_stats) for w in label_stats.T)
    return global_label_stats, global_gaussian_stats
Ejemplo n.º 15
0
def slds_prior_vlb(global_natparam, prior_natparam):
    expected_stats = slds_prior_expectedstats(global_natparam)
    return contract(sub(prior_natparam, global_natparam), expected_stats) \
        - (slds_prior_logZ(prior_natparam) - slds_prior_logZ(global_natparam))
Ejemplo n.º 16
0
def lds_prior_vlb(global_natparam, prior_natparam, expected_stats=None):
    if expected_stats is None:
        expected_stats = lds_prior_expectedstats(global_natparam)
    return contract(sub(prior_natparam, global_natparam), expected_stats) \
        - (lds_prior_logZ(prior_natparam) - lds_prior_logZ(global_natparam))
Ejemplo n.º 17
0
 def messages_to_scalar(messages):
     return contract(dotter, messages)
Ejemplo n.º 18
0
def lds_prior_vlb(global_natparam, prior_natparam, expected_stats=None):
    if expected_stats is None:
        expected_stats = lds_prior_expectedstats(global_natparam)
    return contract(sub(prior_natparam, global_natparam), expected_stats) \
        - (lds_prior_logZ(prior_natparam) - lds_prior_logZ(global_natparam))
Ejemplo n.º 19
0
 def gfun(next_smooth, next_pred, filtered):
     vals = fun(next_smooth, next_pred, filtered, pair_param)
     assert shape(vals) == shape(dotter)
     return contract(dotter, vals)
Ejemplo n.º 20
0
 def py_fun(messages):
     result = natural_smoother_general(messages, *lds)
     assert shape(result) == shape(dotter)
     return contract(dotter, result)
Ejemplo n.º 21
0
 def messages_to_scalar(messages):
     return contract(dotter, messages)
Ejemplo n.º 22
0
 def cy_fun(messages):
     result = _natural_smoother_general(messages, pair_params)
     result = result[0][:3], result[1], result[2]
     assert shape(result) == shape(dotter)
     return contract(dotter, result)
Ejemplo n.º 23
0
 def py_fun(messages):
     result = natural_smoother_general(messages, *lds)
     assert shape(result) == shape(dotter)
     return contract(dotter, result)