示例#1
0
def _getitem(h5group, label, kpti_kptj, kptij_lst, ignore_key_error=False):
    k_id = member(kpti_kptj, kptij_lst)
    if len(k_id) > 0:
        key = label + '/' + str(k_id[0])
        if key not in h5group:
            if ignore_key_error:
                return numpy.zeros(0)
            else:
                raise KeyError('Key "%s" not found' % key)
        hermi = False
    else:
        # swap ki,kj due to the hermiticity
        kptji = kpti_kptj[[1, 0]]
        k_id = member(kptji, kptij_lst)
        if len(k_id) == 0:
            raise RuntimeError(
                '%s for kpts %s is not initialized.\n'
                'You need to update the attribute .kpts then call '
                '.build() to initialize %s.' % (label, kpti_kptj, label))

        key = label + '/' + str(k_id[0])
        if key not in h5group:
            if ignore_key_error:
                return numpy.zeros(0)
            else:
                raise KeyError('Key "%s" not found' % key)
        hermi = True

    dat = _load_and_unpack(h5group[key], hermi)
    return dat
示例#2
0
文件: df.py 项目: zwang123/pyscf
    def __enter__(self):
        self.feri = h5py.File(self.cderi, 'r')
        kpti_kptj = numpy.asarray(self.kpti_kptj)
        kptij_lst = self.feri['%s-kptij'%self.label].value
        k_id = member(kpti_kptj, kptij_lst)
        if len(k_id) > 0:
            dat = self.feri['%s/%d' % (self.label, k_id[0])]
            if isinstance(dat, h5py.Group):
                # Check whether the integral tensor is stored with old data
                # foramt (v1.5.1 or older). The old format puts the entire
                # 3-index tensor in an HDF5 dataset. The new format divides
                # the tensor into pieces and stores them in different groups.
                # The code below combines the slices into a single tensor.
                dat = numpy.hstack([dat[str(i)] for i in range(len(dat))])

        else:
            # swap ki,kj due to the hermiticity
            kptji = kpti_kptj[[1,0]]
            k_id = member(kptji, kptij_lst)
            if len(k_id) == 0:
                raise RuntimeError('%s for kpts %s is not initialized.\n'
                                   'You need to update the attribute .kpts then call '
                                   '.build() to initialize %s.'
                                   % (self.label, kpti_kptj, self.label))
            dat = self.feri['%s/%d' % (self.label, k_id[0])]
            if isinstance(dat, h5py.Group):
                # integral file generated by v1.5.1 or older does not need the
                # code below.
                dat = numpy.hstack([dat[str(i)] for i in range(len(dat))])

            dat = _load_and_unpack(dat)
        return dat
示例#3
0
def get_member(kpti, kptj, kptij_lst):
    kptij = np.vstack((kpti,kptj))
    ijid = member(kptij, kptij_lst)
    dagger = False
    if len(ijid) == 0:
        kptji = np.vstack((kptj,kpti))
        ijid = member(kptji, kptij_lst)
        dagger = True
    return ijid, dagger
示例#4
0
文件: df.py 项目: zwang123/pyscf
 def has_kpts(self, kpts):
     if kpts is None:
         return True
     else:
         kpts = numpy.asarray(kpts).reshape(-1,3)
         if self.kpts_band is None:
             return all((len(member(kpt, self.kpts))>0) for kpt in kpts)
         else:
             return all((len(member(kpt, self.kpts))>0 or
                         len(member(kpt, self.kpts_band))>0) for kpt in kpts)
示例#5
0
文件: df.py 项目: chrinide/pyscf
 def has_kpts(self, kpts):
     if kpts is None:
         return True
     else:
         kpts = numpy.asarray(kpts).reshape(-1,3)
         if self.kpts_band is None:
             return all((len(member(kpt, self.kpts))>0) for kpt in kpts)
         else:
             return all((len(member(kpt, self.kpts))>0 or
                         len(member(kpt, self.kpts_band))>0) for kpt in kpts)
