示例#1
0
    def setUp(self):
        with open(os.path.join(test_dir, "CaO_2605_bandstructure.json"),
                  "r",
                  encoding="utf-8") as f:
            d = json.loads(f.read())
            self.bs = BandStructureSymmLine.from_dict(d)
            self.plotter = BSPlotter(self.bs)

        self.assertEqual(len(self.plotter._bs), 1,
                         "wrong number of band objects")

        with open(os.path.join(test_dir, "N2_12103_bandstructure.json"),
                  "r",
                  encoding="utf-8") as f:
            d = json.loads(f.read())
            self.sbs_sc = BandStructureSymmLine.from_dict(d)

        with open(os.path.join(test_dir, "C_48_bandstructure.json"),
                  "r",
                  encoding="utf-8") as f:
            d = json.loads(f.read())
            self.sbs_met = BandStructureSymmLine.from_dict(d)

        self.plotter_multi = BSPlotter([self.sbs_sc, self.sbs_met])
        self.assertEqual(len(self.plotter_multi._bs), 2,
                         "wrong number of band objects")
        self.assertEqual(self.plotter_multi._nb_bands, [96, 96],
                         "wrong number of bands")
        warnings.simplefilter("ignore")
示例#2
0
    def setUp(self):
        with open(os.path.join(test_dir, "Cu2O_361_bandstructure.json"),
                  "r",
                  encoding='utf-8') as f:
            d = json.load(f)
            self.bs = BandStructureSymmLine.from_dict(d)
            self.assertListEqual(
                self.bs._projections[Spin.up][10][12][Orbital.s],
                [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], "wrong projections")
            self.assertListEqual(
                self.bs._projections[Spin.up][25][0][Orbital.dyz],
                [0.0, 0.0, 0.0011, 0.0219, 0.0219, 0.069], "wrong projections")
            self.assertAlmostEqual(
                self.bs.get_projection_on_elements()[Spin.up][25][10]['O'],
                0.0328)
            self.assertAlmostEqual(
                self.bs.get_projection_on_elements()[Spin.up][22][25]['Cu'],
                0.8327)
            self.assertAlmostEqual(
                self.bs.get_projections_on_elts_and_orbitals(
                    {'Cu': ['s', 'd']})[Spin.up][25][0]['Cu']['s'], 0.0027)
            self.assertAlmostEqual(
                self.bs.get_projections_on_elts_and_orbitals(
                    {'Cu': ['s', 'd']})[Spin.up][25][0]['Cu']['d'],
                0.8495999999999999)

        with open(os.path.join(test_dir, "CaO_2605_bandstructure.json"),
                  "r",
                  encoding='utf-8') as f:
            d = json.load(f)
            #print d.keys()
            self.bs = BandStructureSymmLine.from_dict(d)
            #print self.bs.as_dict().keys()
            #this doesn't really test as_dict() -> from_dict very well
            #self.assertEqual(self.bs.as_dict().keys(), d.keys())
            self.one_kpoint = self.bs.kpoints[31]
            self.assertEqual(self.bs._nb_bands, 16)
            self.assertAlmostEqual(self.bs._bands[Spin.up][5][10], 0.5608)
            self.assertAlmostEqual(self.bs._bands[Spin.up][5][10], 0.5608)
            self.assertEqual(self.bs._branches[5]['name'], "L-U")
            self.assertEqual(self.bs._branches[5]['start_index'], 80)
            self.assertEqual(self.bs._branches[5]['end_index'], 95)
            self.assertAlmostEqual(self.bs._distance[70], 4.2335127528765737)
        with open(os.path.join(test_dir, "NiO_19009_bandstructure.json"),
                  "r",
                  encoding='utf-8') as f:
            d = json.load(f)
            self.bs_spin = BandStructureSymmLine.from_dict(d)
            #this doesn't really test as_dict() -> from_dict very well
            #self.assertEqual(self.bs_spin.as_dict().keys(), d.keys())
            self.assertEqual(self.bs_spin._nb_bands, 27)
            self.assertAlmostEqual(self.bs_spin._bands[Spin.up][5][10], 0.262)
            self.assertAlmostEqual(self.bs_spin._bands[Spin.down][5][10],
                                   1.6156)
    def setUp(self):
        with open(os.path.join(test_dir, "Cu2O_361_bandstructure.json"), "r", encoding="utf-8") as f:
            d = json.load(f)
            self.bs = BandStructureSymmLine.from_dict(d)

        with open(os.path.join(test_dir, "CaO_2605_bandstructure.json"), "r", encoding="utf-8") as f:
            d = json.load(f)
            self.bs2 = BandStructureSymmLine.from_dict(d)

        with open(os.path.join(test_dir, "NiO_19009_bandstructure.json"), "r", encoding="utf-8") as f:
            d = json.load(f)
            self.bs_spin = BandStructureSymmLine.from_dict(d)
示例#4
0
 def setUp(self):
     with open(os.path.join(test_dir, "CaO_2605_bandstructure.json"),
               "r",
               encoding='utf-8') as f:
         d = json.loads(f.read())
         self.bs = BandStructureSymmLine.from_dict(d)
         self.plotter = BSPlotter(self.bs)
示例#5
0
 def setUp(self):
     with open(os.path.join(test_dir, "CaO_2605_bandstructure.json"),
               "r", encoding='utf-8') as f:
         d = json.loads(f.read())
         self.bs = BandStructureSymmLine.from_dict(d)
         self.plotter = BSPlotter(self.bs)
     warnings.simplefilter("ignore")
示例#6
0
 def test_old_format_load(self):
     with open(os.path.join(test_dir, "bs_ZnS_old.json"),
               "r", encoding='utf-8') as f:
         d = json.load(f)
         bs_old = BandStructureSymmLine.from_dict(d)
         self.assertEqual(bs_old.get_projection_on_elements()[
                              Spin.up][0][0]['Zn'], 0.0971)
示例#7
0
    def setUp(self):
        with open(os.path.join(test_dir, 'si_structure.json'), 'r') as sth:
            si_str = Structure.from_dict(json.load(sth))

        with open(os.path.join(test_dir, 'si_bandstructure_line.json'),
                  'r') as bsh:
            si_bs_line = BandStructureSymmLine.from_dict(json.load(bsh))
        si_bs_line.structure = si_str

        with open(os.path.join(test_dir, 'si_bandstructure_uniform.json'),
                  'r') as bsh:
            si_bs_uniform = BandStructure.from_dict(json.load(bsh))
        si_bs_uniform.structure = si_str
        self.si_kpts = list(HighSymmKpath(si_str).kpath['kpoints'].values())
        self.df = pd.DataFrame({
            'bs_line': [si_bs_line],
            'bs_uniform': [si_bs_uniform]
        })

        with open(os.path.join(test_dir, 'VBr2_971787_bandstructure.json'),
                  'r') as bsh:
            vbr2_uniform = BandStructure.from_dict(json.load(bsh))

        self.vbr2kpts = [
            k.frac_coords for k in vbr2_uniform.labels_dict.values()
        ]
        self.vbr2kpts = [
            [0.0, 0.0, 0.0],  # \\Gamma
            [0.2, 0.0, 0.0],  # between \\Gamma and M
            [0.5, 0.0, 0.0],  # M
            [0.5, 0.0, 0.5]
        ]  # L

        self.df2 = pd.DataFrame({'bs_line': [vbr2_uniform]})
