예제 #1
0
def get_solute_def_profile(mpid, solute, solute_conc, T, def_file, sol_file, 
        trial_chem_pot):

    raw_energy_dict = loadfn(def_file,cls=MontyDecoder)
    sol_raw_energy_dict = loadfn(sol_file,cls=MontyDecoder)

    #try:
    e0 = raw_energy_dict[mpid]['e0']
    struct = raw_energy_dict[mpid]['structure']
    vacs = raw_energy_dict[mpid]['vacancies']
    antisites = raw_energy_dict[mpid]['antisites']
    solutes = sol_raw_energy_dict[mpid]['solutes']
    for vac_def in vacs:
        if not vac_def:
            print 'All vacancy defect energies not present'
            continue
    for antisite_def in antisites:
        if not antisite_def:
            print 'All antisite defect energies not preset'
            continue
    for solute_def in solutes:
        if not solute_def:
            print 'All solute defect energies not preset'
            continue

    try:
        def_conc = solute_defect_density(struct, e0, vacs, 
                antisites, solutes, solute_concen=solute_conc, T=T, 
                trial_chem_pot=trial_chem_pot, plot_style="gnuplot")
        return  def_conc
    except:
        raise
예제 #2
0
    def __init__(self,vasp_task=None,name='vaspfw',handlers=None, handler_params=None, config_file=None):
        self.name = name
        self.handlers=handlers if handlers else []
        self.handler_params=handler_params if handler_params else {}

        if config_file:
            config_dict = loadfn(config_file)
        elif os.path.exists(os.path.join(os.environ['HOME'], 'vasp_interface_defaults.yaml')):
            config_dict = loadfn(os.path.join(os.environ['HOME'], 'vasp_interface_defaults.yaml'))
        else:
            config_dict = {}

        if config_dict:
            self.custodian_opts = config_dict.get('CUSTODIAN_PARAMS', {})
            if self.custodian_opts.get('handlers', []):
                self.handlers.extend(self.custodian_opts.get('handlers', []))
            self.handler_params.update(self.custodian_opts.get('handler_params', {}))

        self.tasks=[vasp_task.input,RunCustodianTask(handlers=self.handlers, 
                handler_params=self.handler_params)] if isinstance(vasp_task, 
                VaspInputInterface) else [vasp_task]
        self.Firework=Firework(self.tasks,name=self.name)
        # Try to establish connection with Launchpad
        try:
            self.LaunchPad=LaunchPad.from_file(os.path.join(os.environ["HOME"], ".fireworks", "my_launchpad.yaml"))
        except:
            self.LaunchPad = None
예제 #3
0
def get_solute_def_profile1(mpid, solute, solute_conc, T, def_file, sol_file):

    raw_energy_dict = loadfn(def_file,cls=MontyDecoder)
    sol_raw_energy_dict = loadfn(sol_file,cls=MontyDecoder)

    #try:
    print raw_energy_dict[mpid].keys()
    e0 = raw_energy_dict[mpid]['e0']
    struct = raw_energy_dict[mpid]['structure']
    vacs = raw_energy_dict[mpid]['vacancies']
    antisites = raw_energy_dict[mpid]['antisites']
    solutes = sol_raw_energy_dict[mpid]['solutes']
    for vac_def in vacs:
        if not vac_def:
            print 'All vacancy defect energies not present'
            continue
    for antisite_def in antisites:
        if not antisite_def:
            print 'All antisite defect energies not preset'
            continue
    for solute_def in solutes:
        if not solute_def:
            print 'All solute defect energies not preset'
            continue

    try:
        sol_conc = solute_site_preference_finder(struct, e0, T, vacs, 
                antisites, solutes, solute_conc)#, 
                #trial_chem_pot={'Al':-4.120, 'Ni':-6.5136, 'Ti':-7.7861})
        return sol_conc
    except:
        raise
예제 #4
0
    def test_construction(self):
        edges_frag = {(e[0], e[1]): {"weight":1.0} for e in self.pc_frag1_edges}
        mol_graph = MoleculeGraph.with_edges(self.pc_frag1, edges_frag)
        #dumpfn(mol_graph.as_dict(), os.path.join(module_dir,"pc_frag1_mg.json"))
        ref_mol_graph = loadfn(os.path.join(module_dir, "pc_frag1_mg.json"))
        self.assertEqual(mol_graph, ref_mol_graph)
        self.assertEqual(mol_graph.graph.adj, ref_mol_graph.graph.adj)
        for node in mol_graph.graph.nodes:
            self.assertEqual(mol_graph.graph.node[node]["specie"],
                             ref_mol_graph.graph.node[node]["specie"])
            for ii in range(3):
                self.assertEqual(
                    mol_graph.graph.node[node]["coords"][ii],
                    ref_mol_graph.graph.node[node]["coords"][ii])

        edges_pc = {(e[0], e[1]): {"weight":1.0} for e in self.pc_edges}
        mol_graph = MoleculeGraph.with_edges(self.pc, edges_pc)
        #dumpfn(mol_graph.as_dict(), os.path.join(module_dir,"pc_mg.json"))
        ref_mol_graph = loadfn(os.path.join(module_dir, "pc_mg.json"))
        self.assertEqual(mol_graph, ref_mol_graph)
        self.assertEqual(mol_graph.graph.adj, ref_mol_graph.graph.adj)
        for node in mol_graph.graph:
            self.assertEqual(mol_graph.graph.node[node]["specie"],
                             ref_mol_graph.graph.node[node]["specie"])
            for ii in range(3):
                self.assertEqual(
                    mol_graph.graph.node[node]["coords"][ii],
                    ref_mol_graph.graph.node[node]["coords"][ii])

        mol_graph_edges = MoleculeGraph.with_edges(self.pc, edges=edges_pc)
        mol_graph_strat = MoleculeGraph.with_local_env_strategy(self.pc, OpenBabelNN(), reorder=False, extend_structure=False)
        self.assertTrue(mol_graph_edges.isomorphic_to(mol_graph_strat))
예제 #5
0
    def run_task(self, fw_spec):
        # the FW.json/yaml file is mandatory to get the fw_id
        # no need to deserialize the whole FW

        if '_add_launchpad_and_fw_id' in fw_spec:
            lp = self.launchpad
            fw_id = self.fw_id
        else:
            try:
                fw_dict = loadfn('FW.json')
            except IOError:
                try:
                    fw_dict = loadfn('FW.yaml')
                except IOError:
                    raise RuntimeError("Launchpad/fw_id not present in spec and No FW.json nor FW.yaml file present: "
                                       "impossible to determine fw_id")
            lp = LaunchPad.auto_load()
            fw_id = fw_dict['fw_id']

        wf = lp.get_wf_by_fw_id_lzyfw(fw_id)

        deleted_files = []
        # iterate over all the fws and launches
        for fw_id, fw in wf.id_fw.items():
            for l in fw.launches+fw.archived_launches:
                l_dir = l.launch_dir

                deleted_files.extend(self.delete_files(os.path.join(l_dir, TMPDIR_NAME)))
                deleted_files.extend(self.delete_files(os.path.join(l_dir, INDIR_NAME)))
                deleted_files.extend(self.delete_files(os.path.join(l_dir, OUTDIR_NAME), self.out_exts))

        logging.info("Deleted files:\n {}".format("\n".join(deleted_files)))

        return FWAction(stored_data={'deleted_files': deleted_files})
예제 #6
0
    def run_task(self, fw_spec):
        # the FW.json/yaml file is mandatory to get the fw_id
        # no need to deserialize the whole FW
        try:
            fw_dict = loadfn('FW.json')
        except IOError:
            try:
                fw_dict = loadfn('FW.yaml')
            except IOError:
                raise RuntimeError("No FW.json nor FW.yaml file present: impossible to determine fw_id")

        fw_id = fw_dict['fw_id']
        lp = LaunchPad.auto_load()
        wf = lp.get_wf_by_fw_id_lzyfw(fw_id)
        wf_module = importlib.import_module(wf.metadata['workflow_module'])
        wf_class = getattr(wf_module, wf.metadata['workflow_class'])

        get_results_method = getattr(wf_class, 'get_final_structure_and_history')
        #TODO: make this more general ... just to test right now ...
        results = get_results_method(wf)

        database = MongoDatabase.from_dict(fw_spec['mongo_database'])

        database.insert_entry({'structure': results['structure'], 'history': results['history']})

        logging.info("Inserted data:\n something")


        return FWAction()
예제 #7
0
 def test_check_acc_bzt_bands(self):
     structure = loadfn(os.path.join(test_dir,'boltztrap/structure_mp-12103.json'))
     sbs = loadfn(os.path.join(test_dir,'boltztrap/dft_bs_sym_line.json'))
     sbs_bzt = self.bz_bands.get_symm_bands(structure,-5.25204548)
     corr,werr_vbm,werr_cbm,warn = self.bz_bands.check_acc_bzt_bands(sbs_bzt,sbs)
     self.assertAlmostEqual(corr[2],9.16851750e-05)
     self.assertAlmostEqual(werr_vbm['K-H'],0.18260273521047862)
     self.assertAlmostEqual(werr_cbm['M-K'],0.071552669981356981)
     self.assertFalse(warn)
예제 #8
0
 def test_get_symm_bands(self):
     structure = loadfn(os.path.join(test_dir,'boltztrap/structure_mp-12103.json'))
     sbs = loadfn(os.path.join(test_dir,'boltztrap/dft_bs_sym_line.json'))
     kpoints = [kp.frac_coords for kp in sbs.kpoints]
     labels_dict = {k: sbs.labels_dict[k].frac_coords for k in sbs.labels_dict}
     for kpt_line,labels_dict in zip([None,sbs.kpoints,kpoints],[None,sbs.labels_dict,labels_dict]):
         print(kpt_line)
         sbs_bzt = self.bz_bands.get_symm_bands(structure,-5.25204548,kpt_line=kpt_line,labels_dict=labels_dict)
         self.assertAlmostEqual(len(sbs_bzt.bands[Spin.up]),20)
         self.assertAlmostEqual(len(sbs_bzt.bands[Spin.up][1]),143)
예제 #9
0
 def test_dumpf_loadf(self):
     d = {"hello": "world"}
     dumpfn(d, "monte_test.json", indent=4)
     d2 = loadfn("monte_test.json")
     self.assertEqual(d, d2)
     os.remove("monte_test.json")
     dumpfn(d, "monte_test.yaml", default_flow_style=False)
     d2 = loadfn("monte_test.yaml")
     self.assertEqual(d, d2)
     dumpfn(d, "monte_test.yaml", Dumper=Dumper)
     d2 = loadfn("monte_test.yaml")
     os.remove("monte_test.yaml")