示例#6
0
文件: df.py 项目: zwgsyntax/pyscf
    def build(self, j_only=None, with_j3c=True, kpts_band=None):
        if self.kpts_band is not None:
            self.kpts_band = numpy.reshape(self.kpts_band, (-1, 3))
        if kpts_band is not None:
            kpts_band = numpy.reshape(kpts_band, (-1, 3))
            if self.kpts_band is None:
                self.kpts_band = kpts_band
            else:
                self.kpts_band = unique(
                    numpy.vstack((self.kpts_band, kpts_band)))[0]

        self.check_sanity()
        self.dump_flags()

        self.auxcell = make_modrho_basis(self.cell, self.auxbasis,
                                         self.exp_to_discard)

        if self.kpts_band is None:
            kpts = self.kpts
            kband_uniq = numpy.zeros((0, 3))
        else:
            kpts = self.kpts
            kband_uniq = [
                k for k in self.kpts_band if len(member(k, kpts)) == 0
            ]
        if j_only is None:
            j_only = self._j_only
        if j_only:
            kall = numpy.vstack([kpts, kband_uniq])
            kptij_lst = numpy.hstack((kall, kall)).reshape(-1, 2, 3)
        else:
            kptij_lst = [(ki, kpts[j]) for i, ki in enumerate(kpts)
                         for j in range(i + 1)]
            kptij_lst.extend([(ki, kj) for ki in kband_uniq for kj in kpts])
            kptij_lst.extend([(ki, ki) for ki in kband_uniq])
            kptij_lst = numpy.asarray(kptij_lst)

        if with_j3c:
            if isinstance(self._cderi_to_save, str):
                cderi = self._cderi_to_save
            else:
                cderi = self._cderi_to_save.name
            if isinstance(self._cderi, str):
                if self._cderi == cderi and os.path.isfile(cderi):
                    logger.warn(
                        self, 'DF integrals in %s (specified by '
                        '._cderi) is overwritten by GDF '
                        'initialization. ', cderi)
                else:
                    logger.warn(
                        self, 'Value of ._cderi is ignored. '
                        'DF integrals will be saved in file %s .', cderi)
            self._cderi = cderi
            t1 = (time.clock(), time.time())
            self._make_j3c(self.cell, self.auxcell, kptij_lst, cderi)
            t1 = logger.timer_debug1(self, 'j3c', *t1)
        return self
示例#7
0
文件: df.py 项目: BoothGroup/pyscf
    def build(self, j_only=None, with_j3c=True, kpts_band=None):
        if self.kpts_band is not None:
            self.kpts_band = numpy.reshape(self.kpts_band, (-1,3))
        if kpts_band is not None:
            kpts_band = numpy.reshape(kpts_band, (-1,3))
            if self.kpts_band is None:
                self.kpts_band = kpts_band
            else:
                self.kpts_band = unique(numpy.vstack((self.kpts_band,kpts_band)))[0]

        self.check_sanity()
        self.dump_flags()

        self.auxcell = make_modrho_basis(self.cell, self.auxbasis,
                                         self.exp_to_discard)

        # Remove duplicated k-points. Duplicated kpts may lead to a buffer
        # located in incore.wrap_int3c larger than necessary. Integral code
        # only fills necessary part of the buffer, leaving some space in the
        # buffer unfilled.
        uniq_idx = unique(self.kpts)[1]
        kpts = numpy.asarray(self.kpts)[uniq_idx]
        if self.kpts_band is None:
            kband_uniq = numpy.zeros((0,3))
        else:
            kband_uniq = [k for k in self.kpts_band if len(member(k, kpts))==0]
        if j_only is None:
            j_only = self._j_only
        if j_only:
            kall = numpy.vstack([kpts,kband_uniq])
            kptij_lst = numpy.hstack((kall,kall)).reshape(-1,2,3)
        else:
            kptij_lst = [(ki, kpts[j]) for i, ki in enumerate(kpts) for j in range(i+1)]
            kptij_lst.extend([(ki, kj) for ki in kband_uniq for kj in kpts])
            kptij_lst.extend([(ki, ki) for ki in kband_uniq])
            kptij_lst = numpy.asarray(kptij_lst)

        if with_j3c:
            if isinstance(self._cderi_to_save, str):
                cderi = self._cderi_to_save
            else:
                cderi = self._cderi_to_save.name
            if isinstance(self._cderi, str):
                if self._cderi == cderi and os.path.isfile(cderi):
                    logger.warn(self, 'DF integrals in %s (specified by '
                                '._cderi) is overwritten by GDF '
                                'initialization. ', cderi)
                else:
                    logger.warn(self, 'Value of ._cderi is ignored. '
                                'DF integrals will be saved in file %s .',
                                cderi)
            self._cderi = cderi
            t1 = (logger.process_clock(), logger.perf_counter())
            self._make_j3c(self.cell, self.auxcell, kptij_lst, cderi)
            t1 = logger.timer_debug1(self, 'j3c', *t1)
        return self
