class BztInterpolatorTest(unittest.TestCase): def setUp(self): loader = VasprunLoader().from_file(vrunfile) self.bztInterp = BztInterpolator(loader) self.assertIsNotNone(self.bztInterp) warnings.simplefilter("ignore") def tearDown(self): warnings.resetwarnings() def test_properties(self): self.assertTupleEqual(self.bztInterp.cband.shape,(5, 3, 3, 3, 148877)) self.assertTupleEqual(self.bztInterp.eband.shape,(5, 148877)) self.assertTupleEqual(self.bztInterp.coeffs.shape,(5, 1429)) self.assertEqual(self.bztInterp.nemax,12) def test_get_band_structure(self): sbs = self.bztInterp.get_band_structure() self.assertIsNotNone(sbs) self.assertTupleEqual(list(sbs.bands.values())[0].shape,(5,137)) def test_tot_dos(self): tot_dos = self.bztInterp.get_dos(T=200) self.assertIsNotNone(tot_dos) self.assertEqual(len(tot_dos.energies),10000) def test_tot_proj_dos(self): tot_proj_dos = self.bztInterp.get_dos(partial_dos=True,T=200) self.assertIsNotNone(tot_proj_dos) self.assertEqual(len(tot_proj_dos.get_spd_dos().values()),3)
class BztInterpolatorTest(unittest.TestCase): def setUp(self): self.loader = VasprunLoader(vrun) self.assertTupleEqual(self.loader.proj.shape,(120, 20, 2, 9)) self.bztInterp = BztInterpolator(self.loader,lpfac=2) self.assertIsNotNone(self.bztInterp) warnings.simplefilter("ignore") def tearDown(self): warnings.resetwarnings() def test_properties(self): self.assertTupleEqual(self.bztInterp.cband.shape,(5, 3, 3, 3, 29791)) self.assertTupleEqual(self.bztInterp.eband.shape,(5, 29791)) self.assertTupleEqual(self.bztInterp.coeffs.shape,(5, 322)) self.assertEqual(self.bztInterp.nemax,12) def test_get_band_structure(self): sbs = self.bztInterp.get_band_structure() self.assertIsNotNone(sbs) self.assertTupleEqual(sbs.bands[Spin.up].shape,(5,137)) def test_tot_dos(self): tot_dos = self.bztInterp.get_dos(T=200,npts_mu = 100) self.assertIsNotNone(tot_dos) self.assertEqual(len(tot_dos.energies),100) self.assertAlmostEqual(tot_dos.densities[Spin.up][0],1.42859939,5) def test_tot_proj_dos(self): tot_proj_dos = self.bztInterp.get_dos(partial_dos=True,T=200,npts_mu = 100) self.assertIsNotNone(tot_proj_dos) self.assertEqual(len(tot_proj_dos.get_spd_dos().values()),3) pdos = tot_proj_dos.get_spd_dos()[OrbitalType.s].densities[Spin.up][0] self.assertAlmostEqual(pdos,15.474392020,5)
class BztInterpolatorTest(unittest.TestCase): def setUp(self): self.loader = VasprunLoader(vrun) self.assertTupleEqual(self.loader.proj.shape,(120, 20, 2, 9)) self.bztInterp = BztInterpolator(self.loader,lpfac=2) self.assertIsNotNone(self.bztInterp) warnings.simplefilter("ignore") bs_sp = loadfn(os.path.join(test_dir, "N2_bandstructure.json")) loader_sp_up = BandstructureLoader(bs_sp, vrun_sp.structures[-1],spin=1) loader_sp_dn = BandstructureLoader(bs_sp, vrun_sp.structures[-1],spin=-1) min_bnd = min(loader_sp_up.ebands.min(),loader_sp_dn.ebands.min()) max_bnd = max(loader_sp_up.ebands.max(),loader_sp_dn.ebands.max()) loader_sp_up.set_upper_lower_bands(min_bnd,max_bnd) loader_sp_dn.set_upper_lower_bands(min_bnd,max_bnd) self.bztI_up = BztInterpolator(loader_sp_up,lpfac=2,energy_range=np.inf,curvature=False) self.bztI_dn = BztInterpolator(loader_sp_dn,lpfac=2,energy_range=np.inf,curvature=False) def tearDown(self): warnings.simplefilter("default") def test_properties(self): self.assertTupleEqual(self.bztInterp.cband.shape,(5, 3, 3, 3, 29791)) self.assertTupleEqual(self.bztInterp.eband.shape,(5, 29791)) self.assertTupleEqual(self.bztInterp.coeffs.shape,(5, 322)) self.assertEqual(self.bztInterp.nemax,12) def test_get_band_structure(self): sbs = self.bztInterp.get_band_structure() self.assertIsNotNone(sbs) self.assertTupleEqual(sbs.bands[Spin.up].shape,(5,137)) def test_tot_dos(self): tot_dos = self.bztInterp.get_dos(T=200,npts_mu = 100) self.assertIsNotNone(tot_dos) self.assertEqual(len(tot_dos.energies),100) self.assertAlmostEqual(tot_dos.densities[Spin.up][0],1.42859939,5) dos_up = self.bztI_up.get_dos(partial_dos=False,npts_mu = 100) dos_dn = self.bztI_dn.get_dos(partial_dos=False,npts_mu = 100) cdos = merge_up_down_doses(dos_up,dos_dn) self.assertAlmostEqual(cdos.densities[Spin.down][50],92.87836778,5) self.assertAlmostEqual(cdos.densities[Spin.up][45],9.564067,5) self.assertEqual(len(cdos.energies),100) def test_tot_proj_dos(self): tot_proj_dos = self.bztInterp.get_dos(partial_dos=True,T=200,npts_mu = 100) self.assertIsNotNone(tot_proj_dos) self.assertEqual(len(tot_proj_dos.get_spd_dos().values()),3) pdos = tot_proj_dos.get_spd_dos()[OrbitalType.s].densities[Spin.up][0] self.assertAlmostEqual(pdos,15.474392020,5)
class BztInterpolatorTest(unittest.TestCase): def setUp(self): self.loader = VasprunLoader(vrun) self.assertTupleEqual(self.loader.proj.shape,(120, 20, 2, 9)) self.bztInterp = BztInterpolator(self.loader,lpfac=2) self.assertIsNotNone(self.bztInterp) warnings.simplefilter("ignore") bs_sp = loadfn(os.path.join(test_dir, "N2_bandstructure.json")) loader_sp_up = BandstructureLoader(bs_sp, vrun_sp.structures[-1],spin=1) loader_sp_dn = BandstructureLoader(bs_sp, vrun_sp.structures[-1],spin=-1) min_bnd = min(loader_sp_up.ebands.min(),loader_sp_dn.ebands.min()) max_bnd = max(loader_sp_up.ebands.max(),loader_sp_dn.ebands.max()) loader_sp_up.set_upper_lower_bands(min_bnd,max_bnd) loader_sp_dn.set_upper_lower_bands(min_bnd,max_bnd) self.bztI_up = BztInterpolator(loader_sp_up,lpfac=2,energy_range=np.inf,curvature=False) self.bztI_dn = BztInterpolator(loader_sp_dn,lpfac=2,energy_range=np.inf,curvature=False) def tearDown(self): warnings.resetwarnings() def test_properties(self): self.assertTupleEqual(self.bztInterp.cband.shape,(5, 3, 3, 3, 29791)) self.assertTupleEqual(self.bztInterp.eband.shape,(5, 29791)) self.assertTupleEqual(self.bztInterp.coeffs.shape,(5, 322)) self.assertEqual(self.bztInterp.nemax,12) def test_get_band_structure(self): sbs = self.bztInterp.get_band_structure() self.assertIsNotNone(sbs) self.assertTupleEqual(sbs.bands[Spin.up].shape,(5,137)) def test_tot_dos(self): tot_dos = self.bztInterp.get_dos(T=200,npts_mu = 100) self.assertIsNotNone(tot_dos) self.assertEqual(len(tot_dos.energies),100) self.assertAlmostEqual(tot_dos.densities[Spin.up][0],1.42859939,5) dos_up = self.bztI_up.get_dos(partial_dos=False,npts_mu = 100) dos_dn = self.bztI_dn.get_dos(partial_dos=False,npts_mu = 100) cdos = merge_up_down_doses(dos_up,dos_dn) self.assertAlmostEqual(cdos.densities[Spin.down][50],92.87836778,5) self.assertAlmostEqual(cdos.densities[Spin.up][45],9.564067,5) self.assertEqual(len(cdos.energies),100) def test_tot_proj_dos(self): tot_proj_dos = self.bztInterp.get_dos(partial_dos=True,T=200,npts_mu = 100) self.assertIsNotNone(tot_proj_dos) self.assertEqual(len(tot_proj_dos.get_spd_dos().values()),3) pdos = tot_proj_dos.get_spd_dos()[OrbitalType.s].densities[Spin.up][0] self.assertAlmostEqual(pdos,15.474392020,5)
class BztInterpolatorTest(unittest.TestCase): def setUp(self): self.loader = VasprunLoader(vrun) self.assertTupleEqual(self.loader.proj.shape, (120, 20, 2, 9)) self.bztInterp = BztInterpolator(self.loader, lpfac=2) self.assertIsNotNone(self.bztInterp) warnings.simplefilter("ignore") def tearDown(self): warnings.resetwarnings() def test_properties(self): self.assertTupleEqual(self.bztInterp.cband.shape, (5, 3, 3, 3, 29791)) self.assertTupleEqual(self.bztInterp.eband.shape, (5, 29791)) self.assertTupleEqual(self.bztInterp.coeffs.shape, (5, 322)) self.assertEqual(self.bztInterp.nemax, 12) def test_get_band_structure(self): sbs = self.bztInterp.get_band_structure() self.assertIsNotNone(sbs) self.assertTupleEqual(sbs.bands[Spin.up].shape, (5, 137)) def test_tot_dos(self): tot_dos = self.bztInterp.get_dos(T=200, npts_mu=100) self.assertIsNotNone(tot_dos) self.assertEqual(len(tot_dos.energies), 100) self.assertAlmostEqual(tot_dos.densities[Spin.up][0], 1.42859939, 5) def test_tot_proj_dos(self): tot_proj_dos = self.bztInterp.get_dos(partial_dos=True, T=200, npts_mu=100) self.assertIsNotNone(tot_proj_dos) self.assertEqual(len(tot_proj_dos.get_spd_dos().values()), 3) pdos = tot_proj_dos.get_spd_dos()[OrbitalType.s].densities[Spin.up][0] self.assertAlmostEqual(pdos, 15.474392020, 5)
class BztInterpolatorTest(unittest.TestCase): def setUp(self): self.loader = VasprunBSLoader(vrun) self.bztInterp = BztInterpolator(self.loader, lpfac=2) self.assertIsNotNone(self.bztInterp) self.bztInterp = BztInterpolator(self.loader, lpfac=2, save_bztInterp=True, fname=bztinterp_fn) self.assertIsNotNone(self.bztInterp) self.bztInterp = BztInterpolator(self.loader, load_bztInterp=True, fname=bztinterp_fn) self.assertIsNotNone(self.bztInterp) warnings.simplefilter("ignore") self.loader_sp = VasprunBSLoader(vrun_sp) self.bztInterp_sp = BztInterpolator(self.loader_sp, lpfac=2) self.assertIsNotNone(self.bztInterp_sp) self.bztInterp_sp = BztInterpolator(self.loader_sp, lpfac=2, save_bztInterp=True, fname=bztinterp_fn) self.assertIsNotNone(self.bztInterp_sp) self.bztInterp_sp = BztInterpolator(self.loader_sp, lpfac=2, load_bztInterp=True, fname=bztinterp_fn) self.assertIsNotNone(self.bztInterp_sp) warnings.simplefilter("ignore") def tearDown(self): warnings.simplefilter("default") def test_properties(self): self.assertTupleEqual(self.bztInterp.cband.shape, (6, 3, 3, 3, 29791)) self.assertTupleEqual(self.bztInterp.eband.shape, (6, 29791)) self.assertTupleEqual(self.bztInterp.coeffs.shape, (6, 322)) self.assertEqual(self.bztInterp.data.nelect, 6.0) self.assertEqual(self.bztInterp.data.nelect_all, 20.0) self.assertTupleEqual(self.bztInterp.data.ebands.shape, (6, 120)) self.assertTupleEqual(self.bztInterp_sp.cband.shape, (10, 3, 3, 3, 23275)) self.assertTupleEqual(self.bztInterp_sp.eband.shape, (10, 23275)) self.assertTupleEqual(self.bztInterp_sp.coeffs.shape, (10, 519)) self.assertEqual(self.bztInterp_sp.data.nelect, 6.0) self.assertEqual(self.bztInterp_sp.data.nelect_all, 10.0) self.assertTupleEqual(self.bztInterp_sp.data.ebands.shape, (10, 198)) def test_get_band_structure(self): sbs = self.bztInterp.get_band_structure() self.assertIsNotNone(sbs) self.assertTupleEqual(sbs.bands[Spin.up].shape, (6, 137)) kpaths = [["L", "K"]] kp_lbl = {"L": np.array([0.5, 0.5, 0.5]), "K": np.array([0.375, 0.375, 0.75])} sbs = self.bztInterp.get_band_structure(kpaths, kp_lbl) self.assertIsNotNone(sbs) self.assertTupleEqual(sbs.bands[Spin.up].shape, (6, 20)) sbs = self.bztInterp_sp.get_band_structure() self.assertIsNotNone(sbs) self.assertTupleEqual(sbs.bands[Spin.up].shape, (6, 143)) self.assertTupleEqual(sbs.bands[Spin.down].shape, (4, 143)) def test_tot_dos(self): tot_dos = self.bztInterp.get_dos(T=200, npts_mu=100) self.assertIsNotNone(tot_dos) self.assertEqual(len(tot_dos.energies), 100) self.assertAlmostEqual(tot_dos.densities[Spin.up][0], 1.35371715, 5) tot_dos = self.bztInterp_sp.get_dos(T=200, npts_mu=100) self.assertIsNotNone(tot_dos) self.assertEqual(len(tot_dos.energies), 100) self.assertAlmostEqual(tot_dos.densities[Spin.up][75], 88.034456, 5) self.assertAlmostEqual(tot_dos.densities[Spin.down][75], 41.421367, 5) def test_tot_proj_dos(self): tot_proj_dos = self.bztInterp.get_dos(partial_dos=True, T=200, npts_mu=100) self.assertIsNotNone(tot_proj_dos) self.assertEqual(len(tot_proj_dos.get_spd_dos().values()), 3) pdos = tot_proj_dos.get_spd_dos()[OrbitalType.s].densities[Spin.up][75] self.assertAlmostEqual(pdos, 2490.169396, 5) tot_proj_dos = self.bztInterp_sp.get_dos(partial_dos=True, T=200, npts_mu=100) self.assertIsNotNone(tot_proj_dos) self.assertEqual(len(tot_proj_dos.get_spd_dos().values()), 3) pdos = tot_proj_dos.get_spd_dos()[OrbitalType.s].densities[Spin.up][75] self.assertAlmostEqual(pdos, 166.4933305, 5) pdos = tot_proj_dos.get_spd_dos()[OrbitalType.s].densities[Spin.down][75] self.assertAlmostEqual(pdos, 272.194174, 5)
def bandplot_func( filenames=None, code='vasp', prefix=None, directory=None, vbm_cbm_marker=False, projection_selection=None, mode='rgb', pred=None, interpolate_factor=4, circle_size=150, dos_file=None, cart_coords=False, scissor=None, ylabel='Energy (eV)', dos_label=None, elements=None, lm_orbitals=None, atoms=None, spin=None, total_only=False, plot_total=True, legend_cutoff=3, gaussian=None, height=None, width=None, ymin=-6., ymax=6., colours=None, yscale=1, style=None, no_base_style=False, image_format='pdf', dpi=400, plt=None, fonts=None, boltz={ "ifinter": "T", "lpfac": "10", "energy_range": "50", "curvature": "", "load": "T", 'ismetaltolerance': '0.01' }, nelec=0): if not filenames: filenames = find_vasprun_files() elif isinstance(filenames, str): filenames = [filenames] # only load the orbital projects if we definitely need them parse_projected = True if projection_selection else False # now load all the band structure data and combine using the # get_reconstructed_band_structure function from pymatgen bandstructures = [] if code == 'vasp': for vr_file in filenames: vr = BSVasprun(vr_file, parse_projected_eigen=parse_projected) print("BSVasprun", type(vr), vr) print("vr.eigenvalues.keys()", type(vr.eigenvalues.keys()), vr.eigenvalues.keys()) if pred.any(): # Fill in Model prediction model = BSVasprun(vr_file, parse_projected_eigen=parse_projected) print("pred", type(pred), pred.shape) pred = np.expand_dims(pred, axis=-1) for key in model.eigenvalues.keys(): key_last = key print("model.eigenvalues[key][:, :, :].shape[0]", key, type(model.eigenvalues[key][:, :, :]), model.eigenvalues[key][:, :, :].shape, pred[:, :, :].shape) bands = min(model.eigenvalues[key][:, :, :].shape[1], pred.shape[1]) print( "bands", bands, "attention! max: ", max(model.eigenvalues[key][:, :, :].shape[1], pred.shape[1])) print( "equel False?", np.sum(model.eigenvalues[key][:, :bands, :] - pred[:, :bands, :])) model.eigenvalues[key][:, :bands, :] = pred[:, :bands, :] print( "equel True?", np.sum(model.eigenvalues[key][:, :bands, :] - pred[:, :bands, :])) print( "equel model vr?", np.sum(model.eigenvalues[key][:, :bands, :] - vr.eigenvalues[key][:, :bands, :])) # spin = 1 # for only plotting spin up oder down 1, -1 # model.eigenvalues[key_last][:, :bands, :] = pred[:, :bands, :] #boltztrap={'ifinter':False,'lpfac':10,'energy_range':50,'curvature':False}): if bool(boltz['ifinter']): b_data = VasprunBSLoader(vr) model_data = VasprunBSLoader(model) print("BSVasprunLoader", type(b_data), b_data) b_inter = BztInterpolator(b_data, lpfac=int(boltz['lpfac']), energy_range=float( boltz['energy_range']), curvature=bool(boltz['curvature']), save_bztInterp=True, load_bztInterp=bool(boltz['load'])) model_inter = BztInterpolator( model_data, lpfac=int(boltz['lpfac']), energy_range=float(boltz['energy_range']), curvature=bool(boltz['curvature']), save_bztInterp=True, load_bztInterp=bool(boltz['load'])) try: kpath = json.load(open('./kpath', 'r')) kpaths = kpath['path'] kpoints_lbls_dict = {} for i in range(len(kpaths)): for j in [0, 1]: if 'GAMMA' == kpaths[i][j]: kpaths[i][j] = '\Gamma' for k, v in kpath['kpoints_rel'].items(): if k == 'GAMMA': k = '\Gamma' kpoints_lbls_dict[k] = v except: kpaths = None kpoints_lbls_dict = None print(kpaths, kpoints_lbls_dict) bs = b_inter.get_band_structure( kpaths=kpaths, kpoints_lbls_dict=kpoints_lbls_dict) model_bs = model_inter.get_band_structure( kpaths=kpaths, kpoints_lbls_dict=kpoints_lbls_dict) #bs_uniform = b_inter.get_band_structure() gap = bs.get_band_gap() nvb = int(np.ceil(nelec / (int(bs.is_spin_polarized) + 1))) vbm = -100 print("WHC interpolated gap: %s" % gap) for spin, v in bs.bands.items(): vbm = max(vbm, max(v[nvb - 1])) print( 'WHC WARNNING vasp fermi %s interpolation vbm %s nelec %s nvb %s' % (bs.efermi, vbm, nelec, nvb)) if vbm < bs.efermi: bs.efermi = vbm print("if vbm <") if vbm < model_bs.efermi: model_bs.efermi = vbm print("if vbm <") print(bs.bands.keys()) band_keys = list(bs.bands.keys()) print("Band shapes", bs.bands[band_keys[0]].shape, model_bs.bands[band_keys[0]].shape) print( "equel bands?", np.sum((bs.bands[band_keys[0]] - bs.efermi) - model_bs.bands[band_keys[0]])) bs.bands[band_keys[0]] = ( bs.bands[band_keys[0]] - bs.efermi ) # why?????????????????????????????????????????????????? # bs.bands[band_keys[1]] = (bs.bands[band_keys[1]] - bs.efermi) # why?????????????????????????????????????????????????? print( "equel bands fermi shifted?", np.sum((bs.bands[band_keys[0]] - bs.efermi) - model_bs.bands[band_keys[0]])) # bandstructures.append(bs) # bandstructures.append(model_bs) bs = get_reconstructed_band_structure([bs]) model_bs = get_reconstructed_band_structure([model_bs]) if bool(boltz['ifinter']): bs.nvb = nvb bs.ismetaltolerance = float(boltz['ismetaltolerance']) model_bs.nvb = nvb model_bs.ismetaltolerance = float(boltz['ismetaltolerance']) print("dft bands", bs.bands[band_keys[0]]) print("dft ktps", len(bs.kpoints), kpts.shape) print("dft labels", bs.labels_dict) for key in bs.labels_dict.keys(): print(bs.labels_dict[key].label, bs.labels_dict[key].as_dict(), bs.labels_dict[key].a, bs.labels_dict[key].b, bs.labels_dict[key].c, bs.labels_dict[key].frac_coords) labels = [] for i in range(len(bs.kpoints)): # print(i, bs.kpoints[i]) for key in bs.labels_dict.keys(): if bs.labels_dict[key].label == bs.kpoints[ i].label and bs.labels_dict[ key].label != bs.kpoints[i - 1].label: # print("Labels!!!!!", i, bs.labels_dict[key].label) labels.append([i, bs.labels_dict[key].label]) print(labels) print("dft efermi", bs.efermi) print("dft lattice_rec", bs.lattice_rec) print("dft structure", bs.structure) print("model bands", model_bs.bands[band_keys[0]]) print("model ktps", len(model_bs.kpoints), kpts.shape) print("model labels", model_bs.labels_dict) print("model efermi", model_bs.efermi) print("model lattice_rec", model_bs.lattice_rec) print("model structure", model_bs.structure) return bs.bands[band_keys[0]], model_bs.bands[band_keys[0]], labels save_files = False if plt else True dos_plotter = None dos_opts = None if dos_file: dos, pdos = load_dos(dos_file, elements, lm_orbitals, atoms, gaussian, total_only) dos_plotter = SDOSPlotter(dos, pdos) dos_opts = { 'plot_total': plot_total, 'legend_cutoff': legend_cutoff, 'colours': colours, 'yscale': yscale } model_and_dft_bs = [bs, model_bs] plotter = SBSPlotter(model_bs) print("spin", spin) if len(vr.eigenvalues.keys()) == 1: spin = None print("spin", spin) plt = plotter.get_plot(zero_to_efermi=True, ymin=ymin, ymax=ymax, height=height, width=width, vbm_cbm_marker=vbm_cbm_marker, ylabel=ylabel, plt=plt, dos_plotter=dos_plotter, dos_options=dos_opts, dos_label=dos_label, fonts=fonts, style=style, no_base_style=no_base_style, spin=spin) # don't save if pyplot object provided save_files = False if plt else True if save_files: basename = 'band.{}'.format(image_format) filename = '{}_{}'.format(prefix, basename) if prefix else basename if directory: filename = os.path.join(directory, filename) plt.savefig(filename, format=image_format, dpi=dpi, bbox_inches='tight') written = [filename] written += save_data_files(bs, prefix=prefix, directory=directory) return written else: return plt