Example #1
0
 def write_hypos(self, all_hypos, sen_indices):
     """Writes ngram files for each sentence in ``all_hypos``.
     
     Args:
         all_hypos (list): list of nbest lists of hypotheses
         sen_indices (list): List of sentence indices (0-indexed)
     
     Raises:
         OSError. If the directory could not be created
         IOError. If something goes wrong while writing to the disk
     """
     _mkdir(self.path, "ngram")
     for sen_idx, hypos in zip(sen_indices, all_hypos):
         sen_idx += 1
         total = utils.log_sum([hypo.total_score for hypo in hypos])
         normed_scores = [hypo.total_score - total for hypo in hypos]
         ngrams = defaultdict(dict)
         # Collect ngrams
         for hypo_idx, hypo in enumerate(hypos):
             sen_eos = [utils.GO_ID] + hypo.trgt_sentence + [utils.EOS_ID]
             for pos in range(1, len(sen_eos) + 1):
                 hist = sen_eos[:pos]
                 for order in range(self.min_order, self.max_order + 1):
                     ngram = ' '.join(map(str, hist[-order:]))
                     ngrams[ngram][hypo_idx] = True
         with open(self.file_pattern % sen_idx, "w") as f:
             for ngram, hypo_indices in ngrams.items():
                 ngram_score = np.exp(
                     utils.log_sum([
                         normed_scores[hypo_idx]
                         for hypo_idx in hypo_indices
                     ]))
                 f.write("%s : %f\n" % (ngram, min(1.0, ngram_score)))
Example #2
0
def ic_log_pvalue(N, L, des_ic, verbose=False, trials=100, method="ub"):
    print des_ic
    correction_per_col = 3/(2*log(2)*N)
    K = L * correction_per_col # correction per motif
    ic_for_beta = des_ic + K
    tolerance = 10**-10
    beta = find_beta_for_mean_motif_ic(N,L,ic_for_beta,tolerance=tolerance,verbose=verbose) # correct val of beta
    countses = enumerate_counts(N)
    entropies = np.array(map(entropy_from_counts, countses))
    iterator = tqdm(countses) if verbose else countses
    log_cols = np.array(map(log_counts_to_cols, iterator))
    log_Zq = log_sum(log_cols + -beta*entropies)*L
    log_Zp = N*L*log(4)
    #log_prefactor = log_Zq - log_Zp + beta*2*L
    log_prefactor = log_Zq - log_Zp + beta*(2*L-K)
    if method == "UB":
        log_expectation_ub = (-beta*(des_ic))
        log_pval_ub = log_prefactor + log_expectation_ub
        return log_pval_ub - log(2)
    elif method == "analytic":
        mu, sigma = calc_params(N, L, beta)
        log_expectation = log(compute_expectation_spec(beta, mu, sigma))
        log_pval = log_prefactor + log_expectation
        return log_pval
    else:
        ms = maxent_motifs(N, L, des_ic, trials, beta=beta)
        ics = map(motif_ic, ms)
        print "des_ic, mean ics:", des_ic, mean(ics)
        log_expectation = log_sum([-beta*ic for ic in ics if ic > des_ic]) - log(trials) # Xxx loss of precision
        log_pval = log_prefactor + log_expectation
        return log_pval
Example #3
0
    def pval_estimate(v):
        limit = len(v)
        low = max(s0 - Q1, 0)
        mid = min(s0, limit)

        if theta != 0:
            v += -np.arange(limit) * theta
        if log_mgf != 0:
            v += (L - 1) * log_mgf

        q1 = utils.log_sum(v[low:mid] + tail_sums[s0 - low:s0 - mid:-1])
        q2 = utils.log_sum(v[mid:limit])
        q = np.logaddexp(q1, q2)
        return q
Example #4
0
    def pval_estimate(v):
        limit = len(v)
        low = max(s0 - Q1, 0)
        mid = min(s0, limit)

        if theta != 0:
            v += -np.arange(limit) * theta
        if log_mgf != 0:
            v += (L - 1) * log_mgf

        q1 = utils.log_sum(v[low:mid] + tail_sums[s0 - low:s0 - mid:-1])
        q2 = utils.log_sum(v[mid:limit])
        q = np.logaddexp(q1, q2)
        return q
