Example #1
0
def iteration(V, D, N_DV, N_D, alpha, beta, M, phi_TV, z_D, inv_z_T, active_topics, inactive_topics, N_TV, N_T, D_T):
    """
    Performs a single iteration of Radford Neal's Algorithm 8.
    """

    for t in active_topics:
        phi_TV[t, :] = dirichlet(N_TV[t, :] + beta / V)

    for d in xrange(D):

        old_t = z_D[d]

        if inv_z_T is not None:
            inv_z_T[old_t].remove(d)

        N_TV[old_t, :] -= N_DV[d, :]
        N_T[old_t] -= N_D[d]
        D_T[old_t] -= 1

        seterr(divide='ignore')
        log_dist = log(D_T)
        seterr(divide='warn')

        idx = -1 * ones(M, dtype=int)
        idx[0] = old_t if D_T[old_t] == 0 else inactive_topics.pop()
        for m in xrange(1, M):
            idx[m] = inactive_topics.pop()
        active_topics |= set(idx)
        log_dist[idx] = log(alpha) - log(M)

        if idx[0] == old_t:
            phi_TV[idx[1:], :] = dirichlet(beta * ones(V) / V, M - 1)
        else:
            phi_TV[idx, :] = dirichlet(beta * ones(V) / V, M)

        for t in active_topics:
            log_dist[t] += (N_DV[d, :] * log(phi_TV[t, :])).sum()

        [t] = log_sample(log_dist)

        z_D[d] = t

        if inv_z_T is not None:
            inv_z_T[t].add(d)

        N_TV[t, :] += N_DV[d, :]
        N_T[t] += N_D[d]
        D_T[t] += 1

        idx = set(idx)
        idx.discard(t)
        active_topics -= idx
        inactive_topics |= idx
Example #2
0
def iteration(V, D, N_DV, N_D, alpha, beta, z_D, inv_z_T, active_topics,
              inactive_topics, N_TV, N_T, D_T):
    """
    Performs a single iteration of Radford Neal's Algorithm 3.
    """

    for d in xrange(D):

        old_t = z_D[d]

        if inv_z_T is not None:
            inv_z_T[old_t].remove(d)

        N_TV[old_t, :] -= N_DV[d, :]
        N_T[old_t] -= N_D[d]
        D_T[old_t] -= 1

        seterr(divide='ignore')
        log_dist = log(D_T)
        seterr(divide='warn')

        idx = old_t if D_T[old_t] == 0 else inactive_topics.pop()
        active_topics.add(idx)
        log_dist[idx] = log(alpha)

        for t in active_topics:
            log_dist[t] += gammaln(N_T[t] + beta)
            log_dist[t] -= gammaln(N_D[d] + N_T[t] + beta)
            tmp = N_TV[t, :] + beta / V
            log_dist[t] += gammaln(N_DV[d, :] + tmp).sum()
            log_dist[t] -= gammaln(tmp).sum()

        [t] = log_sample(log_dist)

        z_D[d] = t

        if inv_z_T is not None:
            inv_z_T[t].add(d)

        N_TV[t, :] += N_DV[d, :]
        N_T[t] += N_D[d]
        D_T[t] += 1

        if t != idx:
            active_topics.remove(idx)
            inactive_topics.add(idx)
Example #3
0
def iteration(V, D, N_DV, N_D, alpha, beta, z_D, inv_z_T, active_topics, inactive_topics, N_TV, N_T, D_T):
    """
    Performs a single iteration of Radford Neal's Algorithm 3.
    """

    for d in xrange(D):

        old_t = z_D[d]

        if inv_z_T is not None:
            inv_z_T[old_t].remove(d)

        N_TV[old_t, :] -= N_DV[d, :]
        N_T[old_t] -= N_D[d]
        D_T[old_t] -= 1

        seterr(divide='ignore')
        log_dist = log(D_T)
        seterr(divide='warn')

        idx = old_t if D_T[old_t] == 0 else inactive_topics.pop()
        active_topics.add(idx)
        log_dist[idx] = log(alpha)

        for t in active_topics:
            log_dist[t] += gammaln(N_T[t] + beta)
            log_dist[t] -= gammaln(N_D[d] + N_T[t] + beta)
            tmp = N_TV[t, :] + beta / V
            log_dist[t] += gammaln(N_DV[d, :] + tmp).sum()
            log_dist[t] -= gammaln(tmp).sum()

        [t] = log_sample(log_dist)

        z_D[d] = t

        if inv_z_T is not None:
            inv_z_T[t].add(d)

        N_TV[t, :] += N_DV[d, :]
        N_T[t] += N_D[d]
        D_T[t] += 1

        if t != idx:
            active_topics.remove(idx)
            inactive_topics.add(idx)
