Ejemplo n.º 1
0
    def save(count, tmp_xo):
        di, dj = tmp_xo.shape[2:4]
        tmp_xo = [tmp_xo[p0:p1] for p0, p1 in olocs]
        tmp_xo = mpi.alltoall(tmp_xo, split_recvbuf=True)
        tmp_xo = sum(tmp_xo).reshape(nocc_seg,nocc,di,dj)
        ftmp[str(count)+'b'] = tmp_xo

        tmp_ox = mpi.alltoall([tmp_xo[:,p0:p1] for p0, p1 in olocs],
                              split_recvbuf=True)
        tmp_ox = [tmp_ox[i].reshape(p1-p0,nocc_seg,di,dj)
                  for i, (p0,p1) in enumerate(olocs)]
        ftmp[str(count)+'a'] = numpy.vstack(tmp_ox)
Ejemplo n.º 2
0
    def save(count, tmp_xo):
        di, dj = tmp_xo.shape[2:4]
        tmp_xo = [tmp_xo[p0:p1] for p0, p1 in olocs]
        tmp_xo = mpi.alltoall(tmp_xo, split_recvbuf=True)
        tmp_xo = sum(tmp_xo).reshape(nocc_seg,nocc,di,dj)
        ftmp[str(count)+'b'] = tmp_xo

        tmp_ox = mpi.alltoall([tmp_xo[:,p0:p1] for p0, p1 in olocs],
                              split_recvbuf=True)
        tmp_ox = [tmp_ox[i].reshape(p1-p0,nocc_seg,di,dj)
                  for i, (p0,p1) in enumerate(olocs)]
        ftmp[str(count)+'a'] = numpy.vstack(tmp_ox)
Ejemplo n.º 3
0
def vector_to_amplitudes(vector, nmo, nocc):
    nvir = nmo - nocc
    nov = nocc * nvir
    nocc2 = nocc * (nocc + 1) // 2
    vlocs = [_task_location(nvir, task_id) for task_id in range(mpi.pool.size)]
    vloc0, vloc1 = vlocs[rank]
    nvir_seg = vloc1 - vloc0

    if rank == 0:
        t1T = vector[:nov].copy().reshape((nvir, nocc))
        mpi.bcast(t1T)
        t2tril = vector[nov:].reshape(nvir_seg, nvir, nocc2)
    else:
        t1T = mpi.bcast(None)
        t2tril = vector.reshape(nvir_seg, nvir, nocc2)

    t2T = lib.unpack_tril(t2tril.reshape(nvir_seg * nvir, nocc2),
                          filltriu=lib.PLAIN)
    t2T = t2T.reshape(nvir_seg, nvir, nocc, nocc)
    t2tmp = mpi.alltoall([t2tril[:, p0:p1] for p0, p1 in vlocs],
                         split_recvbuf=True)
    idx, idy = numpy.tril_indices(nocc)
    for task_id, (p0, p1) in enumerate(vlocs):
        tmp = t2tmp[task_id].reshape(p1 - p0, nvir_seg, nocc2)
        t2T[:, p0:p1, idy, idx] = tmp.transpose(1, 0, 2)
    return t1T.T, t2T.transpose(2, 3, 0, 1)
Ejemplo n.º 4
0
def alltoall(n, m):
    import numpy
    from mpi4pyscf.tools import mpi
    mpi.INT_MAX = 7
    arrs = [
        numpy.ones((n + i - mpi.rank, m - i + mpi.rank))
        for i in range(mpi.pool.size)
    ]
    res = mpi.alltoall(arrs)
    print(res.shape)

    res = mpi.alltoall(arrs, split_recvbuf=True)
    print([x.shape for x in res])

    if mpi.rank < 3:
        d1 = 3
    else:
        d1 = 1
    arrs = [numpy.zeros(s) for s in [(d1)] * mpi.pool.size]
    res = mpi.alltoall(arrs, split_recvbuf=True)
    print([x.shape for x in res])
Ejemplo n.º 5
0
 def save(k, p0, p1, segs):
     segs = mpi.alltoall(segs)
     naux1 = nauxs[uniq_inverse[k]]
     loc0, loc1 = min(p0, naux1-naux0), min(p1, naux1-naux0)
     nL = loc1 - loc0
     if nL > 0:
         if aosym_s2[k]:
             segs = numpy.hstack([segs[i0*nL:i1*nL].reshape(nL,-1)
                                  for i0,i1 in segs_loc_s2])
         else:
             segs = numpy.hstack([segs[i0*nL:i1*nL].reshape(nL,-1)
                                  for i0,i1 in segs_loc_s1])
         feri['j3c/%d'%k][loc0:loc1] = segs