示例#8
0
def extract_bs(mat):

    bs = None

    # Process the bandstructure for information
    if "bs" in mat["bandstructure"]:
        bs_dict = mat["bandstructure"]["bs"]
        # Add in structure if not already there
        if "structure" not in bs_dict:
            bs_dict["structure"] = mat["structure"]

        # Add in High Symm K Path if not already there
        if len(bs_dict.get("labels_dict", {})) == 0:
            labels = get(mat, "inputs.nscf_line.kpoints.labels", None)
            kpts = get(mat, "inputs.nscf_line.kpoints.kpoints", None)
            if labels and kpts:
                labels_dict = dict(zip(labels, kpts))
                labels_dict.pop(None, None)
            else:
                struc = Structure.from_dict(mat["structure"])
                labels_dict = HighSymmKpath(struc)._kpath["kpoints"]

            bs_dict["labels_dict"] = labels_dict

        bs = BandStructureSymmLine.from_dict(
            BandStructure.from_dict(bs_dict).as_dict())

    return bs
示例#9
0
 def setUp(self):
     with open(os.path.join(test_dir, "Cu2O_361_bandstructure.json"),
               "r", encoding='utf-8') as f:
         d = json.load(f)
         self.bs = BandStructureSymmLine.from_dict(d)
         self.plotter = BSPlotterProjected(self.bs)
     warnings.simplefilter("ignore")
示例#10
0
 def test_old_format_load(self):
     with open(os.path.join(PymatgenTest.TEST_FILES_DIR, "bs_ZnS_old.json"),
               encoding="utf-8") as f:
         d = json.load(f)
         bs_old = BandStructureSymmLine.from_dict(d)
         self.assertEqual(
             bs_old.get_projection_on_elements()[Spin.up][0][0]["Zn"],
             0.0971)
示例#11
0
 def setUp(self):
     with open(os.path.join(PymatgenTest.TEST_FILES_DIR,
                            "Cu2O_361_bandstructure.json"),
               encoding="utf-8") as f:
         d = json.load(f)
         self.bs = BandStructureSymmLine.from_dict(d)
         self.plotter = BSPlotterProjected(self.bs)
     warnings.simplefilter("ignore")
示例#12
0
    def setUp(self):
        with open(os.path.join(test_dir, "Cu2O_361_bandstructure.json"),
                  "r",
                  encoding='utf-8') as f:
            d = json.load(f)
            self.bs = BandStructureSymmLine.from_dict(d)

        with open(os.path.join(test_dir, "CaO_2605_bandstructure.json"),
                  "r",
                  encoding='utf-8') as f:
            d = json.load(f)
            self.bs2 = BandStructureSymmLine.from_dict(d)

        with open(os.path.join(test_dir, "NiO_19009_bandstructure.json"),
                  "r",
                  encoding='utf-8') as f:
            d = json.load(f)
            self.bs_spin = BandStructureSymmLine.from_dict(d)
def get_bs(db_file):
    """ return bandstructure object from the database """
    d1, d2, d3, d4 = get_collections(db_file)
    db = get_db(db_file)
    fs = gridfs.GridFS(db, 'bandstructure_fs')
    bs_fs_id = d3["calcs_reversed"][0]["bandstructure_fs_id"]
    bs_json = zlib.decompress(fs.get(bs_fs_id).read())
    bs_dict = json.loads(bs_json.decode())
    return BandStructureSymmLine.from_dict(bs_dict)
def update_dosbandsfig(n_clicks, dos, bs, projlist):
    ## figure updates when the inputs change or the button is clicked
    ## figure does NOT update when elements or orbitals are selected
    ## de-serialize dos and bs from json format to pymatgen objects
    if dos:
        dos = CompleteDos.from_dict(json.loads(dos))
    if bs:
        bs = BandStructureSymmLine.from_dict(json.loads(bs))
    ## update the band structure and dos figure
    dosbandfig = BandsFig().generate_fig(dos, bs, projlist)
    return dosbandfig
示例#15
0
 def get_band_structure(self, task_id, line_mode=False):
     m_task = self.collection.find_one({"task_id": task_id},
                                       {"calcs_reversed": 1})
     fs_id = m_task['calcs_reversed'][0]['bandstructure_fs_id']
     fs = gridfs.GridFS(self.db, 'bandstructure_fs')
     bs_json = zlib.decompress(fs.get(fs_id).read())
     bs_dict = json.loads(bs_json)
     if line_mode:
         return BandStructureSymmLine.from_dict(bs_dict)
     else:
         return BandStructure.from_dict(bs_dict)
示例#16
0
 def get_band_structure(self, task_id, line_mode=False):
     m_task = self.collection.find_one({"task_id": task_id},
                                       {"calcs_reversed": 1})
     fs_id = m_task['calcs_reversed'][0]['bandstructure_fs_id']
     fs = gridfs.GridFS(self.db, 'bandstructure_fs')
     bs_json = zlib.decompress(fs.get(fs_id).read())
     bs_dict = json.loads(bs_json)
     if line_mode:
         return BandStructureSymmLine.from_dict(bs_dict)
     else:
         return BandStructure.from_dict(bs_dict)
示例#17
0
 def get_band_structure(self, task_id):
     m_task = self.collection.find_one({"task_id": task_id}, {"calcs_reversed": 1})
     fs_id = m_task['calcs_reversed'][0]['bandstructure_fs_id']
     fs = gridfs.GridFS(self.db, 'bandstructure_fs')
     bs_json = zlib.decompress(fs.get(fs_id).read())
     bs_dict = json.loads(bs_json.decode())
     if bs_dict["@class"] == "BandStructure":
         return BandStructure.from_dict(bs_dict)
     elif bs_dict["@class"] == "BandStructureSymmLine":
         return BandStructureSymmLine.from_dict(bs_dict)
     else:
         raise ValueError("Unknown class for band structure! {}".format(bs_dict["@class"]))
示例#18
0
 def get_band_structure(self, task_id):
     m_task = self.collection.find_one({"task_id": task_id}, {"calcs_reversed": 1})
     fs_id = m_task['calcs_reversed'][0]['bandstructure_fs_id']
     fs = gridfs.GridFS(self.db, 'bandstructure_fs')
     bs_json = zlib.decompress(fs.get(fs_id).read())
     bs_dict = json.loads(bs_json.decode())
     if bs_dict["@class"] == "BandStructure":
         return BandStructure.from_dict(bs_dict)
     elif bs_dict["@class"] == "BandStructureSymmLine":
         return BandStructureSymmLine.from_dict(bs_dict)
     else:
         raise ValueError("Unknown class for band structure! {}".format(bs_dict["@class"]))
示例#19
0
    def setUp(self):
        with open(os.path.join(test_dir, "Cu2O_361_bandstructure.json"),
                  "r", encoding='utf-8') as f:
            d = json.load(f)
            self.bs = BandStructureSymmLine.from_dict(d)
            self.assertListEqual(self.bs._projections[Spin.up][10][12][Orbital.s], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], "wrong projections")
            self.assertListEqual(self.bs._projections[Spin.up][25][0][Orbital.dyz], [0.0, 0.0, 0.0011, 0.0219, 0.0219, 0.069], "wrong projections")
            self.assertAlmostEqual(self.bs.get_projection_on_elements()[Spin.up][25][10]['O'], 0.0328)
            self.assertAlmostEqual(self.bs.get_projection_on_elements()[Spin.up][22][25]['Cu'], 0.8327)
            self.assertAlmostEqual(self.bs.get_projections_on_elts_and_orbitals({'Cu':['s','d']})[Spin.up][25][0]['Cu']['s'], 0.0027)
            self.assertAlmostEqual(self.bs.get_projections_on_elts_and_orbitals({'Cu':['s','d']})[Spin.up][25][0]['Cu']['d'], 0.8495999999999999)

        with open(os.path.join(test_dir, "CaO_2605_bandstructure.json"), "r",
                  encoding='utf-8') as f:
            d = json.load(f)
            #print d.keys()
            self.bs = BandStructureSymmLine.from_dict(d)
            #print self.bs.as_dict().keys()
            #this doesn't really test as_dict() -> from_dict very well
            #self.assertEqual(self.bs.as_dict().keys(), d.keys())
            self.one_kpoint = self.bs.kpoints[31]
            self.assertEqual(self.bs._nb_bands, 16)
            self.assertAlmostEqual(self.bs._bands[Spin.up][5][10], 0.5608)
            self.assertAlmostEqual(self.bs._bands[Spin.up][5][10], 0.5608)
            self.assertEqual(self.bs._branches[5]['name'], "L-U")
            self.assertEqual(self.bs._branches[5]['start_index'], 80)
            self.assertEqual(self.bs._branches[5]['end_index'], 95)
            self.assertAlmostEqual(self.bs._distance[70], 4.2335127528765737)
        with open(os.path.join(test_dir, "NiO_19009_bandstructure.json"),
                  "r", encoding='utf-8') as f:
            d = json.load(f)
            self.bs_spin = BandStructureSymmLine.from_dict(d)
            #this doesn't really test as_dict() -> from_dict very well
            #self.assertEqual(self.bs_spin.as_dict().keys(), d.keys())
            self.assertEqual(self.bs_spin._nb_bands, 27)
            self.assertAlmostEqual(self.bs_spin._bands[Spin.up][5][10], 0.262)
            self.assertAlmostEqual(self.bs_spin._bands[Spin.down][5][10],
                                   1.6156)
