Exemplo n.º 1
0
    def setUp(self):
        with open(os.path.join(test_dir, "CaO_2605_bandstructure.json"),
                  "r",
                  encoding="utf-8") as f:
            d = json.loads(f.read())
            self.bs = BandStructureSymmLine.from_dict(d)
            self.plotter = BSPlotter(self.bs)

        self.assertEqual(len(self.plotter._bs), 1,
                         "wrong number of band objects")

        with open(os.path.join(test_dir, "N2_12103_bandstructure.json"),
                  "r",
                  encoding="utf-8") as f:
            d = json.loads(f.read())
            self.sbs_sc = BandStructureSymmLine.from_dict(d)

        with open(os.path.join(test_dir, "C_48_bandstructure.json"),
                  "r",
                  encoding="utf-8") as f:
            d = json.loads(f.read())
            self.sbs_met = BandStructureSymmLine.from_dict(d)

        self.plotter_multi = BSPlotter([self.sbs_sc, self.sbs_met])
        self.assertEqual(len(self.plotter_multi._bs), 2,
                         "wrong number of band objects")
        self.assertEqual(self.plotter_multi._nb_bands, [96, 96],
                         "wrong number of bands")
        warnings.simplefilter("ignore")
Exemplo n.º 2
0
    def make_band_plot_info(self):
        bs_plotter = BSPlotter(self.bs)
        plot_data = bs_plotter.bs_plot_data(zero_to_efermi=False)
        distances = [list(d) for d in plot_data["distances"]]
        self._composition = self.vasprun.final_structure.composition

        band_info = [BandInfo(band_energies=self._remove_spin_key(plot_data),
                              band_edge=self._band_edge(self.bs, plot_data),
                              fermi_level=self.bs.efermi)]

        if self.vasprun2:
            bs2 = self.vasprun2.get_band_structure(self.kpoints_filename,
                                                   line_mode=True)
            plot_data2 = BSPlotter(bs2).bs_plot_data(zero_to_efermi=False)
            band_info.append(
                BandInfo(band_energies=self._remove_spin_key(plot_data2),
                         band_edge=self._band_edge(bs2, plot_data2),
                         fermi_level=self.bs.efermi))

        x = bs_plotter.get_ticks_old()
        x_ticks = XTicks(_sanitize_labels(x["label"]), x["distance"])

        return BandPlotInfo(band_info_set=band_info,
                            distances_by_branch=distances,
                            x_ticks=x_ticks,
                            title=self._title)
Exemplo n.º 3
0
def plot_simple_smoothed_band_structure(ylim=[-1.5, 3.5], filename=None):
    vasprun = Vasprun('./vasprun.xml')
    bs = vasprun.get_band_structure(line_mode=True)
    if filename is None:
        # BSPlotter(bs).get_plot(smooth=True,ylim=ylim)
        BSPlotter(bs).show(smooth=True, ylim=ylim)
    else:
        BSPlotter(bs).save_plot(filename)
Exemplo n.º 4
0
def find_dirac_nodes():
    """
    Look for band crossings near (within `tol` eV) the Fermi level.

    Returns:
        boolean. Whether or not a band crossing occurs at or near
            the fermi level.
    """

    vasprun = Vasprun('vasprun.xml')
    dirac = False
    if vasprun.get_band_structure().get_band_gap()['energy'] < 0.1:
        efermi = vasprun.efermi
        bsp = BSPlotter(vasprun.get_band_structure('KPOINTS', line_mode=True,
                                                   efermi=efermi))
        bands = []
        data = bsp.bs_plot_data(zero_to_efermi=True)
        for d in range(len(data['distances'])):
            for i in range(bsp._nb_bands):
                x = data['distances'][d],
                y = [data['energy'][d][str(Spin.up)][i][j]
                     for j in range(len(data['distances'][d]))]
                band = [x, y]
                bands.append(band)

        considered = []
        for i in range(len(bands)):
            for j in range(len(bands)):
                if i != j and (j, i) not in considered:
                    considered.append((j, i))
                    for k in range(len(bands[i][0])):
                        if ((-0.1 < bands[i][1][k] < 0.1) and
                                (-0.1 < bands[i][1][k] - bands[j][1][k] < 0.1)):
                            dirac = True
    return dirac
Exemplo n.º 5
0
def plot_bandstructure():
    if "-h" in sys.argv:
        print("usage: complot_bands.py [-g] [-f fname] [-el emin] [-eh emax]")
        sys.exit()

    if "-g" in sys.argv:
        mode = "risb"
    else:
        mode = "tb"
    bs = get_bands_symkpath(mode=mode)
    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()
    if rank == 0:
        bsplot = BSPlotter(bs)
        if "-f" in sys.argv:
            fname = sys.argv[sys.argv.index("-f")+1]
            if ".pdf" not in fname:
                fname += ".pdf"
        else:
            fname = "bndstr.pdf"
        if "-el" in sys.argv:
            emin = float(sys.argv[sys.argv.index("-el")+1])
        else:
            emin = numpy.min(bs.bands.values())
        if "-eh" in sys.argv:
            emax = float(sys.argv[sys.argv.index("-eh")+1])
        else:
            emax = numpy.max(bs.bands.values())
        bsplot.save_plot(fname, img_format="pdf", ylim=(emin, emax), \
                zero_to_efermi=False)
Exemplo n.º 6
0
def plot_band_structure(ylim=(-5, 5), draw_fermi=False, fmt='pdf'):
    """
    Plot a standard band structure with no projections.

    Args:
        ylim (tuple): minimum and maximum potentials for the plot's y-axis.
        draw_fermi (bool): whether or not to draw a dashed line at E_F.
        fmt (str): matplotlib format style. Check the matplotlib docs
            for options.
    """

    vasprun = Vasprun('vasprun.xml')
    efermi = vasprun.efermi
    bsp = BSPlotter(
        vasprun.get_band_structure('KPOINTS', line_mode=True, efermi=efermi))
    if fmt == "None":
        return bsp.bs_plot_data()
    else:
        plot = bsp.get_plot(ylim=ylim)
        fig = plot.gcf()
        ax = fig.gca()
        ax.set_xticklabels(
            [r'$\mathrm{%s}$' % t for t in ax.get_xticklabels()])
        ax.set_yticklabels(
            [r'$\mathrm{%s}$' % t for t in ax.get_yticklabels()])
        if draw_fermi:
            ax.plot([ax.get_xlim()[0], ax.get_xlim()[1]], [0, 0], 'k--')
        plt.savefig('band_structure.{}'.format(fmt), transparent=True)

    plt.close()
Exemplo n.º 7
0
 def setUp(self):
     with open(os.path.join(test_dir, "CaO_2605_bandstructure.json"),
               "r",
               encoding='utf-8') as f:
         d = json.loads(f.read())
         self.bs = BandStructureSymmLine.from_dict(d)
         self.plotter = BSPlotter(self.bs)
Exemplo n.º 8
0
def plot_bs(vasprun_path: str):
    try:
        v = BSVasprun(vasprun_path)
    except xml.etree.ElementTree.ParseError:
        print("\tskipped due to parse error")
        return
    bs = v.get_band_structure(kpoints_filename="KPOINTS", line_mode=True)
    plt = BSPlotter(bs)
    plt.show()
