def backward(self): """Backward recursion. Upon completion, the following list of length T is available: * smth: marginal smoothing probabilities Note ---- Performs the forward step in case it has not been performed before. """ if not self.filt: self.forward() self.smth = [self.filt[-1]] log_trans = np.log(self.hmm.trans_mat) ctg = np.zeros( self.hmm.dim) # cost to go (log-lik of y_{t+1:T} given x_t=k) for filt, next_ft in reversed(list(zip(self.filt[:-1], self.logft[1:]))): new_ctg = np.empty(self.hmm.dim) for k in range(self.hmm.dim): new_ctg[k] = rs.log_sum_exp(log_trans[k, :] + next_ft + ctg) ctg = new_ctg smth = rs.exp_and_normalise(np.log(filt) + ctg) self.smth.append(smth) self.smth.reverse()
def filt_step(self, t, yt): emis = self.hmm.PY(t, None, np.arange(self.hmm.dim)).logpdf(yt) lp = np.log(self.pred[-1]) + emis logpyt = rs.log_sum_exp(lp) f = np.exp(lp - logpyt) self.logft.append(emis) self.logpyt.append(lp) self.filt.append(f)
def update_M(self, new_weights: rs.Weights, normalize: bool, normalize_info: typing.Any = None): if normalize: self.pre_normalize_weights.append(self.current_weights) self.normalize_information.append(normalize_info) self.current_weights = np.log(new_weights.W) else: normalized_new_weights = ut.log(new_weights.W) normalization_constant_of_current_weights = rs.log_sum_exp( self.current_weights) self.current_weights = normalized_new_weights + normalization_constant_of_current_weights
def run(self): self.setup() self.points = [] self.log_weights = [np.log(1. - np.exp(-1. / self.N))] self.step() self.lZhats = [self.log_weights[0] + self.points[0].llik] while True: self.step() b = self.log_weights[-1] + self.points[-1].llik self.lZhats.append(log_sum_exp(self.lZhats[-1], b)) if self.stopping_time(): break next_lw = self.log_weights[-1] - 1. / self.N self.log_weights.append(next_lw) if len(self.log_weights) % self.N == 0: print('iteration %i: log(Z_hat) = %f' % (len(self.log_weights), self.lZhats[-1]))
def logLT(self) -> float: return sum([ rs.log_sum_exp(arr) for arr in self.pre_normalize_weights + [self.current_weights] ])