示例#20
0
    def setUp(self):
        with open(os.path.join(test_dir, 'si_structure.json'), 'r') as sth:
            si_str = Structure.from_dict(json.load(sth))

        with open(os.path.join(test_dir, 'si_bandstructure_line.json'),
                  'r') as bsh:
            si_bs_line = BandStructureSymmLine.from_dict(json.load(bsh))
        si_bs_line.structure = si_str

        with open(os.path.join(test_dir, 'si_bandstructure_uniform.json'),
                  'r') as bsh:
            si_bs_uniform = BandStructure.from_dict(json.load(bsh))
        si_bs_uniform.structure = si_str
        self.si_kpts = list(HighSymmKpath(si_str).kpath['kpoints'].values())
        self.df = pd.DataFrame({
            'bs_line': [si_bs_line],
            'bs_uniform': [si_bs_uniform]
        })
示例#21
0
    def get_band_structure(self, task_id):
        """
        Read the BS data into a PMG BandStructure or BandStructureSymmLine object

        Args:
            task_id(int or str): the task_id containing the data
        Returns:
            BandStructure or BandStructureSymmLine
        """
        obj_dict = self.get_data_from_maggma_or_gridfs(task_id,
                                                       key="bandstructure")
        if obj_dict["@class"] == "BandStructure":
            return BandStructure.from_dict(obj_dict)
        elif obj_dict["@class"] == "BandStructureSymmLine":
            return BandStructureSymmLine.from_dict(obj_dict)
        else:
            raise ValueError("Unknown class for band structure! {}".format(
                obj_dict["@class"]))
示例#22
0
    def _get_bs_dos(data):

        data = data or {}

        # this component can be loaded either from mpid or
        # directly from BandStructureSymmLine or CompleteDos objects
        # if mpid is supplied, this is preferred

        mpid = data.get("mpid")
        bandstructure_symm_line = data.get("bandstructure_symm_line")
        density_of_states = data.get("density_of_states")

        if not mpid and (bandstructure_symm_line is None
                         or density_of_states is None):
            return None, None

        if mpid:

            with MPRester() as mpr:

                try:
                    bandstructure_symm_line = mpr.get_bandstructure_by_material_id(
                        mpid)
                except Exception as exc:
                    print(exc)
                    bandstructure_symm_line = None

                try:
                    density_of_states = mpr.get_dos_by_material_id(mpid)
                except Exception as exc:
                    print(exc)
                    density_of_states = None

        else:

            if bandstructure_symm_line and isinstance(bandstructure_symm_line,
                                                      dict):
                bandstructure_symm_line = BandStructureSymmLine.from_dict(
                    bandstructure_symm_line)

            if density_of_states and isinstance(density_of_states, dict):
                density_of_states = CompleteDos.from_dict(density_of_states)

        return bandstructure_symm_line, density_of_states
示例#23
0
def build_bs(bs_dict, mat):

    bs_dict["structure"] = mat["structure"]

    # Add in High Symm K Path if not already there
    if len(bs_dict.get("labels_dict", {})) == 0:
        labels = get(mat, "inputs.nscf_line.kpoints.labels", None)
        kpts = get(mat, "inputs.nscf_line.kpoints.kpoints", None)
        if labels and kpts:
            labels_dict = dict(zip(labels, kpts))
            labels_dict.pop(None, None)
        else:
            struc = Structure.from_dict(bs_dict["structure"])
            labels_dict = HighSymmKpath(struc)._kpath["kpoints"]

        bs_dict["labels_dict"] = labels_dict

    # This is somethign to do with BandStructureSymmLine's from dict being problematic
    bs = BandStructureSymmLine.from_dict(bs_dict)

    return bs
示例#24
0
    def setUp(self):
        with open(os.path.join(test_dir, 'si_structure.json'),'r') as sth:
            si_str = Structure.from_dict(json.load(sth))

        with open(os.path.join(test_dir, 'si_bandstructure_line.json'),'r') as bsh:
            si_bs_line = BandStructureSymmLine.from_dict(json.load(bsh))
        si_bs_line.structure = si_str

        with open(os.path.join(test_dir, 'si_bandstructure_uniform.json'),'r') as bsh:
            si_bs_uniform = BandStructure.from_dict(json.load(bsh))
        si_bs_uniform.structure = si_str
        self.si_kpts = list(HighSymmKpath(si_str).kpath['kpoints'].values())
        self.df = pd.DataFrame({'bs_line': [si_bs_line], 'bs_uniform': [si_bs_uniform]})

        with open(os.path.join(test_dir, 'VBr2_971787_bandstructure.json'), 'r') as bsh:
            vbr2_uniform = BandStructure.from_dict(json.load(bsh))

        self.vbr2kpts = [k.frac_coords for k in vbr2_uniform.labels_dict.values()]
        self.vbr2kpts = [[0.0, 0.0, 0.0], # \\Gamma
                         [0.2, 0.0, 0.0], # between \\Gamma and M
                         [0.5, 0.0, 0.0], # M
                         [0.5, 0.0, 0.5]] # L

        self.df2 = pd.DataFrame({'bs_line': [vbr2_uniform]})
    def process_item(self, mat):
        """
        Process the tasks and materials into just a list of materials

        Args:
            mat (dict): material document

        Returns:
            (dict): electronic_structure document
        """

        self.logger.info("Processing: {}".format(mat[self.materials.key]))

        d = {self.electronic_structure.key: mat[
            self.materials.key], "bandstructure": {}}
        bs = None
        dos = None
        interpolated_dos = None

        # Process the bandstructure for information
        if "bs" in mat["bandstructure"]:
            if "structure" not in mat["bandstructure"]["bs"]:
                mat["bandstructure"]["bs"]["structure"] = mat["structure"]
            if len(mat["bandstructure"]["bs"].get("labels_dict", {})) == 0:
                struc = Structure.from_dict(mat["structure"])
                kpath = HighSymmKpath(struc)._kpath["kpoints"]
                mat["bandstructure"]["bs"]["labels_dict"] = kpath
            # Somethign is wrong with the as_dict / from_dict encoding in the two band structure objects so have to use this hodge podge serialization
            # TODO: Fix bandstructure objects in pymatgen
            bs = BandStructureSymmLine.from_dict(
                BandStructure.from_dict(mat["bandstructure"]["bs"]).as_dict())
            d["bandstructure"]["band_gap"] = {"band_gap": bs.get_band_gap()["energy"],
                                              "direct_gap": bs.get_direct_band_gap(),
                                              "is_direct": bs.get_band_gap()["direct"],
                                              "transition": bs.get_band_gap()["transition"]}

            if self.small_plot:
                d["bandstructure"]["plot_small"] = get_small_plot(bs)

        if "dos" in mat["bandstructure"]:
            dos = CompleteDos.from_dict(mat["bandstructure"]["dos"])

        if self.interpolate_dos and "uniform_bs" in mat["bandstructure"]:
            try:
                interpolated_dos = self.get_interpolated_dos(mat)
            except Exception:
                self.logger.warning("Boltztrap interpolation failed for {}. Continuing with regular DOS".format(mat[self.materials.key]))

        # Generate static images
        if self.static_images:
            try:
                ylim = None
                if bs:
                    plotter = WebBSPlotter(bs)
                    fig = plotter.get_plot()
                    ylim = fig.ylim()  # Used by DOS plot
                    fig.close()

                    d["bandstructure"]["bs_plot"] = image_from_plotter(plotter)

                if dos:
                    plotter = WebDosVertPlotter()
                    plotter.add_dos_dict(dos.get_element_dos())

                    if interpolated_dos:
                        plotter.add_dos("Total DOS", interpolated_dos)
                        d["bandstructure"]["dos_plot"] = image_from_plotter(plotter, ylim=ylim)

                    d["bandstructure"]["dos_plot"] = image_from_plotter(plotter, ylim=ylim)

            except Exception:
                self.logger.warning(
                    "Caught error in electronic structure plotting for {}: {}".format(mat[self.materials.key], traceback.format_exc()))
                return None

        return d
