def _update_group_vars(args): c, _, lu, zeta, counts_norm, (log_phi, log_mu, beta, lw) = args ## nfeatures, nreplicas = counts_norm.shape ## loglik = _compute_loglik(counts_norm[:, :, np.newaxis], log_phi.reshape(-1, 1, 1), log_mu.reshape(-1, 1, 1), beta) loglik = loglik.sum(1) # update d logw = loglik[:, c] + lu logw = ut.normalize_log_weights(logw.T) d = st.sample_categorical(np.exp(logw)).ravel() # update d: merge null d_ = np.zeros(nfeatures, dtype='int') ll = _compute_loglik(counts_norm, log_phi.reshape(-1, 1), log_mu.reshape(-1, 1), beta[c[d]].reshape(-1, 1)).sum(-1) ll_ = _compute_loglik(counts_norm, log_phi.reshape(-1, 1), log_mu.reshape(-1, 1), beta[c[d_]].reshape(-1, 1)).sum(-1) idxs = (ll_ >= ll) | (rn.rand(nfeatures) < np.exp(ll_ - ll)) d[idxs] = d_[idxs] ## occ = np.bincount(d, minlength=lu.size) iact = occ > 0 kact = np.sum(iact) # update c logw = np.vstack([loglik[d == k].sum(0) for k in np.nonzero(iact)[0]]) + lw logw = ut.normalize_log_weights(logw.T) c[iact] = st.sample_categorical(np.exp(logw)).ravel() c[~iact] = rn.choice(lw.size, c.size - kact, p=np.exp(lw)) c[0] = 0 # update zeta zeta = st.sample_eta_west(zeta, kact, nfeatures) # update lu lu[:], _ = st.sample_stick(occ, zeta) ## return c, d, c[d], lu, zeta
def update(self, data, pool): """Implements a single step of the blocked Gibbs sampler""" ## self.iter += 1 ## # self._update_phi_global(data) # self._update_phi_local(data) # self._update_mu(data) # self._update_beta_global(data) # self._update_beta_local(data) if rn.rand() < 0.5: _update_phi_global(self, data) _update_mu(self, data) _update_beta_global(self, data) else: _update_phi_local(self, data) _update_mu(self, data) _update_beta_local(self, data) # update group-specific variables counts_norm, _ = data common_args = it.repeat( (self.log_phi, self.log_mu, self.beta, self.lw)) args = zip(self.c[1:], self.d[1:], self.lu[1:], self.zeta[1:], counts_norm[1:], common_args) if pool is None: self.c[1:], self.d[1:], self.z[1:], self.lu[1:], self.zeta[ 1:] = zip(*map(_update_group_vars, args)) else: self.c[1:], self.d[1:], self.z[1:], self.lu[1:], self.zeta[ 1:] = zip(*pool.map(_update_group_vars, args)) # update occupancies self.occ[:] = np.bincount(self.z[1:].ravel(), minlength=self.lw.size) self.iact[:] = self.occ > 0 self.nact = np.sum(self.iact) # update eta self.eta = st.sample_eta_west(self.eta, self.nact, self.occ.sum()) # update weights self.lw[:], _ = st.sample_stick(self.occ, self.eta) # update hyper-parameters _update_hpars(self)
def update(self, data, pool): """Implements a single step of the blocked Gibbs sampler""" ## self.iter += 1 ## # self._update_phi_global(data) # self._update_phi_local(data) # self._update_mu(data) # self._update_beta_global(data) # self._update_beta_local(data) if rn.rand() < 0.5: _update_phi_global(self, data) _update_mu(self, data) _update_beta_global(self, data) else: _update_phi_local(self, data) _update_mu(self, data) _update_beta_local(self, data) # update group-specific variables counts_norm, _ = data common_args = it.repeat((self.log_phi, self.log_mu, self.beta, self.lw)) args = zip(self.c[1:], self.d[1:], self.lu[1:], self.zeta[1:], counts_norm[1:], common_args) if pool is None: self.c[1:], self.d[1:], self.z[1:], self.lu[1:], self.zeta[1:] = zip(*map(_update_group_vars, args)) else: self.c[1:], self.d[1:], self.z[1:], self.lu[1:], self.zeta[1:] = zip(*pool.map(_update_group_vars, args)) # update occupancies self.occ[:] = np.bincount(self.z[1:].ravel(), minlength=self.lw.size) self.iact[:] = self.occ > 0 self.nact = np.sum(self.iact) # update eta self.eta = st.sample_eta_west(self.eta, self.nact, self.occ.sum()) # update weights self.lw[:], _ = st.sample_stick(self.occ, self.eta) # update hyper-parameters _update_hpars(self)
def _update_group_vars(args): c, _, lu, zeta, counts_norm, (log_phi, log_mu, beta, lw) = args ## nfeatures, nreplicas = counts_norm.shape ## loglik = _compute_loglik(counts_norm[:, :, np.newaxis], log_phi.reshape(-1, 1, 1), log_mu.reshape(-1, 1, 1), beta) loglik = loglik.sum(1) # update d logw = loglik[:, c] + lu logw = ut.normalize_log_weights(logw.T) d = st.sample_categorical(np.exp(logw)).ravel() # update d: merge null d_ = np.zeros(nfeatures, dtype='int') ll = _compute_loglik(counts_norm, log_phi.reshape(-1, 1), log_mu.reshape(-1, 1), beta[c[d]].reshape(-1, 1)).sum(-1) ll_ = _compute_loglik(counts_norm, log_phi.reshape(-1, 1), log_mu.reshape(-1, 1), beta[c[d_]].reshape(-1, 1)).sum(-1) idxs = np.log(rn.rand(nfeatures)) < ll_ - ll d[idxs] = d_[idxs] ## occ = np.bincount(d, minlength=lu.size) iact = occ > 0 kact = np.sum(iact) # update c logw = np.vstack([loglik[d == k].sum(0) for k in np.nonzero(iact)[0]]) + lw logw = ut.normalize_log_weights(logw.T) c[iact] = st.sample_categorical(np.exp(logw)).ravel() c[~iact] = rn.choice(lw.size, c.size-kact, p=np.exp(lw)) c[0] = 0 # update zeta zeta = st.sample_eta_west(zeta, kact, nfeatures) # update lu lu[:], _ = st.sample_stick(occ, zeta) ## return c, d, c[d], lu, zeta