示例#1
0
    def __init__(
        self,
        band_structure: BandStructure,
        num_electrons: int,
        interpolation_factor: float = defaults["interpolation_factor"],
        soc: bool = False,
        magmom: Optional[np.ndarray] = None,
        mommat: Optional[np.ndarray] = None,
        other_properties: Dict[Spin, Dict[str, np.ndarray]] = None,
    ):
        self._band_structure = band_structure
        self._num_electrons = num_electrons
        self._soc = soc
        self._spins = self._band_structure.bands.keys()
        self._other_properties = other_properties
        self.interpolation_factor = interpolation_factor
        self._lattice_matrix = (band_structure.structure.lattice.matrix.T *
                                angstrom_to_bohr)
        self._coefficients = {}
        self._other_coefficients = defaultdict(dict)

        kpoints = np.array([k.frac_coords for k in band_structure.kpoints])
        atoms = AseAtomsAdaptor.get_atoms(band_structure.structure)

        logger.info("Getting band interpolation coefficients")

        t0 = time.perf_counter()
        self._equivalences = sphere.get_equivalences(atoms=atoms,
                                                     nkpt=kpoints.shape[0] *
                                                     interpolation_factor,
                                                     magmom=magmom)

        # get the interpolation mesh used by BoltzTraP2
        self.interpolation_mesh = (
            2 * np.max(np.abs(np.vstack(self._equivalences)), axis=0) + 1)

        for spin in self._spins:
            energies = band_structure.bands[spin] * ev_to_hartree
            data = DFTData(kpoints,
                           energies,
                           self._lattice_matrix,
                           mommat=mommat)
            self._coefficients[spin] = fite.fitde3D(data, self._equivalences)

        log_time_taken(t0)

        t0 = time.perf_counter()
        if self._other_properties:
            logger.info("Getting additional interpolation coefficients")

            for spin in self._spins:
                for label, prop in self._other_properties[spin].items():
                    data = DFTData(kpoints,
                                   prop,
                                   self._lattice_matrix,
                                   mommat=mommat)
                    self._other_coefficients[spin][label] = fite.fitde3D(
                        data, self._equivalences)
            log_time_taken(t0)
示例#2
0
    def __init__(self,
                 band_structure: BandStructure,
                 num_electrons: int,
                 interpolation_factor: float = 20,
                 soc: bool = False,
                 magmom: Optional[np.ndarray] = None,
                 mommat: Optional[np.ndarray] = None,
                 interpolate_projections: bool = False):
        self._band_structure = band_structure
        self._num_electrons = num_electrons
        self._soc = soc
        self._spins = self._band_structure.bands.keys()
        self._interpolate_projections = interpolate_projections
        self.interpolation_factor = interpolation_factor
        self._lattice_matrix = (band_structure.structure.lattice.matrix *
                                units.Angstrom)
        self._coefficients = {}
        self._projection_coefficients = defaultdict(dict)

        kpoints = np.array([k.frac_coords for k in band_structure.kpoints])
        atoms = AseAtomsAdaptor.get_atoms(band_structure.structure)

        logger.info("Getting band interpolation coefficients")

        t0 = time.perf_counter()
        self._equivalences = sphere.get_equivalences(
            atoms=atoms, nkpt=kpoints.shape[0] * interpolation_factor,
            magmom=magmom)

        # get the interpolation mesh used by BoltzTraP2
        self.interpolation_mesh = 2 * np.max(
            np.abs(np.vstack(self._equivalences)), axis=0) + 1

        for spin in self._spins:
            energies = band_structure.bands[spin] * units.eV
            data = DFTData(kpoints, energies, self._lattice_matrix,
                           mommat=mommat)
            self._coefficients[spin] = fite.fitde3D(data, self._equivalences)

        log_time_taken(t0)

        if self._interpolate_projections:
            logger.info("Getting projection interpolation coefficients")

            if not band_structure.projections:
                raise ValueError(
                    "interpolate_projections is True but band structure has no "
                    "projections")

            for spin in self._spins:
                for label, projection in _get_projections(
                        band_structure.projections[spin]):
                    data = DFTData(kpoints, projection, self._lattice_matrix,
                                   mommat=mommat)
                    self._projection_coefficients[spin][label] = fite.fitde3D(
                        data, self._equivalences)
            log_time_taken(t0)