Example #5
0
def sfft_pvalue(log_pmf, s0, L):
    theta = sisfft._compute_theta(log_pmf, s0, L)
    shifted, mgf = utils.shift(log_pmf, theta)
    sfft_vector, fft_len = naive.power_fft(shifted, L)
    error_estimate = utils.sfft_error_threshold_factor(fft_len, L)
    sfft_vector[sfft_vector < np.log(error_estimate)] = utils.NEG_INF
    return utils.log_sum(utils.unshift(sfft_vector, theta, (mgf, L))[s0:])
Example #6
0
def sfft_pvalue(log_pmf, s0, L):
    theta = sisfft._compute_theta(log_pmf, s0, L)
    shifted, mgf = utils.shift(log_pmf, theta)
    sfft_vector, fft_len = naive.power_fft(shifted, L)
    error_estimate = utils.sfft_error_threshold_factor(fft_len, L)
    sfft_vector[sfft_vector < np.log(error_estimate)] = utils.NEG_INF
    return utils.log_sum(utils.unshift(sfft_vector, theta, (mgf, L))[s0:])
Example #7
0
    def log_alpha(self, observation):
        # 1. Initilization
        T = len(observation)
        ln_alpha = np.zeros(shape=(T, self.state_dim))
        ln_alpha = self.__log_alpha(observation, ln_alpha)

        # 2. Induction
        loglikelihood = float('-inf')
        for i in range(self.state_dim):
            loglikelihood = log_sum(loglikelihood, ln_alpha[T - 1][i])

        return ln_alpha, loglikelihood
Example #8
0
    def log_beta(self, observation):
        # 1. Initilization
        T = len(observation)
        ln_beta = np.zeros(shape=(T, self.state_dim))
        ln_beta = self.__log_alpha(observation, ln_beta)

        # 2. Induction
        loglikelihood = float('-inf')
        for i in range(self.state_dim):
            loglikelihood = log_sum(loglikelihood, ln_beta[0][i] + self.log_pi[i] + self.log_emissions[i][observation[0]])
        
        return ln_beta, loglikelihood
Example #9
0
    def log_alpha(self, observation):
        # 1. Initilization
        T = len(observation)
        ln_alpha = np.zeros(shape=(T, self.state_dim))
        ln_alpha = self.__log_alpha(observation, ln_alpha)

        # 2. Induction
        loglikelihood = float('-inf')
        for i in range(self.state_dim):
            loglikelihood = log_sum(loglikelihood, ln_alpha[T-1][i])
        
        return ln_alpha, loglikelihood
Example #10
0
def convolve_naive_into(log_c, locations, log_u, log_v):
    nu = len(log_u)
    nv = len(log_v)
    nc = nu + nv - 1
    assert len(log_c) == nc

    for k in locations:
        low_j = max(0, k - nv + 1)
        hi_j = min(k + 1, nu)
        slice_u = log_u[low_j:hi_j]
        slice_v = log_v[k - low_j:k - hi_j:-1] if k - hi_j != -1 else log_v[k - low_j::-1]
        log_c[k] = utils.log_sum(slice_u + slice_v)
Example #11
0
    def log_beta(self, observation):
        # 1. Initilization
        T = len(observation)
        ln_beta = np.zeros(shape=(T, self.state_dim))
        ln_beta = self.__log_alpha(observation, ln_beta)

        # 2. Induction
        loglikelihood = float('-inf')
        for i in range(self.state_dim):
            loglikelihood = log_sum(
                loglikelihood, ln_beta[0][i] + self.log_pi[i] +
                self.log_emissions[i][observation[0]])

        return ln_beta, loglikelihood
Example #12
0
    def __log_beta(self, observation, ln_beta):
        # 1. Initilization
        T = len(observation)
        for i in range(self.state_dim):
            ln_beta[T-1][i] = 0

        # 2. Induction
        for t in range(T-2, -1, -1):
            for i in range(self.state_dim):
                tmp_sum = float('-inf')
                for j in range(self.state_dim):
                    tmp_sum = log_sum(tmp_sum, ln_beta[t+1][j] + self.log_transitions[i][j] + self.log_emissions[j][observation[t+1]])
                ln_beta[t][i] += tmp_sum

        return ln_beta