示例#8
0
文件: df.py 项目: xlzan/pyscf
 def __enter__(self):
     self.feri = h5py.File(self.cderi, 'r')
     kpti_kptj = numpy.asarray(self.kpti_kptj)
     kptij_lst = self.feri['%s-kptij'%self.label].value
     k_id = member(kpti_kptj, kptij_lst)
     if len(k_id) > 0:
         dat = self.feri['%s/%d' % (self.label,k_id[0])]
     else:
         # swap ki,kj due to the hermiticity
         kptji = kpti_kptj[[1,0]]
         k_id = member(kptji, kptij_lst)
         if len(k_id) == 0:
             raise RuntimeError('%s for kpts %s is not initialized.\n'
                                'You need to update the attribute .kpts then call '
                                '.build() to initialize %s.'
                                % (self.label, kpti_kptj, self.label))
         dat = self.feri['%s/%d' % (self.label, k_id[0])]
         dat = _load_and_unpack(dat)
     return dat
示例#9
0
def _getitem(h5group, label, kpti_kptj, kptij_lst, ignore_key_error=False):
    k_id = member(kpti_kptj, kptij_lst)
    if len(k_id) > 0:
        key = label + '/' + str(k_id[0])
        if key not in h5group:
            if ignore_key_error:
                return numpy.zeros(0)
            else:
                raise KeyError('Key "%s" not found' % key)

        dat = h5group[key]
        if isinstance(dat, h5py.Group):
            # Check whether the integral tensor is stored with old data
            # foramt (v1.5.1 or older). The old format puts the entire
            # 3-index tensor in an HDF5 dataset. The new format divides
            # the tensor into pieces and stores them in different groups.
            # The code below combines the slices into a single tensor.
            dat = numpy.hstack([dat[str(i)] for i in range(len(dat))])

    else:
        # swap ki,kj due to the hermiticity
        kptji = kpti_kptj[[1, 0]]
        k_id = member(kptji, kptij_lst)
        if len(k_id) == 0:
            raise RuntimeError(
                '%s for kpts %s is not initialized.\n'
                'You need to update the attribute .kpts then call '
                '.build() to initialize %s.' % (label, kpti_kptj, label))

        key = label + '/' + str(k_id[0])
        if key not in h5group:
            if ignore_key_error:
                return numpy.zeros(0)
            else:
                raise KeyError('Key "%s" not found' % key)


#TODO: put the numpy.hstack() call in _load_and_unpack class to lazily load
# the 3D tensor if it is too big.
        dat = _load_and_unpack(h5group[key])
    return dat