示例#3
0
def retrieve_bs_boltztrap2(vrun, bs, ibands, matrix=None):
    pf = PlotlyFig(filename='Energy-bt2')
    sym_line_kpoints = [k.frac_coords for k in bs.kpoints]
    bz_data = PymatgenLoader(vrun)
    equivalences = sphere.get_equivalences(atoms=bz_data.atoms,
                                           nkpt=len(bz_data.kpoints) * 5,
                                           magmom=None)
    lattvec = bz_data.get_lattvec()
    coeffs = fite.fitde3D(bz_data, equivalences)
    kpts = np.array(sym_line_kpoints)
    interp_params = (equivalences, lattvec, coeffs)
    plot_data = []
    v_data = []
    names = []
    mass_data = []
    eref = 0.0
    for ith, iband in enumerate(ibands):
        en, vel, masses = interpolate_bs(kpts,
                                         interp_params,
                                         iband=iband,
                                         method="boltztrap2",
                                         matrix=matrix)
        # method = "boltztrap2", matrix = lattvec * 0.529177)
        if ith == 0:
            eref = np.max(en)
        en -= eref
        plot_data.append((list(range(len(en))), en))
        v_data.append((en, np.linalg.norm(vel, axis=1)))
        mass_data.append((en, [mass.trace() / 3.0 for mass in masses]))
        names.append('band {}'.format(iband + 1))
    pf.xy(plot_data, names=[n for n in names])
    pf2 = PlotlyFig(filename='Velocity-bt2')
    pf2.xy(v_data, names=[n for n in names])
    pf3 = PlotlyFig(filename='mass-bt2')
    pf3.xy(mass_data, names=[n for n in names])
示例#4
0
    def __init__(self,
                 data,
                 lpfac=10,
                 energy_range=1.5,
                 curvature=True,
                 save_bztInterp=False,
                 load_bztInterp=False,
                 save_bands=False,
                 fname='bztInterp.json.gz'):
        """
        Args:
            data: A loader
            lpfac: the number of interpolation points in the real space. By
                default 10 gives 10 time more points in the real space than
                the number of kpoints given in reciprocal space.
            energy_range: usually the interpolation is not needed on the entire energy
                range but on a specific range around the fermi level.
                This energy in eV fix the range around the fermi level
                (E_fermi-energy_range,E_fermi+energy_range) of
                bands that will be interpolated
                and taken into account to calculate the transport properties.
            curvature: boolean value to enable/disable the calculation of second
                derivative related trasport properties (Hall coefficient).
            save_bztInterp: Default False. If True coefficients and equivalences are
                saved in fname file.
            load_bztInterp: Default False. If True the coefficients and equivalences
                are loaded from fname file, not calculated. It can be faster than
                re-calculate them in some cases.
            save_bands: Default False. If True interpolated bands are also stored.
                It can be slower than interpolate them. Not recommended.
            fname: File path where to store/load from the coefficients and equivalences.
        Example:
            data = VasprunLoader().from_file('vasprun.xml')
            bztInterp = BztInterpolator(data)
        """
        bands_loaded = False
        self.data = data
        num_kpts = self.data.kpoints.shape[0]
        self.efermi = self.data.fermi
        middle_gap_en = (self.data.cbm + self.data.vbm) / 2
        self.accepted = self.data.bandana(
            emin=(middle_gap_en - energy_range) * units.eV,
            emax=(middle_gap_en + energy_range) * units.eV)

        if load_bztInterp:
            bands_loaded = self.load(fname)
        else:
            self.equivalences = sphere.get_equivalences(
                self.data.atoms, self.data.magmom, num_kpts * lpfac)
            self.coeffs = fite.fitde3D(self.data, self.equivalences)

        if not bands_loaded:
            self.eband, self.vvband, self.cband = fite.getBTPbands(
                self.equivalences,
                self.coeffs,
                self.data.lattvec,
                curvature=curvature)

        if save_bztInterp:
            self.save(fname, save_bands)