Exemplo n.º 9
0
    def plot_bands(self):
        """
        Plot a band structure on symmetry line using BSPlotter()
        """
        if self.bzt_interp is None:
            raise BoltztrapError("BztInterpolator not present")

        sbs = self.bzt_interp.get_band_structure()

        return BSPlotter(sbs).get_plot()
Exemplo n.º 10
0
def banddos(pref='',storedir=None):
    ru=str("vasprun.xml")
    kpfile=str("KPOINTS")




    run = Vasprun(ru, parse_projected_eigen = True)
    bands = run.get_band_structure(kpfile, line_mode = True, efermi = run.efermi)
    bsp =  BSPlotter(bands)
    zero_to_efermi=True
    bandgap=str(round(bands.get_band_gap()['energy'],3))
    print "bg=",bandgap
    data=bsp.bs_plot_data(zero_to_efermi)
    plt = get_publication_quality_plot(12, 8)
    band_linewidth = 3
    x_max = data['distances'][-1][-1]
    print (x_max)
    for d in range(len(data['distances'])):
       for i in range(bsp._nb_bands):
          plt.plot(data['distances'][d],
                 [data['energy'][d]['1'][i][j]
                  for j in range(len(data['distances'][d]))], 'b-',
                 linewidth=band_linewidth)
          if bsp._bs.is_spin_polarized:
             plt.plot(data['distances'][d],
                     [data['energy'][d]['-1'][i][j]
                      for j in range(len(data['distances'][d]))],
                     'r--', linewidth=band_linewidth)
    bsp._maketicks(plt)
    if bsp._bs.is_metal():
         e_min = -10
         e_max = 10
         band_linewidth = 3

    for cbm in data['cbm']:
            plt.scatter(cbm[0], cbm[1], color='r', marker='o',
                        s=100)

            for vbm in data['vbm']:
                plt.scatter(vbm[0], vbm[1], color='g', marker='o',
                            s=100)


    plt.xlabel(r'$\mathrm{Wave\ Vector}$', fontsize=30)
    ylabel = r'$\mathrm{E\ -\ E_f\ (eV)}$' if zero_to_efermi \
       else r'$\mathrm{Energy\ (eV)}$'
    plt.ylabel(ylabel, fontsize=30)
    plt.ylim(-4,4)
    plt.xlim(0,x_max)
    plt.tight_layout()
    plt.savefig('BAND.png',img_format="png")

    plt.close()
Exemplo n.º 11
0
def get_small_plot(bs):

    plot_small = BSPlotter(bs).bs_plot_data()

    gap = bs.get_band_gap()["energy"]
    for branch in plot_small['energy']:
        for spin, v in branch.items():
            new_bands = []
            for band in v:
                if min(band) < gap + 3 and max(band) > -3:
                    new_bands.append(band)
            branch[spin] = new_bands
    return plot_small
Exemplo n.º 12
0
    def get_bandsxy(self, bs, bandrange):

        ## get coords of the band structure data points
        bsplot = BSPlotter(bs)
        data = bsplot.bs_plot_data()

        x = [k for kbranch in data["distances"] for k in kbranch]
        yu = [[e - bs.efermi for e in bs.bands[Spin.up][band]] for band in bandrange]
        if bs.is_spin_polarized:
            yd = [[e - bs.efermi for e in bs.bands[Spin.down][band]] for band in bandrange]
        else:
            yd = None

        return [x, yu, yd]
Exemplo n.º 13
0
def get_mp_banddos():
    check_matplotlib()
    mpr = check_apikey()
    print("input the mp-ID")
    wait_sep()
    in_str = wait()
    mp_id = in_str
    step_count = 1
    proc_str = "Reading Data From " + web + " ..."
    procs(proc_str, step_count, sp='-->>')
    data = mpr.get_entry_by_material_id(mp_id)
    sepline()
    print(data)
    sepline()
    step_count += 1
    proc_str = "Reading Band Data From " + web + " ..."
    procs(proc_str, step_count, sp='-->>')

    band = mpr.get_bandstructure_by_material_id(mp_id)
    if band is None:
        print("No data obtained online, stop now!")
        os.exit(0)

    step_count += 1
    filename = mp_id + '_band.png'
    proc_str = "Writing Data to " + filename + " File..."
    bsp = BSPlotter(band)
    procs(proc_str, step_count, sp='-->>')
    bsp.save_plot(filename=filename, img_format="png")

    step_count += 1
    proc_str = "Reading DOS Data From " + web + " ..."
    procs(proc_str, step_count, sp='-->>')
    dos = mpr.get_dos_by_material_id(mp_id)
    if dos is None:
        print("No data obtained online, stop now!")

    step_count += 1
    filename = mp_id + '_dos.png'
    proc_str = "Writing Data to " + filename + " File..."
    dsp = DosPlotter()
    dsp.add_dos('Total', dos)
    procs(proc_str, step_count, sp='-->>')
    dsp.save_plot(filename=filename, img_format="png")
Exemplo n.º 14
0
    def get_kpt_labels(self, bs):

        bsplot = BSPlotter(bs)

        ## get unique K-points
        labels = bsplot.get_ticks()["label"]
        labelspos = bsplot.get_ticks()["distance"]
        labels_uniq = [labels[0]]
        labelspos_uniq = [labelspos[0]]
        for i in range(1, len(labels)):
            if labels[i] != labels[i - 1]:
                labels_uniq.append(labels[i])
                labelspos_uniq.append(labelspos[i])

        labels_uniq = [label.replace("$\mid$", "|") for label in labels_uniq]
        ## hack for dash which can't display latex :(
        labels_uniq = [label.replace("$\Gamma$", u"\u0393") for label in labels_uniq]

        return labels_uniq, labelspos_uniq
Exemplo n.º 15
0
def plot_bs(bs, **kwargs):
    """
    Get BS plot with pymatgen.

    Parameters
    ----------
    bs : 
        BandStructureSymmLine object, most likely generaten from Vasprun or BSVasprun.
    **kwargs : (dict)
        Arguments for the get_plot function in BSPlotter in pymatgen.

    Returns
    -------
    plt : 
        Matplotlib object.
    """
    plotter = BSPlotter(bs)
    plt = plotter.get_plot(**kwargs)

    return plt
