Beispiel #1
0
def cc_Wovvo(t1T, t2T, eris, vlocs=None):
    """
    mb[e]j.
    """
    nvir_seg, nvir, nocc, _ = t2T.shape
    if vlocs is None:
        ntasks = mpi.pool.size
        vlocs = [_task_location(nvir, task_id) for task_id in range(ntasks)]

    #Wmbej = einsum('efmn, fj -> mnej', eris.xvoo, -t1T)
    #Wmbej = einsum('mnej, bn -> mbej', Wmbej, t1T)

    #Wmbej -= einsum('mbfe, fj -> mbej', eris.ovvx, t1T)
    #Wmbej += einsum('bn, mnje -> mbej', t1T, eris.ooox)

    Wmbej = 0.0

    for task_id, t2T_tmp, p0, p1 in _rotate_vir_block(t2T, vlocs=vlocs):
        Wmbej -= einsum('fbjn, efmn -> mbej', t2T_tmp, eris.xvoo[:, p0:p1])
        t2T_tmp = None
    Wmbej *= 0.5

    #Wmbej -= np.asarray(eris.oxov).transpose(2, 3, 1, 0)
    Wmbej -= np.asarray(eris.xovo).transpose(3, 2, 0, 1)
    return Wmbej
Beispiel #2
0
def init_amps(mycc, eris=None):
    eris = getattr(mycc, '_eris', None)
    if eris is None:
        mycc.ao2mo()
        eris = mycc._eris

    time0 = logger.process_clock(), logger.perf_counter()
    mo_e = eris.mo_energy
    nocc = mycc.nocc
    nvir = mo_e.size - nocc
    mo_e_o = mo_e[:nocc]
    mo_e_v = mo_e[nocc:] + mycc.level_shift
    eia = mo_e_o[:, None] - mo_e_v
    t1T = eris.fock[nocc:, :nocc] / eia.T
    loc0, loc1 = _task_location(nvir)

    t2T = np.empty((loc1-loc0, nvir, nocc, nocc))
    max_memory = mycc.max_memory - lib.current_memory()[0]
    blksize = int(min(nvir, max(BLKMIN, max_memory*.3e6/8/(nocc**2*nvir+1))))
    emp2 = 0
    for p0, p1 in lib.prange(0, loc1-loc0, blksize):
        eris_vvoo = eris.xvoo[p0:p1]
        t2T[p0:p1] = (eris_vvoo / lib.direct_sum('ia, jb -> abij', eia[:, loc0+p0:loc0+p1], eia))
        emp2 += np.einsum('abij, abij', t2T[p0:p1], eris_vvoo.conj(), optimize=True).real
        eris_vvoo = None

    mycc.emp2 = comm.allreduce(emp2) * 0.25
    logger.info(mycc, 'Init t2, MP2 energy = %.15g', mycc.emp2)
    logger.timer(mycc, 'init mp2', *time0)
    mycc.t1 = t1T.T
    mycc.t2 = t2T.transpose(2, 3, 0, 1)
    return mycc.emp2, mycc.t1, mycc.t2
Beispiel #3
0
def amplitudes_to_vector(t1, t2, out=None):
    """
    amps to vector, with the same bahavior as pyscf gccsd.
    """
    t2T = np.asarray(t2.transpose(2, 3, 0, 1), order='C')
    nvir_seg, nvir, nocc = t2T.shape[:3]
    
    ntasks = mpi.pool.size
    vlocs = [_task_location(nvir, task_id) for task_id in range(ntasks)]
    vtril = get_vtril(nvir, vlocs[rank])
    otril = np.tril_indices(nocc, k=-1)
    nvir2 = len(vtril[0])
    nocc2 = nocc * (nocc - 1) // 2
    size = nvir2 * nocc2
    nov = nocc * nvir
    if rank == 0:
        size += nov
    
    if rank == 0:
        t1T = t1.T
        vector = np.ndarray(size, t1.dtype, buffer=out)
        vector[:nov] = t1T.ravel()
        lib.take_2d(t2T.reshape(-1, nocc**2), vtril[0]*nvir+vtril[1],
                    otril[0]*nocc+otril[1], out=vector[nov:])
    else:
        vector = np.ndarray(size, t1.dtype, buffer=out)
        lib.take_2d(t2T.reshape(-1, nocc**2), vtril[0]*nvir+vtril[1],
                    otril[0]*nocc+otril[1], out=vector) 
    return vector
Beispiel #4
0
def energy(mycc, t1=None, t2=None, eris=None):
    '''CCD correlation energy'''
    if t1 is None:
        t1 = mycc.t1
    if t2 is None:
        t2 = mycc.t2
    eris = getattr(mycc, '_eris', None)
    if eris is None:
        mycc.ao2mo()
        eris = mycc._eris

    nocc, nvir = t1.shape
    fock = eris.fock
    loc0, loc1 = _task_location(nvir)
    #if rank == 0:
    #    e = np.einsum('ia, ia', fock[:nocc, nocc:], t1, optimize=True)
    #else:
    #    e = 0.0
    e = 0.0
    max_memory = mycc.max_memory - lib.current_memory()[0]
    blksize = int(
        min(nvir, max(BLKMIN, max_memory * .3e6 / 8 / (nocc**2 * nvir + 1))))
    for p0, p1 in lib.prange(0, loc1 - loc0, blksize):
        eris_vvoo = eris.xvoo[p0:p1]
        e += np.einsum('ijab, abij', t2[:, :, p0:p1], eris_vvoo, optimize=True)
        #e += 0.50 * np.einsum('ia, jb, abij', t1[:, loc0+p0:loc0+p1], t1,
        #                      eris_vvoo, optimize=True)
    e = comm.allreduce(e) * 0.25

    if rank == 0 and abs(e.imag) > 1e-4:
        logger.warn(mycc, 'Non-zero imaginary part found in CCD energy %s', e)
    return e.real
Beispiel #5
0
def make_tauT(t1T, t2T, fac=1, vlocs=None):
    """
    Make effective t2T (abij) using t2T and t1T.

    Args:
        t1T: ai
        t2T: [a]bij, a is segmented.
        fac: factor

    Returns:
        tauT: [a]bij, a is segmented.
    """
    nvir_seg, nvir, nocc, _ = t2T.shape
    if vlocs is None:
        ntasks = mpi.pool.size
        vlocs = [_task_location(nvir, task_id) for task_id in range(ntasks)]
    vloc0, vloc1 = vlocs[rank]
    tauT = np.einsum('ai, bj -> abij', t1T[vloc0:vloc1] * (fac * 0.5), t1T, optimize=True)
    tauT = tauT - tauT.transpose(0, 1, 3, 2)
    ##:tauT = tauT - tauT.transpose(1, 0, 2, 3)
    tauT_tmp = mpi.alltoall_new([tauT[:, p0:p1] for p0, p1 in vlocs], split_recvbuf=True)
    for task_id, (p0, p1) in enumerate(vlocs):
        tmp = tauT_tmp[task_id].reshape(p1-p0, nvir_seg, nocc, nocc)
        tauT[:, p0:p1] -= tmp.transpose(1, 0, 2, 3)
    tauT += t2T
    return tauT