示例#5
0
 def __init__(self, data, lpfac=10, energy_range=1.5, curvature=True):
     """
     Args:
         data: A loader
         lpfac: the number of interpolation points in the real space. By
             default 10 gives 10 time more points in the real space than
             the number of kpoints given in reciprocal space.
         energy_range: usually the interpolation is not needed on the entire energy
             range but on a specific range around the fermi level.
             This energy in eV fix the range around the fermi level (E_fermi-energy_range,E_fermi+energy_range) of
             bands that will be interpolated
             and taken into account to calculate the transport properties.
         curvature: boolean value to enable/disable the calculation of second
             derivative related trasport properties (Hall coefficient).
     Example:
         data = VasprunLoader().from_file('vasprun.xml')
         bztInterp = BztInterpolator(data)
     """
     self.data = data
     num_kpts = self.data.kpoints.shape[0]
     self.efermi = self.data.fermi
     self.nemin, self.nemax = self.data.bandana(
         emin=self.efermi - (energy_range * units.eV),
         emax=self.efermi + (energy_range * units.eV))
     self.equivalences = sphere.get_equivalences(self.data.atoms,
                                                 self.data.magmom,
                                                 num_kpts * lpfac)
     self.coeffs = fite.fitde3D(self.data, self.equivalences)
     self.eband, self.vvband, self.cband = fite.getBTPbands(
         self.equivalences,
         self.coeffs,
         self.data.lattvec,
         curvature=curvature)
示例#6
0
    def get_partial_doses(self, tdos, eband_ud, spins, enr, npts_mu, T,
                          progress):
        """
        Return a CompleteDos object interpolating the projections

        tdos: total dos previously calculated
        npts_mu: number of energy points of the Dos
        T: parameter used to smooth the Dos
        progress: Default False, If True a progress bar is shown.
        """
        if not self.data.proj:
            raise BoltztrapError("No projections loaded.")

        bkp_data_ebands = np.copy(self.data.ebands)

        pdoss = {}
        if progress:
            n_iter = np.prod(
                np.sum(
                    [np.array(i.shape)[2:] for i in self.data.proj.values()]))
            t = tqdm(total=n_iter * 2)
        for spin, eb in zip(spins, eband_ud):
            for isite, site in enumerate(self.data.structure.sites):
                if site not in pdoss:
                    pdoss[site] = {}
                for iorb, orb in enumerate(Orbital):
                    if progress:
                        t.update()
                    if iorb == self.data.proj[spin].shape[-1]:
                        break

                    if orb not in pdoss[site]:
                        pdoss[site][orb] = {}

                    self.data.ebands = self.data.proj[spin][:, :, isite,
                                                            iorb].T
                    coeffs = fite.fitde3D(self.data, self.equivalences)
                    proj, vvproj, cproj = fite.getBTPbands(
                        self.equivalences, coeffs, self.data.lattvec)

                    edos, pdos = BL.DOS(eb,
                                        npts=npts_mu,
                                        weights=np.abs(proj.real),
                                        erange=enr)

                    if T:
                        pdos = BL.smoothen_DOS(edos, pdos, T)

                    pdoss[site][orb][spin] = pdos

        self.data.ebands = bkp_data_ebands

        return CompleteDos(self.data.structure, total_dos=tdos, pdoss=pdoss)
示例#7
0
 def __init__(self, data, lpfac=10, energy_range=1.5,curvature=True):
     
     self.data = data
     num_kpts = self.data.kpoints.shape[0]
     self.efermi = self.data.fermi
     self.nemin, self.nemax = self.data.bandana(emin=self.efermi - (energy_range * units.eV), emax=self.efermi + (energy_range * units.eV))
     self.equivalences = sphere.get_equivalences(self.data.atoms, self.data.magmom,
                                                 num_kpts * lpfac)
     self.coeffs = fite.fitde3D(self.data, self.equivalences)
     self.eband, self.vvband, self.cband = fite.getBTPbands(self.equivalences,
                                                     self.coeffs, self.data.lattvec,
                                                     curvature=curvature)
示例#8
0
    def __init__(self, data, lpfac=10, energy_range=1.5, curvature=True):

        self.data = data
        num_kpts = self.data.kpoints.shape[0]
        self.efermi = self.data.fermi
        self.nemin, self.nemax = self.data.bandana(emin=self.efermi - (energy_range * units.eV),
                                                   emax=self.efermi + (energy_range * units.eV))
        self.equivalences = sphere.get_equivalences(self.data.atoms, self.data.magmom,
                                                    num_kpts * lpfac)
        self.coeffs = fite.fitde3D(self.data, self.equivalences)
        self.eband, self.vvband, self.cband = fite.getBTPbands(self.equivalences,
                                                               self.coeffs, self.data.lattvec,
                                                               curvature=curvature)