Example #4
0
def iteration(V, D, N_DV, N_D, alpha, beta, z_D, inv_z_T, active_topics, inactive_topics, N_TV, N_T, D_T, num_inner_itns):
    """
    Performs a single iteration of Metropolis-Hastings (split-merge).
    """

    N_s_V = empty(V, dtype=int)
    N_t_V = empty(V, dtype=int)

    log_dist = empty(2)

    d, e = choice(D, 2, replace=False) # choose 2 documents

    if z_D[d] == z_D[e]:
        s = inactive_topics.pop()
        active_topics.add(s)
    else:
        s = z_D[d]

    inv_z_s = set([d])
    N_s_V[:] = N_DV[d, :]
    N_s = N_D[d]
    D_s = 1

    t = z_D[e]
    inv_z_t = set([e])
    N_t_V[:] = N_DV[e, :]
    N_t = N_D[e]
    D_t = 1

    if z_D[d] == z_D[e]:
        idx = inv_z_T[t] - set([d, e])
    else:
        idx = (inv_z_T[s] | inv_z_T[t]) - set([d, e])

    for f in idx:
        if uniform() < 0.5:
            inv_z_s.add(f)
            N_s_V += N_DV[f, :]
            N_s += N_D[f]
            D_s += 1
        else:
            inv_z_t.add(f)
            N_t_V += N_DV[f, :]
            N_t += N_D[f]
            D_t += 1

    acc = 0.0

    for inner_itn in xrange(num_inner_itns):
        for f in idx:

            # (fake) restricted Gibbs sampling scan

            if f in inv_z_s:
                inv_z_s.remove(f)
                N_s_V -= N_DV[f, :]
                N_s -= N_D[f]
                D_s -= 1
            else:
                inv_z_t.remove(f)
                N_t_V -= N_DV[f, :]
                N_t -= N_D[f]
                D_t -= 1

            log_dist[0] = log(D_s)
            log_dist[0] += gammaln(N_s + beta)
            log_dist[0] -= gammaln(N_D[f] + N_s + beta)
            tmp = N_s_V + beta / V
            log_dist[0] += gammaln(N_DV[f, :] + tmp).sum()
            log_dist[0] -= gammaln(tmp).sum()

            log_dist[1] = log(D_t)
            log_dist[1] += gammaln(N_t + beta)
            log_dist[1] -= gammaln(N_D[f] + N_t + beta)
            tmp = N_t_V + beta / V
            log_dist[1] += gammaln(N_DV[f, :] + tmp).sum()
            log_dist[1] -= gammaln(tmp).sum()

            log_dist -= log_sum_exp(log_dist)

            if inner_itn == num_inner_itns - 1 and z_D[d] != z_D[e]:
                u = 0 if z_D[f] == s else 1
            else:
                [u] = log_sample(log_dist)

            if u == 0:
                inv_z_s.add(f)
                N_s_V += N_DV[f, :]
                N_s += N_D[f]
                D_s += 1
            else:
                inv_z_t.add(f)
                N_t_V += N_DV[f, :]
                N_t += N_D[f]
                D_t += 1

            if inner_itn == num_inner_itns - 1:
                acc += log_dist[u]

    if z_D[d] == z_D[e]:

        acc *= -1.0

        acc += log(alpha)
        acc += gammaln(D_s) + gammaln(D_t) - gammaln(D_T[t])

        acc += gammaln(beta) + gammaln(N_T[t] + beta)
        acc -= gammaln(N_s + beta) + gammaln(N_t + beta)
        tmp = beta / V
        acc += gammaln(N_s_V + tmp).sum() + gammaln(N_t_V + tmp).sum()
        acc -= V * gammaln(tmp) + gammaln(N_TV[t, :] + tmp).sum()

        if log(uniform()) < min(0.0, acc):
            z_D[list(inv_z_s)] = s
            z_D[list(inv_z_t)] = t
            inv_z_T[s] = inv_z_s
            inv_z_T[t] = inv_z_t
            N_TV[s, :] = N_s_V
            N_TV[t, :] = N_t_V
            N_T[s] = N_s
            N_T[t] = N_t
            D_T[s] = D_s
            D_T[t] = D_t
        else:
            active_topics.remove(s)
            inactive_topics.add(s)

    else:

        for f in inv_z_T[s]:
            inv_z_t.add(f)
            N_t_V += N_DV[f, :]
            N_t += N_D[f]
            D_t += 1

        acc -= log(alpha)
        acc += gammaln(D_t) - gammaln(D_T[s]) - gammaln(D_T[t])

        acc += gammaln(N_T[s] + beta) + gammaln(N_T[t] + beta)
        acc -= gammaln(beta) + gammaln(N_t + beta)
        tmp = beta / V
        acc += V * gammaln(tmp) + gammaln(N_t_V + tmp).sum()
        acc -= (gammaln(N_TV[s, :] + tmp).sum() +
                gammaln(N_TV[t, :] + tmp).sum())

        if log(uniform()) < min(0.0, acc):
            active_topics.remove(s)
            inactive_topics.add(s)
            z_D[list(inv_z_t)] = t
            inv_z_T[s].clear()
            inv_z_T[t] = inv_z_t
            N_TV[s, :] = zeros(V, dtype=int)
            N_TV[t, :] = N_t_V
            N_T[s] = 0
            N_T[t] = N_t
            D_T[s] = 0
            D_T[t] = D_t
