class HybridXC(HybridXCBase):
    orbital_dependent = True

    def __init__(self, name, hybrid=None, xc=None,
                 alpha=None,
                 gamma_point=1,
                 method='standard',
                 bandstructure=False,
                 logfilename='-', bands=None,
                 fcut=1e-10,
                 molecule=False,
                 qstride=1,
                 world=None):
        """Mix standard functionals with exact exchange.

        name: str
            Name of functional: EXX, PBE0, HSE03, HSE06
        hybrid: float
            Fraction of exact exchange.
        xc: str or XCFunctional object
            Standard DFT functional with scaled down exchange.
        method: str
            Use 'standard' standard formula and 'acdf for
            adiabatic-connection dissipation fluctuation formula.
        alpha: float
            XXX describe
        gamma_point: bool
            0: Skip k2-k1=0 interactions.
            1: Use the alpha method.
            2: Integrate the gamma point.
        bandstructure: bool
            Calculate bandstructure instead of just the total energy.
        bands: list of int
            List of bands to calculate bandstructure for.  Default is
            all bands.
        molecule: bool
            Decouple electrostatic interactions between periodically
            repeated images.
        fcut: float
            Threshold for empty band.
        """

        self.alpha = alpha
        self.fcut = fcut

        self.gamma_point = gamma_point
        self.method = method
        self.bandstructure = bandstructure
        self.bands = bands

        self.fd = logfilename
        self.write_timing_information = True

        HybridXCBase.__init__(self, name, hybrid, xc)

        # EXX energies:
        self.exx = None  # total
        self.evv = None  # valence-valence (pseudo part)
        self.evvacdf = None  # valence-valence (pseudo part)
        self.devv = None  # valence-valence (PAW correction)
        self.evc = None  # valence-core
        self.ecc = None  # core-core

        self.exx_skn = None  # bandstructure

        self.qlatest = None

        if world is None:
            world = mpi.world
        self.world = world

        self.molecule = molecule
        
        if isinstance(qstride, int):
            qstride = [qstride] * 3
        self.qstride_c = np.asarray(qstride)
        
        self.timer = Timer()

    def log(self, *args, **kwargs):
        prnt(file=self.fd, *args, **kwargs)
        self.fd.flush()

    def calculate_radial(self, rgd, n_sLg, Y_L, v_sg,
                         dndr_sLg=None, rnablaY_Lv=None,
                         tau_sg=None, dedtau_sg=None):
        return self.xc.calculate_radial(rgd, n_sLg, Y_L, v_sg,
                                        dndr_sLg, rnablaY_Lv)
    
    def calculate_paw_correction(self, setup, D_sp, dEdD_sp=None,
                                 addcoredensity=True, a=None):
        return self.xc.calculate_paw_correction(setup, D_sp, dEdD_sp,
                                 addcoredensity, a)
    
    def initialize(self, dens, ham, wfs, occupations):
        assert wfs.bd.comm.size == 1

        self.xc.initialize(dens, ham, wfs, occupations)

        self.dens = dens
        self.wfs = wfs

        # Make a k-point descriptor that is not distributed
        # (self.kd.comm is serial_comm):
        self.kd = wfs.kd.copy()

        self.fd = logfile(self.fd, self.world.rank)

        wfs.initialize_wave_functions_from_restart_file()

    def set_positions(self, spos_ac):
        self.spos_ac = spos_ac

    def calculate(self, gd, n_sg, v_sg=None, e_g=None):
        # Normal XC contribution:
        exc = self.xc.calculate(gd, n_sg, v_sg, e_g)

        # Add EXX contribution:
        return exc + self.exx * self.hybrid

    def calculate_exx(self):
        """Non-selfconsistent calculation."""

        self.timer.start('EXX')
        self.timer.start('Initialization')
        
        kd = self.kd
        wfs = self.wfs

        if fftw.FFTPlan is fftw.NumpyFFTPlan:
            self.log('NOT USING FFTW !!')

        self.log('Spins:', self.wfs.nspins)

        W = max(1, self.wfs.kd.comm.size // self.wfs.nspins)
        # Are the k-points distributed?
        kparallel = (W > 1)

        # Find number of occupied bands:
        self.nocc_sk = np.zeros((self.wfs.nspins, kd.nibzkpts), int)
        for kpt in self.wfs.kpt_u:
            for n, f in enumerate(kpt.f_n):
                if abs(f) < self.fcut:
                    self.nocc_sk[kpt.s, kpt.k] = n
                    break
            else:
                self.nocc_sk[kpt.s, kpt.k] = self.wfs.bd.nbands
        self.wfs.kd.comm.sum(self.nocc_sk)

        noccmin = self.nocc_sk.min()
        noccmax = self.nocc_sk.max()
        self.log('Number of occupied bands (min, max): %d, %d' %
                 (noccmin, noccmax))
        
        self.log('Number of valence electrons:', self.wfs.setups.nvalence)

        if self.bandstructure:
            self.log('Calculating eigenvalue shifts.')

            # allocate array for eigenvalue shifts:
            self.exx_skn = np.zeros((self.wfs.nspins,
                                     kd.nibzkpts,
                                     self.wfs.bd.nbands))

            if self.bands is None:
                noccmax = self.wfs.bd.nbands
            else:
                noccmax = max(max(self.bands) + 1, noccmax)

        N_c = self.kd.N_c

        vol = wfs.gd.dv * wfs.gd.N_c.prod()
        if self.alpha is None:
            alpha = 6 * vol**(2 / 3.0) / pi**2
        else:
            alpha = self.alpha
        if self.gamma_point == 1:
            if alpha == 0.0:
                qvol = (2*np.pi)**3 / vol / N_c.prod()
                self.gamma = 4*np.pi * (3*qvol / (4*np.pi))**(1/3.) / qvol
            else:
                self.gamma = self.calculate_gamma(vol, alpha)
        else:
            kcell_cv = wfs.gd.cell_cv.copy()
            kcell_cv[0] *= N_c[0]
            kcell_cv[1] *= N_c[1]
            kcell_cv[2] *= N_c[2]
            self.gamma = madelung(kcell_cv) * vol * N_c.prod() / (4 * np.pi)

        self.log('Value of alpha parameter: %.3f Bohr^2' % alpha)
        self.log('Value of gamma parameter: %.3f Bohr^2' % self.gamma)
            
        # Construct all possible q=k2-k1 vectors:
        Nq_c = (N_c - 1) // self.qstride_c
        i_qc = np.indices(Nq_c * 2 + 1, float).transpose(
            (1, 2, 3, 0)).reshape((-1, 3))
        self.bzq_qc = (i_qc - Nq_c) / N_c * self.qstride_c
        self.q0 = ((Nq_c * 2 + 1).prod() - 1) // 2  # index of q=(0,0,0)
        assert not self.bzq_qc[self.q0].any()

        # Count number of pairs for each q-vector:
        self.npairs_q = np.zeros(len(self.bzq_qc), int)
        for s in range(kd.nspins):
            for k1 in range(kd.nibzkpts):
                for k2 in range(kd.nibzkpts):
                    for K2, q, n1_n, n2 in self.indices(s, k1, k2):
                        self.npairs_q[q] += len(n1_n)

        self.npairs0 = self.npairs_q.sum()  # total number of pairs

        self.log('Number of pairs:', self.npairs0)

        # Distribute q-vectors to Q processors:
        Q = self.world.size // self.wfs.kd.comm.size
        myrank = self.world.rank // self.wfs.kd.comm.size
        rank = 0
        N = 0
        myq = []
        nq = 0
        for q, n in enumerate(self.npairs_q):
            if n > 0:
                nq += 1
                if rank == myrank:
                    myq.append(q)
            N += n
            if N >= (rank + 1.0) * self.npairs0 / Q:
                rank += 1

        assert len(myq) > 0, 'Too few q-vectors for too many processes!'
        self.bzq_qc = self.bzq_qc[myq]
        try:
            self.q0 = myq.index(self.q0)
        except ValueError:
            self.q0 = None

        self.log('%d x %d x %d k-points' % tuple(self.kd.N_c))
        self.log('Distributing %d IBZ k-points over %d process(es).' %
                 (kd.nibzkpts, self.wfs.kd.comm.size))
        self.log('Distributing %d q-vectors over %d process(es).' % (nq, Q))

        # q-point descriptor for my q-vectors:
        qd = KPointDescriptor(self.bzq_qc)

        # Plane-wave descriptor for all wave-functions:
        self.pd = PWDescriptor(wfs.pd.ecut, wfs.gd,
                               dtype=wfs.pd.dtype, kd=kd)

        # Plane-wave descriptor pair-densities:
        self.pd2 = PWDescriptor(self.dens.pd2.ecut, self.dens.gd,
                                dtype=wfs.dtype, kd=qd)

        self.log('Cutoff energies:')
        self.log('    Wave functions:       %10.3f eV' %
                 (self.pd.ecut * Hartree))
        self.log('    Density:              %10.3f eV' %
                 (self.pd2.ecut * Hartree))

        # Calculate 1/|G+q|^2 with special treatment of |G+q|=0:
        G2_qG = self.pd2.G2_qG
        if self.q0 is None:
            if self.omega is None:
                self.iG2_qG = [1.0 / G2_G for G2_G in G2_qG]
            else:
                self.iG2_qG = [(1.0 / G2_G *
                                (1 - np.exp(-G2_G / (4 * self.omega**2))))
                               for G2_G in G2_qG]
        else:
            G2_qG[self.q0][0] = 117.0  # avoid division by zero
            if self.omega is None:
                self.iG2_qG = [1.0 / G2_G for G2_G in G2_qG]
                self.iG2_qG[self.q0][0] = self.gamma
            else:
                self.iG2_qG = [(1.0 / G2_G *
                                (1 - np.exp(-G2_G / (4 * self.omega**2))))
                               for G2_G in G2_qG]
                self.iG2_qG[self.q0][0] = 1 / (4 * self.omega**2)
            G2_qG[self.q0][0] = 0.0  # restore correct value

        # Compensation charges:
        self.ghat = PWLFC([setup.ghat_l for setup in wfs.setups], self.pd2)
        self.ghat.set_positions(self.spos_ac)

        if self.molecule:
            self.initialize_gaussian()
            self.log('Value of beta parameter: %.3f 1/Bohr^2' % self.beta)
            
        self.timer.stop('Initialization')
        
        # Ready ... set ... go:
        self.t0 = time()
        self.npairs = 0
        self.evv = 0.0
        self.evvacdf = 0.0
        for s in range(self.wfs.nspins):
            kpt1_q = [KPoint(self.wfs, noccmax).initialize(kpt)
                      for kpt in self.wfs.kpt_u if kpt.s == s]
            kpt2_q = kpt1_q[:]

            if len(kpt1_q) == 0:
                # No s-spins on this CPU:
                continue

            # Send and receive ranks:
            srank = self.wfs.kd.get_rank_and_index(
                s, (kpt1_q[0].k - 1) % kd.nibzkpts)[0]
            rrank = self.wfs.kd.get_rank_and_index(
                s, (kpt1_q[-1].k + 1) % kd.nibzkpts)[0]

            # Shift k-points kd.nibzkpts - 1 times:
            for i in range(kd.nibzkpts):
                if i < kd.nibzkpts - 1:
                    if kparallel:
                        kpt = kpt2_q[-1].next(self.wfs)
                        kpt.start_receiving(rrank)
                        kpt2_q[0].start_sending(srank)
                    else:
                        kpt = kpt2_q[0]

                self.timer.start('Calculate')
                for kpt1, kpt2 in zip(kpt1_q, kpt2_q):
                    # Loop over all k-points that k2 can be mapped to:
                    for K2, q, n1_n, n2 in self.indices(s, kpt1.k, kpt2.k):
                        self.apply(K2, q, kpt1, kpt2, n1_n, n2)
                self.timer.stop('Calculate')

                if i < kd.nibzkpts - 1:
                    self.timer.start('Wait')
                    if kparallel:
                        kpt.wait()
                        kpt2_q[0].wait()
                    self.timer.stop('Wait')
                    kpt2_q.pop(0)
                    kpt2_q.append(kpt)

        self.evv = self.world.sum(self.evv)
        self.evvacdf = self.world.sum(self.evvacdf)
        self.calculate_exx_paw_correction()
        
        if self.method == 'standard':
            self.exx = self.evv + self.devv + self.evc + self.ecc
        elif self.method == 'acdf':
            self.exx = self.evvacdf + self.devv + self.evc + self.ecc
        else:
            1 / 0

        self.log('Exact exchange energy:')
        for txt, e in [
            ('core-core', self.ecc),
            ('valence-core', self.evc),
            ('valence-valence (pseudo, acdf)', self.evvacdf),
            ('valence-valence (pseudo, standard)', self.evv),
            ('valence-valence (correction)', self.devv),
            ('total (%s)' % self.method, self.exx)]:
            self.log('    %-36s %14.6f eV' % (txt + ':', e * Hartree))

        self.log('Total time: %10.3f seconds' % (time() - self.t0))

        self.npairs = self.world.sum(self.npairs)
        assert self.npairs == self.npairs0
        
        self.timer.stop('EXX')
        self.timer.write(self.fd)

    def calculate_gamma(self, vol, alpha):
        if self.molecule:
            return 0.0

        N_c = self.kd.N_c
        offset_c = (N_c + 1) % 2 * 0.5 / N_c
        bzq_qc = monkhorst_pack(N_c) + offset_c
        qd = KPointDescriptor(bzq_qc)
        pd = PWDescriptor(self.wfs.pd.ecut, self.wfs.gd, kd=qd)
        gamma = (vol / (2 * pi)**2 * sqrt(pi / alpha) *
                 self.kd.nbzkpts)
        for G2_G in pd.G2_qG:
            if G2_G[0] < 1e-7:
                G2_G = G2_G[1:]
            gamma -= np.dot(np.exp(-alpha * G2_G), G2_G**-1)
        return gamma / self.qstride_c.prod()

    def indices(self, s, k1, k2):
        """Generator for (K2, q, n1, n2) indices for (k1, k2) pair.

        s: int
            Spin index.
        k1: int
            Index of k-point in the IBZ.
        k2: int
            Index of k-point in the IBZ.

        Returns (K, q, n1_n, n2), where K then index of the k-point in
        the BZ that k2 is mapped to, q is the index of the q-vector
        between K and k1, and n1_n is a list of bands that should be
        combined with band n2."""

        for K, k in enumerate(self.kd.bz2ibz_k):
            if k == k2:
                for K, q, n1_n, n2 in self._indices(s, k1, k2, K):
                    yield K, q, n1_n, n2
            
    def _indices(self, s, k1, k2, K2):
        k1_c = self.kd.ibzk_kc[k1]
        k2_c = self.kd.bzk_kc[K2]
        q_c = k2_c - k1_c
        q = abs(self.bzq_qc - q_c).sum(1).argmin()
        if abs(self.bzq_qc[q] - q_c).sum() > 1e-7:
            return

        if self.gamma_point == 0 and q == self.q0:
            return

        nocc1 = self.nocc_sk[s, k1]
        nocc2 = self.nocc_sk[s, k2]

        # Is k2 in the IBZ?
        is_ibz2 = (self.kd.ibz2bz_k[k2] == K2)

        for n2 in range(self.wfs.bd.nbands):
            # Find range of n1's (from n1a to n1b-1):
            if is_ibz2:
                # We get this combination twice, so let's only do half:
                if k1 >= k2:
                    n1a = n2
                else:
                    n1a = n2 + 1
            else:
                n1a = 0

            n1b = self.wfs.bd.nbands

            if self.bandstructure:
                if n2 >= nocc2:
                    n1b = min(n1b, nocc1)
            else:
                if n2 >= nocc2:
                    break
                n1b = min(n1b, nocc1)

            if self.bands is not None:
                assert self.bandstructure
                n1_n = []
                for n1 in range(n1a, n1b):
                    if (n1 in self.bands and n2 < nocc2 or
                        is_ibz2 and n2 in self.bands and n1 < nocc1):
                        n1_n.append(n1)
                n1_n = np.array(n1_n)
            else:
                n1_n = np.arange(n1a, n1b)

            if len(n1_n) == 0:
                continue

            yield K2, q, n1_n, n2

    def apply(self, K2, q, kpt1, kpt2, n1_n, n2):
        k20_c = self.kd.ibzk_kc[kpt2.k]
        k2_c = self.kd.bzk_kc[K2]

        if k2_c.any():
            self.timer.start('Initialize plane waves')
            eik2r_R = self.wfs.gd.plane_wave(k2_c)
            eik20r_R = self.wfs.gd.plane_wave(k20_c)
            self.timer.stop('Initialize plane waves')
        else:
            eik2r_R = 1.0
            eik20r_R = 1.0

        w1 = self.kd.weight_k[kpt1.k]
        w2 = self.kd.weight_k[kpt2.k]

        # Is k2 in the 1. BZ?
        is_ibz2 = (self.kd.ibz2bz_k[kpt2.k] == K2)

        e_n = self.calculate_interaction(n1_n, n2, kpt1, kpt2, q, K2,
                                         eik20r_R, eik2r_R,
                                         is_ibz2)

        e_n *= 1.0 / self.kd.nbzkpts / self.wfs.nspins * self.qstride_c.prod()
        
        if q == self.q0:
            e_n[n1_n == n2] *= 0.5

        f1_n = kpt1.f_n[n1_n]
        eps1_n = kpt1.eps_n[n1_n]
        f2 = kpt2.f_n[n2]
        eps2 = kpt2.eps_n[n2]

        s_n = np.sign(eps2 - eps1_n)

        evv = (f1_n * f2 * e_n).sum()
        evvacdf = 0.5 * (f1_n * (1 - s_n) * e_n +
                         f2 * (1 + s_n) * e_n).sum()
        self.evv += evv * w1
        self.evvacdf += evvacdf * w1
        if is_ibz2:
            self.evv += evv * w2
            self.evvacdf += evvacdf * w2

        if self.bandstructure:
            x = self.wfs.nspins
            self.exx_skn[kpt1.s, kpt1.k, n1_n] += x * f2 * e_n
            if is_ibz2:
                self.exx_skn[kpt2.s, kpt2.k, n2] += x * np.dot(f1_n, e_n)

    def calculate_interaction(self, n1_n, n2, kpt1, kpt2, q, k,
                              eik20r_R, eik2r_R, is_ibz2):
        """Calculate Coulomb interactions.

        For all n1 in the n1_n list, calculate interaction with n2."""

        # number of plane waves:
        ng1 = self.wfs.ng_k[kpt1.k]
        ng2 = self.wfs.ng_k[kpt2.k]

        # Transform to real space and apply symmetry operation:
        self.timer.start('IFFT1')
        if is_ibz2:
            u2_R = self.pd.ifft(kpt2.psit_nG[n2, :ng2], kpt2.k)
        else:
            psit2_R = self.pd.ifft(kpt2.psit_nG[n2, :ng2], kpt2.k) * eik20r_R
            self.timer.start('Symmetry transform')
            u2_R = self.kd.transform_wave_function(psit2_R, k) / eik2r_R
            self.timer.stop()
        self.timer.stop()

        # Calculate pair densities:
        nt_nG = self.pd2.zeros(len(n1_n), q=q)
        for n1, nt_G in zip(n1_n, nt_nG):
            self.timer.start('IFFT2')
            u1_R = self.pd.ifft(kpt1.psit_nG[n1, :ng1], kpt1.k)
            self.timer.stop()
            nt_R = u1_R.conj() * u2_R
            self.timer.start('FFT')
            nt_G[:] = self.pd2.fft(nt_R, q)
            self.timer.stop()
        
        s = self.kd.sym_k[k]
        time_reversal = self.kd.time_reversal_k[k]
        k2_c = self.kd.ibzk_kc[kpt2.k]

        self.timer.start('Compensation charges')
        Q_anL = {}  # coefficients for shape functions
        for a, P1_ni in kpt1.P_ani.items():
            P1_ni = P1_ni[n1_n]

            if is_ibz2:
                P2_i = kpt2.P_ani[a][n2]
            else:
                b = self.kd.symmetry.a_sa[s, a]
                S_c = (np.dot(self.spos_ac[a], self.kd.symmetry.op_scc[s]) -
                       self.spos_ac[b])
                assert abs(S_c.round() - S_c).max() < 1e-5
                if self.ghat.dtype == complex:
                    x = np.exp(2j * pi * np.dot(k2_c, S_c))
                else:
                    x = 1.0
                P2_i = np.dot(self.wfs.setups[a].R_sii[s],
                              kpt2.P_ani[b][n2]) * x
                if time_reversal:
                    P2_i = P2_i.conj()

            D_np = []
            for P1_i in P1_ni:
                D_ii = np.outer(P1_i.conj(), P2_i)
                D_np.append(pack(D_ii))
            Q_anL[a] = np.dot(D_np, self.wfs.setups[a].Delta_pL)
            
        self.timer.start('Expand')
        if q != self.qlatest:
            self.f_IG = self.ghat.expand(q)
            self.qlatest = q
        self.timer.stop('Expand')

        # Add compensation charges:
        self.ghat.add(nt_nG, Q_anL, q, self.f_IG)
        self.timer.stop('Compensation charges')

        if self.molecule and n2 in n1_n:
            nn = (n1_n == n2).nonzero()[0][0]
            nt_nG[nn] -= self.ngauss_G
        else:
            nn = None
            
        iG2_G = self.iG2_qG[q]
        
        # Calculate energies:
        e_n = np.empty(len(n1_n))
        for n, nt_G in enumerate(nt_nG):
            e_n[n] = -4 * pi * np.real(self.pd2.integrate(nt_G, nt_G * iG2_G))
            self.npairs += 1
        
        if nn is not None:
            e_n[nn] -= 2 * (self.pd2.integrate(nt_nG[nn], self.vgauss_G) +
                            (self.beta / 2 / pi)**0.5)

        if self.write_timing_information:
            t = (time() - self.t0) / len(n1_n)
            self.log('Time for first pair-density: %10.3f seconds' % t)
            self.log('Estimated total time:        %10.3f seconds' %
                     (t * self.npairs0 / self.world.size))
            self.write_timing_information = False

        return e_n

    def calculate_exx_paw_correction(self):
        self.timer.start('PAW correction')
        self.devv = 0.0
        self.evc = 0.0
        self.ecc = 0.0
                         
        deg = 2 // self.wfs.nspins  # spin degeneracy
        for a, D_sp in self.dens.D_asp.items():
            setup = self.wfs.setups[a]
            for D_p in D_sp:
                D_ii = unpack2(D_p)
                ni = len(D_ii)

                for i1 in range(ni):
                    for i2 in range(ni):
                        A = 0.0
                        for i3 in range(ni):
                            p13 = packed_index(i1, i3, ni)
                            for i4 in range(ni):
                                p24 = packed_index(i2, i4, ni)
                                A += setup.M_pp[p13, p24] * D_ii[i3, i4]
                        self.devv -= D_ii[i1, i2] * A / deg

                self.evc -= np.dot(D_p, setup.X_p)
            self.ecc += setup.ExxC

        if not self.bandstructure:
            self.timer.stop('PAW correction')
            return

        Q = self.world.size // self.wfs.kd.comm.size
        self.exx_skn *= Q
        for kpt in self.wfs.kpt_u:
            for a, D_sp in self.dens.D_asp.items():
                setup = self.wfs.setups[a]
                for D_p in D_sp:
                    D_ii = unpack2(D_p)
                    ni = len(D_ii)
                    P_ni = kpt.P_ani[a]
                    for i1 in range(ni):
                        for i2 in range(ni):
                            A = 0.0
                            for i3 in range(ni):
                                p13 = packed_index(i1, i3, ni)
                                for i4 in range(ni):
                                    p24 = packed_index(i2, i4, ni)
                                    A += setup.M_pp[p13, p24] * D_ii[i3, i4]
                            self.exx_skn[kpt.s, kpt.k] -= \
                                (A * P_ni[:, i1].conj() * P_ni[:, i2]).real
                            p12 = packed_index(i1, i2, ni)
                            self.exx_skn[kpt.s, kpt.k] -= \
                                (P_ni[:, i1].conj() * setup.X_p[p12] *
                                 P_ni[:, i2]).real / self.wfs.nspins

        self.world.sum(self.exx_skn)
        self.exx_skn *= self.hybrid / Q
        self.timer.stop('PAW correction')
    
    def initialize_gaussian(self):
        """Calculate gaussian compensation charge and its potential.

        Used to decouple electrostatic interactions between
        periodically repeated images for molecular calculations.

        Charge containing one electron::

            (beta/pi)^(3/2)*exp(-beta*r^2),

        its Fourier transform::

            exp(-G^2/(4*beta)),

        and its potential::

            erf(beta^0.5*r)/r.
        """

        gd = self.wfs.gd

        # Set exponent of exp-function to -19 on the boundary:
        self.beta = 4 * 19 * (gd.icell_cv**2).sum(1).max()

        # Calculate gaussian:
        G_Gv = self.pd2.get_reciprocal_vectors()
        G2_G = self.pd2.G2_qG[0]
        C_v = gd.cell_cv.sum(0) / 2  # center of cell
        self.ngauss_G = np.exp(-1.0 / (4 * self.beta) * G2_G +
                                1j * np.dot(G_Gv, C_v)) / gd.dv

        # Calculate potential from gaussian:
        R_Rv = gd.get_grid_point_coordinates().transpose((1, 2, 3, 0))
        r_R = ((R_Rv - C_v)**2).sum(3)**0.5
        if (gd.N_c % 2 == 0).all():
            r_R[tuple(gd.N_c // 2)] = 1.0  # avoid dividing by zero
        v_R = erf(self.beta**0.5 * r_R) / r_R
        if (gd.N_c % 2 == 0).all():
            v_R[tuple(gd.N_c // 2)] = (4 * self.beta / pi)**0.5
        self.vgauss_G = self.pd2.fft(v_R)

        # Compare self-interaction to analytic result:
        assert abs(0.5 * self.pd2.integrate(self.ngauss_G, self.vgauss_G) -
                   (self.beta / 2 / pi)**0.5) < 1e-6
Exemple #2
0
        for n in range(0, nb):
            n2_nmv[n] = pair.optical_pair_velocity(n, np.arange(0, nb),
                                                   kptpair.kpt1, kptpair.kpt2)

    # Check for nan's
    assert not np.isnan(n_nmG).any()
    assert not np.isnan(n_nmvG).any()
    if ol:
        assert not np.isnan(n2_nmv).any()

    # PAW correction test
    if ol:
        print('Checking PAW corrections')
        # Check that PAW-corrections are
        # are equal to nabla-PAW corrections
        G_Gv = pd.get_reciprocal_vectors()

        for id, atomdata in pair.calc.wfs.setups.setups.items():
            nabla_vii = atomdata.nabla_iiv.transpose((2, 0, 1))
            Q_vGii = two_phi_nabla_planewave_integrals(G_Gv, atomdata)
            ni = atomdata.ni
            Q_vGii.shape = (3, -1, ni, ni)
            equal(nabla_vii.astype(complex),
                  Q_vGii[:, 0],
                  tolerance=1e-10,
                  msg='Planewave-nabla PAW corrections not equal ' +
                  'to nabla PAW corrections when q + G = 0!')

        # Check optical limit nabla matrix elements
        err = np.abs(n_nmvG[..., 0] - n2_nmv)
        maxerr = np.max(err)
Exemple #3
0
class TDDFT(object):
    """
    Time-dependent DFT+Hartree-Fock in Kohn-Sham orbitals basis:
    
        calc: GPAW calculator (setups='sg15')
        nbands (int): number of bands in calculation
        
    """
    def __init__(self, calc, nbands=None, Fock=False):
        self.calc = calc
        self.Fock = Fock
        self.K = calc.get_ibz_k_points()  # reduced Brillioun zone
        self.NK = self.K.shape[0]

        self.wk = calc.get_k_point_weights(
        )  # weight of reduced Brillioun zone
        if nbands is None:
            self.nbands = calc.get_number_of_bands()
        else:
            self.nbands = nbands
        self.nvalence = int(calc.get_number_of_electrons() / 2)

        self.EK = [
            calc.get_eigenvalues(k)[:self.nbands] for k in range(self.NK)
        ]  # bands energy
        self.EK = np.array(self.EK) / Hartree
        self.shape = tuple(
            calc.get_number_of_grid_points())  # shape of real space grid
        self.density = calc.get_pseudo_density(
        ) * Bohr**3  # density at zero time

        # array of u_nk (periodic part of Kohn-Sham orbitals,only reduced Brillion zone)
        self.ukn = np.zeros((
            self.NK,
            self.nbands,
        ) + self.shape,
                            dtype=np.complex)
        for k in range(self.NK):
            kpt = calc.wfs.kpt_u[k]
            for n in range(self.nbands):
                psit_G = kpt.psit_nG[n]
                psit_R = calc.wfs.pd.ifft(psit_G, kpt.q)
                self.ukn[k, n] = psit_R

        self.icell = 2.0 * np.pi * calc.wfs.gd.icell_cv  # inverse cell
        self.cell = calc.wfs.gd.cell_cv  # cell
        self.r = calc.wfs.gd.get_grid_point_coordinates()
        for i in range(3):
            self.r[i] -= self.cell[i, i] / 2.
        self.volume = np.abs(np.linalg.det(
            calc.wfs.gd.cell_cv))  # volume of cell
        self.norm = calc.wfs.gd.dv  #
        self.Fermi = calc.get_fermi_level() / Hartree  #Fermi level

        #desriptors at q=gamma for Hartree
        self.kdH = KPointDescriptor([[0, 0, 0]])
        self.pdH = PWDescriptor(ecut=calc.wfs.pd.ecut,
                                gd=calc.wfs.gd,
                                kd=self.kdH,
                                dtype=complex)

        #desriptors at q=gamma for Fock
        self.kdF = KPointDescriptor([[0, 0, 0]])
        self.pdF = PWDescriptor(ecut=calc.wfs.pd.ecut / 4.,
                                gd=calc.wfs.gd,
                                kd=self.kdF,
                                dtype=complex)

        #Fermi-Dirac temperature
        self.temperature = calc.occupations.width

        #calculate pair-density matrices
        if Fock:
            self.M = np.zeros((self.nbands, self.nbands, self.NK, self.NK,
                               self.pdF.get_reciprocal_vectors().shape[0]),
                              dtype=np.complex)
            indexes = [(n, k)
                       for n, k in product(range(self.nbands), range(self.NK))]
            for i1 in range(len(indexes)):
                n1, k1 = indexes[i1]
                for i2 in range(i1, len(indexes)):
                    n2, k2 = indexes[i1]
                    self.M[n1, n2, k1, k2] = self.pdF.fft(
                        self.ukn[k1, n1].conj() * self.ukn[k2, n2])
                    self.M[n2, n1, k2, k1] = self.M[n1, n2, k1, k2].conj()
            self.M *= calc.wfs.gd.dv

        #Fermi-Dirac distribution
        self.f = 1 / (1 + np.exp((self.EK - self.Fermi) / self.temperature))

        self.Hartree_elements = np.zeros(
            (self.NK, self.nbands, self.NK, self.nbands, self.nbands),
            dtype=np.complex)
        self.LDAx_elements = np.zeros(
            (self.NK, self.nbands, self.NK, self.nbands, self.nbands),
            dtype=np.complex)
        self.LDAc_elements = np.zeros(
            (self.NK, self.nbands, self.NK, self.nbands, self.nbands),
            dtype=np.complex)
        G = self.pdH.get_reciprocal_vectors()
        G2 = np.linalg.norm(G, axis=1)**2
        G2[G2 == 0] = np.inf
        matrix = np.zeros((self.NK, self.nbands, self.nbands),
                          dtype=np.complex)
        for k in tqdm(range(self.NK)):
            for n in range(self.nbands):
                density = 2 * np.abs(self.ukn[k, n])**2
                operator = xc.VLDAx(density)
                self.LDAx_elements[k, n] = operator_matrix_periodic(
                    matrix, operator, self.ukn.conj(), self.ukn) * self.norm
                operator = xc.VLDAc(density)
                self.LDAc_elements[k, n] = operator_matrix_periodic(
                    matrix, operator, self.ukn.conj(), self.ukn) * self.norm

                density = self.pdH.fft(density)
                operator = 4 * np.pi * self.pdH.ifft(density / G2)
                self.Hartree_elements[k, n] = operator_matrix_periodic(
                    matrix, operator, self.ukn.conj(), self.ukn) * self.norm

        self.wavefunction = np.zeros((self.NK, self.nbands, self.nbands),
                                     dtype=np.complex)
        self.Kinetic = np.zeros((self.NK, self.nbands, self.nbands),
                                dtype=np.complex)
        self.dipole = self.get_dipole_matrix()
        for k in range(self.NK):
            self.wavefunction[k] = np.eye(self.nbands)
            self.Kinetic[k] = np.diag(self.EK[k])
        self.VH0 = self.get_Hartree_matrix(self.wavefunction)
        self.VLDAc0 = self.get_LDA_correlation_matrix(self.wavefunction)
        self.VLDAx0 = self.get_LDA_exchange_matrix(self.wavefunction)

        self.Full_BZ = calc.get_bz_k_points()
        self.IBZ_map = calc.get_bz_to_ibz_map()

    def get_dipole_matrix(self, direction=[1, 0, 0]):
        """ 
        return two-dimensional numpy complex array of dipole matrix elements(
        """
        direction /= np.linalg.norm(direction)
        r = np.sum(direction[:, None, None, None] * self.r, axis=0)
        dipole = np.zeros((self.NK, self.nbands, self.nbands),
                          dtype=np.complex)
        dipole = operator_matrix_periodic(
            dipole, r, self.ukn.conj(),
            self.ukn) * self.norm  #!!!!!! no direction
        return dipole

    def get_density(self, wavefunction):
        """ 
        return numpy array of electron density in real space at each k-point of full Brillioun zone
        wavefunction: numpy array [N_kpoint X N_band X N_band] of wavefunction in basis of Kohn-Sham orbital
        """
        if wavefunction is None:
            return self.density

        density = np.zeros(self.shape, dtype=np.float)
        for k in range(self.NK):
            for n in range(self.nbands):
                for m in range(self.nbands):
                    density += 2 * self.wk[k] * self.f[k, n] * np.abs(
                        wavefunction[k, m, n] * self.ukn[k, m])**2
        return density

    def get_Hartree_potential(self, wavefunction):
        """ 
        return numpy array of Hartree potential in real space at each k-point of full Brillioun zone
        wavefunction: numpy array [N_kpoint X N_band X N_band] of wavefunction in basis of Kohn-Sham orbital
        """
        density = self.get_density(wavefunction)
        VH = np.zeros(self.shape)
        G = self.pdH.get_reciprocal_vectors()
        G2 = np.linalg.norm(G, axis=1)**2
        G2[G2 == 0] = np.inf
        nG = self.pdH.fft(density)
        return -4 * np.pi * self.pdH.ifft(nG / G2)

    def get_coloumb_potential(self, q):
        """
        return coloumb potential in plane wave space V= 4 pi /(|q+G|**2)
        q: [qx,qy,qz] vector in units of reciprocal space
        """
        G = self.pdF.get_reciprocal_vectors() + np.dot(q, self.icell)
        G2 = np.linalg.norm(G, axis=1)**2
        G2[G2 == 0] = np.inf
        return 4 * np.pi / G2

    def get_Hartree_matrix(self, wavefunction=None):
        """
        return numpy array [N_kpoint X N_band X N_band] of Hartree potential matrix elements
        wavefunction: numpy array [N_kpoint X N_band X N_band] of wavefunction in basis of Kohn-Sham orbital
        """
        VH = self.get_Hartree_potential(wavefunction)
        VH_matrix = np.zeros((self.NK, self.nbands, self.nbands),
                             dtype=np.complex)
        VH_matrix = operator_matrix_periodic(VH_matrix, VH, self.ukn.conj(),
                                             self.ukn) * self.norm
        return VH_matrix

    def get_Fock_matrix(self, wavefunction=None):
        """
        return numpy array [N_kpoint X N_band X N_band] of Fock potential matrix elements
        wavefunction: numpy array [N_kpoint X N_band X N_band] of wavefunction in basis of Kohn-Sham orbital
        """
        VF_matrix = np.zeros((self.NK, self.nbands, self.nbands),
                             dtype=np.complex)
        if self.Fock:
            if wavefunction is None:
                wavefunction = np.zeros((self.NK, self.nbands, self.nbands))
                for k in range(self.NK):
                    wavefunction[k] = np.eye(self.nbands)
            K = self.Full_BZ
            NK = K.shape[0]
            NG = self.pdF.get_reciprocal_vectors().shape[0]
            V = np.zeros((self.NK, NK, NG))
            for k in range(self.NK):
                for q in range(NK):
                    kq = K[q] - self.K[k]
                    V[k, q] = self.get_coloumb_potential(kq)

            VF_matrix = Fock_matrix(VF_matrix, V, self.M.conj(), self.M,
                                    self.IBZ_map, self.nvalence)
        return VF_matrix / self.volume

    def get_LDA_exchange_matrix(self, wavefunction=None):
        """
        return numpy array [N_kpoint X N_band X N_band] of LDA exchange potential matrix elements
        wavefunction: numpy array [N_kpoint X N_band X N_band] of wavefunction in basis of Kohn-Sham orbital
        """
        density = self.get_density(wavefunction)
        exchange = xc.VLDAx(density)
        LDAx_matrix = np.zeros((self.NK, self.nbands, self.nbands),
                               dtype=np.complex)
        LDAx_matrix = operator_matrix_periodic(
            LDAx_matrix, exchange, self.ukn.conj(), self.ukn) * self.norm
        return LDAx_matrix

    def get_LDA_correlation_matrix(self, wavefunction=None):
        """
        return numpy array [N_kpoint X N_band X N_band] of LDA correlation potential matrix elements
        wavefunction: numpy array [N_kpoint X N_band X N_band] of wavefunction in basis of Kohn-Sham orbital
        """
        density = self.get_density(wavefunction)
        correlation = xc.VLDAc(density)
        LDAc_matrix = np.zeros((self.NK, self.nbands, self.nbands),
                               dtype=np.complex)
        LDAc_matrix = operator_matrix_periodic(
            LDAc_matrix, correlation, self.ukn.conj(), self.ukn) * self.norm
        return LDAc_matrix

    def occupation(self, wavefunction):
        return 2 * np.sum(self.wk[:, None, None] * self.f[:, None, :] *
                          np.abs(wavefunction)**2,
                          axis=2)

    def fast_Hartree_matrix(self, wavefunction):
        return np.einsum('kn,knqij->qij', self.occupation(wavefunction),
                         self.Hartree_elements) - self.VH0

    def fast_LDA_correlation_matrix(self, wavefunction):
        return np.einsum('kn,knqij->qij', self.occupation(wavefunction),
                         self.LDAc_elements) - self.VLDAc0

    def fast_LDA_exchange_matrix(self, wavefunction):
        return np.einsum('kn,knqij->qij', self.occupation(wavefunction),
                         self.LDAx_elements) - self.VLDAx0

    def propagate(self, dt, steps, E, operator, corrections=10):
        self.wavefunction = np.zeros((self.NK, self.nbands, self.nbands),
                                     dtype=np.complex)
        for k in range(self.NK):
            self.wavefunction[k] = np.eye(self.nbands)
        H = np.copy(self.Kinetic) + E[0] * self.dipole
        operator_macro = np.array(
            [operator[k].diagonal() for k in range(self.NK)])
        self.macro_dipole = np.zeros(steps, dtype=np.complex)
        for t in tqdm(range(steps)):
            wavefunction_next = np.copy(self.wavefunction)
            error = np.inf
            while error > 1e-8:
                wavefunction_check = np.copy(wavefunction_next)
                H_next = self.Kinetic + E[t] * self.dipole
                H_next += self.fast_Hartree_matrix(wavefunction_next)
                H_next += self.fast_LDA_correlation_matrix(wavefunction_next)
                H_next += self.fast_LDA_exchange_matrix(wavefunction_next)

                H_mid = 0.5 * (H + H_next)
                for k in range(self.NK):
                    H_left = np.eye(self.nbands) + 0.5j * dt * H_mid[k]
                    H_right = np.eye(self.nbands) - 0.5j * dt * H_mid[k]
                    wavefunction_next[k] = linalg.solve(
                        H_left, H_right @ self.wavefunction[k])
                error = np.abs(np.sum(wavefunction_next - wavefunction_check))
            self.wavefunction = np.copy(wavefunction_next)
            H = np.copy(H_next)
            self.macro_dipole[t] = np.sum(
                self.occupation(self.wavefunction) * operator_macro)
class HybridXC(HybridXCBase):
    orbital_dependent = True

    def __init__(self,
                 name,
                 hybrid=None,
                 xc=None,
                 alpha=None,
                 gamma_point=1,
                 method='standard',
                 bandstructure=False,
                 logfilename='-',
                 bands=None,
                 fcut=1e-10,
                 molecule=False,
                 qstride=1,
                 world=None):
        """Mix standard functionals with exact exchange.

        name: str
            Name of functional: EXX, PBE0, HSE03, HSE06
        hybrid: float
            Fraction of exact exchange.
        xc: str or XCFunctional object
            Standard DFT functional with scaled down exchange.
        method: str
            Use 'standard' standard formula and 'acdf for
            adiabatic-connection dissipation fluctuation formula.
        alpha: float
            XXX describe
        gamma_point: bool
            0: Skip k2-k1=0 interactions.
            1: Use the alpha method.
            2: Integrate the gamma point.
        bandstructure: bool
            Calculate bandstructure instead of just the total energy.
        bands: list of int
            List of bands to calculate bandstructure for.  Default is
            all bands.
        molecule: bool
            Decouple electrostatic interactions between periodically
            repeated images.
        fcut: float
            Threshold for empty band.
        """

        self.alpha = alpha
        self.fcut = fcut

        self.gamma_point = gamma_point
        self.method = method
        self.bandstructure = bandstructure
        self.bands = bands

        self.fd = logfilename
        self.write_timing_information = True

        HybridXCBase.__init__(self, name, hybrid, xc)

        # EXX energies:
        self.exx = None  # total
        self.evv = None  # valence-valence (pseudo part)
        self.evvacdf = None  # valence-valence (pseudo part)
        self.devv = None  # valence-valence (PAW correction)
        self.evc = None  # valence-core
        self.ecc = None  # core-core

        self.exx_skn = None  # bandstructure

        self.qlatest = None

        if world is None:
            world = mpi.world
        self.world = world

        self.molecule = molecule

        if isinstance(qstride, int):
            qstride = [qstride] * 3
        self.qstride_c = np.asarray(qstride)

        self.timer = Timer()

    def log(self, *args, **kwargs):
        prnt(file=self.fd, *args, **kwargs)
        self.fd.flush()

    def calculate_radial(self,
                         rgd,
                         n_sLg,
                         Y_L,
                         v_sg,
                         dndr_sLg=None,
                         rnablaY_Lv=None,
                         tau_sg=None,
                         dedtau_sg=None):
        return self.xc.calculate_radial(rgd, n_sLg, Y_L, v_sg, dndr_sLg,
                                        rnablaY_Lv)

    def calculate_paw_correction(self,
                                 setup,
                                 D_sp,
                                 dEdD_sp=None,
                                 addcoredensity=True,
                                 a=None):
        return self.xc.calculate_paw_correction(setup, D_sp, dEdD_sp,
                                                addcoredensity, a)

    def initialize(self, dens, ham, wfs, occupations):
        assert wfs.bd.comm.size == 1

        self.xc.initialize(dens, ham, wfs, occupations)

        self.dens = dens
        self.wfs = wfs

        # Make a k-point descriptor that is not distributed
        # (self.kd.comm is serial_comm):
        self.kd = wfs.kd.copy()

        self.fd = logfile(self.fd, self.world.rank)

        wfs.initialize_wave_functions_from_restart_file()

    def set_positions(self, spos_ac):
        self.spos_ac = spos_ac

    def calculate(self, gd, n_sg, v_sg=None, e_g=None):
        # Normal XC contribution:
        exc = self.xc.calculate(gd, n_sg, v_sg, e_g)

        # Add EXX contribution:
        return exc + self.exx * self.hybrid

    def calculate_exx(self):
        """Non-selfconsistent calculation."""

        self.timer.start('EXX')
        self.timer.start('Initialization')

        kd = self.kd
        wfs = self.wfs

        if fftw.FFTPlan is fftw.NumpyFFTPlan:
            self.log('NOT USING FFTW !!')

        self.log('Spins:', self.wfs.nspins)

        W = max(1, self.wfs.kd.comm.size // self.wfs.nspins)
        # Are the k-points distributed?
        kparallel = (W > 1)

        # Find number of occupied bands:
        self.nocc_sk = np.zeros((self.wfs.nspins, kd.nibzkpts), int)
        for kpt in self.wfs.kpt_u:
            for n, f in enumerate(kpt.f_n):
                if abs(f) < self.fcut:
                    self.nocc_sk[kpt.s, kpt.k] = n
                    break
            else:
                self.nocc_sk[kpt.s, kpt.k] = self.wfs.bd.nbands
        self.wfs.kd.comm.sum(self.nocc_sk)

        noccmin = self.nocc_sk.min()
        noccmax = self.nocc_sk.max()
        self.log('Number of occupied bands (min, max): %d, %d' %
                 (noccmin, noccmax))

        self.log('Number of valence electrons:', self.wfs.setups.nvalence)

        if self.bandstructure:
            self.log('Calculating eigenvalue shifts.')

            # allocate array for eigenvalue shifts:
            self.exx_skn = np.zeros(
                (self.wfs.nspins, kd.nibzkpts, self.wfs.bd.nbands))

            if self.bands is None:
                noccmax = self.wfs.bd.nbands
            else:
                noccmax = max(max(self.bands) + 1, noccmax)

        N_c = self.kd.N_c

        vol = wfs.gd.dv * wfs.gd.N_c.prod()
        if self.alpha is None:
            alpha = 6 * vol**(2 / 3.0) / pi**2
        else:
            alpha = self.alpha
        if self.gamma_point == 1:
            if alpha == 0.0:
                qvol = (2 * np.pi)**3 / vol / N_c.prod()
                self.gamma = 4 * np.pi * (3 * qvol /
                                          (4 * np.pi))**(1 / 3.) / qvol
            else:
                self.gamma = self.calculate_gamma(vol, alpha)
        else:
            kcell_cv = wfs.gd.cell_cv.copy()
            kcell_cv[0] *= N_c[0]
            kcell_cv[1] *= N_c[1]
            kcell_cv[2] *= N_c[2]
            self.gamma = madelung(kcell_cv) * vol * N_c.prod() / (4 * np.pi)

        self.log('Value of alpha parameter: %.3f Bohr^2' % alpha)
        self.log('Value of gamma parameter: %.3f Bohr^2' % self.gamma)

        # Construct all possible q=k2-k1 vectors:
        Nq_c = (N_c - 1) // self.qstride_c
        i_qc = np.indices(Nq_c * 2 + 1, float).transpose((1, 2, 3, 0)).reshape(
            (-1, 3))
        self.bzq_qc = (i_qc - Nq_c) / N_c * self.qstride_c
        self.q0 = ((Nq_c * 2 + 1).prod() - 1) // 2  # index of q=(0,0,0)
        assert not self.bzq_qc[self.q0].any()

        # Count number of pairs for each q-vector:
        self.npairs_q = np.zeros(len(self.bzq_qc), int)
        for s in range(kd.nspins):
            for k1 in range(kd.nibzkpts):
                for k2 in range(kd.nibzkpts):
                    for K2, q, n1_n, n2 in self.indices(s, k1, k2):
                        self.npairs_q[q] += len(n1_n)

        self.npairs0 = self.npairs_q.sum()  # total number of pairs

        self.log('Number of pairs:', self.npairs0)

        # Distribute q-vectors to Q processors:
        Q = self.world.size // self.wfs.kd.comm.size
        myrank = self.world.rank // self.wfs.kd.comm.size
        rank = 0
        N = 0
        myq = []
        nq = 0
        for q, n in enumerate(self.npairs_q):
            if n > 0:
                nq += 1
                if rank == myrank:
                    myq.append(q)
            N += n
            if N >= (rank + 1.0) * self.npairs0 / Q:
                rank += 1

        assert len(myq) > 0, 'Too few q-vectors for too many processes!'
        self.bzq_qc = self.bzq_qc[myq]
        try:
            self.q0 = myq.index(self.q0)
        except ValueError:
            self.q0 = None

        self.log('%d x %d x %d k-points' % tuple(self.kd.N_c))
        self.log('Distributing %d IBZ k-points over %d process(es).' %
                 (kd.nibzkpts, self.wfs.kd.comm.size))
        self.log('Distributing %d q-vectors over %d process(es).' % (nq, Q))

        # q-point descriptor for my q-vectors:
        qd = KPointDescriptor(self.bzq_qc)

        # Plane-wave descriptor for all wave-functions:
        self.pd = PWDescriptor(wfs.pd.ecut, wfs.gd, dtype=wfs.pd.dtype, kd=kd)

        # Plane-wave descriptor pair-densities:
        self.pd2 = PWDescriptor(self.dens.pd2.ecut,
                                self.dens.gd,
                                dtype=wfs.dtype,
                                kd=qd)

        self.log('Cutoff energies:')
        self.log('    Wave functions:       %10.3f eV' %
                 (self.pd.ecut * Hartree))
        self.log('    Density:              %10.3f eV' %
                 (self.pd2.ecut * Hartree))

        # Calculate 1/|G+q|^2 with special treatment of |G+q|=0:
        G2_qG = self.pd2.G2_qG
        if self.q0 is None:
            if self.omega is None:
                self.iG2_qG = [1.0 / G2_G for G2_G in G2_qG]
            else:
                self.iG2_qG = [
                    (1.0 / G2_G * (1 - np.exp(-G2_G / (4 * self.omega**2))))
                    for G2_G in G2_qG
                ]
        else:
            G2_qG[self.q0][0] = 117.0  # avoid division by zero
            if self.omega is None:
                self.iG2_qG = [1.0 / G2_G for G2_G in G2_qG]
                self.iG2_qG[self.q0][0] = self.gamma
            else:
                self.iG2_qG = [
                    (1.0 / G2_G * (1 - np.exp(-G2_G / (4 * self.omega**2))))
                    for G2_G in G2_qG
                ]
                self.iG2_qG[self.q0][0] = 1 / (4 * self.omega**2)
            G2_qG[self.q0][0] = 0.0  # restore correct value

        # Compensation charges:
        self.ghat = PWLFC([setup.ghat_l for setup in wfs.setups], self.pd2)
        self.ghat.set_positions(self.spos_ac)

        if self.molecule:
            self.initialize_gaussian()
            self.log('Value of beta parameter: %.3f 1/Bohr^2' % self.beta)

        self.timer.stop('Initialization')

        # Ready ... set ... go:
        self.t0 = time()
        self.npairs = 0
        self.evv = 0.0
        self.evvacdf = 0.0
        for s in range(self.wfs.nspins):
            kpt1_q = [
                KPoint(self.wfs, noccmax).initialize(kpt)
                for kpt in self.wfs.kpt_u if kpt.s == s
            ]
            kpt2_q = kpt1_q[:]

            if len(kpt1_q) == 0:
                # No s-spins on this CPU:
                continue

            # Send and receive ranks:
            srank = self.wfs.kd.get_rank_and_index(s, (kpt1_q[0].k - 1) %
                                                   kd.nibzkpts)[0]
            rrank = self.wfs.kd.get_rank_and_index(s, (kpt1_q[-1].k + 1) %
                                                   kd.nibzkpts)[0]

            # Shift k-points kd.nibzkpts - 1 times:
            for i in range(kd.nibzkpts):
                if i < kd.nibzkpts - 1:
                    if kparallel:
                        kpt = kpt2_q[-1].next(self.wfs)
                        kpt.start_receiving(rrank)
                        kpt2_q[0].start_sending(srank)
                    else:
                        kpt = kpt2_q[0]

                self.timer.start('Calculate')
                for kpt1, kpt2 in zip(kpt1_q, kpt2_q):
                    # Loop over all k-points that k2 can be mapped to:
                    for K2, q, n1_n, n2 in self.indices(s, kpt1.k, kpt2.k):
                        self.apply(K2, q, kpt1, kpt2, n1_n, n2)
                self.timer.stop('Calculate')

                if i < kd.nibzkpts - 1:
                    self.timer.start('Wait')
                    if kparallel:
                        kpt.wait()
                        kpt2_q[0].wait()
                    self.timer.stop('Wait')
                    kpt2_q.pop(0)
                    kpt2_q.append(kpt)

        self.evv = self.world.sum(self.evv)
        self.evvacdf = self.world.sum(self.evvacdf)
        self.calculate_exx_paw_correction()

        if self.method == 'standard':
            self.exx = self.evv + self.devv + self.evc + self.ecc
        elif self.method == 'acdf':
            self.exx = self.evvacdf + self.devv + self.evc + self.ecc
        else:
            1 / 0

        self.log('Exact exchange energy:')
        for txt, e in [('core-core', self.ecc), ('valence-core', self.evc),
                       ('valence-valence (pseudo, acdf)', self.evvacdf),
                       ('valence-valence (pseudo, standard)', self.evv),
                       ('valence-valence (correction)', self.devv),
                       ('total (%s)' % self.method, self.exx)]:
            self.log('    %-36s %14.6f eV' % (txt + ':', e * Hartree))

        self.log('Total time: %10.3f seconds' % (time() - self.t0))

        self.npairs = self.world.sum(self.npairs)
        assert self.npairs == self.npairs0

        self.timer.stop('EXX')
        self.timer.write(self.fd)

    def calculate_gamma(self, vol, alpha):
        if self.molecule:
            return 0.0

        N_c = self.kd.N_c
        offset_c = (N_c + 1) % 2 * 0.5 / N_c
        bzq_qc = monkhorst_pack(N_c) + offset_c
        qd = KPointDescriptor(bzq_qc)
        pd = PWDescriptor(self.wfs.pd.ecut, self.wfs.gd, kd=qd)
        gamma = (vol / (2 * pi)**2 * sqrt(pi / alpha) * self.kd.nbzkpts)
        for G2_G in pd.G2_qG:
            if G2_G[0] < 1e-7:
                G2_G = G2_G[1:]
            gamma -= np.dot(np.exp(-alpha * G2_G), G2_G**-1)
        return gamma / self.qstride_c.prod()

    def indices(self, s, k1, k2):
        """Generator for (K2, q, n1, n2) indices for (k1, k2) pair.

        s: int
            Spin index.
        k1: int
            Index of k-point in the IBZ.
        k2: int
            Index of k-point in the IBZ.

        Returns (K, q, n1_n, n2), where K then index of the k-point in
        the BZ that k2 is mapped to, q is the index of the q-vector
        between K and k1, and n1_n is a list of bands that should be
        combined with band n2."""

        for K, k in enumerate(self.kd.bz2ibz_k):
            if k == k2:
                for K, q, n1_n, n2 in self._indices(s, k1, k2, K):
                    yield K, q, n1_n, n2

    def _indices(self, s, k1, k2, K2):
        k1_c = self.kd.ibzk_kc[k1]
        k2_c = self.kd.bzk_kc[K2]
        q_c = k2_c - k1_c
        q = abs(self.bzq_qc - q_c).sum(1).argmin()
        if abs(self.bzq_qc[q] - q_c).sum() > 1e-7:
            return

        if self.gamma_point == 0 and q == self.q0:
            return

        nocc1 = self.nocc_sk[s, k1]
        nocc2 = self.nocc_sk[s, k2]

        # Is k2 in the IBZ?
        is_ibz2 = (self.kd.ibz2bz_k[k2] == K2)

        for n2 in range(self.wfs.bd.nbands):
            # Find range of n1's (from n1a to n1b-1):
            if is_ibz2:
                # We get this combination twice, so let's only do half:
                if k1 >= k2:
                    n1a = n2
                else:
                    n1a = n2 + 1
            else:
                n1a = 0

            n1b = self.wfs.bd.nbands

            if self.bandstructure:
                if n2 >= nocc2:
                    n1b = min(n1b, nocc1)
            else:
                if n2 >= nocc2:
                    break
                n1b = min(n1b, nocc1)

            if self.bands is not None:
                assert self.bandstructure
                n1_n = []
                for n1 in range(n1a, n1b):
                    if (n1 in self.bands and n2 < nocc2
                            or is_ibz2 and n2 in self.bands and n1 < nocc1):
                        n1_n.append(n1)
                n1_n = np.array(n1_n)
            else:
                n1_n = np.arange(n1a, n1b)

            if len(n1_n) == 0:
                continue

            yield K2, q, n1_n, n2

    def apply(self, K2, q, kpt1, kpt2, n1_n, n2):
        k20_c = self.kd.ibzk_kc[kpt2.k]
        k2_c = self.kd.bzk_kc[K2]

        if k2_c.any():
            self.timer.start('Initialize plane waves')
            eik2r_R = self.wfs.gd.plane_wave(k2_c)
            eik20r_R = self.wfs.gd.plane_wave(k20_c)
            self.timer.stop('Initialize plane waves')
        else:
            eik2r_R = 1.0
            eik20r_R = 1.0

        w1 = self.kd.weight_k[kpt1.k]
        w2 = self.kd.weight_k[kpt2.k]

        # Is k2 in the 1. BZ?
        is_ibz2 = (self.kd.ibz2bz_k[kpt2.k] == K2)

        e_n = self.calculate_interaction(n1_n, n2, kpt1, kpt2, q, K2, eik20r_R,
                                         eik2r_R, is_ibz2)

        e_n *= 1.0 / self.kd.nbzkpts / self.wfs.nspins * self.qstride_c.prod()

        if q == self.q0:
            e_n[n1_n == n2] *= 0.5

        f1_n = kpt1.f_n[n1_n]
        eps1_n = kpt1.eps_n[n1_n]
        f2 = kpt2.f_n[n2]
        eps2 = kpt2.eps_n[n2]

        s_n = np.sign(eps2 - eps1_n)

        evv = (f1_n * f2 * e_n).sum()
        evvacdf = 0.5 * (f1_n * (1 - s_n) * e_n + f2 * (1 + s_n) * e_n).sum()
        self.evv += evv * w1
        self.evvacdf += evvacdf * w1
        if is_ibz2:
            self.evv += evv * w2
            self.evvacdf += evvacdf * w2

        if self.bandstructure:
            x = self.wfs.nspins
            self.exx_skn[kpt1.s, kpt1.k, n1_n] += x * f2 * e_n
            if is_ibz2:
                self.exx_skn[kpt2.s, kpt2.k, n2] += x * np.dot(f1_n, e_n)

    def calculate_interaction(self, n1_n, n2, kpt1, kpt2, q, k, eik20r_R,
                              eik2r_R, is_ibz2):
        """Calculate Coulomb interactions.

        For all n1 in the n1_n list, calculate interaction with n2."""

        # number of plane waves:
        ng1 = self.wfs.ng_k[kpt1.k]
        ng2 = self.wfs.ng_k[kpt2.k]

        # Transform to real space and apply symmetry operation:
        self.timer.start('IFFT1')
        if is_ibz2:
            u2_R = self.pd.ifft(kpt2.psit_nG[n2, :ng2], kpt2.k)
        else:
            psit2_R = self.pd.ifft(kpt2.psit_nG[n2, :ng2], kpt2.k) * eik20r_R
            self.timer.start('Symmetry transform')
            u2_R = self.kd.transform_wave_function(psit2_R, k) / eik2r_R
            self.timer.stop()
        self.timer.stop()

        # Calculate pair densities:
        nt_nG = self.pd2.zeros(len(n1_n), q=q)
        for n1, nt_G in zip(n1_n, nt_nG):
            self.timer.start('IFFT2')
            u1_R = self.pd.ifft(kpt1.psit_nG[n1, :ng1], kpt1.k)
            self.timer.stop()
            nt_R = u1_R.conj() * u2_R
            self.timer.start('FFT')
            nt_G[:] = self.pd2.fft(nt_R, q)
            self.timer.stop()

        s = self.kd.sym_k[k]
        time_reversal = self.kd.time_reversal_k[k]
        k2_c = self.kd.ibzk_kc[kpt2.k]

        self.timer.start('Compensation charges')
        Q_anL = {}  # coefficients for shape functions
        for a, P1_ni in kpt1.P_ani.items():
            P1_ni = P1_ni[n1_n]

            if is_ibz2:
                P2_i = kpt2.P_ani[a][n2]
            else:
                b = self.kd.symmetry.a_sa[s, a]
                S_c = (np.dot(self.spos_ac[a], self.kd.symmetry.op_scc[s]) -
                       self.spos_ac[b])
                assert abs(S_c.round() - S_c).max() < 1e-5
                if self.ghat.dtype == complex:
                    x = np.exp(2j * pi * np.dot(k2_c, S_c))
                else:
                    x = 1.0
                P2_i = np.dot(self.wfs.setups[a].R_sii[s],
                              kpt2.P_ani[b][n2]) * x
                if time_reversal:
                    P2_i = P2_i.conj()

            D_np = []
            for P1_i in P1_ni:
                D_ii = np.outer(P1_i.conj(), P2_i)
                D_np.append(pack(D_ii))
            Q_anL[a] = np.dot(D_np, self.wfs.setups[a].Delta_pL)

        self.timer.start('Expand')
        if q != self.qlatest:
            self.f_IG = self.ghat.expand(q)
            self.qlatest = q
        self.timer.stop('Expand')

        # Add compensation charges:
        self.ghat.add(nt_nG, Q_anL, q, self.f_IG)
        self.timer.stop('Compensation charges')

        if self.molecule and n2 in n1_n:
            nn = (n1_n == n2).nonzero()[0][0]
            nt_nG[nn] -= self.ngauss_G
        else:
            nn = None

        iG2_G = self.iG2_qG[q]

        # Calculate energies:
        e_n = np.empty(len(n1_n))
        for n, nt_G in enumerate(nt_nG):
            e_n[n] = -4 * pi * np.real(self.pd2.integrate(nt_G, nt_G * iG2_G))
            self.npairs += 1

        if nn is not None:
            e_n[nn] -= 2 * (self.pd2.integrate(nt_nG[nn], self.vgauss_G) +
                            (self.beta / 2 / pi)**0.5)

        if self.write_timing_information:
            t = (time() - self.t0) / len(n1_n)
            self.log('Time for first pair-density: %10.3f seconds' % t)
            self.log('Estimated total time:        %10.3f seconds' %
                     (t * self.npairs0 / self.world.size))
            self.write_timing_information = False

        return e_n

    def calculate_exx_paw_correction(self):
        self.timer.start('PAW correction')
        self.devv = 0.0
        self.evc = 0.0
        self.ecc = 0.0

        deg = 2 // self.wfs.nspins  # spin degeneracy
        for a, D_sp in self.dens.D_asp.items():
            setup = self.wfs.setups[a]
            for D_p in D_sp:
                D_ii = unpack2(D_p)
                ni = len(D_ii)

                for i1 in range(ni):
                    for i2 in range(ni):
                        A = 0.0
                        for i3 in range(ni):
                            p13 = packed_index(i1, i3, ni)
                            for i4 in range(ni):
                                p24 = packed_index(i2, i4, ni)
                                A += setup.M_pp[p13, p24] * D_ii[i3, i4]
                        self.devv -= D_ii[i1, i2] * A / deg

                self.evc -= np.dot(D_p, setup.X_p)
            self.ecc += setup.ExxC

        if not self.bandstructure:
            self.timer.stop('PAW correction')
            return

        Q = self.world.size // self.wfs.kd.comm.size
        self.exx_skn *= Q
        for kpt in self.wfs.kpt_u:
            for a, D_sp in self.dens.D_asp.items():
                setup = self.wfs.setups[a]
                for D_p in D_sp:
                    D_ii = unpack2(D_p)
                    ni = len(D_ii)
                    P_ni = kpt.P_ani[a]
                    for i1 in range(ni):
                        for i2 in range(ni):
                            A = 0.0
                            for i3 in range(ni):
                                p13 = packed_index(i1, i3, ni)
                                for i4 in range(ni):
                                    p24 = packed_index(i2, i4, ni)
                                    A += setup.M_pp[p13, p24] * D_ii[i3, i4]
                            self.exx_skn[kpt.s, kpt.k] -= \
                                (A * P_ni[:, i1].conj() * P_ni[:, i2]).real
                            p12 = packed_index(i1, i2, ni)
                            self.exx_skn[kpt.s, kpt.k] -= \
                                (P_ni[:, i1].conj() * setup.X_p[p12] *
                                 P_ni[:, i2]).real / self.wfs.nspins

        self.world.sum(self.exx_skn)
        self.exx_skn *= self.hybrid / Q
        self.timer.stop('PAW correction')

    def initialize_gaussian(self):
        """Calculate gaussian compensation charge and its potential.

        Used to decouple electrostatic interactions between
        periodically repeated images for molecular calculations.

        Charge containing one electron::

            (beta/pi)^(3/2)*exp(-beta*r^2),

        its Fourier transform::

            exp(-G^2/(4*beta)),

        and its potential::

            erf(beta^0.5*r)/r.
        """

        gd = self.wfs.gd

        # Set exponent of exp-function to -19 on the boundary:
        self.beta = 4 * 19 * (gd.icell_cv**2).sum(1).max()

        # Calculate gaussian:
        G_Gv = self.pd2.get_reciprocal_vectors()
        G2_G = self.pd2.G2_qG[0]
        C_v = gd.cell_cv.sum(0) / 2  # center of cell
        self.ngauss_G = np.exp(-1.0 / (4 * self.beta) * G2_G +
                               1j * np.dot(G_Gv, C_v)) / gd.dv

        # Calculate potential from gaussian:
        R_Rv = gd.get_grid_point_coordinates().transpose((1, 2, 3, 0))
        r_R = ((R_Rv - C_v)**2).sum(3)**0.5
        if (gd.N_c % 2 == 0).all():
            r_R[tuple(gd.N_c // 2)] = 1.0  # avoid dividing by zero
        v_R = erf(self.beta**0.5 * r_R) / r_R
        if (gd.N_c % 2 == 0).all():
            v_R[tuple(gd.N_c // 2)] = (4 * self.beta / pi)**0.5
        self.vgauss_G = self.pd2.fft(v_R)

        # Compare self-interaction to analytic result:
        assert abs(0.5 * self.pd2.integrate(self.ngauss_G, self.vgauss_G) -
                   (self.beta / 2 / pi)**0.5) < 1e-6
Exemple #5
0
class Unfold:
    """This Class is used to Unfold the Bands of a supercell (SC) calculations
    into a the primitive cell (PC). As a convention (when possible) capital
    letters variables are related to the SC while lowercase ones to the
    PC """
    def __init__(self, name=None, calc=None, M=None, spinorbit=None):

        self.name = name
        self.calc = GPAW(calc, txt=None, communicator=mpi.serial_comm)
        self.M = np.array(M, dtype=float)
        self.spinorbit = spinorbit

        self.gd = self.calc.wfs.gd.new_descriptor()

        self.kd = self.calc.wfs.kd
        if self.calc.wfs.mode is 'pw':
            self.pd = self.calc.wfs.pd
        else:
            self.pd = PWDescriptor(ecut=None,
                                   gd=self.gd,
                                   kd=self.kd,
                                   dtype=complex)

        self.acell_cv = self.gd.cell_cv
        self.bcell_cv = 2 * np.pi * self.gd.icell_cv
        self.vol = self.gd.volume
        self.BZvol = (2 * np.pi)**3 / self.vol

        self.nb = self.calc.get_number_of_bands()

        self.v_Knm = None
        if spinorbit:
            if mpi.world.rank == 0:
                print('Calculating spinorbit Corrections')
            self.nb = 2 * self.calc.get_number_of_bands()
            self.e_mK, self.v_Knm = get_spinorbit_eigenvalues(self.calc,
                                                              return_wfs=True)
            if mpi.world.rank == 0:
                print('Done with the spinorbit Corrections')

    def get_K_index(self, K):
        """Find the index of a given K."""

        K = np.array([K])
        bzKG = to1bz(K, self.acell_cv)[0]
        iK = self.kd.where_is_q(bzKG, self.kd.bzk_kc)
        return iK

    def get_g(self, iK):
        """Not all the G vectors are relevant for the bands unfolding,
        but only the ones that match the PC reciprocal vectors.
        This function finds the relevant ones."""

        G_Gv_temp = self.pd.get_reciprocal_vectors(q=iK, add_q=False)
        G_Gc_temp = np.dot(G_Gv_temp, np.linalg.inv(self.bcell_cv))

        iG_list = []
        g_list = []
        for iG, G in enumerate(G_Gc_temp):
            a = np.dot(G, np.linalg.inv(self.M).T)
            check = np.abs(a) % 1 < 1e-5
            check2 = np.abs((np.abs(a[np.where(~check)]) % 1) - 1) < 1e-5
            if all(check) or all(check2):
                iG_list.append(iG)
                g_list.append(G)

        return np.array(iG_list), np.array(g_list)

    def get_G_index(self, iK, G, G_list):
        """Find the index of a given G."""

        G_list -= G
        sumG = np.sum(abs(G_list), axis=1)
        iG = np.where(sumG < 1e-5)[0]
        return iG

    def get_eigenvalues(self, iK):
        """Get the list of eigenvalues for a given iK."""

        if not self.spinorbit:
            e_m = self.calc.get_eigenvalues(kpt=iK, spin=0) / Hartree
        else:
            e_m = self.e_mK[:, iK] / Hartree
        return np.array(e_m)

    def get_pw_wavefunctions_k(self, iK):
        """Get the list of Fourier coefficients of the WaveFunction for a
        given iK. For spinors the number of bands is doubled and a spin
        dimension is added."""

        psi_mgrid = get_rs_wavefunctions_k(self.calc, iK, self.spinorbit,
                                           self.v_Knm)
        if not self.spinorbit:
            psi_list_mG = []
            for i in range(len(psi_mgrid)):
                psi_list_mG.append(self.pd.fft(psi_mgrid[i], iK))

            psi_mG = np.array(psi_list_mG)
            return psi_mG
        else:
            u0_list_mG = []
            u1_list_mG = []
            for i in range(psi_mgrid.shape[0]):
                u0_list_mG.append(self.pd.fft(psi_mgrid[i, 0], iK))
                u1_list_mG.append(self.pd.fft(psi_mgrid[i, 1], iK))

            u0_mG = np.array(u0_list_mG)
            u1_mG = np.array(u1_list_mG)

            u_mG = np.zeros((len(u0_mG), 2, u0_mG.shape[1]), complex)

            u_mG[:, 0] = u0_mG
            u_mG[:, 1] = u1_mG
            return u_mG

    def get_spectral_weights_k(self, k_t):
        """Returns the spectral weights for a given k in the PC:
            
            P_mK(k_t) = \sum_n |<Km|k_t n>|**2
        
        which can be shown to be equivalent to:
        
            P_mK(k_t) = \sum_g |C_Km(g+k_t-K)|**2
        """

        K_c, G_t = find_K_from_k(k_t, self.M)
        iK = self.get_K_index(K_c)
        iG_list, g_list = self.get_g(iK)
        gG_t_list = g_list + G_t

        G_Gv = self.pd.get_reciprocal_vectors(q=iK, add_q=False)
        G_Gc = np.dot(G_Gv, np.linalg.inv(self.bcell_cv))

        igG_t_list = []
        for g in gG_t_list:
            igG_t_list.append(self.get_G_index(iK, g.copy(), G_Gc.copy()))

        C_mG = self.get_pw_wavefunctions_k(iK)
        P_m = []
        if not self.spinorbit:
            for m in range(self.nb):
                P = 0.
                norm = np.sum(np.linalg.norm(C_mG[m, :])**2)
                for iG in igG_t_list:
                    P += np.linalg.norm(C_mG[m, iG])**2
                P_m.append(P / norm)
        else:
            for m in range(self.nb):
                P = 0.
                norm = np.sum(
                    np.linalg.norm(C_mG[m, 0, :])**2 +
                    np.linalg.norm(C_mG[m, 1, :])**2)
                for iG in igG_t_list:
                    P += (np.linalg.norm(C_mG[m, 0, iG])**2 +
                          np.linalg.norm(C_mG[m, 1, iG])**2)
                P_m.append(P / norm)

        return np.array(P_m)

    def get_spectral_weights(self, kpoints, filename=None):
        """Collect the spectral weights for the k points in the kpoints list.
        
        This function is parallelized over k's."""

        Nk = len(kpoints)
        Nb = self.nb

        world = mpi.world
        if filename is None:
            try:
                e_mK, P_mK = pickle.load(
                    open('weights_' + self.name + '.pckl', 'rb'))
            except IOError:
                e_Km = []
                P_Km = []
                if world.rank == 0:
                    print('Getting EigenValues and Weights')

                e_Km = np.zeros((Nk, Nb))
                P_Km = np.zeros((Nk, Nb))
                myk = range(0, Nk)[world.rank::world.size]
                for ik in myk:
                    k = kpoints[ik]
                    print('kpoint: %s' % k)
                    K_c, G_c = find_K_from_k(k, self.M)
                    iK = self.get_K_index(K_c)
                    e_Km[ik] = self.get_eigenvalues(iK)
                    P_Km[ik] = self.get_spectral_weights_k(k)

                world.barrier()
                world.sum(e_Km)
                world.sum(P_Km)

                e_mK = np.array(e_Km).T
                P_mK = np.array(P_Km).T
                if world.rank == 0:
                    pickle.dump((e_mK, P_mK),
                                open('weights_' + self.name + '.pckl', 'wb'))
        else:
            e_mK, P_mK = pickle.load(open(filename, 'rb'))

        return e_mK, P_mK

    def spectral_function(self,
                          kpts,
                          x,
                          X,
                          points_name,
                          width=0.002,
                          npts=10000,
                          filename=None):
        """Returns the spectral function for all the ks in kpoints:
                                                                                            
                                              eta / pi
                                                                                      
            A_k(e) = \sum_m  P_mK(k) x  ---------------------
                                                                              
                                        (e - e_mk)**2 + eta**2
                                                                               
 
        at each k-points defined on npts energy points in the range
        [emin, emax]. The width keyword is FWHM = 2 * eta."""

        Nk = len(kpts)
        A_ke = np.zeros((Nk, npts), float)

        world = mpi.world
        e_mK, P_mK = self.get_spectral_weights(kpts, filename)
        if world.rank == 0:
            print('Calculating the Spectral Function')
        emin = np.min(e_mK) - 5 * width
        emax = np.max(e_mK) + 5 * width
        e = np.linspace(emin, emax, npts)

        for ik in range(Nk):
            for ie in range(len(e_mK[:, ik])):
                e0 = e_mK[ie, ik]
                D = (width / 2 / np.pi) / ((e - e0)**2 + (width / 2)**2)
                A_ke[ik] += P_mK[ie, ik] * D
        if world.rank == 0:
            pickle.dump((e * Hartree, A_ke, x, X, points_name),
                        open('sf_' + self.name + '.pckl', 'wb'))
            print('Spectral Function calculation completed!')
        return
Exemple #6
0
class TDDFT(object):
    """
    Time-dependent DFT+Hartree-Fock in Kohn-Sham orbitals basis:
    
        calc: GPAW calculator (setups='sg15')
        nbands (int): number of bands in calculation
        
    """
    
    def __init__(self,calc,nbands=None):
        self.calc=calc
        self.K=calc.get_ibz_k_points() # reduced Brillioun zone
        self.NK=self.K.shape[0] 
        
        self.wk=calc.get_k_point_weights() # weight of reduced Brillioun zone
        if nbands is None:
            self.nbands=calc.get_number_of_bands()
        else:
            self.nbands=nbands
        self.nvalence=int(calc.get_number_of_electrons()/2)
        
        self.EK=[calc.get_eigenvalues(k)[:self.nbands] for k in range(self.NK)] # bands energy
        self.EK=np.array(self.EK)/Hartree
        self.shape=tuple(calc.get_number_of_grid_points()) # shape of real space grid
        self.density=calc.get_pseudo_density()*Bohr**3 # density at zero time
        
        
        # array of u_nk (periodic part of Kohn-Sham orbitals,only reduced Brillion zone)
        self.ukn=np.zeros((self.NK,self.nbands,)+self.shape,dtype=np.complex) 
        for k in range(self.NK):
            kpt = calc.wfs.kpt_u[k]
            for n in range(self.nbands):
                psit_G = kpt.psit_nG[n]
                psit_R = calc.wfs.pd.ifft(psit_G, kpt.q)
                self.ukn[k,n]=psit_R 
                
        self.icell=2.0 * np.pi * calc.wfs.gd.icell_cv # inverse cell 
        self.cell = calc.wfs.gd.cell_cv # cell
        self.r=calc.wfs.gd.get_grid_point_coordinates()
        for i in range(3):
            self.r[i]-=self.cell[i,i]/2.
        self.volume = np.abs(np.linalg.det(calc.wfs.gd.cell_cv)) # volume of cell
        self.norm=calc.wfs.gd.dv # 
        self.Fermi=calc.get_fermi_level()/Hartree #Fermi level
        
        #desriptors at q=gamma for Hartree
        self.kd=KPointDescriptor([[0,0,0]]) 
        self.pd=PWDescriptor(ecut=calc.wfs.pd.ecut,gd=calc.wfs.gd,kd=self.kd,dtype=complex)
        
        
        #Fermi-Dirac temperature
        self.temperature=calc.occupations.width
        
        #Fermi-Dirac distribution
        self.f=1/(1+np.exp((self.EK-self.Fermi)/self.temperature))
        
        self.Hartree_elements=np.zeros((self.NK,self.nbands,self.NK,self.nbands,self.nbands),dtype=np.complex)
        self.LDAx_elements=np.zeros((self.NK,self.nbands,self.NK,self.nbands,self.nbands),dtype=np.complex)
        self.LDAc_elements=np.zeros((self.NK,self.nbands,self.NK,self.nbands,self.nbands),dtype=np.complex)
        
        G=self.pd.get_reciprocal_vectors()
        G2=np.linalg.norm(G,axis=1)**2;G2[G2==0]=np.inf
        matrix=np.zeros((self.NK,self.nbands,self.nbands),dtype=np.complex)
        
        for k in tqdm(range(self.NK)):
            for n in range(self.nbands):
                
                density=self.norm*np.abs(self.ukn[k,n])**2
                
                operator=xc.VLDAx(density)
                self.LDAx_elements[k,n]=operator_matrix_periodic(matrix,operator,self.ukn.conj(),self.ukn)*self.norm
                operator=xc.VLDAc(density)
                self.LDAc_elements[k,n]=operator_matrix_periodic(matrix,operator,self.ukn.conj(),self.ukn)*self.norm
                
                density=self.pd.fft(density)
                operator=4*np.pi*self.pd.ifft(density/G2)  
                self.Hartree_elements[k,n]=operator_matrix_periodic(matrix,operator,self.ukn.conj(),self.ukn)*self.norm
        
        self.wavefunction=np.zeros((self.NK,self.nbands,self.nbands),dtype=np.complex) 
        self.Kinetic=np.zeros((self.NK,self.nbands,self.nbands),dtype=np.complex) 
        for k in range(self.NK):
            self.wavefunction[k]=np.eye(self.nbands)
            self.Kinetic[k]=np.diag(self.EK[k])
            
        self.VH0=np.einsum('kn,knqij->qij',self.occupation(self.wavefunction),self.Hartree_elements)
        self.VLDAc0=np.einsum('kn,knqij->qij',self.occupation(self.wavefunction),self.LDAc_elements)
        self.VLDAx0=np.einsum('kn,knqij->qij',self.occupation(self.wavefunction),self.LDAx_elements)
        
        self.Full_BZ=calc.get_bz_k_points()
        self.IBZ_map=calc.get_bz_to_ibz_map()
    
    
    def get_transition_matrix(self,direction):
        direction/=np.linalg.norm(direction)
        self.dipole=np.zeros((self.NK,self.nbands,self.nbands),dtype=np.complex)
        for k in range(self.NK):
            kpt = self.calc.wfs.kpt_u[k]
            G=self.calc.wfs.pd.get_reciprocal_vectors(q=k,add_q=True)
            G=np.sum(G*direction[None,:],axis=1)
            for n in range(self.nvalence):
                for m in range(self.nvalence,self.nbands):
                    wfn=kpt.psit_nG[n];wfm=kpt.psit_nG[m]
                    self.dipole[k,n,m]=self.calc.wfs.pd.integrate(wfm,G*wfn)/(self.EK[k,n]-self.EK[k,m])
                    self.dipole[k,m,n]=self.dipole[k,n,m].conj()
        return self.dipole
    
    def occupation(self,wavefunction):
        return 2*np.sum(self.wk[:,None,None]*self.f[:,None,:]*np.abs(wavefunction)**2,axis=2)
    
    def fast_Hartree_matrix(self,wavefunction):
        return np.einsum('kn,knqij->qij',self.occupation(wavefunction),self.Hartree_elements)-self.VH0
    
    def fast_LDA_correlation_matrix(self,wavefunction):
        return np.einsum('kn,knqij->qij',self.occupation(wavefunction),self.LDAc_elements)-self.VLDAc0
    
    def fast_LDA_exchange_matrix(self,wavefunction):
        return np.einsum('kn,knqij->qij',self.occupation(wavefunction),self.LDAx_elements)-self.VLDAx0
    
    def propagate(self,dt,steps,E,direction,corrections=10):
        
        dipole=self.get_transition_matrix(direction)
        
        
        self.time_occupation=np.zeros((steps,self.nbands),dtype=np.complex) 
        self.polarization=np.zeros(steps,dtype=np.complex)
        
        self.time_occupation[0]=np.sum(self.occupation(self.wavefunction),axis=0)
        for k in range(self.NK):
            operator=np.linalg.multi_dot([self.wavefunction[k].T.conj(),dipole[k],self.wavefunction[k]])
            self.polarization[0]+=self.wk[k]*np.sum(operator.diagonal())
        
        for t in tqdm(range(1,steps)):
            H = self.Kinetic+E[t]*self.dipole
            H+= self.fast_Hartree_matrix(self.wavefunction)
            H+= self.fast_LDA_correlation_matrix(self.wavefunction)
            H+= self.fast_LDA_exchange_matrix(self.wavefunction)
            for k in range(self.NK):
                H_left = np.eye(self.nbands)+0.5j*dt*H[k]            
                H_right= np.eye(self.nbands)-0.5j*dt*H[k]
                self.wavefunction[k]=linalg.solve(H_left, [email protected][k]) 
                operator=np.linalg.multi_dot([self.wavefunction[k].T.conj(),dipole[k],self.wavefunction[k]])
                self.polarization[t]+=self.wk[k]*np.sum(operator.diagonal())
            self.time_occupation[t]=np.sum(self.occupation(self.wavefunction),axis=0)
Exemple #7
0
    def calculate_QEH(self):
        print('Calculating QEH self-energy contribution', file=self.fd)

        kd = self.calc.wfs.kd

        # Reset calculation
        self.sigma_sin = np.zeros(self.shape)  # self-energies
        self.dsigma_sin = np.zeros(self.shape)  # derivatives of self-energies

        # Get KS eigenvalues and occupation numbers:
        b1, b2 = self.bands
        nibzk = self.calc.wfs.kd.nibzkpts
        for i, k in enumerate(self.kpts):
            for s in range(self.nspins):
                u = s * nibzk + k
                kpt = self.calc.wfs.kpt_u[u]
                self.eps_sin[s, i] = kpt.eps_n[b1:b2]
                self.f_sin[s, i] = kpt.f_n[b1:b2] / kpt.weight

        # My part of the states we want to calculate QP-energies for:
        mykpts = [
            self.get_k_point(s, K, n1, n2) for s, K, n1, n2 in self.mysKn1n2
        ]

        kplusqdone_u = [set() for kpt in mykpts]
        Nq = len((self.qd.ibzk_kc))
        for iq, q_c in enumerate(self.qd.ibzk_kc):
            self.nq = iq
            nq = iq
            self.save_state_file()

            qcstr = '(' + ', '.join(['%.3f' % x for x in q_c]) + ')'
            print('Calculating contribution from IBZ q-point #%d/%d q_c=%s' %
                  (nq, Nq, qcstr),
                  file=self.fd)

            rcell_cv = 2 * pi * np.linalg.inv(self.calc.wfs.gd.cell_cv).T
            q_abs = np.linalg.norm(np.dot(q_c, rcell_cv))

            # Screened potential
            dW_w = self.dW_qw[nq]
            dW_w = dW_w[:, np.newaxis, np.newaxis]
            L = abs(self.calc.wfs.gd.cell_cv[2, 2])
            dW_w *= L

            nw = self.nw

            Wpm_w = np.zeros([2 * nw, 1, 1], dtype=complex)
            Wpm_w[:nw] = dW_w
            Wpm_w[nw:] = Wpm_w[0:nw]

            with self.timer('Hilbert transform'):
                self.htp(Wpm_w[:nw])
                self.htm(Wpm_w[nw:])

            qd = KPointDescriptor([q_c])
            pd0 = PWDescriptor(self.ecut, self.calc.wfs.gd, complex, qd)

            # modify pd0 by hand - only G=0 component is needed
            pd0.G_Qv = np.array([1e-17, 1e-17, 1e-17])[np.newaxis, :]
            pd0.Q_qG = [np.array([0], dtype='int32')]
            pd0.ngmax = 1
            G_Gv = pd0.get_reciprocal_vectors()

            self.Q_aGii = self.initialize_paw_corrections(pd0)

            # Loop over all k-points in the BZ and find those that are related
            # to the current IBZ k-point by symmetry
            Q1 = self.qd.ibz2bz_k[iq]
            Q2s = set()
            for s, Q2 in enumerate(self.qd.bz2bz_ks[Q1]):
                if Q2 >= 0 and Q2 not in Q2s:
                    Q2s.add(Q2)

            for Q2 in Q2s:
                s = self.qd.sym_k[Q2]
                self.s = s
                U_cc = self.qd.symmetry.op_scc[s]
                time_reversal = self.qd.time_reversal_k[Q2]
                self.sign = 1 - 2 * time_reversal
                Q_c = self.qd.bzk_kc[Q2]
                d_c = self.sign * np.dot(U_cc, q_c) - Q_c
                assert np.allclose(d_c.round(), d_c)

                for u1, kpt1 in enumerate(mykpts):
                    K2 = kd.find_k_plus_q(Q_c, [kpt1.K])[0]
                    kpt2 = self.get_k_point(kpt1.s,
                                            K2,
                                            0,
                                            self.nbands,
                                            block=True)
                    k1 = kd.bz2ibz_k[kpt1.K]
                    i = self.kpts.index(k1)

                    N_c = pd0.gd.N_c
                    i_cG = self.sign * np.dot(
                        U_cc, np.unravel_index(pd0.Q_qG[0], N_c))

                    k1_c = kd.bzk_kc[kpt1.K]
                    k2_c = kd.bzk_kc[K2]
                    # This is the q that connects K1 and K2 in the 1st BZ
                    q1_c = kd.bzk_kc[K2] - kd.bzk_kc[kpt1.K]

                    # G-vector that connects the full Q_c with q1_c
                    shift1_c = q1_c - self.sign * np.dot(U_cc, q_c)
                    assert np.allclose(shift1_c.round(), shift1_c)
                    shift1_c = shift1_c.round().astype(int)
                    shift_c = kpt1.shift_c - kpt2.shift_c - shift1_c
                    I_G = np.ravel_multi_index(i_cG + shift_c[:, None], N_c,
                                               'wrap')
                    pos_av = np.dot(self.spos_ac, pd0.gd.cell_cv)
                    M_vv = np.dot(
                        pd0.gd.cell_cv.T,
                        np.dot(U_cc.T,
                               np.linalg.inv(pd0.gd.cell_cv).T))
                    Q_aGii = []
                    for a, Q_Gii in enumerate(self.Q_aGii):
                        x_G = np.exp(
                            1j * np.dot(G_Gv,
                                        (pos_av[a] - np.dot(M_vv, pos_av[a]))))
                        U_ii = self.calc.wfs.setups[a].R_sii[self.s]
                        Q_Gii = np.dot(
                            np.dot(U_ii, Q_Gii * x_G[:, None, None]),
                            U_ii.T).transpose(1, 0, 2)
                        if self.sign == -1:
                            Q_Gii = Q_Gii.conj()
                        Q_aGii.append(Q_Gii)

                    for n in range(kpt1.n2 - kpt1.n1):
                        ut1cc_R = kpt1.ut_nR[n].conj()
                        eps1 = kpt1.eps_n[n]
                        C1_aGi = [
                            np.dot(Qa_Gii, P1_ni[n].conj())
                            for Qa_Gii, P1_ni in zip(Q_aGii, kpt1.P_ani)
                        ]

                        n_mG = self.calculate_pair_densities(
                            ut1cc_R, C1_aGi, kpt2, pd0, I_G)
                        if self.sign == 1:
                            n_mG = n_mG.conj()

                        f_m = kpt2.f_n
                        deps_m = eps1 - kpt2.eps_n
                        sigma, dsigma = self.calculate_sigma(
                            n_mG, deps_m, f_m, Wpm_w)
                        nn = kpt1.n1 + n - self.bands[0]
                        self.sigma_sin[kpt1.s, i, nn] += sigma
                        self.dsigma_sin[kpt1.s, i, nn] += dsigma

        self.world.sum(self.sigma_sin)
        self.world.sum(self.dsigma_sin)

        self.complete = True
        self.save_state_file()

        return self.sigma_sin, self.dsigma_sin