示例#9
0
    def get_partial_doses(self, tdos, npts_mu, T):
        """
        Return a CompleteDos object interpolating the projections

        tdos: total dos previously calculated
        npts_mu: number of energy points of the Dos
        T: parameter used to smooth the Dos
        """
        spin = self.data.spin if isinstance(self.data.spin, int) else 1

        if not isinstance(self.data.proj, np.ndarray):
            raise BoltztrapError("No projections loaded.")

        bkp_data_ebands = np.copy(self.data.ebands)

        pdoss = {}
        # for spin in self.data.proj:
        for isite, site in enumerate(self.data.structure.sites):
            if site not in pdoss:
                pdoss[site] = {}
            for iorb, orb in enumerate(Orbital):
                if iorb == self.data.proj.shape[-1]:
                    break

                if orb not in pdoss[site]:
                    pdoss[site][orb] = {}

                self.data.ebands = self.data.proj[:, :, isite, iorb].T
                coeffs = fite.fitde3D(self.data, self.equivalences)
                proj, vvproj, cproj = fite.getBTPbands(self.equivalences,
                                                       coeffs,
                                                       self.data.lattvec)

                edos, pdos = BL.DOS(self.eband,
                                    npts=npts_mu,
                                    weights=np.abs(proj.real))

                if T is not None:
                    pdos = BL.smoothen_DOS(edos, pdos, T)

                pdoss[site][orb][Spin(spin)] = pdos

        self.data.ebands = bkp_data_ebands

        return CompleteDos(self.data.structure, total_dos=tdos, pdoss=pdoss)
示例#10
0
    def get_partial_doses(self, tdos, npts_mu, T):
        """
            Return a CompleteDos object interpolating the projections

            tdos: total dos previously calculated
            npts_mu: number of energy points of the Dos
            T: parameter used to smooth the Dos
        """
        spin = self.data.spin if isinstance(self.data.spin,int) else 1

        if not isinstance(self.data.proj,np.ndarray):
            raise BoltztrapError("No projections loaded.")

        bkp_data_ebands = np.copy(self.data.ebands)

        pdoss = {}
        # for spin in self.data.proj:
        for isite, site in enumerate(self.data.structure.sites):
            if site not in pdoss:
                pdoss[site] = {}
            for iorb, orb in enumerate(Orbital):
                if iorb == self.data.proj.shape[-1]: break

                if orb not in pdoss[site]:
                    pdoss[site][orb] = {}

                self.data.ebands = self.data.proj[:, :, isite, iorb].T
                coeffs = fite.fitde3D(self.data, self.equivalences)
                proj, vvproj, cproj = fite.getBTPbands(self.equivalences,
                                                       coeffs, self.data.lattvec)

                edos, pdos = BL.DOS(self.eband, npts=npts_mu, weights=np.abs(proj.real))

                if T is not None:
                    pdos = BL.smoothen_DOS(edos, pdos, T)

                pdoss[site][orb][Spin(spin)] = pdos

        self.data.ebands = bkp_data_ebands

        return CompleteDos(self.data.structure, total_dos=tdos, pdoss=pdoss)
