Exemplo n.º 1
0
class BSPlotterTest(unittest.TestCase):
    def setUp(self):
        with open(os.path.join(test_dir, "CaO_2605_bandstructure.json"),
                  "r",
                  encoding='utf-8') as f:
            d = json.loads(f.read())
            self.bs = BandStructureSymmLine.from_dict(d)
            self.plotter = BSPlotter(self.bs)

    def test_bs_plot_data(self):
        self.assertEqual(len(self.plotter.bs_plot_data()['distances'][0]), 16,
                         "wrong number of distances in the first branch")
        self.assertEqual(len(self.plotter.bs_plot_data()['distances']), 10,
                         "wrong number of branches")
        self.assertEqual(
            sum([len(e) for e in self.plotter.bs_plot_data()['distances']]),
            160, "wrong number of distances")
        self.assertEqual(self.plotter.bs_plot_data()['ticks']['label'][5], "K",
                         "wrong tick label")
        self.assertEqual(len(self.plotter.bs_plot_data()['ticks']['label']),
                         19, "wrong number of tick labels")

    def test_qvertex_target(self):
        results = _qvertex_target(
            [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 1.0, 0.0],
             [0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 1.0],
             [1.0, 1.0, 1.0], [0.0, 1.0, 1.0], [0.5, 0.5, 0.5]], 8)
        self.assertEqual(len(results), 6)
        self.assertEqual(results[3][1], 0.5)
Exemplo n.º 2
0
class BSPlotterTest(unittest.TestCase):

    def setUp(self):
        with open(os.path.join(test_dir, "CaO_2605_bandstructure.json"),
                  "r", encoding='utf-8') as f:
            d = json.loads(f.read())
            self.bs = BandStructureSymmLine.from_dict(d)
            self.plotter = BSPlotter(self.bs)

    def test_bs_plot_data(self):
        self.assertEqual(len(self.plotter.bs_plot_data()['distances'][0]), 16,
                         "wrong number of distances in the first branch")
        self.assertEqual(len(self.plotter.bs_plot_data()['distances']), 10,
                         "wrong number of branches")
        self.assertEqual(
            sum([len(e) for e in self.plotter.bs_plot_data()['distances']]),
            160, "wrong number of distances")
        self.assertEqual(self.plotter.bs_plot_data()['ticks']['label'][5], "K",
                         "wrong tick label")
        self.assertEqual(len(self.plotter.bs_plot_data()['ticks']['label']),
                         19, "wrong number of tick labels")

    def test_qvertex_target(self):
        results = _qvertex_target([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0],
                                   [1.0, 1.0, 0.0], [0.0, 1.0, 0.0],
                                   [0.0, 0.0, 1.0], [1.0, 0.0, 1.0],
                                   [1.0, 1.0, 1.0], [0.0, 1.0, 1.0],
                                   [0.5, 0.5, 0.5]], 8)
        self.assertEqual(len(results), 6)
        self.assertEqual(results[3][1], 0.5)
Exemplo n.º 3
0
class BSPlotterTest(unittest.TestCase):
    def setUp(self):
        with open(os.path.join(test_dir, "CaO_2605_bandstructure.json"),
                  "rb") as f:
            d = json.loads(f.read())
            self.bs = BandStructureSymmLine.from_dict(d)
            self.plotter = BSPlotter(self.bs)

    def test_bs_plot_data(self):
        self.assertEqual(len(self.plotter.bs_plot_data()['distances']), 160,
                         "wrong number of distances")
        self.assertEqual(self.plotter.bs_plot_data()['ticks']['label'][5], "K",
                         "wrong tick label")
        self.assertEqual(len(self.plotter.bs_plot_data()['ticks']['label']),
                         19, "wrong number of tick labels")
Exemplo n.º 4
0
def find_dirac_nodes():
    """
    Look for band crossings near (within `tol` eV) the Fermi level.

    Returns:
        boolean. Whether or not a band crossing occurs at or near
            the fermi level.
    """

    vasprun = Vasprun('vasprun.xml')
    dirac = False
    if vasprun.get_band_structure().get_band_gap()['energy'] < 0.1:
        efermi = vasprun.efermi
        bsp = BSPlotter(vasprun.get_band_structure('KPOINTS', line_mode=True,
                                                   efermi=efermi))
        bands = []
        data = bsp.bs_plot_data(zero_to_efermi=True)
        for d in range(len(data['distances'])):
            for i in range(bsp._nb_bands):
                x = data['distances'][d],
                y = [data['energy'][d][str(Spin.up)][i][j]
                     for j in range(len(data['distances'][d]))]
                band = [x, y]
                bands.append(band)

        considered = []
        for i in range(len(bands)):
            for j in range(len(bands)):
                if i != j and (j, i) not in considered:
                    considered.append((j, i))
                    for k in range(len(bands[i][0])):
                        if ((-0.1 < bands[i][1][k] < 0.1) and
                                (-0.1 < bands[i][1][k] - bands[j][1][k] < 0.1)):
                            dirac = True
    return dirac
Exemplo n.º 5
0
def plot_band_structure(ylim=(-5, 5), draw_fermi=False, fmt='pdf'):
    """
    Plot a standard band structure with no projections.

    Args:
        ylim (tuple): minimum and maximum potentials for the plot's y-axis.
        draw_fermi (bool): whether or not to draw a dashed line at E_F.
        fmt (str): matplotlib format style. Check the matplotlib docs
            for options.
    """

    vasprun = Vasprun('vasprun.xml')
    efermi = vasprun.efermi
    bsp = BSPlotter(
        vasprun.get_band_structure('KPOINTS', line_mode=True, efermi=efermi))
    if fmt == "None":
        return bsp.bs_plot_data()
    else:
        plot = bsp.get_plot(ylim=ylim)
        fig = plot.gcf()
        ax = fig.gca()
        ax.set_xticklabels(
            [r'$\mathrm{%s}$' % t for t in ax.get_xticklabels()])
        ax.set_yticklabels(
            [r'$\mathrm{%s}$' % t for t in ax.get_yticklabels()])
        if draw_fermi:
            ax.plot([ax.get_xlim()[0], ax.get_xlim()[1]], [0, 0], 'k--')
        plt.savefig('band_structure.{}'.format(fmt), transparent=True)

    plt.close()
Exemplo n.º 6
0
def find_dirac_nodes():
    """
    Look for band crossings near (within `tol` eV) the Fermi level.

    Returns:
        boolean. Whether or not a band crossing occurs at or near
            the fermi level.
    """

    vasprun = Vasprun('vasprun.xml')
    dirac = False
    if vasprun.get_band_structure().get_band_gap()['energy'] < 0.1:
        efermi = vasprun.efermi
        bsp = BSPlotter(vasprun.get_band_structure('KPOINTS', line_mode=True,
                                                   efermi=efermi))
        bands = []
        data = bsp.bs_plot_data(zero_to_efermi=True)
        for d in range(len(data['distances'])):
            for i in range(bsp._nb_bands):
                x = data['distances'][d],
                y = [data['energy'][d][str(Spin.up)][i][j]
                     for j in range(len(data['distances'][d]))]
                band = [x, y]
                bands.append(band)

        considered = []
        for i in range(len(bands)):
            for j in range(len(bands)):
                if i != j and (j, i) not in considered:
                    considered.append((j, i))
                    for k in range(len(bands[i][0])):
                        if ((-0.1 < bands[i][1][k] < 0.1) and
                                (-0.1 < bands[i][1][k] - bands[j][1][k] < 0.1)):
                            dirac = True
    return dirac
Exemplo n.º 7
0
    def make_band_plot_info(self):
        bs_plotter = BSPlotter(self.bs)
        plot_data = bs_plotter.bs_plot_data(zero_to_efermi=False)
        distances = [list(d) for d in plot_data["distances"]]
        self._composition = self.vasprun.final_structure.composition

        band_info = [BandInfo(band_energies=self._remove_spin_key(plot_data),
                              band_edge=self._band_edge(self.bs, plot_data),
                              fermi_level=self.bs.efermi)]

        if self.vasprun2:
            bs2 = self.vasprun2.get_band_structure(self.kpoints_filename,
                                                   line_mode=True)
            plot_data2 = BSPlotter(bs2).bs_plot_data(zero_to_efermi=False)
            band_info.append(
                BandInfo(band_energies=self._remove_spin_key(plot_data2),
                         band_edge=self._band_edge(bs2, plot_data2),
                         fermi_level=self.bs.efermi))

        x = bs_plotter.get_ticks_old()
        x_ticks = XTicks(_sanitize_labels(x["label"]), x["distance"])

        return BandPlotInfo(band_info_set=band_info,
                            distances_by_branch=distances,
                            x_ticks=x_ticks,
                            title=self._title)