示例#26
0
        def bs_dos_traces(bandStructureSymmLine, densityOfStates):

            if bandStructureSymmLine == "error" or densityOfStates == "error":
                return "error"

            if bandStructureSymmLine == None or densityOfStates == None:
                raise PreventUpdate

            # - BS Data
            bstraces = []

            bs_reg_plot = BSPlotter(BSML.from_dict(bandStructureSymmLine))

            bs_data = bs_reg_plot.bs_plot_data()

            # -- Strip latex math wrapping
            str_replace = {
                "$": "",
                "\\mid": "|",
                "\\Gamma": "Γ",
                "\\Sigma": "Σ",
                "_1": "₁",
                "_2": "₂",
                "_3": "₃",
                "_4": "₄",
            }

            for entry_num in range(len(bs_data["ticks"]["label"])):
                for key in str_replace.keys():
                    if key in bs_data["ticks"]["label"][entry_num]:
                        bs_data["ticks"]["label"][entry_num] = bs_data[
                            "ticks"]["label"][entry_num].replace(
                                key, str_replace[key])

            for d in range(len(bs_data["distances"])):
                for i in range(bs_reg_plot._nb_bands):
                    bstraces.append(
                        go.Scatter(
                            x=bs_data["distances"][d],
                            y=[
                                bs_data["energy"][d][str(Spin.up)][i][j]
                                for j in range(len(bs_data["distances"][d]))
                            ],
                            mode="lines",
                            line=dict(color=("#666666"), width=2),
                            hoverinfo="skip",
                            showlegend=False,
                        ))

                    if bs_reg_plot._bs.is_spin_polarized:
                        bstraces.append(
                            go.Scatter(
                                x=bs_data["distances"][d],
                                y=[
                                    bs_data["energy"][d][str(Spin.down)][i][j]
                                    for j in range(len(bs_data["distances"]
                                                       [d]))
                                ],
                                mode="lines",
                                line=dict(color=("#666666"),
                                          width=2,
                                          dash="dash"),
                                hoverinfo="skip",
                                showlegend=False,
                            ))

            # -- DOS Data
            dostraces = []

            dos = CompleteDos.from_dict(densityOfStates)

            if Spin.down in dos.densities:
                # Add second spin data if available
                trace_tdos = go.Scatter(
                    x=dos.densities[Spin.down],
                    y=dos.energies - dos.efermi,
                    mode="lines",
                    name="Total DOS (spin ↓)",
                    line=go.scatter.Line(color="#444444", dash="dash"),
                    fill="tozeroy",
                )

                dostraces.append(trace_tdos)

                tdos_label = "Total DOS (spin ↑)"
            else:
                tdos_label = "Total DOS"

            # Total DOS
            trace_tdos = go.Scatter(
                x=dos.densities[Spin.up],
                y=dos.energies - dos.efermi,
                mode="lines",
                name=tdos_label,
                line=go.scatter.Line(color="#444444"),
                fill="tozeroy",
                legendgroup="spinup",
            )

            dostraces.append(trace_tdos)

            p_ele_dos = dos.get_element_dos()

            # Projected DOS
            count = 0
            colors = [
                "#1f77b4",  # muted blue
                "#ff7f0e",  # safety orange
                "#2ca02c",  # cooked asparagus green
                "#d62728",  # brick red
                "#9467bd",  # muted purple
                "#8c564b",  # chestnut brown
                "#e377c2",  # raspberry yogurt pink
                "#bcbd22",  # curry yellow-green
                "#17becf",  # blue-teal
            ]

            for ele in p_ele_dos.keys():

                if bs_reg_plot._bs.is_spin_polarized:
                    trace = go.Scatter(
                        x=p_ele_dos[ele].densities[Spin.down],
                        y=dos.energies - dos.efermi,
                        mode="lines",
                        name=ele.symbol + " (spin ↓)",
                        line=dict(width=3, color=colors[count], dash="dash"),
                    )

                    dostraces.append(trace)
                    spin_up_label = ele.symbol + " (spin ↑)"

                else:
                    spin_up_label = ele.symbol

                trace = go.Scatter(
                    x=p_ele_dos[ele].densities[Spin.up],
                    y=dos.energies - dos.efermi,
                    mode="lines",
                    name=spin_up_label,
                    line=dict(width=3, color=colors[count]),
                )

                dostraces.append(trace)

                count += 1

            traces = [bstraces, dostraces, bs_data]

            return traces
