コード例 #1
0
    def test_crystal_model(self):
        callbacks = [
            ModelCheckpointMAE(
                filepath='./val_mae_{epoch:05d}_{val_mae:.6f}.hdf5',
                save_best_only=True,
                val_gen=self.train_gen_crystal,
                steps_per_val=1,
                is_pa=False),
            GeneratorLog(self.train_gen_crystal,
                         1,
                         self.train_gen_crystal,
                         1,
                         val_names=['Ef'],
                         val_units=['eV/atom']),
            ManualStop()
        ]

        self.model.fit_generator(generator=self.train_gen_crystal,
                                 steps_per_epoch=1,
                                 epochs=2,
                                 verbose=1,
                                 callbacks=callbacks)
        model_files = glob('val_mae*.hdf5')
        self.assertGreater(len(model_files), 0)
        for i in model_files:
            os.remove(i)
コード例 #2
0
ファイル: test_model.py プロジェクト: yoshida-lab/megnet
    def test_crystal_model(self):
        model = set2set_with_embedding_mp(self.n_bond_features,
                                          self.n_global_features,
                                          n_blocks=1,
                                          lr=1e-2,
                                          n1=4,
                                          n2=4,
                                          n3=4,
                                          n_pass=1,
                                          n_target=1)
        callbacks = [
            ModelCheckpointMAE(
                filepath='./val_mae_{epoch:05d}_{val_mae:.6f}.hdf5',
                save_best_only=True,
                val_gen=self.train_gen_crystal,
                steps_per_val=1,
                is_pa=False),
            GeneratorLog(self.train_gen_crystal,
                         1,
                         self.train_gen_crystal,
                         1,
                         val_names=['Ef'],
                         val_units=['eV/atom']),
            ManualStop()
        ]

        model.fit_generator(generator=self.train_gen_crystal,
                            steps_per_epoch=1,
                            epochs=2,
                            verbose=1,
                            callbacks=callbacks)
        model_files = glob('val_mae*.hdf5')
        self.assertGreater(len(model_files), 0)
        for i in model_files:
            os.remove(i)
コード例 #3
0
    def test_callback(self):
        callbacks = [
            GeneratorLog(self.train_gen,
                         steps_per_train=1,
                         val_gen=self.train_gen,
                         steps_per_val=1,
                         n_every=1,
                         val_names=['conductivity'],
                         val_units=['S/cm']),
            ModelCheckpointMAE(
                filepath='./val_mae_{epoch:05d}_{val_mae:.6f}.hdf5',
                val_gen=self.train_gen,
                steps_per_val=1),
        ]
        captured_output = StringIO()
        sys.stdout = captured_output

        before_fit_file = glob.glob("./val_mae*.hdf5")
        self.model.fit_generator(self.train_gen,
                                 steps_per_epoch=1,
                                 epochs=1,
                                 callbacks=callbacks,
                                 verbose=0)
        after_fit_file = glob.glob("./val_mae*.hdf5")
        sys.stdout = sys.__stdout__
        result = captured_output.getvalue()
        self.assertRegex(result,
                         "Train MAE:\nconductivity: [-+]?\d*\.\d+|\d+ S/cm")
        self.assertRegex(result,
                         "Test MAE:\nconductivity: [-+]?\d*\.\d+|\d+ S/cm")

        self.assertEqual(len(before_fit_file), 0)
        self.assertEqual(len(after_fit_file), 1)
        os.remove(after_fit_file[0])
コード例 #4
0
    def test_callback(self):
        callbacks = [GeneratorLog(self.train_gen, steps_per_train=1, val_gen=self.train_gen, steps_per_val=1,
                                  n_every=1, val_names=['conductivity'], val_units=['S/cm']),
                     ModelCheckpointMAE(filepath='./val_mae_{epoch:05d}_{val_mae:.6f}.hdf5', val_gen=self.train_gen,
                                        steps_per_val=1),
                     ]
        before_fit_file = glob.glob("./val_mae*.hdf5")
        self.model.fit_generator(self.train_gen, steps_per_epoch=1, epochs=1, callbacks=callbacks, verbose=0)
        after_fit_file = glob.glob("./val_mae*.hdf5")

        self.assertEqual(len(before_fit_file), 0)
        self.assertEqual(len(after_fit_file), 1)
        os.remove(after_fit_file[0])

        callback_mae = ModelCheckpointMAE(filepath='./val_mae_{epoch:05d}_{val_mae:.6f}.hdf5', val_gen=self.train_gen,
                                          steps_per_val=1, target_scaler=StandardScaler(1, 1, is_intensive=True))

        dummy_target = np.array([[1, 1], [2, 2]])
        dummy_nb_atoms = np.array([[2], [3]])
        transformed = callback_mae.target_scaler.inverse_transform(dummy_target, dummy_nb_atoms)
        self.assertTrue(np.allclose(transformed, np.array([[2, 2], [3, 3]])))

        callback_mae = ModelCheckpointMAE(filepath='./val_mae_{epoch:05d}_{val_mae:.6f}.hdf5', val_gen=self.train_gen,
                                          steps_per_val=1, target_scaler=StandardScaler(1, 1, is_intensive=False))

        dummy_target = np.array([[1, 1], [2, 2]])
        dummy_nb_atoms = np.array([[2], [3]])
        transformed = callback_mae.target_scaler.inverse_transform(dummy_target, dummy_nb_atoms)
        self.assertTrue(np.allclose(transformed, np.array([[4, 4], [9, 9]])))