Exemplo n.º 16
0
        def bs_dos_traces(bandStructureSymmLine, densityOfStates):

            if bandStructureSymmLine == "error" or densityOfStates == "error":
                return "error"

            if bandStructureSymmLine == None or densityOfStates == None:
                raise PreventUpdate

            # - BS Data
            bstraces = []

            bs_reg_plot = BSPlotter(BSML.from_dict(bandStructureSymmLine))

            bs_data = bs_reg_plot.bs_plot_data()

            # -- Strip latex math wrapping
            str_replace = {
                "$": "",
                "\\mid": "|",
                "\\Gamma": "Γ",
                "\\Sigma": "Σ",
                "_1": "₁",
                "_2": "₂",
                "_3": "₃",
                "_4": "₄",
            }

            for entry_num in range(len(bs_data["ticks"]["label"])):
                for key in str_replace.keys():
                    if key in bs_data["ticks"]["label"][entry_num]:
                        bs_data["ticks"]["label"][entry_num] = bs_data[
                            "ticks"]["label"][entry_num].replace(
                                key, str_replace[key])

            for d in range(len(bs_data["distances"])):
                for i in range(bs_reg_plot._nb_bands):
                    bstraces.append(
                        go.Scatter(
                            x=bs_data["distances"][d],
                            y=[
                                bs_data["energy"][d][str(Spin.up)][i][j]
                                for j in range(len(bs_data["distances"][d]))
                            ],
                            mode="lines",
                            line=dict(color=("#666666"), width=2),
                            hoverinfo="skip",
                            showlegend=False,
                        ))

                    if bs_reg_plot._bs.is_spin_polarized:
                        bstraces.append(
                            go.Scatter(
                                x=bs_data["distances"][d],
                                y=[
                                    bs_data["energy"][d][str(Spin.down)][i][j]
                                    for j in range(len(bs_data["distances"]
                                                       [d]))
                                ],
                                mode="lines",
                                line=dict(color=("#666666"),
                                          width=2,
                                          dash="dash"),
                                hoverinfo="skip",
                                showlegend=False,
                            ))

            # -- DOS Data
            dostraces = []

            dos = CompleteDos.from_dict(densityOfStates)

            if Spin.down in dos.densities:
                # Add second spin data if available
                trace_tdos = go.Scatter(
                    x=dos.densities[Spin.down],
                    y=dos.energies - dos.efermi,
                    mode="lines",
                    name="Total DOS (spin ↓)",
                    line=go.scatter.Line(color="#444444", dash="dash"),
                    fill="tozeroy",
                )

                dostraces.append(trace_tdos)

                tdos_label = "Total DOS (spin ↑)"
            else:
                tdos_label = "Total DOS"

            # Total DOS
            trace_tdos = go.Scatter(
                x=dos.densities[Spin.up],
                y=dos.energies - dos.efermi,
                mode="lines",
                name=tdos_label,
                line=go.scatter.Line(color="#444444"),
                fill="tozeroy",
                legendgroup="spinup",
            )

            dostraces.append(trace_tdos)

            p_ele_dos = dos.get_element_dos()

            # Projected DOS
            count = 0
            colors = [
                "#1f77b4",  # muted blue
                "#ff7f0e",  # safety orange
                "#2ca02c",  # cooked asparagus green
                "#d62728",  # brick red
                "#9467bd",  # muted purple
                "#8c564b",  # chestnut brown
                "#e377c2",  # raspberry yogurt pink
                "#bcbd22",  # curry yellow-green
                "#17becf",  # blue-teal
            ]

            for ele in p_ele_dos.keys():

                if bs_reg_plot._bs.is_spin_polarized:
                    trace = go.Scatter(
                        x=p_ele_dos[ele].densities[Spin.down],
                        y=dos.energies - dos.efermi,
                        mode="lines",
                        name=ele.symbol + " (spin ↓)",
                        line=dict(width=3, color=colors[count], dash="dash"),
                    )

                    dostraces.append(trace)
                    spin_up_label = ele.symbol + " (spin ↑)"

                else:
                    spin_up_label = ele.symbol

                trace = go.Scatter(
                    x=p_ele_dos[ele].densities[Spin.up],
                    y=dos.energies - dos.efermi,
                    mode="lines",
                    name=spin_up_label,
                    line=dict(width=3, color=colors[count]),
                )

                dostraces.append(trace)

                count += 1

            traces = [bstraces, dostraces, bs_data]

            return traces
Exemplo n.º 17
0
    def get_plot(
        self,
        n_idx,
        t_idx,
        zero_to_efermi=True,
        estep=0.01,
        line_density=100,
        height=6,
        width=6,
        emin=None,
        emax=None,
        ylabel="Energy (eV)",
        plt=None,
        aspect=None,
        distance_factor=10,
        kpath=None,
        style=None,
        no_base_style=False,
        fonts=None,
    ):
        interpolater = self._get_interpolater(n_idx, t_idx)

        bs, prop = interpolater.get_line_mode_band_structure(
            line_density=line_density,
            return_other_properties=True,
            kpath=kpath)

        fd_emin, fd_emax = self.fd_cutoffs
        if not emin:
            emin = fd_emin * hartree_to_ev
            if zero_to_efermi:
                emin -= bs.efermi

        if not emax:
            emax = fd_emax * hartree_to_ev
            if zero_to_efermi:
                emax -= bs.efermi

        logger.info("Plotting band structure")
        plt = pretty_plot(width=width, height=height, plt=plt)
        ax = plt.gca()

        if zero_to_efermi:
            bs.bands = {s: b - bs.efermi for s, b in bs.bands.items()}
            bs.efermi = 0

        bs_plotter = BSPlotter(bs)
        plot_data = bs_plotter.bs_plot_data(zero_to_efermi=zero_to_efermi)

        energies = np.linspace(emin, emax, int((emax - emin) / estep))
        distances = np.array([d for x in plot_data["distances"] for d in x])

        # rates are currently log(rate)
        rates = {}
        for spin, spin_data in prop.items():
            rates[spin] = spin_data["rates"]
            rates[spin][rates[spin] <= 0] = np.min(
                rates[spin][rates[spin] > 0])
            rates[spin][rates[spin] >= 15] = 15

        interp_distances = np.linspace(distances.min(), distances.max(),
                                       int(len(distances) * distance_factor))

        window = np.min([len(distances) - 2, 71])
        window += window % 2 + 1
        mesh_data = np.full((len(interp_distances), len(energies)), 1e-2)

        for spin in self.spins:
            for spin_energies, spin_rates in zip(bs.bands[spin], rates[spin]):
                interp_energies = interp1d(distances,
                                           spin_energies)(interp_distances)
                spin_rates = savgol_filter(spin_rates, window, 3)
                interp_rates = interp1d(distances,
                                        spin_rates)(interp_distances)
                linewidths = 10**interp_rates * hbar / 2

                for d_idx in range(len(interp_distances)):
                    energy = interp_energies[d_idx]
                    linewidth = linewidths[d_idx]

                    broadening = lorentzian(energies, energy, linewidth)
                    mesh_data[d_idx] = np.maximum(broadening, mesh_data[d_idx])
                    mesh_data[d_idx] = np.maximum(broadening, mesh_data[d_idx])

        ax.pcolormesh(
            interp_distances,
            energies,
            mesh_data.T,
            rasterized=True,
            norm=LogNorm(vmin=mesh_data.min(), vmax=mesh_data.max()),
        )

        _maketicks(ax, bs_plotter, ylabel=ylabel)
        _makeplot(
            ax,
            plot_data,
            bs,
            zero_to_efermi=zero_to_efermi,
            width=width,
            height=height,
            ymin=emin,
            ymax=emax,
            aspect=aspect,
        )
        return plt
Exemplo n.º 18
0
# Distributed under the terms of the MIT License.

from __future__ import division, print_function, unicode_literals, \
    absolute_import
