Пример #1
0
 def test_index_tril_to_pair(self):
     i_j = (numpy.random.random((2, 30)) * 100).astype(int)
     i0 = numpy.max(i_j, axis=0)
     j0 = numpy.min(i_j, axis=0)
     ij = i0 * (i0 + 1) // 2 + j0
     i1, j1 = lib.index_tril_to_pair(ij)
     self.assertTrue(numpy.all(i0 == i1))
     self.assertTrue(numpy.all(j0 == j1))
Пример #2
0
 def test_index_tril_to_pair(self):
     i_j = (numpy.random.random((2,30)) * 100).astype(int)
     i0 = numpy.max(i_j, axis=0)
     j0 = numpy.min(i_j, axis=0)
     ij = i0 * (i0+1) // 2 + j0
     i1, j1 = lib.index_tril_to_pair(ij)
     self.assertTrue(numpy.all(i0 == i1))
     self.assertTrue(numpy.all(j0 == j1))
Пример #3
0
def trans_e1_outcore(mol, mo, ncore, ncas, erifile,
                     max_memory=None, level=1, verbose=logger.WARN):
    time0 = (time.clock(), time.time())
    log = logger.new_logger(mol, verbose)
    log.debug1('trans_e1_outcore level %d  max_memory %d', level, max_memory)
    nao, nmo = mo.shape
    nao_pair = nao*(nao+1)//2
    nocc = ncore + ncas

    faapp_buf = lib.H5TmpFile()
    if isinstance(erifile, h5py.Group):
        feri = erifile
    else:
        feri = lib.H5TmpFile(erifile, 'w')

    mo_c = numpy.asarray(mo, order='C')
    mo = numpy.asarray(mo, order='F')
    pashape = (0, nmo, ncore, nocc)
    papa_buf = numpy.zeros((nao,ncas,nmo*ncas))
    j_pc = numpy.zeros((nmo,ncore))
    k_pc = numpy.zeros((nmo,ncore))

    mem_words = int(max(2000,max_memory-papa_buf.nbytes/1e6)*1e6/8)
    aobuflen = mem_words//(nao_pair+nocc*nmo) + 1
    ao_loc = numpy.array(mol.ao_loc_nr(), dtype=numpy.int32)
    shranges = outcore.guess_shell_ranges(mol, True, aobuflen, None, ao_loc)
    intor = mol._add_suffix('int2e')
    ao2mopt = _ao2mo.AO2MOpt(mol, intor,
                             'CVHFnr_schwarz_cond', 'CVHFsetnr_direct_scf')
    nstep = len(shranges)
    paapp = 0
    maxbuflen = max([x[2] for x in shranges])
    log.debug('mem_words %.8g MB, maxbuflen = %d', mem_words*8/1e6, maxbuflen)
    bufs1 = numpy.empty((maxbuflen, nao_pair))
    bufs2 = numpy.empty((maxbuflen, nmo*ncas))
    if level == 1:
        bufs3 = numpy.empty((maxbuflen, nao*ncore))
        log.debug('mem cache %.8g MB',
                  (bufs1.nbytes+bufs2.nbytes+bufs3.nbytes)/1e6)
    else:
        log.debug('mem cache %.8g MB', (bufs1.nbytes+bufs2.nbytes)/1e6)
    ti0 = log.timer('Initializing trans_e1_outcore', *time0)

    # fmmm, ftrans, fdrv for level 1
    fmmm = libmcscf.AO2MOmmm_ket_nr_s2
    ftrans = libmcscf.AO2MOtranse1_nr_s4
    fdrv = libmcscf.AO2MOnr_e2_drv
    for istep,sh_range in enumerate(shranges):
        log.debug('[%d/%d], AO [%d:%d], len(buf) = %d',
                  istep+1, nstep, *sh_range)
        buf = bufs1[:sh_range[2]]
        _ao2mo.nr_e1fill(intor, sh_range,
                         mol._atm, mol._bas, mol._env, 's4', 1, ao2mopt, buf)
        if log.verbose >= logger.DEBUG1:
            ti1 = log.timer('AO integrals buffer', *ti0)
        bufpa = bufs2[:sh_range[2]]
        _ao2mo.nr_e1(buf, mo, pashape, 's4', 's1', out=bufpa)