Example #13
0
def ic_log_pvalue(N, L, des_ic, verbose=False, trials=100, method="ub"):
    print des_ic
    correction_per_col = 3 / (2 * log(2) * N)
    K = L * correction_per_col  # correction per motif
    ic_for_beta = des_ic + K
    tolerance = 10**-10
    beta = find_beta_for_mean_motif_ic(N,
                                       L,
                                       ic_for_beta,
                                       tolerance=tolerance,
                                       verbose=verbose)  # correct val of beta
    countses = enumerate_counts(N)
    entropies = np.array(map(entropy_from_counts, countses))
    iterator = tqdm(countses) if verbose else countses
    log_cols = np.array(map(log_counts_to_cols, iterator))
    log_Zq = log_sum(log_cols + -beta * entropies) * L
    log_Zp = N * L * log(4)
    #log_prefactor = log_Zq - log_Zp + beta*2*L
    log_prefactor = log_Zq - log_Zp + beta * (2 * L - K)
    if method == "UB":
        log_expectation_ub = (-beta * (des_ic))
        log_pval_ub = log_prefactor + log_expectation_ub
        return log_pval_ub - log(2)
    elif method == "analytic":
        mu, sigma = calc_params(N, L, beta)
        log_expectation = log(compute_expectation_spec(beta, mu, sigma))
        log_pval = log_prefactor + log_expectation
        return log_pval
    else:
        ms = maxent_motifs(N, L, des_ic, trials, beta=beta)
        ics = map(motif_ic, ms)
        print "des_ic, mean ics:", des_ic, mean(ics)
        log_expectation = log_sum([-beta * ic for ic in ics if ic > des_ic
                                   ]) - log(trials)  # Xxx loss of precision
        log_pval = log_prefactor + log_expectation
        return log_pval
Example #14
0
    def __log_alpha(self, observation, ln_alpha):
        # 1. Initilization
        # alpha representation: alpha[t][i]
        T = len(observation)
        for i in range(self.state_dim):
            ln_alpha[0][i] = self.log_pi[i] + self.log_emissions[i][observation[0]]

        # 2. Induction
        for t in range(1, T):
            obs = observation[t]
            for i in range(self.state_dim):
                tmp_sum = float('-inf')
                for j in range(self.state_dim):
                    tmp_sum = log_sum(tmp_sum, ln_alpha[t-1][j] + self.log_transitions[j][i])
                ln_alpha[t][i] = tmp_sum + self.log_emissions[i][obs]

        return ln_alpha
Example #15
0
    def ksi(self, index, observation, ln_alpha, ln_beta, log_ksi):
        T = len(observation)

        for t in range(T-1):
            ln_sum = float('-inf')
            x = observation[t+1]

            for i in range(self.state_dim):
                for j in range(self.state_dim):
                    log_ksi[index][t][i][j] = ln_alpha[t][i] + self.log_transitions[i][j] + self.log_emissions[j][x] + ln_beta[t+1][j]
                    ln_sum = log_sum(ln_sum, log_ksi[index][t][i][j])

            for i in range(self.state_dim):
                for j in range(self.state_dim):
                    log_ksi[index][t][i][j] -= ln_sum

        return log_ksi
Example #16
0
    def __log_beta(self, observation, ln_beta):
        # 1. Initilization
        T = len(observation)
        for i in range(self.state_dim):
            ln_beta[T - 1][i] = 0

        # 2. Induction
        for t in range(T - 2, -1, -1):
            for i in range(self.state_dim):
                tmp_sum = float('-inf')
                for j in range(self.state_dim):
                    tmp_sum = log_sum(
                        tmp_sum,
                        ln_beta[t + 1][j] + self.log_transitions[i][j] +
                        self.log_emissions[j][observation[t + 1]])
                ln_beta[t][i] += tmp_sum

        return ln_beta