示例#11
0
def boltztrap(dirname, bt2file, title, T):
    print(("\n\nWorking in %s for %s at %i K" % (dirname, title, T)))
    # If a ready-made file with the interpolation results is available, use it
    # Otherwise, create the file.
    if not os.path.exists(bt2file):
        # Load the input
        data = BTP.DFTData(dirname)
        # Select the interesting bands
        nemin, nemax = data.bandana(emin=data.fermi - .2, emax=data.fermi + .2)
        # Set up a k point grid with roughly five times the density of the input
        equivalences = sphere.get_equivalences(data.atoms,
                                               len(data.kpoints) * 5)
        # Perform the interpolation
        coeffs = fite.fitde3D(data, equivalences)
        # Save the result
        serialization.save_calculation(
            bt2file, data, equivalences, coeffs,
            serialization.gen_bt2_metadata(data, data.mommat is not None))

    # Load the interpolation results
    print("Load the interpolation results")
    data, equivalences, coeffs, metadata = serialization.load_calculation(
        bt2file)

    # Reconstruct the bands
    print("Reconstruct the bands")
    lattvec = data.get_lattvec()
    eband, vvband, cband = fite.getBTPbands(equivalences, coeffs, lattvec)

    # Obtain the Fermi integrals for different chemical potentials at
    # room temperature.
    TEMP = np.array([T])
    epsilon, dos, vvdos, cdos = BL.BTPDOS(eband, vvband, npts=4000)
    margin = 9. * units.BOLTZMANN * TEMP.max()
    mur_indices = np.logical_and(epsilon > epsilon.min() + margin,
                                 epsilon < epsilon.max() - margin)
    mur = epsilon[mur_indices]
    N, L0, L1, L2, Lm11 = BL.fermiintegrals(epsilon,
                                            dos,
                                            vvdos,
                                            mur=mur,
                                            Tr=TEMP,
                                            dosweight=data.dosweight)

    # Compute the Onsager coefficients from those Fermi integrals
    print("Compute the Onsager coefficients")
    UCvol = data.get_volume()
    sigma, seebeck, kappa, Hall = BL.calc_Onsager_coefficients(
        L0, L1, L2, mur, TEMP, UCvol)

    fermi = BL.solve_for_mu(epsilon, dos, data.nelect, T, data.dosweight)

    savedata[title + '-%s' % T] = {
        "sigma": sigma,
        "seebeck": seebeck,
        "kappa": kappa,
        "Hall": Hall,
        "mu": (mur - fermi) / BL.eV,
        "temp": T,
        "n": N[0] + data.nelect
    }
示例#12
0
    def interpolate_bands(self,
                          interpolation_factor: float = 5,
                          energy_cutoff: Optional[float] = None,
                          nworkers: int = -1):
        """Gets a pymatgen band structure.
        Note, the interpolation mesh is determined using by
        ``interpolate_factor`` option in the ``Inteprolater`` constructor.
        The degree of parallelization is controlled by the ``nworkers`` option.
        Args:
            energy_cutoff: The energy cut-off to determine which bands are
                included in the interpolation. If the energy of a band falls
                within the cut-off at any k-point it will be included. For
                metals the range is defined as the Fermi level ± energy_cutoff.
                For gapped materials, the energy range is from the VBM -
                energy_cutoff to the CBM + energy_cutoff.
            nworkers: The number of processors used to perform the
                interpolation. If set to ``-1``, the number of workers will
                be set to the number of CPU cores.
        Returns:
            The interpolated electronic structure.
        """

        coefficients = {}

        equivalences = sphere.get_equivalences(atoms=self._atoms,
                                               nkpt=self._kpoints.shape[0] *
                                               interpolation_factor,
                                               magmom=self._magmom)

        # get the interpolation mesh used by BoltzTraP2
        interpolation_mesh = 2 * np.max(np.abs(np.vstack(equivalences)),
                                        axis=0) + 1

        for spin in self._spins:
            energies = self._band_structure.bands[spin] * units.eV
            data = DFTData(self._kpoints,
                           energies,
                           self._lattice_matrix,
                           mommat=self._mommat)
            coefficients[spin] = fite.fitde3D(data, equivalences)
        is_metal = self._band_structure.is_metal()

        nworkers = multiprocessing.cpu_count() if nworkers == -1 else nworkers

        # determine energy cutoffs
        if energy_cutoff and is_metal:
            min_e = self._band_structure.efermi - energy_cutoff
            max_e = self._band_structure.efermi + energy_cutoff

        elif energy_cutoff:
            min_e = self._band_structure.get_vbm()['energy'] - energy_cutoff
            max_e = self._band_structure.get_cbm()['energy'] + energy_cutoff

        else:
            min_e = min([
                self._band_structure.bands[spin].min() for spin in self._spins
            ])
            max_e = max([
                self._band_structure.bands[spin].max() for spin in self._spins
            ])

        energies = {}
        new_vb_idx = {}
        for spin in self._spins:
            ibands = np.any((self._band_structure.bands[spin] > min_e) &
                            (self._band_structure.bands[spin] < max_e),
                            axis=1)

            energies[spin] = fite.getBTPbands(equivalences,
                                              coefficients[spin][ibands],
                                              self._lattice_matrix,
                                              nworkers=nworkers)[0]

            # boltztrap2 gives energies in Rydberg, convert to eV
            energies[spin] /= units.eV

            if not is_metal:
                vb_idx = max(
                    self._band_structure.get_vbm()["band_index"][spin])
                # need to know the index of the valence band after discounting
                # bands during the interpolation. As ibands is just a list of
                # True/False, we can count the number of Trues up to
                # and including the VBM to get the new number of valence bands
                new_vb_idx[spin] = sum(ibands[:vb_idx + 1]) - 1

        if is_metal:
            efermi = self._band_structure.efermi
        else:
            # if material is semiconducting, set Fermi level to middle of gap
            e_vbm = max(
                [np.max(energies[s][:new_vb_idx[s] + 1]) for s in self._spins])
            e_cbm = min(
                [np.min(energies[s][new_vb_idx[s] + 1:]) for s in self._spins])
            efermi = (e_vbm + e_cbm) / 2

        atoms = AseAtomsAdaptor().get_atoms(self._band_structure.structure)
        mapping, grid = spglib.get_ir_reciprocal_mesh(interpolation_mesh,
                                                      atoms,
                                                      symprec=0.1)
        full_kpoints = grid / interpolation_mesh

        sort_idx = np.lexsort(
            (full_kpoints[:, 2], full_kpoints[:, 2] < 0, full_kpoints[:, 1],
             full_kpoints[:, 1] < 0, full_kpoints[:,
                                                  0], full_kpoints[:, 0] < 0))

        reordered_kpoints = full_kpoints[sort_idx]

        return BandStructure(reordered_kpoints,
                             energies,
                             self._band_structure.structure.lattice,
                             efermi,
                             structure=self._structure), np.max(np.abs(
                                 np.vstack(equivalences)),
                                                                axis=0)