"""
reads in KPOINTS(with labels for high symmetry kpoints) and 
vasprun.xml files and plots the band structure along the high 
symmetry kpoints
"""

# To use matplotlib on Hipergator, uncomment the following 2 lines:
# import matplotlib
# matplotlib.use('Agg')

from pymatgen.io.vasp.outputs import Vasprun
from pymatgen.electronic_structure.plotter import BSPlotter

if __name__ == "__main__":
    # readin bandstructure from vasprun.xml and labeled KPOINTS
    run = Vasprun("vasprun.xml", parse_projected_eigen=True)
    bands = run.get_band_structure("KPOINTS",
                                   line_mode=True,
                                   efermi=run.efermi)
    bsp = BSPlotter(bands)
    # Blue lines are up spin, red lines are down spin
    bsp.save_plot('band_diagram.eps', ylim=(-5, 5))
    # bsp = BSPlotterProjected(bands)
    # plt = bsp.get_projected_plots_dots( {'Fe':['s', 'p', 'd'],
    #                                     'Sb':['s', 'p', 'd']})
    # get_elt_projected_plots_color()
Exemplo n.º 19
0
# !/usr/bin/env python
# -*- coding: utf-8 -*-

from pymatgen.io.vasp.outputs import Vasprun
from pymatgen.electronic_structure.plotter import BSPlotter
vasprun = Vasprun("vasprun.xml")
bss = vasprun.get_band_structure(kpoints_filename="KPOINTS", line_mode=True)
plotter = BSPlotter(bss)
#plotter.save_plot("bandStructure.svg", img_format="svg")
#plotter.save_plot("bandStructure.png", img_format="png")
#plotter.save_plot("lim_bandStructure.svg", img_format="svg", ylim=(-.2, 1.4))
plotter.save_plot("MAPbI3-primitive.png", img_format="png", ylim=(-5, 5))
plotter.plot_brillouin()
Exemplo n.º 20
0
# -*- coding: utf-8 -*-
"""
Created on Fri Sep 14 16:23:21 2018

@author: hxjia
"""
import pymatgen
from pymatgen.io.vasp.outputs import Vasprun
from pymatgen.electronic_structure.plotter import BSPlotter

vaspout = Vasprun("vasprun.xml")
bandstr = vaspout.get_band_structure(line_mode=True)
#Force the band structure to be considered as a run along symmetry lines

print(bandstr.get_band_gap())