Exemplo n.º 8
0
class BSPlotterTest(unittest.TestCase):

    def setUp(self):
        with open(os.path.join(test_dir, "CaO_2605_bandstructure.json"),
                  "rb") as f:
            d = json.loads(f.read())
            self.bs = BandStructureSymmLine.from_dict(d)
            self.plotter = BSPlotter(self.bs)

    def test_bs_plot_data(self):
        self.assertEqual(len(self.plotter.bs_plot_data()['distances']), 160,
                         "wrong number of distances")
        self.assertEqual(self.plotter.bs_plot_data()['ticks']['label'][5], "K",
                         "wrong tick label")
        self.assertEqual(len(self.plotter.bs_plot_data()['ticks']['label']),
                         19, "wrong number of tick labels")
Exemplo n.º 9
0
class BSPlotterTest(unittest.TestCase):
    def setUp(self):
        with open(os.path.join(test_dir, "CaO_2605_bandstructure.json"),
                  "r",
                  encoding='utf-8') as f:
            d = json.loads(f.read())
            self.bs = BandStructureSymmLine.from_dict(d)
            self.plotter = BSPlotter(self.bs)
        warnings.simplefilter("ignore")

    def tearDown(self):
        warnings.simplefilter("default")

    def test_bs_plot_data(self):
        self.assertEqual(len(self.plotter.bs_plot_data()['distances'][0]), 16,
                         "wrong number of distances in the first branch")
        self.assertEqual(len(self.plotter.bs_plot_data()['distances']), 10,
                         "wrong number of branches")
        self.assertEqual(
            sum([len(e) for e in self.plotter.bs_plot_data()['distances']]),
            160, "wrong number of distances")
        self.assertEqual(self.plotter.bs_plot_data()['ticks']['label'][5], "K",
                         "wrong tick label")
        self.assertEqual(len(self.plotter.bs_plot_data()['ticks']['label']),
                         19, "wrong number of tick labels")

    # Minimal baseline testing for get_plot. not a true test. Just checks that
    # it can actually execute.
    def test_get_plot(self):
        # zero_to_efermi = True, ylim = None, smooth = False,
        # vbm_cbm_marker = False, smooth_tol = None

        # Disabling latex is needed for this test to work.
        from matplotlib import rc
        rc('text', usetex=False)

        plt = self.plotter.get_plot()
        plt = self.plotter.get_plot(smooth=True)
        plt = self.plotter.get_plot(vbm_cbm_marker=True)
        self.plotter.save_plot("bsplot.png")
        self.assertTrue(os.path.isfile("bsplot.png"))
        os.remove("bsplot.png")
        plt.close("all")
Exemplo n.º 10
0
def banddos(pref='',storedir=None):
    ru=str("vasprun.xml")
    kpfile=str("KPOINTS")




    run = Vasprun(ru, parse_projected_eigen = True)
    bands = run.get_band_structure(kpfile, line_mode = True, efermi = run.efermi)
    bsp =  BSPlotter(bands)
    zero_to_efermi=True
    bandgap=str(round(bands.get_band_gap()['energy'],3))
    print "bg=",bandgap
    data=bsp.bs_plot_data(zero_to_efermi)
    plt = get_publication_quality_plot(12, 8)
    band_linewidth = 3
    x_max = data['distances'][-1][-1]
    print (x_max)
    for d in range(len(data['distances'])):
       for i in range(bsp._nb_bands):
          plt.plot(data['distances'][d],
                 [data['energy'][d]['1'][i][j]
                  for j in range(len(data['distances'][d]))], 'b-',
                 linewidth=band_linewidth)
          if bsp._bs.is_spin_polarized:
             plt.plot(data['distances'][d],
                     [data['energy'][d]['-1'][i][j]
                      for j in range(len(data['distances'][d]))],
                     'r--', linewidth=band_linewidth)
    bsp._maketicks(plt)
    if bsp._bs.is_metal():
         e_min = -10
         e_max = 10
         band_linewidth = 3

    for cbm in data['cbm']:
            plt.scatter(cbm[0], cbm[1], color='r', marker='o',
                        s=100)

            for vbm in data['vbm']:
                plt.scatter(vbm[0], vbm[1], color='g', marker='o',
                            s=100)


    plt.xlabel(r'$\mathrm{Wave\ Vector}$', fontsize=30)
    ylabel = r'$\mathrm{E\ -\ E_f\ (eV)}$' if zero_to_efermi \
       else r'$\mathrm{Energy\ (eV)}$'
    plt.ylabel(ylabel, fontsize=30)
    plt.ylim(-4,4)
    plt.xlim(0,x_max)
    plt.tight_layout()
    plt.savefig('BAND.png',img_format="png")

    plt.close()
Exemplo n.º 11
0
class BSPlotterTest(unittest.TestCase):
    def setUp(self):
        with open(os.path.join(test_dir, "CaO_2605_bandstructure.json"),
                  "r", encoding='utf-8') as f:
            d = json.loads(f.read())
            self.bs = BandStructureSymmLine.from_dict(d)
            self.plotter = BSPlotter(self.bs)
        warnings.simplefilter("ignore")

    def tearDown(self):
        warnings.resetwarnings()

    def test_bs_plot_data(self):
        self.assertEqual(len(self.plotter.bs_plot_data()['distances'][0]), 16,
                         "wrong number of distances in the first branch")
        self.assertEqual(len(self.plotter.bs_plot_data()['distances']), 10,
                         "wrong number of branches")
        self.assertEqual(
            sum([len(e) for e in self.plotter.bs_plot_data()['distances']]),
            160, "wrong number of distances")
        self.assertEqual(self.plotter.bs_plot_data()['ticks']['label'][5], "K",
                         "wrong tick label")
        self.assertEqual(len(self.plotter.bs_plot_data()['ticks']['label']),
                         19, "wrong number of tick labels")

    # Minimal baseline testing for get_plot. not a true test. Just checks that
    # it can actually execute.
    def test_get_plot(self):
        # zero_to_efermi = True, ylim = None, smooth = False,
        # vbm_cbm_marker = False, smooth_tol = None

        # Disabling latex is needed for this test to work.
        from matplotlib import rc
        rc('text', usetex=False)

        plt = self.plotter.get_plot()
        plt = self.plotter.get_plot(smooth=True)
        plt = self.plotter.get_plot(vbm_cbm_marker=True)
        self.plotter.save_plot("bsplot.png")
        self.assertTrue(os.path.isfile("bsplot.png"))
        os.remove("bsplot.png")
Exemplo n.º 12
0
    def get_bandsxy(self, bs, bandrange):

        ## get coords of the band structure data points
        bsplot = BSPlotter(bs)
        data = bsplot.bs_plot_data()

        x = [k for kbranch in data["distances"] for k in kbranch]
        yu = [[e - bs.efermi for e in bs.bands[Spin.up][band]] for band in bandrange]
        if bs.is_spin_polarized:
            yd = [[e - bs.efermi for e in bs.bands[Spin.down][band]] for band in bandrange]
        else:
            yd = None

        return [x, yu, yd]