Beispiel #6
0
def cc_Wvvvv(t1T, t2T, eris, tauT=None, vlocs=None):
    """
    Wvvvv intermidiates.
    """
    # ZHC TODO make Wvvvv outcore
    nvir_seg, nvir, nocc, _ = t2T.shape
    if vlocs is None:
        ntasks = mpi.pool.size
        vlocs = [_task_location(nvir, task_id) for task_id in range(ntasks)]
    vloc0, vloc1 = vlocs[rank]
    if tauT is None:
        tauT = make_tauT(t1T, t2T, vlocs=vlocs)

    Wabef = np.empty((vloc1 - vloc0, nvir, nvir, nvir))
    eris_vvoo = eris.xvoo

    tauT = tauT * 0.25
    for task_id, eri_tmp, p0, p1 in _rotate_vir_block(eris_vvoo, vlocs=vlocs):
        Wabef[:, :, p0:p1] = einsum('abmn, efmn -> abef', tauT, eri_tmp)
        eri_tmp = None
    eris_vvoo = None
    tauT = None

    Wabef += np.asarray(eris.xvvv)
    #tmp = einsum('bm, mafe -> abfe', t1T, eris.oxvv)
    #Wabef += tmp

    #tmpT = mpi.alltoall([tmp[:, p0:p1] for p0, p1 in vlocs],
    #                    split_recvbuf=True)
    #for task_id, (p0, p1) in enumerate(vlocs):
    #    tmp = tmpT[task_id].reshape(p1-p0, nvir_seg, nvir, nvir)
    #    Wabef[:, p0:p1] -= tmp.transpose(1, 0, 2, 3)
    #    tmp = None

    return Wabef
Beispiel #7
0
def transform_t2_to_bo(t2, u, vlocs=None):
    """
    transform t2.
    """
    t2T = t2.transpose(2, 3, 0, 1)
    nvir_seg, nvir, nocc = t2T.shape[:3]
    if vlocs is None:
        ntasks = mpi.pool.size
        vlocs = [_task_location(nvir, task_id) for task_id in range(ntasks)]
    vloc0, vloc1 = vlocs[rank]
    assert vloc1 - vloc0 == nvir_seg
    
    u_oo = u[:nocc, :nocc]
    u_vv = u[nocc:, nocc:] 
            
    t2Tnew = 0.0
    for task_id, t2T_tmp, p0, p1 in _rotate_vir_block(t2T, vlocs=vlocs):
        t2Tnew += np.einsum('aA, abij -> Abij', u_vv[p0:p1, vloc0:vloc1],
                            t2T_tmp, optimize=True)
        t2T_tmp = None
    t2 = t2T = None

    t2Tnew = np.einsum("Abij, bB, iI, jJ -> ABIJ", t2Tnew, u_vv,
                       u_oo, u_oo, optimize=True)
    t2 = t2Tnew.transpose(2, 3, 0, 1)
    return t2
Beispiel #8
0
def cc_Fov(t1T, eris, vlocs=None):
    """
    Fov: me.
    """
    nvir, nocc = t1T.shape
    if vlocs is None:
        ntasks = mpi.pool.size
        vlocs = [_task_location(nvir, task_id) for task_id in range(ntasks)]
    fov  = eris.fock[:nocc, nocc:]
    Fme  = einsum_mv('efmn, fn -> em', eris.xvoo, t1T)
    Fme  = mpi.allgather(Fme).T
    Fme += fov
    return Fme
Beispiel #9
0
 def vector_size(self, nmo=None, nocc=None):
     if nocc is None: nocc = self.nocc
     if nmo is None: nmo = self.nmo
     nvir = nmo - nocc
     ntasks = mpi.pool.size
     vlocs = [_task_location(nvir, task_id) for task_id in range(ntasks)]
     vtril = get_vtril(nvir, vlocs[rank])
     nvir2 = len(vtril[0])
     nocc2 = nocc * (nocc - 1) // 2
     size = nvir2 * nocc2
     if rank == 0:
         size += nocc * nvir
     return size
Beispiel #10
0
def cc_Woooo(t1T, t2T, eris, tauT=None, vlocs=None):
    nvir = t1T.shape[0]
    if vlocs is None:
        ntasks = mpi.pool.size
        vlocs = [_task_location(nvir, task_id) for task_id in range(ntasks)]
    vloc0, vloc1 = vlocs[rank]
    if tauT is None:
        tauT = make_tauT(t1T, t2T, vlocs=vlocs)

    Wmnij = einsum('efmn, efij -> mnij', eris.xvoo, tauT)
    Wmnij *= 0.25
    tauT = None
    #tmp = einsum('mnie, ej -> mnij', eris.ooox, t1T[vloc0:vloc1])
    #Wmnij += tmp
    #Wmnij -= tmp.transpose(0, 1, 3, 2)
    #tmp = None
    Wmnij = mpi.allreduce_inplace(Wmnij)
    Wmnij += eris.oooo
    return Wmnij
Beispiel #11
0
def vector_to_amplitudes(vector, nmo, nocc):
    """
    vector to amps, with the same bahavior as pyscf gccsd.
    """
    nvir = nmo - nocc
    nov = nocc * nvir
    nocc2 = nocc * (nocc - 1) // 2
    otril = np.tril_indices(nocc, k=-1)
    
    vlocs = [_task_location(nvir, task_id) for task_id in range(mpi.pool.size)]
    vloc0, vloc1 = vlocs[rank]
    nvir_seg = vloc1 - vloc0
    vtril = get_vtril(nvir, vlocs[rank])
    nvir2 = len(vtril[0])

    if rank == 0:
        t1T = vector[:nov].copy().reshape(nvir, nocc)
        mpi.bcast(t1T)
        t2tril = vector[nov:].reshape(nvir2, nocc2)
    else:
        t1T = mpi.bcast(None)
        t2tril = vector.reshape(nvir2, nocc2)
    t2T = np.zeros((nvir_seg * nvir, nocc**2), dtype=t2tril.dtype)
    lib.takebak_2d(t2T, t2tril, vtril[0]*nvir+vtril[1], otril[0]*nocc+otril[1])
    # anti-symmetry when exchanging two particle indices
    lib.takebak_2d(t2T, -t2tril, vtril[0]*nvir+vtril[1], otril[1]*nocc+otril[0])
    t2T = t2T.reshape(nvir_seg, nvir, nocc, nocc)
    
    t2tmp = mpi.alltoall_new([t2T[:, p0:p1] for p0,p1 in vlocs], split_recvbuf=True)
    for task_id, (p0, p1) in enumerate(vlocs):
        if task_id < rank:
            # do not need this part since it is already filled.
            continue
        elif task_id == rank:
            # fill the trlu by -tril.
            v_idx = get_vtril(nvir, vlocs[task_id], p0=p0, p1=p1)
            tmp = t2tmp[task_id].reshape(p1-p0, nvir_seg, nocc, nocc)
            t2T[v_idx[1]-p0, v_idx[0]+p0] = tmp[v_idx[0], v_idx[1]-p0].transpose(0, 2, 1)
        else:
            tmp = t2tmp[task_id].reshape(p1-p0, nvir_seg, nocc, nocc)
            t2T[:, p0:p1] = tmp.transpose(1, 0, 3, 2)

    return t1T.T, t2T.transpose(2, 3, 0, 1)
