Пример #1
0
    def E_step(self):
        alphal = HMMStatesEigen._messages_forwards_log(self.hmm_trans_matrix, self.pi_0, self.aBl)
        betal = HMMStatesEigen._messages_backwards_log(self.hmm_trans_matrix, self.aBl)
        self.expected_states, self.expected_transcounts, self._normalizer = HMMStatesPython._expected_statistics_from_messages(
            self.hmm_trans_matrix, self.aBl, alphal, betal
        )

        # using these is untested!
        self._expected_ns = np.diag(self.expected_transcounts).copy()
        self._expected_tots = self.expected_transcounts.sum(1)

        self.expected_transcounts.flat[:: self.expected_transcounts.shape[0] + 1] = 0.0
Пример #2
0
    def meanfieldupdate_Estep(self):
        # TODO bug in here? it's not as good as sampling
        num_r_samples = self.model.mf_num_samples \
                if hasattr(self.model,'mf_num_samples') else 10
        num_stateseq_samples_per_r = self.model.mf_num_stateseq_samples_per_r \
                if hasattr(self.model,'mf_num_stateseq_samples_per_r') else 1

        self.expected_states = np.zeros((self.T, self.num_states))
        self.expected_transcounts = np.zeros(
            (self.num_states, self.num_states))
        self.expected_durations = np.zeros((self.num_states, self.T))

        mf_aBl = self.mf_aBl

        for i in xrange(num_r_samples):
            for d in self.dur_distns:
                d._resample_r_from_mf()
            self.clear_caches()

            trans = self.mf_bwd_trans_matrix  # TODO check this
            init = self.hmm_mf_bwd_pi_0
            aBl = mf_aBl.repeat(self.rs, axis=1)

            hmm_alphal, hmm_betal = HMMStatesEigen._messages_log(
                self, trans, init, aBl)

            # collect stateseq and transitions statistics from messages
            hmm_expected_states, hmm_expected_transcounts, normalizer = \
                    HMMStatesPython._expected_statistics_from_messages(
                            trans,aBl,hmm_alphal,hmm_betal)
            expected_states, expected_transcounts, _ \
                    = self._hmm_stats_to_hsmm_stats(
                            hmm_expected_states, hmm_expected_transcounts, normalizer)

            self.expected_states += expected_states / num_r_samples
            self.expected_transcounts += expected_transcounts / num_r_samples

            # collect duration statistics by sampling from messages
            for j in xrange(num_stateseq_samples_per_r):
                self._resample_from_mf(trans, init, aBl, hmm_alphal, hmm_betal)
                for state in xrange(self.num_states):
                    self.expected_durations[state] += \
                        np.bincount(
                                self.durations_censored[self.stateseq_norep == state],
                                minlength=self.T)[:self.T].astype(np.double) \
                            /(num_r_samples*num_stateseq_samples_per_r)
Пример #3
0
    def E_step(self):
        alphal = HMMStatesEigen._messages_forwards_log(self.hmm_trans_matrix,
                                                       self.pi_0, self.aBl)
        betal = HMMStatesEigen._messages_backwards_log(self.hmm_trans_matrix,
                                                       self.aBl)
        self.expected_states, self.expected_transcounts, self._normalizer = \
                HMMStatesPython._expected_statistics_from_messages(
                        self.hmm_trans_matrix,
                        self.aBl,
                        alphal,
                        betal)

        # using these is untested!
        self._expected_ns = np.diag(self.expected_transcounts).copy()
        self._expected_tots = self.expected_transcounts.sum(1)

        self.expected_transcounts.flat[::self.expected_transcounts.shape[0] +
                                       1] = 0.
Пример #4
0
    def meanfieldupdate_Estep(self):
        # TODO bug in here? it's not as good as sampling
        num_r_samples = self.model.mf_num_samples \
                if hasattr(self.model,'mf_num_samples') else 10
        num_stateseq_samples_per_r = self.model.mf_num_stateseq_samples_per_r \
                if hasattr(self.model,'mf_num_stateseq_samples_per_r') else 1

        self.expected_states = np.zeros((self.T,self.num_states))
        self.expected_transcounts = np.zeros((self.num_states,self.num_states))
        self.expected_durations = np.zeros((self.num_states,self.T))

        mf_aBl = self.mf_aBl

        for i in xrange(num_r_samples):
            for d in self.dur_distns:
                d._resample_r_from_mf()
            self.clear_caches()

            trans = self.mf_bwd_trans_matrix # TODO check this
            init = self.hmm_mf_bwd_pi_0
            aBl = mf_aBl.repeat(self.rs,axis=1)

            hmm_alphal, hmm_betal = HMMStatesEigen._messages_log(self,trans,init,aBl)

            # collect stateseq and transitions statistics from messages
            hmm_expected_states, hmm_expected_transcounts, normalizer = \
                    HMMStatesPython._expected_statistics_from_messages(
                            trans,aBl,hmm_alphal,hmm_betal)
            expected_states, expected_transcounts, _ \
                    = self._hmm_stats_to_hsmm_stats(
                            hmm_expected_states, hmm_expected_transcounts, normalizer)

            self.expected_states += expected_states / num_r_samples
            self.expected_transcounts += expected_transcounts / num_r_samples

            # collect duration statistics by sampling from messages
            for j in xrange(num_stateseq_samples_per_r):
                self._resample_from_mf(trans,init,aBl,hmm_alphal,hmm_betal)
                for state in xrange(self.num_states):
                    self.expected_durations[state] += \
                        np.bincount(
                                self.durations_censored[self.stateseq_norep == state],
                                minlength=self.T)[:self.T].astype(np.double) \
                            /(num_r_samples*num_stateseq_samples_per_r)
Пример #5
0
 def messages_backwards_log_hmm(self):
     return HMMStatesPython._messages_backwards_log(
         self.hmm_backwards_trans_matrix,self.hmm_aBl)
Пример #6
0
 def messages_backwards_normalized_hmm(self):
     return HMMStatesPython._messages_backwards_normalized(
             self.hmm_backwards_trans_matrix,self.hmm_backwards_pi_0,self.hmm_aBl)