예제 #1
0
 def test_tf_model(self):
     model = fe.build(model_fn=one_layer_tf_model, optimizer_fn='adam')
     model_saver = ModelSaver(model=model, save_dir=self.save_dir)
     model_saver.system = sample_system_object()
     model_saver.on_epoch_end(data=Data())
     model_name = "{}_epoch_{}".format(model_saver.model.model_name, model_saver.system.epoch_idx)
     tf_model_path = os.path.join(self.save_dir, model_name + '.h5')
     with self.subTest('Check if model is saved'):
         self.assertTrue(os.path.exists(tf_model_path))
     with self.subTest('Validate model weights'):
         m2 = fe.build(model_fn=one_layer_model_without_weights, optimizer_fn='adam')
         fe.backend.load_model(m2, tf_model_path)
         self.assertTrue(is_equal(m2.trainable_variables, model.trainable_variables))
예제 #2
0
    def test_max_to_keep_tf_architecture(self):
        save_dir = tempfile.mkdtemp()
        model = fe.build(model_fn=one_layer_tf_model, optimizer_fn='adam')
        model_saver = ModelSaver(model=model, save_dir=save_dir, max_to_keep=2, save_architecture=True)
        model_saver.system = sample_system_object()
        model_saver.on_epoch_end(data=Data())
        model_saver.system.epoch_idx += 1
        model_saver.on_epoch_end(data=Data())
        model_saver.system.epoch_idx += 1
        model_saver.on_epoch_end(data=Data())
        model_name = "{}_epoch_{}".format(model_saver.model.model_name, model_saver.system.epoch_idx)
        tf_model_path1 = os.path.join(save_dir, model_name + '.h5')
        tf_architecture_path1 = os.path.join(save_dir, model_name)

        model_saver.system.epoch_idx += 1
        model_saver.on_epoch_end(data=Data())
        model_name = "{}_epoch_{}".format(model_saver.model.model_name, model_saver.system.epoch_idx)
        tf_model_path2 = os.path.join(save_dir, model_name + '.h5')
        tf_architecture_path2 = os.path.join(save_dir, model_name)

        with self.subTest('Check only four files are kept'):
            self.assertEqual(len(os.listdir(save_dir)), 4)

        with self.subTest('Check two latest models are kept'):
            self.assertTrue(os.path.exists(tf_model_path1))
            self.assertTrue(os.path.exists(tf_model_path2))
            self.assertTrue(os.path.exists(tf_architecture_path1))
            self.assertTrue(os.path.isdir(tf_architecture_path1))
            self.assertTrue(os.path.exists(tf_architecture_path2))
            self.assertTrue(os.path.isdir(tf_architecture_path2))
예제 #3
0
 def test_torch_model(self):
     model = fe.build(model_fn=MultiLayerTorchModel, optimizer_fn='adam')
     model_saver = ModelSaver(model=model, save_dir=self.save_dir)
     model_saver.system = sample_system_object()
     model_name = "{}_epoch_{}".format(model_saver.model.model_name, model_saver.system.epoch_idx)
     torch_model_path = os.path.join(self.save_dir, model_name + '.pt')
     if os.path.exists(torch_model_path):
         os.remove(torch_model_path)
     model_saver.on_epoch_end(data=Data())
     with self.subTest('Check if model is saved'):
         self.assertTrue(os.path.exists(torch_model_path))
     with self.subTest('Validate model weights'):
         m2 = fe.build(model_fn=MultiLayerTorchModelWithoutWeights, optimizer_fn='adam')
         fe.backend.load_model(m2, torch_model_path)
         self.assertTrue(is_equal(list(m2.parameters()), list(model.parameters())))
예제 #4
0
    def test_max_to_keep_torch(self):
        save_dir = tempfile.mkdtemp()
        model = fe.build(model_fn=MultiLayerTorchModel, optimizer_fn='adam')
        model_saver = ModelSaver(model=model, save_dir=save_dir, max_to_keep=2)
        model_saver.system = sample_system_object()
        model_saver.on_epoch_end(data=Data())
        model_saver.system.epoch_idx += 1
        model_saver.on_epoch_end(data=Data())
        model_name = "{}_epoch_{}".format(model_saver.model.model_name, model_saver.system.epoch_idx)
        torch_model_path1 = os.path.join(save_dir, model_name + '.pt')

        model_saver.system.epoch_idx += 1
        model_saver.on_epoch_end(data=Data())
        model_name = "{}_epoch_{}".format(model_saver.model.model_name, model_saver.system.epoch_idx)
        torch_model_path2 = os.path.join(save_dir, model_name + '.pt')

        with self.subTest('Check only two file are kept'):
            self.assertEqual(len(os.listdir(save_dir)), 2)

        with self.subTest('Check two latest model are kept'):
            self.assertTrue(os.path.exists(torch_model_path1))
            self.assertTrue(os.path.exists(torch_model_path2))