def _gamma1_intermediates(mycc, t1, t2, l1, l2): t1T = t1.T t2T = t2.transpose(2, 3, 0, 1) l1T = l1.T l2T = l2.transpose(2, 3, 0, 1) t1 = t2 = l1 = l2 = None #doo = -np.dot(l1T.T, t1T) doo = mpi.allreduce(einsum('efim, efjm -> ij', l2T, t2T) * (-0.5)) #dvv = np.dot(t1T, l1T.T) dvv = mpi.allreduce(einsum('eamn, ebmn -> ab', t2T, l2T) * 0.5) #xt1 = mpi.allreduce(einsum('efmn, efin -> mi', l2T, t2T) * 0.5) #xt2 = mpi.allreduce(einsum('famn, femn -> ae', t2T, l2T) * 0.5) #xt2 += np.dot(t1T, l1T.T) #dvo = mpi.allgather(np.einsum('aeim, em -> ai', t2T, l1T, optimize=True)) #dvo -= np.dot(t1T, xt1) #dvo -= np.dot(xt2, t1T) #dvo += t1T #dov = l1T.T nvir, nocc = t1T.shape dvo = np.zeros((nvir, nocc), dtype=t1T.dtype) dov = np.zeros((nocc, nvir), dtype=t1T.dtype) return doo, dov, dvo, dvv
def get_j_kpts(mydf, dm_kpts, hermi=1, kpts=numpy.zeros((1,3)), kpts_band=None): mydf = _sync_mydf(mydf) cell = mydf.cell mesh = mydf.mesh dm_kpts = lib.asarray(dm_kpts, order='C') dms = _format_dms(dm_kpts, kpts) nset, nkpts, nao = dms.shape[:3] coulG = tools.get_coulG(cell, mesh=mesh) ngrids = len(coulG) vR = rhoR = numpy.zeros((nset,ngrids)) for ao_ks_etc, p0, p1 in mydf.mpi_aoR_loop(mydf.grids, kpts): ao_ks = ao_ks_etc[0] for k, ao in enumerate(ao_ks): for i in range(nset): rhoR[i,p0:p1] += numint.eval_rho(cell, ao, dms[i,k]) ao = ao_ks = None rhoR = mpi.allreduce(rhoR) for i in range(nset): rhoR[i] *= 1./nkpts rhoG = tools.fft(rhoR[i], mesh) vG = coulG * rhoG vR[i] = tools.ifft(vG, mesh).real kpts_band, input_band = _format_kpts_band(kpts_band, kpts), kpts_band nband = len(kpts_band) weight = cell.vol / ngrids vR *= weight if gamma_point(kpts_band): vj_kpts = numpy.zeros((nset,nband,nao,nao)) else: vj_kpts = numpy.zeros((nset,nband,nao,nao), dtype=numpy.complex128) for ao_ks_etc, p0, p1 in mydf.mpi_aoR_loop(mydf.grids, kpts_band): ao_ks = ao_ks_etc[0] for k, ao in enumerate(ao_ks): for i in range(nset): vj_kpts[i,k] += lib.dot(ao.T.conj()*vR[i,p0:p1], ao) vj_kpts = mpi.reduce(vj_kpts) if gamma_point(kpts_band): vj_kpts = vj_kpts.real return _format_jks(vj_kpts, dm_kpts, input_band, kpts)
def _make_j3c(mydf, cell, auxcell, kptij_lst, cderi_file): log = logger.Logger(mydf.stdout, mydf.verbose) t1 = t0 = (time.clock(), time.time()) fused_cell, fuse = fuse_auxcell(mydf, mydf.auxcell) ao_loc = cell.ao_loc_nr() nao = ao_loc[-1] naux = auxcell.nao_nr() nkptij = len(kptij_lst) mesh = mydf.mesh Gv, Gvbase, kws = cell.get_Gv_weights(mesh) b = cell.reciprocal_vectors() gxyz = lib.cartesian_prod([numpy.arange(len(x)) for x in Gvbase]) ngrids = gxyz.shape[0] kptis = kptij_lst[:, 0] kptjs = kptij_lst[:, 1] kpt_ji = kptjs - kptis uniq_kpts, uniq_index, uniq_inverse = unique(kpt_ji) log.debug('Num uniq kpts %d', len(uniq_kpts)) log.debug2('uniq_kpts %s', uniq_kpts) # j2c ~ (-kpt_ji | kpt_ji) j2c = fused_cell.pbc_intor('int2c2e', hermi=1, kpts=uniq_kpts) j2ctags = [] t1 = log.timer_debug1('2c2e', *t1) swapfile = tempfile.NamedTemporaryFile(dir=os.path.dirname(cderi_file)) fswap = lib.H5TmpFile(swapfile.name) # Unlink swapfile to avoid trash swapfile = None mem_now = max(comm.allgather(lib.current_memory()[0])) max_memory = max(2000, mydf.max_memory - mem_now) blksize = max(2048, int(max_memory * .5e6 / 16 / fused_cell.nao_nr())) log.debug2('max_memory %s (MB) blocksize %s', max_memory, blksize) for k, kpt in enumerate(uniq_kpts): coulG = mydf.weighted_coulG(kpt, False, mesh) j2c_k = numpy.zeros_like(j2c[k]) for p0, p1 in mydf.prange(0, ngrids, blksize): aoaux = ft_ao.ft_ao(fused_cell, Gv[p0:p1], None, b, gxyz[p0:p1], Gvbase, kpt).T LkR = numpy.asarray(aoaux.real, order='C') LkI = numpy.asarray(aoaux.imag, order='C') aoaux = None if is_zero(kpt): # kpti == kptj j2c_k[naux:] += lib.ddot(LkR[naux:] * coulG[p0:p1], LkR.T) j2c_k[naux:] += lib.ddot(LkI[naux:] * coulG[p0:p1], LkI.T) else: j2cR, j2cI = zdotCN(LkR[naux:] * coulG[p0:p1], LkI[naux:] * coulG[p0:p1], LkR.T, LkI.T) j2c_k[naux:] += j2cR + j2cI * 1j kLR = kLI = None j2c_k[:naux, naux:] = j2c_k[naux:, :naux].conj().T j2c[k] -= mpi.allreduce(j2c_k) j2c[k] = fuse(fuse(j2c[k]).T).T try: fswap['j2c/%d' % k] = scipy.linalg.cholesky(j2c[k], lower=True) j2ctags.append('CD') except scipy.linalg.LinAlgError as e: #msg =('===================================\n' # 'J-metric not positive definite.\n' # 'It is likely that mesh is not enough.\n' # '===================================') #log.error(msg) #raise scipy.linalg.LinAlgError('\n'.join([str(e), msg])) w, v = scipy.linalg.eigh(j2c[k]) log.debug2('metric linear dependency for kpt %s', k) log.debug2('cond = %.4g, drop %d bfns', w[0] / w[-1], numpy.count_nonzero(w < mydf.linear_dep_threshold)) v1 = v[:, w > mydf.linear_dep_threshold].T.conj() v1 /= numpy.sqrt(w[w > mydf.linear_dep_threshold]).reshape(-1, 1) fswap['j2c/%d' % k] = v1 if cell.dimension == 2 and cell.low_dim_ft_type != 'inf_vacuum': idx = numpy.where(w < -mydf.linear_dep_threshold)[0] if len(idx) > 0: fswap['j2c-/%d' % k] = (v[:, idx] / numpy.sqrt(-w[idx])).conj().T w = v = v1 = None j2ctags.append('eig') j2c = coulG = None aosym_s2 = numpy.einsum('ix->i', abs(kptis - kptjs)) < 1e-9 j_only = numpy.all(aosym_s2) if gamma_point(kptij_lst): dtype = 'f8' else: dtype = 'c16' t1 = log.timer_debug1('aoaux and int2c', *t1) # Estimates the buffer size based on the last contraction in G-space. # This contraction requires to hold nkptj copies of (naux,?) array # simultaneously in memory. mem_now = max(comm.allgather(lib.current_memory()[0])) max_memory = max(2000, mydf.max_memory - mem_now) nkptj_max = max((uniq_inverse == x).sum() for x in set(uniq_inverse)) buflen = max( int( min(max_memory * .5e6 / 16 / naux / (nkptj_max + 2) / nao, nao / 3 / mpi.pool.size)), 1) chunks = (buflen, nao) j3c_jobs = grids2d_int3c_jobs(cell, auxcell, kptij_lst, chunks, j_only) log.debug1('max_memory = %d MB (%d in use) chunks %s', max_memory, mem_now, chunks) log.debug2('j3c_jobs %s', j3c_jobs) if j_only: int3c = wrap_int3c(cell, fused_cell, 'int3c2e', 's2', 1, kptij_lst) else: int3c = wrap_int3c(cell, fused_cell, 'int3c2e', 's1', 1, kptij_lst) idxb = numpy.tril_indices(nao) idxb = (idxb[0] * nao + idxb[1]).astype('i') aux_loc = fused_cell.ao_loc_nr('ssc' in 'int3c2e') def gen_int3c(job_id, ish0, ish1): dataname = 'j3c-chunks/%d' % job_id i0 = ao_loc[ish0] i1 = ao_loc[ish1] dii = i1 * (i1 + 1) // 2 - i0 * (i0 + 1) // 2 if j_only: dij = dii buflen = max(8, int(max_memory * 1e6 / 16 / (nkptij * dii + dii))) else: dij = (i1 - i0) * nao buflen = max(8, int(max_memory * 1e6 / 16 / (nkptij * dij + dij))) auxranges = balance_segs(aux_loc[1:] - aux_loc[:-1], buflen) buflen = max([x[2] for x in auxranges]) buf = numpy.empty(nkptij * dij * buflen, dtype=dtype) buf1 = numpy.empty(dij * buflen, dtype=dtype) naux = aux_loc[-1] for kpt_id, kptij in enumerate(kptij_lst): key = '%s/%d' % (dataname, kpt_id) if aosym_s2[kpt_id]: shape = (naux, dii) else: shape = (naux, dij) if gamma_point(kptij): fswap.create_dataset(key, shape, 'f8') else: fswap.create_dataset(key, shape, 'c16') naux0 = 0 for istep, auxrange in enumerate(auxranges): log.alldebug2("aux_e1 job_id %d step %d", job_id, istep) sh0, sh1, nrow = auxrange sub_slice = (ish0, ish1, 0, cell.nbas, sh0, sh1) mat = numpy.ndarray((nkptij, dij, nrow), dtype=dtype, buffer=buf) mat = int3c(sub_slice, mat) for k, kptij in enumerate(kptij_lst): h5dat = fswap['%s/%d' % (dataname, k)] v = lib.transpose(mat[k], out=buf1) if not j_only and aosym_s2[k]: idy = idxb[i0 * (i0 + 1) // 2:i1 * (i1 + 1) // 2] - i0 * nao out = numpy.ndarray((nrow, dii), dtype=v.dtype, buffer=mat[k]) v = numpy.take(v, idy, axis=1, out=out) if gamma_point(kptij): h5dat[naux0:naux0 + nrow] = v.real else: h5dat[naux0:naux0 + nrow] = v naux0 += nrow def ft_fuse(job_id, uniq_kptji_id, sh0, sh1): kpt = uniq_kpts[uniq_kptji_id] # kpt = kptj - kpti adapted_ji_idx = numpy.where(uniq_inverse == uniq_kptji_id)[0] adapted_kptjs = kptjs[adapted_ji_idx] nkptj = len(adapted_kptjs) j2c = numpy.asarray(fswap['j2c/%d' % uniq_kptji_id]) j2ctag = j2ctags[uniq_kptji_id] naux0 = j2c.shape[0] if ('j2c-/%d' % uniq_kptji_id) in fswap: j2c_negative = numpy.asarray(fswap['j2c-/%d' % uniq_kptji_id]) else: j2c_negative = None if is_zero(kpt): aosym = 's2' else: aosym = 's1' if aosym == 's2' and cell.dimension == 3: vbar = fuse(mydf.auxbar(fused_cell)) ovlp = cell.pbc_intor('int1e_ovlp', hermi=1, kpts=adapted_kptjs) ovlp = [lib.pack_tril(s) for s in ovlp] j3cR = [None] * nkptj j3cI = [None] * nkptj i0 = ao_loc[sh0] i1 = ao_loc[sh1] for k, idx in enumerate(adapted_ji_idx): key = 'j3c-chunks/%d/%d' % (job_id, idx) v = numpy.asarray(fswap[key]) if aosym == 's2' and cell.dimension == 3: for i in numpy.where(vbar != 0)[0]: v[i] -= vbar[i] * ovlp[k][i0 * (i0 + 1) // 2:i1 * (i1 + 1) // 2].ravel() j3cR[k] = numpy.asarray(v.real, order='C') if v.dtype == numpy.complex128: j3cI[k] = numpy.asarray(v.imag, order='C') v = None ncol = j3cR[0].shape[1] Gblksize = max(16, int(max_memory * 1e6 / 16 / ncol / (nkptj + 1))) # +1 for pqkRbuf/pqkIbuf Gblksize = min(Gblksize, ngrids, 16384) pqkRbuf = numpy.empty(ncol * Gblksize) pqkIbuf = numpy.empty(ncol * Gblksize) buf = numpy.empty(nkptj * ncol * Gblksize, dtype=numpy.complex128) log.alldebug2('job_id %d blksize (%d,%d)', job_id, Gblksize, ncol) wcoulG = mydf.weighted_coulG(kpt, False, mesh) fused_cell_slice = (auxcell.nbas, fused_cell.nbas) if aosym == 's2': shls_slice = (sh0, sh1, 0, sh1) else: shls_slice = (sh0, sh1, 0, cell.nbas) for p0, p1 in lib.prange(0, ngrids, Gblksize): Gaux = ft_ao.ft_ao(fused_cell, Gv[p0:p1], fused_cell_slice, b, gxyz[p0:p1], Gvbase, kpt) Gaux *= wcoulG[p0:p1, None] kLR = Gaux.real.copy('C') kLI = Gaux.imag.copy('C') Gaux = None dat = ft_ao._ft_aopair_kpts(cell, Gv[p0:p1], shls_slice, aosym, b, gxyz[p0:p1], Gvbase, kpt, adapted_kptjs, out=buf) nG = p1 - p0 for k, ji in enumerate(adapted_ji_idx): aoao = dat[k].reshape(nG, ncol) pqkR = numpy.ndarray((ncol, nG), buffer=pqkRbuf) pqkI = numpy.ndarray((ncol, nG), buffer=pqkIbuf) pqkR[:] = aoao.real.T pqkI[:] = aoao.imag.T lib.dot(kLR.T, pqkR.T, -1, j3cR[k][naux:], 1) lib.dot(kLI.T, pqkI.T, -1, j3cR[k][naux:], 1) if not (is_zero(kpt) and gamma_point(adapted_kptjs[k])): lib.dot(kLR.T, pqkI.T, -1, j3cI[k][naux:], 1) lib.dot(kLI.T, pqkR.T, 1, j3cI[k][naux:], 1) kLR = kLI = None for k, idx in enumerate(adapted_ji_idx): if is_zero(kpt) and gamma_point(adapted_kptjs[k]): v = fuse(j3cR[k]) else: v = fuse(j3cR[k] + j3cI[k] * 1j) if j2ctag == 'CD': v = scipy.linalg.solve_triangular(j2c, v, lower=True, overwrite_b=True) fswap['j3c-chunks/%d/%d' % (job_id, idx)][:naux0] = v else: fswap['j3c-chunks/%d/%d' % (job_id, idx)][:naux0] = lib.dot( j2c, v) # low-dimension systems if j2c_negative is not None: fswap['j3c-/%d/%d' % (job_id, idx)] = lib.dot(j2c_negative, v) _assemble(mydf, kptij_lst, j3c_jobs, gen_int3c, ft_fuse, cderi_file, fswap, log)
def _assemble(mydf, kptij_lst, j3c_jobs, gen_int3c, ft_fuse, cderi_file, fswap, log): t1 = (time.clock(), time.time()) cell = mydf.cell ao_loc = cell.ao_loc_nr() nao = ao_loc[-1] kptis = kptij_lst[:, 0] kptjs = kptij_lst[:, 1] kpt_ji = kptjs - kptis uniq_kpts, uniq_index, uniq_inverse = unique(kpt_ji) aosym_s2 = numpy.einsum('ix->i', abs(kptis - kptjs)) < 1e-9 t2 = t1 j3c_workers = numpy.zeros(len(j3c_jobs), dtype=int) #for job_id, ish0, ish1 in mpi.work_share_partition(j3c_jobs): for job_id, ish0, ish1 in mpi.work_stealing_partition(j3c_jobs): gen_int3c(job_id, ish0, ish1) t2 = log.alltimer_debug2('int j3c %d' % job_id, *t2) for k, kpt in enumerate(uniq_kpts): ft_fuse(job_id, k, ish0, ish1) t2 = log.alltimer_debug2('ft-fuse %d k %d' % (job_id, k), *t2) j3c_workers[job_id] = rank j3c_workers = mpi.allreduce(j3c_workers) log.debug2('j3c_workers %s', j3c_workers) t1 = log.timer_debug1('int3c and fuse', *t1) # Pass 2 # Transpose 3-index tensor and save data in cderi_file feri = h5py.File(cderi_file, 'w') nauxs = [fswap['j2c/%d' % k].shape[0] for k, kpt in enumerate(uniq_kpts)] segsize = (max(nauxs) + mpi.pool.size - 1) // mpi.pool.size naux0 = rank * segsize for k, kptij in enumerate(kptij_lst): naux1 = min(nauxs[uniq_inverse[k]], naux0 + segsize) nrow = max(0, naux1 - naux0) if gamma_point(kptij): dtype = 'f8' else: dtype = 'c16' if aosym_s2[k]: nao_pair = nao * (nao + 1) // 2 else: nao_pair = nao * nao feri.create_dataset('j3c/%d' % k, (nrow, nao_pair), dtype, maxshape=(None, nao_pair)) def get_segs_loc(aosym): off0 = numpy.asarray([ao_loc[i0] for x, i0, i1 in j3c_jobs]) off1 = numpy.asarray([ao_loc[i1] for x, i0, i1 in j3c_jobs]) if aosym: # s2 dims = off1 * (off1 + 1) // 2 - off0 * (off0 + 1) // 2 else: dims = (off1 - off0) * nao #dims = numpy.asarray([ao_loc[i1]-ao_loc[i0] for x,i0,i1 in j3c_jobs]) dims = numpy.hstack( [dims[j3c_workers == w] for w in range(mpi.pool.size)]) job_idx = numpy.hstack( [numpy.where(j3c_workers == w)[0] for w in range(mpi.pool.size)]) segs_loc = numpy.append(0, numpy.cumsum(dims)) segs_loc = [(segs_loc[j], segs_loc[j + 1]) for j in numpy.argsort(job_idx)] return segs_loc segs_loc_s1 = get_segs_loc(False) segs_loc_s2 = get_segs_loc(True) job_ids = numpy.where(rank == j3c_workers)[0] def load(k, p0, p1): naux1 = nauxs[uniq_inverse[k]] slices = [(min(i * segsize + p0, naux1), min(i * segsize + p1, naux1)) for i in range(mpi.pool.size)] segs = [] for p0, p1 in slices: val = [ fswap['j3c-chunks/%d/%d' % (job, k)][p0:p1].ravel() for job in job_ids ] if val: segs.append(numpy.hstack(val)) else: segs.append(numpy.zeros(0)) return segs def save(k, p0, p1, segs): segs = mpi.alltoall(segs) naux1 = nauxs[uniq_inverse[k]] loc0, loc1 = min(p0, naux1 - naux0), min(p1, naux1 - naux0) nL = loc1 - loc0 if nL > 0: if aosym_s2[k]: segs = numpy.hstack([ segs[i0 * nL:i1 * nL].reshape(nL, -1) for i0, i1 in segs_loc_s2 ]) else: segs = numpy.hstack([ segs[i0 * nL:i1 * nL].reshape(nL, -1) for i0, i1 in segs_loc_s1 ]) feri['j3c/%d' % k][loc0:loc1] = segs mem_now = max(comm.allgather(lib.current_memory()[0])) max_memory = max(2000, min(8000, mydf.max_memory - mem_now)) if numpy.all(aosym_s2): if gamma_point(kptij_lst): blksize = max(16, int(max_memory * .5e6 / 8 / nao**2)) else: blksize = max(16, int(max_memory * .5e6 / 16 / nao**2)) else: blksize = max(16, int(max_memory * .5e6 / 16 / nao**2 / 2)) log.debug1('max_momory %d MB (%d in use), blksize %d', max_memory, mem_now, blksize) t2 = t1 with lib.call_in_background(save) as async_write: for k, kptji in enumerate(kptij_lst): for p0, p1 in lib.prange(0, segsize, blksize): segs = load(k, p0, p1) async_write(k, p0, p1, segs) t2 = log.timer_debug1( 'assemble k=%d %d:%d (in %d)' % (k, p0, p1, segsize), *t2) if 'j2c-' in fswap: j2c_kpts_lists = [] for k, kpt in enumerate(uniq_kpts): if ('j2c-/%d' % k) in fswap: adapted_ji_idx = numpy.where(uniq_inverse == k)[0] j2c_kpts_lists.append(adapted_ji_idx) for k in numpy.hstack(j2c_kpts_lists): val = [ numpy.asarray(fswap['j3c-/%d/%d' % (job, k)]).ravel() for job in job_ids ] val = mpi.gather(numpy.hstack(val)) if rank == 0: naux1 = fswap['j3c-/0/%d' % k].shape[0] if aosym_s2[k]: v = [ val[i0 * naux1:i1 * naux1].reshape(naux1, -1) for i0, i1 in segs_loc_s2 ] else: v = [ val[i0 * naux1:i1 * naux1].reshape(naux1, -1) for i0, i1 in segs_loc_s1 ] feri['j3c-/%d' % k] = numpy.hstack(v) if 'j3c-kptij' in feri: del (feri['j3c-kptij']) feri['j3c-kptij'] = kptij_lst t1 = log.alltimer_debug1('assembling j3c', *t1) feri.close()
def _make_j3c(mydf, cell, auxcell, kptij_lst, cderi_file): log = logger.Logger(mydf.stdout, mydf.verbose) t1 = t0 = (time.clock(), time.time()) fused_cell, fuse = fuse_auxcell(mydf, mydf.auxcell) ao_loc = cell.ao_loc_nr() nao = ao_loc[-1] naux = auxcell.nao_nr() nkptij = len(kptij_lst) gs = mydf.gs Gv, Gvbase, kws = cell.get_Gv_weights(gs) b = cell.reciprocal_vectors() gxyz = lib.cartesian_prod([numpy.arange(len(x)) for x in Gvbase]) ngs = gxyz.shape[0] kptis = kptij_lst[:, 0] kptjs = kptij_lst[:, 1] kpt_ji = kptjs - kptis uniq_kpts, uniq_index, uniq_inverse = unique(kpt_ji) log.debug('Num uniq kpts %d', len(uniq_kpts)) log.debug2('uniq_kpts %s', uniq_kpts) # j2c ~ (-kpt_ji | kpt_ji) j2c = fused_cell.pbc_intor('int2c2e_sph', hermi=1, kpts=uniq_kpts) j2ctags = [] nauxs = [] t1 = log.timer_debug1('2c2e', *t1) if h5py.is_hdf5(cderi_file): feri = h5py.File(cderi_file) else: feri = h5py.File(cderi_file, 'w') for k, kpt in enumerate(uniq_kpts): aoaux = ft_ao.ft_ao(fused_cell, Gv, None, b, gxyz, Gvbase, kpt).T coulG = numpy.sqrt(mydf.weighted_coulG(kpt, False, gs)) kLR = (aoaux.real * coulG).T kLI = (aoaux.imag * coulG).T if not kLR.flags.c_contiguous: kLR = lib.transpose(kLR.T) if not kLI.flags.c_contiguous: kLI = lib.transpose(kLI.T) aoaux = None kLR1 = numpy.asarray(kLR[:, naux:], order='C') kLI1 = numpy.asarray(kLI[:, naux:], order='C') if is_zero(kpt): # kpti == kptj for p0, p1 in mydf.mpi_prange(0, ngs): j2cR = lib.ddot(kLR1[p0:p1].T, kLR[p0:p1]) j2cR = lib.ddot(kLI1[p0:p1].T, kLI[p0:p1], 1, j2cR, 1) j2c[k][naux:] -= mpi.allreduce(j2cR) j2c[k][:naux, naux:] = j2c[k][naux:, :naux].T else: for p0, p1 in mydf.mpi_prange(0, ngs): j2cR, j2cI = zdotCN(kLR1[p0:p1].T, kLI1[p0:p1].T, kLR[p0:p1], kLI[p0:p1]) j2cR = mpi.allreduce(j2cR) j2cI = mpi.allreduce(j2cI) j2c[k][naux:] -= j2cR + j2cI * 1j j2c[k][:naux, naux:] = j2c[k][naux:, :naux].T.conj() j2c[k] = fuse(fuse(j2c[k]).T).T try: feri['j2c/%d' % k] = scipy.linalg.cholesky(j2c[k], lower=True) j2ctags.append('CD') nauxs.append(naux) except scipy.linalg.LinAlgError as e: #msg =('===================================\n' # 'J-metric not positive definite.\n' # 'It is likely that gs is not enough.\n' # '===================================') #log.error(msg) #raise scipy.linalg.LinAlgError('\n'.join([e.message, msg])) w, v = scipy.linalg.eigh(j2c) log.debug2('metric linear dependency for kpt %s', uniq_kptji_id) log.debug2('cond = %.4g, drop %d bfns', w[0] / w[-1], numpy.count_nonzero(w < LINEAR_DEP_THR)) v = v[:, w > LINEAR_DEP_THR].T.conj() v /= numpy.sqrt(w[w > LINEAR_DEP_THR]).reshape(-1, 1) feri['j2c/%d' % k] = v j2ctags.append('eig') nauxs.append(v.shape[0]) kLR = kLI = kLR1 = kLI1 = coulG = None j2c = None aosym_s2 = numpy.einsum('ix->i', abs(kptis - kptjs)) < 1e-9 j_only = numpy.all(aosym_s2) if gamma_point(kptij_lst): dtype = 'f8' else: dtype = 'c16' vbar = mydf.auxbar(fused_cell) vbar = fuse(vbar) ovlp = cell.pbc_intor('int1e_ovlp_sph', hermi=1, kpts=kptjs[aosym_s2]) ovlp = [lib.pack_tril(s) for s in ovlp] t1 = log.timer_debug1('aoaux and int2c', *t1) # Estimates the buffer size based on the last contraction in G-space. # This contraction requires to hold nkptj copies of (naux,?) array # simultaneously in memory. mem_now = max(comm.allgather(lib.current_memory()[0])) max_memory = max(2000, mydf.max_memory - mem_now) nkptj_max = max((uniq_inverse == x).sum() for x in set(uniq_inverse)) buflen = max( int( min(max_memory * .5e6 / 16 / naux / (nkptj_max + 2) / nao, nao / 3 / mpi.pool.size)), 1) chunks = (buflen, nao) j3c_jobs = grids2d_int3c_jobs(cell, auxcell, kptij_lst, chunks, j_only) log.debug1('max_memory = %d MB (%d in use) chunks %s', max_memory, mem_now, chunks) log.debug2('j3c_jobs %s', j3c_jobs) if j_only: int3c = wrap_int3c(cell, fused_cell, 'int3c2e_sph', 's2', 1, kptij_lst) else: int3c = wrap_int3c(cell, fused_cell, 'int3c2e_sph', 's1', 1, kptij_lst) idxb = numpy.tril_indices(nao) idxb = (idxb[0] * nao + idxb[1]).astype('i') aux_loc = fused_cell.ao_loc_nr('ssc' in 'int3c2e_sph') def gen_int3c(auxcell, job_id, ish0, ish1): dataname = 'j3c-chunks/%d' % job_id if dataname in feri: del (feri[dataname]) i0 = ao_loc[ish0] i1 = ao_loc[ish1] dii = i1 * (i1 + 1) // 2 - i0 * (i0 + 1) // 2 dij = (i1 - i0) * nao if j_only: buflen = max(8, int(max_memory * 1e6 / 16 / (nkptij * dii + dii))) else: buflen = max(8, int(max_memory * 1e6 / 16 / (nkptij * dij + dij))) auxranges = balance_segs(aux_loc[1:] - aux_loc[:-1], buflen) buflen = max([x[2] for x in auxranges]) buf = numpy.empty(nkptij * dij * buflen, dtype=dtype) buf1 = numpy.empty(dij * buflen, dtype=dtype) naux = aux_loc[-1] for kpt_id, kptij in enumerate(kptij_lst): key = '%s/%d' % (dataname, kpt_id) if aosym_s2[kpt_id]: shape = (naux, dii) else: shape = (naux, dij) if gamma_point(kptij): feri.create_dataset(key, shape, 'f8') else: feri.create_dataset(key, shape, 'c16') naux0 = 0 for istep, auxrange in enumerate(auxranges): log.alldebug2("aux_e2 job_id %d step %d", job_id, istep) sh0, sh1, nrow = auxrange sub_slice = (ish0, ish1, 0, cell.nbas, sh0, sh1) if j_only: mat = numpy.ndarray((nkptij, dii, nrow), dtype=dtype, buffer=buf) else: mat = numpy.ndarray((nkptij, dij, nrow), dtype=dtype, buffer=buf) mat = int3c(sub_slice, mat) for k, kptij in enumerate(kptij_lst): h5dat = feri['%s/%d' % (dataname, k)] v = lib.transpose(mat[k], out=buf1) if not j_only and aosym_s2[k]: idy = idxb[i0 * (i0 + 1) // 2:i1 * (i1 + 1) // 2] - i0 * nao out = numpy.ndarray((nrow, dii), dtype=v.dtype, buffer=mat[k]) v = numpy.take(v, idy, axis=1, out=out) if gamma_point(kptij): h5dat[naux0:naux0 + nrow] = v.real else: h5dat[naux0:naux0 + nrow] = v naux0 += nrow def ft_fuse(job_id, uniq_kptji_id, sh0, sh1): kpt = uniq_kpts[uniq_kptji_id] # kpt = kptj - kpti adapted_ji_idx = numpy.where(uniq_inverse == uniq_kptji_id)[0] adapted_kptjs = kptjs[adapted_ji_idx] nkptj = len(adapted_kptjs) shls_slice = (auxcell.nbas, fused_cell.nbas) Gaux = ft_ao.ft_ao(fused_cell, Gv, shls_slice, b, gxyz, Gvbase, kpt) Gaux *= mydf.weighted_coulG(kpt, False, gs).reshape(-1, 1) kLR = Gaux.real.copy('C') kLI = Gaux.imag.copy('C') j2c = numpy.asarray(feri['j2c/%d' % uniq_kptji_id]) j2ctag = j2ctags[uniq_kptji_id] naux0 = j2c.shape[0] if is_zero(kpt): aosym = 's2' else: aosym = 's1' j3cR = [None] * nkptj j3cI = [None] * nkptj i0 = ao_loc[sh0] i1 = ao_loc[sh1] for k, idx in enumerate(adapted_ji_idx): key = 'j3c-chunks/%d/%d' % (job_id, idx) v = numpy.asarray(feri[key]) if is_zero(kpt): for i, c in enumerate(vbar): if c != 0: v[i] -= c * ovlp[k][i0 * (i0 + 1) // 2:i1 * (i1 + 1) // 2].ravel() j3cR[k] = numpy.asarray(v.real, order='C') if v.dtype == numpy.complex128: j3cI[k] = numpy.asarray(v.imag, order='C') v = None ncol = j3cR[0].shape[1] Gblksize = max(16, int(max_memory * 1e6 / 16 / ncol / (nkptj + 1))) # +1 for pqkRbuf/pqkIbuf Gblksize = min(Gblksize, ngs, 16384) pqkRbuf = numpy.empty(ncol * Gblksize) pqkIbuf = numpy.empty(ncol * Gblksize) buf = numpy.empty(nkptj * ncol * Gblksize, dtype=numpy.complex128) log.alldebug2(' blksize (%d,%d)', Gblksize, ncol) shls_slice = (sh0, sh1, 0, cell.nbas) for p0, p1 in lib.prange(0, ngs, Gblksize): dat = ft_ao._ft_aopair_kpts(cell, Gv[p0:p1], shls_slice, aosym, b, gxyz[p0:p1], Gvbase, kpt, adapted_kptjs, out=buf) nG = p1 - p0 for k, ji in enumerate(adapted_ji_idx): aoao = dat[k].reshape(nG, ncol) pqkR = numpy.ndarray((ncol, nG), buffer=pqkRbuf) pqkI = numpy.ndarray((ncol, nG), buffer=pqkIbuf) pqkR[:] = aoao.real.T pqkI[:] = aoao.imag.T lib.dot(kLR[p0:p1].T, pqkR.T, -1, j3cR[k][naux:], 1) lib.dot(kLI[p0:p1].T, pqkI.T, -1, j3cR[k][naux:], 1) if not (is_zero(kpt) and gamma_point(adapted_kptjs[k])): lib.dot(kLR[p0:p1].T, pqkI.T, -1, j3cI[k][naux:], 1) lib.dot(kLI[p0:p1].T, pqkR.T, 1, j3cI[k][naux:], 1) for k, idx in enumerate(adapted_ji_idx): if is_zero(kpt) and gamma_point(adapted_kptjs[k]): v = fuse(j3cR[k]) else: v = fuse(j3cR[k] + j3cI[k] * 1j) if j2ctag == 'CD': v = scipy.linalg.solve_triangular(j2c, v, lower=True, overwrite_b=True) else: v = lib.dot(j2c, v) feri['j3c-chunks/%d/%d' % (job_id, idx)][:naux0] = v t2 = t1 j3c_workers = numpy.zeros(len(j3c_jobs), dtype=int) #for job_id, ish0, ish1 in mpi.work_share_partition(j3c_jobs): for job_id, ish0, ish1 in mpi.work_stealing_partition(j3c_jobs): gen_int3c(fused_cell, job_id, ish0, ish1) t2 = log.alltimer_debug2('int j3c %d' % job_id, *t2) for k, kpt in enumerate(uniq_kpts): ft_fuse(job_id, k, ish0, ish1) t2 = log.alltimer_debug2('ft-fuse %d k %d' % (job_id, k), *t2) j3c_workers[job_id] = rank j3c_workers = mpi.allreduce(j3c_workers) log.debug2('j3c_workers %s', j3c_workers) j2c = kLRs = kLIs = ovlp = vbar = fuse = gen_int3c = ft_fuse = None t1 = log.timer_debug1('int3c and fuse', *t1) def get_segs_loc(aosym): off0 = numpy.asarray([ao_loc[i0] for x, i0, i1 in j3c_jobs]) off1 = numpy.asarray([ao_loc[i1] for x, i0, i1 in j3c_jobs]) if aosym: # s2 dims = off1 * (off1 + 1) // 2 - off0 * (off0 + 1) // 2 else: dims = (off1 - off0) * nao #dims = numpy.asarray([ao_loc[i1]-ao_loc[i0] for x,i0,i1 in j3c_jobs]) dims = numpy.hstack( [dims[j3c_workers == w] for w in range(mpi.pool.size)]) job_idx = numpy.hstack( [numpy.where(j3c_workers == w)[0] for w in range(mpi.pool.size)]) segs_loc = numpy.append(0, numpy.cumsum(dims)) segs_loc = [(segs_loc[j], segs_loc[j + 1]) for j in numpy.argsort(job_idx)] return segs_loc segs_loc_s1 = get_segs_loc(False) segs_loc_s2 = get_segs_loc(True) if 'j3c' in feri: del (feri['j3c']) segsize = (max(nauxs) + mpi.pool.size - 1) // mpi.pool.size naux0 = rank * segsize for k, kptij in enumerate(kptij_lst): naux1 = min(nauxs[uniq_inverse[k]], naux0 + segsize) nrow = max(0, naux1 - naux0) if gamma_point(kptij): dtype = 'f8' else: dtype = 'c16' if aosym_s2[k]: nao_pair = nao * (nao + 1) // 2 else: nao_pair = nao * nao feri.create_dataset('j3c/%d' % k, (nrow, nao_pair), dtype, maxshape=(None, nao_pair)) def load(k, p0, p1): naux1 = nauxs[uniq_inverse[k]] slices = [(min(i * segsize + p0, naux1), min(i * segsize + p1, naux1)) for i in range(mpi.pool.size)] segs = [] for p0, p1 in slices: val = [] for job_id, worker in enumerate(j3c_workers): if rank == worker: key = 'j3c-chunks/%d/%d' % (job_id, k) val.append(feri[key][p0:p1].ravel()) if val: segs.append(numpy.hstack(val)) else: segs.append(numpy.zeros(0)) return segs def save(k, p0, p1, segs): segs = mpi.alltoall(segs) naux1 = nauxs[uniq_inverse[k]] loc0, loc1 = min(p0, naux1 - naux0), min(p1, naux1 - naux0) nL = loc1 - loc0 if nL > 0: if aosym_s2[k]: segs = numpy.hstack([ segs[i0 * nL:i1 * nL].reshape(nL, -1) for i0, i1 in segs_loc_s2 ]) else: segs = numpy.hstack([ segs[i0 * nL:i1 * nL].reshape(nL, -1) for i0, i1 in segs_loc_s1 ]) feri['j3c/%d' % k][loc0:loc1] = segs mem_now = max(comm.allgather(lib.current_memory()[0])) max_memory = max(2000, min(8000, mydf.max_memory - mem_now)) if numpy.all(aosym_s2): if gamma_point(kptij_lst): blksize = max(16, int(max_memory * .5e6 / 8 / nao**2)) else: blksize = max(16, int(max_memory * .5e6 / 16 / nao**2)) else: blksize = max(16, int(max_memory * .5e6 / 16 / nao**2 / 2)) log.debug1('max_momory %d MB (%d in use), blksize %d', max_memory, mem_now, blksize) t2 = t1 with lib.call_in_background(save) as async_write: for k, kptji in enumerate(kptij_lst): for p0, p1 in lib.prange(0, segsize, blksize): segs = load(k, p0, p1) async_write(k, p0, p1, segs) t2 = log.timer_debug1( 'assemble k=%d %d:%d (in %d)' % (k, p0, p1, segsize), *t2) if 'j3c-chunks' in feri: del (feri['j3c-chunks']) if 'j3c-kptij' in feri: del (feri['j3c-kptij']) feri['j3c-kptij'] = kptij_lst t1 = log.alltimer_debug1('assembling j3c', *t1) feri.close()
def _make_j3c(mydf, cell, auxcell, kptij_lst, cderi_file): log = logger.Logger(mydf.stdout, mydf.verbose) t1 = t0 = (time.clock(), time.time()) fused_cell, fuse = fuse_auxcell(mydf, mydf.auxcell) ao_loc = cell.ao_loc_nr() nao = ao_loc[-1] naux = auxcell.nao_nr() nkptij = len(kptij_lst) mesh = mydf.mesh Gv, Gvbase, kws = cell.get_Gv_weights(mesh) b = cell.reciprocal_vectors() gxyz = lib.cartesian_prod([numpy.arange(len(x)) for x in Gvbase]) ngrids = gxyz.shape[0] kptis = kptij_lst[:,0] kptjs = kptij_lst[:,1] kpt_ji = kptjs - kptis uniq_kpts, uniq_index, uniq_inverse = unique(kpt_ji) log.debug('Num uniq kpts %d', len(uniq_kpts)) log.debug2('uniq_kpts %s', uniq_kpts) # j2c ~ (-kpt_ji | kpt_ji) j2c = fused_cell.pbc_intor('int2c2e', hermi=1, kpts=uniq_kpts) j2ctags = [] t1 = log.timer_debug1('2c2e', *t1) swapfile = tempfile.NamedTemporaryFile(dir=os.path.dirname(cderi_file)) fswap = lib.H5TmpFile(swapfile.name) # Unlink swapfile to avoid trash swapfile = None for k, kpt in enumerate(uniq_kpts): coulG = mydf.weighted_coulG(kpt, False, mesh) j2c[k] = fuse(fuse(j2c[k]).T).T.copy() j2c_k = numpy.zeros_like(j2c[k]) for p0, p1 in mydf.mpi_prange(0, ngrids): aoaux = ft_ao.ft_ao(fused_cell, Gv[p0:p1], None, b, gxyz[p0:p1], Gvbase, kpt).T aoaux = fuse(aoaux) LkR = numpy.asarray(aoaux.real, order='C') LkI = numpy.asarray(aoaux.imag, order='C') aoaux = None if is_zero(kpt): # kpti == kptj j2cR = lib.dot(LkR*coulG[p0:p1], LkR.T) j2c_k += lib.dot(LkI*coulG[p0:p1], LkI.T, 1, j2cR, 1) else: # aoaux ~ kpt_ij, aoaux.conj() ~ kpt_kl j2cR, j2cI = zdotCN(LkR*coulG[p0:p1], LkI*coulG[p0:p1], LkR.T, LkI.T) j2c_k += j2cR + j2cI * 1j LkR = LkI = None j2c[k] -= mpi.allreduce(j2c_k) try: fswap['j2c/%d'%k] = scipy.linalg.cholesky(j2c[k], lower=True) j2ctags.append('CD') except scipy.linalg.LinAlgError: w, v = scipy.linalg.eigh(j2c[k]) log.debug2('metric linear dependency for kpt %s', k) log.debug2('cond = %.4g, drop %d bfns', w[0]/w[-1], numpy.count_nonzero(w<mydf.linear_dep_threshold)) v1 = v[:,w>mydf.linear_dep_threshold].T.conj() v1 /= numpy.sqrt(w[w>mydf.linear_dep_threshold]).reshape(-1,1) fswap['j2c/%d'%k] = v1 if cell.dimension == 2 and cell.low_dim_ft_type != 'inf_vacuum': idx = numpy.where(w < -mydf.linear_dep_threshold)[0] if len(idx) > 0: fswap['j2c-/%d'%k] = (v[:,idx]/numpy.sqrt(-w[idx])).conj().T w = v = v1 = v2 = None j2ctags.append('eig') aoaux = kLR = kLI = j2cR = j2cI = coulG = None j2c = None aosym_s2 = numpy.einsum('ix->i', abs(kptis-kptjs)) < 1e-9 j_only = numpy.all(aosym_s2) if gamma_point(kptij_lst): dtype = 'f8' else: dtype = 'c16' t1 = log.timer_debug1('aoaux and int2c', *t1) # Estimates the buffer size based on the last contraction in G-space. # This contraction requires to hold nkptj copies of (naux,?) array # simultaneously in memory. mem_now = max(comm.allgather(lib.current_memory()[0])) max_memory = max(2000, mydf.max_memory - mem_now) nkptj_max = max((uniq_inverse==x).sum() for x in set(uniq_inverse)) buflen = max(int(min(max_memory*.5e6/16/naux/(nkptj_max+2)/nao, nao/3/mpi.pool.size)), 1) chunks = (buflen, nao) j3c_jobs = mpi_df.grids2d_int3c_jobs(cell, auxcell, kptij_lst, chunks, j_only) log.debug1('max_memory = %d MB (%d in use) chunks %s', max_memory, mem_now, chunks) log.debug2('j3c_jobs %s', j3c_jobs) if j_only: int3c = wrap_int3c(cell, fused_cell, 'int3c2e', 's2', 1, kptij_lst) else: int3c = wrap_int3c(cell, fused_cell, 'int3c2e', 's1', 1, kptij_lst) idxb = numpy.tril_indices(nao) idxb = (idxb[0] * nao + idxb[1]).astype('i') aux_loc = fused_cell.ao_loc_nr(fused_cell.cart) def gen_int3c(job_id, ish0, ish1): dataname = 'j3c-chunks/%d' % job_id i0 = ao_loc[ish0] i1 = ao_loc[ish1] dii = i1*(i1+1)//2 - i0*(i0+1)//2 dij = (i1 - i0) * nao if j_only: buflen = max(8, int(max_memory*1e6/16/(nkptij*dii+dii))) else: buflen = max(8, int(max_memory*1e6/16/(nkptij*dij+dij))) auxranges = balance_segs(aux_loc[1:]-aux_loc[:-1], buflen) buflen = max([x[2] for x in auxranges]) buf = numpy.empty(nkptij*dij*buflen, dtype=dtype) buf1 = numpy.empty(dij*buflen, dtype=dtype) naux = aux_loc[-1] for kpt_id, kptij in enumerate(kptij_lst): key = '%s/%d' % (dataname, kpt_id) if aosym_s2[kpt_id]: shape = (naux, dii) else: shape = (naux, dij) if gamma_point(kptij): fswap.create_dataset(key, shape, 'f8') else: fswap.create_dataset(key, shape, 'c16') naux0 = 0 for istep, auxrange in enumerate(auxranges): log.alldebug2("aux_e1 job_id %d step %d", job_id, istep) sh0, sh1, nrow = auxrange sub_slice = (ish0, ish1, 0, cell.nbas, sh0, sh1) if j_only: mat = numpy.ndarray((nkptij,dii,nrow), dtype=dtype, buffer=buf) else: mat = numpy.ndarray((nkptij,dij,nrow), dtype=dtype, buffer=buf) mat = int3c(sub_slice, mat) for k, kptij in enumerate(kptij_lst): h5dat = fswap['%s/%d'%(dataname,k)] v = lib.transpose(mat[k], out=buf1) if not j_only and aosym_s2[k]: idy = idxb[i0*(i0+1)//2:i1*(i1+1)//2] - i0 * nao out = numpy.ndarray((nrow,dii), dtype=v.dtype, buffer=mat[k]) v = numpy.take(v, idy, axis=1, out=out) if gamma_point(kptij): h5dat[naux0:naux0+nrow] = v.real else: h5dat[naux0:naux0+nrow] = v naux0 += nrow def ft_fuse(job_id, uniq_kptji_id, sh0, sh1): kpt = uniq_kpts[uniq_kptji_id] # kpt = kptj - kpti adapted_ji_idx = numpy.where(uniq_inverse == uniq_kptji_id)[0] adapted_kptjs = kptjs[adapted_ji_idx] nkptj = len(adapted_kptjs) Gaux = ft_ao.ft_ao(fused_cell, Gv, None, b, gxyz, Gvbase, kpt).T Gaux = fuse(Gaux) Gaux *= mydf.weighted_coulG(kpt, False, mesh) kLR = lib.transpose(numpy.asarray(Gaux.real, order='C')) kLI = lib.transpose(numpy.asarray(Gaux.imag, order='C')) j2c = numpy.asarray(fswap['j2c/%d'%uniq_kptji_id]) j2ctag = j2ctags[uniq_kptji_id] naux0 = j2c.shape[0] if ('j2c-/%d' % uniq_kptji_id) in fswap: j2c_negative = numpy.asarray(fswap['j2c-/%d'%uniq_kptji_id]) else: j2c_negative = None if is_zero(kpt): aosym = 's2' else: aosym = 's1' if aosym == 's2' and cell.dimension == 3: vbar = fuse(mydf.auxbar(fused_cell)) ovlp = cell.pbc_intor('int1e_ovlp', hermi=1, kpts=adapted_kptjs) ovlp = [lib.pack_tril(s) for s in ovlp] j3cR = [None] * nkptj j3cI = [None] * nkptj i0 = ao_loc[sh0] i1 = ao_loc[sh1] for k, idx in enumerate(adapted_ji_idx): key = 'j3c-chunks/%d/%d' % (job_id, idx) v = fuse(numpy.asarray(fswap[key])) if aosym == 's2' and cell.dimension == 3: for i in numpy.where(vbar != 0)[0]: v[i] -= vbar[i] * ovlp[k][i0*(i0+1)//2:i1*(i1+1)//2].ravel() j3cR[k] = numpy.asarray(v.real, order='C') if v.dtype == numpy.complex128: j3cI[k] = numpy.asarray(v.imag, order='C') v = None ncol = j3cR[0].shape[1] Gblksize = max(16, int(max_memory*1e6/16/ncol/(nkptj+1))) # +1 for pqkRbuf/pqkIbuf Gblksize = min(Gblksize, ngrids, 16384) pqkRbuf = numpy.empty(ncol*Gblksize) pqkIbuf = numpy.empty(ncol*Gblksize) buf = numpy.empty(nkptj*ncol*Gblksize, dtype=numpy.complex128) log.alldebug2(' blksize (%d,%d)', Gblksize, ncol) if aosym == 's2': shls_slice = (sh0, sh1, 0, sh1) else: shls_slice = (sh0, sh1, 0, cell.nbas) for p0, p1 in lib.prange(0, ngrids, Gblksize): dat = ft_ao._ft_aopair_kpts(cell, Gv[p0:p1], shls_slice, aosym, b, gxyz[p0:p1], Gvbase, kpt, adapted_kptjs, out=buf) nG = p1 - p0 for k, ji in enumerate(adapted_ji_idx): aoao = dat[k].reshape(nG,ncol) pqkR = numpy.ndarray((ncol,nG), buffer=pqkRbuf) pqkI = numpy.ndarray((ncol,nG), buffer=pqkIbuf) pqkR[:] = aoao.real.T pqkI[:] = aoao.imag.T lib.dot(kLR[p0:p1].T, pqkR.T, -1, j3cR[k], 1) lib.dot(kLI[p0:p1].T, pqkI.T, -1, j3cR[k], 1) if not (is_zero(kpt) and gamma_point(adapted_kptjs[k])): lib.dot(kLR[p0:p1].T, pqkI.T, -1, j3cI[k], 1) lib.dot(kLI[p0:p1].T, pqkR.T, 1, j3cI[k], 1) for k, idx in enumerate(adapted_ji_idx): if is_zero(kpt) and gamma_point(adapted_kptjs[k]): v = j3cR[k] else: v = j3cR[k] + j3cI[k] * 1j if j2ctag == 'CD': v = scipy.linalg.solve_triangular(j2c, v, lower=True, overwrite_b=True) fswap['j3c-chunks/%d/%d'%(job_id,idx)][:naux0] = v else: fswap['j3c-chunks/%d/%d'%(job_id,idx)][:naux0] = lib.dot(j2c, v) # low-dimension systems if j2c_negative is not None: fswap['j3c-/%d/%d'%(job_id,idx)] = lib.dot(j2c_negative, v) mpi_df._assemble(mydf, kptij_lst, j3c_jobs, gen_int3c, ft_fuse, cderi_file, fswap, log)
def update_amps(mycc, t1, t2, eris): time1 = time0 = time.clock(), time.time() log = logger.Logger(mycc.stdout, mycc.verbose) cpu1 = time0 t1T = t1.T t2T = numpy.asarray(t2.transpose(2, 3, 0, 1), order='C') nvir_seg, nvir, nocc = t2T.shape[:3] t1 = t2 = None ntasks = mpi.pool.size vlocs = [_task_location(nvir, task_id) for task_id in range(ntasks)] vloc0, vloc1 = vlocs[rank] log.debug2('vlocs %s', vlocs) assert (vloc1 - vloc0 == nvir_seg) fock = eris.fock mo_e_o = eris.mo_energy[:nocc] mo_e_v = eris.mo_energy[nocc:] + mycc.level_shift def _rotate_vir_block(buf): for task_id, buf in _rotate_tensor_block(buf): loc0, loc1 = vlocs[task_id] yield task_id, buf, loc0, loc1 fswap = lib.H5TmpFile() wVooV = numpy.zeros((nvir_seg, nocc, nocc, nvir)) eris_voov = _cp(eris.ovvo).transpose(1, 0, 3, 2) tau = t2T * .5 tau += numpy.einsum('ai,bj->abij', t1T[vloc0:vloc1], t1T) for task_id, tau, p0, p1 in _rotate_vir_block(tau): wVooV += lib.einsum('bkic,cajk->bija', eris_voov[:, :, :, p0:p1], tau) fswap['wVooV1'] = wVooV wVooV = tau = None time1 = log.timer_debug1('wVooV', *time1) wVOov = eris_voov eris_VOov = eris_voov - eris_voov.transpose(0, 2, 1, 3) * .5 tau = t2T.transpose(2, 0, 3, 1) - t2T.transpose(3, 0, 2, 1) * .5 tau -= numpy.einsum('ai,bj->jaib', t1T[vloc0:vloc1], t1T) for task_id, tau, p0, p1 in _rotate_vir_block(tau): wVOov += lib.einsum('dlkc,kcjb->dljb', eris_VOov[:, :, :, p0:p1], tau) fswap['wVOov1'] = wVOov wVOov = tau = eris_VOov = eris_voov = None time1 = log.timer_debug1('wVOov', *time1) t1Tnew = numpy.zeros_like(t1T) t2Tnew = mycc._add_vvvv(t1T, t2T, eris, t2sym='jiba') time1 = log.timer_debug1('vvvv', *time1) #** make_inter_F fov = fock[:nocc, nocc:].copy() t1Tnew += fock[nocc:, :nocc] foo = fock[:nocc, :nocc] - numpy.diag(mo_e_o) foo += .5 * numpy.einsum('ia,aj->ij', fock[:nocc, nocc:], t1T) fvv = fock[nocc:, nocc:] - numpy.diag(mo_e_v) fvv -= .5 * numpy.einsum('ai,ib->ab', t1T, fock[:nocc, nocc:]) foo_priv = numpy.zeros_like(foo) fov_priv = numpy.zeros_like(fov) fvv_priv = numpy.zeros_like(fvv) t1T_priv = numpy.zeros_like(t1T) max_memory = mycc.max_memory - lib.current_memory()[0] unit = nocc * nvir**2 * 3 + nocc**2 * nvir + 1 blksize = min(nvir, max(BLKMIN, int((max_memory * .9e6 / 8 - t2T.size) / unit))) log.debug1('pass 1, max_memory %d MB, nocc,nvir = %d,%d blksize = %d', max_memory, nocc, nvir, blksize) buf = numpy.empty((blksize, nvir, nvir, nocc)) def load_vvvo(p0): p1 = min(nvir_seg, p0 + blksize) if p0 < p1: buf[:p1 - p0] = eris.vvvo[p0:p1] fswap.create_dataset('wVooV', (nvir_seg, nocc, nocc, nvir), 'f8') wVOov = [] with lib.call_in_background(load_vvvo) as prefetch: load_vvvo(0) for p0, p1 in lib.prange(vloc0, vloc1, blksize): i0, i1 = p0 - vloc0, p1 - vloc0 eris_vvvo, buf = buf[:p1 - p0], numpy.empty_like(buf) prefetch(i1) fvv_priv[p0:p1] += 2 * numpy.einsum('ck,abck->ab', t1T, eris_vvvo) fvv_priv -= numpy.einsum('ck,cabk->ab', t1T[p0:p1], eris_vvvo) if not mycc.direct: raise NotImplementedError tau = t2T[i0:i1] + numpy.einsum('ai,bj->abij', t1T[p0:p1], t1T) for task_id, tau, q0, q1 in _rotate_vir_block(tau): tmp = lib.einsum('bdck,cdij->bkij', eris_vvvo[:, :, q0:q1], tau) t2Tnew -= lib.einsum('ak,bkij->baji', t1T, tmp) tau = tmp = None fswap['wVooV'][i0:i1] = lib.einsum('cj,baci->bija', -t1T, eris_vvvo) theta = t2T[i0:i1].transpose(0, 2, 1, 3) * 2 theta -= t2T[i0:i1].transpose(0, 3, 1, 2) t1T_priv += lib.einsum('bicj,bacj->ai', theta, eris_vvvo) wVOov.append(lib.einsum('acbi,cj->abij', eris_vvvo, t1T)) theta = eris_vvvo = None time1 = log.timer_debug1('vvvo [%d:%d]' % (p0, p1), *time1) wVOov = numpy.vstack(wVOov) wVOov = mpi.alltoall([wVOov[:, q0:q1] for q0, q1 in vlocs], split_recvbuf=True) wVOov = numpy.vstack([x.reshape(-1, nvir_seg, nocc, nocc) for x in wVOov]) fswap['wVOov'] = wVOov.transpose(1, 2, 3, 0) wVooV = None unit = nocc**2 * nvir * 7 + nocc**3 + nocc * nvir**2 max_memory = max(0, mycc.max_memory - lib.current_memory()[0]) blksize = min(nvir, max(BLKMIN, int((max_memory * .9e6 / 8 - nocc**4) / unit))) log.debug1('pass 2, max_memory %d MB, nocc,nvir = %d,%d blksize = %d', max_memory, nocc, nvir, blksize) woooo = numpy.zeros((nocc, nocc, nocc, nocc)) for p0, p1 in lib.prange(vloc0, vloc1, blksize): i0, i1 = p0 - vloc0, p1 - vloc0 wVOov = fswap['wVOov'][i0:i1] wVooV = fswap['wVooV'][i0:i1] eris_ovoo = eris.ovoo[:, i0:i1] eris_oovv = numpy.empty((nocc, nocc, i1 - i0, nvir)) def load_oovv(p0, p1): eris_oovv[:] = eris.oovv[:, :, p0:p1] with lib.call_in_background(load_oovv) as prefetch_oovv: #:eris_oovv = eris.oovv[:,:,i0:i1] prefetch_oovv(i0, i1) foo_priv += numpy.einsum('ck,kcji->ij', 2 * t1T[p0:p1], eris_ovoo) foo_priv += numpy.einsum('ck,icjk->ij', -t1T[p0:p1], eris_ovoo) tmp = lib.einsum('al,jaik->lkji', t1T[p0:p1], eris_ovoo) woooo += tmp + tmp.transpose(1, 0, 3, 2) tmp = None wVOov -= lib.einsum('jbik,ak->bjia', eris_ovoo, t1T) t2Tnew[i0:i1] += wVOov.transpose(0, 3, 1, 2) wVooV += lib.einsum('kbij,ak->bija', eris_ovoo, t1T) eris_ovoo = None load_oovv = prefetch_oovv = None eris_ovvo = numpy.empty((nocc, i1 - i0, nvir, nocc)) def load_ovvo(p0, p1): eris_ovvo[:] = eris.ovvo[:, p0:p1] with lib.call_in_background(load_ovvo) as prefetch_ovvo: #:eris_ovvo = eris.ovvo[:,i0:i1] prefetch_ovvo(i0, i1) t1T_priv[p0:p1] -= numpy.einsum('bj,jiab->ai', t1T, eris_oovv) wVooV -= eris_oovv.transpose(2, 0, 1, 3) wVOov += wVooV * .5 #: bjia + bija*.5 eris_voov = eris_ovvo.transpose(1, 0, 3, 2) eris_ovvo = None load_ovvo = prefetch_ovvo = None def update_wVooV(i0, i1): wVooV[:] += fswap['wVooV1'][i0:i1] fswap['wVooV1'][i0:i1] = wVooV wVOov[:] += fswap['wVOov1'][i0:i1] fswap['wVOov1'][i0:i1] = wVOov with lib.call_in_background(update_wVooV) as update_wVooV: update_wVooV(i0, i1) t2Tnew[i0:i1] += eris_voov.transpose(0, 3, 1, 2) * .5 t1T_priv[p0:p1] += 2 * numpy.einsum('bj,aijb->ai', t1T, eris_voov) tmp = lib.einsum('ci,kjbc->bijk', t1T, eris_oovv) tmp += lib.einsum('bjkc,ci->bjik', eris_voov, t1T) t2Tnew[i0:i1] -= lib.einsum('bjik,ak->baji', tmp, t1T) eris_oovv = tmp = None fov_priv[:, p0:p1] += numpy.einsum('ck,aikc->ia', t1T, eris_voov) * 2 fov_priv[:, p0:p1] -= numpy.einsum('ck,akic->ia', t1T, eris_voov) tau = numpy.einsum('ai,bj->abij', t1T[p0:p1] * .5, t1T) tau += t2T[i0:i1] theta = tau.transpose(0, 1, 3, 2) * 2 theta -= tau fvv_priv -= lib.einsum('caij,cjib->ab', theta, eris_voov) foo_priv += lib.einsum('aikb,abkj->ij', eris_voov, theta) tau = theta = None tau = t2T[i0:i1] + numpy.einsum('ai,bj->abij', t1T[p0:p1], t1T) woooo += lib.einsum('abij,aklb->ijkl', tau, eris_voov) tau = None eris_VOov = wVOov = wVooV = update_wVooV = None time1 = log.timer_debug1('voov [%d:%d]' % (p0, p1), *time1) wVooV = _cp(fswap['wVooV1']) for task_id, wVooV, p0, p1 in _rotate_vir_block(wVooV): tmp = lib.einsum('ackj,ckib->ajbi', t2T[:, p0:p1], wVooV) t2Tnew += tmp.transpose(0, 2, 3, 1) t2Tnew += tmp.transpose(0, 2, 1, 3) * .5 wVooV = tmp = None time1 = log.timer_debug1('contracting wVooV', *time1) wVOov = _cp(fswap['wVOov1']) theta = t2T * 2 theta -= t2T.transpose(0, 1, 3, 2) for task_id, wVOov, p0, p1 in _rotate_vir_block(wVOov): t2Tnew += lib.einsum('acik,ckjb->abij', theta[:, p0:p1], wVOov) wVOov = theta = None fswap = None time1 = log.timer_debug1('contracting wVOov', *time1) foo += mpi.allreduce(foo_priv) fov += mpi.allreduce(fov_priv) fvv += mpi.allreduce(fvv_priv) theta = t2T.transpose(0, 1, 3, 2) * 2 - t2T t1T_priv[vloc0:vloc1] += numpy.einsum('jb,abji->ai', fov, theta) ovoo = _cp(eris.ovoo) for task_id, ovoo, p0, p1 in _rotate_vir_block(ovoo): t1T_priv[vloc0:vloc1] -= lib.einsum('jbki,abjk->ai', ovoo, theta[:, p0:p1]) theta = ovoo = None woooo = mpi.allreduce(woooo) woooo += _cp(eris.oooo).transpose(0, 2, 1, 3) tau = t2T + numpy.einsum('ai,bj->abij', t1T[vloc0:vloc1], t1T) t2Tnew += .5 * lib.einsum('abkl,ijkl->abij', tau, woooo) tau = woooo = None t1Tnew += mpi.allreduce(t1T_priv) ft_ij = foo + numpy.einsum('aj,ia->ij', .5 * t1T, fov) ft_ab = fvv - numpy.einsum('ai,ib->ab', .5 * t1T, fov) t2Tnew += lib.einsum('acij,bc->abij', t2T, ft_ab) t2Tnew -= lib.einsum('ki,abkj->abij', ft_ij, t2T) eia = mo_e_o[:, None] - mo_e_v t1Tnew += numpy.einsum('bi,ab->ai', t1T, fvv) t1Tnew -= numpy.einsum('aj,ji->ai', t1T, foo) t1Tnew /= eia.T t2tmp = mpi.alltoall([t2Tnew[:, p0:p1] for p0, p1 in vlocs], split_recvbuf=True) for task_id, (p0, p1) in enumerate(vlocs): tmp = t2tmp[task_id].reshape(p1 - p0, nvir_seg, nocc, nocc) t2Tnew[:, p0:p1] += tmp.transpose(1, 0, 3, 2) for i in range(vloc0, vloc1): t2Tnew[i - vloc0] /= lib.direct_sum('i+jb->bij', eia[:, i], eia) time0 = log.timer_debug1('update t1 t2', *time0) return t1Tnew.T, t2Tnew.transpose(2, 3, 0, 1)
def update_amps(mycc, t1, t2, eris): time1 = time0 = time.clock(), time.time() log = logger.Logger(mycc.stdout, mycc.verbose) cpu1 = time0 t1T = t1.T t2T = numpy.asarray(t2.transpose(2,3,0,1), order='C') nvir_seg, nvir, nocc = t2T.shape[:3] t1 = t2 = None ntasks = mpi.pool.size vlocs = [_task_location(nvir, task_id) for task_id in range(ntasks)] vloc0, vloc1 = vlocs[rank] log.debug2('vlocs %s', vlocs) assert(vloc1-vloc0 == nvir_seg) fock = eris.fock mo_e_o = eris.mo_energy[:nocc] mo_e_v = eris.mo_energy[nocc:] + mycc.level_shift def _rotate_vir_block(buf): for task_id, buf in _rotate_tensor_block(buf): loc0, loc1 = vlocs[task_id] yield task_id, buf, loc0, loc1 fswap = lib.H5TmpFile() wVooV = numpy.zeros((nvir_seg,nocc,nocc,nvir)) eris_voov = _cp(eris.ovvo).transpose(1,0,3,2) tau = t2T * .5 tau += numpy.einsum('ai,bj->abij', t1T[vloc0:vloc1], t1T) for task_id, tau, p0, p1 in _rotate_vir_block(tau): wVooV += lib.einsum('bkic,cajk->bija', eris_voov[:,:,:,p0:p1], tau) fswap['wVooV1'] = wVooV wVooV = tau = None time1 = log.timer_debug1('wVooV', *time1) wVOov = eris_voov eris_VOov = eris_voov - eris_voov.transpose(0,2,1,3)*.5 tau = t2T.transpose(2,0,3,1) - t2T.transpose(3,0,2,1)*.5 tau -= numpy.einsum('ai,bj->jaib', t1T[vloc0:vloc1], t1T) for task_id, tau, p0, p1 in _rotate_vir_block(tau): wVOov += lib.einsum('dlkc,kcjb->dljb', eris_VOov[:,:,:,p0:p1], tau) fswap['wVOov1'] = wVOov wVOov = tau = eris_VOov = eris_voov = None time1 = log.timer_debug1('wVOov', *time1) t1Tnew = numpy.zeros_like(t1T) t2Tnew = mycc._add_vvvv(t1T, t2T, eris, t2sym='jiba') time1 = log.timer_debug1('vvvv', *time1) #** make_inter_F fov = fock[:nocc,nocc:].copy() t1Tnew += fock[nocc:,:nocc] foo = fock[:nocc,:nocc] - numpy.diag(mo_e_o) foo += .5 * numpy.einsum('ia,aj->ij', fock[:nocc,nocc:], t1T) fvv = fock[nocc:,nocc:] - numpy.diag(mo_e_v) fvv -= .5 * numpy.einsum('ai,ib->ab', t1T, fock[:nocc,nocc:]) foo_priv = numpy.zeros_like(foo) fov_priv = numpy.zeros_like(fov) fvv_priv = numpy.zeros_like(fvv) t1T_priv = numpy.zeros_like(t1T) max_memory = mycc.max_memory - lib.current_memory()[0] unit = nocc*nvir**2*3 + nocc**2*nvir + 1 blksize = min(nvir, max(BLKMIN, int((max_memory*.9e6/8-t2T.size)/unit))) log.debug1('pass 1, max_memory %d MB, nocc,nvir = %d,%d blksize = %d', max_memory, nocc, nvir, blksize) buf = numpy.empty((blksize,nvir,nvir,nocc)) def load_vvvo(p0): p1 = min(nvir_seg, p0+blksize) if p0 < p1: buf[:p1-p0] = eris.vvvo[p0:p1] fswap.create_dataset('wVooV', (nvir_seg,nocc,nocc,nvir), 'f8') wVOov = [] with lib.call_in_background(load_vvvo) as prefetch: load_vvvo(0) for p0, p1 in lib.prange(vloc0, vloc1, blksize): i0, i1 = p0 - vloc0, p1 - vloc0 eris_vvvo, buf = buf[:p1-p0], numpy.empty_like(buf) prefetch(i1) fvv_priv[p0:p1] += 2*numpy.einsum('ck,abck->ab', t1T, eris_vvvo) fvv_priv -= numpy.einsum('ck,cabk->ab', t1T[p0:p1], eris_vvvo) if not mycc.direct: raise NotImplementedError tau = t2T[i0:i1] + numpy.einsum('ai,bj->abij', t1T[p0:p1], t1T) for task_id, tau, q0, q1 in _rotate_vir_block(tau): tmp = lib.einsum('bdck,cdij->bkij', eris_vvvo[:,:,q0:q1], tau) t2Tnew -= lib.einsum('ak,bkij->baji', t1T, tmp) tau = tmp = None fswap['wVooV'][i0:i1] = lib.einsum('cj,baci->bija', -t1T, eris_vvvo) theta = t2T[i0:i1].transpose(0,2,1,3) * 2 theta -= t2T[i0:i1].transpose(0,3,1,2) t1T_priv += lib.einsum('bicj,bacj->ai', theta, eris_vvvo) wVOov.append(lib.einsum('acbi,cj->abij', eris_vvvo, t1T)) theta = eris_vvvo = None time1 = log.timer_debug1('vvvo [%d:%d]'%(p0, p1), *time1) wVOov = numpy.vstack(wVOov) wVOov = mpi.alltoall([wVOov[:,q0:q1] for q0,q1 in vlocs], split_recvbuf=True) wVOov = numpy.vstack([x.reshape(-1,nvir_seg,nocc,nocc) for x in wVOov]) fswap['wVOov'] = wVOov.transpose(1,2,3,0) wVooV = None unit = nocc**2*nvir*7 + nocc**3 + nocc*nvir**2 max_memory = max(0, mycc.max_memory - lib.current_memory()[0]) blksize = min(nvir, max(BLKMIN, int((max_memory*.9e6/8-nocc**4)/unit))) log.debug1('pass 2, max_memory %d MB, nocc,nvir = %d,%d blksize = %d', max_memory, nocc, nvir, blksize) woooo = numpy.zeros((nocc,nocc,nocc,nocc)) for p0, p1 in lib.prange(vloc0, vloc1, blksize): i0, i1 = p0 - vloc0, p1 - vloc0 wVOov = fswap['wVOov'][i0:i1] wVooV = fswap['wVooV'][i0:i1] eris_ovoo = eris.ovoo[:,i0:i1] eris_oovv = numpy.empty((nocc,nocc,i1-i0,nvir)) def load_oovv(p0, p1): eris_oovv[:] = eris.oovv[:,:,p0:p1] with lib.call_in_background(load_oovv) as prefetch_oovv: #:eris_oovv = eris.oovv[:,:,i0:i1] prefetch_oovv(i0, i1) foo_priv += numpy.einsum('ck,kcji->ij', 2*t1T[p0:p1], eris_ovoo) foo_priv += numpy.einsum('ck,icjk->ij', -t1T[p0:p1], eris_ovoo) tmp = lib.einsum('al,jaik->lkji', t1T[p0:p1], eris_ovoo) woooo += tmp + tmp.transpose(1,0,3,2) tmp = None wVOov -= lib.einsum('jbik,ak->bjia', eris_ovoo, t1T) t2Tnew[i0:i1] += wVOov.transpose(0,3,1,2) wVooV += lib.einsum('kbij,ak->bija', eris_ovoo, t1T) eris_ovoo = None load_oovv = prefetch_oovv = None eris_ovvo = numpy.empty((nocc,i1-i0,nvir,nocc)) def load_ovvo(p0, p1): eris_ovvo[:] = eris.ovvo[:,p0:p1] with lib.call_in_background(load_ovvo) as prefetch_ovvo: #:eris_ovvo = eris.ovvo[:,i0:i1] prefetch_ovvo(i0, i1) t1T_priv[p0:p1] -= numpy.einsum('bj,jiab->ai', t1T, eris_oovv) wVooV -= eris_oovv.transpose(2,0,1,3) wVOov += wVooV*.5 #: bjia + bija*.5 eris_voov = eris_ovvo.transpose(1,0,3,2) eris_ovvo = None load_ovvo = prefetch_ovvo = None def update_wVooV(i0, i1): wVooV[:] += fswap['wVooV1'][i0:i1] fswap['wVooV1'][i0:i1] = wVooV wVOov[:] += fswap['wVOov1'][i0:i1] fswap['wVOov1'][i0:i1] = wVOov with lib.call_in_background(update_wVooV) as update_wVooV: update_wVooV(i0, i1) t2Tnew[i0:i1] += eris_voov.transpose(0,3,1,2) * .5 t1T_priv[p0:p1] += 2*numpy.einsum('bj,aijb->ai', t1T, eris_voov) tmp = lib.einsum('ci,kjbc->bijk', t1T, eris_oovv) tmp += lib.einsum('bjkc,ci->bjik', eris_voov, t1T) t2Tnew[i0:i1] -= lib.einsum('bjik,ak->baji', tmp, t1T) eris_oovv = tmp = None fov_priv[:,p0:p1] += numpy.einsum('ck,aikc->ia', t1T, eris_voov) * 2 fov_priv[:,p0:p1] -= numpy.einsum('ck,akic->ia', t1T, eris_voov) tau = numpy.einsum('ai,bj->abij', t1T[p0:p1]*.5, t1T) tau += t2T[i0:i1] theta = tau.transpose(0,1,3,2) * 2 theta -= tau fvv_priv -= lib.einsum('caij,cjib->ab', theta, eris_voov) foo_priv += lib.einsum('aikb,abkj->ij', eris_voov, theta) tau = theta = None tau = t2T[i0:i1] + numpy.einsum('ai,bj->abij', t1T[p0:p1], t1T) woooo += lib.einsum('abij,aklb->ijkl', tau, eris_voov) tau = None eris_VOov = wVOov = wVooV = update_wVooV = None time1 = log.timer_debug1('voov [%d:%d]'%(p0, p1), *time1) wVooV = _cp(fswap['wVooV1']) for task_id, wVooV, p0, p1 in _rotate_vir_block(wVooV): tmp = lib.einsum('ackj,ckib->ajbi', t2T[:,p0:p1], wVooV) t2Tnew += tmp.transpose(0,2,3,1) t2Tnew += tmp.transpose(0,2,1,3) * .5 wVooV = tmp = None time1 = log.timer_debug1('contracting wVooV', *time1) wVOov = _cp(fswap['wVOov1']) theta = t2T * 2 theta -= t2T.transpose(0,1,3,2) for task_id, wVOov, p0, p1 in _rotate_vir_block(wVOov): t2Tnew += lib.einsum('acik,ckjb->abij', theta[:,p0:p1], wVOov) wVOov = theta = None fswap = None time1 = log.timer_debug1('contracting wVOov', *time1) foo += mpi.allreduce(foo_priv) fov += mpi.allreduce(fov_priv) fvv += mpi.allreduce(fvv_priv) theta = t2T.transpose(0,1,3,2) * 2 - t2T t1T_priv[vloc0:vloc1] += numpy.einsum('jb,abji->ai', fov, theta) ovoo = _cp(eris.ovoo) for task_id, ovoo, p0, p1 in _rotate_vir_block(ovoo): t1T_priv[vloc0:vloc1] -= lib.einsum('jbki,abjk->ai', ovoo, theta[:,p0:p1]) theta = ovoo = None woooo = mpi.allreduce(woooo) woooo += _cp(eris.oooo).transpose(0,2,1,3) tau = t2T + numpy.einsum('ai,bj->abij', t1T[vloc0:vloc1], t1T) t2Tnew += .5 * lib.einsum('abkl,ijkl->abij', tau, woooo) tau = woooo = None t1Tnew += mpi.allreduce(t1T_priv) ft_ij = foo + numpy.einsum('aj,ia->ij', .5*t1T, fov) ft_ab = fvv - numpy.einsum('ai,ib->ab', .5*t1T, fov) t2Tnew += lib.einsum('acij,bc->abij', t2T, ft_ab) t2Tnew -= lib.einsum('ki,abkj->abij', ft_ij, t2T) eia = mo_e_o[:,None] - mo_e_v t1Tnew += numpy.einsum('bi,ab->ai', t1T, fvv) t1Tnew -= numpy.einsum('aj,ji->ai', t1T, foo) t1Tnew /= eia.T t2tmp = mpi.alltoall([t2Tnew[:,p0:p1] for p0,p1 in vlocs], split_recvbuf=True) for task_id, (p0, p1) in enumerate(vlocs): tmp = t2tmp[task_id].reshape(p1-p0,nvir_seg,nocc,nocc) t2Tnew[:,p0:p1] += tmp.transpose(1,0,3,2) for i in range(vloc0, vloc1): t2Tnew[i-vloc0] /= lib.direct_sum('i+jb->bij', eia[:,i], eia) time0 = log.timer_debug1('update t1 t2', *time0) return t1Tnew.T, t2Tnew.transpose(2,3,0,1)
def _assemble(mydf, kptij_lst, j3c_jobs, gen_int3c, ft_fuse, cderi_file, fswap, log): t1 = (time.clock(), time.time()) cell = mydf.cell ao_loc = cell.ao_loc_nr() nao = ao_loc[-1] kptis = kptij_lst[:,0] kptjs = kptij_lst[:,1] kpt_ji = kptjs - kptis uniq_kpts, uniq_index, uniq_inverse = unique(kpt_ji) aosym_s2 = numpy.einsum('ix->i', abs(kptis-kptjs)) < 1e-9 t2 = t1 j3c_workers = numpy.zeros(len(j3c_jobs), dtype=int) #for job_id, ish0, ish1 in mpi.work_share_partition(j3c_jobs): for job_id, ish0, ish1 in mpi.work_stealing_partition(j3c_jobs): gen_int3c(job_id, ish0, ish1) t2 = log.alltimer_debug2('int j3c %d' % job_id, *t2) for k, kpt in enumerate(uniq_kpts): ft_fuse(job_id, k, ish0, ish1) t2 = log.alltimer_debug2('ft-fuse %d k %d' % (job_id, k), *t2) j3c_workers[job_id] = rank j3c_workers = mpi.allreduce(j3c_workers) log.debug2('j3c_workers %s', j3c_workers) t1 = log.timer_debug1('int3c and fuse', *t1) # Pass 2 # Transpose 3-index tensor and save data in cderi_file feri = h5py.File(cderi_file, 'w') nauxs = [fswap['j2c/%d'%k].shape[0] for k, kpt in enumerate(uniq_kpts)] segsize = (max(nauxs)+mpi.pool.size-1) // mpi.pool.size naux0 = rank * segsize for k, kptij in enumerate(kptij_lst): naux1 = min(nauxs[uniq_inverse[k]], naux0+segsize) nrow = max(0, naux1-naux0) if gamma_point(kptij): dtype = 'f8' else: dtype = 'c16' if aosym_s2[k]: nao_pair = nao * (nao+1) // 2 else: nao_pair = nao * nao feri.create_dataset('j3c/%d'%k, (nrow,nao_pair), dtype, maxshape=(None,nao_pair)) def get_segs_loc(aosym): off0 = numpy.asarray([ao_loc[i0] for x,i0,i1 in j3c_jobs]) off1 = numpy.asarray([ao_loc[i1] for x,i0,i1 in j3c_jobs]) if aosym: # s2 dims = off1*(off1+1)//2 - off0*(off0+1)//2 else: dims = (off1-off0) * nao #dims = numpy.asarray([ao_loc[i1]-ao_loc[i0] for x,i0,i1 in j3c_jobs]) dims = numpy.hstack([dims[j3c_workers==w] for w in range(mpi.pool.size)]) job_idx = numpy.hstack([numpy.where(j3c_workers==w)[0] for w in range(mpi.pool.size)]) segs_loc = numpy.append(0, numpy.cumsum(dims)) segs_loc = [(segs_loc[j], segs_loc[j+1]) for j in numpy.argsort(job_idx)] return segs_loc segs_loc_s1 = get_segs_loc(False) segs_loc_s2 = get_segs_loc(True) job_ids = numpy.where(rank == j3c_workers)[0] def load(k, p0, p1): naux1 = nauxs[uniq_inverse[k]] slices = [(min(i*segsize+p0,naux1), min(i*segsize+p1,naux1)) for i in range(mpi.pool.size)] segs = [] for p0, p1 in slices: val = [fswap['j3c-chunks/%d/%d' % (job, k)][p0:p1].ravel() for job in job_ids] if val: segs.append(numpy.hstack(val)) else: segs.append(numpy.zeros(0)) return segs def save(k, p0, p1, segs): segs = mpi.alltoall(segs) naux1 = nauxs[uniq_inverse[k]] loc0, loc1 = min(p0, naux1-naux0), min(p1, naux1-naux0) nL = loc1 - loc0 if nL > 0: if aosym_s2[k]: segs = numpy.hstack([segs[i0*nL:i1*nL].reshape(nL,-1) for i0,i1 in segs_loc_s2]) else: segs = numpy.hstack([segs[i0*nL:i1*nL].reshape(nL,-1) for i0,i1 in segs_loc_s1]) feri['j3c/%d'%k][loc0:loc1] = segs mem_now = max(comm.allgather(lib.current_memory()[0])) max_memory = max(2000, min(8000, mydf.max_memory - mem_now)) if numpy.all(aosym_s2): if gamma_point(kptij_lst): blksize = max(16, int(max_memory*.5e6/8/nao**2)) else: blksize = max(16, int(max_memory*.5e6/16/nao**2)) else: blksize = max(16, int(max_memory*.5e6/16/nao**2/2)) log.debug1('max_momory %d MB (%d in use), blksize %d', max_memory, mem_now, blksize) t2 = t1 with lib.call_in_background(save) as async_write: for k, kptji in enumerate(kptij_lst): for p0, p1 in lib.prange(0, segsize, blksize): segs = load(k, p0, p1) async_write(k, p0, p1, segs) t2 = log.timer_debug1('assemble k=%d %d:%d (in %d)' % (k, p0, p1, segsize), *t2) if 'j2c-' in fswap: j2c_kpts_lists = [] for k, kpt in enumerate(uniq_kpts): if ('j2c-/%d' % k) in fswap: adapted_ji_idx = numpy.where(uniq_inverse == k)[0] j2c_kpts_lists.append(adapted_ji_idx) for k in numpy.hstack(j2c_kpts_lists): val = [numpy.asarray(fswap['j3c-/%d/%d' % (job, k)]).ravel() for job in job_ids] val = mpi.gather(numpy.hstack(val)) if rank == 0: naux1 = fswap['j3c-/0/%d'%k].shape[0] if aosym_s2[k]: v = [val[i0*naux1:i1*naux1].reshape(naux1,-1) for i0,i1 in segs_loc_s2] else: v = [val[i0*naux1:i1*naux1].reshape(naux1,-1) for i0,i1 in segs_loc_s1] feri['j3c-/%d'%k] = numpy.hstack(v) if 'j3c-kptij' in feri: del(feri['j3c-kptij']) feri['j3c-kptij'] = kptij_lst t1 = log.alltimer_debug1('assembling j3c', *t1) feri.close()