plt = BSPlotter(bandstr).get_plot(ylim=[-4, 4])
plt.yticks(range(-4, 5))
plt.savefig("band.pdf")
Exemplo n.º 21
0
        def bs_dos_data(
            mpid,
            path_convention,
            dos_select,
            label_select,
            bandstructure_symm_line,
            density_of_states,
        ):
            if (not mpid
                    or "mpid" not in mpid) and (bandstructure_symm_line is None
                                                or density_of_states is None):
                raise PreventUpdate
            elif mpid:
                raise PreventUpdate
            elif bandstructure_symm_line is None or density_of_states is None:

                # --
                # -- BS and DOS from API
                # --

                mpid = mpid["mpid"]
                bs_data = {"ticks": {}}

                # client = MongoClient(
                #     "mongodb03.nersc.gov", username="******", password="", authSource="fw_bs_prod",
                # )

                db = client.fw_bs_prod

                # - BS traces from DB using task_id
                bs_query = list(
                    db.electronic_structure.find(
                        {"task_id": int(mpid)},
                        [
                            "bandstructure.{}.total.traces".format(
                                path_convention)
                        ],
                    ))[0]

                is_sp = (len(bs_query["bandstructure"][path_convention]
                             ["total"]["traces"]) == 2)

                if is_sp:
                    bstraces = (bs_query["bandstructure"][path_convention]
                                ["total"]["traces"]["1"] +
                                bs_query["bandstructure"][path_convention]
                                ["total"]["traces"]["-1"])
                else:
                    bstraces = bs_query["bandstructure"][path_convention][
                        "total"]["traces"]["1"]

                bs_data["ticks"]["distance"] = bs_query["bandstructure"][
                    path_convention]["total"]["traces"]["ticks"]
                bs_data["ticks"]["label"] = bs_query["bandstructure"][
                    path_convention]["total"]["traces"]["labels"]

                # If LM convention, get equivalent labels
                if path_convention == "lm" and label_select != "lm":
                    bs_equiv_labels = bs_query["bandstructure"][
                        path_convention]["total"]["traces"]["equiv_labels"]

                    alt_choice = label_select

                    if label_select == "hin":
                        alt_choice = "h"

                    new_labels = []
                    for label in bs_data["ticks"]["label"]:
                        label_formatted = label.replace("$", "")

                        if "|" in label_formatted:
                            f_label = label_formatted.split("|")
                            new_labels.append(
                                "$" + bs_equiv_labels[alt_choice][f_label[0]] +
                                "|" + bs_equiv_labels[alt_choice][f_label[1]] +
                                "$")
                        else:
                            new_labels.append(
                                "$" +
                                bs_equiv_labels[alt_choice][label_formatted] +
                                "$")

                    bs_data["ticks"]["label"] = new_labels

                # Strip latex math wrapping
                str_replace = {
                    "$": "",
                    "\\mid": "|",
                    "\\Gamma": "Γ",
                    "\\Sigma": "Σ",
                    "GAMMA": "Γ",
                    "_1": "₁",
                    "_2": "₂",
                    "_3": "₃",
                    "_4": "₄",
                }

                for entry_num in range(len(bs_data["ticks"]["label"])):
                    for key in str_replace.keys():
                        if key in bs_data["ticks"]["label"][entry_num]:
                            bs_data["ticks"]["label"][entry_num] = bs_data[
                                "ticks"]["label"][entry_num].replace(
                                    key, str_replace[key])

                # - DOS traces from DB using task_id
                dostraces = []

                dos_tot_ele_traces = list(
                    db.electronic_structure.find(
                        {"task_id": int(mpid)},
                        ["dos.total.traces", "dos.elements"]))[0]

                dostraces = [
                    dos_tot_ele_traces["dos"]["total"]["traces"][spin] for spin
                    in dos_tot_ele_traces["dos"]["total"]["traces"].keys()
                ]

                elements = [
                    ele
                    for ele in dos_tot_ele_traces["dos"]["elements"].keys()
                ]

                if dos_select == "ap":
                    for ele_label in elements:
                        dostraces += [
                            dos_tot_ele_traces["dos"]["elements"][ele_label]
                            ["total"]["traces"][spin]
                            for spin in dos_tot_ele_traces["dos"]["elements"]
                            [ele_label]["total"]["traces"].keys()
                        ]

                elif dos_select == "op":
                    orb_tot_traces = list(
                        db.electronic_structure.find({"task_id": int(mpid)},
                                                     ["dos.orbitals"]))[0]
                    for orbital in ["s", "p", "d"]:
                        dostraces += [
                            orb_tot_traces["dos"]["orbitals"][orbital]
                            ["traces"][spin] for spin in orb_tot_traces["dos"]
                            ["orbitals"]["s"]["traces"].keys()
                        ]

                elif "orb" in dos_select:
                    ele_label = dos_select.replace("orb", "")

                    for orbital in ["s", "p", "d"]:
                        dostraces += [
                            dos_tot_ele_traces["dos"]["elements"][ele_label]
                            [orbital]["traces"][spin]
                            for spin in dos_tot_ele_traces["dos"]["elements"]
                            [ele_label][orbital]["traces"].keys()
                        ]

                traces = [bstraces, dostraces, bs_data]

                return (traces, elements)

            else:

                # --
                # -- BS and DOS passed manually
                # --

                # - BS Data

                if type(bandstructure_symm_line) != dict:
                    bandstructure_symm_line = bandstructure_symm_line.to_dict()

                if type(density_of_states) != dict:
                    density_of_states = density_of_states.to_dict()

                bs_reg_plot = BSPlotter(
                    BSML.from_dict(bandstructure_symm_line))
                bs_data = bs_reg_plot.bs_plot_data()

                # - Strip latex math wrapping
                str_replace = {
                    "$": "",
                    "\\mid": "|",
                    "\\Gamma": "Γ",
                    "\\Sigma": "Σ",
                    "GAMMA": "Γ",
                    "_1": "₁",
                    "_2": "₂",
                    "_3": "₃",
                    "_4": "₄",
                }

                for entry_num in range(len(bs_data["ticks"]["label"])):
                    for key in str_replace.keys():
                        if key in bs_data["ticks"]["label"][entry_num]:
                            bs_data["ticks"]["label"][entry_num] = bs_data[
                                "ticks"]["label"][entry_num].replace(
                                    key, str_replace[key])

                # Obtain bands to plot over:
                energy_window = (-6.0, 10.0)
                bands = []
                for band_num in range(bs_reg_plot._nb_bands):
                    if (bs_data["energy"][0][str(Spin.up)][band_num][0] <=
                            energy_window[1]) and (bs_data["energy"][0][str(
                                Spin.up)][band_num][0] >= energy_window[0]):
                        bands.append(band_num)

                bstraces = []

                # Generate traces for total BS data
                for d in range(len(bs_data["distances"])):
                    dist_dat = bs_data["distances"][d]
                    energy_ind = [
                        i for i in range(len(bs_data["distances"][d]))
                    ]

                    traces_for_segment = [{
                        "x":
                        dist_dat,
                        "y":
                        [bs_data["energy"][d]["1"][i][j] for j in energy_ind],
                        "mode":
                        "lines",
                        "line": {
                            "color": "#666666"
                        },
                        "hoverinfo":
                        "skip",
                        "showlegend":
                        False,
                    } for i in bands]

                    if bs_reg_plot._bs.is_spin_polarized:
                        traces_for_segment += [{
                            "x":
                            dist_dat,
                            "y": [
                                bs_data["energy"][d]["-1"][i][j]
                                for j in energy_ind
                            ],
                            "mode":
                            "lines",
                            "line": {
                                "color": "#666666"
                            },
                            "hoverinfo":
                            "skip",
                            "showlegend":
                            False,
                        } for i in bands]

                    bstraces += traces_for_segment

                # - DOS Data
                dostraces = []

                dos = CompleteDos.from_dict(density_of_states)

                dos_max = np.abs(
                    (dos.energies - dos.efermi - energy_window[1])).argmin()
                dos_min = np.abs(
                    (dos.energies - dos.efermi - energy_window[0])).argmin()

                if bs_reg_plot._bs.is_spin_polarized:
                    # Add second spin data if available
                    trace_tdos = go.Scatter(
                        x=dos.densities[Spin.down][dos_min:dos_max],
                        y=dos.energies[dos_min:dos_max] - dos.efermi,
                        mode="lines",
                        name="Total DOS (spin ↓)",
                        line=go.scatter.Line(color="#444444", dash="dash"),
                        fill="tozerox",
                    )

                    dostraces.append(trace_tdos)

                    tdos_label = "Total DOS (spin ↑)"
                else:
                    tdos_label = "Total DOS"

                # Total DOS
                trace_tdos = go.Scatter(
                    x=dos.densities[Spin.up][dos_min:dos_max],
                    y=dos.energies[dos_min:dos_max] - dos.efermi,
                    mode="lines",
                    name=tdos_label,
                    line=go.scatter.Line(color="#444444"),
                    fill="tozerox",
                    legendgroup="spinup",
                )

                dostraces.append(trace_tdos)

                ele_dos = dos.get_element_dos()
                elements = [str(entry) for entry in ele_dos.keys()]

                if dos_select == "ap":
                    proj_data = ele_dos
                elif dos_select == "op":
                    proj_data = dos.get_spd_dos()
                elif "orb" in dos_select:
                    proj_data = dos.get_element_spd_dos(
                        Element(dos_select.replace("orb", "")))
                else:
                    raise PreventUpdate

                # Projected DOS
                count = 0
                colors = [
                    "#1f77b4",  # muted blue
                    "#ff7f0e",  # safety orange
                    "#2ca02c",  # cooked asparagus green
                    "#9467bd",  # muted purple
                    "#e377c2",  # raspberry yogurt pink
                    "#d62728",  # brick red
                    "#8c564b",  # chestnut brown
                    "#bcbd22",  # curry yellow-green
                    "#17becf",  # blue-teal
                ]

                for label in proj_data.keys():

                    if bs_reg_plot._bs.is_spin_polarized:
                        trace = go.Scatter(
                            x=proj_data[label].densities[Spin.down]
                            [dos_min:dos_max],
                            y=dos.energies[dos_min:dos_max] - dos.efermi,
                            mode="lines",
                            name=str(label) + " (spin ↓)",
                            line=dict(width=3,
                                      color=colors[count],
                                      dash="dash"),
                        )

                        dostraces.append(trace)
                        spin_up_label = str(label) + " (spin ↑)"

                    else:
                        spin_up_label = str(label)

                    trace = go.Scatter(
                        x=proj_data[label].densities[Spin.up][dos_min:dos_max],
                        y=dos.energies[dos_min:dos_max] - dos.efermi,
                        mode="lines",
                        name=spin_up_label,
                        line=dict(width=3, color=colors[count]),
                    )

                    dostraces.append(trace)

                    count += 1

                traces = [bstraces, dostraces, bs_data]

                return (traces, elements)
Exemplo n.º 22
0
from pymatgen.electronic_structure.plotter import BSPlotter, BSPlotterProjected
from pymatgen.io.vasp import Vasprun, BandStructure

v = Vasprun("AgTe_bs/vasprun.xml")
bands = v.get_band_structure(kpoints_filename="AgTe_bs/KPOINTS",
                             line_mode=True)

print(bands.get_band_gap())

plt = BSPlotter(bands)
#plt.plot_brillouin()