Beispiel #12
0
def cc_Foo(t1T, t2T, eris, tauT_tilde=None, vlocs=None):
    nvir, nocc = t1T.shape
    if vlocs is None:
        ntasks = mpi.pool.size
        vlocs = [_task_location(nvir, task_id) for task_id in range(ntasks)]
    if tauT_tilde is None:
        tauT_tilde = make_tauT(t1T, t2T, fac=0.5, vlocs=vlocs)
    vloc0, vloc1 = vlocs[rank]

    Fmi = 0.5 * einsum('efmn, efin -> mi', eris.xvoo, tauT_tilde)
    tauT_tilde = None
    #Fmi += np.einsum('mnie, en -> mi', eris.ooox, t1T[vloc0:vloc1], optimize=True)
    Fmi = mpi.allreduce_inplace(Fmi)

    fov = eris.fock[:nocc, nocc:]
    foo = eris.fock[:nocc, :nocc]

    Fmi += foo
    #Fmi += 0.5 * np.dot(fov, t1T)
    return Fmi
Beispiel #13
0
def restore_from_h5(mycc, fname="fcc", umat=None):
    """
    Restore t1, t2, l1, l2 from file.
    
    Args:
        mycc: CC object.
        fname: prefix for the filename.
        umat: (nmo, nmo), rotation matrix to rotate amps.

    Return:
        mycc: CC object, with t1, t2, l1, l2 updated.
    """
    _sync_(mycc)
    if fname.endswith(".h5"):
        fname = fname[:-3]
    logger.info(mycc, "restore amps from h5 ...")
    filename = fname + '__rank' + str(rank) + ".h5"
    if all(comm.allgather(os.path.isfile(filename))):
        t1, t2, l1, l2, mo_coeff = mycc.load_amps(fname=fname)
        if umat is not None:
            logger.info(mycc, "rotate amps ...")
            nocc, nvir = t1.shape
            ntasks = mpi.pool.size
            vlocs = [_task_location(nvir, task_id) for task_id in range(ntasks)]
            
            t1 = transform_t1_to_bo(t1, umat)
            t2 = transform_t2_to_bo(t2, umat, vlocs=vlocs)
            
            if l1 is not None:
                l1 = transform_l1_to_bo(l1, umat)
            if l2 is not None:
                l2 = transform_l2_to_bo(l2, umat, vlocs=vlocs)
            
        mycc.t1 = t1
        mycc.t2 = t2
        mycc.l1 = l1
        mycc.l2 = l2
    else:
        raise ValueError("restore_from_h5 failed, (part of) files not exist.")
    return mycc
Beispiel #14
0
def cc_Wovvo(t1T, t2T, eris, vlocs=None):
    """
    mb[e]j.
    """
    nvir_seg, nvir, nocc, _ = t2T.shape
    if vlocs is None:
        ntasks = mpi.pool.size
        vlocs = [_task_location(nvir, task_id) for task_id in range(ntasks)]

    Wmbej = einsum('efmn, fj -> mnej', eris.xvoo, -t1T)
    Wmbej = einsum('mnej, bn -> mbej', Wmbej, t1T)
    
    Wmbej -= einsum('efbm, fj -> mbej', eris.xvvo, t1T)
    Wmbej += einsum('bn, ejnm -> mbej', t1T, eris.xooo)

    for task_id, t2T_tmp, p0, p1 in _rotate_vir_block(t2T, vlocs=vlocs):
        tmp = einsum('fbjn, efmn -> mbej', t2T_tmp, eris.xvoo[:, p0:p1])
        tmp *= (-0.5)
        Wmbej += tmp
        tmp = t2T_tmp = None

    Wmbej -= np.asarray(eris.xovo).transpose(3, 2, 0, 1)
    return Wmbej
Beispiel #15
0
def cc_Fvv(t1T, t2T, eris, tauT_tilde=None, vlocs=None):
    """
    Fvv: ae.
    """
    nvir, nocc = t1T.shape
    if vlocs is None:
        ntasks = mpi.pool.size
        vlocs = [_task_location(nvir, task_id) for task_id in range(ntasks)]
    if tauT_tilde is None:
        tauT_tilde = make_tauT(t1T, t2T, fac=0.5, vlocs=vlocs)
    vloc0, vloc1 = vlocs[rank]

    #fvo = eris.fock[nocc:, :nocc]
    fvv = eris.fock[nocc + vloc0:nocc + vloc1, nocc:]
    Fea = fvv  #- 0.5 * np.dot(fvo[vloc0:vloc1], t1T.T)

    Fae = (-0.5) * einsum('femn, famn -> ae', eris.xvoo, tauT_tilde)
    Fae = mpi.allreduce_inplace(Fae)
    tauT_tilde = None

    #Fea += np.einsum('mafe, fm -> ea', eris.ovvx, t1T, optimize=True)
    Fae += mpi.allgather(Fea).T
    return Fae
