def save(count, tmp_xo): di, dj = tmp_xo.shape[2:4] tmp_xo = [tmp_xo[p0:p1] for p0, p1 in olocs] tmp_xo = mpi.alltoall(tmp_xo, split_recvbuf=True) tmp_xo = sum(tmp_xo).reshape(nocc_seg,nocc,di,dj) ftmp[str(count)+'b'] = tmp_xo tmp_ox = mpi.alltoall([tmp_xo[:,p0:p1] for p0, p1 in olocs], split_recvbuf=True) tmp_ox = [tmp_ox[i].reshape(p1-p0,nocc_seg,di,dj) for i, (p0,p1) in enumerate(olocs)] ftmp[str(count)+'a'] = numpy.vstack(tmp_ox)
def vector_to_amplitudes(vector, nmo, nocc): nvir = nmo - nocc nov = nocc * nvir nocc2 = nocc * (nocc + 1) // 2 vlocs = [_task_location(nvir, task_id) for task_id in range(mpi.pool.size)] vloc0, vloc1 = vlocs[rank] nvir_seg = vloc1 - vloc0 if rank == 0: t1T = vector[:nov].copy().reshape((nvir, nocc)) mpi.bcast(t1T) t2tril = vector[nov:].reshape(nvir_seg, nvir, nocc2) else: t1T = mpi.bcast(None) t2tril = vector.reshape(nvir_seg, nvir, nocc2) t2T = lib.unpack_tril(t2tril.reshape(nvir_seg * nvir, nocc2), filltriu=lib.PLAIN) t2T = t2T.reshape(nvir_seg, nvir, nocc, nocc) t2tmp = mpi.alltoall([t2tril[:, p0:p1] for p0, p1 in vlocs], split_recvbuf=True) idx, idy = numpy.tril_indices(nocc) for task_id, (p0, p1) in enumerate(vlocs): tmp = t2tmp[task_id].reshape(p1 - p0, nvir_seg, nocc2) t2T[:, p0:p1, idy, idx] = tmp.transpose(1, 0, 2) return t1T.T, t2T.transpose(2, 3, 0, 1)
def alltoall(n, m): import numpy from mpi4pyscf.tools import mpi mpi.INT_MAX = 7 arrs = [ numpy.ones((n + i - mpi.rank, m - i + mpi.rank)) for i in range(mpi.pool.size) ] res = mpi.alltoall(arrs) print(res.shape) res = mpi.alltoall(arrs, split_recvbuf=True) print([x.shape for x in res]) if mpi.rank < 3: d1 = 3 else: d1 = 1 arrs = [numpy.zeros(s) for s in [(d1)] * mpi.pool.size] res = mpi.alltoall(arrs, split_recvbuf=True) print([x.shape for x in res])
def save(k, p0, p1, segs): segs = mpi.alltoall(segs) naux1 = nauxs[uniq_inverse[k]] loc0, loc1 = min(p0, naux1-naux0), min(p1, naux1-naux0) nL = loc1 - loc0 if nL > 0: if aosym_s2[k]: segs = numpy.hstack([segs[i0*nL:i1*nL].reshape(nL,-1) for i0,i1 in segs_loc_s2]) else: segs = numpy.hstack([segs[i0*nL:i1*nL].reshape(nL,-1) for i0,i1 in segs_loc_s1]) feri['j3c/%d'%k][loc0:loc1] = segs
def start(self, interval=0.02): mycc = self._cc log = logger.new_logger(mycc) cpu1 = (time.clock(), time.time()) eris = mycc._eris t2T = mycc.t2.transpose(2,3,0,1) nocc, nvir = mycc.t1.shape nmo = nocc + nvir vloc0, vloc1 = self.vranges[rank] nvir_seg = vloc1 - vloc0 max_memory = min(24000, mycc.max_memory - lib.current_memory()[0]) blksize = min(nvir_seg//4+1, max(16, int(max_memory*.3e6/8/(nvir*nocc*nmo)))) self.eri_tmp = lib.H5TmpFile() vvop = self.eri_tmp.create_dataset('vvop', (nvir_seg,nvir,nocc,nmo), 'f8') def save_vvop(j0, j1, vvvo): buf = numpy.empty((j1-j0,nvir,nocc,nmo), dtype=t2T.dtype) buf[:,:,:,:nocc] = eris.ovov[:,j0:j1].conj().transpose(1,3,0,2) for k, (q0, q1) in enumerate(self.vranges): blk = vvvo[k].reshape(q1-q0,nvir,j1-j0,nocc) buf[:,q0:q1,:,nocc:] = blk.transpose(2,0,3,1) vvop[j0:j1] = buf with lib.call_in_background(save_vvop) as save_vvop: for p0, p1 in mpi.prange(vloc0, vloc1, blksize): j0, j1 = p0 - vloc0, p1 - vloc0 sub_locs = comm.allgather((p0,p1)) vvvo = mpi.alltoall([eris.vvvo[:,:,q0:q1] for q0, q1 in sub_locs], split_recvbuf=True) save_vvop(j0, j1, vvvo) cpu1 = log.timer_debug1('transpose %d:%d'%(p0,p1), *cpu1) def send_data(): while True: while comm.Iprobe(source=MPI.ANY_SOURCE, tag=INQUIRY): tensors, dest = comm.recv(source=MPI.ANY_SOURCE, tag=INQUIRY) for task, slices in tensors: if task == 'Done': return else: mpi.send(self._get_tensor(task, slices), dest, tag=TRANSFER_DATA) time.sleep(interval) daemon = threading.Thread(target=send_data) daemon.start() return daemon
def save(k, p0, p1, segs): segs = mpi.alltoall(segs) naux1 = nauxs[uniq_inverse[k]] loc0, loc1 = min(p0, naux1 - naux0), min(p1, naux1 - naux0) nL = loc1 - loc0 if nL > 0: if aosym_s2[k]: segs = numpy.hstack([ segs[i0 * nL:i1 * nL].reshape(nL, -1) for i0, i1 in segs_loc_s2 ]) else: segs = numpy.hstack([ segs[i0 * nL:i1 * nL].reshape(nL, -1) for i0, i1 in segs_loc_s1 ]) feri['j3c/%d' % k][loc0:loc1] = segs
def save_vir_frac(p0, p1, eri): log.alldebug1('save_vir_frac %d %d %s', p0, p1, eri.shape) eri = eri.reshape(p1-p0,nocc,nmo,nmo) eris.ovoo[:,p0:p1] = eri[:,:,:nocc,:nocc].transpose(1,0,2,3) eris.ovvo[:,p0:p1] = eri[:,:,nocc:,:nocc].transpose(1,0,2,3) eris.ovov[:,p0:p1] = eri[:,:,:nocc,nocc:].transpose(1,0,2,3) # vvv = lib.pack_tril(eri[:,:,nocc:,nocc:].reshape((p1-p0)*nocc,nvir,nvir)) # eris.ovvv[:,p0:p1] = vvv.reshape(p1-p0,nocc,nvpair).transpose(1,0,2) cput2 = time.clock(), time.time() ovvv_segs = [eri[:,:,nocc+q0:nocc+q1,nocc:].transpose(2,3,0,1) for q0,q1 in vlocs] ovvv_segs = mpi.alltoall(ovvv_segs, split_recvbuf=True) cput2 = log.timer_debug1('vvvo alltoall', *cput2) for task_id, (q0, q1) in enumerate(comm.allgather((p0,p1))): ip0 = q0 + vlocs[task_id][0] ip1 = q1 + vlocs[task_id][0] eris.vvvo[:,:,ip0:ip1] = ovvv_segs[task_id].reshape(vseg,nvir,q1-q0,nocc)
def save_vir_frac(p0, p1, eri): log.alldebug1('save_vir_frac %d %d %s', p0, p1, eri.shape) eri = eri.reshape(p1 - p0, nocc, nmo, nmo) eris.ovoo[:, p0:p1] = eri[:, :, :nocc, :nocc].transpose(1, 0, 2, 3) eris.ovvo[:, p0:p1] = eri[:, :, nocc:, :nocc].transpose(1, 0, 2, 3) eris.ovov[:, p0:p1] = eri[:, :, :nocc, nocc:].transpose(1, 0, 2, 3) # vvv = lib.pack_tril(eri[:,:,nocc:,nocc:].reshape((p1-p0)*nocc,nvir,nvir)) # eris.ovvv[:,p0:p1] = vvv.reshape(p1-p0,nocc,nvpair).transpose(1,0,2) cput2 = time.clock(), time.time() ovvv_segs = [ eri[:, :, nocc + q0:nocc + q1, nocc:].transpose(2, 3, 0, 1) for q0, q1 in vlocs ] ovvv_segs = mpi.alltoall(ovvv_segs, split_recvbuf=True) cput2 = log.timer_debug1('vvvo alltoall', *cput2) for task_id, (q0, q1) in enumerate(comm.allgather((p0, p1))): ip0 = q0 + vlocs[task_id][0] ip1 = q1 + vlocs[task_id][0] eris.vvvo[:, :, ip0:ip1] = ovvv_segs[task_id].reshape( vseg, nvir, q1 - q0, nocc)
def vector_to_amplitudes(vector, nmo, nocc): nvir = nmo - nocc nov = nocc * nvir nocc2 = nocc*(nocc+1)//2 vlocs = [_task_location(nvir, task_id) for task_id in range(mpi.pool.size)] vloc0, vloc1 = vlocs[rank] nvir_seg = vloc1 - vloc0 if rank == 0: t1T = vector[:nov].copy().reshape((nvir,nocc)) mpi.bcast(t1T) t2tril = vector[nov:].reshape(nvir_seg,nvir,nocc2) else: t1T = mpi.bcast(None) t2tril = vector.reshape(nvir_seg,nvir,nocc2) t2T = lib.unpack_tril(t2tril.reshape(nvir_seg*nvir,nocc2), filltriu=lib.PLAIN) t2T = t2T.reshape(nvir_seg,nvir,nocc,nocc) t2tmp = mpi.alltoall([t2tril[:,p0:p1] for p0,p1 in vlocs], split_recvbuf=True) idx,idy = numpy.tril_indices(nocc) for task_id, (p0, p1) in enumerate(vlocs): tmp = t2tmp[task_id].reshape(p1-p0,nvir_seg,nocc2) t2T[:,p0:p1,idy,idx] = tmp.transpose(1,0,2) return t1T.T, t2T.transpose(2,3,0,1)
def update_amps(mycc, t1, t2, eris): time1 = time0 = time.clock(), time.time() log = logger.Logger(mycc.stdout, mycc.verbose) cpu1 = time0 t1T = t1.T t2T = numpy.asarray(t2.transpose(2, 3, 0, 1), order='C') nvir_seg, nvir, nocc = t2T.shape[:3] t1 = t2 = None ntasks = mpi.pool.size vlocs = [_task_location(nvir, task_id) for task_id in range(ntasks)] vloc0, vloc1 = vlocs[rank] log.debug2('vlocs %s', vlocs) assert (vloc1 - vloc0 == nvir_seg) fock = eris.fock mo_e_o = eris.mo_energy[:nocc] mo_e_v = eris.mo_energy[nocc:] + mycc.level_shift def _rotate_vir_block(buf): for task_id, buf in _rotate_tensor_block(buf): loc0, loc1 = vlocs[task_id] yield task_id, buf, loc0, loc1 fswap = lib.H5TmpFile() wVooV = numpy.zeros((nvir_seg, nocc, nocc, nvir)) eris_voov = _cp(eris.ovvo).transpose(1, 0, 3, 2) tau = t2T * .5 tau += numpy.einsum('ai,bj->abij', t1T[vloc0:vloc1], t1T) for task_id, tau, p0, p1 in _rotate_vir_block(tau): wVooV += lib.einsum('bkic,cajk->bija', eris_voov[:, :, :, p0:p1], tau) fswap['wVooV1'] = wVooV wVooV = tau = None time1 = log.timer_debug1('wVooV', *time1) wVOov = eris_voov eris_VOov = eris_voov - eris_voov.transpose(0, 2, 1, 3) * .5 tau = t2T.transpose(2, 0, 3, 1) - t2T.transpose(3, 0, 2, 1) * .5 tau -= numpy.einsum('ai,bj->jaib', t1T[vloc0:vloc1], t1T) for task_id, tau, p0, p1 in _rotate_vir_block(tau): wVOov += lib.einsum('dlkc,kcjb->dljb', eris_VOov[:, :, :, p0:p1], tau) fswap['wVOov1'] = wVOov wVOov = tau = eris_VOov = eris_voov = None time1 = log.timer_debug1('wVOov', *time1) t1Tnew = numpy.zeros_like(t1T) t2Tnew = mycc._add_vvvv(t1T, t2T, eris, t2sym='jiba') time1 = log.timer_debug1('vvvv', *time1) #** make_inter_F fov = fock[:nocc, nocc:].copy() t1Tnew += fock[nocc:, :nocc] foo = fock[:nocc, :nocc] - numpy.diag(mo_e_o) foo += .5 * numpy.einsum('ia,aj->ij', fock[:nocc, nocc:], t1T) fvv = fock[nocc:, nocc:] - numpy.diag(mo_e_v) fvv -= .5 * numpy.einsum('ai,ib->ab', t1T, fock[:nocc, nocc:]) foo_priv = numpy.zeros_like(foo) fov_priv = numpy.zeros_like(fov) fvv_priv = numpy.zeros_like(fvv) t1T_priv = numpy.zeros_like(t1T) max_memory = mycc.max_memory - lib.current_memory()[0] unit = nocc * nvir**2 * 3 + nocc**2 * nvir + 1 blksize = min(nvir, max(BLKMIN, int((max_memory * .9e6 / 8 - t2T.size) / unit))) log.debug1('pass 1, max_memory %d MB, nocc,nvir = %d,%d blksize = %d', max_memory, nocc, nvir, blksize) buf = numpy.empty((blksize, nvir, nvir, nocc)) def load_vvvo(p0): p1 = min(nvir_seg, p0 + blksize) if p0 < p1: buf[:p1 - p0] = eris.vvvo[p0:p1] fswap.create_dataset('wVooV', (nvir_seg, nocc, nocc, nvir), 'f8') wVOov = [] with lib.call_in_background(load_vvvo) as prefetch: load_vvvo(0) for p0, p1 in lib.prange(vloc0, vloc1, blksize): i0, i1 = p0 - vloc0, p1 - vloc0 eris_vvvo, buf = buf[:p1 - p0], numpy.empty_like(buf) prefetch(i1) fvv_priv[p0:p1] += 2 * numpy.einsum('ck,abck->ab', t1T, eris_vvvo) fvv_priv -= numpy.einsum('ck,cabk->ab', t1T[p0:p1], eris_vvvo) if not mycc.direct: raise NotImplementedError tau = t2T[i0:i1] + numpy.einsum('ai,bj->abij', t1T[p0:p1], t1T) for task_id, tau, q0, q1 in _rotate_vir_block(tau): tmp = lib.einsum('bdck,cdij->bkij', eris_vvvo[:, :, q0:q1], tau) t2Tnew -= lib.einsum('ak,bkij->baji', t1T, tmp) tau = tmp = None fswap['wVooV'][i0:i1] = lib.einsum('cj,baci->bija', -t1T, eris_vvvo) theta = t2T[i0:i1].transpose(0, 2, 1, 3) * 2 theta -= t2T[i0:i1].transpose(0, 3, 1, 2) t1T_priv += lib.einsum('bicj,bacj->ai', theta, eris_vvvo) wVOov.append(lib.einsum('acbi,cj->abij', eris_vvvo, t1T)) theta = eris_vvvo = None time1 = log.timer_debug1('vvvo [%d:%d]' % (p0, p1), *time1) wVOov = numpy.vstack(wVOov) wVOov = mpi.alltoall([wVOov[:, q0:q1] for q0, q1 in vlocs], split_recvbuf=True) wVOov = numpy.vstack([x.reshape(-1, nvir_seg, nocc, nocc) for x in wVOov]) fswap['wVOov'] = wVOov.transpose(1, 2, 3, 0) wVooV = None unit = nocc**2 * nvir * 7 + nocc**3 + nocc * nvir**2 max_memory = max(0, mycc.max_memory - lib.current_memory()[0]) blksize = min(nvir, max(BLKMIN, int((max_memory * .9e6 / 8 - nocc**4) / unit))) log.debug1('pass 2, max_memory %d MB, nocc,nvir = %d,%d blksize = %d', max_memory, nocc, nvir, blksize) woooo = numpy.zeros((nocc, nocc, nocc, nocc)) for p0, p1 in lib.prange(vloc0, vloc1, blksize): i0, i1 = p0 - vloc0, p1 - vloc0 wVOov = fswap['wVOov'][i0:i1] wVooV = fswap['wVooV'][i0:i1] eris_ovoo = eris.ovoo[:, i0:i1] eris_oovv = numpy.empty((nocc, nocc, i1 - i0, nvir)) def load_oovv(p0, p1): eris_oovv[:] = eris.oovv[:, :, p0:p1] with lib.call_in_background(load_oovv) as prefetch_oovv: #:eris_oovv = eris.oovv[:,:,i0:i1] prefetch_oovv(i0, i1) foo_priv += numpy.einsum('ck,kcji->ij', 2 * t1T[p0:p1], eris_ovoo) foo_priv += numpy.einsum('ck,icjk->ij', -t1T[p0:p1], eris_ovoo) tmp = lib.einsum('al,jaik->lkji', t1T[p0:p1], eris_ovoo) woooo += tmp + tmp.transpose(1, 0, 3, 2) tmp = None wVOov -= lib.einsum('jbik,ak->bjia', eris_ovoo, t1T) t2Tnew[i0:i1] += wVOov.transpose(0, 3, 1, 2) wVooV += lib.einsum('kbij,ak->bija', eris_ovoo, t1T) eris_ovoo = None load_oovv = prefetch_oovv = None eris_ovvo = numpy.empty((nocc, i1 - i0, nvir, nocc)) def load_ovvo(p0, p1): eris_ovvo[:] = eris.ovvo[:, p0:p1] with lib.call_in_background(load_ovvo) as prefetch_ovvo: #:eris_ovvo = eris.ovvo[:,i0:i1] prefetch_ovvo(i0, i1) t1T_priv[p0:p1] -= numpy.einsum('bj,jiab->ai', t1T, eris_oovv) wVooV -= eris_oovv.transpose(2, 0, 1, 3) wVOov += wVooV * .5 #: bjia + bija*.5 eris_voov = eris_ovvo.transpose(1, 0, 3, 2) eris_ovvo = None load_ovvo = prefetch_ovvo = None def update_wVooV(i0, i1): wVooV[:] += fswap['wVooV1'][i0:i1] fswap['wVooV1'][i0:i1] = wVooV wVOov[:] += fswap['wVOov1'][i0:i1] fswap['wVOov1'][i0:i1] = wVOov with lib.call_in_background(update_wVooV) as update_wVooV: update_wVooV(i0, i1) t2Tnew[i0:i1] += eris_voov.transpose(0, 3, 1, 2) * .5 t1T_priv[p0:p1] += 2 * numpy.einsum('bj,aijb->ai', t1T, eris_voov) tmp = lib.einsum('ci,kjbc->bijk', t1T, eris_oovv) tmp += lib.einsum('bjkc,ci->bjik', eris_voov, t1T) t2Tnew[i0:i1] -= lib.einsum('bjik,ak->baji', tmp, t1T) eris_oovv = tmp = None fov_priv[:, p0:p1] += numpy.einsum('ck,aikc->ia', t1T, eris_voov) * 2 fov_priv[:, p0:p1] -= numpy.einsum('ck,akic->ia', t1T, eris_voov) tau = numpy.einsum('ai,bj->abij', t1T[p0:p1] * .5, t1T) tau += t2T[i0:i1] theta = tau.transpose(0, 1, 3, 2) * 2 theta -= tau fvv_priv -= lib.einsum('caij,cjib->ab', theta, eris_voov) foo_priv += lib.einsum('aikb,abkj->ij', eris_voov, theta) tau = theta = None tau = t2T[i0:i1] + numpy.einsum('ai,bj->abij', t1T[p0:p1], t1T) woooo += lib.einsum('abij,aklb->ijkl', tau, eris_voov) tau = None eris_VOov = wVOov = wVooV = update_wVooV = None time1 = log.timer_debug1('voov [%d:%d]' % (p0, p1), *time1) wVooV = _cp(fswap['wVooV1']) for task_id, wVooV, p0, p1 in _rotate_vir_block(wVooV): tmp = lib.einsum('ackj,ckib->ajbi', t2T[:, p0:p1], wVooV) t2Tnew += tmp.transpose(0, 2, 3, 1) t2Tnew += tmp.transpose(0, 2, 1, 3) * .5 wVooV = tmp = None time1 = log.timer_debug1('contracting wVooV', *time1) wVOov = _cp(fswap['wVOov1']) theta = t2T * 2 theta -= t2T.transpose(0, 1, 3, 2) for task_id, wVOov, p0, p1 in _rotate_vir_block(wVOov): t2Tnew += lib.einsum('acik,ckjb->abij', theta[:, p0:p1], wVOov) wVOov = theta = None fswap = None time1 = log.timer_debug1('contracting wVOov', *time1) foo += mpi.allreduce(foo_priv) fov += mpi.allreduce(fov_priv) fvv += mpi.allreduce(fvv_priv) theta = t2T.transpose(0, 1, 3, 2) * 2 - t2T t1T_priv[vloc0:vloc1] += numpy.einsum('jb,abji->ai', fov, theta) ovoo = _cp(eris.ovoo) for task_id, ovoo, p0, p1 in _rotate_vir_block(ovoo): t1T_priv[vloc0:vloc1] -= lib.einsum('jbki,abjk->ai', ovoo, theta[:, p0:p1]) theta = ovoo = None woooo = mpi.allreduce(woooo) woooo += _cp(eris.oooo).transpose(0, 2, 1, 3) tau = t2T + numpy.einsum('ai,bj->abij', t1T[vloc0:vloc1], t1T) t2Tnew += .5 * lib.einsum('abkl,ijkl->abij', tau, woooo) tau = woooo = None t1Tnew += mpi.allreduce(t1T_priv) ft_ij = foo + numpy.einsum('aj,ia->ij', .5 * t1T, fov) ft_ab = fvv - numpy.einsum('ai,ib->ab', .5 * t1T, fov) t2Tnew += lib.einsum('acij,bc->abij', t2T, ft_ab) t2Tnew -= lib.einsum('ki,abkj->abij', ft_ij, t2T) eia = mo_e_o[:, None] - mo_e_v t1Tnew += numpy.einsum('bi,ab->ai', t1T, fvv) t1Tnew -= numpy.einsum('aj,ji->ai', t1T, foo) t1Tnew /= eia.T t2tmp = mpi.alltoall([t2Tnew[:, p0:p1] for p0, p1 in vlocs], split_recvbuf=True) for task_id, (p0, p1) in enumerate(vlocs): tmp = t2tmp[task_id].reshape(p1 - p0, nvir_seg, nocc, nocc) t2Tnew[:, p0:p1] += tmp.transpose(1, 0, 3, 2) for i in range(vloc0, vloc1): t2Tnew[i - vloc0] /= lib.direct_sum('i+jb->bij', eia[:, i], eia) time0 = log.timer_debug1('update t1 t2', *time0) return t1Tnew.T, t2Tnew.transpose(2, 3, 0, 1)
def update_amps(mycc, t1, t2, eris): time1 = time0 = time.clock(), time.time() log = logger.Logger(mycc.stdout, mycc.verbose) cpu1 = time0 t1T = t1.T t2T = numpy.asarray(t2.transpose(2,3,0,1), order='C') nvir_seg, nvir, nocc = t2T.shape[:3] t1 = t2 = None ntasks = mpi.pool.size vlocs = [_task_location(nvir, task_id) for task_id in range(ntasks)] vloc0, vloc1 = vlocs[rank] log.debug2('vlocs %s', vlocs) assert(vloc1-vloc0 == nvir_seg) fock = eris.fock mo_e_o = eris.mo_energy[:nocc] mo_e_v = eris.mo_energy[nocc:] + mycc.level_shift def _rotate_vir_block(buf): for task_id, buf in _rotate_tensor_block(buf): loc0, loc1 = vlocs[task_id] yield task_id, buf, loc0, loc1 fswap = lib.H5TmpFile() wVooV = numpy.zeros((nvir_seg,nocc,nocc,nvir)) eris_voov = _cp(eris.ovvo).transpose(1,0,3,2) tau = t2T * .5 tau += numpy.einsum('ai,bj->abij', t1T[vloc0:vloc1], t1T) for task_id, tau, p0, p1 in _rotate_vir_block(tau): wVooV += lib.einsum('bkic,cajk->bija', eris_voov[:,:,:,p0:p1], tau) fswap['wVooV1'] = wVooV wVooV = tau = None time1 = log.timer_debug1('wVooV', *time1) wVOov = eris_voov eris_VOov = eris_voov - eris_voov.transpose(0,2,1,3)*.5 tau = t2T.transpose(2,0,3,1) - t2T.transpose(3,0,2,1)*.5 tau -= numpy.einsum('ai,bj->jaib', t1T[vloc0:vloc1], t1T) for task_id, tau, p0, p1 in _rotate_vir_block(tau): wVOov += lib.einsum('dlkc,kcjb->dljb', eris_VOov[:,:,:,p0:p1], tau) fswap['wVOov1'] = wVOov wVOov = tau = eris_VOov = eris_voov = None time1 = log.timer_debug1('wVOov', *time1) t1Tnew = numpy.zeros_like(t1T) t2Tnew = mycc._add_vvvv(t1T, t2T, eris, t2sym='jiba') time1 = log.timer_debug1('vvvv', *time1) #** make_inter_F fov = fock[:nocc,nocc:].copy() t1Tnew += fock[nocc:,:nocc] foo = fock[:nocc,:nocc] - numpy.diag(mo_e_o) foo += .5 * numpy.einsum('ia,aj->ij', fock[:nocc,nocc:], t1T) fvv = fock[nocc:,nocc:] - numpy.diag(mo_e_v) fvv -= .5 * numpy.einsum('ai,ib->ab', t1T, fock[:nocc,nocc:]) foo_priv = numpy.zeros_like(foo) fov_priv = numpy.zeros_like(fov) fvv_priv = numpy.zeros_like(fvv) t1T_priv = numpy.zeros_like(t1T) max_memory = mycc.max_memory - lib.current_memory()[0] unit = nocc*nvir**2*3 + nocc**2*nvir + 1 blksize = min(nvir, max(BLKMIN, int((max_memory*.9e6/8-t2T.size)/unit))) log.debug1('pass 1, max_memory %d MB, nocc,nvir = %d,%d blksize = %d', max_memory, nocc, nvir, blksize) buf = numpy.empty((blksize,nvir,nvir,nocc)) def load_vvvo(p0): p1 = min(nvir_seg, p0+blksize) if p0 < p1: buf[:p1-p0] = eris.vvvo[p0:p1] fswap.create_dataset('wVooV', (nvir_seg,nocc,nocc,nvir), 'f8') wVOov = [] with lib.call_in_background(load_vvvo) as prefetch: load_vvvo(0) for p0, p1 in lib.prange(vloc0, vloc1, blksize): i0, i1 = p0 - vloc0, p1 - vloc0 eris_vvvo, buf = buf[:p1-p0], numpy.empty_like(buf) prefetch(i1) fvv_priv[p0:p1] += 2*numpy.einsum('ck,abck->ab', t1T, eris_vvvo) fvv_priv -= numpy.einsum('ck,cabk->ab', t1T[p0:p1], eris_vvvo) if not mycc.direct: raise NotImplementedError tau = t2T[i0:i1] + numpy.einsum('ai,bj->abij', t1T[p0:p1], t1T) for task_id, tau, q0, q1 in _rotate_vir_block(tau): tmp = lib.einsum('bdck,cdij->bkij', eris_vvvo[:,:,q0:q1], tau) t2Tnew -= lib.einsum('ak,bkij->baji', t1T, tmp) tau = tmp = None fswap['wVooV'][i0:i1] = lib.einsum('cj,baci->bija', -t1T, eris_vvvo) theta = t2T[i0:i1].transpose(0,2,1,3) * 2 theta -= t2T[i0:i1].transpose(0,3,1,2) t1T_priv += lib.einsum('bicj,bacj->ai', theta, eris_vvvo) wVOov.append(lib.einsum('acbi,cj->abij', eris_vvvo, t1T)) theta = eris_vvvo = None time1 = log.timer_debug1('vvvo [%d:%d]'%(p0, p1), *time1) wVOov = numpy.vstack(wVOov) wVOov = mpi.alltoall([wVOov[:,q0:q1] for q0,q1 in vlocs], split_recvbuf=True) wVOov = numpy.vstack([x.reshape(-1,nvir_seg,nocc,nocc) for x in wVOov]) fswap['wVOov'] = wVOov.transpose(1,2,3,0) wVooV = None unit = nocc**2*nvir*7 + nocc**3 + nocc*nvir**2 max_memory = max(0, mycc.max_memory - lib.current_memory()[0]) blksize = min(nvir, max(BLKMIN, int((max_memory*.9e6/8-nocc**4)/unit))) log.debug1('pass 2, max_memory %d MB, nocc,nvir = %d,%d blksize = %d', max_memory, nocc, nvir, blksize) woooo = numpy.zeros((nocc,nocc,nocc,nocc)) for p0, p1 in lib.prange(vloc0, vloc1, blksize): i0, i1 = p0 - vloc0, p1 - vloc0 wVOov = fswap['wVOov'][i0:i1] wVooV = fswap['wVooV'][i0:i1] eris_ovoo = eris.ovoo[:,i0:i1] eris_oovv = numpy.empty((nocc,nocc,i1-i0,nvir)) def load_oovv(p0, p1): eris_oovv[:] = eris.oovv[:,:,p0:p1] with lib.call_in_background(load_oovv) as prefetch_oovv: #:eris_oovv = eris.oovv[:,:,i0:i1] prefetch_oovv(i0, i1) foo_priv += numpy.einsum('ck,kcji->ij', 2*t1T[p0:p1], eris_ovoo) foo_priv += numpy.einsum('ck,icjk->ij', -t1T[p0:p1], eris_ovoo) tmp = lib.einsum('al,jaik->lkji', t1T[p0:p1], eris_ovoo) woooo += tmp + tmp.transpose(1,0,3,2) tmp = None wVOov -= lib.einsum('jbik,ak->bjia', eris_ovoo, t1T) t2Tnew[i0:i1] += wVOov.transpose(0,3,1,2) wVooV += lib.einsum('kbij,ak->bija', eris_ovoo, t1T) eris_ovoo = None load_oovv = prefetch_oovv = None eris_ovvo = numpy.empty((nocc,i1-i0,nvir,nocc)) def load_ovvo(p0, p1): eris_ovvo[:] = eris.ovvo[:,p0:p1] with lib.call_in_background(load_ovvo) as prefetch_ovvo: #:eris_ovvo = eris.ovvo[:,i0:i1] prefetch_ovvo(i0, i1) t1T_priv[p0:p1] -= numpy.einsum('bj,jiab->ai', t1T, eris_oovv) wVooV -= eris_oovv.transpose(2,0,1,3) wVOov += wVooV*.5 #: bjia + bija*.5 eris_voov = eris_ovvo.transpose(1,0,3,2) eris_ovvo = None load_ovvo = prefetch_ovvo = None def update_wVooV(i0, i1): wVooV[:] += fswap['wVooV1'][i0:i1] fswap['wVooV1'][i0:i1] = wVooV wVOov[:] += fswap['wVOov1'][i0:i1] fswap['wVOov1'][i0:i1] = wVOov with lib.call_in_background(update_wVooV) as update_wVooV: update_wVooV(i0, i1) t2Tnew[i0:i1] += eris_voov.transpose(0,3,1,2) * .5 t1T_priv[p0:p1] += 2*numpy.einsum('bj,aijb->ai', t1T, eris_voov) tmp = lib.einsum('ci,kjbc->bijk', t1T, eris_oovv) tmp += lib.einsum('bjkc,ci->bjik', eris_voov, t1T) t2Tnew[i0:i1] -= lib.einsum('bjik,ak->baji', tmp, t1T) eris_oovv = tmp = None fov_priv[:,p0:p1] += numpy.einsum('ck,aikc->ia', t1T, eris_voov) * 2 fov_priv[:,p0:p1] -= numpy.einsum('ck,akic->ia', t1T, eris_voov) tau = numpy.einsum('ai,bj->abij', t1T[p0:p1]*.5, t1T) tau += t2T[i0:i1] theta = tau.transpose(0,1,3,2) * 2 theta -= tau fvv_priv -= lib.einsum('caij,cjib->ab', theta, eris_voov) foo_priv += lib.einsum('aikb,abkj->ij', eris_voov, theta) tau = theta = None tau = t2T[i0:i1] + numpy.einsum('ai,bj->abij', t1T[p0:p1], t1T) woooo += lib.einsum('abij,aklb->ijkl', tau, eris_voov) tau = None eris_VOov = wVOov = wVooV = update_wVooV = None time1 = log.timer_debug1('voov [%d:%d]'%(p0, p1), *time1) wVooV = _cp(fswap['wVooV1']) for task_id, wVooV, p0, p1 in _rotate_vir_block(wVooV): tmp = lib.einsum('ackj,ckib->ajbi', t2T[:,p0:p1], wVooV) t2Tnew += tmp.transpose(0,2,3,1) t2Tnew += tmp.transpose(0,2,1,3) * .5 wVooV = tmp = None time1 = log.timer_debug1('contracting wVooV', *time1) wVOov = _cp(fswap['wVOov1']) theta = t2T * 2 theta -= t2T.transpose(0,1,3,2) for task_id, wVOov, p0, p1 in _rotate_vir_block(wVOov): t2Tnew += lib.einsum('acik,ckjb->abij', theta[:,p0:p1], wVOov) wVOov = theta = None fswap = None time1 = log.timer_debug1('contracting wVOov', *time1) foo += mpi.allreduce(foo_priv) fov += mpi.allreduce(fov_priv) fvv += mpi.allreduce(fvv_priv) theta = t2T.transpose(0,1,3,2) * 2 - t2T t1T_priv[vloc0:vloc1] += numpy.einsum('jb,abji->ai', fov, theta) ovoo = _cp(eris.ovoo) for task_id, ovoo, p0, p1 in _rotate_vir_block(ovoo): t1T_priv[vloc0:vloc1] -= lib.einsum('jbki,abjk->ai', ovoo, theta[:,p0:p1]) theta = ovoo = None woooo = mpi.allreduce(woooo) woooo += _cp(eris.oooo).transpose(0,2,1,3) tau = t2T + numpy.einsum('ai,bj->abij', t1T[vloc0:vloc1], t1T) t2Tnew += .5 * lib.einsum('abkl,ijkl->abij', tau, woooo) tau = woooo = None t1Tnew += mpi.allreduce(t1T_priv) ft_ij = foo + numpy.einsum('aj,ia->ij', .5*t1T, fov) ft_ab = fvv - numpy.einsum('ai,ib->ab', .5*t1T, fov) t2Tnew += lib.einsum('acij,bc->abij', t2T, ft_ab) t2Tnew -= lib.einsum('ki,abkj->abij', ft_ij, t2T) eia = mo_e_o[:,None] - mo_e_v t1Tnew += numpy.einsum('bi,ab->ai', t1T, fvv) t1Tnew -= numpy.einsum('aj,ji->ai', t1T, foo) t1Tnew /= eia.T t2tmp = mpi.alltoall([t2Tnew[:,p0:p1] for p0,p1 in vlocs], split_recvbuf=True) for task_id, (p0, p1) in enumerate(vlocs): tmp = t2tmp[task_id].reshape(p1-p0,nvir_seg,nocc,nocc) t2Tnew[:,p0:p1] += tmp.transpose(1,0,3,2) for i in range(vloc0, vloc1): t2Tnew[i-vloc0] /= lib.direct_sum('i+jb->bij', eia[:,i], eia) time0 = log.timer_debug1('update t1 t2', *time0) return t1Tnew.T, t2Tnew.transpose(2,3,0,1)