示例#10
0
def _ewald_exxdiv_1d2d(cell, kpts, dms, vk, kpts_band=None):
    s = cell.pbc_intor('int1e_ovlp', hermi=1, kpts=kpts)
    madelung = tools.pbc.madelung(cell, kpts)

    Gv, Gvbase, kws = cell.get_Gv_weights(cell.mesh)
    G0idx, SI_on_z = gto.cell._SI_for_uniform_model_charge(cell, Gv)
    coulG = 4 * numpy.pi / numpy.linalg.norm(Gv[G0idx], axis=1)**2
    wcoulG = coulG * kws[G0idx]
    aoao_ij = ft_ao._ft_aopair_kpts(cell, Gv[G0idx], kptjs=kpts)
    aoao_kl = ft_ao._ft_aopair_kpts(cell, -Gv[G0idx], kptjs=kpts)

    def _contract_(vk, dms, s, aoao_ij, aoao_kl, kweight):
        # Without removing aoao(Gx=0,Gy=0), the summation of vk and ewald probe
        # charge correction (as _ewald_exxdiv_3d did) gives the reasonable
        # finite value for vk.  Here madelung constant and vk were calculated
        # without (Gx=0,Gy=0).  The code below restores the (Gx=0,Gy=0) part.
        madelung_mod = numpy.einsum('g,g,g', SI_on_z.conj(), wcoulG, SI_on_z)
        tmp_ij = numpy.einsum('gij,g,g->ij', aoao_ij, wcoulG, SI_on_z.conj())
        tmp_kl = numpy.einsum('gij,g,g->ij', aoao_kl, wcoulG, SI_on_z)
        for i, dm in enumerate(dms):
            #:aoaomod_ij = aoao_ij - numpy.einsum('g,ij->gij', SI_on_z       , s)
            #:aoaomod_kl = aoao_kl - numpy.einsum('g,ij->gij', SI_on_z.conj(), s)
            #:ktmp  = kweight * lib.einsum('gij,jk,g,gkl->il', aoao_ij   , dm, wcoulG, aoao_kl   )
            #:ktmp -= kweight * lib.einsum('gij,jk,g,gkl->il', aoaomod_ij, dm, wcoulG, aoaomod_kl)
            #:ktmp += (madelung - kweight*wcoulG.sum()) * reduce(numpy.dot, (s, dm, s))
            ktmp = kweight * lib.einsum('ij,jk,kl->il', tmp_ij, dm, s)
            ktmp += kweight * lib.einsum('ij,jk,kl->il', s, dm, tmp_kl)
            ktmp += (
                (madelung - kweight * wcoulG.sum() - kweight * madelung_mod) *
                reduce(numpy.dot, (s, dm, s)))
            if vk.dtype == numpy.double:
                vk[i] += ktmp.real
            else:
                vk[i] += ktmp

    if kpts is None:
        _contract_(vk, dms, s, aoao_ij[0], aoao_kl[0], 1)

    elif numpy.shape(kpts) == (3, ):
        if kpts_band is None or is_zero(kpts_band - kpts):
            _contract_(vk, dms, s, aoao_ij[0], aoao_kl[0], 1)

    elif kpts_band is None or numpy.array_equal(kpts, kpts_band):
        nkpts = len(kpts)
        for k in range(nkpts):
            _contract_(vk[:, k], dms[:, k], s[k], aoao_ij[k], aoao_kl[k],
                       1. / nkpts)
    else:
        nkpts = len(kpts)
        for k, kpt in enumerate(kpts):
            for kp in member(kpt, kpts_band.reshape(-1, 3)):
                _contract_(vk[:, kp], dms[:, k], s[k], aoao_ij[k], aoao_kl[k],
                           1. / nkpts)
示例#11
0
文件: df.py 项目: chrinide/pyscf
def _getitem(h5group, label, kpti_kptj, kptij_lst, ignore_key_error=False):
    k_id = member(kpti_kptj, kptij_lst)
    if len(k_id) > 0:
        key = label + '/' + str(k_id[0])
        if key not in h5group:
            if ignore_key_error:
                return numpy.zeros(0)
            else:
                raise KeyError('Key "%s" not found' % key)

        dat = h5group[key]
        if isinstance(dat, h5py.Group):
            # Check whether the integral tensor is stored with old data
            # foramt (v1.5.1 or older). The old format puts the entire
            # 3-index tensor in an HDF5 dataset. The new format divides
            # the tensor into pieces and stores them in different groups.
            # The code below combines the slices into a single tensor.
            dat = numpy.hstack([dat[str(i)] for i in range(len(dat))])

    else:
        # swap ki,kj due to the hermiticity
        kptji = kpti_kptj[[1,0]]
        k_id = member(kptji, kptij_lst)
        if len(k_id) == 0:
            raise RuntimeError('%s for kpts %s is not initialized.\n'
                               'You need to update the attribute .kpts then call '
                               '.build() to initialize %s.'
                               % (label, kpti_kptj, label))

        key = label + '/' + str(k_id[0])
        if key not in h5group:
            if ignore_key_error:
                return numpy.zeros(0)
            else:
                raise KeyError('Key "%s" not found' % key)