Beispiel #16
0
def update_amps(mycc, t1, t2, eris):
    """
    Update GCCD amplitudes.
    """
    time0 = logger.process_clock(), logger.perf_counter()
    log = logger.Logger(mycc.stdout, mycc.verbose)

    t1T = t1.T
    t2T = np.asarray(t2.transpose(2, 3, 0, 1), order='C')
    nvir_seg, nvir, nocc = t2T.shape[:3]
    t2 = None
    ntasks = mpi.pool.size
    vlocs = [_task_location(nvir, task_id) for task_id in range(ntasks)]
    vloc0, vloc1 = vlocs[rank]
    log.debug2('vlocs %s', vlocs)
    assert vloc1 - vloc0 == nvir_seg

    fock = eris.fock
    fvo = fock[nocc:, :nocc]
    mo_e_o = eris.mo_energy[:nocc]
    mo_e_v = eris.mo_energy[nocc:] + mycc.level_shift

    tauT_tilde = make_tauT(t1T, t2T, fac=0.5, vlocs=vlocs)
    Fvv = cc_Fvv(t1T, t2T, eris, tauT_tilde=tauT_tilde, vlocs=vlocs)
    Foo = cc_Foo(t1T, t2T, eris, tauT_tilde=tauT_tilde, vlocs=vlocs)
    tauT_tilde = None
    Fov = cc_Fov(t1T, eris, vlocs=vlocs)

    # Move energy terms to the other side
    Fvv[np.diag_indices(nvir)] -= mo_e_v
    Foo[np.diag_indices(nocc)] -= mo_e_o

    # T1 equation
    t1Tnew = np.zeros_like(t1T)
    #t1Tnew  = np.dot(Fvv, t1T)
    #t1Tnew -= np.dot(t1T, Foo)

    tmp = einsum('aeim, me -> ai', t2T, Fov)
    #tmp -= np.einsum('fn, naif -> ai', t1T, eris.oxov, optimize=True)
    tmp = mpi.allgather(tmp)

    #tmp2  = einsum('eamn, mnie -> ai', t2T, eris.ooox)
    tmp2 = einsum('eamn, einm -> ai', t2T, eris.xooo)
    #tmp2 += einsum('efim, mafe -> ai', t2T, eris.ovvx)
    tmp2 += einsum('efim, efam -> ai', t2T, eris.xvvo)
    tmp2 *= 0.5
    tmp2 = mpi.allreduce_inplace(tmp2)
    tmp += tmp2
    tmp2 = None

    #t1Tnew += tmp
    #t1Tnew += fvo

    # T2 equation
    Ftmp = Fvv  #- 0.5 * np.dot(t1T, Fov)
    t2Tnew = einsum('aeij, be -> abij', t2T, Ftmp)
    t2T_tmp = mpi.alltoall_new([t2Tnew[:, p0:p1] for p0, p1 in vlocs],
                               split_recvbuf=True)
    for task_id, (p0, p1) in enumerate(vlocs):
        tmp = t2T_tmp[task_id].reshape(p1 - p0, nvir_seg, nocc, nocc)
        t2Tnew[:, p0:p1] -= tmp.transpose(1, 0, 2, 3)
        tmp = None
    t2T_tmp = None

    Ftmp = Foo  #+ 0.5 * np.dot(Fov, t1T)
    tmp = einsum('abim, mj -> abij', t2T, Ftmp)
    t2Tnew -= tmp
    t2Tnew += tmp.transpose(0, 1, 3, 2)
    tmp = None

    t2Tnew += np.asarray(eris.xvoo)
    tauT = make_tauT(t1T, t2T, vlocs=vlocs)
    Woooo = cc_Woooo(t1T, t2T, eris, tauT=tauT, vlocs=vlocs)
    Woooo *= 0.5
    t2Tnew += einsum('abmn, mnij -> abij', tauT, Woooo)
    Woooo = None

    Wvvvv = cc_Wvvvv(t1T, t2T, eris, tauT=tauT, vlocs=vlocs)
    for task_id, tauT_tmp, p0, p1 in _rotate_vir_block(tauT, vlocs=vlocs):
        tmp = einsum('abef, efij -> abij', Wvvvv[:, :, p0:p1], tauT_tmp)
        tmp *= 0.5
        t2Tnew += tmp
        tmp = tauT_tmp = None
    Wvvvv = None
    tauT = None

    #tmp = einsum('mbje, ei -> bmij', eris.oxov, t1T) # [b]mij
    #tmp = mpi.allgather(tmp) # bmij
    #tmp = einsum('am, bmij -> abij', t1T[vloc0:vloc1], tmp) # [a]bij
    tmp = 0.0

    Wvovo = cc_Wovvo(t1T, t2T, eris, vlocs=vlocs).transpose(2, 0, 1, 3)
    for task_id, w_tmp, p0, p1 in _rotate_vir_block(Wvovo, vlocs=vlocs):
        tmp += einsum('aeim, embj -> abij', t2T[:, p0:p1], w_tmp)
        w_tmp = None
    Wvovo = None

    tmp = tmp - tmp.transpose(0, 1, 3, 2)
    t2Tnew += tmp
    tmpT = mpi.alltoall_new([tmp[:, p0:p1] for p0, p1 in vlocs],
                            split_recvbuf=True)
    for task_id, (p0, p1) in enumerate(vlocs):
        tmp = tmpT[task_id].reshape(p1 - p0, nvir_seg, nocc, nocc)
        t2Tnew[:, p0:p1] -= tmp.transpose(1, 0, 2, 3)
        tmp = None
    tmpT = None

    #tmp = einsum('ei, jeba -> abij', t1T, eris.ovvx)
    #t2Tnew += tmp
    #t2Tnew -= tmp.transpose(0, 1, 3, 2)

    #tmp = einsum('am, ijmb -> baij', t1T, eris.ooox.conj())
    #t2Tnew += tmp
    #tmpT = mpi.alltoall([tmp[:, p0:p1] for p0, p1 in vlocs],
    #                    split_recvbuf=True)
    #for task_id, (p0, p1) in enumerate(vlocs):
    #    tmp = tmpT[task_id].reshape(p1-p0, nvir_seg, nocc, nocc)
    #    t2Tnew[:, p0:p1] -= tmp.transpose(1, 0, 2, 3)
    #    tmp = None
    #tmpT = None

    eia = mo_e_o[:, None] - mo_e_v
    #t1Tnew /= eia.T
    for i in range(vloc0, vloc1):
        t2Tnew[i - vloc0] /= lib.direct_sum('i + jb -> bij', eia[:, i], eia)

    time0 = log.timer_debug1('update t1 t2', *time0)
    return t1Tnew.T, t2Tnew.transpose(2, 3, 0, 1)
