예제 #1
0
def packParamBagForPost(nu=None,
                        beta=None,
                        m=None,
                        kappa=None,
                        D=None,
                        Post=None,
                        **kwargs):
    '''
    '''
    m = as2D(m)
    beta = as2D(beta)

    if D is None:
        D = m.shape[1]

    if m.shape[1] != D:
        m = m.T.copy()
    if beta.shape[1] != D:
        beta = beta.T.copy()
    K, _ = m.shape
    if Post is None:
        Post = ParamBag(K=K, D=D)
    else:
        assert isinstance(Post, ParamBag)
        assert Post.K == K
        assert Post.D == D
    Post.setField('nu', as1D(nu), dims=('K'))
    Post.setField('beta', beta, dims=('K', 'D'))
    Post.setField('m', m, dims=('K', 'D'))
    Post.setField('kappa', as1D(kappa), dims=('K'))
    return Post
def packParamBagForPost(pnu_K=None,
                        ptau_K=None,
                        w_KE=None,
                        P_KEE=None,
                        Post=None,
                        **kwargs):
    ''' Parse provided array args and pack into parameter bag

    Returns
    -------
    Post : ParamBag, with K clusters
    '''
    pnu_K = as1D(pnu_K)
    ptau_K = as1D(ptau_K)
    w_KE = as2D(w_KE)
    P_KEE = as3D(P_KEE)

    K = pnu_K.size
    E = w_KE.shape[1]
    if Post is None:
        Post = ParamBag(K=K, D=E - 1, E=E)
    elif not hasattr(Post, 'E'):
        Post.E = E
    assert Post.K == K
    assert Post.D == E - 1
    assert Post.E == E
    Post.setField('pnu_K', pnu_K, dims=('K'))
    Post.setField('ptau_K', ptau_K, dims=('K'))
    Post.setField('w_KE', w_KE, dims=('K', 'E'))
    Post.setField('P_KEE', P_KEE, dims=('K', 'E', 'E'))
    return Post
예제 #3
0
def createParamBagForPrior(Data,
                           D=0,
                           nu=0,
                           beta=None,
                           m=None,
                           kappa=None,
                           MMat='zero',
                           ECovMat=None,
                           sF=1.0,
                           Prior=None,
                           **kwargs):
    ''' Initialize ParamBag of parameters which specify prior.

    Returns
    -------
    Prior : ParamBag
    '''
    if Data is None:
        D = int(D)
    else:
        D = int(Data.dim)
    nu = np.maximum(nu, D + 2)
    kappa = np.maximum(kappa, 1e-8)
    if beta is None:
        if ECovMat is None or isinstance(ECovMat, str):
            ECovMat = createECovMatFromUserInput(D, Data, ECovMat, sF)
        beta = np.diag(ECovMat) * (nu - 2)
    else:
        if beta.ndim == 0:
            beta = np.asarray([beta], dtype=np.float)
    if m is None:
        if MMat == 'data':
            m = np.sum(Data.X, axis=0)
        else:
            m = np.zeros(D)
    elif m.ndim < 1:
        m = np.asarray([m], dtype=np.float)
    if Prior is None:
        Prior = ParamBag(K=0, D=D)
    assert Prior.D == D
    Prior.setField('nu', nu, dims=None)
    Prior.setField('kappa', kappa, dims=None)
    Prior.setField('m', m, dims=('D'))
    Prior.setField('beta', beta, dims=('D'))
    return Prior
 def calcIterationParams(self, std, mask=None):
     D, K, GP = self.D, self.K, self.GP
     IP = ParamBag(K=K, D=D)
     if mask is None:
         mask = np.ones(D, dtype=bool)
     invSigma = 1.0 / std**2 * np.diag(mask) + GP.Lam
     Rc = np.zeros((K, D, D))
     Rlower = np.ones(K, dtype=bool)
     for k in xrange(K):
         Rc[k], Rlower[k] = cho_factor(invSigma[k], lower=True)
     try:
         IP.setField('Rc', np.tril(Rc), dims=('K', 'D', 'D'))
     except ValueError:
         for k in xrange(K):
             Rc[k] = np.tril(Rc[k])
         IP.setField('Rc', Rc, dims=('K', 'D', 'D'))
     IP.setField('Rlower', Rlower, dims='K')
     logdetSigma = -2 * np.sum(np.log(np.diagonal(Rc, axis1=1, axis2=2)),
                               axis=1)
     IP.setField('logdetSigma', logdetSigma, dims='K')
     return IP
def createParamBagForPrior(Data=None,
                           D=0,
                           pnu=0,
                           ptau=None,
                           w_E=0,
                           P_EE=None,
                           P_diag_E=None,
                           P_diag_val=1.0,
                           Prior=None,
                           **kwargs):
    ''' Initialize Prior ParamBag attribute.

    Returns
    -------
    Prior : ParamBag
        with dimension attributes K, D, E
        with parameter attributes pnu, ptau, w_E, P_EE
    '''
    if Data is None:
        D = int(D)
    else:
        D = int(Data.dim)
    E = D + 1

    # Init parameters of 1D Wishart prior on delta
    pnu = np.maximum(pnu, 1e-9)
    ptau = np.maximum(ptau, 1e-9)

    # Initialize precision matrix of the weight vector
    if P_EE is not None:
        P_EE = np.asarray(P_EE)
    elif P_diag_E is not None:
        P_EE = np.diag(np.asarray(P_diag_E))
    else:
        P_EE = np.diag(P_diag_val * np.ones(E))
    assert P_EE.ndim == 2
    assert P_EE.shape == (E, E)

    # Initialize mean of the weight vector
    w_E = as1D(np.asarray(w_E))
    if w_E.size < E:
        w_E = np.tile(w_E, E)[:E]
    assert w_E.ndim == 1
    assert w_E.size == E

    if Prior is None:
        Prior = ParamBag(K=0, D=D, E=E)
    if not hasattr(Prior, 'E'):
        Prior.E = E
    assert Prior.D == D
    assert Prior.E == E
    Prior.setField('pnu', pnu, dims=None)
    Prior.setField('ptau', ptau, dims=None)
    Prior.setField('w_E', w_E, dims=('E'))
    Prior.setField('P_EE', P_EE, dims=('E', 'E'))

    Pw_E = np.dot(P_EE, w_E)
    wPw_1 = np.dot(w_E, Pw_E)
    Prior.setField('Pw_E', Pw_E, dims=('E'))
    Prior.setField('wPw_1', wPw_1, dims=None)
    return Prior
예제 #6
0
class GaussObsModel(AbstractObsModel):

    ''' Full-covariance gaussian data generation model for real vectors.

    Attributes for Prior (Normal-Wishart)
    --------
    nu : float
        degrees of freedom
    B : 2D array, size D x D
        scale parameters that set mean of parameter sigma
    m : 1D array, size D
        mean of the parameter mu
    kappa : float
        scalar precision on parameter mu

    Attributes for k-th component of EstParams (EM point estimates)
    ---------
    mu[k] : 1D array, size D
    Sigma[k] : 2D array, size DxD

    Attributes for k-th component of Post (VB parameter)
    ---------
    nu[k] : float
    B[k] : 1D array, size D
    m[k] : 1D array, size D
    kappa[k] : float

    '''

    def __init__(self, inferType='EM', D=0, min_covar=None,
                 Data=None,
                 **PriorArgs):
        ''' Initialize bare obsmodel with valid prior hyperparameters.

        Resulting object lacks either EstParams or Post,
        which must be created separately (see init_global_params).
        '''
        if Data is not None:
            self.D = Data.dim
        else:
            self.D = int(D)
        self.K = 0
        self.inferType = inferType
        self.min_covar = min_covar
        self.createPrior(Data, **PriorArgs)
        self.Cache = dict()

    def createPrior(self, Data, nu=0, B=None,
                    m=None, kappa=None,
                    MMat='zero',
                    ECovMat=None, sF=1.0, **kwargs):
        ''' Initialize Prior ParamBag attribute.

        Post Condition
        ------
        Prior expected covariance matrix set to match provided value.
        '''
        D = self.D
        nu = np.maximum(nu, D + 2)
        if B is None:
            if ECovMat is None or isinstance(ECovMat, str):
                ECovMat = createECovMatFromUserInput(D, Data, ECovMat, sF)
            B = ECovMat * (nu - D - 1)
        if B.ndim == 1:
            B = np.asarray([B], dtype=np.float)
        elif B.ndim == 0:
            B = np.asarray([[B]], dtype=np.float)
        if m is None:
            if MMat == 'data':
                m = np.mean(Data.X, axis=0)
            else:
                m = np.zeros(D)
        elif m.ndim < 1:
            m = np.asarray([m], dtype=np.float)
        kappa = np.maximum(kappa or 0, 1e-8)
        self.Prior = ParamBag(K=0, D=D)
        self.Prior.setField('nu', nu, dims=None)
        self.Prior.setField('kappa', kappa, dims=None)
        self.Prior.setField('m', m, dims=('D'))
        self.Prior.setField('B', B, dims=('D', 'D'))

    def get_mean_for_comp(self, k=None):
        if hasattr(self, 'EstParams'):
            return self.EstParams.mu[k]
        elif k is None or k == 'prior':
            return self.Prior.m
        else:
            return self.Post.m[k]

    def get_covar_mat_for_comp(self, k=None):
        if hasattr(self, 'EstParams'):
            return self.EstParams.Sigma[k]
        elif k is None or k == 'prior':
            return self._E_CovMat()
        else:
            return self._E_CovMat(k)

    def get_name(self):
        return 'Gauss'

    def get_info_string(self):
        return 'Gaussian with full covariance.'

    def get_info_string_prior(self):
        msg = 'Gauss-Wishart on mean and covar of each cluster\n'
        if self.D > 2:
            sfx = ' ...'
        else:
            sfx = ''
        S = self._E_CovMat()[:2, :2]
        msg += 'E[  mean[k] ] = \n %s %s\n' % (str(self.Prior.m[:2]), sfx)
        msg += 'E[ covar[k] ] = \n'
        msg += str(S) + sfx
        msg = msg.replace('\n', '\n  ')
        return msg

    def setEstParams(self, obsModel=None, SS=None, LP=None, Data=None,
                     mu=None, Sigma=None,
                     **kwargs):
        ''' Create EstParams ParamBag with fields mu, Sigma
        '''
        self.ClearCache()
        if obsModel is not None:
            self.EstParams = obsModel.EstParams.copy()
            self.K = self.EstParams.K
            return

        if LP is not None and Data is not None:
            SS = self.calcSummaryStats(Data, None, LP)

        if SS is not None:
            self.updateEstParams(SS)
        else:
            Sigma = as3D(Sigma)
            K, D, D2 = Sigma.shape
            mu = as2D(mu)
            if mu.shape[0] != K:
                mu = mu.T
            assert mu.shape[0] == K
            assert mu.shape[1] == D
            self.EstParams = ParamBag(K=K, D=D)
            self.EstParams.setField('mu', mu, dims=('K', 'D'))
            self.EstParams.setField('Sigma', Sigma, dims=('K', 'D', 'D'))
            self.K = self.EstParams.K

    def setEstParamsFromPost(self, Post):
        ''' Convert from Post (nu, B, m, kappa) to EstParams (mu, Sigma),
             each EstParam is set to its posterior mean.
        '''
        self.EstParams = ParamBag(K=Post.K, D=Post.D)
        mu = Post.m.copy()
        Sigma = Post.B / (Post.nu[:, np.newaxis, np.newaxis] - Post.D - 1)
        self.EstParams.setField('mu', mu, dims=('K', 'D'))
        self.EstParams.setField('Sigma', Sigma, dims=('K', 'D', 'D'))
        self.K = self.EstParams.K

    def setPostFactors(self, obsModel=None, SS=None, LP=None, Data=None,
                       nu=0, B=0, m=0, kappa=0,
                       **kwargs):
        ''' Set attribute Post to provided values.
        '''
        self.ClearCache()
        if obsModel is not None:
            if hasattr(obsModel, 'Post'):
                self.Post = obsModel.Post.copy()
                self.K = self.Post.K
            else:
                self.setPostFromEstParams(obsModel.EstParams)
            return

        if LP is not None and Data is not None:
            SS = self.calcSummaryStats(Data, None, LP)

        if SS is not None:
            self.updatePost(SS)
        else:
            m = as2D(m)
            if m.shape[1] != self.D:
                m = m.T.copy()
            K, _ = m.shape
            self.Post = ParamBag(K=K, D=self.D)
            self.Post.setField('nu', as1D(nu), dims=('K'))
            self.Post.setField('B', B, dims=('K', 'D', 'D'))
            self.Post.setField('m', m, dims=('K', 'D'))
            self.Post.setField('kappa', as1D(kappa), dims=('K'))
        self.K = self.Post.K

    def setPostFromEstParams(self, EstParams, Data=None, N=None):
        ''' Set attribute Post based on values in EstParams.
        '''
        K = EstParams.K
        D = EstParams.D
        if Data is not None:
            N = Data.nObsTotal
        N = np.asarray(N, dtype=np.float)
        if N.ndim == 0:
            N = N / K * np.ones(K)

        nu = self.Prior.nu + N
        B = np.zeros((K, D, D))
        for k in xrange(K):
            B[k] = (nu[k] - D - 1) * EstParams.Sigma[k]
        m = EstParams.mu.copy()
        kappa = self.Prior.kappa + N

        self.Post = ParamBag(K=K, D=D)
        self.Post.setField('nu', nu, dims=('K'))
        self.Post.setField('B', B, dims=('K', 'D', 'D'))
        self.Post.setField('m', m, dims=('K', 'D'))
        self.Post.setField('kappa', kappa, dims=('K'))
        self.K = self.Post.K

    def calcSummaryStats(self, Data, SS, LP, **kwargs):
        ''' Calculate summary statistics for given dataset and local parameters

        Returns
        --------
        SS : SuffStatBag object, with K components.
        '''
        return calcSummaryStats(Data, SS, LP, **kwargs)

    def forceSSInBounds(self, SS):
        ''' Force count vector N to remain positive

            This avoids numerical problems due to incremental add/subtract ops
            which can cause computations like
                x = 10.
                x += 1e-15
                x -= 10
                x -= 1e-15
            to be slightly different than zero instead of exactly zero.

            Returns
            -------
            None. SS.N updated in-place.
        '''
        np.maximum(SS.N, 0, out=SS.N)

    def incrementSS(self, SS, k, x):
        SS.x[k] += x
        SS.xxT[k] += np.outer(x, x)

    def decrementSS(self, SS, k, x):
        SS.x[k] -= x
        SS.xxT[k] -= np.outer(x, x)

    def calcSummaryStatsForContigBlock(self, Data, SS=None, a=0, b=0):
        ''' Calculate sufficient stats for a single contiguous block of data
        '''
        if SS is None:
            SS = SuffStatBag(K=1, D=Data.dim)

        SS.setField('N', (b - a) * np.ones(1), dims='K')
        SS.setField(
            'x', np.sum(Data.X[a:b], axis=0)[np.newaxis, :], dims=('K', 'D'))
        SS.setField(
            'xxT', dotATA(Data.X[a:b])[np.newaxis, :, :], dims=('K', 'D', 'D'))
        return SS

    def calcLogSoftEvMatrix_FromEstParams(self, Data, **kwargs):
        ''' Compute log soft evidence matrix for Dataset under EstParams.

        Returns
        ---------
        L : 2D array, N x K
        '''
        K = self.EstParams.K
        L = np.empty((Data.nObs, K))
        for k in xrange(K):
            L[:, k] = - 0.5 * self.D * LOGTWOPI \
                - 0.5 * self._logdetSigma(k)  \
                - 0.5 * self._mahalDist_EstParam(Data.X, k)
        return L

    def _mahalDist_EstParam(self, X, k):
        ''' Calc Mahalanobis distance from comp k to every row of X

        Args
        ---------
        X : 2D array, size N x D
        k : integer ID of comp

        Returns
        ----------
        dist : 1D array, size N
        '''
        Q = np.linalg.solve(self.GetCached('cholSigma', k),
                            (X - self.EstParams.mu[k]).T)
        Q *= Q
        return np.sum(Q, axis=0)

    def _cholSigma(self, k):
        ''' Calculate lower cholesky decomposition of Sigma[k]

        Returns
        --------
        L : 2D array, size D x D, lower triangular
            Sigma = np.dot(L, L.T)
        '''
        return scipy.linalg.cholesky(self.EstParams.Sigma[k], lower=1)

    def _logdetSigma(self, k):
        ''' Calculate log determinant of EstParam.Sigma for comp k

        Returns
        ---------
        logdet : scalar real
        '''
        return 2 * np.sum(np.log(np.diag(self.GetCached('cholSigma', k))))

    def updateEstParams_MaxLik(self, SS):
        ''' Update attribute EstParams for all comps given suff stats.

        Update uses the maximum likelihood objective for point estimation.

        Post Condition
        ---------
        Attributes K and EstParams updated in-place.
        '''
        self.ClearCache()
        if not hasattr(self, 'EstParams') or self.EstParams.K != SS.K:
            self.EstParams = ParamBag(K=SS.K, D=SS.D)

        mu = SS.x / SS.N[:, np.newaxis]
        minCovMat = self.min_covar * np.eye(SS.D)
        covMat = np.tile(minCovMat, (SS.K, 1, 1))
        for k in xrange(SS.K):
            covMat[k] += SS.xxT[k] / SS.N[k] - np.outer(mu[k], mu[k])
        self.EstParams.setField('mu', mu, dims=('K', 'D'))
        self.EstParams.setField('Sigma', covMat, dims=('K', 'D', 'D'))
        self.K = SS.K

    def updateEstParams_MAP(self, SS):
        ''' Update attribute EstParams for all comps given suff stats.

        Update uses the MAP objective for point estimation.

        Post Condition
        ---------
        Attributes K and EstParams updated in-place.
        '''
        self.ClearCache()
        if not hasattr(self, 'EstParams') or self.EstParams.K != SS.K:
            self.EstParams = ParamBag(K=SS.K, D=SS.D)

        Prior = self.Prior
        nu = Prior.nu + SS.N
        kappa = Prior.kappa + SS.N
        PB = Prior.B + Prior.kappa * np.outer(Prior.m, Prior.m)

        m = np.empty((SS.K, SS.D))
        B = np.empty((SS.K, SS.D, SS.D))
        for k in xrange(SS.K):
            km_x = Prior.kappa * Prior.m + SS.x[k]
            m[k] = 1.0 / kappa[k] * km_x
            B[k] = PB + SS.xxT[k] - 1.0 / kappa[k] * np.outer(km_x, km_x)

        mu, Sigma = MAPEstParams_inplace(nu, B, m, kappa)
        self.EstParams.setField('mu', mu, dims=('K', 'D'))
        self.EstParams.setField('Sigma', Sigma, dims=('K', 'D', 'D'))
        self.K = SS.K

    def updatePost(self, SS):
        ''' Update attribute Post for all comps given suff stats.

        Update uses the variational objective.

        Post Condition
        ---------
        Attributes K and Post updated in-place.
        '''
        self.ClearCache()
        if not hasattr(self, 'Post') or self.Post.K != SS.K:
            self.Post = ParamBag(K=SS.K, D=SS.D)

        nu, B, m, kappa = self.calcPostParams(SS)
        self.Post.setField('nu', nu, dims=('K'))
        self.Post.setField('kappa', kappa, dims=('K'))
        self.Post.setField('m', m, dims=('K', 'D'))
        self.Post.setField('B', B, dims=('K', 'D', 'D'))
        self.K = SS.K

    def calcPostParams(self, SS):
        ''' Calc updated params (nu, B, m, kappa) for all comps given suff stats

            These params define the common-form of the exponential family
            Normal-Wishart posterior distribution over mu, diag(Lambda)

            Returns
            --------
            nu : 1D array, size K
            B : 3D array, size K x D x D, each B[k] is symmetric and pos. def.
            m : 2D array, size K x D
            kappa : 1D array, size K
        '''
        Prior = self.Prior
        nu = Prior.nu + SS.N
        kappa = Prior.kappa + SS.N
        m = (Prior.kappa * Prior.m + SS.x) / kappa[:, np.newaxis]
        Bmm = Prior.B + Prior.kappa * np.outer(Prior.m, Prior.m)
        B = SS.xxT + Bmm[np.newaxis, :]
        for k in xrange(B.shape[0]):
            B[k] -= kappa[k] * np.outer(m[k], m[k])
        return nu, B, m, kappa

    def calcPostParamsForComp(self, SS, kA=None, kB=None):
        ''' Calc params (nu, B, m, kappa) for specific comp, given suff stats

            These params define the common-form of the exponential family
            Normal-Wishart posterior distribution over mu[k], diag(Lambda)[k]

            Returns
            --------
            nu : positive scalar
            B : 2D array, size D x D, symmetric and positive definite
            m : 1D array, size D
            kappa : positive scalar
        '''
        if kB is None:
            SN = SS.N[kA]
            Sx = SS.x[kA]
            SxxT = SS.xxT[kA]
        else:
            SN = SS.N[kA] + SS.N[kB]
            Sx = SS.x[kA] + SS.x[kB]
            SxxT = SS.xxT[kA] + SS.xxT[kB]
        Prior = self.Prior
        nu = Prior.nu + SN
        kappa = Prior.kappa + SN
        m = (Prior.kappa * Prior.m + Sx) / kappa
        B = Prior.B + SxxT \
            + Prior.kappa * np.outer(Prior.m, Prior.m) \
            - kappa * np.outer(m, m)
        return nu, B, m, kappa

    def updatePost_stochastic(self, SS, rho):
        ''' Update attribute Post for all comps given suff stats

        Update uses the stochastic variational formula.

        Post Condition
        ---------
        Attributes K and Post updated in-place.
        '''
        assert hasattr(self, 'Post')
        assert self.Post.K == SS.K
        self.ClearCache()

        self.convertPostToNatural()
        nu, Bnat, km, kappa = self.calcNaturalPostParams(SS)
        Post = self.Post
        Post.nu[:] = (1 - rho) * Post.nu + rho * nu
        Post.Bnat[:] = (1 - rho) * Post.Bnat + rho * Bnat
        Post.km[:] = (1 - rho) * Post.km + rho * km
        Post.kappa[:] = (1 - rho) * Post.kappa + rho * kappa
        self.convertPostToCommon()

    def calcNaturalPostParams(self, SS):
        ''' Calc  natural posterior parameters given suff stats SS.

        Returns
        --------
        nu : 1D array, size K
        Bnat : 3D array, size K x D x D
        km : 2D array, size K x D
        kappa : 1D array, size K
        '''
        Prior = self.Prior
        nu = Prior.nu + SS.N
        kappa = Prior.kappa + SS.N
        km = Prior.kappa * Prior.m + SS.x
        Bnat = (Prior.B + Prior.kappa * np.outer(Prior.m, Prior.m)) + SS.xxT
        return nu, Bnat, km, kappa

    def convertPostToNatural(self):
        ''' Convert current posterior params from common to natural form
        '''
        Post = self.Post
        assert hasattr(Post, 'nu')
        assert hasattr(Post, 'kappa')
        km = Post.m * Post.kappa[:, np.newaxis]
        Bnat = np.empty((self.K, self.D, self.D))
        for k in xrange(self.K):
            Bnat[k] = Post.B[k] + np.outer(km[k], km[k]) / Post.kappa[k]
        Post.setField('km', km, dims=('K', 'D'))
        Post.setField('Bnat', Bnat, dims=('K', 'D', 'D'))

    def convertPostToCommon(self):
        ''' Convert current posterior params from natural to common form
        '''
        Post = self.Post
        assert hasattr(Post, 'nu')
        assert hasattr(Post, 'kappa')
        if hasattr(Post, 'm'):
            Post.m[:] = Post.km / Post.kappa[:, np.newaxis]
        else:
            m = Post.km / Post.kappa[:, np.newaxis]
            Post.setField('m', m, dims=('K', 'D'))

        if hasattr(Post, 'B'):
            B = Post.B  # update in place, no reallocation!
        else:
            B = np.empty((self.K, self.D, self.D))
        for k in xrange(self.K):
            B[k] = Post.Bnat[k] - \
                np.outer(Post.km[k], Post.km[k]) / Post.kappa[k]
        Post.setField('B', B, dims=('K', 'D', 'D'))

    def calcLogSoftEvMatrix_FromPost(self, Data, **kwargs):
        ''' Calculate expected log soft ev matrix under Post.

        Returns
        ------
        L : 2D array, size N x K
        '''
        K = self.Post.K
        L = np.zeros((Data.nObs, K))
        for k in xrange(K):
            L[:, k] = - 0.5 * self.D * LOGTWOPI \
                + 0.5 * self.GetCached('E_logdetL', k)  \
                - 0.5 * self._mahalDist_Post(Data.X, k)
        return L

    def _mahalDist_Post(self, X, k):
        ''' Calc expected mahalonobis distance from comp k to each data atom

            Returns
            --------
            distvec : 1D array, size nObs
                   distvec[n] gives E[ (x-\mu) \Lam (x-\mu) ] for comp k
        '''
        Q = np.linalg.solve(self.GetCached('cholB', k),
                            (X - self.Post.m[k]).T)
        Q *= Q
        return self.Post.nu[k] * np.sum(Q, axis=0) \
            + self.D / self.Post.kappa[k]

    def calcELBO_Memoized(self, SS, returnVec=0, afterMStep=False, **kwargs):
        """ Calculate obsModel's objective using suff stats SS and Post.

        Args
        -------
        SS : bnpy SuffStatBag
        afterMStep : boolean flag
            if 1, elbo calculated assuming M-step just completed

        Returns
        -------
        obsELBO : scalar float
            Equal to E[ log p(x) + log p(phi) - log q(phi)]
        """
        elbo = np.zeros(SS.K)
        Post = self.Post
        Prior = self.Prior
        for k in xrange(SS.K):
            elbo[k] = c_Diff(Prior.nu,
                             self.GetCached('logdetB'),
                             Prior.m, Prior.kappa,
                             Post.nu[k],
                             self.GetCached('logdetB', k),
                             Post.m[k], Post.kappa[k],
                             )
            if not afterMStep:
                aDiff = SS.N[k] + Prior.nu - Post.nu[k]
                bDiff = SS.xxT[k] + Prior.B \
                                  + Prior.kappa * np.outer(Prior.m, Prior.m) \
                    - Post.B[k] \
                    - Post.kappa[k] * np.outer(Post.m[k], Post.m[k])
                cDiff = SS.x[k] + Prior.kappa * Prior.m \
                    - Post.kappa[k] * Post.m[k]
                dDiff = SS.N[k] + Prior.kappa - Post.kappa[k]
                elbo[k] += 0.5 * aDiff * self.GetCached('E_logdetL', k) \
                    - 0.5 * self._trace__E_L(bDiff, k) \
                    + np.inner(cDiff, self.GetCached('E_Lmu', k)) \
                    - 0.5 * dDiff * self.GetCached('E_muLmu', k)
        if returnVec:
            return elbo - (0.5 * SS.D * LOGTWOPI) * SS.N
        return elbo.sum() - 0.5 * np.sum(SS.N) * SS.D * LOGTWOPI

    def getDatasetScale(self, SS):
        ''' Get number of observed scalars in dataset from suff stats.

        Used for normalizing the ELBO so it has reasonable range.

        Returns
        ---------
        s : scalar positive integer
        '''
        return SS.N.sum() * SS.D

    def calcHardMergeGap(self, SS, kA, kB):
        ''' Calculate change in ELBO after a hard merge applied to this model

            Returns
            ---------
            gap : scalar real, indicates change in ELBO after merge of kA, kB
        '''
        Post = self.Post
        Prior = self.Prior
        cA = c_Func(Post.nu[kA], Post.B[kA], Post.m[kA], Post.kappa[kA])
        cB = c_Func(Post.nu[kB], Post.B[kB], Post.m[kB], Post.kappa[kB])
        cPrior = c_Func(Prior.nu, Prior.B, Prior.m, Prior.kappa)

        nu, B, m, kappa = self.calcPostParamsForComp(SS, kA, kB)
        cAB = c_Func(nu, B, m, kappa)
        return cA + cB - cPrior - cAB

    def calcHardMergeGap_AllPairs(self, SS):
        ''' Calculate change in ELBO for all candidate hard merge pairs

        Returns
        ---------
        Gap : 2D array, size K x K, upper-triangular entries non-zero
              Gap[j,k] : scalar change in ELBO after merge of k into j
        '''
        Post = self.Post
        Prior = self.Prior
        cPrior = c_Func(Prior.nu, Prior.B, Prior.m, Prior.kappa)
        c = np.zeros(SS.K)
        for k in xrange(SS.K):
            c[k] = c_Func(Post.nu[k], Post.B[k], Post.m[k], Post.kappa[k])

        Gap = np.zeros((SS.K, SS.K))
        for j in xrange(SS.K):
            for k in xrange(j + 1, SS.K):
                nu, B, m, kappa = self.calcPostParamsForComp(SS, j, k)
                cjk = c_Func(nu, B, m, kappa)
                Gap[j, k] = c[j] + c[k] - cPrior - cjk
        return Gap

    def calcHardMergeGap_SpecificPairs(self, SS, PairList):
        ''' Calc change in ELBO for specific list of candidate hard merge pairs

        Returns
        ---------
        Gaps : 1D array, size L
              Gap[j] : scalar change in ELBO after merge of pair in PairList[j]
        '''
        Gaps = np.zeros(len(PairList))
        for ii, (kA, kB) in enumerate(PairList):
            Gaps[ii] = self.calcHardMergeGap(SS, kA, kB)
        return Gaps

    def calcHardMergeGap_SpecificPairSS(self, SS1, SS2):
        ''' Calc change in ELBO for merge of two K=1 suff stat bags.

        Returns
        -------
        gap : scalar float
        '''
        assert SS1.K == 1
        assert SS2.K == 1

        Prior = self.Prior
        cPrior = c_Func(Prior.nu, Prior.B, Prior.m, Prior.kappa)

        # Compute cumulants of individual states 1 and 2
        nu1, B1, m1, kappa1 = self.calcPostParamsForComp(SS1, 0)
        nu2, B2, m2, kappa2 = self.calcPostParamsForComp(SS2, 0)
        c1 = c_Func(nu1, B1, m1, kappa1)
        c2 = c_Func(nu2, B2, m2, kappa2)

        # Compute cumulant of merged state 1&2
        SS12 = SS1 + SS2
        nu12, B12, m12, kappa12 = self.calcPostParamsForComp(SS12, 0)
        c12 = c_Func(nu12, B12, m12, kappa12)

        return c1 + c2 - cPrior - c12

    def calcLogMargLikForComp(self, SS, kA, kB=None, **kwargs):
        ''' Calc log marginal likelihood of data assigned to given component

        Args
        -------
        SS : bnpy suff stats object
        kA : integer ID of target component to compute likelihood for
        kB : (optional) integer ID of second component.
             If provided, we merge kA, kB into one component for calculation.
        Returns
        -------
        logM : scalar real
               logM = log p( data assigned to comp kA )
                      computed up to an additive constant
        '''
        nu, beta, m, kappa = self.calcPostParamsForComp(SS, kA, kB)
        return -1 * c_Func(nu, beta, m, kappa)

    def calcMargLik(self, SS):
        ''' Calc log marginal likelihood across all comps, given suff stats

            Returns
            --------
            logM : scalar real
                   logM = \sum_{k=1}^K log p( data assigned to comp k | Prior)
        '''
        return self.calcMargLik_CFuncForLoop(SS)

    def calcMargLik_CFuncForLoop(self, SS):
        Prior = self.Prior
        logp = np.zeros(SS.K)
        for k in xrange(SS.K):
            nu, B, m, kappa = self.calcPostParamsForComp(SS, k)
            logp[k] = c_Diff(Prior.nu, Prior.B, Prior.m, Prior.kappa,
                             nu, B, m, kappa)
        return np.sum(logp) - 0.5 * np.sum(SS.N) * LOGTWOPI

    def calcPredProbVec_Unnorm(self, SS, x):
        ''' Calculate predictive probability that each comp assigns to vector x

            Returns
            --------
            p : 1D array, size K, all entries positive
                p[k] \propto p( x | SS for comp k)
        '''
        return self._calcPredProbVec_Fast(SS, x)

    def _calcPredProbVec_cFunc(self, SS, x):
        nu, B, m, kappa = self.calcPostParams(SS)
        pSS = SS.copy()
        pSS.N += 1
        pSS.x += x[np.newaxis, :]
        pSS.xxT += np.outer(x, x)[np.newaxis, :, :]
        pnu, pB, pm, pkappa = self.calcPostParams(pSS)
        logp = np.zeros(SS.K)
        for k in xrange(SS.K):
            logp[k] = c_Diff(nu[k], B[k], m[k], kappa[k],
                             pnu[k], pB[k], pm[k], pkappa[k])
        return np.exp(logp - np.max(logp))

    def _calcPredProbVec_Fast(self, SS, x):
        nu, B, m, kappa = self.calcPostParams(SS)
        kB = B
        kB *= ((kappa + 1) / kappa)[:, np.newaxis, np.newaxis]
        logp = np.zeros(SS.K)
        p = logp  # Rename so its not confusing what we're returning
        for k in xrange(SS.K):
            cholKB = scipy.linalg.cholesky(kB[k], lower=1)
            logdetKB = 2 * np.sum(np.log(np.diag(cholKB)))
            mVec = np.linalg.solve(cholKB, x - m[k])
            mDist_k = np.inner(mVec, mVec)
            logp[k] = -0.5 * logdetKB - 0.5 * \
                (nu[k] + 1) * np.log(1.0 + mDist_k)
        logp += gammaln(0.5 * (nu + 1)) - gammaln(0.5 * (nu + 1 - self.D))
        logp -= np.max(logp)
        np.exp(logp, out=p)
        return p

    def _Verify_calcPredProbVec(self, SS, x):
        ''' Verify that the predictive prob vector is correct,
              by comparing very different implementations
        '''
        pA = self._calcPredProbVec_Fast(SS, x)
        pC = self._calcPredProbVec_cFunc(SS, x)
        pA /= np.sum(pA)
        pC /= np.sum(pC)
        assert np.allclose(pA, pC)

    def _E_CovMat(self, k=None):
        if k is None:
            B = self.Prior.B
            nu = self.Prior.nu
        else:
            B = self.Post.B[k]
            nu = self.Post.nu[k]
        return B / (nu - self.D - 1)

    def _cholB(self, k=None):
        if k == 'all':
            retArr = np.zeros((self.K, self.D, self.D))
            for kk in xrange(self.K):
                retArr[kk] = self.GetCached('cholB', kk)
            return retArr
        elif k is None:
            B = self.Prior.B
        else:
            # k is one of [0, 1, ... K-1]
            B = self.Post.B[k]
        return scipy.linalg.cholesky(B, lower=True)

    def _logdetB(self, k=None):
        cholB = self.GetCached('cholB', k)
        return 2 * np.sum(np.log(np.diag(cholB)))

    def _E_logdetL(self, k=None):
        dvec = np.arange(1, self.D + 1, dtype=np.float)
        if k == 'all':
            dvec = dvec[:, np.newaxis]
            retVec = self.D * LOGTWO * np.ones(self.K)
            for kk in xrange(self.K):
                retVec[kk] -= self.GetCached('logdetB', kk)
            nuT = self.Post.nu[np.newaxis, :]
            retVec += np.sum(digamma(0.5 * (nuT + 1 - dvec)), axis=0)
            return retVec
        elif k is None:
            nu = self.Prior.nu
        else:
            nu = self.Post.nu[k]
        return self.D * LOGTWO \
            - self.GetCached('logdetB', k) \
            + np.sum(digamma(0.5 * (nu + 1 - dvec)))

    def _trace__E_L(self, Smat, k=None):
        if k is None:
            nu = self.Prior.nu
            B = self.Prior.B
        else:
            nu = self.Post.nu[k]
            B = self.Post.B[k]
        return nu * np.trace(np.linalg.solve(B, Smat))

    def _E_Lmu(self, k=None):
        if k is None:
            nu = self.Prior.nu
            B = self.Prior.B
            m = self.Prior.m
        else:
            nu = self.Post.nu[k]
            B = self.Post.B[k]
            m = self.Post.m[k]
        return nu * np.linalg.solve(B, m)

    def _E_muLmu(self, k=None):
        if k is None:
            nu = self.Prior.nu
            kappa = self.Prior.kappa
            m = self.Prior.m
            B = self.Prior.B
        else:
            nu = self.Post.nu[k]
            kappa = self.Post.kappa[k]
            m = self.Post.m[k]
            B = self.Post.B[k]
        Q = np.linalg.solve(self.GetCached('cholB', k), m.T)
        return self.D / kappa + nu * np.inner(Q, Q)

    def getSerializableParamsForLocalStep(self):
        """ Get compact dict of params for local step.

        Returns
        -------
        Info : dict
        """
        if self.inferType == 'EM':
            raise NotImplementedError('TODO')
        return dict(inferType=self.inferType,
                    K=self.K,
                    D=self.D,
                    )

    def fillSharedMemDictForLocalStep(self, ShMem=None):
        """ Get dict of shared mem arrays needed for parallel local step.

        Returns
        -------
        ShMem : dict of RawArray objects
        """
        if ShMem is None:
            ShMem = dict()
        if 'nu' in ShMem:
            fillSharedMemArray(ShMem['nu'], self.Post.nu)
            fillSharedMemArray(ShMem['kappa'], self.Post.kappa)
            fillSharedMemArray(ShMem['m'], self.Post.m)
            fillSharedMemArray(ShMem['cholB'], self._cholB('all'))
            fillSharedMemArray(ShMem['E_logdetL'], self._E_logdetL('all'))

        else:
            ShMem['nu'] = numpyToSharedMemArray(self.Post.nu)
            ShMem['kappa'] = numpyToSharedMemArray(self.Post.kappa)
            ShMem['m'] = numpyToSharedMemArray(self.Post.m.copy())
            ShMem['cholB'] = numpyToSharedMemArray(self._cholB('all'))
            ShMem['E_logdetL'] = numpyToSharedMemArray(self._E_logdetL('all'))

        return ShMem

    def getLocalAndSummaryFunctionHandles(self):
        """ Get function handles for local step and summary step

        Useful for parallelized algorithms.

        Returns
        -------
        calcLocalParams : f handle
        calcSummaryStats : f handle
        """
        return calcLocalParams, calcSummaryStats


    def calcSmoothedMu(self, X, W=None):
        ''' Compute smoothed estimate of mean of statistic xxT.

        Args
        ----
        X : 2D array, size N x D

        Returns
        -------
        Mu_1 : 2D array, size D x D
            Expected value of Cov[ X[n] ]
        Mu_2 : 1D array, size D
            Expected value of Mean[ X[n] ]
        '''
        if X is None:
            Mu1 = self.Prior.B / self.Prior.nu
            Mu2 = self.Prior.m
            return Mu1, Mu2

        if X.ndim == 1:
            X = X[np.newaxis,:]
        N, D = X.shape
        # Compute suff stats
        if W is None:
            sum_wxxT = np.dot(X.T, X)
            sum_wx = np.sum(X, axis=0)
            sum_w = X.shape[0]
        else:
            W = as1D(W)
            sqrtWX = np.sqrt(W)[:,np.newaxis] * X
            sum_wxxT = np.dot(sqrtWX.T, sqrtWX)
            sum_wx = np.dot(W, X)
            sum_w = np.sum(W)

        kappa = self.Prior.kappa + sum_w
        m = (self.Prior.m * self.Prior.kappa + sum_wx) / kappa
        Mu_2 = m

        prior_kmmT = self.Prior.kappa * np.outer(self.Prior.m, self.Prior.m)
        post_kmmT = kappa * np.outer(m,m)
        B = sum_wxxT + self.Prior.B + prior_kmmT - post_kmmT
        Mu_1 = B / (self.Prior.nu + sum_w)

        assert Mu_1.ndim == 2
        assert Mu_1.shape == (D, D,)
        assert Mu_2.shape == (D,)
        return Mu_1, Mu_2

    def calcSmoothedBregDiv(
            self, X, Mu, W=None,
            eps=1e-10,
            smoothFrac=0.0,
            includeOnlyFastTerms=False,
            DivDataVec=None,
            returnDivDataVec=False,
            return1D=False,
            **kwargs):
        ''' Compute Bregman divergence between data X and clusters Mu.

        Smooth the data via update with prior parameters.

        Args
        ----
        X : 2D array, size N x D
        Mu : list of size K, or tuple

        Returns
        -------
        Div : 2D array, N x K
            Div[n,k] = smoothed distance between X[n] and Mu[k]
        '''
        # Parse X
        if X.ndim < 2:
            X = X[np.newaxis,:]
        assert X.ndim == 2
        N = X.shape[0]
        D = X.shape[1]
        # Parse Mu
        if isinstance(Mu, tuple):
            Mu = [Mu]
        assert isinstance(Mu, list)
        K = len(Mu)
        assert Mu[0][0].shape[0] == D
        assert Mu[0][0].shape[1] == D
        assert Mu[0][1].size == D

        prior_x = self.Prior.m
        prior_covx = self.Prior.B / (self.Prior.nu)
        CovX = eps * prior_covx

        Div = np.zeros((N, K))
        for k in xrange(K):
            chol_CovMu_k = np.linalg.cholesky(Mu[k][0])
            logdet_CovMu_k = 2.0 * np.sum(np.log(np.diag(chol_CovMu_k)))
            tr_InvMu_CovX_k = np.trace(np.linalg.solve(
                Mu[k][0], CovX))
            XdiffMu_k = X - Mu[k][1]
            tr_InvMu_XdXdT_k = np.linalg.solve(chol_CovMu_k, XdiffMu_k.T)
            tr_InvMu_XdXdT_k *= tr_InvMu_XdXdT_k
            tr_InvMu_XdXdT_k = tr_InvMu_XdXdT_k.sum(axis=0)
            Div[:,k] = \
                + 0.5 * logdet_CovMu_k \
                + 0.5 * (tr_InvMu_CovX_k + tr_InvMu_XdXdT_k)

        if not includeOnlyFastTerms:
            if DivDataVec is None:
                # Compute DivDataVec : 1D array of size N
                # This is the per-row additive constant indep. of k. 
                DivDataVec = -0.5 * D * np.ones(N)
                s, logdet = np.linalg.slogdet(CovX)
                logdet_CovX = s * logdet
                DivDataVec -= 0.5 * logdet_CovX
        
            Div += DivDataVec[:,np.newaxis]

        # Apply per-atom weights to divergences.
        if W is not None:
            assert W.ndim == 1
            assert W.size == N
            Div *= W[:,np.newaxis]
        # Verify divergences are strictly non-negative 
        if not includeOnlyFastTerms:
            minDiv = Div.min()
            if minDiv < 0:
                if minDiv < -1e-6:
                    raise AssertionError(
                        "Expected Div.min() to be positive or" + \
                        " indistinguishable from zero. Instead " + \
                        " minDiv=% .3e" % (minDiv))
                np.maximum(Div, 0, out=Div)
                minDiv = Div.min()
            assert minDiv >= 0
        if return1D:
            Div = Div[:,0]
        if returnDivDataVec:
            return Div, DivDataVec
        return Div

    def calcBregDivFromPrior(self, Mu, smoothFrac=0.0):
        ''' Compute Bregman divergence between Mu and prior mean.

        Returns
        -------
        Div : 1D array, size K
            Div[k] = distance between Mu[k] and priorMu
        '''
        assert isinstance(Mu, list)
        K = len(Mu)
        assert K >= 1
        assert Mu[0][0].ndim == 2
        assert Mu[0][1].ndim == 1
        D = Mu[0][0].shape[0]
        assert D == Mu[0][0].shape[1]
        assert D == Mu[0][1].size

        priorCov = self.Prior.B / self.Prior.nu
        priorMu_2 = self.Prior.m

        priorN_ZMG = (1-smoothFrac) * self.Prior.nu
        priorN_FVG = (1-smoothFrac) * self.Prior.kappa

        Div_ZMG = np.zeros(K) # zero-mean gaussian
        Div_FVG = np.zeros(K) # fixed variance gaussian
 
        s, logdet = np.linalg.slogdet(priorCov)
        logdet_priorCov = s * logdet
        for k in xrange(K):
            Cov_k = Mu[k][0]
            s, logdet = np.linalg.slogdet(Cov_k)
            logdet_Cov_k = s * logdet
            Div_ZMG[k] = 0.5 * logdet_Cov_k + \
                - 0.5 * logdet_priorCov \
                + 0.5 * np.trace(np.linalg.solve(Cov_k, priorCov)) \
                - 0.5
            pmT = np.outer(priorMu_2 - Mu[k][1], priorMu_2 - Mu[k][1])
            Div_FVG[k] = 0.5 * np.trace(np.linalg.solve(Cov_k, pmT))
            
        return priorN_ZMG * Div_ZMG + priorN_FVG * Div_FVG