#TODO: put the numpy.hstack() call in _load_and_unpack class to lazily load
# the 3D tensor if it is too big.
        dat = _load_and_unpack(h5group[key])
    return dat
示例#12
0
文件: df.py 项目: chrinide/pyscf
    def build(self, j_only=None, with_j3c=True, kpts_band=None):
        if self.kpts_band is not None:
            self.kpts_band = numpy.reshape(self.kpts_band, (-1,3))
        if kpts_band is not None:
            kpts_band = numpy.reshape(kpts_band, (-1,3))
            if self.kpts_band is None:
                self.kpts_band = kpts_band
            else:
                self.kpts_band = unique(numpy.vstack((self.kpts_band,kpts_band)))[0]

        self.check_sanity()
        self.dump_flags()

        self.auxcell = make_modrho_basis(self.cell, self.auxbasis,
                                         self.exp_to_discard)

        if self.kpts_band is None:
            kpts = self.kpts
            kband_uniq = numpy.zeros((0,3))
        else:
            kpts = self.kpts
            kband_uniq = [k for k in self.kpts_band if len(member(k, kpts))==0]
        if j_only is None:
            j_only = self._j_only
        if j_only:
            kall = numpy.vstack([kpts,kband_uniq])
            kptij_lst = numpy.hstack((kall,kall)).reshape(-1,2,3)
        else:
            kptij_lst = [(ki, kpts[j]) for i, ki in enumerate(kpts) for j in range(i+1)]
            kptij_lst.extend([(ki, kj) for ki in kband_uniq for kj in kpts])
            kptij_lst.extend([(ki, ki) for ki in kband_uniq])
            kptij_lst = numpy.asarray(kptij_lst)

        if with_j3c:
            if isinstance(self._cderi_to_save, str):
                cderi = self._cderi_to_save
            else:
                cderi = self._cderi_to_save.name
            if isinstance(self._cderi, str):
                if self._cderi == cderi and os.path.isfile(cderi):
                    logger.warn(self, 'DF integrals in %s (specified by '
                                '._cderi) is overwritten by GDF '
                                'initialization. ', cderi)
                else:
                    logger.warn(self, 'Value of ._cderi is ignored. '
                                'DF integrals will be saved in file %s .',
                                cderi)
            self._cderi = cderi
            t1 = (time.clock(), time.time())
            self._make_j3c(self.cell, self.auxcell, kptij_lst, cderi)
            t1 = logger.timer_debug1(self, 'j3c', *t1)
        return self
示例#13
0
def _ewald_exxdiv_for_G0(cell, kpts, dms, vk, kpts_band=None):
    s = cell.pbc_intor('int1e_ovlp', hermi=1, kpts=kpts)
    madelung = tools.pbc.madelung(cell, kpts)
    if kpts is None:
        for i,dm in enumerate(dms):
            vk[i] += madelung * reduce(numpy.dot, (s, dm, s))
    elif numpy.shape(kpts) == (3,):
        if kpts_band is None or is_zero(kpts_band-kpts):
            for i,dm in enumerate(dms):
                vk[i] += madelung * reduce(numpy.dot, (s, dm, s))

    elif kpts_band is None or numpy.array_equal(kpts, kpts_band):
        for k in range(len(kpts)):
            for i,dm in enumerate(dms):
                vk[i,k] += madelung * reduce(numpy.dot, (s[k], dm[k], s[k]))
    else:
        for k, kpt in enumerate(kpts):
            for kp in member(kpt, kpts_band.reshape(-1,3)):
                for i,dm in enumerate(dms):
                    vk[i,kp] += madelung * reduce(numpy.dot, (s[k], dm[k], s[k]))