def iteration(V, D, N_DV, N_D, alpha, beta, z_D, inv_z_T, active_topics,
              inactive_topics, N_TV, N_T, D_T, num_inner_itns):
    """
    Performs a single iteration of Metropolis-Hastings (split-merge).
    """

    N_s_V = empty(V, dtype=int)
    N_t_V = empty(V, dtype=int)

    log_dist = empty(2)

    d, e = choice(D, 2, replace=False)  # choose 2 documents

    if z_D[d] == z_D[e]:
        s = inactive_topics.pop()
        active_topics.add(s)
    else:
        s = z_D[d]

    inv_z_s = set([d])
    N_s_V[:] = N_DV[d, :]
    N_s = N_D[d]
    D_s = 1

    t = z_D[e]
    inv_z_t = set([e])
    N_t_V[:] = N_DV[e, :]
    N_t = N_D[e]
    D_t = 1

    if z_D[d] == z_D[e]:
        idx = inv_z_T[t] - set([d, e])
    else:
        idx = (inv_z_T[s] | inv_z_T[t]) - set([d, e])

    for f in idx:
        if uniform() < 0.5:
            inv_z_s.add(f)
            N_s_V += N_DV[f, :]
            N_s += N_D[f]
            D_s += 1
        else:
            inv_z_t.add(f)
            N_t_V += N_DV[f, :]
            N_t += N_D[f]
            D_t += 1

    acc = 0.0

    for inner_itn in xrange(num_inner_itns):
        for f in idx:

            # (fake) restricted Gibbs sampling scan

            if f in inv_z_s:
                inv_z_s.remove(f)
                N_s_V -= N_DV[f, :]
                N_s -= N_D[f]
                D_s -= 1
            else:
                inv_z_t.remove(f)
                N_t_V -= N_DV[f, :]
                N_t -= N_D[f]
                D_t -= 1

            log_dist[0] = log(D_s)
            log_dist[0] += gammaln(N_s + beta)
            log_dist[0] -= gammaln(N_D[f] + N_s + beta)
            tmp = N_s_V + beta / V
            log_dist[0] += gammaln(N_DV[f, :] + tmp).sum()
            log_dist[0] -= gammaln(tmp).sum()

            log_dist[1] = log(D_t)
            log_dist[1] += gammaln(N_t + beta)
            log_dist[1] -= gammaln(N_D[f] + N_t + beta)
            tmp = N_t_V + beta / V
            log_dist[1] += gammaln(N_DV[f, :] + tmp).sum()
            log_dist[1] -= gammaln(tmp).sum()

            log_dist -= log_sum_exp(log_dist)

            if inner_itn == num_inner_itns - 1 and z_D[d] != z_D[e]:
                u = 0 if z_D[f] == s else 1
            else:
                [u] = log_sample(log_dist)

            if u == 0:
                inv_z_s.add(f)
                N_s_V += N_DV[f, :]
                N_s += N_D[f]
                D_s += 1
            else:
                inv_z_t.add(f)
                N_t_V += N_DV[f, :]
                N_t += N_D[f]
                D_t += 1

            if inner_itn == num_inner_itns - 1:
                acc += log_dist[u]

    if z_D[d] == z_D[e]:

        acc *= -1.0

        acc += log(alpha)
        acc += gammaln(D_s) + gammaln(D_t) - gammaln(D_T[t])

        acc += gammaln(beta) + gammaln(N_T[t] + beta)
        acc -= gammaln(N_s + beta) + gammaln(N_t + beta)
        tmp = beta / V
        acc += gammaln(N_s_V + tmp).sum() + gammaln(N_t_V + tmp).sum()
        acc -= V * gammaln(tmp) + gammaln(N_TV[t, :] + tmp).sum()

        if log(uniform()) < min(0.0, acc):
            z_D[list(inv_z_s)] = s
            z_D[list(inv_z_t)] = t
            inv_z_T[s] = inv_z_s
            inv_z_T[t] = inv_z_t
            N_TV[s, :] = N_s_V
            N_TV[t, :] = N_t_V
            N_T[s] = N_s
            N_T[t] = N_t
            D_T[s] = D_s
            D_T[t] = D_t
        else:
            active_topics.remove(s)
            inactive_topics.add(s)

    else:

        for f in inv_z_T[s]:
            inv_z_t.add(f)
            N_t_V += N_DV[f, :]
            N_t += N_D[f]
            D_t += 1

        acc -= log(alpha)
        acc += gammaln(D_t) - gammaln(D_T[s]) - gammaln(D_T[t])

        acc += gammaln(N_T[s] + beta) + gammaln(N_T[t] + beta)
        acc -= gammaln(beta) + gammaln(N_t + beta)
        tmp = beta / V
        acc += V * gammaln(tmp) + gammaln(N_t_V + tmp).sum()
        acc -= (gammaln(N_TV[s, :] + tmp).sum() +
                gammaln(N_TV[t, :] + tmp).sum())

        if log(uniform()) < min(0.0, acc):
            active_topics.remove(s)
            inactive_topics.add(s)
            z_D[list(inv_z_t)] = t
            inv_z_T[s].clear()
            inv_z_T[t] = inv_z_t
            N_TV[s, :] = zeros(V, dtype=int)
            N_TV[t, :] = N_t_V
            N_T[s] = 0
            N_T[t] = N_t
            D_T[s] = 0
            D_T[t] = D_t
