Esempio n. 1
0
    def get_kpt_labels(self, bs):

        bsplot = BSPlotter(bs)

        ## get unique K-points
        labels = bsplot.get_ticks()["label"]
        labelspos = bsplot.get_ticks()["distance"]
        labels_uniq = [labels[0]]
        labelspos_uniq = [labelspos[0]]
        for i in range(1, len(labels)):
            if labels[i] != labels[i - 1]:
                labels_uniq.append(labels[i])
                labelspos_uniq.append(labelspos[i])

        labels_uniq = [label.replace("$\mid$", "|") for label in labels_uniq]
        ## hack for dash which can't display latex :(
        labels_uniq = [label.replace("$\Gamma$", u"\u0393") for label in labels_uniq]

        return labels_uniq, labelspos_uniq
Esempio n. 2
0
dosplotter = DosPlotter()
Totaldos = dosplotter.add_dos('Total DOS', tdos)
Integrateddos = dosplotter.add_dos('Integrated DOS', idos)
#Pdos = dosplotter.add_dos('Partial DOS',pdos)
#Spd_dos =  dosplotter.add_dos('spd DOS',spd_dos)
#Element_dos = dosplotter.add_dos('Element DOS',element_dos)
#Element_spd_dos = dosplotter.add_dos('Element_spd DOS',element_spd_dos)
dos_dict = {
    'Total DOS': tdos,
    'Integrated DOS': idos
}  #'Partial DOS':pdos,'spd DOS':spd_dos,'Element DOS':element_dos}#'Element_spd DOS':element_spd_dos
add_dos_dict = dosplotter.add_dos_dict(dos_dict)
get_dos_dict = dosplotter.get_dos_dict()
dos_plot = dosplotter.get_plot()
##dosplotter.save_plot("MAPbI3_dos",img_format="png")
##dos_plot.show()
bsplotter = BSPlotter(bs)
bs_plot_data = bsplotter.bs_plot_data()
bs_plot = bsplotter.get_plot()
#bsplotter.save_plot("MAPbI3_bs",img_format="png")
#bsplotter.show()
ticks = bsplotter.get_ticks()
print ticks
bsplotter.plot_brillouin()
bsdos = BSDOSPlotter(
    tick_fontsize=10,
    egrid_interval=20,
    dos_projection="orbitals",
    bs_legend=None)  #bs_projection="HPbCIN",dos_projection="HPbCIN")
bds = bsdos.get_plot(bs, cdos)
Esempio n. 3
0
def projected_band_structure():
   step_count=1
   filename='vasprun.xml'
   check_file(filename)
   proc_str="Reading Data From "+ filename +" File ..."
   procs(proc_str,step_count,sp='-->>')
   vsr=Vasprun(filename)

   filename='PROCAR'
   check_file(filename)
   step_count+=1
   proc_str="Reading Data From "+ filename +" File ..."
   procs(proc_str,step_count,sp='-->>')
   procar=Procar(filename)
   nbands=procar.nbands
   nions=procar.nions
   norbitals=len(procar.orbitals)
   nkpoints=procar.nkpoints

   step_count+=1
   filename='KPOINTS'
   check_file(filename)
   proc_str="Reading Data From "+ filename +" File ..."
   procs(proc_str,step_count,sp='-->>')
   bands = vsr.get_band_structure(filename, line_mode=True, efermi=vsr.efermi)
   struct=vsr.final_structure
   (atom_index,in_str)=atom_selection(struct)
   
   if len(atom_index)==0:
      print("No atoms selected!")
      return