示例#14
0
文件: df_jk.py 项目: chrinide/pyscf
def _ewald_exxdiv_for_G0(cell, kpts, dms, vk, kpts_band=None):
    s = cell.pbc_intor('int1e_ovlp', hermi=1, kpts=kpts)
    madelung = tools.pbc.madelung(cell, kpts)
    if kpts is None:
        for i,dm in enumerate(dms):
            vk[i] += madelung * reduce(numpy.dot, (s, dm, s))
    elif numpy.shape(kpts) == (3,):
        if kpts_band is None or is_zero(kpts_band-kpts):
            for i,dm in enumerate(dms):
                vk[i] += madelung * reduce(numpy.dot, (s, dm, s))

    elif kpts_band is None or numpy.array_equal(kpts, kpts_band):
        for k in range(len(kpts)):
            for i,dm in enumerate(dms):
                vk[i,k] += madelung * reduce(numpy.dot, (s[k], dm[k], s[k]))
    else:
        for k, kpt in enumerate(kpts):
            for kp in member(kpt, kpts_band.reshape(-1,3)):
                for i,dm in enumerate(dms):
                    vk[i,kp] += madelung * reduce(numpy.dot, (s[k], dm[k], s[k]))
示例#15
0
    def build(self, j_only=None, with_j3c=True, kpts_band=None):
        log = Logger(self.stdout, self.verbose)
        if self.kpts_band is not None:
            self.kpts_band = np.reshape(self.kpts_band, (-1,3))
        if kpts_band is not None:
            kpts_band = np.reshape(kpts_band, (-1,3))
            if self.kpts_band is None:
                self.kpts_band = kpts_band
            else:
                self.kpts_band = unique(np.vstack((self.kpts_band,kpts_band)))[0]

        self.check_sanity()
        self.dump_flags()

        self.auxcell = df.df.make_modrho_basis(self.cell, self.auxbasis,
                                         self.exp_to_discard)

        if self.kpts_band is None:
            kpts = self.kpts
            kband_uniq = np.zeros((0,3))
        else:
            kpts = self.kpts
            kband_uniq = [k for k in self.kpts_band if len(member(k, kpts))==0]
        if j_only is None:
            j_only = self._j_only
        if j_only:
            kall = np.vstack([kpts,kband_uniq])
            kptij_lst = np.hstack((kall,kall)).reshape(-1,2,3)
        else:
            kptij_lst = [(ki, kpts[j]) for i, ki in enumerate(kpts) for j in range(i+1)]
            kptij_lst.extend([(ki, kj) for ki in kband_uniq for kj in kpts])
            kptij_lst.extend([(ki, ki) for ki in kband_uniq])
            kptij_lst = np.asarray(kptij_lst)

        t1 = (time.clock(), time.time())
        self._make_j3c(self.cell, self.auxcell, kptij_lst)
        t1 = log.timer('j3c', *t1)
        return self