Example #6
0
def iteration(
    V, D, N_DV, N_D, alpha, beta, phi_TV, z_D, inv_z_T, active_topics, inactive_topics, N_TV, N_T, D_T, num_inner_itns
):
    """
    Performs a single iteration of Metropolis-Hastings (split-merge).
    """

    phi_s_V = empty(V)
    phi_t_V = empty(V)
    phi_merge_t_V = empty(V)

    N_s_V = empty(V, dtype=int)
    N_t_V = empty(V, dtype=int)
    N_merge_t_V = empty(V, dtype=int)

    log_dist = empty(2)

    d, e = choice(D, 2, replace=False)  # choose 2 documents

    if z_D[d] == z_D[e]:
        s = inactive_topics.pop()
        active_topics.add(s)
    else:
        s = z_D[d]

    inv_z_s = set([d])
    N_s_V[:] = N_DV[d, :]
    N_s = N_D[d]
    D_s = 1

    t = z_D[e]
    inv_z_t = set([e])
    N_t_V[:] = N_DV[e, :]
    N_t = N_D[e]
    D_t = 1

    inv_z_merge_t = set([d, e])
    N_merge_t_V[:] = N_DV[d, :] + N_DV[e, :]
    N_merge_t = N_D[d] + N_D[e]
    D_merge_t = 2

    if z_D[d] == z_D[e]:
        idx = inv_z_T[t] - set([d, e])
    else:
        idx = (inv_z_T[s] | inv_z_T[t]) - set([d, e])

    for f in idx:
        if uniform() < 0.5:
            inv_z_s.add(f)
            N_s_V += N_DV[f, :]
            N_s += N_D[f]
            D_s += 1
        else:
            inv_z_t.add(f)
            N_t_V += N_DV[f, :]
            N_t += N_D[f]
            D_t += 1

        inv_z_merge_t.add(f)
        N_merge_t_V += N_DV[f, :]
        N_merge_t += N_D[f]
        D_merge_t += 1

    if z_D[d] == z_D[e]:
        phi_merge_t_V[:] = phi_TV[t, :]
    else:
        phi_merge_t_V = dirichlet(N_merge_t_V + beta / V)

    acc = 0.0

    for inner_itn in xrange(num_inner_itns):

        # sample new parameters for topics s and t ... but if it's the
        # last iteration and we're doing a merge, then just set the
        # parameters back to phi_TV[s, :] and phi_TV[t, :]

        if inner_itn == num_inner_itns - 1 and z_D[d] != z_D[e]:
            phi_s_V[:] = phi_TV[s, :]
            phi_t_V[:] = phi_TV[t, :]
        else:
            phi_s_V = dirichlet(N_s_V + beta / V)
            phi_t_V = dirichlet(N_t_V + beta / V)

        if inner_itn == num_inner_itns - 1:

            acc += gammaln(N_s + beta)
            acc -= gammaln(N_s_V + beta / V).sum()
            acc += ((N_s_V + beta / V - 1) * log(phi_s_V)).sum()

            acc += gammaln(N_t + beta)
            acc -= gammaln(N_t_V + beta / V).sum()
            acc += ((N_t_V + beta / V - 1) * log(phi_t_V)).sum()

            acc -= gammaln(N_merge_t + beta)
            acc += gammaln(N_merge_t_V + beta / V).sum()
            acc -= ((N_merge_t_V + beta / V - 1) * log(phi_merge_t_V)).sum()

        for f in idx:

            # (fake) restricted Gibbs sampling scan

            if f in inv_z_s:
                inv_z_s.remove(f)
                N_s_V -= N_DV[f, :]
                N_s -= N_D[f]
                D_s -= 1
            else:
                inv_z_t.remove(f)
                N_t_V -= N_DV[f, :]
                N_t -= N_D[f]
                D_t -= 1

            log_dist[0] = log(D_s)
            log_dist[0] += (N_DV[f, :] * log(phi_s_V)).sum()

            log_dist[1] = log(D_t)
            log_dist[1] += (N_DV[f, :] * log(phi_t_V)).sum()

            log_dist -= log_sum_exp(log_dist)

            if inner_itn == num_inner_itns - 1 and z_D[d] != z_D[e]:
                u = 0 if z_D[f] == s else 1
            else:
                [u] = log_sample(log_dist)

            if u == 0:
                inv_z_s.add(f)
                N_s_V += N_DV[f, :]
                N_s += N_D[f]
                D_s += 1
            else:
                inv_z_t.add(f)
                N_t_V += N_DV[f, :]
                N_t += N_D[f]
                D_t += 1

            if inner_itn == num_inner_itns - 1:
                acc += log_dist[u]

    if z_D[d] == z_D[e]:

        acc *= -1.0

        acc += log(alpha)
        acc += gammaln(D_s) + gammaln(D_t) - gammaln(D_T[t])
        tmp = beta / V
        acc += gammaln(beta) - V * gammaln(tmp)
        acc += (tmp - 1) * (log(phi_s_V).sum() + log(phi_t_V).sum())
        acc -= (tmp - 1) * log(phi_TV[t, :]).sum()

        acc += (N_s_V * log(phi_s_V)).sum() + (N_t_V * log(phi_t_V)).sum()
        acc -= (N_TV[t, :] * log(phi_TV[t, :])).sum()

        if log(uniform()) < min(0.0, acc):
            phi_TV[s, :] = phi_s_V
            phi_TV[t, :] = phi_t_V
            z_D[list(inv_z_s)] = s
            z_D[list(inv_z_t)] = t
            inv_z_T[s] = inv_z_s
            inv_z_T[t] = inv_z_t
            N_TV[s, :] = N_s_V
            N_TV[t, :] = N_t_V
            N_T[s] = N_s
            N_T[t] = N_t
            D_T[s] = D_s
            D_T[t] = D_t
        else:
            active_topics.remove(s)
            inactive_topics.add(s)

    else:

        acc -= log(alpha)
        acc += gammaln(D_merge_t) - gammaln(D_T[s]) - gammaln(D_T[t])
        tmp = beta / V
        acc += V * gammaln(tmp) - gammaln(beta)
        acc += (tmp - 1) * log(phi_merge_t_V).sum()
        acc -= (tmp - 1) * (log(phi_TV[s, :]).sum() + log(phi_TV[t, :]).sum())

        acc += (N_merge_t_V * log(phi_merge_t_V)).sum()
        acc -= (N_TV[s, :] * log(phi_TV[s, :])).sum() + (N_TV[t, :] * log(phi_TV[t, :])).sum()

        if log(uniform()) < min(0.0, acc):
            phi_TV[s, :] = zeros(V)
            phi_TV[t, :] = phi_merge_t_V
            active_topics.remove(s)
            inactive_topics.add(s)
            z_D[list(inv_z_merge_t)] = t
            inv_z_T[s].clear()
            inv_z_T[t] = inv_z_merge_t
            N_TV[s, :] = zeros(V, dtype=int)
            N_TV[t, :] = N_merge_t_V
            N_T[s] = 0
            N_T[t] = N_merge_t
            D_T[s] = 0
            D_T[t] = D_merge_t