Exemplo n.º 13
0
    def get_plot(
        self,
        n_idx,
        t_idx,
        zero_to_efermi=True,
        estep=0.01,
        line_density=100,
        height=3.2,
        width=3.2,
        emin=None,
        emax=None,
        amin=5e-5,
        amax=1e-1,
        ylabel="Energy (eV)",
        plt=None,
        aspect=None,
        kpath=None,
        cmap="viridis",
        colorbar=True,
        style=None,
        no_base_style=False,
        fonts=None,
    ):
        interpolater = self._get_interpolater(n_idx, t_idx)

        bs, prop = interpolater.get_line_mode_band_structure(
            line_density=line_density,
            return_other_properties=True,
            kpath=kpath,
            symprec=self.symprec,
        )
        bs, rates = force_branches(bs, {s: p["rates"] for s, p in prop.items()})

        fd_emin, fd_emax = self.fd_cutoffs
        if not emin:
            emin = fd_emin
            if zero_to_efermi:
                emin -= bs.efermi

        if not emax:
            emax = fd_emax
            if zero_to_efermi:
                emax -= bs.efermi

        logger.info("Plotting band structure")
        if isinstance(plt, (Axis, SubplotBase)):
            ax = plt
        else:
            plt = pretty_plot(width=width, height=height, plt=plt)
            ax = plt.gca()

        if zero_to_efermi:
            bs.bands = {s: b - bs.efermi for s, b in bs.bands.items()}
            bs.efermi = 0

        bs_plotter = BSPlotter(bs)
        plot_data = bs_plotter.bs_plot_data(zero_to_efermi=zero_to_efermi)

        energies = np.linspace(emin, emax, int((emax - emin) / estep))
        distances = np.array([d for x in plot_data["distances"] for d in x])

        # rates are currently log(rate)
        mesh_data = np.full((len(distances), len(energies)), 0.0)
        for spin in self.spins:
            for spin_energies, spin_rates in zip(bs.bands[spin], rates[spin]):
                for d_idx in range(len(distances)):
                    energy = spin_energies[d_idx]
                    linewidth = 10 ** spin_rates[d_idx] * hbar / 2
                    broadening = lorentzian(energies, energy, linewidth)
                    broadening /= 1000  # convert 1/eV to 1/meV
                    mesh_data[d_idx] += broadening

        im = ax.pcolormesh(
            distances,
            energies,
            mesh_data.T,
            rasterized=True,
            cmap=cmap,
            norm=LogNorm(vmin=amin, vmax=amax),
            shading="auto",
        )
        if colorbar:
            pos = ax.get_position()
            cax = plt.gcf().add_axes([pos.x1 + 0.035, pos.y0, 0.035, pos.height])
            cbar = plt.colorbar(im, cax=cax)
            cbar.ax.tick_params(axis="y", length=rcParams["ytick.major.size"] * 0.5)
            cbar.ax.set_ylabel(
                r"$A_\mathbf{k}$ (meV$^{-1}$)", rotation=270, va="bottom"
            )

        _maketicks(ax, bs_plotter, ylabel=ylabel)
        _makeplot(
            ax,
            plot_data,
            bs,
            zero_to_efermi=zero_to_efermi,
            width=width,
            height=height,
            ymin=emin,
            ymax=emax,
            aspect=aspect,
        )
        return plt
Exemplo n.º 14
0
    def get_bandstructure_traces(bs,
                                 path_convention,
                                 energy_window=(-6.0, 10.0)):

        if path_convention == "lm":
            bs = HighSymmKpath.get_continuous_path(bs)

        bs_reg_plot = BSPlotter(bs)

        bs_data = bs_reg_plot.bs_plot_data(split_branches=False)

        bands = []
        for band_num in range(bs.nb_bands):
            for segment in bs_data["energy"][str(Spin.up)]:
                if any(segment[band_num] <= energy_window[1]) and any(
                        segment[band_num] >= energy_window[0]):
                    bands.append(band_num)

        bstraces = []

        cbm = bs.get_cbm()
        vbm = bs.get_vbm()

        cbm_new = bs_data["cbm"]
        vbm_new = bs_data["vbm"]

        bar_loc = []

        for d, dist_val in enumerate(bs_data["distances"]):

            x_dat = dist_val

            traces_for_segment = []

            segment = bs_data["energy"][str(Spin.up)][d]

            traces_for_segment += [{
                "x":
                x_dat,
                "y":
                segment[band_num],
                "mode":
                "lines",
                "line": {
                    "color": "#1f77b4"
                },
                "hoverinfo":
                "skip",
                "name":
                "spin ↑" if bs.is_spin_polarized else "Total",
                "hovertemplate":
                "%{y:.2f} eV",
                "showlegend":
                False,
                "xaxis":
                "x",
                "yaxis":
                "y",
            } for band_num in bands]

            if bs.is_spin_polarized:
                traces_for_segment += [{
                    "x":
                    x_dat,
                    "y": [
                        bs_data["energy"][str(Spin.down)][d][i][j]
                        for j in range(len(bs_data["distances"][d]))
                    ],
                    "mode":
                    "lines",
                    "line": {
                        "color": "#ff7f0e",
                        "dash": "dot"
                    },
                    "hoverinfo":
                    "skip",
                    "showlegend":
                    False,
                    "name":
                    "spin ↓",
                    "hovertemplate":
                    "%{y:.2f} eV",
                    "xaxis":
                    "x",
                    "yaxis":
                    "y",
                } for i in bands]

            bstraces += traces_for_segment

            bar_loc.append(dist_val[-1])

        # - Strip latex math wrapping for labels
        str_replace = {
            "$": "",
            "\\mid": "|",
            "\\Gamma": "Γ",
            "\\Sigma": "Σ",
            "GAMMA": "Γ",
            "_1": "₁",
            "_2": "₂",
            "_3": "₃",
            "_4": "₄",
            "_{1}": "₁",
            "_{2}": "₂",
            "_{3}": "₃",
            "_{4}": "₄",
            "^{*}": "*",
        }

        for entry_num in range(len(bs_data["ticks"]["label"])):
            for key in str_replace.keys():
                if key in bs_data["ticks"]["label"][entry_num]:
                    bs_data["ticks"]["label"][entry_num] = bs_data["ticks"][
                        "label"][entry_num].replace(key, str_replace[key])

        # Vertical lines for disjointed segments
        vert_traces = [{
            "x": [x_point, x_point],
            "y": energy_window,
            "mode": "lines",
            "marker": {
                "color": "white"
            },
            "hoverinfo": "skip",
            "showlegend": False,
            "xaxis": "x",
            "yaxis": "y",
        } for x_point in bar_loc]

        bstraces += vert_traces

        # Dots for cbm and vbm

        dot_traces = [{
            "x": [x_point],
            "y": [y_point],
            "mode":
            "markers",
            "marker": {
                "color": "#7E259B",
                "size": 16,
                "line": {
                    "color": "white",
                    "width": 2
                },
            },
            "showlegend":
            False,
            "hoverinfo":
            "text",
            "name":
            "",
            "hovertemplate":
            "CBM: k = {}, {} eV".format(list(cbm["kpoint"].frac_coords),
                                        cbm["energy"]),
            "xaxis":
            "x",
            "yaxis":
            "y",
        }
                      for (x_point, y_point) in set(cbm_new)] + [{
                          "x": [x_point],
                          "y": [y_point],
                          "mode":
                          "marker",
                          "marker": {
                              "color": "#7E259B",
                              "size": 16,
                              "line": {
                                  "color": "white",
                                  "width": 2
                              },
                          },
                          "showlegend":
                          False,
                          "hoverinfo":
                          "text",
                          "name":
                          "",
                          "hovertemplate":
                          "VBM: k = {}, {} eV".format(
                              list(vbm["kpoint"].frac_coords), vbm["energy"]),
                          "xaxis":
                          "x",
                          "yaxis":
                          "y",
                      } for (x_point, y_point) in set(vbm_new)]

        bstraces += dot_traces

        return bstraces, bs_data
Exemplo n.º 15
0
dosplotter = DosPlotter()
Totaldos = dosplotter.add_dos('Total DOS', tdos)
Integrateddos = dosplotter.add_dos('Integrated DOS', idos)
#Pdos = dosplotter.add_dos('Partial DOS',pdos)
#Spd_dos =  dosplotter.add_dos('spd DOS',spd_dos)
#Element_dos = dosplotter.add_dos('Element DOS',element_dos)
#Element_spd_dos = dosplotter.add_dos('Element_spd DOS',element_spd_dos)
dos_dict = {
    'Total DOS': tdos,
    'Integrated DOS': idos
}  #'Partial DOS':pdos,'spd DOS':spd_dos,'Element DOS':element_dos}#'Element_spd DOS':element_spd_dos
add_dos_dict = dosplotter.add_dos_dict(dos_dict)
get_dos_dict = dosplotter.get_dos_dict()
dos_plot = dosplotter.get_plot()
##dosplotter.save_plot("MAPbI3_dos",img_format="png")
##dos_plot.show()
bsplotter = BSPlotter(bs)
bs_plot_data = bsplotter.bs_plot_data()
bs_plot = bsplotter.get_plot()
#bsplotter.save_plot("MAPbI3_bs",img_format="png")
#bsplotter.show()
ticks = bsplotter.get_ticks()
print ticks
bsplotter.plot_brillouin()
bsdos = BSDOSPlotter(
    tick_fontsize=10,
    egrid_interval=20,
    dos_projection="orbitals",
    bs_legend=None)  #bs_projection="HPbCIN",dos_projection="HPbCIN")