示例#16
0
def _make_j3c(mydf, cell, auxcell, kptij_lst):
    max_memory = max(2000, mydf.max_memory-pyscflib.current_memory()[0])
    fused_cell, fuse = df.df.fuse_auxcell(mydf, auxcell)
    log = Logger(mydf.stdout, mydf.verbose)
    nao, nfao = cell.nao_nr(), fused_cell.nao_nr()
    jobs = np.arange(fused_cell.nbas)
    tasks = list(static_partition(jobs))
    ntasks = max(comm.allgather(len(tasks)))
    j3c_junk = ctf.zeros([len(kptij_lst), nao**2, nfao], dtype=np.complex128)
    t1 =  t0 = (time.clock(), time.time())

    idx_full = np.arange(j3c_junk.size).reshape(j3c_junk.shape)
    if len(tasks) > 0:
        q0, q1 = tasks[0], tasks[-1] + 1
        shls_slice = (0, cell.nbas, 0, cell.nbas, q0, q1)
        bstart, bend = fused_cell.ao_loc_nr()[q0], fused_cell.ao_loc_nr()[q1]
        idx = idx_full[:,:,bstart:bend].ravel()
        tmp = df.incore.aux_e2(cell, fused_cell, intor='int3c2e', aosym='s2', kptij_lst=kptij_lst, shls_slice=shls_slice)
        nao_pair = nao**2
        if tmp.shape[-2] != nao_pair and tmp.ndim == 2:
            tmp = pyscflib.unpack_tril(tmp, axis=0).reshape(nao_pair,-1)
        j3c_junk.write(idx, tmp.ravel())
    else:
        j3c_junk.write([],[])

    t1 = log.timer('j3c_junk', *t1)

    naux = auxcell.nao_nr()
    mesh = mydf.mesh
    Gv, Gvbase, kws = cell.get_Gv_weights(mesh)
    b = cell.reciprocal_vectors()
    gxyz = pyscflib.cartesian_prod([np.arange(len(x)) for x in Gvbase])
    ngrids = gxyz.shape[0]

    kptis = kptij_lst[:,0]
    kptjs = kptij_lst[:,1]
    kpt_ji = kptjs - kptis
    mydf.kptij_lst = kptij_lst
    uniq_kpts, uniq_index, uniq_inverse = unique(kpt_ji)

    jobs = np.arange(len(uniq_kpts))
    tasks = list(static_partition(jobs))
    ntasks = max(comm.allgather(len(tasks)))

    blksize = max(2048, int(max_memory*.5e6/16/fused_cell.nao_nr()))
    log.debug2('max_memory %s (MB)  blocksize %s', max_memory, blksize)
    j2c  = ctf.zeros([len(uniq_kpts),naux,naux], dtype=np.complex128)

    a = cell.lattice_vectors() / (2*np.pi)
    def kconserve_indices(kpt):
        '''search which (kpts+kpt) satisfies momentum conservation'''
        kdif = np.einsum('wx,ix->wi', a, uniq_kpts + kpt)
        kdif_int = np.rint(kdif)
        mask = np.einsum('wi->i', abs(kdif - kdif_int)) < KPT_DIFF_TOL
        uniq_kptji_ids = np.where(mask)[0]
        return uniq_kptji_ids

    def cholesky_decomposed_metric(j2c_kptij):
        j2c_negative = None
        try:
            j2c_kptij = scipy.linalg.cholesky(j2c_kptij, lower=True)
            j2ctag = 'CD'
        except scipy.linalg.LinAlgError as e:
            w, v = scipy.linalg.eigh(j2c_kptij)
            log.debug('cond = %.4g, drop %d bfns',
                      w[-1]/w[0], np.count_nonzero(w<mydf.linear_dep_threshold))
            v1 = np.zeros(v.T.shape, dtype=v.dtype)
            v1[w>mydf.linear_dep_threshold,:] = v[:,w>mydf.linear_dep_threshold].conj().T
            v1[w>mydf.linear_dep_threshold,:] /= np.sqrt(w[w>mydf.linear_dep_threshold]).reshape(-1,1)
            j2c_kptij = v1
            if cell.dimension == 2 and cell.low_dim_ft_type != 'inf_vacuum':
                idx = np.where(w < -mydf.linear_dep_threshold)[0]
                if len(idx) > 0:
                    j2c_negative = np.zeros(v1.shape, dtype=v1.dtype)
                    j2c_negative[idx,:] = (v[:,idx]/np.sqrt(-w[idx])).conj().T

            w = v = None
            j2ctag = 'eig'
        return j2c_kptij, j2c_negative, j2ctag

    for itask in range(ntasks):
        if itask >= len(tasks):
            j2c.write([],[])
            continue
        k = tasks[itask]
        kpt = uniq_kpts[k]
        j2ctmp = np.asarray(fused_cell.pbc_intor('int2c2e', hermi=1, kpts=kpt))
        coulG = mydf.weighted_coulG(kpt, False, mesh)
        for p0, p1 in pyscflib.prange(0, ngrids, blksize):
            aoaux = ft_ao.ft_ao(fused_cell, Gv[p0:p1], None, b, gxyz[p0:p1], Gvbase, kpt).T
            if is_zero(kpt):
                j2ctmp[naux:] -= np.dot(aoaux[naux:].conj()*coulG[p0:p1].conj(), aoaux.T).real
                j2ctmp[:naux,naux:] = j2ctmp[naux:,:naux].T
            else:
                j2ctmp[naux:] -= np.dot(aoaux[naux:].conj()*coulG[p0:p1].conj(), aoaux.T)
                j2ctmp[:naux,naux:] = j2ctmp[naux:,:naux].T.conj()

        tmp = fuse(fuse(j2ctmp).T).T
        idx = k * naux**2 + np.arange(naux**2)
        j2c.write(idx, tmp.ravel())
        j2ctmp = tmp = None

    coulG = None
    t1 = log.timer('j2c', *t1)

    j3c = ctf.zeros([len(kpt_ji),nao,nao,naux], dtype=np.complex128)
    jobs = np.arange(len(kpt_ji))
    tasks = list(static_partition(jobs))
    ntasks = max(comm.allgather(len(tasks)))

    for itask in range(ntasks):
        if itask >= len(tasks):
            j2c_ji = j2c.read([])
            j3ctmp = j3c_junk.read([])
            j3c.write([],[])
            continue
        idx_ji = tasks[itask]
        kpti, kptj = kptij_lst[idx_ji]
        idxi, idxj = member(kpti, mydf.kpts), member(kptj, mydf.kpts)
        uniq_idx = uniq_inverse[idx_ji]
        kpt = uniq_kpts[uniq_idx]
        id_eq = kconserve_indices(-kpt)
        id_conj = kconserve_indices(kpt)
        id_conj = np.asarray([i for i in id_conj if i not in id_eq], dtype=int)
        id_full = np.hstack((id_eq, id_conj))
        map_id, conj = min(id_full), np.argmin(id_full) >=len(id_eq)
        j2cidx = map_id * naux**2 + np.arange(naux**2)
        j2c_ji = j2c.read(j2cidx).reshape(naux, naux) # read to be added
        j2c_ji, j2c_negative, j2ctag = cholesky_decomposed_metric(j2c_ji)
        if conj: j2c_ji = j2c_ji.conj()
        shls_slice= (auxcell.nbas, fused_cell.nbas)
        Gaux = ft_ao.ft_ao(fused_cell, Gv, shls_slice, b, gxyz, Gvbase, kpt)
        wcoulG = mydf.weighted_coulG(kpt, False, mesh)
        Gaux *= wcoulG.reshape(-1,1)
        j3c_id = idx_ji * nao**2*nfao + np.arange(nao**2*nfao)
        j3ctmp = j3c_junk.read(j3c_id).reshape(nao**2, fused_cell.nao_nr()).T
        if is_zero(kpt):  # kpti == kptj
            if cell.dimension == 3:
                vbar = fuse(mydf.auxbar(fused_cell))
                ovlp = cell.pbc_intor('int1e_ovlp', hermi=1, kpts=kptj)
                for i in np.where(vbar != 0)[0]:
                    j3ctmp[i] -= vbar[i] * ovlp.reshape(-1)

        aoao = ft_ao._ft_aopair_kpts(cell, Gv, None, 's1', b, gxyz, Gvbase, kpt, kptj)[0].reshape(len(Gv),-1)
        j3ctmp[naux:] -= np.dot(Gaux.T.conj(), aoao)
        j3ctmp = fuse(j3ctmp)
        if j2ctag == 'CD':
            v = scipy.linalg.solve_triangular(j2c_ji, j3ctmp, lower=True, overwrite_b=True)
        else:
            v = np.dot(j2c_ji, j3ctmp)
        v = v.T.reshape(nao,nao,naux)
        j3c_id = idx_ji * nao**2*naux + np.arange(nao**2*naux)
        j3c.write(j3c_id, v.ravel())

    mydf.j3c = j3c
    return None