Пример #1
0
    def get_jk(self,
               dm,
               hermi=1,
               kpts=None,
               kpt_band=None,
               with_j=True,
               with_k=True,
               exxdiv='ewald'):
        if kpts is None:
            if numpy.all(self.kpts == 0):
                # Gamma-point calculation by default
                kpts = numpy.zeros(3)
            else:
                kpts = self.kpts

        if kpts.shape == (3, ):
            return pwdf_jk.get_jk(self, dm, hermi, kpts, kpt_band, with_j,
                                  with_k, exxdiv)

        vj = vk = None
        if with_k:
            vk = pwdf_jk.get_k_kpts(self, dm, hermi, kpts, kpt_band, exxdiv)
        if with_j:
            vj = pwdf_jk.get_j_kpts(self, dm, hermi, kpts, kpt_band)
        return vj, vk
Пример #2
0
    def get_jk(self, dm, hermi=1, kpts=None, kpt_band=None,
               with_j=True, with_k=True, exxdiv='ewald'):
        if kpts is None:
            if numpy.all(self.kpts == 0):
                # Gamma-point calculation by default
                kpts = numpy.zeros(3)
            else:
                kpts = self.kpts

        if kpts.shape == (3,):
            return pwdf_jk.get_jk(self, dm, hermi, kpts, kpt_band, with_j,
                                  with_k, exxdiv)

        vj = vk = None
        if with_k:
            vk = pwdf_jk.get_k_kpts(self, dm, hermi, kpts, kpt_band, exxdiv)
        if with_j:
            vj = pwdf_jk.get_j_kpts(self, dm, hermi, kpts, kpt_band)
        return vj, vk
Пример #3
0
def get_jk(mydf, dm, hermi=1, kpt=numpy.zeros(3),
           kpt_band=None, with_j=True, with_k=True, exxdiv=None):
    '''JK for given k-point'''
    vj = vk = None
    if kpt_band is not None and abs(kpt-kpt_band).sum() > 1e-9:
        kpt = numpy.reshape(kpt, (1,3))
        if with_k:
            vk = get_k_kpts(mydf, [dm], hermi, kpt, kpt_band, exxdiv)
        if with_j:
            vj = get_j_kpts(mydf, [dm], hermi, kpt, kpt_band)
        return vj, vk

    log = logger.Logger(mydf.stdout, mydf.verbose)
    t2 = t1 = (time.clock(), time.time())
    if mydf._cderi is None:
        mydf.build()
        t1 = log.timer_debug1('Init get_jk', *t1)

    dm = numpy.asarray(dm, order='C')
    dms = _format_dms(dm, [kpt])
    nset, _, nao = dms.shape[:3]
    dms = dms.reshape(nset,nao,nao)
    j_real = gamma_point(kpt)
    k_real = gamma_point(kpt) and not numpy.iscomplexobj(dms)
    kptii = numpy.asarray((kpt,kpt))