bds = bsdos.get_plot(bs, cdos)
Exemplo n.º 16
0
        def bs_dos_traces(bandStructureSymmLine, densityOfStates):

            if bandStructureSymmLine == "error" or densityOfStates == "error":
                return "error"

            if bandStructureSymmLine == None or densityOfStates == None:
                raise PreventUpdate

            # - BS Data
            bstraces = []

            bs_reg_plot = BSPlotter(BSML.from_dict(bandStructureSymmLine))

            bs_data = bs_reg_plot.bs_plot_data()

            # -- Strip latex math wrapping
            str_replace = {
                "$": "",
                "\\mid": "|",
                "\\Gamma": "Γ",
                "\\Sigma": "Σ",
                "_1": "₁",
                "_2": "₂",
                "_3": "₃",
                "_4": "₄",
            }

            for entry_num in range(len(bs_data["ticks"]["label"])):
                for key in str_replace.keys():
                    if key in bs_data["ticks"]["label"][entry_num]:
                        bs_data["ticks"]["label"][entry_num] = bs_data[
                            "ticks"]["label"][entry_num].replace(
                                key, str_replace[key])

            for d in range(len(bs_data["distances"])):
                for i in range(bs_reg_plot._nb_bands):
                    bstraces.append(
                        go.Scatter(
                            x=bs_data["distances"][d],
                            y=[
                                bs_data["energy"][d][str(Spin.up)][i][j]
                                for j in range(len(bs_data["distances"][d]))
                            ],
                            mode="lines",
                            line=dict(color=("#666666"), width=2),
                            hoverinfo="skip",
                            showlegend=False,
                        ))

                    if bs_reg_plot._bs.is_spin_polarized:
                        bstraces.append(
                            go.Scatter(
                                x=bs_data["distances"][d],
                                y=[
                                    bs_data["energy"][d][str(Spin.down)][i][j]
                                    for j in range(len(bs_data["distances"]
                                                       [d]))
                                ],
                                mode="lines",
                                line=dict(color=("#666666"),
                                          width=2,
                                          dash="dash"),
                                hoverinfo="skip",
                                showlegend=False,
                            ))

            # -- DOS Data
            dostraces = []

            dos = CompleteDos.from_dict(densityOfStates)

            if Spin.down in dos.densities:
                # Add second spin data if available
                trace_tdos = go.Scatter(
                    x=dos.densities[Spin.down],
                    y=dos.energies - dos.efermi,
                    mode="lines",
                    name="Total DOS (spin ↓)",
                    line=go.scatter.Line(color="#444444", dash="dash"),
                    fill="tozeroy",
                )

                dostraces.append(trace_tdos)

                tdos_label = "Total DOS (spin ↑)"
            else:
                tdos_label = "Total DOS"

            # Total DOS
            trace_tdos = go.Scatter(
                x=dos.densities[Spin.up],
                y=dos.energies - dos.efermi,
                mode="lines",
                name=tdos_label,
                line=go.scatter.Line(color="#444444"),
                fill="tozeroy",
                legendgroup="spinup",
            )

            dostraces.append(trace_tdos)

            p_ele_dos = dos.get_element_dos()

            # Projected DOS
            count = 0
            colors = [
                "#1f77b4",  # muted blue
                "#ff7f0e",  # safety orange
                "#2ca02c",  # cooked asparagus green
                "#d62728",  # brick red
                "#9467bd",  # muted purple
                "#8c564b",  # chestnut brown
                "#e377c2",  # raspberry yogurt pink
                "#bcbd22",  # curry yellow-green
                "#17becf",  # blue-teal
            ]

            for ele in p_ele_dos.keys():

                if bs_reg_plot._bs.is_spin_polarized:
                    trace = go.Scatter(
                        x=p_ele_dos[ele].densities[Spin.down],
                        y=dos.energies - dos.efermi,
                        mode="lines",
                        name=ele.symbol + " (spin ↓)",
                        line=dict(width=3, color=colors[count], dash="dash"),
                    )

                    dostraces.append(trace)
                    spin_up_label = ele.symbol + " (spin ↑)"

                else:
                    spin_up_label = ele.symbol

                trace = go.Scatter(
                    x=p_ele_dos[ele].densities[Spin.up],
                    y=dos.energies - dos.efermi,
                    mode="lines",
                    name=spin_up_label,
                    line=dict(width=3, color=colors[count]),
                )

                dostraces.append(trace)

                count += 1

            traces = [bstraces, dostraces, bs_data]

            return traces
Exemplo n.º 17
0
    def get_plot(
        self,
        n_idx,
        t_idx,
        zero_to_efermi=True,
        estep=0.01,
        line_density=100,
        height=6,
        width=6,
        emin=None,
        emax=None,
        ylabel="Energy (eV)",
        plt=None,
        aspect=None,
        distance_factor=10,
        kpath=None,
        style=None,
        no_base_style=False,
        fonts=None,
    ):
        interpolater = self._get_interpolater(n_idx, t_idx)

        bs, prop = interpolater.get_line_mode_band_structure(
            line_density=line_density,
            return_other_properties=True,
            kpath=kpath)

        fd_emin, fd_emax = self.fd_cutoffs
        if not emin:
            emin = fd_emin * hartree_to_ev
            if zero_to_efermi:
                emin -= bs.efermi

        if not emax:
            emax = fd_emax * hartree_to_ev
            if zero_to_efermi:
                emax -= bs.efermi

        logger.info("Plotting band structure")
        plt = pretty_plot(width=width, height=height, plt=plt)
        ax = plt.gca()

        if zero_to_efermi:
            bs.bands = {s: b - bs.efermi for s, b in bs.bands.items()}
            bs.efermi = 0

        bs_plotter = BSPlotter(bs)
        plot_data = bs_plotter.bs_plot_data(zero_to_efermi=zero_to_efermi)

        energies = np.linspace(emin, emax, int((emax - emin) / estep))
        distances = np.array([d for x in plot_data["distances"] for d in x])

        # rates are currently log(rate)
        rates = {}
        for spin, spin_data in prop.items():
            rates[spin] = spin_data["rates"]
            rates[spin][rates[spin] <= 0] = np.min(
                rates[spin][rates[spin] > 0])
            rates[spin][rates[spin] >= 15] = 15

        interp_distances = np.linspace(distances.min(), distances.max(),
                                       int(len(distances) * distance_factor))

        window = np.min([len(distances) - 2, 71])
        window += window % 2 + 1
        mesh_data = np.full((len(interp_distances), len(energies)), 1e-2)

        for spin in self.spins:
            for spin_energies, spin_rates in zip(bs.bands[spin], rates[spin]):
                interp_energies = interp1d(distances,
                                           spin_energies)(interp_distances)
                spin_rates = savgol_filter(spin_rates, window, 3)
                interp_rates = interp1d(distances,
                                        spin_rates)(interp_distances)
                linewidths = 10**interp_rates * hbar / 2

                for d_idx in range(len(interp_distances)):
                    energy = interp_energies[d_idx]
                    linewidth = linewidths[d_idx]

                    broadening = lorentzian(energies, energy, linewidth)
                    mesh_data[d_idx] = np.maximum(broadening, mesh_data[d_idx])
                    mesh_data[d_idx] = np.maximum(broadening, mesh_data[d_idx])

        ax.pcolormesh(
            interp_distances,
            energies,
            mesh_data.T,
            rasterized=True,
            norm=LogNorm(vmin=mesh_data.min(), vmax=mesh_data.max()),
        )

        _maketicks(ax, bs_plotter, ylabel=ylabel)
        _makeplot(
            ax,
            plot_data,
            bs,
            zero_to_efermi=zero_to_efermi,
            width=width,
            height=height,
            ymin=emin,
            ymax=emax,
            aspect=aspect,
        )
        return plt