예제 #7
0
class AutoRegGaussObsModel(AbstractObsModel):
    ''' First-order auto-regressive data generation model.

    Attributes for Prior (Matrix-Normal-Wishart)
    --------
    nu : float
        degrees of freedom
    B : 2D array, size D x D
        scale matrix that sets mean of parameter Sigma
    M : 2D array, size D x D
        sets mean of parameter A
    V : 2D array, size D x D
        scale matrix that sets covariance of parameter A

    Attributes for k-th component of EstParams (EM point estimates)
    ---------
    A[k] : 2D array, size D x D
        coefficient matrix for auto-regression.
    Sigma[k] : 2D array, size D x D
        covariance matrix.

    Attributes for k-th component of Post (VB parameter)
    ---------
    nu[k] : float
    B[k] : 2D array, size D x D
    M[k] : 2D array, size D x D
    V[k] : 2D array, size D x D
    '''
    def __init__(self,
                 inferType='EM',
                 D=None,
                 E=None,
                 min_covar=None,
                 Data=None,
                 **PriorArgs):
        ''' Initialize bare obsmodel with valid prior hyperparameters.

        Resulting object lacks either EstParams or Post,
        which must be created separately (see init_global_params).
        '''
        # Set dimension D
        if Data is not None:
            D = Data.X.shape[1]
        else:
            assert D is not None
            D = int(D)
        self.D = D

        # Set dimension E
        if Data is not None:
            E = Data.Xprev.shape[1]
        else:
            assert E is not None
            E = int(E)
        self.E = E

        self.K = 0
        self.inferType = inferType
        self.min_covar = min_covar
        self.createPrior(Data, D=D, E=E, **PriorArgs)
        self.Cache = dict()

    def createPrior(self,
                    Data,
                    D=None,
                    E=None,
                    nu=0,
                    B=None,
                    M=None,
                    V=None,
                    ECovMat=None,
                    sF=1.0,
                    VMat='eye',
                    sV=1.0,
                    MMat='zero',
                    sM=1.0,
                    **kwargs):
        ''' Initialize Prior ParamBag attribute.

        Post Condition
        ------
        Prior expected covariance matrix set to match provided value.
        '''
        if Data is None:
            if D is None:
                raise ValueError("Need to specify dimension D")
            if E is None:
                raise ValueError("Need to specify dimension E")
        if Data is not None:
            if D is None:
                D = Data.X.shape[1]
            else:
                assert D == Data.X.shape[1]
            if E is None:
                E = Data.Xprev.shape[1]
            else:
                assert E == Data.Xprev.shape[1]

        nu = np.maximum(nu, D + 2)
        if B is None:
            if ECovMat is None or isinstance(ECovMat, str):
                ECovMat = createECovMatFromUserInput(D, Data, ECovMat, sF)
            B = ECovMat * (nu - D - 1)
        B = as2D(B)

        if M is None:
            if MMat == 'zero':
                M = np.zeros((D, E))
            elif MMat == 'eye':
                assert D <= E
                M = sM * np.eye(D)
                M = np.hstack([M, np.zeros((D, E - D))])
                assert M.shape == (D, E)
            else:
                raise ValueError('Unrecognized MMat: %s' % (MMat))
        else:
            M = as2D(M)

        if V is None:
            if VMat == 'eye':
                V = sV * np.eye(E)
            elif VMat == 'same':
                assert D == E
                V = sV * ECovMat
            else:
                raise ValueError('Unrecognized VMat: %s' % (VMat))
        else:
            V = as2D(V)

        self.Prior = ParamBag(K=0, D=D, E=E)
        self.Prior.setField('nu', nu, dims=None)
        self.Prior.setField('B', B, dims=('D', 'D'))
        self.Prior.setField('V', V, dims=('E', 'E'))
        self.Prior.setField('M', M, dims=('D', 'E'))

    def get_mean_for_comp(self, k=None):
        if hasattr(self, 'EstParams'):
            return np.diag(self.EstParams.A[k])
        elif k is None or k == 'prior':
            return np.diag(self.Prior.M)
        else:
            return np.diag(self.Post.M[k])

    def get_covar_mat_for_comp(self, k=None):
        if hasattr(self, 'EstParams'):
            return self.EstParams.Sigma[k]
        elif k is None or k == 'prior':
            return self._E_CovMat()
        else:
            return self._E_CovMat(k)

    def get_name(self):
        return 'AutoRegGauss'

    def get_info_string(self):
        return 'Auto-Regressive Gaussian with full covariance.'

    def get_info_string_prior(self):
        msg = 'MatrixNormal-Wishart on each mean/prec matrix pair: A, Lam\n'
        if self.D > 2:
            sfx = ' ...'
        else:
            sfx = ''
        M = self.Prior.M[:2, :2]
        S = self._E_CovMat()[:2, :2]
        msg += 'E[ A ] = \n'
        msg += str(M) + sfx + '\n'
        msg += 'E[ Sigma ] = \n'
        msg += str(S) + sfx
        msg = msg.replace('\n', '\n  ')
        return msg

    def setEstParams(self,
                     obsModel=None,
                     SS=None,
                     LP=None,
                     Data=None,
                     A=None,
                     Sigma=None,
                     **kwargs):
        ''' Initialize EstParams attribute with fields A, Sigma.
        '''
        self.ClearCache()
        if obsModel is not None:
            self.EstParams = obsModel.EstParams.copy()
            self.K = self.EstParams.K
            return

        if LP is not None and Data is not None:
            SS = self.calcSummaryStats(Data, None, LP)

        if SS is not None:
            self.updateEstParams(SS)
        else:
            A = as3D(A)
            Sigma = as3D(Sigma)
            self.EstParams = ParamBag(K=A.shape[0], D=A.shape[1])
            self.EstParams.setField('A', A, dims=('K', 'D', 'D'))
            self.EstParams.setField('Sigma', Sigma, dims=('K', 'D', 'D'))

    def setEstParamsFromPost(self, Post):
        ''' Convert from Post to EstParams.
        '''
        D = Post.D
        self.EstParams = ParamBag(K=Post.K, D=D)
        A = Post.M.copy()
        Sigma = Post.B / (Post.nu - D - 1)[:, np.newaxis, np.newaxis]
        self.EstParams.setField('A', A, dims=('K', 'D', 'D'))
        self.EstParams.setField('Sigma', Sigma, dims=('K', 'D', 'D'))
        self.K = self.EstParams.K

    def setPostFactors(self,
                       obsModel=None,
                       SS=None,
                       LP=None,
                       Data=None,
                       nu=0,
                       B=0,
                       M=0,
                       V=0,
                       **kwargs):
        ''' Set Post attribute to provided values.
        '''
        self.ClearCache()
        if obsModel is not None:
            if hasattr(obsModel, 'Post'):
                self.Post = obsModel.Post.copy()
            else:
                self.setPostFromEstParams(obsModel.EstParams)
            self.K = self.Post.K
            return

        if LP is not None and Data is not None:
            SS = self.calcSummaryStats(Data, None, LP)

        if SS is not None:
            self.updatePost(SS)
        else:
            M = as3D(M)
            B = as3D(B)
            V = as3D(V)

            K, D, E = M.shape
            assert D == self.D
            assert E == self.E
            self.Post = ParamBag(K=K, D=self.D, E=self.E)
            self.Post.setField('nu', as1D(nu), dims=('K'))
            self.Post.setField('B', B, dims=('K', 'D', 'D'))
            self.Post.setField('M', M, dims=('K', 'D', 'E'))
            self.Post.setField('V', V, dims=('K', 'E', 'E'))
        self.K = self.Post.K

    def setPostFromEstParams(self, EstParams, Data=None, N=None):
        ''' Set Post attribute values based on provided EstParams.
        '''
        K = EstParams.K
        D = EstParams.D
        if Data is not None:
            N = Data.nObsTotal
        N = np.asarray(N, dtype=np.float)
        if N.ndim == 0:
            N = N / K * np.ones(K)

        nu = self.Prior.nu + N
        B = EstParams.Sigma * (nu - D - 1)[:, np.newaxis, np.newaxis]
        M = EstParams.A.copy()
        V = as3D(self.Prior.V)

        self.Post = ParamBag(K=K, D=D)
        self.Post.setField('nu', nu, dims=('K'))
        self.Post.setField('B', B, dims=('K', 'D', 'D'))
        self.Post.setField('M', M, dims=('K', 'D', 'D'))
        self.Post.setField('V', V, dims=('K', 'D', 'D'))
        self.K = self.Post.K

    def calcSummaryStats(self, Data, SS, LP, **kwargs):
        """ Fill in relevant sufficient stats fields into provided SS.

        Returns
        -------
        SS : bnpy.suffstats.SuffStatBag
        """
        return calcSummaryStats(Data, SS, LP, **kwargs)

    def forceSSInBounds(self, SS):
        ''' Force count vector N to remain positive

        This avoids numerical problems due to incremental add/subtract ops
        which can cause computations like

            x = 10.
            x += 1e-15
            x -= 10
            x -= 1e-15

        to be slightly different than zero instead of exactly zero.

        Post Condition
        --------------
        Field N is guaranteed to be positive.
        '''
        np.maximum(SS.N, 0, out=SS.N)

    def incrementSS(self, SS, k, x):
        pass

    def decrementSS(self, SS, k, x):
        pass

    def calcSummaryStatsForContigBlock(self, Data, SS=None, a=0, b=0):
        ''' Calculate sufficient stats for a single contiguous block of data
        '''
        D = Data.X.shape[1]
        E = Data.Xprev.shape[1]

        if SS is None:
            SS = SuffStatBag(K=1, D=D, E=E)
        elif not hasattr(SS, 'E'):
            SS._Fields.E = E

        ppT = dotATA(Data.Xprev[a:b])[np.newaxis, :, :]
        xxT = dotATA(Data.X[a:b])[np.newaxis, :, :]
        pxT = dotATB(Data.Xprev[a:b], Data.X[a:b])[np.newaxis, :, :]

        SS.setField('N', (b - a) * np.ones(1), dims='K')
        SS.setField('xxT', xxT, dims=('K', 'D', 'D'))
        SS.setField('ppT', ppT, dims=('K', 'E', 'E'))
        SS.setField('pxT', pxT, dims=('K', 'E', 'D'))
        return SS

    def calcLogSoftEvMatrix_FromEstParams(self, Data, **kwargs):
        ''' Compute log soft evidence matrix for Dataset under EstParams.

        Returns
        -------
        L : 2D array, size N x K
            L[n,k] = log p( data n | EstParams for comp k )
        '''
        K = self.EstParams.K
        L = np.empty((Data.nObs, K))
        for k in xrange(K):
            L[:, k] = - 0.5 * self.D * LOGTWOPI \
                - 0.5 * self._logdetSigma(k)  \
                - 0.5 * self._mahalDist_EstParam(Data.X, Data.Xprev, k)
        return L

    def _mahalDist_EstParam(self, X, Xprev, k):
        ''' Calc Mahalanobis distance from comp k to every row of X.

        Args
        ----
        X : 2D array, size N x D
        k : integer ID of comp

        Returns
        -------
        dist : 1D array, size N
        '''
        deltaX = X - np.dot(Xprev, self.EstParams.A[k].T)
        Q = np.linalg.solve(self.GetCached('cholSigma', k), deltaX.T)
        Q *= Q
        return np.sum(Q, axis=0)

    def _cholSigma(self, k):
        ''' Calculate lower cholesky decomposition of Sigma[k]

        Returns
        -------
        L : 2D array, size D x D, lower triangular
            Sigma = np.dot(L, L.T)
        '''
        return scipy.linalg.cholesky(self.EstParams.Sigma[k], lower=1)

    def _logdetSigma(self, k):
        ''' Calculate log determinant of EstParam.Sigma for comp k

        Returns
        -------
        logdet : scalar real
        '''
        return 2 * np.sum(np.log(np.diag(self.GetCached('cholSigma', k))))

    def updateEstParams_MaxLik(self, SS):
        ''' Update attribute EstParams for all comps given suff stats.

        Update uses the maximum likelihood objective for point estimation.

        Post Condition
        ---------
        Attributes K and EstParams updated in-place.
        '''
        self.ClearCache()
        if not hasattr(self, 'EstParams') or self.EstParams.K != SS.K:
            self.EstParams = ParamBag(K=SS.K, D=SS.D)
        minCovMat = self.min_covar * np.eye(SS.D)
        A = np.zeros((SS.K, self.D, self.D))
        Sigma = np.zeros((SS.K, self.D, self.D))
        for k in xrange(SS.K):
            # Add small pos multiple of identity to make invertible
            # TODO: This is source of potential stability issues.
            A[k] = np.linalg.solve(SS.ppT[k] + minCovMat, SS.pxT[k]).T
            Sigma[k] = SS.xxT[k] \
                - 2 * np.dot(SS.pxT[k].T, A[k].T) \
                + np.dot(A[k], np.dot(SS.ppT[k], A[k].T))
            Sigma[k] /= SS.N[k]
            # Sigma[k] = 0.5 * (Sigma[k] + Sigma[k].T) # symmetry!
            Sigma[k] += minCovMat
        self.EstParams.setField('A', A, dims=('K', 'D', 'D'))
        self.EstParams.setField('Sigma', Sigma, dims=('K', 'D', 'D'))
        self.K = SS.K

    def updateEstParams_MAP(self, SS):
        ''' Update attribute EstParams for all comps given suff stats.

        Update uses the MAP objective for point estimation.

        Post Condition
        ---------
        Attributes K and EstParams updated in-place.
        '''
        self.ClearCache()
        if not hasattr(self, 'EstParams') or self.EstParams.K != SS.K:
            self.EstParams = ParamBag(K=SS.K, D=SS.D)
        raise NotImplemented('TODO')

    def updatePost(self, SS):
        ''' Update attribute Post for all comps given suff stats.

        Update uses the variational objective.

        Post Condition
        ---------
        Attributes K and Post updated in-place.
        '''
        self.ClearCache()
        if not hasattr(self, 'Post') or self.Post.K != SS.K:
            self.Post = ParamBag(K=SS.K, D=SS.D, E=SS.E)
        elif not hasattr(self.Post, 'E'):
            self.Post.E = SS.E

        nu, B, M, V = self.calcPostParams(SS)
        self.Post.setField('nu', nu, dims=('K'))
        self.Post.setField('B', B, dims=('K', 'D', 'D'))
        self.Post.setField('M', M, dims=('K', 'D', 'E'))
        self.Post.setField('V', V, dims=('K', 'E', 'E'))
        self.K = SS.K

    def calcPostParams(self, SS):
        ''' Calc updated posterior params for all comps given suff stats

        These params define the common-form of the exponential family
        Normal-Wishart posterior distribution over mu, diag(Lambda)

        Returns
        --------
        nu : 1D array, size K
        B : 3D array, size K x D x D
            each B[k] symmetric and positive definite
        M : 3D array, size K x D x E
        V : 3D array, size K x E x E
        '''
        Prior = self.Prior
        nu = Prior.nu + SS.N

        B_MVM = Prior.B + np.dot(Prior.M, np.dot(Prior.V, Prior.M.T))
        B = SS.xxT + B_MVM[np.newaxis, :]
        V = SS.ppT + Prior.V[np.newaxis, :]
        M = np.zeros((SS.K, SS.D, SS.E))
        for k in xrange(B.shape[0]):
            M[k] = np.linalg.solve(V[k],
                                   SS.pxT[k] + np.dot(Prior.V, Prior.M.T)).T
            B[k] -= np.dot(M[k], np.dot(V[k], M[k].T))
        return nu, B, M, V

    def calcPostParamsForComp(self, SS, kA=None, kB=None):
        ''' Calc params (nu, B, m, kappa) for specific comp, given suff stats

        These params define the common-form of the exponential family
        Normal-Wishart posterior distribution over mu[k], diag(Lambda)[k]

        Returns
        --------
        nu : positive scalar
        B : 2D array, size D x D, symmetric and positive definite
        M : 2D array, size D x D
        V : 2D array, size D x D
        '''
        if kA is not None and kB is not None:
            N = SS.N[kA] + SS.N[kB]
            xxT = SS.xxT[kA] + SS.xxT[kB]
            ppT = SS.ppT[kA] + SS.ppT[kB]
            pxT = SS.pxT[kA] + SS.pxT[kB]
        elif kA is not None:
            N = SS.N[kA]
            xxT = SS.xxT[kA]
            ppT = SS.ppT[kA]
            pxT = SS.pxT[kA]
        else:
            raise ValueError('Need to specify specific component.')
        Prior = self.Prior
        nu = Prior.nu + N
        B_MVM = Prior.B + np.dot(Prior.M, np.dot(Prior.V, Prior.M.T))
        B = xxT + B_MVM
        V = ppT + Prior.V
        M = np.linalg.solve(V, pxT + np.dot(Prior.V, Prior.M.T)).T
        B -= np.dot(M, np.dot(V, M.T))
        return nu, B, M, V

    def updatePost_stochastic(self, SS, rho):
        ''' Update attribute Post for all comps given suff stats

        Update uses the stochastic variational formula.

        Post Condition
        ---------
        Attributes K and Post updated in-place.
        '''
        assert hasattr(self, 'Post')
        assert self.Post.K == SS.K
        self.ClearCache()

        self.convertPostToNatural()
        n_nu, n_V, n_VMT, n_B = self.calcNaturalPostParams(SS)
        Post = self.Post
        Post.nu[:] = (1 - rho) * Post.nu + rho * n_nu
        Post.V[:] = (1 - rho) * Post.V + rho * n_V

        Post.n_VMT[:] = (1 - rho) * Post.n_VMT + rho * n_VMT
        Post.n_B[:] = (1 - rho) * Post.n_B + rho * n_B
        self.convertPostToCommon()

    def calcNaturalPostParams(self, SS):
        ''' Calc updated natural params for all comps given suff stats

        These params define the natural-form of the exponential family
        Normal-Wishart posterior distribution over mu, Lambda

        Returns
        --------
        nu : 1D array, size K
        Bnat : 3D array, size K x D x D
        '''
        Prior = self.Prior
        VMT = np.dot(Prior.V, Prior.M.T)
        MVMT = np.dot(Prior.M, VMT)

        n_nu = Prior.nu + SS.N
        n_V = Prior.V + SS.ppT
        n_VMT = VMT + SS.pxT
        n_B = Prior.B + MVMT + SS.xxT
        return n_nu, n_V, n_VMT, n_B

    def convertPostToNatural(self):
        ''' Convert current posterior params from common to natural form

        Post Condition
        --------
        Attribute Post has new fields n_VMT, n_B.
        '''
        Post = self.Post
        # These two are done implicitly
        # Post.setField('nu', Post.nu, dims=None)
        # Post.setField('V', Post.nu, dims=None)
        VMT = np.zeros((self.K, self.D, self.D))
        for k in xrange(self.K):
            VMT[k] = np.dot(Post.V[k], Post.M[k].T)
        Post.setField('n_VMT', VMT, dims=('K', 'D', 'D'))

        Bnat = np.empty((self.K, self.D, self.D))
        for k in xrange(self.K):
            Bnat[k] = Post.B[k] + np.dot(Post.M[k], VMT[k])
        Post.setField('n_B', Bnat, dims=('K', 'D', 'D'))

    def convertPostToCommon(self):
        ''' Convert current posterior params from natural to common form

        Post Condition
        --------
        Attribute Post has new fields n_VMT, n_B.
        '''
        Post = self.Post
        # These two are done implicitly
        # Post.setField('nu', Post.nu, dims=None)
        # Post.setField('V', Post.nu, dims=None)

        M = np.zeros((self.K, self.D, self.D))
        for k in xrange(self.K):
            M[k] = np.linalg.solve(Post.V[k], Post.n_VMT[k]).T
        Post.setField('M', M, dims=('K', 'D', 'D'))

        B = np.empty((self.K, self.D, self.D))
        for k in xrange(self.K):
            B[k] = Post.n_B[k] - np.dot(Post.M[k], Post.n_VMT[k])
        Post.setField('B', B, dims=('K', 'D', 'D'))

    def calcLogSoftEvMatrix_FromPost(self, Data, **kwargs):
        ''' Calculate expected log soft ev matrix under Post.

        Returns
        ------
        L : 2D array, size N x K
        '''
        K = self.Post.K
        L = np.zeros((Data.nObs, K))
        for k in xrange(K):
            L[:, k] = - 0.5 * self.D * LOGTWOPI \
                + 0.5 * self.GetCached('E_logdetL', k)  \
                - 0.5 * self._mahalDist_Post(Data.X, Data.Xprev, k)
        return L

    def _mahalDist_Post(self, X, Xprev, k):
        ''' Calc expected mahalonobis distance from comp k to each data atom

        Returns
        --------
        distvec : 1D array, size nObs
               distvec[n] gives E[ (x-\mu) \Lam (x-\mu) ] for comp k
        '''
        # Calc: (x-M*xprev)' * B * (x-M*xprev)
        deltaX = X - np.dot(Xprev, self.Post.M[k].T)
        Q = np.linalg.solve(self.GetCached('cholB', k), deltaX.T)
        Q *= Q

        # Calc: xprev' * V * xprev
        Qprev = np.linalg.solve(self.GetCached('cholV', k), Xprev.T)
        Qprev *= Qprev

        return self.Post.nu[k] * np.sum(Q, axis=0) \
            + self.D * np.sum(Qprev, axis=0)

    def calcELBO_Memoized(self, SS, afterMStep=False, **kwargs):
        ''' Calculate obsModel's objective using suff stats SS and Post.

        Args
        -------
        SS : bnpy SuffStatBag
        afterMStep : boolean flag
            if 1, elbo calculated assuming M-step just completed

        Returns
        -------
        obsELBO : scalar float
            Equal to E[ log p(x) + log p(phi) - log q(phi)]
        '''
        elbo = np.zeros(SS.K)
        Post = self.Post
        Prior = self.Prior
        for k in xrange(SS.K):
            elbo[k] = c_Diff(
                Prior.nu,
                self.GetCached('logdetB'),
                Prior.M,
                self.GetCached('logdetV'),
                Post.nu[k],
                self.GetCached('logdetB', k),
                Post.M[k],
                self.GetCached('logdetV', k),
            )
            if not afterMStep:
                aDiff = SS.N[k] + Prior.nu - Post.nu[k]
                bDiff = SS.xxT[k] + Prior.B + \
                    np.dot(Prior.M, np.dot(Prior.V, Prior.M.T)) - \
                    Post.B[k] - \
                    np.dot(Post.M[k], np.dot(Post.V[k], Post.M[k].T))
                cDiff = SS.pxT[k] + np.dot(Prior.V, Prior.M.T) - \
                    np.dot(Post.V[k], Post.M[k].T)
                dDiff = SS.ppT[k] + Prior.V - Post.V[k]
                elbo[k] += 0.5 * aDiff * self.GetCached('E_logdetL', k) \
                    - 0.5 * self._trace__E_L(bDiff, k) \
                    + self._trace__E_LA(cDiff, k) \
                    - 0.5 * self._trace__E_ALA(dDiff, k)
        return elbo.sum() - 0.5 * np.sum(SS.N) * SS.D * LOGTWOPI

    def getDatasetScale(self, SS):
        ''' Get number of observed scalars in dataset from suff stats.

        Used for normalizing the ELBO so it has reasonable range.

        Returns
        ---------
        s : scalar positive integer
        '''
        return SS.N.sum() * SS.D

    def calcHardMergeGap(self, SS, kA, kB):
        ''' Calculate change in ELBO after a hard merge applied to this model

        Returns
        ---------
        gap : scalar real, indicates change in ELBO after merge of kA, kB
        '''
        Post = self.Post
        Prior = self.Prior
        cA = c_Func(Post.nu[kA], Post.B[kA], Post.M[kA], Post.V[kA])
        cB = c_Func(Post.nu[kB], Post.B[kB], Post.M[kB], Post.V[kB])

        cPrior = c_Func(Prior.nu, Prior.B, Prior.M, Prior.V)
        nu, B, M, V = self.calcPostParamsForComp(SS, kA, kB)
        cAB = c_Func(nu, B, M, V)
        return cA + cB - cPrior - cAB

    def calcHardMergeGap_AllPairs(self, SS):
        ''' Calculate change in ELBO for all candidate hard merge pairs

        Returns
        ---------
        Gap : 2D array, size K x K, upper-triangular entries non-zero
              Gap[j,k] : scalar change in ELBO after merge of k into j
        '''
        Post = self.Post
        Prior = self.Prior
        cPrior = c_Func(Prior.nu, Prior.B, Prior.M, Prior.V)
        c = np.zeros(SS.K)
        for k in xrange(SS.K):
            c[k] = c_Func(Post.nu[k], Post.B[k], Post.M[k], Post.V[k])

        Gap = np.zeros((SS.K, SS.K))
        for j in xrange(SS.K):
            for k in xrange(j + 1, SS.K):
                nu, B, M, V = self.calcPostParamsForComp(SS, j, k)
                cjk = c_Func(nu, B, M, V)
                Gap[j, k] = c[j] + c[k] - cPrior - cjk
        return Gap

    def calcHardMergeGap_SpecificPairs(self, SS, PairList):
        ''' Calc change in ELBO for specific list of candidate hard merge pairs

        Returns
        ---------
        Gaps : 1D array, size L
               Gaps[j] = scalar change in ELBO after merge of PairList[j]
        '''
        Gaps = np.zeros(len(PairList))
        for ii, (kA, kB) in enumerate(PairList):
            Gaps[ii] = self.calcHardMergeGap(SS, kA, kB)
        return Gaps

    def calcHardMergeGap_SpecificPairSS(self, SS1, SS2):
        ''' Calc change in ELBO for merger of two components.
        '''
        assert SS1.K == 1
        assert SS2.K == 1

        Prior = self.Prior
        cPrior = c_Func(Prior.nu, Prior.B, Prior.M, Prior.V)

        # Compute cumulants of individual states 1 and 2
        c1 = c_Func(*self.calcPostParamsForComp(SS1, 0))
        c2 = c_Func(*self.calcPostParamsForComp(SS2, 0))

        # Compute cumulant of merged state 1&2
        SS12 = SS1 + SS2
        c12 = c_Func(*self.calcPostParamsForComp(SS12, 0))
        return c1 + c2 - cPrior - c12

    def calcLogMargLikForComp(self, SS, kA, kB=None, **kwargs):
        ''' Calc log marginal likelihood of data assigned to component

        Up to an additive constant that depends on the prior.

        Requires Data pre-summarized into sufficient stats for each comp.
        If multiple comp IDs are provided,
        we combine into a "merged" component.

        Args
        -------
        SS : bnpy suff stats object
        kA : integer ID of target component to compute likelihood for
        kB : (optional) integer ID of second component.
             If provided, we merge kA, kB into one component for calculation.
        Returns
        -------
        logM : scalar real
               logM = log p( data assigned to comp kA )
                      computed up to an additive constant
        '''
        nu, B, M, V = self.calcPostParamsForComp(SS, kA, kB)
        return -1 * c_Func(nu, B, M, V)

    def calcMargLik(self, SS):
        ''' Calc log marginal likelihood across all comps, given suff stats

        Returns
        --------
        logM : scalar real
               logM = \sum_{k=1}^K log p( data assigned to comp k | Prior)
        '''
        return self.calcMargLik_CFuncForLoop(SS)

    def calcMargLik_CFuncForLoop(self, SS):
        Prior = self.Prior
        logp = np.zeros(SS.K)
        for k in xrange(SS.K):
            nu, B, m, kappa = self.calcPostParamsForComp(SS, k)
            logp[k] = c_Diff(Prior.nu, Prior.B, Prior.m, Prior.kappa, nu, B, m,
                             kappa)
        return np.sum(logp) - 0.5 * np.sum(SS.N) * LOGTWOPI

    def calcPredProbVec_Unnorm(self, SS, x):
        ''' Calculate predictive probability that each comp assigns to vector x

        Returns
        --------
        p : 1D array, size K, all entries positive
            p[k] \propto p( x | SS for comp k)
        '''
        return self._calcPredProbVec_Fast(SS, x)

    def _calcPredProbVec_cFunc(self, SS, x):
        raise NotImplementedError('TODO')

    def _calcPredProbVec_Fast(self, SS, x):
        raise NotImplementedError('TODO')

    def _Verify_calcPredProbVec(self, SS, x):
        raise NotImplementedError('TODO')

    def _E_CovMat(self, k=None):
        if k is None:
            B = self.Prior.B
            nu = self.Prior.nu
        else:
            B = self.Post.B[k]
            nu = self.Post.nu[k]
        return B / (nu - self.D - 1)

    def _cholB(self, k=None):
        if k == 'all':
            retMat = np.zeros((self.K, self.D, self.D))
            for k in xrange(self.K):
                retMat[k] = scipy.linalg.cholesky(self.Post.B[k], lower=True)
            return retMat
        elif k is None:
            B = self.Prior.B
        else:
            B = self.Post.B[k]
        return scipy.linalg.cholesky(B, lower=True)

    def _logdetB(self, k=None):
        cholB = self.GetCached('cholB', k)
        return 2 * np.sum(np.log(np.diag(cholB)))

    def _cholV(self, k=None):
        if k == 'all':
            retMat = np.zeros((self.K, self.D, self.D))
            for k in xrange(self.K):
                retMat[k] = scipy.linalg.cholesky(self.Post.V[k], lower=True)
            return retMat
        elif k is None:
            V = self.Prior.V
        else:
            V = self.Post.V[k]
        return scipy.linalg.cholesky(V, lower=True)

    def _logdetV(self, k=None):
        cholV = self.GetCached('cholV', k)
        return 2 * np.sum(np.log(np.diag(cholV)))

    def _E_logdetL(self, k=None):
        dvec = np.arange(1, self.D + 1, dtype=np.float)
        if k is 'all':
            dvec = dvec[:, np.newaxis]
            retVec = self.D * LOGTWO * np.ones(self.K)
            for kk in xrange(self.K):
                retVec[kk] -= self.GetCached('logdetB', kk)
            nuT = self.Post.nu[np.newaxis, :]
            retVec += np.sum(digamma(0.5 * (nuT + 1 - dvec)), axis=0)
            return retVec
        elif k is None:
            nu = self.Prior.nu
        else:
            nu = self.Post.nu[k]
        return self.D * LOGTWO \
            - self.GetCached('logdetB', k) \
            + np.sum(digamma(0.5 * (nu + 1 - dvec)))

    def _E_LA(self, k=None):
        if k is None:
            nu = self.Prior.nu
            B = self.Prior.B
            M = self.Prior.M
        else:
            nu = self.Post.nu[k]
            B = self.Post.B[k]
            M = self.Post.M[k]
        return nu * np.linalg.solve(B, M)

    def _E_ALA(self, k=None):
        if k is None:
            nu = self.Prior.nu
            M = self.Prior.M
            B = self.Prior.B
            V = self.Prior.V
        else:
            nu = self.Post.nu[k]
            M = self.Post.M[k]
            B = self.Post.B[k]
            V = self.Post.V[k]
        Q = np.linalg.solve(self.GetCached('cholB', k), M)
        return self.D * np.linalg.inv(V) \
            + nu * np.dot(Q.T, Q)

    def _trace__E_L(self, S, k=None):
        if k is None:
            nu = self.Prior.nu
            B = self.Prior.B
        else:
            nu = self.Post.nu[k]
            B = self.Post.B[k]
        return nu * np.trace(np.linalg.solve(B, S))

    def _trace__E_LA(self, S, k=None):
        E_LA = self._E_LA(k)
        return np.trace(np.dot(E_LA, S))

    def _trace__E_ALA(self, S, k=None):
        E_ALA = self._E_ALA(k)
        return np.trace(np.dot(E_ALA, S))

    def getSerializableParamsForLocalStep(self):
        """ Get compact dict of params for local step.

        Returns
        -------
        Info : dict
        """
        if self.inferType == 'EM':
            raise NotImplementedError('TODO')
        return dict(
            inferType=self.inferType,
            K=self.K,
            D=self.D,
        )

    def fillSharedMemDictForLocalStep(self, ShMem=None):
        """ Get dict of shared mem arrays needed for parallel local step.

        Returns
        -------
        ShMem : dict of RawArray objects
        """
        if ShMem is None:
            ShMem = dict()
        if 'nu' in ShMem:
            fillSharedMemArray(ShMem['nu'], self.Post.nu)
            fillSharedMemArray(ShMem['M'], self.Post.M)
            fillSharedMemArray(ShMem['cholV'], self._cholV('all'))
            fillSharedMemArray(ShMem['cholB'], self._cholB('all'))
            fillSharedMemArray(ShMem['E_logdetL'], self._E_logdetL('all'))

        else:
            ShMem['nu'] = numpyToSharedMemArray(self.Post.nu)
            ShMem['M'] = numpyToSharedMemArray(self.Post.M)
            ShMem['cholV'] = numpyToSharedMemArray(self._cholV('all'))
            ShMem['cholB'] = numpyToSharedMemArray(self._cholB('all'))
            ShMem['E_logdetL'] = numpyToSharedMemArray(self._E_logdetL('all'))
        return ShMem

    def getLocalAndSummaryFunctionHandles(self):
        """ Get function handles for local step and summary step

        Useful for parallelized algorithms.

        Returns
        -------
        calcLocalParams : f handle
        calcSummaryStats : f handle
        """
        return calcLocalParams, calcSummaryStats