#   print(atom_index)

   if vsr.is_spin:
      proc_str="This Is a Spin-polarized Calculation."
      procs(proc_str,0,sp='-->>')
      ISPIN=2
      contrib=np.zeros((nkpoints,nbands,norbitals,2))
      for i in atom_index:
          contrib[:,:,:,0]=contrib[:,:,:,0]+procar.data[Spin.up][:,:,i,:]
          contrib[:,:,:,1]=contrib[:,:,:,1]+procar.data[Spin.down][:,:,i,:]

      for ispin in range(2):
          proj_band=contrib[:,:,:,ispin].reshape(nkpoints*nbands,norbitals)
          step_count+=1
          if ispin==0:
              filename="PBAND_Up.dat"
          else:
              filename="PBAND_Down.dat"
          proc_str="Writting Projected Band Structure Data to "+ filename +" File ..."
          procs(proc_str,step_count,sp='-->>')
          band_data=bands.bands[Spin.up]
          y_data=band_data.reshape(1,nbands*nkpoints)[0]-vsr.efermi #shift fermi level to 0
          x_data=np.array(bands.distance*nbands)
          data=np.vstack((x_data,y_data,proj_band.T)).T
          tmp1_str="#%(key1)+12s%(key2)+12s"
          tmp2_dic={'key1':'K-Distance','key2':'Energy(ev)'}
          for i in range(norbitals):
              tmp1_str+="%(key"+str(i+3)+")+12s"
              tmp2_dic["key"+str(i+3)]=procar.orbitals[i]

#          print(tmp1_str)
          atom_index_str=[str(x+1) for x in atom_index]
          head_line1="#String: "+in_str+'\n#Selected atom: ' +' '.join(atom_index_str)+'\n'
          head_line2=tmp1_str % tmp2_dic
          head_line=head_line1+head_line2
          write_col_data(filename,data,head_line,nkpoints)

   else:
      if vsr.parameters['LNONCOLLINEAR']:
         proc_str="This Is a Non-Collinear Calculation."
         procs(proc_str,0,sp='-->>')
         ISPIN=3
      else:
         proc_str="This Is a Non-Spin Calculation."
         procs(proc_str,0,sp='-->>')
         ISPIN=1

      contrib=np.zeros((nkpoints,nbands,norbitals))
      for i in atom_index:
          contrib[:,:,:]=contrib[:,:,:]+procar.data[Spin.up][:,:,i,:]
 
      proj_band=contrib.reshape(nkpoints*nbands,norbitals)        
      step_count+=1
      filename="PBAND.dat"
      proc_str="Writting Projected Band Structure Data to "+ filename +" File ..."
      procs(proc_str,step_count,sp='-->>')
      band_data=bands.bands[Spin.up]
      y_data=band_data.reshape(1,nbands*nkpoints)[0]-vsr.efermi #shift fermi level to 0
      x_data=np.array(bands.distance*nbands)
      data=np.vstack((x_data,y_data,proj_band.T)).T
      tmp1_str="#%(key1)+12s%(key2)+12s"
      tmp2_dic={'key1':'K-Distance','key2':'Energy(ev)'}
      for i in range(norbitals):
          tmp1_str+="%(key"+str(i+3)+")+12s"
          tmp2_dic["key"+str(i+3)]=procar.orbitals[i]

#      print(tmp1_str)
      atom_index_str=[str(x+1) for x in atom_index]
      head_line1="#String: "+in_str+'\n#Selected atom: ' +' '.join(atom_index_str)+'\n'
      head_line2=tmp1_str % tmp2_dic
      head_line=head_line1+head_line2
      write_col_data(filename,data,head_line,nkpoints)

   step_count+=1
   bsp=BSPlotter(bands)
   filename="HighSymmetricPoints.dat"
   proc_str="Writting Label infomation to "+ filename +" File ..."
   procs(proc_str,step_count,sp='-->>')
   head_line="#%(key1)+12s%(key2)+12s%(key3)+12s"%{'key1':'index','key2':'label','key3':'position'}
   line=head_line+'\n'
   for i,label in enumerate(bsp.get_ticks()['label']):
       new_line="%(key1)12d%(key2)+12s%(key3)12f\n"%{'key1':i,'key2':label,'key3':bsp.get_ticks()['distance'][i]}
       line+=new_line
   write_col_data(filename,line,'',str_data=True) 
