示例#1
0
    def _int2c(self) -> torch.Tensor:
        # 2-centre integral
        # this function works mostly in numpy
        # no gradients propagated in this function (and it's OK)
        # this function mostly replicate the `ft_aopair_kpts` function in pyscf
        # https://github.com/pyscf/pyscf/blob/master/pyscf/pbc/df/ft_ao.py
        # https://github.com/pyscf/pyscf/blob/c9aa2be600d75a97410c3203abf35046af8ca615/pyscf/pbc/df/ft_ao.py#L52
        assert len(self.wrappers) == 2

        # if the ls is too big, it might produce segfault
        if (self.ls.shape[0] > 1e6):
            warnings.warn("The number of neighbors in the integral is too many, "
                          "it might causes segfault")

        # libpbc will do in-place shift of the basis of one of the wrappers, so
        # we need to make a concatenated copy of the wrapper's atm_bas_env
        atm, bas, env, ao_loc = _concat_atm_bas_env(self.wrappers[0], self.wrappers[1])
        i0, i1 = self.wrappers[0].shell_idxs
        j0, j1 = self.wrappers[1].shell_idxs
        nshls0 = len(self.wrappers[0].parent)
        shls_slice = (i0, i1, j0 + nshls0, j1 + nshls0)

        # get the lattice translation vectors and the exponential factors
        expkl = np.asarray(np.exp(1j * np.dot(self.kpts_inp_np, self.ls.T)), order='C')

        # prepare the output
        nGv = self.GvT.shape[-1]
        nkpts = len(self.kpts_inp_np)
        outshape = (nkpts,) + self.comp_shape + tuple(w.nao() for w in self.wrappers) + (nGv,)
        out = np.empty(outshape, dtype=np.complex128)

        # do the integration
        cintor = getattr(CGTO, self.opname)
        eval_gz = CPBC.GTO_Gv_general
        fill = CPBC.PBC_ft_fill_ks1
        drv = CPBC.PBC_ft_latsum_drv
        p_gxyzT = c_null_ptr()
        p_mesh = (ctypes.c_int * 3)(0, 0, 0)
        p_b = (ctypes.c_double * 1)(0)
        drv(cintor, eval_gz, fill,
            np2ctypes(out),  # ???
            int2ctypes(nkpts),
            int2ctypes(self.ncomp),
            int2ctypes(len(self.ls)),
            np2ctypes(self.ls),
            np2ctypes(expkl),
            (ctypes.c_int * len(shls_slice))(*shls_slice),
            np2ctypes(ao_loc),
            np2ctypes(self.GvT),
            p_b, p_gxyzT, p_mesh,
            int2ctypes(nGv),
            np2ctypes(atm), int2ctypes(len(atm)),
            np2ctypes(bas), int2ctypes(len(bas)),
            np2ctypes(env))

        out_tensor = torch.as_tensor(out, dtype=get_complex_dtype(self.dtype),
                                     device=self.device)
        return out_tensor
示例#2
0
    def _int2c(self) -> torch.Tensor:
        # 2-centre integral
        # this function works mostly in numpy
        # no gradients propagated in this function (and it's OK)
        # this function mostly replicate the `intor_cross` function in pyscf
        # https://github.com/pyscf/pyscf/blob/master/pyscf/pbc/gto/cell.py
        # https://github.com/pyscf/pyscf/blob/f1321d5dd4fa103b5b04f10f31389c408949269d/pyscf/pbc/gto/cell.py#L345
        assert len(self.wrappers) == 2

        # libpbc will do in-place shift of the basis of one of the wrappers, so
        # we need to make a concatenated copy of the wrapper's atm_bas_env
        atm, bas, env, ao_loc = _concat_atm_bas_env(self.wrappers[0],
                                                    self.wrappers[1])
        i0, i1 = self.wrappers[0].shell_idxs
        j0, j1 = self.wrappers[1].shell_idxs
        nshls0 = len(self.wrappers[0].parent)
        shls_slice = (i0, i1, j0 + nshls0, j1 + nshls0)

        # prepare the output
        nkpts = len(self.kpts_inp_np)
        outshape = (nkpts, ) + self.comp_shape + tuple(w.nao()
                                                       for w in self.wrappers)
        out = np.empty(outshape, dtype=np.complex128)

        # TODO: add symmetry here
        fill = CPBC().PBCnr2c_fill_ks1
        fintor = getattr(CGTO(), self.opname)
        # TODO: use proper optimizers
        cintopt = _get_intgl_optimizer(self.opname, atm, bas, env)
        cpbcopt = c_null_ptr()

        # get the lattice translation vectors and the exponential factors
        expkl = np.asarray(np.exp(1j * np.dot(self.kpts_inp_np, self.ls.T)),
                           order='C')

        # if the ls is too big, it might produce segfault
        if (self.ls.shape[0] > 1e6):
            warnings.warn(
                "The number of neighbors in the integral is too many, "
                "it might causes segfault")

        # perform the integration
        drv = CPBC().PBCnr2c_drv
        drv(fintor, fill,
            out.ctypes.data_as(ctypes.c_void_p), int2ctypes(nkpts),
            int2ctypes(self.ncomp), int2ctypes(len(self.ls)),
            np2ctypes(self.ls),
            np2ctypes(expkl), (ctypes.c_int * len(shls_slice))(*shls_slice),
            np2ctypes(ao_loc), cintopt, cpbcopt, np2ctypes(atm),
            int2ctypes(atm.shape[0]), np2ctypes(bas), int2ctypes(bas.shape[0]),
            np2ctypes(env), int2ctypes(env.size))

        out_tensor = torch.as_tensor(out,
                                     dtype=get_complex_dtype(self.dtype),
                                     device=self.device)
        return out_tensor