예제 #10
0
    def setUp(self):
        bs = loadfn(os.path.join(test_dir, "PbTe_bandstructure.json"))
        bs_sp = loadfn(os.path.join(test_dir, "N2_bandstructure.json"))
        self.loader = BandstructureLoader(bs, vrun.structures[-1])
        self.assertIsNotNone(self.loader)

        self.loader_sp_up = BandstructureLoader(bs_sp, vrun_sp.structures[-1],spin=1)
        self.loader_sp_dn = BandstructureLoader(bs_sp, vrun_sp.structures[-1],spin=-1)
        self.assertTupleEqual(self.loader_sp_up.ebands.shape, (12, 198))
        self.assertTupleEqual(self.loader_sp_dn.ebands.shape, (12, 198))
        self.assertIsNotNone(self.loader_sp_dn)
        self.assertIsNotNone(self.loader_sp_up)
        
        warnings.simplefilter("ignore")
예제 #11
0
    def test_get_wf_from_spec_dict(self):
        d = loadfn(os.path.join(os.path.abspath(os.path.dirname(__file__)), "spec.yaml"))
        wf = get_wf_from_spec_dict(self.structure, d)
        self.assertEqual(len(wf.fws), 4)
        for f in wf.fws:
            self.assertEqual(f.spec['_tasks'][-1]["db_file"], "db.json")

        self.assertEqual(sorted([len(v) for v in wf.links.values()]),
                         [0, 0, 1, 2])

        self.assertEqual(wf.name, "Si:band structure")
        d = loadfn(os.path.join(os.path.abspath(os.path.dirname(__file__)),
                                "badspec.yaml"))

        self.assertRaises(ImportError, get_wf_from_spec_dict, self.structure, d)
예제 #12
0
    def test_get_corrections_for_Mo_Ta_and_W(self):
        os.chdir(os.path.join(ROOT, 'Mo_Ta_W_controls'))
        get_corrections(write_yaml=True)
        test_corrections = loadfn('ion_corrections.yaml')
        control_corrections = {'Mo': 0.379, 'Ta': 0.161, 'W': 0.635}
        test_mus = loadfn('chemical_potentials.yaml')
        control_mus = {'Mo': -8.885, 'Ta': -10.17, 'W': -11.179, 'O': -4.658}

        for elt in control_corrections:
            self.assertEqual(control_corrections[elt], test_corrections[elt])
        for elt in control_mus:
            self.assertEqual(control_mus[elt], test_mus[elt])
        os.system('rm ion_corrections.yaml')
        os.system('rm chemical_potentials.yaml')
        os.chdir(ROOT)
예제 #13
0
 def from_config(cls, config):
     db_yaml = os.path.expandvars(config.db_yaml)
     db_cfg = loadfn(db_yaml)
     client = MongoClient(db_cfg['host'], db_cfg['port'], j=False)
     db = client[db_cfg['db']]
     try:
         db.authenticate(db_cfg['username'], db_cfg['password'])
     except:
         logger.error('authentication failed for {}'.format(db_yaml))
         sys.exit(1)
     logger.debug('using DB from {}'.format(db_yaml))
     duplicates_file = os.path.expandvars(config.duplicates_file)
     duplicates = loadfn(duplicates_file) \
             if os.path.exists(duplicates_file) else {}
     return OstiMongoAdapter(db, duplicates, config.osti.elink)
예제 #14
0
파일: pydii.py 프로젝트: bocklund/pymatgen
def get_solute_def_profile(args):
    if not args.mpid:
        print ('===========\nERROR: mpid is not given.\n===========')
        return
    if not args.solute:
        print ('===========\nERROR: Solute atom is not given.\n===========')
        return

    mpid = args.mpid 
    solute = args.solute 
    solute_conc = args.solute_conc/100.0
    T = args.T 

    def_file = mpid + '_raw_defect_energy.json'
    raw_energy_dict = loadfn(def_file,cls=MontyDecoder)
    sol_file = mpid+'_solute-'+solute+'_raw_defect_energy.json'
    sol_raw_energy_dict = loadfn(sol_file,cls=MontyDecoder)

    #try:
    e0 = raw_energy_dict[mpid]['e0']
    struct = raw_energy_dict[mpid]['structure']
    vacs = raw_energy_dict[mpid]['vacancies']
    antisites = raw_energy_dict[mpid]['antisites']
    solutes = sol_raw_energy_dict[mpid]['solutes']

    for vac_def in vacs:
        if not vac_def:
            print('All vacancy defect energies not present')
            continue
    for antisite_def in antisites:
        if not antisite_def:
            print('All antisite defect energies not preset')
            continue
    for solute_def in solutes:
        if not solute_def:
            print('All solute defect energies not preset')
            continue

    try:
        def_conc = solute_defect_density(struct, e0, vacs, 
                antisites, solutes, solute_concen=solute_conc, T=T, 
                plot_style="gnuplot")
        fl_nm = args.mpid+'_solute-'+args.solute+'_def_concentration.dat'
        with open(fl_nm,'w') as fp: 
            for row in def_conc:
                print >> fp, row 
    except:
        raise
예제 #15
0
    def from_db_file(cls, db_file, admin=True):
        """
        Create MMDB from database file. File requires host, port, database,
        collection, and optionally admin_user/readonly_user and
        admin_password/readonly_password

        Args:
            db_file (str): path to the file containing the credentials
            admin (bool): whether to use the admin user

        Returns:
            MMDb object
        """
        creds = loadfn(db_file)

        if admin and "admin_user" not in creds and "readonly_user" in creds:
            raise ValueError("Trying to use admin credentials, "
                             "but no admin credentials are defined. "
                             "Use admin=False if only read_only "
                             "credentials are available.")

        if admin:
            user = creds.get("admin_user")
            password = creds.get("admin_password")
        else:
            user = creds.get("readonly_user")
            password = creds.get("readonly_password")

        return cls(creds["host"], int(creds["port"]), creds["database"], creds["collection"],
                   user, password)
예제 #16
0
    def from_db_file(cls, db_file, admin=True):
        """
        Args:
            db_file (str): path to the filepad cred file

        Returns:
            FilePad object
        """
        creds = loadfn(db_file)

        if admin:
            user = creds.get("admin_user")
            password = creds.get("admin_password")
        else:
            user = creds.get("readonly_user")
            password = creds.get("readonly_password")

        coll_name = creds.get("filepad", "filepad")
        gfs_name = creds.get("filepad_gridfs", "filepad_gfs")

        return cls(
            creds.get("host", "localhost"),
            int(creds.get("port", 27017)),
            creds.get("name", "fireworks"),
            user,
            password,
            coll_name,
            gfs_name,
        )
예제 #17
0
    def setUp(self):
        self.vbm_val = 2.6682
        self.gap = 1.5
        self.entries = list(loadfn(os.path.join(os.path.dirname(__file__), "GaAs_test_defentries.json")).values())
        for entry in self.entries:
            entry.parameters.update( {'vbm': self.vbm_val})
        self.pd = DefectPhaseDiagram(self.entries, self.vbm_val, self.gap)
        self.mu_elts = {Element("As"): -4.658070555, Element("Ga"): -3.7317319750000006}

        # make Vac_As (q= -2) only defect test single-stable-charge exceptions
        self.extra_entry = DefectEntry(self.entries[5].defect.copy(), 100.)
        sep_entries = [ent for ent in self.entries if not (ent.name == 'Vac_As_mult4' and
                                                           ent.charge in [-2,-1,0,1,2])]
        sep_entries.append( self.extra_entry.copy())
        self.sep_pd = DefectPhaseDiagram( sep_entries, self.vbm_val, self.gap)

        # make Vac_As (q= -2) is incompatible for larger supercell
        ls_entries = self.entries[:]
        for entry in ls_entries:
            if entry.name == 'Vac_As_mult4' and entry.charge == -2.:
                entry.parameters['is_compatible'] = False
        self.pd_ls_fcTrue = DefectPhaseDiagram(ls_entries, self.vbm_val, self.gap, filter_compatible=True)
        self.pd_ls_fcFalse = DefectPhaseDiagram(ls_entries, self.vbm_val, self.gap, filter_compatible=False)

        # load complete dos for fermi energy solving
        with open(os.path.join(test_dir, "complete_dos.json"), "r") as f:
            dos_dict = json.load(f)
        self.dos = CompleteDos.from_dict(dos_dict)
예제 #18
0
 def __init__(self, config_file, correct_peroxide=True):
     c = loadfn(config_file)
     self.oxide_correction = c['OxideCorrections']
     self.sulfide_correction = c.get('SulfideCorrections', defaultdict(
         float))
     self.name = c['Name']
     self.correct_peroxide = correct_peroxide
예제 #19
0
def im_sol_sub_def_profile():
    m_description = 'Command to generate solute defect site preference ' \
                    'in an intermetallics from the raw defect energies.' 

    parser = ArgumentParser(description=m_description)

    parser.add_argument("--mpid",
            type=str.lower,
            help="Materials Project id of the intermetallic structure.\n" \
                 "For more info on Materials Project, please refer to " \
                 "www.materialsproject.org")

    parser.add_argument("--solute", help="Solute Element")

    parser.add_argument("--sol_conc", type=float, default=1.0,
            help="Solute Concentration in %. Default is 1%")

    parser.add_argument("--T", type=float, help="Temperature in Kelvin")
    parser.add_argument("--trail_mu_file",  default=None,
            help="Trial chemcal potential in dict format stored in file")

    args = parser.parse_args()
    print args

    if not args.mpid:
        print ('===========\nERROR: mpid is not given.\n===========')
        return
    if not args.T:
        print ('===========\nERROR: Temperature is not given.\n===========')
        return
    if not args.solute:
        print ('===========\nERROR: Solute atom is not given.\n===========')
        return
    def_file = args.mpid+'_raw_defect_energy.json'
    sol_file = args.mpid+'_solute-'+args.solute+'_raw_defect_energy.json'
    sol_conc = args.sol_conc/100.0 # Convert from percentage
    if not os.path.exists(def_file):
        print ('===========\nERROR: Defect file not found.\n===========')
        return
    if not os.path.exists(sol_file):
        print ('===========\nERROR: Solute file not found.\n===========')
        return
    if args.trail_mu_file:
        trail_chem_pot = loadfn(args.trial_mu_file,cls=MontyDecoder)
    else:
        trail_chem_pot = None

    conc_dat = get_solute_def_profile(args.mpid, args.solute, sol_conc, 
            args.T, def_file, sol_file, trial_chem_pot=trail_chem_pot)

    if conc_dat:
        #print plot_dict.keys()
        #for key in plot_dict:
        #    print key, type(key)
        #    print plot_dict[key]
        fl_nm = args.mpid+'_solute-'+args.solute+'_site_pref.dat'
        #fl_nm = args.mpid+'_def_concentration.dat'
        with open(fl_nm,'w') as fp:
            for row in conc_dat:
                print >> fp, row
