def test_new_config_registration(self):
        try:
            AutoConfig.register("custom", CustomConfig)
            # Wrong model type will raise an error
            with self.assertRaises(ValueError):
                AutoConfig.register("model", CustomConfig)
            # Trying to register something existing in the Transformers library will raise an error
            with self.assertRaises(ValueError):
                AutoConfig.register("bert", BertConfig)

            # Now that the config is registered, it can be used as any other config with the auto-API
            config = CustomConfig()
            with tempfile.TemporaryDirectory() as tmp_dir:
                config.save_pretrained(tmp_dir)
                new_config = AutoConfig.from_pretrained(tmp_dir)
                self.assertIsInstance(new_config, CustomConfig)

        finally:
            if "custom" in CONFIG_MAPPING._extra_content:
                del CONFIG_MAPPING._extra_content["custom"]
Beispiel #2
0
    def test_push_to_hub_dynamic_config(self):
        CustomConfig.register_for_auto_class()
        config = CustomConfig(attribute=42)

        with tempfile.TemporaryDirectory() as tmp_dir:
            repo = Repository(tmp_dir,
                              clone_from=f"{USER}/test-dynamic-config",
                              use_auth_token=self._token)
            config.save_pretrained(tmp_dir)

            # This has added the proper auto_map field to the config
            self.assertDictEqual(
                config.auto_map,
                {"AutoConfig": "custom_configuration.CustomConfig"})
            # The code has been copied from fixtures
            self.assertTrue(
                os.path.isfile(os.path.join(tmp_dir,
                                            "custom_configuration.py")))

            repo.push_to_hub()

        new_config = AutoConfig.from_pretrained(f"{USER}/test-dynamic-config",
                                                trust_remote_code=True)
        # Can't make an isinstance check because the new_config is from the FakeConfig class of a dynamic module
        self.assertEqual(new_config.__class__.__name__, "CustomConfig")
        self.assertEqual(new_config.attribute, 42)
Beispiel #3
0
    def test_new_model_registration(self):
        AutoConfig.register("custom", CustomConfig)

        auto_classes = [
            AutoModel,
            AutoModelForCausalLM,
            AutoModelForMaskedLM,
            AutoModelForPreTraining,
            AutoModelForQuestionAnswering,
            AutoModelForSequenceClassification,
            AutoModelForTokenClassification,
        ]

        try:
            for auto_class in auto_classes:
                with self.subTest(auto_class.__name__):
                    # Wrong config class will raise an error
                    with self.assertRaises(ValueError):
                        auto_class.register(BertConfig, CustomModel)
                    auto_class.register(CustomConfig, CustomModel)
                    # Trying to register something existing in the Transformers library will raise an error
                    with self.assertRaises(ValueError):
                        auto_class.register(BertConfig, BertModel)

                    # Now that the config is registered, it can be used as any other config with the auto-API
                    tiny_config = BertModelTester(self).get_config()
                    config = CustomConfig(**tiny_config.to_dict())
                    model = auto_class.from_config(config)
                    self.assertIsInstance(model, CustomModel)

                    with tempfile.TemporaryDirectory() as tmp_dir:
                        model.save_pretrained(tmp_dir)
                        new_model = auto_class.from_pretrained(tmp_dir)
                        # The model is a CustomModel but from the new dynamically imported class.
                        self.assertIsInstance(new_model, CustomModel)

        finally:
            if "custom" in CONFIG_MAPPING._extra_content:
                del CONFIG_MAPPING._extra_content["custom"]
            for mapping in (
                MODEL_MAPPING,
                MODEL_FOR_PRETRAINING_MAPPING,
                MODEL_FOR_QUESTION_ANSWERING_MAPPING,
                MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
                MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
                MODEL_FOR_CAUSAL_LM_MAPPING,
                MODEL_FOR_MASKED_LM_MAPPING,
            ):
                if CustomConfig in mapping._extra_content:
                    del mapping._extra_content[CustomConfig]
Beispiel #4
0
    def test_from_pretrained_dynamic_model_local(self):
        try:
            AutoConfig.register("custom", CustomConfig)
            AutoModel.register(CustomConfig, CustomModel)

            config = CustomConfig(hidden_size=32)
            model = CustomModel(config)

            with tempfile.TemporaryDirectory() as tmp_dir:
                model.save_pretrained(tmp_dir)

                new_model = AutoModel.from_pretrained(tmp_dir, trust_remote_code=True)
                for p1, p2 in zip(model.parameters(), new_model.parameters()):
                    self.assertTrue(torch.equal(p1, p2))

        finally:
            if "custom" in CONFIG_MAPPING._extra_content:
                del CONFIG_MAPPING._extra_content["custom"]
            if CustomConfig in MODEL_MAPPING._extra_content:
                del MODEL_MAPPING._extra_content[CustomConfig]
Beispiel #5
0
    def test_push_to_hub_dynamic_config(self):
        CustomConfig.register_for_auto_class()
        config = CustomConfig(attribute=42)

        config.push_to_hub("test-dynamic-config", use_auth_token=self._token)

        # This has added the proper auto_map field to the config
        self.assertDictEqual(
            config.auto_map,
            {"AutoConfig": "custom_configuration.CustomConfig"})

        new_config = AutoConfig.from_pretrained(f"{USER}/test-dynamic-config",
                                                trust_remote_code=True)
        # Can't make an isinstance check because the new_config is from the FakeConfig class of a dynamic module
        self.assertEqual(new_config.__class__.__name__, "CustomConfig")
        self.assertEqual(new_config.attribute, 42)