Ejemplo n.º 1
0
 def setUp(self):
     with open(os.path.join(test_dir, "CaO_2605_bandstructure.json"),
               "r",
               encoding='utf-8') as f:
         d = json.loads(f.read())
         self.bs = BandStructureSymmLine.from_dict(d)
         self.plotter = BSPlotter(self.bs)
Ejemplo n.º 2
0
def find_dirac_nodes():
    """
    Look for band crossings near (within `tol` eV) the Fermi level.

    Returns:
        boolean. Whether or not a band crossing occurs at or near
            the fermi level.
    """

    vasprun = Vasprun('vasprun.xml')
    dirac = False
    if vasprun.get_band_structure().get_band_gap()['energy'] < 0.1:
        efermi = vasprun.efermi
        bsp = BSPlotter(vasprun.get_band_structure('KPOINTS', line_mode=True,
                                                   efermi=efermi))
        bands = []
        data = bsp.bs_plot_data(zero_to_efermi=True)
        for d in range(len(data['distances'])):
            for i in range(bsp._nb_bands):
                x = data['distances'][d],
                y = [data['energy'][d][str(Spin.up)][i][j]
                     for j in range(len(data['distances'][d]))]
                band = [x, y]
                bands.append(band)

        considered = []
        for i in range(len(bands)):
            for j in range(len(bands)):
                if i != j and (j, i) not in considered:
                    considered.append((j, i))
                    for k in range(len(bands[i][0])):
                        if ((-0.1 < bands[i][1][k] < 0.1) and
                                (-0.1 < bands[i][1][k] - bands[j][1][k] < 0.1)):
                            dirac = True
    return dirac
Ejemplo n.º 3
0
def find_dirac_nodes():
    """
    Look for band crossings near (within `tol` eV) the Fermi level.

    Returns:
        boolean. Whether or not a band crossing occurs at or near
            the fermi level.
    """

    vasprun = Vasprun('vasprun.xml')
    dirac = False
    if vasprun.get_band_structure().get_band_gap()['energy'] < 0.1:
        efermi = vasprun.efermi
        bsp = BSPlotter(vasprun.get_band_structure('KPOINTS', line_mode=True,
                                                   efermi=efermi))
        bands = []
        data = bsp.bs_plot_data(zero_to_efermi=True)
        for d in range(len(data['distances'])):
            for i in range(bsp._nb_bands):
                x = data['distances'][d],
                y = [data['energy'][d][str(Spin.up)][i][j]
                     for j in range(len(data['distances'][d]))]
                band = [x, y]
                bands.append(band)

        considered = []
        for i in range(len(bands)):
            for j in range(len(bands)):
                if i != j and (j, i) not in considered:
                    considered.append((j, i))
                    for k in range(len(bands[i][0])):
                        if ((-0.1 < bands[i][1][k] < 0.1) and
                                (-0.1 < bands[i][1][k] - bands[j][1][k] < 0.1)):
                            dirac = True
    return dirac
Ejemplo n.º 4
0
def plot_bandstructure():
    if "-h" in sys.argv:
        print("usage: complot_bands.py [-g] [-f fname] [-el emin] [-eh emax]")
        sys.exit()

    if "-g" in sys.argv:
        mode = "risb"
    else:
        mode = "tb"
    bs = get_bands_symkpath(mode=mode)
    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()
    if rank == 0:
        bsplot = BSPlotter(bs)
        if "-f" in sys.argv:
            fname = sys.argv[sys.argv.index("-f")+1]
            if ".pdf" not in fname:
                fname += ".pdf"
        else:
            fname = "bndstr.pdf"
        if "-el" in sys.argv:
            emin = float(sys.argv[sys.argv.index("-el")+1])
        else:
            emin = numpy.min(bs.bands.values())
        if "-eh" in sys.argv:
            emax = float(sys.argv[sys.argv.index("-eh")+1])
        else:
            emax = numpy.max(bs.bands.values())
        bsplot.save_plot(fname, img_format="pdf", ylim=(emin, emax), \
                zero_to_efermi=False)
Ejemplo n.º 5
0
    def make_band_plot_info(self):
        bs_plotter = BSPlotter(self.bs)
        plot_data = bs_plotter.bs_plot_data(zero_to_efermi=False)
        distances = [list(d) for d in plot_data["distances"]]
        self._composition = self.vasprun.final_structure.composition

        band_info = [BandInfo(band_energies=self._remove_spin_key(plot_data),
                              band_edge=self._band_edge(self.bs, plot_data),
                              fermi_level=self.bs.efermi)]

        if self.vasprun2:
            bs2 = self.vasprun2.get_band_structure(self.kpoints_filename,
                                                   line_mode=True)
            plot_data2 = BSPlotter(bs2).bs_plot_data(zero_to_efermi=False)
            band_info.append(
                BandInfo(band_energies=self._remove_spin_key(plot_data2),
                         band_edge=self._band_edge(bs2, plot_data2),
                         fermi_level=self.bs.efermi))

        x = bs_plotter.get_ticks_old()
        x_ticks = XTicks(_sanitize_labels(x["label"]), x["distance"])

        return BandPlotInfo(band_info_set=band_info,
                            distances_by_branch=distances,
                            x_ticks=x_ticks,
                            title=self._title)
Ejemplo n.º 6
0
def plot_band_structure(ylim=(-5, 5), draw_fermi=False, fmt='pdf'):
    """
    Plot a standard band structure with no projections.

    Args:
        ylim (tuple): minimum and maximum potentials for the plot's y-axis.
        draw_fermi (bool): whether or not to draw a dashed line at E_F.
        fmt (str): matplotlib format style. Check the matplotlib docs
            for options.
    """

    vasprun = Vasprun('vasprun.xml')
    efermi = vasprun.efermi
    bsp = BSPlotter(
        vasprun.get_band_structure('KPOINTS', line_mode=True, efermi=efermi))
    if fmt == "None":
        return bsp.bs_plot_data()
    else:
        plot = bsp.get_plot(ylim=ylim)
        fig = plot.gcf()
        ax = fig.gca()
        ax.set_xticklabels(
            [r'$\mathrm{%s}$' % t for t in ax.get_xticklabels()])
        ax.set_yticklabels(
            [r'$\mathrm{%s}$' % t for t in ax.get_yticklabels()])
        if draw_fermi:
            ax.plot([ax.get_xlim()[0], ax.get_xlim()[1]], [0, 0], 'k--')
        plt.savefig('band_structure.{}'.format(fmt), transparent=True)

    plt.close()