Beispiel #17
0
def _make_eris_incore_ghf(mycc, mo_coeff=None, ao2mofn=None):
    """
    Make physist eri with incore ao2mo, for GGHF.
    """
    cput0 = (logger.process_clock(), logger.perf_counter())
    log = logger.Logger(mycc.stdout, mycc.verbose)
    _sync_(mycc)
    eris = gccsd._PhysicistsERIs()
    
    if rank == 0:
        eris._common_init_(mycc, mo_coeff)
        comm.bcast((eris.mo_coeff, eris.fock, eris.nocc, eris.mo_energy))
    else:
        eris.mol = mycc.mol
        eris.mo_coeff, eris.fock, eris.nocc, eris.mo_energy = comm.bcast(None)
    
    nocc = eris.nocc
    nao, nmo = eris.mo_coeff.shape

    nvir = nmo - nocc
    vlocs = [_task_location(nvir, task_id) for task_id in range(mpi.pool.size)]
    vloc0, vloc1 = vlocs[rank]
    vseg = vloc1 - vloc0
    
    if rank == 0:
        if callable(ao2mofn):
            raise NotImplementedError
        else:
            assert eris.mo_coeff.dtype == np.double
            eri = mycc._scf._eri
            if (nao == nmo) and (la.norm(eris.mo_coeff - np.eye(nmo)) < 1e-12):
                # ZHC NOTE special treatment for OO-CCD,
                # where the ao2mo is not needed for identity mo_coeff.
                from libdmet.utils import take_eri as fn
                o = np.arange(0, nocc)
                v = np.arange(nocc, nmo)
                if eri.size == nmo**4:
                    eri = ao2mo.restore(8, eri, nmo)
            else:
                if mycc.save_mem:
                    # ZHC NOTE the following is slower, although may save some memory.
                    def fn(x, mo0, mo1, mo2, mo3):
                        return ao2mo.general(x, (mo0, mo1, mo2, mo3),
                                             compact=False).reshape(mo0.shape[-1], mo1.shape[-1],
                                                                    mo2.shape[-1], mo3.shape[-1])
                    o = eris.mo_coeff[:, :nocc]
                    v = eris.mo_coeff[:, nocc:]
                    if eri.size == nao**4:
                        eri = ao2mo.restore(8, eri, nao)
                else:
                    from libdmet.utils import take_eri as fn
                    o = np.arange(0, nocc)
                    v = np.arange(nocc, nmo)
                    if mycc.remove_h2:
                        mycc._scf._eri = None
                        _release_regs(mycc, remove_h2=True)
                    eri = ao2mo.kernel(eri, eris.mo_coeff)
                    if eri.size == nmo**4:
                        eri = ao2mo.restore(8, eri, nmo)

    comm.Barrier()
    cput2 = log.timer('CCSD ao2mo initialization:     ', *cput0)
    
    # chunck and scatter:
    
    # 1. oooo
    if rank == 0:
        tmp = fn(eri, o, o, o, o)
        eris.oooo = tmp.transpose(0, 2, 1, 3) - tmp.transpose(0, 2, 3, 1)
        tmp = None
        mpi.bcast(eris.oooo)
    else:
        eris.oooo = mpi.bcast(None)
    cput3 = log.timer('CCSD bcast   oooo:              ', *cput2)
    
    # 2. xooo
    if rank == 0:
        tmp = fn(eri, v, o, o, o)
        eri_sliced = [tmp[p0:p1] for (p0, p1) in vlocs]
    else:
        tmp = None
        eri_sliced = None
    tmp = mpi.scatter_new(eri_sliced, root=0, data=tmp)
    eri_sliced = None
    eris.xooo = tmp.transpose(0, 2, 1, 3) - tmp.transpose(0, 2, 3, 1)
    tmp = None
    cput4 = log.timer('CCSD scatter xooo:              ', *cput3)
    
    # 3. xovo
    if rank == 0:
        tmp_vvoo = fn(eri, v, v, o, o)
        tmp_voov = fn(eri, v, o, o, v)
        # ZHC NOTE need to keep tmp_voov for xvoo
        eri_1 = [tmp_vvoo[p0:p1] for (p0, p1) in vlocs]
        eri_2 = [tmp_voov[p0:p1] for (p0, p1) in vlocs]
    else:
        tmp_vvoo = None
        tmp_voov = None
        eri_1 = None
        eri_2 = None

    tmp_1 = mpi.scatter_new(eri_1, root=0, data=tmp_vvoo)
    eri_1 = None
    tmp_vvoo = None
    
    tmp_2 = mpi.scatter_new(eri_2, root=0, data=tmp_voov)
    eri_2 = None
    tmp_voov = None
    
    eris.xovo = tmp_1.transpose(0, 2, 1, 3) - tmp_2.transpose(0, 2, 3, 1)
    tmp_1 = None
    cput5 = log.timer('CCSD scatter xovo:              ', *cput4)
    
    # 4. xvoo
    eris.xvoo = tmp_2.transpose(0, 3, 1, 2) - tmp_2.transpose(0, 3, 2, 1)
    tmp_2 = None
    cput6 = log.timer('CCSD scatter xvoo:              ', *cput5)
    
    # 5. 6. xovv, xvvo
    if rank == 0:
        tmp = fn(eri, v, v, o, v)
        eri_sliced = [tmp[p0:p1] for (p0, p1) in vlocs]
    else:
        tmp = None
        eri_sliced = None
    tmp_1 = mpi.scatter_new(eri_sliced, root=0, data=tmp)
    eri_sliced = None
    eris.xovv = tmp_1.transpose(0, 2, 1, 3) - tmp_1.transpose(0, 2, 3, 1)

    if rank == 0:
        tmp_2 = np.asarray(tmp.transpose(3, 2, 1, 0), order='C') # vovv
        tmp = None
        eri_sliced = [tmp_2[p0:p1] for (p0, p1) in vlocs]
    else:
        tmp_2 = None
        tmp = None
        eri_sliced = None
    tmp_2 = mpi.scatter_new(eri_sliced, root=0, data=tmp_2)
    eri_sliced = None
    
    eris.xvvo = tmp_1.transpose(0, 3, 1, 2) - tmp_2.transpose(0, 2, 3, 1)
    tmp_1 = None
    tmp_2 = None
    cput7 = log.timer('CCSD scatter xovv, xvvo:        ', *cput6)

    # 7. xvvv
    if rank == 0:
        tmp = fn(eri, v, v, v, v)
        if mycc.remove_h2:
            eri = None
            if mycc._scf is not None:
                mycc._scf._eri = None
        eri_sliced = [tmp[p0:p1] for (p0, p1) in vlocs]
    else:
        tmp = None
        eri_sliced = None
    tmp = mpi.scatter_new(eri_sliced, root=0, data=tmp)
    eri_sliced = None
    eris.xvvv = tmp.transpose(0, 2, 1, 3) - tmp.transpose(0, 2, 3, 1)
    tmp = None
    eri = None
    cput8 = log.timer('CCSD scatter xvvv:              ', *cput7)
    
    mycc._eris = eris
    log.timer('CCSD integral transformation   ', *cput0)
    return eris
