def async_read(icomp, row0, row1, thread_read): buf_current, buf_prefetch = reading_frame reading_frame[:] = [buf_prefetch, buf_current] if thread_read is None: _load_from_h5g(fswap['%d'%icomp], row0, row1, buf_current) else: thread_read.join() thread_read = lib.background_thread(prefetch, icomp, row0, row1, buf_prefetch) return buf_current[:row1-row0], thread_read
def get_jk(mydf, dm, hermi=1, kpt=numpy.zeros(3), kpts_band=None, with_j=True, with_k=True, exxdiv=None): '''JK for given k-point''' vj = vk = None if kpts_band is not None and abs(kpt - kpts_band).sum() > 1e-9: kpt = numpy.reshape(kpt, (1, 3)) if with_k: vk = get_k_kpts(mydf, dm, hermi, kpt, kpts_band, exxdiv) if with_j: vj = get_j_kpts(mydf, dm, hermi, kpt, kpts_band) return vj, vk cell = mydf.cell log = logger.Logger(mydf.stdout, mydf.verbose) t1 = (time.clock(), time.time()) if mydf._cderi is None or not mydf.has_kpts(kpts_band): if mydf._cderi is not None: log.warn( 'DF integrals for band k-points were not found %s. ' 'DF integrals will be rebuilt to include band k-points.', mydf._cderi) mydf.build(kpts_band=kpts_band) t1 = log.timer_debug1('Init get_jk', *t1) dm = numpy.asarray(dm, order='C') dms = _format_dms(dm, [kpt]) nset, _, nao = dms.shape[:3] dms = dms.reshape(nset, nao, nao) j_real = gamma_point(kpt) k_real = gamma_point(kpt) and not numpy.iscomplexobj(dms) kptii = numpy.asarray((kpt, kpt)) dmsR = dms.real.reshape(nset, nao, nao) dmsI = dms.imag.reshape(nset, nao, nao) mem_now = lib.current_memory()[0] max_memory = max(2000, (mydf.max_memory - mem_now)) if with_j: vjR = numpy.zeros((nset, nao, nao)) vjI = numpy.zeros((nset, nao, nao)) if with_k: vkR = numpy.zeros((nset, nao, nao)) vkI = numpy.zeros((nset, nao, nao)) buf1R = numpy.empty((mydf.blockdim * nao**2)) buf2R = numpy.empty((mydf.blockdim * nao**2)) buf1I = numpy.zeros((mydf.blockdim * nao**2)) buf2I = numpy.empty((mydf.blockdim * nao**2)) max_memory *= .5 log.debug1('max_memory = %d MB (%d in use)', max_memory, mem_now) def contract_k(pLqR, pLqI): # K ~ 'iLj,lLk*,li->kj' + 'lLk*,iLj,li->kj' #:pLq = (LpqR + LpqI.reshape(-1,nao,nao)*1j).transpose(1,0,2) #:tmp = numpy.dot(dm, pLq.reshape(nao,-1)) #:vk += numpy.dot(pLq.reshape(-1,nao).conj().T, tmp.reshape(-1,nao)) nrow = pLqR.shape[1] tmpR = numpy.ndarray((nao, nrow * nao), buffer=buf2R) if k_real: for i in range(nset): lib.ddot(dmsR[i], pLqR.reshape(nao, -1), 1, tmpR) lib.ddot( pLqR.reshape(-1, nao).T, tmpR.reshape(-1, nao), 1, vkR[i], 1) else: tmpI = numpy.ndarray((nao, nrow * nao), buffer=buf2I) for i in range(nset): zdotNN(dmsR[i], dmsI[i], pLqR.reshape(nao, -1), pLqI.reshape(nao, -1), 1, tmpR, tmpI, 0) zdotCN( pLqR.reshape(-1, nao).T, pLqI.reshape(-1, nao).T, tmpR.reshape(-1, nao), tmpI.reshape(-1, nao), 1, vkR[i], vkI[i], 1) pLqI = None thread_k = None for LpqR, LpqI in mydf.sr_loop(kptii, max_memory, False): LpqR = LpqR.reshape(-1, nao, nao) t1 = log.timer_debug1(' load', *t1) if thread_k is not None: thread_k.join() if with_j: #:rho_coeff = numpy.einsum('Lpq,xqp->xL', Lpq, dms) #:vj += numpy.dot(rho_coeff, Lpq.reshape(-1,nao**2)) rhoR = numpy.einsum('Lpq,xpq->xL', LpqR, dmsR) if not j_real: LpqI = LpqI.reshape(-1, nao, nao) rhoR -= numpy.einsum('Lpq,xpq->xL', LpqI, dmsI) rhoI = numpy.einsum('Lpq,xpq->xL', LpqR, dmsI) rhoI += numpy.einsum('Lpq,xpq->xL', LpqI, dmsR) vjR += numpy.einsum('xL,Lpq->xpq', rhoR, LpqR) if not j_real: vjR -= numpy.einsum('xL,Lpq->xpq', rhoI, LpqI) vjI += numpy.einsum('xL,Lpq->xpq', rhoR, LpqI) vjI += numpy.einsum('xL,Lpq->xpq', rhoI, LpqR) t1 = log.timer_debug1(' with_j', *t1) if with_k: nrow = LpqR.shape[0] pLqR = numpy.ndarray((nao, nrow, nao), buffer=buf1R) pLqR[:] = LpqR.transpose(1, 0, 2) if not k_real: pLqI = numpy.ndarray((nao, nrow, nao), buffer=buf1I) if LpqI is not None: pLqI[:] = LpqI.reshape(-1, nao, nao).transpose(1, 0, 2) thread_k = lib.background_thread(contract_k, pLqR, pLqI) t1 = log.timer_debug1(' with_k', *t1) LpqR = LpqI = pLqR = pLqI = None if thread_k is not None: thread_k.join() thread_k = None if with_j: if j_real: vj = vjR else: vj = vjR + vjI * 1j vj = vj.reshape(dm.shape) if with_k: if k_real: vk = vkR else: vk = vkR + vkI * 1j if exxdiv: assert (exxdiv.lower() == 'ewald') _ewald_exxdiv_for_G0(cell, kpt, dms, vk) vk = vk.reshape(dm.shape) t1 = log.timer('sr jk', *t1) return vj, vk
def kernel(mycc, eris, t1=None, t2=None, verbose=logger.NOTE): if isinstance(verbose, logger.Logger): log = verbose else: log = logger.Logger(mycc.stdout, verbose) cpu1 = cpu0 = (time.clock(), time.time()) if t1 is None: t1 = mycc.t1 if t2 is None: t2 = mycc.t2 nocc, nvir = t1.shape nmo = nocc + nvir _tmpfile = tempfile.NamedTemporaryFile(dir=lib.param.TMPDIR) ftmp = h5py.File(_tmpfile.name) eris_vvop = ftmp.create_dataset('vvop', (nvir,nvir,nocc,nmo), 'f8') orbsym = _sort_eri(mycc, eris, nocc, nvir, eris_vvop, log) ftmp['t2'] = t2 # read back late. Cache t2T in t2 to reduce memory footprint mo_energy, t1T, t2T, vooo = _sort_t2_vooo(mycc, orbsym, t1, t2, eris) cpu2 = [time.clock(), time.time()] orbsym = numpy.hstack((numpy.sort(orbsym[:nocc]),numpy.sort(orbsym[nocc:]))) o_ir_loc = numpy.append(0, numpy.cumsum(numpy.bincount(orbsym[:nocc], minlength=8))) v_ir_loc = numpy.append(0, numpy.cumsum(numpy.bincount(orbsym[nocc:], minlength=8))) o_sym = orbsym[:nocc] oo_sym = (o_sym[:,None] ^ o_sym).ravel() oo_ir_loc = numpy.append(0, numpy.cumsum(numpy.bincount(oo_sym, minlength=8))) nirrep = max(oo_sym) + 1 orbsym = orbsym.astype(numpy.int32) o_ir_loc = o_ir_loc.astype(numpy.int32) v_ir_loc = v_ir_loc.astype(numpy.int32) oo_ir_loc = oo_ir_loc.astype(numpy.int32) def contract(a0, a1, b0, b1, cache): cache_row_a, cache_col_a, cache_row_b, cache_col_b = cache drv = _ccsd.libcc.CCsd_t_contract drv.restype = ctypes.c_double et = drv(mo_energy.ctypes.data_as(ctypes.c_void_p), t1T.ctypes.data_as(ctypes.c_void_p), t2T.ctypes.data_as(ctypes.c_void_p), vooo.ctypes.data_as(ctypes.c_void_p), ctypes.c_int(nocc), ctypes.c_int(nvir), ctypes.c_int(a0), ctypes.c_int(a1), ctypes.c_int(b0), ctypes.c_int(b1), ctypes.c_int(nirrep), o_ir_loc.ctypes.data_as(ctypes.c_void_p), v_ir_loc.ctypes.data_as(ctypes.c_void_p), oo_ir_loc.ctypes.data_as(ctypes.c_void_p), orbsym.ctypes.data_as(ctypes.c_void_p), cache_row_a.ctypes.data_as(ctypes.c_void_p), cache_col_a.ctypes.data_as(ctypes.c_void_p), cache_row_b.ctypes.data_as(ctypes.c_void_p), cache_col_b.ctypes.data_as(ctypes.c_void_p)) cpu2[:] = log.timer_debug1('contract %d:%d,%d:%d'%(a0,a1,b0,b1), *cpu2) return et def tril_prange(start, stop, step): cum_costs = numpy.arange(stop+1)**2 tasks = balance_partition(cum_costs, step, start, stop) return tasks # The rest 20% memory for cache b mem_now = lib.current_memory()[0] max_memory = max(2000, mycc.max_memory - mem_now) bufsize = max(1, (max_memory*1e6/8-nocc**3*100)*.7/(nocc*nmo)) log.debug('max_memory %d MB (%d MB in use)', max_memory, mem_now) et = 0 handler = None for a0, a1, na in reversed(tril_prange(0, nvir, bufsize)): if handler is not None: et += handler.get() handler = None gc.collect() # DO NOT prefetch here to reserve more memory for cache_a cache_row_a = numpy.asarray(eris_vvop[a0:a1,:a1]) cache_col_a = numpy.asarray(eris_vvop[:a0,a0:a1]) handler = lib.background_thread(contract, a0, a1, a0, a1, (cache_row_a,cache_col_a, cache_row_a,cache_col_a)) for b0, b1, nb in tril_prange(0, a0, bufsize/10): cache_row_b = numpy.asarray(eris_vvop[b0:b1,:b1]) cache_col_b = numpy.asarray(eris_vvop[:b0,b0:b1]) if handler is not None: et += handler.get() handler = None gc.collect() handler = lib.background_thread(contract, a0, a1, b0, b1, (cache_row_a,cache_col_a, cache_row_b,cache_col_b)) cache_row_b = cache_col_b = None cache_row_a = cache_col_a = None if handler is not None: et += handler.get() handler = None t2[:] = ftmp['t2'] ftmp.close() _tmpfile = None et *= 2 log.timer('CCSD(T)', *cpu0) log.note('CCSD(T) correction = %.15g', et) return et
def get_jk(mydf, dm, hermi=1, kpt=numpy.zeros(3), kpt_band=None, with_j=True, with_k=True, exxdiv=None): """JK for given k-point""" vj = vk = None if kpt_band is not None and abs(kpt - kpt_band).sum() > 1e-9: kpt = numpy.reshape(kpt, (1, 3)) if with_k: vk = get_k_kpts(mydf, [dm], hermi, kpt, kpt_band, exxdiv) if with_j: vj = get_j_kpts(mydf, [dm], hermi, kpt, kpt_band) return vj, vk cell = mydf.cell log = logger.Logger(mydf.stdout, mydf.verbose) t1 = (time.clock(), time.time()) if mydf._cderi is None: mydf.build() t1 = log.timer_debug1("Init get_jk", *t1) dm = numpy.asarray(dm, order="C") dms = _format_dms(dm, [kpt]) nset, _, nao = dms.shape[:3] dms = dms.reshape(nset, nao, nao) j_real = gamma_point(kpt) k_real = gamma_point(kpt) and not numpy.iscomplexobj(dms) kptii = numpy.asarray((kpt, kpt)) dmsR = dms.real.reshape(nset, nao, nao) dmsI = dms.imag.reshape(nset, nao, nao) mem_now = lib.current_memory()[0] max_memory = max(2000, (mydf.max_memory - mem_now)) if with_j: vjR = numpy.zeros((nset, nao, nao)) vjI = numpy.zeros((nset, nao, nao)) if with_k: vkR = numpy.zeros((nset, nao, nao)) vkI = numpy.zeros((nset, nao, nao)) buf1R = numpy.empty((mydf.blockdim * nao ** 2)) buf2R = numpy.empty((mydf.blockdim * nao ** 2)) buf1I = numpy.zeros((mydf.blockdim * nao ** 2)) buf2I = numpy.empty((mydf.blockdim * nao ** 2)) max_memory *= 0.5 log.debug1("max_memory = %d MB (%d in use)", max_memory, mem_now) def contract_k(pLqR, pLqI): # K ~ 'iLj,lLk*,li->kj' + 'lLk*,iLj,li->kj' #:pLq = (LpqR + LpqI.reshape(-1,nao,nao)*1j).transpose(1,0,2) #:tmp = numpy.dot(dm, pLq.reshape(nao,-1)) #:vk += numpy.dot(pLq.reshape(-1,nao).conj().T, tmp.reshape(-1,nao)) nrow = pLqR.shape[1] tmpR = numpy.ndarray((nao, nrow * nao), buffer=buf2R) if k_real: for i in range(nset): lib.ddot(dmsR[i], pLqR.reshape(nao, -1), 1, tmpR) lib.ddot(pLqR.reshape(-1, nao).T, tmpR.reshape(-1, nao), 1, vkR[i], 1) else: tmpI = numpy.ndarray((nao, nrow * nao), buffer=buf2I) for i in range(nset): zdotNN(dmsR[i], dmsI[i], pLqR.reshape(nao, -1), pLqI.reshape(nao, -1), 1, tmpR, tmpI, 0) zdotCN( pLqR.reshape(-1, nao).T, pLqI.reshape(-1, nao).T, tmpR.reshape(-1, nao), tmpI.reshape(-1, nao), 1, vkR[i], vkI[i], 1, ) pLqI = None thread_k = None for LpqR, LpqI in mydf.sr_loop(kptii, max_memory, False): LpqR = LpqR.reshape(-1, nao, nao) t1 = log.timer_debug1(" load", *t1) if thread_k is not None: thread_k.join() if with_j: #:rho_coeff = numpy.einsum('Lpq,xqp->xL', Lpq, dms) #:vj += numpy.dot(rho_coeff, Lpq.reshape(-1,nao**2)) rhoR = numpy.einsum("Lpq,xpq->xL", LpqR, dmsR) if not j_real: LpqI = LpqI.reshape(-1, nao, nao) rhoR -= numpy.einsum("Lpq,xpq->xL", LpqI, dmsI) rhoI = numpy.einsum("Lpq,xpq->xL", LpqR, dmsI) rhoI += numpy.einsum("Lpq,xpq->xL", LpqI, dmsR) vjR += numpy.einsum("xL,Lpq->xpq", rhoR, LpqR) if not j_real: vjR -= numpy.einsum("xL,Lpq->xpq", rhoI, LpqI) vjI += numpy.einsum("xL,Lpq->xpq", rhoR, LpqI) vjI += numpy.einsum("xL,Lpq->xpq", rhoI, LpqR) t1 = log.timer_debug1(" with_j", *t1) if with_k: nrow = LpqR.shape[0] pLqR = numpy.ndarray((nao, nrow, nao), buffer=buf1R) pLqR[:] = LpqR.transpose(1, 0, 2) if not k_real: pLqI = numpy.ndarray((nao, nrow, nao), buffer=buf1I) if LpqI is not None: pLqI[:] = LpqI.reshape(-1, nao, nao).transpose(1, 0, 2) thread_k = lib.background_thread(contract_k, pLqR, pLqI) t1 = log.timer_debug1(" with_k", *t1) LpqR = LpqI = pLqR = pLqI = tmpR = tmpI = None if thread_k is not None: thread_k.join() thread_k = None if with_j: if j_real: vj = vjR else: vj = vjR + vjI * 1j vj = vj.reshape(dm.shape) if with_k: if k_real: vk = vkR else: vk = vkR + vkI * 1j if exxdiv is not None: assert exxdiv.lower() == "ewald" _ewald_exxdiv_for_G0(cell, kpt, dms, vk) vk = vk.reshape(dm.shape) t1 = log.timer("sr jk", *t1) return vj, vk
def async_do(handler, fn, *args): if handler is not None: handler.join() handler = lib.background_thread(fn, *args) return handler
def async_write(istep, iobuf, thread_io): if thread_io is not None: thread_io.join() thread_io = lib.background_thread(save, istep, iobuf) return thread_io
def async_write(icomp, row0, row1, buf, thread_io): if thread_io is not None: thread_io.join() thread_io = lib.background_thread(save, icomp, row0, row1, buf) return thread_io
def update_amps(mycc, t1, t2, eris): time0 = time.clock(), time.time() log = logger.Logger(mycc.stdout, mycc.verbose) nocc, nvir = t1.shape nov = nocc*nvir fock = eris.fock t1new = numpy.zeros_like(t1) t2new = numpy.zeros_like(t2) t2new_tril = numpy.zeros((nocc*(nocc+1)//2,nvir,nvir)) mycc.add_wvvVV_(t1, t2, eris, t2new_tril) for i in range(nocc): for j in range(i+1): t2new[i,j] = t2new_tril[i*(i+1)//2+j] t2new[i,i] *= .5 t2new_tril = None time1 = log.timer_debug1('vvvv', *time0) #** make_inter_F fov = fock[:nocc,nocc:].copy() t1new += fov foo = fock[:nocc,:nocc].copy() foo[range(nocc),range(nocc)] = 0 foo += .5 * numpy.einsum('ia,ja->ij', fock[:nocc,nocc:], t1) fvv = fock[nocc:,nocc:].copy() fvv[range(nvir),range(nvir)] = 0 fvv -= .5 * numpy.einsum('ia,ib->ab', t1, fock[:nocc,nocc:]) #: woooo = numpy.einsum('la,ikja->ikjl', t1, eris.ooov) eris_ooov = _cp(eris.ooov) foo += numpy.einsum('kc,jikc->ij', 2*t1, eris_ooov) foo += numpy.einsum('kc,jkic->ij', -t1, eris_ooov) woooo = lib.ddot(eris_ooov.reshape(-1,nvir), t1.T).reshape((nocc,)*4) woooo = lib.transpose_sum(woooo.reshape(nocc*nocc,-1), inplace=True) woooo += _cp(eris.oooo).reshape(nocc**2,-1) woooo = _cp(woooo.reshape(nocc,nocc,nocc,nocc).transpose(0,2,1,3)) eris_ooov = None time1 = log.timer_debug1('woooo', *time1) unit = _memory_usage_inloop(nocc, nvir) max_memory = max(2000, mycc.max_memory - lib.current_memory()[0]) blksize = min(nocc, max(BLKMIN, int(max_memory/unit))) blknvir = int((max_memory*.9e6/8-blksize*nocc*nvir**2*6)/(blksize*nvir**2*2)) blknvir = min(nvir, max(BLKMIN, blknvir)) log.debug1('max_memory %d MB, nocc,nvir = %d,%d blksize = %d,%d', max_memory, nocc, nvir, blksize, blknvir) nvir_pair = nvir * (nvir+1) // 2 def prefect_ovvv(p0, p1, q0, q1, prefetch): if q1 != nvir: q0, q1 = q1, min(nvir, q1+blknvir) readbuf = numpy.ndarray((p1-p0,q1-q0,nvir_pair), buffer=prefetch) readbuf[:] = eris.ovvv[p0:p1,q0:q1] def prefect_ovov(p0, p1, buf): buf[:] = eris.ovov[p0:p1] def prefect_oovv(p0, p1, buf): buf[:] = eris.oovv[p0:p1] buflen = max(nocc*nvir**2, nocc**3) bufs = numpy.empty((5,blksize*buflen)) buf1, buf2, buf3, buf4, buf5 = bufs for p0, p1 in prange(0, nocc, blksize): #: wOoVv += numpy.einsum('iabc,jc->ijab', eris.ovvv, t1) #: wOoVv -= numpy.einsum('jbik,ka->jiba', eris.ovoo, t1) wOoVv = numpy.ndarray((nocc,p1-p0,nvir,nvir), buffer=buf3) wooVV = numpy.ndarray((p1-p0,nocc,nvir,nvir), buffer=buf4) handler = None readbuf = numpy.empty((p1-p0,blknvir,nvir_pair)) prefetchbuf = numpy.empty((p1-p0,blknvir,nvir_pair)) ovvvbuf = numpy.empty((p1-p0,blknvir,nvir,nvir)) for q0, q1 in lib.prange(0, nvir, blknvir): if q0 == 0: readbuf[:] = eris.ovvv[p0:p1,q0:q1] else: readbuf, prefetchbuf = prefetchbuf, readbuf handler = async_do(handler, prefect_ovvv, p0, p1, q0, q1, prefetchbuf) eris_ovvv = numpy.ndarray(((p1-p0)*(q1-q0),nvir_pair), buffer=readbuf) #:eris_ovvv = _cp(eris.ovvv[p0:p1,q0:q1]) eris_ovvv = lib.unpack_tril(eris_ovvv, out=ovvvbuf) eris_ovvv = eris_ovvv.reshape(p1-p0,q1-q0,nvir,nvir) #: tau = t2 + numpy.einsum('ia,jb->ijab', t1, t1) #: tmp = numpy.einsum('ijcd,kcdb->ijbk', tau, eris.ovvv) #: t2new += numpy.einsum('ka,ijbk->ijab', -t1, tmp) if not mycc.direct: eris_vovv = lib.transpose(eris_ovvv.reshape(-1,nvir)) eris_vovv = eris_vovv.reshape(nvir*(p1-p0),-1) tmp = numpy.ndarray((nocc,nocc,nvir,p1-p0), buffer=buf1) for j0, j1 in prange(0, nocc, blksize): tau = numpy.ndarray((j1-j0,nocc,q1-q0,nvir), buffer=buf2) tau = numpy.einsum('ia,jb->ijab', t1[j0:j1,q0:q1], t1, out=tau) tau += t2[j0:j1,:,q0:q1] lib.ddot(tau.reshape((j1-j0)*nocc,-1), eris_vovv.T, 1, tmp[j0:j1].reshape((j1-j0)*nocc,-1), 0) tmp1 = numpy.ndarray((nocc,nocc,nvir,p1-p0), buffer=buf2) tmp1[:] = tmp.transpose(1,0,2,3) lib.ddot(tmp1.reshape(-1,p1-p0), t1[p0:p1], -1, t2new.reshape(-1,nvir), 1) eris_vovv = tau = tmp1 = tmp = None fvv += numpy.einsum('kc,kcba->ab', 2*t1[p0:p1,q0:q1], eris_ovvv) fvv[:,q0:q1] += numpy.einsum('kc,kbca->ab', -t1[p0:p1], eris_ovvv) #: wooVV -= numpy.einsum('jc,icba->ijba', t1, eris_ovvv) tmp = t1[:,q0:q1].copy() for i in range(eris_ovvv.shape[0]): lib.ddot(tmp, eris_ovvv[i].reshape(q1-q0,-1), -1, wooVV[i].reshape(nocc,-1)) #: wOoVv += numpy.einsum('ibac,jc->jiba', eris_ovvv, t1) tmp = numpy.ndarray((nocc,p1-p0,q1-q0,nvir), buffer=buf1) lib.ddot(t1, eris_ovvv.reshape(-1,nvir).T, 1, tmp.reshape(nocc,-1)) wOoVv[:,:,q0:q1] = tmp #: theta = t2.transpose(1,0,2,3) * 2 - t2 #: t1new += numpy.einsum('ijcb,jcba->ia', theta, eris.ovvv) theta = tmp theta[:] = t2[p0:p1,:,q0:q1,:].transpose(1,0,2,3) theta *= 2 theta -= t2[:,p0:p1,q0:q1,:] lib.ddot(theta.reshape(nocc,-1), eris_ovvv.reshape(-1,nvir), 1, t1new, 1) theta = tmp = None handler.join() readbuf = prefetchbuf = ovvvbuf = eris_ovvv = None time2 = log.timer_debug1('ovvv [%d:%d]'%(p0, p1), *time1) tmp = numpy.ndarray((nocc,p1-p0,nvir,nocc), buffer=buf1) tmp[:] = _cp(eris.ovoo[p0:p1]).transpose(2,0,1,3) lib.ddot(tmp.reshape(-1,nocc), t1, -1, wOoVv.reshape(-1,nvir), 1) eris_ooov = _cp(eris.ooov[p0:p1]) eris_oovv = numpy.empty((p1-p0,nocc,nvir,nvir)) handler = lib.background_thread(prefect_oovv, p0, p1, eris_oovv) tmp = numpy.ndarray((p1-p0,nocc,nvir,nocc), buffer=buf1) tmp[:] = eris_ooov.transpose(0,1,3,2) #: wooVV = numpy.einsum('ka,ijkb->ijba', t1, eris.ooov[p0:p1]) lib.ddot(tmp.reshape(-1,nocc), t1, 1, wooVV.reshape(-1,nvir), 1) t2new[p0:p1] += wOoVv.transpose(1,0,2,3) #:eris_oovv = _cp(eris.oovv[p0:p1]) handler.join() eris_ovov = numpy.empty((p1-p0,nvir,nocc,nvir)) handler = lib.background_thread(prefect_ovov, p0, p1, eris_ovov) #: g2 = 2 * eris.oOVv - eris.oovv #: t1new += numpy.einsum('jb,ijba->ia', t1, g2) t1new[p0:p1] += numpy.einsum('jb,ijba->ia', -t1, eris_oovv) wooVV -= eris_oovv #tmp = numpy.einsum('ic,jkbc->jikb', t1, eris_oovv) #t2new[p0:p1] += numpy.einsum('ka,jikb->ijba', -t1, tmp) tmp1 = numpy.ndarray((nocc,nocc*nvir), buffer=buf1) tmp2 = numpy.ndarray((nocc*nvir,nocc), buffer=buf2) for j in range(p1-p0): tmp = lib.ddot(t1, eris_oovv[j].reshape(-1,nvir).T, 1, tmp1) lib.transpose(_cp(tmp).reshape(nocc,nocc,nvir), axes=(0,2,1), out=tmp2) t2new[:,p0+j] -= lib.ddot(tmp2, t1).reshape(nocc,nvir,nvir) eris_oovv = None #:eris_ovov = _cp(eris.ovov[p0:p1]) handler.join() for i in range(p1-p0): t2new[p0+i] += eris_ovov[i].transpose(1,0,2) * .5 t1new[p0:p1] += numpy.einsum('jb,iajb->ia', 2*t1, eris_ovov) #:tmp = numpy.einsum('ic,jbkc->jibk', t1, eris_ovov) #:t2new[p0:p1] += numpy.einsum('ka,jibk->jiba', -t1, tmp) for j in range(p1-p0): lib.ddot(t1, eris_ovov[j].reshape(-1,nvir).T, 1, tmp1) lib.ddot(tmp1.reshape(-1,nocc), t1, -1, t2new[p0+j].reshape(-1,nvir), 1) tmp1 = tmp2 = tmp = None fov[p0:p1] += numpy.einsum('kc,iakc->ia', t1, eris_ovov) * 2 fov[p0:p1] -= numpy.einsum('kc,icka->ia', t1, eris_ovov) #: fvv -= numpy.einsum('ijca,ibjc->ab', theta, eris.ovov) #: foo += numpy.einsum('iakb,jkba->ij', eris.ovov, theta) tau = numpy.ndarray((nocc,nvir,nvir), buffer=buf1) theta = numpy.ndarray((nocc,nvir,nvir), buffer=buf2) for i in range(p1-p0): tau = numpy.einsum('a,jb->jab', t1[p0+i]*.5, t1, out=tau) tau += t2[p0+i] theta = lib.transpose(tau, axes=(0,2,1), out=theta) theta *= 2 theta -= tau vov = lib.transpose(eris_ovov[i].reshape(nvir,-1), out=tau) lib.ddot(vov.reshape(nocc,-1), theta.reshape(nocc,-1).T, 1, foo, 1) lib.ddot(theta.reshape(-1,nvir).T, eris_ovov[i].reshape(nvir,-1).T, -1, fvv, 1) tau = theta = vov = None #: theta = t2.transpose(0,2,1,3) * 2 - t2.transpose(0,3,2,1) #: t1new += numpy.einsum('jb,ijba->ia', fov, theta) #: t1new -= numpy.einsum('kijb,kjba->ia', eris_ooov, theta) theta = numpy.ndarray((p1-p0,nvir,nocc,nvir), buffer=buf1) for i in range(p1-p0): tmp = t2[p0+i].transpose(0,2,1) * 2 tmp-= t2[p0+i] lib.ddot(eris_ooov[i].reshape(nocc,-1), tmp.reshape(-1,nvir), -1, t1new, 1) lib.transpose(_cp(tmp).reshape(-1,nvir), out=theta[i]) # theta[i] = tmp.transpose(2,0,1) t1new += numpy.einsum('jb,jbia->ia', fov[p0:p1], theta) eris_ooov = None #: wOVov += eris.ovov #: tau = theta - numpy.einsum('ic,kb->ikcb', t1, t1*2) #: wOVov += .5 * numpy.einsum('jakc,ikcb->jiba', eris.ovov, tau) #: wOVov -= .5 * numpy.einsum('jcka,ikcb->jiba', eris.ovov, t2) #: t2new += numpy.einsum('ikca,kjbc->ijba', theta, wOVov) for i in range(p1-p0): wOoVv[:,i] += wooVV[i]*.5 #: jiba + ijba*.5 wOVov = lib.transpose(wOoVv.reshape(nocc,-1,nvir), axes=(0,2,1), out=buf5) wOVov = wOVov.reshape(nocc,nvir,-1,nvir) eris_OVov = lib.transpose(eris_ovov.reshape(-1,nov), out=buf3) eris_OVov = eris_OVov.reshape(nocc,nvir,-1,nvir) wOVov += eris_OVov theta = theta.reshape(-1,nov) for i in range(nocc): # OVov-OVov.transpose(0,3,2,1)*.5 eris_OVov[i] -= eris_OVov[i].transpose(2,1,0)*.5 for j0, j1 in prange(0, nocc, blksize): tau = numpy.ndarray((j1-j0,nvir,nocc,nvir), buffer=buf2) for i in range(j1-j0): tau[i] = t2[j0+i].transpose(1,0,2) * 2 tau[i] -= t2[j0+i].transpose(2,0,1) tau[i] -= numpy.einsum('a,jb->bja', t1[j0+i]*2, t1) #: wOVov[j0:j1] += .5 * numpy.einsum('iakc,jbkc->jbai', eris_ovov, tau) lib.ddot(tau.reshape(-1,nov), eris_OVov.reshape(nov,-1), .5, wOVov[j0:j1].reshape((j1-j0)*nvir,-1), 1) #theta = t2[p0:p1] * 2 - t2[p0:p1].transpose(0,1,3,2) #: t2new[j0:j1] += numpy.einsum('iack,jbck->jiba', theta, wOVov[j0:j1]) tmp = lib.ddot(wOVov[j0:j1].reshape((j1-j0)*nvir,-1), theta, 1, tau.reshape(-1,nov)).reshape(-1,nvir,nocc,nvir) for i in range(j1-j0): t2new[j0+i] += tmp[i].transpose(1,0,2) theta = wOoVv = wOVov = eris_OVov = tmp = tau = None time2 = log.timer_debug1('wOVov [%d:%d]'%(p0, p1), *time2) #: tau = t2 + numpy.einsum('ia,jb->ijab', t1, t1) #: woooo += numpy.einsum('ijba,klab->ijkl', eris.oOVv, tau) #: tau = .5*t2 + numpy.einsum('ia,jb->ijab', t1, t1) #: woVoV += numpy.einsum('jkca,ikbc->ijba', tau, eris.oOVv) tmp = numpy.ndarray((p1-p0,nvir,nocc,nvir), buffer=buf1) tmp[:] = wooVV.transpose(0,2,1,3) woVoV = lib.transpose(_cp(tmp).reshape(-1,nov), out=buf4).reshape(nocc,nvir,p1-p0,nvir) eris_oOvV = numpy.ndarray((p1-p0,nocc,nvir,nvir), buffer=buf3) eris_oOvV[:] = eris_ovov.transpose(0,2,1,3) eris_oVOv = lib.transpose(eris_oOvV.reshape(-1,nov,nvir), axes=(0,2,1), out=buf5) eris_oVOv = eris_oVOv.reshape(-1,nvir,nocc,nvir) for j0, j1 in prange(0, nocc, blksize): tau = make_tau(t2[j0:j1], t1[j0:j1], t1, 1, out=buf2) #: woooo[p0:p1,:,j0:j1] += numpy.einsum('ijab,klab->ijkl', eris_oOvV, tau) _dgemm('N', 'T', (p1-p0)*nocc, (j1-j0)*nocc, nvir*nvir, eris_oOvV.reshape(-1,nvir*nvir), tau.reshape(-1,nvir*nvir), woooo[p0:p1].reshape(-1,nocc*nocc), 1, 1, 0, 0, j0*nocc) for i in range(j1-j0): tau[i] -= t2[j0+i] * .5 #: woVoV[j0:j1] += numpy.einsum('jkca,ickb->jiab', tau, eris_ovov) lib.ddot(lib.transpose(tau.reshape(-1,nov,nvir), axes=(0,2,1)).reshape(-1,nov), eris_oVOv.reshape(-1,nov).T, 1, woVoV[j0:j1].reshape((j1-j0)*nvir,-1), 1) time2 = log.timer_debug1('woVoV [%d:%d]'%(p0, p1), *time2) tau = make_tau(t2[p0:p1], t1[p0:p1], t1, 1, out=buf2) #: t2new += .5 * numpy.einsum('klij,klab->ijab', woooo[p0:p1], tau) lib.ddot(woooo[p0:p1].reshape(-1,nocc*nocc).T, tau.reshape(-1,nvir*nvir), .5, t2new.reshape(nocc*nocc,-1), 1) eris_ovov = eris_oVOv = eris_oOvV = wooVV = tau = tmp = None t2ibja = lib.transpose(_cp(t2[p0:p1]).reshape(-1,nov,nvir), axes=(0,2,1), out=buf1).reshape(-1,nvir,nocc,nvir) tmp = numpy.ndarray((blksize,nvir,nocc,nvir), buffer=buf2) for j0, j1 in prange(0, nocc, blksize): #: t2new[j0:j1] += numpy.einsum('ibkc,kcja->ijab', woVoV[j0:j1], t2ibja) lib.ddot(woVoV[j0:j1].reshape((j1-j0)*nvir,-1), t2ibja.reshape(-1,nov), 1, tmp[:j1-j0].reshape(-1,nov)) for i in range(j1-j0): t2new[j0+i] += tmp[i].transpose(1,2,0) t2new[j0+i] += tmp[i].transpose(1,0,2) * .5 woVoV = t2ibja = tmp = None time1 = log.timer_debug1('contract occ [%d:%d]'%(p0, p1), *time1) buf1 = buf2 = buf3 = buf4 = buf5 = bufs = None time1 = log.timer_debug1('contract loop', *time0) woooo = None ft_ij = foo + numpy.einsum('ja,ia->ij', .5*t1, fov) ft_ab = fvv - numpy.einsum('ia,ib->ab', .5*t1, fov) #: t2new += numpy.einsum('ijac,bc->ijab', t2, ft_ab) #: t2new -= numpy.einsum('ki,kjab->ijab', ft_ij, t2) lib.ddot(t2.reshape(-1,nvir), ft_ab.T, 1, t2new.reshape(-1,nvir), 1) lib.ddot(ft_ij.T, t2.reshape(nocc,-1),-1, t2new.reshape(nocc,-1), 1) mo_e = fock.diagonal() eia = mo_e[:nocc,None] - mo_e[None,nocc:] t1new += numpy.einsum('ib,ab->ia', t1, fvv) t1new -= numpy.einsum('ja,ji->ia', t1, foo) t1new /= eia #: t2new = t2new + t2new.transpose(1,0,3,2) ij = 0 for i in range(nocc): for j in range(i+1): t2new[i,j] += t2new[j,i].T t2new[i,j] /= lib.direct_sum('a,b->ab', eia[i], eia[j]) t2new[j,i] = t2new[i,j].T ij += 1 time0 = log.timer_debug1('update t1 t2', *time0) return t1new, t2new
def update_amps(mycc, t1, t2, eris): time0 = time.clock(), time.time() log = logger.Logger(mycc.stdout, mycc.verbose) nocc, nvir = t1.shape nov = nocc*nvir fock = eris.fock t1t2new = numpy.zeros((nov+nov**2)) t1new = t1t2new[:nov].reshape(t1.shape) t2new = t1t2new[nov:].reshape(t2.shape) t2new_tril = numpy.zeros((nocc*(nocc+1)//2,nvir,nvir)) mycc.add_wvvVV_(t1, t2, eris, t2new_tril) idxo = numpy.tril_indices(nocc) lib.takebak_2d(t2new.reshape(nocc**2,nvir**2), t2new_tril.reshape(-1,nvir**2), idxo[0]*nocc+idxo[1], numpy.arange(nvir**2)) idxo = numpy.arange(nocc) t2new[idxo,idxo] *= .5 t2new_tril = None time1 = log.timer_debug1('vvvv', *time0) #** make_inter_F fov = fock[:nocc,nocc:].copy() t1new += fov foo = fock[:nocc,:nocc].copy() foo[range(nocc),range(nocc)] = 0 foo += .5 * numpy.einsum('ia,ja->ij', fock[:nocc,nocc:], t1) fvv = fock[nocc:,nocc:].copy() fvv[range(nvir),range(nvir)] = 0 fvv -= .5 * numpy.einsum('ia,ib->ab', t1, fock[:nocc,nocc:]) #: woooo = numpy.einsum('la,ikja->ikjl', t1, eris.ooov) eris_ooov = _cp(eris.ooov) foo += numpy.einsum('kc,jikc->ij', 2*t1, eris_ooov) foo += numpy.einsum('kc,jkic->ij', -t1, eris_ooov) woooo = lib.ddot(eris_ooov.reshape(-1,nvir), t1.T).reshape((nocc,)*4) woooo = lib.transpose_sum(woooo.reshape(nocc**2,nocc**2), inplace=True) woooo += _cp(eris.oooo).reshape(nocc**2,nocc**2) woooo = _cp(woooo.reshape(nocc,nocc,nocc,nocc).transpose(0,2,1,3)) eris_ooov = None time1 = log.timer_debug1('woooo', *time1) unit = _memory_usage_inloop(nocc, nvir) max_memory = max(2000, mycc.max_memory - lib.current_memory()[0]) blksize = min(nocc, max(BLKMIN, int(max_memory/unit))) blknvir = int((max_memory*.9e6/8-blksize*nocc*nvir**2*6)/(blksize*nvir**2*2)) blknvir = min(nvir, max(BLKMIN, blknvir)) log.debug1('max_memory %d MB, nocc,nvir = %d,%d blksize = %d,%d', max_memory, nocc, nvir, blksize, blknvir) nvir_pair = nvir * (nvir+1) // 2 def prefect_ovvv(p0, p1, q0, q1, prefetch): if q1 != nvir: q0, q1 = q1, min(nvir, q1+blknvir) readbuf = numpy.ndarray((p1-p0,q1-q0,nvir_pair), buffer=prefetch) readbuf[:] = eris.ovvv[p0:p1,q0:q1] def prefect_ovov(p0, p1, buf): buf[:] = eris.ovov[p0:p1] def prefect_oovv(p0, p1, buf): buf[:] = eris.oovv[p0:p1] buflen = max(nocc*nvir**2, nocc**3) bufs = numpy.empty((5,blksize*buflen)) buf1, buf2, buf3, buf4, buf5 = bufs for p0, p1 in prange(0, nocc, blksize): #: wOoVv += numpy.einsum('iabc,jc->ijab', eris.ovvv, t1) #: wOoVv -= numpy.einsum('jbik,ka->jiba', eris.ovoo, t1) wOoVv = numpy.ndarray((nocc,p1-p0,nvir,nvir), buffer=buf3) wooVV = numpy.ndarray((p1-p0,nocc,nvir,nvir), buffer=buf4) handler = None readbuf = numpy.empty((p1-p0,blknvir,nvir_pair)) prefetchbuf = numpy.empty((p1-p0,blknvir,nvir_pair)) ovvvbuf = numpy.empty((p1-p0,blknvir,nvir,nvir)) for q0, q1 in lib.prange(0, nvir, blknvir): if q0 == 0: readbuf[:] = eris.ovvv[p0:p1,q0:q1] else: readbuf, prefetchbuf = prefetchbuf, readbuf handler = async_do(handler, prefect_ovvv, p0, p1, q0, q1, prefetchbuf) eris_ovvv = numpy.ndarray(((p1-p0)*(q1-q0),nvir_pair), buffer=readbuf) #:eris_ovvv = _cp(eris.ovvv[p0:p1,q0:q1]) eris_ovvv = lib.unpack_tril(eris_ovvv, out=ovvvbuf) eris_ovvv = eris_ovvv.reshape(p1-p0,q1-q0,nvir,nvir) #: tau = t2 + numpy.einsum('ia,jb->ijab', t1, t1) #: tmp = numpy.einsum('ijcd,kcdb->ijbk', tau, eris.ovvv) #: t2new += numpy.einsum('ka,ijbk->ijab', -t1, tmp) if not mycc.direct: eris_vovv = lib.transpose(eris_ovvv.reshape(-1,nvir)) eris_vovv = eris_vovv.reshape(nvir*(p1-p0),-1) tmp = numpy.ndarray((nocc,nocc,nvir,p1-p0), buffer=buf1) for j0, j1 in prange(0, nocc, blksize): tau = numpy.ndarray((j1-j0,nocc,q1-q0,nvir), buffer=buf2) tau = numpy.einsum('ia,jb->ijab', t1[j0:j1,q0:q1], t1, out=tau) tau += t2[j0:j1,:,q0:q1] lib.ddot(tau.reshape((j1-j0)*nocc,-1), eris_vovv.T, 1, tmp[j0:j1].reshape((j1-j0)*nocc,-1), 0) tmp1 = numpy.ndarray((nocc,nocc,nvir,p1-p0), buffer=buf2) tmp1[:] = tmp.transpose(1,0,2,3) lib.ddot(tmp1.reshape(-1,p1-p0), t1[p0:p1], -1, t2new.reshape(-1,nvir), 1) eris_vovv = tau = tmp1 = tmp = None fvv += numpy.einsum('kc,kcba->ab', 2*t1[p0:p1,q0:q1], eris_ovvv) fvv[:,q0:q1] += numpy.einsum('kc,kbca->ab', -t1[p0:p1], eris_ovvv) #: wooVV -= numpy.einsum('jc,icba->ijba', t1, eris_ovvv) tmp = t1[:,q0:q1].copy() for i in range(eris_ovvv.shape[0]): lib.ddot(tmp, eris_ovvv[i].reshape(q1-q0,-1), -1, wooVV[i].reshape(nocc,-1)) #: wOoVv += numpy.einsum('ibac,jc->jiba', eris_ovvv, t1) tmp = numpy.ndarray((nocc,p1-p0,q1-q0,nvir), buffer=buf1) lib.ddot(t1, eris_ovvv.reshape(-1,nvir).T, 1, tmp.reshape(nocc,-1)) wOoVv[:,:,q0:q1] = tmp #: theta = t2.transpose(1,0,2,3) * 2 - t2 #: t1new += numpy.einsum('ijcb,jcba->ia', theta, eris.ovvv) theta = tmp theta[:] = t2[p0:p1,:,q0:q1,:].transpose(1,0,2,3) theta *= 2 theta -= t2[:,p0:p1,q0:q1,:] lib.ddot(theta.reshape(nocc,-1), eris_ovvv.reshape(-1,nvir), 1, t1new, 1) theta = tmp = None handler.join() readbuf = prefetchbuf = ovvvbuf = eris_ovvv = None time2 = log.timer_debug1('ovvv [%d:%d]'%(p0, p1), *time1) tmp = numpy.ndarray((nocc,p1-p0,nvir,nocc), buffer=buf1) tmp[:] = _cp(eris.ovoo[p0:p1]).transpose(2,0,1,3) lib.ddot(tmp.reshape(-1,nocc), t1, -1, wOoVv.reshape(-1,nvir), 1) eris_ooov = _cp(eris.ooov[p0:p1]) eris_oovv = numpy.empty((p1-p0,nocc,nvir,nvir)) handler = lib.background_thread(prefect_oovv, p0, p1, eris_oovv) tmp = numpy.ndarray((p1-p0,nocc,nvir,nocc), buffer=buf1) tmp[:] = eris_ooov.transpose(0,1,3,2) #: wooVV = numpy.einsum('ka,ijkb->ijba', t1, eris.ooov[p0:p1]) lib.ddot(tmp.reshape(-1,nocc), t1, 1, wooVV.reshape(-1,nvir), 1) t2new[p0:p1] += wOoVv.transpose(1,0,2,3) #:eris_oovv = _cp(eris.oovv[p0:p1]) handler.join() eris_ovov = numpy.empty((p1-p0,nvir,nocc,nvir)) handler = lib.background_thread(prefect_ovov, p0, p1, eris_ovov) #: g2 = 2 * eris.oOVv - eris.oovv #: t1new += numpy.einsum('jb,ijba->ia', t1, g2) t1new[p0:p1] += numpy.einsum('jb,ijba->ia', -t1, eris_oovv) wooVV -= eris_oovv #tmp = numpy.einsum('ic,jkbc->jikb', t1, eris_oovv) #t2new[p0:p1] += numpy.einsum('ka,jikb->ijba', -t1, tmp) tmp1 = numpy.ndarray((nocc,nocc*nvir), buffer=buf1) tmp2 = numpy.ndarray((nocc*nvir,nocc), buffer=buf2) for j in range(p1-p0): tmp = lib.ddot(t1, eris_oovv[j].reshape(-1,nvir).T, 1, tmp1) lib.transpose(_cp(tmp).reshape(nocc,nocc,nvir), axes=(0,2,1), out=tmp2) t2new[:,p0+j] -= lib.ddot(tmp2, t1).reshape(nocc,nvir,nvir) eris_oovv = None #:eris_ovov = _cp(eris.ovov[p0:p1]) handler.join() for i in range(p1-p0): t2new[p0+i] += eris_ovov[i].transpose(1,0,2) * .5 t1new[p0:p1] += numpy.einsum('jb,iajb->ia', 2*t1, eris_ovov) #:tmp = numpy.einsum('ic,jbkc->jibk', t1, eris_ovov) #:t2new[p0:p1] += numpy.einsum('ka,jibk->jiba', -t1, tmp) for j in range(p1-p0): lib.ddot(t1, eris_ovov[j].reshape(-1,nvir).T, 1, tmp1) lib.ddot(tmp1.reshape(-1,nocc), t1, -1, t2new[p0+j].reshape(-1,nvir), 1) tmp1 = tmp2 = tmp = None fov[p0:p1] += numpy.einsum('kc,iakc->ia', t1, eris_ovov) * 2 fov[p0:p1] -= numpy.einsum('kc,icka->ia', t1, eris_ovov) #: fvv -= numpy.einsum('ijca,ibjc->ab', theta, eris.ovov) #: foo += numpy.einsum('iakb,jkba->ij', eris.ovov, theta) tau = numpy.ndarray((nocc,nvir,nvir), buffer=buf1) theta = numpy.ndarray((nocc,nvir,nvir), buffer=buf2) for i in range(p1-p0): tau = numpy.einsum('a,jb->jab', t1[p0+i]*.5, t1, out=tau) tau += t2[p0+i] theta = lib.transpose(tau, axes=(0,2,1), out=theta) theta *= 2 theta -= tau vov = lib.transpose(eris_ovov[i].reshape(nvir,-1), out=tau) lib.ddot(vov.reshape(nocc,-1), theta.reshape(nocc,-1).T, 1, foo, 1) lib.ddot(theta.reshape(-1,nvir).T, eris_ovov[i].reshape(nvir,-1).T, -1, fvv, 1) tau = theta = vov = None #: theta = t2.transpose(0,2,1,3) * 2 - t2.transpose(0,3,2,1) #: t1new += numpy.einsum('jb,ijba->ia', fov, theta) #: t1new -= numpy.einsum('kijb,kjba->ia', eris_ooov, theta) theta = numpy.ndarray((p1-p0,nvir,nocc,nvir), buffer=buf1) for i in range(p1-p0): tmp = t2[p0+i].transpose(0,2,1) * 2 tmp-= t2[p0+i] lib.ddot(eris_ooov[i].reshape(nocc,-1), tmp.reshape(-1,nvir), -1, t1new, 1) lib.transpose(_cp(tmp).reshape(-1,nvir), out=theta[i]) # theta[i] = tmp.transpose(2,0,1) t1new += numpy.einsum('jb,jbia->ia', fov[p0:p1], theta) eris_ooov = None #: wOVov += eris.ovov #: tau = theta - numpy.einsum('ic,kb->ikcb', t1, t1*2) #: wOVov += .5 * numpy.einsum('jakc,ikcb->jiba', eris.ovov, tau) #: wOVov -= .5 * numpy.einsum('jcka,ikcb->jiba', eris.ovov, t2) #: t2new += numpy.einsum('ikca,kjbc->ijba', theta, wOVov) for i in range(p1-p0): wOoVv[:,i] += wooVV[i]*.5 #: jiba + ijba*.5 wOVov = lib.transpose(wOoVv.reshape(nocc,-1,nvir), axes=(0,2,1), out=buf5) wOVov = wOVov.reshape(nocc,nvir,-1,nvir) eris_OVov = lib.transpose(eris_ovov.reshape(-1,nov), out=buf3) eris_OVov = eris_OVov.reshape(nocc,nvir,-1,nvir) wOVov += eris_OVov theta = theta.reshape(-1,nov) for i in range(nocc): # OVov-OVov.transpose(0,3,2,1)*.5 eris_OVov[i] -= eris_OVov[i].transpose(2,1,0)*.5 for j0, j1 in prange(0, nocc, blksize): tau = numpy.ndarray((j1-j0,nvir,nocc,nvir), buffer=buf2) for i in range(j1-j0): tau[i] = t2[j0+i].transpose(1,0,2) * 2 tau[i] -= t2[j0+i].transpose(2,0,1) tau[i] -= numpy.einsum('a,jb->bja', t1[j0+i]*2, t1) #: wOVov[j0:j1] += .5 * numpy.einsum('iakc,jbkc->jbai', eris_ovov, tau) lib.ddot(tau.reshape(-1,nov), eris_OVov.reshape(nov,-1), .5, wOVov[j0:j1].reshape((j1-j0)*nvir,-1), 1) #theta = t2[p0:p1] * 2 - t2[p0:p1].transpose(0,1,3,2) #: t2new[j0:j1] += numpy.einsum('iack,jbck->jiba', theta, wOVov[j0:j1]) tmp = lib.ddot(wOVov[j0:j1].reshape((j1-j0)*nvir,-1), theta, 1, tau.reshape(-1,nov)).reshape(-1,nvir,nocc,nvir) for i in range(j1-j0): t2new[j0+i] += tmp[i].transpose(1,0,2) theta = wOoVv = wOVov = eris_OVov = tmp = tau = None time2 = log.timer_debug1('wOVov [%d:%d]'%(p0, p1), *time2) #: tau = t2 + numpy.einsum('ia,jb->ijab', t1, t1) #: woooo += numpy.einsum('ijba,klab->ijkl', eris.oOVv, tau) #: tau = .5*t2 + numpy.einsum('ia,jb->ijab', t1, t1) #: woVoV += numpy.einsum('jkca,ikbc->ijba', tau, eris.oOVv) tmp = numpy.ndarray((p1-p0,nvir,nocc,nvir), buffer=buf1) tmp[:] = wooVV.transpose(0,2,1,3) woVoV = lib.transpose(_cp(tmp).reshape(-1,nov), out=buf4).reshape(nocc,nvir,p1-p0,nvir) eris_oOvV = numpy.ndarray((p1-p0,nocc,nvir,nvir), buffer=buf3) eris_oOvV[:] = eris_ovov.transpose(0,2,1,3) eris_oVOv = lib.transpose(eris_oOvV.reshape(-1,nov,nvir), axes=(0,2,1), out=buf5) eris_oVOv = eris_oVOv.reshape(-1,nvir,nocc,nvir) for j0, j1 in prange(0, nocc, blksize): tau = make_tau(t2[j0:j1], t1[j0:j1], t1, 1, out=buf2) #: woooo[p0:p1,:,j0:j1] += numpy.einsum('ijab,klab->ijkl', eris_oOvV, tau) _dgemm('N', 'T', (p1-p0)*nocc, (j1-j0)*nocc, nvir*nvir, eris_oOvV.reshape(-1,nvir*nvir), tau.reshape(-1,nvir*nvir), woooo[p0:p1].reshape(-1,nocc*nocc), 1, 1, 0, 0, j0*nocc) for i in range(j1-j0): tau[i] -= t2[j0+i] * .5 #: woVoV[j0:j1] += numpy.einsum('jkca,ickb->jiab', tau, eris_ovov) lib.ddot(lib.transpose(tau.reshape(-1,nov,nvir), axes=(0,2,1)).reshape(-1,nov), eris_oVOv.reshape(-1,nov).T, 1, woVoV[j0:j1].reshape((j1-j0)*nvir,-1), 1) time2 = log.timer_debug1('woVoV [%d:%d]'%(p0, p1), *time2) tau = make_tau(t2[p0:p1], t1[p0:p1], t1, 1, out=buf2) #: t2new += .5 * numpy.einsum('klij,klab->ijab', woooo[p0:p1], tau) lib.ddot(woooo[p0:p1].reshape(-1,nocc*nocc).T, tau.reshape(-1,nvir*nvir), .5, t2new.reshape(nocc*nocc,-1), 1) eris_ovov = eris_oVOv = eris_oOvV = wooVV = tau = tmp = None t2ibja = lib.transpose(_cp(t2[p0:p1]).reshape(-1,nov,nvir), axes=(0,2,1), out=buf1).reshape(-1,nvir,nocc,nvir) tmp = numpy.ndarray((blksize,nvir,nocc,nvir), buffer=buf2) for j0, j1 in prange(0, nocc, blksize): #: t2new[j0:j1] += numpy.einsum('ibkc,kcja->ijab', woVoV[j0:j1], t2ibja) lib.ddot(woVoV[j0:j1].reshape((j1-j0)*nvir,-1), t2ibja.reshape(-1,nov), 1, tmp[:j1-j0].reshape(-1,nov)) for i in range(j1-j0): t2new[j0+i] += tmp[i].transpose(1,2,0) t2new[j0+i] += tmp[i].transpose(1,0,2) * .5 woVoV = t2ibja = tmp = None time1 = log.timer_debug1('contract occ [%d:%d]'%(p0, p1), *time1) buf1 = buf2 = buf3 = buf4 = buf5 = bufs = None time1 = log.timer_debug1('contract loop', *time0) woooo = None ft_ij = foo + numpy.einsum('ja,ia->ij', .5*t1, fov) ft_ab = fvv - numpy.einsum('ia,ib->ab', .5*t1, fov) #: t2new += numpy.einsum('ijac,bc->ijab', t2, ft_ab) #: t2new -= numpy.einsum('ki,kjab->ijab', ft_ij, t2) lib.ddot(t2.reshape(-1,nvir), ft_ab.T, 1, t2new.reshape(-1,nvir), 1) lib.ddot(ft_ij.T, t2.reshape(nocc,nocc*nvir**2),-1, t2new.reshape(nocc,nocc*nvir**2), 1) mo_e = fock.diagonal() eia = mo_e[:nocc,None] - mo_e[None,nocc:] t1new += numpy.einsum('ib,ab->ia', t1, fvv) t1new -= numpy.einsum('ja,ji->ia', t1, foo) t1new /= eia #: t2new = t2new + t2new.transpose(1,0,3,2) ij = 0 for i in range(nocc): for j in range(i+1): t2new[i,j] += t2new[j,i].T t2new[i,j] /= lib.direct_sum('a,b->ab', eia[i], eia[j]) t2new[j,i] = t2new[i,j].T ij += 1 time0 = log.timer_debug1('update t1 t2', *time0) return t1new, t2new
def get_jk(mydf, dm, hermi=1, kpt=numpy.zeros(3), kpt_band=None, with_j=True, with_k=True, exxdiv=None): '''JK for given k-point''' vj = vk = None if kpt_band is not None and abs(kpt-kpt_band).sum() > 1e-9: kpt = numpy.reshape(kpt, (1,3)) if with_k: vk = get_k_kpts(mydf, [dm], hermi, kpt, kpt_band, exxdiv) if with_j: vj = get_j_kpts(mydf, [dm], hermi, kpt, kpt_band) return vj, vk log = logger.Logger(mydf.stdout, mydf.verbose) t2 = t1 = (time.clock(), time.time()) if mydf._cderi is None: mydf.build() t1 = log.timer_debug1('Init get_jk', *t1) dm = numpy.asarray(dm, order='C') dms = _format_dms(dm, [kpt]) nset, _, nao = dms.shape[:3] dms = dms.reshape(nset,nao,nao) j_real = gamma_point(kpt) k_real = gamma_point(kpt) and not numpy.iscomplexobj(dms) kptii = numpy.asarray((kpt,kpt)) # .45 is estimation for the memory usage ratio sr_loop / (sr_loop+bufR+bufI) dmsR = numpy.asarray(dms.real.reshape(nset,nao,nao), order='C') dmsI = numpy.asarray(dms.imag.reshape(nset,nao,nao), order='C') if with_j: vjR = numpy.zeros((nset,nao,nao)) vjI = numpy.zeros((nset,nao,nao)) if with_k: vkR = numpy.zeros((nset,nao,nao)) vkI = numpy.zeros((nset,nao,nao)) max_memory = max(2000, (mydf.max_memory - lib.current_memory()[0])) * .45 if with_k: buf1R = numpy.empty((mydf.blockdim*nao**2)) buf2R = numpy.empty((mydf.blockdim*nao**2)) buf3R = numpy.empty((mydf.blockdim*nao**2)) if not k_real: buf1I = numpy.empty((mydf.blockdim*nao**2)) buf2I = numpy.empty((mydf.blockdim*nao**2)) buf3I = numpy.empty((mydf.blockdim*nao**2)) def contract_k(pLqR, pLqI, pjqR, pjqI): # K ~ 'iLj,lLk*,li->kj' + 'lLk*,iLj,li->kj' #:Lpq = LpqR + LpqI*1j #:j3c = j3cR + j3cI*1j #:for i in range(nset): #: tmp = numpy.dot(dms[i], j3c.reshape(nao,-1)) #: vk1 = numpy.dot(Lpq.reshape(-1,nao).conj().T, tmp.reshape(-1,nao)) #: tmp = numpy.dot(dms[i], Lpq.reshape(nao,-1)) #: vk1+= numpy.dot(j3c.reshape(-1,nao).conj().T, tmp.reshape(-1,nao)) #: vkR[i] += vk1.real #: vkI[i] += vk1.imag nrow = pLqR.shape[1] tmpR = numpy.ndarray((nao,nrow*nao), buffer=buf3R) if k_real: for i in range(nset): tmpR = lib.ddot(dmsR[i], pjqR.reshape(nao,-1), 1, tmpR) vk1R = lib.ddot(pLqR.reshape(-1,nao).T, tmpR.reshape(-1,nao)) vkR[i] += vk1R if hermi: vkR[i] += vk1R.T else: tmpR = lib.ddot(dmsR[i], pLqR.reshape(nao,-1), 1, tmpR) lib.ddot(pjqR.reshape(-1,nao).T, tmpR.reshape(-1,nao), 1, vkR[i], 1) else: tmpI = numpy.ndarray((nao,nrow*nao), buffer=buf3I) for i in range(nset): tmpR, tmpI = zdotNN(dmsR[i], dmsI[i], pjqR.reshape(nao,-1), pjqI.reshape(nao,-1), 1, tmpR, tmpI, 0) vk1R, vk1I = zdotCN(pLqR.reshape(-1,nao).T, pLqI.reshape(-1,nao).T, tmpR.reshape(-1,nao), tmpI.reshape(-1,nao)) vkR[i] += vk1R vkI[i] += vk1I if hermi: vkR[i] += vk1R.T vkI[i] -= vk1I.T else: tmpR, tmpI = zdotNN(dmsR[i], dmsI[i], pLqR.reshape(nao,-1), pLqI.reshape(nao,-1), 1, tmpR, tmpI, 0) zdotCN(pjqR.reshape(-1,nao).T, pjqI.reshape(-1,nao).T, tmpR.reshape(-1,nao), tmpI.reshape(-1,nao), 1, vkR[i], vkI[i], 1) pLqI = pjqI = None thread_k = None for LpqR, LpqI, j3cR, j3cI in mydf.sr_loop(kptii, max_memory, False): LpqR = LpqR.reshape(-1,nao,nao) LpqI = LpqI.reshape(-1,nao,nao) j3cR = j3cR.reshape(-1,nao,nao) j3cI = j3cI.reshape(-1,nao,nao) t2 = log.timer_debug1(' load', *t2) if thread_k is not None: thread_k.join() if with_j: #:rho_coeff = numpy.einsum('Lpq,xqp->xL', Lpq, dms) #:jaux = numpy.einsum('Lpq,xqp->xL', j3c, dms) #:vj += numpy.dot(jaux, Lpq.reshape(-1,nao**2)) #:vj += numpy.dot(rho_coeff, j3c.reshape(-1,nao**2)) rhoR = numpy.einsum('Lpq,xqp->xL', LpqR, dmsR) jauxR = numpy.einsum('Lpq,xqp->xL', j3cR, dmsR) if not j_real: rhoR -= numpy.einsum('Lpq,xqp->xL', LpqI, dmsI) rhoI = numpy.einsum('Lpq,xqp->xL', LpqR, dmsI) rhoI += numpy.einsum('Lpq,xqp->xL', LpqI, dmsR) jauxR-= numpy.einsum('Lpq,xqp->xL', j3cI, dmsI) jauxI = numpy.einsum('Lpq,xqp->xL', j3cR, dmsI) jauxI+= numpy.einsum('Lpq,xqp->xL', j3cI, dmsR) vjR += numpy.einsum('xL,Lpq->xpq', jauxR, LpqR) vjR += numpy.einsum('xL,Lpq->xpq', rhoR, j3cR) if not j_real: vjR -= numpy.einsum('xL,Lpq->xpq', jauxI, LpqI) vjR -= numpy.einsum('xL,Lpq->xpq', rhoI, j3cI) vjI += numpy.einsum('xL,Lpq->xpq', jauxR, LpqI) vjI += numpy.einsum('xL,Lpq->xpq', jauxI, LpqR) vjI += numpy.einsum('xL,Lpq->xpq', rhoR, j3cI) vjI += numpy.einsum('xL,Lpq->xpq', rhoI, j3cR) t2 = log.timer_debug1(' with_j', *t2) if with_k: nrow = LpqR.shape[0] pLqR = numpy.ndarray((nao,nrow,nao), buffer=buf1R) pjqR = numpy.ndarray((nao,nrow,nao), buffer=buf2R) pLqR[:] = LpqR.transpose(1,0,2) pjqR[:] = j3cR.transpose(1,0,2) if not k_real: pLqI = numpy.ndarray((nao,nrow,nao), buffer=buf1I) pjqI = numpy.ndarray((nao,nrow,nao), buffer=buf2I) pLqI[:] = LpqI.transpose(1,0,2) pjqI[:] = j3cI.transpose(1,0,2) thread_k = lib.background_thread(contract_k, pLqR, pLqI, pjqR, pjqI) t2 = log.timer_debug1(' with_k', *t2) LpqR = LpqI = j3cR = j3cI = None if thread_k is not None: thread_k.join() thread_k = None t1 = log.timer_debug1('mdf_jk.get_jk pass 1', *t1) vj, vk = pwdf_jk.get_jk(mydf, dm, hermi, kpt, kpt_band, with_j, with_k, exxdiv) if with_j: if j_real: vj += vjR.reshape(dm.shape) else: vj += (vjR+vjI*1j).reshape(dm.shape) vj = vj if with_k: if k_real: vk += vkR.reshape(dm.shape) else: vk += (vkR+vkI*1j).reshape(dm.shape) return vj, vk