コード例 #1
0
def alternative_update_subst(unigram_tbl,
                             bigram_tbl,
                             link_tbl,
                             subst_tbl,
                             xe_gap,
                             max_iter,
                             prev_xe=np.inf):
    c = 1
    k, n = link_tbl.shape
    normalization_factor = np.sum(logsumexp(link_tbl, axis=0))
    while True:
        cnt_tbl, cnt_subst_tbl, prb_cf = em_iter_count(unigram_tbl,
                                                       bigram_tbl,
                                                       link_tbl,
                                                       subst_tbl,
                                                       only_subst_tbl=True)
        subst_tbl = em_iter_update_subst(cnt_subst_tbl)
        x_entropy = cross_entropy([prb_cf], [n])
        eprint('iter-subst {} cross entropy is {}, gap {},'
               ' logP(c) {}'.format(c, x_entropy,
                                    abs(1.0 - x_entropy / prev_xe), prb_cf))
        if time_to_stop(c, max_iter, prev_xe, x_entropy, xe_gap):
            break
        elif np.isnan(x_entropy):
            eprint('program end in iter {} caused by nan'.format(c))
            break
        else:
            prev_xe = x_entropy
            c += 1
    return subst_tbl, x_entropy, prb_cf, normalization_factor
コード例 #2
0
def alternative_update_gmm(line,
                           unigram_tbl,
                           bigram_tbl,
                           link_tbl,
                           subst_tbl,
                           xe_gap,
                           max_iter,
                           prev_xe=np.inf):
    c = 1
    while True:
        cnt_tbl, cnt_subst_tbl, prb_cf = em_iter_count(unigram_tbl, bigram_tbl,
                                                       link_tbl, subst_tbl)
        link_tbl, gmm, weights = em_iter_update(cnt_tbl, line)
        ll_gmm = gmm_log_likelihood(link_tbl, weights)
        x_entropy = cross_entropy([prb_cf], [len(line)])
        eprint('iter-GMM {} cross entropy is {}, gap {},'
               ' logP(c) {}, logP_GMM(c) {}'.format(
                   c, x_entropy, abs(1.0 - x_entropy / prev_xe), prb_cf,
                   ll_gmm))
        if time_to_stop(c, max_iter, prev_xe, x_entropy, xe_gap):
            break
        elif np.isnan(x_entropy):
            eprint('program end in iter {} caused by nan'.format(c))
            break
        else:
            prev_xe = x_entropy
            c += 1
    return link_tbl, gmm, weights, x_entropy, prb_cf
コード例 #3
0
    def fit(self):
        cnt_tbl = np.zeros((self.k, self.n))

        for i, y in enumerate(self.targets):
            cnt_tbl[y, i] = 1

        eprint('initialize use gold')
        self.gmm, self.weights = gmm_update(self.features,
                                            cnt_tbl,
                                            cov_type='fix',
                                            scaling_fix_cov=0.1)
        self.link_tbl = gmm_assign(self.gmm, self.features)
        if self.use_em:
            eprint('continue training with LM-GMM')
            self.link_tbl, self.gmm, self.weights, self.xe, self.ll =\
                em_decipher(self.features, self.unigram_tbl, self.bigram_tbl,
                            self.link_tbl)
        else:
            pass

        if self.unigram_tbl is not None and self.bigram_tbl is not None:
            _, _, prb_cf = em_forward_backward(self.features, self.unigram_tbl,
                                               self.bigram_tbl, self.link_tbl)
            eprint('log likelihood of LM-GMM is {}'.format(prb_cf))
            self.ll = prb_cf
            self.xe = cross_entropy([prb_cf], [self.n])