Ejemplo n.º 7
0
class BSPlotterTest(unittest.TestCase):
    def setUp(self):
        with open(os.path.join(test_dir, "CaO_2605_bandstructure.json"),
                  "r",
                  encoding='utf-8') as f:
            d = json.loads(f.read())
            self.bs = BandStructureSymmLine.from_dict(d)
            self.plotter = BSPlotter(self.bs)

    def test_bs_plot_data(self):
        self.assertEqual(len(self.plotter.bs_plot_data()['distances'][0]), 16,
                         "wrong number of distances in the first branch")
        self.assertEqual(len(self.plotter.bs_plot_data()['distances']), 10,
                         "wrong number of branches")
        self.assertEqual(
            sum([len(e) for e in self.plotter.bs_plot_data()['distances']]),
            160, "wrong number of distances")
        self.assertEqual(self.plotter.bs_plot_data()['ticks']['label'][5], "K",
                         "wrong tick label")
        self.assertEqual(len(self.plotter.bs_plot_data()['ticks']['label']),
                         19, "wrong number of tick labels")

    def test_qvertex_target(self):
        results = _qvertex_target(
            [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 1.0, 0.0],
             [0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 1.0],
             [1.0, 1.0, 1.0], [0.0, 1.0, 1.0], [0.5, 0.5, 0.5]], 8)
        self.assertEqual(len(results), 6)
        self.assertEqual(results[3][1], 0.5)
Ejemplo n.º 8
0
class BSPlotterTest(unittest.TestCase):

    def setUp(self):
        with open(os.path.join(test_dir, "CaO_2605_bandstructure.json"),
                  "r", encoding='utf-8') as f:
            d = json.loads(f.read())
            self.bs = BandStructureSymmLine.from_dict(d)
            self.plotter = BSPlotter(self.bs)

    def test_bs_plot_data(self):
        self.assertEqual(len(self.plotter.bs_plot_data()['distances'][0]), 16,
                         "wrong number of distances in the first branch")
        self.assertEqual(len(self.plotter.bs_plot_data()['distances']), 10,
                         "wrong number of branches")
        self.assertEqual(
            sum([len(e) for e in self.plotter.bs_plot_data()['distances']]),
            160, "wrong number of distances")
        self.assertEqual(self.plotter.bs_plot_data()['ticks']['label'][5], "K",
                         "wrong tick label")
        self.assertEqual(len(self.plotter.bs_plot_data()['ticks']['label']),
                         19, "wrong number of tick labels")

    def test_qvertex_target(self):
        results = _qvertex_target([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0],
                                   [1.0, 1.0, 0.0], [0.0, 1.0, 0.0],
                                   [0.0, 0.0, 1.0], [1.0, 0.0, 1.0],
                                   [1.0, 1.0, 1.0], [0.0, 1.0, 1.0],
                                   [0.5, 0.5, 0.5]], 8)
        self.assertEqual(len(results), 6)
        self.assertEqual(results[3][1], 0.5)
Ejemplo n.º 9
0
    def setUp(self):
        with open(os.path.join(test_dir, "CaO_2605_bandstructure.json"),
                  "r",
                  encoding="utf-8") as f:
            d = json.loads(f.read())
            self.bs = BandStructureSymmLine.from_dict(d)
            self.plotter = BSPlotter(self.bs)

        self.assertEqual(len(self.plotter._bs), 1,
                         "wrong number of band objects")

        with open(os.path.join(test_dir, "N2_12103_bandstructure.json"),
                  "r",
                  encoding="utf-8") as f:
            d = json.loads(f.read())
            self.sbs_sc = BandStructureSymmLine.from_dict(d)

        with open(os.path.join(test_dir, "C_48_bandstructure.json"),
                  "r",
                  encoding="utf-8") as f:
            d = json.loads(f.read())
            self.sbs_met = BandStructureSymmLine.from_dict(d)

        self.plotter_multi = BSPlotter([self.sbs_sc, self.sbs_met])
        self.assertEqual(len(self.plotter_multi._bs), 2,
                         "wrong number of band objects")
        self.assertEqual(self.plotter_multi._nb_bands, [96, 96],
                         "wrong number of bands")
        warnings.simplefilter("ignore")
Ejemplo n.º 10
0
def plot_simple_smoothed_band_structure(ylim=[-1.5, 3.5], filename=None):
    vasprun = Vasprun('./vasprun.xml')
    bs = vasprun.get_band_structure(line_mode=True)
    if filename is None:
        # BSPlotter(bs).get_plot(smooth=True,ylim=ylim)
        BSPlotter(bs).show(smooth=True, ylim=ylim)
    else:
        BSPlotter(bs).save_plot(filename)
Ejemplo n.º 11
0
def plot_bs(vasprun_path: str):
    try:
        v = BSVasprun(vasprun_path)
    except xml.etree.ElementTree.ParseError:
        print("\tskipped due to parse error")
        return
    bs = v.get_band_structure(kpoints_filename="KPOINTS", line_mode=True)
    plt = BSPlotter(bs)
    plt.show()
Ejemplo n.º 12
0
def banddos(pref='',storedir=None):
    ru=str("vasprun.xml")
    kpfile=str("KPOINTS")




    run = Vasprun(ru, parse_projected_eigen = True)
    bands = run.get_band_structure(kpfile, line_mode = True, efermi = run.efermi)
    bsp =  BSPlotter(bands)
    zero_to_efermi=True
    bandgap=str(round(bands.get_band_gap()['energy'],3))
    print "bg=",bandgap
    data=bsp.bs_plot_data(zero_to_efermi)
    plt = get_publication_quality_plot(12, 8)
    band_linewidth = 3
    x_max = data['distances'][-1][-1]
    print (x_max)
    for d in range(len(data['distances'])):
       for i in range(bsp._nb_bands):
          plt.plot(data['distances'][d],
                 [data['energy'][d]['1'][i][j]
                  for j in range(len(data['distances'][d]))], 'b-',
                 linewidth=band_linewidth)
          if bsp._bs.is_spin_polarized:
             plt.plot(data['distances'][d],
                     [data['energy'][d]['-1'][i][j]
                      for j in range(len(data['distances'][d]))],
                     'r--', linewidth=band_linewidth)
    bsp._maketicks(plt)
    if bsp._bs.is_metal():
         e_min = -10
         e_max = 10
         band_linewidth = 3

    for cbm in data['cbm']:
            plt.scatter(cbm[0], cbm[1], color='r', marker='o',
                        s=100)

            for vbm in data['vbm']:
                plt.scatter(vbm[0], vbm[1], color='g', marker='o',
                            s=100)


    plt.xlabel(r'$\mathrm{Wave\ Vector}$', fontsize=30)
    ylabel = r'$\mathrm{E\ -\ E_f\ (eV)}$' if zero_to_efermi \
       else r'$\mathrm{Energy\ (eV)}$'
    plt.ylabel(ylabel, fontsize=30)
    plt.ylim(-4,4)
    plt.xlim(0,x_max)
    plt.tight_layout()
    plt.savefig('BAND.png',img_format="png")

    plt.close()
Ejemplo n.º 13
0
    def get_bandsxy(self, bs, bandrange):

        ## get coords of the band structure data points
        bsplot = BSPlotter(bs)
        data = bsplot.bs_plot_data()

        x = [k for kbranch in data["distances"] for k in kbranch]
        yu = [[e - bs.efermi for e in bs.bands[Spin.up][band]] for band in bandrange]
        if bs.is_spin_polarized:
            yd = [[e - bs.efermi for e in bs.bands[Spin.down][band]] for band in bandrange]
        else:
            yd = None

        return [x, yu, yd]
