def get_jk(self, dm, hermi=1, kpts=None, kpt_band=None, with_j=True, with_k=True, exxdiv='ewald'): if kpts is None: if numpy.all(self.kpts == 0): # Gamma-point calculation by default kpts = numpy.zeros(3) else: kpts = self.kpts if kpts.shape == (3, ): return pwdf_jk.get_jk(self, dm, hermi, kpts, kpt_band, with_j, with_k, exxdiv) vj = vk = None if with_k: vk = pwdf_jk.get_k_kpts(self, dm, hermi, kpts, kpt_band, exxdiv) if with_j: vj = pwdf_jk.get_j_kpts(self, dm, hermi, kpts, kpt_band) return vj, vk
def get_jk(self, dm, hermi=1, kpts=None, kpt_band=None, with_j=True, with_k=True, exxdiv='ewald'): if kpts is None: if numpy.all(self.kpts == 0): # Gamma-point calculation by default kpts = numpy.zeros(3) else: kpts = self.kpts if kpts.shape == (3,): return pwdf_jk.get_jk(self, dm, hermi, kpts, kpt_band, with_j, with_k, exxdiv) vj = vk = None if with_k: vk = pwdf_jk.get_k_kpts(self, dm, hermi, kpts, kpt_band, exxdiv) if with_j: vj = pwdf_jk.get_j_kpts(self, dm, hermi, kpts, kpt_band) return vj, vk
def get_jk(mydf, dm, hermi=1, kpt=numpy.zeros(3), kpt_band=None, with_j=True, with_k=True, exxdiv=None): '''JK for given k-point''' vj = vk = None if kpt_band is not None and abs(kpt-kpt_band).sum() > 1e-9: kpt = numpy.reshape(kpt, (1,3)) if with_k: vk = get_k_kpts(mydf, [dm], hermi, kpt, kpt_band, exxdiv) if with_j: vj = get_j_kpts(mydf, [dm], hermi, kpt, kpt_band) return vj, vk log = logger.Logger(mydf.stdout, mydf.verbose) t2 = t1 = (time.clock(), time.time()) if mydf._cderi is None: mydf.build() t1 = log.timer_debug1('Init get_jk', *t1) dm = numpy.asarray(dm, order='C') dms = _format_dms(dm, [kpt]) nset, _, nao = dms.shape[:3] dms = dms.reshape(nset,nao,nao) j_real = gamma_point(kpt) k_real = gamma_point(kpt) and not numpy.iscomplexobj(dms) kptii = numpy.asarray((kpt,kpt)) # .45 is estimation for the memory usage ratio sr_loop / (sr_loop+bufR+bufI) dmsR = numpy.asarray(dms.real.reshape(nset,nao,nao), order='C') dmsI = numpy.asarray(dms.imag.reshape(nset,nao,nao), order='C') if with_j: vjR = numpy.zeros((nset,nao,nao)) vjI = numpy.zeros((nset,nao,nao)) if with_k: vkR = numpy.zeros((nset,nao,nao)) vkI = numpy.zeros((nset,nao,nao)) max_memory = max(2000, (mydf.max_memory - lib.current_memory()[0])) * .45 if with_k: buf1R = numpy.empty((mydf.blockdim*nao**2)) buf2R = numpy.empty((mydf.blockdim*nao**2)) buf3R = numpy.empty((mydf.blockdim*nao**2)) if not k_real: buf1I = numpy.empty((mydf.blockdim*nao**2)) buf2I = numpy.empty((mydf.blockdim*nao**2)) buf3I = numpy.empty((mydf.blockdim*nao**2)) def contract_k(pLqR, pLqI, pjqR, pjqI): # K ~ 'iLj,lLk*,li->kj' + 'lLk*,iLj,li->kj' #:Lpq = LpqR + LpqI*1j #:j3c = j3cR + j3cI*1j #:for i in range(nset): #: tmp = numpy.dot(dms[i], j3c.reshape(nao,-1)) #: vk1 = numpy.dot(Lpq.reshape(-1,nao).conj().T, tmp.reshape(-1,nao)) #: tmp = numpy.dot(dms[i], Lpq.reshape(nao,-1)) #: vk1+= numpy.dot(j3c.reshape(-1,nao).conj().T, tmp.reshape(-1,nao)) #: vkR[i] += vk1.real #: vkI[i] += vk1.imag nrow = pLqR.shape[1] tmpR = numpy.ndarray((nao,nrow*nao), buffer=buf3R) if k_real: for i in range(nset): tmpR = lib.ddot(dmsR[i], pjqR.reshape(nao,-1), 1, tmpR) vk1R = lib.ddot(pLqR.reshape(-1,nao).T, tmpR.reshape(-1,nao)) vkR[i] += vk1R if hermi: vkR[i] += vk1R.T else: tmpR = lib.ddot(dmsR[i], pLqR.reshape(nao,-1), 1, tmpR) lib.ddot(pjqR.reshape(-1,nao).T, tmpR.reshape(-1,nao), 1, vkR[i], 1) else: tmpI = numpy.ndarray((nao,nrow*nao), buffer=buf3I) for i in range(nset): tmpR, tmpI = zdotNN(dmsR[i], dmsI[i], pjqR.reshape(nao,-1), pjqI.reshape(nao,-1), 1, tmpR, tmpI, 0) vk1R, vk1I = zdotCN(pLqR.reshape(-1,nao).T, pLqI.reshape(-1,nao).T, tmpR.reshape(-1,nao), tmpI.reshape(-1,nao)) vkR[i] += vk1R vkI[i] += vk1I if hermi: vkR[i] += vk1R.T vkI[i] -= vk1I.T else: tmpR, tmpI = zdotNN(dmsR[i], dmsI[i], pLqR.reshape(nao,-1), pLqI.reshape(nao,-1), 1, tmpR, tmpI, 0) zdotCN(pjqR.reshape(-1,nao).T, pjqI.reshape(-1,nao).T, tmpR.reshape(-1,nao), tmpI.reshape(-1,nao), 1, vkR[i], vkI[i], 1) pLqI = pjqI = None thread_k = None for LpqR, LpqI, j3cR, j3cI in mydf.sr_loop(kptii, max_memory, False): LpqR = LpqR.reshape(-1,nao,nao) LpqI = LpqI.reshape(-1,nao,nao) j3cR = j3cR.reshape(-1,nao,nao) j3cI = j3cI.reshape(-1,nao,nao) t2 = log.timer_debug1(' load', *t2) if thread_k is not None: thread_k.join() if with_j: #:rho_coeff = numpy.einsum('Lpq,xqp->xL', Lpq, dms) #:jaux = numpy.einsum('Lpq,xqp->xL', j3c, dms) #:vj += numpy.dot(jaux, Lpq.reshape(-1,nao**2)) #:vj += numpy.dot(rho_coeff, j3c.reshape(-1,nao**2)) rhoR = numpy.einsum('Lpq,xqp->xL', LpqR, dmsR) jauxR = numpy.einsum('Lpq,xqp->xL', j3cR, dmsR) if not j_real: rhoR -= numpy.einsum('Lpq,xqp->xL', LpqI, dmsI) rhoI = numpy.einsum('Lpq,xqp->xL', LpqR, dmsI) rhoI += numpy.einsum('Lpq,xqp->xL', LpqI, dmsR) jauxR-= numpy.einsum('Lpq,xqp->xL', j3cI, dmsI) jauxI = numpy.einsum('Lpq,xqp->xL', j3cR, dmsI) jauxI+= numpy.einsum('Lpq,xqp->xL', j3cI, dmsR) vjR += numpy.einsum('xL,Lpq->xpq', jauxR, LpqR) vjR += numpy.einsum('xL,Lpq->xpq', rhoR, j3cR) if not j_real: vjR -= numpy.einsum('xL,Lpq->xpq', jauxI, LpqI) vjR -= numpy.einsum('xL,Lpq->xpq', rhoI, j3cI) vjI += numpy.einsum('xL,Lpq->xpq', jauxR, LpqI) vjI += numpy.einsum('xL,Lpq->xpq', jauxI, LpqR) vjI += numpy.einsum('xL,Lpq->xpq', rhoR, j3cI) vjI += numpy.einsum('xL,Lpq->xpq', rhoI, j3cR) t2 = log.timer_debug1(' with_j', *t2) if with_k: nrow = LpqR.shape[0] pLqR = numpy.ndarray((nao,nrow,nao), buffer=buf1R) pjqR = numpy.ndarray((nao,nrow,nao), buffer=buf2R) pLqR[:] = LpqR.transpose(1,0,2) pjqR[:] = j3cR.transpose(1,0,2) if not k_real: pLqI = numpy.ndarray((nao,nrow,nao), buffer=buf1I) pjqI = numpy.ndarray((nao,nrow,nao), buffer=buf2I) pLqI[:] = LpqI.transpose(1,0,2) pjqI[:] = j3cI.transpose(1,0,2) thread_k = lib.background_thread(contract_k, pLqR, pLqI, pjqR, pjqI) t2 = log.timer_debug1(' with_k', *t2) LpqR = LpqI = j3cR = j3cI = None if thread_k is not None: thread_k.join() thread_k = None t1 = log.timer_debug1('mdf_jk.get_jk pass 1', *t1) vj, vk = pwdf_jk.get_jk(mydf, dm, hermi, kpt, kpt_band, with_j, with_k, exxdiv) if with_j: if j_real: vj += vjR.reshape(dm.shape) else: vj += (vjR+vjI*1j).reshape(dm.shape) vj = vj if with_k: if k_real: vk += vkR.reshape(dm.shape) else: vk += (vkR+vkI*1j).reshape(dm.shape) return vj, vk