# jc_pp, kc_pp
        if level == 1: # ppaa, papa and vhf, jcp, kcp
            if log.verbose >= logger.DEBUG1:
                ti1 = log.timer('buffer-pa', *ti1)
            buf1 = bufs3[:sh_range[2]]
            fdrv(ftrans, fmmm,
                 buf1.ctypes.data_as(ctypes.c_void_p),
                 buf.ctypes.data_as(ctypes.c_void_p),
                 mo.ctypes.data_as(ctypes.c_void_p),
                 ctypes.c_int(sh_range[2]), ctypes.c_int(nao),
                 (ctypes.c_int*4)(0, nao, 0, ncore),
                 ctypes.POINTER(ctypes.c_void_p)(), ctypes.c_int(0))
            p0 = 0
            for ij in range(sh_range[0], sh_range[1]):
                i,j = lib.index_tril_to_pair(ij)
                i0 = ao_loc[i]
                j0 = ao_loc[j]
                i1 = ao_loc[i+1]
                j1 = ao_loc[j+1]
                di = i1 - i0
                dj = j1 - j0
                if i == j:
                    dij = di * (di+1) // 2
                    buf = numpy.empty((di,di,nao*ncore))
                    idx = numpy.tril_indices(di)
                    buf[idx] = buf1[p0:p0+dij]
                    buf[idx[1],idx[0]] = buf1[p0:p0+dij]
                    buf = buf.reshape(di,di,nao,ncore)
                    mo1 = mo_c[i0:i1]
                    tmp = numpy.einsum('uvpc,pc->uvc', buf, mo[:,:ncore])
                    tmp = lib.dot(mo1.T, tmp.reshape(di,-1))
                    j_pc += numpy.einsum('vp,pvc->pc', mo1, tmp.reshape(nmo,di,ncore))
                    tmp = numpy.einsum('uvpc,uc->vcp', buf, mo1[:,:ncore])
                    tmp = lib.dot(tmp.reshape(-1,nmo), mo).reshape(di,ncore,nmo)
                    k_pc += numpy.einsum('vp,vcp->pc', mo1, tmp)
                else:
                    dij = di * dj
                    buf = buf1[p0:p0+dij].reshape(di,dj,nao,ncore)
                    mo1 = mo_c[i0:i1]
                    mo2 = mo_c[j0:j1]
                    tmp = numpy.einsum('uvpc,pc->uvc', buf, mo[:,:ncore])
                    tmp = lib.dot(mo1.T, tmp.reshape(di,-1))
                    j_pc += numpy.einsum('vp,pvc->pc',
                                         mo2, tmp.reshape(nmo,dj,ncore)) * 2
                    tmp = numpy.einsum('uvpc,uc->vcp', buf, mo1[:,:ncore])
                    tmp = lib.dot(tmp.reshape(-1,nmo), mo).reshape(dj,ncore,nmo)
                    k_pc += numpy.einsum('vp,vcp->pc', mo2, tmp)
                    tmp = numpy.einsum('uvpc,vc->ucp', buf, mo2[:,:ncore])
                    tmp = lib.dot(tmp.reshape(-1,nmo), mo).reshape(di,ncore,nmo)
                    k_pc += numpy.einsum('up,ucp->pc', mo1, tmp)
                p0 += dij
            if log.verbose >= logger.DEBUG1:
                ti1 = log.timer('j_cp and k_cp', *ti1)

        if log.verbose >= logger.DEBUG1:
            ti1 = log.timer('half transformation of the buffer', *ti1)

# ppaa, papa
        faapp_buf[str(istep)] = \
                bufpa.reshape(sh_range[2],nmo,ncas)[:,ncore:nocc].reshape(-1,ncas**2).T
        p0 = 0
        for ij in range(sh_range[0], sh_range[1]):
            i,j = lib.index_tril_to_pair(ij)
            i0 = ao_loc[i]
            j0 = ao_loc[j]
            i1 = ao_loc[i+1]
            j1 = ao_loc[j+1]
            di = i1 - i0
            dj = j1 - j0
            if i == j:
                dij = di * (di+1) // 2
                buf1 = numpy.empty((di,di,nmo*ncas))
                idx = numpy.tril_indices(di)
                buf1[idx] = bufpa[p0:p0+dij]
                buf1[idx[1],idx[0]] = bufpa[p0:p0+dij]
            else:
                dij = di * dj
                buf1 = bufpa[p0:p0+dij].reshape(di,dj,-1)
                mo1 = mo[j0:j1,ncore:nocc].copy()
                for i in range(di):
                     lib.dot(mo1.T, buf1[i], 1, papa_buf[i0+i], 1)
            mo1 = mo[i0:i1,ncore:nocc].copy()
            buf1 = lib.dot(mo1.T, buf1.reshape(di,-1))
            papa_buf[j0:j1] += buf1.reshape(ncas,dj,-1).transpose(1,0,2)
            p0 += dij
        if log.verbose >= logger.DEBUG1:
            ti1 = log.timer('ppaa and papa buffer', *ti1)

        ti0 = log.timer('gen AO/transform MO [%d/%d]'%(istep+1,nstep), *ti0)
    buf = buf1 = bufpa = None
    bufs1 = bufs2 = bufs3 = None
    time1 = log.timer('mc_ao2mo pass 1', *time0)

    log.debug1('Half transformation done. Current memory %d',
               lib.current_memory()[0])

    nblk = int(max(8, min(nmo, (max_memory*1e6/8-papa_buf.size)/(ncas**2*nmo))))
    log.debug1('nblk for papa = %d', nblk)
    dset = feri.create_dataset('papa', (nmo,ncas,nmo,ncas), 'f8')
    for i0, i1 in prange(0, nmo, nblk):
        tmp = lib.dot(mo[:,i0:i1].T, papa_buf.reshape(nao,-1))
        dset[i0:i1] = tmp.reshape(i1-i0,ncas,nmo,ncas)
    papa_buf = tmp = None
    time1 = log.timer('papa pass 2', *time1)

    tmp = numpy.empty((ncas**2,nao_pair))
    p0 = 0
    for istep, sh_range in enumerate(shranges):
        tmp[:,p0:p0+sh_range[2]] = faapp_buf[str(istep)]
        p0 += sh_range[2]
    nblk = int(max(8, min(nmo, (max_memory*1e6/8-tmp.size)/(ncas**2*nmo)-1)))
    log.debug1('nblk for ppaa = %d', nblk)
    dset = feri.create_dataset('ppaa', (nmo,nmo,ncas,ncas), 'f8')
    for i0, i1 in prange(0, nmo, nblk):
        tmp1 = _ao2mo.nr_e2(tmp, mo, (i0,i1,0,nmo), 's4', 's1', ao_loc=ao_loc)
        tmp1 = tmp1.reshape(ncas,ncas,i1-i0,nmo)
        for j in range(i1-i0):
            dset[i0+j] = tmp1[:,:,j].transpose(2,0,1)
    tmp = tmp1 = None
    time1 = log.timer('ppaa pass 2', *time1)

    time0 = log.timer('mc_ao2mo', *time0)
    return j_pc, k_pc