Beispiel #18
0
def make_intermediates(mycc, t1, t2, eris):
    t1T = t1.T
    t2T = np.asarray(t2.transpose(2, 3, 0, 1), order='C')
    nvir_seg, nvir, nocc = t2T.shape[:3]
    t1 = t2 = None
    ntasks = mpi.pool.size
    vlocs = [_task_location(nvir, task_id) for task_id in range(ntasks)]
    vloc0, vloc1 = vlocs[rank]
    assert vloc1 - vloc0 == nvir_seg

    class _IMDS:
        pass

    imds = _IMDS()
    imds.ftmp = lib.H5TmpFile()
    dtype = t1T.dtype
    imds.woooo = imds.ftmp.create_dataset('woooo', (nocc, nocc, nocc, nocc),
                                          dtype)
    imds.wovvo = imds.ftmp.create_dataset('wovvo',
                                          (nocc, nvir_seg, nvir, nocc), dtype)
    imds.wovoo = imds.ftmp.create_dataset('wovoo',
                                          (nocc, nvir_seg, nocc, nocc), dtype)
    imds.wvvvo = imds.ftmp.create_dataset('wvvvo',
                                          (nvir_seg, nvir, nvir, nocc), dtype)

    foo = eris.fock[:nocc, :nocc]
    fov = eris.fock[:nocc, nocc:]
    fvo = eris.fock[nocc:, :nocc]
    fvv = eris.fock[nocc:, nocc:]

    #tauT = np.einsum('ai, bj -> abij', t1T[vloc0:vloc1] * 2.0, t1T, optimize=True)
    tauT = t2T

    v1 = np.array(fvv, copy=True)  #- np.dot(t1T, fov)
    #tmp = einsum('jbac, cj -> ba', eris.oxvv, t1T)
    tmp = 0.0
    v4 = 0.0

    eris_voov = eris.xvoo.transpose(0, 2, 3, 1)
    for task_id, eri_tmp, p0, p1 in _rotate_vir_block(eris_voov, vlocs=vlocs):
        tmp += einsum('cjka, bcjk -> ba', eri_tmp, tauT[:, p0:p1])
        v4 += einsum('dljb, cdkl -> jcbk', eri_tmp, t2T[:, p0:p1])
        eri_tmp = None
    eris_voov = None
    tmp *= 0.5
    v1 += mpi.allgather(tmp)

    v2 = np.array(foo, copy=True)  #+ np.dot(fov, t1T)
    #tmp  = einsum('kijb, bk -> ij', eris.ooox, t1T[vloc0:vloc1])
    tmp = einsum('bcik, bcjk -> ij', eris.xvoo, tauT)
    tmp *= 0.5
    v2 += mpi.allreduce_inplace(tmp)

    #v4 -= np.asarray(eris.oxov).transpose(0, 1, 3, 2)
    v4 -= np.asarray(eris.xovo).transpose(1, 0, 2, 3)

    v5 = fvo + mpi.allgather(einsum('kc, bcjk -> bj', fov, t2T))
    #tmp = fvo[vloc0:vloc1] #+ einsum('cdkl, dl -> ck', eris.xvoo, t1T)
    #v5 += mpi.allreduce(np.einsum('ck, bk, cj -> bj', tmp, t1T, t1T[vloc0:vloc1], optimize=True))

    #v5 += mpi.allreduce(einsum('kljc, cbkl -> bj', eris.ooox, t2T)) * 0.5
    v5 += mpi.allreduce_inplace(
        einsum('cjlk, cbkl -> bj', eris.xooo, t2T) * 0.5)
    tmp = 0.0
    # ZHC NOTE FIXME it seems that the tmp does not contribute to rdm
    for task_id, t2T_tmp, p0, p1 in _rotate_vir_block(t2T, vlocs=vlocs):
        #tmp += einsum('kbcd, cdjk -> bj', eris.oxvv[:, :, p0:p1], t2T_tmp)
        tmp -= einsum('bkcd, cdjk -> bj', eris.xovv[:, :, p0:p1], t2T_tmp)
        t2T_tmp = None
    tmp *= 0.5
    tmp = mpi.allgather(tmp)
    v5 -= tmp

    #w3  = np.array(v5[vloc0:vloc1], copy=True) #+ einsum('jcbk, bj -> ck', v4, t1T)
    #w3 += np.dot(v1[vloc0:vloc1], t1T)
    #w3 -= np.dot(t1T[vloc0:vloc1], v2)
    #w3  = mpi.allgather(w3)
    w3 = v5

    woooo = einsum('cdij, cdkl -> ijkl', eris.xvoo, tauT)
    woooo *= 0.25
    #woooo += einsum('jilc, ck -> jilk', eris.ooox, t1T[vloc0:vloc1])
    woooo = mpi.allreduce_inplace(woooo)
    woooo += np.asarray(eris.oooo) * 0.5
    imds.woooo[:] = woooo
    woooo = None

    # ZHC NOTE: wovvo, v4 has shape j[c]bk
    wovvo = v4  #+ einsum('jcbd, dk -> jcbk', eris.oxvv, t1T)

    #tmp = einsum('bdlj, dk -> bklj', eris.xvoo, t1T)
    #for task_id, tmp_2, p0, p1 in _rotate_vir_block(tmp, vlocs=vlocs):
    #    wovvo[:, :, p0:p1] += einsum('bklj, cl -> jcbk', tmp_2, t1T[vloc0:vloc1])
    #    tmp_2 = None
    #tmp = None

    #eris_vooo = eris.ooox.transpose(3, 2, 1, 0)
    #for task_id, eri_tmp, p0, p1 in _rotate_vir_block(eris_vooo, vlocs=vlocs):
    #    wovvo[:, :, p0:p1] -= einsum('bkjl, cl -> jcbk', eri_tmp, t1T[vloc0:vloc1])
    #    eri_tmp = None
    #eris_vooo = None
    imds.wovvo[:] = wovvo
    wovvo = None

    wovoo = 0.0
    for task_id, tauT_tmp, p0, p1 in _rotate_vir_block(tauT, vlocs=vlocs):
        #wovoo += einsum('icdb, dbjk -> icjk', eris.oxvv[:, :, p0:p1], tauT_tmp)
        wovoo -= einsum('cidb, dbjk -> icjk', eris.xovv[:, :, p0:p1], tauT_tmp)
        tauT_tmp = None
    wovoo *= 0.25

    #wovoo += np.asarray(eris.ooox.transpose(2, 3, 0, 1)) * 0.5
    wovoo += np.asarray(eris.xooo.transpose(1, 0, 2, 3)) * (-0.5)
    #wovoo += einsum('icbk, bj -> icjk', v4, t1T)

    tauT = tauT * 0.25
    #eris_vooo = eris.ooox.transpose(3, 0, 1, 2)
    eris_vooo = eris.xooo.transpose(0, 3, 2, 1)
    for task_id, eri_tmp, p0, p1 in _rotate_vir_block(eris_vooo, vlocs=vlocs):
        wovoo -= einsum('blij, cbkl -> icjk', eri_tmp, t2T[:, p0:p1])
        imds.wvvvo[:, :, p0:p1] = einsum('bcjl, ajlk -> bcak', tauT, eri_tmp)
        eri_tmp = None
    eris_vooo = None
    imds.wovoo[:] = wovoo
    wovoo = None
    tauT = None

    #v4 = v4.transpose(1, 0, 2, 3)
    #for task_id, v4_tmp, p0, p1 in _rotate_vir_block(v4, vlocs=vlocs):
    #    imds.wvvvo[:, p0:p1] += einsum('bj, cjak -> bcak', t1T[vloc0:vloc1], v4_tmp)
    #    v4_tmp = None
    #v4 = None

    #wvvvo = np.asarray(eris.ovvx).conj().transpose(3, 2, 1, 0) * 0.5
    wvvvo = np.asarray(eris.xvvo) * 0.5
    #eris_ovvv = eris.oxvv
    #for task_id, t2T_tmp, p0, p1 in _rotate_vir_block(t2T, vlocs=vlocs):
    #    wvvvo[:, p0:p1] -= einsum('kbad, cdjk -> bcaj', eris_ovvv, t2T_tmp)
    #    t2T_tmp = None
    eris_vovv = eris.xovv
    for task_id, t2T_tmp, p0, p1 in _rotate_vir_block(t2T, vlocs=vlocs):
        wvvvo[:, p0:p1] -= einsum('bkda, cdjk -> bcaj', eris_vovv, t2T_tmp)
        t2T_tmp = None

    imds.wvvvo -= wvvvo
    wvvvo = None

    imds.v1 = v1
    imds.v2 = v2
    imds.w3 = w3
    imds.ftmp.flush()
    return imds
