Пример #1
0
    def _resample_Z_gsl(self, data=[]):
        """
        Resample the parents given the bias_model, weight_model, and impulse_model.

        :param bias_model:
        :param weight_model:
        :param impulse_model:
        :return:
        """
        bias_model, weight_model, impulse_model = \
            self.model.bias_model, self.model.weight_model, self.model.impulse_model
        K, B = self.K, self.B
        # Make a big matrix of size T x (K*B + 1)
        for k2, (Sk, Fk, Tk, Zk) in enumerate(zip(self.Ss, self.Fs, self.Ts, self.Z)):
            P = np.zeros((Tk, 1+self.K * self.B))
            P[:,0] = bias_model.lambda0[k2]

            Wk2 = np.repeat(weight_model.W_effective[:,k2], B)
            Gk2 = impulse_model.g[:,k2,:].reshape((K*B,), order="C")
            P[:,1:] = Wk2 * Gk2 * Fk.reshape((Tk, K*B))

            # Normalize the rows
            P = P / P.sum(1)[:,None]

            # Sample parents from P with counts S[:,k2]
            multinomial_par(self.pyrngs, Sk, P, Zk)
Пример #2
0
def test_parallel_multi_N_multi_p_with_out():
    # Multiple N counts, multiple p arrays, out structure provided
    L = 10
    N = np.arange(L, dtype=np.uint32) + 10
    K = 5
    p = np.zeros((L, K))
    p[:5] = 1./K * np.ones(K)
    p[5:] = np.asarray([0.5, 0.25, 0.05, 0.1, 0.1])

    # Create some RNGs
    rngs = [PyRNG() for _ in xrange(get_omp_num_threads())]

    n_iter = 1000000
    z = np.zeros((L,K))
    for _ in xrange(n_iter):
        out = np.zeros((L,K), dtype=np.uint32)
        multinomial_par(rngs, N, p, out)
        assert out.shape == (L,K)
        assert (out.sum(axis=1) == N).all()
        z += out

    print z/z.sum(axis=1)[:,np.newaxis]
    assert (np.abs(z/z.sum(axis=1)[:,np.newaxis] - p) < 1e-2).all()
Пример #3
0
 def resample_z(self):
     topicprobs = self._get_topicprobs()
     multinomial_par(self.pyrngs, self.data.data, topicprobs, self.z)
     self._update_counts()
Пример #4
0
 def resample_z(self):
     topicprobs = self.get_topicprobs(self.data)
     multinomial_par(self.pyrngs, self.data.data, topicprobs, self.z)
     self._update_counts()