示例#27
0
        def bs_dos_data(
            mpid,
            path_convention,
            dos_select,
            label_select,
            bandstructure_symm_line,
            density_of_states,
        ):
            if not mpid and (bandstructure_symm_line is None
                             or density_of_states is None):
                raise PreventUpdate

            elif bandstructure_symm_line is None or density_of_states is None:
                if label_select == "":
                    raise PreventUpdate

                # --
                # -- BS and DOS from API or DB
                # --

                bs_data = {"ticks": {}}

                bs_store = GridFSStore(
                    database="fw_bs_prod",
                    collection_name="bandstructure_fs",
                    host="mongodb03.nersc.gov",
                    port=27017,
                    username="******",
                    password="",
                )

                dos_store = GridFSStore(
                    database="fw_bs_prod",
                    collection_name="dos_fs",
                    host="mongodb03.nersc.gov",
                    port=27017,
                    username="******",
                    password="",
                )

                es_store = MongoStore(
                    database="fw_bs_prod",
                    collection_name="electronic_structure",
                    host="mongodb03.nersc.gov",
                    port=27017,
                    username="******",
                    password="",
                    key="task_id",
                )

                # - BS traces from DB using task_id
                es_store.connect()
                bs_query = es_store.query_one(
                    criteria={"task_id": int(mpid)},
                    properties=[
                        "bandstructure.{}.task_id".format(path_convention),
                        "bandstructure.{}.total.equiv_labels".format(
                            path_convention),
                    ],
                )

                es_store.close()

                bs_store.connect()
                bandstructure_symm_line = bs_store.query_one(criteria={
                    "metadata.task_id":
                    int(bs_query["bandstructure"][path_convention]["task_id"])
                }, )

                # If LM convention, get equivalent labels
                if path_convention != label_select:
                    bs_equiv_labels = bs_query["bandstructure"][
                        path_convention]["total"]["equiv_labels"]

                    new_labels_dict = {}
                    for label in bandstructure_symm_line["labels_dict"].keys():

                        label_formatted = label.replace("$", "")

                        if "|" in label_formatted:
                            f_label = label_formatted.split("|")
                            new_labels.append(
                                "$" +
                                bs_equiv_labels[label_select][f_label[0]] +
                                "|" +
                                bs_equiv_labels[label_select][f_label[1]] +
                                "$")
                        else:
                            new_labels_dict["$" + bs_equiv_labels[label_select]
                                            [label_formatted] +
                                            "$"] = bandstructure_symm_line[
                                                "labels_dict"][label]

                    bandstructure_symm_line["labels_dict"] = new_labels_dict

                # - DOS traces from DB using task_id
                es_store.connect()
                dos_query = es_store.query_one(
                    criteria={"task_id": int(mpid)},
                    properties=["dos.task_id"],
                )
                es_store.close()

                dos_store.connect()
                density_of_states = dos_store.query_one(
                    criteria={"task_id": int(dos_query["dos"]["task_id"])}, )

            # - BS Data
            if (type(bandstructure_symm_line) != dict
                    and bandstructure_symm_line is not None):
                bandstructure_symm_line = bandstructure_symm_line.to_dict()

            if type(density_of_states
                    ) != dict and density_of_states is not None:
                density_of_states = density_of_states.to_dict()

            bsml = BSML.from_dict(bandstructure_symm_line)

            bs_reg_plot = BSPlotter(bsml)

            bs_data = bs_reg_plot.bs_plot_data()

            # Make plot continous for lm
            if path_convention == "lm":
                distance_map, kpath_euler = HSKP(
                    bsml.structure).get_continuous_path(bsml)

                kpath_labels = [pair[0] for pair in kpath_euler]
                kpath_labels.append(kpath_euler[-1][1])

            else:
                distance_map = [(i, False)
                                for i in range(len(bs_data["distances"]))]
                kpath_labels = []
                for label_ind in range(len(bs_data["ticks"]["label"]) - 1):
                    if (bs_data["ticks"]["label"][label_ind] !=
                            bs_data["ticks"]["label"][label_ind + 1]):
                        kpath_labels.append(
                            bs_data["ticks"]["label"][label_ind])
                kpath_labels.append(bs_data["ticks"]["label"][-1])

            bs_data["ticks"]["label"] = kpath_labels

            # Obtain bands to plot over and generate traces for bs data:
            energy_window = (-6.0, 10.0)
            bands = []
            for band_num in range(bs_reg_plot._nb_bands):
                if (bs_data["energy"][0][str(Spin.up)][band_num][0] <=
                        energy_window[1]) and (bs_data["energy"][0][str(
                            Spin.up)][band_num][0] >= energy_window[0]):
                    bands.append(band_num)

            bstraces = []

            pmin = 0.0
            tick_vals = [0.0]

            cbm = bsml.get_cbm()
            vbm = bsml.get_vbm()

            cbm_new = bs_data["cbm"]
            vbm_new = bs_data["vbm"]

            for dnum, (d, rev) in enumerate(distance_map):

                x_dat = [
                    dval - bs_data["distances"][d][0] + pmin
                    for dval in bs_data["distances"][d]
                ]

                pmin = x_dat[-1]

                tick_vals.append(pmin)

                if not rev:
                    traces_for_segment = [{
                        "x":
                        x_dat,
                        "y": [
                            bs_data["energy"][d][str(Spin.up)][i][j]
                            for j in range(len(bs_data["distances"][d]))
                        ],
                        "mode":
                        "lines",
                        "line": {
                            "color": "#1f77b4"
                        },
                        "hoverinfo":
                        "skip",
                        "name":
                        "spin ↑"
                        if bs_reg_plot._bs.is_spin_polarized else "Total",
                        "hovertemplate":
                        "%{y:.2f} eV",
                        "showlegend":
                        False,
                        "xaxis":
                        "x",
                        "yaxis":
                        "y",
                    } for i in bands]
                elif rev:
                    traces_for_segment = [{
                        "x":
                        x_dat,
                        "y": [
                            bs_data["energy"][d][str(Spin.up)][i][j]
                            for j in reversed(
                                range(len(bs_data["distances"][d])))
                        ],
                        "mode":
                        "lines",
                        "line": {
                            "color": "#1f77b4"
                        },
                        "hoverinfo":
                        "skip",
                        "name":
                        "spin ↑"
                        if bs_reg_plot._bs.is_spin_polarized else "Total",
                        "hovertemplate":
                        "%{y:.2f} eV",
                        "showlegend":
                        False,
                        "xaxis":
                        "x",
                        "yaxis":
                        "y",
                    } for i in bands]

                if bs_reg_plot._bs.is_spin_polarized:

                    if not rev:
                        traces_for_segment += [{
                            "x":
                            x_dat,
                            "y": [
                                bs_data["energy"][d][str(Spin.down)][i][j]
                                for j in range(len(bs_data["distances"][d]))
                            ],
                            "mode":
                            "lines",
                            "line": {
                                "color": "#ff7f0e",
                                "dash": "dot"
                            },
                            "hoverinfo":
                            "skip",
                            "showlegend":
                            False,
                            "name":
                            "spin ↓",
                            "hovertemplate":
                            "%{y:.2f} eV",
                            "xaxis":
                            "x",
                            "yaxis":
                            "y",
                        } for i in bands]
                    elif rev:
                        traces_for_segment += [{
                            "x":
                            x_dat,
                            "y": [
                                bs_data["energy"][d][str(Spin.down)][i][j]
                                for j in reversed(
                                    range(len(bs_data["distances"][d])))
                            ],
                            "mode":
                            "lines",
                            "line": {
                                "color": "#ff7f0e",
                                "dash": "dot"
                            },
                            "hoverinfo":
                            "skip",
                            "showlegend":
                            False,
                            "name":
                            "spin ↓",
                            "hovertemplate":
                            "%{y:.2f} eV",
                            "xaxis":
                            "x",
                            "yaxis":
                            "y",
                        } for i in bands]

                bstraces += traces_for_segment

                # - Get proper cbm and vbm coords for lm
                if path_convention == "lm":
                    for (x_point, y_point) in bs_data["cbm"]:
                        if x_point in bs_data["distances"][d]:
                            xind = bs_data["distances"][d].index(x_point)
                            if not rev:
                                x_point_new = x_dat[xind]
                            else:
                                x_point_new = x_dat[len(x_dat) - xind - 1]

                            new_label = bs_data["ticks"]["label"][
                                tick_vals.index(x_point_new)]

                            if (cbm["kpoint"].label is None
                                    or cbm["kpoint"].label in new_label):
                                cbm_new.append((x_point_new, y_point))

                    for (x_point, y_point) in bs_data["vbm"]:
                        if x_point in bs_data["distances"][d]:
                            xind = bs_data["distances"][d].index(x_point)
                            if not rev:
                                x_point_new = x_dat[xind]
                            else:
                                x_point_new = x_dat[len(x_dat) - xind - 1]

                            new_label = bs_data["ticks"]["label"][
                                tick_vals.index(x_point_new)]

                            if (vbm["kpoint"].label is None
                                    or vbm["kpoint"].label in new_label):
                                vbm_new.append((x_point_new, y_point))

            bs_data["ticks"]["distance"] = tick_vals

            # - Strip latex math wrapping for labels
            str_replace = {
                "$": "",
                "\\mid": "|",
                "\\Gamma": "Γ",
                "\\Sigma": "Σ",
                "GAMMA": "Γ",
                "_1": "₁",
                "_2": "₂",
                "_3": "₃",
                "_4": "₄",
                "_{1}": "₁",
                "_{2}": "₂",
                "_{3}": "₃",
                "_{4}": "₄",
                "^{*}": "*",
            }

            bar_loc = []
            for entry_num in range(len(bs_data["ticks"]["label"])):
                for key in str_replace.keys():
                    if key in bs_data["ticks"]["label"][entry_num]:
                        bs_data["ticks"]["label"][entry_num] = bs_data[
                            "ticks"]["label"][entry_num].replace(
                                key, str_replace[key])
                        if key == "\\mid":
                            bar_loc.append(
                                bs_data["ticks"]["distance"][entry_num])

            # Vertical lines for disjointed segments
            vert_traces = [{
                "x": [x_point, x_point],
                "y": energy_window,
                "mode": "lines",
                "marker": {
                    "color": "white"
                },
                "hoverinfo": "skip",
                "showlegend": False,
                "xaxis": "x",
                "yaxis": "y",
            } for x_point in bar_loc]

            bstraces += vert_traces

            # Dots for cbm and vbm

            dot_traces = [{
                "x": [x_point],
                "y": [y_point],
                "mode":
                "markers",
                "marker": {
                    "color": "#7E259B",
                    "size": 16,
                    "line": {
                        "color": "white",
                        "width": 2
                    },
                },
                "showlegend":
                False,
                "hoverinfo":
                "text",
                "name":
                "",
                "hovertemplate":
                "CBM: k = {}, {} eV".format(list(cbm["kpoint"].frac_coords),
                                            cbm["energy"]),
                "xaxis":
                "x",
                "yaxis":
                "y",
            } for (x_point, y_point) in set(cbm_new)] + [{
                "x": [x_point],
                "y": [y_point],
                "mode":
                "marker",
                "marker": {
                    "color": "#7E259B",
                    "size": 16,
                    "line": {
                        "color": "white",
                        "width": 2
                    },
                },
                "showlegend":
                False,
                "hoverinfo":
                "text",
                "name":
                "",
                "hovertemplate":
                "VBM: k = {}, {} eV".format(list(vbm["kpoint"].frac_coords),
                                            vbm["energy"]),
                "xaxis":
                "x",
                "yaxis":
                "y",
            } for (x_point, y_point) in set(vbm_new)]

            bstraces += dot_traces

            # - DOS Data
            dostraces = []

            dos = CompleteDos.from_dict(density_of_states)

            dos_max = np.abs(
                (dos.energies - dos.efermi - energy_window[1])).argmin()
            dos_min = np.abs(
                (dos.energies - dos.efermi - energy_window[0])).argmin()

            if bs_reg_plot._bs.is_spin_polarized:
                # Add second spin data if available
                trace_tdos = {
                    "x": -1.0 * dos.densities[Spin.down][dos_min:dos_max],
                    "y": dos.energies[dos_min:dos_max] - dos.efermi,
                    "mode": "lines",
                    "name": "Total DOS (spin ↓)",
                    "line": go.scatter.Line(color="#444444", dash="dot"),
                    "fill": "tozerox",
                    "fillcolor": "#C4C4C4",
                    "xaxis": "x2",
                    "yaxis": "y2",
                }

                dostraces.append(trace_tdos)

                tdos_label = "Total DOS (spin ↑)"
            else:
                tdos_label = "Total DOS"

            # Total DOS
            trace_tdos = {
                "x": dos.densities[Spin.up][dos_min:dos_max],
                "y": dos.energies[dos_min:dos_max] - dos.efermi,
                "mode": "lines",
                "name": tdos_label,
                "line": go.scatter.Line(color="#444444"),
                "fill": "tozerox",
                "fillcolor": "#C4C4C4",
                "legendgroup": "spinup",
                "xaxis": "x2",
                "yaxis": "y2",
            }

            dostraces.append(trace_tdos)

            ele_dos = dos.get_element_dos()
            elements = [str(entry) for entry in ele_dos.keys()]

            if dos_select == "ap":
                proj_data = ele_dos
            elif dos_select == "op":
                proj_data = dos.get_spd_dos()
            elif "orb" in dos_select:
                proj_data = dos.get_element_spd_dos(
                    Element(dos_select.replace("orb", "")))
            else:
                raise PreventUpdate

            # Projected DOS
            count = 0
            colors = [
                "#d62728",  # brick red
                "#2ca02c",  # cooked asparagus green
                "#17becf",  # blue-teal
                "#bcbd22",  # curry yellow-green
                "#9467bd",  # muted purple
                "#8c564b",  # chestnut brown
                "#e377c2",  # raspberry yogurt pink
            ]

            for label in proj_data.keys():

                if bs_reg_plot._bs.is_spin_polarized:
                    trace = {
                        "x":
                        -1.0 *
                        proj_data[label].densities[Spin.down][dos_min:dos_max],
                        "y":
                        dos.energies[dos_min:dos_max] - dos.efermi,
                        "mode":
                        "lines",
                        "name":
                        str(label) + " (spin ↓)",
                        "line":
                        dict(width=3, color=colors[count], dash="dot"),
                        "xaxis":
                        "x2",
                        "yaxis":
                        "y2",
                    }

                    dostraces.append(trace)
                    spin_up_label = str(label) + " (spin ↑)"

                else:
                    spin_up_label = str(label)

                trace = {
                    "x": proj_data[label].densities[Spin.up][dos_min:dos_max],
                    "y": dos.energies[dos_min:dos_max] - dos.efermi,
                    "mode": "lines",
                    "name": spin_up_label,
                    "line": dict(width=2, color=colors[count]),
                    "xaxis": "x2",
                    "yaxis": "y2",
                }

                dostraces.append(trace)

                count += 1
            traces = [bstraces, dostraces, bs_data]

            return (traces, elements)