Example #17
0
File: core.py Project: rycolab/bfbs
 def finalize_posterior(self, scores, use_weights, normalize_scores):
     """This method can be used to enforce the parameters use_weights
     normalize_scores in predictors with dict posteriors.
     
     Args:
         scores (dict): unnormalized log valued scores
         use_weights (bool): Set to false to replace all values in 
                             ``scores`` with 0 (= log 1)
         normalize_scores: Set to true to make the exp of elements 
                           in ``scores`` sum up to 1"""
     if not scores: # empty scores -> pass through
         return scores
     if not use_weights:
         scores = dict.fromkeys(scores, 0.0)
     if normalize_scores:
         log_sum = utils.log_sum(scores.values())
         ret = {k: v - log_sum for k, v in scores.items()}
         return ret
     return scores
Example #18
0
def pvalue(log_pmf, s0, L, desired_beta):
    """Compute $log((exp(log_pmf)**L)[s0:])$, such that the relative error
       to the exact answer is less than or equal to $desired_beta$."""
    total_len, _ = utils.iterated_convolution_lengths(len(log_pmf), L)
    if s0 >= total_len:
        return NEG_INF

    _, p_lower_preshift, p_upper_preshift = _bounds(log_pmf, log_pmf, 0, 0.0,
                                                    s0, L, desired_beta)
    sfft_good_preshift, sfft_pval_preshift = _check_sfft_pvalue(p_lower_preshift,
                                                                p_upper_preshift,
                                                                desired_beta)
    if sfft_good_preshift:
        logging.debug(' pre-shift sfft worked %.20f', sfft_pval_preshift)
        return sfft_pval_preshift

    with timer('computing theta'):
        theta = _compute_theta(log_pmf, s0, L)
    logging.debug('raw theta %s', theta)

    # TODO: too-large or negative theta causes numerical instability,
    # so this is a huge hack
    theta = utils.clamp(theta, 0, THETA_LIMIT)
    shifted_pmf, log_mgf = utils.shift(log_pmf, theta)

    beta = desired_beta / 2.0
    with timer('bounds'):
        log_delta, p_lower, p_upper = _bounds(log_pmf, shifted_pmf, theta, log_mgf,
                                              s0, L, desired_beta)

    sfft_good, sfft_pval = _check_sfft_pvalue(p_lower, p_upper, desired_beta)

    logging.debug('theta %s, log_mgf %s, beta %s, log delta %s', theta, log_mgf, beta, log_delta)
    if sfft_good:
        logging.debug(' sfft worked %.20f', sfft_pval)
        return sfft_pval
    delta = np.exp(log_delta)

    conv = conv_power(shifted_pmf, L, beta, delta)

    pval = utils.log_sum(utils.unshift(conv, theta, (log_mgf, L))[s0:])
    logging.debug(' sis pvalue %.20f', pval)
    return pval
Example #19
0
def pvalue(log_pmf, s0, L, desired_beta):
    """Compute $log((exp(log_pmf)**L)[s0:])$, such that the relative error
       to the exact answer is less than or equal to $desired_beta$."""
    total_len, _ = utils.iterated_convolution_lengths(len(log_pmf), L)
    if s0 >= total_len:
        return NEG_INF

    _, p_lower_preshift, p_upper_preshift = _bounds(log_pmf, log_pmf, 0, 0.0,
                                                    s0, L, desired_beta)
    sfft_good_preshift, sfft_pval_preshift = _check_sfft_pvalue(
        p_lower_preshift, p_upper_preshift, desired_beta)
    if sfft_good_preshift:
        logging.debug(' pre-shift sfft worked %.20f', sfft_pval_preshift)
        return sfft_pval_preshift

    with timer('computing theta'):
        theta = _compute_theta(log_pmf, s0, L)
    logging.debug('raw theta %s', theta)

    # TODO: too-large or negative theta causes numerical instability,
    # so this is a huge hack
    theta = utils.clamp(theta, 0, THETA_LIMIT)
    shifted_pmf, log_mgf = utils.shift(log_pmf, theta)

    beta = desired_beta / 2.0
    with timer('bounds'):
        log_delta, p_lower, p_upper = _bounds(log_pmf, shifted_pmf, theta,
                                              log_mgf, s0, L, desired_beta)

    sfft_good, sfft_pval = _check_sfft_pvalue(p_lower, p_upper, desired_beta)

    logging.debug('theta %s, log_mgf %s, beta %s, log delta %s', theta,
                  log_mgf, beta, log_delta)
    if sfft_good:
        logging.debug(' sfft worked %.20f', sfft_pval)
        return sfft_pval
    delta = np.exp(log_delta)

    conv = conv_power(shifted_pmf, L, beta, delta)

    pval = utils.log_sum(utils.unshift(conv, theta, (log_mgf, L))[s0:])
    logging.debug(' sis pvalue %.20f', pval)
    return pval