Exemplo n.º 18
0
        def bs_dos_data(
            mpid,
            path_convention,
            dos_select,
            label_select,
            bandstructure_symm_line,
            density_of_states,
        ):
            if not mpid and (bandstructure_symm_line is None
                             or density_of_states is None):
                raise PreventUpdate

            elif bandstructure_symm_line is None or density_of_states is None:
                if label_select == "":
                    raise PreventUpdate

                # --
                # -- BS and DOS from API or DB
                # --

                bs_data = {"ticks": {}}

                bs_store = GridFSStore(
                    database="fw_bs_prod",
                    collection_name="bandstructure_fs",
                    host="mongodb03.nersc.gov",
                    port=27017,
                    username="******",
                    password="",
                )

                dos_store = GridFSStore(
                    database="fw_bs_prod",
                    collection_name="dos_fs",
                    host="mongodb03.nersc.gov",
                    port=27017,
                    username="******",
                    password="",
                )

                es_store = MongoStore(
                    database="fw_bs_prod",
                    collection_name="electronic_structure",
                    host="mongodb03.nersc.gov",
                    port=27017,
                    username="******",
                    password="",
                    key="task_id",
                )

                # - BS traces from DB using task_id
                es_store.connect()
                bs_query = es_store.query_one(
                    criteria={"task_id": int(mpid)},
                    properties=[
                        "bandstructure.{}.task_id".format(path_convention),
                        "bandstructure.{}.total.equiv_labels".format(
                            path_convention),
                    ],
                )

                es_store.close()

                bs_store.connect()
                bandstructure_symm_line = bs_store.query_one(criteria={
                    "metadata.task_id":
                    int(bs_query["bandstructure"][path_convention]["task_id"])
                }, )

                # If LM convention, get equivalent labels
                if path_convention != label_select:
                    bs_equiv_labels = bs_query["bandstructure"][
                        path_convention]["total"]["equiv_labels"]

                    new_labels_dict = {}
                    for label in bandstructure_symm_line["labels_dict"].keys():

                        label_formatted = label.replace("$", "")

                        if "|" in label_formatted:
                            f_label = label_formatted.split("|")
                            new_labels.append(
                                "$" +
                                bs_equiv_labels[label_select][f_label[0]] +
                                "|" +
                                bs_equiv_labels[label_select][f_label[1]] +
                                "$")
                        else:
                            new_labels_dict["$" + bs_equiv_labels[label_select]
                                            [label_formatted] +
                                            "$"] = bandstructure_symm_line[
                                                "labels_dict"][label]

                    bandstructure_symm_line["labels_dict"] = new_labels_dict

                # - DOS traces from DB using task_id
                es_store.connect()
                dos_query = es_store.query_one(
                    criteria={"task_id": int(mpid)},
                    properties=["dos.task_id"],
                )
                es_store.close()

                dos_store.connect()
                density_of_states = dos_store.query_one(
                    criteria={"task_id": int(dos_query["dos"]["task_id"])}, )

            # - BS Data
            if (type(bandstructure_symm_line) != dict
                    and bandstructure_symm_line is not None):
                bandstructure_symm_line = bandstructure_symm_line.to_dict()

            if type(density_of_states
                    ) != dict and density_of_states is not None:
                density_of_states = density_of_states.to_dict()

            bsml = BSML.from_dict(bandstructure_symm_line)

            bs_reg_plot = BSPlotter(bsml)

            bs_data = bs_reg_plot.bs_plot_data()

            # Make plot continous for lm
            if path_convention == "lm":
                distance_map, kpath_euler = HSKP(
                    bsml.structure).get_continuous_path(bsml)

                kpath_labels = [pair[0] for pair in kpath_euler]
                kpath_labels.append(kpath_euler[-1][1])

            else:
                distance_map = [(i, False)
                                for i in range(len(bs_data["distances"]))]
                kpath_labels = []
                for label_ind in range(len(bs_data["ticks"]["label"]) - 1):
                    if (bs_data["ticks"]["label"][label_ind] !=
                            bs_data["ticks"]["label"][label_ind + 1]):
                        kpath_labels.append(
                            bs_data["ticks"]["label"][label_ind])
                kpath_labels.append(bs_data["ticks"]["label"][-1])

            bs_data["ticks"]["label"] = kpath_labels

            # Obtain bands to plot over and generate traces for bs data:
            energy_window = (-6.0, 10.0)
            bands = []
            for band_num in range(bs_reg_plot._nb_bands):
                if (bs_data["energy"][0][str(Spin.up)][band_num][0] <=
                        energy_window[1]) and (bs_data["energy"][0][str(
                            Spin.up)][band_num][0] >= energy_window[0]):
                    bands.append(band_num)

            bstraces = []

            pmin = 0.0
            tick_vals = [0.0]

            cbm = bsml.get_cbm()
            vbm = bsml.get_vbm()

            cbm_new = bs_data["cbm"]
            vbm_new = bs_data["vbm"]

            for dnum, (d, rev) in enumerate(distance_map):

                x_dat = [
                    dval - bs_data["distances"][d][0] + pmin
                    for dval in bs_data["distances"][d]
                ]

                pmin = x_dat[-1]

                tick_vals.append(pmin)

                if not rev:
                    traces_for_segment = [{
                        "x":
                        x_dat,
                        "y": [
                            bs_data["energy"][d][str(Spin.up)][i][j]
                            for j in range(len(bs_data["distances"][d]))
                        ],
                        "mode":
                        "lines",
                        "line": {
                            "color": "#1f77b4"
                        },
                        "hoverinfo":
                        "skip",
                        "name":
                        "spin ↑"
                        if bs_reg_plot._bs.is_spin_polarized else "Total",
                        "hovertemplate":
                        "%{y:.2f} eV",
                        "showlegend":
                        False,
                        "xaxis":
                        "x",
                        "yaxis":
                        "y",
                    } for i in bands]
                elif rev:
                    traces_for_segment = [{
                        "x":
                        x_dat,
                        "y": [
                            bs_data["energy"][d][str(Spin.up)][i][j]
                            for j in reversed(
                                range(len(bs_data["distances"][d])))
                        ],
                        "mode":
                        "lines",
                        "line": {
                            "color": "#1f77b4"
                        },
                        "hoverinfo":
                        "skip",
                        "name":
                        "spin ↑"
                        if bs_reg_plot._bs.is_spin_polarized else "Total",
                        "hovertemplate":
                        "%{y:.2f} eV",
                        "showlegend":
                        False,
                        "xaxis":
                        "x",
                        "yaxis":
                        "y",
                    } for i in bands]

                if bs_reg_plot._bs.is_spin_polarized:

                    if not rev:
                        traces_for_segment += [{
                            "x":
                            x_dat,
                            "y": [
                                bs_data["energy"][d][str(Spin.down)][i][j]
                                for j in range(len(bs_data["distances"][d]))
                            ],
                            "mode":
                            "lines",
                            "line": {
                                "color": "#ff7f0e",
                                "dash": "dot"
                            },
                            "hoverinfo":
                            "skip",
                            "showlegend":
                            False,
                            "name":
                            "spin ↓",
                            "hovertemplate":
                            "%{y:.2f} eV",
                            "xaxis":
                            "x",
                            "yaxis":
                            "y",
                        } for i in bands]
                    elif rev:
                        traces_for_segment += [{
                            "x":
                            x_dat,
                            "y": [
                                bs_data["energy"][d][str(Spin.down)][i][j]
                                for j in reversed(
                                    range(len(bs_data["distances"][d])))
                            ],
                            "mode":
                            "lines",
                            "line": {
                                "color": "#ff7f0e",
                                "dash": "dot"
                            },
                            "hoverinfo":
                            "skip",
                            "showlegend":
                            False,
                            "name":
                            "spin ↓",
                            "hovertemplate":
                            "%{y:.2f} eV",
                            "xaxis":
                            "x",
                            "yaxis":
                            "y",
                        } for i in bands]

                bstraces += traces_for_segment

                # - Get proper cbm and vbm coords for lm
                if path_convention == "lm":
                    for (x_point, y_point) in bs_data["cbm"]:
                        if x_point in bs_data["distances"][d]:
                            xind = bs_data["distances"][d].index(x_point)
                            if not rev:
                                x_point_new = x_dat[xind]
                            else:
                                x_point_new = x_dat[len(x_dat) - xind - 1]

                            new_label = bs_data["ticks"]["label"][
                                tick_vals.index(x_point_new)]

                            if (cbm["kpoint"].label is None
                                    or cbm["kpoint"].label in new_label):
                                cbm_new.append((x_point_new, y_point))

                    for (x_point, y_point) in bs_data["vbm"]:
                        if x_point in bs_data["distances"][d]:
                            xind = bs_data["distances"][d].index(x_point)
                            if not rev:
                                x_point_new = x_dat[xind]
                            else:
                                x_point_new = x_dat[len(x_dat) - xind - 1]

                            new_label = bs_data["ticks"]["label"][
                                tick_vals.index(x_point_new)]

                            if (vbm["kpoint"].label is None
                                    or vbm["kpoint"].label in new_label):
                                vbm_new.append((x_point_new, y_point))

            bs_data["ticks"]["distance"] = tick_vals

            # - Strip latex math wrapping for labels
            str_replace = {
                "$": "",
                "\\mid": "|",
                "\\Gamma": "Γ",
                "\\Sigma": "Σ",
                "GAMMA": "Γ",
                "_1": "₁",
                "_2": "₂",
                "_3": "₃",
                "_4": "₄",
                "_{1}": "₁",
                "_{2}": "₂",
                "_{3}": "₃",
                "_{4}": "₄",
                "^{*}": "*",
            }

            bar_loc = []
            for entry_num in range(len(bs_data["ticks"]["label"])):
                for key in str_replace.keys():
                    if key in bs_data["ticks"]["label"][entry_num]:
                        bs_data["ticks"]["label"][entry_num] = bs_data[
                            "ticks"]["label"][entry_num].replace(
                                key, str_replace[key])
                        if key == "\\mid":
                            bar_loc.append(
                                bs_data["ticks"]["distance"][entry_num])

            # Vertical lines for disjointed segments
            vert_traces = [{
                "x": [x_point, x_point],
                "y": energy_window,
                "mode": "lines",
                "marker": {
                    "color": "white"
                },
                "hoverinfo": "skip",
                "showlegend": False,
                "xaxis": "x",
                "yaxis": "y",
            } for x_point in bar_loc]

            bstraces += vert_traces

            # Dots for cbm and vbm

            dot_traces = [{
                "x": [x_point],
                "y": [y_point],
                "mode":
                "markers",
                "marker": {
                    "color": "#7E259B",
                    "size": 16,
                    "line": {
                        "color": "white",
                        "width": 2
                    },
                },
                "showlegend":
                False,
                "hoverinfo":
                "text",
                "name":
                "",
                "hovertemplate":
                "CBM: k = {}, {} eV".format(list(cbm["kpoint"].frac_coords),
                                            cbm["energy"]),
                "xaxis":
                "x",
                "yaxis":
                "y",
            } for (x_point, y_point) in set(cbm_new)] + [{
                "x": [x_point],
                "y": [y_point],
                "mode":
                "marker",
                "marker": {
                    "color": "#7E259B",
                    "size": 16,
                    "line": {
                        "color": "white",
                        "width": 2
                    },
                },
                "showlegend":
                False,
                "hoverinfo":
                "text",
                "name":
                "",
                "hovertemplate":
                "VBM: k = {}, {} eV".format(list(vbm["kpoint"].frac_coords),
                                            vbm["energy"]),
                "xaxis":
                "x",
                "yaxis":
                "y",
            } for (x_point, y_point) in set(vbm_new)]

            bstraces += dot_traces

            # - DOS Data
            dostraces = []

            dos = CompleteDos.from_dict(density_of_states)

            dos_max = np.abs(
                (dos.energies - dos.efermi - energy_window[1])).argmin()
            dos_min = np.abs(
                (dos.energies - dos.efermi - energy_window[0])).argmin()

            if bs_reg_plot._bs.is_spin_polarized:
                # Add second spin data if available
                trace_tdos = {
                    "x": -1.0 * dos.densities[Spin.down][dos_min:dos_max],
                    "y": dos.energies[dos_min:dos_max] - dos.efermi,
                    "mode": "lines",
                    "name": "Total DOS (spin ↓)",
                    "line": go.scatter.Line(color="#444444", dash="dot"),
                    "fill": "tozerox",
                    "fillcolor": "#C4C4C4",
                    "xaxis": "x2",
                    "yaxis": "y2",
                }

                dostraces.append(trace_tdos)

                tdos_label = "Total DOS (spin ↑)"
            else:
                tdos_label = "Total DOS"

            # Total DOS
            trace_tdos = {
                "x": dos.densities[Spin.up][dos_min:dos_max],
                "y": dos.energies[dos_min:dos_max] - dos.efermi,
                "mode": "lines",
                "name": tdos_label,
                "line": go.scatter.Line(color="#444444"),
                "fill": "tozerox",
                "fillcolor": "#C4C4C4",
                "legendgroup": "spinup",
                "xaxis": "x2",
                "yaxis": "y2",
            }

            dostraces.append(trace_tdos)

            ele_dos = dos.get_element_dos()
            elements = [str(entry) for entry in ele_dos.keys()]

            if dos_select == "ap":
                proj_data = ele_dos
            elif dos_select == "op":
                proj_data = dos.get_spd_dos()
            elif "orb" in dos_select:
                proj_data = dos.get_element_spd_dos(
                    Element(dos_select.replace("orb", "")))
            else:
                raise PreventUpdate

            # Projected DOS
            count = 0
            colors = [
                "#d62728",  # brick red
                "#2ca02c",  # cooked asparagus green
                "#17becf",  # blue-teal
                "#bcbd22",  # curry yellow-green
                "#9467bd",  # muted purple
                "#8c564b",  # chestnut brown
                "#e377c2",  # raspberry yogurt pink
            ]

            for label in proj_data.keys():

                if bs_reg_plot._bs.is_spin_polarized:
                    trace = {
                        "x":
                        -1.0 *
                        proj_data[label].densities[Spin.down][dos_min:dos_max],
                        "y":
                        dos.energies[dos_min:dos_max] - dos.efermi,
                        "mode":
                        "lines",
                        "name":
                        str(label) + " (spin ↓)",
                        "line":
                        dict(width=3, color=colors[count], dash="dot"),
                        "xaxis":
                        "x2",
                        "yaxis":
                        "y2",
                    }

                    dostraces.append(trace)
                    spin_up_label = str(label) + " (spin ↑)"

                else:
                    spin_up_label = str(label)

                trace = {
                    "x": proj_data[label].densities[Spin.up][dos_min:dos_max],
                    "y": dos.energies[dos_min:dos_max] - dos.efermi,
                    "mode": "lines",
                    "name": spin_up_label,
                    "line": dict(width=2, color=colors[count]),
                    "xaxis": "x2",
                    "yaxis": "y2",
                }

                dostraces.append(trace)

                count += 1
            traces = [bstraces, dostraces, bs_data]

            return (traces, elements)
