def test_save_pretrained_optimizer_state(self):
        REPO_NAME = repo_name("save")
        model = self.model_init()

        model.build((None, 2))
        save_pretrained_keras(model,
                              f"{WORKING_REPO_DIR}/{REPO_NAME}",
                              include_optimizer=True)

        loaded_model = from_pretrained_keras(f"{WORKING_REPO_DIR}/{REPO_NAME}")
        self.assertIsNotNone(loaded_model.optimizer)
    def test_save_pretrained_fit(self):
        REPO_NAME = repo_name("functional")
        model = self.model_init()
        model = self.model_fit(model)

        save_pretrained_keras(model, f"{WORKING_REPO_DIR}/{REPO_NAME}")
        files = os.listdir(f"{WORKING_REPO_DIR}/{REPO_NAME}")

        self.assertIn("saved_model.pb", files)
        self.assertIn("keras_metadata.pb", files)
        self.assertEqual(len(files), 7)
    def test_from_pretrained_weights(self):
        REPO_NAME = repo_name("FROM_PRETRAINED")
        model = self.model_init()
        model.build((None, 2))

        save_pretrained_keras(model, f"{WORKING_REPO_DIR}/{REPO_NAME}")
        new_model = from_pretrained_keras(f"{WORKING_REPO_DIR}/{REPO_NAME}")

        # Check a new model's weights are not the same as the reloaded model's weights
        another_model = DummyModel()
        another_model(tf.ones([2, 2]))
        self.assertFalse(
            tf.reduce_all(
                tf.equal(new_model.weights[0],
                         another_model.weights[0])).numpy().item())
Beispiel #4
0
    def test_save_pretrained(self):
        REPO_NAME = repo_name("save")
        model = self.model_init()

        with pytest.raises(ValueError, match="Model should be built*"):
            save_pretrained_keras(model, f"{WORKING_REPO_DIR}/{REPO_NAME}")

        model.build((None, 2))

        save_pretrained_keras(model, f"{WORKING_REPO_DIR}/{REPO_NAME}")
        files = os.listdir(f"{WORKING_REPO_DIR}/{REPO_NAME}")

        self.assertIn("saved_model.pb", files)
        self.assertIn("keras_metadata.pb", files)
        self.assertEqual(len(files), 4)
    def test_save_pretrained_kwargs_load_fails_without_traces(self):
        REPO_NAME = repo_name("save")
        model = self.model_init()

        model.build((None, 2))

        save_pretrained_keras(
            model,
            f"{WORKING_REPO_DIR}/{REPO_NAME}",
            include_optimizer=False,
            save_traces=False,
        )

        from_pretrained_keras(f"{WORKING_REPO_DIR}/{REPO_NAME}")
        self.assertRaises(ValueError,
                          msg="Exception encountered when calling layer*")
    def test_save_pretrained_task_name_deprecation(self):
        REPO_NAME = repo_name("save")
        model = self.model_init()
        model.build((None, 2))

        with pytest.warns(
                FutureWarning,
                match=
                "`task_name` input argument is deprecated. Pass `tags` instead.",
        ):
            save_pretrained_keras(
                model,
                f"{WORKING_REPO_DIR}/{REPO_NAME}",
                tags=["test"],
                task_name="test",
                save_traces=True,
            )
    def test_save_pretrained_model_card_fit(self):
        REPO_NAME = repo_name("save")
        model = self.model_init()
        model = self.model_fit(model)

        save_pretrained_keras(
            model,
            f"{WORKING_REPO_DIR}/{REPO_NAME}",
        )
        files = os.listdir(f"{WORKING_REPO_DIR}/{REPO_NAME}")

        self.assertIn("saved_model.pb", files)
        self.assertIn("keras_metadata.pb", files)
        self.assertIn("model.png", files)
        self.assertIn("README.md", files)
        self.assertIn("history.json", files)
        self.assertEqual(len(files), 7)
Beispiel #8
0
    def test_abs_path_from_pretrained(self):
        REPO_NAME = repo_name("FROM_PRETRAINED")
        model = self.model_init()
        model.build((None, 2))
        save_pretrained_keras(
            model,
            f"{WORKING_REPO_DIR}/{REPO_NAME}",
            config={
                "num": 10,
                "act": "gelu_fast"
            },
        )

        new_model = from_pretrained_keras(f"{WORKING_REPO_DIR}/{REPO_NAME}")
        self.assertTrue(
            tf.reduce_all(tf.equal(new_model.weights[0], model.weights[0])))
        self.assertTrue(new_model.config == {"num": 10, "act": "gelu_fast"})
    def test_rel_path_from_pretrained(self):
        model = self.model_init()
        model.build((None, 2))
        save_pretrained_keras(
            model,
            f"tests/{WORKING_REPO_SUBDIR}/FROM_PRETRAINED",
            config={
                "num": 10,
                "act": "gelu_fast"
            },
        )

        new_model = from_pretrained_keras(
            f"tests/{WORKING_REPO_SUBDIR}/FROM_PRETRAINED")

        # Check the reloaded model's weights match the original model's weights
        self.assertTrue(
            tf.reduce_all(tf.equal(new_model.weights[0], model.weights[0])))

        # Check saved configuration is what we expect
        self.assertTrue(new_model.config == {"num": 10, "act": "gelu_fast"})