Ejemplo n.º 14
0
    def __init__(self, bs):
        if isinstance(bs, list):
            bs = [force_branches(b) for b in bs]
        else:
            bs = force_branches(bs)

        BSPlotter.__init__(self, bs)

        # old versions of pymatgen only support a single band structure
        if isinstance(self._bs, list):
            self.bs = self._bs[0]
            self.nbands = self._nb_bands[0]
        else:
            self.bs = self._bs
            self.nbands = self._nb_bands
Ejemplo n.º 15
0
class BSPlotterTest(unittest.TestCase):
    def setUp(self):
        with open(os.path.join(test_dir, "CaO_2605_bandstructure.json"),
                  "rb") as f:
            d = json.loads(f.read())
            self.bs = BandStructureSymmLine.from_dict(d)
            self.plotter = BSPlotter(self.bs)

    def test_bs_plot_data(self):
        self.assertEqual(len(self.plotter.bs_plot_data()['distances']), 160,
                         "wrong number of distances")
        self.assertEqual(self.plotter.bs_plot_data()['ticks']['label'][5], "K",
                         "wrong tick label")
        self.assertEqual(len(self.plotter.bs_plot_data()['ticks']['label']),
                         19, "wrong number of tick labels")
Ejemplo n.º 16
0
 def setUp(self):
     with open(os.path.join(test_dir, "CaO_2605_bandstructure.json"),
               "r", encoding='utf-8') as f:
         d = json.loads(f.read())
         self.bs = BandStructureSymmLine.from_dict(d)
         self.plotter = BSPlotter(self.bs)
     warnings.simplefilter("ignore")
Ejemplo n.º 17
0
class BSPlotterTest(unittest.TestCase):

    def setUp(self):
        with open(os.path.join(test_dir, "CaO_2605_bandstructure.json"),
                  "rb") as f:
            d = json.loads(f.read())
            self.bs = BandStructureSymmLine.from_dict(d)
            self.plotter = BSPlotter(self.bs)

    def test_bs_plot_data(self):
        self.assertEqual(len(self.plotter.bs_plot_data()['distances']), 160,
                         "wrong number of distances")
        self.assertEqual(self.plotter.bs_plot_data()['ticks']['label'][5], "K",
                         "wrong tick label")
        self.assertEqual(len(self.plotter.bs_plot_data()['ticks']['label']),
                         19, "wrong number of tick labels")
Ejemplo n.º 18
0
def get_mp_banddos():
    check_matplotlib()
    mpr = check_apikey()
    print("input the mp-ID")
    wait_sep()
    in_str = wait()
    mp_id = in_str
    step_count = 1
    proc_str = "Reading Data From " + web + " ..."
    procs(proc_str, step_count, sp='-->>')
    data = mpr.get_entry_by_material_id(mp_id)
    sepline()
    print(data)
    sepline()
    step_count += 1
    proc_str = "Reading Band Data From " + web + " ..."
    procs(proc_str, step_count, sp='-->>')

    band = mpr.get_bandstructure_by_material_id(mp_id)
    if band is None:
        print("No data obtained online, stop now!")
        os.exit(0)

    step_count += 1
    filename = mp_id + '_band.png'
    proc_str = "Writing Data to " + filename + " File..."
    bsp = BSPlotter(band)
    procs(proc_str, step_count, sp='-->>')
    bsp.save_plot(filename=filename, img_format="png")

    step_count += 1
    proc_str = "Reading DOS Data From " + web + " ..."
    procs(proc_str, step_count, sp='-->>')
    dos = mpr.get_dos_by_material_id(mp_id)
    if dos is None:
        print("No data obtained online, stop now!")

    step_count += 1
    filename = mp_id + '_dos.png'
    proc_str = "Writing Data to " + filename + " File..."
    dsp = DosPlotter()
    dsp.add_dos('Total', dos)
    procs(proc_str, step_count, sp='-->>')
    dsp.save_plot(filename=filename, img_format="png")
Ejemplo n.º 19
0
    def get_kpt_labels(self, bs):

        bsplot = BSPlotter(bs)

        ## get unique K-points
        labels = bsplot.get_ticks()["label"]
        labelspos = bsplot.get_ticks()["distance"]
        labels_uniq = [labels[0]]
        labelspos_uniq = [labelspos[0]]
        for i in range(1, len(labels)):
            if labels[i] != labels[i - 1]:
                labels_uniq.append(labels[i])
                labelspos_uniq.append(labelspos[i])

        labels_uniq = [label.replace("$\mid$", "|") for label in labels_uniq]
        ## hack for dash which can't display latex :(
        labels_uniq = [label.replace("$\Gamma$", u"\u0393") for label in labels_uniq]

        return labels_uniq, labelspos_uniq
Ejemplo n.º 20
0
class BSPlotterTest(unittest.TestCase):
    def setUp(self):
        with open(os.path.join(test_dir, "CaO_2605_bandstructure.json"),
                  "r",
                  encoding='utf-8') as f:
            d = json.loads(f.read())
            self.bs = BandStructureSymmLine.from_dict(d)
            self.plotter = BSPlotter(self.bs)
        warnings.simplefilter("ignore")

    def tearDown(self):
        warnings.simplefilter("default")

    def test_bs_plot_data(self):
        self.assertEqual(len(self.plotter.bs_plot_data()['distances'][0]), 16,
                         "wrong number of distances in the first branch")
        self.assertEqual(len(self.plotter.bs_plot_data()['distances']), 10,
                         "wrong number of branches")
        self.assertEqual(
            sum([len(e) for e in self.plotter.bs_plot_data()['distances']]),
            160, "wrong number of distances")
        self.assertEqual(self.plotter.bs_plot_data()['ticks']['label'][5], "K",
                         "wrong tick label")
        self.assertEqual(len(self.plotter.bs_plot_data()['ticks']['label']),
                         19, "wrong number of tick labels")

    # Minimal baseline testing for get_plot. not a true test. Just checks that
    # it can actually execute.
    def test_get_plot(self):
        # zero_to_efermi = True, ylim = None, smooth = False,
        # vbm_cbm_marker = False, smooth_tol = None

        # Disabling latex is needed for this test to work.
        from matplotlib import rc
        rc('text', usetex=False)

        plt = self.plotter.get_plot()
        plt = self.plotter.get_plot(smooth=True)
        plt = self.plotter.get_plot(vbm_cbm_marker=True)
        self.plotter.save_plot("bsplot.png")
        self.assertTrue(os.path.isfile("bsplot.png"))
        os.remove("bsplot.png")
        plt.close("all")
Ejemplo n.º 21
0
    def plot_bands(self):
        """
        Plot a band structure on symmetry line using BSPlotter()
        """
        if self.bzt_interp is None:
            raise BoltztrapError("BztInterpolator not present")

        sbs = self.bzt_interp.get_band_structure()

        return BSPlotter(sbs).get_plot()
Ejemplo n.º 22
0
def plot_bs(bs, **kwargs):
    """
    Get BS plot with pymatgen.

    Parameters
    ----------
    bs : 
        BandStructureSymmLine object, most likely generaten from Vasprun or BSVasprun.
    **kwargs : (dict)
        Arguments for the get_plot function in BSPlotter in pymatgen.

    Returns
    -------
    plt : 
        Matplotlib object.
    """
    plotter = BSPlotter(bs)
    plt = plotter.get_plot(**kwargs)

    return plt