コード例 #4
0
def em_decipher(line, unigram_tbl, bigram_tbl, link_tbl_size=None):
    # prepare initial data and parameters
    if link_tbl_size is None:  # only for simple substitution ciphers
        link_tbl_size = (len(unigram_tbl), len(unigram_tbl))
    link_tbl = np.random.random(link_tbl_size)
    sum_tbl = np.sum(link_tbl, axis=1)
    nonzero = sum_tbl > 0
    link_tbl[nonzero] /= sum_tbl[nonzero, None]
    link_tbl = np.log(link_tbl)

    # prepare hyper-parameters
    xe_gap = 0.99999
    max_iter = 300
    prev_xe = np.inf

    # start training
    c = 1
    while True:
        cnt_tbl, pcf = em_iter_count(line, unigram_tbl, bigram_tbl, link_tbl)
        link_tbl = em_iter_update(cnt_tbl)
        x_entropy = cross_entropy([pcf], [len(line)])
        best_pc = pcf
        eprint('iter {} cross entropy is {}, gap {}, logP(c) {}'.format(
            c, x_entropy, x_entropy / prev_xe, pcf))
        if c >= max_iter or x_entropy / prev_xe >= xe_gap:
            break
        else:
            prev_xe = x_entropy
            c += 1

    return link_tbl, x_entropy, best_pc
コード例 #5
0
def em_decipher(line,
                unigram_tbl,
                bigram_tbl,
                link_tbl,
                subst_tbl,
                xe_gap=1e-8,
                max_iter=0):
    """
    EM on a line of features.
    EM iterations stop if matches one of the following conditions:
      1) reach the max_iter
      2) current cross entropy / last cross entropy >= xe_gap
    :return: final link_tbl, gmm model, cross entropy, and log likelihood
    """
    # prepare hyper-parameters
    prev_xe = np.inf

    # start training
    c = 1
    while True:
        cnt_tbl, cnt_subst_tbl, prb_cf = em_iter_count(unigram_tbl, bigram_tbl,
                                                       link_tbl, subst_tbl)
        _, gmm, weights = em_iter_update(cnt_tbl, line)
        # subst_tbl = em_iter_update_subst(cnt_subst_tbl)
        ll_gmm = gmm_log_likelihood(link_tbl, weights)
        x_entropy = cross_entropy([prb_cf], [len(line)])
        eprint('iter {} cross entropy is {}, gap {},'
               ' logP(c) {}, logP_GMM(c) {}'.format(
                   c, x_entropy, abs(1.0 - x_entropy / prev_xe), prb_cf,
                   ll_gmm))

        if time_to_stop(c, max_iter, prev_xe, x_entropy, xe_gap):
            break
        elif np.isnan(x_entropy):
            eprint('program end in iter {} caused by nan'.format(c))
            break
        else:
            prev_xe = x_entropy
            c += 1

    return link_tbl, subst_tbl, gmm, weights, x_entropy, prb_cf
コード例 #6
0
ファイル: gmm.py プロジェクト: yinxusen/decipherment-images
def em_gmm(line,
           link_tbl,
           weights,
           xe_gap=1e-8,
           max_iter=300,
           cov_type='fix',
           scaling_factor=0.1):
    """
    EM on a line of features.
    EM iterations stop if matches one of the following conditions:
      1) reach the max_iter
      2) current cross entropy / last cross entropy >= xe_gap
    :return: final link_tbl, gmm model, and cross entropy
    """
    # prepare hyper-parameters
    prev_xe = np.inf

    # start training
    c = 1
    while True:
        cnt_tbl, prb_cf = em_iter_gmm_count(link_tbl, weights)
        link_tbl, gmm, weights = em_iter_update(cnt_tbl,
                                                line,
                                                cov_type=cov_type,
                                                scaling_factor=scaling_factor)
        x_entropy = cross_entropy([prb_cf], [len(line)])
        eprint('iter {} cross entropy is {}, gap {}, logP(c) {}'.format(
            c, x_entropy, x_entropy / prev_xe, prb_cf))
        if time_to_stop(c, max_iter, prev_xe, x_entropy, xe_gap):
            break
        elif np.isnan(x_entropy):
            eprint('program end in iter {} caused by nan'.format(c))
            break
        else:
            prev_xe = x_entropy
            c += 1

    return link_tbl, gmm, weights, x_entropy