Example #20
0
    def ksi(self, index, observation, ln_alpha, ln_beta, log_ksi):
        T = len(observation)

        for t in range(T - 1):
            ln_sum = float('-inf')
            x = observation[t + 1]

            for i in range(self.state_dim):
                for j in range(self.state_dim):
                    log_ksi[index][t][i][
                        j] = ln_alpha[t][i] + self.log_transitions[i][
                            j] + self.log_emissions[j][x] + ln_beta[t + 1][j]
                    ln_sum = log_sum(ln_sum, log_ksi[index][t][i][j])

            for i in range(self.state_dim):
                for j in range(self.state_dim):
                    log_ksi[index][t][i][j] -= ln_sum

        return log_ksi
Example #21
0
    def __log_alpha(self, observation, ln_alpha):
        # 1. Initilization
        # alpha representation: alpha[t][i]
        T = len(observation)
        for i in range(self.state_dim):
            ln_alpha[0][i] = self.log_pi[i] + self.log_emissions[i][
                observation[0]]

        # 2. Induction
        for t in range(1, T):
            obs = observation[t]
            for i in range(self.state_dim):
                tmp_sum = float('-inf')
                for j in range(self.state_dim):
                    tmp_sum = log_sum(
                        tmp_sum,
                        ln_alpha[t - 1][j] + self.log_transitions[j][i])
                ln_alpha[t][i] = tmp_sum + self.log_emissions[i][obs]

        return ln_alpha
