コード例 #1
0
    def test_add_adapter(self):
        for model_class in self.model_classes:
            model_config = model_class.config_class
            model = model_class(model_config())

            for config_name, adapter_config in ADAPTER_CONFIG_MAP.items():
                for type_name, adapter_type in AdapterType.__members__.items():
                    # skip configs without invertible language adapters
                    if adapter_type == AdapterType.text_lang and not adapter_config.invertible_adapter:
                        continue
                    with self.subTest(model_class=model_class,
                                      config=config_name,
                                      adapter_type=type_name):
                        name = f"{type_name}-{config_name}"
                        model.add_adapter(name,
                                          adapter_type,
                                          config=adapter_config)

                        # adapter is correctly added to config
                        self.assertTrue(name in model.config.adapters.
                                        adapter_list(adapter_type))
                        self.assertEqual(adapter_config,
                                         model.config.adapters.get(name))

                        # check forward pass
                        input_ids = ids_tensor((1, 128), 1000)
                        input_data = {"input_ids": input_ids}
                        if adapter_type == AdapterType.text_task or adapter_type == AdapterType.text_lang:
                            input_data["adapter_names"] = [name]
                        adapter_output = model(**input_data)
                        base_output = model(input_ids)
                        self.assertEqual(len(adapter_output), len(base_output))
                        self.assertFalse(
                            torch.equal(adapter_output[0], base_output[0]))
コード例 #2
0
 def test_custom_attr(self):
     for config in ADAPTER_CONFIG_MAP.values():
         with self.subTest(config=config.__class__.__name__):
             # create a copy to leave original untouched
             config = config.replace()
             config.dummy_attr = "test_value"
             self.assertEqual(config.dummy_attr, "test_value")
コード例 #3
0
    def test_config_immutable(self):
        def set_attr(config: AdapterConfig):
            config.ln_before = True

        for config in ADAPTER_CONFIG_MAP.values():
            with self.subTest(config=config.__class__.__name__):
                self.assertRaises(FrozenInstanceError, lambda: set_attr(config))
コード例 #4
0
    def test_model_config_serialization(self):
        """PretrainedConfigurations should not raise an Exception when serializing the config dict

        See, e.g., PretrainedConfig.to_json_string()
        """
        for k, v in ADAPTER_CONFIG_MAP.items():
            model = AutoModel.from_config(self.config())
            model.add_adapter("test", config=v)
            # should not raise an exception
            model.config.to_json_string()
コード例 #5
0
    def test_model_config_serialization(self):
        """PretrainedConfigurations should not raise an Exception when serializing the config dict

        See, e.g., PretrainedConfig.to_json_string()
        """
        for model_class in self.model_classes:
            for k, v in ADAPTER_CONFIG_MAP.items():
                model_config = model_class.config_class
                model = model_class(model_config())
                model.add_adapter("test", adapter_type=AdapterType.text_task, config=v)
                # should not raise an exception
                model.config.to_json_string()
コード例 #6
0
 def __init__(
     self,
     input_paths: List[str],
     output_path: str,
     template: str,
     extract_from_models: bool = False,
 ):
     self.input_paths = input_paths
     self.output_path = output_path or DEFAULT_OUTPUT_PATH
     self.template = template
     self.extract_from_models = extract_from_models
     self._validate_func = lambda x: len(x) > 0 or "This field must not be empty."
     self._input_cache = {}
     # create a lookup map for default configs
     self._config_id_lookup = {}
     for k, v in ADAPTER_CONFIG_MAP.items():
         config_hash = get_adapter_config_hash(v)
         self._config_id_lookup[config_hash] = k