Example #7
0
File: crp.py Project: aschein/dpmm
def inference_algorithm_3(N_DV, alpha, beta, num_itns=250, true_z_D=None):
    """
    Algorithm 3.
    """

    D, V = N_DV.shape

    T = D # maximum number of topics

    N_D = N_DV.sum(1) # document lengths

    N_TV = zeros((T, V), dtype=int)
    N_T = zeros(T, dtype=int)

    z_D = range(D) # intialize every document to its own topic

    active_topics = set(unique(z_D))
    inactive_topics = set(xrange(T)) - active_topics

    for d in xrange(D):
        N_TV[z_D[d], :] += N_DV[d, :]
        N_T[z_D[d]] += N_D[d]

    D_T = bincount(z_D, minlength=T)

    for itn in xrange(num_itns):
        for d in xrange(D):

            old_t = z_D[d]

            D_T[old_t] -= 1
            N_TV[old_t, :] -= N_DV[d, :]
            N_T[old_t] -= N_D[d]

            log_dist = log(D_T)

            idx = old_t if D_T[old_t] == 0 else inactive_topics.pop()
            active_topics.add(idx)
            log_dist[idx] = log(alpha)

            for t in active_topics:
                log_dist[t] += gammaln(N_T[t] + beta)
                log_dist[t] -= gammaln(N_T[t] + N_D[d] + beta)
                tmp = N_TV[t, :] + beta / V
                log_dist[t] += gammaln(tmp + N_DV[d, :]).sum()
                log_dist[t] -= gammaln(tmp).sum()

            [t] = log_sample(log_dist)

            D_T[t] += 1
            N_TV[t, :] += N_DV[d, :]
            N_T[t] += N_D[d]

            z_D[d] = t

            if t != idx:
                active_topics.remove(idx)
                inactive_topics.add(idx)

        if true_z_D is not None:
            print 'VI: %f bits (%f bits max.)' % (vi(true_z_D, z_D), log2(D))

        for t in active_topics:
            print D_T[t], (N_TV[t, :] + beta / V) / (N_TV[t, :].sum() + beta)

        print len(active_topics)

    return z_D