# .45 is estimation for the memory usage ratio  sr_loop / (sr_loop+bufR+bufI)
    dmsR = numpy.asarray(dms.real.reshape(nset,nao,nao), order='C')
    dmsI = numpy.asarray(dms.imag.reshape(nset,nao,nao), order='C')
    if with_j:
        vjR = numpy.zeros((nset,nao,nao))
        vjI = numpy.zeros((nset,nao,nao))
    if with_k:
        vkR = numpy.zeros((nset,nao,nao))
        vkI = numpy.zeros((nset,nao,nao))
    max_memory = max(2000, (mydf.max_memory - lib.current_memory()[0])) * .45
    if with_k:
        buf1R = numpy.empty((mydf.blockdim*nao**2))
        buf2R = numpy.empty((mydf.blockdim*nao**2))
        buf3R = numpy.empty((mydf.blockdim*nao**2))
        if not k_real:
            buf1I = numpy.empty((mydf.blockdim*nao**2))
            buf2I = numpy.empty((mydf.blockdim*nao**2))
            buf3I = numpy.empty((mydf.blockdim*nao**2))
    def contract_k(pLqR, pLqI, pjqR, pjqI):
        # K ~ 'iLj,lLk*,li->kj' + 'lLk*,iLj,li->kj'
        #:Lpq = LpqR + LpqI*1j
        #:j3c = j3cR + j3cI*1j
        #:for i in range(nset):
        #:    tmp = numpy.dot(dms[i], j3c.reshape(nao,-1))
        #:    vk1 = numpy.dot(Lpq.reshape(-1,nao).conj().T, tmp.reshape(-1,nao))
        #:    tmp = numpy.dot(dms[i], Lpq.reshape(nao,-1))
        #:    vk1+= numpy.dot(j3c.reshape(-1,nao).conj().T, tmp.reshape(-1,nao))
        #:    vkR[i] += vk1.real
        #:    vkI[i] += vk1.imag
        nrow = pLqR.shape[1]
        tmpR = numpy.ndarray((nao,nrow*nao), buffer=buf3R)
        if k_real:
            for i in range(nset):
                tmpR = lib.ddot(dmsR[i], pjqR.reshape(nao,-1), 1, tmpR)
                vk1R = lib.ddot(pLqR.reshape(-1,nao).T, tmpR.reshape(-1,nao))
                vkR[i] += vk1R
                if hermi:
                    vkR[i] += vk1R.T
                else:
                    tmpR = lib.ddot(dmsR[i], pLqR.reshape(nao,-1), 1, tmpR)
                    lib.ddot(pjqR.reshape(-1,nao).T, tmpR.reshape(-1,nao),
                             1, vkR[i], 1)
        else:
            tmpI = numpy.ndarray((nao,nrow*nao), buffer=buf3I)
            for i in range(nset):
                tmpR, tmpI = zdotNN(dmsR[i], dmsI[i], pjqR.reshape(nao,-1),
                                    pjqI.reshape(nao,-1), 1, tmpR, tmpI, 0)
                vk1R, vk1I = zdotCN(pLqR.reshape(-1,nao).T, pLqI.reshape(-1,nao).T,
                                    tmpR.reshape(-1,nao), tmpI.reshape(-1,nao))
                vkR[i] += vk1R
                vkI[i] += vk1I
                if hermi:
                    vkR[i] += vk1R.T
                    vkI[i] -= vk1I.T
                else:
                    tmpR, tmpI = zdotNN(dmsR[i], dmsI[i], pLqR.reshape(nao,-1),
                                        pLqI.reshape(nao,-1), 1, tmpR, tmpI, 0)
                    zdotCN(pjqR.reshape(-1,nao).T, pjqI.reshape(-1,nao).T,
                           tmpR.reshape(-1,nao), tmpI.reshape(-1,nao),
                           1, vkR[i], vkI[i], 1)

    pLqI = pjqI = None
    thread_k = None
    for LpqR, LpqI, j3cR, j3cI in mydf.sr_loop(kptii, max_memory, False):
        LpqR = LpqR.reshape(-1,nao,nao)
        LpqI = LpqI.reshape(-1,nao,nao)
        j3cR = j3cR.reshape(-1,nao,nao)
        j3cI = j3cI.reshape(-1,nao,nao)
        t2 = log.timer_debug1('        load', *t2)
        if thread_k is not None:
            thread_k.join()
        if with_j:
            #:rho_coeff = numpy.einsum('Lpq,xqp->xL', Lpq, dms)
            #:jaux = numpy.einsum('Lpq,xqp->xL', j3c, dms)
            #:vj += numpy.dot(jaux, Lpq.reshape(-1,nao**2))
            #:vj += numpy.dot(rho_coeff, j3c.reshape(-1,nao**2))
            rhoR  = numpy.einsum('Lpq,xqp->xL', LpqR, dmsR)
            jauxR = numpy.einsum('Lpq,xqp->xL', j3cR, dmsR)
            if not j_real:
                rhoR -= numpy.einsum('Lpq,xqp->xL', LpqI, dmsI)
                rhoI  = numpy.einsum('Lpq,xqp->xL', LpqR, dmsI)
                rhoI += numpy.einsum('Lpq,xqp->xL', LpqI, dmsR)
                jauxR-= numpy.einsum('Lpq,xqp->xL', j3cI, dmsI)
                jauxI = numpy.einsum('Lpq,xqp->xL', j3cR, dmsI)
                jauxI+= numpy.einsum('Lpq,xqp->xL', j3cI, dmsR)
            vjR += numpy.einsum('xL,Lpq->xpq', jauxR, LpqR)
            vjR += numpy.einsum('xL,Lpq->xpq', rhoR, j3cR)
            if not j_real:
                vjR -= numpy.einsum('xL,Lpq->xpq', jauxI, LpqI)
                vjR -= numpy.einsum('xL,Lpq->xpq', rhoI, j3cI)
                vjI += numpy.einsum('xL,Lpq->xpq', jauxR, LpqI)
                vjI += numpy.einsum('xL,Lpq->xpq', jauxI, LpqR)
                vjI += numpy.einsum('xL,Lpq->xpq', rhoR, j3cI)
                vjI += numpy.einsum('xL,Lpq->xpq', rhoI, j3cR)
        t2 = log.timer_debug1('        with_j', *t2)
        if with_k:
            nrow = LpqR.shape[0]
            pLqR = numpy.ndarray((nao,nrow,nao), buffer=buf1R)
            pjqR = numpy.ndarray((nao,nrow,nao), buffer=buf2R)
            pLqR[:] = LpqR.transpose(1,0,2)
            pjqR[:] = j3cR.transpose(1,0,2)
            if not k_real:
                pLqI = numpy.ndarray((nao,nrow,nao), buffer=buf1I)
                pjqI = numpy.ndarray((nao,nrow,nao), buffer=buf2I)
                pLqI[:] = LpqI.transpose(1,0,2)
                pjqI[:] = j3cI.transpose(1,0,2)

            thread_k = lib.background_thread(contract_k, pLqR, pLqI, pjqR, pjqI)
            t2 = log.timer_debug1('        with_k', *t2)
        LpqR = LpqI = j3cR = j3cI = None
    if thread_k is not None:
        thread_k.join()
    thread_k = None
    t1 = log.timer_debug1('mdf_jk.get_jk pass 1', *t1)

    vj, vk = pwdf_jk.get_jk(mydf, dm, hermi, kpt, kpt_band, with_j, with_k, exxdiv)
    if with_j:
        if j_real:
            vj += vjR.reshape(dm.shape)
        else:
            vj += (vjR+vjI*1j).reshape(dm.shape)
        vj = vj
    if with_k:
        if k_real:
            vk += vkR.reshape(dm.shape)
        else:
            vk += (vkR+vkI*1j).reshape(dm.shape)
    return vj, vk