plt.get_plot(zero_to_efermi=True, vbm_cbm_marker=True, ylim=(-3, 3)).show()
#plt.get_plot(zero_to_efermi=True,vbm_cbm_marker=True,ylim=(-2.2,0.5)).savefig(fname="bs.eps",img_format="eps")
Exemplo n.º 23
0
def bandstr(vrun="", kpfile="", filename=".", plot=False):
    """
    Plot electronic bandstructure
 
    Args:
        vrun: path to vasprun.xml
        kpfile:path to line mode KPOINTS file 
    Returns:
           matplotlib object
    """

    run = Vasprun(vrun, parse_projected_eigen=True)
    bands = run.get_band_structure(kpfile, line_mode=True, efermi=run.efermi)
    bsp = BSPlotter(bands)
    zero_to_efermi = True
    bandgap = str(round(bands.get_band_gap()["energy"], 3))
    # print "bg=",bandgap
    data = bsp.bs_plot_data(zero_to_efermi)
    plt = get_publication_quality_plot(12, 8)
    plt.close()
    plt.clf()
    band_linewidth = 3
    x_max = data["distances"][-1][-1]
    # print (x_max)
    for d in range(len(data["distances"])):
        for i in range(bsp._nb_bands):
            plt.plot(
                data["distances"][d],
                [
                    data["energy"][d]["1"][i][j]
                    for j in range(len(data["distances"][d]))
                ],
                "b-",
                linewidth=band_linewidth,
            )
            if bsp._bs.is_spin_polarized:
                plt.plot(
                    data["distances"][d],
                    [
                        data["energy"][d]["-1"][i][j]
                        for j in range(len(data["distances"][d]))
                    ],
                    "r--",
                    linewidth=band_linewidth,
                )
    bsp._maketicks(plt)
    if bsp._bs.is_metal():
        e_min = -10
        e_max = 10
        band_linewidth = 3

    for cbm in data["cbm"]:
        plt.scatter(cbm[0], cbm[1], color="r", marker="o", s=100)

        for vbm in data["vbm"]:
            plt.scatter(vbm[0], vbm[1], color="g", marker="o", s=100)

    plt.xlabel(r"$\mathrm{Wave\ Vector}$", fontsize=30)
    ylabel = (
        r"$\mathrm{E\ -\ E_f\ (eV)}$" if zero_to_efermi else r"$\mathrm{Energy\ (eV)}$"
    )
    plt.ylabel(ylabel, fontsize=30)
    plt.ylim(-4, 4)
    plt.xlim(0, x_max)
    plt.tight_layout()
    if plot == True:
        plt.savefig(filename, img_format="png")
        plt.close()

    return plt
Exemplo n.º 24
0
#!/nfshome/villa/anaconda3/bin/python

from pymatgen.io.vasp.outputs import Vasprun
from pymatgen.electronic_structure.plotter import BSPlotter, BSDOSPlotter

vaspout_bs = Vasprun("vasprun.xml")
vaspout_dos = Vasprun("vasprun.xml")

bandstr = vaspout_bs.get_band_structure(line_mode=True, force_hybrid_mode=True)
bs_data = BSPlotter(bandstr).bs_plot_data()

dos = vaspout_dos.complete_dos
dos_dict = dos.as_dict()

#plt = BSPlotter(bandstr).get_plot()
plt = BSDOSPlotter(bs_projection=None,
                   dos_projection='elements',
                   bs_legend=None,
                   fig_size=(14, 11)).get_plot(bandstr, dos)
plt.savefig("dos_bs.png")
Exemplo n.º 25
0
dosplotter = DosPlotter()
Totaldos = dosplotter.add_dos('Total DOS', tdos)
Integrateddos = dosplotter.add_dos('Integrated DOS', idos)
#Pdos = dosplotter.add_dos('Partial DOS',pdos)
#Spd_dos =  dosplotter.add_dos('spd DOS',spd_dos)
#Element_dos = dosplotter.add_dos('Element DOS',element_dos)
#Element_spd_dos = dosplotter.add_dos('Element_spd DOS',element_spd_dos)
dos_dict = {
    'Total DOS': tdos,
    'Integrated DOS': idos
}  #'Partial DOS':pdos,'spd DOS':spd_dos,'Element DOS':element_dos}#'Element_spd DOS':element_spd_dos
add_dos_dict = dosplotter.add_dos_dict(dos_dict)
get_dos_dict = dosplotter.get_dos_dict()
dos_plot = dosplotter.get_plot()
##dosplotter.save_plot("MAPbI3_dos",img_format="png")
##dos_plot.show()
bsplotter = BSPlotter(bs)
bs_plot_data = bsplotter.bs_plot_data()
bs_plot = bsplotter.get_plot()
#bsplotter.save_plot("MAPbI3_bs",img_format="png")
#bsplotter.show()
ticks = bsplotter.get_ticks()
print ticks
bsplotter.plot_brillouin()
bsdos = BSDOSPlotter(
    tick_fontsize=10,
    egrid_interval=20,
    dos_projection="orbitals",
    bs_legend=None)  #bs_projection="HPbCIN",dos_projection="HPbCIN")