Esempio n. 4
0
def band_structure():
    check_matplotlib()
    step_count=1

    filename='vasprun.xml'
    check_file(filename)
    proc_str="Reading Data From "+ filename +" File ..."
    procs(proc_str,step_count,sp='-->>')
    vsr=Vasprun(filename)

    step_count+=1
    filename='KPOINTS'
    check_file(filename)
    proc_str="Reading Data From "+ filename +" File ..."
    procs(proc_str,step_count,sp='-->>')
    bands = vsr.get_band_structure(filename, line_mode=True, efermi=vsr.efermi)

    step_count+=1
    filename='OUTCAR'
    check_file(filename)
    proc_str="Reading Data From "+ filename +" File ..."
    procs(proc_str,step_count,sp='-->>')
    outcar=Outcar('OUTCAR')
    mag=outcar.as_dict()['total_magnetization']

    if vsr.is_spin:
       proc_str="This Is a Spin-polarized Calculation."
       procs(proc_str,0,sp='-->>')
       tdos=vsr.tdos
       SpinUp_gap=tdos.get_gap(spin=Spin.up) 
       cbm_vbm_up=tdos.get_cbm_vbm(spin=Spin.up)
       SpinDown_gap=tdos.get_gap(spin=Spin.down) 
       cbm_vbm_down=tdos.get_cbm_vbm(spin=Spin.up)

       if SpinUp_gap > min_gap and SpinDown_gap > min_gap:
          is_metal=False
          is_semimetal=False
       elif SpinUp_gap > min_gap and SpinDown_gap < min_gap:
          is_metal=False
          is_semimetal=True
       elif SpinUp_gap < min_gap and SpinDown_gap > min_gap:
          is_metal=False
          is_semimetal=True
       elif SpinUp_gap < min_gap and SpinDown_gap < min_gap:
          is_metal=True
          is_semimetal=False
          
       if is_metal:   
          proc_str="This Material Is a Metal."
          procs(proc_str,0,sp='-->>')
       if not is_metal and is_semimetal:
          proc_str="This Material Is a Semimetal."
          procs(proc_str,0,sp='-->>')
       else:
          proc_str="This Material Is a Semiconductor."
          procs(proc_str,0,sp='-->>')
          proc_str="Total magnetization is "+str(mag)
          procs(proc_str,0,sp='-->>')
          if mag > min_mag:
             proc_str="SpinUp  : vbm=%f eV cbm=%f eV gap=%f eV"%(cbm_vbm_up[1],cbm_vbm_up[0],SpinUp_gap)
             procs(proc_str,0,sp='-->>')
             proc_str="SpinDown: vbm=%f eV cbm=%f eV gap=%f eV"%(cbm_vbm_down[1],cbm_vbm_down[0],SpinUp_gap)
             procs(proc_str,0,sp='-->>')
          else:
             proc_str="SpinUp  : vbm=%f eV cbm=%f eV gap=%f eV"%(cbm_vbm_up[1],cbm_vbm_up[0],SpinUp_gap)
             procs(proc_str,0,sp='-->>')
       step_count+=1
       filename="BAND.dat"
       proc_str="Writting Band Structure Data to "+ filename +" File ..."
       procs(proc_str,step_count,sp='-->>')
       band_data_up=bands.bands[Spin.up]
       band_data_down=bands.bands[Spin.down]
       y_data_up=band_data_up.reshape(1,band_data_up.shape[0]*band_data_up.shape[1])[0]-vsr.efermi #shift fermi level to 0
       y_data_down=band_data_down.reshape(1,band_data_down.shape[0]*band_data_down.shape[1])[0]-vsr.efermi #shift fermi level to 0
       x_data=np.array(bands.distance*band_data_up.shape[0])
       data=np.vstack((x_data,y_data_up,y_data_down)).T
       head_line="#%(key1)+12s%(key2)+13s%(key3)+15s"%{'key1':'K-Distance','key2':'UpEnergy(ev)','key3':'DownEnergy(ev)'}
       write_col_data(filename,data,head_line,band_data_up.shape[1])
 
    else:
       if vsr.parameters['LNONCOLLINEAR']:
          proc_str="This Is a Non-Collinear Calculation."
       else:
           proc_str="This Is a Non-Spin Calculation."
       procs(proc_str,0,sp='-->>')
       cbm=bands.get_cbm()['energy']
       vbm=bands.get_vbm()['energy']
       gap=bands.get_band_gap()['energy']
       if not bands.is_metal():
          proc_str="This Material Is a Semiconductor."
          procs(proc_str,0,sp='-->>')
          proc_str="vbm=%f eV cbm=%f eV gap=%f eV"%(vbm,cbm,gap)
          procs(proc_str,0,sp='-->>')
       else:
          proc_str="This Material Is a Metal."
          procs(proc_str,0,sp='-->>')
       
       step_count+=1
       filename3="BAND.dat"
       proc_str="Writting Band Structure Data to "+ filename3 +" File ..."
       procs(proc_str,step_count,sp='-->>')
       band_data=bands.bands[Spin.up]
       y_data=band_data.reshape(1,band_data.shape[0]*band_data.shape[1])[0]-vsr.efermi #shift fermi level to 0
       x_data=np.array(bands.distance*band_data.shape[0])
       data=np.vstack((x_data,y_data)).T
       head_line="#%(key1)+12s%(key2)+13s"%{'key1':'K-Distance','key2':'Energy(ev)'}
       write_col_data(filename3,data,head_line,band_data.shape[1])
       step_count+=1
       bsp=BSPlotter(bands)
       filename4="HighSymmetricPoints.dat"
       proc_str="Writting Label infomation to "+ filename4 +" File ..."
       procs(proc_str,step_count,sp='-->>')
       head_line="#%(key1)+12s%(key2)+12s%(key3)+12s"%{'key1':'index','key2':'label','key3':'position'}
       line=head_line+'\n'
       for i,label in enumerate(bsp.get_ticks()['label']):
           new_line="%(key1)12d%(key2)+12s%(key3)12f\n"%{'key1':i,'key2':label,'key3':bsp.get_ticks()['distance'][i]}
           line+=new_line
       line+='\n'
       write_col_data(filename4,line,'',str_data=True) 
    try:
       step_count+=1
       filename5="BAND.png"
       proc_str="Saving Plot to "+ filename5 +" File ..."
       procs(proc_str,step_count,sp='-->>')
       bsp.save_plot(filename5, img_format="png")
    except:
       print("Figure output fails !!!")   
