Exemplo n.º 1
0
    def __init__(self, config: Config, *args, **kwargs):
        super().__init__()
        self.config = config
        hf_params = {"config": self._build_encoder_config(config)}

        # For BERT models, initialize using Jit version
        if self.config.bert_model_name.startswith("bert-"):
            self.module = BertModelJit.from_pretrained(
                self.config.bert_model_name, **hf_params)
        else:
            self.module = AutoModel.from_pretrained(
                self.config.bert_model_name, **hf_params)
        self.embeddings = self.module.embeddings
        self.original_config = self.config
        self.config = self.module.config
        self._init_segment_embeddings()