Exemplo n.º 19
0
        def bs_dos_data(
            mpid,
            path_convention,
            dos_select,
            label_select,
            bandstructure_symm_line,
            density_of_states,
        ):
            if (not mpid
                    or "mpid" not in mpid) and (bandstructure_symm_line is None
                                                or density_of_states is None):
                raise PreventUpdate
            elif mpid:
                raise PreventUpdate
            elif bandstructure_symm_line is None or density_of_states is None:

                # --
                # -- BS and DOS from API
                # --

                mpid = mpid["mpid"]
                bs_data = {"ticks": {}}

                # client = MongoClient(
                #     "mongodb03.nersc.gov", username="******", password="", authSource="fw_bs_prod",
                # )

                db = client.fw_bs_prod

                # - BS traces from DB using task_id
                bs_query = list(
                    db.electronic_structure.find(
                        {"task_id": int(mpid)},
                        [
                            "bandstructure.{}.total.traces".format(
                                path_convention)
                        ],
                    ))[0]

                is_sp = (len(bs_query["bandstructure"][path_convention]
                             ["total"]["traces"]) == 2)

                if is_sp:
                    bstraces = (bs_query["bandstructure"][path_convention]
                                ["total"]["traces"]["1"] +
                                bs_query["bandstructure"][path_convention]
                                ["total"]["traces"]["-1"])
                else:
                    bstraces = bs_query["bandstructure"][path_convention][
                        "total"]["traces"]["1"]

                bs_data["ticks"]["distance"] = bs_query["bandstructure"][
                    path_convention]["total"]["traces"]["ticks"]
                bs_data["ticks"]["label"] = bs_query["bandstructure"][
                    path_convention]["total"]["traces"]["labels"]

                # If LM convention, get equivalent labels
                if path_convention == "lm" and label_select != "lm":
                    bs_equiv_labels = bs_query["bandstructure"][
                        path_convention]["total"]["traces"]["equiv_labels"]

                    alt_choice = label_select

                    if label_select == "hin":
                        alt_choice = "h"

                    new_labels = []
                    for label in bs_data["ticks"]["label"]:
                        label_formatted = label.replace("$", "")

                        if "|" in label_formatted:
                            f_label = label_formatted.split("|")
                            new_labels.append(
                                "$" + bs_equiv_labels[alt_choice][f_label[0]] +
                                "|" + bs_equiv_labels[alt_choice][f_label[1]] +
                                "$")
                        else:
                            new_labels.append(
                                "$" +
                                bs_equiv_labels[alt_choice][label_formatted] +
                                "$")

                    bs_data["ticks"]["label"] = new_labels

                # Strip latex math wrapping
                str_replace = {
                    "$": "",
                    "\\mid": "|",
                    "\\Gamma": "Γ",
                    "\\Sigma": "Σ",
                    "GAMMA": "Γ",
                    "_1": "₁",
                    "_2": "₂",
                    "_3": "₃",
                    "_4": "₄",
                }

                for entry_num in range(len(bs_data["ticks"]["label"])):
                    for key in str_replace.keys():
                        if key in bs_data["ticks"]["label"][entry_num]:
                            bs_data["ticks"]["label"][entry_num] = bs_data[
                                "ticks"]["label"][entry_num].replace(
                                    key, str_replace[key])

                # - DOS traces from DB using task_id
                dostraces = []

                dos_tot_ele_traces = list(
                    db.electronic_structure.find(
                        {"task_id": int(mpid)},
                        ["dos.total.traces", "dos.elements"]))[0]

                dostraces = [
                    dos_tot_ele_traces["dos"]["total"]["traces"][spin] for spin
                    in dos_tot_ele_traces["dos"]["total"]["traces"].keys()
                ]

                elements = [
                    ele
                    for ele in dos_tot_ele_traces["dos"]["elements"].keys()
                ]

                if dos_select == "ap":
                    for ele_label in elements:
                        dostraces += [
                            dos_tot_ele_traces["dos"]["elements"][ele_label]
                            ["total"]["traces"][spin]
                            for spin in dos_tot_ele_traces["dos"]["elements"]
                            [ele_label]["total"]["traces"].keys()
                        ]

                elif dos_select == "op":
                    orb_tot_traces = list(
                        db.electronic_structure.find({"task_id": int(mpid)},
                                                     ["dos.orbitals"]))[0]
                    for orbital in ["s", "p", "d"]:
                        dostraces += [
                            orb_tot_traces["dos"]["orbitals"][orbital]
                            ["traces"][spin] for spin in orb_tot_traces["dos"]
                            ["orbitals"]["s"]["traces"].keys()
                        ]

                elif "orb" in dos_select:
                    ele_label = dos_select.replace("orb", "")

                    for orbital in ["s", "p", "d"]:
                        dostraces += [
                            dos_tot_ele_traces["dos"]["elements"][ele_label]
                            [orbital]["traces"][spin]
                            for spin in dos_tot_ele_traces["dos"]["elements"]
                            [ele_label][orbital]["traces"].keys()
                        ]

                traces = [bstraces, dostraces, bs_data]

                return (traces, elements)

            else:

                # --
                # -- BS and DOS passed manually
                # --

                # - BS Data

                if type(bandstructure_symm_line) != dict:
                    bandstructure_symm_line = bandstructure_symm_line.to_dict()

                if type(density_of_states) != dict:
                    density_of_states = density_of_states.to_dict()

                bs_reg_plot = BSPlotter(
                    BSML.from_dict(bandstructure_symm_line))
                bs_data = bs_reg_plot.bs_plot_data()

                # - Strip latex math wrapping
                str_replace = {
                    "$": "",
                    "\\mid": "|",
                    "\\Gamma": "Γ",
                    "\\Sigma": "Σ",
                    "GAMMA": "Γ",
                    "_1": "₁",
                    "_2": "₂",
                    "_3": "₃",
                    "_4": "₄",
                }

                for entry_num in range(len(bs_data["ticks"]["label"])):
                    for key in str_replace.keys():
                        if key in bs_data["ticks"]["label"][entry_num]:
                            bs_data["ticks"]["label"][entry_num] = bs_data[
                                "ticks"]["label"][entry_num].replace(
                                    key, str_replace[key])

                # Obtain bands to plot over:
                energy_window = (-6.0, 10.0)
                bands = []
                for band_num in range(bs_reg_plot._nb_bands):
                    if (bs_data["energy"][0][str(Spin.up)][band_num][0] <=
                            energy_window[1]) and (bs_data["energy"][0][str(
                                Spin.up)][band_num][0] >= energy_window[0]):
                        bands.append(band_num)

                bstraces = []

                # Generate traces for total BS data
                for d in range(len(bs_data["distances"])):
                    dist_dat = bs_data["distances"][d]
                    energy_ind = [
                        i for i in range(len(bs_data["distances"][d]))
                    ]

                    traces_for_segment = [{
                        "x":
                        dist_dat,
                        "y":
                        [bs_data["energy"][d]["1"][i][j] for j in energy_ind],
                        "mode":
                        "lines",
                        "line": {
                            "color": "#666666"
                        },
                        "hoverinfo":
                        "skip",
                        "showlegend":
                        False,
                    } for i in bands]

                    if bs_reg_plot._bs.is_spin_polarized:
                        traces_for_segment += [{
                            "x":
                            dist_dat,
                            "y": [
                                bs_data["energy"][d]["-1"][i][j]
                                for j in energy_ind
                            ],
                            "mode":
                            "lines",
                            "line": {
                                "color": "#666666"
                            },
                            "hoverinfo":
                            "skip",
                            "showlegend":
                            False,
                        } for i in bands]

                    bstraces += traces_for_segment

                # - DOS Data
                dostraces = []

                dos = CompleteDos.from_dict(density_of_states)

                dos_max = np.abs(
                    (dos.energies - dos.efermi - energy_window[1])).argmin()
                dos_min = np.abs(
                    (dos.energies - dos.efermi - energy_window[0])).argmin()

                if bs_reg_plot._bs.is_spin_polarized:
                    # Add second spin data if available
                    trace_tdos = go.Scatter(
                        x=dos.densities[Spin.down][dos_min:dos_max],
                        y=dos.energies[dos_min:dos_max] - dos.efermi,
                        mode="lines",
                        name="Total DOS (spin ↓)",
                        line=go.scatter.Line(color="#444444", dash="dash"),
                        fill="tozerox",
                    )

                    dostraces.append(trace_tdos)

                    tdos_label = "Total DOS (spin ↑)"
                else:
                    tdos_label = "Total DOS"

                # Total DOS
                trace_tdos = go.Scatter(
                    x=dos.densities[Spin.up][dos_min:dos_max],
                    y=dos.energies[dos_min:dos_max] - dos.efermi,
                    mode="lines",
                    name=tdos_label,
                    line=go.scatter.Line(color="#444444"),
                    fill="tozerox",
                    legendgroup="spinup",
                )

                dostraces.append(trace_tdos)

                ele_dos = dos.get_element_dos()
                elements = [str(entry) for entry in ele_dos.keys()]

                if dos_select == "ap":
                    proj_data = ele_dos
                elif dos_select == "op":
                    proj_data = dos.get_spd_dos()
                elif "orb" in dos_select:
                    proj_data = dos.get_element_spd_dos(
                        Element(dos_select.replace("orb", "")))
                else:
                    raise PreventUpdate

                # Projected DOS
                count = 0
                colors = [
                    "#1f77b4",  # muted blue
                    "#ff7f0e",  # safety orange
                    "#2ca02c",  # cooked asparagus green
                    "#9467bd",  # muted purple
                    "#e377c2",  # raspberry yogurt pink
                    "#d62728",  # brick red
                    "#8c564b",  # chestnut brown
                    "#bcbd22",  # curry yellow-green
                    "#17becf",  # blue-teal
                ]

                for label in proj_data.keys():

                    if bs_reg_plot._bs.is_spin_polarized:
                        trace = go.Scatter(
                            x=proj_data[label].densities[Spin.down]
                            [dos_min:dos_max],
                            y=dos.energies[dos_min:dos_max] - dos.efermi,
                            mode="lines",
                            name=str(label) + " (spin ↓)",
                            line=dict(width=3,
                                      color=colors[count],
                                      dash="dash"),
                        )

                        dostraces.append(trace)
                        spin_up_label = str(label) + " (spin ↑)"

                    else:
                        spin_up_label = str(label)

                    trace = go.Scatter(
                        x=proj_data[label].densities[Spin.up][dos_min:dos_max],
                        y=dos.energies[dos_min:dos_max] - dos.efermi,
                        mode="lines",
                        name=spin_up_label,
                        line=dict(width=3, color=colors[count]),
                    )

                    dostraces.append(trace)

                    count += 1

                traces = [bstraces, dostraces, bs_data]

                return (traces, elements)