Ejemplo n.º 23
0
class BSPlotterTest(unittest.TestCase):
    def setUp(self):
        with open(os.path.join(test_dir, "CaO_2605_bandstructure.json"),
                  "r", encoding='utf-8') as f:
            d = json.loads(f.read())
            self.bs = BandStructureSymmLine.from_dict(d)
            self.plotter = BSPlotter(self.bs)
        warnings.simplefilter("ignore")

    def tearDown(self):
        warnings.resetwarnings()

    def test_bs_plot_data(self):
        self.assertEqual(len(self.plotter.bs_plot_data()['distances'][0]), 16,
                         "wrong number of distances in the first branch")
        self.assertEqual(len(self.plotter.bs_plot_data()['distances']), 10,
                         "wrong number of branches")
        self.assertEqual(
            sum([len(e) for e in self.plotter.bs_plot_data()['distances']]),
            160, "wrong number of distances")
        self.assertEqual(self.plotter.bs_plot_data()['ticks']['label'][5], "K",
                         "wrong tick label")
        self.assertEqual(len(self.plotter.bs_plot_data()['ticks']['label']),
                         19, "wrong number of tick labels")

    # Minimal baseline testing for get_plot. not a true test. Just checks that
    # it can actually execute.
    def test_get_plot(self):
        # zero_to_efermi = True, ylim = None, smooth = False,
        # vbm_cbm_marker = False, smooth_tol = None

        # Disabling latex is needed for this test to work.
        from matplotlib import rc
        rc('text', usetex=False)

        plt = self.plotter.get_plot()
        plt = self.plotter.get_plot(smooth=True)
        plt = self.plotter.get_plot(vbm_cbm_marker=True)
        self.plotter.save_plot("bsplot.png")
        self.assertTrue(os.path.isfile("bsplot.png"))
        os.remove("bsplot.png")
Ejemplo n.º 24
0
def get_small_plot(bs):

    plot_small = BSPlotter(bs).bs_plot_data()

    gap = bs.get_band_gap()["energy"]
    for branch in plot_small['energy']:
        for spin, v in branch.items():
            new_bands = []
            for band in v:
                if min(band) < gap + 3 and max(band) > -3:
                    new_bands.append(band)
            branch[spin] = new_bands
    return plot_small
Ejemplo n.º 25
0
def plot_band_structure(ylim=(-5, 5), draw_fermi=False, fmt="pdf"):
    """
    Plot a standard band structure with no projections.

    Args:
        ylim (tuple): minimum and maximum potentials for the plot's
            y-axis.
        draw_fermi (bool): whether or not to draw a dashed line at
            E_F.
        fmt (str): matplotlib format style. Check the matplotlib
            docs for options.
    """

    vasprun = Vasprun("vasprun.xml")
    efermi = vasprun.efermi
    bsp = BSPlotter(vasprun.get_band_structure("KPOINTS", line_mode=True, efermi=efermi))
    plot = bsp.get_plot(ylim=ylim)
    fig = plot.gcf()
    ax = fig.gca()
    ax.set_xticklabels([r"$\mathrm{%s}$" % t for t in ax.get_xticklabels()])
    if draw_fermi:
        ax.plot([ax.get_xlim()[0], ax.get_xlim()[1]], [0, 0], "k--")
    fig.savefig("band_structure.{}".format(fmt), transparent=True)
    plt.close()
Ejemplo n.º 26
0
from pymatgen.io.vasp.outputs import BSVasprun
from pymatgen.electronic_structure.plotter import BSPlotter
import os

os.chdir('/home/jinho93/half-metal/1.CrO2/3.band')
vrun = BSVasprun('vasprun.xml')
bs = vrun.get_band_structure('KPOINTS', line_mode=True)
bsp = BSPlotter(bs)

bsp.show()
Ejemplo n.º 27
0
# Copyright (c) Henniggroup.
# Distributed under the terms of the MIT License.

from __future__ import division, unicode_literals, print_function

"""
reads in KPOINTS(with labels for high symmetry kpoints) and 
vasprun.xml files and plots the band structure along the high 
symmetry kpoints
"""

# To use matplotlib on Hipergator, uncomment the following 2 lines:
# import matplotlib
# matplotlib.use('Agg')

from pymatgen.io.vasp.outputs import Vasprun
from pymatgen.electronic_structure.plotter import BSPlotterProjected, BSPlotter

if __name__ == "__main__":
    # readin bandstructure from vasprun.xml and labeled KPOINTS
    run = Vasprun("vasprun.xml", parse_projected_eigen=True)
    bands = run.get_band_structure("KPOINTS", line_mode=True,
                                   efermi=run.efermi)
    bsp = BSPlotter(bands)
    # Blue lines are up spin, red lines are down spin
    bsp.save_plot('band_diagram.eps', ylim=(-5, 5))
    # bsp = BSPlotterProjected(bands)
    # plt = bsp.get_projected_plots_dots( {'Fe':['s', 'p', 'd'],
    #                                     'Sb':['s', 'p', 'd']})
    # get_elt_projected_plots_color()
Ejemplo n.º 28
0
dosplotter = DosPlotter()
Totaldos = dosplotter.add_dos('Total DOS', tdos)
Integrateddos = dosplotter.add_dos('Integrated DOS', idos)
#Pdos = dosplotter.add_dos('Partial DOS',pdos)
#Spd_dos =  dosplotter.add_dos('spd DOS',spd_dos)
#Element_dos = dosplotter.add_dos('Element DOS',element_dos)
#Element_spd_dos = dosplotter.add_dos('Element_spd DOS',element_spd_dos)
dos_dict = {
    'Total DOS': tdos,
    'Integrated DOS': idos
}  #'Partial DOS':pdos,'spd DOS':spd_dos,'Element DOS':element_dos}#'Element_spd DOS':element_spd_dos
add_dos_dict = dosplotter.add_dos_dict(dos_dict)
get_dos_dict = dosplotter.get_dos_dict()
dos_plot = dosplotter.get_plot()
##dosplotter.save_plot("MAPbI3_dos",img_format="png")
##dos_plot.show()
bsplotter = BSPlotter(bs)
bs_plot_data = bsplotter.bs_plot_data()
bs_plot = bsplotter.get_plot()
#bsplotter.save_plot("MAPbI3_bs",img_format="png")
#bsplotter.show()
ticks = bsplotter.get_ticks()
print ticks
bsplotter.plot_brillouin()
bsdos = BSDOSPlotter(
    tick_fontsize=10,
    egrid_interval=20,
    dos_projection="orbitals",
    bs_legend=None)  #bs_projection="HPbCIN",dos_projection="HPbCIN")
bds = bsdos.get_plot(bs, cdos)
Ejemplo n.º 29
0
#!/nfshome/villa/anaconda3/bin/python

from pymatgen.io.vasp.outputs import Vasprun
from pymatgen.electronic_structure.plotter import BSPlotter, BSDOSPlotter

vaspout_bs = Vasprun("vasprun.xml")
vaspout_dos = Vasprun("vasprun.xml")