예제 #8
0
파일: BernObsModel.py 프로젝트: Vimos/bnpy
class BernObsModel(AbstractObsModel):

    ''' Bernoulli data generation model for binary vectors.

    Attributes for Prior (Beta)
    --------
    lam1 : 1D array, size D
        pseudo-count of positive (binary value=1) observations
    lam0 : 1D array, size D
        pseudo-count of negative (binary value=0) observations

    Attributes for k-th component of EstParams (EM point estimates)
    ---------
    phi[k] : 1D array, size D
        phi[k] is a vector of positive numbers in range [0, 1]
        phi[k,d] is probability that dimension d has binary value 1.

    Attributes for k-th component of Post (VB parameter)
    ---------
    lam1[k] : 1D array, size D
    lam0[k] : 1D array, size D
    '''

    def __init__(self, inferType='EM', D=0,
                 Data=None, CompDims=('K',), **PriorArgs):
        ''' Initialize bare obsmodel with valid prior hyperparameters.

        Resulting object lacks either EstParams or Post,
        which must be created separately (see init_global_params).
        '''
        if Data is not None:
            self.D = Data.dim
        elif D > 0:
            self.D = int(D)
        self.K = 0
        self.inferType = inferType
        self.createPrior(Data, **PriorArgs)
        self.Cache = dict()
        if isinstance(CompDims, tuple):
            self.CompDims = CompDims
        elif isinstance(CompDims, str):
            self.CompDims = tuple(CompDims)
        assert isinstance(self.CompDims, tuple)

    def createPrior(
            self, Data, lam1=1.0, lam0=1.0, 
            priorMean=None, priorScale=None,
            eps_phi=1e-8, **kwargs):
        ''' Initialize Prior ParamBag attribute.
        '''
        D = self.D
        self.eps_phi = eps_phi
        self.Prior = ParamBag(K=0, D=D)
        if priorMean is None or priorMean.lower().count('none'):
            lam1 = np.asarray(lam1, dtype=np.float)
            lam0 = np.asarray(lam0, dtype=np.float)
        elif isinstance(priorMean, str) and priorMean.count("data"):
            assert priorScale is not None
            priorScale = float(priorScale)
            if hasattr(Data, 'word_id'):
                X = Data.getDocTypeBinaryMatrix()
                dataMean = np.mean(X, axis=0)
            else:
                dataMean = np.mean(Data.X, axis=0)
            dataMean = np.minimum(dataMean, 0.95) # Make prior more smooth
            dataMean = np.maximum(dataMean, 0.05)
            lam1 = priorScale * dataMean
            lam0 = priorScale * (1-dataMean)
        else:
            assert priorScale is not None
            priorScale = float(priorScale)
            priorMean = np.asarray(priorMean, dtype=np.float64)
            lam1 = priorScale * priorMean
            lam0 = priorScale * (1-priorMean)
        if lam1.ndim == 0:
            lam1 = lam1 * np.ones(D)
        if lam0.ndim == 0:
            lam0 = lam0 * np.ones(D)
        assert lam1.size == D
        assert lam0.size == D
        self.Prior.setField('lam1', lam1, dims=('D',))
        self.Prior.setField('lam0', lam0, dims=('D',))

    def get_name(self):
        return 'Bern'

    def get_info_string(self):
        return 'Bernoulli over %d binary attributes.' % (self.D)

    def get_info_string_prior(self):
        msg = 'Beta over %d attributes.\n' % (self.D)
        if self.D > 2:
            sfx = ' ...'
        else:
            sfx = ''
        msg += 'lam1 = %s%s\n' % (str(self.Prior.lam1[:2]), sfx)
        msg += 'lam0 = %s%s\n' % (str(self.Prior.lam0[:2]), sfx)
        msg = msg.replace('\n', '\n  ')
        return msg

    def setupWithAllocModel(self, allocModel):
        ''' Setup expected dimensions of components.

        Args
        ----
        allocModel : instance of bnpy.allocmodel.AllocModel
        '''
        self.CompDims = allocModel.getCompDims()
        assert isinstance(self.CompDims, tuple)

        allocModelName = str(type(allocModel)).lower()
        if allocModelName.count('hdp') or allocModelName.count('topic'):
            self.DataAtomType = 'word'
        else:
            self.DataAtomType = 'doc'

    def setEstParams(self, obsModel=None, SS=None, LP=None, Data=None,
                     phi=None,
                     **kwargs):
        ''' Set attribute EstParams to provided values.
        '''
        self.ClearCache()
        if obsModel is not None:
            self.EstParams = obsModel.EstParams.copy()
            self.K = self.EstParams.K
            return

        if LP is not None and Data is not None:
            SS = self.calcSummaryStats(Data, None, LP)

        if SS is not None:
            self.updateEstParams(SS)
        else:
            self.EstParams = ParamBag(K=phi.shape[0], D=phi.shape[1])
            self.EstParams.setField('phi', phi, dims=('K', 'D',))
        self.K = self.EstParams.K

    def setEstParamsFromPost(self, Post=None):
        ''' Set attribute EstParams based on values in Post.
        '''
        if Post is None:
            Post = self.Post
        self.EstParams = ParamBag(K=Post.K, D=Post.D)
        phi = Post.lam1 / (Post.lam1 + Post.lam0)
        self.EstParams.setField('phi', phi, dims=('K', 'D',))
        self.K = self.EstParams.K

    def setPostFactors(self, obsModel=None, SS=None, LP=None, Data=None,
                       lam1=None, lam0=None, **kwargs):
        ''' Set attribute Post to provided values.
        '''
        self.ClearCache()
        if obsModel is not None:
            if hasattr(obsModel, 'Post'):
                self.Post = obsModel.Post.copy()
                self.K = self.Post.K
            else:
                self.setPostFromEstParams(obsModel.EstParams)
            return

        if LP is not None and Data is not None:
            SS = self.calcSummaryStats(Data, None, LP)

        if SS is not None:
            self.updatePost(SS)
        else:
            lam1 = as2D(lam1)
            lam0 = as2D(lam0)
            D = lam1.shape[-1]
            if self.D != D:
                if not lam1.shape[0] == self.D:
                    raise ValueError("Bad dimension for lam1, lam0")
                lam1 = lam1.T.copy()
                lam0 = lam0.T.copy()

            K = lam1.shape[0]
            self.Post = ParamBag(K=K, D=self.D)
            self.Post.setField('lam1', lam1, dims=self.CompDims + ('D',))
            self.Post.setField('lam0', lam0, dims=self.CompDims + ('D',))
        self.K = self.Post.K

    def setPostFromEstParams(self, EstParams, Data=None, nTotalTokens=1,
                             **kwargs):
        ''' Set attribute Post based on values in EstParams.
        '''
        K = EstParams.K
        D = EstParams.D

        WordCounts = EstParams.phi * nTotalTokens
        lam1 = WordCounts + self.Prior.lam1
        lam0 = (1 - WordCounts) + self.Prior.lam0

        self.Post = ParamBag(K=K, D=D)
        self.Post.setField('lam1', lam1, dims=('K', 'D'))
        self.Post.setField('lam0', lam0, dims=('K', 'D'))
        self.K = K

    def calcSummaryStats(self, Data, SS, LP, **kwargs):
        ''' Calculate summary statistics for given dataset and local parameters

        Returns
        --------
        SS : SuffStatBag object, with K components.
        '''
        return calcSummaryStats(Data, SS, LP,
            DataAtomType=self.DataAtomType, **kwargs)

    def calcSummaryStatsForContigBlock(self, Data, a=0, b=0, **kwargs):
        ''' Calculate summary stats for a contiguous block of the data.

        Returns
        --------
        SS : SuffStatBag object, with 1 component.
        '''
        Xab = Data.X[a:b]  # 2D array, Nab x D
        CountON = np.sum(Xab, axis=0)[np.newaxis, :]
        CountOFF = (b - a) - CountON

        SS = SuffStatBag(K=1, D=Data.dim)
        SS.setField('N', np.asarray([b - a], dtype=np.float64), dims='K')
        SS.setField('Count1', CountON, dims=('K', 'D'))
        SS.setField('Count0', CountOFF, dims=('K', 'D'))
        return SS

    def forceSSInBounds(self, SS):
        ''' Force count vectors to remain positive

        This avoids numerical problems due to incremental add/subtract ops
        which can cause computations like
            x = 10.
            x += 1e-15
            x -= 10
            x -= 1e-15
        to be slightly different than zero instead of exactly zero.

        Post Condition
        -------
        Fields Count1, Count0 guaranteed to be positive.
        '''
        np.maximum(SS.Count1, 0, out=SS.Count1)
        np.maximum(SS.Count0, 0, out=SS.Count0)

    def incrementSS(self, SS, k, Data, docID):
        raise NotImplementedError('TODO')

    def decrementSS(self, SS, k, Data, docID):
        raise NotImplementedError('TODO')

    def calcLogSoftEvMatrix_FromEstParams(self, Data, **kwargs):
        ''' Compute log soft evidence matrix for Dataset under EstParams.

        Returns
        ---------
        L : 2D array, N x K
        '''
        logphiT = np.log(self.EstParams.phi.T)  # D x K matrix
        log1mphiT = np.log(1.0 - self.EstParams.phi.T)  # D x K matrix
        return np.dot(Data.X, logphiT) + np.dot(1 - Data.X, log1mphiT)

    def updateEstParams_MaxLik(self, SS):
        ''' Update attribute EstParams for all comps given suff stats.

        Update uses the maximum likelihood objective for point estimation.

        Post Condition
        ---------
        Attributes K and EstParams updated in-place.
        '''
        self.ClearCache()
        self.K = SS.K
        if not hasattr(self, 'EstParams') or self.EstParams.K != SS.K:
            self.EstParams = ParamBag(K=SS.K, D=SS.D)
        phi = SS.Count1 / (SS.Count1 + SS.Count0)
        # prevent entries from reaching exactly 0
        np.maximum(phi, self.eps_phi, out=phi)
        np.minimum(phi, 1.0 - self.eps_phi, out=phi)
        self.EstParams.setField('phi', phi, dims=('K', 'D'))

    def updateEstParams_MAP(self, SS):
        ''' Update attribute EstParams for all comps given suff stats.

        Update uses the MAP objective for point estimation.

        Post Condition
        ---------
        Attributes K and EstParams updated in-place.
        '''
        self.ClearCache()
        if not hasattr(self, 'EstParams') or self.EstParams.K != SS.K:
            self.EstParams = ParamBag(K=SS.K, D=SS.D)
        phi_numer = SS.Count1 + self.Prior.lam1 - 1
        phi_denom = SS.Count1 + SS.Count0 + \
            self.Prior.lam1 + self.Prior.lam0 - 2
        phi = phi_numer / phi_denom
        self.EstParams.setField('phi', phi, dims=('K', 'D'))

    def updatePost(self, SS):
        ''' Update attribute Post for all comps given suff stats.

        Update uses the variational objective.

        Post Condition
        ---------
        Attributes K and Post updated in-place.
        '''
        self.ClearCache()
        if not hasattr(self, 'Post') or self.Post.K != SS.K:
            self.Post = ParamBag(K=SS.K, D=SS.D)

        lam1, lam0 = self.calcPostParams(SS)
        self.Post.setField('lam1', lam1, dims=self.CompDims + ('D',))
        self.Post.setField('lam0', lam0, dims=self.CompDims + ('D',))
        self.K = SS.K

    def calcPostParams(self, SS):
        ''' Calc posterior parameters for all comps given suff stats.

        Returns
        --------
        lam1 : 2D array, K x D (or K x K x D if relational)
        lam0 : 2D array, K x D (or K x K x D if relational)
        '''
        if SS.Count1.ndim == 2:
            lam1 = SS.Count1 + self.Prior.lam1[np.newaxis, :]
            lam0 = SS.Count0 + self.Prior.lam0[np.newaxis, :]
        elif SS.Count1.ndim == 3:
            lam1 = SS.Count1 + self.Prior.lam1[np.newaxis, np.newaxis, :]
            lam0 = SS.Count0 + self.Prior.lam0[np.newaxis, np.newaxis, :]
        return lam1, lam0

    def calcPostParamsForComp(self, SS, kA=None, kB=None):
        ''' Calc params (lam) for specific comp, given suff stats

            These params define the common-form of the exponential family
            Dirichlet posterior distribution over parameter vector phi

            Returns
            --------
            lam : 1D array, size D
        '''
        if kB is None:
            lam1_k = SS.Count1[kA].copy()
            lam0_k = SS.Count0[kA].copy()
        else:
            lam1_k = SS.Count1[kA] + SS.Count1[kB]
            lam0_k = SS.Count0[kA] + SS.Count0[kB]
        lam1_k += self.Prior.lam1
        lam0_k += self.Prior.lam0
        return lam1_k, lam0_k

    def updatePost_stochastic(self, SS, rho):
        ''' Update attribute Post for all comps given suff stats

        Update uses the stochastic variational formula.

        Post Condition
        ---------
        Attributes K and Post updated in-place.
        '''
        assert hasattr(self, 'Post')
        assert self.Post.K == SS.K
        self.ClearCache()

        lam1, lam0 = self.calcPostParams(SS)
        Post = self.Post
        Post.lam1[:] = (1 - rho) * Post.lam1 + rho * lam1
        Post.lam0[:] = (1 - rho) * Post.lam0 + rho * lam0

    def convertPostToNatural(self):
        ''' Convert current posterior params from common to natural form
        '''
        pass

    def convertPostToCommon(self):
        ''' Convert current posterior params from natural to common form
        '''
        pass

    def calcLogSoftEvMatrix_FromPost(self, Data, **kwargs):
        ''' Calculate expected log soft ev matrix under Post.

        Returns
        ------
        C : 2D array, size N x K
        '''
        # ElogphiT : vocab_size x K
        ElogphiT, Elog1mphiT = self.GetCached('E_logphiT_log1mphiT', 'all')
        return calcLogSoftEvMatrix_FromPost(
            Data,
            ElogphiT=ElogphiT,
            Elog1mphiT=Elog1mphiT,
            DataAtomType=self.DataAtomType, **kwargs)

    def calcELBO_Memoized(self, SS, returnVec=0, afterMStep=False, **kwargs):
        """ Calculate obsModel's objective using suff stats SS and Post.

        Args
        -------
        SS : bnpy SuffStatBag
        afterMStep : boolean flag
            if 1, elbo calculated assuming M-step just completed

        Returns
        -------
        obsELBO : scalar float
            Equal to E[ log p(x) + log p(phi) - log q(phi)]
        """
        Post = self.Post
        Prior = self.Prior
        if not afterMStep:
            ElogphiT = self.GetCached('E_logphiT', 'all')
            Elog1mphiT = self.GetCached('E_log1mphiT', 'all')
            # with relational/graph datasets, these have shape D x K x K
            # otherwise, these have shape D x K
        if self.CompDims == ('K',):
            # Typical case: K x D
            assert Post._FieldDims['lam1'] == ('K', 'D')
            L_perComp = np.zeros(SS.K)
            for k in xrange(SS.K):
                L_perComp[k] = c_Diff(Prior.lam1, Prior.lam0,
                                      Post.lam1[k], Post.lam0[k])
                if not afterMStep:
                    L_perComp[k] += np.inner(
                        SS.Count1[k] + Prior.lam1 - Post.lam1[k],
                        ElogphiT[:, k])
                    L_perComp[k] += np.inner(
                        SS.Count0[k] + Prior.lam0 - Post.lam0[k],
                        Elog1mphiT[:, k])
            if returnVec:
                return L_perComp
            return np.sum(L_perComp)
        elif self.CompDims == ('K','K',):
            # Relational case, K x K x D
            assert Post._FieldDims['lam1'] == ('K', 'K', 'D')

            cPrior = c_Func(Prior.lam1, Prior.lam0)
            Ldata = SS.K * SS.K * cPrior - c_Func(Post.lam1, Post.lam0)
            if not afterMStep:
                Ldata += np.sum(
                    (SS.Count1 + Prior.lam1[nx, nx, :] - Post.lam1) *
                    ElogphiT.T)
                Ldata += np.sum(
                    (SS.Count0 + Prior.lam0[nx, nx, :] - Post.lam0) *
                    Elog1mphiT.T)
            return Ldata
        else:
            raise ValueError("Unrecognized compdims: " + str(self.CompDims))
            
    def getDatasetScale(self, SS, extraSS=None):
        ''' Get number of observed scalars in dataset from suff stats.

        Used for normalizing the ELBO so it has reasonable range.

        Returns
        ---------
        s : scalar positive integer
        '''
        s = SS.Count1.sum() + SS.Count0.sum()
        if extraSS is None:
            return s
        else:
            sextra = extraSS.Count1.sum() + extraSS.Count0.sum()
            return s - sextra

    def calcHardMergeGap(self, SS, kA, kB):
        ''' Calculate change in ELBO after a hard merge applied to this model

        Returns
        ---------
        gap : scalar real, indicates change in ELBO after merge of kA, kB
        '''
        Prior = self.Prior
        cPrior = c_Func(Prior.lam1, Prior.lam0)

        Post = self.Post
        cA = c_Func(Post.lam1[kA], Post.lam0[kA])
        cB = c_Func(Post.lam1[kB], Post.lam0[kB])

        lam1, lam0 = self.calcPostParamsForComp(SS, kA, kB)
        cAB = c_Func(lam1, lam0)
        return cA + cB - cPrior - cAB

    def calcHardMergeGap_AllPairs(self, SS):
        ''' Calculate change in ELBO for all candidate hard merge pairs

        Returns
        ---------
        Gap : 2D array, size K x K, upper-triangular entries non-zero
              Gap[j,k] : scalar change in ELBO after merge of k into j
        '''
        Prior = self.Prior
        cPrior = c_Func(Prior.lam1, Prior.lam0)

        Post = self.Post
        c = np.zeros(SS.K)
        for k in xrange(SS.K):
            c[k] = c_Func(Post.lam1[k], Post.lam0[k])

        Gap = np.zeros((SS.K, SS.K))
        for j in xrange(SS.K):
            for k in xrange(j + 1, SS.K):
                cjk = c_Func(*self.calcPostParamsForComp(SS, j, k))
                Gap[j, k] = c[j] + c[k] - cPrior - cjk
        return Gap

    def calcHardMergeGap_SpecificPairs(self, SS, PairList):
        ''' Calc change in ELBO for specific list of candidate hard merge pairs

        Returns
        ---------
        Gaps : 1D array, size L
              Gap[j] : scalar change in ELBO after merge of pair in PairList[j]
        '''
        Gaps = np.zeros(len(PairList))
        for ii, (kA, kB) in enumerate(PairList):
            Gaps[ii] = self.calcHardMergeGap(SS, kA, kB)
        return Gaps

    def calcHardMergeGap_SpecificPairSS(self, SS1, SS2):
        ''' Calc change in ELBO for merge of two K=1 suff stat bags.

        Returns
        -------
        gap : scalar float
        '''
        assert SS1.K == 1
        assert SS2.K == 1

        Prior = self.Prior
        cPrior = c_Func(Prior.lam1, Prior.lam0)

        # Compute cumulants of individual states 1 and 2
        lam11, lam10 = self.calcPostParamsForComp(SS1, 0)
        lam21, lam20 = self.calcPostParamsForComp(SS2, 0)
        c1 = c_Func(lam11, lam10)
        c2 = c_Func(lam21, lam20)

        # Compute cumulant of merged state 1&2
        SSM = SS1 + SS2
        lamM1, lamM0 = self.calcPostParamsForComp(SSM, 0)
        cM = c_Func(lamM1, lamM0)

        return c1 + c2 - cPrior - cM

    def calcLogMargLikForComp(self, SS, kA, kB=None, **kwargs):
        ''' Calc log marginal likelihood of data assigned to given component

        Args
        -------
        SS : bnpy suff stats object
        kA : integer ID of target component to compute likelihood for
        kB : (optional) integer ID of second component.
             If provided, we merge kA, kB into one component for calculation.
        Returns
        -------
        logM : scalar real
               logM = log p( data assigned to comp kA )
                      computed up to an additive constant
        '''
        return -1 * c_Func(*self.calcPostParamsForComp(SS, kA, kB))

    def calcMargLik(self, SS):
        ''' Calc log marginal likelihood combining all comps, given suff stats

        Returns
        --------
        logM : scalar real
               logM = \sum_{k=1}^K log p( data assigned to comp k | Prior)
        '''
        return self.calcMargLik_CFuncForLoop(SS)

    def calcMargLik_CFuncForLoop(self, SS):
        Prior = self.Prior
        logp = np.zeros(SS.K)
        for k in xrange(SS.K):
            lam1, lam0 = self.calcPostParamsForComp(SS, k)
            logp[k] = c_Diff(Prior.lam1, Prior.lam0,
                             lam1, lam0)
        return np.sum(logp)

    def _E_logphi(self, k=None):
        if k is None or k == 'prior':
            lam1 = self.Prior.lam1
            lam0 = self.Prior.lam0
        elif k == 'all':
            lam1 = self.Post.lam1
            lam0 = self.Post.lam0
        else:
            lam1 = self.Post.lam1[k]
            lam0 = self.Post.lam0[k]
        Elogphi = digamma(lam1) - digamma(lam1 + lam0)
        return Elogphi

    def _E_log1mphi(self, k=None):
        if k is None or k == 'prior':
            lam1 = self.Prior.lam1
            lam0 = self.Prior.lam0
        elif k == 'all':
            lam1 = self.Post.lam1
            lam0 = self.Post.lam0
        else:
            lam1 = self.Post.lam1[k]
            lam0 = self.Post.lam0[k]
        Elog1mphi = digamma(lam0) - digamma(lam1 + lam0)
        return Elog1mphi

    def _E_logphiT_log1mphiT(self, k=None):
        if k == 'all':
            lam1T = self.Post.lam1.T.copy()
            lam0T = self.Post.lam0.T.copy()
            digammaBoth = digamma(lam1T + lam0T)
            ElogphiT = digamma(lam1T) - digammaBoth
            Elog1mphiT = digamma(lam0T) - digammaBoth
        else:
            ElogphiT = self._E_logphiT(k)
            Elog1mphiT = self._E_log1mphiT(k)
        return ElogphiT, Elog1mphiT

    def _E_logphiT(self, k=None):
        ''' Calculate transpose of expected phi matrix

        Important to make a copy of the matrix so it is C-contiguous,
        which leads to much much faster matrix operations.

        Returns
        -------
        ElogphiT : 2D array, vocab_size x K
        '''
        if k == 'all':
            dlam1T = self.Post.lam1.T.copy()
            dlambothT = self.Post.lam0.T.copy()
            dlambothT += dlam1T
            digamma(dlam1T, out=dlam1T)
            digamma(dlambothT, out=dlambothT)
            return dlam1T - dlambothT
        ElogphiT = self._E_logphi(k).T.copy()
        return ElogphiT

    def _E_log1mphiT(self, k=None):
        ''' Calculate transpose of expected 1-minus-phi matrix

        Important to make a copy of the matrix so it is C-contiguous,
        which leads to much much faster matrix operations.

        Returns
        -------
        ElogphiT : 2D array, vocab_size x K
        '''
        if k == 'all':
            # Copy so lam1T/lam0T are C-contig and can be shared mem.
            lam1T = self.Post.lam1.T.copy()
            lam0T = self.Post.lam0.T.copy()
            return digamma(lam0T) - digamma(lam1T + lam0T)

        ElogphiT = self._E_log1mphi(k).T.copy()
        return ElogphiT

    def getSerializableParamsForLocalStep(self):
        """ Get compact dict of params for local step.

        Returns
        -------
        Info : dict
        """
        return dict(inferType=self.inferType,
                    DataAtomType=self.DataAtomType,
                    K=self.K)

    def fillSharedMemDictForLocalStep(self, ShMem=None):
        """ Get dict of shared mem arrays needed for parallel local step.

        Returns
        -------
        ShMem : dict of RawArray objects
        """
        ElogphiT, Elog1mphiT = self.GetCached('E_logphiT_log1mphiT', 'all')
        K = self.K
        if ShMem is None:
            ShMem = dict()
        if 'ElogphiT' not in ShMem:
            ShMem['ElogphiT'] = numpyToSharedMemArray(ElogphiT)
            ShMem['Elog1mphiT'] = numpyToSharedMemArray(Elog1mphiT)
        else:
            ElogphiT_shView = sharedMemToNumpyArray(ShMem['ElogphiT'])
            assert ElogphiT_shView.shape >= K
            ElogphiT_shView[:, :K] = ElogphiT

            Elog1mphiT_shView = sharedMemToNumpyArray(ShMem['Elog1mphiT'])
            assert Elog1mphiT_shView.shape >= K
            Elog1mphiT_shView[:, :K] = Elog1mphiT
        return ShMem

    def getLocalAndSummaryFunctionHandles(self):
        """ Get function handles for local step and summary step

        Useful for parallelized algorithms.

        Returns
        -------
        calcLocalParams : f handle
        calcSummaryStats : f handle
        """
        return calcLocalParams, calcSummaryStats


    def calcSmoothedMu(self, X, W=None):
        ''' Compute smoothed estimate of probability of each word.

        Returns
        -------
        Mu : 1D array, size D (aka vocab_size)
        '''
        if X is None:
            return self.Prior.lam1 / (self.Prior.lam1 + self.Prior.lam0)

        if X.ndim > 1:
            if W is None:
                NX = X.shape[0]
                X = np.sum(X, axis=0)
            else:
                NX = np.sum(W)
                X = np.dot(W, X)
        else:
            NX = 1

        assert X.ndim == 1
        assert X.size == self.D
        Mu = X + self.Prior.lam1
        Mu /= (NX + self.Prior.lam1 + self.Prior.lam0)
        return Mu

    def calcSmoothedBregDiv(self, 
            X, Mu, W=None,
            smoothFrac=0.0,
            includeOnlyFastTerms=False,
            DivDataVec=None,
            returnDivDataVec=False,
            return1D=False,
            **kwargs):
        ''' Compute Bregman divergence between data X and clusters Mu.

        Smooth the data via update with prior parameters.

        Returns
        -------
        Div : 2D array, N x K
            Div[n,k] = smoothed distance between X[n] and Mu[k]
        '''
        if X.ndim < 2:
            X = X[np.newaxis,:]
        assert X.ndim == 2
        N = X.shape[0]
        D = X.shape[1]

        if W is not None:
            assert W.ndim == 1
            assert W.size == N

        if not isinstance(Mu, list):
            Mu = (Mu,)
        K = len(Mu)
        assert Mu[0].size == D

        # Smooth-ify the data matrix X
        if smoothFrac == 0:
            MuX = np.minimum(X, 1 - 1e-14)
            np.maximum(MuX, 1e-14, out=MuX)
        else:
            MuX = X + smoothFrac * self.Prior.lam1
            NX = 1.0 + smoothFrac * (self.Prior.lam1 + self.Prior.lam0)
            MuX /= NX

        # Compute Div array up to a per-row additive constant indep. of k
        Div = np.zeros((N, K))
        for k in xrange(K):
            Div[:,k] = -1 * np.sum(MuX * np.log(Mu[k]), axis=1) + \
                -1 * np.sum((1-MuX) * np.log(1-Mu[k]), axis=1)

        if not includeOnlyFastTerms:
            if DivDataVec is None:
                # Compute DivDataVec : 1D array of size N
                # This is the per-row additive constant indep. of k. 

                # STEP 1: Compute MuX * log(MuX)
                logMuX = np.log(MuX)
                MuXlogMuX = logMuX
                MuXlogMuX *= MuX
                DivDataVec = np.sum(MuXlogMuX, axis=1)
                
                # STEP 2: Compute (1-MuX) * log(1-MuX)
                OneMinusMuX = MuX
                OneMinusMuX *= -1
                OneMinusMuX += 1
                logOneMinusMuX = logMuX
                np.log(OneMinusMuX, out=logOneMinusMuX)
                logOneMinusMuX *= OneMinusMuX
                DivDataVec += np.sum(logOneMinusMuX, axis=1)
            Div += DivDataVec[:,np.newaxis]

        assert np.all(np.isfinite(Div))
        # Apply per-atom weights to divergences.
        if W is not None:
            assert W.ndim == 1
            assert W.size == N
            Div *= W[:,np.newaxis]
        # Verify divergences are strictly non-negative 
        if not includeOnlyFastTerms:
            minDiv = Div.min()
            if minDiv < 0:
                if minDiv < -1e-6:
                    raise AssertionError(
                        "Expected Div.min() to be positive or" + \
                        " indistinguishable from zero. Instead " + \
                        " minDiv=% .3e" % (minDiv))
                np.maximum(Div, 0, out=Div)
                minDiv = Div.min()
            assert minDiv >= 0

        if return1D:
            Div = Div[:,0]
        if returnDivDataVec:
            return Div, DivDataVec
        return Div

        ''' OLD VERSION
        if smoothFrac == 0:
            MuX = np.minimum(X, 1 - 1e-14)
            MuX = np.maximum(MuX, 1e-14)
        else:
            MuX = X + smoothFrac * self.Prior.lam1
            NX = 1.0 + smoothFrac * (self.Prior.lam1 + self.Prior.lam0)
            MuX /= NX

        Div = np.zeros((N, K))
        for k in xrange(K):
            Mu_k = Mu[k][np.newaxis,:]
            Div[:,k] = np.sum(MuX * np.log(MuX / Mu_k), axis=1) + \
                np.sum((1-MuX) * np.log((1-MuX) / (1-Mu_k)), axis=1)
        '''

    def calcBregDivFromPrior(self, Mu, smoothFrac=0.0):
        ''' Compute Bregman divergence between Mu and prior mean.

        Returns
        -------
        Div : 1D array, size K
            Div[k] = distance between Mu[k] and priorMu
        '''
        if not isinstance(Mu, list):
            Mu = (Mu,) # cheaper than a list
        K = len(Mu)

        priorMu = self.Prior.lam1 / (self.Prior.lam1 + self.Prior.lam0)
        priorN = (1-smoothFrac) * (self.Prior.lam1 + self.Prior.lam0)

        Div = np.zeros((K, self.D))
        for k in xrange(K):
            Div[k, :] = priorMu * np.log(priorMu / Mu[k]) + \
                (1-priorMu) * np.log((1-priorMu)/(1-Mu[k]))
        return np.dot(Div, priorN)
