def get_kpt_labels(self, bs): bsplot = BSPlotter(bs) ## get unique K-points labels = bsplot.get_ticks()["label"] labelspos = bsplot.get_ticks()["distance"] labels_uniq = [labels[0]] labelspos_uniq = [labelspos[0]] for i in range(1, len(labels)): if labels[i] != labels[i - 1]: labels_uniq.append(labels[i]) labelspos_uniq.append(labelspos[i]) labels_uniq = [label.replace("$\mid$", "|") for label in labels_uniq] ## hack for dash which can't display latex :( labels_uniq = [label.replace("$\Gamma$", u"\u0393") for label in labels_uniq] return labels_uniq, labelspos_uniq
dosplotter = DosPlotter() Totaldos = dosplotter.add_dos('Total DOS', tdos) Integrateddos = dosplotter.add_dos('Integrated DOS', idos) #Pdos = dosplotter.add_dos('Partial DOS',pdos) #Spd_dos = dosplotter.add_dos('spd DOS',spd_dos) #Element_dos = dosplotter.add_dos('Element DOS',element_dos) #Element_spd_dos = dosplotter.add_dos('Element_spd DOS',element_spd_dos) dos_dict = { 'Total DOS': tdos, 'Integrated DOS': idos } #'Partial DOS':pdos,'spd DOS':spd_dos,'Element DOS':element_dos}#'Element_spd DOS':element_spd_dos add_dos_dict = dosplotter.add_dos_dict(dos_dict) get_dos_dict = dosplotter.get_dos_dict() dos_plot = dosplotter.get_plot() ##dosplotter.save_plot("MAPbI3_dos",img_format="png") ##dos_plot.show() bsplotter = BSPlotter(bs) bs_plot_data = bsplotter.bs_plot_data() bs_plot = bsplotter.get_plot() #bsplotter.save_plot("MAPbI3_bs",img_format="png") #bsplotter.show() ticks = bsplotter.get_ticks() print ticks bsplotter.plot_brillouin() bsdos = BSDOSPlotter( tick_fontsize=10, egrid_interval=20, dos_projection="orbitals", bs_legend=None) #bs_projection="HPbCIN",dos_projection="HPbCIN") bds = bsdos.get_plot(bs, cdos)
def projected_band_structure(): step_count=1 filename='vasprun.xml' check_file(filename) proc_str="Reading Data From "+ filename +" File ..." procs(proc_str,step_count,sp='-->>') vsr=Vasprun(filename) filename='PROCAR' check_file(filename) step_count+=1 proc_str="Reading Data From "+ filename +" File ..." procs(proc_str,step_count,sp='-->>') procar=Procar(filename) nbands=procar.nbands nions=procar.nions norbitals=len(procar.orbitals) nkpoints=procar.nkpoints step_count+=1 filename='KPOINTS' check_file(filename) proc_str="Reading Data From "+ filename +" File ..." procs(proc_str,step_count,sp='-->>') bands = vsr.get_band_structure(filename, line_mode=True, efermi=vsr.efermi) struct=vsr.final_structure (atom_index,in_str)=atom_selection(struct) if len(atom_index)==0: print("No atoms selected!") return # print(atom_index) if vsr.is_spin: proc_str="This Is a Spin-polarized Calculation." procs(proc_str,0,sp='-->>') ISPIN=2 contrib=np.zeros((nkpoints,nbands,norbitals,2)) for i in atom_index: contrib[:,:,:,0]=contrib[:,:,:,0]+procar.data[Spin.up][:,:,i,:] contrib[:,:,:,1]=contrib[:,:,:,1]+procar.data[Spin.down][:,:,i,:] for ispin in range(2): proj_band=contrib[:,:,:,ispin].reshape(nkpoints*nbands,norbitals) step_count+=1 if ispin==0: filename="PBAND_Up.dat" else: filename="PBAND_Down.dat" proc_str="Writting Projected Band Structure Data to "+ filename +" File ..." procs(proc_str,step_count,sp='-->>') band_data=bands.bands[Spin.up] y_data=band_data.reshape(1,nbands*nkpoints)[0]-vsr.efermi #shift fermi level to 0 x_data=np.array(bands.distance*nbands) data=np.vstack((x_data,y_data,proj_band.T)).T tmp1_str="#%(key1)+12s%(key2)+12s" tmp2_dic={'key1':'K-Distance','key2':'Energy(ev)'} for i in range(norbitals): tmp1_str+="%(key"+str(i+3)+")+12s" tmp2_dic["key"+str(i+3)]=procar.orbitals[i] # print(tmp1_str) atom_index_str=[str(x+1) for x in atom_index] head_line1="#String: "+in_str+'\n#Selected atom: ' +' '.join(atom_index_str)+'\n' head_line2=tmp1_str % tmp2_dic head_line=head_line1+head_line2 write_col_data(filename,data,head_line,nkpoints) else: if vsr.parameters['LNONCOLLINEAR']: proc_str="This Is a Non-Collinear Calculation." procs(proc_str,0,sp='-->>') ISPIN=3 else: proc_str="This Is a Non-Spin Calculation." procs(proc_str,0,sp='-->>') ISPIN=1 contrib=np.zeros((nkpoints,nbands,norbitals)) for i in atom_index: contrib[:,:,:]=contrib[:,:,:]+procar.data[Spin.up][:,:,i,:] proj_band=contrib.reshape(nkpoints*nbands,norbitals) step_count+=1 filename="PBAND.dat" proc_str="Writting Projected Band Structure Data to "+ filename +" File ..." procs(proc_str,step_count,sp='-->>') band_data=bands.bands[Spin.up] y_data=band_data.reshape(1,nbands*nkpoints)[0]-vsr.efermi #shift fermi level to 0 x_data=np.array(bands.distance*nbands) data=np.vstack((x_data,y_data,proj_band.T)).T tmp1_str="#%(key1)+12s%(key2)+12s" tmp2_dic={'key1':'K-Distance','key2':'Energy(ev)'} for i in range(norbitals): tmp1_str+="%(key"+str(i+3)+")+12s" tmp2_dic["key"+str(i+3)]=procar.orbitals[i] # print(tmp1_str) atom_index_str=[str(x+1) for x in atom_index] head_line1="#String: "+in_str+'\n#Selected atom: ' +' '.join(atom_index_str)+'\n' head_line2=tmp1_str % tmp2_dic head_line=head_line1+head_line2 write_col_data(filename,data,head_line,nkpoints) step_count+=1 bsp=BSPlotter(bands) filename="HighSymmetricPoints.dat" proc_str="Writting Label infomation to "+ filename +" File ..." procs(proc_str,step_count,sp='-->>') head_line="#%(key1)+12s%(key2)+12s%(key3)+12s"%{'key1':'index','key2':'label','key3':'position'} line=head_line+'\n' for i,label in enumerate(bsp.get_ticks()['label']): new_line="%(key1)12d%(key2)+12s%(key3)12f\n"%{'key1':i,'key2':label,'key3':bsp.get_ticks()['distance'][i]} line+=new_line write_col_data(filename,line,'',str_data=True)
def band_structure(): check_matplotlib() step_count=1 filename='vasprun.xml' check_file(filename) proc_str="Reading Data From "+ filename +" File ..." procs(proc_str,step_count,sp='-->>') vsr=Vasprun(filename) step_count+=1 filename='KPOINTS' check_file(filename) proc_str="Reading Data From "+ filename +" File ..." procs(proc_str,step_count,sp='-->>') bands = vsr.get_band_structure(filename, line_mode=True, efermi=vsr.efermi) step_count+=1 filename='OUTCAR' check_file(filename) proc_str="Reading Data From "+ filename +" File ..." procs(proc_str,step_count,sp='-->>') outcar=Outcar('OUTCAR') mag=outcar.as_dict()['total_magnetization'] if vsr.is_spin: proc_str="This Is a Spin-polarized Calculation." procs(proc_str,0,sp='-->>') tdos=vsr.tdos SpinUp_gap=tdos.get_gap(spin=Spin.up) cbm_vbm_up=tdos.get_cbm_vbm(spin=Spin.up) SpinDown_gap=tdos.get_gap(spin=Spin.down) cbm_vbm_down=tdos.get_cbm_vbm(spin=Spin.up) if SpinUp_gap > min_gap and SpinDown_gap > min_gap: is_metal=False is_semimetal=False elif SpinUp_gap > min_gap and SpinDown_gap < min_gap: is_metal=False is_semimetal=True elif SpinUp_gap < min_gap and SpinDown_gap > min_gap: is_metal=False is_semimetal=True elif SpinUp_gap < min_gap and SpinDown_gap < min_gap: is_metal=True is_semimetal=False if is_metal: proc_str="This Material Is a Metal." procs(proc_str,0,sp='-->>') if not is_metal and is_semimetal: proc_str="This Material Is a Semimetal." procs(proc_str,0,sp='-->>') else: proc_str="This Material Is a Semiconductor." procs(proc_str,0,sp='-->>') proc_str="Total magnetization is "+str(mag) procs(proc_str,0,sp='-->>') if mag > min_mag: proc_str="SpinUp : vbm=%f eV cbm=%f eV gap=%f eV"%(cbm_vbm_up[1],cbm_vbm_up[0],SpinUp_gap) procs(proc_str,0,sp='-->>') proc_str="SpinDown: vbm=%f eV cbm=%f eV gap=%f eV"%(cbm_vbm_down[1],cbm_vbm_down[0],SpinUp_gap) procs(proc_str,0,sp='-->>') else: proc_str="SpinUp : vbm=%f eV cbm=%f eV gap=%f eV"%(cbm_vbm_up[1],cbm_vbm_up[0],SpinUp_gap) procs(proc_str,0,sp='-->>') step_count+=1 filename="BAND.dat" proc_str="Writting Band Structure Data to "+ filename +" File ..." procs(proc_str,step_count,sp='-->>') band_data_up=bands.bands[Spin.up] band_data_down=bands.bands[Spin.down] y_data_up=band_data_up.reshape(1,band_data_up.shape[0]*band_data_up.shape[1])[0]-vsr.efermi #shift fermi level to 0 y_data_down=band_data_down.reshape(1,band_data_down.shape[0]*band_data_down.shape[1])[0]-vsr.efermi #shift fermi level to 0 x_data=np.array(bands.distance*band_data_up.shape[0]) data=np.vstack((x_data,y_data_up,y_data_down)).T head_line="#%(key1)+12s%(key2)+13s%(key3)+15s"%{'key1':'K-Distance','key2':'UpEnergy(ev)','key3':'DownEnergy(ev)'} write_col_data(filename,data,head_line,band_data_up.shape[1]) else: if vsr.parameters['LNONCOLLINEAR']: proc_str="This Is a Non-Collinear Calculation." else: proc_str="This Is a Non-Spin Calculation." procs(proc_str,0,sp='-->>') cbm=bands.get_cbm()['energy'] vbm=bands.get_vbm()['energy'] gap=bands.get_band_gap()['energy'] if not bands.is_metal(): proc_str="This Material Is a Semiconductor." procs(proc_str,0,sp='-->>') proc_str="vbm=%f eV cbm=%f eV gap=%f eV"%(vbm,cbm,gap) procs(proc_str,0,sp='-->>') else: proc_str="This Material Is a Metal." procs(proc_str,0,sp='-->>') step_count+=1 filename3="BAND.dat" proc_str="Writting Band Structure Data to "+ filename3 +" File ..." procs(proc_str,step_count,sp='-->>') band_data=bands.bands[Spin.up] y_data=band_data.reshape(1,band_data.shape[0]*band_data.shape[1])[0]-vsr.efermi #shift fermi level to 0 x_data=np.array(bands.distance*band_data.shape[0]) data=np.vstack((x_data,y_data)).T head_line="#%(key1)+12s%(key2)+13s"%{'key1':'K-Distance','key2':'Energy(ev)'} write_col_data(filename3,data,head_line,band_data.shape[1]) step_count+=1 bsp=BSPlotter(bands) filename4="HighSymmetricPoints.dat" proc_str="Writting Label infomation to "+ filename4 +" File ..." procs(proc_str,step_count,sp='-->>') head_line="#%(key1)+12s%(key2)+12s%(key3)+12s"%{'key1':'index','key2':'label','key3':'position'} line=head_line+'\n' for i,label in enumerate(bsp.get_ticks()['label']): new_line="%(key1)12d%(key2)+12s%(key3)12f\n"%{'key1':i,'key2':label,'key3':bsp.get_ticks()['distance'][i]} line+=new_line line+='\n' write_col_data(filename4,line,'',str_data=True) try: step_count+=1 filename5="BAND.png" proc_str="Saving Plot to "+ filename5 +" File ..." procs(proc_str,step_count,sp='-->>') bsp.save_plot(filename5, img_format="png") except: print("Figure output fails !!!")
class BSPlotterTest(unittest.TestCase): def setUp(self): with open(os.path.join(test_dir, "CaO_2605_bandstructure.json"), "r", encoding="utf-8") as f: d = json.loads(f.read()) self.bs = BandStructureSymmLine.from_dict(d) self.plotter = BSPlotter(self.bs) self.assertEqual(len(self.plotter._bs), 1, "wrong number of band objects") with open(os.path.join(test_dir, "N2_12103_bandstructure.json"), "r", encoding="utf-8") as f: d = json.loads(f.read()) self.sbs_sc = BandStructureSymmLine.from_dict(d) with open(os.path.join(test_dir, "C_48_bandstructure.json"), "r", encoding="utf-8") as f: d = json.loads(f.read()) self.sbs_met = BandStructureSymmLine.from_dict(d) self.plotter_multi = BSPlotter([self.sbs_sc, self.sbs_met]) self.assertEqual(len(self.plotter_multi._bs), 2, "wrong number of band objects") self.assertEqual(self.plotter_multi._nb_bands, [96, 96], "wrong number of bands") warnings.simplefilter("ignore") def tearDown(self): warnings.simplefilter("default") def test_add_bs(self): self.plotter_multi.add_bs(self.sbs_sc) self.assertEqual(len(self.plotter_multi._bs), 3, "wrong number of band objects") self.assertEqual(self.plotter_multi._nb_bands, [96, 96, 96], "wrong number of bands") def test_get_branch_steps(self): steps_idx = BSPlotter._get_branch_steps(self.sbs_sc.branches) self.assertEqual(steps_idx, [0, 121, 132, 143], "wrong list of steps idx") def test_rescale_distances(self): rescaled_distances = self.plotter_multi._rescale_distances( self.sbs_sc, self.sbs_met) self.assertEqual( len(rescaled_distances), len(self.sbs_met.distance), "wrong lenght of distances list", ) self.assertEqual(rescaled_distances[-1], 6.5191398067252875, "wrong last distance value") self.assertEqual( rescaled_distances[148], self.sbs_sc.distance[19], "wrong distance at high symm k-point", ) def test_interpolate_bands(self): data = self.plotter.bs_plot_data() d = data["distances"] en = data["energy"]["1"] int_distances, int_energies = self.plotter._interpolate_bands(d, en) self.assertEqual(len(int_distances), 10, "wrong lenght of distances list") self.assertEqual(len(int_distances[0]), 100, "wrong lenght of distances in a branch") self.assertEqual(len(int_energies), 10, "wrong lenght of distances list") self.assertEqual(int_energies[0].shape, (16, 100), "wrong lenght of distances list") def test_bs_plot_data(self): self.assertEqual( len(self.plotter.bs_plot_data()["distances"]), 10, "wrong number of sequences of branches", ) self.assertEqual( len(self.plotter.bs_plot_data()["distances"][0]), 16, "wrong number of distances in the first sequence of branches", ) self.assertEqual( sum([len(e) for e in self.plotter.bs_plot_data()["distances"]]), 160, "wrong number of distances", ) lenght = len( self.plotter.bs_plot_data(split_branches=False)["distances"][0]) self.assertEqual( lenght, 144, "wrong number of distances in the first sequence of branches") lenght = len( self.plotter.bs_plot_data(split_branches=False)["distances"]) self.assertEqual( lenght, 2, "wrong number of distances in the first sequence of branches") self.assertEqual(self.plotter.bs_plot_data()["ticks"]["label"][5], "K", "wrong tick label") self.assertEqual( len(self.plotter.bs_plot_data()["ticks"]["label"]), 19, "wrong number of tick labels", ) def test_get_ticks(self): self.assertEqual(self.plotter.get_ticks()["label"][5], "K", "wrong tick label") self.assertEqual( self.plotter.get_ticks()["distance"][5], 2.406607625322699, "wrong tick distance", ) # Minimal baseline testing for get_plot. not a true test. Just checks that # it can actually execute. def test_get_plot(self): # zero_to_efermi = True, ylim = None, smooth = False, # vbm_cbm_marker = False, smooth_tol = None # Disabling latex is needed for this test to work. from matplotlib import rc rc("text", usetex=False) plt = self.plotter.get_plot() self.assertEqual(plt.ylim(), (-4.0, 7.6348), "wrong ylim") plt = self.plotter.get_plot(smooth=True) plt = self.plotter.get_plot(vbm_cbm_marker=True) self.plotter.save_plot("bsplot.png") self.assertTrue(os.path.isfile("bsplot.png")) os.remove("bsplot.png") plt.close("all") # test plotter with 2 bandstructures plt = self.plotter_multi.get_plot() self.assertEqual(len(plt.gca().get_lines()), 874, "wrong number of lines") self.assertEqual(plt.ylim(), (-10.0, 10.0), "wrong ylim") plt = self.plotter_multi.get_plot(zero_to_efermi=False) self.assertEqual(plt.ylim(), (-15.2379, 12.67141266), "wrong ylim") plt = self.plotter_multi.get_plot(smooth=True) self.plotter_multi.save_plot("bsplot.png") self.assertTrue(os.path.isfile("bsplot.png")) os.remove("bsplot.png") plt.close("all")