bandstr = vaspout_bs.get_band_structure(line_mode=True, force_hybrid_mode=True)
bs_data = BSPlotter(bandstr).bs_plot_data()

dos = vaspout_dos.complete_dos
dos_dict = dos.as_dict()

#plt = BSPlotter(bandstr).get_plot()
plt = BSDOSPlotter(bs_projection=None,
                   dos_projection='elements',
                   bs_legend=None,
                   fig_size=(14, 11)).get_plot(bandstr, dos)
plt.savefig("dos_bs.png")
Ejemplo n.º 30
0
    def get_plot(
        self,
        n_idx,
        t_idx,
        zero_to_efermi=True,
        estep=0.01,
        line_density=100,
        height=6,
        width=6,
        emin=None,
        emax=None,
        ylabel="Energy (eV)",
        plt=None,
        aspect=None,
        distance_factor=10,
        kpath=None,
        style=None,
        no_base_style=False,
        fonts=None,
    ):
        interpolater = self._get_interpolater(n_idx, t_idx)

        bs, prop = interpolater.get_line_mode_band_structure(
            line_density=line_density,
            return_other_properties=True,
            kpath=kpath)

        fd_emin, fd_emax = self.fd_cutoffs
        if not emin:
            emin = fd_emin * hartree_to_ev
            if zero_to_efermi:
                emin -= bs.efermi

        if not emax:
            emax = fd_emax * hartree_to_ev
            if zero_to_efermi:
                emax -= bs.efermi

        logger.info("Plotting band structure")
        plt = pretty_plot(width=width, height=height, plt=plt)
        ax = plt.gca()

        if zero_to_efermi:
            bs.bands = {s: b - bs.efermi for s, b in bs.bands.items()}
            bs.efermi = 0

        bs_plotter = BSPlotter(bs)
        plot_data = bs_plotter.bs_plot_data(zero_to_efermi=zero_to_efermi)

        energies = np.linspace(emin, emax, int((emax - emin) / estep))
        distances = np.array([d for x in plot_data["distances"] for d in x])

        # rates are currently log(rate)
        rates = {}
        for spin, spin_data in prop.items():
            rates[spin] = spin_data["rates"]
            rates[spin][rates[spin] <= 0] = np.min(
                rates[spin][rates[spin] > 0])
            rates[spin][rates[spin] >= 15] = 15

        interp_distances = np.linspace(distances.min(), distances.max(),
                                       int(len(distances) * distance_factor))

        window = np.min([len(distances) - 2, 71])
        window += window % 2 + 1
        mesh_data = np.full((len(interp_distances), len(energies)), 1e-2)

        for spin in self.spins:
            for spin_energies, spin_rates in zip(bs.bands[spin], rates[spin]):
                interp_energies = interp1d(distances,
                                           spin_energies)(interp_distances)
                spin_rates = savgol_filter(spin_rates, window, 3)
                interp_rates = interp1d(distances,
                                        spin_rates)(interp_distances)
                linewidths = 10**interp_rates * hbar / 2

                for d_idx in range(len(interp_distances)):
                    energy = interp_energies[d_idx]
                    linewidth = linewidths[d_idx]

                    broadening = lorentzian(energies, energy, linewidth)
                    mesh_data[d_idx] = np.maximum(broadening, mesh_data[d_idx])
                    mesh_data[d_idx] = np.maximum(broadening, mesh_data[d_idx])

        ax.pcolormesh(
            interp_distances,
            energies,
            mesh_data.T,
            rasterized=True,
            norm=LogNorm(vmin=mesh_data.min(), vmax=mesh_data.max()),
        )

        _maketicks(ax, bs_plotter, ylabel=ylabel)
        _makeplot(
            ax,
            plot_data,
            bs,
            zero_to_efermi=zero_to_efermi,
            width=width,
            height=height,
            ymin=emin,
            ymax=emax,
            aspect=aspect,
        )
        return plt
Ejemplo n.º 31
0
        def bs_dos_traces(bandStructureSymmLine, densityOfStates):

            if bandStructureSymmLine == "error" or densityOfStates == "error":
                return "error"

            if bandStructureSymmLine == None or densityOfStates == None:
                raise PreventUpdate

            # - BS Data
            bstraces = []

            bs_reg_plot = BSPlotter(BSML.from_dict(bandStructureSymmLine))

            bs_data = bs_reg_plot.bs_plot_data()

            # -- Strip latex math wrapping
            str_replace = {
                "$": "",
                "\\mid": "|",
                "\\Gamma": "Γ",
                "\\Sigma": "Σ",
                "_1": "₁",
                "_2": "₂",
                "_3": "₃",
                "_4": "₄",
            }

            for entry_num in range(len(bs_data["ticks"]["label"])):
                for key in str_replace.keys():
                    if key in bs_data["ticks"]["label"][entry_num]:
                        bs_data["ticks"]["label"][entry_num] = bs_data[
                            "ticks"]["label"][entry_num].replace(
                                key, str_replace[key])

            for d in range(len(bs_data["distances"])):
                for i in range(bs_reg_plot._nb_bands):
                    bstraces.append(
                        go.Scatter(
                            x=bs_data["distances"][d],
                            y=[
                                bs_data["energy"][d][str(Spin.up)][i][j]
                                for j in range(len(bs_data["distances"][d]))
                            ],
                            mode="lines",
                            line=dict(color=("#666666"), width=2),
                            hoverinfo="skip",
                            showlegend=False,
                        ))

                    if bs_reg_plot._bs.is_spin_polarized:
                        bstraces.append(
                            go.Scatter(
                                x=bs_data["distances"][d],
                                y=[
                                    bs_data["energy"][d][str(Spin.down)][i][j]
                                    for j in range(len(bs_data["distances"]
                                                       [d]))
                                ],
                                mode="lines",
                                line=dict(color=("#666666"),
                                          width=2,
                                          dash="dash"),
                                hoverinfo="skip",
                                showlegend=False,
                            ))

            # -- DOS Data
            dostraces = []

            dos = CompleteDos.from_dict(densityOfStates)

            if Spin.down in dos.densities:
                # Add second spin data if available
                trace_tdos = go.Scatter(
                    x=dos.densities[Spin.down],
                    y=dos.energies - dos.efermi,
                    mode="lines",
                    name="Total DOS (spin ↓)",
                    line=go.scatter.Line(color="#444444", dash="dash"),
                    fill="tozeroy",
                )

                dostraces.append(trace_tdos)

                tdos_label = "Total DOS (spin ↑)"
            else:
                tdos_label = "Total DOS"

            # Total DOS
            trace_tdos = go.Scatter(
                x=dos.densities[Spin.up],
                y=dos.energies - dos.efermi,
                mode="lines",
                name=tdos_label,
                line=go.scatter.Line(color="#444444"),
                fill="tozeroy",
                legendgroup="spinup",
            )

            dostraces.append(trace_tdos)

            p_ele_dos = dos.get_element_dos()

            # Projected DOS
            count = 0
            colors = [
                "#1f77b4",  # muted blue
                "#ff7f0e",  # safety orange
                "#2ca02c",  # cooked asparagus green
                "#d62728",  # brick red
                "#9467bd",  # muted purple
                "#8c564b",  # chestnut brown
                "#e377c2",  # raspberry yogurt pink
                "#bcbd22",  # curry yellow-green
                "#17becf",  # blue-teal
            ]

            for ele in p_ele_dos.keys():

                if bs_reg_plot._bs.is_spin_polarized:
                    trace = go.Scatter(
                        x=p_ele_dos[ele].densities[Spin.down],
                        y=dos.energies - dos.efermi,
                        mode="lines",
                        name=ele.symbol + " (spin ↓)",
                        line=dict(width=3, color=colors[count], dash="dash"),
                    )

                    dostraces.append(trace)
                    spin_up_label = ele.symbol + " (spin ↑)"

                else:
                    spin_up_label = ele.symbol

                trace = go.Scatter(
                    x=p_ele_dos[ele].densities[Spin.up],
                    y=dos.energies - dos.efermi,
                    mode="lines",
                    name=spin_up_label,
                    line=dict(width=3, color=colors[count]),
                )

                dostraces.append(trace)

                count += 1

            traces = [bstraces, dostraces, bs_data]

            return traces