class MultObsModel(AbstractObsModel):
    """ Multinomial data generation model for count vectors.

    Attributes for Prior (Dirichlet)
    --------
    lam : 1D array, size vocab_size
        pseudo-count of observations of each symbol (word) type.

    Attributes for k-th component of EstParams (EM point estimates)
    ---------
    phi[k] : 1D array, size vocab_size
        phi[k] is a vector of positive numbers that sum to one.
        phi[k,v] is probability that vocab type v appears under k.

    Attributes for k-th component of Post (VB parameter)
    ---------
    lam[k] : 1D array, size vocab_size
    """
    def __init__(self,
                 inferType='EM',
                 D=0,
                 vocab_size=0,
                 Data=None,
                 **PriorArgs):
        ''' Initialize bare obsmodel with valid prior hyperparameters.

        Resulting object lacks either EstParams or Post,
        which must be created separately (see init_global_params).
        '''
        if Data is not None:
            self.D = Data.vocab_size
        elif vocab_size > 0:
            self.D = int(vocab_size)
        else:
            self.D = int(D)
        self.K = 0
        self.inferType = inferType
        self.createPrior(Data, **PriorArgs)
        self.Cache = dict()

    def createPrior(self, Data, lam=1.0, min_phi=1e-100, **kwargs):
        ''' Initialize Prior ParamBag attribute.
        '''
        D = self.D
        self.min_phi = min_phi
        self.Prior = ParamBag(K=0, D=D)
        lam = np.asarray(lam, dtype=np.float)
        if lam.ndim == 0:
            lam = lam * np.ones(D)
        assert lam.size == D
        self.Prior.setField('lam', lam, dims=('D'))
        self.prior_cFunc = c_Func(lam)

    def setupWithAllocModel(self, allocModel):
        ''' Using the allocation model, determine the modeling scenario.

        doc  : multinomial : each atom is vector of empirical counts in doc
        word : categorical : each atom is single word token (one of vocab_size)
        '''
        if not isinstance(allocModel, str):
            allocModel = str(type(allocModel))
        aModelName = allocModel.lower()
        if aModelName.count('hdp') or aModelName.count('topic'):
            self.DataAtomType = 'word'
        else:
            self.DataAtomType = 'doc'

    def getTopics(self):
        ''' Retrieve matrix of estimated topic-word probability vectors

        Returns
        --------
        topics : K x vocab_size
                 topics[k,:] is a non-negative vector that sums to one
        '''
        if hasattr(self, 'EstParams'):
            return self.EstParams.phi
        else:
            phi = self.Post.lam / np.sum(self.Post.lam, axis=1)[:, np.newaxis]
            return phi

    def get_name(self):
        return 'Mult'

    def get_info_string(self):
        return 'Multinomial over finite vocabulary.'

    def get_info_string_prior(self):
        msg = 'Dirichlet over finite vocabulary \n'
        if self.D > 2:
            sfx = ' ...'
        else:
            sfx = ''
        S = self.Prior.lam[:2]
        msg += 'lam = %s%s' % (str(S), sfx)
        msg = msg.replace('\n', '\n  ')
        return msg

    def setEstParams(self,
                     obsModel=None,
                     SS=None,
                     LP=None,
                     Data=None,
                     phi=None,
                     topics=None,
                     **kwargs):
        ''' Create EstParams ParamBag with fields phi
        '''
        if topics is not None:
            phi = topics

        self.ClearCache()
        if obsModel is not None:
            self.EstParams = obsModel.EstParams.copy()
            self.K = self.EstParams.K
            return

        if LP is not None and Data is not None:
            SS = self.calcSummaryStats(Data, None, LP)

        if SS is not None:
            self.updateEstParams(SS)
        else:
            self.EstParams = ParamBag(K=phi.shape[0], D=phi.shape[1])
            self.EstParams.setField('phi', phi, dims=('K', 'D'))
        self.K = self.EstParams.K

    def setEstParamsFromPost(self, Post=None, **kwargs):
        ''' Convert from Post (lam) to EstParams (phi),
             each EstParam is set to its posterior mean.
        '''
        if Post is None:
            Post = self.Post
        self.EstParams = ParamBag(K=Post.K, D=Post.D)
        phi = Post.lam / np.sum(Post.lam, axis=1)[:, np.newaxis]
        self.EstParams.setField('phi', phi, dims=('K', 'D'))
        self.K = self.EstParams.K

    def setPostFactors(self,
                       obsModel=None,
                       SS=None,
                       LP=None,
                       Data=None,
                       lam=None,
                       WordCounts=None,
                       **kwargs):
        ''' Set attribute Post to provided values.
        '''
        self.ClearCache()
        if obsModel is not None:
            if hasattr(obsModel, 'Post'):
                self.Post = obsModel.Post.copy()
                self.K = self.Post.K
            else:
                self.setPostFromEstParams(obsModel.EstParams)
            return

        if LP is not None and Data is not None:
            SS = self.calcSummaryStats(Data, None, LP)

        if SS is not None:
            self.updatePost(SS)
        else:
            if WordCounts is not None:
                lam = as2D(WordCounts) + lam
            else:
                lam = as2D(lam)
            K, D = lam.shape
            self.Post = ParamBag(K=K, D=D)
            self.Post.setField('lam', lam, dims=('K', 'D'))
        self.K = self.Post.K

    def setPostFromEstParams(self,
                             EstParams,
                             Data=None,
                             nTotalTokens=0,
                             **kwargs):
        ''' Set attribute Post based on values in EstParams.
        '''
        K = EstParams.K
        D = EstParams.D

        if Data is not None:
            nTotalTokens = Data.word_count.sum()
        if isinstance(nTotalTokens, int) or nTotalTokens.ndim == 0:
            nTotalTokens = float(nTotalTokens) / float(K) * np.ones(K)
        if np.any(nTotalTokens == 0):
            priorScale = self.Prior.lam.sum()
            warnings.warn("Enforcing minimum scale of %.3f for lam" %
                          (priorScale))
            nTotalTokens = np.maximum(nTotalTokens, priorScale)

        if 'lam' in kwargs and kwargs['lam'] is not None:
            lam = kwargs['lam']
        else:
            WordCounts = EstParams.phi * nTotalTokens[:, np.newaxis]
            assert WordCounts.max() > 0
            lam = WordCounts + self.Prior.lam

        self.Post = ParamBag(K=K, D=D)
        self.Post.setField('lam', lam, dims=('K', 'D'))
        self.K = K

    def calcSummaryStats(self, Data, SS, LP, cslice=(0, None), **kwargs):
        ''' Calculate summary statistics for given dataset and local parameters

        Returns
        --------
        SS : SuffStatBag object, with K components.
        '''
        return calcSummaryStats(Data,
                                SS,
                                LP,
                                DataAtomType=self.DataAtomType,
                                **kwargs)

    def forceSSInBounds(self, SS):
        ''' Force count vectors to remain positive
        '''
        np.maximum(SS.WordCounts, 0, out=SS.WordCounts)
        np.maximum(SS.SumWordCounts, 0, out=SS.SumWordCounts)
        if not np.allclose(SS.WordCounts.sum(axis=1), SS.SumWordCounts):
            raise ValueError('Bad Word Counts!')

    def incrementSS(self, SS, k, Data, docID):
        SS.WordCounts[k] += Data.getSparseDocTypeCountMatrix()[docID, :]

    def decrementSS(self, SS, k, Data, docID):
        SS.WordCounts[k] -= Data.getSparseDocTypeCountMatrix()[docID, :]

    def calcLogSoftEvMatrix_FromEstParams(self, Data, **kwargs):
        ''' Compute log soft evidence matrix for Dataset under EstParams.

        Returns
        ---------
        L : 2D array, N x K
        '''
        logphiT = np.log(self.EstParams.phi.T)
        if self.DataAtomType == 'doc':
            X = Data.getSparseDocTypeCountMatrix()
            return X * logphiT
        else:
            return logphiT[Data.word_id, :]

    def updateEstParams_MaxLik(self, SS):
        ''' Update attribute EstParams for all comps given suff stats.

        Update uses the maximum likelihood objective for point estimation.

        Post Condition
        ---------
        Attributes K and EstParams updated in-place.
        '''
        self.ClearCache()
        if not hasattr(self, 'EstParams') or self.EstParams.K != SS.K:
            self.EstParams = ParamBag(K=SS.K, D=SS.D)
        phi = SS.WordCounts / SS.SumWordCounts[:, np.newaxis]
        # prevent entries from reaching exactly 0
        np.maximum(phi, self.min_phi, out=phi)
        self.EstParams.setField('phi', phi, dims=('K', 'D'))

    def updateEstParams_MAP(self, SS):
        ''' Update attribute EstParams for all comps given suff stats.

        Update uses the MAP objective for point estimation.

        Post Condition
        ---------
        Attributes K and EstParams updated in-place.
        '''
        self.ClearCache()
        if not hasattr(self, 'EstParams') or self.EstParams.K != SS.K:
            self.EstParams = ParamBag(K=SS.K, D=SS.D)
        phi = SS.WordCounts + self.Prior.lam - 1
        phi /= phi.sum(axis=1)[:, np.newaxis]
        self.EstParams.setField('phi', phi, dims=('K', 'D'))

    def updatePost(self, SS):
        ''' Update attribute Post for all comps given suff stats.

        Update uses the variational objective.

        Post Condition
        ---------
        Attributes K and Post updated in-place.
        '''
        self.ClearCache()
        if not hasattr(self, 'Post') or self.Post.K != SS.K:
            self.Post = ParamBag(K=SS.K, D=SS.D)

        lam = self.calcPostParams(SS)
        self.Post.setField('lam', lam, dims=('K', 'D'))
        self.K = SS.K

    def calcPostParams(self, SS):
        ''' Calc updated params (lam) for all comps given suff stats

            These params define the common-form of the exponential family
            Dirichlet posterior distribution over parameter vector phi

            Returns
            --------
            lam : 2D array, size K x D
        '''
        Prior = self.Prior
        lam = SS.WordCounts + Prior.lam[np.newaxis, :]
        return lam

    def calcPostParamsForComp(self, SS, kA=None, kB=None):
        ''' Calc params (lam) for specific comp, given suff stats

            These params define the common-form of the exponential family
            Dirichlet posterior distribution over parameter vector phi

            Returns
            --------
            lam : 1D array, size D
        '''
        if kB is None:
            SM = SS.WordCounts[kA]
        else:
            SM = SS.WordCounts[kA] + SS.WordCounts[kB]
        return SM + self.Prior.lam

    def updatePost_stochastic(self, SS, rho):
        ''' Update attribute Post for all comps given suff stats

        Update uses the stochastic variational formula.

        Post Condition
        ---------
        Attributes K and Post updated in-place.
        '''
        assert hasattr(self, 'Post')
        assert self.Post.K == SS.K
        self.ClearCache()

        lam = self.calcPostParams(SS)
        Post = self.Post
        Post.lam[:] = (1 - rho) * Post.lam + rho * lam

    def convertPostToNatural(self):
        ''' Convert current posterior params from common to natural form
        '''
        # Dirichlet common equivalent to natural here.
        pass

    def convertPostToCommon(self):
        ''' Convert current posterior params from natural to common form
        '''
        # Dirichlet common equivalent to natural here.
        pass

    def calcLogSoftEvMatrix_FromPost(self, Data, **kwargs):
        ''' Calculate expected log soft ev matrix under Post.

        Returns
        ------
        L : 2D array, size N x K
        '''
        ElogphiT = self.GetCached('E_logphiT', 'all')  # V x K
        doSparse1 = 'activeonlyLP' in kwargs and kwargs['activeonlyLP'] == 2
        doSparse2 = 'nnzPerRowLP' in kwargs and \
            kwargs['nnzPerRowLP'] > 0 and kwargs['nnzPerRowLP'] < self.K
        if doSparse2 and doSparse1:
            return dict(ElogphiT=ElogphiT)
        else:
            E_log_soft_ev = calcLogSoftEvMatrix_FromPost_Static(
                Data,
                DataAtomType=self.DataAtomType,
                ElogphiT=ElogphiT,
                **kwargs)
            return dict(E_log_soft_ev=E_log_soft_ev, ElogphiT=ElogphiT)

    def calcELBO_Memoized(self,
                          SS,
                          returnVec=0,
                          afterGlobalStep=False,
                          **kwargs):
        """ Calculate obsModel's objective using suff stats SS and Post.

        Args
        -------
        SS : bnpy SuffStatBag
        afterMStep : boolean flag
            if 1, elbo calculated assuming M-step just completed

        Returns
        -------
        obsELBO : scalar float
            Equal to E[ log p(x) + log p(phi) - log q(phi)]
        """
        elbo = np.zeros(SS.K)
        Post = self.Post
        Prior = self.Prior
        if not afterGlobalStep:
            Elogphi = self.GetCached('E_logphi', 'all')  # K x V
        for k in range(SS.K):
            elbo[k] = self.prior_cFunc - self.GetCached('cFunc', k)
            #elbo[k] = c_Diff(Prior.lam, Post.lam[k])
            if not afterGlobalStep:
                elbo[k] += np.inner(SS.WordCounts[k] + Prior.lam - Post.lam[k],
                                    Elogphi[k])
        if returnVec:
            return elbo
        return np.sum(elbo)

    def logh(self, Data):
        ''' Calculate reference measure for the multinomial distribution

        Returns
        -------
        logh : scalar float, log h(Data) = \sum_{n=1}^N log [ C!/prod_d C_d!]
        '''
        raise NotImplementedError('TODO')

    def getDatasetScale(self, SS, extraSS=None):
        ''' Get number of observed scalars in dataset from suff stats.

        Used for normalizing the ELBO so it has reasonable range.

        Returns
        ---------
        s : scalar positive integer
        '''
        if extraSS is None:
            return SS.SumWordCounts.sum()
        else:
            return SS.SumWordCounts.sum() - extraSS.SumWordCounts.sum()

    def calcCFuncForMergeComp(self, SS, kA=None, kB=None, tmpvec=None):
        ''' Compute cumulant function value directly from suff stats

        Returns
        -------
        cval : c_Func evaluated on SS[kA] + SS[kB] + priorlam
        '''
        if tmpvec is None:
            tmpvec = SS.WordCounts[kA] + SS.WordCounts[kB]
        else:
            np.add(SS.WordCounts[kA], SS.WordCounts[kB], out=tmpvec)
        tmpvec += self.Prior.lam
        gammalnsum = gammaln(np.sum(tmpvec))
        return gammalnsum - np.sum(gammaln(tmpvec))

    def calcHardMergeGap(self, SS, kA, kB):
        ''' Calculate change in ELBO after a hard merge applied to this model

        Returns
        ---------
        gap : scalar real, indicates change in ELBO after merge of kA, kB
        '''
        #Prior = self.Prior
        #cPrior = c_Func(Prior.lam)
        cPrior = self.prior_cFunc

        Post = self.Post
        cA = c_Func(Post.lam[kA])
        cB = c_Func(Post.lam[kB])

        cAB = self.calcCFuncForMergeComp(SS, kA, kB)
        #lam = self.calcPostParamsForComp(SS, kA, kB)
        #cAB = c_Func(lam)
        return cA + cB - cPrior - cAB

    def calcHardMergeGap_AllPairs(self, SS):
        ''' Calculate change in ELBO for all candidate hard merge pairs

        Returns
        ---------
        Gap : 2D array, size K x K, upper-triangular entries non-zero
              Gap[j,k] : scalar change in ELBO after merge of k into j
        '''
        cPrior = self.prior_cFunc

        Post = self.Post
        c = np.zeros(SS.K)
        for k in range(SS.K):
            c[k] = c_Func(Post.lam[k])

        tmpvec = np.zeros(Post.D)
        Gap = np.zeros((SS.K, SS.K))
        for j in range(SS.K):
            for k in range(j + 1, SS.K):
                cjk = self.calcCFuncForMergeComp(SS, j, k, tmpvec=tmpvec)
                #lam = self.calcPostParamsForComp(SS, j, k)
                #oldcjk = c_Func(lam)
                #assert np.allclose(cjk, oldcjk)
                Gap[j, k] = c[j] + c[k] - cPrior - cjk
        return Gap

    def calcHardMergeGap_SpecificPairs(self, SS, PairList):
        ''' Calc change in ELBO for specific list of candidate hard merge pairs

        Returns
        ---------
        Gaps : 1D array, size L
              Gap[j] : scalar change in ELBO after merge of pair in PairList[j]
        '''
        Gaps = np.zeros(len(PairList))
        for ii, (kA, kB) in enumerate(PairList):
            Gaps[ii] = self.calcHardMergeGap(SS, kA, kB)
        return Gaps

    def calcLogMargLikForComp(self, SS, kA, kB=None, **kwargs):
        ''' Calc log marginal likelihood of data assigned to given component

        Args
        -------
        SS : bnpy suff stats object
        kA : integer ID of target component to compute likelihood for
        kB : (optional) integer ID of second component.
             If provided, we merge kA, kB into one component for calculation.
        Returns
        -------
        logM : scalar real
               logM = log p( data assigned to comp kA )
                      computed up to an additive constant
        '''
        return -1 * c_Func(self.calcPostParamsForComp(SS, kA, kB))

    def calcMargLik(self, SS):
        ''' Calc log marginal likelihood combining all comps, given suff stats

            Returns
            --------
            logM : scalar real
                   logM = \sum_{k=1}^K log p( data assigned to comp k | Prior)
        '''
        return self.calcMargLik_CFuncForLoop(SS)

    def calcMargLik_CFuncForLoop(self, SS):
        Prior = self.Prior
        logp = np.zeros(SS.K)
        for k in range(SS.K):
            lam = self.calcPostParamsForComp(SS, k)
            logp[k] = c_Diff(Prior.lam, lam)
        return np.sum(logp)

    def _cFunc(self, k=None):
        ''' Compute cached value of cumulant function at desired cluster index.

        Args
        ----
        k : int or str or None
            None or 'prior' uses the prior parameter
            otherwise, uses integer cluster index

        Returns
        -------
        cval : scalar real
        '''
        if k is None or k == 'prior':
            return c_Func(self.Prior.lam)
        elif k == 'all':
            raise NotImplementedError("TODO")
        else:
            return c_Func(self.Post.lam[k])

    def _E_logphi(self, k=None):
        if k is None or k == 'prior':
            lam = self.Prior.lam
            Elogphi = digamma(lam) - digamma(np.sum(lam))
        elif k == 'all':
            AMat = self.Post.lam
            Elogphi = digamma(AMat) \
                - digamma(np.sum(AMat, axis=1))[:, np.newaxis]
        else:
            Elogphi = digamma(self.Post.lam[k]) - \
                digamma(self.Post.lam[k].sum())
        return Elogphi

    def _E_logphiT(self, k=None):
        ''' Calculate transpose of topic-word matrix

            Important to make a copy of the matrix so it is C-contiguous,
            which leads to much much faster matrix operations.

            Returns
            -------
            ElogphiT : 2D array, vocab_size x K
        '''
        if k is None or k == 'prior':
            lam = self.Prior.lam
            ElogphiT = digamma(lam) - digamma(np.sum(lam))
        elif k == 'all':
            ElogphiT = self.Post.lam.T.copy()
            digamma(ElogphiT, out=ElogphiT)
            digammaColSumVec = digamma(np.sum(self.Post.lam, axis=1))
            ElogphiT -= digammaColSumVec[np.newaxis, :]
        else:
            ElogphiT = digamma(self.Post.lam[k]) - \
                digamma(self.Post.lam[k].sum())
        assert ElogphiT.flags.c_contiguous
        return ElogphiT

    def getSerializableParamsForLocalStep(self):
        """ Get compact dict of params for local step.

        Returns
        -------
        Info : dict
        """
        return dict(inferType=self.inferType,
                    K=self.K,
                    DataAtomType=self.DataAtomType)

    def fillSharedMemDictForLocalStep(self, ShMem=None):
        """ Get dict of shared mem arrays needed for parallel local step.

        Returns
        -------
        ShMem : dict of RawArray objects
        """
        ElogphiT = self.GetCached('E_logphiT', 'all')  # V x K
        K = self.K
        if ShMem is None:
            ShMem = dict()
        if 'ElogphiT' not in ShMem:
            ShMem['ElogphiT'] = numpyToSharedMemArray(ElogphiT)
        else:
            ShMemView = sharedMemToNumpyArray(ShMem['ElogphiT'])
            assert ShMemView.shape >= ElogphiT.shape
            ShMemView[:, :K] = ElogphiT
        return ShMem

    def getLocalAndSummaryFunctionHandles(self):
        """ Get function handles for local step and summary step

        Useful for parallelized algorithms.

        Returns
        -------
        calcLocalParams : f handle
        calcSummaryStats : f handle
        """
        return calcLocalParams, calcSummaryStats

    def calcSmoothedMu(self, X, W=None):
        ''' Compute smoothed estimate of probability of each word.

        Returns
        -------
        Mu : 1D array, size D (aka vocab_size)
            Each entry is non-negative, whole vector sums to one.
        '''
        if X is None:
            Mu = self.Prior.lam.copy()
            Mu /= Mu.sum()
            return Mu

        if X.ndim > 1:
            if W is None:
                X = np.sum(X, axis=0)
            else:
                X = np.dot(W, X)
        assert X.ndim == 1
        assert X.size == self.D
        Mu = X + self.Prior.lam
        Mu /= Mu.sum()
        return Mu

    def calcSmoothedBregDiv(self,
                            X,
                            Mu,
                            W=None,
                            smoothFrac=0.0,
                            includeOnlyFastTerms=False,
                            DivDataVec=None,
                            returnDivDataVec=False,
                            return1D=False,
                            **kwargs):
        ''' Compute Bregman divergence between data X and clusters Mu.

        Smooth the data via update with prior parameters.

        Keyword Args
        ------------
        includeOnlyFastTerms : boolean
            if False, includes all terms in divergence calculation.
                Returns Div[n,:] guaranteed to be non-negative.
            if True, includes only terms that vary with cluster index k
                Returns Div[n,:] equal to divergence up to additive constant

        Returns
        -------
        Div : 2D array, N x K
            Div[n,k] = smoothed distance between X[n] and Mu[k]
        '''
        if X.ndim < 2:
            X = X[np.newaxis, :]
        assert X.ndim == 2
        N = X.shape[0]
        if not isinstance(Mu, list):
            Mu = (Mu, )
        K = len(Mu)
        # Compute Div array up to a per-row additive constant indep. of k
        Div = np.zeros((N, K))
        for k in range(K):
            Div[:, k] = -1 * np.dot(X, np.log(Mu[k]))

        # Compute contribution of prior smoothing
        if smoothFrac > 0:
            smoothVec = smoothFrac * self.Prior.lam
            for k in range(K):
                Div[:, k] -= np.sum(smoothVec * np.log(Mu[k]))
            # Equivalent to -1 * np.dot(MuX, np.log(Mu[k])),
            # but without allocating a new matrix MuX

        if not includeOnlyFastTerms:
            if DivDataVec is None:
                # Compute DivDataVec : 1D array of size N
                # This is the per-row additive constant indep. of k.
                # We do lots of steps in-place, to save memory.
                if smoothFrac > 0:
                    MuX = X + smoothVec
                else:
                    # Add small pos constant so that we never compute np.log(0)
                    MuX = X + 1e-100
                NX = MuX.sum(axis=1)
                # First block equivalent to
                # DivDataVec = -1 * NX * np.log(NX)
                DivDataVec = np.log(NX)
                DivDataVec *= -1 * NX

                # This next block is equivalent to:
                # >>> DivDataVec += np.sum(MuX * np.log(MuX), axis=1)
                # but uses in-place operations with faster numexpr library.
                NumericUtil.inplaceLog(MuX)
                logMuX = MuX
                if smoothFrac > 0:
                    DivDataVec += np.dot(logMuX, smoothVec)
                logMuX *= X
                XlogMuX = logMuX
                DivDataVec += np.sum(XlogMuX, axis=1)

            Div += DivDataVec[:, np.newaxis]

        # Apply per-atom weights to divergences.
        if W is not None:
            assert W.ndim == 1
            assert W.size == N
            Div *= W[:, np.newaxis]
        # Verify divergences are strictly non-negative
        if not includeOnlyFastTerms:
            minDiv = Div.min()
            if minDiv < 0:
                if minDiv < -1e-6:
                    raise AssertionError(
                        "Expected Div.min() to be positive or" + \
                        " indistinguishable from zero. Instead " + \
                        " minDiv=% .3e" % (minDiv))
                np.maximum(Div, 0, out=Div)
                minDiv = Div.min()
            assert minDiv >= 0
        if return1D:
            Div = Div[:, 0]
        if returnDivDataVec:
            return Div, DivDataVec
        return Div

    def calcBregDivFromPrior(self, Mu, smoothFrac=0.0):
        ''' Compute Bregman divergence between Mu and prior mean.

        Returns
        -------
        Div : 1D array, size K
            Div[k] = distance between Mu[k] and priorMu
        '''
        if not isinstance(Mu, list):
            Mu = (Mu, )
        K = len(Mu)

        priorMu = self.Prior.lam / self.Prior.lam.sum()
        priorN = (1 - smoothFrac) * (self.Prior.lam[0] / priorMu[0])

        Div = np.zeros(K)
        for k in range(K):
            Div[k] = np.sum(priorMu * np.log(priorMu / Mu[k]))
        return priorN * Div