bds = bsdos.get_plot(bs, cdos)
Exemplo n.º 26
0
def band_structure():
    check_matplotlib()
    step_count=1

    filename='vasprun.xml'
    check_file(filename)
    proc_str="Reading Data From "+ filename +" File ..."
    procs(proc_str,step_count,sp='-->>')
    vsr=Vasprun(filename)

    step_count+=1
    filename='KPOINTS'
    check_file(filename)
    proc_str="Reading Data From "+ filename +" File ..."
    procs(proc_str,step_count,sp='-->>')
    bands = vsr.get_band_structure(filename, line_mode=True, efermi=vsr.efermi)

    step_count+=1
    filename='OUTCAR'
    check_file(filename)
    proc_str="Reading Data From "+ filename +" File ..."
    procs(proc_str,step_count,sp='-->>')
    outcar=Outcar('OUTCAR')
    mag=outcar.as_dict()['total_magnetization']

    if vsr.is_spin:
       proc_str="This Is a Spin-polarized Calculation."
       procs(proc_str,0,sp='-->>')
       tdos=vsr.tdos
       SpinUp_gap=tdos.get_gap(spin=Spin.up) 
       cbm_vbm_up=tdos.get_cbm_vbm(spin=Spin.up)
       SpinDown_gap=tdos.get_gap(spin=Spin.down) 
       cbm_vbm_down=tdos.get_cbm_vbm(spin=Spin.up)

       if SpinUp_gap > min_gap and SpinDown_gap > min_gap:
          is_metal=False
          is_semimetal=False
       elif SpinUp_gap > min_gap and SpinDown_gap < min_gap:
          is_metal=False
          is_semimetal=True
       elif SpinUp_gap < min_gap and SpinDown_gap > min_gap:
          is_metal=False
          is_semimetal=True
       elif SpinUp_gap < min_gap and SpinDown_gap < min_gap:
          is_metal=True
          is_semimetal=False
          
       if is_metal:   
          proc_str="This Material Is a Metal."
          procs(proc_str,0,sp='-->>')
       if not is_metal and is_semimetal:
          proc_str="This Material Is a Semimetal."
          procs(proc_str,0,sp='-->>')
       else:
          proc_str="This Material Is a Semiconductor."
          procs(proc_str,0,sp='-->>')
          proc_str="Total magnetization is "+str(mag)
          procs(proc_str,0,sp='-->>')
          if mag > min_mag:
             proc_str="SpinUp  : vbm=%f eV cbm=%f eV gap=%f eV"%(cbm_vbm_up[1],cbm_vbm_up[0],SpinUp_gap)
             procs(proc_str,0,sp='-->>')
             proc_str="SpinDown: vbm=%f eV cbm=%f eV gap=%f eV"%(cbm_vbm_down[1],cbm_vbm_down[0],SpinUp_gap)
             procs(proc_str,0,sp='-->>')
          else:
             proc_str="SpinUp  : vbm=%f eV cbm=%f eV gap=%f eV"%(cbm_vbm_up[1],cbm_vbm_up[0],SpinUp_gap)
             procs(proc_str,0,sp='-->>')
       step_count+=1
       filename="BAND.dat"
       proc_str="Writting Band Structure Data to "+ filename +" File ..."
       procs(proc_str,step_count,sp='-->>')
       band_data_up=bands.bands[Spin.up]
       band_data_down=bands.bands[Spin.down]
       y_data_up=band_data_up.reshape(1,band_data_up.shape[0]*band_data_up.shape[1])[0]-vsr.efermi #shift fermi level to 0
       y_data_down=band_data_down.reshape(1,band_data_down.shape[0]*band_data_down.shape[1])[0]-vsr.efermi #shift fermi level to 0
       x_data=np.array(bands.distance*band_data_up.shape[0])
       data=np.vstack((x_data,y_data_up,y_data_down)).T
       head_line="#%(key1)+12s%(key2)+13s%(key3)+15s"%{'key1':'K-Distance','key2':'UpEnergy(ev)','key3':'DownEnergy(ev)'}
       write_col_data(filename,data,head_line,band_data_up.shape[1])
 
    else:
       if vsr.parameters['LNONCOLLINEAR']:
          proc_str="This Is a Non-Collinear Calculation."
       else:
           proc_str="This Is a Non-Spin Calculation."
       procs(proc_str,0,sp='-->>')
       cbm=bands.get_cbm()['energy']
       vbm=bands.get_vbm()['energy']
       gap=bands.get_band_gap()['energy']
       if not bands.is_metal():
          proc_str="This Material Is a Semiconductor."
          procs(proc_str,0,sp='-->>')
          proc_str="vbm=%f eV cbm=%f eV gap=%f eV"%(vbm,cbm,gap)
          procs(proc_str,0,sp='-->>')
       else:
          proc_str="This Material Is a Metal."
          procs(proc_str,0,sp='-->>')
       
       step_count+=1
       filename3="BAND.dat"
       proc_str="Writting Band Structure Data to "+ filename3 +" File ..."
       procs(proc_str,step_count,sp='-->>')
       band_data=bands.bands[Spin.up]
       y_data=band_data.reshape(1,band_data.shape[0]*band_data.shape[1])[0]-vsr.efermi #shift fermi level to 0
       x_data=np.array(bands.distance*band_data.shape[0])
       data=np.vstack((x_data,y_data)).T
       head_line="#%(key1)+12s%(key2)+13s"%{'key1':'K-Distance','key2':'Energy(ev)'}
       write_col_data(filename3,data,head_line,band_data.shape[1])
       step_count+=1
       bsp=BSPlotter(bands)
       filename4="HighSymmetricPoints.dat"
       proc_str="Writting Label infomation to "+ filename4 +" File ..."
       procs(proc_str,step_count,sp='-->>')
       head_line="#%(key1)+12s%(key2)+12s%(key3)+12s"%{'key1':'index','key2':'label','key3':'position'}
       line=head_line+'\n'
       for i,label in enumerate(bsp.get_ticks()['label']):
           new_line="%(key1)12d%(key2)+12s%(key3)12f\n"%{'key1':i,'key2':label,'key3':bsp.get_ticks()['distance'][i]}
           line+=new_line
       line+='\n'
       write_col_data(filename4,line,'',str_data=True) 
    try:
       step_count+=1
       filename5="BAND.png"
       proc_str="Saving Plot to "+ filename5 +" File ..."
       procs(proc_str,step_count,sp='-->>')
       bsp.save_plot(filename5, img_format="png")
    except:
       print("Figure output fails !!!")   
Exemplo n.º 27
0
from pymatgen.io.vasp.outputs import BSVasprun
from pymatgen.electronic_structure.plotter import BSPlotter
import os

os.chdir('/home/jinho93/half-metal/1.CrO2/3.band')
vrun = BSVasprun('vasprun.xml')
bs = vrun.get_band_structure('KPOINTS', line_mode=True)
bsp = BSPlotter(bs)

bsp.show()
Exemplo n.º 28
0
def projected_band_structure():
   step_count=1
   filename='vasprun.xml'
   check_file(filename)
   proc_str="Reading Data From "+ filename +" File ..."
   procs(proc_str,step_count,sp='-->>')
   vsr=Vasprun(filename)

   filename='PROCAR'
   check_file(filename)
   step_count+=1
   proc_str="Reading Data From "+ filename +" File ..."
   procs(proc_str,step_count,sp='-->>')
   procar=Procar(filename)
   nbands=procar.nbands
   nions=procar.nions
   norbitals=len(procar.orbitals)
   nkpoints=procar.nkpoints

   step_count+=1
   filename='KPOINTS'
   check_file(filename)
   proc_str="Reading Data From "+ filename +" File ..."
   procs(proc_str,step_count,sp='-->>')
   bands = vsr.get_band_structure(filename, line_mode=True, efermi=vsr.efermi)
   struct=vsr.final_structure
   (atom_index,in_str)=atom_selection(struct)
   
   if len(atom_index)==0:
      print("No atoms selected!")
      return
#   print(atom_index)

   if vsr.is_spin:
      proc_str="This Is a Spin-polarized Calculation."
      procs(proc_str,0,sp='-->>')
      ISPIN=2
      contrib=np.zeros((nkpoints,nbands,norbitals,2))
      for i in atom_index:
          contrib[:,:,:,0]=contrib[:,:,:,0]+procar.data[Spin.up][:,:,i,:]
          contrib[:,:,:,1]=contrib[:,:,:,1]+procar.data[Spin.down][:,:,i,:]

      for ispin in range(2):
          proj_band=contrib[:,:,:,ispin].reshape(nkpoints*nbands,norbitals)
          step_count+=1
          if ispin==0:
              filename="PBAND_Up.dat"
          else:
              filename="PBAND_Down.dat"
          proc_str="Writting Projected Band Structure Data to "+ filename +" File ..."
          procs(proc_str,step_count,sp='-->>')
          band_data=bands.bands[Spin.up]
          y_data=band_data.reshape(1,nbands*nkpoints)[0]-vsr.efermi #shift fermi level to 0
          x_data=np.array(bands.distance*nbands)
          data=np.vstack((x_data,y_data,proj_band.T)).T
          tmp1_str="#%(key1)+12s%(key2)+12s"
          tmp2_dic={'key1':'K-Distance','key2':'Energy(ev)'}
          for i in range(norbitals):
              tmp1_str+="%(key"+str(i+3)+")+12s"
              tmp2_dic["key"+str(i+3)]=procar.orbitals[i]

