def load(cls, model_name: str, tokenizer_name: str, cache_model: bool = True, adapter_size: int = 8, pretrained: bool = True) -> AutoModel: has_adapter = False if model_name.startswith("adapter"): has_adapter = True _, model_name = model_name.split("_") if model_name in cls._cache: return PretrainedAutoModel._cache[model_name] pretrained_config = AutoConfig.from_pretrained( model_name, output_hidden_states=True) if has_adapter: from src.modules.modeling_adapter_bert import AdapterBertModel pretrained_config.adapter_size = adapter_size model = AdapterBertModel.from_pretrained(model_name, config=pretrained_config) else: if pretrained: model = AutoModel.from_pretrained(model_name, config=pretrained_config) else: model = AutoModel.from_config(config=pretrained_config) if cache_model: cls._cache[model_name] = model return model
def __init__(self, config: Config, *args, **kwargs): super().__init__() self.config = config hf_params = {"config": self._build_encoder_config(config)} should_random_init = self.config.get("random_init", False) # For BERT models, initialize using Jit version if self.config.bert_model_name.startswith("bert-"): if should_random_init: self.module = BertModelJit(**hf_params) else: self.module = BertModelJit.from_pretrained( self.config.bert_model_name, **hf_params ) else: if should_random_init: self.module = AutoModel.from_config(**hf_params) else: self.module = AutoModel.from_pretrained( self.config.bert_model_name, **hf_params ) self.embeddings = self.module.embeddings self.original_config = self.config self.config = self.module.config self._init_segment_embeddings()