def __init__(self, config: Config, *args, **kwargs): super().__init__() self.config = config hf_params = {"config": self._build_encoder_config(config)} # For BERT models, initialize using Jit version if self.config.bert_model_name.startswith("bert-"): self.module = BertModelJit.from_pretrained( self.config.bert_model_name, **hf_params) else: self.module = AutoModel.from_pretrained( self.config.bert_model_name, **hf_params) self.embeddings = self.module.embeddings self.original_config = self.config self.config = self.module.config self._init_segment_embeddings()