def step(self): # S ~ N(0, exp(Z / 2)) # # Z = signal_node + noise_node + bias # = signal_node + gaussian_term # = scale_node + bias # # resample gaussian_term conditioned on signal_node scale_node = self.gsm_node.scale_node S = self.gsm_node.value() N, K = S.shape # resample Z Z = scale_node.value() + self.gsm_node.bias if scale_node.isleaf(): mu = self.gsm_node.bias * np.ones((N, K)) sigma_sq = scale_node.variance() else: assert scale_node.issum() mu = self.gsm_node.bias + scale_node.value() - scale_node.children[-1].value() sigma_sq = scale_node.children[-1].variance() for i in range(N): for k in range(K): log_f = sparse_coding.LogFUncollapsed(S[i, k]) if self.maximize: temp = lambda z: -log_f(z) - distributions.gauss_loglik(z, mu[i, k], sigma_sq[i, k]) Z[i, k] = scipy.optimize.fmin(temp, Z[i, k], disp=False) else: Z[i, k] = slice_sampling.slice_sample_gauss(log_f, mu[i, k], sigma_sq[i, k], Z[i, k]) # resample bias if scale_node.isleaf(): gaussian_term = Z else: signal = scale_node.value() - scale_node.children[-1].value() gaussian_term = Z - signal if not self.maximize: if self.gsm_node.bias_type == 'scalar': mu = gaussian_term.mean() lam = (1. / sigma_sq).sum() self.gsm_node.bias = np.random.normal(mu, 1. / lam) elif self.gsm_node.bias_type == 'row': mu = gaussian_term.mean(1) lam = (1. / sigma_sq).sum(1) self.gsm_node.bias = np.random.normal(mu, 1. / lam)[:, nax] elif self.gsm_node.bias_type == 'col': mu = gaussian_term.mean(0) lam = (1. / sigma_sq).sum(0) self.gsm_node.bias = np.random.normal(mu, 1. / lam)[nax, :] # set noise node noise_term = gaussian_term - self.gsm_node.bias if scale_node.isleaf(): scale_node.set_value(noise_term) else: scale_node.children[-1].set_value(noise_term)
def sample_Z(state): N, K= state.S.shape[0], state.Z.shape[1] for i in range(N): for k in range(K): log_f = LogFUncollapsed(state.S[i, k]) if np.isscalar(state.mu_Z): mu_Z = state.mu_Z else: mu_Z = state.mu_Z[k] state.Z[i, k] = slice_sampling.slice_sample_gauss(log_f, mu_Z, state.sigma_sq_Z, state.Z[i, k]) if hasattr(debugger, 'after_sample_Z'): debugger.after_sample_Z(vars())