示例#3
0
    def _int3c(self) -> torch.Tensor:
        # 3-centre integral
        # this function works mostly in numpy
        # no gradients propagated in this function (and it's OK)
        # this function mostly replicate the `aux_e2` and `wrap_int3c` functions in pyscf
        # https://github.com/pyscf/pyscf/blob/master/pyscf/pbc/df/incore.py
        # https://github.com/pyscf/pyscf/blob/f1321d5dd4fa103b5b04f10f31389c408949269d/pyscf/pbc/df/incore.py#L46
        assert len(self.wrappers) == 3

        # libpbc will do in-place shift of the basis of one of the wrappers, so
        # we need to make a concatenated copy of the wrapper's atm_bas_env
        atm, bas, env, ao_loc = _concat_atm_bas_env(*self.wrappers)
        i0, i1 = self.wrappers[0].shell_idxs
        j0, j1 = self.wrappers[1].shell_idxs
        k0, k1 = self.wrappers[2].shell_idxs
        nshls0 = len(self.wrappers[0].parent)
        nshls01 = len(self.wrappers[1].parent) + nshls0
        shls_slice = (i0, i1, j0 + nshls0, j1 + nshls0, k0 + nshls01,
                      k1 + nshls01)

        # kpts is actually kpts_ij in this function
        nkpts_ij = len(self.kpts_inp_np)
        outshape = (nkpts_ij, ) + self.comp_shape + tuple(
            w.nao() for w in self.wrappers)
        out = np.empty(outshape, dtype=np.complex128)

        # get the unique k-points
        kpts_i = self.kpts_inp_np[:, 0, :]  # (nkpts, NDIM)
        kpts_j = self.kpts_inp_np[:, 1, :]
        kpts_stack = np.concatenate((kpts_i, kpts_j), axis=0)
        kpt_diff_tol = self.options.kpt_diff_tol
        _, kpts_idxs = np.unique(np.floor(kpts_stack / kpt_diff_tol) *
                                 kpt_diff_tol,
                                 axis=0,
                                 return_index=True)
        kpts = kpts_stack[kpts_idxs, :]
        nkpts = len(kpts)
        expkl = np.asarray(np.exp(1j * np.dot(kpts, self.ls.T)), order="C")

        # get the kpts_ij_idxs
        # TODO: check if it is the index inverse from unique
        wherei = np.where(
            np.abs(kpts_i.reshape(-1, 1, 3) -
                   kpts).sum(axis=2) < kpt_diff_tol)[1]
        wherej = np.where(
            np.abs(kpts_j.reshape(-1, 1, 3) -
                   kpts).sum(axis=2) < kpt_diff_tol)[1]
        kpts_ij_idxs = np.asarray(wherei * nkpts + wherej, dtype=np.int32)

        # prepare the optimizers
        # TODO: use proper optimizers
        # NOTE: using _get_intgl_optimizer in this case produce wrong results (I don't know why)
        cintopt = c_null_ptr(
        )  # _get_intgl_optimizer(self.opname, atm, bas, env)
        cpbcopt = c_null_ptr()

        # do the integration
        drv = CPBC().PBCnr3c_drv
        fill = CPBC(
        ).PBCnr3c_fill_kks1  # TODO: optimize the kk-type and symmetry
        fintor = getattr(CPBC(), self.opname)
        drv(fintor, fill, np2ctypes(out), int2ctypes(nkpts_ij),
            int2ctypes(nkpts), int2ctypes(self.ncomp),
            int2ctypes(len(self.ls)), np2ctypes(self.ls), np2ctypes(expkl),
            np2ctypes(kpts_ij_idxs),
            (ctypes.c_int * len(shls_slice))(*shls_slice), np2ctypes(ao_loc),
            cintopt, cpbcopt, np2ctypes(atm), int2ctypes(atm.shape[0]),
            np2ctypes(bas), int2ctypes(bas.shape[0]), np2ctypes(env),
            int2ctypes(env.size))

        out_tensor = torch.as_tensor(out,
                                     dtype=get_complex_dtype(self.dtype),
                                     device=self.device)
        return out_tensor