示例#28
0
    def process_item(self, mat):
        """
        Process the tasks and materials into just a list of materials

        Args:
            mat (dict): material document

        Returns:
            (dict): electronic_structure document
        """
        d = {self.electronic_structure.key: mat[self.materials.key]}
        self.logger.info("Processing: {}".format(mat[self.materials.key]))

        bs = BandStructureSymmLine.from_dict(mat["bandstructure"]["bs"])
        dos = CompleteDos.from_dict(mat["bandstructure"]["dos"])

        # Plot Band structure
        if bs:
            try:
                plotter = WebBSPlotter(bs)
                plot = plotter.get_plot()
                ylim = plot.ylim()
                d["bs_plot"] = image_from_plot(plot)
                plot.close()
            except Exception:
                self.logger.warning(
                    "Caught error in bandstructure plotting for {}: {}".format(
                        mat[self.materials.key], traceback.format_exc()))

        # Reduced Band structure plot
        try:
            gap = bs.get_band_gap()["energy"]
            plot_data = plotter.bs_plot_data()
            d["bs_plot_small"] = get_small_plot(plot_data, gap)
        except Exception:
            self.logger.warning(
                "Caught error in generating reduced bandstructure plot for {}: {}"
                .format(mat[self.materials.key], traceback.format_exc()))

        # Plot DOS
        if dos:
            try:
                plotter = WebDosVertPlotter()
                plotter.add_dos_dict(dos.get_element_dos())
                plot = plotter.get_plot(ylim=ylim)
                d["dos_plot"] = image_from_plot(plot)
                plot.close()
            except Exception:
                self.logger.warning(
                    "Caught error in dos plotting for {}: {}".format(
                        mat[self.materials.key], traceback.format_exc()))

        # Get basic bandgap properties
        if bs:
            try:
                d["band_gap"] = {
                    "band_gap": bs.get_band_gap()["energy"],
                    "direct_gap": bs.get_direct_band_gap(),
                    "is_direct": bs.get_band_gap()["direct"],
                    "transition": bs.get_band_gap()["transition"]
                }
            except Exception:
                self.logger.warning(
                    "Caught error in calculating bandgap {}: {}".format(
                        mat[self.materials.key], traceback.format_exc()))

        return d
示例#29
0
 def setUp(self):
     with open(os.path.join(test_dir, "Cu2O_361_bandstructure.json"),
               "r", encoding='utf-8') as f:
         d = json.load(f)
         self.bs = BandStructureSymmLine.from_dict(d)
         self.plotter = BSPlotterProjected(self.bs)
