class BztInterpolatorTest(unittest.TestCase):

    def setUp(self):
        loader = VasprunLoader().from_file(vrunfile)
        self.bztInterp = BztInterpolator(loader)
        self.assertIsNotNone(self.bztInterp)
        warnings.simplefilter("ignore")

    def tearDown(self):
        warnings.resetwarnings()

    def test_properties(self):
        self.assertTupleEqual(self.bztInterp.cband.shape,(5, 3, 3, 3, 148877))
        self.assertTupleEqual(self.bztInterp.eband.shape,(5, 148877))
        self.assertTupleEqual(self.bztInterp.coeffs.shape,(5, 1429))
        self.assertEqual(self.bztInterp.nemax,12)
        
    def test_get_band_structure(self):
        sbs = self.bztInterp.get_band_structure()
        self.assertIsNotNone(sbs)
        self.assertTupleEqual(list(sbs.bands.values())[0].shape,(5,137))

    def test_tot_dos(self):
        tot_dos = self.bztInterp.get_dos(T=200)
        self.assertIsNotNone(tot_dos)
        self.assertEqual(len(tot_dos.energies),10000)

    def test_tot_proj_dos(self):
        tot_proj_dos = self.bztInterp.get_dos(partial_dos=True,T=200)
        self.assertIsNotNone(tot_proj_dos)
        self.assertEqual(len(tot_proj_dos.get_spd_dos().values()),3)
class BztInterpolatorTest(unittest.TestCase):

    def setUp(self):
        self.loader = VasprunLoader(vrun)
        self.assertTupleEqual(self.loader.proj.shape,(120, 20, 2, 9))
        self.bztInterp = BztInterpolator(self.loader,lpfac=2)
        self.assertIsNotNone(self.bztInterp)
        warnings.simplefilter("ignore")

    def tearDown(self):
        warnings.resetwarnings()

    def test_properties(self):
        self.assertTupleEqual(self.bztInterp.cband.shape,(5, 3, 3, 3, 29791))
        self.assertTupleEqual(self.bztInterp.eband.shape,(5, 29791))
        self.assertTupleEqual(self.bztInterp.coeffs.shape,(5, 322))
        self.assertEqual(self.bztInterp.nemax,12)
        
    def test_get_band_structure(self):
        sbs = self.bztInterp.get_band_structure()
        self.assertIsNotNone(sbs)
        self.assertTupleEqual(sbs.bands[Spin.up].shape,(5,137))

    def test_tot_dos(self):
        tot_dos = self.bztInterp.get_dos(T=200,npts_mu = 100)
        self.assertIsNotNone(tot_dos)
        self.assertEqual(len(tot_dos.energies),100)
        self.assertAlmostEqual(tot_dos.densities[Spin.up][0],1.42859939,5)

    def test_tot_proj_dos(self):
        tot_proj_dos = self.bztInterp.get_dos(partial_dos=True,T=200,npts_mu = 100)
        self.assertIsNotNone(tot_proj_dos)
        self.assertEqual(len(tot_proj_dos.get_spd_dos().values()),3)
        pdos = tot_proj_dos.get_spd_dos()[OrbitalType.s].densities[Spin.up][0]
        self.assertAlmostEqual(pdos,15.474392020,5)