Esempio n. 5
0
class BSPlotterTest(unittest.TestCase):
    def setUp(self):
        with open(os.path.join(test_dir, "CaO_2605_bandstructure.json"),
                  "r",
                  encoding="utf-8") as f:
            d = json.loads(f.read())
            self.bs = BandStructureSymmLine.from_dict(d)
            self.plotter = BSPlotter(self.bs)

        self.assertEqual(len(self.plotter._bs), 1,
                         "wrong number of band objects")

        with open(os.path.join(test_dir, "N2_12103_bandstructure.json"),
                  "r",
                  encoding="utf-8") as f:
            d = json.loads(f.read())
            self.sbs_sc = BandStructureSymmLine.from_dict(d)

        with open(os.path.join(test_dir, "C_48_bandstructure.json"),
                  "r",
                  encoding="utf-8") as f:
            d = json.loads(f.read())
            self.sbs_met = BandStructureSymmLine.from_dict(d)

        self.plotter_multi = BSPlotter([self.sbs_sc, self.sbs_met])
        self.assertEqual(len(self.plotter_multi._bs), 2,
                         "wrong number of band objects")
        self.assertEqual(self.plotter_multi._nb_bands, [96, 96],
                         "wrong number of bands")
        warnings.simplefilter("ignore")

    def tearDown(self):
        warnings.simplefilter("default")

    def test_add_bs(self):
        self.plotter_multi.add_bs(self.sbs_sc)
        self.assertEqual(len(self.plotter_multi._bs), 3,
                         "wrong number of band objects")
        self.assertEqual(self.plotter_multi._nb_bands, [96, 96, 96],
                         "wrong number of bands")

    def test_get_branch_steps(self):
        steps_idx = BSPlotter._get_branch_steps(self.sbs_sc.branches)
        self.assertEqual(steps_idx, [0, 121, 132, 143],
                         "wrong list of steps idx")

    def test_rescale_distances(self):
        rescaled_distances = self.plotter_multi._rescale_distances(
            self.sbs_sc, self.sbs_met)
        self.assertEqual(
            len(rescaled_distances),
            len(self.sbs_met.distance),
            "wrong lenght of distances list",
        )
        self.assertEqual(rescaled_distances[-1], 6.5191398067252875,
                         "wrong last distance value")
        self.assertEqual(
            rescaled_distances[148],
            self.sbs_sc.distance[19],
            "wrong distance at high symm k-point",
        )

    def test_interpolate_bands(self):
        data = self.plotter.bs_plot_data()
        d = data["distances"]
        en = data["energy"]["1"]
        int_distances, int_energies = self.plotter._interpolate_bands(d, en)

        self.assertEqual(len(int_distances), 10,
                         "wrong lenght of distances list")
        self.assertEqual(len(int_distances[0]), 100,
                         "wrong lenght of distances in a branch")
        self.assertEqual(len(int_energies), 10,
                         "wrong lenght of distances list")
        self.assertEqual(int_energies[0].shape, (16, 100),
                         "wrong lenght of distances list")

    def test_bs_plot_data(self):
        self.assertEqual(
            len(self.plotter.bs_plot_data()["distances"]),
            10,
            "wrong number of sequences of branches",
        )
        self.assertEqual(
            len(self.plotter.bs_plot_data()["distances"][0]),
            16,
            "wrong number of distances in the first sequence of branches",
        )
        self.assertEqual(
            sum([len(e) for e in self.plotter.bs_plot_data()["distances"]]),
            160,
            "wrong number of distances",
        )

        lenght = len(
            self.plotter.bs_plot_data(split_branches=False)["distances"][0])
        self.assertEqual(
            lenght, 144,
            "wrong number of distances in the first sequence of branches")

        lenght = len(
            self.plotter.bs_plot_data(split_branches=False)["distances"])
        self.assertEqual(
            lenght, 2,
            "wrong number of distances in the first sequence of branches")

        self.assertEqual(self.plotter.bs_plot_data()["ticks"]["label"][5], "K",
                         "wrong tick label")
        self.assertEqual(
            len(self.plotter.bs_plot_data()["ticks"]["label"]),
            19,
            "wrong number of tick labels",
        )

    def test_get_ticks(self):
        self.assertEqual(self.plotter.get_ticks()["label"][5], "K",
                         "wrong tick label")
        self.assertEqual(
            self.plotter.get_ticks()["distance"][5],
            2.406607625322699,
            "wrong tick distance",
        )

    # Minimal baseline testing for get_plot. not a true test. Just checks that
    # it can actually execute.
    def test_get_plot(self):
        # zero_to_efermi = True, ylim = None, smooth = False,
        # vbm_cbm_marker = False, smooth_tol = None

        # Disabling latex is needed for this test to work.
        from matplotlib import rc

        rc("text", usetex=False)

        plt = self.plotter.get_plot()
        self.assertEqual(plt.ylim(), (-4.0, 7.6348), "wrong ylim")
        plt = self.plotter.get_plot(smooth=True)
        plt = self.plotter.get_plot(vbm_cbm_marker=True)
        self.plotter.save_plot("bsplot.png")
        self.assertTrue(os.path.isfile("bsplot.png"))
        os.remove("bsplot.png")
        plt.close("all")

        # test plotter with 2 bandstructures
        plt = self.plotter_multi.get_plot()
        self.assertEqual(len(plt.gca().get_lines()), 874,
                         "wrong number of lines")
        self.assertEqual(plt.ylim(), (-10.0, 10.0), "wrong ylim")
        plt = self.plotter_multi.get_plot(zero_to_efermi=False)
        self.assertEqual(plt.ylim(), (-15.2379, 12.67141266), "wrong ylim")
        plt = self.plotter_multi.get_plot(smooth=True)
        self.plotter_multi.save_plot("bsplot.png")
        self.assertTrue(os.path.isfile("bsplot.png"))
        os.remove("bsplot.png")
        plt.close("all")