#          print(tmp1_str)
          atom_index_str=[str(x+1) for x in atom_index]
          head_line1="#String: "+in_str+'\n#Selected atom: ' +' '.join(atom_index_str)+'\n'
          head_line2=tmp1_str % tmp2_dic
          head_line=head_line1+head_line2
          write_col_data(filename,data,head_line,nkpoints)

   else:
      if vsr.parameters['LNONCOLLINEAR']:
         proc_str="This Is a Non-Collinear Calculation."
         procs(proc_str,0,sp='-->>')
         ISPIN=3
      else:
         proc_str="This Is a Non-Spin Calculation."
         procs(proc_str,0,sp='-->>')
         ISPIN=1

      contrib=np.zeros((nkpoints,nbands,norbitals))
      for i in atom_index:
          contrib[:,:,:]=contrib[:,:,:]+procar.data[Spin.up][:,:,i,:]
 
      proj_band=contrib.reshape(nkpoints*nbands,norbitals)        
      step_count+=1
      filename="PBAND.dat"
      proc_str="Writting Projected Band Structure Data to "+ filename +" File ..."
      procs(proc_str,step_count,sp='-->>')
      band_data=bands.bands[Spin.up]
      y_data=band_data.reshape(1,nbands*nkpoints)[0]-vsr.efermi #shift fermi level to 0
      x_data=np.array(bands.distance*nbands)
      data=np.vstack((x_data,y_data,proj_band.T)).T
      tmp1_str="#%(key1)+12s%(key2)+12s"
      tmp2_dic={'key1':'K-Distance','key2':'Energy(ev)'}
      for i in range(norbitals):
          tmp1_str+="%(key"+str(i+3)+")+12s"
          tmp2_dic["key"+str(i+3)]=procar.orbitals[i]

#      print(tmp1_str)
      atom_index_str=[str(x+1) for x in atom_index]
      head_line1="#String: "+in_str+'\n#Selected atom: ' +' '.join(atom_index_str)+'\n'
      head_line2=tmp1_str % tmp2_dic
      head_line=head_line1+head_line2
      write_col_data(filename,data,head_line,nkpoints)

   step_count+=1
   bsp=BSPlotter(bands)
   filename="HighSymmetricPoints.dat"
   proc_str="Writting Label infomation to "+ filename +" File ..."
   procs(proc_str,step_count,sp='-->>')
   head_line="#%(key1)+12s%(key2)+12s%(key3)+12s"%{'key1':'index','key2':'label','key3':'position'}
   line=head_line+'\n'
   for i,label in enumerate(bsp.get_ticks()['label']):
       new_line="%(key1)12d%(key2)+12s%(key3)12f\n"%{'key1':i,'key2':label,'key3':bsp.get_ticks()['distance'][i]}
       line+=new_line
   write_col_data(filename,line,'',str_data=True) 
Exemplo n.º 29
0
    def get_plot(
        self,
        n_idx,
        t_idx,
        zero_to_efermi=True,
        estep=0.01,
        line_density=100,
        height=3.2,
        width=3.2,
        emin=None,
        emax=None,
        amin=5e-5,
        amax=1e-1,
        ylabel="Energy (eV)",
        plt=None,
        aspect=None,
        kpath=None,
        cmap="viridis",
        colorbar=True,
        style=None,
        no_base_style=False,
        fonts=None,
    ):
        interpolater = self._get_interpolater(n_idx, t_idx)

        bs, prop = interpolater.get_line_mode_band_structure(
            line_density=line_density,
            return_other_properties=True,
            kpath=kpath,
            symprec=self.symprec,
        )
        bs, rates = force_branches(bs, {s: p["rates"] for s, p in prop.items()})

        fd_emin, fd_emax = self.fd_cutoffs
        if not emin:
            emin = fd_emin
            if zero_to_efermi:
                emin -= bs.efermi

        if not emax:
            emax = fd_emax
            if zero_to_efermi:
                emax -= bs.efermi

        logger.info("Plotting band structure")
        if isinstance(plt, (Axis, SubplotBase)):
            ax = plt
        else:
            plt = pretty_plot(width=width, height=height, plt=plt)
            ax = plt.gca()

        if zero_to_efermi:
            bs.bands = {s: b - bs.efermi for s, b in bs.bands.items()}
            bs.efermi = 0

        bs_plotter = BSPlotter(bs)
        plot_data = bs_plotter.bs_plot_data(zero_to_efermi=zero_to_efermi)

        energies = np.linspace(emin, emax, int((emax - emin) / estep))
        distances = np.array([d for x in plot_data["distances"] for d in x])

        # rates are currently log(rate)
        mesh_data = np.full((len(distances), len(energies)), 0.0)
        for spin in self.spins:
            for spin_energies, spin_rates in zip(bs.bands[spin], rates[spin]):
                for d_idx in range(len(distances)):
                    energy = spin_energies[d_idx]
                    linewidth = 10 ** spin_rates[d_idx] * hbar / 2
                    broadening = lorentzian(energies, energy, linewidth)
                    broadening /= 1000  # convert 1/eV to 1/meV
                    mesh_data[d_idx] += broadening

        im = ax.pcolormesh(
            distances,
            energies,
            mesh_data.T,
            rasterized=True,
            cmap=cmap,
            norm=LogNorm(vmin=amin, vmax=amax),
            shading="auto",
        )
        if colorbar:
            pos = ax.get_position()
            cax = plt.gcf().add_axes([pos.x1 + 0.035, pos.y0, 0.035, pos.height])
            cbar = plt.colorbar(im, cax=cax)
            cbar.ax.tick_params(axis="y", length=rcParams["ytick.major.size"] * 0.5)
            cbar.ax.set_ylabel(
                r"$A_\mathbf{k}$ (meV$^{-1}$)", rotation=270, va="bottom"
            )

        _maketicks(ax, bs_plotter, ylabel=ylabel)
        _makeplot(
            ax,
            plot_data,
            bs,
            zero_to_efermi=zero_to_efermi,
            width=width,
            height=height,
            ymin=emin,
            ymax=emax,
            aspect=aspect,
        )
        return plt
Exemplo n.º 30
0
# -*- coding: utf-8 -*-
"""
Created on Wed May 15 10:33:43 2019

@author: nwpuf
"""

from pymatgen.io.vasp import Vasprun, BSVasprun
from pymatgen.electronic_structure.plotter import BSPlotter, DosPlotter
import matplotlib.pyplot as plt

bsv = BSVasprun("vasprun.xml")
bs = bsv.get_band_structure(kpoints_filename="KPOINTS", line_mode=True)
print(bs.get_band_gap())
bsplot = BSPlotter(bs)
bsplot.get_plot(vbm_cbm_marker=True).show()

#dosrun = Vasprun("DOS/vasprun.xml", parse_dos=True)
#dos = dosrun.complete_dos
#dosplot = DosPlotter(sigma=0.1)
#dosplot.add_dos("Total DOS", dos)
#dosplot.add_dos_dict(dos.get_element_dos())
#ax = plt.gca()
#print(type(dosplot.get_plot()))
#dosplot.get_plot().show()