Beispiel #19
0
def update_lambda(mycc, t1, t2, l1, l2, eris, imds):
    """
    Update GCCSD lambda.
    """
    time0 = logger.process_clock(), logger.perf_counter()
    log = logger.Logger(mycc.stdout, mycc.verbose)

    t1T = t1.T
    t2T = np.asarray(t2.transpose(2, 3, 0, 1), order='C')
    t1 = t2 = None
    nvir_seg, nvir, nocc = t2T.shape[:3]
    l1T = l1.T
    l2T = np.asarray(l2.transpose(2, 3, 0, 1), order='C')
    l1 = l2 = None

    ntasks = mpi.pool.size
    vlocs = [_task_location(nvir, task_id) for task_id in range(ntasks)]
    vloc0, vloc1 = vlocs[rank]
    log.debug2('vlocs %s', vlocs)
    assert vloc1 - vloc0 == nvir_seg

    fvo = eris.fock[nocc:, :nocc]
    mo_e_o = eris.mo_energy[:nocc]
    mo_e_v = eris.mo_energy[nocc:] + mycc.level_shift
    v1 = imds.v1 - np.diag(mo_e_v)
    v2 = imds.v2 - np.diag(mo_e_o)

    mba = einsum('cakl, cbkl -> ba', l2T, t2T) * 0.5
    mba = mpi.allreduce_inplace(mba)
    mij = einsum('cdki, cdkj -> ij', l2T, t2T) * 0.5
    mij = mpi.allreduce_inplace(mij)
    # m3 [a]bij
    m3 = einsum('abkl, ijkl -> abij', l2T, np.asarray(imds.woooo))

    tauT = t2T  #+ np.einsum('ai, bj -> abij', t1T[vloc0:vloc1] * 2.0, t1T, optimize=True)
    tmp = einsum('cdij, cdkl -> ijkl', l2T, tauT)
    tmp = mpi.allreduce_inplace(tmp)
    tauT = None

    vvoo = np.asarray(eris.xvoo)
    tmp = einsum('abkl, ijkl -> abij', vvoo, tmp)
    tmp *= 0.25
    m3 += tmp
    tmp = None

    #tmp  = einsum('cdij, dk -> ckij', l2T, t1T)
    #for task_id, tmp, p0, p1 in _rotate_vir_block(tmp, vlocs=vlocs):
    #    m3 -= einsum('kcba, ckij -> abij', eris.ovvx[:, p0:p1], tmp)
    #    tmp = None
    eris_vvvv = eris.xvvv.transpose(2, 3, 0, 1)
    tmp_2 = np.empty_like(l2T)  # used for line 387
    for task_id, l2T_tmp, p0, p1 in _rotate_vir_block(l2T, vlocs=vlocs):
        tmp = einsum('cdij, cdab -> abij', l2T_tmp, eris_vvvv[p0:p1])
        tmp *= 0.5
        m3 += tmp
        tmp_2[:, p0:p1] = einsum('acij, cb -> baij', l2T_tmp, v1[:,
                                                                 vloc0:vloc1])
        tmp = l2T_tmp = None
    eris_vvvv = None

    #l1Tnew = einsum('abij, bj -> ai', m3, t1T)
    #l1Tnew = mpi.allgather(l1Tnew)
    l1Tnew = np.zeros_like(l1T)
    l2Tnew = m3

    l2Tnew += vvoo
    #fvo1 = fvo #+ mpi.allreduce(einsum('cbkj, ck -> bj', vvoo, t1T[vloc0:vloc1]))

    #tmp  = np.einsum('ai, bj -> abij', l1T[vloc0:vloc1], fvo1, optimize=True)
    tmp = 0.0
    wvovo = np.asarray(imds.wovvo).transpose(1, 0, 2, 3)
    for task_id, w_tmp, p0, p1 in _rotate_vir_block(wvovo, vlocs=vlocs):
        tmp -= einsum('acki, cjbk -> abij', l2T[:, p0:p1], w_tmp)
        w_tmp = None
    wvovo = None
    tmp = tmp - tmp.transpose(0, 1, 3, 2)
    l2Tnew += tmp
    tmpT = mpi.alltoall_new([tmp[:, p0:p1] for p0, p1 in vlocs],
                            split_recvbuf=True)
    for task_id, (p0, p1) in enumerate(vlocs):
        tmp = tmpT[task_id].reshape(p1 - p0, nvir_seg, nocc, nocc)
        l2Tnew[:, p0:p1] -= tmp.transpose(1, 0, 2, 3)
        tmp = None

    #tmp  = einsum('ak, ijkb -> baij', l1T, eris.ooox)
    #tmp -= tmp_2
    tmp = -tmp_2
    tmp1vv = mba  #+ np.dot(t1T, l1T.T) # ba

    tmp -= einsum('ca, bcij -> baij', tmp1vv, vvoo)
    l2Tnew += tmp
    tmpT = mpi.alltoall_new([tmp[:, p0:p1] for p0, p1 in vlocs],
                            split_recvbuf=True)
    for task_id, (p0, p1) in enumerate(vlocs):
        tmp = tmpT[task_id].reshape(p1 - p0, nvir_seg, nocc, nocc)
        l2Tnew[:, p0:p1] -= tmp.transpose(1, 0, 2, 3)
        tmp = None

    #tmp  = einsum('jcab, ci -> baji', eris.ovvx, -l1T)
    tmp = einsum('abki, jk -> abij', l2T, v2)
    tmp1oo = mij  #+ np.dot(l1T.T, t1T) # ik
    tmp -= einsum('ik, abkj -> abij', tmp1oo, vvoo)
    vvoo = None
    l2Tnew += tmp
    l2Tnew -= tmp.transpose(0, 1, 3, 2)
    tmp = None

    #l1Tnew += fvo
    #tmp = einsum('bj, ibja -> ai', -l1T[vloc0:vloc1], eris.oxov)
    #l1Tnew += np.dot(v1.T, l1T)
    #l1Tnew -= np.dot(l1T, v2.T)
    #tmp -= einsum('cakj, icjk -> ai', l2T, imds.wovoo)
    #tmp -= einsum('bcak, bcik -> ai', imds.wvvvo, l2T)
    #tmp += einsum('baji, bj -> ai', l2T, imds.w3[vloc0:vloc1])

    #tmp_2  = t1T[vloc0:vloc1] - np.dot(tmp1vv[vloc0:vloc1], t1T)
    #tmp_2 -= np.dot(t1T[vloc0:vloc1], mij)
    #tmp_2 += einsum('bcjk, ck -> bj', t2T, l1T)

    #tmp += einsum('baji, bj -> ai', vvoo, tmp_2)
    #tmp_2 = None

    #tmp += einsum('icab, bc -> ai', eris.oxvv, tmp1vv[:, vloc0:vloc1])
    #l1Tnew += mpi.allreduce(tmp)
    #l1Tnew -= mpi.allgather(einsum('jika, kj -> ai', eris.ooox, tmp1oo))

    #tmp = fvo - mpi.allreduce(einsum('bakj, bj -> ak', vvoo, t1T[vloc0:vloc1]))
    #vvoo = None
    #l1Tnew -= np.dot(tmp, mij.T)
    #l1Tnew -= np.dot(mba.T, tmp)

    eia = mo_e_o[:, None] - mo_e_v
    #l1Tnew /= eia.T
    for i in range(vloc0, vloc1):
        l2Tnew[i - vloc0] /= lib.direct_sum('i + jb -> bij', eia[:, i], eia)

    time0 = log.timer_debug1('update l1 l2', *time0)
    return l1Tnew.T, l2Tnew.transpose(2, 3, 0, 1)
