class BSPlotterTest(unittest.TestCase): def setUp(self): with open(os.path.join(test_dir, "CaO_2605_bandstructure.json"), "r", encoding='utf-8') as f: d = json.loads(f.read()) self.bs = BandStructureSymmLine.from_dict(d) self.plotter = BSPlotter(self.bs) def test_bs_plot_data(self): self.assertEqual(len(self.plotter.bs_plot_data()['distances'][0]), 16, "wrong number of distances in the first branch") self.assertEqual(len(self.plotter.bs_plot_data()['distances']), 10, "wrong number of branches") self.assertEqual( sum([len(e) for e in self.plotter.bs_plot_data()['distances']]), 160, "wrong number of distances") self.assertEqual(self.plotter.bs_plot_data()['ticks']['label'][5], "K", "wrong tick label") self.assertEqual(len(self.plotter.bs_plot_data()['ticks']['label']), 19, "wrong number of tick labels") def test_qvertex_target(self): results = _qvertex_target( [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 1.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 1.0], [1.0, 1.0, 1.0], [0.0, 1.0, 1.0], [0.5, 0.5, 0.5]], 8) self.assertEqual(len(results), 6) self.assertEqual(results[3][1], 0.5)
class BSPlotterTest(unittest.TestCase): def setUp(self): with open(os.path.join(test_dir, "CaO_2605_bandstructure.json"), "r", encoding='utf-8') as f: d = json.loads(f.read()) self.bs = BandStructureSymmLine.from_dict(d) self.plotter = BSPlotter(self.bs) def test_bs_plot_data(self): self.assertEqual(len(self.plotter.bs_plot_data()['distances'][0]), 16, "wrong number of distances in the first branch") self.assertEqual(len(self.plotter.bs_plot_data()['distances']), 10, "wrong number of branches") self.assertEqual( sum([len(e) for e in self.plotter.bs_plot_data()['distances']]), 160, "wrong number of distances") self.assertEqual(self.plotter.bs_plot_data()['ticks']['label'][5], "K", "wrong tick label") self.assertEqual(len(self.plotter.bs_plot_data()['ticks']['label']), 19, "wrong number of tick labels") def test_qvertex_target(self): results = _qvertex_target([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 1.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 1.0], [1.0, 1.0, 1.0], [0.0, 1.0, 1.0], [0.5, 0.5, 0.5]], 8) self.assertEqual(len(results), 6) self.assertEqual(results[3][1], 0.5)
class BSPlotterTest(unittest.TestCase): def setUp(self): with open(os.path.join(test_dir, "CaO_2605_bandstructure.json"), "rb") as f: d = json.loads(f.read()) self.bs = BandStructureSymmLine.from_dict(d) self.plotter = BSPlotter(self.bs) def test_bs_plot_data(self): self.assertEqual(len(self.plotter.bs_plot_data()['distances']), 160, "wrong number of distances") self.assertEqual(self.plotter.bs_plot_data()['ticks']['label'][5], "K", "wrong tick label") self.assertEqual(len(self.plotter.bs_plot_data()['ticks']['label']), 19, "wrong number of tick labels")
def find_dirac_nodes(): """ Look for band crossings near (within `tol` eV) the Fermi level. Returns: boolean. Whether or not a band crossing occurs at or near the fermi level. """ vasprun = Vasprun('vasprun.xml') dirac = False if vasprun.get_band_structure().get_band_gap()['energy'] < 0.1: efermi = vasprun.efermi bsp = BSPlotter(vasprun.get_band_structure('KPOINTS', line_mode=True, efermi=efermi)) bands = [] data = bsp.bs_plot_data(zero_to_efermi=True) for d in range(len(data['distances'])): for i in range(bsp._nb_bands): x = data['distances'][d], y = [data['energy'][d][str(Spin.up)][i][j] for j in range(len(data['distances'][d]))] band = [x, y] bands.append(band) considered = [] for i in range(len(bands)): for j in range(len(bands)): if i != j and (j, i) not in considered: considered.append((j, i)) for k in range(len(bands[i][0])): if ((-0.1 < bands[i][1][k] < 0.1) and (-0.1 < bands[i][1][k] - bands[j][1][k] < 0.1)): dirac = True return dirac
def plot_band_structure(ylim=(-5, 5), draw_fermi=False, fmt='pdf'): """ Plot a standard band structure with no projections. Args: ylim (tuple): minimum and maximum potentials for the plot's y-axis. draw_fermi (bool): whether or not to draw a dashed line at E_F. fmt (str): matplotlib format style. Check the matplotlib docs for options. """ vasprun = Vasprun('vasprun.xml') efermi = vasprun.efermi bsp = BSPlotter( vasprun.get_band_structure('KPOINTS', line_mode=True, efermi=efermi)) if fmt == "None": return bsp.bs_plot_data() else: plot = bsp.get_plot(ylim=ylim) fig = plot.gcf() ax = fig.gca() ax.set_xticklabels( [r'$\mathrm{%s}$' % t for t in ax.get_xticklabels()]) ax.set_yticklabels( [r'$\mathrm{%s}$' % t for t in ax.get_yticklabels()]) if draw_fermi: ax.plot([ax.get_xlim()[0], ax.get_xlim()[1]], [0, 0], 'k--') plt.savefig('band_structure.{}'.format(fmt), transparent=True) plt.close()
def find_dirac_nodes(): """ Look for band crossings near (within `tol` eV) the Fermi level. Returns: boolean. Whether or not a band crossing occurs at or near the fermi level. """ vasprun = Vasprun('vasprun.xml') dirac = False if vasprun.get_band_structure().get_band_gap()['energy'] < 0.1: efermi = vasprun.efermi bsp = BSPlotter(vasprun.get_band_structure('KPOINTS', line_mode=True, efermi=efermi)) bands = [] data = bsp.bs_plot_data(zero_to_efermi=True) for d in range(len(data['distances'])): for i in range(bsp._nb_bands): x = data['distances'][d], y = [data['energy'][d][str(Spin.up)][i][j] for j in range(len(data['distances'][d]))] band = [x, y] bands.append(band) considered = [] for i in range(len(bands)): for j in range(len(bands)): if i != j and (j, i) not in considered: considered.append((j, i)) for k in range(len(bands[i][0])): if ((-0.1 < bands[i][1][k] < 0.1) and (-0.1 < bands[i][1][k] - bands[j][1][k] < 0.1)): dirac = True return dirac
def make_band_plot_info(self): bs_plotter = BSPlotter(self.bs) plot_data = bs_plotter.bs_plot_data(zero_to_efermi=False) distances = [list(d) for d in plot_data["distances"]] self._composition = self.vasprun.final_structure.composition band_info = [BandInfo(band_energies=self._remove_spin_key(plot_data), band_edge=self._band_edge(self.bs, plot_data), fermi_level=self.bs.efermi)] if self.vasprun2: bs2 = self.vasprun2.get_band_structure(self.kpoints_filename, line_mode=True) plot_data2 = BSPlotter(bs2).bs_plot_data(zero_to_efermi=False) band_info.append( BandInfo(band_energies=self._remove_spin_key(plot_data2), band_edge=self._band_edge(bs2, plot_data2), fermi_level=self.bs.efermi)) x = bs_plotter.get_ticks_old() x_ticks = XTicks(_sanitize_labels(x["label"]), x["distance"]) return BandPlotInfo(band_info_set=band_info, distances_by_branch=distances, x_ticks=x_ticks, title=self._title)
class BSPlotterTest(unittest.TestCase): def setUp(self): with open(os.path.join(test_dir, "CaO_2605_bandstructure.json"), "rb") as f: d = json.loads(f.read()) self.bs = BandStructureSymmLine.from_dict(d) self.plotter = BSPlotter(self.bs) def test_bs_plot_data(self): self.assertEqual(len(self.plotter.bs_plot_data()['distances']), 160, "wrong number of distances") self.assertEqual(self.plotter.bs_plot_data()['ticks']['label'][5], "K", "wrong tick label") self.assertEqual(len(self.plotter.bs_plot_data()['ticks']['label']), 19, "wrong number of tick labels")
class BSPlotterTest(unittest.TestCase): def setUp(self): with open(os.path.join(test_dir, "CaO_2605_bandstructure.json"), "r", encoding='utf-8') as f: d = json.loads(f.read()) self.bs = BandStructureSymmLine.from_dict(d) self.plotter = BSPlotter(self.bs) warnings.simplefilter("ignore") def tearDown(self): warnings.simplefilter("default") def test_bs_plot_data(self): self.assertEqual(len(self.plotter.bs_plot_data()['distances'][0]), 16, "wrong number of distances in the first branch") self.assertEqual(len(self.plotter.bs_plot_data()['distances']), 10, "wrong number of branches") self.assertEqual( sum([len(e) for e in self.plotter.bs_plot_data()['distances']]), 160, "wrong number of distances") self.assertEqual(self.plotter.bs_plot_data()['ticks']['label'][5], "K", "wrong tick label") self.assertEqual(len(self.plotter.bs_plot_data()['ticks']['label']), 19, "wrong number of tick labels") # Minimal baseline testing for get_plot. not a true test. Just checks that # it can actually execute. def test_get_plot(self): # zero_to_efermi = True, ylim = None, smooth = False, # vbm_cbm_marker = False, smooth_tol = None # Disabling latex is needed for this test to work. from matplotlib import rc rc('text', usetex=False) plt = self.plotter.get_plot() plt = self.plotter.get_plot(smooth=True) plt = self.plotter.get_plot(vbm_cbm_marker=True) self.plotter.save_plot("bsplot.png") self.assertTrue(os.path.isfile("bsplot.png")) os.remove("bsplot.png") plt.close("all")
def banddos(pref='',storedir=None): ru=str("vasprun.xml") kpfile=str("KPOINTS") run = Vasprun(ru, parse_projected_eigen = True) bands = run.get_band_structure(kpfile, line_mode = True, efermi = run.efermi) bsp = BSPlotter(bands) zero_to_efermi=True bandgap=str(round(bands.get_band_gap()['energy'],3)) print "bg=",bandgap data=bsp.bs_plot_data(zero_to_efermi) plt = get_publication_quality_plot(12, 8) band_linewidth = 3 x_max = data['distances'][-1][-1] print (x_max) for d in range(len(data['distances'])): for i in range(bsp._nb_bands): plt.plot(data['distances'][d], [data['energy'][d]['1'][i][j] for j in range(len(data['distances'][d]))], 'b-', linewidth=band_linewidth) if bsp._bs.is_spin_polarized: plt.plot(data['distances'][d], [data['energy'][d]['-1'][i][j] for j in range(len(data['distances'][d]))], 'r--', linewidth=band_linewidth) bsp._maketicks(plt) if bsp._bs.is_metal(): e_min = -10 e_max = 10 band_linewidth = 3 for cbm in data['cbm']: plt.scatter(cbm[0], cbm[1], color='r', marker='o', s=100) for vbm in data['vbm']: plt.scatter(vbm[0], vbm[1], color='g', marker='o', s=100) plt.xlabel(r'$\mathrm{Wave\ Vector}$', fontsize=30) ylabel = r'$\mathrm{E\ -\ E_f\ (eV)}$' if zero_to_efermi \ else r'$\mathrm{Energy\ (eV)}$' plt.ylabel(ylabel, fontsize=30) plt.ylim(-4,4) plt.xlim(0,x_max) plt.tight_layout() plt.savefig('BAND.png',img_format="png") plt.close()
class BSPlotterTest(unittest.TestCase): def setUp(self): with open(os.path.join(test_dir, "CaO_2605_bandstructure.json"), "r", encoding='utf-8') as f: d = json.loads(f.read()) self.bs = BandStructureSymmLine.from_dict(d) self.plotter = BSPlotter(self.bs) warnings.simplefilter("ignore") def tearDown(self): warnings.resetwarnings() def test_bs_plot_data(self): self.assertEqual(len(self.plotter.bs_plot_data()['distances'][0]), 16, "wrong number of distances in the first branch") self.assertEqual(len(self.plotter.bs_plot_data()['distances']), 10, "wrong number of branches") self.assertEqual( sum([len(e) for e in self.plotter.bs_plot_data()['distances']]), 160, "wrong number of distances") self.assertEqual(self.plotter.bs_plot_data()['ticks']['label'][5], "K", "wrong tick label") self.assertEqual(len(self.plotter.bs_plot_data()['ticks']['label']), 19, "wrong number of tick labels") # Minimal baseline testing for get_plot. not a true test. Just checks that # it can actually execute. def test_get_plot(self): # zero_to_efermi = True, ylim = None, smooth = False, # vbm_cbm_marker = False, smooth_tol = None # Disabling latex is needed for this test to work. from matplotlib import rc rc('text', usetex=False) plt = self.plotter.get_plot() plt = self.plotter.get_plot(smooth=True) plt = self.plotter.get_plot(vbm_cbm_marker=True) self.plotter.save_plot("bsplot.png") self.assertTrue(os.path.isfile("bsplot.png")) os.remove("bsplot.png")
def get_bandsxy(self, bs, bandrange): ## get coords of the band structure data points bsplot = BSPlotter(bs) data = bsplot.bs_plot_data() x = [k for kbranch in data["distances"] for k in kbranch] yu = [[e - bs.efermi for e in bs.bands[Spin.up][band]] for band in bandrange] if bs.is_spin_polarized: yd = [[e - bs.efermi for e in bs.bands[Spin.down][band]] for band in bandrange] else: yd = None return [x, yu, yd]
def get_plot( self, n_idx, t_idx, zero_to_efermi=True, estep=0.01, line_density=100, height=3.2, width=3.2, emin=None, emax=None, amin=5e-5, amax=1e-1, ylabel="Energy (eV)", plt=None, aspect=None, kpath=None, cmap="viridis", colorbar=True, style=None, no_base_style=False, fonts=None, ): interpolater = self._get_interpolater(n_idx, t_idx) bs, prop = interpolater.get_line_mode_band_structure( line_density=line_density, return_other_properties=True, kpath=kpath, symprec=self.symprec, ) bs, rates = force_branches(bs, {s: p["rates"] for s, p in prop.items()}) fd_emin, fd_emax = self.fd_cutoffs if not emin: emin = fd_emin if zero_to_efermi: emin -= bs.efermi if not emax: emax = fd_emax if zero_to_efermi: emax -= bs.efermi logger.info("Plotting band structure") if isinstance(plt, (Axis, SubplotBase)): ax = plt else: plt = pretty_plot(width=width, height=height, plt=plt) ax = plt.gca() if zero_to_efermi: bs.bands = {s: b - bs.efermi for s, b in bs.bands.items()} bs.efermi = 0 bs_plotter = BSPlotter(bs) plot_data = bs_plotter.bs_plot_data(zero_to_efermi=zero_to_efermi) energies = np.linspace(emin, emax, int((emax - emin) / estep)) distances = np.array([d for x in plot_data["distances"] for d in x]) # rates are currently log(rate) mesh_data = np.full((len(distances), len(energies)), 0.0) for spin in self.spins: for spin_energies, spin_rates in zip(bs.bands[spin], rates[spin]): for d_idx in range(len(distances)): energy = spin_energies[d_idx] linewidth = 10 ** spin_rates[d_idx] * hbar / 2 broadening = lorentzian(energies, energy, linewidth) broadening /= 1000 # convert 1/eV to 1/meV mesh_data[d_idx] += broadening im = ax.pcolormesh( distances, energies, mesh_data.T, rasterized=True, cmap=cmap, norm=LogNorm(vmin=amin, vmax=amax), shading="auto", ) if colorbar: pos = ax.get_position() cax = plt.gcf().add_axes([pos.x1 + 0.035, pos.y0, 0.035, pos.height]) cbar = plt.colorbar(im, cax=cax) cbar.ax.tick_params(axis="y", length=rcParams["ytick.major.size"] * 0.5) cbar.ax.set_ylabel( r"$A_\mathbf{k}$ (meV$^{-1}$)", rotation=270, va="bottom" ) _maketicks(ax, bs_plotter, ylabel=ylabel) _makeplot( ax, plot_data, bs, zero_to_efermi=zero_to_efermi, width=width, height=height, ymin=emin, ymax=emax, aspect=aspect, ) return plt
def get_bandstructure_traces(bs, path_convention, energy_window=(-6.0, 10.0)): if path_convention == "lm": bs = HighSymmKpath.get_continuous_path(bs) bs_reg_plot = BSPlotter(bs) bs_data = bs_reg_plot.bs_plot_data(split_branches=False) bands = [] for band_num in range(bs.nb_bands): for segment in bs_data["energy"][str(Spin.up)]: if any(segment[band_num] <= energy_window[1]) and any( segment[band_num] >= energy_window[0]): bands.append(band_num) bstraces = [] cbm = bs.get_cbm() vbm = bs.get_vbm() cbm_new = bs_data["cbm"] vbm_new = bs_data["vbm"] bar_loc = [] for d, dist_val in enumerate(bs_data["distances"]): x_dat = dist_val traces_for_segment = [] segment = bs_data["energy"][str(Spin.up)][d] traces_for_segment += [{ "x": x_dat, "y": segment[band_num], "mode": "lines", "line": { "color": "#1f77b4" }, "hoverinfo": "skip", "name": "spin ↑" if bs.is_spin_polarized else "Total", "hovertemplate": "%{y:.2f} eV", "showlegend": False, "xaxis": "x", "yaxis": "y", } for band_num in bands] if bs.is_spin_polarized: traces_for_segment += [{ "x": x_dat, "y": [ bs_data["energy"][str(Spin.down)][d][i][j] for j in range(len(bs_data["distances"][d])) ], "mode": "lines", "line": { "color": "#ff7f0e", "dash": "dot" }, "hoverinfo": "skip", "showlegend": False, "name": "spin ↓", "hovertemplate": "%{y:.2f} eV", "xaxis": "x", "yaxis": "y", } for i in bands] bstraces += traces_for_segment bar_loc.append(dist_val[-1]) # - Strip latex math wrapping for labels str_replace = { "$": "", "\\mid": "|", "\\Gamma": "Γ", "\\Sigma": "Σ", "GAMMA": "Γ", "_1": "₁", "_2": "₂", "_3": "₃", "_4": "₄", "_{1}": "₁", "_{2}": "₂", "_{3}": "₃", "_{4}": "₄", "^{*}": "*", } for entry_num in range(len(bs_data["ticks"]["label"])): for key in str_replace.keys(): if key in bs_data["ticks"]["label"][entry_num]: bs_data["ticks"]["label"][entry_num] = bs_data["ticks"][ "label"][entry_num].replace(key, str_replace[key]) # Vertical lines for disjointed segments vert_traces = [{ "x": [x_point, x_point], "y": energy_window, "mode": "lines", "marker": { "color": "white" }, "hoverinfo": "skip", "showlegend": False, "xaxis": "x", "yaxis": "y", } for x_point in bar_loc] bstraces += vert_traces # Dots for cbm and vbm dot_traces = [{ "x": [x_point], "y": [y_point], "mode": "markers", "marker": { "color": "#7E259B", "size": 16, "line": { "color": "white", "width": 2 }, }, "showlegend": False, "hoverinfo": "text", "name": "", "hovertemplate": "CBM: k = {}, {} eV".format(list(cbm["kpoint"].frac_coords), cbm["energy"]), "xaxis": "x", "yaxis": "y", } for (x_point, y_point) in set(cbm_new)] + [{ "x": [x_point], "y": [y_point], "mode": "marker", "marker": { "color": "#7E259B", "size": 16, "line": { "color": "white", "width": 2 }, }, "showlegend": False, "hoverinfo": "text", "name": "", "hovertemplate": "VBM: k = {}, {} eV".format( list(vbm["kpoint"].frac_coords), vbm["energy"]), "xaxis": "x", "yaxis": "y", } for (x_point, y_point) in set(vbm_new)] bstraces += dot_traces return bstraces, bs_data
dosplotter = DosPlotter() Totaldos = dosplotter.add_dos('Total DOS', tdos) Integrateddos = dosplotter.add_dos('Integrated DOS', idos) #Pdos = dosplotter.add_dos('Partial DOS',pdos) #Spd_dos = dosplotter.add_dos('spd DOS',spd_dos) #Element_dos = dosplotter.add_dos('Element DOS',element_dos) #Element_spd_dos = dosplotter.add_dos('Element_spd DOS',element_spd_dos) dos_dict = { 'Total DOS': tdos, 'Integrated DOS': idos } #'Partial DOS':pdos,'spd DOS':spd_dos,'Element DOS':element_dos}#'Element_spd DOS':element_spd_dos add_dos_dict = dosplotter.add_dos_dict(dos_dict) get_dos_dict = dosplotter.get_dos_dict() dos_plot = dosplotter.get_plot() ##dosplotter.save_plot("MAPbI3_dos",img_format="png") ##dos_plot.show() bsplotter = BSPlotter(bs) bs_plot_data = bsplotter.bs_plot_data() bs_plot = bsplotter.get_plot() #bsplotter.save_plot("MAPbI3_bs",img_format="png") #bsplotter.show() ticks = bsplotter.get_ticks() print ticks bsplotter.plot_brillouin() bsdos = BSDOSPlotter( tick_fontsize=10, egrid_interval=20, dos_projection="orbitals", bs_legend=None) #bs_projection="HPbCIN",dos_projection="HPbCIN") bds = bsdos.get_plot(bs, cdos)
def bs_dos_traces(bandStructureSymmLine, densityOfStates): if bandStructureSymmLine == "error" or densityOfStates == "error": return "error" if bandStructureSymmLine == None or densityOfStates == None: raise PreventUpdate # - BS Data bstraces = [] bs_reg_plot = BSPlotter(BSML.from_dict(bandStructureSymmLine)) bs_data = bs_reg_plot.bs_plot_data() # -- Strip latex math wrapping str_replace = { "$": "", "\\mid": "|", "\\Gamma": "Γ", "\\Sigma": "Σ", "_1": "₁", "_2": "₂", "_3": "₃", "_4": "₄", } for entry_num in range(len(bs_data["ticks"]["label"])): for key in str_replace.keys(): if key in bs_data["ticks"]["label"][entry_num]: bs_data["ticks"]["label"][entry_num] = bs_data[ "ticks"]["label"][entry_num].replace( key, str_replace[key]) for d in range(len(bs_data["distances"])): for i in range(bs_reg_plot._nb_bands): bstraces.append( go.Scatter( x=bs_data["distances"][d], y=[ bs_data["energy"][d][str(Spin.up)][i][j] for j in range(len(bs_data["distances"][d])) ], mode="lines", line=dict(color=("#666666"), width=2), hoverinfo="skip", showlegend=False, )) if bs_reg_plot._bs.is_spin_polarized: bstraces.append( go.Scatter( x=bs_data["distances"][d], y=[ bs_data["energy"][d][str(Spin.down)][i][j] for j in range(len(bs_data["distances"] [d])) ], mode="lines", line=dict(color=("#666666"), width=2, dash="dash"), hoverinfo="skip", showlegend=False, )) # -- DOS Data dostraces = [] dos = CompleteDos.from_dict(densityOfStates) if Spin.down in dos.densities: # Add second spin data if available trace_tdos = go.Scatter( x=dos.densities[Spin.down], y=dos.energies - dos.efermi, mode="lines", name="Total DOS (spin ↓)", line=go.scatter.Line(color="#444444", dash="dash"), fill="tozeroy", ) dostraces.append(trace_tdos) tdos_label = "Total DOS (spin ↑)" else: tdos_label = "Total DOS" # Total DOS trace_tdos = go.Scatter( x=dos.densities[Spin.up], y=dos.energies - dos.efermi, mode="lines", name=tdos_label, line=go.scatter.Line(color="#444444"), fill="tozeroy", legendgroup="spinup", ) dostraces.append(trace_tdos) p_ele_dos = dos.get_element_dos() # Projected DOS count = 0 colors = [ "#1f77b4", # muted blue "#ff7f0e", # safety orange "#2ca02c", # cooked asparagus green "#d62728", # brick red "#9467bd", # muted purple "#8c564b", # chestnut brown "#e377c2", # raspberry yogurt pink "#bcbd22", # curry yellow-green "#17becf", # blue-teal ] for ele in p_ele_dos.keys(): if bs_reg_plot._bs.is_spin_polarized: trace = go.Scatter( x=p_ele_dos[ele].densities[Spin.down], y=dos.energies - dos.efermi, mode="lines", name=ele.symbol + " (spin ↓)", line=dict(width=3, color=colors[count], dash="dash"), ) dostraces.append(trace) spin_up_label = ele.symbol + " (spin ↑)" else: spin_up_label = ele.symbol trace = go.Scatter( x=p_ele_dos[ele].densities[Spin.up], y=dos.energies - dos.efermi, mode="lines", name=spin_up_label, line=dict(width=3, color=colors[count]), ) dostraces.append(trace) count += 1 traces = [bstraces, dostraces, bs_data] return traces
def get_plot( self, n_idx, t_idx, zero_to_efermi=True, estep=0.01, line_density=100, height=6, width=6, emin=None, emax=None, ylabel="Energy (eV)", plt=None, aspect=None, distance_factor=10, kpath=None, style=None, no_base_style=False, fonts=None, ): interpolater = self._get_interpolater(n_idx, t_idx) bs, prop = interpolater.get_line_mode_band_structure( line_density=line_density, return_other_properties=True, kpath=kpath) fd_emin, fd_emax = self.fd_cutoffs if not emin: emin = fd_emin * hartree_to_ev if zero_to_efermi: emin -= bs.efermi if not emax: emax = fd_emax * hartree_to_ev if zero_to_efermi: emax -= bs.efermi logger.info("Plotting band structure") plt = pretty_plot(width=width, height=height, plt=plt) ax = plt.gca() if zero_to_efermi: bs.bands = {s: b - bs.efermi for s, b in bs.bands.items()} bs.efermi = 0 bs_plotter = BSPlotter(bs) plot_data = bs_plotter.bs_plot_data(zero_to_efermi=zero_to_efermi) energies = np.linspace(emin, emax, int((emax - emin) / estep)) distances = np.array([d for x in plot_data["distances"] for d in x]) # rates are currently log(rate) rates = {} for spin, spin_data in prop.items(): rates[spin] = spin_data["rates"] rates[spin][rates[spin] <= 0] = np.min( rates[spin][rates[spin] > 0]) rates[spin][rates[spin] >= 15] = 15 interp_distances = np.linspace(distances.min(), distances.max(), int(len(distances) * distance_factor)) window = np.min([len(distances) - 2, 71]) window += window % 2 + 1 mesh_data = np.full((len(interp_distances), len(energies)), 1e-2) for spin in self.spins: for spin_energies, spin_rates in zip(bs.bands[spin], rates[spin]): interp_energies = interp1d(distances, spin_energies)(interp_distances) spin_rates = savgol_filter(spin_rates, window, 3) interp_rates = interp1d(distances, spin_rates)(interp_distances) linewidths = 10**interp_rates * hbar / 2 for d_idx in range(len(interp_distances)): energy = interp_energies[d_idx] linewidth = linewidths[d_idx] broadening = lorentzian(energies, energy, linewidth) mesh_data[d_idx] = np.maximum(broadening, mesh_data[d_idx]) mesh_data[d_idx] = np.maximum(broadening, mesh_data[d_idx]) ax.pcolormesh( interp_distances, energies, mesh_data.T, rasterized=True, norm=LogNorm(vmin=mesh_data.min(), vmax=mesh_data.max()), ) _maketicks(ax, bs_plotter, ylabel=ylabel) _makeplot( ax, plot_data, bs, zero_to_efermi=zero_to_efermi, width=width, height=height, ymin=emin, ymax=emax, aspect=aspect, ) return plt
def bs_dos_data( mpid, path_convention, dos_select, label_select, bandstructure_symm_line, density_of_states, ): if not mpid and (bandstructure_symm_line is None or density_of_states is None): raise PreventUpdate elif bandstructure_symm_line is None or density_of_states is None: if label_select == "": raise PreventUpdate # -- # -- BS and DOS from API or DB # -- bs_data = {"ticks": {}} bs_store = GridFSStore( database="fw_bs_prod", collection_name="bandstructure_fs", host="mongodb03.nersc.gov", port=27017, username="******", password="", ) dos_store = GridFSStore( database="fw_bs_prod", collection_name="dos_fs", host="mongodb03.nersc.gov", port=27017, username="******", password="", ) es_store = MongoStore( database="fw_bs_prod", collection_name="electronic_structure", host="mongodb03.nersc.gov", port=27017, username="******", password="", key="task_id", ) # - BS traces from DB using task_id es_store.connect() bs_query = es_store.query_one( criteria={"task_id": int(mpid)}, properties=[ "bandstructure.{}.task_id".format(path_convention), "bandstructure.{}.total.equiv_labels".format( path_convention), ], ) es_store.close() bs_store.connect() bandstructure_symm_line = bs_store.query_one(criteria={ "metadata.task_id": int(bs_query["bandstructure"][path_convention]["task_id"]) }, ) # If LM convention, get equivalent labels if path_convention != label_select: bs_equiv_labels = bs_query["bandstructure"][ path_convention]["total"]["equiv_labels"] new_labels_dict = {} for label in bandstructure_symm_line["labels_dict"].keys(): label_formatted = label.replace("$", "") if "|" in label_formatted: f_label = label_formatted.split("|") new_labels.append( "$" + bs_equiv_labels[label_select][f_label[0]] + "|" + bs_equiv_labels[label_select][f_label[1]] + "$") else: new_labels_dict["$" + bs_equiv_labels[label_select] [label_formatted] + "$"] = bandstructure_symm_line[ "labels_dict"][label] bandstructure_symm_line["labels_dict"] = new_labels_dict # - DOS traces from DB using task_id es_store.connect() dos_query = es_store.query_one( criteria={"task_id": int(mpid)}, properties=["dos.task_id"], ) es_store.close() dos_store.connect() density_of_states = dos_store.query_one( criteria={"task_id": int(dos_query["dos"]["task_id"])}, ) # - BS Data if (type(bandstructure_symm_line) != dict and bandstructure_symm_line is not None): bandstructure_symm_line = bandstructure_symm_line.to_dict() if type(density_of_states ) != dict and density_of_states is not None: density_of_states = density_of_states.to_dict() bsml = BSML.from_dict(bandstructure_symm_line) bs_reg_plot = BSPlotter(bsml) bs_data = bs_reg_plot.bs_plot_data() # Make plot continous for lm if path_convention == "lm": distance_map, kpath_euler = HSKP( bsml.structure).get_continuous_path(bsml) kpath_labels = [pair[0] for pair in kpath_euler] kpath_labels.append(kpath_euler[-1][1]) else: distance_map = [(i, False) for i in range(len(bs_data["distances"]))] kpath_labels = [] for label_ind in range(len(bs_data["ticks"]["label"]) - 1): if (bs_data["ticks"]["label"][label_ind] != bs_data["ticks"]["label"][label_ind + 1]): kpath_labels.append( bs_data["ticks"]["label"][label_ind]) kpath_labels.append(bs_data["ticks"]["label"][-1]) bs_data["ticks"]["label"] = kpath_labels # Obtain bands to plot over and generate traces for bs data: energy_window = (-6.0, 10.0) bands = [] for band_num in range(bs_reg_plot._nb_bands): if (bs_data["energy"][0][str(Spin.up)][band_num][0] <= energy_window[1]) and (bs_data["energy"][0][str( Spin.up)][band_num][0] >= energy_window[0]): bands.append(band_num) bstraces = [] pmin = 0.0 tick_vals = [0.0] cbm = bsml.get_cbm() vbm = bsml.get_vbm() cbm_new = bs_data["cbm"] vbm_new = bs_data["vbm"] for dnum, (d, rev) in enumerate(distance_map): x_dat = [ dval - bs_data["distances"][d][0] + pmin for dval in bs_data["distances"][d] ] pmin = x_dat[-1] tick_vals.append(pmin) if not rev: traces_for_segment = [{ "x": x_dat, "y": [ bs_data["energy"][d][str(Spin.up)][i][j] for j in range(len(bs_data["distances"][d])) ], "mode": "lines", "line": { "color": "#1f77b4" }, "hoverinfo": "skip", "name": "spin ↑" if bs_reg_plot._bs.is_spin_polarized else "Total", "hovertemplate": "%{y:.2f} eV", "showlegend": False, "xaxis": "x", "yaxis": "y", } for i in bands] elif rev: traces_for_segment = [{ "x": x_dat, "y": [ bs_data["energy"][d][str(Spin.up)][i][j] for j in reversed( range(len(bs_data["distances"][d]))) ], "mode": "lines", "line": { "color": "#1f77b4" }, "hoverinfo": "skip", "name": "spin ↑" if bs_reg_plot._bs.is_spin_polarized else "Total", "hovertemplate": "%{y:.2f} eV", "showlegend": False, "xaxis": "x", "yaxis": "y", } for i in bands] if bs_reg_plot._bs.is_spin_polarized: if not rev: traces_for_segment += [{ "x": x_dat, "y": [ bs_data["energy"][d][str(Spin.down)][i][j] for j in range(len(bs_data["distances"][d])) ], "mode": "lines", "line": { "color": "#ff7f0e", "dash": "dot" }, "hoverinfo": "skip", "showlegend": False, "name": "spin ↓", "hovertemplate": "%{y:.2f} eV", "xaxis": "x", "yaxis": "y", } for i in bands] elif rev: traces_for_segment += [{ "x": x_dat, "y": [ bs_data["energy"][d][str(Spin.down)][i][j] for j in reversed( range(len(bs_data["distances"][d]))) ], "mode": "lines", "line": { "color": "#ff7f0e", "dash": "dot" }, "hoverinfo": "skip", "showlegend": False, "name": "spin ↓", "hovertemplate": "%{y:.2f} eV", "xaxis": "x", "yaxis": "y", } for i in bands] bstraces += traces_for_segment # - Get proper cbm and vbm coords for lm if path_convention == "lm": for (x_point, y_point) in bs_data["cbm"]: if x_point in bs_data["distances"][d]: xind = bs_data["distances"][d].index(x_point) if not rev: x_point_new = x_dat[xind] else: x_point_new = x_dat[len(x_dat) - xind - 1] new_label = bs_data["ticks"]["label"][ tick_vals.index(x_point_new)] if (cbm["kpoint"].label is None or cbm["kpoint"].label in new_label): cbm_new.append((x_point_new, y_point)) for (x_point, y_point) in bs_data["vbm"]: if x_point in bs_data["distances"][d]: xind = bs_data["distances"][d].index(x_point) if not rev: x_point_new = x_dat[xind] else: x_point_new = x_dat[len(x_dat) - xind - 1] new_label = bs_data["ticks"]["label"][ tick_vals.index(x_point_new)] if (vbm["kpoint"].label is None or vbm["kpoint"].label in new_label): vbm_new.append((x_point_new, y_point)) bs_data["ticks"]["distance"] = tick_vals # - Strip latex math wrapping for labels str_replace = { "$": "", "\\mid": "|", "\\Gamma": "Γ", "\\Sigma": "Σ", "GAMMA": "Γ", "_1": "₁", "_2": "₂", "_3": "₃", "_4": "₄", "_{1}": "₁", "_{2}": "₂", "_{3}": "₃", "_{4}": "₄", "^{*}": "*", } bar_loc = [] for entry_num in range(len(bs_data["ticks"]["label"])): for key in str_replace.keys(): if key in bs_data["ticks"]["label"][entry_num]: bs_data["ticks"]["label"][entry_num] = bs_data[ "ticks"]["label"][entry_num].replace( key, str_replace[key]) if key == "\\mid": bar_loc.append( bs_data["ticks"]["distance"][entry_num]) # Vertical lines for disjointed segments vert_traces = [{ "x": [x_point, x_point], "y": energy_window, "mode": "lines", "marker": { "color": "white" }, "hoverinfo": "skip", "showlegend": False, "xaxis": "x", "yaxis": "y", } for x_point in bar_loc] bstraces += vert_traces # Dots for cbm and vbm dot_traces = [{ "x": [x_point], "y": [y_point], "mode": "markers", "marker": { "color": "#7E259B", "size": 16, "line": { "color": "white", "width": 2 }, }, "showlegend": False, "hoverinfo": "text", "name": "", "hovertemplate": "CBM: k = {}, {} eV".format(list(cbm["kpoint"].frac_coords), cbm["energy"]), "xaxis": "x", "yaxis": "y", } for (x_point, y_point) in set(cbm_new)] + [{ "x": [x_point], "y": [y_point], "mode": "marker", "marker": { "color": "#7E259B", "size": 16, "line": { "color": "white", "width": 2 }, }, "showlegend": False, "hoverinfo": "text", "name": "", "hovertemplate": "VBM: k = {}, {} eV".format(list(vbm["kpoint"].frac_coords), vbm["energy"]), "xaxis": "x", "yaxis": "y", } for (x_point, y_point) in set(vbm_new)] bstraces += dot_traces # - DOS Data dostraces = [] dos = CompleteDos.from_dict(density_of_states) dos_max = np.abs( (dos.energies - dos.efermi - energy_window[1])).argmin() dos_min = np.abs( (dos.energies - dos.efermi - energy_window[0])).argmin() if bs_reg_plot._bs.is_spin_polarized: # Add second spin data if available trace_tdos = { "x": -1.0 * dos.densities[Spin.down][dos_min:dos_max], "y": dos.energies[dos_min:dos_max] - dos.efermi, "mode": "lines", "name": "Total DOS (spin ↓)", "line": go.scatter.Line(color="#444444", dash="dot"), "fill": "tozerox", "fillcolor": "#C4C4C4", "xaxis": "x2", "yaxis": "y2", } dostraces.append(trace_tdos) tdos_label = "Total DOS (spin ↑)" else: tdos_label = "Total DOS" # Total DOS trace_tdos = { "x": dos.densities[Spin.up][dos_min:dos_max], "y": dos.energies[dos_min:dos_max] - dos.efermi, "mode": "lines", "name": tdos_label, "line": go.scatter.Line(color="#444444"), "fill": "tozerox", "fillcolor": "#C4C4C4", "legendgroup": "spinup", "xaxis": "x2", "yaxis": "y2", } dostraces.append(trace_tdos) ele_dos = dos.get_element_dos() elements = [str(entry) for entry in ele_dos.keys()] if dos_select == "ap": proj_data = ele_dos elif dos_select == "op": proj_data = dos.get_spd_dos() elif "orb" in dos_select: proj_data = dos.get_element_spd_dos( Element(dos_select.replace("orb", ""))) else: raise PreventUpdate # Projected DOS count = 0 colors = [ "#d62728", # brick red "#2ca02c", # cooked asparagus green "#17becf", # blue-teal "#bcbd22", # curry yellow-green "#9467bd", # muted purple "#8c564b", # chestnut brown "#e377c2", # raspberry yogurt pink ] for label in proj_data.keys(): if bs_reg_plot._bs.is_spin_polarized: trace = { "x": -1.0 * proj_data[label].densities[Spin.down][dos_min:dos_max], "y": dos.energies[dos_min:dos_max] - dos.efermi, "mode": "lines", "name": str(label) + " (spin ↓)", "line": dict(width=3, color=colors[count], dash="dot"), "xaxis": "x2", "yaxis": "y2", } dostraces.append(trace) spin_up_label = str(label) + " (spin ↑)" else: spin_up_label = str(label) trace = { "x": proj_data[label].densities[Spin.up][dos_min:dos_max], "y": dos.energies[dos_min:dos_max] - dos.efermi, "mode": "lines", "name": spin_up_label, "line": dict(width=2, color=colors[count]), "xaxis": "x2", "yaxis": "y2", } dostraces.append(trace) count += 1 traces = [bstraces, dostraces, bs_data] return (traces, elements)
def bs_dos_data( mpid, path_convention, dos_select, label_select, bandstructure_symm_line, density_of_states, ): if (not mpid or "mpid" not in mpid) and (bandstructure_symm_line is None or density_of_states is None): raise PreventUpdate elif mpid: raise PreventUpdate elif bandstructure_symm_line is None or density_of_states is None: # -- # -- BS and DOS from API # -- mpid = mpid["mpid"] bs_data = {"ticks": {}} # client = MongoClient( # "mongodb03.nersc.gov", username="******", password="", authSource="fw_bs_prod", # ) db = client.fw_bs_prod # - BS traces from DB using task_id bs_query = list( db.electronic_structure.find( {"task_id": int(mpid)}, [ "bandstructure.{}.total.traces".format( path_convention) ], ))[0] is_sp = (len(bs_query["bandstructure"][path_convention] ["total"]["traces"]) == 2) if is_sp: bstraces = (bs_query["bandstructure"][path_convention] ["total"]["traces"]["1"] + bs_query["bandstructure"][path_convention] ["total"]["traces"]["-1"]) else: bstraces = bs_query["bandstructure"][path_convention][ "total"]["traces"]["1"] bs_data["ticks"]["distance"] = bs_query["bandstructure"][ path_convention]["total"]["traces"]["ticks"] bs_data["ticks"]["label"] = bs_query["bandstructure"][ path_convention]["total"]["traces"]["labels"] # If LM convention, get equivalent labels if path_convention == "lm" and label_select != "lm": bs_equiv_labels = bs_query["bandstructure"][ path_convention]["total"]["traces"]["equiv_labels"] alt_choice = label_select if label_select == "hin": alt_choice = "h" new_labels = [] for label in bs_data["ticks"]["label"]: label_formatted = label.replace("$", "") if "|" in label_formatted: f_label = label_formatted.split("|") new_labels.append( "$" + bs_equiv_labels[alt_choice][f_label[0]] + "|" + bs_equiv_labels[alt_choice][f_label[1]] + "$") else: new_labels.append( "$" + bs_equiv_labels[alt_choice][label_formatted] + "$") bs_data["ticks"]["label"] = new_labels # Strip latex math wrapping str_replace = { "$": "", "\\mid": "|", "\\Gamma": "Γ", "\\Sigma": "Σ", "GAMMA": "Γ", "_1": "₁", "_2": "₂", "_3": "₃", "_4": "₄", } for entry_num in range(len(bs_data["ticks"]["label"])): for key in str_replace.keys(): if key in bs_data["ticks"]["label"][entry_num]: bs_data["ticks"]["label"][entry_num] = bs_data[ "ticks"]["label"][entry_num].replace( key, str_replace[key]) # - DOS traces from DB using task_id dostraces = [] dos_tot_ele_traces = list( db.electronic_structure.find( {"task_id": int(mpid)}, ["dos.total.traces", "dos.elements"]))[0] dostraces = [ dos_tot_ele_traces["dos"]["total"]["traces"][spin] for spin in dos_tot_ele_traces["dos"]["total"]["traces"].keys() ] elements = [ ele for ele in dos_tot_ele_traces["dos"]["elements"].keys() ] if dos_select == "ap": for ele_label in elements: dostraces += [ dos_tot_ele_traces["dos"]["elements"][ele_label] ["total"]["traces"][spin] for spin in dos_tot_ele_traces["dos"]["elements"] [ele_label]["total"]["traces"].keys() ] elif dos_select == "op": orb_tot_traces = list( db.electronic_structure.find({"task_id": int(mpid)}, ["dos.orbitals"]))[0] for orbital in ["s", "p", "d"]: dostraces += [ orb_tot_traces["dos"]["orbitals"][orbital] ["traces"][spin] for spin in orb_tot_traces["dos"] ["orbitals"]["s"]["traces"].keys() ] elif "orb" in dos_select: ele_label = dos_select.replace("orb", "") for orbital in ["s", "p", "d"]: dostraces += [ dos_tot_ele_traces["dos"]["elements"][ele_label] [orbital]["traces"][spin] for spin in dos_tot_ele_traces["dos"]["elements"] [ele_label][orbital]["traces"].keys() ] traces = [bstraces, dostraces, bs_data] return (traces, elements) else: # -- # -- BS and DOS passed manually # -- # - BS Data if type(bandstructure_symm_line) != dict: bandstructure_symm_line = bandstructure_symm_line.to_dict() if type(density_of_states) != dict: density_of_states = density_of_states.to_dict() bs_reg_plot = BSPlotter( BSML.from_dict(bandstructure_symm_line)) bs_data = bs_reg_plot.bs_plot_data() # - Strip latex math wrapping str_replace = { "$": "", "\\mid": "|", "\\Gamma": "Γ", "\\Sigma": "Σ", "GAMMA": "Γ", "_1": "₁", "_2": "₂", "_3": "₃", "_4": "₄", } for entry_num in range(len(bs_data["ticks"]["label"])): for key in str_replace.keys(): if key in bs_data["ticks"]["label"][entry_num]: bs_data["ticks"]["label"][entry_num] = bs_data[ "ticks"]["label"][entry_num].replace( key, str_replace[key]) # Obtain bands to plot over: energy_window = (-6.0, 10.0) bands = [] for band_num in range(bs_reg_plot._nb_bands): if (bs_data["energy"][0][str(Spin.up)][band_num][0] <= energy_window[1]) and (bs_data["energy"][0][str( Spin.up)][band_num][0] >= energy_window[0]): bands.append(band_num) bstraces = [] # Generate traces for total BS data for d in range(len(bs_data["distances"])): dist_dat = bs_data["distances"][d] energy_ind = [ i for i in range(len(bs_data["distances"][d])) ] traces_for_segment = [{ "x": dist_dat, "y": [bs_data["energy"][d]["1"][i][j] for j in energy_ind], "mode": "lines", "line": { "color": "#666666" }, "hoverinfo": "skip", "showlegend": False, } for i in bands] if bs_reg_plot._bs.is_spin_polarized: traces_for_segment += [{ "x": dist_dat, "y": [ bs_data["energy"][d]["-1"][i][j] for j in energy_ind ], "mode": "lines", "line": { "color": "#666666" }, "hoverinfo": "skip", "showlegend": False, } for i in bands] bstraces += traces_for_segment # - DOS Data dostraces = [] dos = CompleteDos.from_dict(density_of_states) dos_max = np.abs( (dos.energies - dos.efermi - energy_window[1])).argmin() dos_min = np.abs( (dos.energies - dos.efermi - energy_window[0])).argmin() if bs_reg_plot._bs.is_spin_polarized: # Add second spin data if available trace_tdos = go.Scatter( x=dos.densities[Spin.down][dos_min:dos_max], y=dos.energies[dos_min:dos_max] - dos.efermi, mode="lines", name="Total DOS (spin ↓)", line=go.scatter.Line(color="#444444", dash="dash"), fill="tozerox", ) dostraces.append(trace_tdos) tdos_label = "Total DOS (spin ↑)" else: tdos_label = "Total DOS" # Total DOS trace_tdos = go.Scatter( x=dos.densities[Spin.up][dos_min:dos_max], y=dos.energies[dos_min:dos_max] - dos.efermi, mode="lines", name=tdos_label, line=go.scatter.Line(color="#444444"), fill="tozerox", legendgroup="spinup", ) dostraces.append(trace_tdos) ele_dos = dos.get_element_dos() elements = [str(entry) for entry in ele_dos.keys()] if dos_select == "ap": proj_data = ele_dos elif dos_select == "op": proj_data = dos.get_spd_dos() elif "orb" in dos_select: proj_data = dos.get_element_spd_dos( Element(dos_select.replace("orb", ""))) else: raise PreventUpdate # Projected DOS count = 0 colors = [ "#1f77b4", # muted blue "#ff7f0e", # safety orange "#2ca02c", # cooked asparagus green "#9467bd", # muted purple "#e377c2", # raspberry yogurt pink "#d62728", # brick red "#8c564b", # chestnut brown "#bcbd22", # curry yellow-green "#17becf", # blue-teal ] for label in proj_data.keys(): if bs_reg_plot._bs.is_spin_polarized: trace = go.Scatter( x=proj_data[label].densities[Spin.down] [dos_min:dos_max], y=dos.energies[dos_min:dos_max] - dos.efermi, mode="lines", name=str(label) + " (spin ↓)", line=dict(width=3, color=colors[count], dash="dash"), ) dostraces.append(trace) spin_up_label = str(label) + " (spin ↑)" else: spin_up_label = str(label) trace = go.Scatter( x=proj_data[label].densities[Spin.up][dos_min:dos_max], y=dos.energies[dos_min:dos_max] - dos.efermi, mode="lines", name=spin_up_label, line=dict(width=3, color=colors[count]), ) dostraces.append(trace) count += 1 traces = [bstraces, dostraces, bs_data] return (traces, elements)
def bandstr(vrun="", kpfile="", filename=".", plot=False): """ Plot electronic bandstructure Args: vrun: path to vasprun.xml kpfile:path to line mode KPOINTS file Returns: matplotlib object """ run = Vasprun(vrun, parse_projected_eigen=True) bands = run.get_band_structure(kpfile, line_mode=True, efermi=run.efermi) bsp = BSPlotter(bands) zero_to_efermi = True bandgap = str(round(bands.get_band_gap()["energy"], 3)) # print "bg=",bandgap data = bsp.bs_plot_data(zero_to_efermi) plt = get_publication_quality_plot(12, 8) plt.close() plt.clf() band_linewidth = 3 x_max = data["distances"][-1][-1] # print (x_max) for d in range(len(data["distances"])): for i in range(bsp._nb_bands): plt.plot( data["distances"][d], [ data["energy"][d]["1"][i][j] for j in range(len(data["distances"][d])) ], "b-", linewidth=band_linewidth, ) if bsp._bs.is_spin_polarized: plt.plot( data["distances"][d], [ data["energy"][d]["-1"][i][j] for j in range(len(data["distances"][d])) ], "r--", linewidth=band_linewidth, ) bsp._maketicks(plt) if bsp._bs.is_metal(): e_min = -10 e_max = 10 band_linewidth = 3 for cbm in data["cbm"]: plt.scatter(cbm[0], cbm[1], color="r", marker="o", s=100) for vbm in data["vbm"]: plt.scatter(vbm[0], vbm[1], color="g", marker="o", s=100) plt.xlabel(r"$\mathrm{Wave\ Vector}$", fontsize=30) ylabel = ( r"$\mathrm{E\ -\ E_f\ (eV)}$" if zero_to_efermi else r"$\mathrm{Energy\ (eV)}$" ) plt.ylabel(ylabel, fontsize=30) plt.ylim(-4, 4) plt.xlim(0, x_max) plt.tight_layout() if plot == True: plt.savefig(filename, img_format="png") plt.close() return plt
class BSPlotterTest(unittest.TestCase): def setUp(self): with open(os.path.join(test_dir, "CaO_2605_bandstructure.json"), "r", encoding="utf-8") as f: d = json.loads(f.read()) self.bs = BandStructureSymmLine.from_dict(d) self.plotter = BSPlotter(self.bs) self.assertEqual(len(self.plotter._bs), 1, "wrong number of band objects") with open(os.path.join(test_dir, "N2_12103_bandstructure.json"), "r", encoding="utf-8") as f: d = json.loads(f.read()) self.sbs_sc = BandStructureSymmLine.from_dict(d) with open(os.path.join(test_dir, "C_48_bandstructure.json"), "r", encoding="utf-8") as f: d = json.loads(f.read()) self.sbs_met = BandStructureSymmLine.from_dict(d) self.plotter_multi = BSPlotter([self.sbs_sc, self.sbs_met]) self.assertEqual(len(self.plotter_multi._bs), 2, "wrong number of band objects") self.assertEqual(self.plotter_multi._nb_bands, [96, 96], "wrong number of bands") warnings.simplefilter("ignore") def tearDown(self): warnings.simplefilter("default") def test_add_bs(self): self.plotter_multi.add_bs(self.sbs_sc) self.assertEqual(len(self.plotter_multi._bs), 3, "wrong number of band objects") self.assertEqual(self.plotter_multi._nb_bands, [96, 96, 96], "wrong number of bands") def test_get_branch_steps(self): steps_idx = BSPlotter._get_branch_steps(self.sbs_sc.branches) self.assertEqual(steps_idx, [0, 121, 132, 143], "wrong list of steps idx") def test_rescale_distances(self): rescaled_distances = self.plotter_multi._rescale_distances( self.sbs_sc, self.sbs_met) self.assertEqual( len(rescaled_distances), len(self.sbs_met.distance), "wrong lenght of distances list", ) self.assertEqual(rescaled_distances[-1], 6.5191398067252875, "wrong last distance value") self.assertEqual( rescaled_distances[148], self.sbs_sc.distance[19], "wrong distance at high symm k-point", ) def test_interpolate_bands(self): data = self.plotter.bs_plot_data() d = data["distances"] en = data["energy"]["1"] int_distances, int_energies = self.plotter._interpolate_bands(d, en) self.assertEqual(len(int_distances), 10, "wrong lenght of distances list") self.assertEqual(len(int_distances[0]), 100, "wrong lenght of distances in a branch") self.assertEqual(len(int_energies), 10, "wrong lenght of distances list") self.assertEqual(int_energies[0].shape, (16, 100), "wrong lenght of distances list") def test_bs_plot_data(self): self.assertEqual( len(self.plotter.bs_plot_data()["distances"]), 10, "wrong number of sequences of branches", ) self.assertEqual( len(self.plotter.bs_plot_data()["distances"][0]), 16, "wrong number of distances in the first sequence of branches", ) self.assertEqual( sum([len(e) for e in self.plotter.bs_plot_data()["distances"]]), 160, "wrong number of distances", ) lenght = len( self.plotter.bs_plot_data(split_branches=False)["distances"][0]) self.assertEqual( lenght, 144, "wrong number of distances in the first sequence of branches") lenght = len( self.plotter.bs_plot_data(split_branches=False)["distances"]) self.assertEqual( lenght, 2, "wrong number of distances in the first sequence of branches") self.assertEqual(self.plotter.bs_plot_data()["ticks"]["label"][5], "K", "wrong tick label") self.assertEqual( len(self.plotter.bs_plot_data()["ticks"]["label"]), 19, "wrong number of tick labels", ) def test_get_ticks(self): self.assertEqual(self.plotter.get_ticks()["label"][5], "K", "wrong tick label") self.assertEqual( self.plotter.get_ticks()["distance"][5], 2.406607625322699, "wrong tick distance", ) # Minimal baseline testing for get_plot. not a true test. Just checks that # it can actually execute. def test_get_plot(self): # zero_to_efermi = True, ylim = None, smooth = False, # vbm_cbm_marker = False, smooth_tol = None # Disabling latex is needed for this test to work. from matplotlib import rc rc("text", usetex=False) plt = self.plotter.get_plot() self.assertEqual(plt.ylim(), (-4.0, 7.6348), "wrong ylim") plt = self.plotter.get_plot(smooth=True) plt = self.plotter.get_plot(vbm_cbm_marker=True) self.plotter.save_plot("bsplot.png") self.assertTrue(os.path.isfile("bsplot.png")) os.remove("bsplot.png") plt.close("all") # test plotter with 2 bandstructures plt = self.plotter_multi.get_plot() self.assertEqual(len(plt.gca().get_lines()), 874, "wrong number of lines") self.assertEqual(plt.ylim(), (-10.0, 10.0), "wrong ylim") plt = self.plotter_multi.get_plot(zero_to_efermi=False) self.assertEqual(plt.ylim(), (-15.2379, 12.67141266), "wrong ylim") plt = self.plotter_multi.get_plot(smooth=True) self.plotter_multi.save_plot("bsplot.png") self.assertTrue(os.path.isfile("bsplot.png")) os.remove("bsplot.png") plt.close("all")
def make_el_band_plot(ax, bands, yticklabels=True, **kargs): """ Make a DOS plot Args: ax: an Axes object bands: band structure object linewidth (int): line width """ # default values of options name = None linewidth = 2 if "name" in kargs: name = kargs["name"] if "linewidth" in kargs: linewidth = kargs["linewidth"] if "elements" in kargs: elements = kargs["elements"] else: raise KeyError("argument 'elements' in make_el_band_plot is missing") for key in kargs: if key not in ["name", "linewidth", "elements"]: print("WARNING: option {0} not considered".format(key)) # band structure plot data bsplot = BSPlotter(bands) plotdata = bsplot.bs_plot_data(zero_to_efermi=True) # spin polarized calculation if bands.is_spin_polarized: all_spins = [Spin.up, Spin.down] else: all_spins = [Spin.up] for spin in all_spins: if spin == Spin.up: alpha = 1 lw = linewidth if spin == Spin.down: alpha = .7 lw = linewidth / 2 # compute s, p, d normalized contributions contrib = compute_contrib_el(bands, spin, elements) # plot bands ikpts = 0 maxd = -1 mind = 1e10 for d, ene in zip(plotdata["distances"], plotdata["energy"]): npts = len(d) maxd = max(max(d), maxd) mind = min(min(d), mind) for b in range(bands.nb_bands): rgbline(ax, d, ene[str(spin)][b], contrib[b, ikpts:ikpts + npts, 0], contrib[b, ikpts:ikpts + npts, 1], contrib[b, ikpts:ikpts + npts:, 2], alpha, lw) ikpts += len(d) # add ticks and vlines make_ticks(ax, bsplot) ax.set_xlabel("k-points") ax.set_xlim(mind, maxd) ax.grid(False) if not yticklabels: ax.set_yticklabels([])