Exemplo n.º 20
0
def bandstr(vrun="", kpfile="", filename=".", plot=False):
    """
    Plot electronic bandstructure
 
    Args:
        vrun: path to vasprun.xml
        kpfile:path to line mode KPOINTS file 
    Returns:
           matplotlib object
    """

    run = Vasprun(vrun, parse_projected_eigen=True)
    bands = run.get_band_structure(kpfile, line_mode=True, efermi=run.efermi)
    bsp = BSPlotter(bands)
    zero_to_efermi = True
    bandgap = str(round(bands.get_band_gap()["energy"], 3))
    # print "bg=",bandgap
    data = bsp.bs_plot_data(zero_to_efermi)
    plt = get_publication_quality_plot(12, 8)
    plt.close()
    plt.clf()
    band_linewidth = 3
    x_max = data["distances"][-1][-1]
    # print (x_max)
    for d in range(len(data["distances"])):
        for i in range(bsp._nb_bands):
            plt.plot(
                data["distances"][d],
                [
                    data["energy"][d]["1"][i][j]
                    for j in range(len(data["distances"][d]))
                ],
                "b-",
                linewidth=band_linewidth,
            )
            if bsp._bs.is_spin_polarized:
                plt.plot(
                    data["distances"][d],
                    [
                        data["energy"][d]["-1"][i][j]
                        for j in range(len(data["distances"][d]))
                    ],
                    "r--",
                    linewidth=band_linewidth,
                )
    bsp._maketicks(plt)
    if bsp._bs.is_metal():
        e_min = -10
        e_max = 10
        band_linewidth = 3

    for cbm in data["cbm"]:
        plt.scatter(cbm[0], cbm[1], color="r", marker="o", s=100)

        for vbm in data["vbm"]:
            plt.scatter(vbm[0], vbm[1], color="g", marker="o", s=100)

    plt.xlabel(r"$\mathrm{Wave\ Vector}$", fontsize=30)
    ylabel = (
        r"$\mathrm{E\ -\ E_f\ (eV)}$" if zero_to_efermi else r"$\mathrm{Energy\ (eV)}$"
    )
    plt.ylabel(ylabel, fontsize=30)
    plt.ylim(-4, 4)
    plt.xlim(0, x_max)
    plt.tight_layout()
    if plot == True:
        plt.savefig(filename, img_format="png")
        plt.close()

    return plt
