예제 #1
0
    def setUp(self) -> None:
        ce = CheckElements.from_pymatgen_structures()
        self.data = pd.read_pickle("data_structure.pkl_pd")
        self.data0 = self.data[0]
        self.data0_3 = self.data[:3]
        self.data0_checked = ce.check(self.data)[:10]

        gt = CrystalGraph(n_jobs=1, batch_calculate=True, batch_size=10)
        data = gt.transform(self.data0_checked)

        gen = GraphGenerator(*data, targets=None)
        self.gen = gen
예제 #2
0
    def test_ion(self):
        os.chdir('./test_gl')
        PATH = os.getcwd()
        print(PATH)
        from mgetool.imports import BatchFile

        bf = BatchFile(os.path.join(PATH, "data"), suffix='cif')
        f = bf.merge()
        os.chdir(PATH)
        data = [Structure.from_file(i) for i in f[:10]]
        ce = CheckElements.from_pymatgen_structures()
        checked_data = ce.check(data)

        tmps = AtomTableMap(search_tp="name")
        gt = CrystalGraph(n_jobs=2, atom_converter=tmps)
        in_data = gt.transform(checked_data, state_attributes=None)
예제 #3
0
 def test_CrystalGraph2(self):
     for i in self.data0_checked:
         sg1 = CrystalGraph(nn_strategy="EAMD",
                            bond_generator="BaseDesGet",
                            cutoff=None)
         s12 = sg1(i)
         # print(s12)
         print(s12["bond"].shape[-2])
         print(s12["bond"].shape[-1])
예제 #4
0
 def test_CrystalGraphsmooth(self):
     for i in self.data0_checked:
         sg1 = CrystalGraph(nn_strategy="find_xyz_in_spheres",
                            return_bonds="bonds",
                            cutoff=5.0,
                            bond_converter=Smooth(r_c=5.0, r_cs=3.0))
         s12 = sg1(i)
         # print(s12)
         print(s12["bond"].shape[-2])
         print(s12["bond"].shape[-1])
예제 #5
0
    def test_CrystalGraph_convert_call(self):
        sg1 = CrystalGraph()
        s11 = sg1(self.data0)
        s12 = sg1(self.data0, state_attributes=np.array([2, 3.0]))

        self.assertTrue(isinstance(s12, dict))
        self.assertEqual(list(s12.keys()),
                         ['atom', 'bond', 'state', 'atom_nbr_idx'])
        for i in s12.values():
            print(type(i))
            self.assertTrue(isinstance(i, (np.ndarray, list)))
예제 #6
0
 def test_CrystalGraph_as_dict(self):
     sg1 = CrystalGraph()
     dict1 = sg1.as_dict()
     sg2 = CrystalGraph.from_dict(dict1)
     s12 = sg2(self.data0)