Example #22
0
    def baum_welch(self,
                   observation_sequences=None,
                   encoded_sequences=None,
                   max_iter=15,
                   tolerance=0.001):
        # ???? testing only
        if observation_sequences:
            observation_sequences = self.map_binary_to_decimal(
                observation_sequences)
        if encoded_sequences:
            observation_sequences = encoded_sequences

        convergence = AbsoluteConvergence(max_iter, tolerance)

        # 1. Initialization
        N = len(observation_sequences)
        logN = math.log(N)
        log_ksi = []
        log_gamma = []

        for i in range(N):
            T = len(observation_sequences[i])
            log_ksi.append([])
            log_gamma.append(np.zeros(shape=(T, self.state_dim)))
            for t in range(T):
                log_ksi[i].append(
                    np.zeros(shape=(self.state_dim, self.state_dim)))

        stop = False

        TMax = max([len(obs) for obs in observation_sequences])
        # ln_alpha = np.zeros(shape=(TMax, self.state_dim))
        # ln_beta = np.zeros(shape=(TMax, self.state_dim))

        new_ll = float('-inf')
        old_ll = float('-inf')

        # 2. Iterate until convergence or max iterations is reached
        while not stop:
            # for each sequence in the observation_sequences
            for i in range(N):
                observation = observation_sequences[i]
                T = len(observation)
                tmp_log_gamma = log_gamma[i]
                # w = log_weights[i]

                # 1st step: calculating the forward and backward prob for each HMM state
                ln_alpha, ln_beta, ln_alpha_ll, ln_beta_ll = self.forward_backward(
                    observation)

                # 2nd step: determining the freq of the transition-emission pair values, and dividing it by the prob of the entire string
                # compute the gamma values for next computations
                for t in range(T):
                    ln_sum = float('-inf')
                    for k in range(self.state_dim):
                        tmp_log_gamma[t][k] = ln_alpha[t][k] + ln_beta[t][
                            k]  # + w
                        ln_sum = log_sum(ln_sum, tmp_log_gamma[t][k])

                    # Normalize if different from zero
                    if ln_sum != float('-inf'):
                        for k in range(self.state_dim):
                            tmp_log_gamma[t][k] = tmp_log_gamma[t][k] - ln_sum

                # Calculate ksi values for next computations
                log_ksi = self.ksi(i, observation, ln_alpha, ln_beta, log_ksi)

                # Compute loglikelihood for the given sequence
                new_ll = float('-inf')
                for j in range(self.state_dim):
                    new_ll = log_sum(new_ll, ln_alpha[T - 1][j])

            # Average the loglikelihood for all sequences
            new_ll /= N
            convergence.set_new_value(new_ll)

            # Check for convergence
            if not convergence.has_converged():
                print 'not converged!'
                # 3. Continue with the param re-estimation
                old_ll = new_ll
                new_ll = float('-inf')

                # 3.1 Re-estimate of initial state prob
                for i in range(len(self.log_pi)):
                    ln_sum = float('-inf')
                    for k in range(N):
                        ln_sum = log_sum(ln_sum, log_gamma[k][0][i])

                    print '>>>>>>>shoud update pi, ln_sum: ', ln_sum, ',  logN: ', logN, ', i: ', i
                    self.log_pi[i] = ln_sum - logN

                # 3.2 Re-estimate of transition probabilities
                for i in range(self.state_dim):
                    for j in range(self.state_dim):
                        ln_num = float('-inf')
                        ln_den = float('-inf')
                        for k in range(N):
                            T = len(observation_sequences[k])
                            for t in range(T - 1):
                                ln_num = log_sum(ln_num, log_ksi[k][t][i][j])
                                ln_den = log_sum(ln_den, log_gamma[k][t][i])

                        print '>>>>>>>shoud update trans'
                        if ln_num == ln_den:
                            self.log_transitions[i][j] = 0
                        else:
                            self.log_transitions[i][j] = ln_num - ln_den

                # Update the emission prob matrix
                for i in range(self.state_dim):
                    for j in range(self.encoded_observation_dim):
                        ln_num = float('-inf')
                        ln_den = float('-inf')
                        for k in range(N):
                            T = len(observation_sequences[k])
                            gamma_k = log_gamma[k]

                            for t in range(T):
                                if observation_sequences[k][t] == j:
                                    ln_num = log_sum(ln_num, gamma_k[t][i])
                                ln_den = log_sum(ln_den, gamma_k[t][i])
                        print '>>>>>>>shoud update emit'
                        self.log_emissions[i][j] = ln_num - ln_den

            else:
                print 'converged?'
                stop = True

        # Update the non_log params
        for i in range(len(self.log_pi)):
            self.pi[i] = math.exp(self.log_pi[i])
        for i in range(len(self.log_transitions)):
            for j in range(len(self.log_transitions)):
                self.transitions[i][j] = math.exp(self.log_transitions[i][j])
        for i in range(len(self.log_emissions)):
            for j in range(len(self.log_emissions[0])):
                self.emissions[i][j] = math.exp(self.log_emissions[i][j])

        return new_ll