Exemplo n.º 21
0
class BSPlotterTest(unittest.TestCase):
    def setUp(self):
        with open(os.path.join(test_dir, "CaO_2605_bandstructure.json"),
                  "r",
                  encoding="utf-8") as f:
            d = json.loads(f.read())
            self.bs = BandStructureSymmLine.from_dict(d)
            self.plotter = BSPlotter(self.bs)

        self.assertEqual(len(self.plotter._bs), 1,
                         "wrong number of band objects")

        with open(os.path.join(test_dir, "N2_12103_bandstructure.json"),
                  "r",
                  encoding="utf-8") as f:
            d = json.loads(f.read())
            self.sbs_sc = BandStructureSymmLine.from_dict(d)

        with open(os.path.join(test_dir, "C_48_bandstructure.json"),
                  "r",
                  encoding="utf-8") as f:
            d = json.loads(f.read())
            self.sbs_met = BandStructureSymmLine.from_dict(d)

        self.plotter_multi = BSPlotter([self.sbs_sc, self.sbs_met])
        self.assertEqual(len(self.plotter_multi._bs), 2,
                         "wrong number of band objects")
        self.assertEqual(self.plotter_multi._nb_bands, [96, 96],
                         "wrong number of bands")
        warnings.simplefilter("ignore")

    def tearDown(self):
        warnings.simplefilter("default")

    def test_add_bs(self):
        self.plotter_multi.add_bs(self.sbs_sc)
        self.assertEqual(len(self.plotter_multi._bs), 3,
                         "wrong number of band objects")
        self.assertEqual(self.plotter_multi._nb_bands, [96, 96, 96],
                         "wrong number of bands")

    def test_get_branch_steps(self):
        steps_idx = BSPlotter._get_branch_steps(self.sbs_sc.branches)
        self.assertEqual(steps_idx, [0, 121, 132, 143],
                         "wrong list of steps idx")

    def test_rescale_distances(self):
        rescaled_distances = self.plotter_multi._rescale_distances(
            self.sbs_sc, self.sbs_met)
        self.assertEqual(
            len(rescaled_distances),
            len(self.sbs_met.distance),
            "wrong lenght of distances list",
        )
        self.assertEqual(rescaled_distances[-1], 6.5191398067252875,
                         "wrong last distance value")
        self.assertEqual(
            rescaled_distances[148],
            self.sbs_sc.distance[19],
            "wrong distance at high symm k-point",
        )

    def test_interpolate_bands(self):
        data = self.plotter.bs_plot_data()
        d = data["distances"]
        en = data["energy"]["1"]
        int_distances, int_energies = self.plotter._interpolate_bands(d, en)

        self.assertEqual(len(int_distances), 10,
                         "wrong lenght of distances list")
        self.assertEqual(len(int_distances[0]), 100,
                         "wrong lenght of distances in a branch")
        self.assertEqual(len(int_energies), 10,
                         "wrong lenght of distances list")
        self.assertEqual(int_energies[0].shape, (16, 100),
                         "wrong lenght of distances list")

    def test_bs_plot_data(self):
        self.assertEqual(
            len(self.plotter.bs_plot_data()["distances"]),
            10,
            "wrong number of sequences of branches",
        )
        self.assertEqual(
            len(self.plotter.bs_plot_data()["distances"][0]),
            16,
            "wrong number of distances in the first sequence of branches",
        )
        self.assertEqual(
            sum([len(e) for e in self.plotter.bs_plot_data()["distances"]]),
            160,
            "wrong number of distances",
        )

        lenght = len(
            self.plotter.bs_plot_data(split_branches=False)["distances"][0])
        self.assertEqual(
            lenght, 144,
            "wrong number of distances in the first sequence of branches")

        lenght = len(
            self.plotter.bs_plot_data(split_branches=False)["distances"])
        self.assertEqual(
            lenght, 2,
            "wrong number of distances in the first sequence of branches")

        self.assertEqual(self.plotter.bs_plot_data()["ticks"]["label"][5], "K",
                         "wrong tick label")
        self.assertEqual(
            len(self.plotter.bs_plot_data()["ticks"]["label"]),
            19,
            "wrong number of tick labels",
        )

    def test_get_ticks(self):
        self.assertEqual(self.plotter.get_ticks()["label"][5], "K",
                         "wrong tick label")
        self.assertEqual(
            self.plotter.get_ticks()["distance"][5],
            2.406607625322699,
            "wrong tick distance",
        )

    # Minimal baseline testing for get_plot. not a true test. Just checks that
    # it can actually execute.
    def test_get_plot(self):
        # zero_to_efermi = True, ylim = None, smooth = False,
        # vbm_cbm_marker = False, smooth_tol = None

        # Disabling latex is needed for this test to work.
        from matplotlib import rc

        rc("text", usetex=False)

        plt = self.plotter.get_plot()
        self.assertEqual(plt.ylim(), (-4.0, 7.6348), "wrong ylim")
        plt = self.plotter.get_plot(smooth=True)
        plt = self.plotter.get_plot(vbm_cbm_marker=True)
        self.plotter.save_plot("bsplot.png")
        self.assertTrue(os.path.isfile("bsplot.png"))
        os.remove("bsplot.png")
        plt.close("all")

        # test plotter with 2 bandstructures
        plt = self.plotter_multi.get_plot()
        self.assertEqual(len(plt.gca().get_lines()), 874,
                         "wrong number of lines")
        self.assertEqual(plt.ylim(), (-10.0, 10.0), "wrong ylim")
        plt = self.plotter_multi.get_plot(zero_to_efermi=False)
        self.assertEqual(plt.ylim(), (-15.2379, 12.67141266), "wrong ylim")
        plt = self.plotter_multi.get_plot(smooth=True)
        self.plotter_multi.save_plot("bsplot.png")
        self.assertTrue(os.path.isfile("bsplot.png"))
        os.remove("bsplot.png")
        plt.close("all")
Exemplo n.º 22
0
def make_el_band_plot(ax, bands, yticklabels=True, **kargs):
    """
    Make a DOS plot

    Args:
        ax: an Axes object
        bands: band structure object
        linewidth (int): line width
    """
    # default values of options
    name = None
    linewidth = 2
    if "name" in kargs:
        name = kargs["name"]
    if "linewidth" in kargs:
        linewidth = kargs["linewidth"]
    if "elements" in kargs:
        elements = kargs["elements"]
    else:
        raise KeyError("argument 'elements' in make_el_band_plot is missing")
    for key in kargs:
        if key not in ["name", "linewidth", "elements"]:
            print("WARNING: option {0} not considered".format(key))

    # band structure plot data
    bsplot = BSPlotter(bands)
    plotdata = bsplot.bs_plot_data(zero_to_efermi=True)

    # spin polarized calculation
    if bands.is_spin_polarized:
        all_spins = [Spin.up, Spin.down]
    else:
        all_spins = [Spin.up]

    for spin in all_spins:
        if spin == Spin.up:
            alpha = 1
            lw = linewidth
        if spin == Spin.down:
            alpha = .7
            lw = linewidth / 2

        # compute s, p, d normalized contributions
        contrib = compute_contrib_el(bands, spin, elements)

        # plot bands
        ikpts = 0
        maxd = -1
        mind = 1e10
        for d, ene in zip(plotdata["distances"], plotdata["energy"]):
            npts = len(d)
            maxd = max(max(d), maxd)
            mind = min(min(d), mind)
            for b in range(bands.nb_bands):
                rgbline(ax, d, ene[str(spin)][b],
                        contrib[b, ikpts:ikpts + npts, 0],
                        contrib[b, ikpts:ikpts + npts, 1],
                        contrib[b, ikpts:ikpts + npts:, 2],
                        alpha, lw)
            ikpts += len(d)

    # add ticks and vlines
    make_ticks(ax, bsplot)
    ax.set_xlabel("k-points")
    ax.set_xlim(mind, maxd)
    ax.grid(False)

    if not yticklabels:
        ax.set_yticklabels([])