コード例 #1
0
    def test_push_to_hub_model_kwargs(self):
        REPO_NAME = repo_name("PUSH_TO_HUB")
        model = self.model_init()
        model = self.model_fit(model)
        push_to_hub_keras(
            model,
            repo_path_or_name=f"{WORKING_REPO_DIR}/{REPO_NAME}",
            api_endpoint=ENDPOINT_STAGING,
            use_auth_token=self._token,
            git_user="******",
            git_email="*****@*****.**",
            config={
                "num": 7,
                "act": "gelu_fast"
            },
            include_optimizer=True,
            save_traces=False,
        )

        model_info = HfApi(endpoint=ENDPOINT_STAGING).model_info(
            f"{USER}/{REPO_NAME}", )
        self.assertEqual(model_info.modelId, f"{USER}/{REPO_NAME}")

        from_pretrained_keras(f"{WORKING_REPO_DIR}/{REPO_NAME}")
        self.assertRaises(ValueError,
                          msg="Exception encountered when calling layer*")

        self._api.delete_repo(repo_id=f"{REPO_NAME}", token=self._token)
コード例 #2
0
 def test_push_to_hub_model_card_plot_false(self):
     REPO_NAME = repo_name("PUSH_TO_HUB")
     model = self.model_init()
     model = self.model_fit(model)
     push_to_hub_keras(
         model,
         repo_path_or_name=f"{WORKING_REPO_DIR}/{REPO_NAME}",
         api_endpoint=ENDPOINT_STAGING,
         use_auth_token=self._token,
         git_user="******",
         git_email="*****@*****.**",
         plot_model=False,
     )
     model_info = HfApi(endpoint=ENDPOINT_STAGING).model_info(
         f"{USER}/{REPO_NAME}", )
     self.assertFalse(
         "model.png" in [f.rfilename for f in model_info.siblings])
     self._api.delete_repo(repo_id=f"{REPO_NAME}", token=self._token)
コード例 #3
0
    def test_push_to_hub(self):
        REPO_NAME = repo_name("PUSH_TO_HUB")
        model = self.model_init()
        model.build((None, 2))
        push_to_hub_keras(
            model,
            repo_path_or_name=f"{WORKING_REPO_DIR}/{REPO_NAME}",
            api_endpoint=ENDPOINT_STAGING,
            use_auth_token=self._token,
            git_user="******",
            git_email="*****@*****.**",
            config={
                "num": 7,
                "act": "gelu_fast"
            },
        )

        model_info = HfApi(endpoint=ENDPOINT_STAGING).model_info(
            f"{USER}/{REPO_NAME}", )
        self.assertEqual(model_info.modelId, f"{USER}/{REPO_NAME}")

        self._api.delete_repo(name=f"{REPO_NAME}", token=self._token)
コード例 #4
0
    def test_override_tensorboard(self):
        REPO_NAME = repo_name("PUSH_TO_HUB")
        with tempfile.TemporaryDirectory() as tmpdirname:
            os.makedirs(f"{tmpdirname}/tb_log_dir")
            with open(f"{tmpdirname}/tb_log_dir/tensorboard.txt", "w") as fp:
                fp.write("Keras FTW")
            model = self.model_init()
            model.build((None, 2))
            push_to_hub_keras(
                model,
                repo_path_or_name=f"{WORKING_REPO_DIR}/{REPO_NAME}",
                log_dir=f"{tmpdirname}/tb_log_dir",
                api_endpoint=ENDPOINT_STAGING,
                use_auth_token=self._token,
                git_user="******",
                git_email="*****@*****.**",
            )
            os.makedirs(f"{tmpdirname}/tb_log_dir2")
            with open(f"{tmpdirname}/tb_log_dir2/override.txt", "w") as fp:
                fp.write("Keras FTW")
            push_to_hub_keras(
                model,
                repo_path_or_name=f"{WORKING_REPO_DIR}/{REPO_NAME}",
                log_dir=f"{tmpdirname}/tb_log_dir2",
                api_endpoint=ENDPOINT_STAGING,
                use_auth_token=self._token,
                git_user="******",
                git_email="*****@*****.**",
            )

            model_info = HfApi(endpoint=ENDPOINT_STAGING).model_info(
                f"{USER}/{REPO_NAME}", )
            self.assertTrue("logs/override.txt" in
                            [f.rfilename for f in model_info.siblings])
            self.assertFalse("logs/tensorboard.txt" in
                             [f.rfilename for f in model_info.siblings])

            self._api.delete_repo(repo_id=f"{REPO_NAME}", token=self._token)