Ejemplo n.º 32
0
# -*- coding: utf-8 -*-
"""
Created on Wed May 15 10:33:43 2019

@author: nwpuf
"""

from pymatgen.io.vasp import Vasprun, BSVasprun
from pymatgen.electronic_structure.plotter import BSPlotter, DosPlotter
import matplotlib.pyplot as plt

bsv = BSVasprun("vasprun.xml")
bs = bsv.get_band_structure(kpoints_filename="KPOINTS", line_mode=True)
print(bs.get_band_gap())
bsplot = BSPlotter(bs)
bsplot.get_plot(vbm_cbm_marker=True).show()

#dosrun = Vasprun("DOS/vasprun.xml", parse_dos=True)
#dos = dosrun.complete_dos
#dosplot = DosPlotter(sigma=0.1)
#dosplot.add_dos("Total DOS", dos)
#dosplot.add_dos_dict(dos.get_element_dos())
#ax = plt.gca()
#print(type(dosplot.get_plot()))
#dosplot.get_plot().show()
Ejemplo n.º 33
0
# Distributed under the terms of the MIT License.

from __future__ import division, print_function, unicode_literals, \
    absolute_import
"""
reads in KPOINTS(with labels for high symmetry kpoints) and 
vasprun.xml files and plots the band structure along the high 
symmetry kpoints
"""

# To use matplotlib on Hipergator, uncomment the following 2 lines:
# import matplotlib
# matplotlib.use('Agg')

from pymatgen.io.vasp.outputs import Vasprun
from pymatgen.electronic_structure.plotter import BSPlotter

if __name__ == "__main__":
    # readin bandstructure from vasprun.xml and labeled KPOINTS
    run = Vasprun("vasprun.xml", parse_projected_eigen=True)
    bands = run.get_band_structure("KPOINTS",
                                   line_mode=True,
                                   efermi=run.efermi)
    bsp = BSPlotter(bands)
    # Blue lines are up spin, red lines are down spin
    bsp.save_plot('band_diagram.eps', ylim=(-5, 5))
    # bsp = BSPlotterProjected(bands)
    # plt = bsp.get_projected_plots_dots( {'Fe':['s', 'p', 'd'],
    #                                     'Sb':['s', 'p', 'd']})
    # get_elt_projected_plots_color()
Ejemplo n.º 34
0
def projected_band_structure():
   step_count=1
   filename='vasprun.xml'
   check_file(filename)
   proc_str="Reading Data From "+ filename +" File ..."
   procs(proc_str,step_count,sp='-->>')
   vsr=Vasprun(filename)

   filename='PROCAR'
   check_file(filename)
   step_count+=1
   proc_str="Reading Data From "+ filename +" File ..."
   procs(proc_str,step_count,sp='-->>')
   procar=Procar(filename)
   nbands=procar.nbands
   nions=procar.nions
   norbitals=len(procar.orbitals)
   nkpoints=procar.nkpoints

   step_count+=1
   filename='KPOINTS'
   check_file(filename)
   proc_str="Reading Data From "+ filename +" File ..."
   procs(proc_str,step_count,sp='-->>')
   bands = vsr.get_band_structure(filename, line_mode=True, efermi=vsr.efermi)
   struct=vsr.final_structure
   (atom_index,in_str)=atom_selection(struct)
   
   if len(atom_index)==0:
      print("No atoms selected!")
      return
#   print(atom_index)

   if vsr.is_spin:
      proc_str="This Is a Spin-polarized Calculation."
      procs(proc_str,0,sp='-->>')
      ISPIN=2
      contrib=np.zeros((nkpoints,nbands,norbitals,2))
      for i in atom_index:
          contrib[:,:,:,0]=contrib[:,:,:,0]+procar.data[Spin.up][:,:,i,:]
          contrib[:,:,:,1]=contrib[:,:,:,1]+procar.data[Spin.down][:,:,i,:]

      for ispin in range(2):
          proj_band=contrib[:,:,:,ispin].reshape(nkpoints*nbands,norbitals)
          step_count+=1
          if ispin==0:
              filename="PBAND_Up.dat"
          else:
              filename="PBAND_Down.dat"
          proc_str="Writting Projected Band Structure Data to "+ filename +" File ..."
          procs(proc_str,step_count,sp='-->>')
          band_data=bands.bands[Spin.up]
          y_data=band_data.reshape(1,nbands*nkpoints)[0]-vsr.efermi #shift fermi level to 0
          x_data=np.array(bands.distance*nbands)
          data=np.vstack((x_data,y_data,proj_band.T)).T
          tmp1_str="#%(key1)+12s%(key2)+12s"
          tmp2_dic={'key1':'K-Distance','key2':'Energy(ev)'}
          for i in range(norbitals):
              tmp1_str+="%(key"+str(i+3)+")+12s"
              tmp2_dic["key"+str(i+3)]=procar.orbitals[i]

#          print(tmp1_str)
          atom_index_str=[str(x+1) for x in atom_index]
          head_line1="#String: "+in_str+'\n#Selected atom: ' +' '.join(atom_index_str)+'\n'
          head_line2=tmp1_str % tmp2_dic
          head_line=head_line1+head_line2
          write_col_data(filename,data,head_line,nkpoints)

   else:
      if vsr.parameters['LNONCOLLINEAR']:
         proc_str="This Is a Non-Collinear Calculation."
         procs(proc_str,0,sp='-->>')
         ISPIN=3
      else:
         proc_str="This Is a Non-Spin Calculation."
         procs(proc_str,0,sp='-->>')
         ISPIN=1

      contrib=np.zeros((nkpoints,nbands,norbitals))
      for i in atom_index:
          contrib[:,:,:]=contrib[:,:,:]+procar.data[Spin.up][:,:,i,:]
 
      proj_band=contrib.reshape(nkpoints*nbands,norbitals)        
      step_count+=1
      filename="PBAND.dat"
      proc_str="Writting Projected Band Structure Data to "+ filename +" File ..."
      procs(proc_str,step_count,sp='-->>')
      band_data=bands.bands[Spin.up]
      y_data=band_data.reshape(1,nbands*nkpoints)[0]-vsr.efermi #shift fermi level to 0
      x_data=np.array(bands.distance*nbands)
      data=np.vstack((x_data,y_data,proj_band.T)).T
      tmp1_str="#%(key1)+12s%(key2)+12s"
      tmp2_dic={'key1':'K-Distance','key2':'Energy(ev)'}
      for i in range(norbitals):
          tmp1_str+="%(key"+str(i+3)+")+12s"
          tmp2_dic["key"+str(i+3)]=procar.orbitals[i]