Ejemplo n.º 6
0
    def start(self, interval=0.02):
        mycc = self._cc
        log = logger.new_logger(mycc)
        cpu1 = (time.clock(), time.time())
        eris = mycc._eris
        t2T = mycc.t2.transpose(2,3,0,1)

        nocc, nvir = mycc.t1.shape
        nmo = nocc + nvir
        vloc0, vloc1 = self.vranges[rank]
        nvir_seg = vloc1 - vloc0

        max_memory = min(24000, mycc.max_memory - lib.current_memory()[0])
        blksize = min(nvir_seg//4+1, max(16, int(max_memory*.3e6/8/(nvir*nocc*nmo))))
        self.eri_tmp = lib.H5TmpFile()
        vvop = self.eri_tmp.create_dataset('vvop', (nvir_seg,nvir,nocc,nmo), 'f8')

        def save_vvop(j0, j1, vvvo):
            buf = numpy.empty((j1-j0,nvir,nocc,nmo), dtype=t2T.dtype)
            buf[:,:,:,:nocc] = eris.ovov[:,j0:j1].conj().transpose(1,3,0,2)
            for k, (q0, q1) in enumerate(self.vranges):
                blk = vvvo[k].reshape(q1-q0,nvir,j1-j0,nocc)
                buf[:,q0:q1,:,nocc:] = blk.transpose(2,0,3,1)
            vvop[j0:j1] = buf

        with lib.call_in_background(save_vvop) as save_vvop:
            for p0, p1 in mpi.prange(vloc0, vloc1, blksize):
                j0, j1 = p0 - vloc0, p1 - vloc0
                sub_locs = comm.allgather((p0,p1))
                vvvo = mpi.alltoall([eris.vvvo[:,:,q0:q1] for q0, q1 in sub_locs],
                                    split_recvbuf=True)
                save_vvop(j0, j1, vvvo)
                cpu1 = log.timer_debug1('transpose %d:%d'%(p0,p1), *cpu1)

        def send_data():
            while True:
                while comm.Iprobe(source=MPI.ANY_SOURCE, tag=INQUIRY):
                    tensors, dest = comm.recv(source=MPI.ANY_SOURCE, tag=INQUIRY)
                    for task, slices in tensors:
                        if task == 'Done':
                            return
                        else:
                            mpi.send(self._get_tensor(task, slices), dest,
                                     tag=TRANSFER_DATA)
                time.sleep(interval)

        daemon = threading.Thread(target=send_data)
        daemon.start()
        return daemon
Ejemplo n.º 7
0
 def save(k, p0, p1, segs):
     segs = mpi.alltoall(segs)
     naux1 = nauxs[uniq_inverse[k]]
     loc0, loc1 = min(p0, naux1 - naux0), min(p1, naux1 - naux0)
     nL = loc1 - loc0
     if nL > 0:
         if aosym_s2[k]:
             segs = numpy.hstack([
                 segs[i0 * nL:i1 * nL].reshape(nL, -1)
                 for i0, i1 in segs_loc_s2
             ])
         else:
             segs = numpy.hstack([
                 segs[i0 * nL:i1 * nL].reshape(nL, -1)
                 for i0, i1 in segs_loc_s1
             ])
         feri['j3c/%d' % k][loc0:loc1] = segs
Ejemplo n.º 8
0
    def save_vir_frac(p0, p1, eri):
        log.alldebug1('save_vir_frac %d %d %s', p0, p1, eri.shape)
        eri = eri.reshape(p1-p0,nocc,nmo,nmo)
        eris.ovoo[:,p0:p1] = eri[:,:,:nocc,:nocc].transpose(1,0,2,3)
        eris.ovvo[:,p0:p1] = eri[:,:,nocc:,:nocc].transpose(1,0,2,3)
        eris.ovov[:,p0:p1] = eri[:,:,:nocc,nocc:].transpose(1,0,2,3)
#        vvv = lib.pack_tril(eri[:,:,nocc:,nocc:].reshape((p1-p0)*nocc,nvir,nvir))
#        eris.ovvv[:,p0:p1] = vvv.reshape(p1-p0,nocc,nvpair).transpose(1,0,2)

        cput2 = time.clock(), time.time()
        ovvv_segs = [eri[:,:,nocc+q0:nocc+q1,nocc:].transpose(2,3,0,1) for q0,q1 in vlocs]
        ovvv_segs = mpi.alltoall(ovvv_segs, split_recvbuf=True)
        cput2 = log.timer_debug1('vvvo alltoall', *cput2)
        for task_id, (q0, q1) in enumerate(comm.allgather((p0,p1))):
            ip0 = q0 + vlocs[task_id][0]
            ip1 = q1 + vlocs[task_id][0]
            eris.vvvo[:,:,ip0:ip1] = ovvv_segs[task_id].reshape(vseg,nvir,q1-q0,nocc)
Ejemplo n.º 9
0
    def save_vir_frac(p0, p1, eri):
        log.alldebug1('save_vir_frac %d %d %s', p0, p1, eri.shape)
        eri = eri.reshape(p1 - p0, nocc, nmo, nmo)
        eris.ovoo[:, p0:p1] = eri[:, :, :nocc, :nocc].transpose(1, 0, 2, 3)
        eris.ovvo[:, p0:p1] = eri[:, :, nocc:, :nocc].transpose(1, 0, 2, 3)
        eris.ovov[:, p0:p1] = eri[:, :, :nocc, nocc:].transpose(1, 0, 2, 3)
        #        vvv = lib.pack_tril(eri[:,:,nocc:,nocc:].reshape((p1-p0)*nocc,nvir,nvir))
        #        eris.ovvv[:,p0:p1] = vvv.reshape(p1-p0,nocc,nvpair).transpose(1,0,2)

        cput2 = time.clock(), time.time()
        ovvv_segs = [
            eri[:, :, nocc + q0:nocc + q1, nocc:].transpose(2, 3, 0, 1)
            for q0, q1 in vlocs
        ]
        ovvv_segs = mpi.alltoall(ovvv_segs, split_recvbuf=True)
        cput2 = log.timer_debug1('vvvo alltoall', *cput2)
        for task_id, (q0, q1) in enumerate(comm.allgather((p0, p1))):
            ip0 = q0 + vlocs[task_id][0]
            ip1 = q1 + vlocs[task_id][0]
            eris.vvvo[:, :, ip0:ip1] = ovvv_segs[task_id].reshape(
                vseg, nvir, q1 - q0, nocc)
Ejemplo n.º 10
0
def vector_to_amplitudes(vector, nmo, nocc):
    nvir = nmo - nocc
    nov = nocc * nvir
    nocc2 = nocc*(nocc+1)//2
    vlocs = [_task_location(nvir, task_id) for task_id in range(mpi.pool.size)]
    vloc0, vloc1 = vlocs[rank]
    nvir_seg = vloc1 - vloc0

    if rank == 0:
        t1T = vector[:nov].copy().reshape((nvir,nocc))
        mpi.bcast(t1T)
        t2tril = vector[nov:].reshape(nvir_seg,nvir,nocc2)
    else:
        t1T = mpi.bcast(None)
        t2tril = vector.reshape(nvir_seg,nvir,nocc2)

    t2T = lib.unpack_tril(t2tril.reshape(nvir_seg*nvir,nocc2), filltriu=lib.PLAIN)
    t2T = t2T.reshape(nvir_seg,nvir,nocc,nocc)
    t2tmp = mpi.alltoall([t2tril[:,p0:p1] for p0,p1 in vlocs], split_recvbuf=True)
    idx,idy = numpy.tril_indices(nocc)
    for task_id, (p0, p1) in enumerate(vlocs):
        tmp = t2tmp[task_id].reshape(p1-p0,nvir_seg,nocc2)
        t2T[:,p0:p1,idy,idx] = tmp.transpose(1,0,2)
    return t1T.T, t2T.transpose(2,3,0,1)
Ejemplo n.º 11
0
def update_amps(mycc, t1, t2, eris):
    time1 = time0 = time.clock(), time.time()
    log = logger.Logger(mycc.stdout, mycc.verbose)
    cpu1 = time0

    t1T = t1.T
    t2T = numpy.asarray(t2.transpose(2, 3, 0, 1), order='C')
    nvir_seg, nvir, nocc = t2T.shape[:3]
    t1 = t2 = None
    ntasks = mpi.pool.size
    vlocs = [_task_location(nvir, task_id) for task_id in range(ntasks)]
    vloc0, vloc1 = vlocs[rank]
    log.debug2('vlocs %s', vlocs)
    assert (vloc1 - vloc0 == nvir_seg)

    fock = eris.fock
    mo_e_o = eris.mo_energy[:nocc]
    mo_e_v = eris.mo_energy[nocc:] + mycc.level_shift

    def _rotate_vir_block(buf):
        for task_id, buf in _rotate_tensor_block(buf):
            loc0, loc1 = vlocs[task_id]
            yield task_id, buf, loc0, loc1

    fswap = lib.H5TmpFile()
    wVooV = numpy.zeros((nvir_seg, nocc, nocc, nvir))
    eris_voov = _cp(eris.ovvo).transpose(1, 0, 3, 2)
    tau = t2T * .5
    tau += numpy.einsum('ai,bj->abij', t1T[vloc0:vloc1], t1T)
    for task_id, tau, p0, p1 in _rotate_vir_block(tau):
        wVooV += lib.einsum('bkic,cajk->bija', eris_voov[:, :, :, p0:p1], tau)
    fswap['wVooV1'] = wVooV
    wVooV = tau = None
    time1 = log.timer_debug1('wVooV', *time1)

    wVOov = eris_voov
    eris_VOov = eris_voov - eris_voov.transpose(0, 2, 1, 3) * .5
    tau = t2T.transpose(2, 0, 3, 1) - t2T.transpose(3, 0, 2, 1) * .5
    tau -= numpy.einsum('ai,bj->jaib', t1T[vloc0:vloc1], t1T)
    for task_id, tau, p0, p1 in _rotate_vir_block(tau):
        wVOov += lib.einsum('dlkc,kcjb->dljb', eris_VOov[:, :, :, p0:p1], tau)
    fswap['wVOov1'] = wVOov
    wVOov = tau = eris_VOov = eris_voov = None
    time1 = log.timer_debug1('wVOov', *time1)

    t1Tnew = numpy.zeros_like(t1T)
    t2Tnew = mycc._add_vvvv(t1T, t2T, eris, t2sym='jiba')
    time1 = log.timer_debug1('vvvv', *time1)

    #** make_inter_F
    fov = fock[:nocc, nocc:].copy()
    t1Tnew += fock[nocc:, :nocc]

    foo = fock[:nocc, :nocc] - numpy.diag(mo_e_o)
    foo += .5 * numpy.einsum('ia,aj->ij', fock[:nocc, nocc:], t1T)

    fvv = fock[nocc:, nocc:] - numpy.diag(mo_e_v)
    fvv -= .5 * numpy.einsum('ai,ib->ab', t1T, fock[:nocc, nocc:])

    foo_priv = numpy.zeros_like(foo)
    fov_priv = numpy.zeros_like(fov)
    fvv_priv = numpy.zeros_like(fvv)
    t1T_priv = numpy.zeros_like(t1T)

    max_memory = mycc.max_memory - lib.current_memory()[0]
    unit = nocc * nvir**2 * 3 + nocc**2 * nvir + 1
    blksize = min(nvir,
                  max(BLKMIN, int((max_memory * .9e6 / 8 - t2T.size) / unit)))
    log.debug1('pass 1, max_memory %d MB,  nocc,nvir = %d,%d  blksize = %d',
               max_memory, nocc, nvir, blksize)

    buf = numpy.empty((blksize, nvir, nvir, nocc))

    def load_vvvo(p0):
        p1 = min(nvir_seg, p0 + blksize)
        if p0 < p1:
            buf[:p1 - p0] = eris.vvvo[p0:p1]

    fswap.create_dataset('wVooV', (nvir_seg, nocc, nocc, nvir), 'f8')
    wVOov = []

    with lib.call_in_background(load_vvvo) as prefetch:
        load_vvvo(0)
        for p0, p1 in lib.prange(vloc0, vloc1, blksize):
            i0, i1 = p0 - vloc0, p1 - vloc0
            eris_vvvo, buf = buf[:p1 - p0], numpy.empty_like(buf)
            prefetch(i1)

            fvv_priv[p0:p1] += 2 * numpy.einsum('ck,abck->ab', t1T, eris_vvvo)
            fvv_priv -= numpy.einsum('ck,cabk->ab', t1T[p0:p1], eris_vvvo)

            if not mycc.direct:
                raise NotImplementedError
                tau = t2T[i0:i1] + numpy.einsum('ai,bj->abij', t1T[p0:p1], t1T)
                for task_id, tau, q0, q1 in _rotate_vir_block(tau):
                    tmp = lib.einsum('bdck,cdij->bkij', eris_vvvo[:, :, q0:q1],
                                     tau)
                    t2Tnew -= lib.einsum('ak,bkij->baji', t1T, tmp)
                tau = tmp = None

            fswap['wVooV'][i0:i1] = lib.einsum('cj,baci->bija', -t1T,
                                               eris_vvvo)

            theta = t2T[i0:i1].transpose(0, 2, 1, 3) * 2
            theta -= t2T[i0:i1].transpose(0, 3, 1, 2)
            t1T_priv += lib.einsum('bicj,bacj->ai', theta, eris_vvvo)
            wVOov.append(lib.einsum('acbi,cj->abij', eris_vvvo, t1T))
            theta = eris_vvvo = None
            time1 = log.timer_debug1('vvvo [%d:%d]' % (p0, p1), *time1)

    wVOov = numpy.vstack(wVOov)
    wVOov = mpi.alltoall([wVOov[:, q0:q1] for q0, q1 in vlocs],
                         split_recvbuf=True)
    wVOov = numpy.vstack([x.reshape(-1, nvir_seg, nocc, nocc) for x in wVOov])
    fswap['wVOov'] = wVOov.transpose(1, 2, 3, 0)
    wVooV = None

    unit = nocc**2 * nvir * 7 + nocc**3 + nocc * nvir**2
    max_memory = max(0, mycc.max_memory - lib.current_memory()[0])
    blksize = min(nvir,
                  max(BLKMIN, int((max_memory * .9e6 / 8 - nocc**4) / unit)))
    log.debug1('pass 2, max_memory %d MB,  nocc,nvir = %d,%d  blksize = %d',
               max_memory, nocc, nvir, blksize)

    woooo = numpy.zeros((nocc, nocc, nocc, nocc))

    for p0, p1 in lib.prange(vloc0, vloc1, blksize):
        i0, i1 = p0 - vloc0, p1 - vloc0
        wVOov = fswap['wVOov'][i0:i1]
        wVooV = fswap['wVooV'][i0:i1]
        eris_ovoo = eris.ovoo[:, i0:i1]
        eris_oovv = numpy.empty((nocc, nocc, i1 - i0, nvir))

        def load_oovv(p0, p1):
            eris_oovv[:] = eris.oovv[:, :, p0:p1]

        with lib.call_in_background(load_oovv) as prefetch_oovv:
            #:eris_oovv = eris.oovv[:,:,i0:i1]
            prefetch_oovv(i0, i1)
            foo_priv += numpy.einsum('ck,kcji->ij', 2 * t1T[p0:p1], eris_ovoo)
            foo_priv += numpy.einsum('ck,icjk->ij', -t1T[p0:p1], eris_ovoo)
            tmp = lib.einsum('al,jaik->lkji', t1T[p0:p1], eris_ovoo)
            woooo += tmp + tmp.transpose(1, 0, 3, 2)
            tmp = None

            wVOov -= lib.einsum('jbik,ak->bjia', eris_ovoo, t1T)
            t2Tnew[i0:i1] += wVOov.transpose(0, 3, 1, 2)

            wVooV += lib.einsum('kbij,ak->bija', eris_ovoo, t1T)
            eris_ovoo = None
        load_oovv = prefetch_oovv = None

        eris_ovvo = numpy.empty((nocc, i1 - i0, nvir, nocc))

        def load_ovvo(p0, p1):
            eris_ovvo[:] = eris.ovvo[:, p0:p1]

        with lib.call_in_background(load_ovvo) as prefetch_ovvo:
            #:eris_ovvo = eris.ovvo[:,i0:i1]
            prefetch_ovvo(i0, i1)
            t1T_priv[p0:p1] -= numpy.einsum('bj,jiab->ai', t1T, eris_oovv)
            wVooV -= eris_oovv.transpose(2, 0, 1, 3)
            wVOov += wVooV * .5  #: bjia + bija*.5
        eris_voov = eris_ovvo.transpose(1, 0, 3, 2)
        eris_ovvo = None
        load_ovvo = prefetch_ovvo = None

        def update_wVooV(i0, i1):
            wVooV[:] += fswap['wVooV1'][i0:i1]
            fswap['wVooV1'][i0:i1] = wVooV
            wVOov[:] += fswap['wVOov1'][i0:i1]
            fswap['wVOov1'][i0:i1] = wVOov

        with lib.call_in_background(update_wVooV) as update_wVooV:
            update_wVooV(i0, i1)
            t2Tnew[i0:i1] += eris_voov.transpose(0, 3, 1, 2) * .5
            t1T_priv[p0:p1] += 2 * numpy.einsum('bj,aijb->ai', t1T, eris_voov)

            tmp = lib.einsum('ci,kjbc->bijk', t1T, eris_oovv)
            tmp += lib.einsum('bjkc,ci->bjik', eris_voov, t1T)
            t2Tnew[i0:i1] -= lib.einsum('bjik,ak->baji', tmp, t1T)
            eris_oovv = tmp = None

            fov_priv[:,
                     p0:p1] += numpy.einsum('ck,aikc->ia', t1T, eris_voov) * 2
            fov_priv[:, p0:p1] -= numpy.einsum('ck,akic->ia', t1T, eris_voov)

            tau = numpy.einsum('ai,bj->abij', t1T[p0:p1] * .5, t1T)
            tau += t2T[i0:i1]
            theta = tau.transpose(0, 1, 3, 2) * 2
            theta -= tau
            fvv_priv -= lib.einsum('caij,cjib->ab', theta, eris_voov)
            foo_priv += lib.einsum('aikb,abkj->ij', eris_voov, theta)
            tau = theta = None

            tau = t2T[i0:i1] + numpy.einsum('ai,bj->abij', t1T[p0:p1], t1T)
            woooo += lib.einsum('abij,aklb->ijkl', tau, eris_voov)
            tau = None
        eris_VOov = wVOov = wVooV = update_wVooV = None
        time1 = log.timer_debug1('voov [%d:%d]' % (p0, p1), *time1)

    wVooV = _cp(fswap['wVooV1'])
    for task_id, wVooV, p0, p1 in _rotate_vir_block(wVooV):
        tmp = lib.einsum('ackj,ckib->ajbi', t2T[:, p0:p1], wVooV)
        t2Tnew += tmp.transpose(0, 2, 3, 1)
        t2Tnew += tmp.transpose(0, 2, 1, 3) * .5
    wVooV = tmp = None
    time1 = log.timer_debug1('contracting wVooV', *time1)

    wVOov = _cp(fswap['wVOov1'])
    theta = t2T * 2
    theta -= t2T.transpose(0, 1, 3, 2)
    for task_id, wVOov, p0, p1 in _rotate_vir_block(wVOov):
        t2Tnew += lib.einsum('acik,ckjb->abij', theta[:, p0:p1], wVOov)
    wVOov = theta = None
    fswap = None
    time1 = log.timer_debug1('contracting wVOov', *time1)

    foo += mpi.allreduce(foo_priv)
    fov += mpi.allreduce(fov_priv)
    fvv += mpi.allreduce(fvv_priv)

    theta = t2T.transpose(0, 1, 3, 2) * 2 - t2T
    t1T_priv[vloc0:vloc1] += numpy.einsum('jb,abji->ai', fov, theta)
    ovoo = _cp(eris.ovoo)
    for task_id, ovoo, p0, p1 in _rotate_vir_block(ovoo):
        t1T_priv[vloc0:vloc1] -= lib.einsum('jbki,abjk->ai', ovoo,
                                            theta[:, p0:p1])
    theta = ovoo = None

    woooo = mpi.allreduce(woooo)
    woooo += _cp(eris.oooo).transpose(0, 2, 1, 3)
    tau = t2T + numpy.einsum('ai,bj->abij', t1T[vloc0:vloc1], t1T)
    t2Tnew += .5 * lib.einsum('abkl,ijkl->abij', tau, woooo)
    tau = woooo = None

    t1Tnew += mpi.allreduce(t1T_priv)

    ft_ij = foo + numpy.einsum('aj,ia->ij', .5 * t1T, fov)
    ft_ab = fvv - numpy.einsum('ai,ib->ab', .5 * t1T, fov)
    t2Tnew += lib.einsum('acij,bc->abij', t2T, ft_ab)
    t2Tnew -= lib.einsum('ki,abkj->abij', ft_ij, t2T)

    eia = mo_e_o[:, None] - mo_e_v
    t1Tnew += numpy.einsum('bi,ab->ai', t1T, fvv)
    t1Tnew -= numpy.einsum('aj,ji->ai', t1T, foo)
    t1Tnew /= eia.T

    t2tmp = mpi.alltoall([t2Tnew[:, p0:p1] for p0, p1 in vlocs],
                         split_recvbuf=True)
    for task_id, (p0, p1) in enumerate(vlocs):
        tmp = t2tmp[task_id].reshape(p1 - p0, nvir_seg, nocc, nocc)
        t2Tnew[:, p0:p1] += tmp.transpose(1, 0, 3, 2)

    for i in range(vloc0, vloc1):
        t2Tnew[i - vloc0] /= lib.direct_sum('i+jb->bij', eia[:, i], eia)

    time0 = log.timer_debug1('update t1 t2', *time0)
    return t1Tnew.T, t2Tnew.transpose(2, 3, 0, 1)
Ejemplo n.º 12
0
def update_amps(mycc, t1, t2, eris):
    time1 = time0 = time.clock(), time.time()
    log = logger.Logger(mycc.stdout, mycc.verbose)
    cpu1 = time0

    t1T = t1.T
    t2T = numpy.asarray(t2.transpose(2,3,0,1), order='C')
    nvir_seg, nvir, nocc = t2T.shape[:3]
    t1 = t2 = None
    ntasks = mpi.pool.size
    vlocs = [_task_location(nvir, task_id) for task_id in range(ntasks)]
    vloc0, vloc1 = vlocs[rank]
    log.debug2('vlocs %s', vlocs)
    assert(vloc1-vloc0 == nvir_seg)

    fock = eris.fock
    mo_e_o = eris.mo_energy[:nocc]
    mo_e_v = eris.mo_energy[nocc:] + mycc.level_shift

    def _rotate_vir_block(buf):
        for task_id, buf in _rotate_tensor_block(buf):
            loc0, loc1 = vlocs[task_id]
            yield task_id, buf, loc0, loc1

    fswap = lib.H5TmpFile()
    wVooV = numpy.zeros((nvir_seg,nocc,nocc,nvir))
    eris_voov = _cp(eris.ovvo).transpose(1,0,3,2)
    tau  = t2T * .5
    tau += numpy.einsum('ai,bj->abij', t1T[vloc0:vloc1], t1T)
    for task_id, tau, p0, p1 in _rotate_vir_block(tau):
        wVooV += lib.einsum('bkic,cajk->bija', eris_voov[:,:,:,p0:p1], tau)
    fswap['wVooV1'] = wVooV
    wVooV = tau = None
    time1 = log.timer_debug1('wVooV', *time1)

    wVOov = eris_voov
    eris_VOov = eris_voov - eris_voov.transpose(0,2,1,3)*.5
    tau  = t2T.transpose(2,0,3,1) - t2T.transpose(3,0,2,1)*.5
    tau -= numpy.einsum('ai,bj->jaib', t1T[vloc0:vloc1], t1T)
    for task_id, tau, p0, p1 in _rotate_vir_block(tau):
        wVOov += lib.einsum('dlkc,kcjb->dljb', eris_VOov[:,:,:,p0:p1], tau)
    fswap['wVOov1'] = wVOov
    wVOov = tau = eris_VOov = eris_voov = None
    time1 = log.timer_debug1('wVOov', *time1)

    t1Tnew = numpy.zeros_like(t1T)
    t2Tnew = mycc._add_vvvv(t1T, t2T, eris, t2sym='jiba')
    time1 = log.timer_debug1('vvvv', *time1)

#** make_inter_F
    fov = fock[:nocc,nocc:].copy()
    t1Tnew += fock[nocc:,:nocc]

    foo = fock[:nocc,:nocc] - numpy.diag(mo_e_o)
    foo += .5 * numpy.einsum('ia,aj->ij', fock[:nocc,nocc:], t1T)

    fvv = fock[nocc:,nocc:] - numpy.diag(mo_e_v)
    fvv -= .5 * numpy.einsum('ai,ib->ab', t1T, fock[:nocc,nocc:])

    foo_priv = numpy.zeros_like(foo)
    fov_priv = numpy.zeros_like(fov)
    fvv_priv = numpy.zeros_like(fvv)
    t1T_priv = numpy.zeros_like(t1T)

    max_memory = mycc.max_memory - lib.current_memory()[0]
    unit = nocc*nvir**2*3 + nocc**2*nvir + 1
    blksize = min(nvir, max(BLKMIN, int((max_memory*.9e6/8-t2T.size)/unit)))
    log.debug1('pass 1, max_memory %d MB,  nocc,nvir = %d,%d  blksize = %d',
               max_memory, nocc, nvir, blksize)

    buf = numpy.empty((blksize,nvir,nvir,nocc))
    def load_vvvo(p0):
        p1 = min(nvir_seg, p0+blksize)
        if p0 < p1:
            buf[:p1-p0] = eris.vvvo[p0:p1]
    fswap.create_dataset('wVooV', (nvir_seg,nocc,nocc,nvir), 'f8')
    wVOov = []

    with lib.call_in_background(load_vvvo) as prefetch:
        load_vvvo(0)
        for p0, p1 in lib.prange(vloc0, vloc1, blksize):
            i0, i1 = p0 - vloc0, p1 - vloc0
            eris_vvvo, buf = buf[:p1-p0], numpy.empty_like(buf)
            prefetch(i1)

            fvv_priv[p0:p1] += 2*numpy.einsum('ck,abck->ab', t1T, eris_vvvo)
            fvv_priv -= numpy.einsum('ck,cabk->ab', t1T[p0:p1], eris_vvvo)

            if not mycc.direct:
                raise NotImplementedError
                tau = t2T[i0:i1] + numpy.einsum('ai,bj->abij', t1T[p0:p1], t1T)
                for task_id, tau, q0, q1 in _rotate_vir_block(tau):
                    tmp = lib.einsum('bdck,cdij->bkij', eris_vvvo[:,:,q0:q1], tau)
                    t2Tnew -= lib.einsum('ak,bkij->baji', t1T, tmp)
                tau = tmp = None

            fswap['wVooV'][i0:i1] = lib.einsum('cj,baci->bija', -t1T, eris_vvvo)

            theta  = t2T[i0:i1].transpose(0,2,1,3) * 2
            theta -= t2T[i0:i1].transpose(0,3,1,2)
            t1T_priv += lib.einsum('bicj,bacj->ai', theta, eris_vvvo)
            wVOov.append(lib.einsum('acbi,cj->abij', eris_vvvo, t1T))
            theta = eris_vvvo = None
            time1 = log.timer_debug1('vvvo [%d:%d]'%(p0, p1), *time1)

    wVOov = numpy.vstack(wVOov)
    wVOov = mpi.alltoall([wVOov[:,q0:q1] for q0,q1 in vlocs], split_recvbuf=True)
    wVOov = numpy.vstack([x.reshape(-1,nvir_seg,nocc,nocc) for x in wVOov])
    fswap['wVOov'] = wVOov.transpose(1,2,3,0)
    wVooV = None

    unit = nocc**2*nvir*7 + nocc**3 + nocc*nvir**2
    max_memory = max(0, mycc.max_memory - lib.current_memory()[0])
    blksize = min(nvir, max(BLKMIN, int((max_memory*.9e6/8-nocc**4)/unit)))
    log.debug1('pass 2, max_memory %d MB,  nocc,nvir = %d,%d  blksize = %d',
               max_memory, nocc, nvir, blksize)

    woooo = numpy.zeros((nocc,nocc,nocc,nocc))

    for p0, p1 in lib.prange(vloc0, vloc1, blksize):
        i0, i1 = p0 - vloc0, p1 - vloc0
        wVOov = fswap['wVOov'][i0:i1]
        wVooV = fswap['wVooV'][i0:i1]
        eris_ovoo = eris.ovoo[:,i0:i1]
        eris_oovv = numpy.empty((nocc,nocc,i1-i0,nvir))
        def load_oovv(p0, p1):
            eris_oovv[:] = eris.oovv[:,:,p0:p1]
        with lib.call_in_background(load_oovv) as prefetch_oovv:
            #:eris_oovv = eris.oovv[:,:,i0:i1]
            prefetch_oovv(i0, i1)
            foo_priv += numpy.einsum('ck,kcji->ij', 2*t1T[p0:p1], eris_ovoo)
            foo_priv += numpy.einsum('ck,icjk->ij',  -t1T[p0:p1], eris_ovoo)
            tmp = lib.einsum('al,jaik->lkji', t1T[p0:p1], eris_ovoo)
            woooo += tmp + tmp.transpose(1,0,3,2)
            tmp = None

            wVOov -= lib.einsum('jbik,ak->bjia', eris_ovoo, t1T)
            t2Tnew[i0:i1] += wVOov.transpose(0,3,1,2)

            wVooV += lib.einsum('kbij,ak->bija', eris_ovoo, t1T)
            eris_ovoo = None
        load_oovv = prefetch_oovv = None

        eris_ovvo = numpy.empty((nocc,i1-i0,nvir,nocc))
        def load_ovvo(p0, p1):
            eris_ovvo[:] = eris.ovvo[:,p0:p1]
        with lib.call_in_background(load_ovvo) as prefetch_ovvo:
            #:eris_ovvo = eris.ovvo[:,i0:i1]
            prefetch_ovvo(i0, i1)
            t1T_priv[p0:p1] -= numpy.einsum('bj,jiab->ai', t1T, eris_oovv)
            wVooV -= eris_oovv.transpose(2,0,1,3)
            wVOov += wVooV*.5  #: bjia + bija*.5
        eris_voov = eris_ovvo.transpose(1,0,3,2)
        eris_ovvo = None
        load_ovvo = prefetch_ovvo = None

        def update_wVooV(i0, i1):
            wVooV[:] += fswap['wVooV1'][i0:i1]
            fswap['wVooV1'][i0:i1] = wVooV
            wVOov[:] += fswap['wVOov1'][i0:i1]
            fswap['wVOov1'][i0:i1] = wVOov
        with lib.call_in_background(update_wVooV) as update_wVooV:
            update_wVooV(i0, i1)
            t2Tnew[i0:i1] += eris_voov.transpose(0,3,1,2) * .5
            t1T_priv[p0:p1] += 2*numpy.einsum('bj,aijb->ai', t1T, eris_voov)

            tmp  = lib.einsum('ci,kjbc->bijk', t1T, eris_oovv)
            tmp += lib.einsum('bjkc,ci->bjik', eris_voov, t1T)
            t2Tnew[i0:i1] -= lib.einsum('bjik,ak->baji', tmp, t1T)
            eris_oovv = tmp = None

            fov_priv[:,p0:p1] += numpy.einsum('ck,aikc->ia', t1T, eris_voov) * 2
            fov_priv[:,p0:p1] -= numpy.einsum('ck,akic->ia', t1T, eris_voov)

            tau  = numpy.einsum('ai,bj->abij', t1T[p0:p1]*.5, t1T)
            tau += t2T[i0:i1]
            theta  = tau.transpose(0,1,3,2) * 2
            theta -= tau
            fvv_priv -= lib.einsum('caij,cjib->ab', theta, eris_voov)
            foo_priv += lib.einsum('aikb,abkj->ij', eris_voov, theta)
            tau = theta = None

            tau = t2T[i0:i1] + numpy.einsum('ai,bj->abij', t1T[p0:p1], t1T)
            woooo += lib.einsum('abij,aklb->ijkl', tau, eris_voov)
            tau = None
        eris_VOov = wVOov = wVooV = update_wVooV = None
        time1 = log.timer_debug1('voov [%d:%d]'%(p0, p1), *time1)

    wVooV = _cp(fswap['wVooV1'])
    for task_id, wVooV, p0, p1 in _rotate_vir_block(wVooV):
        tmp = lib.einsum('ackj,ckib->ajbi', t2T[:,p0:p1], wVooV)
        t2Tnew += tmp.transpose(0,2,3,1)
        t2Tnew += tmp.transpose(0,2,1,3) * .5
    wVooV = tmp = None
    time1 = log.timer_debug1('contracting wVooV', *time1)

    wVOov = _cp(fswap['wVOov1'])
    theta  = t2T * 2
    theta -= t2T.transpose(0,1,3,2)
    for task_id, wVOov, p0, p1 in _rotate_vir_block(wVOov):
        t2Tnew += lib.einsum('acik,ckjb->abij', theta[:,p0:p1], wVOov)
    wVOov = theta = None
    fswap = None
    time1 = log.timer_debug1('contracting wVOov', *time1)

    foo += mpi.allreduce(foo_priv)
    fov += mpi.allreduce(fov_priv)
    fvv += mpi.allreduce(fvv_priv)

    theta = t2T.transpose(0,1,3,2) * 2 - t2T
    t1T_priv[vloc0:vloc1] += numpy.einsum('jb,abji->ai', fov, theta)
    ovoo = _cp(eris.ovoo)
    for task_id, ovoo, p0, p1 in _rotate_vir_block(ovoo):
        t1T_priv[vloc0:vloc1] -= lib.einsum('jbki,abjk->ai', ovoo, theta[:,p0:p1])
    theta = ovoo = None

    woooo = mpi.allreduce(woooo)
    woooo += _cp(eris.oooo).transpose(0,2,1,3)
    tau = t2T + numpy.einsum('ai,bj->abij', t1T[vloc0:vloc1], t1T)
    t2Tnew += .5 * lib.einsum('abkl,ijkl->abij', tau, woooo)
    tau = woooo = None

    t1Tnew += mpi.allreduce(t1T_priv)

    ft_ij = foo + numpy.einsum('aj,ia->ij', .5*t1T, fov)
    ft_ab = fvv - numpy.einsum('ai,ib->ab', .5*t1T, fov)
    t2Tnew += lib.einsum('acij,bc->abij', t2T, ft_ab)
    t2Tnew -= lib.einsum('ki,abkj->abij', ft_ij, t2T)

    eia = mo_e_o[:,None] - mo_e_v
    t1Tnew += numpy.einsum('bi,ab->ai', t1T, fvv)
    t1Tnew -= numpy.einsum('aj,ji->ai', t1T, foo)
    t1Tnew /= eia.T

    t2tmp = mpi.alltoall([t2Tnew[:,p0:p1] for p0,p1 in vlocs],
                         split_recvbuf=True)
    for task_id, (p0, p1) in enumerate(vlocs):
        tmp = t2tmp[task_id].reshape(p1-p0,nvir_seg,nocc,nocc)
        t2Tnew[:,p0:p1] += tmp.transpose(1,0,3,2)

    for i in range(vloc0, vloc1):
        t2Tnew[i-vloc0] /= lib.direct_sum('i+jb->bij', eia[:,i], eia)

    time0 = log.timer_debug1('update t1 t2', *time0)
    return t1Tnew.T, t2Tnew.transpose(2,3,0,1)