class BSPlotterTest(unittest.TestCase): def setUp(self): with open(os.path.join(test_dir, "CaO_2605_bandstructure.json"), "r", encoding="utf-8") as f: d = json.loads(f.read()) self.bs = BandStructureSymmLine.from_dict(d) self.plotter = BSPlotter(self.bs) self.assertEqual(len(self.plotter._bs), 1, "wrong number of band objects") with open(os.path.join(test_dir, "N2_12103_bandstructure.json"), "r", encoding="utf-8") as f: d = json.loads(f.read()) self.sbs_sc = BandStructureSymmLine.from_dict(d) with open(os.path.join(test_dir, "C_48_bandstructure.json"), "r", encoding="utf-8") as f: d = json.loads(f.read()) self.sbs_met = BandStructureSymmLine.from_dict(d) self.plotter_multi = BSPlotter([self.sbs_sc, self.sbs_met]) self.assertEqual(len(self.plotter_multi._bs), 2, "wrong number of band objects") self.assertEqual(self.plotter_multi._nb_bands, [96, 96], "wrong number of bands") warnings.simplefilter("ignore") def tearDown(self): warnings.simplefilter("default") def test_add_bs(self): self.plotter_multi.add_bs(self.sbs_sc) self.assertEqual(len(self.plotter_multi._bs), 3, "wrong number of band objects") self.assertEqual(self.plotter_multi._nb_bands, [96, 96, 96], "wrong number of bands") def test_get_branch_steps(self): steps_idx = BSPlotter._get_branch_steps(self.sbs_sc.branches) self.assertEqual(steps_idx, [0, 121, 132, 143], "wrong list of steps idx") def test_rescale_distances(self): rescaled_distances = self.plotter_multi._rescale_distances( self.sbs_sc, self.sbs_met) self.assertEqual( len(rescaled_distances), len(self.sbs_met.distance), "wrong lenght of distances list", ) self.assertEqual(rescaled_distances[-1], 6.5191398067252875, "wrong last distance value") self.assertEqual( rescaled_distances[148], self.sbs_sc.distance[19], "wrong distance at high symm k-point", ) def test_interpolate_bands(self): data = self.plotter.bs_plot_data() d = data["distances"] en = data["energy"]["1"] int_distances, int_energies = self.plotter._interpolate_bands(d, en) self.assertEqual(len(int_distances), 10, "wrong lenght of distances list") self.assertEqual(len(int_distances[0]), 100, "wrong lenght of distances in a branch") self.assertEqual(len(int_energies), 10, "wrong lenght of distances list") self.assertEqual(int_energies[0].shape, (16, 100), "wrong lenght of distances list") def test_bs_plot_data(self): self.assertEqual( len(self.plotter.bs_plot_data()["distances"]), 10, "wrong number of sequences of branches", ) self.assertEqual( len(self.plotter.bs_plot_data()["distances"][0]), 16, "wrong number of distances in the first sequence of branches", ) self.assertEqual( sum([len(e) for e in self.plotter.bs_plot_data()["distances"]]), 160, "wrong number of distances", ) lenght = len( self.plotter.bs_plot_data(split_branches=False)["distances"][0]) self.assertEqual( lenght, 144, "wrong number of distances in the first sequence of branches") lenght = len( self.plotter.bs_plot_data(split_branches=False)["distances"]) self.assertEqual( lenght, 2, "wrong number of distances in the first sequence of branches") self.assertEqual(self.plotter.bs_plot_data()["ticks"]["label"][5], "K", "wrong tick label") self.assertEqual( len(self.plotter.bs_plot_data()["ticks"]["label"]), 19, "wrong number of tick labels", ) def test_get_ticks(self): self.assertEqual(self.plotter.get_ticks()["label"][5], "K", "wrong tick label") self.assertEqual( self.plotter.get_ticks()["distance"][5], 2.406607625322699, "wrong tick distance", ) # Minimal baseline testing for get_plot. not a true test. Just checks that # it can actually execute. def test_get_plot(self): # zero_to_efermi = True, ylim = None, smooth = False, # vbm_cbm_marker = False, smooth_tol = None # Disabling latex is needed for this test to work. from matplotlib import rc rc("text", usetex=False) plt = self.plotter.get_plot() self.assertEqual(plt.ylim(), (-4.0, 7.6348), "wrong ylim") plt = self.plotter.get_plot(smooth=True) plt = self.plotter.get_plot(vbm_cbm_marker=True) self.plotter.save_plot("bsplot.png") self.assertTrue(os.path.isfile("bsplot.png")) os.remove("bsplot.png") plt.close("all") # test plotter with 2 bandstructures plt = self.plotter_multi.get_plot() self.assertEqual(len(plt.gca().get_lines()), 874, "wrong number of lines") self.assertEqual(plt.ylim(), (-10.0, 10.0), "wrong ylim") plt = self.plotter_multi.get_plot(zero_to_efermi=False) self.assertEqual(plt.ylim(), (-15.2379, 12.67141266), "wrong ylim") plt = self.plotter_multi.get_plot(smooth=True) self.plotter_multi.save_plot("bsplot.png") self.assertTrue(os.path.isfile("bsplot.png")) os.remove("bsplot.png") plt.close("all")