def test_tf_model(self): model = fe.build(model_fn=one_layer_tf_model, optimizer_fn='adam') model_saver = ModelSaver(model=model, save_dir=self.save_dir) model_saver.system = sample_system_object() model_saver.on_epoch_end(data=Data()) model_name = "{}_epoch_{}".format(model_saver.model.model_name, model_saver.system.epoch_idx) tf_model_path = os.path.join(self.save_dir, model_name + '.h5') with self.subTest('Check if model is saved'): self.assertTrue(os.path.exists(tf_model_path)) with self.subTest('Validate model weights'): m2 = fe.build(model_fn=one_layer_model_without_weights, optimizer_fn='adam') fe.backend.load_model(m2, tf_model_path) self.assertTrue(is_equal(m2.trainable_variables, model.trainable_variables))
def test_max_to_keep_tf_architecture(self): save_dir = tempfile.mkdtemp() model = fe.build(model_fn=one_layer_tf_model, optimizer_fn='adam') model_saver = ModelSaver(model=model, save_dir=save_dir, max_to_keep=2, save_architecture=True) model_saver.system = sample_system_object() model_saver.on_epoch_end(data=Data()) model_saver.system.epoch_idx += 1 model_saver.on_epoch_end(data=Data()) model_saver.system.epoch_idx += 1 model_saver.on_epoch_end(data=Data()) model_name = "{}_epoch_{}".format(model_saver.model.model_name, model_saver.system.epoch_idx) tf_model_path1 = os.path.join(save_dir, model_name + '.h5') tf_architecture_path1 = os.path.join(save_dir, model_name) model_saver.system.epoch_idx += 1 model_saver.on_epoch_end(data=Data()) model_name = "{}_epoch_{}".format(model_saver.model.model_name, model_saver.system.epoch_idx) tf_model_path2 = os.path.join(save_dir, model_name + '.h5') tf_architecture_path2 = os.path.join(save_dir, model_name) with self.subTest('Check only four files are kept'): self.assertEqual(len(os.listdir(save_dir)), 4) with self.subTest('Check two latest models are kept'): self.assertTrue(os.path.exists(tf_model_path1)) self.assertTrue(os.path.exists(tf_model_path2)) self.assertTrue(os.path.exists(tf_architecture_path1)) self.assertTrue(os.path.isdir(tf_architecture_path1)) self.assertTrue(os.path.exists(tf_architecture_path2)) self.assertTrue(os.path.isdir(tf_architecture_path2))
def test_torch_model(self): model = fe.build(model_fn=MultiLayerTorchModel, optimizer_fn='adam') model_saver = ModelSaver(model=model, save_dir=self.save_dir) model_saver.system = sample_system_object() model_name = "{}_epoch_{}".format(model_saver.model.model_name, model_saver.system.epoch_idx) torch_model_path = os.path.join(self.save_dir, model_name + '.pt') if os.path.exists(torch_model_path): os.remove(torch_model_path) model_saver.on_epoch_end(data=Data()) with self.subTest('Check if model is saved'): self.assertTrue(os.path.exists(torch_model_path)) with self.subTest('Validate model weights'): m2 = fe.build(model_fn=MultiLayerTorchModelWithoutWeights, optimizer_fn='adam') fe.backend.load_model(m2, torch_model_path) self.assertTrue(is_equal(list(m2.parameters()), list(model.parameters())))
def test_max_to_keep_torch(self): save_dir = tempfile.mkdtemp() model = fe.build(model_fn=MultiLayerTorchModel, optimizer_fn='adam') model_saver = ModelSaver(model=model, save_dir=save_dir, max_to_keep=2) model_saver.system = sample_system_object() model_saver.on_epoch_end(data=Data()) model_saver.system.epoch_idx += 1 model_saver.on_epoch_end(data=Data()) model_name = "{}_epoch_{}".format(model_saver.model.model_name, model_saver.system.epoch_idx) torch_model_path1 = os.path.join(save_dir, model_name + '.pt') model_saver.system.epoch_idx += 1 model_saver.on_epoch_end(data=Data()) model_name = "{}_epoch_{}".format(model_saver.model.model_name, model_saver.system.epoch_idx) torch_model_path2 = os.path.join(save_dir, model_name + '.pt') with self.subTest('Check only two file are kept'): self.assertEqual(len(os.listdir(save_dir)), 2) with self.subTest('Check two latest model are kept'): self.assertTrue(os.path.exists(torch_model_path1)) self.assertTrue(os.path.exists(torch_model_path2))