예제 #10
0
class DiagGaussObsModel(AbstractObsModel):
    ''' Diagonal gaussian data generation model for real vectors.

    Attributes for Prior (Normal-Wishart)
    --------
    nu : float
        degrees of freedom
    beta : 1D array, size D
        scale parameters that set mean of parameter sigma
    m : 1D array, size D
        mean of the parameter mu
    kappa : float
        scalar precision on parameter mu

    Attributes for k-th component of EstParams (EM point estimates)
    ---------
    mu[k] : 1D array, size D
    sigma[k] : 1D array, size D

    Attributes for k-th component of Post (VB parameter)
    ---------
    nu[k] : float
    beta[k] : 1D array, size D
    m[k] : 1D array, size D
    kappa[k] : float

    '''
    def __init__(self,
                 inferType='EM',
                 D=0,
                 min_covar=None,
                 Data=None,
                 **PriorArgs):
        ''' Initialize bare obsmodel with valid prior hyperparameters.

        Resulting object lacks either EstParams or Post,
        which must be created separately (see init_global_params).
        '''
        if Data is not None:
            self.D = Data.dim
        else:
            self.D = int(D)
        self.K = 0
        self.inferType = inferType
        self.min_covar = min_covar
        self.Prior = createParamBagForPrior(Data, D=D, **PriorArgs)
        self.Cache = dict()

    def get_mean_for_comp(self, k=None):
        if hasattr(self, 'EstParams'):
            return self.EstParams.mu[k]
        elif k is None or k == 'prior':
            return self.Prior.m
        else:
            return self.Post.m[k]

    def get_covar_mat_for_comp(self, k=None):
        if hasattr(self, 'EstParams'):
            return np.diag(self.EstParams.sigma[k])
        elif k is None or k == 'prior':
            return self._E_CovMat()
        else:
            return self._E_CovMat(k)

    def get_name(self):
        return 'DiagGauss'

    def get_info_string(self):
        return 'Gaussian with diagonal covariance.'

    def get_info_string_prior(self):
        return getStringSummaryOfPrior(self.Prior)

    def setEstParams(self,
                     obsModel=None,
                     SS=None,
                     LP=None,
                     Data=None,
                     mu=None,
                     sigma=None,
                     Sigma=None,
                     **kwargs):
        ''' Create EstParams ParamBag with fields mu, Sigma
        '''
        self.ClearCache()
        if obsModel is not None:
            self.EstParams = obsModel.EstParams.copy()
            self.K = self.EstParams.K
            return

        if LP is not None and Data is not None:
            SS = self.calcSummaryStats(Data, None, LP)

        if SS is not None:
            self.updateEstParams(SS)
        else:
            K = mu.shape[0]
            if Sigma is not None:
                assert Sigma.ndim == 3
                sigma = np.empty((Sigma.shape[0], Sigma.shape[1]))
                for k in xrange(K):
                    sigma[k] = np.diag(Sigma[k])
            assert sigma.ndim == 2
            self.EstParams = ParamBag(K=K, D=mu.shape[1])
            self.EstParams.setField('mu', mu, dims=('K', 'D'))
            self.EstParams.setField('sigma', sigma, dims=('K', 'D'))
        self.K = self.EstParams.K

    def setEstParamsFromPost(self, Post):
        ''' Convert from Post (nu, beta, m, kappa) to EstParams (mu, Sigma),
             each EstParam is set to its posterior mean.
        '''
        self.EstParams = ParamBag(K=Post.K, D=self.D)
        mu = Post.m.copy()
        sigma = Post.beta / (Post.nu - 2)[:, np.newaxis]
        self.EstParams.setField('mu', mu, dims=('K', 'D'))
        self.EstParams.setField('sigma', sigma, dims=('K', 'D'))
        self.K = self.EstParams.K

    def setPostFactors(self,
                       obsModel=None,
                       SS=None,
                       LP=None,
                       Data=None,
                       **param_kwargs):
        ''' Set attribute Post to provided values.
        '''
        self.ClearCache()
        if obsModel is not None:
            if hasattr(obsModel, 'Post'):
                self.Post = obsModel.Post.copy()
                self.K = self.Post.K
            else:
                self.setPostFromEstParams(obsModel.EstParams)
            return

        if LP is not None and Data is not None:
            SS = self.calcSummaryStats(Data, None, LP)

        if SS is not None:
            self.updatePost(SS)
        else:
            if 'D' not in param_kwargs:
                param_kwargs['D'] = self.D
            self.Post = packParamBagForPost(**param_kwargs)
        self.K = self.Post.K

    def setPostFromEstParams(self, EstParams, Data=None, N=None):
        ''' Set attribute Post based on values in EstParams.
        '''
        K = EstParams.K
        D = EstParams.D
        if Data is not None:
            N = Data.nObsTotal

        N = np.asarray(N, dtype=np.float)
        if N.ndim == 0:
            N = float(N) / K * np.ones(K)

        nu = self.Prior.nu + N
        beta = np.zeros((K, D))
        beta = (nu - 2)[:, np.newaxis] * EstParams.sigma
        m = EstParams.mu.copy()
        kappa = self.Prior.kappa + N

        self.Post = ParamBag(K=K, D=D)
        self.Post.setField('nu', nu, dims=('K'))
        self.Post.setField('beta', beta, dims=('K', 'D'))
        self.Post.setField('m', m, dims=('K', 'D'))
        self.Post.setField('kappa', kappa, dims=('K'))
        self.K = self.Post.K

    def calcSummaryStats(self, Data, SS, LP, **kwargs):
        ''' Calculate summary statistics for given dataset and local parameters

        Returns
        --------
        SS : SuffStatBag object, with K components.
        '''
        return calcSummaryStats(Data, SS, LP, **kwargs)

    def forceSSInBounds(self, SS):
        ''' Force count vector N to remain positive
        '''
        np.maximum(SS.N, 0, out=SS.N)

    def incrementSS(self, SS, k, x):
        SS.x[k] += x
        SS.xx[k] += np.square(x)

    def decrementSS(self, SS, k, x):
        SS.x[k] -= x
        SS.xx[k] -= np.square(x)

    def calcLogSoftEvMatrix_FromEstParams(self, Data, **kwargs):
        ''' Compute log soft evidence matrix for Dataset under EstParams.

        Returns
        ---------
        L : 2D array, N x K
        '''
        K = self.EstParams.K
        L = np.zeros((Data.nObs, K))
        for k in xrange(K):
            L[:, k] = - 0.5 * self.D * LOGTWOPI \
                - 0.5 * np.sum(np.log(self.EstParams.sigma[k])) \
                - 0.5 * self._mahalDist_EstParam(Data.X, k)
        return L

    def _mahalDist_EstParam(self, X, k):
        ''' Calculate distance to every row of matrix X

            Args
            -------
            X : 2D array, size N x D

            Returns
            ------
            dist : 1D array, size N
        '''
        Xdiff = X - self.EstParams.mu[k]
        np.square(Xdiff, out=Xdiff)
        dist = np.sum(Xdiff / self.EstParams.sigma[k], axis=1)
        return dist

    def updateEstParams_MaxLik(self, SS):
        ''' Update attribute EstParams for all comps given suff stats.

        Update uses the maximum likelihood objective for point estimation.

        Post Condition
        ---------
        Attributes K and EstParams updated in-place.
        '''
        self.ClearCache()
        if not hasattr(self, 'EstParams') or self.EstParams.K != SS.K:
            self.EstParams = ParamBag(K=SS.K, D=SS.D)

        mu = SS.x / SS.N[:, np.newaxis]
        sigma = self.min_covar \
            + SS.xx / SS.N[:, np.newaxis] \
            - np.square(mu)

        self.EstParams.setField('mu', mu, dims=('K', 'D'))
        self.EstParams.setField('sigma', sigma, dims=('K', 'D'))
        self.K = SS.K

    def updateEstParams_MAP(self, SS):
        ''' Update attribute EstParams for all comps given suff stats.

        Update uses the MAP objective for point estimation.

        Post Condition
        ---------
        Attributes K and EstParams updated in-place.
        '''
        self.ClearCache()
        if not hasattr(self, 'EstParams') or self.EstParams.K != SS.K:
            self.EstParams = ParamBag(K=SS.K, D=SS.D)

        Prior = self.Prior
        nu = Prior.nu + SS.N
        kappa = Prior.kappa + SS.N
        PB = Prior.beta + Prior.kappa * np.square(Prior.m)

        m = np.empty((SS.K, SS.D))
        beta = np.empty((SS.K, SS.D))
        for k in xrange(SS.K):
            km_x = Prior.kappa * Prior.m + SS.x[k]
            m[k] = 1.0 / kappa[k] * km_x
            beta[k] = PB + SS.xx[k] - 1.0 / kappa[k] * np.square(km_x)

        mu, sigma = MAPEstParams_inplace(nu, beta, m, kappa)
        self.EstParams.setField('mu', mu, dims=('K', 'D'))
        self.EstParams.setField('sigma', sigma, dims=('K', 'D'))
        self.K = SS.K

    def updatePost(self, SS):
        ''' Update attribute Post for all comps given suff stats.

        Update uses the variational objective.

        Post Condition
        ---------
        Attributes K and Post updated in-place.
        '''
        self.ClearCache()
        if not hasattr(self, 'Post') or self.Post.K != SS.K:
            self.Post = ParamBag(K=SS.K, D=SS.D)

        nu, beta, m, kappa = self.calcPostParams(SS)
        self.Post.setField('nu', nu, dims=('K'))
        self.Post.setField('kappa', kappa, dims=('K'))
        self.Post.setField('m', m, dims=('K', 'D'))
        self.Post.setField('beta', beta, dims=('K', 'D'))
        self.K = SS.K

    def calcPostParams(self, SS):
        ''' Calc posterior parameters for all comps given suff stats.

        Returns
        --------
        nu : 1D array, size K
        beta : 2D array, size K x D
        m : 2D array, size K x D
        kappa : 1D array, size K
        '''
        Prior = self.Prior
        nu = Prior.nu + SS.N
        kappa = Prior.kappa + SS.N
        m = (Prior.kappa * Prior.m + SS.x) / kappa[:, np.newaxis]
        beta = Prior.beta + SS.xx \
            + Prior.kappa * np.square(Prior.m) \
            - kappa[:, np.newaxis] * np.square(m)
        return nu, beta, m, kappa

    def calcPostParamsForComp(self, SS, kA, kB=None):
        ''' Calc posterior parameters for specific comp given suff stats.

        Returns
        --------
        nu : positive scalar
        beta : 1D array, size D
        m : 1D array, size D
        kappa : positive scalar
        '''
        return calcPostParamsFromSSForComp(SS, kA, kB)
        '''
        if kB is None:
            SN = SS.N[kA]
            Sx = SS.x[kA]
            Sxx = SS.xx[kA]
        else:
            SN = SS.N[kA] + SS.N[kB]
            Sx = SS.x[kA] + SS.x[kB]
            Sxx = SS.xx[kA] + SS.xx[kB]
        Prior = self.Prior
        nu = Prior.nu + SN
        kappa = Prior.kappa + SN
        m = (Prior.kappa * Prior.m + Sx) / kappa
        beta = Prior.beta + Sxx \
            + Prior.kappa * np.square(Prior.m) \
            - kappa * np.square(m)
        return nu, beta, m, kappa
        '''

    def updatePost_stochastic(self, SS, rho):
        ''' Update attribute Post for all comps given suff stats

        Update uses the stochastic variational formula.

        Post Condition
        ---------
        Attributes K and Post updated in-place.
        '''
        assert hasattr(self, 'Post')
        assert self.Post.K == SS.K
        self.ClearCache()

        self.convertPostToNatural()
        nu, b, km, kappa = self.calcNaturalPostParams(SS)
        Post = self.Post
        Post.nu[:] = (1 - rho) * Post.nu + rho * nu
        Post.b[:] = (1 - rho) * Post.b + rho * b
        Post.km[:] = (1 - rho) * Post.km + rho * km
        Post.kappa[:] = (1 - rho) * Post.kappa + rho * kappa
        self.convertPostToCommon()

    def calcNaturalPostParams(self, SS):
        ''' Calc natural posterior params for all comps given suff stats.


        Returns
        --------
        nu : 1D array, size K
        b : 2D array, size K x D
        km : 2D array, size K x D
        kappa : 1D array, size K
        '''
        Prior = self.Prior
        nu = Prior.nu + SS.N
        kappa = Prior.kappa + SS.N
        km = Prior.kappa * Prior.m + SS.x
        b = Prior.beta + Prior.kappa * np.square(Prior.m) + SS.xx
        return nu, b, km, kappa

    def convertPostToNatural(self):
        ''' Convert current posterior params from common to natural form
        '''
        Post = self.Post
        assert hasattr(Post, 'nu')
        assert hasattr(Post, 'kappa')
        km = Post.m * Post.kappa[:, np.newaxis]
        b = Post.beta + (np.square(km) / Post.kappa[:, np.newaxis])
        Post.setField('km', km, dims=('K', 'D'))
        Post.setField('b', b, dims=('K', 'D'))

    def convertPostToCommon(self):
        ''' Convert current posterior params from natural to common form
        '''
        Post = self.Post
        assert hasattr(Post, 'nu')
        assert hasattr(Post, 'kappa')
        if hasattr(Post, 'm'):
            Post.m[:] = Post.km / Post.kappa[:, np.newaxis]
        else:
            m = Post.km / Post.kappa[:, np.newaxis]
            Post.setField('m', m, dims=('K', 'D'))

        if hasattr(Post, 'beta'):
            Post.beta[:] = Post.b - \
                (np.square(Post.km) / Post.kappa[:, np.newaxis])
        else:
            beta = Post.b - (np.square(Post.km) / Post.kappa[:, np.newaxis])
            Post.setField('beta', beta, dims=('K', 'D'))

    def calcLogSoftEvMatrix_FromPost(self, Data, **kwargs):
        return calcLogSoftEvMatrix_FromPost(Data, Post=self.Post)

    def zzz_calcLogSoftEvMatrix_FromPost(self, Data, **kwargs):
        ''' Calculate expected log soft ev matrix under Post.

        Returns
        ------
        L : 2D array, size N x K
        '''
        K = self.Post.K
        L = np.zeros((Data.nObs, K))
        for k in xrange(K):
            L[:, k] = - 0.5 * self.D * LOGTWOPI \
                + 0.5 * np.sum(self.GetCached('E_logL', k)) \
                - 0.5 * self._mahalDist_Post(Data.X, k)
        return L

    def _mahalDist_Post(self, X, k):
        ''' Calc expected mahalonobis distance from comp k to each data atom

            Returns
            --------
            distvec : 1D array, size nObs
                   distvec[n] gives E[ \Lam (x-\mu)^2 ] for comp k
        '''
        Xdiff = X - self.Post.m[k]
        np.square(Xdiff, out=Xdiff)
        dist = np.dot(Xdiff, self.Post.nu[k] / self.Post.beta[k])
        dist += self.D / self.Post.kappa[k]
        return dist

    def calcELBO_Memoized(self, SS, returnVec=0, afterMStep=False, **kwargs):
        return calcELBOFromSSAndPost(SS=SS, Post=self.Post, Prior=self.Prior)

    def _zzz_calcELBO_Memoized(self,
                               SS,
                               returnVec=0,
                               afterMStep=False,
                               **kwargs):
        """ Calculate obsModel's objective using suff stats SS and Post.

        Args
        -------
        SS : bnpy SuffStatBag
        afterMStep : boolean flag
            if 1, elbo calculated assuming M-step just completed

        Returns
        -------
        obsELBO : scalar float
            Equal to E[ log p(x) + log p(phi) - log q(phi)]
        """
        elbo = np.zeros(SS.K)
        Post = self.Post
        Prior = self.Prior
        for k in xrange(SS.K):
            elbo[k] = c_Diff(
                Prior.nu,
                Prior.beta,
                Prior.m,
                Prior.kappa,
                Post.nu[k],
                Post.beta[k],
                Post.m[k],
                Post.kappa[k],
            )
            if not afterMStep:
                aDiff = SS.N[k] + Prior.nu - Post.nu[k]
                bDiff = SS.xx[k] + Prior.beta \
                    + Prior.kappa * np.square(Prior.m) \
                    - Post.beta[k] \
                    - Post.kappa[k] * np.square(Post.m[k])
                cDiff = SS.x[k] + Prior.kappa * Prior.m \
                    - Post.kappa[k] * Post.m[k]
                dDiff = SS.N[k] + Prior.kappa - Post.kappa[k]
                elbo[k] += 0.5 * aDiff * np.sum(self._E_logL(k)) \
                    - 0.5 * np.inner(bDiff, self._E_L(k)) \
                    + np.inner(cDiff, self.GetCached('E_Lmu', k)) \
                    - 0.5 * dDiff * np.sum(self.GetCached('E_muLmu', k))

        elbo += -(0.5 * SS.D * LOGTWOPI) * SS.N
        if returnVec:
            return elbo
        return elbo.sum()

    def getDatasetScale(self, SS):
        ''' Get number of observed scalars in dataset from suff stats.

        Used for normalizing the ELBO so it has reasonable range.

        Returns
        ---------
        s : scalar positive integer
        '''
        return SS.N.sum() * SS.D

    def calcHardMergeGap(self, SS, kA, kB):
        ''' Calculate change in ELBO after a hard merge applied to this model

        Returns
        ---------
        gap : scalar real, indicates change in ELBO after merge of kA, kB
        '''
        gap, _, _ = calcHardMergeGapForPair(SS=SS,
                                            Prior=self.Prior,
                                            Post=self.Post,
                                            kA=kA,
                                            kB=kB)
        return gap

    def calcHardMergeGap_AllPairs(self, SS):
        ''' Calculate change in ELBO for all possible hard merge pairs

        Returns
        ---------
        Gap : 2D array, size K x K, upper-triangular entries non-zero
              Gap[j,k] : scalar change in ELBO after merge of k into j
        '''
        Gap2D = np.zeros((SS.K, SS.K))
        cPrior = None
        cPost_K = [None for k in range(SS.K)]
        for kA in xrange(SS.K):
            for kB in xrange(kA + 1, SS.K):
                Gap2D[kA, kB], cPost_K, cPrior = calcHardMergeGapForPair(
                    SS=SS,
                    Post=self.Post,
                    Prior=self.Prior,
                    kA=kA,
                    kB=kB,
                    cPrior=cPrior,
                    cPost_K=cPost_K)
        return Gap2D

    def calcHardMergeGap_SpecificPairs(self, SS, PairList):
        ''' Calc change in ELBO for specific list of candidate hard merge pairs

        Returns
        ---------
        Gaps : 1D array, size L
              Gap[j] : scalar change in ELBO after merge of pair in PairList[j]
        '''
        Gaps = np.zeros(len(PairList))
        cPrior = None
        cPost_K = [None for k in range(SS.K)]
        for ii, (kA, kB) in enumerate(PairList):
            Gaps[ii], cPost_K, cPrior = calcHardMergeGapForPair(
                SS=SS,
                Post=self.Post,
                Prior=self.Prior,
                kA=kA,
                kB=kB,
                cPrior=cPrior,
                cPost_K=cPost_K)
        return Gaps

    def calcLogMargLikForComp(self, SS, kA, kB=None, **kwargs):
        ''' Calc log marginal likelihood of data assigned to given component

        Args
        -------
        SS : bnpy suff stats object
        kA : integer ID of target component to compute likelihood for
        kB : (optional) integer ID of second component.
             If provided, we merge kA, kB into one component for calculation.
        Returns
        -------
        logM : scalar real
               logM = log p( data assigned to comp kA | Prior )
                      computed up to an additive constant
        '''
        nu, B, m, kappa = self.calcPostParamsForComp(SS, kA, kB)
        return -1 * c_Func(nu, B, m, kappa)

    def calcMargLik(self, SS):
        ''' Calc log marginal likelihood additively across all comps.

        Returns
        --------
        logM : scalar real
               logM = \sum_{k=1}^K log p( data assigned to comp k | Prior)
        '''
        return self.calcMargLik_CFuncForLoop(SS)

    def calcMargLik_CFuncForLoop(self, SS):
        Prior = self.Prior
        logp = np.zeros(SS.K)
        for k in xrange(SS.K):
            nu, beta, m, kappa = self.calcPostParamsForComp(SS, k)
            logp[k] = c_Diff(Prior.nu, Prior.beta, Prior.m, Prior.kappa, nu,
                             beta, m, kappa)
        return np.sum(logp) - 0.5 * np.sum(SS.N) * LOGTWOPI

    def calcPredProbVec_Unnorm(self, SS, x):
        ''' Calculate K-vector of positive entries \propto p( x | SS[k] )
        '''
        return self._calcPredProbVec_Fast(SS, x)

    def _calcPredProbVec_Naive(self, SS, x):
        nu, beta, m, kappa = self.calcPostParams(SS)
        pSS = SS.copy()
        pSS.N += 1
        pSS.x += x
        pSS.xx += np.square(x)
        pnu, pbeta, pm, pkappa = self.calcPostParams(pSS)
        logp = np.zeros(SS.K)
        for k in xrange(SS.K):
            logp[k] = c_Diff(nu[k], beta[k], m[k], kappa[k], pnu[k], pbeta[k],
                             pm[k], pkappa[k])
        return np.exp(logp - np.max(logp))

    def _calcPredProbVec_Fast(self, SS, x):
        p = np.zeros(SS.K)
        nu, beta, m, kappa = self.calcPostParams(SS)
        kbeta = beta
        kbeta *= ((kappa + 1) / kappa)[:, np.newaxis]
        base = np.square(x - m)
        base /= kbeta
        base += 1
        # logp : 2D array, size K x D
        logp = (-0.5 * (nu + 1))[:, np.newaxis] * np.log(base)
        logp += (gammaln(0.5 * (nu + 1)) - gammaln(0.5 * nu))[:, np.newaxis]
        logp -= 0.5 * np.log(kbeta)

        # p : 1D array, size K
        p = np.sum(logp, axis=1)
        p -= np.max(p)
        np.exp(p, out=p)
        return p

    def _calcPredProbVec_ForLoop(self, SS, x):
        ''' For-loop version
        '''
        p = np.zeros(SS.K)
        for k in xrange(SS.K):
            nu, beta, m, kappa = self.calcPostParamsForComp(SS, k)
            kbeta = (kappa + 1) / kappa * beta
            base = np.square(x - m)
            base /= kbeta
            base += 1
            p_k = np.exp(gammaln(0.5 * (nu + 1)) - gammaln(0.5 * nu)) \
                * 1.0 / np.sqrt(kbeta) \
                * base ** (-0.5 * (nu + 1))
            p[k] = np.prod(p_k)
        return p

    def _Verify_calcPredProbVec(self, SS, x):
        ''' Verify that the predictive prob vector is correct,
              by comparing 3 very different implementations
        '''
        pA = self._calcPredProbVec_Fast(SS, x)
        pB = self._calcPredProbVec_Naive(SS, x)
        pC = self._calcPredProbVec_ForLoop(SS, x)
        pA /= np.sum(pA)
        pB /= np.sum(pB)
        pC /= np.sum(pC)
        assert np.allclose(pA, pB)
        assert np.allclose(pA, pC)

    def _E_CovMat(self, k=None):
        ''' Get expected value of Sigma under specified distribution.

        Returns
        --------
        E[ Sigma ] : 2D array, size DxD
        '''
        return np.diag(self._E_Cov(k))

    def _E_Cov(self, k=None):
        ''' Get expected value of sigma vector under specified distribution.

        Returns
        --------
        E[ sigma^2 ] : 1D array, size D
        '''
        if k is None:
            nu = self.Prior.nu
            beta = self.Prior.beta
        else:
            nu = self.Post.nu[k]
            beta = self.Post.beta[k]
        return beta / (nu - 2)

    def _E_logL(self, k=None):
        '''
        Returns
        -------
        E_logL : 1D array, size D
        '''
        if k == 'all':
            # retVec : K x D
            retVec = LOGTWO - np.log(self.Post.beta.copy())  # no strided!
            retVec += digamma(0.5 * self.Post.nu)[:, np.newaxis]
            return retVec
        elif k is None:
            nu = self.Prior.nu
            beta = self.Prior.beta
        else:
            nu = self.Post.nu[k]
            beta = self.Post.beta[k]
        return LOGTWO - np.log(beta) + digamma(0.5 * nu)

    def _E_L(self, k=None):
        '''
        Returns
        --------
        EL : 1D array, size D
        '''
        if k is None:
            nu = self.Prior.nu
            beta = self.Prior.beta
        else:
            nu = self.Post.nu[k]
            beta = self.Post.beta[k]
        return nu / beta

    def _E_Lmu(self, k=None):
        '''
        Returns
        --------
        ELmu : 1D array, size D
        '''
        if k is None:
            nu = self.Prior.nu
            beta = self.Prior.beta
            m = self.Prior.m
        else:
            nu = self.Post.nu[k]
            beta = self.Post.beta[k]
            m = self.Post.m[k]
        return (nu / beta) * m

    def _E_muLmu(self, k=None):
        ''' Calc expectation E[lam * mu^2]

        Returns
        --------
        EmuLmu : 1D array, size D
        '''
        if k is None:
            nu = self.Prior.nu
            kappa = self.Prior.kappa
            m = self.Prior.m
            beta = self.Prior.beta
        else:
            nu = self.Post.nu[k]
            kappa = self.Post.kappa[k]
            m = self.Post.m[k]
            beta = self.Post.beta[k]
        return 1.0 / kappa + (nu / beta) * (m * m)

    def getSerializableParamsForLocalStep(self):
        """ Get compact dict of params for local step.

        Returns
        -------
        Info : dict
        """
        if self.inferType == 'EM':
            raise NotImplementedError('TODO')
        return dict(
            inferType=self.inferType,
            K=self.K,
            D=self.D,
        )

    def fillSharedMemDictForLocalStep(self, ShMem=None):
        """ Get dict of shared mem arrays needed for parallel local step.

        Returns
        -------
        ShMem : dict of RawArray objects
        """
        if ShMem is None:
            ShMem = dict()
        if 'nu' in ShMem:
            fillSharedMemArray(ShMem['nu'], self.Post.nu)
            fillSharedMemArray(ShMem['kappa'], self.Post.kappa)
            fillSharedMemArray(ShMem['m'], self.Post.m)
            fillSharedMemArray(ShMem['beta'], self.Post.beta)
            fillSharedMemArray(ShMem['E_logL'], self._E_logL('all'))

        else:
            ShMem['nu'] = numpyToSharedMemArray(self.Post.nu)
            ShMem['kappa'] = numpyToSharedMemArray(self.Post.kappa)
            # Post.m is strided, so we need to copy it to do shared mem.
            ShMem['m'] = numpyToSharedMemArray(self.Post.m.copy())
            ShMem['beta'] = numpyToSharedMemArray(self.Post.beta.copy())
            ShMem['E_logL'] = numpyToSharedMemArray(self._E_logL('all'))

        return ShMem

    def getLocalAndSummaryFunctionHandles(self):
        """ Get function handles for local step and summary step

        Useful for parallelized algorithms.

        Returns
        -------
        calcLocalParams : f handle
        calcSummaryStats : f handle
        """
        return calcLocalParams, calcSummaryStats

    def calcSmoothedMu(self, X, W=None):
        ''' Compute smoothed estimate of mean of statistic xxT.

        Args
        ----
        X : 2D array, size N x D

        Returns
        -------
        Mu_1 : 2D array, size D
            Expected value of Var[ X[n,d] ]
        Mu_2 : 1D array, size D
            Expected value of Mean[ X[n] ]
        '''
        if X is None:
            Mu1 = self.Prior.beta / self.Prior.nu
            Mu2 = self.Prior.m
            return Mu1, Mu2

        if X.ndim == 1:
            X = X[np.newaxis, :]
        N, D = X.shape
        # Compute suff stats
        if W is None:
            sum_wxx = np.sum(np.square(X), axis=0)
            sum_wx = np.sum(X, axis=0)
            sum_w = X.shape[0]
        else:
            W = as1D(W)
            sum_wxx = np.dot(W, np.square(X))
            sum_wx = np.dot(W, X)
            sum_w = np.sum(W)

        post_kappa = self.Prior.kappa + sum_w
        post_m = (self.Prior.m * self.Prior.kappa + sum_wx) / post_kappa
        Mu_2 = post_m

        prior_kmm = self.Prior.kappa * (self.Prior.m * self.Prior.m)
        post_kmm = post_kappa * (post_m * post_m)
        post_beta = sum_wxx + self.Prior.beta + prior_kmm - post_kmm
        Mu_1 = post_beta / (self.Prior.nu + sum_w)

        assert Mu_1.ndim == 1
        assert Mu_1.shape == (D, )
        assert Mu_2.shape == (D, )
        return Mu_1, Mu_2

    def calcSmoothedBregDiv(self,
                            X,
                            Mu,
                            W=None,
                            eps=1e-10,
                            smoothFrac=0.0,
                            includeOnlyFastTerms=False,
                            DivDataVec=None,
                            returnDivDataVec=False,
                            return1D=False,
                            **kwargs):
        ''' Compute Bregman divergence between data X and clusters Mu.

        Smooth the data via update with prior parameters.

        Args
        ----
        X : 2D array, size N x D
        Mu : list of size K, or tuple

        Returns
        -------
        Div : 2D array, N x K
            Div[n,k] = smoothed distance between X[n] and Mu[k]
        '''
        # Parse X
        if X.ndim < 2:
            X = X[np.newaxis, :]
        assert X.ndim == 2
        N = X.shape[0]
        D = X.shape[1]
        # Parse Mu
        if isinstance(Mu, tuple):
            Mu = [Mu]
        assert isinstance(Mu, list)
        K = len(Mu)
        assert Mu[0][0].size == D
        assert Mu[0][1].size == D

        prior_x = self.Prior.m
        prior_varx = self.Prior.beta / (self.Prior.nu)
        VarX = eps * prior_varx

        Div = np.zeros((N, K))
        for k in xrange(K):
            muVar_k = Mu[k][0]
            muMean_k = Mu[k][1]
            logdet_MuVar_k = np.sum(np.log(muVar_k))
            squareDiff_X_Mu_k = np.square(X - muMean_k)
            tr_k = np.sum((VarX + squareDiff_X_Mu_k) / \
                muVar_k[np.newaxis,:], axis=1)
            Div[:,k] = \
                + 0.5 * logdet_MuVar_k \
                + 0.5 * tr_k
        # Only enter here if exactly computing Div,
        # If just need it up to additive constant, skip this part.
        if not includeOnlyFastTerms:
            if DivDataVec is None:
                # Compute DivDataVec : 1D array of size N
                # This is the per-row additive constant indep. of k.
                DivDataVec = -0.5 * D * np.ones(N)
                logdet_VarX = np.sum(np.log(VarX))
                DivDataVec -= 0.5 * logdet_VarX

            Div += DivDataVec[:, np.newaxis]
        # Apply per-atom weights to divergences.
        if W is not None:
            assert W.ndim == 1
            assert W.size == N
            Div *= W[:, np.newaxis]
        # Verify divergences are strictly non-negative
        if not includeOnlyFastTerms:
            minDiv = Div.min()
            if minDiv < 0:
                if minDiv < -1e-6:
                    raise AssertionError(
                        "Expected Div.min() to be positive or" + \
                        " indistinguishable from zero. Instead " + \
                        " minDiv=% .3e" % (minDiv))
                np.maximum(Div, 0, out=Div)
                minDiv = Div.min()
            assert minDiv >= 0
        if return1D:
            Div = Div[:, 0]
        if returnDivDataVec:
            return Div, DivDataVec
        return Div

    def calcBregDivFromPrior(self, Mu, smoothFrac=0.0):
        ''' Compute Bregman divergence between Mu and prior mean.

        Returns
        -------
        Div : 1D array, size K
            Div[k] = distance between Mu[k] and priorMu
        '''
        assert isinstance(Mu, list)
        K = len(Mu)
        assert K >= 1
        assert Mu[0][0].ndim == 1
        assert Mu[0][1].ndim == 1
        D = Mu[0][0].size
        assert D == Mu[0][1].size

        priorMuVar = self.Prior.beta / self.Prior.nu
        priorMuMean = self.Prior.m

        priorN_ZMG = (1 - smoothFrac) * self.Prior.nu
        priorN_FVG = (1 - smoothFrac) * self.Prior.kappa

        Div_ZMG = np.zeros(K)  # zero-mean gaussian
        Div_FVG = np.zeros(K)  # fixed variance gaussian

        logdet_priorMuVar = np.sum(np.log(priorMuVar))
        for k in xrange(K):
            MuVar_k = Mu[k][0]
            MuMean_k = Mu[k][1]

            logdet_MuVar_k = np.sum(np.log(MuVar_k))

            Div_ZMG[k] = 0.5 * logdet_MuVar_k + \
                - 0.5 * logdet_priorMuVar + \
                + 0.5 * np.sum(priorMuVar / MuVar_k) + \
                - 0.5
            squareDiff = np.square(priorMuMean - MuMean_k)
            Div_FVG[k] = 0.5 * np.sum(squareDiff / MuVar_k)

        return priorN_ZMG * Div_ZMG + priorN_FVG * Div_FVG