示例#13
0
文件: silicon_crt.py 项目: suan12/EPA
dirname = './'
ftr = dirname + 'silicon_crt.trace'
RYDBERG = 0.5
doping_level = 0.01
tau = 1.0e-14

ecut, efcut, deltae, tmax, deltat, lpfac = 1.0 * RYDBERG, 0.3 * RYDBERG, 0.0005 * RYDBERG, 1200.0, 10.0, 5

# Load the input
data = dft.DFTData(dirname)
# Select the interesting bands
nemin, nemax = data.bandana(emin=data.fermi - ecut, emax=data.fermi + ecut)
# Set up a k point grid with roughly five times the density of the input
equivalences = sphere.get_equivalences(data.atoms, len(data.kpoints) * lpfac)
# Perform the interpolation
coeffs = fite.fitde3D(data, equivalences)

lattvec = data.get_lattvec()
eband, vvband, cband = fite.getBTPbands(equivalences, coeffs, lattvec)
epsilon, dos, vvdos, cdos = BL.BTPDOS(
    eband,
    vvband,
    erange=[data.fermi - ecut, data.fermi + ecut],
    npts=round(2 * ecut / deltae),
    scattering_model='uniform_tau')

# Define the temperatures and chemical potentials we are interested in
#Tr = np.arange(deltat, tmax + deltat / 2, deltat)
#mur_indices = np.logical_and(epsilon > data.fermi - efcut, epsilon < data.fermi + efcut)
#mur = epsilon[mur_indices]
示例#14
0
    def interpolate_bands(
        self,
        interpolation_factor: float = 5,
        energy_cutoff: Optional[float] = None,
        nworkers: int = -1,
    ):
        """Gets a pymatgen band structure.
        Note, the interpolation mesh is determined using by
        ``interpolate_factor`` option in the ``Inteprolater`` constructor.
        The degree of parallelization is controlled by the ``nworkers`` option.
        Args:
            interpolation_factor: The factor by which the band structure will
                be interpolated.
            energy_cutoff: The energy cut-off to determine which bands are
                included in the interpolation. If the energy of a band falls
                within the cut-off at any k-point it will be included. For
                metals the range is defined as the Fermi level ± energy_cutoff.
                For gapped materials, the energy range is from the VBM -
                energy_cutoff to the CBM + energy_cutoff.
            nworkers: The number of processors used to perform the
                interpolation. If set to ``-1``, the number of workers will
                be set to the number of CPU cores.

        Returns:
            The interpolated electronic structure.
        """

        coefficients = {}

        equivalences = sphere.get_equivalences(
            atoms=self._atoms,
            nkpt=self._kpoints.shape[0] * interpolation_factor,
            magmom=self._magmom,
        )

        # get the interpolation mesh used by BoltzTraP2
        interpolation_mesh = 2 * np.max(np.abs(np.vstack(equivalences)),
                                        axis=0) + 1

        for spin in self._spins:
            energies = self._band_structure.bands[spin] * eV
            data = DFTData(self._kpoints,
                           energies,
                           self._lattice_matrix,
                           mommat=self._mommat)
            coefficients[spin] = fite.fitde3D(data, equivalences)
        is_metal = self._band_structure.is_metal()

        nworkers = multiprocessing.cpu_count() if nworkers == -1 else nworkers

        # determine energy cutoffs
        if energy_cutoff and is_metal:
            min_e = self._band_structure.efermi - energy_cutoff
            max_e = self._band_structure.efermi + energy_cutoff

        elif energy_cutoff:
            min_e = self._band_structure.get_vbm()["energy"] - energy_cutoff
            max_e = self._band_structure.get_cbm()["energy"] + energy_cutoff

        else:
            min_e = min([
                self._band_structure.bands[spin].min() for spin in self._spins
            ])
            max_e = max([
                self._band_structure.bands[spin].max() for spin in self._spins
            ])

        energies = {}
        new_vb_idx = {}
        for spin in self._spins:
            ibands = np.any(
                (self._band_structure.bands[spin] > min_e)
                & (self._band_structure.bands[spin] < max_e),
                axis=1,
            )

            energies[spin] = fite.getBTPbands(
                equivalences,
                coefficients[spin][ibands],
                self._lattice_matrix,
                nworkers=nworkers,
            )[0]

            # boltztrap2 gives energies in Rydberg, convert to eV
            energies[spin] /= eV

            if not is_metal:
                vb_energy = self._band_structure.get_vbm()["energy"]
                spin_bands = self._band_structure.bands[spin]
                below_vbm = np.any(spin_bands < vb_energy, axis=1)
                spin_vb_idx = np.max(np.where(below_vbm)[0])

                # need to know the index of the valence band after discounting
                # bands during the interpolation. As ibands is just a list of
                # True/False, we can count the number of Trues up to
                # and including the VBM to get the new number of valence bands
                new_vb_idx[spin] = sum(ibands[:spin_vb_idx + 1]) - 1

        if is_metal:
            efermi = self._band_structure.efermi
        else:
            # if material is semiconducting, set Fermi level to middle of gap
            warnings.warn(
                "The Fermi energy may be different to that in the vasprun.xml file,"
                " due to the material being a semiconductor. The Fermi level has been "
                "set to midway between the top of the valence band and the bottom of "
                "the conduction band.",
                category=None,
                stacklevel=1,
                source=None,
            )
            e_vbm = max(
                [np.max(energies[s][:new_vb_idx[s] + 1]) for s in self._spins])
            e_cbm = min(
                [np.min(energies[s][new_vb_idx[s] + 1:]) for s in self._spins])
            efermi = (e_vbm + e_cbm) / 2

        atoms = AseAtomsAdaptor().get_atoms(self._band_structure.structure)
        mapping, grid = spglib.get_ir_reciprocal_mesh(interpolation_mesh,
                                                      atoms,
                                                      symprec=0.1)
        kpoints = grid / interpolation_mesh

        # sort energies so they have the same order as the k-points generated by spglib
        sort_idx = sort_boltztrap_to_spglib(kpoints)
        energies = {s: ener[:, sort_idx] for s, ener in energies.items()}

        rlat = self._band_structure.structure.lattice.reciprocal_lattice
        interp_band_structure = BandStructure(kpoints,
                                              energies,
                                              rlat,
                                              efermi,
                                              structure=self._structure)

        return interp_band_structure, interpolation_mesh