예제 #20
0
 def test_dumpfn_loadfn(self):
     d = {"hello": "world"}
     dumpfn(d, "monte_test.json", indent=4)
     d2 = loadfn("monte_test.json")
     self.assertEqual(d, d2)
     os.remove("monte_test.json")
     dumpfn(d, "monte_test.yaml", default_flow_style=False)
     d2 = loadfn("monte_test.yaml")
     self.assertEqual(d, d2)
     dumpfn(d, "monte_test.yaml", Dumper=Dumper)
     d2 = loadfn("monte_test.yaml")
     os.remove("monte_test.yaml")
     dumpfn(d, "monte_test.mpk")
     d2 = loadfn("monte_test.mpk")
     self.assertEqual(d, {k.decode('utf-8'): v.decode('utf-8') for k, v in d2.items()})
     os.remove("monte_test.mpk")
예제 #21
0
    def __init__(self, materials_write, counter_write, tasks_read, tasks_prefix="t",
                 materials_prefix="m", query=None, settings_file=None):
        """
        Create a materials collection from a tasks collection.

        Args:
            materials_write (pymongo.collection): mongodb collection for materials (write access needed)
            counter_write (pymongo.collection): mongodb collection for counter (write access needed)
            tasks_read (pymongo.collection): mongodb collection for tasks (suggest read-only for safety)
            tasks_prefix (str): a string prefix for tasks, e.g. "t" gives a task_id like "t-132"
            materials_prefix (str): a string prefix to prepend to material_ids
            query (dict): a pymongo query on tasks_read for which tasks to include in the builder
            settings_file (str): filepath to a custom settings path
        """

        settings_file = settings_file or os.path.join(
            module_dir, "tasks_materials_settings.yaml")
        x = loadfn(settings_file)
        self.supported_task_labels = x['supported_task_labels']
        self.property_settings = x['property_settings']
        self.indexes = x.get('indexes', [])
        self.properties_root = x.get('properties_root', [])

        self._materials = materials_write
        if self._materials.count() == 0:
            self._build_indexes()

        self._counter = counter_write
        if self._counter.find({"_id": "materialid"}).count() == 0:
            self._counter.insert_one({"_id": "materialid", "c": 0})

        self._tasks = tasks_read
        self._t_prefix = tasks_prefix
        self._m_prefix = materials_prefix
        self.query = query
예제 #22
0
    def __init__(self, user_incar_settings, 
                 constrain_total_magmom=False, sort_structure=False,
                 kpoints_density=1000, sym_prec=0.1):
        self.sym_prec = sym_prec
        
        DictVaspInputSet.__init__(
            self, "MaterialsProject Static",
            loadfn(os.path.join(MODULE_DIR, "MPVaspInputSet.yaml")),
            constrain_total_magmom=constrain_total_magmom,
            sort_structure=sort_structure)

        self.user_incar_settings = user_incar_settings

        # impose tetrahedron method
        self.incar_settings.update(
            {"IBRION": -1, "ISMEAR": -5, "LCHARG": False,
             "LORBIT": 11, "LWAVE": False, "NSW": 0, "ISYM": 0, "ICHARG": 11})

        # this variable may have been used in the  ground state calculation,
        #  but is not relevant to the tetrahedron DOS calculation; its superfluous
        # presence in the INCAR leads to (human) confusion (VASP doesn't care).
        self.incar_settings.pop('SIGMA',None) 

        self.kpoints_settings.update({"kpoints_density": kpoints_density})

        # Set dense DOS output
        self.incar_settings.update({"NEDOS": 2001})

        if "NBANDS" not in user_incar_settings:
            raise KeyError("For NonSCF runs, NBANDS value from SC runs is "
                           "required!")
        else:
            self.incar_settings.update(user_incar_settings)
예제 #23
0
    def __init__(self, materials_write, counter_write, tasks_read, tasks_prefix="t", materials_prefix="m"):
        """
        Create a materials collection from a tasks collection.

        Args:
            materials_write (pymongo.collection): mongodb collection for materials (write access needed)
            counter_write (pymongo.collection): mongodb collection for counter (write access needed)
            tasks_read (pymongo.collection): mongodb collection for tasks (suggest read-only for safety)
        """
        x = loadfn(os.path.join(module_dir, "tasks_materials_settings.yaml"))
        self.supported_task_labels = x['supported_task_labels']
        self.property_settings = x['property_settings']
        self.indexes = x.get('indexes', [])
        self.properties_root = x.get('properties_root', [])

        self._materials = materials_write
        if self._materials.count() == 0:
            self._build_indexes()

        self._counter = counter_write
        if self._counter.find({"_id": "materialid"}).count() == 0:
            self._counter.insert_one({"_id": "materialid", "c": 0})

        self._tasks = tasks_read

        self._t_prefix = tasks_prefix
        self._m_prefix = materials_prefix
예제 #24
0
def run(args):
    FORMAT = '%(asctime)s %(message)s'
    logging.basicConfig(format=FORMAT, level=logging.INFO, filename="run.log")
    logging.info("Spec file is %s" % args.spec_file)
    d = loadfn(args.spec_file[0])
    c = Custodian.from_spec(d)
    c.run()
예제 #25
0
def get_def_profile(mpid, T,  file):

    raw_energy_dict = loadfn(file,cls=MontyDecoder)

    e0 = raw_energy_dict[mpid]['e0']
    struct = raw_energy_dict[mpid]['structure']
    vacs = raw_energy_dict[mpid]['vacancies']
    antisites = raw_energy_dict[mpid]['antisites']
    vacs.sort(key=lambda entry: entry['site_index'])
    antisites.sort(key=lambda entry: entry['site_index'])
    for vac_def in vacs:
        if not vac_def:
            print 'All vacancy defect energies not present'
            continue
    for antisite_def in antisites:
        if not antisite_def:
            print 'All antisite defect energies not preset'
            continue

    try:
        def_conc, def_en, mu = compute_defect_density(struct, e0, vacs, antisites, T,
                plot_style='gnuplot')
        return def_conc, def_en, mu
    except:
        raise
예제 #26
0
 def test_get_comlete_dos(self):
     structure = loadfn(os.path.join(test_dir,'boltztrap/structure_mp-12103.json'))
     cdos = self.bz_up.get_complete_dos(structure,self.bz_dw)
     self.assertIs(cdos.densities.keys()[0],Spin.down)
     self.assertIs(cdos.densities.keys()[1],Spin.up)
     self.assertAlmostEqual(cdos.get_spd_dos()[OrbitalType.p].densities[Spin.up][3134],43.839230100999991)
     self.assertAlmostEqual(cdos.get_spd_dos()[OrbitalType.s].densities[Spin.down][716],6.5383268000000001)
예제 #27
0
파일: fw_workflows.py 프로젝트: kidaa/abipy
    def get_final_structure_and_history(cls, wf):
        assert wf.metadata['workflow_class'] == cls.workflow_class
        assert wf.metadata['workflow_module'] == cls.workflow_module
        ioncell = -1
        final_fw_id = None
        for fw_id, fw in wf.id_fw.items():
            if 'wf_task_index' in fw.spec and fw.spec['wf_task_index'][:8] == 'ioncell_':
                try:
                    this_ioncell =  int(fw.spec['wf_task_index'].split('_')[-1])
                except ValueError:
                    # skip if the index is not an int
                    continue
                if this_ioncell > ioncell:
                    ioncell = this_ioncell
                    final_fw_id = fw_id
        if final_fw_id is None:
            raise RuntimeError('Final strucure not found ...')
        myfw = wf.id_fw[final_fw_id]
        #TODO add a check on the state of the launches
        last_launch = (myfw.archived_launches + myfw.launches)[-1]
        #TODO add a cycle to find the instance of AbiFireTask?
        myfw.tasks[-1].set_workdir(workdir=last_launch.launch_dir)
        structure = myfw.tasks[-1].get_final_structure()
        history = loadfn(os.path.join(last_launch.launch_dir, 'history.json'))

        return {'structure': structure.as_dict(), 'history': history}
예제 #28
0
    def setup(self):
        """
        Performs initial setup for VaspJob, including overriding any settings
        and backing up.
        """
        decompress_dir('.')

        if self.backup:
            for f in VASP_INPUT_FILES:
                shutil.copy(f, "{}.orig".format(f))

        if self.auto_npar:
            try:
                incar = Incar.from_file("INCAR")
                # Only optimized NPAR for non-HF and non-RPA calculations.
                if not (incar.get("LHFCALC") or incar.get("LRPA") or
                        incar.get("LEPSILON")):
                    if incar.get("IBRION") in [5, 6, 7, 8]:
                        # NPAR should not be set for Hessian matrix
                        # calculations, whether in DFPT or otherwise.
                        del incar["NPAR"]
                    else:
                        import multiprocessing
                        # try sge environment variable first
                        # (since multiprocessing counts cores on the current
                        # machine only)
                        ncores = os.environ.get('NSLOTS') or \
                            multiprocessing.cpu_count()
                        ncores = int(ncores)
                        for npar in range(int(math.sqrt(ncores)),
                                          ncores):
                            if ncores % npar == 0:
                                incar["NPAR"] = npar
                                break
                    incar.write_file("INCAR")
            except:
                pass

        if self.auto_continue:
            if os.path.exists("continue.json"):
                actions = loadfn("continue.json").get("actions")
                logger.info("Continuing previous VaspJob. Actions: {}".format(actions))
                backup(VASP_BACKUP_FILES, prefix="prev_run")
                VaspModder().apply_actions(actions)

            else:
                # Default functionality is to copy CONTCAR to POSCAR and set
                # ISTART to 1 in the INCAR, but other actions can be specified
                if self.auto_continue is True:
                    actions = [{"file": "CONTCAR",
                                "action": {"_file_copy": {"dest": "POSCAR"}}},
                               {"dict": "INCAR",
                                "action": {"_set": {"ISTART": 1}}}]
                else:
                    actions = self.auto_continue
                dumpfn({"actions": actions}, "continue.json")

        if self.settings_override is not None:
            VaspModder().apply_actions(self.settings_override)
예제 #29
0
 def setUp(self):
     self.seq_tc = [t for t in np.arange(4*3**3).reshape((4, 3, 3, 3))]
     self.seq_tc = TensorCollection(self.seq_tc)
     self.rand_tc = TensorCollection([t for t in np.random.random((4, 3, 3))])
     self.diff_rank = TensorCollection([np.ones([3]*i) for i in range(2, 5)])
     self.struct = self.get_structure("Si")
     ieee_file_path = os.path.join(test_dir, "ieee_conversion_data.json")
     self.ieee_data = loadfn(ieee_file_path)
예제 #30
0
    def test_group_entries_by_structure(self):

        entries = loadfn(str(test_dir / "TiO2_entries.json"))
        groups = group_entries_by_structure(entries)
        self.assertEqual(sorted([len(g) for g in groups]),
                         [1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 4])
        self.assertLess(len(groups), len(entries))
        # Make sure no entries are left behind
        self.assertEqual(sum([len(g) for g in groups]), len(entries))
예제 #31
0
 def test_get_symm_bands(self):
     structure = loadfn(
         os.path.join(test_dir, 'boltztrap/structure_mp-12103.json'))
     sbs_bzt = self.bz_bands.get_symm_bands(structure, -5.25204548)
     self.assertAlmostEqual(len(sbs_bzt.bands[Spin.up]), 20)
     self.assertAlmostEqual(len(sbs_bzt.bands[Spin.up][1]), 143)
