Exemple #1
0
    def test_load_json_nested(self):
        # We will hack on connectivity a sub transformer on its `depth` param.
        data = {
            'parameters': {
                'n_jobs': 2,
                'depth': {
                    'parameters': {
                        'n_jobs': 3,
                        'input_type': 'list'},
                    'attributes': {'_base_chains': [['H']]},
                    'transformer': 'molml.molecule.Connectivity'},
                'input_type': 'list'},
            'attributes': {'_base_chains': [['H']]},
            'transformer': 'molml.molecule.Connectivity'
        }
        path = '/tmp/somefile.json'
        with open(path, 'w') as f:
            json.dump(data, f)

        m = load_json(path)
        self.assertEqual(m.__class__.__name__,
                         data["transformer"].split('.')[-1])
        self.assertEqual(m.n_jobs, data["parameters"]["n_jobs"])
        self.assertEqual(m._base_chains, data["attributes"]["_base_chains"])

        in_data = data["parameters"]["depth"]
        in_m = m.depth
        self.assertEqual(in_m.__class__.__name__,
                         in_data["transformer"].split('.')[-1])
        self.assertEqual(in_m.n_jobs, in_data["parameters"]["n_jobs"])
        self.assertEqual(in_m._base_chains,
                         in_data["attributes"]["_base_chains"])
Exemple #2
0
    def test_load_json_nested(self):
        # We will hack on connectivity a sub transformer on its `depth` param.
        data = {
            'parameters': {
                'n_jobs': 2,
                'depth': {
                    'parameters': {
                        'n_jobs': 3,
                        'input_type': 'list'},
                    'attributes': {'_base_chains': [['H']]},
                    'transformer': 'molml.molecule.Connectivity'},
                'input_type': 'list'},
            'attributes': {'_base_chains': [['H']]},
            'transformer': 'molml.molecule.Connectivity'
        }
        path = '/tmp/somefile.json'
        with open(path, 'w') as f:
            json.dump(data, f)

        m = load_json(path)
        self.assertEqual(m.__class__.__name__,
                         data["transformer"].split('.')[-1])
        self.assertEqual(m.n_jobs, data["parameters"]["n_jobs"])
        self.assertEqual(m._base_chains, data["attributes"]["_base_chains"])

        in_data = data["parameters"]["depth"]
        in_m = m.depth
        self.assertEqual(in_m.__class__.__name__,
                         in_data["transformer"].split('.')[-1])
        self.assertEqual(in_m.n_jobs, in_data["parameters"]["n_jobs"])
        self.assertEqual(in_m._base_chains,
                         in_data["attributes"]["_base_chains"])
Exemple #3
0
    def test_load_json(self):
        data = {'parameters': {'n_jobs': 2,
                               'input_type': 'list'},
                'attributes': {'_base_chains': [['H']]},
                'transformer': 'molml.molecule.Connectivity'}
        path = '/tmp/somefile.json'
        with open(path, 'w') as f:
            json.dump(data, f)

        with open(path, 'r') as f:
            for x in (path, f):
                m = load_json(path)
                self.assertEqual(m.__class__.__name__,
                                 data["transformer"].split('.')[-1])
                self.assertEqual(m.n_jobs, data["parameters"]["n_jobs"])
                self.assertEqual(m._base_chains,
                                 data["attributes"]["_base_chains"])
Exemple #4
0
    def test_load_json(self):
        data = {'parameters': {'n_jobs': 2,
                               'input_type': 'list'},
                'attributes': {'_base_chains': [['H']]},
                'transformer': 'molml.molecule.Connectivity'}
        path = '/tmp/somefile.json'
        with open(path, 'w') as f:
            json.dump(data, f)

        with open(path, 'r') as f:
            for x in (path, f):
                m = load_json(path)
                self.assertEqual(m.__class__.__name__,
                                 data["transformer"].split('.')[-1])
                self.assertEqual(m.n_jobs, data["parameters"]["n_jobs"])
                self.assertEqual(m._base_chains,
                                 data["attributes"]["_base_chains"])
Exemple #5
0
    [-1.0, 0.0, 0.0],
    [0.0, 0.0, 0.0],
    [1.0, 0.0, 0.0],
]
HCN = (HCN_ELES, HCN_COORDS)


if __name__ == "__main__":
    # Example of fitting the Coulomb matrix and then saving it
    feat = CoulombMatrix()
    feat.fit([H2, HCN])
    print("Saving Model")
    feat.save_json("coulomb_model.json")

    print("Loading Model")
    feat2 = load_json("coulomb_model.json")
    print(feat2.transform([H2, HCN]))

    # Example of fitting a generallized crystal with the Coulomb matrix and
    # then saving it
    input_type = ("elements", "coords", "unit_cell")
    radius = 4.1
    feat = CoulombMatrix(input_type=input_type)
    crystal = GenerallizedCrystal(transformer=feat, radius=radius)
    feat.fit([H2_FULL])
    print("Saving Model")
    feat.save_json("coulomb_crystal_model.json")

    print("Loading Model")
    feat2 = load_json("coulomb_crystal_model.json")
    print(feat2.transform([H2_FULL]))
Exemple #6
0
HCN_COORDS = [
    [-1.0, 0.0, 0.0],
    [0.0, 0.0, 0.0],
    [1.0, 0.0, 0.0],
]
HCN = (HCN_ELES, HCN_COORDS)

if __name__ == "__main__":
    # Example of fitting the Coulomb matrix and then saving it
    feat = CoulombMatrix()
    feat.fit([H2, HCN])
    print("Saving Model")
    feat.save_json("coulomb_model.json")

    print("Loading Model")
    feat2 = load_json("coulomb_model.json")
    print(feat2.transform([H2, HCN]))

    # Example of fitting a generallized crystal with the Coulomb matrix and
    # then saving it
    input_type = ("elements", "coords", "unit_cell")
    radius = 4.1
    feat = CoulombMatrix(input_type=input_type)
    crystal = GenerallizedCrystal(transformer=feat, radius=radius)
    feat.fit([H2_FULL])
    print("Saving Model")
    feat.save_json("coulomb_crystal_model.json")

    print("Loading Model")
    feat2 = load_json("coulomb_crystal_model.json")
    print(feat2.transform([H2_FULL]))