Esempio n. 1
0
    def step(self):
        # S ~ N(0, exp(Z / 2))
        #
        # Z = signal_node + noise_node + bias
        #   = signal_node +   gaussian_term
        #   =       scale_node         + bias
        #
        # resample gaussian_term conditioned on signal_node
         
        scale_node = self.gsm_node.scale_node
        S = self.gsm_node.value()
        N, K = S.shape

        # resample Z
        Z = scale_node.value() + self.gsm_node.bias
        if scale_node.isleaf():
            mu = self.gsm_node.bias * np.ones((N, K))
            sigma_sq = scale_node.variance()
        else:
            assert scale_node.issum()
            mu = self.gsm_node.bias + scale_node.value() - scale_node.children[-1].value()
            sigma_sq = scale_node.children[-1].variance()

        for i in range(N):
            for k in range(K):
                log_f = sparse_coding.LogFUncollapsed(S[i, k])
                if self.maximize:
                    temp = lambda z: -log_f(z) - distributions.gauss_loglik(z, mu[i, k], sigma_sq[i, k])
                    Z[i, k] = scipy.optimize.fmin(temp, Z[i, k], disp=False)
                else:
                    Z[i, k] = slice_sampling.slice_sample_gauss(log_f, mu[i, k], sigma_sq[i, k], Z[i, k])

        # resample bias
        if scale_node.isleaf():
            gaussian_term = Z
        else:
            signal = scale_node.value() - scale_node.children[-1].value()
            gaussian_term = Z - signal

        if not self.maximize:
            if self.gsm_node.bias_type == 'scalar':
                mu = gaussian_term.mean()
                lam = (1. / sigma_sq).sum()
                self.gsm_node.bias = np.random.normal(mu, 1. / lam)
            elif self.gsm_node.bias_type == 'row':
                mu = gaussian_term.mean(1)
                lam = (1. / sigma_sq).sum(1)
                self.gsm_node.bias = np.random.normal(mu, 1. / lam)[:, nax]
            elif self.gsm_node.bias_type == 'col':
                mu = gaussian_term.mean(0)
                lam = (1. / sigma_sq).sum(0)
                self.gsm_node.bias = np.random.normal(mu, 1. / lam)[nax, :]

        # set noise node
        noise_term = gaussian_term - self.gsm_node.bias
        if scale_node.isleaf():
            scale_node.set_value(noise_term)
        else:
            scale_node.children[-1].set_value(noise_term)
def sample_Z(state):
    N, K= state.S.shape[0], state.Z.shape[1]
    for i in range(N):
        for k in range(K):
            log_f = LogFUncollapsed(state.S[i, k])
            if np.isscalar(state.mu_Z):
                mu_Z = state.mu_Z
            else:
                mu_Z = state.mu_Z[k]
            state.Z[i, k] = slice_sampling.slice_sample_gauss(log_f, mu_Z, state.sigma_sq_Z, state.Z[i, k])

    if hasattr(debugger, 'after_sample_Z'):
        debugger.after_sample_Z(vars())