예제 #32
0
def get_wf_single(structure, WORKFLOW="get_wf_gibbs", settings={}):
    """
    Get a single workflow

    Parameters
        structure: pymatgen.Structure
            The structure
        WORKFLOW: str
            The name of the workflow, now only gibbs energy workflow(get_wf_gibbs) is supported
        settings: dict
            User settings for the workflow
    Return
    """
    ################ PARAMETERS FOR WF #############################
    #str, the absolute path of db.json file, e.g. /storage/home/mjl6505/atomate/config/db.json
    #  If None, it will use the configuration in fireworks
    db_file = settings.get('db_file', None)
    #list, the MAGMOM of the structure, e.g. [4.0, 4.0, -4.0, -4.0]
    magmom = settings.get('magmom', None)
    #int, the number of initial deformations, e.g. 7
    num_deformations = settings.get('num_deformations', 7)
    #list/tuple(min, max) or float(-max, max), the maximum amplitude of deformation, e.g. (-0.05, 0.1) means (0.95, 1.1) in volume
    deformation_fraction = settings.get('deformation_fraction', (-0.1, 0.1))
    #float, minimum ratio of Volumes spacing, e.g. 0.03
    volume_spacing_min = settings.get('volume_spacing_min', 0.03)
    #bool, run phonon(True) or not(False)
    phonon = settings.get('phonon', False)
    #list(3x3), the supercell matrix for phonon, e.g. [[2.0, 0, 0], [0, 2.0, 0], [0, 0, 2.0]]
    phonon_supercell_matrix = settings.get('phonon_supercell_matrix', None)
    #float, the mimimum of temperature in QHA process, e.g. 5
    t_min = settings.get('t_min', 5)
    #float, the maximum of temperature in QHA process, e.g. 2000
    t_max = settings.get('t_max', 2000)
    #float, the step of temperature in QHA process, e.g. 5
    t_step = settings.get('t_step', 5)
    #float, acceptable value for average RMS, recommend >= 0.005
    tolerance = settings.get('tolerance', 0.01)
    #str, the vasp command, if None then find in the FWorker configuration
    vasp_cmd = settings.get('vasp_cmd', None)
    #dict, metadata to be included, this parameter is useful for filter the data, e.g. metadata={"phase": "BCC_A2", "tag": "AFM"}
    metadata = settings.get('metadata', None)
    #float, the tolerannce for symmetry, e.g. 0.05
    symmetry_tolerance = settings.get('symmetry_tolerance', 0.05)
    #bool, set True to pass initial VASP running if the results exist in DB, use carefully to keep data consistent.
    passinitrun = settings.get('passinitrun', False)
    #bool, Whether run isif=2 calculation before isif=4 running
    run_isif2 = settings.get('run_isif2', False)
    #bool, Whether pass isif=4 calculation.
    pass_isif4 = settings.get('pass_isif4', False)
    #Set the path already exists for new static calculations; if set as '', will try to get the path from db_file
    relax_path = settings.get('relax_path', '')
    #dict, dict of class ModifyIncar with keywords in Workflow name. e.g.
    """
    modify_incar_params = { 'Full relax': {'incar_update': {"LAECHG":False,"LCHARG":False,"LWAVE":False}},
                            'PreStatic': {'incar_update': {"LAECHG":False,"LCHARG":False,"LWAVE":False}},
                            'PS2': {'incar_update': {"LAECHG":False,"LCHARG":False,"LWAVE":False}}, 
                            'static': {'incar_update': {"LAECHG":False,"LCHARG":False,"LWAVE":False}},
    """
    modify_incar_params = settings.get('modify_incar_params', {})
    #dict, dict of class ModifyKpoints with keywords in Workflow name, similar with modify_incar_params
    modify_kpoints_params = settings.get('modify_kpoints_params', {})
    #bool, print(True) or not(False) some informations, used for debug
    verbose = settings.get('verbose', False)

    if magmom:
        structure.add_site_property('magmom', magmom)
    if not db_file:
        from fireworks.fw_config import config_to_dict
        db_file = loadfn(config_to_dict()["FWORKER_LOC"])["env"]["db_file"]

    if WORKFLOW == "get_wf_gibbs":
        #Currently, only this workflow is supported
        wf = get_wf_gibbs(structure,
                          num_deformations=num_deformations,
                          deformation_fraction=deformation_fraction,
                          phonon=phonon,
                          phonon_supercell_matrix=phonon_supercell_matrix,
                          t_min=t_min,
                          t_max=t_max,
                          t_step=t_step,
                          tolerance=tolerance,
                          volume_spacing_min=volume_spacing_min,
                          vasp_cmd=vasp_cmd,
                          db_file=db_file,
                          metadata=metadata,
                          name='EV_QHA',
                          symmetry_tolerance=symmetry_tolerance,
                          run_isif2=run_isif2,
                          pass_isif4=pass_isif4,
                          passinitrun=passinitrun,
                          relax_path=relax_path,
                          modify_incar_params=modify_incar_params,
                          modify_kpoints_params=modify_kpoints_params,
                          verbose=verbose)
    else:
        raise ValueError(
            "Currently, only the gibbs energy workflow is supported.")
    return wf
예제 #33
0

# Let's initialize some module level properties.

# List of electronegative elements specified in M. O'Keefe, & N. Brese,
# JACS, 1991, 113(9), 3226-3229. doi:10.1021/ja00009a002.
ELECTRONEG = [Element(sym) for sym in ["H", "B", "C", "Si",
                                       "N", "P", "As", "Sb",
                                       "O", "S", "Se", "Te",
                                       "F", "Cl", "Br", "I"]]

module_dir = os.path.dirname(os.path.abspath(__file__))

# Read in BV parameters.
BV_PARAMS = {}
for k, v in loadfn(os.path.join(module_dir, "bvparam_1991.yaml")).items():
    BV_PARAMS[Element(k)] = v

# Read in yaml containing data-mined ICSD BV data.
all_data = loadfn(os.path.join(module_dir, "icsd_bv.yaml"))
ICSD_BV_DATA = {Specie.from_string(sp): data
                for sp, data in all_data["bvsum"].items()}
PRIOR_PROB = {Specie.from_string(sp): data
              for sp, data in all_data["occurrence"].items()}


def calculate_bv_sum(site, nn_list, scale_factor=1.0):
    """
    Calculates the BV sum of a site.

    Args:
예제 #34
0
    def test_json_processing(self):

        with ScratchDir('.'):
            os.environ['BEEP_ROOT'] = os.getcwd()
            os.mkdir("data-share")
            os.mkdir(os.path.join("data-share", "structure"))

            # Create dummy json obj
            json_obj = {
                "mode": self.events_mode,
                "file_list": [self.arbin_file, "garbage_file"],
                'run_list': [0, 1],
                "validity": ['valid', 'invalid']
            }
            json_string = json.dumps(json_obj)
            # Get json output from method
            json_output = process_file_list_from_json(json_string)
            reloaded = json.loads(json_output)

            # Actual tests here
            # Ensure garbage file doesn't have output string
            self.assertEqual(reloaded['invalid_file_list'][0], 'garbage_file')

            # Ensure first is correct
            loaded_processed_cycler_run = loadfn(reloaded['file_list'][0])
            loaded_from_raw = RawCyclerRun.from_file(
                json_obj['file_list'][0]).to_processed_cycler_run()
            self.assertTrue(
                np.all(loaded_processed_cycler_run.summary ==
                       loaded_from_raw.summary),
                "Loaded processed cycler_run is not equal to that loaded from raw file"
            )

        # Test same functionality with json file
        with ScratchDir('.'):
            os.environ['BEEP_ROOT'] = os.getcwd()
            os.mkdir("data-share")
            os.mkdir(os.path.join("data-share", "structure"))

            json_obj = {
                "mode": self.events_mode,
                "file_list": [self.arbin_file, "garbage_file"],
                'run_list': [0, 1],
                "validity": ['valid', 'invalid']
            }
            dumpfn(json_obj, "test.json")
            # Get json output from method
            json_output = process_file_list_from_json("test.json")
            reloaded = json.loads(json_output)

            # Actual tests here
            # Ensure garbage file doesn't have output string
            self.assertEqual(reloaded['invalid_file_list'][0], 'garbage_file')

            # Ensure first is correct
            loaded_processed_cycler_run = loadfn(reloaded['file_list'][0])
            loaded_from_raw = RawCyclerRun.from_file(
                json_obj['file_list'][0]).to_processed_cycler_run()
            self.assertTrue(
                np.all(loaded_processed_cycler_run.summary ==
                       loaded_from_raw.summary),
                "Loaded processed cycler_run is not equal to that loaded from raw file"
            )
예제 #35
0
파일: ehull.py 프로젝트: lq131/garnetdnn
            "num_atoms": 2,
            "max_ordering": 10,
            'cn': "XII"
        },
        'b': {
            "num_atoms": 2,
            "max_ordering": 10,
            'cn': "VI"
        }
    }
}

DATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../data")
GARNET_CALC_ENTRIES_PATH = os.path.join(
    DATA_DIR, "garnet/garnet_calc_entries_dict.json")
GARNET_CALC_ENTRIES = loadfn(GARNET_CALC_ENTRIES_PATH)
PEROVSKITE_CALC_ENTRIES_PATH = os.path.join(
    DATA_DIR, "perovskite/perov_calc_entries_dict.json")
PEROVSKITE_CALC_ENTRIES = loadfn(PEROVSKITE_CALC_ENTRIES_PATH)
CALC_ENTRIES = {
    'garnet': GARNET_CALC_ENTRIES,
    'perovskite': PEROVSKITE_CALC_ENTRIES
}
GARNET_EHULL_ENTRIES_PATH = os.path.join(
    DATA_DIR, "garnet/garnet_ehull_entries_dict.json")
GARNET_EHULL_ENTRIES = loadfn(GARNET_EHULL_ENTRIES_PATH)
PEROVSKITE_EHULL_ENTRIES_PATH = os.path.join(
    DATA_DIR, "perovskite/perov_ehull_entries_dict.json")
PEROVSKITE_EHULL_ENTRIES = loadfn(PEROVSKITE_EHULL_ENTRIES_PATH)
EHULL_ENTRIES = {
    'garnet': GARNET_EHULL_ENTRIES,
예제 #36
0
    def test_get_diagnostic(self):
        os.environ['BEEP_ROOT'] = TEST_FILE_DIR

        cycler_run = RawCyclerRun.from_file(self.maccor_file_w_parameters)

        v_range, resolution, nominal_capacity, full_fast_charge, diagnostic_available = \
            cycler_run.determine_structuring_parameters()
        self.assertEqual(nominal_capacity, 4.84)
        self.assertEqual(v_range, [2.7, 4.2])
        self.assertEqual(diagnostic_available['cycle_type'],
                         ['reset', 'hppc', 'rpt_0.2C', 'rpt_1C', 'rpt_2C'])
        diag_summary = cycler_run.get_diagnostic_summary(diagnostic_available)
        self.assertEqual(diag_summary.cycle_index.tolist(), [
            1, 2, 3, 4, 5, 36, 37, 38, 39, 40, 141, 142, 143, 144, 145, 246,
            247
        ])
        self.assertEqual(diag_summary.cycle_type.tolist(), [
            'reset', 'hppc', 'rpt_0.2C', 'rpt_1C', 'rpt_2C', 'reset', 'hppc',
            'rpt_0.2C', 'rpt_1C', 'rpt_2C', 'reset', 'hppc', 'rpt_0.2C',
            'rpt_1C', 'rpt_2C', 'reset', 'hppc'
        ])
        diag_interpolated = cycler_run.get_interpolated_diagnostic_cycles(
            diagnostic_available, resolution=1000)
        diag_cycle = diag_interpolated[
            (diag_interpolated.cycle_type == 'rpt_0.2C')
            & (diag_interpolated.step_type == 1)]
        self.assertEqual(diag_cycle.cycle_index.unique().tolist(),
                         [3, 38, 143])
        plt.figure()
        plt.plot(diag_cycle.discharge_capacity, diag_cycle.voltage)
        plt.savefig(
            os.path.join(TEST_FILE_DIR,
                         "discharge_capacity_interpolation.png"))
        plt.figure()
        plt.plot(diag_cycle.voltage, diag_cycle.discharge_dQdV)
        plt.savefig(
            os.path.join(TEST_FILE_DIR, "discharge_dQdV_interpolation.png"))

        self.assertEqual(len(diag_cycle.index), 3000)

        hppcs = diag_interpolated[(diag_interpolated.cycle_type == 'hppc')
                                  & pd.isnull(diag_interpolated.current)]
        self.assertEqual(len(hppcs), 0)

        hppc_dischg1 = diag_interpolated[
            (diag_interpolated.cycle_index == 37)
            & (diag_interpolated.step_type == 2)
            & (diag_interpolated.step_index_counter == 3)
            & ~pd.isnull(diag_interpolated.current)]
        print(hppc_dischg1)
        plt.figure()
        plt.plot(hppc_dischg1.test_time, hppc_dischg1.voltage)
        plt.savefig(os.path.join(TEST_FILE_DIR, "hppc_discharge_pulse_1.png"))
        self.assertEqual(len(hppc_dischg1), 176)

        processed_cycler_run = cycler_run.to_processed_cycler_run()
        self.assertNotIn(
            diag_summary.index.tolist(),
            processed_cycler_run.cycles_interpolated.cycle_index.unique())
        processed_cycler_run_loc = os.path.join(TEST_FILE_DIR,
                                                'processed_diagnostic.json')
        dumpfn(processed_cycler_run, processed_cycler_run_loc)
        proc_size = os.path.getsize(processed_cycler_run_loc)
        self.assertLess(proc_size, 29000000)
        test = loadfn(processed_cycler_run_loc)
        self.assertIsInstance(test.diagnostic_summary, pd.DataFrame)
        os.remove(processed_cycler_run_loc)
예제 #37
0
def _load_yaml_config(filename):
    config = loadfn(os.path.join(MODULE_DIR, "%s.yaml" % filename))
    return config
예제 #38
0
 def test_capacities_at_set_cycles(self):
     pcycler_run = loadfn(self.pcycler_run_file)
     capacities = pcycler_run.capacities_at_set_cycles()
     self.assertLessEqual(capacities.iloc[0, 0], 1.1)
예제 #39
0
def test_json_file_mixin(diele_func_data, tmpdir):
    tmpdir.chdir()
    diele_func_data.to_json_file()
    actual = loadfn("diele_func_data.json")
    assert actual.diele_func_real == diele_func_data.diele_func_real
예제 #40
0
This module provides some useful functions for dealing with magnetic Structures
(e.g. Structures with associated magmom tags).
"""

