def alternative_update_gmm(line, unigram_tbl, bigram_tbl, link_tbl, subst_tbl, xe_gap, max_iter, prev_xe=np.inf): c = 1 while True: cnt_tbl, cnt_subst_tbl, prb_cf = em_iter_count(unigram_tbl, bigram_tbl, link_tbl, subst_tbl) link_tbl, gmm, weights = em_iter_update(cnt_tbl, line) ll_gmm = gmm_log_likelihood(link_tbl, weights) x_entropy = cross_entropy([prb_cf], [len(line)]) eprint('iter-GMM {} cross entropy is {}, gap {},' ' logP(c) {}, logP_GMM(c) {}'.format( c, x_entropy, abs(1.0 - x_entropy / prev_xe), prb_cf, ll_gmm)) if time_to_stop(c, max_iter, prev_xe, x_entropy, xe_gap): break elif np.isnan(x_entropy): eprint('program end in iter {} caused by nan'.format(c)) break else: prev_xe = x_entropy c += 1 return link_tbl, gmm, weights, x_entropy, prb_cf
def em_iter_gmm_count(link_tbl, weights): """ compute fractional counts for GMM. :param link_tbl: \log p(g_i | z_j) :param weights: p(z_j) :return: count table and log likelihood """ ll_gmm = gmm_log_likelihood(link_tbl, weights) log_weights = np.log(weights) weighted_link_tbl = link_tbl + log_weights[:, np.newaxis] cnt_tbl = weighted_link_tbl - logsumexp(weighted_link_tbl, axis=0)[np.newaxis, :] return cnt_tbl, ll_gmm
def em_decipher(line, unigram_tbl, bigram_tbl, link_tbl, subst_tbl, xe_gap=1e-8, max_iter=0): """ EM on a line of features. EM iterations stop if matches one of the following conditions: 1) reach the max_iter 2) current cross entropy / last cross entropy >= xe_gap :return: final link_tbl, gmm model, cross entropy, and log likelihood """ # prepare hyper-parameters prev_xe = np.inf # start training c = 1 while True: cnt_tbl, cnt_subst_tbl, prb_cf = em_iter_count(unigram_tbl, bigram_tbl, link_tbl, subst_tbl) _, gmm, weights = em_iter_update(cnt_tbl, line) # subst_tbl = em_iter_update_subst(cnt_subst_tbl) ll_gmm = gmm_log_likelihood(link_tbl, weights) x_entropy = cross_entropy([prb_cf], [len(line)]) eprint('iter {} cross entropy is {}, gap {},' ' logP(c) {}, logP_GMM(c) {}'.format( c, x_entropy, abs(1.0 - x_entropy / prev_xe), prb_cf, ll_gmm)) if time_to_stop(c, max_iter, prev_xe, x_entropy, xe_gap): break elif np.isnan(x_entropy): eprint('program end in iter {} caused by nan'.format(c)) break else: prev_xe = x_entropy c += 1 return link_tbl, subst_tbl, gmm, weights, x_entropy, prb_cf