示例#30
0
        def bs_dos_data(
            mpid,
            path_convention,
            dos_select,
            label_select,
            bandstructure_symm_line,
            density_of_states,
        ):
            if (not mpid
                    or "mpid" not in mpid) and (bandstructure_symm_line is None
                                                or density_of_states is None):
                raise PreventUpdate
            elif mpid:
                raise PreventUpdate
            elif bandstructure_symm_line is None or density_of_states is None:

                # --
                # -- BS and DOS from API
                # --

                mpid = mpid["mpid"]
                bs_data = {"ticks": {}}

                # client = MongoClient(
                #     "mongodb03.nersc.gov", username="******", password="", authSource="fw_bs_prod",
                # )

                db = client.fw_bs_prod

                # - BS traces from DB using task_id
                bs_query = list(
                    db.electronic_structure.find(
                        {"task_id": int(mpid)},
                        [
                            "bandstructure.{}.total.traces".format(
                                path_convention)
                        ],
                    ))[0]

                is_sp = (len(bs_query["bandstructure"][path_convention]
                             ["total"]["traces"]) == 2)

                if is_sp:
                    bstraces = (bs_query["bandstructure"][path_convention]
                                ["total"]["traces"]["1"] +
                                bs_query["bandstructure"][path_convention]
                                ["total"]["traces"]["-1"])
                else:
                    bstraces = bs_query["bandstructure"][path_convention][
                        "total"]["traces"]["1"]

                bs_data["ticks"]["distance"] = bs_query["bandstructure"][
                    path_convention]["total"]["traces"]["ticks"]
                bs_data["ticks"]["label"] = bs_query["bandstructure"][
                    path_convention]["total"]["traces"]["labels"]

                # If LM convention, get equivalent labels
                if path_convention == "lm" and label_select != "lm":
                    bs_equiv_labels = bs_query["bandstructure"][
                        path_convention]["total"]["traces"]["equiv_labels"]

                    alt_choice = label_select

                    if label_select == "hin":
                        alt_choice = "h"

                    new_labels = []
                    for label in bs_data["ticks"]["label"]:
                        label_formatted = label.replace("$", "")

                        if "|" in label_formatted:
                            f_label = label_formatted.split("|")
                            new_labels.append(
                                "$" + bs_equiv_labels[alt_choice][f_label[0]] +
                                "|" + bs_equiv_labels[alt_choice][f_label[1]] +
                                "$")
                        else:
                            new_labels.append(
                                "$" +
                                bs_equiv_labels[alt_choice][label_formatted] +
                                "$")

                    bs_data["ticks"]["label"] = new_labels

                # Strip latex math wrapping
                str_replace = {
                    "$": "",
                    "\\mid": "|",
                    "\\Gamma": "Γ",
                    "\\Sigma": "Σ",
                    "GAMMA": "Γ",
                    "_1": "₁",
                    "_2": "₂",
                    "_3": "₃",
                    "_4": "₄",
                }

                for entry_num in range(len(bs_data["ticks"]["label"])):
                    for key in str_replace.keys():
                        if key in bs_data["ticks"]["label"][entry_num]:
                            bs_data["ticks"]["label"][entry_num] = bs_data[
                                "ticks"]["label"][entry_num].replace(
                                    key, str_replace[key])

                # - DOS traces from DB using task_id
                dostraces = []

                dos_tot_ele_traces = list(
                    db.electronic_structure.find(
                        {"task_id": int(mpid)},
                        ["dos.total.traces", "dos.elements"]))[0]

                dostraces = [
                    dos_tot_ele_traces["dos"]["total"]["traces"][spin] for spin
                    in dos_tot_ele_traces["dos"]["total"]["traces"].keys()
                ]

                elements = [
                    ele
                    for ele in dos_tot_ele_traces["dos"]["elements"].keys()
                ]

                if dos_select == "ap":
                    for ele_label in elements:
                        dostraces += [
                            dos_tot_ele_traces["dos"]["elements"][ele_label]
                            ["total"]["traces"][spin]
                            for spin in dos_tot_ele_traces["dos"]["elements"]
                            [ele_label]["total"]["traces"].keys()
                        ]

                elif dos_select == "op":
                    orb_tot_traces = list(
                        db.electronic_structure.find({"task_id": int(mpid)},
                                                     ["dos.orbitals"]))[0]
                    for orbital in ["s", "p", "d"]:
                        dostraces += [
                            orb_tot_traces["dos"]["orbitals"][orbital]
                            ["traces"][spin] for spin in orb_tot_traces["dos"]
                            ["orbitals"]["s"]["traces"].keys()
                        ]

                elif "orb" in dos_select:
                    ele_label = dos_select.replace("orb", "")

                    for orbital in ["s", "p", "d"]:
                        dostraces += [
                            dos_tot_ele_traces["dos"]["elements"][ele_label]
                            [orbital]["traces"][spin]
                            for spin in dos_tot_ele_traces["dos"]["elements"]
                            [ele_label][orbital]["traces"].keys()
                        ]

                traces = [bstraces, dostraces, bs_data]

                return (traces, elements)

            else:

                # --
                # -- BS and DOS passed manually
                # --

                # - BS Data

                if type(bandstructure_symm_line) != dict:
                    bandstructure_symm_line = bandstructure_symm_line.to_dict()

                if type(density_of_states) != dict:
                    density_of_states = density_of_states.to_dict()

                bs_reg_plot = BSPlotter(
                    BSML.from_dict(bandstructure_symm_line))
                bs_data = bs_reg_plot.bs_plot_data()

                # - Strip latex math wrapping
                str_replace = {
                    "$": "",
                    "\\mid": "|",
                    "\\Gamma": "Γ",
                    "\\Sigma": "Σ",
                    "GAMMA": "Γ",
                    "_1": "₁",
                    "_2": "₂",
                    "_3": "₃",
                    "_4": "₄",
                }

                for entry_num in range(len(bs_data["ticks"]["label"])):
                    for key in str_replace.keys():
                        if key in bs_data["ticks"]["label"][entry_num]:
                            bs_data["ticks"]["label"][entry_num] = bs_data[
                                "ticks"]["label"][entry_num].replace(
                                    key, str_replace[key])

                # Obtain bands to plot over:
                energy_window = (-6.0, 10.0)
                bands = []
                for band_num in range(bs_reg_plot._nb_bands):
                    if (bs_data["energy"][0][str(Spin.up)][band_num][0] <=
                            energy_window[1]) and (bs_data["energy"][0][str(
                                Spin.up)][band_num][0] >= energy_window[0]):
                        bands.append(band_num)

                bstraces = []

                # Generate traces for total BS data
                for d in range(len(bs_data["distances"])):
                    dist_dat = bs_data["distances"][d]
                    energy_ind = [
                        i for i in range(len(bs_data["distances"][d]))
                    ]

                    traces_for_segment = [{
                        "x":
                        dist_dat,
                        "y":
                        [bs_data["energy"][d]["1"][i][j] for j in energy_ind],
                        "mode":
                        "lines",
                        "line": {
                            "color": "#666666"
                        },
                        "hoverinfo":
                        "skip",
                        "showlegend":
                        False,
                    } for i in bands]

                    if bs_reg_plot._bs.is_spin_polarized:
                        traces_for_segment += [{
                            "x":
                            dist_dat,
                            "y": [
                                bs_data["energy"][d]["-1"][i][j]
                                for j in energy_ind
                            ],
                            "mode":
                            "lines",
                            "line": {
                                "color": "#666666"
                            },
                            "hoverinfo":
                            "skip",
                            "showlegend":
                            False,
                        } for i in bands]

                    bstraces += traces_for_segment

                # - DOS Data
                dostraces = []

                dos = CompleteDos.from_dict(density_of_states)

                dos_max = np.abs(
                    (dos.energies - dos.efermi - energy_window[1])).argmin()
                dos_min = np.abs(
                    (dos.energies - dos.efermi - energy_window[0])).argmin()

                if bs_reg_plot._bs.is_spin_polarized:
                    # Add second spin data if available
                    trace_tdos = go.Scatter(
                        x=dos.densities[Spin.down][dos_min:dos_max],
                        y=dos.energies[dos_min:dos_max] - dos.efermi,
                        mode="lines",
                        name="Total DOS (spin ↓)",
                        line=go.scatter.Line(color="#444444", dash="dash"),
                        fill="tozerox",
                    )

                    dostraces.append(trace_tdos)

                    tdos_label = "Total DOS (spin ↑)"
                else:
                    tdos_label = "Total DOS"

                # Total DOS
                trace_tdos = go.Scatter(
                    x=dos.densities[Spin.up][dos_min:dos_max],
                    y=dos.energies[dos_min:dos_max] - dos.efermi,
                    mode="lines",
                    name=tdos_label,
                    line=go.scatter.Line(color="#444444"),
                    fill="tozerox",
                    legendgroup="spinup",
                )

                dostraces.append(trace_tdos)

                ele_dos = dos.get_element_dos()
                elements = [str(entry) for entry in ele_dos.keys()]

                if dos_select == "ap":
                    proj_data = ele_dos
                elif dos_select == "op":
                    proj_data = dos.get_spd_dos()
                elif "orb" in dos_select:
                    proj_data = dos.get_element_spd_dos(
                        Element(dos_select.replace("orb", "")))
                else:
                    raise PreventUpdate

                # Projected DOS
                count = 0
                colors = [
                    "#1f77b4",  # muted blue
                    "#ff7f0e",  # safety orange
                    "#2ca02c",  # cooked asparagus green
                    "#9467bd",  # muted purple
                    "#e377c2",  # raspberry yogurt pink
                    "#d62728",  # brick red
                    "#8c564b",  # chestnut brown
                    "#bcbd22",  # curry yellow-green
                    "#17becf",  # blue-teal
                ]

                for label in proj_data.keys():

                    if bs_reg_plot._bs.is_spin_polarized:
                        trace = go.Scatter(
                            x=proj_data[label].densities[Spin.down]
                            [dos_min:dos_max],
                            y=dos.energies[dos_min:dos_max] - dos.efermi,
                            mode="lines",
                            name=str(label) + " (spin ↓)",
                            line=dict(width=3,
                                      color=colors[count],
                                      dash="dash"),
                        )

                        dostraces.append(trace)
                        spin_up_label = str(label) + " (spin ↑)"

                    else:
                        spin_up_label = str(label)

                    trace = go.Scatter(
                        x=proj_data[label].densities[Spin.up][dos_min:dos_max],
                        y=dos.energies[dos_min:dos_max] - dos.efermi,
                        mode="lines",
                        name=spin_up_label,
                        line=dict(width=3, color=colors[count]),
                    )

                    dostraces.append(trace)

                    count += 1

                traces = [bstraces, dostraces, bs_data]

                return (traces, elements)