__author__ = "Matthew Horton"
__copyright__ = "Copyright 2017, The Materials Project"
__version__ = "0.1"
__maintainer__ = "Matthew Horton"
__email__ = "*****@*****.**"
__status__ = "Development"
__date__ = "Feb 2017"

MODULE_DIR = os.path.dirname(os.path.abspath(__file__))

try:
    DEFAULT_MAGMOMS = loadfn(os.path.join(MODULE_DIR, "default_magmoms.yaml"))
except Exception:
    warnings.warn(
        "Could not load default_magmoms.yaml, falling back to VASPIncarBase.yaml"
    )
    DEFAULT_MAGMOMS = loadfn(
        os.path.join(MODULE_DIR, "../../io/vasp/VASPIncarBase.yaml"))
    DEFAULT_MAGMOMS = DEFAULT_MAGMOMS["MAGMOM"]


@unique
class Ordering(Enum):
    """
    Enumeration defining possible magnetic orderings.
    """
예제 #41
0
from monty.serialization import loadfn
from pymatgen.analysis.structure_matcher import StructureMatcher
"""
This module uses data from the AFLOW LIBRARY OF CRYSTALLOGRAPHIC PROTOTYPES.
If using this module, please cite their publication appropriately:

Mehl, M. J., Hicks, D., Toher, C., Levy, O., Hanson, R. M., Hart, G., & Curtarolo, S. (2017).
The AFLOW library of crystallographic prototypes: part 1.
Computational Materials Science, 136, S1-S828.
http://doi.org/10.1016/j.commatsci.2017.01.017
"""

module_dir = os.path.dirname(os.path.abspath(__file__))
AFLOW_PROTOTYPE_LIBRARY = loadfn(
    os.path.join(os.path.dirname(os.path.abspath(__file__)),
                 "aflow_prototypes.json"))


class AflowPrototypeMatcher:
    """
    This class will match structures to their crystal prototypes, and will
    attempt to group species together to match structures derived from
    prototypes (e.g. an A_xB_1-x_C from a binary prototype), and will
    give these the names the "-like" suffix.

    This class uses data from the AFLOW LIBRARY OF CRYSTALLOGRAPHIC PROTOTYPES.
    If using this class, please cite their publication appropriately:

    Mehl, M. J., Hicks, D., Toher, C., Levy, O., Hanson, R. M., Hart, G., & Curtarolo, S. (2017).
    The AFLOW library of crystallographic prototypes: part 1.
예제 #42
0
    def run_interrupted(self):
        """
        Runs custodian in a interuppted mode, which sets up and
        validates jobs but doesn't run the executable

        Returns:
            number of remaining jobs

        Raises:
            ValidationError: if a job fails validation
            ReturnCodeError: if the process has a return code different from 0
            NonRecoverableError: if an unrecoverable occurs
            MaxCorrectionsPerJobError: if max_errors_per_job is reached
            MaxCorrectionsError: if max_errors is reached
            MaxCorrectionsPerHandlerError: if max_errors_per_handler is reached
        """
        start = datetime.datetime.now()
        try:
            cwd = os.getcwd()
            v = sys.version.replace("\n", " ")
            logger.info(
                "Custodian started in singleshot mode at {} in {}.".format(
                    start, cwd))
            logger.info("Custodian running on Python version {}".format(v))

            # load run log
            if os.path.exists(Custodian.LOG_FILE):
                self.run_log = loadfn(Custodian.LOG_FILE, cls=MontyDecoder)

            if len(self.run_log) == 0:
                # starting up an initial job - setup input and quit
                job_n = 0
                job = self.jobs[job_n]
                logger.info("Setting up job no. 1 ({}) ".format(job.name))
                job.setup()
                self.run_log.append({
                    "job": job.as_dict(),
                    "corrections": [],
                    "job_n": job_n
                })
                return len(self.jobs)

            # Continuing after running calculation
            job_n = self.run_log[-1]["job_n"]
            job = self.jobs[job_n]

            # If we had to fix errors from a previous run, insert clean log
            # dict
            if len(self.run_log[-1]["corrections"]) > 0:
                logger.info("Reran {}.run due to fixable errors".format(
                    job.name))

            # check error handlers
            logger.info("Checking error handlers for {}.run".format(job.name))
            if self._do_check(self.handlers):
                logger.info("Failed validation based on error handlers")
                # raise an error for an unrecoverable error
                for x in self.run_log[-1]["corrections"]:
                    if not x["actions"] and x["handler"].raises_runtime_error:
                        self.run_log[-1]["handler"] = x["handler"]
                        s = "Unrecoverable error for handler: {}. " "Raising RuntimeError".format(
                            x["handler"])
                        raise NonRecoverableError(s, True, x["handler"])
                logger.info("Corrected input based on error handlers")
                # Return with more jobs to run if recoverable error caught
                # and corrected for
                return len(self.jobs) - job_n

            # check validators
            logger.info("Checking validator for {}.run".format(job.name))
            for v in self.validators:
                if v.check():
                    self.run_log[-1]["validator"] = v
                    logger.info("Failed validation based on validator")
                    s = "Validation failed: {}".format(v)
                    raise ValidationError(s, True, v)

            logger.info("Postprocessing for {}.run".format(job.name))
            job.postprocess()

            # IF DONE WITH ALL JOBS - DELETE ALL CHECKPOINTS AND RETURN
            # VALIDATED
            if len(self.jobs) == (job_n + 1):
                self.finished = True
                return 0

            # Setup next job_n
            job_n += 1
            job = self.jobs[job_n]
            self.run_log.append({
                "job": job.as_dict(),
                "corrections": [],
                "job_n": job_n
            })
            job.setup()
            return len(self.jobs) - job_n

        except CustodianError as ex:
            logger.error(ex.message)
            if ex.raises:
                raise

        finally:
            # Log the corrections to a json file.
            logger.info("Logging to {}...".format(Custodian.LOG_FILE))
            dumpfn(self.run_log,
                   Custodian.LOG_FILE,
                   cls=MontyEncoder,
                   indent=4)
            end = datetime.datetime.now()
            logger.info("Run ended at {}.".format(end))
            run_time = end - start
            logger.info(
                "Run completed. Total time taken = {}.".format(run_time))
            if self.finished and self.gzipped_output:
                gzip_dir(".")
        return None
예제 #43
0
import os
import shutil
import unittest
import tempfile

from monty.os.path import which
from monty.serialization import loadfn
from sklearn.linear_model import LinearRegression

from maml.apps.pes import SNAPotential
from maml import SKLModel
from maml.describer import BispectrumCoefficients

CWD = os.getcwd()
DIR = os.path.abspath(os.path.dirname(__file__))
test_datapool = loadfn(os.path.join(DIR, 'datapool.json'))
coeff_file = os.path.join(DIR, 'SNAP', 'SNAPotential.snapcoeff')
param_file = os.path.join(DIR, 'SNAP', 'SNAPotential.snapparam')


@unittest.skipIf(not which('lmp_serial'), 'No LAMMPS cmd found.')
class SNAPotentialTest(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        cls.this_dir = os.path.dirname(os.path.abspath(__file__))
        cls.test_dir = tempfile.mkdtemp()
        os.chdir(cls.test_dir)

    @classmethod
    def tearDownClass(cls):
        os.chdir(CWD)
예제 #44
0
 def test_spin_polarization(self):
     dos_path = os.path.join(test_dir, "dos_spin_polarization_mp-865805.json")
     dos = loadfn(dos_path)
     self.assertAlmostEqual(dos.spin_polarization, 0.6460514663341762)
예제 #45
0
def update_checkpoint(job_ids=None, jfile=None, **kwargs):
    """
    rerun the jobs with job ids in the job_ids list. The jobs are
    read from the json checkpoint file, jfile.
    If no job_ids are given then the checkpoint file will
    be updated with corresponding final energy

    Args:
        job_ids: list of job ids to update or q resolve
        jfile: check point file
    """
    cal_log = loadfn(jfile, cls=MontyDecoder)
    cal_log_new = []
    all_jobs = []
    run_jobs = []
    handlers = []
    final_energy = None
    incar = None
    kpoints = None
    qadapter = None
    # if updating the specs of the job
    for k, v in kwargs.items():
        if k == 'incar':
            incar = v
        if k == 'kpoints':
            kpoints = v
        if k == 'que':
            qadapter = v
    for j in cal_log:
        job = j["job"]
        job.job_id = j['job_id']
        all_jobs.append(job)
        if job_ids and (j['job_id'] in job_ids or job.job_dir in job_ids):
            logger.info('setting job {0} in {1} to rerun'.format(j['job_id'],
                                                                 job.job_dir))
            contcar_file = job.job_dir + os.sep + 'CONTCAR'
            poscar_file = job.job_dir + os.sep + 'POSCAR'
            if os.path.isfile(contcar_file) and len(
                    open(contcar_file).readlines()) != 0:
                logger.info('setting poscar file from {}'
                            .format(contcar_file))
                job.vis.poscar = Poscar.from_file(contcar_file)
            else:
                logger.info('setting poscar file from {}'
                            .format(poscar_file))
                job.vis.poscar = Poscar.from_file(poscar_file)
            if incar:
                logger.info('incar overridden')
                job.vis.incar = incar
            if kpoints:
                logger.info('kpoints overridden')
                job.vis.kpoints = kpoints
            if qadapter:
                logger.info('qadapter overridden')
                job.vis.qadapter = qadapter
            run_jobs.append(job)
    if run_jobs:
        c = Custodian(handlers, run_jobs, max_errors=5)
        c.run()
    for j in all_jobs:
        final_energy = j.get_final_energy()
        cal_log_new.append({"job": j.as_dict(),
                            'job_id': j.job_id,
                            "corrections": [],
                            'final_energy': final_energy})
    dumpfn(cal_log_new, jfile, cls=MontyEncoder, indent=4)
예제 #46
0
import warnings

from pymatgen.matproj.rest import MPRester

from monty.serialization import loadfn

__author__ = "Kiran Mathew, Joshua J. Gabriel, Michael Ashton, " \
             "Arunima K. Singh, Joshua T. Paul, Seve G. Monahan, " \
             "Richard G. Hennig"
__date__ = "March 3 2017"
__version__ = "1.7.0"

PACKAGE_PATH = os.path.dirname(__file__)

try:
    MPINT_CONFIG = loadfn(os.path.join(PACKAGE_PATH, 'mpint_config.yaml'))
except:
    MPINT_CONFIG = {}
    warnings.warn('mpint_config.yaml file not configured.')

# set environ variables for MAPI_KEY and VASP_PSP_DIR
if MPINT_CONFIG.get('potentials', ''):
    os.environ['PMG_VASP_PSP_DIR'] = MPINT_CONFIG.get('potentials', '')
MP_API = MPINT_CONFIG.get('mp_api', '')
if MP_API:
    os.environ['MAPI_KEY'] = MP_API

MPR = MPRester(MP_API)
USERNAME = MPINT_CONFIG.get('username', None)
VASP_STD_BIN = MPINT_CONFIG.get('normal_binary', None)
VASP_TWOD_BIN = MPINT_CONFIG.get('twod_binary', None)
예제 #47
0
    """
    Create a HTML table from a list of elements.
    :param rows: list of list of cell contents
    :return: html.Table
    """
    contents = []
    for row in rows:
        contents.append(html.Tr([html.Td(item) for item in row]))
    if not header:
        return html.Table([html.Tbody(contents)], className="table")
    else:
        header = html.Thead([html.Tr([html.Th(item) for item in header])])
        return html.Table([header, html.Tbody(contents)], className="table")


DOI_CACHE = loadfn(MODULE_PATH / "apps/assets/doi_cache.json")


def cite_me(doi: str = None,
            manual_ref: str = None,
            cite_text: str = "Cite me") -> html.Div:
    """
    Create a button to show users how to cite a particular resource.
    :param doi: DOI
    :param manual_ref: If DOI not available
    :param cite_text: Text to show as button label
    :return: A button
    """
    if doi:
        if doi in DOI_CACHE:
            ref = DOI_CACHE[doi]
예제 #48
0
"""

