def test_permutation(): a = numpy.array((3,2,1,6,5,4,10,0)) p = _permutation_by_sort(a) a_permuted = _permute(a, p) # b should be in decreasing order b = numpy.array([a+100*i for i in xrange(4)]) b_permuted = _permute(b, p, axis=1) # b's second axis should be in decreasing order return p, a, a_permuted, b, b_permuted
def _reorder_topic_labels(self): "Reorder the topic labels so that the largest topic is always first." # calculate the permutation required to put the topics in order of size. permutation = _permutation_by_sort(self.counts.E_n_k) # don't bother re-ordering anything if permutation is the identity if (numpy.arange(self.K) != permutation).any(): #start_LL = numpy.array(self._log_likelihood()) for d, q_zd in enumerate(self.q_z): self.q_z[d] = _permute(q_zd, permutation, axis=1) a_permuted = _permute(self.q_pi_bar.a, permutation, axis=0) b_permuted = _permute(self.q_pi_bar.b, permutation, axis=0) self.q_pi_bar.a = a_permuted self.q_pi_bar.b = b_permuted self.q_1_minus_pi_bar.a = b_permuted self.q_1_minus_pi_bar.b = a_permuted # clear the cached values (they all depend on the topic ordering) # could also permute them rather than recalculating - that might # be more efficient. self._calculate_E_s_dk.cached_value = _permute(self.E_s_dk, permutation, axis=1) self._calculate_E_t_kw.cached_value = _permute(self.E_t_kw, permutation, axis=0) self._calculate_E_log_xi.cached_value = _permute(self.E_log_xi, permutation, axis=0) self._calculate_G_pi.cached_value = _permute(self.G_pi, permutation, axis=0) self._calculate_E_log_eta.clear_cached_value() self.counts.permute(permutation) # check topics are sorted in decreasing order of size assert (self.counts.E_n_k[:-1] >= self.counts.E_n_k[1:]).all()
def permute(self, permutation): "Permute the counts based on a permutation of the topics." self.E_n_dk = _permute(self.E_n_dk, permutation, axis=1) self.E_n_k = _permute(self.E_n_k , permutation, axis=0) self.E_n_kw = _permute(self.E_n_kw, permutation, axis=0) self.E_plus_n_dk = _permute(self.E_plus_n_dk, permutation, axis=1) self.E_plus_n_k = _permute(self.E_plus_n_k , permutation, axis=0) self.E_plus_n_kw = _permute(self.E_plus_n_kw, permutation, axis=0) self.V_n_dk = _permute(self.V_n_dk, permutation, axis=1) self.V_n_k = _permute(self.V_n_k , permutation, axis=0) self.V_n_kw = _permute(self.V_n_kw, permutation, axis=0) self.V_plus_n_dk = _permute(self.V_plus_n_dk, permutation, axis=1) self.V_plus_n_k = _permute(self.V_plus_n_k , permutation, axis=0) self.V_plus_n_kw = _permute(self.V_plus_n_kw, permutation, axis=0) self.Z_n_dk = _permute(self.Z_n_dk, permutation, axis=1) self.Z_n_k = _permute(self.Z_n_k , permutation, axis=0) self.Z_n_kw = _permute(self.Z_n_kw, permutation, axis=0) self.P_plus_n_dk = _permute(self.P_plus_n_dk, permutation, axis=1) self.P_plus_n_k = _permute(self.P_plus_n_k , permutation, axis=0) self.P_plus_n_kw = _permute(self.P_plus_n_kw, permutation, axis=0)