예제 #1
0
def _add_vvvv_full(mycc, t1T, t2T, eris, out=None, with_ovvv=False):
    '''Ht2 = numpy.einsum('ijcd,acdb->ijab', t2, vvvv)
    without using symmetry t2[ijab] = t2[jiba] in t2 or Ht2
    '''
    time0 = time.clock(), time.time()
    log = logger.Logger(mycc.stdout, mycc.verbose)

    nvir_seg, nvir, nocc = t2T.shape[:3]
    vloc0, vloc1 = _task_location(nvir, rank)
    nocc2 = nocc * (nocc + 1) // 2
    if t1T is None:
        tau = lib.pack_tril(t2T.reshape(nvir_seg * nvir, nocc2))
    else:
        tau = t2T + numpy.einsum('ai,bj->abij', t1T[vloc0:vloc1], t1T)
        tau = lib.pack_tril(tau.reshape(nvir_seg * nvir, nocc2))
    tau = tau.reshape(nvir_seg, nvir, nocc2)

    if mycc.direct:  # AO-direct CCSD
        if with_ovvv:
            raise NotImplementedError
        mo = getattr(eris, 'mo_coeff', None)
        if mo is None:  # If eris does not have the attribute mo_coeff
            mo = _mo_without_core(mycc, mycc.mo_coeff)

        ao_loc = mycc.mol.ao_loc_nr()
        nao, nmo = mo.shape
        ntasks = mpi.pool.size
        task_sh_locs = lib.misc._balanced_partition(ao_loc, ntasks)
        ao_loc0 = ao_loc[task_sh_locs[rank]]
        ao_loc1 = ao_loc[task_sh_locs[rank + 1]]

        orbv = mo[:, nocc:]
        tau = lib.einsum('abij,pb->apij', tau, orbv)
        tau_priv = numpy.zeros((ao_loc1 - ao_loc0, nao, nocc, nocc))
        for task_id, tau in _rotate_tensor_block(tau):
            loc0, loc1 = _task_location(nvir, task_id)
            tau_priv += lib.einsum('pa,abij->pbij', orbv[ao_loc0:ao_loc1,
                                                         loc0:loc1], tau)
        tau = None
        time1 = log.timer_debug1('vvvv-tau mo2ao', *time0)

        buf = _contract_vvvv_t2(mycc, None, tau_priv, task_sh_locs, None, log)
        buf = buf.reshape(tau_priv.shape)
        tau_priv = None
        time1 = log.timer_debug1('vvvv-tau contraction', *time1)

        buf = lib.einsum('apij,pb->abij', buf, orbv)
        Ht2 = numpy.ndarray(t2T.shape, buffer=out)
        Ht2[:] = 0
        for task_id, buf in _rotate_tensor_block(buf):
            ao_loc0 = ao_loc[task_sh_locs[task_id]]
            ao_loc1 = ao_loc[task_sh_locs[task_id + 1]]
            Ht2 += lib.einsum('pa,pbij->abij', orbv[ao_loc0:ao_loc1,
                                                    vloc0:vloc1], buf)

        time1 = log.timer_debug1('vvvv-tau ao2mo', *time1)
    else:
        raise NotImplementedError
    return Ht2.reshape(t2T.shape)
예제 #2
0
파일: df.py 프로젝트: yfyh2013/mpi4pyscf
def build(mydf, j_only=None, with_j3c=True, kpts_band=None):
    # Unlike DF and AFT class, here MDF objects are synced once
    if mpi.pool.size == 1:
        return df.DF.build(mydf, j_only, with_j3c, kpts_band)

    mydf = _sync_mydf(mydf)
    cell = mydf.cell
    log = logger.Logger(mydf.stdout, mydf.verbose)
    info = rank, platform.node(), platform.os.getpid()
    log.debug('MPI info (rank, host, pid)  %s', comm.gather(info))

    t1 = (time.clock(), time.time())
    if mydf.kpts_band is not None:
        mydf.kpts_band = numpy.reshape(mydf.kpts_band, (-1, 3))
    if kpts_band is not None:
        kpts_band = numpy.reshape(kpts_band, (-1, 3))
        if mydf.kpts_band is None:
            mydf.kpts_band = kpts_band
        else:
            mydf.kpts_band = unique(numpy.vstack(
                (mydf.kpts_band, kpts_band)))[0]

    mydf.dump_flags()

    mydf.auxcell = make_modrho_basis(cell, mydf.auxbasis, mydf.eta)

    if mydf.kpts_band is None:
        kpts = mydf.kpts
        kband_uniq = numpy.zeros((0, 3))
    else:
        kpts = mydf.kpts
        kband_uniq = [k for k in mydf.kpts_band if len(member(k, kpts)) == 0]
    if j_only is None:
        j_only = mydf._j_only
    if j_only:
        kall = numpy.vstack([kpts, kband_uniq])
        kptij_lst = numpy.hstack((kall, kall)).reshape(-1, 2, 3)
    else:
        kptij_lst = [(ki, kpts[j]) for i, ki in enumerate(kpts)
                     for j in range(i + 1)]
        kptij_lst.extend([(ki, kj) for ki in kband_uniq for kj in kpts])
        kptij_lst.extend([(ki, ki) for ki in kband_uniq])
        kptij_lst = numpy.asarray(kptij_lst)

    if with_j3c:
        if isinstance(mydf._cderi_to_save, str):
            cderi = mydf._cderi_to_save
        else:
            cderi = mydf._cderi_to_save.name
        if isinstance(mydf._cderi, str):
            log.warn(
                'Value of _cderi is ignored. DF integrals will be '
                'saved in file %s .', cderi)
        mydf._cderi = cderi
        mydf._make_j3c(cell, mydf.auxcell, kptij_lst, cderi)
        t1 = log.timer_debug1('j3c', *t1)
    return mydf
