class PhononPerturbation(Perturbation):
    """Implementation of a phonon perturbation.

    This class implements the change in the effective potential due to a
    displacement of an atom ``a`` in direction ``v`` with wave-vector ``q``.
    The action of the perturbing potential on a state vector is implemented in
    the ``apply`` member function.
    def __init__(self, calc, kd, poisson_solver, dtype=float, **kwargs):
        """Store useful objects, e.g. lfc's for the various atomic functions.
        Depending on whether the system is periodic or finite, Poisson's equation
        is solved with FFT or multigrid techniques, respectively.

        calc: Calculator
            Ground-state calculation.
        kd: KPointDescriptor
            Descriptor for the q-vectors of the dynamical matrix.

        self.kd = kd
        self.dtype = dtype
        self.poisson = poisson_solver

        # Gamma wrt q-vector
        if self.kd.gamma:
            self.phase_cd = None
            assert self.kd.mynks == len(self.kd.ibzk_qc)

            self.phase_qcd = []
            sdisp_cd = calc.wfs.gd.sdisp_cd

            for q in range(self.kd.mynks):
                phase_cd = np.exp(2j * np.pi * \
                                  sdisp_cd * self.kd.ibzk_qc[q, :, np.newaxis])
        # Store grid-descriptors
        self.gd = calc.density.gd
        self.finegd = calc.density.finegd

        # Steal setups for the lfc's
        setups = calc.wfs.setups

        # Store projector coefficients
        self.dH_asp = calc.hamiltonian.dH_asp.copy()
        # Localized functions:
        # core corections
        self.nct = LFC(self.gd, [[setup.nct] for setup in setups],
                       integral=[setup.Nct for setup in setups], dtype=self.dtype)
        # compensation charges
        #XXX what is the consequence of numerical errors in the integral ??
        self.ghat = LFC(self.finegd, [setup.ghat_l for setup in setups],
        ## self.ghat = LFC(self.finegd, [setup.ghat_l for setup in setups],
        ##                 integral=sqrt(4 * pi), dtype=self.dtype)
        # vbar potential
        self.vbar = LFC(self.finegd, [[setup.vbar] for setup in setups],

        # Expansion coefficients for the compensation charges
        self.Q_aL = calc.density.Q_aL.copy()
        # Grid transformer -- convert array from fine to coarse grid
        self.restrictor = Transformer(self.finegd, self.gd, nn=3,
                                      dtype=self.dtype, allocate=False)

        # Atom, cartesian coordinate and q-vector of the perturbation
        self.a = None
        self.v = None
        # Local q-vector index of the perturbation
        if self.kd.gamma:
            self.q = -1
            self.q = None

    def initialize(self, spos_ac):
        """Prepare the various attributes for a calculation."""

        # Set positions on LFC's

        if not self.kd.gamma:
            # Set q-vectors and update
            # Set q-vectors and update

            # Phase factor exp(iq.r) needed to obtian the periodic part of lfcs
            coor_vg = self.finegd.get_grid_point_coordinates()
            cell_cv = self.finegd.cell_cv
            # Convert to scaled coordinates
            scoor_cg = np.dot(la.inv(cell_cv), coor_vg.swapaxes(0, -2))
            scoor_cg = scoor_cg.swapaxes(1,-2)
            # Phase factor
            phase_qg = np.exp(2j * pi *
                              np.dot(self.kd.ibzk_qc, scoor_cg.swapaxes(0,-2)))
            self.phase_qg = phase_qg.swapaxes(1, -2)

        #XXX To be removed from this class !!
        # Setup the Poisson solver -- to be used on the fine grid

        # Grid transformer

    def set_q(self, q):
        """Set the index of the q-vector of the perturbation."""

        assert not self.kd.gamma, "Gamma-point calculation"
        self.q = q

        # Update phases and Poisson solver
        self.phase_cd = self.phase_qcd[q]

        # Invalidate calculated quantities
        # - local part of perturbing potential
        self.v1_G = None

    def set_av(self, a, v):
        """Set atom and cartesian component of the perturbation.

        a: int
            Index of the atom.
        v: int 
            Cartesian component (0, 1 or 2) of the atomic displacement.

        assert self.q is not None
        self.a = a
        self.v = v
        # Update derivative of local potential
    def get_phase_cd(self):
        """Overwrite base class member function."""

        return self.phase_cd
    def has_q(self):
        """Overwrite base class member function."""

        return (not self.kd.gamma)

    def get_q(self):
        """Return q-vector."""

        assert not self.kd.gamma, "Gamma-point calculation."
        return self.kd.ibzk_qc[self.q]
    def solve_poisson(self, phi_g, rho_g):
        """Solve Poisson's equation for a Bloch-type charge distribution.

        More to come here ...
        phi_g: GridDescriptor
            Grid for the solution of Poissons's equation.
        rho_g: GridDescriptor
            Grid with the charge distribution.


        #assert phi_g.shape == rho_g.shape == self.phase_qg.shape[-3:], \
        #       ("Arrays have incompatible shapes.")
        assert self.q is not None, ("q-vector not set")
        # Gamma point calculation wrt the q-vector -> rho_g periodic
        if self.kd.gamma: 
            #XXX NOTICE: solve_neutral
            self.poisson.solve_neutral(phi_g, rho_g)
            # Divide out the phase factor to get the periodic part
            rhot_g = rho_g/self.phase_qg[self.q]

            # Solve Poisson's equation for the periodic part of the potential
            #XXX NOTICE: solve_neutral
            self.poisson.solve_neutral(phi_g, rhot_g)

            # Return to Bloch form
            phi_g *= self.phase_qg[self.q]

    def calculate_local_potential(self):
        """Derivate of the local potential wrt an atomic displacements.

        The local part of the PAW potential has contributions from the
        compensation charges (``ghat``) and a spherical symmetric atomic
        potential (``vbar``).

        assert self.a is not None
        assert self.v is not None
        assert self.q is not None
        a = self.a
        v = self.v
        # Expansion coefficients for the ghat functions
        Q_aL = self.ghat.dict(zero=True)
        # Remember sign convention for add_derivative method
        # And be sure not to change the dtype of the arrays by assigning values
        # to array elements.
        Q_aL[a][:] = -1 * self.Q_aL[a]

        # Grid for derivative of compensation charges
        ghat1_g = self.finegd.zeros(dtype=self.dtype)
        self.ghat.add_derivative(a, v, ghat1_g, c_axi=Q_aL, q=self.q)
        # Solve Poisson's eq. for the potential from the periodic part of the
        # compensation charge derivative
        v1_g = self.finegd.zeros(dtype=self.dtype)
        self.solve_poisson(v1_g, ghat1_g)
        # Store potential from the compensation charge
        self.vghat1_g = v1_g.copy()
        # Add derivative of vbar - sign convention in add_derivative method
        c_ai = self.vbar.dict(zero=True)
        c_ai[a][0] = -1.
        self.vbar.add_derivative(a, v, v1_g, c_axi=c_ai, q=self.q)

        # Store potential for the evaluation of the energy derivative
        self.v1_g = v1_g.copy()
        # Transfer to coarse grid
        v1_G = self.gd.zeros(dtype=self.dtype)
        self.restrictor.apply(v1_g, v1_G, phases=self.phase_cd)

        self.v1_G = v1_G
    def apply(self, psi_nG, y_nG, wfs, k, kplusq):
        """Apply perturbation to unperturbed wave-functions.

        psi_nG: ndarray
            Set of grid vectors to which the perturbation is applied.
        y_nG: ndarray
            Output vectors.
        wfs: WaveFunctions
            Instance of class ``WaveFunctions``.
        k: int
            Index of the k-point for the vectors.
        kplusq: int
            Index of the k+q vector.

        assert self.a is not None
        assert self.v is not None
        assert self.q is not None
        assert psi_nG.ndim in (3, 4)
        assert tuple(self.gd.n_c) == psi_nG.shape[-3:]

        if psi_nG.ndim == 3:
            y_nG += self.v1_G * psi_nG
            y_nG += self.v1_G[np.newaxis, :] * psi_nG

        self.apply_nonlocal_potential(psi_nG, y_nG, wfs, k, kplusq)

    def apply_nonlocal_potential(self, psi_nG, y_nG, wfs, k, kplusq):
        """Derivate of the non-local PAW potential wrt an atomic displacement.

        k: int
            Index of the k-point being operated on.
        kplusq: int
            Index of the k+q vector.

        assert self.a is not None
        assert self.v is not None
        assert psi_nG.ndim in (3, 4)
        assert tuple(self.gd.n_c) == psi_nG.shape[-3:]
        if psi_nG.ndim == 3:
            n = 1
            n = psi_nG.shape[0] 
        a = self.a
        v = self.v
        P_ani = wfs.kpt_u[k].P_ani
        dP_aniv = wfs.kpt_u[k].dP_aniv
        pt = wfs.pt
        # < p_a^i | Psi_nk >
        P_ni = P_ani[a]
        # < dp_av^i | Psi_nk > - remember the sign convention of the derivative
        dP_ni = -1 * dP_aniv[a][...,v]
        # Expansion coefficients for the projectors on atom a
        dH_ii = unpack(self.dH_asp[a][0])
        # The derivative of the non-local PAW potential has two contributions
        # 1) Sum over projectors
        c_ni = np.dot(dP_ni, dH_ii)
        c_ani = pt.dict(shape=n, zero=True)
        c_ani[a] = c_ni
        # k+q !!
        pt.add(y_nG, c_ani, q=kplusq)

        # 2) Sum over derivatives of the projectors
        dc_ni = np.dot(P_ni, dH_ii)
        dc_ani = pt.dict(shape=n, zero=True)
        # Take care of sign of derivative in the coefficients
        dc_ani[a] = -1 * dc_ni
        # k+q !!
        pt.add_derivative(a, v, y_nG, dc_ani, q=kplusq)
    def get_projections(self, locfun, spin=0):
        """Project wave functions onto localized functions

        Determine the projections of the Kohn-Sham eigenstates
        onto specified localized functions of the format::

          locfun = [[spos_c, l, sigma], [...]]

        spos_c can be an atom index, or a scaled position vector. l is
        the angular momentum, and sigma is the (half-) width of the
        radial gaussian.

        Return format is::

          f_kni = <psi_kn | f_i>

        where psi_kn are the wave functions, and f_i are the specified
        localized functions.

        As a special case, locfun can be the string 'projectors', in which
        case the bound state projectors are used as localized functions.

        wfs = self.wfs

        if locfun == 'projectors':
            f_kin = []
            for kpt in wfs.kpt_u:
                if kpt.s == spin:
                    f_in = []
                    for a, P_ni in kpt.P_ani.items():
                        i = 0
                        setup = wfs.setups[a]
                        for l, n in zip(setup.l_j, setup.n_j):
                            if n >= 0:
                                for j in range(i, i + 2 * l + 1):
                                    f_in.append(P_ni[:, j])
                            i += 2 * l + 1
            f_kni = np.array(f_kin).transpose(0, 2, 1)
            return f_kni.conj()

        from gpaw.lfc import LocalizedFunctionsCollection as LFC
        from gpaw.spline import Spline
        from gpaw.utilities import _fact

        nkpts = len(wfs.ibzk_kc)
        nbf = np.sum([2 * l + 1 for pos, l, a in locfun])
        f_kni = np.zeros((nkpts, wfs.nbands, nbf), wfs.dtype)

        spos_ac = self.atoms.get_scaled_positions() % 1.0
        spos_xc = []
        splines_x = []
        for spos_c, l, sigma in locfun:
            if isinstance(spos_c, int):
                spos_c = spos_ac[spos_c]
            alpha = .5 * Bohr**2 / sigma**2
            r = np.linspace(0, 6. * sigma, 500)
            f_g = (_fact[l] * (4 * alpha)**(l + 3 / 2.) *
                   np.exp(-alpha * r**2) /
                   (np.sqrt(4 * np.pi) * _fact[2 * l + 1]))
            splines_x.append([Spline(l, rmax=r[-1], f_g=f_g, points=61)])

        lf = LFC(wfs.gd, splines_x, wfs.kpt_comm, dtype=wfs.dtype)
        if not wfs.gamma:

        k = 0
        f_ani = lf.dict(wfs.nbands)
        for kpt in wfs.kpt_u:
            if kpt.s != spin:
            lf.integrate(kpt.psit_nG[:], f_ani, kpt.q)
            i1 = 0
            for x, f_ni in f_ani.items():
                i2 = i1 + f_ni.shape[1]
                f_kni[k, :, i1:i2] = f_ni
                i1 = i2
            k += 1

        return f_kni.conj()
class WaveFunctions:
    """Class for wave-function related stuff (e.g. projectors)."""
    def __init__(self, nbands, kpt_u, setups, kd, gd, dtype=float):
        """Store and initialize required attributes.

        nbands: int
            Number of occupied bands.
        kpt_u: list of KPoints
            List of KPoint instances from a ground-state calculation (i.e. the
            attribute ``calc.wfs.kpt_u``).
        setups: Setups
            LocalizedFunctionsCollection setups.
        kd: KPointDescriptor
            K-point and symmetry related stuff.
        gd: GridDescriptor
            Descriptor for the coarse grid.            
        dtype: dtype
            This is the ``dtype`` for the wave-function derivatives (same as
            the ``dtype`` for the ground-state wave-functions).


        self.dtype = dtype
        # K-point related attributes
        self.kd = kd
        # Number of occupied bands
        self.nbands = nbands
        # Projectors
        self.pt = LFC(gd, [setup.pt_j for setup in setups], dtype=self.dtype)
        # Store grid
        self.gd = gd

        # Unfold the irreducible BZ to the full BZ
        # List of KPointContainers for the k-points in the full BZ
        self.kpt_u = []

        # No symmetries or only time-reversal symmetry used
        if kd.symmetry is None:
            # For now, time-reversal symmetry not allowed
            assert len(kpt_u) == kd.nbzkpts            

            for k in range(kd.nbzkpts):
                kpt_ = kpt_u[k]

                psit_nG = gd.empty(nbands, dtype=self.dtype)

                for n, psit_G in enumerate(psit_nG):
                    psit_G[:] = kpt_.psit_nG[n]
                    # psit_0 = psit_G[0, 0, 0]
                    # psit_G *= psit_0.conj() / (abs(psit_0))
                # Strip off KPoint attributes and store in the KPointContainer
                # Note, only the occupied GS wave-functions are retained here !
                kpt = KPointContainer(weight=kpt_.weight,
                                       # q=kpt.q,
                                       # f_n=kpt.f_n[:nbands])

            assert len(kpt_u) == kd.nibzkpts

            for k, k_c in enumerate(kd.bzk_kc):

                # Index of symmetry related point in the irreducible BZ
                ik = kd.kibz_k[k]
                # Index of point group operation
                s = kd.sym_k[k]
                # Time-reversal symmetry used
                time_reversal = kd.time_reversal_k[k]

                # Coordinates of symmetry related point in the irreducible BZ
                ik_c = kd.ibzk_kc[ik]
                # Point group operation
                op_cc = kd.symmetry.op_scc[s]
                # KPoint from ground-state calculation
                kpt_ = kpt_u[ik]
                weight = 1. / kd.nbzkpts * (2 - kpt_.s)
                phase_cd = np.exp(2j * pi * gd.sdisp_cd * k_c[:, np.newaxis])

                psit_nG = gd.empty(nbands, dtype=self.dtype)

                for n, psit_G in enumerate(psit_nG):
                    #XXX Seems to corrupt my memory somehow ???
                    psit_G[:] = kd.symmetry.symmetrize_wavefunction(
                        kpt_.psit_nG[n], ik_c, k_c, op_cc, time_reversal)
                    # Choose gauge
                    # psit_0 = psit_G[0, 0, 0]
                    # psit_G *= psit_0.conj() / (abs(psit_0))

                kpt = KPointContainer(weight=weight,
    def initialize(self, spos_ac):
        """Initialize projectors according to the ``gamma`` attribute."""

        # Set positions on LFC's
        if not self.kd.gamma:
            # Set k-vectors and update

        # Calculate projector coefficients for the GS wave-functions

    def reset(self):
        """Make fresh arrays for wave-function derivatives."""

        for kpt in self.kpt_u:
            kpt.psit1_nG = self.gd.zeros(n=self.nbands, dtype=self.dtype)
    def calculate_projector_coef(self):
        """Coefficients for the derivative of the non-local part of the PP.

        k: int
            Index of the k-point of the Bloch state on which the non-local
            potential operates on.

        The calculated coefficients are the following (except for an overall
        sign of -1; see ``derivative`` member function of class ``LFC``):

        1. Coefficients from the projector functions::

                        /      a          
               P_ani =  | dG  p (G) Psi (G)  ,
                        /      i       n
        2. Coefficients from the derivative of the projector functions::

                          /      a           
               dP_aniv =  | dG dp  (G) Psi (G)  ,
                          /      iv       n   

                 a        d       a
               dp  (G) =  ---  Phi (G) .
                 iv         a     i


        n = self.nbands

        for kpt in self.kpt_u:

            # K-point index and wave-functions
            k = kpt.k
            psit_nG = kpt.psit_nG
            # Integration dicts
            P_ani   = self.pt.dict(shape=n)
            dP_aniv = self.pt.dict(shape=n, derivative=True)
            # 1) Integrate with projectors
            self.pt.integrate(psit_nG, P_ani, q=k)
            kpt.P_ani = P_ani
            # 2) Integrate with derivative of projectors
            self.pt.derivative(psit_nG, dP_aniv, q=k)
            kpt.dP_aniv = dP_aniv
    def initialize(self):

        self.eta /= Hartree
        self.ecut /= Hartree

        calc = self.calc

        # kpoint init
        self.kd = kd = calc.wfs.kd
        self.bzk_kc = kd.bzk_kc
        self.ibzk_kc = kd.ibzk_kc
        self.nkpt = kd.nbzkpts
        self.ftol /= self.nkpt

        # band init
        if self.nbands is None:
            self.nbands = calc.wfs.nbands
        self.nvalence = calc.wfs.nvalence

        # cell init
        self.acell_cv = calc.atoms.cell / Bohr
        self.bcell_cv, self.vol, self.BZvol = get_primitive_cell(self.acell_cv)

        # grid init
        self.nG = calc.get_number_of_grid_points()
        self.nG0 = self.nG[0] * self.nG[1] * self.nG[2]
        gd = GridDescriptor(self.nG,
        self.gd = gd
        self.h_cv = gd.h_cv

        # obtain eigenvalues, occupations
        nibzkpt = kd.nibzkpts
        kweight_k = kd.weight_k

            self.printtxt('Use eigenvalues from the calculator.')
            self.e_kn = np.array(
                 for k in range(nibzkpt)]) / Hartree
            self.printtxt('Eigenvalues(k=0) are:')
            print >> self.txt, self.e_kn[0] * Hartree
        self.f_kn = np.array([
            calc.get_occupation_numbers(kpt=k) / kweight_k[k]
            for k in range(nibzkpt)
        ]) / self.nkpt

        # k + q init
        assert self.q_c is not None
        self.qq_v = np.dot(self.q_c, self.bcell_cv)  # summation over c

        if self.optical_limit:
            kq_k = np.arange(self.nkpt)
            self.expqr_g = 1.
            r_vg = gd.get_grid_point_coordinates()  # (3, nG)
            qr_g = gemmdot(self.qq_v, r_vg, beta=0.0)
            self.expqr_g = np.exp(-1j * qr_g)
            del r_vg, qr_g
            kq_k = kd.find_k_plus_q(self.q_c)
        self.kq_k = kq_k

        # Plane wave init
        self.npw, self.Gvec_Gc, self.Gindex_G = set_Gvectors(
            self.acell_cv, self.bcell_cv, self.nG, self.ecut)

        # Projectors init
        setups = calc.wfs.setups
        pt = LFC(gd, [setup.pt_j for setup in setups],
        spos_ac = calc.atoms.get_scaled_positions()
        self.pt = pt

        # Printing calculation information

import numpy as np
from gpaw.lfc import LocalizedFunctionsCollection as LFC
from gpaw.grid_descriptor import GridDescriptor
from gpaw.spline import Spline
a = 4.0
gd = GridDescriptor(N_c=[16, 20, 20], cell_cv=[a, a + 1, a + 2],
                    pbc_c=(0, 1, 1))
spos_ac = np.array([[0.25, 0.15, 0.35], [0.5, 0.5, 0.5]])
kpts_kc = None
s = Spline(l=0, rmax=2.0, f_g=np.array([1, 0.9, 0.1, 0.0]))
p = Spline(l=1, rmax=2.0, f_g=np.array([1, 0.9, 0.1, 0.0]))
spline_aj = [[s], [s, p]]
c = LFC(gd, spline_aj, cut=True, forces=True)
if kpts_kc is not None:
C_ani = c.dict(3, zero=True)
if 1 in C_ani:
    C_ani[1][:, 1:] = np.eye(3)
psi = gd.zeros(3)
c.add(psi, C_ani)
c.integrate(psi, C_ani)
if 1 in C_ani:
    d = C_ani[1][:, 1:].diagonal()
    assert d.ptp() < 4e-6
    C_ani[1][:, 1:] -= np.diag(d)
    assert abs(C_ani[1]).max() < 5e-17
d_aniv = c.dict(3, derivative=True)
c.derivative(psi, d_aniv)
if 1 in d_aniv:
    for v in range(3):
        # Printing calculation information