Beispiel #20
0
def _make_eris_incore(mycc, mo_coeff=None, ao2mofn=None):
    """
    Make physist eri with incore ao2mo.
    """
    cput0 = (logger.process_clock(), logger.perf_counter())
    log = logger.Logger(mycc.stdout, mycc.verbose)
    _sync_(mycc)
    eris = gccsd._PhysicistsERIs()
    
    if rank == 0:
        eris._common_init_(mycc, mo_coeff)
        comm.bcast((eris.mo_coeff, eris.fock, eris.nocc, eris.mo_energy))
    else:
        eris.mol = mycc.mol
        eris.mo_coeff, eris.fock, eris.nocc, eris.mo_energy = comm.bcast(None)
    
    # if workers does not have _eri, bcast from root
    if comm.allreduce(mycc._scf._eri is None, op=mpi.MPI.LOR):
        if rank == 0:
            mpi.bcast(mycc._scf._eri)
        else:
            mycc._scf._eri = mpi.bcast(None)
    cput1 = log.timer('CCSD ao2mo initialization:     ', *cput0)

    nocc = eris.nocc
    nao, nmo = eris.mo_coeff.shape
    nvir = nmo - nocc
    vlocs = [_task_location(nvir, task_id) for task_id in range(mpi.pool.size)]
    vloc0, vloc1 = vlocs[rank]
    vseg = vloc1 - vloc0
    
    plocs = [_task_location(nmo, task_id) for task_id in range(mpi.pool.size)]
    ploc0, ploc1 = plocs[rank]
    pseg = ploc1 - ploc0
    
    mo_a = eris.mo_coeff[:nao//2]
    mo_b = eris.mo_coeff[nao//2:]
    mo_seg_a = mo_a[:, ploc0:ploc1]
    mo_seg_b = mo_b[:, ploc0:ploc1]
    
    fname = "gccsd_eri_tmp_%s.h5"%rank
    f = h5py.File(fname, 'w')
    eri_phys = f.create_dataset('eri_phys', (pseg, nmo, nmo, nmo), 'f8', 
                                chunks=(pseg, 1, nmo, nmo))
    
    eri_a = ao2mo.incore.half_e1(mycc._scf._eri, (mo_seg_a, mo_a), compact=False)
    eri_b = ao2mo.incore.half_e1(mycc._scf._eri, (mo_seg_b, mo_b), compact=False)
    cput1 = log.timer('CCSD ao2mo half_e1:            ', *cput1)

    unit = pseg * nmo * nmo * 2
    mem_now = lib.current_memory()[0]
    max_memory = max(0, mycc.max_memory - mem_now)
    blksize = min(nmo, max(BLKMIN, int((max_memory*0.9e6/8)/unit)))

    for p0, p1 in lib.prange(0, nmo, blksize):
        klmosym_a, nkl_pair_a, mokl_a, klshape_a = \
                ao2mo.incore._conc_mos(mo_a[:, p0:p1], mo_a, compact=False)
        klmosym_b, nkl_pair_b, mokl_b, klshape_b = \
                ao2mo.incore._conc_mos(mo_b[:, p0:p1], mo_b, compact=False)
        
        eri  = _ao2mo.nr_e2(eri_a, mokl_a, klshape_a, aosym='s4', mosym=klmosym_a)
        eri += _ao2mo.nr_e2(eri_a, mokl_b, klshape_b, aosym='s4', mosym=klmosym_b)
        eri += _ao2mo.nr_e2(eri_b, mokl_a, klshape_a, aosym='s4', mosym=klmosym_a)
        eri += _ao2mo.nr_e2(eri_b, mokl_b, klshape_b, aosym='s4', mosym=klmosym_b)
        
        eri = eri.reshape(pseg, nmo, p1-p0, nmo)
        eri_phys[:, p0:p1] = eri.transpose(0, 2, 1, 3) - eri.transpose(0, 2, 3, 1)
        eri = None
    eri_a = None
    eri_b = None
    
    f.close()
    comm.Barrier()
    cput1 = log.timer('CCSD ao2mo nr_e2:              ', *cput1)

    o_idx = -1
    v_idx = mpi.pool.size
    for r, (p0, p1) in enumerate(plocs):
        if p0 <= nocc - 1 < p1:
            o_idx = r
        if p0 <= nocc < p1:
            v_idx = r
            break
    o_files = np.arange(mpi.pool.size)[:(o_idx+1)]
    v_files = np.arange(mpi.pool.size)[v_idx:]

    eris.oooo = np.empty((nocc, nocc, nocc, nocc))
    eris.xooo = np.empty((vseg, nocc, nocc, nocc))
    eris.xovo = np.empty((vseg, nocc, nvir, nocc))
    eris.xovv = np.empty((vseg, nocc, nvir, nvir))
    eris.xvvo = np.empty((vseg, nvir, nvir, nocc))
    eris.xvoo = np.empty((vseg, nvir, nocc, nocc))
    eris.xvvv = np.empty((vseg, nvir, nvir, nvir))
    for r in range(mpi.pool.size):
        f = lib.H5TmpFile(filename="gccsd_eri_tmp_%s.h5"%r, mode='r')
        eri_phys = f["eri_phys"]
        if r in o_files:
            p0, p1 = plocs[r]
            p1 = min(p1, nocc)
            pseg = p1 - p0
            if pseg > 0:
                eris.oooo[p0:p1] = eri_phys[:pseg, :nocc, :nocc, :nocc]
        
        if r in v_files:
            p00, p10 = plocs[r]
            p0 = max(p00, nocc+vloc0)
            p1 = min(p10, nocc+vloc1)
            pseg = p1 - p0
            if pseg > 0:
                eris.xooo[p0-(nocc+vloc0):p1-(nocc+vloc0)] = eri_phys[p0-p00:p1-p00, :nocc, :nocc, :nocc]
                eris.xovo[p0-(nocc+vloc0):p1-(nocc+vloc0)] = eri_phys[p0-p00:p1-p00, :nocc, nocc:, :nocc]
                eris.xvoo[p0-(nocc+vloc0):p1-(nocc+vloc0)] = eri_phys[p0-p00:p1-p00, nocc:, :nocc, :nocc]
                eris.xvvo[p0-(nocc+vloc0):p1-(nocc+vloc0)] = eri_phys[p0-p00:p1-p00, nocc:, nocc:, :nocc]
                eris.xovv[p0-(nocc+vloc0):p1-(nocc+vloc0)] = eri_phys[p0-p00:p1-p00, :nocc, nocc:, nocc:]
                eris.xvvv[p0-(nocc+vloc0):p1-(nocc+vloc0)] = eri_phys[p0-p00:p1-p00, nocc:, nocc:, nocc:]
    cput1 = log.timer('CCSD ao2mo load:               ', *cput1)

    f.close() 
    comm.Barrier()
    os.remove("gccsd_eri_tmp_%s.h5"%rank)
    mycc._eris = eris
    log.timer('CCSD integral transformation   ', *cput0)
    return eris