예제 #11
0
class ZeroMeanGaussObsModel(AbstractObsModel):
    ''' Zero-mean, full-covariance gaussian model for real vectors.

    Attributes for Prior (Normal-Wishart)
    --------
    nu : float
        degrees of freedom
    B : 2D array, size D x D
        scale parameters that set mean of parameter Sigma

    Attributes for k-th component of EstParams (EM point estimates)
    ---------
    Sigma[k] : 2D array, size DxD

    Attributes for k-th component of Post (VB parameter)
    ---------
    nu[k] : float
    B[k] : 1D array, size D

    '''
    def __init__(self,
                 inferType='EM',
                 D=0,
                 min_covar=None,
                 Data=None,
                 **PriorArgs):
        ''' Initialize bare obsmodel with valid prior hyperparameters.

        Resulting object lacks either EstParams or Post,
        which must be created separately (see init_global_params).
        '''
        if Data is not None:
            self.D = Data.dim
        else:
            self.D = int(D)
        self.K = 0
        self.inferType = inferType
        self.min_covar = min_covar
        self.createPrior(Data, **PriorArgs)
        self.Cache = dict()

    def createPrior(self, Data, nu=0, B=None, ECovMat=None, sF=1.0, **kwargs):
        ''' Initialize Prior ParamBag attribute.

        Post Condition
        ------
        Prior expected covariance matrix set to match provided value.
        '''
        D = self.D
        nu = np.maximum(nu, D + 2)
        if B is None:
            if ECovMat is None or isinstance(ECovMat, str):
                ECovMat = createECovMatFromUserInput(D, Data, ECovMat, sF)
            B = ECovMat * (nu - D - 1)
        else:
            B = as2D(B)
        self.Prior = ParamBag(K=0, D=D)
        self.Prior.setField('nu', nu, dims=None)
        self.Prior.setField('B', B, dims=('D', 'D'))

    def get_mean_for_comp(self, k):
        return np.zeros(self.D)

    def get_covar_mat_for_comp(self, k=None):
        if hasattr(self, 'EstParams'):
            return self.EstParams.Sigma[k]
        elif k is None or k == 'prior':
            return self._E_CovMat()
        else:
            return self._E_CovMat(k)

    def get_name(self):
        return 'ZeroMeanGauss'

    def get_info_string(self):
        return 'Gaussian with fixed zero means, full covariance.'

    def get_info_string_prior(self):
        msg = 'Wishart on prec matrix Lam\n'
        if self.D > 2:
            sfx = ' ...'
        else:
            sfx = ''
        S = self._E_CovMat()[:2, :2]
        msg += 'E[ CovMat[k] ] = \n'
        msg += str(S) + sfx
        msg = msg.replace('\n', '\n  ')
        return msg

    def setEstParams(self,
                     obsModel=None,
                     SS=None,
                     LP=None,
                     Data=None,
                     Sigma=None,
                     **kwargs):
        ''' Create EstParams ParamBag with fields Sigma
        '''
        self.ClearCache()
        if obsModel is not None:
            self.EstParams = obsModel.EstParams.copy()
            self.K = self.EstParams.K
            return

        if LP is not None and Data is not None:
            SS = self.calcSummaryStats(Data, None, LP)

        if SS is not None:
            self.updateEstParams(SS)
        else:
            K = Sigma.shape[0]
            self.EstParams = ParamBag(K=K, D=self.D)
            self.EstParams.setField('Sigma', Sigma, dims=('K', 'D', 'D'))
        self.K = self.EstParams.K

    def setEstParamsFromPost(self, Post):
        ''' Convert from Post (nu, B) to EstParams (Sigma),
             each EstParam is set to its posterior mean.
        '''
        D = Post.D
        self.EstParams = ParamBag(K=Post.K, D=D)
        Sigma = Post.B / (Post.nu - D - 1)[:, np.newaxis, np.newaxis]
        self.EstParams.setField('Sigma', Sigma, dims=('K', 'D', 'D'))
        self.K = self.EstParams.K

    def setPostFactors(self,
                       obsModel=None,
                       SS=None,
                       LP=None,
                       Data=None,
                       nu=0,
                       B=0,
                       **kwargs):
        ''' Set attribute Post to provided values.
        '''
        self.ClearCache()
        if obsModel is not None:
            if hasattr(obsModel, 'Post'):
                self.Post = obsModel.Post.copy()
                self.K = self.Post.K
            else:
                self.setPostFromEstParams(obsModel.EstParams)
            return

        if LP is not None and Data is not None:
            SS = self.calcSummaryStats(Data, None, LP)

        if SS is not None:
            self.updatePost(SS)
        else:
            K = B.shape[0]
            self.Post = ParamBag(K=K, D=self.D)
            self.Post.setField('nu', as1D(nu), dims=('K'))
            self.Post.setField('B', B, dims=('K', 'D', 'D'))
        self.K = self.Post.K

    def setPostFromEstParams(self, EstParams, Data=None, N=None):
        ''' Set attribute Post based on values in EstParams.
        '''
        K = EstParams.K
        D = EstParams.D
        if Data is not None:
            N = Data.nObsTotal
        N = np.asarray(N, dtype=np.float)
        if N.ndim == 0:
            N = float(N) / K * np.ones(K)

        nu = self.Prior.nu + N
        B = np.zeros((K, D, D))
        for k in xrange(K):
            B[k] = (nu[k] - D - 1) * EstParams.Sigma[k]
        self.Post = ParamBag(K=K, D=D)
        self.Post.setField('nu', nu, dims=('K'))
        self.Post.setField('B', B, dims=('K', 'D', 'D'))
        self.K = K

    def calcSummaryStats(self, Data, SS, LP, **kwargs):
        ''' Calculate summary statistics for given dataset and local parameters

        Returns
        --------
        SS : SuffStatBag object, with K components.
        '''
        return calcSummaryStats(Data, SS, LP, **kwargs)

    def calcSummaryStatsForContigBlock(self,
                                       Data,
                                       SS=None,
                                       a=None,
                                       b=None,
                                       **kwargs):
        ''' Calculate summary statistics for specific block of dataset

        Returns
        --------
        SS : SuffStatBag object, with K components.
        '''
        SS = SuffStatBag(K=1, D=Data.dim)

        # Expected count
        SS.setField('N', (b - a) * np.ones(1, dtype=np.float64), dims='K')

        # Expected outer-product
        xxT = dotATA(Data.X[a:b])[np.newaxis, :, :]
        SS.setField('xxT', xxT, dims=('K', 'D', 'D'))
        return SS

    def forceSSInBounds(self, SS):
        ''' Force count vector N to remain positive
        '''
        np.maximum(SS.N, 0, out=SS.N)

    def incrementSS(self, SS, k, x):
        SS.xxT[k] += np.outer(x, x)

    def decrementSS(self, SS, k, x):
        SS.xxT[k] -= np.outer(x, x)

    def calcLogSoftEvMatrix_FromEstParams(self, Data, **kwargs):
        ''' Compute log soft evidence matrix for Dataset under EstParams.

        Returns
        ---------
        L : 2D array, N x K
        '''
        K = self.EstParams.K
        L = np.empty((Data.nObs, K))
        for k in xrange(K):
            L[:, k] = - 0.5 * self.D * LOGTWOPI \
                - 0.5 * self._logdetSigma(k)  \
                - 0.5 * self._mahalDist_EstParam(Data.X, k)
        return L

    def _mahalDist_EstParam(self, X, k):
        ''' Calc Mahalanobis distance from comp k to every row of X

        Args
        ---------
        X : 2D array, size N x D
        k : integer ID of comp

        Returns
        ----------
        dist : 1D array, size N
        '''
        cholSigma_k = self.GetCached('cholSigma', k)
        Q = scipy.linalg.solve_triangular(cholSigma_k,
                                          X.T,
                                          lower=True,
                                          check_finite=False)
        Q *= Q
        return np.sum(Q, axis=0)

    def _cholSigma(self, k):
        ''' Calculate lower cholesky decomposition of Sigma for comp k

        Returns
        --------
        L : 2D array, size D x D, lower triangular
            Sigma = np.dot(L, L.T)
        '''
        return scipy.linalg.cholesky(self.EstParams.Sigma[k], lower=1)

    def _logdetSigma(self, k):
        ''' Calculate log determinant of EstParam.Sigma for comp k

        Returns
        ---------
        logdet : scalar real
        '''
        return 2 * np.sum(np.log(np.diag(self.GetCached('cholSigma', k))))

    def updateEstParams_MaxLik(self, SS):
        ''' Update attribute EstParams for all comps given suff stats.

        Update uses the maximum likelihood objective for point estimation.

        Post Condition
        ---------
        Attributes K and EstParams updated in-place.
        '''
        self.ClearCache()
        if not hasattr(self, 'EstParams') or self.EstParams.K != SS.K:
            self.EstParams = ParamBag(K=SS.K, D=SS.D)

        minCovMat = self.min_covar * np.eye(SS.D)
        covMat = np.tile(minCovMat, (SS.K, 1, 1))
        for k in xrange(SS.K):
            covMat[k] += SS.xxT[k] / SS.N[k]
        self.EstParams.setField('Sigma', covMat, dims=('K', 'D', 'D'))
        self.K = SS.K

    def updateEstParams_MAP(self, SS):
        ''' Update attribute EstParams for all comps given suff stats.

        Update uses the MAP objective for point estimation.

        Post Condition
        ---------
        Attributes K and EstParams updated in-place.
        '''
        self.ClearCache()
        if not hasattr(self, 'EstParams') or self.EstParams.K != SS.K:
            self.EstParams = ParamBag(K=SS.K, D=SS.D)
        Prior = self.Prior
        nu = Prior.nu + SS.N
        B = np.empty((SS.K, SS.D, SS.D))
        for k in xrange(SS.K):
            B[k] = Prior.B + SS.xxT[k]

        Sigma = MAPEstParams_inplace(nu, B)
        self.EstParams.setField('Sigma', Sigma, dims=('K', 'D', 'D'))
        self.K = SS.K

    def updatePost(self, SS):
        ''' Update attribute Post for all comps given suff stats.

        Update uses the variational objective.

        Post Condition
        ---------
        Attributes K and Post updated in-place.
        '''
        self.ClearCache()
        if not hasattr(self, 'Post') or self.Post.K != SS.K:
            self.Post = ParamBag(K=SS.K, D=SS.D)

        # use 'Prior' not 'self.Prior', improves readability
        Prior = self.Prior
        Post = self.Post

        Post.setField('nu', Prior.nu + SS.N, dims=('K'))
        B = np.empty((SS.K, SS.D, SS.D))
        for k in xrange(SS.K):
            B[k] = Prior.B + SS.xxT[k]
        Post.setField('B', B, dims=('K', 'D', 'D'))
        self.K = SS.K

    def calcPostParams(self, SS):
        ''' Calc posterior parameters for all comps given suff stats

        Returns
        --------
        nu : 1D array, size K
        B : 3D array, size K x D x D, each B[k] is symmetric and pos. def.
        '''
        Prior = self.Prior
        nu = Prior.nu + SS.N
        B = Prior.B + SS.xxT
        return nu, B

    def calcPostParamsForComp(self, SS, kA=None, kB=None):
        ''' Calc params (nu, B, m, kappa) for specific comp, given suff stats

        Returns
        --------
        nu : positive scalar
        B : 2D array, size D x D, symmetric and positive definite
        '''
        if kB is None:
            SN = SS.N[kA]
            SxxT = SS.xxT[kA]
        else:
            SN = SS.N[kA] + SS.N[kB]
            SxxT = SS.xxT[kA] + SS.xxT[kB]
        Prior = self.Prior
        nu = Prior.nu + SN
        B = Prior.B + SxxT
        return nu, B

    def updatePost_stochastic(self, SS, rho):
        ''' Update attribute Post for all comps given suff stats

        Update uses the stochastic variational formula.

        Post Condition
        ---------
        Attributes K and Post updated in-place.
        '''
        assert hasattr(self, 'Post')
        assert self.Post.K == SS.K
        self.ClearCache()

        nu, B = self.calcPostParams(SS)
        Post = self.Post
        Post.nu[:] = (1 - rho) * Post.nu + rho * nu
        Post.B[:] = (1 - rho) * Post.B + rho * B

    def convertPostToNatural(self):
        ''' Convert current posterior params from common to natural form

        Here, the Wishart common form is already equivalent to the natural form
        '''
        pass

    def convertPostToCommon(self):
        ''' Convert (current posterior params from natural to common form

        Here, the Wishart common form is already equivalent to the natural form
        '''
        pass

    def calcLogSoftEvMatrix_FromPost(self, Data, **kwargs):
        ''' Calculate expected log soft ev matrix under Post.

        Returns
        ------
        L : 2D array, size N x K
        '''
        K = self.Post.K
        L = np.zeros((Data.nObs, K))
        for k in xrange(K):
            L[:, k] = - 0.5 * self.D * LOGTWOPI \
                + 0.5 * self.GetCached('E_logdetL', k)  \
                - 0.5 * self._mahalDist_Post(Data.X, k)
        return L

    def _mahalDist_Post(self, X, k):
        ''' Calc expected mahalonobis distance from comp k to each data atom

            Returns
            --------
            distvec : 1D array, size nObs
                   distvec[n] gives E[ (x-\mu) \Lam (x-\mu) ] for comp k
        '''
        cholB_k = self.GetCached('cholB', k)
        Q = scipy.linalg.solve_triangular(cholB_k,
                                          X.T,
                                          lower=True,
                                          check_finite=False)
        Q *= Q
        return self.Post.nu[k] * np.sum(Q, axis=0)

    def calcELBO_Memoized(self, SS, returnVec=0, afterMStep=False, **kwargs):
        """ Calculate obsModel's objective using suff stats SS and Post.

        Args
        -------
        SS : bnpy SuffStatBag
        afterMStep : boolean flag
            if 1, elbo calculated assuming M-step just completed

        Returns
        -------
        obsELBO : scalar float
            Equal to E[ log p(x) + log p(phi) - log q(phi)]
        """
        elbo = np.zeros(SS.K)
        Post = self.Post
        Prior = self.Prior
        for k in xrange(SS.K):
            elbo[k] = c_Diff(
                Prior.nu,
                self.GetCached('logdetB'),
                self.D,
                Post.nu[k],
                self.GetCached('logdetB', k),
            )
            if not afterMStep:
                aDiff = SS.N[k] + Prior.nu - Post.nu[k]
                bDiff = SS.xxT[k] + Prior.B - Post.B[k]
                elbo[k] += 0.5 * aDiff * self.GetCached('E_logdetL', k) \
                    - 0.5 * self._trace__E_L(bDiff, k)
        if returnVec:
            return elbo - (0.5 * SS.D * LOGTWOPI) * SS.N
        return elbo.sum() - 0.5 * np.sum(SS.N) * SS.D * LOGTWOPI

    def getDatasetScale(self, SS):
        ''' Get number of observed scalars in dataset from suff stats.

        Used for normalizing the ELBO so it has reasonable range.

        Returns
        ---------
        s : scalar positive integer
        '''
        return SS.N.sum() * SS.D

    def calcHardMergeGap(self, SS, kA, kB):
        ''' Calculate change in ELBO after a hard merge applied to this model

            Returns
            ---------
            gap : scalar real, indicates change in ELBO after merge of kA, kB
        '''
        Post = self.Post
        Prior = self.Prior
        cPrior = c_Func(Prior.nu, self.GetCached('logdetB'), self.D)

        cA = c_Func(Post.nu[kA], self.GetCached('logdetB', kA), self.D)
        cB = c_Func(Post.nu[kB], self.GetCached('logdetB', kB), self.D)

        nu, B = self.calcPostParamsForComp(SS, kA, kB)
        cAB = c_Func(nu, B)
        return cA + cB - cPrior - cAB

    def calcHardMergeGap_AllPairs(self, SS):
        ''' Calculate change in ELBO for all candidate hard merge pairs

            Returns
            ---------
            Gap : 2D array, size K x K, upper-triangular entries non-zero
                  Gap[j,k] : scalar change in ELBO after merge of k into j
        '''
        Post = self.Post
        Prior = self.Prior
        cPrior = c_Func(Prior.nu, self.GetCached('logdetB'), self.D)

        c = np.zeros(SS.K)
        for k in xrange(SS.K):
            c[k] = c_Func(Post.nu[k], self.GetCached('logdetB', k), self.D)

        Gap = np.zeros((SS.K, SS.K))
        for j in xrange(SS.K):
            for k in xrange(j + 1, SS.K):
                nu, B = self.calcPostParamsForComp(SS, j, k)
                cjk = c_Func(nu, B)
                Gap[j, k] = c[j] + c[k] - cPrior - cjk
        return Gap

    def calcHardMergeGap_SpecificPairs(self, SS, PairList):
        ''' Calc change in ELBO for specific list of hard merge pairs

        Returns
        ---------
        Gaps : 1D array, size L
              Gap[j] : scalar change in ELBO after merge of pair in PairList[j]
        '''
        Gaps = np.zeros(len(PairList))
        for ii, (kA, kB) in enumerate(PairList):
            Gaps[ii] = self.calcHardMergeGap(SS, kA, kB)
        return Gaps

    def calcLogMargLikForComp(self, SS, kA, kB=None, **kwargs):
        ''' Calc log marginal likelihood of data assigned to given component

        Args
        -------
        SS : bnpy suff stats object
        kA : integer ID of target component to compute likelihood for
        kB : (optional) integer ID of second component.
             If provided, we merge kA, kB into one component for calculation.
        Returns
        -------
        logM : scalar real
               logM = log p( data assigned to comp kA )
                      computed up to an additive constant
        '''
        nu, B = self.calcPostParamsForComp(SS, kA, kB)
        return -1 * c_Func(nu, B)

    def calcMargLik(self, SS):
        ''' Calc log marginal likelihood given suff stats

        Returns
        --------
        logM : scalar real
               logM = \sum_{k=1}^K log p( data assigned to comp k | Prior)
        '''
        return self.calcMargLik_CFuncForLoop(SS)

    def calcMargLik_CFuncForLoop(self, SS):
        Prior = self.Prior
        logp = np.zeros(SS.K)
        for k in xrange(SS.K):
            nu, B = self.calcPostParamsForComp(SS, k)
            logp[k] = c_Diff(Prior.nu, Prior.B, self.D, nu, B)
        return np.sum(logp) - 0.5 * np.sum(SS.N) * LOGTWOPI

    def calcPredProbVec_Unnorm(self, SS, x):
        ''' Calculate predictive probability that each comp assigns to vector x

        Returns
        --------
        p : 1D array, size K, all entries positive
            p[k] \propto p( x | SS for comp k)
        '''
        return self._calcPredProbVec_Fast(SS, x)

    def _calcPredProbVec_cFunc(self, SS, x):
        nu, B, m, kappa = self.calcPostParams(SS)
        pSS = SS.copy()
        pSS.N += 1
        pSS.xxT += np.outer(x, x)[np.newaxis, :, :]
        pnu, pB, pm, pkappa = self.calcPostParams(pSS)
        logp = np.zeros(SS.K)
        for k in xrange(SS.K):
            logp[k] = c_Diff(nu[k], B[k], self.D, pnu[k], pB[k])
        return np.exp(logp - np.max(logp))

    def _calcPredProbVec_Fast(self, SS, x):
        nu, B = self.calcPostParams(SS)
        logp = np.zeros(SS.K)
        p = logp  # Rename so its not confusing what we're returning
        for k in xrange(SS.K):
            cholB_k = scipy.linalg.cholesky(B[k], lower=1)
            logdetB_k = 2 * np.sum(np.log(np.diag(cholB_k)))
            mVec = np.linalg.solve(cholB_k, x)
            mDist_k = np.inner(mVec, mVec)
            logp[k] = -0.5 * logdetB_k - 0.5 * \
                (nu[k] + 1) * np.log(1.0 + mDist_k)
        logp += gammaln(0.5 * (nu + 1)) - gammaln(0.5 * (nu + 1 - self.D))
        logp -= np.max(logp)
        np.exp(logp, out=p)
        return p

    def _Verify_calcPredProbVec(self, SS, x):
        ''' Verify that the predictive prob vector is correct,
              by comparing very different implementations
        '''
        pA = self._calcPredProbVec_Fast(SS, x)
        pC = self._calcPredProbVec_cFunc(SS, x)
        pA /= np.sum(pA)
        pC /= np.sum(pC)
        assert np.allclose(pA, pC)

    def _E_CovMat(self, k=None):
        if k is None:
            B = self.Prior.B
            nu = self.Prior.nu
        else:
            B = self.Post.B[k]
            nu = self.Post.nu[k]
        return B / (nu - self.D - 1)

    def _cholB(self, k=None):
        if k == 'all':
            retArr = np.zeros((self.K, self.D, self.D))
            for kk in xrange(self.K):
                retArr[kk] = self.GetCached('cholB', kk)
            return retArr
        elif k is None:
            B = self.Prior.B
        else:
            B = self.Post.B[k]
        return scipy.linalg.cholesky(B, lower=True)

    def _logdetB(self, k=None):
        cholB = self.GetCached('cholB', k)
        return 2 * np.sum(np.log(np.diag(cholB)))

    def _E_logdetL(self, k=None):
        dvec = np.arange(1, self.D + 1, dtype=np.float)
        if k == 'all':
            dvec = dvec[:, np.newaxis]
            retVec = self.D * LOGTWO * np.ones(self.K)
            for kk in xrange(self.K):
                retVec[kk] -= self.GetCached('logdetB', kk)
            nuT = self.Post.nu[np.newaxis, :]
            retVec += np.sum(digamma(0.5 * (nuT + 1 - dvec)), axis=0)
            return retVec
        elif k is None:
            nu = self.Prior.nu
        else:
            nu = self.Post.nu[k]
        return self.D * LOGTWO \
            - self.GetCached('logdetB', k) \
            + np.sum(digamma(0.5 * (nu + 1 - dvec)))

    def _trace__E_L(self, Smat, k=None):
        if k is None:
            nu = self.Prior.nu
            B = self.Prior.B
        else:
            nu = self.Post.nu[k]
            B = self.Post.B[k]
        return nu * np.trace(np.linalg.solve(B, Smat))

    def getSmoothedMuForComp(self, k):
        ''' Compute smoothed mean vector for cluster k

        Returns
        -------
        Mu_k : 2D array, size D x D
        '''
        #return self.Post.B[k] / self.Post.nu[k]
        return self.get_covar_mat_for_comp(k)

    def calcSmoothedMu(self, X, W=None):
        ''' Compute smoothed estimate of mean of statistic xxT.

        Args
        ----
        X : 2D array, size N x D

        Returns
        -------
        Mu : 2D array, size D x D
        '''
        Prior_nu = self.Prior.nu - self.D - 1
        # Prior_nu = self.Prior.nu

        if X is None:
            Mu = self.Prior.B / (Prior_nu)
            return Mu
        if X.ndim == 1:
            X = X[np.newaxis, :]
        N, D = X.shape
        # Compute suff stats
        if W is None:
            sum_wxxT = np.dot(X.T, X)
            sum_w = X.shape[0]
        else:
            W = as1D(W)
            wX = np.sqrt(W)[:, np.newaxis] * X
            sum_wxxT = np.dot(wX.T, wX)
            sum_w = np.sum(W)
        Mu = (self.Prior.B + sum_wxxT) / (Prior_nu + sum_w)
        assert Mu.ndim == 2
        assert Mu.shape == (
            D,
            D,
        )
        return Mu

    def calcSmoothedBregDiv(self,
                            X,
                            Mu,
                            W=None,
                            smoothFrac=0.0,
                            eps=1e-10,
                            includeOnlyFastTerms=False,
                            DivDataVec=None,
                            returnDivDataVec=False,
                            return1D=False,
                            **kwargs):
        ''' Compute Bregman divergence between data X and clusters Mu.

        Smooth the data via update with prior parameters.

        Returns
        -------
        Div : 2D array, N x K
            Div[n,k] = smoothed distance between X[n] and Mu[k]
        '''
        if X.ndim < 2:
            X = X[np.newaxis, :]
        assert X.ndim == 2
        N = X.shape[0]
        D = X.shape[1]
        if not isinstance(Mu, list):
            Mu = [Mu]
        K = len(Mu)
        assert Mu[0].ndim == 2
        assert Mu[0].shape[0] == D
        assert Mu[0].shape[1] == D

        if smoothFrac == 0:
            smoothMu = eps * self.Prior.B / (self.Prior.nu - self.D - 1)
            smoothNu = 1.0  # + eps ??
        else:
            smoothMu = self.Prior.B
            smoothNu = 1 + self.Prior.nu - self.D - 1
        Div = np.zeros((N, K))
        for k in xrange(K):
            chol_Mu_k = np.linalg.cholesky(Mu[k])
            logdet_Mu_k = 2.0 * np.sum(np.log(np.diag(chol_Mu_k)))
            xxTInvMu_k = scipy.linalg.solve_triangular(chol_Mu_k,
                                                       X.T,
                                                       lower=True,
                                                       check_finite=False)
            xxTInvMu_k *= xxTInvMu_k
            tr_xxTInvMu_k = np.sum(xxTInvMu_k, axis=0) / smoothNu
            Div[:,k] = 0.5 * logdet_Mu_k + \
                0.5 * tr_xxTInvMu_k
            if smoothFrac > 0:
                Div[:, k] += 0.5 * np.trace(np.linalg.solve(Mu[k], smoothMu))

        if not includeOnlyFastTerms:
            if DivDataVec is None:
                # Compute DivDataVec : 1D array of size N
                # This is the per-row additive constant indep. of k.
                # We do lots of steps in-place, to save memory.

                # FAST VERSION: Use the matrix determinant lemma
                chol_SM = np.linalg.cholesky(smoothMu / smoothNu)
                logdet_SM = 2.0 * np.sum(np.log(np.diag(chol_SM)))
                xxTInvSM = scipy.linalg.solve_triangular(chol_SM,
                                                         X.T,
                                                         lower=True)
                xxTInvSM *= xxTInvSM
                tr_xxTSM = np.sum(xxTInvSM, axis=0) / smoothNu
                assert tr_xxTSM.size == N
                DivDataVec = np.log(1.0 + tr_xxTSM)
                DivDataVec *= -0.5
                DivDataVec += -0.5 * D - 0.5 * logdet_SM
                # SLOW VERSION: use a naive for loop
                # DivDataVecS = -0.5 * D * np.ones(N)
                # for n in xrange(N):
                #    s, logdet_xxT_n = np.linalg.slogdet(
                #        (np.outer(X[n], X[n]) + smoothMu) / smoothNu)
                #    DivDataVecS[n] -= 0.5 * s * logdet_xxT_n

            Div += DivDataVec[:, np.newaxis]

        # Apply per-atom weights to divergences.
        if W is not None:
            assert W.ndim == 1
            assert W.size == N
            Div *= W[:, np.newaxis]
        # Verify divergences are strictly non-negative
        if not includeOnlyFastTerms:
            minDiv = Div.min()
            if minDiv < 0:
                if minDiv < -1e-6:
                    raise AssertionError(
                        "Expected Div.min() to be positive or" + \
                        " indistinguishable from zero. Instead " + \
                        " minDiv=% .3e" % (minDiv))
                np.maximum(Div, 0, out=Div)
                minDiv = Div.min()
            assert minDiv >= 0
        if return1D:
            Div = Div[:, 0]
        if returnDivDataVec:
            return Div, DivDataVec
        return Div
        '''
        logdet_xxT = np.zeros(N)
        tr_xxTInvMu = np.zeros((N, K))
        for n in xrange(N):
            if smoothFrac == 0:
                smooth_xxT = np.outer(X[n], X[n]) + eps * priorMu
            else:
                smooth_xxT = np.outer(X[n], X[n]) + self.Prior.B
                smooth_xxT /= (1.0 + self.Prior.nu)
            s, logdet = np.linalg.slogdet(smooth_xxT)
            logdet_xxT[n] = s * logdet

            for k in xrange(K):
                tr_xxTInvMu[n, k] = np.trace(
                    np.linalg.solve(Mu[k], smooth_xxT))

        Div = np.zeros((N, K))
        for k in xrange(K):
            chol_Mu_k = np.linalg.cholesky(Mu[k])
            logdet_Mu_k = 2.0 * np.sum(np.log(np.diag(chol_Mu_k)))

            Div[:,k] = -0.5 * D - 0.5 * logdet_xxT + \
                0.5 * logdet_Mu_k + \
                0.5 * tr_xxTInvMu[:, k]
        '''

    def calcBregDivFromPrior(self, Mu, smoothFrac=0.0):
        ''' Compute Bregman divergence between Mu and prior mean.

        Returns
        -------
        Div : 1D array, size K
            Div[k] = distance between Mu[k] and priorMu
        '''
        if not isinstance(Mu, list):
            Mu = [Mu]
        K = len(Mu)
        D = Mu[0].shape[0]
        assert D == Mu[0].shape[1]

        priorMu = self.Prior.B / self.Prior.nu
        priorN = (1 - smoothFrac) * (self.Prior.nu)

        Div = np.zeros(K)
        s, logdet = np.linalg.slogdet(priorMu)
        logdet_prior = s * logdet
        for k in xrange(K):
            chol_Mu_k = np.linalg.cholesky(Mu[k])
            logdet_Mu_k = 2.0 * np.sum(np.log(np.diag(chol_Mu_k)))
            tr_PriorInvMu_k = np.trace(np.linalg.solve(Mu[k], priorMu))
            Div[k] = -0.5 * logdet_prior + 0.5 * logdet_Mu_k + \
                0.5 * tr_PriorInvMu_k - 0.5 * D
        return priorN * Div

    def getSerializableParamsForLocalStep(self):
        """ Get compact dict of params for local step.

        Returns
        -------
        Info : dict
        """
        if self.inferType == 'EM':
            raise NotImplementedError('TODO')
        return dict(
            inferType=self.inferType,
            K=self.K,
            D=self.D,
        )

    def fillSharedMemDictForLocalStep(self, ShMem=None):
        """ Get dict of shared mem arrays needed for parallel local step.

        Returns
        -------
        ShMem : dict of RawArray objects
        """
        if ShMem is None:
            ShMem = dict()
        if 'nu' in ShMem:
            fillSharedMemArray(ShMem['nu'], self.Post.nu)
            fillSharedMemArray(ShMem['cholB'], self._cholB('all'))
            fillSharedMemArray(ShMem['E_logdetL'], self._E_logdetL('all'))

        else:
            ShMem['nu'] = numpyToSharedMemArray(self.Post.nu)
            ShMem['cholB'] = numpyToSharedMemArray(self._cholB('all'))
            ShMem['E_logdetL'] = numpyToSharedMemArray(self._E_logdetL('all'))

        return ShMem

    def getLocalAndSummaryFunctionHandles(self):
        """ Get function handles for local step and summary step

        Useful for parallelized algorithms.

        Returns
        -------
        calcLocalParams : f handle
        calcSummaryStats : f handle
        """
        return calcLocalParams, calcSummaryStats
