Exemple #1
0
    def test_push_to_hub_in_organization(self):
        config = BertConfig(vocab_size=99,
                            hidden_size=32,
                            num_hidden_layers=5,
                            num_attention_heads=4,
                            intermediate_size=37)
        config.push_to_hub("valid_org/test-config-org",
                           use_auth_token=self._token)

        new_config = BertConfig.from_pretrained("valid_org/test-config-org")
        for k, v in config.to_dict().items():
            if k != "transformers_version":
                self.assertEqual(v, getattr(new_config, k))

        # Reset repo
        delete_repo(token=self._token, repo_id="valid_org/test-config-org")

        # Push to hub via save_pretrained
        with tempfile.TemporaryDirectory() as tmp_dir:
            config.save_pretrained(tmp_dir,
                                   repo_id="valid_org/test-config-org",
                                   push_to_hub=True,
                                   use_auth_token=self._token)

        new_config = BertConfig.from_pretrained("valid_org/test-config-org")
        for k, v in config.to_dict().items():
            if k != "transformers_version":
                self.assertEqual(v, getattr(new_config, k))
    def test_push_to_hub_in_organization(self):
        feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
            SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
        feature_extractor.push_to_hub("valid_org/test-feature-extractor",
                                      use_auth_token=self._token)

        new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
            "valid_org/test-feature-extractor")
        for k, v in feature_extractor.__dict__.items():
            self.assertEqual(v, getattr(new_feature_extractor, k))

        # Reset repo
        delete_repo(token=self._token,
                    repo_id="valid_org/test-feature-extractor")

        # Push to hub via save_pretrained
        with tempfile.TemporaryDirectory() as tmp_dir:
            feature_extractor.save_pretrained(
                tmp_dir,
                repo_id="valid_org/test-feature-extractor-org",
                push_to_hub=True,
                use_auth_token=self._token)

        new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
            "valid_org/test-feature-extractor-org")
        for k, v in feature_extractor.__dict__.items():
            self.assertEqual(v, getattr(new_feature_extractor, k))
Exemple #3
0
    def tearDownClass(cls):
        try:
            delete_repo(token=cls._token, name="test-model-flax")
        except HTTPError:
            pass

        try:
            delete_repo(token=cls._token, name="test-model-flax-org", organization="valid_org")
        except HTTPError:
            pass
Exemple #4
0
    def tearDownClass(cls):
        try:
            delete_repo(token=cls._token, name="test-config")
        except HTTPError:
            pass

        try:
            delete_repo(token=cls._token, name="test-config-org", organization="valid_org")
        except HTTPError:
            pass

        try:
            delete_repo(token=cls._token, name="test-dynamic-config")
        except HTTPError:
            pass
Exemple #5
0
    def tearDownClass(cls):
        try:
            delete_repo(token=cls._token, repo_id="test-config")
        except HTTPError:
            pass

        try:
            delete_repo(token=cls._token, repo_id="valid_org/test-config-org")
        except HTTPError:
            pass

        try:
            delete_repo(token=cls._token, repo_id="test-dynamic-config")
        except HTTPError:
            pass
Exemple #6
0
    def tearDownClass(cls):
        try:
            delete_repo(token=cls._token, repo_id="test-feature-extractor")
        except HTTPError:
            pass

        try:
            delete_repo(token=cls._token, repo_id="valid_org/test-feature-extractor-org")
        except HTTPError:
            pass

        try:
            delete_repo(token=cls._token, repo_id="test-dynamic-feature-extractor")
        except HTTPError:
            pass
 def tearDownClass(cls):
     try:
         delete_repo(token=cls._token, repo_id="test-dynamic-pipeline")
     except HTTPError:
         pass