Exemple #3
0
class BztInterpolatorTest(unittest.TestCase):

    def setUp(self):
        self.loader = VasprunLoader(vrun)
        self.assertTupleEqual(self.loader.proj.shape,(120, 20, 2, 9))
        self.bztInterp = BztInterpolator(self.loader,lpfac=2)
        self.assertIsNotNone(self.bztInterp)
        warnings.simplefilter("ignore")
        
        bs_sp = loadfn(os.path.join(test_dir, "N2_bandstructure.json"))
        loader_sp_up = BandstructureLoader(bs_sp, vrun_sp.structures[-1],spin=1)
        loader_sp_dn = BandstructureLoader(bs_sp, vrun_sp.structures[-1],spin=-1)
        
        min_bnd = min(loader_sp_up.ebands.min(),loader_sp_dn.ebands.min())
        max_bnd = max(loader_sp_up.ebands.max(),loader_sp_dn.ebands.max())
        loader_sp_up.set_upper_lower_bands(min_bnd,max_bnd)
        loader_sp_dn.set_upper_lower_bands(min_bnd,max_bnd)

        self.bztI_up = BztInterpolator(loader_sp_up,lpfac=2,energy_range=np.inf,curvature=False)
        self.bztI_dn = BztInterpolator(loader_sp_dn,lpfac=2,energy_range=np.inf,curvature=False)
        
    def tearDown(self):
        warnings.simplefilter("default")

    def test_properties(self):
        self.assertTupleEqual(self.bztInterp.cband.shape,(5, 3, 3, 3, 29791))
        self.assertTupleEqual(self.bztInterp.eband.shape,(5, 29791))
        self.assertTupleEqual(self.bztInterp.coeffs.shape,(5, 322))
        self.assertEqual(self.bztInterp.nemax,12)
        
    def test_get_band_structure(self):
        sbs = self.bztInterp.get_band_structure()
        self.assertIsNotNone(sbs)
        self.assertTupleEqual(sbs.bands[Spin.up].shape,(5,137))

    def test_tot_dos(self):
        tot_dos = self.bztInterp.get_dos(T=200,npts_mu = 100)
        self.assertIsNotNone(tot_dos)
        self.assertEqual(len(tot_dos.energies),100)
        self.assertAlmostEqual(tot_dos.densities[Spin.up][0],1.42859939,5)
            
        dos_up = self.bztI_up.get_dos(partial_dos=False,npts_mu = 100)
        dos_dn = self.bztI_dn.get_dos(partial_dos=False,npts_mu = 100)
        cdos = merge_up_down_doses(dos_up,dos_dn)
        self.assertAlmostEqual(cdos.densities[Spin.down][50],92.87836778,5)
        self.assertAlmostEqual(cdos.densities[Spin.up][45],9.564067,5)
        self.assertEqual(len(cdos.energies),100)

    def test_tot_proj_dos(self):
        tot_proj_dos = self.bztInterp.get_dos(partial_dos=True,T=200,npts_mu = 100)
        self.assertIsNotNone(tot_proj_dos)
        self.assertEqual(len(tot_proj_dos.get_spd_dos().values()),3)
        pdos = tot_proj_dos.get_spd_dos()[OrbitalType.s].densities[Spin.up][0]
        self.assertAlmostEqual(pdos,15.474392020,5)