예제 #3
0
파일: df.py 프로젝트: plin1112/mpi4pyscf
def _make_j3c(mydf, cell, auxcell, kptij_lst, cderi_file):
    log = logger.Logger(mydf.stdout, mydf.verbose)
    t1 = t0 = (time.clock(), time.time())

    fused_cell, fuse = fuse_auxcell(mydf, mydf.auxcell)
    ao_loc = cell.ao_loc_nr()
    nao = ao_loc[-1]
    naux = auxcell.nao_nr()
    nkptij = len(kptij_lst)
    mesh = mydf.mesh
    Gv, Gvbase, kws = cell.get_Gv_weights(mesh)
    b = cell.reciprocal_vectors()
    gxyz = lib.cartesian_prod([numpy.arange(len(x)) for x in Gvbase])
    ngrids = gxyz.shape[0]

    kptis = kptij_lst[:, 0]
    kptjs = kptij_lst[:, 1]
    kpt_ji = kptjs - kptis
    uniq_kpts, uniq_index, uniq_inverse = unique(kpt_ji)
    log.debug('Num uniq kpts %d', len(uniq_kpts))
    log.debug2('uniq_kpts %s', uniq_kpts)
    # j2c ~ (-kpt_ji | kpt_ji)
    j2c = fused_cell.pbc_intor('int2c2e', hermi=1, kpts=uniq_kpts)
    j2ctags = []
    t1 = log.timer_debug1('2c2e', *t1)

    swapfile = tempfile.NamedTemporaryFile(dir=os.path.dirname(cderi_file))
    fswap = lib.H5TmpFile(swapfile.name)
    # Unlink swapfile to avoid trash
    swapfile = None

    mem_now = max(comm.allgather(lib.current_memory()[0]))
    max_memory = max(2000, mydf.max_memory - mem_now)
    blksize = max(2048, int(max_memory * .5e6 / 16 / fused_cell.nao_nr()))
    log.debug2('max_memory %s (MB)  blocksize %s', max_memory, blksize)
    for k, kpt in enumerate(uniq_kpts):
        coulG = mydf.weighted_coulG(kpt, False, mesh)
        j2c_k = numpy.zeros_like(j2c[k])
        for p0, p1 in mydf.prange(0, ngrids, blksize):
            aoaux = ft_ao.ft_ao(fused_cell, Gv[p0:p1], None, b, gxyz[p0:p1],
                                Gvbase, kpt).T
            LkR = numpy.asarray(aoaux.real, order='C')
            LkI = numpy.asarray(aoaux.imag, order='C')
            aoaux = None

            if is_zero(kpt):  # kpti == kptj
                j2c_k[naux:] += lib.ddot(LkR[naux:] * coulG[p0:p1], LkR.T)
                j2c_k[naux:] += lib.ddot(LkI[naux:] * coulG[p0:p1], LkI.T)
            else:
                j2cR, j2cI = zdotCN(LkR[naux:] * coulG[p0:p1],
                                    LkI[naux:] * coulG[p0:p1], LkR.T, LkI.T)
                j2c_k[naux:] += j2cR + j2cI * 1j
            kLR = kLI = None

        j2c_k[:naux, naux:] = j2c_k[naux:, :naux].conj().T
        j2c[k] -= mpi.allreduce(j2c_k)
        j2c[k] = fuse(fuse(j2c[k]).T).T
        try:
            fswap['j2c/%d' % k] = scipy.linalg.cholesky(j2c[k], lower=True)
            j2ctags.append('CD')
        except scipy.linalg.LinAlgError as e:
            #msg =('===================================\n'
            #      'J-metric not positive definite.\n'
            #      'It is likely that mesh is not enough.\n'
            #      '===================================')
            #log.error(msg)
            #raise scipy.linalg.LinAlgError('\n'.join([str(e), msg]))
            w, v = scipy.linalg.eigh(j2c[k])
            log.debug2('metric linear dependency for kpt %s', k)
            log.debug2('cond = %.4g, drop %d bfns', w[0] / w[-1],
                       numpy.count_nonzero(w < mydf.linear_dep_threshold))
            v1 = v[:, w > mydf.linear_dep_threshold].T.conj()
            v1 /= numpy.sqrt(w[w > mydf.linear_dep_threshold]).reshape(-1, 1)
            fswap['j2c/%d' % k] = v1
            if cell.dimension == 2 and cell.low_dim_ft_type != 'inf_vacuum':
                idx = numpy.where(w < -mydf.linear_dep_threshold)[0]
                if len(idx) > 0:
                    fswap['j2c-/%d' % k] = (v[:, idx] /
                                            numpy.sqrt(-w[idx])).conj().T
            w = v = v1 = None
            j2ctags.append('eig')
    j2c = coulG = None

    aosym_s2 = numpy.einsum('ix->i', abs(kptis - kptjs)) < 1e-9
    j_only = numpy.all(aosym_s2)
    if gamma_point(kptij_lst):
        dtype = 'f8'
    else:
        dtype = 'c16'
    t1 = log.timer_debug1('aoaux and int2c', *t1)

    # Estimates the buffer size based on the last contraction in G-space.
    # This contraction requires to hold nkptj copies of (naux,?) array
    # simultaneously in memory.
    mem_now = max(comm.allgather(lib.current_memory()[0]))
    max_memory = max(2000, mydf.max_memory - mem_now)
    nkptj_max = max((uniq_inverse == x).sum() for x in set(uniq_inverse))
    buflen = max(
        int(
            min(max_memory * .5e6 / 16 / naux / (nkptj_max + 2) / nao,
                nao / 3 / mpi.pool.size)), 1)
    chunks = (buflen, nao)

    j3c_jobs = grids2d_int3c_jobs(cell, auxcell, kptij_lst, chunks, j_only)
    log.debug1('max_memory = %d MB (%d in use)  chunks %s', max_memory,
               mem_now, chunks)
    log.debug2('j3c_jobs %s', j3c_jobs)

    if j_only:
        int3c = wrap_int3c(cell, fused_cell, 'int3c2e', 's2', 1, kptij_lst)
    else:
        int3c = wrap_int3c(cell, fused_cell, 'int3c2e', 's1', 1, kptij_lst)
        idxb = numpy.tril_indices(nao)
        idxb = (idxb[0] * nao + idxb[1]).astype('i')
    aux_loc = fused_cell.ao_loc_nr('ssc' in 'int3c2e')

    def gen_int3c(job_id, ish0, ish1):
        dataname = 'j3c-chunks/%d' % job_id
        i0 = ao_loc[ish0]
        i1 = ao_loc[ish1]
        dii = i1 * (i1 + 1) // 2 - i0 * (i0 + 1) // 2
        if j_only:
            dij = dii
            buflen = max(8, int(max_memory * 1e6 / 16 / (nkptij * dii + dii)))
        else:
            dij = (i1 - i0) * nao
            buflen = max(8, int(max_memory * 1e6 / 16 / (nkptij * dij + dij)))
        auxranges = balance_segs(aux_loc[1:] - aux_loc[:-1], buflen)
        buflen = max([x[2] for x in auxranges])
        buf = numpy.empty(nkptij * dij * buflen, dtype=dtype)
        buf1 = numpy.empty(dij * buflen, dtype=dtype)

        naux = aux_loc[-1]
        for kpt_id, kptij in enumerate(kptij_lst):
            key = '%s/%d' % (dataname, kpt_id)
            if aosym_s2[kpt_id]:
                shape = (naux, dii)
            else:
                shape = (naux, dij)
            if gamma_point(kptij):
                fswap.create_dataset(key, shape, 'f8')
            else:
                fswap.create_dataset(key, shape, 'c16')

        naux0 = 0
        for istep, auxrange in enumerate(auxranges):
            log.alldebug2("aux_e1 job_id %d step %d", job_id, istep)
            sh0, sh1, nrow = auxrange
            sub_slice = (ish0, ish1, 0, cell.nbas, sh0, sh1)
            mat = numpy.ndarray((nkptij, dij, nrow), dtype=dtype, buffer=buf)
            mat = int3c(sub_slice, mat)

            for k, kptij in enumerate(kptij_lst):
                h5dat = fswap['%s/%d' % (dataname, k)]
                v = lib.transpose(mat[k], out=buf1)
                if not j_only and aosym_s2[k]:
                    idy = idxb[i0 * (i0 + 1) // 2:i1 *
                               (i1 + 1) // 2] - i0 * nao
                    out = numpy.ndarray((nrow, dii),
                                        dtype=v.dtype,
                                        buffer=mat[k])
                    v = numpy.take(v, idy, axis=1, out=out)
                if gamma_point(kptij):
                    h5dat[naux0:naux0 + nrow] = v.real
                else:
                    h5dat[naux0:naux0 + nrow] = v
            naux0 += nrow

    def ft_fuse(job_id, uniq_kptji_id, sh0, sh1):
        kpt = uniq_kpts[uniq_kptji_id]  # kpt = kptj - kpti
        adapted_ji_idx = numpy.where(uniq_inverse == uniq_kptji_id)[0]
        adapted_kptjs = kptjs[adapted_ji_idx]
        nkptj = len(adapted_kptjs)

        j2c = numpy.asarray(fswap['j2c/%d' % uniq_kptji_id])
        j2ctag = j2ctags[uniq_kptji_id]
        naux0 = j2c.shape[0]
        if ('j2c-/%d' % uniq_kptji_id) in fswap:
            j2c_negative = numpy.asarray(fswap['j2c-/%d' % uniq_kptji_id])
        else:
            j2c_negative = None

        if is_zero(kpt):
            aosym = 's2'
        else:
            aosym = 's1'

        if aosym == 's2' and cell.dimension == 3:
            vbar = fuse(mydf.auxbar(fused_cell))
            ovlp = cell.pbc_intor('int1e_ovlp', hermi=1, kpts=adapted_kptjs)
            ovlp = [lib.pack_tril(s) for s in ovlp]

        j3cR = [None] * nkptj
        j3cI = [None] * nkptj
        i0 = ao_loc[sh0]
        i1 = ao_loc[sh1]
        for k, idx in enumerate(adapted_ji_idx):
            key = 'j3c-chunks/%d/%d' % (job_id, idx)
            v = numpy.asarray(fswap[key])
            if aosym == 's2' and cell.dimension == 3:
                for i in numpy.where(vbar != 0)[0]:
                    v[i] -= vbar[i] * ovlp[k][i0 * (i0 + 1) // 2:i1 *
                                              (i1 + 1) // 2].ravel()
            j3cR[k] = numpy.asarray(v.real, order='C')
            if v.dtype == numpy.complex128:
                j3cI[k] = numpy.asarray(v.imag, order='C')
            v = None

        ncol = j3cR[0].shape[1]
        Gblksize = max(16, int(max_memory * 1e6 / 16 / ncol /
                               (nkptj + 1)))  # +1 for pqkRbuf/pqkIbuf
        Gblksize = min(Gblksize, ngrids, 16384)
        pqkRbuf = numpy.empty(ncol * Gblksize)
        pqkIbuf = numpy.empty(ncol * Gblksize)
        buf = numpy.empty(nkptj * ncol * Gblksize, dtype=numpy.complex128)
        log.alldebug2('job_id %d  blksize (%d,%d)', job_id, Gblksize, ncol)

        wcoulG = mydf.weighted_coulG(kpt, False, mesh)
        fused_cell_slice = (auxcell.nbas, fused_cell.nbas)
        if aosym == 's2':
            shls_slice = (sh0, sh1, 0, sh1)
        else:
            shls_slice = (sh0, sh1, 0, cell.nbas)
        for p0, p1 in lib.prange(0, ngrids, Gblksize):
            Gaux = ft_ao.ft_ao(fused_cell, Gv[p0:p1], fused_cell_slice, b,
                               gxyz[p0:p1], Gvbase, kpt)
            Gaux *= wcoulG[p0:p1, None]
            kLR = Gaux.real.copy('C')
            kLI = Gaux.imag.copy('C')
            Gaux = None

            dat = ft_ao._ft_aopair_kpts(cell,
                                        Gv[p0:p1],
                                        shls_slice,
                                        aosym,
                                        b,
                                        gxyz[p0:p1],
                                        Gvbase,
                                        kpt,
                                        adapted_kptjs,
                                        out=buf)
            nG = p1 - p0
            for k, ji in enumerate(adapted_ji_idx):
                aoao = dat[k].reshape(nG, ncol)
                pqkR = numpy.ndarray((ncol, nG), buffer=pqkRbuf)
                pqkI = numpy.ndarray((ncol, nG), buffer=pqkIbuf)
                pqkR[:] = aoao.real.T
                pqkI[:] = aoao.imag.T

                lib.dot(kLR.T, pqkR.T, -1, j3cR[k][naux:], 1)
                lib.dot(kLI.T, pqkI.T, -1, j3cR[k][naux:], 1)
                if not (is_zero(kpt) and gamma_point(adapted_kptjs[k])):
                    lib.dot(kLR.T, pqkI.T, -1, j3cI[k][naux:], 1)
                    lib.dot(kLI.T, pqkR.T, 1, j3cI[k][naux:], 1)
            kLR = kLI = None

        for k, idx in enumerate(adapted_ji_idx):
            if is_zero(kpt) and gamma_point(adapted_kptjs[k]):
                v = fuse(j3cR[k])
            else:
                v = fuse(j3cR[k] + j3cI[k] * 1j)
            if j2ctag == 'CD':
                v = scipy.linalg.solve_triangular(j2c,
                                                  v,
                                                  lower=True,
                                                  overwrite_b=True)
                fswap['j3c-chunks/%d/%d' % (job_id, idx)][:naux0] = v
            else:
                fswap['j3c-chunks/%d/%d' % (job_id, idx)][:naux0] = lib.dot(
                    j2c, v)

            # low-dimension systems
            if j2c_negative is not None:
                fswap['j3c-/%d/%d' % (job_id, idx)] = lib.dot(j2c_negative, v)

    _assemble(mydf, kptij_lst, j3c_jobs, gen_int3c, ft_fuse, cderi_file, fswap,
              log)
예제 #4
0
파일: df.py 프로젝트: plin1112/mpi4pyscf
 def dump_flags(self, verbose=None):
     return df.DF.dump_flags(self, logger.Logger(self.stdout, self.verbose))
예제 #5
0
파일: df.py 프로젝트: yfyh2013/mpi4pyscf
def _make_j3c(mydf, cell, auxcell, kptij_lst, cderi_file):
    log = logger.Logger(mydf.stdout, mydf.verbose)
    t1 = t0 = (time.clock(), time.time())

    fused_cell, fuse = fuse_auxcell(mydf, mydf.auxcell)
    ao_loc = cell.ao_loc_nr()
    nao = ao_loc[-1]
    naux = auxcell.nao_nr()
    nkptij = len(kptij_lst)
    gs = mydf.gs
    Gv, Gvbase, kws = cell.get_Gv_weights(gs)
    b = cell.reciprocal_vectors()
    gxyz = lib.cartesian_prod([numpy.arange(len(x)) for x in Gvbase])
    ngs = gxyz.shape[0]

    kptis = kptij_lst[:, 0]
    kptjs = kptij_lst[:, 1]
    kpt_ji = kptjs - kptis
    uniq_kpts, uniq_index, uniq_inverse = unique(kpt_ji)
    log.debug('Num uniq kpts %d', len(uniq_kpts))
    log.debug2('uniq_kpts %s', uniq_kpts)
    # j2c ~ (-kpt_ji | kpt_ji)
    j2c = fused_cell.pbc_intor('int2c2e_sph', hermi=1, kpts=uniq_kpts)
    j2ctags = []
    nauxs = []
    t1 = log.timer_debug1('2c2e', *t1)

    if h5py.is_hdf5(cderi_file):
        feri = h5py.File(cderi_file)
    else:
        feri = h5py.File(cderi_file, 'w')
    for k, kpt in enumerate(uniq_kpts):
        aoaux = ft_ao.ft_ao(fused_cell, Gv, None, b, gxyz, Gvbase, kpt).T
        coulG = numpy.sqrt(mydf.weighted_coulG(kpt, False, gs))
        kLR = (aoaux.real * coulG).T
        kLI = (aoaux.imag * coulG).T
        if not kLR.flags.c_contiguous: kLR = lib.transpose(kLR.T)
        if not kLI.flags.c_contiguous: kLI = lib.transpose(kLI.T)
        aoaux = None

        kLR1 = numpy.asarray(kLR[:, naux:], order='C')
        kLI1 = numpy.asarray(kLI[:, naux:], order='C')
        if is_zero(kpt):  # kpti == kptj
            for p0, p1 in mydf.mpi_prange(0, ngs):
                j2cR = lib.ddot(kLR1[p0:p1].T, kLR[p0:p1])
                j2cR = lib.ddot(kLI1[p0:p1].T, kLI[p0:p1], 1, j2cR, 1)
                j2c[k][naux:] -= mpi.allreduce(j2cR)
                j2c[k][:naux, naux:] = j2c[k][naux:, :naux].T
        else:
            for p0, p1 in mydf.mpi_prange(0, ngs):
                j2cR, j2cI = zdotCN(kLR1[p0:p1].T, kLI1[p0:p1].T, kLR[p0:p1],
                                    kLI[p0:p1])
                j2cR = mpi.allreduce(j2cR)
                j2cI = mpi.allreduce(j2cI)
                j2c[k][naux:] -= j2cR + j2cI * 1j
                j2c[k][:naux, naux:] = j2c[k][naux:, :naux].T.conj()
        j2c[k] = fuse(fuse(j2c[k]).T).T
        try:
            feri['j2c/%d' % k] = scipy.linalg.cholesky(j2c[k], lower=True)
            j2ctags.append('CD')
            nauxs.append(naux)
        except scipy.linalg.LinAlgError as e:
            #msg =('===================================\n'
            #      'J-metric not positive definite.\n'
            #      'It is likely that gs is not enough.\n'
            #      '===================================')
            #log.error(msg)
            #raise scipy.linalg.LinAlgError('\n'.join([e.message, msg]))
            w, v = scipy.linalg.eigh(j2c)
            log.debug2('metric linear dependency for kpt %s', uniq_kptji_id)
            log.debug2('cond = %.4g, drop %d bfns', w[0] / w[-1],
                       numpy.count_nonzero(w < LINEAR_DEP_THR))
            v = v[:, w > LINEAR_DEP_THR].T.conj()
            v /= numpy.sqrt(w[w > LINEAR_DEP_THR]).reshape(-1, 1)
            feri['j2c/%d' % k] = v
            j2ctags.append('eig')
            nauxs.append(v.shape[0])
        kLR = kLI = kLR1 = kLI1 = coulG = None
    j2c = None

    aosym_s2 = numpy.einsum('ix->i', abs(kptis - kptjs)) < 1e-9
    j_only = numpy.all(aosym_s2)
    if gamma_point(kptij_lst):
        dtype = 'f8'
    else:
        dtype = 'c16'
    vbar = mydf.auxbar(fused_cell)
    vbar = fuse(vbar)
    ovlp = cell.pbc_intor('int1e_ovlp_sph', hermi=1, kpts=kptjs[aosym_s2])
    ovlp = [lib.pack_tril(s) for s in ovlp]
    t1 = log.timer_debug1('aoaux and int2c', *t1)

    # Estimates the buffer size based on the last contraction in G-space.
    # This contraction requires to hold nkptj copies of (naux,?) array
    # simultaneously in memory.
    mem_now = max(comm.allgather(lib.current_memory()[0]))
    max_memory = max(2000, mydf.max_memory - mem_now)
    nkptj_max = max((uniq_inverse == x).sum() for x in set(uniq_inverse))
    buflen = max(
        int(
            min(max_memory * .5e6 / 16 / naux / (nkptj_max + 2) / nao,
                nao / 3 / mpi.pool.size)), 1)
    chunks = (buflen, nao)

    j3c_jobs = grids2d_int3c_jobs(cell, auxcell, kptij_lst, chunks, j_only)
    log.debug1('max_memory = %d MB (%d in use)  chunks %s', max_memory,
               mem_now, chunks)
    log.debug2('j3c_jobs %s', j3c_jobs)

    if j_only:
        int3c = wrap_int3c(cell, fused_cell, 'int3c2e_sph', 's2', 1, kptij_lst)
    else:
        int3c = wrap_int3c(cell, fused_cell, 'int3c2e_sph', 's1', 1, kptij_lst)
        idxb = numpy.tril_indices(nao)
        idxb = (idxb[0] * nao + idxb[1]).astype('i')
    aux_loc = fused_cell.ao_loc_nr('ssc' in 'int3c2e_sph')

    def gen_int3c(auxcell, job_id, ish0, ish1):
        dataname = 'j3c-chunks/%d' % job_id
        if dataname in feri:
            del (feri[dataname])

        i0 = ao_loc[ish0]
        i1 = ao_loc[ish1]
        dii = i1 * (i1 + 1) // 2 - i0 * (i0 + 1) // 2
        dij = (i1 - i0) * nao
        if j_only:
            buflen = max(8, int(max_memory * 1e6 / 16 / (nkptij * dii + dii)))
        else:
            buflen = max(8, int(max_memory * 1e6 / 16 / (nkptij * dij + dij)))
        auxranges = balance_segs(aux_loc[1:] - aux_loc[:-1], buflen)
        buflen = max([x[2] for x in auxranges])
        buf = numpy.empty(nkptij * dij * buflen, dtype=dtype)
        buf1 = numpy.empty(dij * buflen, dtype=dtype)

        naux = aux_loc[-1]
        for kpt_id, kptij in enumerate(kptij_lst):
            key = '%s/%d' % (dataname, kpt_id)
            if aosym_s2[kpt_id]:
                shape = (naux, dii)
            else:
                shape = (naux, dij)
            if gamma_point(kptij):
                feri.create_dataset(key, shape, 'f8')
            else:
                feri.create_dataset(key, shape, 'c16')

        naux0 = 0
        for istep, auxrange in enumerate(auxranges):
            log.alldebug2("aux_e2 job_id %d step %d", job_id, istep)
            sh0, sh1, nrow = auxrange
            sub_slice = (ish0, ish1, 0, cell.nbas, sh0, sh1)
            if j_only:
                mat = numpy.ndarray((nkptij, dii, nrow),
                                    dtype=dtype,
                                    buffer=buf)
            else:
                mat = numpy.ndarray((nkptij, dij, nrow),
                                    dtype=dtype,
                                    buffer=buf)
            mat = int3c(sub_slice, mat)

            for k, kptij in enumerate(kptij_lst):
                h5dat = feri['%s/%d' % (dataname, k)]
                v = lib.transpose(mat[k], out=buf1)
                if not j_only and aosym_s2[k]:
                    idy = idxb[i0 * (i0 + 1) // 2:i1 *
                               (i1 + 1) // 2] - i0 * nao
                    out = numpy.ndarray((nrow, dii),
                                        dtype=v.dtype,
                                        buffer=mat[k])
                    v = numpy.take(v, idy, axis=1, out=out)
                if gamma_point(kptij):
                    h5dat[naux0:naux0 + nrow] = v.real
                else:
                    h5dat[naux0:naux0 + nrow] = v
            naux0 += nrow

    def ft_fuse(job_id, uniq_kptji_id, sh0, sh1):
        kpt = uniq_kpts[uniq_kptji_id]  # kpt = kptj - kpti
        adapted_ji_idx = numpy.where(uniq_inverse == uniq_kptji_id)[0]
        adapted_kptjs = kptjs[adapted_ji_idx]
        nkptj = len(adapted_kptjs)

        shls_slice = (auxcell.nbas, fused_cell.nbas)
        Gaux = ft_ao.ft_ao(fused_cell, Gv, shls_slice, b, gxyz, Gvbase, kpt)
        Gaux *= mydf.weighted_coulG(kpt, False, gs).reshape(-1, 1)
        kLR = Gaux.real.copy('C')
        kLI = Gaux.imag.copy('C')
        j2c = numpy.asarray(feri['j2c/%d' % uniq_kptji_id])
        j2ctag = j2ctags[uniq_kptji_id]
        naux0 = j2c.shape[0]

        if is_zero(kpt):
            aosym = 's2'
        else:
            aosym = 's1'

        j3cR = [None] * nkptj
        j3cI = [None] * nkptj
        i0 = ao_loc[sh0]
        i1 = ao_loc[sh1]
        for k, idx in enumerate(adapted_ji_idx):
            key = 'j3c-chunks/%d/%d' % (job_id, idx)
            v = numpy.asarray(feri[key])
            if is_zero(kpt):
                for i, c in enumerate(vbar):
                    if c != 0:
                        v[i] -= c * ovlp[k][i0 * (i0 + 1) // 2:i1 *
                                            (i1 + 1) // 2].ravel()
            j3cR[k] = numpy.asarray(v.real, order='C')
            if v.dtype == numpy.complex128:
                j3cI[k] = numpy.asarray(v.imag, order='C')
            v = None

        ncol = j3cR[0].shape[1]
        Gblksize = max(16, int(max_memory * 1e6 / 16 / ncol /
                               (nkptj + 1)))  # +1 for pqkRbuf/pqkIbuf
        Gblksize = min(Gblksize, ngs, 16384)
        pqkRbuf = numpy.empty(ncol * Gblksize)
        pqkIbuf = numpy.empty(ncol * Gblksize)
        buf = numpy.empty(nkptj * ncol * Gblksize, dtype=numpy.complex128)
        log.alldebug2('    blksize (%d,%d)', Gblksize, ncol)

        shls_slice = (sh0, sh1, 0, cell.nbas)
        for p0, p1 in lib.prange(0, ngs, Gblksize):
            dat = ft_ao._ft_aopair_kpts(cell,
                                        Gv[p0:p1],
                                        shls_slice,
                                        aosym,
                                        b,
                                        gxyz[p0:p1],
                                        Gvbase,
                                        kpt,
                                        adapted_kptjs,
                                        out=buf)
            nG = p1 - p0
            for k, ji in enumerate(adapted_ji_idx):
                aoao = dat[k].reshape(nG, ncol)
                pqkR = numpy.ndarray((ncol, nG), buffer=pqkRbuf)
                pqkI = numpy.ndarray((ncol, nG), buffer=pqkIbuf)
                pqkR[:] = aoao.real.T
                pqkI[:] = aoao.imag.T

                lib.dot(kLR[p0:p1].T, pqkR.T, -1, j3cR[k][naux:], 1)
                lib.dot(kLI[p0:p1].T, pqkI.T, -1, j3cR[k][naux:], 1)
                if not (is_zero(kpt) and gamma_point(adapted_kptjs[k])):
                    lib.dot(kLR[p0:p1].T, pqkI.T, -1, j3cI[k][naux:], 1)
                    lib.dot(kLI[p0:p1].T, pqkR.T, 1, j3cI[k][naux:], 1)

        for k, idx in enumerate(adapted_ji_idx):
            if is_zero(kpt) and gamma_point(adapted_kptjs[k]):
                v = fuse(j3cR[k])
            else:
                v = fuse(j3cR[k] + j3cI[k] * 1j)
            if j2ctag == 'CD':
                v = scipy.linalg.solve_triangular(j2c,
                                                  v,
                                                  lower=True,
                                                  overwrite_b=True)
            else:
                v = lib.dot(j2c, v)
            feri['j3c-chunks/%d/%d' % (job_id, idx)][:naux0] = v

    t2 = t1
    j3c_workers = numpy.zeros(len(j3c_jobs), dtype=int)
    #for job_id, ish0, ish1 in mpi.work_share_partition(j3c_jobs):
    for job_id, ish0, ish1 in mpi.work_stealing_partition(j3c_jobs):
        gen_int3c(fused_cell, job_id, ish0, ish1)
        t2 = log.alltimer_debug2('int j3c %d' % job_id, *t2)

        for k, kpt in enumerate(uniq_kpts):
            ft_fuse(job_id, k, ish0, ish1)
            t2 = log.alltimer_debug2('ft-fuse %d k %d' % (job_id, k), *t2)

        j3c_workers[job_id] = rank
    j3c_workers = mpi.allreduce(j3c_workers)
    log.debug2('j3c_workers %s', j3c_workers)
    j2c = kLRs = kLIs = ovlp = vbar = fuse = gen_int3c = ft_fuse = None
    t1 = log.timer_debug1('int3c and fuse', *t1)

    def get_segs_loc(aosym):
        off0 = numpy.asarray([ao_loc[i0] for x, i0, i1 in j3c_jobs])
        off1 = numpy.asarray([ao_loc[i1] for x, i0, i1 in j3c_jobs])
        if aosym:  # s2
            dims = off1 * (off1 + 1) // 2 - off0 * (off0 + 1) // 2
        else:
            dims = (off1 - off0) * nao
        #dims = numpy.asarray([ao_loc[i1]-ao_loc[i0] for x,i0,i1 in j3c_jobs])
        dims = numpy.hstack(
            [dims[j3c_workers == w] for w in range(mpi.pool.size)])
        job_idx = numpy.hstack(
            [numpy.where(j3c_workers == w)[0] for w in range(mpi.pool.size)])
        segs_loc = numpy.append(0, numpy.cumsum(dims))
        segs_loc = [(segs_loc[j], segs_loc[j + 1])
                    for j in numpy.argsort(job_idx)]
        return segs_loc

    segs_loc_s1 = get_segs_loc(False)
    segs_loc_s2 = get_segs_loc(True)

    if 'j3c' in feri: del (feri['j3c'])
    segsize = (max(nauxs) + mpi.pool.size - 1) // mpi.pool.size
    naux0 = rank * segsize
    for k, kptij in enumerate(kptij_lst):
        naux1 = min(nauxs[uniq_inverse[k]], naux0 + segsize)
        nrow = max(0, naux1 - naux0)
        if gamma_point(kptij):
            dtype = 'f8'
        else:
            dtype = 'c16'
        if aosym_s2[k]:
            nao_pair = nao * (nao + 1) // 2
        else:
            nao_pair = nao * nao
        feri.create_dataset('j3c/%d' % k, (nrow, nao_pair),
                            dtype,
                            maxshape=(None, nao_pair))

    def load(k, p0, p1):
        naux1 = nauxs[uniq_inverse[k]]
        slices = [(min(i * segsize + p0, naux1), min(i * segsize + p1, naux1))
                  for i in range(mpi.pool.size)]
        segs = []
        for p0, p1 in slices:
            val = []
            for job_id, worker in enumerate(j3c_workers):
                if rank == worker:
                    key = 'j3c-chunks/%d/%d' % (job_id, k)
                    val.append(feri[key][p0:p1].ravel())
            if val:
                segs.append(numpy.hstack(val))
            else:
                segs.append(numpy.zeros(0))
        return segs

    def save(k, p0, p1, segs):
        segs = mpi.alltoall(segs)
        naux1 = nauxs[uniq_inverse[k]]
        loc0, loc1 = min(p0, naux1 - naux0), min(p1, naux1 - naux0)
        nL = loc1 - loc0
        if nL > 0:
            if aosym_s2[k]:
                segs = numpy.hstack([
                    segs[i0 * nL:i1 * nL].reshape(nL, -1)
                    for i0, i1 in segs_loc_s2
                ])
            else:
                segs = numpy.hstack([
                    segs[i0 * nL:i1 * nL].reshape(nL, -1)
                    for i0, i1 in segs_loc_s1
                ])
            feri['j3c/%d' % k][loc0:loc1] = segs

    mem_now = max(comm.allgather(lib.current_memory()[0]))
    max_memory = max(2000, min(8000, mydf.max_memory - mem_now))
    if numpy.all(aosym_s2):
        if gamma_point(kptij_lst):
            blksize = max(16, int(max_memory * .5e6 / 8 / nao**2))
        else:
            blksize = max(16, int(max_memory * .5e6 / 16 / nao**2))
    else:
        blksize = max(16, int(max_memory * .5e6 / 16 / nao**2 / 2))
    log.debug1('max_momory %d MB (%d in use), blksize %d', max_memory, mem_now,
               blksize)

    t2 = t1
    with lib.call_in_background(save) as async_write:
        for k, kptji in enumerate(kptij_lst):
            for p0, p1 in lib.prange(0, segsize, blksize):
                segs = load(k, p0, p1)
                async_write(k, p0, p1, segs)
                t2 = log.timer_debug1(
                    'assemble k=%d %d:%d (in %d)' % (k, p0, p1, segsize), *t2)

    if 'j3c-chunks' in feri: del (feri['j3c-chunks'])
    if 'j3c-kptij' in feri: del (feri['j3c-kptij'])
    feri['j3c-kptij'] = kptij_lst
    t1 = log.alltimer_debug1('assembling j3c', *t1)
    feri.close()
예제 #6
0
def update_lambda(mycc, t1, t2, l1, l2, eris, imds):
    """
    Update GCCSD lambda.
    """
    time0 = logger.process_clock(), logger.perf_counter()
    log = logger.Logger(mycc.stdout, mycc.verbose)

    t1T = t1.T
    t2T = np.asarray(t2.transpose(2, 3, 0, 1), order='C')
    t1 = t2 = None
    nvir_seg, nvir, nocc = t2T.shape[:3]
    l1T = l1.T
    l2T = np.asarray(l2.transpose(2, 3, 0, 1), order='C')
    l1 = l2 = None

    ntasks = mpi.pool.size
    vlocs = [_task_location(nvir, task_id) for task_id in range(ntasks)]
    vloc0, vloc1 = vlocs[rank]
    log.debug2('vlocs %s', vlocs)
    assert vloc1 - vloc0 == nvir_seg

    fvo = eris.fock[nocc:, :nocc]
    mo_e_o = eris.mo_energy[:nocc]
    mo_e_v = eris.mo_energy[nocc:] + mycc.level_shift
    v1 = imds.v1 - np.diag(mo_e_v)
    v2 = imds.v2 - np.diag(mo_e_o)

    mba = einsum('cakl, cbkl -> ba', l2T, t2T) * 0.5
    mba = mpi.allreduce_inplace(mba)
    mij = einsum('cdki, cdkj -> ij', l2T, t2T) * 0.5
    mij = mpi.allreduce_inplace(mij)
    # m3 [a]bij
    m3 = einsum('abkl, ijkl -> abij', l2T, np.asarray(imds.woooo))

    tauT = t2T  #+ np.einsum('ai, bj -> abij', t1T[vloc0:vloc1] * 2.0, t1T, optimize=True)
    tmp = einsum('cdij, cdkl -> ijkl', l2T, tauT)
    tmp = mpi.allreduce_inplace(tmp)
    tauT = None

    vvoo = np.asarray(eris.xvoo)
    tmp = einsum('abkl, ijkl -> abij', vvoo, tmp)
    tmp *= 0.25
    m3 += tmp
    tmp = None

    #tmp  = einsum('cdij, dk -> ckij', l2T, t1T)
    #for task_id, tmp, p0, p1 in _rotate_vir_block(tmp, vlocs=vlocs):
    #    m3 -= einsum('kcba, ckij -> abij', eris.ovvx[:, p0:p1], tmp)
    #    tmp = None
    eris_vvvv = eris.xvvv.transpose(2, 3, 0, 1)
    tmp_2 = np.empty_like(l2T)  # used for line 387
    for task_id, l2T_tmp, p0, p1 in _rotate_vir_block(l2T, vlocs=vlocs):
        tmp = einsum('cdij, cdab -> abij', l2T_tmp, eris_vvvv[p0:p1])
        tmp *= 0.5
        m3 += tmp
        tmp_2[:, p0:p1] = einsum('acij, cb -> baij', l2T_tmp, v1[:,
                                                                 vloc0:vloc1])
        tmp = l2T_tmp = None
    eris_vvvv = None

    #l1Tnew = einsum('abij, bj -> ai', m3, t1T)
    #l1Tnew = mpi.allgather(l1Tnew)
    l1Tnew = np.zeros_like(l1T)
    l2Tnew = m3

    l2Tnew += vvoo
    #fvo1 = fvo #+ mpi.allreduce(einsum('cbkj, ck -> bj', vvoo, t1T[vloc0:vloc1]))

    #tmp  = np.einsum('ai, bj -> abij', l1T[vloc0:vloc1], fvo1, optimize=True)
    tmp = 0.0
    wvovo = np.asarray(imds.wovvo).transpose(1, 0, 2, 3)
    for task_id, w_tmp, p0, p1 in _rotate_vir_block(wvovo, vlocs=vlocs):
        tmp -= einsum('acki, cjbk -> abij', l2T[:, p0:p1], w_tmp)
        w_tmp = None
    wvovo = None
    tmp = tmp - tmp.transpose(0, 1, 3, 2)
    l2Tnew += tmp
    tmpT = mpi.alltoall_new([tmp[:, p0:p1] for p0, p1 in vlocs],
                            split_recvbuf=True)
    for task_id, (p0, p1) in enumerate(vlocs):
        tmp = tmpT[task_id].reshape(p1 - p0, nvir_seg, nocc, nocc)
        l2Tnew[:, p0:p1] -= tmp.transpose(1, 0, 2, 3)
        tmp = None

    #tmp  = einsum('ak, ijkb -> baij', l1T, eris.ooox)
    #tmp -= tmp_2
    tmp = -tmp_2
    tmp1vv = mba  #+ np.dot(t1T, l1T.T) # ba

    tmp -= einsum('ca, bcij -> baij', tmp1vv, vvoo)
    l2Tnew += tmp
    tmpT = mpi.alltoall_new([tmp[:, p0:p1] for p0, p1 in vlocs],
                            split_recvbuf=True)
    for task_id, (p0, p1) in enumerate(vlocs):
        tmp = tmpT[task_id].reshape(p1 - p0, nvir_seg, nocc, nocc)
        l2Tnew[:, p0:p1] -= tmp.transpose(1, 0, 2, 3)
        tmp = None

    #tmp  = einsum('jcab, ci -> baji', eris.ovvx, -l1T)
    tmp = einsum('abki, jk -> abij', l2T, v2)
    tmp1oo = mij  #+ np.dot(l1T.T, t1T) # ik
    tmp -= einsum('ik, abkj -> abij', tmp1oo, vvoo)
    vvoo = None
    l2Tnew += tmp
    l2Tnew -= tmp.transpose(0, 1, 3, 2)
    tmp = None

    #l1Tnew += fvo
    #tmp = einsum('bj, ibja -> ai', -l1T[vloc0:vloc1], eris.oxov)
    #l1Tnew += np.dot(v1.T, l1T)
    #l1Tnew -= np.dot(l1T, v2.T)
    #tmp -= einsum('cakj, icjk -> ai', l2T, imds.wovoo)
    #tmp -= einsum('bcak, bcik -> ai', imds.wvvvo, l2T)
    #tmp += einsum('baji, bj -> ai', l2T, imds.w3[vloc0:vloc1])

    #tmp_2  = t1T[vloc0:vloc1] - np.dot(tmp1vv[vloc0:vloc1], t1T)
    #tmp_2 -= np.dot(t1T[vloc0:vloc1], mij)
    #tmp_2 += einsum('bcjk, ck -> bj', t2T, l1T)

    #tmp += einsum('baji, bj -> ai', vvoo, tmp_2)
    #tmp_2 = None

    #tmp += einsum('icab, bc -> ai', eris.oxvv, tmp1vv[:, vloc0:vloc1])
    #l1Tnew += mpi.allreduce(tmp)
    #l1Tnew -= mpi.allgather(einsum('jika, kj -> ai', eris.ooox, tmp1oo))

    #tmp = fvo - mpi.allreduce(einsum('bakj, bj -> ak', vvoo, t1T[vloc0:vloc1]))
    #vvoo = None
    #l1Tnew -= np.dot(tmp, mij.T)
    #l1Tnew -= np.dot(mba.T, tmp)

    eia = mo_e_o[:, None] - mo_e_v
    #l1Tnew /= eia.T
    for i in range(vloc0, vloc1):
        l2Tnew[i - vloc0] /= lib.direct_sum('i + jb -> bij', eia[:, i], eia)

    time0 = log.timer_debug1('update l1 l2', *time0)
    return l1Tnew.T, l2Tnew.transpose(2, 3, 0, 1)
예제 #7
0
def update_amps(mycc, t1, t2, eris):
    """
    Update GCCD amplitudes.
    """
    time0 = logger.process_clock(), logger.perf_counter()
    log = logger.Logger(mycc.stdout, mycc.verbose)

    t1T = t1.T
    t2T = np.asarray(t2.transpose(2, 3, 0, 1), order='C')
    nvir_seg, nvir, nocc = t2T.shape[:3]
    t2 = None
    ntasks = mpi.pool.size
    vlocs = [_task_location(nvir, task_id) for task_id in range(ntasks)]
    vloc0, vloc1 = vlocs[rank]
    log.debug2('vlocs %s', vlocs)
    assert vloc1 - vloc0 == nvir_seg

    fock = eris.fock
    fvo = fock[nocc:, :nocc]
    mo_e_o = eris.mo_energy[:nocc]
    mo_e_v = eris.mo_energy[nocc:] + mycc.level_shift

    tauT_tilde = make_tauT(t1T, t2T, fac=0.5, vlocs=vlocs)
    Fvv = cc_Fvv(t1T, t2T, eris, tauT_tilde=tauT_tilde, vlocs=vlocs)
    Foo = cc_Foo(t1T, t2T, eris, tauT_tilde=tauT_tilde, vlocs=vlocs)
    tauT_tilde = None
    Fov = cc_Fov(t1T, eris, vlocs=vlocs)

    # Move energy terms to the other side
    Fvv[np.diag_indices(nvir)] -= mo_e_v
    Foo[np.diag_indices(nocc)] -= mo_e_o

    # T1 equation
    t1Tnew = np.zeros_like(t1T)
    #t1Tnew  = np.dot(Fvv, t1T)
    #t1Tnew -= np.dot(t1T, Foo)

    tmp = einsum('aeim, me -> ai', t2T, Fov)
    #tmp -= np.einsum('fn, naif -> ai', t1T, eris.oxov, optimize=True)
    tmp = mpi.allgather(tmp)

    #tmp2  = einsum('eamn, mnie -> ai', t2T, eris.ooox)
    tmp2 = einsum('eamn, einm -> ai', t2T, eris.xooo)
    #tmp2 += einsum('efim, mafe -> ai', t2T, eris.ovvx)
    tmp2 += einsum('efim, efam -> ai', t2T, eris.xvvo)
    tmp2 *= 0.5
    tmp2 = mpi.allreduce_inplace(tmp2)
    tmp += tmp2
    tmp2 = None

    #t1Tnew += tmp
    #t1Tnew += fvo

    # T2 equation
    Ftmp = Fvv  #- 0.5 * np.dot(t1T, Fov)
    t2Tnew = einsum('aeij, be -> abij', t2T, Ftmp)
    t2T_tmp = mpi.alltoall_new([t2Tnew[:, p0:p1] for p0, p1 in vlocs],
                               split_recvbuf=True)
    for task_id, (p0, p1) in enumerate(vlocs):
        tmp = t2T_tmp[task_id].reshape(p1 - p0, nvir_seg, nocc, nocc)
        t2Tnew[:, p0:p1] -= tmp.transpose(1, 0, 2, 3)
        tmp = None
    t2T_tmp = None

    Ftmp = Foo  #+ 0.5 * np.dot(Fov, t1T)
    tmp = einsum('abim, mj -> abij', t2T, Ftmp)
    t2Tnew -= tmp
    t2Tnew += tmp.transpose(0, 1, 3, 2)
    tmp = None

    t2Tnew += np.asarray(eris.xvoo)
    tauT = make_tauT(t1T, t2T, vlocs=vlocs)
    Woooo = cc_Woooo(t1T, t2T, eris, tauT=tauT, vlocs=vlocs)
    Woooo *= 0.5
    t2Tnew += einsum('abmn, mnij -> abij', tauT, Woooo)
    Woooo = None

    Wvvvv = cc_Wvvvv(t1T, t2T, eris, tauT=tauT, vlocs=vlocs)
    for task_id, tauT_tmp, p0, p1 in _rotate_vir_block(tauT, vlocs=vlocs):
        tmp = einsum('abef, efij -> abij', Wvvvv[:, :, p0:p1], tauT_tmp)
        tmp *= 0.5
        t2Tnew += tmp
        tmp = tauT_tmp = None
    Wvvvv = None
    tauT = None

    #tmp = einsum('mbje, ei -> bmij', eris.oxov, t1T) # [b]mij
    #tmp = mpi.allgather(tmp) # bmij
    #tmp = einsum('am, bmij -> abij', t1T[vloc0:vloc1], tmp) # [a]bij
    tmp = 0.0

    Wvovo = cc_Wovvo(t1T, t2T, eris, vlocs=vlocs).transpose(2, 0, 1, 3)
    for task_id, w_tmp, p0, p1 in _rotate_vir_block(Wvovo, vlocs=vlocs):
        tmp += einsum('aeim, embj -> abij', t2T[:, p0:p1], w_tmp)
        w_tmp = None
    Wvovo = None

    tmp = tmp - tmp.transpose(0, 1, 3, 2)
    t2Tnew += tmp
    tmpT = mpi.alltoall_new([tmp[:, p0:p1] for p0, p1 in vlocs],
                            split_recvbuf=True)
    for task_id, (p0, p1) in enumerate(vlocs):
        tmp = tmpT[task_id].reshape(p1 - p0, nvir_seg, nocc, nocc)
        t2Tnew[:, p0:p1] -= tmp.transpose(1, 0, 2, 3)
        tmp = None
    tmpT = None

    #tmp = einsum('ei, jeba -> abij', t1T, eris.ovvx)
    #t2Tnew += tmp
    #t2Tnew -= tmp.transpose(0, 1, 3, 2)

    #tmp = einsum('am, ijmb -> baij', t1T, eris.ooox.conj())
    #t2Tnew += tmp
    #tmpT = mpi.alltoall([tmp[:, p0:p1] for p0, p1 in vlocs],
    #                    split_recvbuf=True)
    #for task_id, (p0, p1) in enumerate(vlocs):
    #    tmp = tmpT[task_id].reshape(p1-p0, nvir_seg, nocc, nocc)
    #    t2Tnew[:, p0:p1] -= tmp.transpose(1, 0, 2, 3)
    #    tmp = None
    #tmpT = None

    eia = mo_e_o[:, None] - mo_e_v
    #t1Tnew /= eia.T
    for i in range(vloc0, vloc1):
        t2Tnew[i - vloc0] /= lib.direct_sum('i + jb -> bij', eia[:, i], eia)

    time0 = log.timer_debug1('update t1 t2', *time0)
    return t1Tnew.T, t2Tnew.transpose(2, 3, 0, 1)
예제 #8
0
def _make_eris_outcore(mycc, mo_coeff=None):
    cput0 = (time.clock(), time.time())
    log = logger.Logger(mycc.stdout, mycc.verbose)
    _sync_(mycc)
    eris = ccsd._ChemistsERIs()
    if rank == 0:
        eris._common_init_(mycc, mo_coeff)
        comm.bcast((eris.mo_coeff, eris.fock, eris.nocc, eris.mo_energy))
    else:
        eris.mol = mycc.mol
        eris.mo_coeff, eris.fock, eris.nocc, eris.mo_energy = comm.bcast(None)

    mol = mycc.mol
    mo_coeff = numpy.asarray(eris.mo_coeff, order='F')
    nocc = eris.nocc
    nao, nmo = mo_coeff.shape
    nvir = nmo - nocc
    orbo = mo_coeff[:, :nocc]
    orbv = mo_coeff[:, nocc:]
    nvpair = nvir * (nvir + 1) // 2
    vlocs = [_task_location(nvir, task_id) for task_id in range(mpi.pool.size)]
    vloc0, vloc1 = vlocs[rank]
    vseg = vloc1 - vloc0

    eris.feri1 = lib.H5TmpFile()
    eris.oooo = eris.feri1.create_dataset('oooo', (nocc, nocc, nocc, nocc),
                                          'f8')
    eris.oovv = eris.feri1.create_dataset('oovv', (nocc, nocc, vseg, nvir),
                                          'f8',
                                          chunks=(nocc, nocc, 1, nvir))
    eris.ovoo = eris.feri1.create_dataset('ovoo', (nocc, vseg, nocc, nocc),
                                          'f8',
                                          chunks=(nocc, 1, nocc, nocc))
    eris.ovvo = eris.feri1.create_dataset('ovvo', (nocc, vseg, nvir, nocc),
                                          'f8',
                                          chunks=(nocc, 1, nvir, nocc))
    eris.ovov = eris.feri1.create_dataset('ovov', (nocc, vseg, nocc, nvir),
                                          'f8',
                                          chunks=(nocc, 1, nocc, nvir))
    #    eris.ovvv = eris.feri1.create_dataset('ovvv', (nocc,vseg,nvpair), 'f8', chunks=(nocc,1,nvpair))
    eris.vvvo = eris.feri1.create_dataset('vvvo', (vseg, nvir, nvir, nocc),
                                          'f8',
                                          chunks=(1, nvir, 1, nocc))
    assert (mycc.direct)

    def save_occ_frac(p0, p1, eri):
        eri = eri.reshape(p1 - p0, nocc, nmo, nmo)
        eris.oooo[p0:p1] = eri[:, :, :nocc, :nocc]
        eris.oovv[p0:p1] = eri[:, :, nocc + vloc0:nocc + vloc1, nocc:]

    def save_vir_frac(p0, p1, eri):
        log.alldebug1('save_vir_frac %d %d %s', p0, p1, eri.shape)
        eri = eri.reshape(p1 - p0, nocc, nmo, nmo)
        eris.ovoo[:, p0:p1] = eri[:, :, :nocc, :nocc].transpose(1, 0, 2, 3)
        eris.ovvo[:, p0:p1] = eri[:, :, nocc:, :nocc].transpose(1, 0, 2, 3)
        eris.ovov[:, p0:p1] = eri[:, :, :nocc, nocc:].transpose(1, 0, 2, 3)
        #        vvv = lib.pack_tril(eri[:,:,nocc:,nocc:].reshape((p1-p0)*nocc,nvir,nvir))
        #        eris.ovvv[:,p0:p1] = vvv.reshape(p1-p0,nocc,nvpair).transpose(1,0,2)

        cput2 = time.clock(), time.time()
        ovvv_segs = [
            eri[:, :, nocc + q0:nocc + q1, nocc:].transpose(2, 3, 0, 1)
            for q0, q1 in vlocs
        ]
        ovvv_segs = mpi.alltoall(ovvv_segs, split_recvbuf=True)
        cput2 = log.timer_debug1('vvvo alltoall', *cput2)
        for task_id, (q0, q1) in enumerate(comm.allgather((p0, p1))):
            ip0 = q0 + vlocs[task_id][0]
            ip1 = q1 + vlocs[task_id][0]
            eris.vvvo[:, :, ip0:ip1] = ovvv_segs[task_id].reshape(
                vseg, nvir, q1 - q0, nocc)

    fswap = lib.H5TmpFile()
    max_memory = max(MEMORYMIN, mycc.max_memory - lib.current_memory()[0])
    int2e = mol._add_suffix('int2e')
    orbov = numpy.hstack((orbo, orbv[:, vloc0:vloc1]))
    ao2mo.outcore.half_e1(mol, (orbov, orbo),
                          fswap,
                          int2e,
                          's4',
                          1,
                          max_memory,
                          verbose=log)

    ao_loc = mol.ao_loc_nr()
    nao_pair = nao * (nao + 1) // 2
    blksize = int(min(8e9, max_memory * .5e6) / 8 / (nao_pair + nmo**2) / nocc)
    blksize = min(nvir, max(BLKMIN, blksize))
    fload = ao2mo.outcore._load_from_h5g

    buf = numpy.empty((blksize * nocc, nao_pair))
    buf_prefetch = numpy.empty_like(buf)

    def prefetch(p0, p1, rowmax):
        p0, p1 = p1, min(rowmax, p1 + blksize)
        if p0 < p1:
            fload(fswap['0'], p0 * nocc, p1 * nocc, buf_prefetch)

    cput1 = time.clock(), time.time()
    outbuf = numpy.empty((blksize * nocc, nmo**2))
    with lib.call_in_background(prefetch) as bprefetch:
        fload(fswap['0'], 0, min(nocc, blksize) * nocc, buf_prefetch)
        for p0, p1 in lib.prange(0, nocc, blksize):
            nrow = (p1 - p0) * nocc
            buf, buf_prefetch = buf_prefetch, buf
            bprefetch(p0, p1, nocc)
            dat = ao2mo._ao2mo.nr_e2(buf[:nrow],
                                     mo_coeff, (0, nmo, 0, nmo),
                                     's4',
                                     's1',
                                     out=outbuf,
                                     ao_loc=ao_loc)
            save_occ_frac(p0, p1, dat)

        blksize = min(comm.allgather(blksize))
        norb_max = nocc + vseg
        fload(fswap['0'], nocc**2,
              min(nocc + blksize, norb_max) * nocc, buf_prefetch)
        for p0, p1 in mpi.prange(vloc0, vloc1, blksize):
            i0, i1 = p0 - vloc0, p1 - vloc0
            nrow = (p1 - p0) * nocc
            buf, buf_prefetch = buf_prefetch, buf
            bprefetch(nocc + i0, nocc + i1, norb_max)
            dat = ao2mo._ao2mo.nr_e2(buf[:nrow],
                                     mo_coeff, (0, nmo, 0, nmo),
                                     's4',
                                     's1',
                                     out=outbuf,
                                     ao_loc=ao_loc)
            save_vir_frac(i0, i1, dat)
    buf = buf_prefecth = outbuf = None

    cput1 = log.timer_debug1('transforming oppp', *cput1)
    log.timer('CCSD integral transformation', *cput0)
    mycc._eris = eris
    return eris
예제 #9
0
def update_amps(mycc, t1, t2, eris):
    time1 = time0 = time.clock(), time.time()
    log = logger.Logger(mycc.stdout, mycc.verbose)
    cpu1 = time0

    t1T = t1.T
    t2T = numpy.asarray(t2.transpose(2, 3, 0, 1), order='C')
    nvir_seg, nvir, nocc = t2T.shape[:3]
    t1 = t2 = None
    ntasks = mpi.pool.size
    vlocs = [_task_location(nvir, task_id) for task_id in range(ntasks)]
    vloc0, vloc1 = vlocs[rank]
    log.debug2('vlocs %s', vlocs)
    assert (vloc1 - vloc0 == nvir_seg)

    fock = eris.fock
    mo_e_o = eris.mo_energy[:nocc]
    mo_e_v = eris.mo_energy[nocc:] + mycc.level_shift

    def _rotate_vir_block(buf):
        for task_id, buf in _rotate_tensor_block(buf):
            loc0, loc1 = vlocs[task_id]
            yield task_id, buf, loc0, loc1

    fswap = lib.H5TmpFile()
    wVooV = numpy.zeros((nvir_seg, nocc, nocc, nvir))
    eris_voov = _cp(eris.ovvo).transpose(1, 0, 3, 2)
    tau = t2T * .5
    tau += numpy.einsum('ai,bj->abij', t1T[vloc0:vloc1], t1T)
    for task_id, tau, p0, p1 in _rotate_vir_block(tau):
        wVooV += lib.einsum('bkic,cajk->bija', eris_voov[:, :, :, p0:p1], tau)
    fswap['wVooV1'] = wVooV
    wVooV = tau = None
    time1 = log.timer_debug1('wVooV', *time1)

    wVOov = eris_voov
    eris_VOov = eris_voov - eris_voov.transpose(0, 2, 1, 3) * .5
    tau = t2T.transpose(2, 0, 3, 1) - t2T.transpose(3, 0, 2, 1) * .5
    tau -= numpy.einsum('ai,bj->jaib', t1T[vloc0:vloc1], t1T)
    for task_id, tau, p0, p1 in _rotate_vir_block(tau):
        wVOov += lib.einsum('dlkc,kcjb->dljb', eris_VOov[:, :, :, p0:p1], tau)
    fswap['wVOov1'] = wVOov
    wVOov = tau = eris_VOov = eris_voov = None
    time1 = log.timer_debug1('wVOov', *time1)

    t1Tnew = numpy.zeros_like(t1T)
    t2Tnew = mycc._add_vvvv(t1T, t2T, eris, t2sym='jiba')
    time1 = log.timer_debug1('vvvv', *time1)

    #** make_inter_F
    fov = fock[:nocc, nocc:].copy()
    t1Tnew += fock[nocc:, :nocc]

    foo = fock[:nocc, :nocc] - numpy.diag(mo_e_o)
    foo += .5 * numpy.einsum('ia,aj->ij', fock[:nocc, nocc:], t1T)

    fvv = fock[nocc:, nocc:] - numpy.diag(mo_e_v)
    fvv -= .5 * numpy.einsum('ai,ib->ab', t1T, fock[:nocc, nocc:])

    foo_priv = numpy.zeros_like(foo)
    fov_priv = numpy.zeros_like(fov)
    fvv_priv = numpy.zeros_like(fvv)
    t1T_priv = numpy.zeros_like(t1T)

    max_memory = mycc.max_memory - lib.current_memory()[0]
    unit = nocc * nvir**2 * 3 + nocc**2 * nvir + 1
    blksize = min(nvir,
                  max(BLKMIN, int((max_memory * .9e6 / 8 - t2T.size) / unit)))
    log.debug1('pass 1, max_memory %d MB,  nocc,nvir = %d,%d  blksize = %d',
               max_memory, nocc, nvir, blksize)

    buf = numpy.empty((blksize, nvir, nvir, nocc))

    def load_vvvo(p0):
        p1 = min(nvir_seg, p0 + blksize)
        if p0 < p1:
            buf[:p1 - p0] = eris.vvvo[p0:p1]

    fswap.create_dataset('wVooV', (nvir_seg, nocc, nocc, nvir), 'f8')
    wVOov = []

    with lib.call_in_background(load_vvvo) as prefetch:
        load_vvvo(0)
        for p0, p1 in lib.prange(vloc0, vloc1, blksize):
            i0, i1 = p0 - vloc0, p1 - vloc0
            eris_vvvo, buf = buf[:p1 - p0], numpy.empty_like(buf)
            prefetch(i1)

            fvv_priv[p0:p1] += 2 * numpy.einsum('ck,abck->ab', t1T, eris_vvvo)
            fvv_priv -= numpy.einsum('ck,cabk->ab', t1T[p0:p1], eris_vvvo)

            if not mycc.direct:
                raise NotImplementedError
                tau = t2T[i0:i1] + numpy.einsum('ai,bj->abij', t1T[p0:p1], t1T)
                for task_id, tau, q0, q1 in _rotate_vir_block(tau):
                    tmp = lib.einsum('bdck,cdij->bkij', eris_vvvo[:, :, q0:q1],
                                     tau)
                    t2Tnew -= lib.einsum('ak,bkij->baji', t1T, tmp)
                tau = tmp = None

            fswap['wVooV'][i0:i1] = lib.einsum('cj,baci->bija', -t1T,
                                               eris_vvvo)

            theta = t2T[i0:i1].transpose(0, 2, 1, 3) * 2
            theta -= t2T[i0:i1].transpose(0, 3, 1, 2)
            t1T_priv += lib.einsum('bicj,bacj->ai', theta, eris_vvvo)
            wVOov.append(lib.einsum('acbi,cj->abij', eris_vvvo, t1T))
            theta = eris_vvvo = None
            time1 = log.timer_debug1('vvvo [%d:%d]' % (p0, p1), *time1)

    wVOov = numpy.vstack(wVOov)
    wVOov = mpi.alltoall([wVOov[:, q0:q1] for q0, q1 in vlocs],
                         split_recvbuf=True)
    wVOov = numpy.vstack([x.reshape(-1, nvir_seg, nocc, nocc) for x in wVOov])
    fswap['wVOov'] = wVOov.transpose(1, 2, 3, 0)
    wVooV = None

    unit = nocc**2 * nvir * 7 + nocc**3 + nocc * nvir**2
    max_memory = max(0, mycc.max_memory - lib.current_memory()[0])
    blksize = min(nvir,
                  max(BLKMIN, int((max_memory * .9e6 / 8 - nocc**4) / unit)))
    log.debug1('pass 2, max_memory %d MB,  nocc,nvir = %d,%d  blksize = %d',
               max_memory, nocc, nvir, blksize)

    woooo = numpy.zeros((nocc, nocc, nocc, nocc))

    for p0, p1 in lib.prange(vloc0, vloc1, blksize):
        i0, i1 = p0 - vloc0, p1 - vloc0
        wVOov = fswap['wVOov'][i0:i1]
        wVooV = fswap['wVooV'][i0:i1]
        eris_ovoo = eris.ovoo[:, i0:i1]
        eris_oovv = numpy.empty((nocc, nocc, i1 - i0, nvir))

        def load_oovv(p0, p1):
            eris_oovv[:] = eris.oovv[:, :, p0:p1]

        with lib.call_in_background(load_oovv) as prefetch_oovv:
            #:eris_oovv = eris.oovv[:,:,i0:i1]
            prefetch_oovv(i0, i1)
            foo_priv += numpy.einsum('ck,kcji->ij', 2 * t1T[p0:p1], eris_ovoo)
            foo_priv += numpy.einsum('ck,icjk->ij', -t1T[p0:p1], eris_ovoo)
            tmp = lib.einsum('al,jaik->lkji', t1T[p0:p1], eris_ovoo)
            woooo += tmp + tmp.transpose(1, 0, 3, 2)
            tmp = None

            wVOov -= lib.einsum('jbik,ak->bjia', eris_ovoo, t1T)
            t2Tnew[i0:i1] += wVOov.transpose(0, 3, 1, 2)

            wVooV += lib.einsum('kbij,ak->bija', eris_ovoo, t1T)
            eris_ovoo = None
        load_oovv = prefetch_oovv = None

        eris_ovvo = numpy.empty((nocc, i1 - i0, nvir, nocc))

        def load_ovvo(p0, p1):
            eris_ovvo[:] = eris.ovvo[:, p0:p1]

        with lib.call_in_background(load_ovvo) as prefetch_ovvo:
            #:eris_ovvo = eris.ovvo[:,i0:i1]
            prefetch_ovvo(i0, i1)
            t1T_priv[p0:p1] -= numpy.einsum('bj,jiab->ai', t1T, eris_oovv)
            wVooV -= eris_oovv.transpose(2, 0, 1, 3)
            wVOov += wVooV * .5  #: bjia + bija*.5
        eris_voov = eris_ovvo.transpose(1, 0, 3, 2)
        eris_ovvo = None
        load_ovvo = prefetch_ovvo = None

        def update_wVooV(i0, i1):
            wVooV[:] += fswap['wVooV1'][i0:i1]
            fswap['wVooV1'][i0:i1] = wVooV
            wVOov[:] += fswap['wVOov1'][i0:i1]
            fswap['wVOov1'][i0:i1] = wVOov

        with lib.call_in_background(update_wVooV) as update_wVooV:
            update_wVooV(i0, i1)
            t2Tnew[i0:i1] += eris_voov.transpose(0, 3, 1, 2) * .5
            t1T_priv[p0:p1] += 2 * numpy.einsum('bj,aijb->ai', t1T, eris_voov)

            tmp = lib.einsum('ci,kjbc->bijk', t1T, eris_oovv)
            tmp += lib.einsum('bjkc,ci->bjik', eris_voov, t1T)
            t2Tnew[i0:i1] -= lib.einsum('bjik,ak->baji', tmp, t1T)
            eris_oovv = tmp = None

            fov_priv[:,
                     p0:p1] += numpy.einsum('ck,aikc->ia', t1T, eris_voov) * 2
            fov_priv[:, p0:p1] -= numpy.einsum('ck,akic->ia', t1T, eris_voov)

            tau = numpy.einsum('ai,bj->abij', t1T[p0:p1] * .5, t1T)
            tau += t2T[i0:i1]
            theta = tau.transpose(0, 1, 3, 2) * 2
            theta -= tau
            fvv_priv -= lib.einsum('caij,cjib->ab', theta, eris_voov)
            foo_priv += lib.einsum('aikb,abkj->ij', eris_voov, theta)
            tau = theta = None

            tau = t2T[i0:i1] + numpy.einsum('ai,bj->abij', t1T[p0:p1], t1T)
            woooo += lib.einsum('abij,aklb->ijkl', tau, eris_voov)
            tau = None
        eris_VOov = wVOov = wVooV = update_wVooV = None
        time1 = log.timer_debug1('voov [%d:%d]' % (p0, p1), *time1)

    wVooV = _cp(fswap['wVooV1'])
    for task_id, wVooV, p0, p1 in _rotate_vir_block(wVooV):
        tmp = lib.einsum('ackj,ckib->ajbi', t2T[:, p0:p1], wVooV)
        t2Tnew += tmp.transpose(0, 2, 3, 1)
        t2Tnew += tmp.transpose(0, 2, 1, 3) * .5
    wVooV = tmp = None
    time1 = log.timer_debug1('contracting wVooV', *time1)

    wVOov = _cp(fswap['wVOov1'])
    theta = t2T * 2
    theta -= t2T.transpose(0, 1, 3, 2)
    for task_id, wVOov, p0, p1 in _rotate_vir_block(wVOov):
        t2Tnew += lib.einsum('acik,ckjb->abij', theta[:, p0:p1], wVOov)
    wVOov = theta = None
    fswap = None
    time1 = log.timer_debug1('contracting wVOov', *time1)

    foo += mpi.allreduce(foo_priv)
    fov += mpi.allreduce(fov_priv)
    fvv += mpi.allreduce(fvv_priv)

    theta = t2T.transpose(0, 1, 3, 2) * 2 - t2T
    t1T_priv[vloc0:vloc1] += numpy.einsum('jb,abji->ai', fov, theta)
    ovoo = _cp(eris.ovoo)
    for task_id, ovoo, p0, p1 in _rotate_vir_block(ovoo):
        t1T_priv[vloc0:vloc1] -= lib.einsum('jbki,abjk->ai', ovoo,
                                            theta[:, p0:p1])
    theta = ovoo = None

    woooo = mpi.allreduce(woooo)
    woooo += _cp(eris.oooo).transpose(0, 2, 1, 3)
    tau = t2T + numpy.einsum('ai,bj->abij', t1T[vloc0:vloc1], t1T)
    t2Tnew += .5 * lib.einsum('abkl,ijkl->abij', tau, woooo)
    tau = woooo = None

    t1Tnew += mpi.allreduce(t1T_priv)

    ft_ij = foo + numpy.einsum('aj,ia->ij', .5 * t1T, fov)
    ft_ab = fvv - numpy.einsum('ai,ib->ab', .5 * t1T, fov)
    t2Tnew += lib.einsum('acij,bc->abij', t2T, ft_ab)
    t2Tnew -= lib.einsum('ki,abkj->abij', ft_ij, t2T)

    eia = mo_e_o[:, None] - mo_e_v
    t1Tnew += numpy.einsum('bi,ab->ai', t1T, fvv)
    t1Tnew -= numpy.einsum('aj,ji->ai', t1T, foo)
    t1Tnew /= eia.T

    t2tmp = mpi.alltoall([t2Tnew[:, p0:p1] for p0, p1 in vlocs],
                         split_recvbuf=True)
    for task_id, (p0, p1) in enumerate(vlocs):
        tmp = t2tmp[task_id].reshape(p1 - p0, nvir_seg, nocc, nocc)
        t2Tnew[:, p0:p1] += tmp.transpose(1, 0, 3, 2)

    for i in range(vloc0, vloc1):
        t2Tnew[i - vloc0] /= lib.direct_sum('i+jb->bij', eia[:, i], eia)

    time0 = log.timer_debug1('update t1 t2', *time0)
    return t1Tnew.T, t2Tnew.transpose(2, 3, 0, 1)
예제 #10
0
def _add_vvvv_tril(mycc, t1T, t2T, eris, out=None, with_ovvv=None):
    '''Ht2 = numpy.einsum('ijcd,acdb->ijab', t2, vvvv)
    Using symmetry t2[ijab] = t2[jiba] and Ht2[ijab] = Ht2[jiba], compute the
    lower triangular part of  Ht2
    '''
    time0 = time.clock(), time.time()
    log = logger.Logger(mycc.stdout, mycc.verbose)
    if with_ovvv is None:
        with_ovvv = mycc.direct
    nvir_seg, nvir, nocc = t2T.shape[:3]
    vloc0, vloc1 = _task_location(nvir, rank)
    nocc2 = nocc * (nocc + 1) // 2
    if t1T is None:
        tau = lib.pack_tril(t2T.reshape(nvir_seg * nvir, nocc, nocc))
    else:
        tau = t2T + numpy.einsum('ai,bj->abij', t1T[vloc0:vloc1], t1T)
        tau = lib.pack_tril(tau.reshape(nvir_seg * nvir, nocc, nocc))
    tau = tau.reshape(nvir_seg, nvir, nocc2)

    if mycc.direct:  # AO-direct CCSD
        mo = getattr(eris, 'mo_coeff', None)
        if mo is None:  # If eris does not have the attribute mo_coeff
            mo = _mo_without_core(mycc, mycc.mo_coeff)

        tau_shape = tau.shape
        ao_loc = mycc.mol.ao_loc_nr()
        orbv = mo[:, nocc:]
        nao, nvir = orbv.shape

        ntasks = mpi.pool.size
        task_sh_locs = lib.misc._balanced_partition(ao_loc, ntasks)
        ao_loc0 = ao_loc[task_sh_locs[rank]]
        ao_loc1 = ao_loc[task_sh_locs[rank + 1]]

        tau = lib.einsum('pb,abx->apx', orbv, tau)
        tau_priv = numpy.zeros((ao_loc1 - ao_loc0, nao, nocc2))
        for task_id, tau in _rotate_tensor_block(tau):
            loc0, loc1 = _task_location(nvir, task_id)
            tau_priv += lib.einsum('pa,abx->pbx', orbv[ao_loc0:ao_loc1,
                                                       loc0:loc1], tau)
        tau = None
        time1 = log.timer_debug1('vvvv-tau mo2ao', *time0)

        buf = _contract_vvvv_t2(mycc, None, tau_priv, task_sh_locs, None, log)
        buf = buf_ao = buf.reshape(tau_priv.shape)
        tau_priv = None
        time1 = log.timer_debug1('vvvv-tau contraction', *time1)

        buf = lib.einsum('apx,pb->abx', buf, orbv)
        Ht2tril = numpy.ndarray((nvir_seg, nvir, nocc2), buffer=out)
        Ht2tril[:] = 0
        for task_id, buf in _rotate_tensor_block(buf):
            ao_loc0 = ao_loc[task_sh_locs[task_id]]
            ao_loc1 = ao_loc[task_sh_locs[task_id + 1]]
            Ht2tril += lib.einsum('pa,pbx->abx', orbv[ao_loc0:ao_loc1,
                                                      vloc0:vloc1], buf)
        time1 = log.timer_debug1('vvvv-tau ao2mo', *time1)

        if with_ovvv:
            #: tmp = numpy.einsum('ijcd,ak,kdcb->ijba', tau, t1T, eris.ovvv)
            #: t2new -= tmp + tmp.transpose(1,0,3,2)
            orbo = mo[:, :nocc]
            buf = lib.einsum('apx,pi->axi', buf_ao, orbo)
            tmp = numpy.zeros((nvir_seg, nocc2, nocc))
            for task_id, buf in _rotate_tensor_block(buf):
                ao_loc0 = ao_loc[task_sh_locs[task_id]]
                ao_loc1 = ao_loc[task_sh_locs[task_id + 1]]
                tmp += lib.einsum('pa,pxi->axi', orbv[ao_loc0:ao_loc1,
                                                      vloc0:vloc1], buf)
            Ht2tril -= lib.einsum('axi,bi->abx', tmp, t1T)
            tmp = buf = None

            t1_ao = numpy.dot(orbo, t1T[vloc0:vloc1].T)
            buf = lib.einsum('apx,pb->abx', buf_ao, orbv)
            for task_id, buf in _rotate_tensor_block(buf):
                ao_loc0 = ao_loc[task_sh_locs[task_id]]
                ao_loc1 = ao_loc[task_sh_locs[task_id + 1]]
                Ht2tril -= lib.einsum('pa,pbx->abx', t1_ao[ao_loc0:ao_loc1],
                                      buf)
        time1 = log.timer_debug1('contracting vvvv-tau', *time0)
    else:
        raise NotImplementedError
    return Ht2tril
예제 #11
0
def _make_eris_incore_ghf(mycc, mo_coeff=None, ao2mofn=None):
    """
    Make physist eri with incore ao2mo, for GGHF.
    """
    cput0 = (logger.process_clock(), logger.perf_counter())
    log = logger.Logger(mycc.stdout, mycc.verbose)
    _sync_(mycc)
    eris = gccsd._PhysicistsERIs()
    
    if rank == 0:
        eris._common_init_(mycc, mo_coeff)
        comm.bcast((eris.mo_coeff, eris.fock, eris.nocc, eris.mo_energy))
    else:
        eris.mol = mycc.mol
        eris.mo_coeff, eris.fock, eris.nocc, eris.mo_energy = comm.bcast(None)
    
    nocc = eris.nocc
    nao, nmo = eris.mo_coeff.shape

    nvir = nmo - nocc
    vlocs = [_task_location(nvir, task_id) for task_id in range(mpi.pool.size)]
    vloc0, vloc1 = vlocs[rank]
    vseg = vloc1 - vloc0
    
    if rank == 0:
        if callable(ao2mofn):
            raise NotImplementedError
        else:
            assert eris.mo_coeff.dtype == np.double
            eri = mycc._scf._eri
            if (nao == nmo) and (la.norm(eris.mo_coeff - np.eye(nmo)) < 1e-12):
                # ZHC NOTE special treatment for OO-CCD,
                # where the ao2mo is not needed for identity mo_coeff.
                from libdmet.utils import take_eri as fn
                o = np.arange(0, nocc)
                v = np.arange(nocc, nmo)
                if eri.size == nmo**4:
                    eri = ao2mo.restore(8, eri, nmo)
            else:
                if mycc.save_mem:
                    # ZHC NOTE the following is slower, although may save some memory.
                    def fn(x, mo0, mo1, mo2, mo3):
                        return ao2mo.general(x, (mo0, mo1, mo2, mo3),
                                             compact=False).reshape(mo0.shape[-1], mo1.shape[-1],
                                                                    mo2.shape[-1], mo3.shape[-1])
                    o = eris.mo_coeff[:, :nocc]
                    v = eris.mo_coeff[:, nocc:]
                    if eri.size == nao**4:
                        eri = ao2mo.restore(8, eri, nao)
                else:
                    from libdmet.utils import take_eri as fn
                    o = np.arange(0, nocc)
                    v = np.arange(nocc, nmo)
                    if mycc.remove_h2:
                        mycc._scf._eri = None
                        _release_regs(mycc, remove_h2=True)
                    eri = ao2mo.kernel(eri, eris.mo_coeff)
                    if eri.size == nmo**4:
                        eri = ao2mo.restore(8, eri, nmo)

    comm.Barrier()
    cput2 = log.timer('CCSD ao2mo initialization:     ', *cput0)
    
    # chunck and scatter:
    
    # 1. oooo
    if rank == 0:
        tmp = fn(eri, o, o, o, o)
        eris.oooo = tmp.transpose(0, 2, 1, 3) - tmp.transpose(0, 2, 3, 1)
        tmp = None
        mpi.bcast(eris.oooo)
    else:
        eris.oooo = mpi.bcast(None)
    cput3 = log.timer('CCSD bcast   oooo:              ', *cput2)
    
    # 2. xooo
    if rank == 0:
        tmp = fn(eri, v, o, o, o)
        eri_sliced = [tmp[p0:p1] for (p0, p1) in vlocs]
    else:
        tmp = None
        eri_sliced = None
    tmp = mpi.scatter_new(eri_sliced, root=0, data=tmp)
    eri_sliced = None
    eris.xooo = tmp.transpose(0, 2, 1, 3) - tmp.transpose(0, 2, 3, 1)
    tmp = None
    cput4 = log.timer('CCSD scatter xooo:              ', *cput3)
    
    # 3. xovo
    if rank == 0:
        tmp_vvoo = fn(eri, v, v, o, o)
        tmp_voov = fn(eri, v, o, o, v)
        # ZHC NOTE need to keep tmp_voov for xvoo
        eri_1 = [tmp_vvoo[p0:p1] for (p0, p1) in vlocs]
        eri_2 = [tmp_voov[p0:p1] for (p0, p1) in vlocs]
    else:
        tmp_vvoo = None
        tmp_voov = None
        eri_1 = None
        eri_2 = None

    tmp_1 = mpi.scatter_new(eri_1, root=0, data=tmp_vvoo)
    eri_1 = None
    tmp_vvoo = None
    
    tmp_2 = mpi.scatter_new(eri_2, root=0, data=tmp_voov)
    eri_2 = None
    tmp_voov = None
    
    eris.xovo = tmp_1.transpose(0, 2, 1, 3) - tmp_2.transpose(0, 2, 3, 1)
    tmp_1 = None
    cput5 = log.timer('CCSD scatter xovo:              ', *cput4)
    
    # 4. xvoo
    eris.xvoo = tmp_2.transpose(0, 3, 1, 2) - tmp_2.transpose(0, 3, 2, 1)
    tmp_2 = None
    cput6 = log.timer('CCSD scatter xvoo:              ', *cput5)
    
    # 5. 6. xovv, xvvo
    if rank == 0:
        tmp = fn(eri, v, v, o, v)
        eri_sliced = [tmp[p0:p1] for (p0, p1) in vlocs]
    else:
        tmp = None
        eri_sliced = None
    tmp_1 = mpi.scatter_new(eri_sliced, root=0, data=tmp)
    eri_sliced = None
    eris.xovv = tmp_1.transpose(0, 2, 1, 3) - tmp_1.transpose(0, 2, 3, 1)

    if rank == 0:
        tmp_2 = np.asarray(tmp.transpose(3, 2, 1, 0), order='C') # vovv
        tmp = None
        eri_sliced = [tmp_2[p0:p1] for (p0, p1) in vlocs]
    else:
        tmp_2 = None
        tmp = None
        eri_sliced = None
    tmp_2 = mpi.scatter_new(eri_sliced, root=0, data=tmp_2)
    eri_sliced = None
    
    eris.xvvo = tmp_1.transpose(0, 3, 1, 2) - tmp_2.transpose(0, 2, 3, 1)
    tmp_1 = None
    tmp_2 = None
    cput7 = log.timer('CCSD scatter xovv, xvvo:        ', *cput6)

    # 7. xvvv
    if rank == 0:
        tmp = fn(eri, v, v, v, v)
        if mycc.remove_h2:
            eri = None
            if mycc._scf is not None:
                mycc._scf._eri = None
        eri_sliced = [tmp[p0:p1] for (p0, p1) in vlocs]
    else:
        tmp = None
        eri_sliced = None
    tmp = mpi.scatter_new(eri_sliced, root=0, data=tmp)
    eri_sliced = None
    eris.xvvv = tmp.transpose(0, 2, 1, 3) - tmp.transpose(0, 2, 3, 1)
    tmp = None
    eri = None
    cput8 = log.timer('CCSD scatter xvvv:              ', *cput7)
    
    mycc._eris = eris
    log.timer('CCSD integral transformation   ', *cput0)
    return eris
예제 #12
0
def _make_eris_incore(mycc, mo_coeff=None, ao2mofn=None):
    """
    Make physist eri with incore ao2mo.
    """
    cput0 = (logger.process_clock(), logger.perf_counter())
    log = logger.Logger(mycc.stdout, mycc.verbose)
    _sync_(mycc)
    eris = gccsd._PhysicistsERIs()
    
    if rank == 0:
        eris._common_init_(mycc, mo_coeff)
        comm.bcast((eris.mo_coeff, eris.fock, eris.nocc, eris.mo_energy))
    else:
        eris.mol = mycc.mol
        eris.mo_coeff, eris.fock, eris.nocc, eris.mo_energy = comm.bcast(None)
    
    # if workers does not have _eri, bcast from root
    if comm.allreduce(mycc._scf._eri is None, op=mpi.MPI.LOR):
        if rank == 0:
            mpi.bcast(mycc._scf._eri)
        else:
            mycc._scf._eri = mpi.bcast(None)
    cput1 = log.timer('CCSD ao2mo initialization:     ', *cput0)

    nocc = eris.nocc
    nao, nmo = eris.mo_coeff.shape
    nvir = nmo - nocc
    vlocs = [_task_location(nvir, task_id) for task_id in range(mpi.pool.size)]
    vloc0, vloc1 = vlocs[rank]
    vseg = vloc1 - vloc0
    
    plocs = [_task_location(nmo, task_id) for task_id in range(mpi.pool.size)]
    ploc0, ploc1 = plocs[rank]
    pseg = ploc1 - ploc0
    
    mo_a = eris.mo_coeff[:nao//2]
    mo_b = eris.mo_coeff[nao//2:]
    mo_seg_a = mo_a[:, ploc0:ploc1]
    mo_seg_b = mo_b[:, ploc0:ploc1]
    
    fname = "gccsd_eri_tmp_%s.h5"%rank
    f = h5py.File(fname, 'w')
    eri_phys = f.create_dataset('eri_phys', (pseg, nmo, nmo, nmo), 'f8', 
                                chunks=(pseg, 1, nmo, nmo))
    
    eri_a = ao2mo.incore.half_e1(mycc._scf._eri, (mo_seg_a, mo_a), compact=False)
    eri_b = ao2mo.incore.half_e1(mycc._scf._eri, (mo_seg_b, mo_b), compact=False)
    cput1 = log.timer('CCSD ao2mo half_e1:            ', *cput1)

    unit = pseg * nmo * nmo * 2
    mem_now = lib.current_memory()[0]
    max_memory = max(0, mycc.max_memory - mem_now)
    blksize = min(nmo, max(BLKMIN, int((max_memory*0.9e6/8)/unit)))

    for p0, p1 in lib.prange(0, nmo, blksize):
        klmosym_a, nkl_pair_a, mokl_a, klshape_a = \
                ao2mo.incore._conc_mos(mo_a[:, p0:p1], mo_a, compact=False)
        klmosym_b, nkl_pair_b, mokl_b, klshape_b = \
                ao2mo.incore._conc_mos(mo_b[:, p0:p1], mo_b, compact=False)
        
        eri  = _ao2mo.nr_e2(eri_a, mokl_a, klshape_a, aosym='s4', mosym=klmosym_a)
        eri += _ao2mo.nr_e2(eri_a, mokl_b, klshape_b, aosym='s4', mosym=klmosym_b)
        eri += _ao2mo.nr_e2(eri_b, mokl_a, klshape_a, aosym='s4', mosym=klmosym_a)
        eri += _ao2mo.nr_e2(eri_b, mokl_b, klshape_b, aosym='s4', mosym=klmosym_b)
        
        eri = eri.reshape(pseg, nmo, p1-p0, nmo)
        eri_phys[:, p0:p1] = eri.transpose(0, 2, 1, 3) - eri.transpose(0, 2, 3, 1)
        eri = None
    eri_a = None
    eri_b = None
    
    f.close()
    comm.Barrier()
    cput1 = log.timer('CCSD ao2mo nr_e2:              ', *cput1)

    o_idx = -1
    v_idx = mpi.pool.size
    for r, (p0, p1) in enumerate(plocs):
        if p0 <= nocc - 1 < p1:
            o_idx = r
        if p0 <= nocc < p1:
            v_idx = r
            break
    o_files = np.arange(mpi.pool.size)[:(o_idx+1)]
    v_files = np.arange(mpi.pool.size)[v_idx:]

    eris.oooo = np.empty((nocc, nocc, nocc, nocc))
    eris.xooo = np.empty((vseg, nocc, nocc, nocc))
    eris.xovo = np.empty((vseg, nocc, nvir, nocc))
    eris.xovv = np.empty((vseg, nocc, nvir, nvir))
    eris.xvvo = np.empty((vseg, nvir, nvir, nocc))
    eris.xvoo = np.empty((vseg, nvir, nocc, nocc))
    eris.xvvv = np.empty((vseg, nvir, nvir, nvir))
    for r in range(mpi.pool.size):
        f = lib.H5TmpFile(filename="gccsd_eri_tmp_%s.h5"%r, mode='r')
        eri_phys = f["eri_phys"]
        if r in o_files:
            p0, p1 = plocs[r]
            p1 = min(p1, nocc)
            pseg = p1 - p0
            if pseg > 0:
                eris.oooo[p0:p1] = eri_phys[:pseg, :nocc, :nocc, :nocc]
        
        if r in v_files:
            p00, p10 = plocs[r]
            p0 = max(p00, nocc+vloc0)
            p1 = min(p10, nocc+vloc1)
            pseg = p1 - p0
            if pseg > 0:
                eris.xooo[p0-(nocc+vloc0):p1-(nocc+vloc0)] = eri_phys[p0-p00:p1-p00, :nocc, :nocc, :nocc]
                eris.xovo[p0-(nocc+vloc0):p1-(nocc+vloc0)] = eri_phys[p0-p00:p1-p00, :nocc, nocc:, :nocc]
                eris.xvoo[p0-(nocc+vloc0):p1-(nocc+vloc0)] = eri_phys[p0-p00:p1-p00, nocc:, :nocc, :nocc]
                eris.xvvo[p0-(nocc+vloc0):p1-(nocc+vloc0)] = eri_phys[p0-p00:p1-p00, nocc:, nocc:, :nocc]
                eris.xovv[p0-(nocc+vloc0):p1-(nocc+vloc0)] = eri_phys[p0-p00:p1-p00, :nocc, nocc:, nocc:]
                eris.xvvv[p0-(nocc+vloc0):p1-(nocc+vloc0)] = eri_phys[p0-p00:p1-p00, nocc:, nocc:, nocc:]
    cput1 = log.timer('CCSD ao2mo load:               ', *cput1)

    f.close() 
    comm.Barrier()
    os.remove("gccsd_eri_tmp_%s.h5"%rank)
    mycc._eris = eris
    log.timer('CCSD integral transformation   ', *cput0)
    return eris