def cc_Wvvvv(t1T, t2T, eris, tauT=None, vlocs=None): """ Wvvvv intermidiates. """ # ZHC TODO make Wvvvv outcore nvir_seg, nvir, nocc, _ = t2T.shape if vlocs is None: ntasks = mpi.pool.size vlocs = [_task_location(nvir, task_id) for task_id in range(ntasks)] vloc0, vloc1 = vlocs[rank] if tauT is None: tauT = make_tauT(t1T, t2T, vlocs=vlocs) Wabef = np.empty((vloc1-vloc0, nvir, nvir, nvir)) eris_vvoo = eris.xvoo tauT = tauT * 0.25 for task_id, eri_tmp, p0, p1 in _rotate_vir_block(eris_vvoo, vlocs=vlocs): Wabef[:, :, p0:p1] = einsum('abmn, efmn -> abef', tauT, eri_tmp) eri_tmp = None eris_vvoo = None tauT = None Wabef += np.asarray(eris.xvvv) tmp = einsum('bm, amef -> abfe', t1T, eris.xovv) Wabef += tmp tmpT = mpi.alltoall_new([tmp[:, p0:p1] for p0, p1 in vlocs], split_recvbuf=True) for task_id, (p0, p1) in enumerate(vlocs): tmp = tmpT[task_id].reshape(p1-p0, nvir_seg, nvir, nvir) Wabef[:, p0:p1] -= tmp.transpose(1, 0, 2, 3) tmp = None return Wabef
def make_tauT(t1T, t2T, fac=1, vlocs=None): """ Make effective t2T (abij) using t2T and t1T. Args: t1T: ai t2T: [a]bij, a is segmented. fac: factor Returns: tauT: [a]bij, a is segmented. """ nvir_seg, nvir, nocc, _ = t2T.shape if vlocs is None: ntasks = mpi.pool.size vlocs = [_task_location(nvir, task_id) for task_id in range(ntasks)] vloc0, vloc1 = vlocs[rank] tauT = np.einsum('ai, bj -> abij', t1T[vloc0:vloc1] * (fac * 0.5), t1T, optimize=True) tauT = tauT - tauT.transpose(0, 1, 3, 2) ##:tauT = tauT - tauT.transpose(1, 0, 2, 3) tauT_tmp = mpi.alltoall_new([tauT[:, p0:p1] for p0, p1 in vlocs], split_recvbuf=True) for task_id, (p0, p1) in enumerate(vlocs): tmp = tauT_tmp[task_id].reshape(p1-p0, nvir_seg, nocc, nocc) tauT[:, p0:p1] -= tmp.transpose(1, 0, 2, 3) tauT += t2T return tauT
def vector_to_amplitudes(vector, nmo, nocc): nvir = nmo - nocc nov = nocc * nvir nocc2 = nocc * (nocc + 1) // 2 vlocs = [_task_location(nvir, task_id) for task_id in range(mpi.pool.size)] vloc0, vloc1 = vlocs[rank] nvir_seg = vloc1 - vloc0 if rank == 0: t1T = vector[:nov].copy().reshape((nvir, nocc)) mpi.bcast(t1T) t2tril = vector[nov:].reshape(nvir_seg, nvir, nocc2) else: t1T = mpi.bcast(None) t2tril = vector.reshape(nvir_seg, nvir, nocc2) t2T = lib.unpack_tril(t2tril.reshape(nvir_seg * nvir, nocc2), filltriu=lib.PLAIN) t2T = t2T.reshape(nvir_seg, nvir, nocc, nocc) t2tmp = mpi.alltoall_new([t2tril[:, p0:p1] for p0, p1 in vlocs], split_recvbuf=True) idx, idy = numpy.tril_indices(nocc) for task_id, (p0, p1) in enumerate(vlocs): tmp = t2tmp[task_id].reshape(p1 - p0, nvir_seg, nocc2) t2T[:, p0:p1, idy, idx] = tmp.transpose(1, 0, 2) return t1T.T, t2T.transpose(2, 3, 0, 1)
def start(self, interval=0.02): mycc = self._cc log = logger.new_logger(mycc) cpu1 = (logger.process_clock(), logger.perf_counter()) eris = mycc._eris t2T = mycc.t2.transpose(2, 3, 0, 1) nocc, nvir = mycc.t1.shape nmo = nocc + nvir vloc0, vloc1 = self.vranges[rank] nvir_seg = vloc1 - vloc0 max_memory = min(24000, mycc.max_memory - lib.current_memory()[0]) blksize = min( nvir_seg // 4 + 1, max(16, int(max_memory * .3e6 / 8 / (nvir * nocc * nmo)))) self.eri_tmp = lib.H5TmpFile() vvop = self.eri_tmp.create_dataset('vvop', (nvir_seg, nvir, nocc, nmo), 'f8') def save_vvop(j0, j1, vvvo): buf = numpy.empty((j1 - j0, nvir, nocc, nmo), dtype=t2T.dtype) buf[:, :, :, :nocc] = eris.ovov[:, j0:j1].conj().transpose( 1, 3, 0, 2) for k, (q0, q1) in enumerate(self.vranges): blk = vvvo[k].reshape(q1 - q0, nvir, j1 - j0, nocc) buf[:, q0:q1, :, nocc:] = blk.transpose(2, 0, 3, 1) vvop[j0:j1] = buf with lib.call_in_background(save_vvop) as save_vvop: for p0, p1 in mpi.prange(vloc0, vloc1, blksize): j0, j1 = p0 - vloc0, p1 - vloc0 sub_locs = comm.allgather((p0, p1)) vvvo = mpi.alltoall_new( [eris.vvvo[:, :, q0:q1] for q0, q1 in sub_locs], split_recvbuf=True) save_vvop(j0, j1, vvvo) cpu1 = log.timer_debug1('transpose %d:%d' % (p0, p1), *cpu1) def send_data(): while True: while comm.Iprobe(source=MPI.ANY_SOURCE, tag=INQUIRY): tensors, dest = comm.recv(source=MPI.ANY_SOURCE, tag=INQUIRY) for task, slices in tensors: if task == 'Done': return else: mpi.send(self._get_tensor(task, slices), dest, tag=TRANSFER_DATA) time.sleep(interval) daemon = threading.Thread(target=send_data) daemon.start() return daemon
def vector_to_amplitudes(vector, nmo, nocc): """ vector to amps, with the same bahavior as pyscf gccsd. """ nvir = nmo - nocc nov = nocc * nvir nocc2 = nocc * (nocc - 1) // 2 otril = np.tril_indices(nocc, k=-1) vlocs = [_task_location(nvir, task_id) for task_id in range(mpi.pool.size)] vloc0, vloc1 = vlocs[rank] nvir_seg = vloc1 - vloc0 vtril = get_vtril(nvir, vlocs[rank]) nvir2 = len(vtril[0]) if rank == 0: t1T = vector[:nov].copy().reshape(nvir, nocc) mpi.bcast(t1T) t2tril = vector[nov:].reshape(nvir2, nocc2) else: t1T = mpi.bcast(None) t2tril = vector.reshape(nvir2, nocc2) t2T = np.zeros((nvir_seg * nvir, nocc**2), dtype=t2tril.dtype) lib.takebak_2d(t2T, t2tril, vtril[0]*nvir+vtril[1], otril[0]*nocc+otril[1]) # anti-symmetry when exchanging two particle indices lib.takebak_2d(t2T, -t2tril, vtril[0]*nvir+vtril[1], otril[1]*nocc+otril[0]) t2T = t2T.reshape(nvir_seg, nvir, nocc, nocc) t2tmp = mpi.alltoall_new([t2T[:, p0:p1] for p0,p1 in vlocs], split_recvbuf=True) for task_id, (p0, p1) in enumerate(vlocs): if task_id < rank: # do not need this part since it is already filled. continue elif task_id == rank: # fill the trlu by -tril. v_idx = get_vtril(nvir, vlocs[task_id], p0=p0, p1=p1) tmp = t2tmp[task_id].reshape(p1-p0, nvir_seg, nocc, nocc) t2T[v_idx[1]-p0, v_idx[0]+p0] = tmp[v_idx[0], v_idx[1]-p0].transpose(0, 2, 1) else: tmp = t2tmp[task_id].reshape(p1-p0, nvir_seg, nocc, nocc) t2T[:, p0:p1] = tmp.transpose(1, 0, 3, 2) return t1T.T, t2T.transpose(2, 3, 0, 1)
def save_vir_frac(p0, p1, eri): log.alldebug1('save_vir_frac %d %d %s', p0, p1, eri.shape) eri = eri.reshape(p1 - p0, nocc, nmo, nmo) eris.ovoo[:, p0:p1] = eri[:, :, :nocc, :nocc].transpose(1, 0, 2, 3) eris.ovvo[:, p0:p1] = eri[:, :, nocc:, :nocc].transpose(1, 0, 2, 3) eris.ovov[:, p0:p1] = eri[:, :, :nocc, nocc:].transpose(1, 0, 2, 3) # vvv = lib.pack_tril(eri[:,:,nocc:,nocc:].reshape((p1-p0)*nocc,nvir,nvir)) # eris.ovvv[:,p0:p1] = vvv.reshape(p1-p0,nocc,nvpair).transpose(1,0,2) cput2 = logger.process_clock(), logger.perf_counter() ovvv_segs = [ eri[:, :, nocc + q0:nocc + q1, nocc:].transpose(2, 3, 0, 1) for q0, q1 in vlocs ] ovvv_segs = mpi.alltoall_new(ovvv_segs, split_recvbuf=True) cput2 = log.timer_debug1('vvvo alltoall', *cput2) for task_id, (q0, q1) in enumerate(comm.allgather((p0, p1))): ip0 = q0 + vlocs[task_id][0] ip1 = q1 + vlocs[task_id][0] eris.vvvo[:, :, ip0:ip1] = ovvv_segs[task_id].reshape( vseg, nvir, q1 - q0, nocc)
def update_lambda(mycc, t1, t2, l1, l2, eris, imds): """ Update GCCSD lambda. """ time0 = logger.process_clock(), logger.perf_counter() log = logger.Logger(mycc.stdout, mycc.verbose) t1T = t1.T t2T = np.asarray(t2.transpose(2, 3, 0, 1), order='C') t1 = t2 = None nvir_seg, nvir, nocc = t2T.shape[:3] l1T = l1.T l2T = np.asarray(l2.transpose(2, 3, 0, 1), order='C') l1 = l2 = None ntasks = mpi.pool.size vlocs = [_task_location(nvir, task_id) for task_id in range(ntasks)] vloc0, vloc1 = vlocs[rank] log.debug2('vlocs %s', vlocs) assert vloc1 - vloc0 == nvir_seg fvo = eris.fock[nocc:, :nocc] mo_e_o = eris.mo_energy[:nocc] mo_e_v = eris.mo_energy[nocc:] + mycc.level_shift v1 = imds.v1 - np.diag(mo_e_v) v2 = imds.v2 - np.diag(mo_e_o) mba = einsum('cakl, cbkl -> ba', l2T, t2T) * 0.5 mba = mpi.allreduce_inplace(mba) mij = einsum('cdki, cdkj -> ij', l2T, t2T) * 0.5 mij = mpi.allreduce_inplace(mij) # m3 [a]bij m3 = einsum('abkl, ijkl -> abij', l2T, np.asarray(imds.woooo)) tauT = t2T #+ np.einsum('ai, bj -> abij', t1T[vloc0:vloc1] * 2.0, t1T, optimize=True) tmp = einsum('cdij, cdkl -> ijkl', l2T, tauT) tmp = mpi.allreduce_inplace(tmp) tauT = None vvoo = np.asarray(eris.xvoo) tmp = einsum('abkl, ijkl -> abij', vvoo, tmp) tmp *= 0.25 m3 += tmp tmp = None #tmp = einsum('cdij, dk -> ckij', l2T, t1T) #for task_id, tmp, p0, p1 in _rotate_vir_block(tmp, vlocs=vlocs): # m3 -= einsum('kcba, ckij -> abij', eris.ovvx[:, p0:p1], tmp) # tmp = None eris_vvvv = eris.xvvv.transpose(2, 3, 0, 1) tmp_2 = np.empty_like(l2T) # used for line 387 for task_id, l2T_tmp, p0, p1 in _rotate_vir_block(l2T, vlocs=vlocs): tmp = einsum('cdij, cdab -> abij', l2T_tmp, eris_vvvv[p0:p1]) tmp *= 0.5 m3 += tmp tmp_2[:, p0:p1] = einsum('acij, cb -> baij', l2T_tmp, v1[:, vloc0:vloc1]) tmp = l2T_tmp = None eris_vvvv = None #l1Tnew = einsum('abij, bj -> ai', m3, t1T) #l1Tnew = mpi.allgather(l1Tnew) l1Tnew = np.zeros_like(l1T) l2Tnew = m3 l2Tnew += vvoo #fvo1 = fvo #+ mpi.allreduce(einsum('cbkj, ck -> bj', vvoo, t1T[vloc0:vloc1])) #tmp = np.einsum('ai, bj -> abij', l1T[vloc0:vloc1], fvo1, optimize=True) tmp = 0.0 wvovo = np.asarray(imds.wovvo).transpose(1, 0, 2, 3) for task_id, w_tmp, p0, p1 in _rotate_vir_block(wvovo, vlocs=vlocs): tmp -= einsum('acki, cjbk -> abij', l2T[:, p0:p1], w_tmp) w_tmp = None wvovo = None tmp = tmp - tmp.transpose(0, 1, 3, 2) l2Tnew += tmp tmpT = mpi.alltoall_new([tmp[:, p0:p1] for p0, p1 in vlocs], split_recvbuf=True) for task_id, (p0, p1) in enumerate(vlocs): tmp = tmpT[task_id].reshape(p1 - p0, nvir_seg, nocc, nocc) l2Tnew[:, p0:p1] -= tmp.transpose(1, 0, 2, 3) tmp = None #tmp = einsum('ak, ijkb -> baij', l1T, eris.ooox) #tmp -= tmp_2 tmp = -tmp_2 tmp1vv = mba #+ np.dot(t1T, l1T.T) # ba tmp -= einsum('ca, bcij -> baij', tmp1vv, vvoo) l2Tnew += tmp tmpT = mpi.alltoall_new([tmp[:, p0:p1] for p0, p1 in vlocs], split_recvbuf=True) for task_id, (p0, p1) in enumerate(vlocs): tmp = tmpT[task_id].reshape(p1 - p0, nvir_seg, nocc, nocc) l2Tnew[:, p0:p1] -= tmp.transpose(1, 0, 2, 3) tmp = None #tmp = einsum('jcab, ci -> baji', eris.ovvx, -l1T) tmp = einsum('abki, jk -> abij', l2T, v2) tmp1oo = mij #+ np.dot(l1T.T, t1T) # ik tmp -= einsum('ik, abkj -> abij', tmp1oo, vvoo) vvoo = None l2Tnew += tmp l2Tnew -= tmp.transpose(0, 1, 3, 2) tmp = None #l1Tnew += fvo #tmp = einsum('bj, ibja -> ai', -l1T[vloc0:vloc1], eris.oxov) #l1Tnew += np.dot(v1.T, l1T) #l1Tnew -= np.dot(l1T, v2.T) #tmp -= einsum('cakj, icjk -> ai', l2T, imds.wovoo) #tmp -= einsum('bcak, bcik -> ai', imds.wvvvo, l2T) #tmp += einsum('baji, bj -> ai', l2T, imds.w3[vloc0:vloc1]) #tmp_2 = t1T[vloc0:vloc1] - np.dot(tmp1vv[vloc0:vloc1], t1T) #tmp_2 -= np.dot(t1T[vloc0:vloc1], mij) #tmp_2 += einsum('bcjk, ck -> bj', t2T, l1T) #tmp += einsum('baji, bj -> ai', vvoo, tmp_2) #tmp_2 = None #tmp += einsum('icab, bc -> ai', eris.oxvv, tmp1vv[:, vloc0:vloc1]) #l1Tnew += mpi.allreduce(tmp) #l1Tnew -= mpi.allgather(einsum('jika, kj -> ai', eris.ooox, tmp1oo)) #tmp = fvo - mpi.allreduce(einsum('bakj, bj -> ak', vvoo, t1T[vloc0:vloc1])) #vvoo = None #l1Tnew -= np.dot(tmp, mij.T) #l1Tnew -= np.dot(mba.T, tmp) eia = mo_e_o[:, None] - mo_e_v #l1Tnew /= eia.T for i in range(vloc0, vloc1): l2Tnew[i - vloc0] /= lib.direct_sum('i + jb -> bij', eia[:, i], eia) time0 = log.timer_debug1('update l1 l2', *time0) return l1Tnew.T, l2Tnew.transpose(2, 3, 0, 1)
def update_amps(mycc, t1, t2, eris): """ Update GCCD amplitudes. """ time0 = logger.process_clock(), logger.perf_counter() log = logger.Logger(mycc.stdout, mycc.verbose) t1T = t1.T t2T = np.asarray(t2.transpose(2, 3, 0, 1), order='C') nvir_seg, nvir, nocc = t2T.shape[:3] t2 = None ntasks = mpi.pool.size vlocs = [_task_location(nvir, task_id) for task_id in range(ntasks)] vloc0, vloc1 = vlocs[rank] log.debug2('vlocs %s', vlocs) assert vloc1 - vloc0 == nvir_seg fock = eris.fock fvo = fock[nocc:, :nocc] mo_e_o = eris.mo_energy[:nocc] mo_e_v = eris.mo_energy[nocc:] + mycc.level_shift tauT_tilde = make_tauT(t1T, t2T, fac=0.5, vlocs=vlocs) Fvv = cc_Fvv(t1T, t2T, eris, tauT_tilde=tauT_tilde, vlocs=vlocs) Foo = cc_Foo(t1T, t2T, eris, tauT_tilde=tauT_tilde, vlocs=vlocs) tauT_tilde = None Fov = cc_Fov(t1T, eris, vlocs=vlocs) # Move energy terms to the other side Fvv[np.diag_indices(nvir)] -= mo_e_v Foo[np.diag_indices(nocc)] -= mo_e_o # T1 equation t1Tnew = np.zeros_like(t1T) #t1Tnew = np.dot(Fvv, t1T) #t1Tnew -= np.dot(t1T, Foo) tmp = einsum('aeim, me -> ai', t2T, Fov) #tmp -= np.einsum('fn, naif -> ai', t1T, eris.oxov, optimize=True) tmp = mpi.allgather(tmp) #tmp2 = einsum('eamn, mnie -> ai', t2T, eris.ooox) tmp2 = einsum('eamn, einm -> ai', t2T, eris.xooo) #tmp2 += einsum('efim, mafe -> ai', t2T, eris.ovvx) tmp2 += einsum('efim, efam -> ai', t2T, eris.xvvo) tmp2 *= 0.5 tmp2 = mpi.allreduce_inplace(tmp2) tmp += tmp2 tmp2 = None #t1Tnew += tmp #t1Tnew += fvo # T2 equation Ftmp = Fvv #- 0.5 * np.dot(t1T, Fov) t2Tnew = einsum('aeij, be -> abij', t2T, Ftmp) t2T_tmp = mpi.alltoall_new([t2Tnew[:, p0:p1] for p0, p1 in vlocs], split_recvbuf=True) for task_id, (p0, p1) in enumerate(vlocs): tmp = t2T_tmp[task_id].reshape(p1 - p0, nvir_seg, nocc, nocc) t2Tnew[:, p0:p1] -= tmp.transpose(1, 0, 2, 3) tmp = None t2T_tmp = None Ftmp = Foo #+ 0.5 * np.dot(Fov, t1T) tmp = einsum('abim, mj -> abij', t2T, Ftmp) t2Tnew -= tmp t2Tnew += tmp.transpose(0, 1, 3, 2) tmp = None t2Tnew += np.asarray(eris.xvoo) tauT = make_tauT(t1T, t2T, vlocs=vlocs) Woooo = cc_Woooo(t1T, t2T, eris, tauT=tauT, vlocs=vlocs) Woooo *= 0.5 t2Tnew += einsum('abmn, mnij -> abij', tauT, Woooo) Woooo = None Wvvvv = cc_Wvvvv(t1T, t2T, eris, tauT=tauT, vlocs=vlocs) for task_id, tauT_tmp, p0, p1 in _rotate_vir_block(tauT, vlocs=vlocs): tmp = einsum('abef, efij -> abij', Wvvvv[:, :, p0:p1], tauT_tmp) tmp *= 0.5 t2Tnew += tmp tmp = tauT_tmp = None Wvvvv = None tauT = None #tmp = einsum('mbje, ei -> bmij', eris.oxov, t1T) # [b]mij #tmp = mpi.allgather(tmp) # bmij #tmp = einsum('am, bmij -> abij', t1T[vloc0:vloc1], tmp) # [a]bij tmp = 0.0 Wvovo = cc_Wovvo(t1T, t2T, eris, vlocs=vlocs).transpose(2, 0, 1, 3) for task_id, w_tmp, p0, p1 in _rotate_vir_block(Wvovo, vlocs=vlocs): tmp += einsum('aeim, embj -> abij', t2T[:, p0:p1], w_tmp) w_tmp = None Wvovo = None tmp = tmp - tmp.transpose(0, 1, 3, 2) t2Tnew += tmp tmpT = mpi.alltoall_new([tmp[:, p0:p1] for p0, p1 in vlocs], split_recvbuf=True) for task_id, (p0, p1) in enumerate(vlocs): tmp = tmpT[task_id].reshape(p1 - p0, nvir_seg, nocc, nocc) t2Tnew[:, p0:p1] -= tmp.transpose(1, 0, 2, 3) tmp = None tmpT = None #tmp = einsum('ei, jeba -> abij', t1T, eris.ovvx) #t2Tnew += tmp #t2Tnew -= tmp.transpose(0, 1, 3, 2) #tmp = einsum('am, ijmb -> baij', t1T, eris.ooox.conj()) #t2Tnew += tmp #tmpT = mpi.alltoall([tmp[:, p0:p1] for p0, p1 in vlocs], # split_recvbuf=True) #for task_id, (p0, p1) in enumerate(vlocs): # tmp = tmpT[task_id].reshape(p1-p0, nvir_seg, nocc, nocc) # t2Tnew[:, p0:p1] -= tmp.transpose(1, 0, 2, 3) # tmp = None #tmpT = None eia = mo_e_o[:, None] - mo_e_v #t1Tnew /= eia.T for i in range(vloc0, vloc1): t2Tnew[i - vloc0] /= lib.direct_sum('i + jb -> bij', eia[:, i], eia) time0 = log.timer_debug1('update t1 t2', *time0) return t1Tnew.T, t2Tnew.transpose(2, 3, 0, 1)
def update_amps(mycc, t1, t2, eris): time1 = time0 = logger.process_clock(), logger.perf_counter() log = logger.Logger(mycc.stdout, mycc.verbose) cpu1 = time0 t1T = t1.T t2T = numpy.asarray(t2.transpose(2, 3, 0, 1), order='C') nvir_seg, nvir, nocc = t2T.shape[:3] t1 = t2 = None ntasks = mpi.pool.size vlocs = [_task_location(nvir, task_id) for task_id in range(ntasks)] vloc0, vloc1 = vlocs[rank] log.debug2('vlocs %s', vlocs) assert (vloc1 - vloc0 == nvir_seg) fock = eris.fock mo_e_o = eris.mo_energy[:nocc] mo_e_v = eris.mo_energy[nocc:] + mycc.level_shift def _rotate_vir_block(buf): for task_id, buf in _rotate_tensor_block(buf): loc0, loc1 = vlocs[task_id] yield task_id, buf, loc0, loc1 fswap = lib.H5TmpFile() wVooV = numpy.zeros((nvir_seg, nocc, nocc, nvir)) eris_voov = _cp(eris.ovvo).transpose(1, 0, 3, 2) tau = t2T * .5 tau += numpy.einsum('ai,bj->abij', t1T[vloc0:vloc1], t1T) for task_id, tau, p0, p1 in _rotate_vir_block(tau): wVooV += lib.einsum('bkic,cajk->bija', eris_voov[:, :, :, p0:p1], tau) fswap['wVooV1'] = wVooV wVooV = tau = None time1 = log.timer_debug1('wVooV', *time1) wVOov = eris_voov eris_VOov = eris_voov - eris_voov.transpose(0, 2, 1, 3) * .5 tau = t2T.transpose(2, 0, 3, 1) - t2T.transpose(3, 0, 2, 1) * .5 tau -= numpy.einsum('ai,bj->jaib', t1T[vloc0:vloc1], t1T) for task_id, tau, p0, p1 in _rotate_vir_block(tau): wVOov += lib.einsum('dlkc,kcjb->dljb', eris_VOov[:, :, :, p0:p1], tau) fswap['wVOov1'] = wVOov wVOov = tau = eris_VOov = eris_voov = None time1 = log.timer_debug1('wVOov', *time1) t1Tnew = numpy.zeros_like(t1T) t2Tnew = mycc._add_vvvv(t1T, t2T, eris, t2sym='jiba') time1 = log.timer_debug1('vvvv', *time1) #** make_inter_F fov = fock[:nocc, nocc:].copy() t1Tnew += fock[nocc:, :nocc] foo = fock[:nocc, :nocc] - numpy.diag(mo_e_o) foo += .5 * numpy.einsum('ia,aj->ij', fock[:nocc, nocc:], t1T) fvv = fock[nocc:, nocc:] - numpy.diag(mo_e_v) fvv -= .5 * numpy.einsum('ai,ib->ab', t1T, fock[:nocc, nocc:]) foo_priv = numpy.zeros_like(foo) fov_priv = numpy.zeros_like(fov) fvv_priv = numpy.zeros_like(fvv) t1T_priv = numpy.zeros_like(t1T) max_memory = mycc.max_memory - lib.current_memory()[0] unit = nocc * nvir**2 * 3 + nocc**2 * nvir + 1 blksize = min(nvir, max(BLKMIN, int((max_memory * .9e6 / 8 - t2T.size) / unit))) log.debug1('pass 1, max_memory %d MB, nocc,nvir = %d,%d blksize = %d', max_memory, nocc, nvir, blksize) buf = numpy.empty((blksize, nvir, nvir, nocc)) def load_vvvo(p0): p1 = min(nvir_seg, p0 + blksize) if p0 < p1: buf[:p1 - p0] = eris.vvvo[p0:p1] fswap.create_dataset('wVooV', (nvir_seg, nocc, nocc, nvir), 'f8') wVOov = [] with lib.call_in_background(load_vvvo) as prefetch: load_vvvo(0) for p0, p1 in lib.prange(vloc0, vloc1, blksize): i0, i1 = p0 - vloc0, p1 - vloc0 eris_vvvo, buf = buf[:p1 - p0], numpy.empty_like(buf) prefetch(i1) fvv_priv[p0:p1] += 2 * numpy.einsum('ck,abck->ab', t1T, eris_vvvo) fvv_priv -= numpy.einsum('ck,cabk->ab', t1T[p0:p1], eris_vvvo) if not mycc.direct: raise NotImplementedError tau = t2T[i0:i1] + numpy.einsum('ai,bj->abij', t1T[p0:p1], t1T) for task_id, tau, q0, q1 in _rotate_vir_block(tau): tmp = lib.einsum('bdck,cdij->bkij', eris_vvvo[:, :, q0:q1], tau) t2Tnew -= lib.einsum('ak,bkij->baji', t1T, tmp) tau = tmp = None fswap['wVooV'][i0:i1] = lib.einsum('cj,baci->bija', -t1T, eris_vvvo) theta = t2T[i0:i1].transpose(0, 2, 1, 3) * 2 theta -= t2T[i0:i1].transpose(0, 3, 1, 2) t1T_priv += lib.einsum('bicj,bacj->ai', theta, eris_vvvo) wVOov.append(lib.einsum('acbi,cj->abij', eris_vvvo, t1T)) theta = eris_vvvo = None time1 = log.timer_debug1('vvvo [%d:%d]' % (p0, p1), *time1) wVOov = numpy.vstack(wVOov) wVOov = mpi.alltoall_new([wVOov[:, q0:q1] for q0, q1 in vlocs], split_recvbuf=True) wVOov = numpy.vstack([x.reshape(-1, nvir_seg, nocc, nocc) for x in wVOov]) fswap['wVOov'] = wVOov.transpose(1, 2, 3, 0) wVooV = None unit = nocc**2 * nvir * 7 + nocc**3 + nocc * nvir**2 max_memory = max(0, mycc.max_memory - lib.current_memory()[0]) blksize = min(nvir, max(BLKMIN, int((max_memory * .9e6 / 8 - nocc**4) / unit))) log.debug1('pass 2, max_memory %d MB, nocc,nvir = %d,%d blksize = %d', max_memory, nocc, nvir, blksize) woooo = numpy.zeros((nocc, nocc, nocc, nocc)) for p0, p1 in lib.prange(vloc0, vloc1, blksize): i0, i1 = p0 - vloc0, p1 - vloc0 wVOov = fswap['wVOov'][i0:i1] wVooV = fswap['wVooV'][i0:i1] eris_ovoo = eris.ovoo[:, i0:i1] eris_oovv = numpy.empty((nocc, nocc, i1 - i0, nvir)) def load_oovv(p0, p1): eris_oovv[:] = eris.oovv[:, :, p0:p1] with lib.call_in_background(load_oovv) as prefetch_oovv: #:eris_oovv = eris.oovv[:,:,i0:i1] prefetch_oovv(i0, i1) foo_priv += numpy.einsum('ck,kcji->ij', 2 * t1T[p0:p1], eris_ovoo) foo_priv += numpy.einsum('ck,icjk->ij', -t1T[p0:p1], eris_ovoo) tmp = lib.einsum('al,jaik->lkji', t1T[p0:p1], eris_ovoo) woooo += tmp + tmp.transpose(1, 0, 3, 2) tmp = None wVOov -= lib.einsum('jbik,ak->bjia', eris_ovoo, t1T) t2Tnew[i0:i1] += wVOov.transpose(0, 3, 1, 2) wVooV += lib.einsum('kbij,ak->bija', eris_ovoo, t1T) eris_ovoo = None load_oovv = prefetch_oovv = None eris_ovvo = numpy.empty((nocc, i1 - i0, nvir, nocc)) def load_ovvo(p0, p1): eris_ovvo[:] = eris.ovvo[:, p0:p1] with lib.call_in_background(load_ovvo) as prefetch_ovvo: #:eris_ovvo = eris.ovvo[:,i0:i1] prefetch_ovvo(i0, i1) t1T_priv[p0:p1] -= numpy.einsum('bj,jiab->ai', t1T, eris_oovv) wVooV -= eris_oovv.transpose(2, 0, 1, 3) wVOov += wVooV * .5 #: bjia + bija*.5 eris_voov = eris_ovvo.transpose(1, 0, 3, 2) eris_ovvo = None load_ovvo = prefetch_ovvo = None def update_wVooV(i0, i1): wVooV[:] += fswap['wVooV1'][i0:i1] fswap['wVooV1'][i0:i1] = wVooV wVOov[:] += fswap['wVOov1'][i0:i1] fswap['wVOov1'][i0:i1] = wVOov with lib.call_in_background(update_wVooV) as update_wVooV: update_wVooV(i0, i1) t2Tnew[i0:i1] += eris_voov.transpose(0, 3, 1, 2) * .5 t1T_priv[p0:p1] += 2 * numpy.einsum('bj,aijb->ai', t1T, eris_voov) tmp = lib.einsum('ci,kjbc->bijk', t1T, eris_oovv) tmp += lib.einsum('bjkc,ci->bjik', eris_voov, t1T) t2Tnew[i0:i1] -= lib.einsum('bjik,ak->baji', tmp, t1T) eris_oovv = tmp = None fov_priv[:, p0:p1] += numpy.einsum('ck,aikc->ia', t1T, eris_voov) * 2 fov_priv[:, p0:p1] -= numpy.einsum('ck,akic->ia', t1T, eris_voov) tau = numpy.einsum('ai,bj->abij', t1T[p0:p1] * .5, t1T) tau += t2T[i0:i1] theta = tau.transpose(0, 1, 3, 2) * 2 theta -= tau fvv_priv -= lib.einsum('caij,cjib->ab', theta, eris_voov) foo_priv += lib.einsum('aikb,abkj->ij', eris_voov, theta) tau = theta = None tau = t2T[i0:i1] + numpy.einsum('ai,bj->abij', t1T[p0:p1], t1T) woooo += lib.einsum('abij,aklb->ijkl', tau, eris_voov) tau = None eris_VOov = wVOov = wVooV = update_wVooV = None time1 = log.timer_debug1('voov [%d:%d]' % (p0, p1), *time1) wVooV = _cp(fswap['wVooV1']) for task_id, wVooV, p0, p1 in _rotate_vir_block(wVooV): tmp = lib.einsum('ackj,ckib->ajbi', t2T[:, p0:p1], wVooV) t2Tnew += tmp.transpose(0, 2, 3, 1) t2Tnew += tmp.transpose(0, 2, 1, 3) * .5 wVooV = tmp = None time1 = log.timer_debug1('contracting wVooV', *time1) wVOov = _cp(fswap['wVOov1']) theta = t2T * 2 theta -= t2T.transpose(0, 1, 3, 2) for task_id, wVOov, p0, p1 in _rotate_vir_block(wVOov): t2Tnew += lib.einsum('acik,ckjb->abij', theta[:, p0:p1], wVOov) wVOov = theta = None fswap = None time1 = log.timer_debug1('contracting wVOov', *time1) foo += mpi.allreduce(foo_priv) fov += mpi.allreduce(fov_priv) fvv += mpi.allreduce(fvv_priv) theta = t2T.transpose(0, 1, 3, 2) * 2 - t2T t1T_priv[vloc0:vloc1] += numpy.einsum('jb,abji->ai', fov, theta) ovoo = _cp(eris.ovoo) for task_id, ovoo, p0, p1 in _rotate_vir_block(ovoo): t1T_priv[vloc0:vloc1] -= lib.einsum('jbki,abjk->ai', ovoo, theta[:, p0:p1]) theta = ovoo = None woooo = mpi.allreduce(woooo) woooo += _cp(eris.oooo).transpose(0, 2, 1, 3) tau = t2T + numpy.einsum('ai,bj->abij', t1T[vloc0:vloc1], t1T) t2Tnew += .5 * lib.einsum('abkl,ijkl->abij', tau, woooo) tau = woooo = None t1Tnew += mpi.allreduce(t1T_priv) ft_ij = foo + numpy.einsum('aj,ia->ij', .5 * t1T, fov) ft_ab = fvv - numpy.einsum('ai,ib->ab', .5 * t1T, fov) t2Tnew += lib.einsum('acij,bc->abij', t2T, ft_ab) t2Tnew -= lib.einsum('ki,abkj->abij', ft_ij, t2T) eia = mo_e_o[:, None] - mo_e_v t1Tnew += numpy.einsum('bi,ab->ai', t1T, fvv) t1Tnew -= numpy.einsum('aj,ji->ai', t1T, foo) t1Tnew /= eia.T t2tmp = mpi.alltoall_new([t2Tnew[:, p0:p1] for p0, p1 in vlocs], split_recvbuf=True) for task_id, (p0, p1) in enumerate(vlocs): tmp = t2tmp[task_id].reshape(p1 - p0, nvir_seg, nocc, nocc) t2Tnew[:, p0:p1] += tmp.transpose(1, 0, 3, 2) for i in range(vloc0, vloc1): t2Tnew[i - vloc0] /= lib.direct_sum('i+jb->bij', eia[:, i], eia) time0 = log.timer_debug1('update t1 t2', *time0) return t1Tnew.T, t2Tnew.transpose(2, 3, 0, 1)