Exemple #4
0
class BztInterpolatorTest(unittest.TestCase):

    def setUp(self):
        self.loader = VasprunLoader(vrun)
        self.assertTupleEqual(self.loader.proj.shape,(120, 20, 2, 9))
        self.bztInterp = BztInterpolator(self.loader,lpfac=2)
        self.assertIsNotNone(self.bztInterp)
        warnings.simplefilter("ignore")
        
        bs_sp = loadfn(os.path.join(test_dir, "N2_bandstructure.json"))
        loader_sp_up = BandstructureLoader(bs_sp, vrun_sp.structures[-1],spin=1)
        loader_sp_dn = BandstructureLoader(bs_sp, vrun_sp.structures[-1],spin=-1)
        
        min_bnd = min(loader_sp_up.ebands.min(),loader_sp_dn.ebands.min())
        max_bnd = max(loader_sp_up.ebands.max(),loader_sp_dn.ebands.max())
        loader_sp_up.set_upper_lower_bands(min_bnd,max_bnd)
        loader_sp_dn.set_upper_lower_bands(min_bnd,max_bnd)

        self.bztI_up = BztInterpolator(loader_sp_up,lpfac=2,energy_range=np.inf,curvature=False)
        self.bztI_dn = BztInterpolator(loader_sp_dn,lpfac=2,energy_range=np.inf,curvature=False)
        
    def tearDown(self):
        warnings.resetwarnings()

    def test_properties(self):
        self.assertTupleEqual(self.bztInterp.cband.shape,(5, 3, 3, 3, 29791))
        self.assertTupleEqual(self.bztInterp.eband.shape,(5, 29791))
        self.assertTupleEqual(self.bztInterp.coeffs.shape,(5, 322))
        self.assertEqual(self.bztInterp.nemax,12)
        
    def test_get_band_structure(self):
        sbs = self.bztInterp.get_band_structure()
        self.assertIsNotNone(sbs)
        self.assertTupleEqual(sbs.bands[Spin.up].shape,(5,137))

    def test_tot_dos(self):
        tot_dos = self.bztInterp.get_dos(T=200,npts_mu = 100)
        self.assertIsNotNone(tot_dos)
        self.assertEqual(len(tot_dos.energies),100)
        self.assertAlmostEqual(tot_dos.densities[Spin.up][0],1.42859939,5)
            
        dos_up = self.bztI_up.get_dos(partial_dos=False,npts_mu = 100)
        dos_dn = self.bztI_dn.get_dos(partial_dos=False,npts_mu = 100)
        cdos = merge_up_down_doses(dos_up,dos_dn)
        self.assertAlmostEqual(cdos.densities[Spin.down][50],92.87836778,5)
        self.assertAlmostEqual(cdos.densities[Spin.up][45],9.564067,5)
        self.assertEqual(len(cdos.energies),100)

    def test_tot_proj_dos(self):
        tot_proj_dos = self.bztInterp.get_dos(partial_dos=True,T=200,npts_mu = 100)
        self.assertIsNotNone(tot_proj_dos)
        self.assertEqual(len(tot_proj_dos.get_spd_dos().values()),3)
        pdos = tot_proj_dos.get_spd_dos()[OrbitalType.s].densities[Spin.up][0]
        self.assertAlmostEqual(pdos,15.474392020,5)
Exemple #5
0
class BztInterpolatorTest(unittest.TestCase):
    def setUp(self):
        self.loader = VasprunLoader(vrun)
        self.assertTupleEqual(self.loader.proj.shape, (120, 20, 2, 9))
        self.bztInterp = BztInterpolator(self.loader, lpfac=2)
        self.assertIsNotNone(self.bztInterp)
        warnings.simplefilter("ignore")

    def tearDown(self):
        warnings.resetwarnings()

    def test_properties(self):
        self.assertTupleEqual(self.bztInterp.cband.shape, (5, 3, 3, 3, 29791))
        self.assertTupleEqual(self.bztInterp.eband.shape, (5, 29791))
        self.assertTupleEqual(self.bztInterp.coeffs.shape, (5, 322))
        self.assertEqual(self.bztInterp.nemax, 12)

    def test_get_band_structure(self):
        sbs = self.bztInterp.get_band_structure()
        self.assertIsNotNone(sbs)
        self.assertTupleEqual(sbs.bands[Spin.up].shape, (5, 137))

    def test_tot_dos(self):
        tot_dos = self.bztInterp.get_dos(T=200, npts_mu=100)
        self.assertIsNotNone(tot_dos)
        self.assertEqual(len(tot_dos.energies), 100)
        self.assertAlmostEqual(tot_dos.densities[Spin.up][0], 1.42859939, 5)

    def test_tot_proj_dos(self):
        tot_proj_dos = self.bztInterp.get_dos(partial_dos=True,
                                              T=200,
                                              npts_mu=100)
        self.assertIsNotNone(tot_proj_dos)
        self.assertEqual(len(tot_proj_dos.get_spd_dos().values()), 3)
        pdos = tot_proj_dos.get_spd_dos()[OrbitalType.s].densities[Spin.up][0]
        self.assertAlmostEqual(pdos, 15.474392020, 5)