Example #23
0
    def baum_welch(self, observation_sequences=None, encoded_sequences=None, max_iter=15, tolerance=0.001):
        # ???? testing only
        if observation_sequences:
            observation_sequences = self.map_binary_to_decimal(observation_sequences)
        if encoded_sequences:
            observation_sequences = encoded_sequences

        convergence = AbsoluteConvergence(max_iter, tolerance)

        # 1. Initialization
        N = len(observation_sequences)
        logN = math.log(N)
        log_ksi = []
        log_gamma = []

        for i in range(N):
            T = len(observation_sequences[i])
            log_ksi.append([])
            log_gamma.append(np.zeros(shape=(T, self.state_dim)))
            for t in range(T):
                log_ksi[i].append(np.zeros(shape=(self.state_dim, self.state_dim)))

        stop = False

        TMax = max([len(obs) for obs in observation_sequences])
        # ln_alpha = np.zeros(shape=(TMax, self.state_dim))
        # ln_beta = np.zeros(shape=(TMax, self.state_dim))

        new_ll = float('-inf')
        old_ll = float('-inf')

        # 2. Iterate until convergence or max iterations is reached
        while not stop:
            # for each sequence in the observation_sequences
            for i in range(N):
                observation = observation_sequences[i]
                T = len(observation)
                tmp_log_gamma = log_gamma[i]
                # w = log_weights[i]

                # 1st step: calculating the forward and backward prob for each HMM state
                ln_alpha, ln_beta, ln_alpha_ll, ln_beta_ll = self.forward_backward(observation)

                # 2nd step: determining the freq of the transition-emission pair values, and dividing it by the prob of the entire string
                # compute the gamma values for next computations
                for t in range(T):
                    ln_sum = float('-inf')
                    for k in range(self.state_dim):
                        tmp_log_gamma[t][k] = ln_alpha[t][k] + ln_beta[t][k]# + w
                        ln_sum = log_sum(ln_sum, tmp_log_gamma[t][k])

                    # Normalize if different from zero
                    if ln_sum != float('-inf'):
                        for k in range(self.state_dim):
                            tmp_log_gamma[t][k] = tmp_log_gamma[t][k] - ln_sum

                # Calculate ksi values for next computations
                log_ksi = self.ksi(i, observation, ln_alpha, ln_beta, log_ksi)

                # Compute loglikelihood for the given sequence 
                new_ll = float('-inf')
                for j in range(self.state_dim):
                    new_ll = log_sum(new_ll, ln_alpha[T-1][j])

            # Average the loglikelihood for all sequences
            new_ll /= N
            convergence.set_new_value(new_ll)

            # Check for convergence
            if not convergence.has_converged():
                print 'not converged!'
                # 3. Continue with the param re-estimation
                old_ll = new_ll
                new_ll = float('-inf')

                # 3.1 Re-estimate of initial state prob
                for i in range(len(self.log_pi)):
                    ln_sum = float('-inf')
                    for k in range(N):
                        ln_sum = log_sum(ln_sum, log_gamma[k][0][i])

                    print '>>>>>>>shoud update pi, ln_sum: ', ln_sum, ',  logN: ', logN, ', i: ', i
                    self.log_pi[i] = ln_sum - logN

                # 3.2 Re-estimate of transition probabilities
                for i in range(self.state_dim):
                    for j in range(self.state_dim):
                        ln_num = float('-inf')
                        ln_den = float('-inf')
                        for k in range(N):
                            T = len(observation_sequences[k])
                            for t in range(T-1):
                                ln_num = log_sum(ln_num, log_ksi[k][t][i][j])
                                ln_den = log_sum(ln_den, log_gamma[k][t][i])
                        
                        print '>>>>>>>shoud update trans'    
                        if ln_num == ln_den:
                            self.log_transitions[i][j] = 0
                        else:
                            self.log_transitions[i][j] = ln_num - ln_den

                # Update the emission prob matrix
                for i in range(self.state_dim):
                    for j in range(self.encoded_observation_dim):
                        ln_num = float('-inf')
                        ln_den = float('-inf')
                        for k in range(N):
                            T = len(observation_sequences[k])
                            gamma_k = log_gamma[k]

                            for t in range(T):
                                if observation_sequences[k][t] == j:
                                    ln_num = log_sum(ln_num, gamma_k[t][i])
                                ln_den = log_sum(ln_den, gamma_k[t][i])
                        print '>>>>>>>shoud update emit'
                        self.log_emissions[i][j] = ln_num - ln_den


            else:
                print 'converged?'
                stop = True

        # Update the non_log params
        for i in range(len(self.log_pi)):
            self.pi[i] = math.exp(self.log_pi[i])
        for i in range(len(self.log_transitions)):
            for j in range(len(self.log_transitions)):
                self.transitions[i][j] = math.exp(self.log_transitions[i][j])
        for i in range(len(self.log_emissions)):
            for j in range(len(self.log_emissions[0])):
                self.emissions[i][j] = math.exp(self.log_emissions[i][j])

        return new_ll