示例#4
0
def gto_ft_evaluator(wrapper: LibcintWrapper,
                     Gvgrid: torch.Tensor) -> torch.Tensor:
    # evaluate Fourier Transform of the basis which is defined as
    # FT(f(r)) = integral(f(r) * exp(-ik.r) dr)

    # NOTE: this function do not propagate gradient and should only be used
    # in this file only
    # this is mainly from PySCF
    # https://github.com/pyscf/pyscf/blob/c9aa2be600d75a97410c3203abf35046af8ca615/pyscf/gto/ft_ao.py#L107

    assert Gvgrid.ndim == 2
    assert Gvgrid.shape[-1] == NDIM

    # Gvgrid: (ngrid, ndim)
    # returns: (nao, ngrid)
    dtype = wrapper.dtype
    device = wrapper.device

    fill = CGTO.GTO_ft_fill_s1
    if wrapper.spherical:
        intor = CGTO.GTO_ft_ovlp_sph
    else:
        intor = CGTO.GTO_ft_ovlp_cart
    fn = CGTO.GTO_ft_fill_drv

    eval_gz = CGTO.GTO_Gv_general
    p_gxyzT = c_null_ptr()
    p_gs = (ctypes.c_int * 3)(0, 0, 0)
    p_b = (ctypes.c_double * 1)(0)

    # add another dummy basis to provide the multiplier
    c = np.sqrt(4 * np.pi)  # s-type normalization
    ghost_basis = CGTOBasis(
        angmom=0,
        alphas=torch.tensor([0.], dtype=dtype, device=device),
        coeffs=torch.tensor([c], dtype=dtype, device=device),
        normalized=True,
    )
    ghost_atom_basis = AtomCGTOBasis(atomz=0,
                                     bases=[ghost_basis],
                                     pos=torch.tensor([0.0, 0.0, 0.0],
                                                      dtype=dtype,
                                                      device=device))
    ghost_wrapper = LibcintWrapper([ghost_atom_basis],
                                   spherical=wrapper.spherical,
                                   lattice=wrapper.lattice)
    wrapper, ghost_wrapper = LibcintWrapper.concatenate(wrapper, ghost_wrapper)
    shls_slice = (*wrapper.shell_idxs, *ghost_wrapper.shell_idxs)
    ao_loc = wrapper.full_shell_to_aoloc
    atm, bas, env = wrapper.atm_bas_env

    # prepare the Gvgrid
    GvT = np.asarray(Gvgrid.detach().numpy().T, order="C")
    nGv = Gvgrid.shape[0]

    # prepare the output matrix
    outshape = (wrapper.nao(), nGv)
    out = np.zeros(outshape, dtype=np.complex128, order="C")

    fn(intor, eval_gz, fill, np2ctypes(out), int2ctypes(1),
       (ctypes.c_int * len(shls_slice))(*shls_slice), np2ctypes(ao_loc),
       ctypes.c_double(0), np2ctypes(GvT), p_b, p_gxyzT, p_gs, int2ctypes(nGv),
       np2ctypes(atm), int2ctypes(len(atm)), np2ctypes(bas),
       int2ctypes(len(bas)), np2ctypes(env))

    return torch.as_tensor(out, dtype=get_complex_dtype(dtype), device=device)