__author__ = "Joseph Montoya"
__copyright__ = "Copyright 2017, The Materials Project"
__credits__ = ("Maarten de Jong, Shyam Dwaraknath, Wei Chen, "
               "Mark Asta, Anubhav Jain, Terence Lew")
__version__ = "1.0"
__maintainer__ = "Joseph Montoya"
__email__ = "*****@*****.**"
__status__ = "Production"
__date__ = "July 24, 2018"

voigt_map = [(0, 0), (1, 1), (2, 2), (1, 2), (0, 2), (0, 1)]
reverse_voigt_map = np.array([[0, 5, 4], [5, 1, 3], [4, 3, 2]])

DEFAULT_QUAD = loadfn(os.path.join(os.path.dirname(__file__),
                                   "quad_data.json"))


class Tensor(np.ndarray, MSONable):
    """
    Base class for doing useful general operations on Nth order tensors,
    without restrictions on the type (stress, elastic, strain, piezo, etc.)
    """
    symbol = "T"

    def __new__(cls, input_array, vscale=None, check_rank=None):
        """
        Create a Tensor object.  Note that the constructor uses __new__
        rather than __init__ according to the standard method of
        subclassing numpy ndarrays.
예제 #49
0
    return lattice.get_cartesian_coords([c - np.floor(c) for c in fc])


def reorient_z(structure):
    """
    reorients a structure such that the z axis is concurrent with the
    normal to the A-B plane
    """
    struct = structure.copy()
    sop = get_rot(struct)
    struct.apply_operation(sop)
    return struct


# Get color dictionary
colors = loadfn(
    os.path.join(os.path.dirname(vis.__file__), "ElementColorSchemes.yaml"))
color_dict = {
    el: [j / 256.001 for j in colors["Jmol"][el]]
    for el in colors["Jmol"].keys()
}


def plot_slab(slab,
              ax,
              scale=0.8,
              repeat=5,
              window=1.5,
              draw_unit_cell=True,
              decay=0.2,
              adsorption_sites=True):
    """
def test_totalnumber():
    AFLOW_PROTOTYPE_LIBRARY = loadfn(
        os.path.join(os.path.dirname(parse_proto.__file__),
                     "aflow_prototype_db.json"))
    assert (len(AFLOW_PROTOTYPE_LIBRARY) == 590)
예제 #51
0
from crystal_toolkit.helpers.layouts import *
import crystal_toolkit.components as ctc

from pymatgen import MPRester, Structure, Molecule
from pymatgen.analysis.graphs import StructureGraph, MoleculeGraph
from pymatgen import __version__ as pmg_version

from json import loads
from uuid import uuid4
from urllib import parse
from random import choice
from ast import literal_eval
from monty.serialization import loadfn

# choose a default structure on load
DEFAULT_MPIDS = loadfn("task_ids_on_load.json")

################################################################################
# region SET UP APP
################################################################################

meta_tags = [
    {
        "name": "description",
        "content": "Crystal Toolkit allows you to import, view, analyze and transform "
        "crystal structures and molecules using the full power of the Materials "
        "Project.",
    }
]

crystal_toolkit_app = dash.Dash(__name__, meta_tags=meta_tags)
예제 #52
0
 def test_get_cycle_life(self):
     pcycler_run = loadfn(self.pcycler_run_file)
     self.assertEqual(pcycler_run.get_cycle_life(30, 0.99), 82)
     self.assertEqual(pcycler_run.get_cycle_life(), 189)
예제 #53
0
def _get_symm_data(name):
    global SYMM_DATA
    if SYMM_DATA is None:
        SYMM_DATA = loadfn(
            os.path.join(os.path.dirname(__file__), "symm_data.json"))
    return SYMM_DATA[name]
예제 #54
0
from urllib.parse import parse_qs, urlsplit

from propnet.core.symbols import Symbol
from propnet.core.models import Model

from monty.serialization import loadfn
import networkx as nx

from propnet.core.graph import Graph
# noinspection PyUnresolvedReferences
import propnet.models
from propnet.core.registry import Registry

log = logging.getLogger(__name__)

GRAPH_LAYOUT_CONFIG = loadfn(path.join(path.dirname(__file__), 'graph_layout_config.yaml'))
GRAPH_STYLESHEET = loadfn(path.join(path.dirname(__file__), 'graph_stylesheet.yaml'))
GRAPH_SETTINGS = loadfn(path.join(path.dirname(__file__), 'graph_settings.yaml'))

GRAPH_HEIGHT_PX = re.match(r'^([0-9]+)[^0-9]*',
                           GRAPH_SETTINGS['full_view']['style']['height']).group(1)
SUBGRAPH_HEIGHT_PX = re.match(r'^([0-9]+)[^0-9]*',
                              GRAPH_SETTINGS['model_symbol_view']['style']['height']).group(1)

propnet_nx_graph = Graph().get_networkx_graph()