class DPGridModel(object):
    def __init__(self, fileName, **kwargs):
        self.patchModel = load_model_at_prefix(fileName)
        self.patchModelFileName = fileName
        self.calcGlobalParams(**kwargs)

    def calcGlobalParams(self, **kwargs):
        self.D = self.patchModel.obsModel.D
        self.K = self.patchModel.obsModel.K
        self.GP = ParamBag(K=self.K, D=self.D)
        self._calcAllocGP()
        self._calcObsGP()
        self._calcUGP(**kwargs)

    def _calcAllocGP(self):
        # Calculate DP parameters
        logPi = self.patchModel.allocModel.Elogbeta
        self.GP.setField('logPi', logPi, dims='K')

    def _calcObsGP(self):
        # Calculate zero-mean Gaussian parameters
        model = self.patchModel.obsModel
        logdetLam = model.GetCached('E_logdetL', 'all')
        self.GP.setField('logdetLam', logdetLam, dims='K')
        Lam = model.Post.nu[:, np.newaxis, np.newaxis] * inv(model.Post.B)
        self.GP.setField('Lam', Lam, dims=('K', 'D', 'D'))

    def _calcUGP(self, r=0.43, s2=0.21**2):
        self.GP.setField('r', r)
        self.GP.setField('s2', s2)

    def denoise(self, y, sigma, cleanI, T=8):
        self.print_denoising_info(y, cleanI)
        self.PgnPart = self.get_part_info(y)
        betas = self.get_annealing_schedule(sigma, T)
        x, u, uPart, logPi = self.init_x_u_logPi(y)
        for t in xrange(T):
            print('Iteration %d/%d' % (t + 1, T))
            beta = betas[t]
            print('updating z...')
            resp, respPart = self.update_z(beta, logPi, x, u, uPart)
            print('updating v...')
            v, vPart = self.update_v(beta, x, u, uPart, resp, respPart)
            print('updating u...')
            u, uPart = self.update_u(beta, x, v, vPart)
            print('updating x...')
            x = self.update_x(sigma, beta, y, v, vPart, u, uPart)
            print('PSNR: %.2f dB' % self.calcPSNR(x, cleanI))
        x = self.clip_pixel_intensity(x)
        finalPSNR = float(format(self.calcPSNR(x, cleanI), '.2f'))
        print('Final PSNR: %.2f dB' % finalPSNR)
        return x, finalPSNR

    def print_denoising_info(self, y, cleanI):
        patchSz = int(np.sqrt(self.D))
        print('Pretrained %s: K = %d clusters' %
              (self.__class__.__name__, self.K))
        print('Patch size: D = %d x %d pixels' % (patchSz, patchSz))
        print('Image size: %d x %d pixels' % y.shape)
        print('PSNR of the noisy image: %.2f dB' % self.calcPSNR(y, cleanI))

    def get_part_info(self, image):
        # Gathers information for partial patches; return a dict
        # whose keys are masks for observable pixels wrt a patch,
        # and values are indices of those pixels wrt the image
        patchSize = int(np.sqrt(self.D))
        H, W = image.shape
        HFull = H + (patchSize - 1) * 2
        WFull = W + (patchSize - 1) * 2
        imgFull = np.reshape(np.arange(HFull * WFull), (HFull, WFull))
        PgnFull = im2col(imgFull, patchSize).T
        NFull = PgnFull.shape[0]
        PgnPart = dict()
        for n in xrange(NFull):
            h, w = np.unravel_index(PgnFull[n], (HFull, WFull))
            hMask = np.logical_and(h >= patchSize - 1, h <= HFull - patchSize)
            wMask = np.logical_and(w >= patchSize - 1, w <= WFull - patchSize)
            mask = np.logical_and(hMask, wMask)
            if not np.all(mask):
                h = h[mask] - (patchSize - 1)
                w = w[mask] - (patchSize - 1)
                idx = np.ravel_multi_index(np.array([h, w]), (H, W))
                if tuple(mask) in PgnPart:
                    PgnPart[tuple(mask)] = np.vstack(
                        (PgnPart[tuple(mask)], idx))
                else:
                    PgnPart[tuple(mask)] = np.array([idx])
        return PgnPart

    def get_annealing_schedule(self, sigma, T):
        MINBETA = 0.5 / 255
        if sigma == MINBETA:
            betas = MINBETA * np.ones(T)
        elif sigma < MINBETA:
            raise ValueError('Noise std shouldn\'t be smaller than %f!' %
                             MINBETA)
        else:
            betaAneal = np.array([sigma])
            tmp = sigma / 2.0
            if tmp > MINBETA and T - len(betaAneal) > 0:
                betaAneal = np.append(betaAneal, np.array([tmp]))
            while tmp / np.sqrt(2) > MINBETA and T - len(betaAneal) > 0:
                tmp /= np.sqrt(2.0)
                betaAneal = np.append(betaAneal, np.array([tmp]))
            if T - len(betaAneal) > 0:
                betaReal = MINBETA * np.ones(T - len(betaAneal))
                betas = np.concatenate((betaAneal, betaReal))
            else:
                betas = betaAneal
        return betas

    def init_x_u_logPi(self, y):
        x = self._initX(y)
        u, uPart = self._initU(y)
        logPi = self._initLogPi()
        return x, u, uPart, logPi

    def _initX(self, y):
        return y.copy()

    def _initU(self, y):
        patchSize = int(np.sqrt(self.D))
        patches = im2col(y, patchSize)
        u = np.mean(patches, axis=0)
        uPart = dict()
        for mask, idx in self.PgnPart.items():
            uPart[mask] = np.mean(y.ravel()[idx], axis=1)
        return u, uPart

    def _initLogPi(self):
        return self.GP.logPi

    def update_z(self, beta, logPi, x, u, uPart, patchLst=None, IP=None):
        # fully observable patches
        if IP is None:
            IP = self.calcIterationParams(beta)
        D, K, GP, patchSize = self.D, self.K, self.GP, int(np.sqrt(self.D))
        Px_minus_u = im2col(x, patchSize) - u
        if patchLst is not None:
            Px_minus_u = Px_minus_u[:, patchLst]
        NFull = Px_minus_u.shape[1]
        resp = np.tile(logPi + 0.5 * (IP.logdetSigma + GP.logdetLam),
                       (NFull, 1))
        for k in xrange(K):
            tmp = solve_triangular(beta**2 * IP.Rc[k],
                                   Px_minus_u,
                                   lower=IP.Rlower[k],
                                   check_finite=False)
            resp[:, k] += .5 * np.einsum('dn,dn->n', tmp, tmp)
        resp = np.argmax(resp, axis=1)
        # partially observable patches
        respPart = dict()
        for mask, idx in self.PgnPart.items():
            maskLst = np.array(list(mask), dtype=bool)
            IPPart = self.calcIterationParams(beta, mask=maskLst)
            NPart = idx.shape[0]
            CT_Px_minus_u = np.zeros((D, NPart))
            CT_Px_minus_u[maskLst, :] = x.ravel()[idx].T - uPart[mask]
            this_resp = np.tile(
                logPi + 0.5 * (IPPart.logdetSigma + GP.logdetLam), (NPart, 1))
            for k in xrange(K):
                tmp = solve_triangular(beta**2 * IPPart.Rc[k],
                                       CT_Px_minus_u,
                                       lower=IPPart.Rlower[k],
                                       check_finite=False)
                this_resp[:, k] += .5 * np.einsum('dn,dn->n', tmp, tmp)
            respPart[mask] = np.argmax(this_resp, axis=1)
        return resp, respPart

    def calcIterationParams(self, std, mask=None):
        D, K, GP = self.D, self.K, self.GP
        IP = ParamBag(K=K, D=D)
        if mask is None:
            mask = np.ones(D, dtype=bool)
        invSigma = 1.0 / std**2 * np.diag(mask) + GP.Lam
        Rc = np.zeros((K, D, D))
        Rlower = np.ones(K, dtype=bool)
        for k in xrange(K):
            Rc[k], Rlower[k] = cho_factor(invSigma[k], lower=True)
        try:
            IP.setField('Rc', np.tril(Rc), dims=('K', 'D', 'D'))
        except ValueError:
            for k in xrange(K):
                Rc[k] = np.tril(Rc[k])
            IP.setField('Rc', Rc, dims=('K', 'D', 'D'))
        IP.setField('Rlower', Rlower, dims='K')
        logdetSigma = -2 * np.sum(np.log(np.diagonal(Rc, axis1=1, axis2=2)),
                                  axis=1)
        IP.setField('logdetSigma', logdetSigma, dims='K')
        return IP

    def update_v(self,
                 beta,
                 x,
                 u,
                 uPart,
                 resp,
                 respPart,
                 patchLst=None,
                 IP=None):
        # fully observable patches
        if IP is None:
            IP = self.calcIterationParams(beta)
        D, K, GP, patchSize = self.D, self.K, self.GP, int(np.sqrt(self.D))
        Px_minus_u = im2col(x, patchSize) - u
        if patchLst is not None:
            Px_minus_u = Px_minus_u[:, patchLst]
        NFull = Px_minus_u.shape[1]
        v = np.zeros((NFull, D))
        for k in xrange(K):
            idx_k = np.flatnonzero(resp == k)
            if len(idx_k) == 0:
                continue
            cho = (IP.Rc[k] * beta, bool(IP.Rlower[k]))
            v[idx_k] = cho_solve(cho,
                                 Px_minus_u[:, idx_k],
                                 overwrite_b=True,
                                 check_finite=False).T
        # partially observable patches
        vPart = dict()
        for mask, idx in self.PgnPart.items():
            maskLst = np.array(list(mask), dtype=bool)
            IPPart = self.calcIterationParams(beta, mask=maskLst)
            NPart = len(uPart[mask])
            CT_Px_minus_u = np.zeros((D, NPart))
            CT_Px_minus_u[maskLst, :] = x.ravel()[idx].T - uPart[mask]
            this_v = np.zeros((NPart, D))
            for k in xrange(K):
                idx_k = np.flatnonzero(respPart[mask] == k)
                if len(idx_k) == 0:
                    continue
                cho = (IPPart.Rc[k] * beta, bool(IPPart.Rlower[k]))
                this_v[idx_k] = cho_solve(cho,
                                          CT_Px_minus_u[:, idx_k],
                                          overwrite_b=True,
                                          check_finite=False).T
            vPart[mask] = this_v
        return v, vPart

    def update_u(self, beta, x, v, vPart, patchLst=None):
        # fully observable patches
        D, GP, patchSize = self.D, self.GP, int(np.sqrt(self.D))
        beta2inv = 1.0 / beta**2
        gamma2 = 1.0 / (1.0 / GP.s2 + D * beta2inv)
        patches = im2col(x, patchSize)
        if patchLst is not None:
            patches = patches[:, patchLst]
        Px_minus_v = patches.T - v
        u = gamma2 * (GP.r / GP.s2 + beta2inv * np.sum(Px_minus_v, axis=1))
        # partially observable patches
        uPart = dict()
        for mask, idx in self.PgnPart.items():
            maskLst = np.array(list(mask), dtype=bool)
            NPart, DPart = idx.shape
            gamma2 = 1.0 / (1.0 / GP.s2 + DPart * beta2inv)
            Px_minus_v = x.ravel()[idx] - vPart[mask][:, maskLst]
            uPart[mask] = gamma2 * (GP.r / GP.s2 +
                                    beta2inv * np.sum(Px_minus_v, axis=1))
        return u, uPart

    def update_x(self, sigma, beta, y, v, vPart, u, uPart):
        D, patchSize = self.D, int(np.sqrt(self.D))
        H, W = y.shape

        def piece_up_patches():
            result = col2im(v.T + u, patchSize, H, W, normalize=False).ravel()
            for mask, idx in self.PgnPart.items():
                maskLst = np.array(list(mask), dtype=bool)
                result += np.bincount(
                    idx.ravel(),
                    minlength=H * W,
                    weights=(uPart[mask][:, np.newaxis] +
                             vPart[mask][:, maskLst]).ravel())
            result /= D
            return result.reshape(y.shape)

        rec_from_patches = piece_up_patches()
        sigma2, beta2 = sigma**2, beta**2
        x = (beta2 * y + sigma2 * rec_from_patches) / (sigma2 + beta2)
        return x

    def clip_pixel_intensity(self, image):
        image[image < 0.0] = 0.0
        image[image > 1.0] = 1.0
        return image

    def calcPSNR(self, I, cleanI):
        return 20 * np.log10(1.0 / np.std(cleanI - I))