def test_new_config_registration(self): try: AutoConfig.register("custom", CustomConfig) # Wrong model type will raise an error with self.assertRaises(ValueError): AutoConfig.register("model", CustomConfig) # Trying to register something existing in the Transformers library will raise an error with self.assertRaises(ValueError): AutoConfig.register("bert", BertConfig) # Now that the config is registered, it can be used as any other config with the auto-API config = CustomConfig() with tempfile.TemporaryDirectory() as tmp_dir: config.save_pretrained(tmp_dir) new_config = AutoConfig.from_pretrained(tmp_dir) self.assertIsInstance(new_config, CustomConfig) finally: if "custom" in CONFIG_MAPPING._extra_content: del CONFIG_MAPPING._extra_content["custom"]
def test_push_to_hub_dynamic_config(self): CustomConfig.register_for_auto_class() config = CustomConfig(attribute=42) with tempfile.TemporaryDirectory() as tmp_dir: repo = Repository(tmp_dir, clone_from=f"{USER}/test-dynamic-config", use_auth_token=self._token) config.save_pretrained(tmp_dir) # This has added the proper auto_map field to the config self.assertDictEqual( config.auto_map, {"AutoConfig": "custom_configuration.CustomConfig"}) # The code has been copied from fixtures self.assertTrue( os.path.isfile(os.path.join(tmp_dir, "custom_configuration.py"))) repo.push_to_hub() new_config = AutoConfig.from_pretrained(f"{USER}/test-dynamic-config", trust_remote_code=True) # Can't make an isinstance check because the new_config is from the FakeConfig class of a dynamic module self.assertEqual(new_config.__class__.__name__, "CustomConfig") self.assertEqual(new_config.attribute, 42)
def test_new_model_registration(self): AutoConfig.register("custom", CustomConfig) auto_classes = [ AutoModel, AutoModelForCausalLM, AutoModelForMaskedLM, AutoModelForPreTraining, AutoModelForQuestionAnswering, AutoModelForSequenceClassification, AutoModelForTokenClassification, ] try: for auto_class in auto_classes: with self.subTest(auto_class.__name__): # Wrong config class will raise an error with self.assertRaises(ValueError): auto_class.register(BertConfig, CustomModel) auto_class.register(CustomConfig, CustomModel) # Trying to register something existing in the Transformers library will raise an error with self.assertRaises(ValueError): auto_class.register(BertConfig, BertModel) # Now that the config is registered, it can be used as any other config with the auto-API tiny_config = BertModelTester(self).get_config() config = CustomConfig(**tiny_config.to_dict()) model = auto_class.from_config(config) self.assertIsInstance(model, CustomModel) with tempfile.TemporaryDirectory() as tmp_dir: model.save_pretrained(tmp_dir) new_model = auto_class.from_pretrained(tmp_dir) # The model is a CustomModel but from the new dynamically imported class. self.assertIsInstance(new_model, CustomModel) finally: if "custom" in CONFIG_MAPPING._extra_content: del CONFIG_MAPPING._extra_content["custom"] for mapping in ( MODEL_MAPPING, MODEL_FOR_PRETRAINING_MAPPING, MODEL_FOR_QUESTION_ANSWERING_MAPPING, MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, MODEL_FOR_CAUSAL_LM_MAPPING, MODEL_FOR_MASKED_LM_MAPPING, ): if CustomConfig in mapping._extra_content: del mapping._extra_content[CustomConfig]
def test_from_pretrained_dynamic_model_local(self): try: AutoConfig.register("custom", CustomConfig) AutoModel.register(CustomConfig, CustomModel) config = CustomConfig(hidden_size=32) model = CustomModel(config) with tempfile.TemporaryDirectory() as tmp_dir: model.save_pretrained(tmp_dir) new_model = AutoModel.from_pretrained(tmp_dir, trust_remote_code=True) for p1, p2 in zip(model.parameters(), new_model.parameters()): self.assertTrue(torch.equal(p1, p2)) finally: if "custom" in CONFIG_MAPPING._extra_content: del CONFIG_MAPPING._extra_content["custom"] if CustomConfig in MODEL_MAPPING._extra_content: del MODEL_MAPPING._extra_content[CustomConfig]
def test_push_to_hub_dynamic_config(self): CustomConfig.register_for_auto_class() config = CustomConfig(attribute=42) config.push_to_hub("test-dynamic-config", use_auth_token=self._token) # This has added the proper auto_map field to the config self.assertDictEqual( config.auto_map, {"AutoConfig": "custom_configuration.CustomConfig"}) new_config = AutoConfig.from_pretrained(f"{USER}/test-dynamic-config", trust_remote_code=True) # Can't make an isinstance check because the new_config is from the FakeConfig class of a dynamic module self.assertEqual(new_config.__class__.__name__, "CustomConfig") self.assertEqual(new_config.attribute, 42)