# TODO: use the attributes of the graph class, rather than networkx
def graph_conversion(graph: nx.DiGraph,
                     derivation_pathway=None,
                     hide_unconnected_nodes=True,
예제 #55
0
# dos and bs data from local jsons
from monty.serialization import loadfn
import os

# create Dash app as normal, assets folder set for visual styles only
app = dash.Dash(assets_folder=SETTINGS.ASSETS_PATH)

# If callbacks created dynamically they cannot be statically checked at app startup.
# For this simple example this IS a problem and,
# nested layout this will need to be enabled -- consult Dash documentation
# for more information.
# app.config["suppress_callback_exceptions"] = True

path = os.path.dirname(os.path.realpath(__file__))
bandstructure_symm_line = loadfn(path + "/BaTiO3_ph_bs.json")
density_of_states = loadfn(path + "/BaTiO3_ph_dos.json")

# # create the Crystal Toolkit component
bsdos_component = PhononBandstructureAndDosComponent(
    bandstructure_symm_line=bandstructure_symm_line,
    density_of_states=density_of_states,
    id="ph_bs_dos",
)

# example layout to demonstrate capabilities of component
my_layout = Container([
    H1("Phonon Band Structure and Density of States Example"),
    bsdos_component.layout(),
])
예제 #56
0
    def from_db_file(cls, db_file, admin=True):
        """
        Create MMDB from database file. File requires host, port, database,
        collection, and optionally admin_user/readonly_user and
        admin_password/readonly_password

        Args:
            db_file (str): path to the file containing the credentials
            admin (bool): whether to use the admin user

        Returns:
            MMDb object
        """
        creds = loadfn(db_file)

        maggma_kwargs = creds.get("maggma_store", {})
        maggma_prefix = creds.get("maggma_store_prefix", "atomate")
        database = creds.get("database", None)

        kwargs = creds.get(
            "mongoclient_kwargs", {}
        )  # any other MongoClient kwargs can go here ...
        if "host_uri" in creds:
            return cls(
                host_uri=creds["host_uri"],
                database=database,
                collection=creds["collection"],
                maggma_store_kwargs=maggma_kwargs,
                maggma_store_prefix=maggma_prefix,
                **kwargs,
            )

        if admin and "admin_user" not in creds and "readonly_user" in creds:
            raise ValueError(
                "Trying to use admin credentials, "
                "but no admin credentials are defined. "
                "Use admin=False if only read_only "
                "credentials are available."
            )

        if admin:
            user = creds.get("admin_user", "")
            password = creds.get("admin_password", "")
        else:
            user = creds.get("readonly_user", "")
            password = creds.get("readonly_password", "")

        if "authsource" in creds:
            kwargs["authsource"] = creds["authsource"]
        else:
            kwargs["authsource"] = creds["database"]

        return cls(
            host=creds["host"],
            port=int(creds.get("port", 27017)),
            database=creds["database"],
            collection=creds["collection"],
            user=user,
            password=password,
            maggma_store_kwargs=maggma_kwargs,
            maggma_store_prefix=maggma_prefix,
            **kwargs,
        )
예제 #57
0
class PymatgenTest(unittest.TestCase):
    """
    Extends unittest.TestCase with functions (taken from numpy.testing.utils)
    that support the comparison of arrays.
    """
    _multiprocess_shared_ = True
    MODULE_DIR = Path(__file__).absolute().parent
    STRUCTURES_DIR = MODULE_DIR / "structures"
    TEST_FILES_DIR = MODULE_DIR / ".." / ".." / "test_files"
    """
    Dict for test structures to aid testing.
    """
    TEST_STRUCTURES = {}
    for fn in STRUCTURES_DIR.iterdir():
        TEST_STRUCTURES[fn.name.rsplit(".", 1)[0]] = loadfn(str(fn))

    @classmethod
    def get_structure(cls, name):
        """
        Get a structure from the template directories.

        :param name: Name of a structure.
        :return: Structure
        """
        return cls.TEST_STRUCTURES[name].copy()

    @classmethod
    @requires(SETTINGS.get("PMG_MAPI_KEY"), "PMG_MAPI_KEY needs to be set.")
    def get_mp_structure(cls, mpid):
        """
        Get a structure from MP.

        :param mpid: Materials Project id.
        :return: Structure
        """
        m = MPRester()
        return m.get_structure_by_material_id(mpid)

    @staticmethod
    def assertArrayAlmostEqual(actual, desired, decimal=7, err_msg='',
                               verbose=True):
        """
        Tests if two arrays are almost equal to a tolerance. The CamelCase
        naming is so that it is consistent with standard unittest methods.
        """
        return nptu.assert_almost_equal(actual, desired, decimal, err_msg,
                                        verbose)

    @staticmethod
    def assertDictsAlmostEqual(actual, desired, decimal=7, err_msg='',
                               verbose=True):
        """
        Tests if two arrays are almost equal to a tolerance. The CamelCase
        naming is so that it is consistent with standard unittest methods.
        """

        for k, v in actual.items():
            if k not in desired:
                return False
            v2 = desired[k]
            if isinstance(v, dict):
                pass_test = PymatgenTest.assertDictsAlmostEqual(
                    v, v2, decimal=decimal, err_msg=err_msg, verbose=verbose)
                if not pass_test:
                    return False
            elif isinstance(v, (list, tuple)):
                pass_test = nptu.assert_almost_equal(v, v2, decimal, err_msg,
                                                     verbose)
                if not pass_test:
                    return False
            elif isinstance(v, (int, float)):
                pass_test = PymatgenTest.assertAlmostEqual(v, v2)
                if not pass_test:
                    return False
            else:
                assert v == v2
        return True

    @staticmethod
    def assertArrayEqual(actual, desired, err_msg='', verbose=True):
        """
        Tests if two arrays are equal. The CamelCase naming is so that it is
         consistent with standard unittest methods.
        """
        return nptu.assert_equal(actual, desired, err_msg=err_msg,
                                 verbose=verbose)

    @staticmethod
    def assertStrContentEqual(actual, desired, err_msg='', verbose=True):
        """
        Tests if two strings are equal, ignoring things like trailing spaces,
        etc.
        """
        lines1 = actual.split("\n")
        lines2 = desired.split("\n")
        if len(lines1) != len(lines2):
            return False
        failed = []
        for l1, l2 in zip(lines1, lines2):
            if l1.strip() != l2.strip():
                failed.append("%s != %s" % (l1, l2))
        return len(failed) == 0

    def serialize_with_pickle(self, objects, protocols=None, test_eq=True):
        """
        Test whether the object(s) can be serialized and deserialized with
        pickle. This method tries to serialize the objects with pickle and the
        protocols specified in input. Then it deserializes the pickle format
        and compares the two objects with the __eq__ operator if
        test_eq == True.

        Args:
            objects: Object or list of objects.
            protocols: List of pickle protocols to test. If protocols is None,
                HIGHEST_PROTOCOL is tested.

        Returns:
            Nested list with the objects deserialized with the specified
            protocols.
        """
        # Use the python version so that we get the traceback in case of errors
        import pickle
        from pymatgen.util.serialization import pmg_pickle_load, pmg_pickle_dump

        # Build a list even when we receive a single object.
        got_single_object = False
        if not isinstance(objects, (list, tuple)):
            got_single_object = True
            objects = [objects]

        if protocols is None:
            # protocols = set([0, 1, 2] + [pickle.HIGHEST_PROTOCOL])
            protocols = [pickle.HIGHEST_PROTOCOL]

        # This list will contains the object deserialized with the different
        # protocols.
        objects_by_protocol, errors = [], []

        for protocol in protocols:
            # Serialize and deserialize the object.
            mode = "wb"
            fd, tmpfile = tempfile.mkstemp(text="b" not in mode)

            try:
                with open(tmpfile, mode) as fh:
                    pmg_pickle_dump(objects, fh, protocol=protocol)
            except Exception as exc:
                errors.append("pickle.dump with protocol %s raised:\n%s" %
                              (protocol, str(exc)))
                continue

            try:
                with open(tmpfile, "rb") as fh:
                    new_objects = pmg_pickle_load(fh)
            except Exception as exc:
                errors.append("pickle.load with protocol %s raised:\n%s" %
                              (protocol, str(exc)))
                continue

            # Test for equality
            if test_eq:
                for old_obj, new_obj in zip(objects, new_objects):
                    self.assertEqual(old_obj, new_obj)

            # Save the deserialized objects and test for equality.
            objects_by_protocol.append(new_objects)

        if errors:
            raise ValueError("\n".join(errors))

        # Return nested list so that client code can perform additional tests.
        if got_single_object:
            return [o[0] for o in objects_by_protocol]
        return objects_by_protocol

    def assertMSONable(self, obj, test_if_subclass=True):
        """
        Tests if obj is MSONable and tries to verify whether the contract is
        fulfilled.

        By default, the method tests whether obj is an instance of MSONable.
        This check can be deactivated by setting test_if_subclass to False.
        """
        if test_if_subclass:
            self.assertIsInstance(obj, MSONable)
        self.assertDictEqual(obj.as_dict(), obj.__class__.from_dict(
            obj.as_dict()).as_dict())
        json.loads(obj.to_json(), cls=MontyDecoder)
예제 #58
0
 def test_cycles_to_reach_set_capacities(self):
     pcycler_run = loadfn(self.pcycler_run_file)
     cycles = pcycler_run.cycles_to_reach_set_capacities()
     self.assertGreaterEqual(cycles.iloc[0, 0], 100)
예제 #59
0
# coding: utf-8
# Copyright (c) Materials Virtual Lab
# Distributed under the terms of the BSD License.

import os
import shutil
import tempfile

import unittest
import numpy as np
from monty.serialization import loadfn
from maml.utils._data_conversion import convert_docs, pool_from

CWD = os.getcwd()
test_datapool = loadfn(
    os.path.join(os.path.dirname(__file__),
                 "../../apps/pes/tests/datapool.json"))


class PorcessingTest(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        cls.this_dir = os.path.dirname(os.path.abspath(__file__))
        cls.test_dir = tempfile.mkdtemp()
        os.chdir(cls.test_dir)

    @classmethod
    def tearDownClass(cls):
        os.chdir(CWD)
        shutil.rmtree(cls.test_dir)
예제 #60
0
class SpaceGroup(SymmetryGroup):
    """
    Class representing a SpaceGroup.

    .. attribute:: symbol

        Full International or Hermann-Mauguin Symbol.

    .. attribute:: int_number

        International number

    .. attribute:: generators

        List of generator matrices. Note that 4x4 matrices are used for Space
        Groups.

    .. attribute:: order

        Order of Space Group
    """

    SYMM_OPS = loadfn(os.path.join(os.path.dirname(__file__), "symm_ops.json"))
    SG_SYMBOLS = set(_get_symm_data("space_group_encoding").keys())
    for op in SYMM_OPS:
        op["hermann_mauguin"] = re.sub(r" ", "", op["hermann_mauguin"])
        op["universal_h_m"] = re.sub(r" ", "", op["universal_h_m"])
        SG_SYMBOLS.add(op["hermann_mauguin"])
        SG_SYMBOLS.add(op["universal_h_m"])

    gen_matrices = _get_symm_data("generator_matrices")
    # POINT_GROUP_ENC = SYMM_DATA["point_group_encoding"]
    sgencoding = _get_symm_data("space_group_encoding")
    abbrev_sg_mapping = _get_symm_data("abbreviated_spacegroup_symbols")
    translations = {
        k: Fraction(v)
        for k, v in _get_symm_data("translations").items()
    }
    full_sg_mapping = {
        v["full_symbol"]: k
        for k, v in _get_symm_data("space_group_encoding").items()
    }

    def __init__(self, int_symbol: str) -> None:
        """
        Initializes a Space Group from its full or abbreviated international
        symbol. Only standard settings are supported.

        Args:
            int_symbol (str): Full International (e.g., "P2/m2/m2/m") or
                Hermann-Mauguin Symbol ("Pmmm") or abbreviated symbol. The
                notation is a LaTeX-like string, with screw axes being
                represented by an underscore. For example, "P6_3/mmc".
                Alternative settings can be accessed by adding a ":identifier".
                For example, the hexagonal setting  for rhombohedral cells can be
                accessed by adding a ":H", e.g., "R-3m:H". To find out all
                possible settings for a spacegroup, use the get_settings()
                classmethod. Alternative origin choices can be indicated by a
                translation vector, e.g., 'Fm-3m(a-1/4,b-1/4,c-1/4)'.
        """
        from pymatgen.core.operations import SymmOp

        int_symbol = re.sub(r" ", "", int_symbol)
        if int_symbol in SpaceGroup.abbrev_sg_mapping:
            int_symbol = SpaceGroup.abbrev_sg_mapping[int_symbol]
        elif int_symbol in SpaceGroup.full_sg_mapping:
            int_symbol = SpaceGroup.full_sg_mapping[int_symbol]

        self._symmetry_ops: set[SymmOp] | None

        for spg in SpaceGroup.SYMM_OPS:
            if int_symbol in [spg["hermann_mauguin"], spg["universal_h_m"]]:
                ops = [SymmOp.from_xyz_string(s) for s in spg["symops"]]
                self.symbol = re.sub(r":", "",
                                     re.sub(r" ", "", spg["universal_h_m"]))
                if int_symbol in SpaceGroup.sgencoding:
                    self.full_symbol = SpaceGroup.sgencoding[int_symbol][
                        "full_symbol"]
                    self.point_group = SpaceGroup.sgencoding[int_symbol][
                        "point_group"]
                else:
                    self.full_symbol = re.sub(r" ", "", spg["universal_h_m"])
                    self.point_group = spg["schoenflies"]
                self.int_number = spg["number"]
                self.order = len(ops)
                self._symmetry_ops = {*ops}
                break
        else:
            if int_symbol not in SpaceGroup.sgencoding:
                raise ValueError(f"Bad international symbol {int_symbol}")

            data = SpaceGroup.sgencoding[int_symbol]

            self.symbol = int_symbol
            # TODO: Support different origin choices.
            enc = list(data["enc"])
            inversion = int(enc.pop(0))
            ngen = int(enc.pop(0))
            symm_ops = [np.eye(4)]
            if inversion:
                symm_ops.append(
                    np.array([[-1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0],
                              [0, 0, 0, 1]]))
            for i in range(ngen):
                m = np.eye(4)
                m[:3, :3] = SpaceGroup.gen_matrices[enc.pop(0)]
                m[0, 3] = SpaceGroup.translations[enc.pop(0)]
                m[1, 3] = SpaceGroup.translations[enc.pop(0)]
                m[2, 3] = SpaceGroup.translations[enc.pop(0)]
                symm_ops.append(m)
            self.generators = symm_ops
            self.full_symbol = data["full_symbol"]
            self.point_group = data["point_group"]
            self.int_number = data["int_number"]
            self.order = data["order"]

            self._symmetry_ops = None

    def _generate_full_symmetry_ops(self) -> list[SymmOp]:
        symm_ops = np.array(self.generators)
        for op in symm_ops:
            op[0:3, 3] = np.mod(op[0:3, 3], 1)
        new_ops = symm_ops
        while len(new_ops) > 0 and len(symm_ops) < self.order:
            gen_ops = []
            for g in new_ops:
                temp_ops = np.einsum("ijk,kl", symm_ops, g)
                for op in temp_ops:
                    op[0:3, 3] = np.mod(op[0:3, 3], 1)
                    ind = np.where(np.abs(1 - op[0:3, 3]) < 1e-5)
                    op[ind, 3] = 0
                    if not in_array_list(symm_ops, op):
                        gen_ops.append(op)
                        symm_ops = np.append(symm_ops, [op], axis=0)
            new_ops = gen_ops
        assert len(symm_ops) == self.order
        return symm_ops

    @classmethod
    def get_settings(cls, int_symbol: str) -> set[str]:
        """
        Returns all the settings for a particular international symbol.

        Args:
            int_symbol (str): Full International (e.g., "P2/m2/m2/m") or
                Hermann-Mauguin Symbol ("Pmmm") or abbreviated symbol. The
                notation is a LaTeX-like string, with screw axes being
                represented by an underscore. For example, "P6_3/mmc".

        Returns:
            set[str]: All possible settings for the given international symbol.
        """
        symbols = []
        if int_symbol in SpaceGroup.abbrev_sg_mapping:
            symbols.append(SpaceGroup.abbrev_sg_mapping[int_symbol])
            int_number = SpaceGroup.sgencoding[int_symbol]["int_number"]
        elif int_symbol in SpaceGroup.full_sg_mapping:
            symbols.append(SpaceGroup.full_sg_mapping[int_symbol])
            int_number = SpaceGroup.sgencoding[int_symbol]["int_number"]
        else:
            for spg in SpaceGroup.SYMM_OPS:
                if int_symbol in [
                        re.split(r"\(|:", spg["hermann_mauguin"])[0],
                        re.split(r"\(|:", spg["universal_h_m"])[0],
                ]:
                    int_number = spg["number"]
                    break

        for spg in SpaceGroup.SYMM_OPS:
            if int_number == spg["number"]:
                symbols.append(spg["hermann_mauguin"])
                symbols.append(spg["universal_h_m"])
        return set(symbols)

    @property
    def symmetry_ops(self) -> set[SymmOp]:
        """
        Full set of symmetry operations as matrices. Lazily initialized as
        generation sometimes takes a bit of time.
        """
        from pymatgen.core.operations import SymmOp

        if self._symmetry_ops is None:
            self._symmetry_ops = {
                SymmOp(m)
                for m in self._generate_full_symmetry_ops()
            }
        return self._symmetry_ops

    def get_orbit(self, p: ArrayLike, tol: float = 1e-5) -> list[ArrayLike]:
        """
        Returns the orbit for a point.

        Args:
            p: Point as a 3x1 array.
            tol: Tolerance for determining if sites are the same. 1e-5 should
                be sufficient for most purposes. Set to 0 for exact matching
                (and also needed for symbolic orbits).

        Returns:
            ([array]) Orbit for point.
        """
        orbit: list[ArrayLike] = []
        for o in self.symmetry_ops:
            pp = o.operate(p)
            pp = np.mod(np.round(pp, decimals=10), 1)
            if not in_array_list(orbit, pp, tol=tol):
                orbit.append(pp)
        return orbit

    def is_compatible(self,
                      lattice: Lattice,
                      tol: float = 1e-5,
                      angle_tol: float = 5) -> bool:
        """
        Checks whether a particular lattice is compatible with the
        *conventional* unit cell.

        Args:
            lattice (Lattice): A Lattice.
            tol (float): The tolerance to check for equality of lengths.
            angle_tol (float): The tolerance to check for equality of angles
                in degrees.
        """
        abc = lattice.lengths
        angles = lattice.angles
        crys_system = self.crystal_system

        def check(param, ref, tolerance):
            return all(
                abs(i - j) < tolerance for i, j in zip(param, ref)
                if j is not None)

        if crys_system == "cubic":
            a = abc[0]
            return check(abc, [a, a, a], tol) and check(
                angles, [90, 90, 90], angle_tol)
        if crys_system == "hexagonal" or (crys_system == "trigonal" and
                                          (self.symbol.endswith("H")
                                           or self.int_number in [
                                               143,
                                               144,
                                               145,
                                               147,
                                               149,
                                               150,
                                               151,
                                               152,
                                               153,
                                               154,
                                               156,
                                               157,
                                               158,
                                               159,
                                               162,
                                               163,
                                               164,
                                               165,
                                           ])):
            a = abc[0]
            return check(abc, [a, a, None], tol) and check(
                angles, [90, 90, 120], angle_tol)
        if crys_system == "trigonal":
            a = abc[0]
            alpha = angles[0]
            return check(abc, [a, a, a], tol) and check(
                angles, [alpha, alpha, alpha], angle_tol)
        if crys_system == "tetragonal":
            a = abc[0]
            return check(abc, [a, a, None], tol) and check(
                angles, [90, 90, 90], angle_tol)
        if crys_system == "orthorhombic":
            return check(angles, [90, 90, 90], angle_tol)
        if crys_system == "monoclinic":
            return check(angles, [90, None, 90], angle_tol)
        return True

    @property
    def crystal_system(
        self,
    ) -> Literal["cubic", "hexagonal", "trigonal", "tetragonal",
                 "orthorhombic", "monoclinic", "triclinic"]:
        """
        Returns:
            str: Crystal system of the space group, e.g., cubic, hexagonal, etc.
        """
        i = self.int_number
        if i <= 2:
            return "triclinic"
        if i <= 15:
            return "monoclinic"
        if i <= 74:
            return "orthorhombic"
        if i <= 142:
            return "tetragonal"
        if i <= 167:
            return "trigonal"
        if i <= 194:
            return "hexagonal"
        return "cubic"

    def is_subgroup(self, supergroup: SymmetryGroup) -> bool:
        """
        True if this space group is a subgroup of the supplied group.

        Args:
            group (Spacegroup): Supergroup to test.

        Returns:
            True if this space group is a subgroup of the supplied group.
        """
        if not isinstance(supergroup, SpaceGroup):
            return NotImplemented

        if len(supergroup.symmetry_ops) < len(self.symmetry_ops):
            return False

        groups = [{supergroup.int_number}]
        all_groups = [supergroup.int_number]
        max_subgroups = {
            int(k): v
            for k, v in _get_symm_data("maximal_subgroups").items()
        }
        while True:
            new_sub_groups = set()
            for i in groups[-1]:
                new_sub_groups.update(
                    [j for j in max_subgroups[i] if j not in all_groups])
            if self.int_number in new_sub_groups:
                return True

            if len(new_sub_groups) == 0:
                break

            groups.append(new_sub_groups)
            all_groups.extend(new_sub_groups)
        return False

    def is_supergroup(self, subgroup: SymmetryGroup) -> bool:
        """
        True if this space group is a supergroup of the supplied group.

        Args:
            subgroup (Spacegroup): Subgroup to test.

        Returns:
            True if this space group is a supergroup of the supplied group.
        """
        return subgroup.is_subgroup(self)

    @classmethod
    def from_int_number(cls,
                        int_number: int,
                        hexagonal: bool = True) -> SpaceGroup:
        """
        Obtains a SpaceGroup from its international number.

        Args:
            int_number (int): International number.
            hexagonal (bool): For rhombohedral groups, whether to return the
                hexagonal setting (default) or rhombohedral setting.

        Returns:
            (SpaceGroup)
        """
        sym = sg_symbol_from_int_number(int_number, hexagonal=hexagonal)
        if not hexagonal and int_number in [146, 148, 155, 160, 161, 166, 167]:
            sym += ":R"
        return SpaceGroup(sym)

    def __str__(self) -> str:
        return (
            f"Spacegroup {self.symbol} with international number {self.int_number} and order {len(self.symmetry_ops)}"
        )

    def to_pretty_string(self) -> str:
        """
        Returns:
            (str): A pretty string representation of the space group.
        """
        return self.symbol