示例#31
0
def HSE_Gap_with_Vasp_from_icsd_data(file, vasp_cmd):
    fws = []
    all = json.load(open(file))
    njobs = 0
    g = open('/SCRATCH/acad/htbase/yugd/fw_workplace/hse_gaps_icsd/ids')
    ids = g.read().split()
    f = open('wrong_list.dat', 'a+')
    for task in all:
        icsd_id = 'icsd-' + str(task['icsd_ids'][0])
        if icsd_id not in ids:
            continue
        try:
            struct = Structure.from_dict(task['structure'])
            pbe_bs = BandStructureSymmLine.from_dict(
                task['band_structures']['line'])
            is_spin_polarized = pbe_bs.is_spin_polarized
            k_vbm = pbe_bs.get_vbm()['kpoint']._fcoords
            k_cbm = pbe_bs.get_cbm()['kpoint']._fcoords
        except:
            f.write('all[' + str(all.index(task)) + ']')
            continue


##########################################################
        njobs = njobs + 1
        queue_hse = {}
        queue_pbesol = {}
        queue_hse['job_name'] = icsd_id + '_HSE'
        queue_pbesol['job_name'] = icsd_id + '_PBEsol'
        nsites = struct.num_sites
        if nsites <= 5:
            queue_pbesol['ntasks'] = 6
            queue_pbesol['walltime'] = '6:00:00'
            queue_hse['ntasks'] = 12
            queue_hse['walltime'] = '24:00:00'
        elif 5 < nsites <= 10:
            queue_pbesol['ntasks'] = 12
            queue_pbesol['walltime'] = '6:00:00'
            queue_hse['ntasks'] = 24
            queue_hse['walltime'] = '48:00:00'
        elif 10 < nsites <= 20:
            queue_pbesol['ntasks'] = 12
            queue_pbesol['walltime'] = '24:00:00'
            queue_hse['ntasks'] = 48
            queue_hse['walltime'] = '48:00:00'
        elif 20 < nsites <= 50:
            queue_pbesol['ntasks'] = 24
            queue_pbesol['walltime'] = '48:00:00'
            queue_hse['ntasks'] = 60
            queue_hse['walltime'] = '72:00:00'
        else:
            queue_pbesol['ntasks'] = 36
            queue_pbesol['walltime'] = '48:00:00'
            queue_hse['ntasks'] = 96
            queue_hse['walltime'] = '72:00:00'

        queue_insert = {}
        queue_insert['ntasks'] = 1
        queue_insert['walltime'] = '01:00:00'
        queue_insert['vmem'] = '1024mb'
        queue_insert['job_name'] = icsd_id + '_Insert'
        ##########################################################

        if pbe_bs.get_band_gap()['direct']:
            ks_add = [k_vbm]
        else:
            ks_add = [k_vbm, k_cbm]
        ftask00 = Generate_VaspInputFiles_for_Relax_with_PBEsol_icsd()
        ftask01 = VaspRun()
        fw0 = Firework(
            [ftask00, ftask01],
            spec={
                '_pass_job_info': True,
                '_preserve_fworker': True,
                '_fworker': fworker,
                'vasp_cmd': vasp_cmd,
                'workdir': './',
                'structure': struct,
                'icsd_id': icsd_id,
                'is_spin_polarized': is_spin_polarized,
                '_queueadapter': queue_pbesol
            },
            name=icsd_id + '_pbesol')

        ftask10 = Generate_VaspInputFiles_for_HSE_scf_icsd()
        ftask11 = VaspRun()
        fw1 = Firework(
            [ftask10, ftask11],
            parents=[fw0],
            spec={
                '_pass_job_info': True,
                '_preserve_fworker': True,
                '_fworker': fworker,
                'vasp_cmd': vasp_cmd,
                'workdir': './',
                'icsd_id': icsd_id,
                'ks_add': ks_add,
                'is_spin_polarized': is_spin_polarized,
                '_queueadapter': queue_hse
            },
            name=icsd_id + '_HSE')
        finsert = Insert_Gap_into_monogdb_icsd()
        fw2 = Firework([finsert],
                       parents=[fw1],
                       spec={
                           'icsd_id': icsd_id,
                           '_queueadapter': queue_insert,
                           'pbe_bs': pbe_bs
                       },
                       name=icsd_id + '_Insert')
        fws.append(fw0)
        fws.append(fw1)
        fws.append(fw2)
    wf = Workflow(fws, name='HSE_Gap_for_icsds')
    launchpad.add_wf(wf)
    f.close()
    return njobs
示例#32
0
 def setUp(self):
     with open(os.path.join(test_dir, "CaO_2605_bandstructure.json"),
               "rb") as f:
         d = json.loads(f.read())
         self.bs = BandStructureSymmLine.from_dict(d)
         self.plotter = BSPlotter(self.bs)