#      print(tmp1_str)
      atom_index_str=[str(x+1) for x in atom_index]
      head_line1="#String: "+in_str+'\n#Selected atom: ' +' '.join(atom_index_str)+'\n'
      head_line2=tmp1_str % tmp2_dic
      head_line=head_line1+head_line2
      write_col_data(filename,data,head_line,nkpoints)

   step_count+=1
   bsp=BSPlotter(bands)
   filename="HighSymmetricPoints.dat"
   proc_str="Writting Label infomation to "+ filename +" File ..."
   procs(proc_str,step_count,sp='-->>')
   head_line="#%(key1)+12s%(key2)+12s%(key3)+12s"%{'key1':'index','key2':'label','key3':'position'}
   line=head_line+'\n'
   for i,label in enumerate(bsp.get_ticks()['label']):
       new_line="%(key1)12d%(key2)+12s%(key3)12f\n"%{'key1':i,'key2':label,'key3':bsp.get_ticks()['distance'][i]}
       line+=new_line
   write_col_data(filename,line,'',str_data=True) 
Ejemplo n.º 35
0
def band_structure():
    check_matplotlib()
    step_count=1

    filename='vasprun.xml'
    check_file(filename)
    proc_str="Reading Data From "+ filename +" File ..."
    procs(proc_str,step_count,sp='-->>')
    vsr=Vasprun(filename)

    step_count+=1
    filename='KPOINTS'
    check_file(filename)
    proc_str="Reading Data From "+ filename +" File ..."
    procs(proc_str,step_count,sp='-->>')
    bands = vsr.get_band_structure(filename, line_mode=True, efermi=vsr.efermi)

    step_count+=1
    filename='OUTCAR'
    check_file(filename)
    proc_str="Reading Data From "+ filename +" File ..."
    procs(proc_str,step_count,sp='-->>')
    outcar=Outcar('OUTCAR')
    mag=outcar.as_dict()['total_magnetization']

    if vsr.is_spin:
       proc_str="This Is a Spin-polarized Calculation."
       procs(proc_str,0,sp='-->>')
       tdos=vsr.tdos
       SpinUp_gap=tdos.get_gap(spin=Spin.up) 
       cbm_vbm_up=tdos.get_cbm_vbm(spin=Spin.up)
       SpinDown_gap=tdos.get_gap(spin=Spin.down) 
       cbm_vbm_down=tdos.get_cbm_vbm(spin=Spin.up)

       if SpinUp_gap > min_gap and SpinDown_gap > min_gap:
          is_metal=False
          is_semimetal=False
       elif SpinUp_gap > min_gap and SpinDown_gap < min_gap:
          is_metal=False
          is_semimetal=True
       elif SpinUp_gap < min_gap and SpinDown_gap > min_gap:
          is_metal=False
          is_semimetal=True
       elif SpinUp_gap < min_gap and SpinDown_gap < min_gap:
          is_metal=True
          is_semimetal=False
          
       if is_metal:   
          proc_str="This Material Is a Metal."
          procs(proc_str,0,sp='-->>')
       if not is_metal and is_semimetal:
          proc_str="This Material Is a Semimetal."
          procs(proc_str,0,sp='-->>')
       else:
          proc_str="This Material Is a Semiconductor."
          procs(proc_str,0,sp='-->>')
          proc_str="Total magnetization is "+str(mag)
          procs(proc_str,0,sp='-->>')
          if mag > min_mag:
             proc_str="SpinUp  : vbm=%f eV cbm=%f eV gap=%f eV"%(cbm_vbm_up[1],cbm_vbm_up[0],SpinUp_gap)
             procs(proc_str,0,sp='-->>')
             proc_str="SpinDown: vbm=%f eV cbm=%f eV gap=%f eV"%(cbm_vbm_down[1],cbm_vbm_down[0],SpinUp_gap)
             procs(proc_str,0,sp='-->>')
          else:
             proc_str="SpinUp  : vbm=%f eV cbm=%f eV gap=%f eV"%(cbm_vbm_up[1],cbm_vbm_up[0],SpinUp_gap)
             procs(proc_str,0,sp='-->>')
       step_count+=1
       filename="BAND.dat"
       proc_str="Writting Band Structure Data to "+ filename +" File ..."
       procs(proc_str,step_count,sp='-->>')
       band_data_up=bands.bands[Spin.up]
       band_data_down=bands.bands[Spin.down]
       y_data_up=band_data_up.reshape(1,band_data_up.shape[0]*band_data_up.shape[1])[0]-vsr.efermi #shift fermi level to 0
       y_data_down=band_data_down.reshape(1,band_data_down.shape[0]*band_data_down.shape[1])[0]-vsr.efermi #shift fermi level to 0
       x_data=np.array(bands.distance*band_data_up.shape[0])
       data=np.vstack((x_data,y_data_up,y_data_down)).T
       head_line="#%(key1)+12s%(key2)+13s%(key3)+15s"%{'key1':'K-Distance','key2':'UpEnergy(ev)','key3':'DownEnergy(ev)'}
       write_col_data(filename,data,head_line,band_data_up.shape[1])
 
    else:
       if vsr.parameters['LNONCOLLINEAR']:
          proc_str="This Is a Non-Collinear Calculation."
       else:
           proc_str="This Is a Non-Spin Calculation."
       procs(proc_str,0,sp='-->>')
       cbm=bands.get_cbm()['energy']
       vbm=bands.get_vbm()['energy']
       gap=bands.get_band_gap()['energy']
       if not bands.is_metal():
          proc_str="This Material Is a Semiconductor."
          procs(proc_str,0,sp='-->>')
          proc_str="vbm=%f eV cbm=%f eV gap=%f eV"%(vbm,cbm,gap)
          procs(proc_str,0,sp='-->>')
       else:
          proc_str="This Material Is a Metal."
          procs(proc_str,0,sp='-->>')
       
       step_count+=1
       filename3="BAND.dat"
       proc_str="Writting Band Structure Data to "+ filename3 +" File ..."
       procs(proc_str,step_count,sp='-->>')
       band_data=bands.bands[Spin.up]
       y_data=band_data.reshape(1,band_data.shape[0]*band_data.shape[1])[0]-vsr.efermi #shift fermi level to 0
       x_data=np.array(bands.distance*band_data.shape[0])
       data=np.vstack((x_data,y_data)).T
       head_line="#%(key1)+12s%(key2)+13s"%{'key1':'K-Distance','key2':'Energy(ev)'}
       write_col_data(filename3,data,head_line,band_data.shape[1])
       step_count+=1
       bsp=BSPlotter(bands)
       filename4="HighSymmetricPoints.dat"
       proc_str="Writting Label infomation to "+ filename4 +" File ..."
       procs(proc_str,step_count,sp='-->>')
       head_line="#%(key1)+12s%(key2)+12s%(key3)+12s"%{'key1':'index','key2':'label','key3':'position'}
       line=head_line+'\n'
       for i,label in enumerate(bsp.get_ticks()['label']):
           new_line="%(key1)12d%(key2)+12s%(key3)12f\n"%{'key1':i,'key2':label,'key3':bsp.get_ticks()['distance'][i]}
           line+=new_line
       line+='\n'
       write_col_data(filename4,line,'',str_data=True) 
    try:
       step_count+=1
       filename5="BAND.png"
       proc_str="Saving Plot to "+ filename5 +" File ..."
       procs(proc_str,step_count,sp='-->>')
       bsp.save_plot(filename5, img_format="png")
    except:
       print("Figure output fails !!!")   
