def setUp(self) -> None: ce = CheckElements.from_pymatgen_structures() self.data = pd.read_pickle("data_structure.pkl_pd") self.data0 = self.data[0] self.data0_3 = self.data[:3] self.data0_checked = ce.check(self.data)[:10] gt = CrystalGraph(n_jobs=1, batch_calculate=True, batch_size=10) data = gt.transform(self.data0_checked) gen = GraphGenerator(*data, targets=None) self.gen = gen
def test_ion(self): os.chdir('./test_gl') PATH = os.getcwd() print(PATH) from mgetool.imports import BatchFile bf = BatchFile(os.path.join(PATH, "data"), suffix='cif') f = bf.merge() os.chdir(PATH) data = [Structure.from_file(i) for i in f[:10]] ce = CheckElements.from_pymatgen_structures() checked_data = ce.check(data) tmps = AtomTableMap(search_tp="name") gt = CrystalGraph(n_jobs=2, atom_converter=tmps) in_data = gt.transform(checked_data, state_attributes=None)
def test_CrystalGraph2(self): for i in self.data0_checked: sg1 = CrystalGraph(nn_strategy="EAMD", bond_generator="BaseDesGet", cutoff=None) s12 = sg1(i) # print(s12) print(s12["bond"].shape[-2]) print(s12["bond"].shape[-1])
def test_CrystalGraphsmooth(self): for i in self.data0_checked: sg1 = CrystalGraph(nn_strategy="find_xyz_in_spheres", return_bonds="bonds", cutoff=5.0, bond_converter=Smooth(r_c=5.0, r_cs=3.0)) s12 = sg1(i) # print(s12) print(s12["bond"].shape[-2]) print(s12["bond"].shape[-1])
def test_CrystalGraph_convert_call(self): sg1 = CrystalGraph() s11 = sg1(self.data0) s12 = sg1(self.data0, state_attributes=np.array([2, 3.0])) self.assertTrue(isinstance(s12, dict)) self.assertEqual(list(s12.keys()), ['atom', 'bond', 'state', 'atom_nbr_idx']) for i in s12.values(): print(type(i)) self.assertTrue(isinstance(i, (np.ndarray, list)))
def test_CrystalGraph_as_dict(self): sg1 = CrystalGraph() dict1 = sg1.as_dict() sg2 = CrystalGraph.from_dict(dict1) s12 = sg2(self.data0)