Exemple #6
0
class BztInterpolatorTest(unittest.TestCase):
    def setUp(self):
        self.loader = VasprunBSLoader(vrun)
        self.bztInterp = BztInterpolator(self.loader, lpfac=2)
        self.assertIsNotNone(self.bztInterp)
        self.bztInterp = BztInterpolator(self.loader, lpfac=2, save_bztInterp=True, fname=bztinterp_fn)
        self.assertIsNotNone(self.bztInterp)
        self.bztInterp = BztInterpolator(self.loader, load_bztInterp=True, fname=bztinterp_fn)
        self.assertIsNotNone(self.bztInterp)

        warnings.simplefilter("ignore")

        self.loader_sp = VasprunBSLoader(vrun_sp)
        self.bztInterp_sp = BztInterpolator(self.loader_sp, lpfac=2)
        self.assertIsNotNone(self.bztInterp_sp)
        self.bztInterp_sp = BztInterpolator(self.loader_sp, lpfac=2, save_bztInterp=True, fname=bztinterp_fn)
        self.assertIsNotNone(self.bztInterp_sp)
        self.bztInterp_sp = BztInterpolator(self.loader_sp, lpfac=2, load_bztInterp=True, fname=bztinterp_fn)
        self.assertIsNotNone(self.bztInterp_sp)

        warnings.simplefilter("ignore")

    def tearDown(self):
        warnings.simplefilter("default")

    def test_properties(self):
        self.assertTupleEqual(self.bztInterp.cband.shape, (6, 3, 3, 3, 29791))
        self.assertTupleEqual(self.bztInterp.eband.shape, (6, 29791))
        self.assertTupleEqual(self.bztInterp.coeffs.shape, (6, 322))
        self.assertEqual(self.bztInterp.data.nelect, 6.0)
        self.assertEqual(self.bztInterp.data.nelect_all, 20.0)
        self.assertTupleEqual(self.bztInterp.data.ebands.shape, (6, 120))

        self.assertTupleEqual(self.bztInterp_sp.cband.shape, (10, 3, 3, 3, 23275))
        self.assertTupleEqual(self.bztInterp_sp.eband.shape, (10, 23275))
        self.assertTupleEqual(self.bztInterp_sp.coeffs.shape, (10, 519))
        self.assertEqual(self.bztInterp_sp.data.nelect, 6.0)
        self.assertEqual(self.bztInterp_sp.data.nelect_all, 10.0)
        self.assertTupleEqual(self.bztInterp_sp.data.ebands.shape, (10, 198))

    def test_get_band_structure(self):
        sbs = self.bztInterp.get_band_structure()
        self.assertIsNotNone(sbs)
        self.assertTupleEqual(sbs.bands[Spin.up].shape, (6, 137))
        kpaths = [["L", "K"]]
        kp_lbl = {"L": np.array([0.5, 0.5, 0.5]), "K": np.array([0.375, 0.375, 0.75])}
        sbs = self.bztInterp.get_band_structure(kpaths, kp_lbl)
        self.assertIsNotNone(sbs)
        self.assertTupleEqual(sbs.bands[Spin.up].shape, (6, 20))

        sbs = self.bztInterp_sp.get_band_structure()
        self.assertIsNotNone(sbs)
        self.assertTupleEqual(sbs.bands[Spin.up].shape, (6, 143))
        self.assertTupleEqual(sbs.bands[Spin.down].shape, (4, 143))

    def test_tot_dos(self):
        tot_dos = self.bztInterp.get_dos(T=200, npts_mu=100)
        self.assertIsNotNone(tot_dos)
        self.assertEqual(len(tot_dos.energies), 100)
        self.assertAlmostEqual(tot_dos.densities[Spin.up][0], 1.35371715, 5)

        tot_dos = self.bztInterp_sp.get_dos(T=200, npts_mu=100)
        self.assertIsNotNone(tot_dos)
        self.assertEqual(len(tot_dos.energies), 100)
        self.assertAlmostEqual(tot_dos.densities[Spin.up][75], 88.034456, 5)
        self.assertAlmostEqual(tot_dos.densities[Spin.down][75], 41.421367, 5)

    def test_tot_proj_dos(self):
        tot_proj_dos = self.bztInterp.get_dos(partial_dos=True, T=200, npts_mu=100)
        self.assertIsNotNone(tot_proj_dos)
        self.assertEqual(len(tot_proj_dos.get_spd_dos().values()), 3)
        pdos = tot_proj_dos.get_spd_dos()[OrbitalType.s].densities[Spin.up][75]
        self.assertAlmostEqual(pdos, 2490.169396, 5)

        tot_proj_dos = self.bztInterp_sp.get_dos(partial_dos=True, T=200, npts_mu=100)
        self.assertIsNotNone(tot_proj_dos)
        self.assertEqual(len(tot_proj_dos.get_spd_dos().values()), 3)
        pdos = tot_proj_dos.get_spd_dos()[OrbitalType.s].densities[Spin.up][75]
        self.assertAlmostEqual(pdos, 166.4933305, 5)
        pdos = tot_proj_dos.get_spd_dos()[OrbitalType.s].densities[Spin.down][75]
        self.assertAlmostEqual(pdos, 272.194174, 5)