Ejemplo n.º 36
0
from pymatgen.io.vasp import Vasprun
from pymatgen.electronic_structure.plotter import BSPlotter, BSPlotterProjected

vr = Vasprun("nself/vasprun.xml")
bs = vr.get_band_structure(kpoints_filename="nself/KPOINTS", line_mode=True)
bsp = BSPlotter(bs)
#plt = bsp.get_elt_projected_plots(zero_to_efermi=False)
#plt.savefig("band_structure.png", format="png")
bsp.save_plot(filename="band_structure.png", img_format="png")
Ejemplo n.º 37
0
 def setUp(self):
     with open(os.path.join(test_dir, "CaO_2605_bandstructure.json"),
               "rb") as f:
         d = json.loads(f.read())
         self.bs = BandStructureSymmLine.from_dict(d)
         self.plotter = BSPlotter(self.bs)
Ejemplo n.º 38
0
# !/usr/bin/env python
# -*- coding: utf-8 -*-

from pymatgen.io.vasp.outputs import Vasprun
from pymatgen.electronic_structure.plotter import BSPlotter
vasprun = Vasprun("vasprun.xml")
bss = vasprun.get_band_structure(kpoints_filename="KPOINTS", line_mode=True)
plotter = BSPlotter(bss)
#plotter.save_plot("bandStructure.svg", img_format="svg")
#plotter.save_plot("bandStructure.png", img_format="png")
#plotter.save_plot("lim_bandStructure.svg", img_format="svg", ylim=(-.2, 1.4))
plotter.save_plot("MAPbI3-primitive.png", img_format="png", ylim=(-5, 5))
plotter.plot_brillouin()
Ejemplo n.º 39
0
    def get_plot(
        self,
        n_idx,
        t_idx,
        zero_to_efermi=True,
        estep=0.01,
        line_density=100,
        height=3.2,
        width=3.2,
        emin=None,
        emax=None,
        amin=5e-5,
        amax=1e-1,
        ylabel="Energy (eV)",
        plt=None,
        aspect=None,
        kpath=None,
        cmap="viridis",
        colorbar=True,
        style=None,
        no_base_style=False,
        fonts=None,
    ):
        interpolater = self._get_interpolater(n_idx, t_idx)

        bs, prop = interpolater.get_line_mode_band_structure(
            line_density=line_density,
            return_other_properties=True,
            kpath=kpath,
            symprec=self.symprec,
        )
        bs, rates = force_branches(bs, {s: p["rates"] for s, p in prop.items()})

        fd_emin, fd_emax = self.fd_cutoffs
        if not emin:
            emin = fd_emin
            if zero_to_efermi:
                emin -= bs.efermi

        if not emax:
            emax = fd_emax
            if zero_to_efermi:
                emax -= bs.efermi

        logger.info("Plotting band structure")
        if isinstance(plt, (Axis, SubplotBase)):
            ax = plt
        else:
            plt = pretty_plot(width=width, height=height, plt=plt)
            ax = plt.gca()

        if zero_to_efermi:
            bs.bands = {s: b - bs.efermi for s, b in bs.bands.items()}
            bs.efermi = 0

        bs_plotter = BSPlotter(bs)
        plot_data = bs_plotter.bs_plot_data(zero_to_efermi=zero_to_efermi)

        energies = np.linspace(emin, emax, int((emax - emin) / estep))
        distances = np.array([d for x in plot_data["distances"] for d in x])

        # rates are currently log(rate)
        mesh_data = np.full((len(distances), len(energies)), 0.0)
        for spin in self.spins:
            for spin_energies, spin_rates in zip(bs.bands[spin], rates[spin]):
                for d_idx in range(len(distances)):
                    energy = spin_energies[d_idx]
                    linewidth = 10 ** spin_rates[d_idx] * hbar / 2
                    broadening = lorentzian(energies, energy, linewidth)
                    broadening /= 1000  # convert 1/eV to 1/meV
                    mesh_data[d_idx] += broadening

        im = ax.pcolormesh(
            distances,
            energies,
            mesh_data.T,
            rasterized=True,
            cmap=cmap,
            norm=LogNorm(vmin=amin, vmax=amax),
            shading="auto",
        )
        if colorbar:
            pos = ax.get_position()
            cax = plt.gcf().add_axes([pos.x1 + 0.035, pos.y0, 0.035, pos.height])
            cbar = plt.colorbar(im, cax=cax)
            cbar.ax.tick_params(axis="y", length=rcParams["ytick.major.size"] * 0.5)
            cbar.ax.set_ylabel(
                r"$A_\mathbf{k}$ (meV$^{-1}$)", rotation=270, va="bottom"
            )

        _maketicks(ax, bs_plotter, ylabel=ylabel)
        _makeplot(
            ax,
            plot_data,
            bs,
            zero_to_efermi=zero_to_efermi,
            width=width,
            height=height,
            ymin=emin,
            ymax=emax,
            aspect=aspect,
        )
        return plt
Ejemplo n.º 40
0
from __future__ import division, unicode_literals, print_function

"""
reads in KPOINTS(with labels for high symmetry kpoints) and 
vasprun.xml files and plots the band structure along the high 
symmetry kpoints
"""

# To use matplotlib on Hipergator, uncomment the following 2 lines:
# import matplotlib
# matplotlib.use('Agg')

from pymatgen.io.vasp.outputs import Vasprun
from pymatgen.electronic_structure.plotter import BSPlotterProjected, BSPlotter

if __name__ == "__main__":
    # readin bandstructure from vasprun.xml and labeled KPOINTS
    run = Vasprun("vasprun.xml", parse_projected_eigen=True)
    bands = run.get_band_structure("KPOINTS", line_mode=True, efermi=run.efermi)
    bsp = BSPlotter(bands)
    # Blue lines are up spin, red lines are down spin
    bsp.save_plot("band_diagram.eps", ylim=(-5, 5))
    # bsp = BSPlotterProjected(bands)
    # plt = bsp.get_projected_plots_dots( {'Fe':['s', 'p', 'd'],
    #                                     'Sb':['s', 'p', 'd']})
    # get_elt_projected_plots_color()