Example #8
0
File: crp.py Project: aschein/dpmm
def inference_algorithm_8(N_DV, alpha, beta, num_itns=250, true_z_D=None):
    """
    Algorithm 8.
    """

    M = 10

    D, V = N_DV.shape

    T = D + M - 1 # maximum number of topics

    N_D = N_DV.sum(1) # document lengths

    N_TV = zeros((T, V), dtype=int)
    N_T = zeros(T, dtype=int)

    z_D = range(D) # intialize every document to its own topic

    phi_TV = zeros((T, V))

    active_topics = set(unique(z_D))
    inactive_topics = set(xrange(T)) - active_topics

    for d in xrange(D):
        N_TV[z_D[d], :] += N_DV[d, :]
        N_T[z_D[d]] += N_D[d]

    D_T = bincount(z_D, minlength=T)

    for itn in xrange(num_itns):

        for t in active_topics:
            phi_TV[t, :] = dirichlet(N_TV[t, :] + beta / V, 1)

        for d in xrange(D):

            old_t = z_D[d]

            D_T[old_t] -= 1
            N_TV[old_t, :] -= N_DV[d, :]
            N_T[old_t] -= N_D[d]

            log_dist = log(D_T)

            idx = -1 * ones(M, dtype=int)
            idx[0] = old_t if D_T[old_t] == 0 else inactive_topics.pop()
            for m in xrange(1, M):
                idx[m] = inactive_topics.pop()
            active_topics |= set(idx)
            log_dist[idx] = log(alpha) - log(M)

            if idx[0] == old_t:
                phi_TV[idx[1:], :] = dirichlet(beta * ones(V) / V, M - 1)
            else:
                phi_TV[idx, :] = dirichlet(beta * ones(V) / V, M)

            for t in active_topics:
                log_dist[t] += (N_DV[d, :] * log(phi_TV[t, :])).sum()

            [t] = log_sample(log_dist)

            D_T[t] += 1
            N_TV[t, :] += N_DV[d, :]
            N_T[t] += N_D[d]

            z_D[d] = t

            idx = set(idx)
            idx.discard(t)
            active_topics -= idx
            inactive_topics |= idx

        if true_z_D is not None:
            print 'VI: %f bits (%f bits max.)' % (vi(true_z_D, z_D), log2(D))

        for t in active_topics:
            print D_T[t], (N_TV[t, :] + beta / V) / (N_TV[t, :].sum() + beta)

        print len(active_topics)

    return z_D
