예제 #1
0
    def _compute_e_log_p_gamma_posterior(self, data_type):
        e_log_epsilon = self.get_e_log_epsilon(data_type)

        data_term = self._get_gamma_data_term(data_type)

        if self.use_position_specific_gamma:
            result = 0

            for m in range(self.M[data_type]):
                result += np.sum(
                    safe_multiply(e_log_epsilon[:, :, m], data_term[:, :, m]))

        else:
            result = np.sum(safe_multiply(e_log_epsilon, data_term))

        return result
예제 #2
0
    def _compute_e_log_p_kappa_posterior(self):
        e_log_p = 0

        for sample in self.kappa:
            e_log_p += np.sum(
                safe_multiply(self.get_e_log_pi(sample),
                              self._get_kappa_data_term(sample)))

        return e_log_p
예제 #3
0
파일: dirichlet.py 프로젝트: Roth-Lab/scg
 def _compute_e_log_p_gamma_posterior(self, data_type):
     return np.sum(safe_multiply(self.get_e_log_mu(data_type), self._get_gamma_data_term(data_type)))
예제 #4
0
파일: dirichlet.py 프로젝트: Roth-Lab/scg
 def _compute_e_log_p_kappa_posterior(self):
     return np.sum(safe_multiply(self.e_log_pi, self._get_kappa_data_term()))
예제 #5
0
    def _get_gamma_data_term(self, data_type):
        log_G = self.log_G[data_type]

        G = self.get_G(data_type)

        M = self.M[data_type]

        S = self.S[data_type]

        X = self.X[data_type]

        state_map = self.state_map[data_type]

        if self.use_position_specific_gamma:
            out_dim = 'stm'

        else:
            out_dim = 'st'

        # SxNxKxM
        singlet_term = np.exp(self.log_Z_0[np.newaxis, :, :, np.newaxis] +
                              log_G[:, np.newaxis, :, :])

        singlet_term = np.einsum('stnkm, stnkm -> {0}'.format(out_dim),
                                 singlet_term[:, np.newaxis, :, :, :],
                                 X[np.newaxis, :, :, np.newaxis, :])

        # TxKxKxM
        doublet_diff_term_temp = np.einsum(
            'tnklm, tnklm -> tklm', self.Z_1[np.newaxis, :, :, :, np.newaxis],
            X[:, :, np.newaxis, np.newaxis, :])

        # SxKxKxM
        G_G = self._get_G_G_marginalised(data_type)

        doublet_diff_term = np.einsum(
            'stklm, stklm -> {0}'.format(out_dim),
            doublet_diff_term_temp[np.newaxis, :, :, :],
            G_G[:, np.newaxis, :, :, :])

        doublet_diff_term_correction = np.einsum(
            'stmk, stmk -> {0}'.format(out_dim),
            np.diagonal(doublet_diff_term_temp, axis1=1,
                        axis2=2)[np.newaxis, :, :, :],
            np.diagonal(G_G, axis1=1, axis2=2)[:, np.newaxis, :, :])

        doublet_diff_term = doublet_diff_term - doublet_diff_term_correction

        # SxKxM
        doublet_same_term = np.zeros((S, self.K, M))

        for s in state_map:
            for (u, v) in state_map[s]:
                if u != v:
                    continue

                doublet_same_term[s, :, :] += G[u, :, :]

        # SxNxKxM
        doublet_same_term = safe_multiply(
            doublet_same_term[:, np.newaxis, :, :],
            self.Z_1_k_k[np.newaxis, :, :, np.newaxis])

        # SxT
        doublet_same_term = np.einsum(
            'stnkm, stnkm -> {0}'.format(out_dim),
            doublet_same_term[:, np.newaxis, :, :, :], X[np.newaxis, :, :,
                                                         np.newaxis, :])

        # SxT
        return singlet_term + doublet_same_term + doublet_diff_term
예제 #6
0
    def _update_G_d(self, data_type):
        X = self.X[data_type]

        G = self.get_G(data_type)

        state_map = self.state_map[data_type]

        inverse_state_map = self._inverse_state_map[data_type]

        G_prior = self.G_prior[data_type]

        M = self.M[data_type]

        S = self.S[data_type]

        T = self.T[data_type]

        e_log_epsilon = self.get_e_log_epsilon(data_type)

        # SxTxNxKxM
        if self.use_position_specific_gamma:
            e_log_epsilon = e_log_epsilon[:, :, np.newaxis, np.newaxis, :]

        else:
            e_log_epsilon = e_log_epsilon[:, :, np.newaxis, np.newaxis,
                                          np.newaxis]

        # SxKxM
        singlet_term = np.einsum(
            'stnkm, stnkm, stnkm -> skm',
            self.Z_0[np.newaxis, np.newaxis, :, :, np.newaxis],
            X[np.newaxis, :, :, np.newaxis, :], e_log_epsilon)
        # SxTxKxM
        doublet_diff_term_temp = np.einsum(
            'stnklm, stnklm, stnklm -> stklm', G[:, np.newaxis, np.newaxis,
                                                 np.newaxis, :, :],
            self.Z_1[np.newaxis, np.newaxis, :, :, :, np.newaxis],
            X[np.newaxis, :, :, np.newaxis, np.newaxis, :])

        # SxTxKxM
        doublet_diff_term_temp = doublet_diff_term_temp.sum(
            axis=3) - np.swapaxes(
                np.diagonal(doublet_diff_term_temp, axis1=2, axis2=3), -2, -1)

        doublet_diff_term = np.zeros(doublet_diff_term_temp.shape)

        # SxTxKxM
        e_log_epsilon = np.squeeze(e_log_epsilon, axis=2)

        for s in state_map:
            for w in state_map:
                for (u, v) in state_map[w]:

                    if (u != s) and (v != s):
                        continue

                    elif (u == s):
                        doublet_diff_term[s, :, :, :] += safe_multiply(
                            e_log_epsilon[w], doublet_diff_term_temp[v])

                    elif (v == s):
                        doublet_diff_term[s, :, :, :] += safe_multiply(
                            e_log_epsilon[w], doublet_diff_term_temp[u])

        # SxKxM
        doublet_diff_term = doublet_diff_term.sum(axis=1)

        # TxKxM
        doublet_same_term_temp = np.einsum(
            'tnkm, tnkm -> tkm', self.Z_1_k_k[np.newaxis, :, :, np.newaxis],
            X[:, :, np.newaxis, :])

        # SxTxKxM
        doublet_same_term = np.zeros((S, T, self.K, M))

        for s in state_map:
            w = inverse_state_map[(s, s)]

            doublet_same_term[s, :, :, :] = e_log_epsilon[w] * np.expand_dims(
                doublet_same_term_temp, axis=0)

        # SxKxM
        doublet_same_term = doublet_same_term.sum(axis=1)

        # SxKxM
        data_term = singlet_term + doublet_diff_term + doublet_same_term

        log_G = np.log(G_prior[:, np.newaxis, np.newaxis]) + data_term

        log_G = log_G - np.expand_dims(log_sum_exp(log_G, axis=0), axis=0)

        self.log_G[data_type] = log_G