Exemple #7
0
def bandplot_func(
        filenames=None,
        code='vasp',
        prefix=None,
        directory=None,
        vbm_cbm_marker=False,
        projection_selection=None,
        mode='rgb',
        pred=None,
        interpolate_factor=4,
        circle_size=150,
        dos_file=None,
        cart_coords=False,
        scissor=None,
        ylabel='Energy (eV)',
        dos_label=None,
        elements=None,
        lm_orbitals=None,
        atoms=None,
        spin=None,
        total_only=False,
        plot_total=True,
        legend_cutoff=3,
        gaussian=None,
        height=None,
        width=None,
        ymin=-6.,
        ymax=6.,
        colours=None,
        yscale=1,
        style=None,
        no_base_style=False,
        image_format='pdf',
        dpi=400,
        plt=None,
        fonts=None,
        boltz={
            "ifinter": "T",
            "lpfac": "10",
            "energy_range": "50",
            "curvature": "",
            "load": "T",
            'ismetaltolerance': '0.01'
        },
        nelec=0):
    if not filenames:
        filenames = find_vasprun_files()
    elif isinstance(filenames, str):
        filenames = [filenames]

    # only load the orbital projects if we definitely need them
    parse_projected = True if projection_selection else False

    # now load all the band structure data and combine using the
    # get_reconstructed_band_structure function from pymatgen
    bandstructures = []
    if code == 'vasp':
        for vr_file in filenames:
            vr = BSVasprun(vr_file, parse_projected_eigen=parse_projected)
            print("BSVasprun", type(vr), vr)
            print("vr.eigenvalues.keys()", type(vr.eigenvalues.keys()),
                  vr.eigenvalues.keys())
            if pred.any():
                # Fill in Model prediction
                model = BSVasprun(vr_file,
                                  parse_projected_eigen=parse_projected)
                print("pred", type(pred), pred.shape)
                pred = np.expand_dims(pred, axis=-1)
                for key in model.eigenvalues.keys():
                    key_last = key
                    print("model.eigenvalues[key][:, :, :].shape[0]", key,
                          type(model.eigenvalues[key][:, :, :]),
                          model.eigenvalues[key][:, :, :].shape,
                          pred[:, :, :].shape)
                    bands = min(model.eigenvalues[key][:, :, :].shape[1],
                                pred.shape[1])
                    print(
                        "bands", bands, "attention! max: ",
                        max(model.eigenvalues[key][:, :, :].shape[1],
                            pred.shape[1]))
                    print(
                        "equel False?",
                        np.sum(model.eigenvalues[key][:, :bands, :] -
                               pred[:, :bands, :]))
                    model.eigenvalues[key][:, :bands, :] = pred[:, :bands, :]
                    print(
                        "equel True?",
                        np.sum(model.eigenvalues[key][:, :bands, :] -
                               pred[:, :bands, :]))
                    print(
                        "equel model vr?",
                        np.sum(model.eigenvalues[key][:, :bands, :] -
                               vr.eigenvalues[key][:, :bands, :]))
                    # spin = 1  # for only plotting spin up oder down 1, -1
                # model.eigenvalues[key_last][:, :bands, :] = pred[:, :bands, :]

            #boltztrap={'ifinter':False,'lpfac':10,'energy_range':50,'curvature':False}):

            if bool(boltz['ifinter']):
                b_data = VasprunBSLoader(vr)
                model_data = VasprunBSLoader(model)
                print("BSVasprunLoader", type(b_data), b_data)
                b_inter = BztInterpolator(b_data,
                                          lpfac=int(boltz['lpfac']),
                                          energy_range=float(
                                              boltz['energy_range']),
                                          curvature=bool(boltz['curvature']),
                                          save_bztInterp=True,
                                          load_bztInterp=bool(boltz['load']))
                model_inter = BztInterpolator(
                    model_data,
                    lpfac=int(boltz['lpfac']),
                    energy_range=float(boltz['energy_range']),
                    curvature=bool(boltz['curvature']),
                    save_bztInterp=True,
                    load_bztInterp=bool(boltz['load']))

                try:
                    kpath = json.load(open('./kpath', 'r'))
                    kpaths = kpath['path']
                    kpoints_lbls_dict = {}
                    for i in range(len(kpaths)):
                        for j in [0, 1]:
                            if 'GAMMA' == kpaths[i][j]:
                                kpaths[i][j] = '\Gamma'
                    for k, v in kpath['kpoints_rel'].items():
                        if k == 'GAMMA':
                            k = '\Gamma'
                        kpoints_lbls_dict[k] = v
                except:
                    kpaths = None
                    kpoints_lbls_dict = None

                print(kpaths, kpoints_lbls_dict)
                bs = b_inter.get_band_structure(
                    kpaths=kpaths, kpoints_lbls_dict=kpoints_lbls_dict)
                model_bs = model_inter.get_band_structure(
                    kpaths=kpaths, kpoints_lbls_dict=kpoints_lbls_dict)

                #bs_uniform = b_inter.get_band_structure()
                gap = bs.get_band_gap()
                nvb = int(np.ceil(nelec / (int(bs.is_spin_polarized) + 1)))
                vbm = -100
                print("WHC interpolated gap: %s" % gap)
                for spin, v in bs.bands.items():
                    vbm = max(vbm, max(v[nvb - 1]))
                print(
                    'WHC WARNNING vasp fermi %s interpolation vbm %s nelec %s nvb %s'
                    % (bs.efermi, vbm, nelec, nvb))
                if vbm < bs.efermi:
                    bs.efermi = vbm
                    print("if vbm <")
                if vbm < model_bs.efermi:
                    model_bs.efermi = vbm
                    print("if vbm <")
                print(bs.bands.keys())
                band_keys = list(bs.bands.keys())
                print("Band shapes", bs.bands[band_keys[0]].shape,
                      model_bs.bands[band_keys[0]].shape)
                print(
                    "equel bands?",
                    np.sum((bs.bands[band_keys[0]] - bs.efermi) -
                           model_bs.bands[band_keys[0]]))
                bs.bands[band_keys[0]] = (
                    bs.bands[band_keys[0]] - bs.efermi
                )  # why??????????????????????????????????????????????????
                # bs.bands[band_keys[1]] = (bs.bands[band_keys[1]] - bs.efermi)  # why??????????????????????????????????????????????????
                print(
                    "equel bands fermi shifted?",
                    np.sum((bs.bands[band_keys[0]] - bs.efermi) -
                           model_bs.bands[band_keys[0]]))

                # bandstructures.append(bs)
                # bandstructures.append(model_bs)
            bs = get_reconstructed_band_structure([bs])
            model_bs = get_reconstructed_band_structure([model_bs])

            if bool(boltz['ifinter']):
                bs.nvb = nvb
                bs.ismetaltolerance = float(boltz['ismetaltolerance'])
                model_bs.nvb = nvb
                model_bs.ismetaltolerance = float(boltz['ismetaltolerance'])

            print("dft bands", bs.bands[band_keys[0]])
            print("dft ktps", len(bs.kpoints), kpts.shape)
            print("dft labels", bs.labels_dict)

            for key in bs.labels_dict.keys():
                print(bs.labels_dict[key].label, bs.labels_dict[key].as_dict(),
                      bs.labels_dict[key].a, bs.labels_dict[key].b,
                      bs.labels_dict[key].c, bs.labels_dict[key].frac_coords)
            labels = []
            for i in range(len(bs.kpoints)):
                # print(i, bs.kpoints[i])
                for key in bs.labels_dict.keys():
                    if bs.labels_dict[key].label == bs.kpoints[
                            i].label and bs.labels_dict[
                                key].label != bs.kpoints[i - 1].label:
                        # print("Labels!!!!!", i, bs.labels_dict[key].label)
                        labels.append([i, bs.labels_dict[key].label])
            print(labels)

            print("dft efermi", bs.efermi)
            print("dft lattice_rec", bs.lattice_rec)
            print("dft structure", bs.structure)

            print("model bands", model_bs.bands[band_keys[0]])
            print("model ktps", len(model_bs.kpoints), kpts.shape)
            print("model labels", model_bs.labels_dict)
            print("model efermi", model_bs.efermi)
            print("model lattice_rec", model_bs.lattice_rec)
            print("model structure", model_bs.structure)

            return bs.bands[band_keys[0]], model_bs.bands[band_keys[0]], labels

            save_files = False if plt else True
            dos_plotter = None
            dos_opts = None
            if dos_file:
                dos, pdos = load_dos(dos_file, elements, lm_orbitals, atoms,
                                     gaussian, total_only)
                dos_plotter = SDOSPlotter(dos, pdos)
                dos_opts = {
                    'plot_total': plot_total,
                    'legend_cutoff': legend_cutoff,
                    'colours': colours,
                    'yscale': yscale
                }

            model_and_dft_bs = [bs, model_bs]
            plotter = SBSPlotter(model_bs)
            print("spin", spin)
            if len(vr.eigenvalues.keys()) == 1:
                spin = None
            print("spin", spin)
            plt = plotter.get_plot(zero_to_efermi=True,
                                   ymin=ymin,
                                   ymax=ymax,
                                   height=height,
                                   width=width,
                                   vbm_cbm_marker=vbm_cbm_marker,
                                   ylabel=ylabel,
                                   plt=plt,
                                   dos_plotter=dos_plotter,
                                   dos_options=dos_opts,
                                   dos_label=dos_label,
                                   fonts=fonts,
                                   style=style,
                                   no_base_style=no_base_style,
                                   spin=spin)

        # don't save if pyplot object provided
        save_files = False if plt else True
        if save_files:
            basename = 'band.{}'.format(image_format)
            filename = '{}_{}'.format(prefix, basename) if prefix else basename
            if directory:
                filename = os.path.join(directory, filename)
            plt.savefig(filename,
                        format=image_format,
                        dpi=dpi,
                        bbox_inches='tight')

            written = [filename]
            written += save_data_files(bs, prefix=prefix, directory=directory)
            return written

        else:
            return plt