Example #9
0
def iteration(V, D, N_DV, N_D, alpha, beta, phi_TV, z_D, inv_z_T, active_topics, inactive_topics, N_TV, N_T, D_T, num_inner_itns):
    """
    Performs a single iteration of Metropolis-Hastings (split-merge).
    """

    phi_s_V = empty(V)
    phi_t_V = empty(V)
    phi_merge_t_V = empty(V)

    N_s_V = empty(V, dtype=int)
    N_t_V = empty(V, dtype=int)
    N_merge_t_V = empty(V, dtype=int)

    log_dist = empty(2)

    d, e = choice(D, 2, replace=False) # choose 2 documents

    if z_D[d] == z_D[e]:
        s = inactive_topics.pop()
        active_topics.add(s)
    else:
        s = z_D[d]

    inv_z_s = set([d])
    N_s_V[:] = N_DV[d, :]
    N_s = N_D[d]
    D_s = 1

    t = z_D[e]
    inv_z_t = set([e])
    N_t_V[:] = N_DV[e, :]
    N_t = N_D[e]
    D_t = 1

    inv_z_merge_t = set([d, e])
    N_merge_t_V[:] = N_DV[d, :] + N_DV[e, :]
    N_merge_t = N_D[d] + N_D[e]
    D_merge_t = 2

    if z_D[d] == z_D[e]:
        idx = inv_z_T[t] - set([d, e])
    else:
        idx = (inv_z_T[s] | inv_z_T[t]) - set([d, e])

    for f in idx:
        if uniform() < 0.5:
            inv_z_s.add(f)
            N_s_V += N_DV[f, :]
            N_s += N_D[f]
            D_s += 1
        else:
            inv_z_t.add(f)
            N_t_V += N_DV[f, :]
            N_t += N_D[f]
            D_t += 1

        inv_z_merge_t.add(f)
        N_merge_t_V += N_DV[f, :]
        N_merge_t += N_D[f]
        D_merge_t += 1

    if z_D[d] == z_D[e]:
        phi_merge_t_V[:] = phi_TV[t, :]
    else:
        phi_merge_t_V = dirichlet(N_merge_t_V + beta / V)

    acc = 0.0

    for inner_itn in xrange(num_inner_itns):

        # sample new parameters for topics s and t ... but if it's the
        # last iteration and we're doing a merge, then just set the
        # parameters back to phi_TV[s, :] and phi_TV[t, :]

        if inner_itn == num_inner_itns - 1 and z_D[d] != z_D[e]:
            phi_s_V[:] = phi_TV[s, :]
            phi_t_V[:] = phi_TV[t, :]
        else:
            phi_s_V = dirichlet(N_s_V + beta / V)
            phi_t_V = dirichlet(N_t_V + beta / V)

        if inner_itn == num_inner_itns - 1:

            acc += gammaln(N_s + beta)
            acc -= gammaln(N_s_V + beta / V).sum()
            acc += ((N_s_V + beta / V - 1) * log(phi_s_V)).sum()

            acc += gammaln(N_t + beta)
            acc -= gammaln(N_t_V + beta / V).sum()
            acc += ((N_t_V + beta / V - 1) * log(phi_t_V)).sum()

            acc -= gammaln(N_merge_t + beta)
            acc += gammaln(N_merge_t_V + beta / V).sum()
            acc -= ((N_merge_t_V + beta / V - 1) *
                    log(phi_merge_t_V)).sum()

        for f in idx:

            # (fake) restricted Gibbs sampling scan

            if f in inv_z_s:
                inv_z_s.remove(f)
                N_s_V -= N_DV[f, :]
                N_s -= N_D[f]
                D_s -= 1
            else:
                inv_z_t.remove(f)
                N_t_V -= N_DV[f, :]
                N_t -= N_D[f]
                D_t -= 1

            log_dist[0] = log(D_s)
            log_dist[0] += (N_DV[f, :] * log(phi_s_V)).sum()

            log_dist[1] = log(D_t)
            log_dist[1] += (N_DV[f, :] * log(phi_t_V)).sum()

            log_dist -= log_sum_exp(log_dist)

            if inner_itn == num_inner_itns - 1 and z_D[d] != z_D[e]:
                u = 0 if z_D[f] == s else 1
            else:
                [u] = log_sample(log_dist)

            if u == 0:
                inv_z_s.add(f)
                N_s_V += N_DV[f, :]
                N_s += N_D[f]
                D_s += 1
            else:
                inv_z_t.add(f)
                N_t_V += N_DV[f, :]
                N_t += N_D[f]
                D_t += 1

            if inner_itn == num_inner_itns - 1:
                acc += log_dist[u]

    if z_D[d] == z_D[e]:

        acc *= -1.0

        acc += log(alpha)
        acc += gammaln(D_s) + gammaln(D_t) - gammaln(D_T[t])
        tmp = beta / V
        acc += gammaln(beta) - V * gammaln(tmp)
        acc += (tmp - 1) * (log(phi_s_V).sum() + log(phi_t_V).sum())
        acc -= (tmp - 1) * log(phi_TV[t, :]).sum()

        acc += (N_s_V * log(phi_s_V)).sum() + (N_t_V * log(phi_t_V)).sum()
        acc -= (N_TV[t, :] * log(phi_TV[t, :])).sum()

        if log(uniform()) < min(0.0, acc):
            phi_TV[s, :] = phi_s_V
            phi_TV[t, :] = phi_t_V
            z_D[list(inv_z_s)] = s
            z_D[list(inv_z_t)] = t
            inv_z_T[s] = inv_z_s
            inv_z_T[t] = inv_z_t
            N_TV[s, :] = N_s_V
            N_TV[t, :] = N_t_V
            N_T[s] = N_s
            N_T[t] = N_t
            D_T[s] = D_s
            D_T[t] = D_t
        else:
            active_topics.remove(s)
            inactive_topics.add(s)

    else:

        acc -= log(alpha)
        acc += gammaln(D_merge_t) - gammaln(D_T[s]) - gammaln(D_T[t])
        tmp = beta / V
        acc += V * gammaln(tmp) - gammaln(beta)
        acc += (tmp - 1) * log(phi_merge_t_V).sum()
        acc -= (tmp - 1) * (log(phi_TV[s, :]).sum() + log(phi_TV[t, :]).sum())

        acc += (N_merge_t_V * log(phi_merge_t_V)).sum()
        acc -= ((N_TV[s, :] * log(phi_TV[s, :])).sum() +
                (N_TV[t, :] * log(phi_TV[t, :])).sum())

        if log(uniform()) < min(0.0, acc):
            phi_TV[s, :] = zeros(V)
            phi_TV[t, :] = phi_merge_t_V
            active_topics.remove(s)
            inactive_topics.add(s)
            z_D[list(inv_z_merge_t)] = t
            inv_z_T[s].clear()
            inv_z_T[t] = inv_z_merge_t
            N_TV[s, :] = zeros(V, dtype=int)
            N_TV[t, :] = N_merge_t_V
            N_T[s] = 0
            N_T[t] = N_merge_t
            D_T[s] = 0
            D_T[t] = D_merge_t