def test_save_pretrained_optimizer_state(self): REPO_NAME = repo_name("save") model = self.model_init() model.build((None, 2)) save_pretrained_keras(model, f"{WORKING_REPO_DIR}/{REPO_NAME}", include_optimizer=True) loaded_model = from_pretrained_keras(f"{WORKING_REPO_DIR}/{REPO_NAME}") self.assertIsNotNone(loaded_model.optimizer)
def test_save_pretrained_fit(self): REPO_NAME = repo_name("functional") model = self.model_init() model = self.model_fit(model) save_pretrained_keras(model, f"{WORKING_REPO_DIR}/{REPO_NAME}") files = os.listdir(f"{WORKING_REPO_DIR}/{REPO_NAME}") self.assertIn("saved_model.pb", files) self.assertIn("keras_metadata.pb", files) self.assertEqual(len(files), 7)
def test_from_pretrained_weights(self): REPO_NAME = repo_name("FROM_PRETRAINED") model = self.model_init() model.build((None, 2)) save_pretrained_keras(model, f"{WORKING_REPO_DIR}/{REPO_NAME}") new_model = from_pretrained_keras(f"{WORKING_REPO_DIR}/{REPO_NAME}") # Check a new model's weights are not the same as the reloaded model's weights another_model = DummyModel() another_model(tf.ones([2, 2])) self.assertFalse( tf.reduce_all( tf.equal(new_model.weights[0], another_model.weights[0])).numpy().item())
def test_save_pretrained(self): REPO_NAME = repo_name("save") model = self.model_init() with pytest.raises(ValueError, match="Model should be built*"): save_pretrained_keras(model, f"{WORKING_REPO_DIR}/{REPO_NAME}") model.build((None, 2)) save_pretrained_keras(model, f"{WORKING_REPO_DIR}/{REPO_NAME}") files = os.listdir(f"{WORKING_REPO_DIR}/{REPO_NAME}") self.assertIn("saved_model.pb", files) self.assertIn("keras_metadata.pb", files) self.assertEqual(len(files), 4)
def test_save_pretrained_kwargs_load_fails_without_traces(self): REPO_NAME = repo_name("save") model = self.model_init() model.build((None, 2)) save_pretrained_keras( model, f"{WORKING_REPO_DIR}/{REPO_NAME}", include_optimizer=False, save_traces=False, ) from_pretrained_keras(f"{WORKING_REPO_DIR}/{REPO_NAME}") self.assertRaises(ValueError, msg="Exception encountered when calling layer*")
def test_save_pretrained_task_name_deprecation(self): REPO_NAME = repo_name("save") model = self.model_init() model.build((None, 2)) with pytest.warns( FutureWarning, match= "`task_name` input argument is deprecated. Pass `tags` instead.", ): save_pretrained_keras( model, f"{WORKING_REPO_DIR}/{REPO_NAME}", tags=["test"], task_name="test", save_traces=True, )
def test_save_pretrained_model_card_fit(self): REPO_NAME = repo_name("save") model = self.model_init() model = self.model_fit(model) save_pretrained_keras( model, f"{WORKING_REPO_DIR}/{REPO_NAME}", ) files = os.listdir(f"{WORKING_REPO_DIR}/{REPO_NAME}") self.assertIn("saved_model.pb", files) self.assertIn("keras_metadata.pb", files) self.assertIn("model.png", files) self.assertIn("README.md", files) self.assertIn("history.json", files) self.assertEqual(len(files), 7)
def test_abs_path_from_pretrained(self): REPO_NAME = repo_name("FROM_PRETRAINED") model = self.model_init() model.build((None, 2)) save_pretrained_keras( model, f"{WORKING_REPO_DIR}/{REPO_NAME}", config={ "num": 10, "act": "gelu_fast" }, ) new_model = from_pretrained_keras(f"{WORKING_REPO_DIR}/{REPO_NAME}") self.assertTrue( tf.reduce_all(tf.equal(new_model.weights[0], model.weights[0]))) self.assertTrue(new_model.config == {"num": 10, "act": "gelu_fast"})
def test_rel_path_from_pretrained(self): model = self.model_init() model.build((None, 2)) save_pretrained_keras( model, f"tests/{WORKING_REPO_SUBDIR}/FROM_PRETRAINED", config={ "num": 10, "act": "gelu_fast" }, ) new_model = from_pretrained_keras( f"tests/{WORKING_REPO_SUBDIR}/FROM_PRETRAINED") # Check the reloaded model's weights match the original model's weights self.assertTrue( tf.reduce_all(tf.equal(new_model.weights[0], model.weights[0]))) # Check saved configuration is what we expect self.assertTrue(new_model.config == {"num": 10, "act": "gelu_fast"})