Пример #1
0
        def create_and_check_xlnet_base_model(self, config, input_ids_1,
                                              input_ids_2, input_ids_q,
                                              perm_mask, input_mask,
                                              target_mapping, segment_ids,
                                              lm_labels, sequence_labels,
                                              is_impossible_labels):
            model = XLNetModel(config)
            model.eval()

            _, _ = model(input_ids_1, input_mask=input_mask)
            _, _ = model(input_ids_1, attention_mask=input_mask)
            _, _ = model(input_ids_1, token_type_ids=segment_ids)
            outputs, mems_1 = model(input_ids_1)

            result = {
                "mems_1": mems_1,
                "outputs": outputs,
            }

            self.parent.assertListEqual(
                list(result["outputs"].size()),
                [self.batch_size, self.seq_length, self.hidden_size])
            self.parent.assertListEqual(
                list(list(mem.size()) for mem in result["mems_1"]),
                [[self.seq_length, self.batch_size, self.hidden_size]] *
                self.num_hidden_layers)
    def __init__(self, config):
        super().__init__()

        self.vocab = vocab  = Vocab.load(config['vocab_file'])
        self.src_word_embed = nn.Embedding(len(vocab.source_tokens), config['source_embedding_size'])
        self.config = config

        self.decoder_cell_init = nn.Linear(config['source_encoding_size'], config['decoder_hidden_size'])

        if self.config['transformer'] == 'none':
            dropout = config['dropout']
            self.lstm_encoder = nn.LSTM(input_size=self.src_word_embed.embedding_dim,
                                        hidden_size=config['source_encoding_size'] // 2, num_layers=config['num_layers'],
                                        batch_first=True, bidirectional=True, dropout=dropout)

            self.dropout = nn.Dropout(dropout)

        elif self.config['transformer'] == 'bert':
            self.vocab_size = len(self.vocab.source_tokens) + 1

            state_dict = torch.load('saved_checkpoints/bert_2604/bert_pretrained_epoch_23_batch_140000.pth')

            keys_to_delete = ["cls.predictions.bias", "cls.predictions.transform.dense.weight", "cls.predictions.transform.dense.bias", "cls.predictions.transform.LayerNorm.weight",
                            "cls.predictions.transform.LayerNorm.bias", "cls.predictions.decoder.weight", "cls.predictions.decoder.bias",
                            "cls.seq_relationship.weight", "cls.seq_relationship.bias"]

            from collections import OrderedDict
            new_state_dict = OrderedDict()
            for k, v in state_dict['model'].items():
                if k in keys_to_delete: continue
                name = k[5:] # remove `bert.`
                new_state_dict[name] = v

            bert_config = BertConfig(vocab_size=self.vocab_size, max_position_embeddings=512, num_hidden_layers=6, hidden_size=256, num_attention_heads=4)
            self.bert_model = BertModel(bert_config)
            self.bert_model.load_state_dict(new_state_dict)

        elif self.config['transformer'] == 'xlnet':
            self.vocab_size = len(self.vocab.source_tokens) + 1

            state_dict = torch.load('saved_checkpoints/xlnet_2704/xlnet1_pretrained_epoch_13_iter_500000.pth')

            keys_to_delete = ["lm_loss.weight", "lm_loss.bias"]

            from collections import OrderedDict
            new_state_dict = OrderedDict()
            for k, v in state_dict['model'].items():
                if k in keys_to_delete: continue
                if k[:12] == 'transformer.': name = k[12:]
                else:                       name = k
                new_state_dict[name] = v

            xlnet_config = XLNetConfig(vocab_size=self.vocab_size, d_model=256, n_layer=12)
            self.xlnet_model = XLNetModel(xlnet_config)
            self.xlnet_model.load_state_dict(new_state_dict)
        else:
            print("Error! Unknown transformer type '{}'".format(self.config['transformer']))
Пример #3
0
        def create_and_check_xlnet_base_model_with_att_output(
                self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask,
                input_mask, target_mapping, segment_ids, lm_labels,
                sequence_labels, is_impossible_labels, token_labels):
            model = XLNetModel(config)
            model.eval()

            _, _, attentions = model(input_ids_1,
                                     target_mapping=target_mapping)

            self.parent.assertEqual(len(attentions), config.n_layer)
            self.parent.assertIsInstance(attentions[0], tuple)
            self.parent.assertEqual(len(attentions[0]), 2)
            self.parent.assertTrue(attentions[0][0].shape,
                                   attentions[0][0].shape)
Пример #4
0
    def __init__(self, model_name, num_labels=2):
        super(ClassificationXLNet, self).__init__()

        self.transformer = XLNetModel.from_pretrained(model_name)
        self.max_pool = nn.MaxPool1d(64)
        self.drop = nn.Dropout(0.3)
        self.linear = nn.Sequential(nn.Linear(768, num_labels))
    def __init__(self, bert_config, device, dropout_rate, n_class, out_channel=16):
        """

        :param bert_config: str, BERT configuration description
        :param device: torch.device
        :param dropout_rate: float
        :param n_class: int
        :param out_channel: int, NOTE: out_channel per layer of BERT
        """

        super(CustomBertConvModel, self).__init__()

        self.bert_config = bert_config
        self.dropout_rate = dropout_rate
        self.n_class = n_class
        self.out_channel = out_channel
        self.bert = XLNetModel.from_pretrained(self.bert_config, output_hidden_states=True)
        self.out_channels = self.bert.config.num_hidden_layers*self.out_channel
        self.tokenizer = XLNetTokenizer.from_pretrained(self.bert_config)
        self.conv = nn.Conv2d(in_channels=self.bert.config.num_hidden_layers,
                              out_channels=self.out_channels,
                              kernel_size=(3, self.bert.config.hidden_size),
                              groups=self.bert.config.num_hidden_layers)
        self.hidden_to_softmax = nn.Linear(self.out_channels, self.n_class, bias=True)
        self.dropout = nn.Dropout(p=self.dropout_rate)
        self.device = device
    def __init__(self, bert_config, device, dropout_rate, n_class, lstm_hidden_size=None):
        """

        :param bert_config: str, BERT configuration description
        :param device: torch.device
        :param dropout_rate: float
        :param n_class: int
        :param lstm_hidden_size: int
        """

        super(CustomBertLSTMAttentionModel, self).__init__()

        self.bert_config = bert_config
        self.bert = XLNetModel.from_pretrained(self.bert_config, output_hidden_states = False)
        self.tokenizer = XLNetTokenizer.from_pretrained(self.bert_config, output_hidden_states= False)

        if not lstm_hidden_size:
            self.lstm_hidden_size = self.bert.config.hidden_size
        else:
            self.lstm_hidden_size = lstm_hidden_size
        self.n_class = n_class
        self.dropout_rate = dropout_rate
        self.lstm = nn.LSTM(self.bert.config.hidden_size, self.lstm_hidden_size, bidirectional=True)
        self.hidden_to_softmax = nn.Linear(self.lstm_hidden_size * 2, n_class, bias=True)
        self.dropout = nn.Dropout(p=self.dropout_rate)
        self.softmax = nn.Softmax(dim=1)
        self.device = device
 def __init__(self, model_name, cache_dir, task_list):
     super(MultiTaskModel, self).__init__()
     cache = os.path.join(cache_dir, model_name)
     self.transformer = XLNetModel.from_pretrained(model_name,
                                                   cache_dir=cache)
     self.transformer_config = self.transformer.config
     self.dropout = DropoutWrapper(self.transformer_config.dropout)
     self.decoderID = {}  #模型内部的task_id与decoder_id的映射
     # self.decoder = {}
     self.decoder_list = nn.ModuleList()
     for innerid, task in enumerate(task_list):
         if task[1] == TaskType["classification"]:  # task[1] = tasktype
             classifier = Classification(self.transformer_config)
             # classifier = Classification(self.transformer_config)
             print("use simple classification")
             self.decoder_list.append(classifier)
         elif task[1] == TaskType["SANclassification"]:
             classifier = SANClassifier(self.transformer_config.hidden_size,
                                        self.transformer_config.hidden_size,
                                        label_size=1,
                                        dropout=self.dropout)
             print("use SANClassifier")
             self.decoder_list.append(classifier)
         else:
             pass
         self.decoderID[task[0]] = innerid
Пример #8
0
    def __init__(self,
                 max_seq_len=512,
                 min_window_overlap=128,
                 mask='none',
                 dropout_rate=0.1,
                 fp16=False,
                 yes_no_logits=False,
                 ctx_emb='bert'):
        super(DiaBERT, self).__init__()
        assert min_window_overlap % 2 == 0

        self.ctx_emb = ctx_emb
        if ctx_emb == 'bert':
            pretrained_bert = BertModel.from_pretrained('bert-base-uncased')
            self.bert = Bert(768, pretrained_bert, mask)
        elif ctx_emb == 'xlnet':
            self.bert = XLNetModel.from_pretrained('xlnet-base-cased')

        self.linear_start_end = torch.nn.Linear(768, 2, bias=False)
        self.max_seq_len = max_seq_len
        self.min_window_overlap = min_window_overlap
        self.fp16 = fp16
        if yes_no_logits:
            self.yesno_mlp = torch.nn.Sequential(torch.nn.Linear(768, 256),
                                                 torch.nn.ReLU(),
                                                 torch.nn.Linear(256, 3))
        else:
            self.yesno_mlp = None
    def __init__(self,
                 num_labels,
                 pretrained_model_name_or_path=None,
                 cat_num=0,
                 token_size=None,
                 MAX_SEQUENCE_LENGTH=512):
        super(BertModelForBinaryMultiLabelClassifier, self).__init__()
        if pretrained_model_name_or_path:
            # self.model = BertModel.from_pretrained(
            self.model = XLNetModel.from_pretrained(
                pretrained_model_name_or_path)
        else:
            raise NotImplementedError
        self.num_labels = num_labels
        if cat_num > 0:
            self.catembedding = nn.Embedding(cat_num, 768)
            self.catdropout = nn.Dropout(0.2)
            self.catactivate = nn.ReLU()

            self.catembeddingOut = nn.Embedding(cat_num, cat_num // 2 + 1)
            self.catactivateOut = nn.ReLU()
            self.dropout = nn.Dropout(0.2)
            self.classifier = nn.Linear(768 + cat_num // 2 + 1, num_labels)
        else:
            self.catembedding = None
            self.catdropout = None
            self.catactivate = None
            self.catembeddingOut = None
            self.catactivateOut = None
            self.dropout = nn.Dropout(0.2)
            self.classifier = nn.Linear(768, num_labels)

        # resize
        if token_size:
            self.model.resize_token_embeddings(token_size)
Пример #10
0
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = 3
        # RTE Task
        self.num_labels_3way = 3
        # RTE SPs multi-label task
        self.num_labels_multi = 5
        # RTE span detection task
        self.start_n_top = config.start_n_top
        self.end_n_top = config.end_n_top

        self.transformer = XLNetModel(config)
        self.sequence_summary = SequenceSummary(config)
        self.logits_proj_3way = nn.Linear(config.d_model, self.num_labels_3way)
        self.logits_proj_multi = nn.Linear(config.d_model,
                                           self.num_labels_multi)
        self.weights_3way = [1, 1.3, 3.3]
        self.weights_multi = [15, 10, 15, 5, 5]
        self.class_weights_3way = torch.FloatTensor(
            self.weights_3way).to(device)
        self.class_weights_multi = torch.FloatTensor(
            self.weights_multi).to(device)

        # RTE span detection task
        self.start_logits = PoolerStartLogits(config)
        self.end_logits = PoolerEndLogits(config)
        self.answer_class = PoolerAnswerClass(config)

        self.init_weights()
Пример #11
0
 def __init__(self,  vocab_size, hidden_size,output_size,num_labels=2, dropout_rate=0.3):
   super(XLNetClassification, self).__init__()
   self.xlnet = XLNetModel.from_pretrained('xlnet-base-cased')
   self.classifier = torch.nn.Linear(hidden_size, output_size)
   self.dropout = nn.Dropout(dropout_rate)
   self.embedding= nn.Embedding(vocab_size, hidden_size, padding_idx = 0)
   torch.nn.init.xavier_normal_(self.classifier.weight)
Пример #12
0
    def __init__(self,
                 config,
                 device,
                 pretrained_model,
                 with_semi=True,
                 with_sum=True):
        super().__init__()
        self.cls_x = nn.Linear(config.d_model, config.num_labels)
        self.cls_s = nn.Linear(config.d_model, config.num_labels)
        self.mlp_x = nn.Sequential(
            nn.Linear(config.hidden_size, config.hidden_size), nn.ReLU(),
            nn.Linear(config.hidden_size, 256))
        self.mlp_s = nn.Sequential(
            nn.Linear(config.hidden_size, config.hidden_size), nn.ReLU(),
            nn.Linear(config.hidden_size, 256))

        self.f = XLNetModel(config)
        self.scl_criterion = SupConLoss(temperature=0.3, base_temperature=0.3)
        self.ce_criterion = nn.CrossEntropyLoss()
        # self.f = copy.deepcopy(pretrained_enc)

        # self.f = RobertaModel(config)
        self.device = device
        self.init_weights(pretrained_model)
        self.with_semi = with_semi
        self.with_sum = with_sum
Пример #13
0
def get_model(output_dir):
    '''
    output_dir: path to hugging face directory
    '''
    model = XLNetModel.from_pretrained(output_dir)
    tokeniser = XLNetTokenizer.from_pretrained(output_dir)
    return model, tokeniser
Пример #14
0
    def __init__(self, config):
        super().__init__()
        if config['model']['pretrained_model'] == 'XLNet':
            self.pretrainedModel = XLNetModel.from_pretrained(
                config['model']['xlnet_base_chinese'])
            self.tokenizer = XLNetTokenizer.from_pretrained(
                self.config['model']['xlnet_base_chinese'], do_lower_case=True)

        if config['model']['pretrained_model'] == 'Bert':
            self.pretrainedModel = BertModel.from_pretrained(
                config['model']['bert_base_chinese'])
            self.tokenizer = BertTokenizer.from_pretrained(
                config['model']['bert_base_chinese'], do_lower_case=True)

        #for p in self.bertModel.parameters(): p.requires_grad = False
        self.dropout = nn.Dropout(config['model']['dropout'])
        self.lstm = nn.LSTM(
            input_size=768,
            hidden_size=768 // 2,
            batch_first=True,
            bidirectional=True
        )  #, num_layers=2,dropout=config['model']['dropout'])
        #self.layerNorm = nn.LayerNorm(768)
        self.fc = nn.Linear(768, len(tagDict))
        #weight = torch.Tensor([1, 1, 2.5, 2.5, 2.5]).to(config['DEVICE'])
        weight = torch.Tensor([1, 1, 3, 3, 3]).to(config['DEVICE'])
        self.criterion = nn.CrossEntropyLoss(weight=weight)
Пример #15
0
    def __init__(self, config):
        super(XLNet, self).__init__()
        self.xlnet = XLNetModel.from_pretrained(config.model_path)

        self.isDropout = True if 0 < config.dropout < 1 else False
        self.dropout = nn.Dropout(p=config.dropout)
        self.fc = nn.Linear(self.xlnet.d_model, config.num_class)
Пример #16
0
    def __init__(self, num_labels=2):
        super(XLNetForMultiLabelSequenceClassification, self).__init__()
        self.num_labels = num_labels
        self.xlnet = XLNetModel.from_pretrained('xlnet-base-cased')
        self.classifier = torch.nn.Linear(768, num_labels)

        torch.nn.init.xavier_normal_(self.classifier.weight)
Пример #17
0
 def __init__(self, xlnet_config):
     super(XLNetABSATagger, self).__init__(xlnet_config)
     self.num_labels = xlnet_config.num_labels
     self.xlnet = XLNetModel(xlnet_config)
     self.tagger_config = xlnet_config.absa_tagger_config
     self.tagger = None
     if self.tagger_config.tagger == '':
         # hidden size at the penultimate layer
         penultimate_hidden_size = xlnet_config.d_model
     else:
         self.tagger_dropout = nn.Dropout(
             self.tagger_config.hidden_dropout_prob)
         if self.tagger_config.tagger in ['RNN', 'LSTM', 'GRU']:
             # 2-layer bi-directional rnn decoder
             self.tagger = getattr(nn, self.tagger_config.tagger)(
                 input_size=xlnet_config.d_model,
                 hidden_size=self.tagger_config.hidden_size // 2,
                 num_layers=self.tagger_config.n_rnn_layers,
                 batch_first=True,
                 bidirectional=True)
         elif self.tagger_config.tagger in ['CRF']:
             # crf tagger
             raise Exception("Unimplemented now!!")
         else:
             raise Exception('Unimplemented tagger %s...' %
                             self.tagger_config.tagger)
         penultimate_hidden_size = self.tagger_config.hidden_size
     self.tagger_dropout = nn.Dropout(
         self.tagger_config.hidden_dropout_prob)
     self.classifier = nn.Linear(penultimate_hidden_size,
                                 xlnet_config.num_labels)
     self.apply(self.init_weights)
Пример #18
0
 def __init__(self, config):
     super(Model, self).__init__()
     self.xlnet = XLNetModel.from_pretrained(config.xlnet_path, num_labels=config.num_classes)
     for param in list(self.xlnet.parameters())[:-5]:
         param.requires_grad = False
     self.fc = nn.Linear(config.hidden_size, 192)
     self.fc1 = nn.Linear(192, config.num_classes)
 def __init__(self):
     super(StsClassifier, self).__init__()
     self.xlnet = XLNetModel.from_pretrained('xlnet-large-cased')
     self.linear1 = nn.Linear(1024, 1024)
     self.linear2 = nn.Linear(1024, 512)
     self.linear3 = nn.Linear(512, 1)
     self.activation = nn.ReLU()
Пример #20
0
    def __init__(self, args, num_class, use_cls=True):
        super().__init__()
        self.args = args
        # gcn layer
        self.use_cls = use_cls

        self.dropout = nn.Dropout(args.dropout)

        if args.basemodel == 'xlnet':
            self.xlnet = XLNetModel.from_pretrained(args.bert_model_dir)
        elif args.basemodel == 'xlnet_dialog':
            self.xlnet = XLNetModel_dialog.from_pretrained(args.bert_model_dir)
        self.xlnet.mem_len = args.mem_len
        self.xlnet.attn_type = args.attn_type
        in_dim = args.bert_dim

        pool_layers = [nn.Linear(in_dim, args.hidden_dim), nn.ReLU()]
        self.pool_fc = nn.Sequential(*pool_layers)

        # output mlp layers
        layers = []
        for _ in range(args.mlp_layers):
            layers += [nn.Linear(args.hidden_dim, args.hidden_dim), nn.ReLU()]
        layers += [nn.Linear(args.hidden_dim, num_class)]

        self.out_mlp = nn.Sequential(*layers)
Пример #21
0
    def __init__(self,
                 pretrained_model_dir,
                 num_classes,
                 segment_len=150,
                 dropout_p=0.5):
        super(MyXLNetModel, self).__init__()

        self.seg_len = segment_len

        self.config = XLNetConfig.from_json_file(pretrained_model_dir +
                                                 'config.json')
        self.config.mem_len = 150  # enable the memory #
        self.xlnet = XLNetModel.from_pretrained(pretrained_model_dir,
                                                config=self.config)

        if feature_extract:
            for p in self.xlnet.parameters():  # 迁移学习:xlnet作为特征提取器
                p.requires_grad = False

        d_model = self.config.hidden_size  # 768
        self.attention_layer1 = NyAttentioin(d_model, d_model // 2)
        self.attention_layer2 = NyAttentioin(d_model, d_model // 2)

        self.dropout = torch.nn.Dropout(p=dropout_p)
        self.fc = torch.nn.Linear(d_model, num_classes)
Пример #22
0
    def __init__(self, xlnet_path, num_classes, word_embedding, trained=True):
        super(XLNet, self).__init__()
        self.xlnet = XLNetModel.from_pretrained(xlnet_path)
        # 不对bert进行训练
        for param in self.xlnet.parameters():
            param.requires_grad = trained

        self.fc = nn.Linear(self.xlnet.d_model, num_classes)
Пример #23
0
    def __init__(self, config: XLNetConfig):
        super().__init__(config)
        self.num_labels = config.num_labels

        self.xlnet = XLNetModel(config)
        self.seq_summary = SequenceSummary(config)

        self.dropout = nn.Dropout(config.dropout)

        self.fc = nn.Linear(config.hidden_size, config.num_labels)
        self.fc_bn = nn.BatchNorm1d(config.num_labels)

        self.init_weights()

        # Default: freeze xlnet
        for name, param in self.xlnet.named_parameters():
            param.requires_grad = False
Пример #24
0
    def __init__(self, config):
        super(XLNetForXMC, self).__init__(config)
        self.num_labels = config.num_labels

        self.transformer = XLNetModel(config)
        self.sequence_summary = SequenceSummary(config)

        self.init_weights()
Пример #25
0
def get_bert(bert_model, bert_do_lower_case):
    # Avoid a hard dependency on BERT by only importing it if it's being used
    from transformers import XLNetTokenizer, XLNetModel
    model = XLNetModel.from_pretrained('huseinzol05/xlnet-base-bahasa-standard-cased')
    tokenizer = XLNetTokenizer.from_pretrained(
        'huseinzol05/xlnet-base-bahasa-standard-cased', do_lower_case = False
    )
    return tokenizer, model
Пример #26
0
def select_pretrained(model_name, cache_dir):
    cache = os.path.join(cache_dir, model_name)
    if 'bert' in model_name:
        return BertModel.from_pretrained(model_name, cache_dir=cache)
    elif 'xlnet' in model_name:
        return XLNetModel.from_pretrained(model_name, cache_dir=cache)
    else:
        return None
Пример #27
0
 def __init__(self):
     super(XLNetCls, self).__init__()
     self.xlnet = XLNetModel.from_pretrained(config.XLNET_MODEL_PATH)
     self.liner = torch.nn.Sequential(
         torch.nn.BatchNorm1d(config.EMBEDDING_DIM * 2), torch.nn.Dropout(),
         torch.nn.Linear(config.EMBEDDING_DIM * 2, 256),
         torch.nn.BatchNorm1d(256), torch.nn.Dropout(), torch.nn.ReLU(),
         torch.nn.Linear(256, 1), torch.nn.Sigmoid())
Пример #28
0
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels

        self.transformer = XLNetModel(config)
        self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)

        self.init_weights()
Пример #29
0
class XLNetConv(XLNetPreTrainedModel):
    def __init__(self, config: XLNetConfig):
        super().__init__(config)
        self.num_labels = config.num_labels

        self.xlnet = XLNetModel(config)
        self.seq_summary = SequenceSummary(config)

        self.convs = nn.ModuleList([
            nn.Conv2d(in_channels=1,
                      out_channels=config.n_filters,
                      kernel_size=(fsize, config.hidden_size))
            for fsize in config.filter_sizes
        ])

        self.dropout = nn.Dropout(config.dropout)
        self.fc = nn.Linear(config.hidden_size, config.num_labels)
        self.fc_bn = nn.BatchNorm1d(config.num_labels)

        self.init_weights()

        # Default: freeze xlnet
        for name, param in self.xlnet.named_parameters():
            param.requires_grad = False

    def forward(self, doc):
        """     
        Input:
            doc: [batch_size, seq_len, 2]           
        Returns:
            out: [batch_size, output_dim]  

        """
        # input_ids / attnention_mask: [batch_size, seq_len]
        xln_out = self.xlnet(input_ids=doc[:, :, 0],
                             attention_mask=doc[:, :, 1])

        xln = xln_out[0]  # [batch_size, seq_len, hidden_size]

        xln = xln.unsqueeze(1)  # [batch_size, 1, seq_len, hidden_size]

        conved = [F.relu(conv(xln)) for conv in self.convs
                  ]  # [batch_size, n_filters, (seq_len-fsize+1), 1]
        conved = [conv.squeeze(3) for conv in conved
                  ]  # [batch_size, n_filters, (seq_len-fsize+1)]
        pooled = [F.max_pool1d(conv, conv.shape[2])
                  for conv in conved]  # [batch_size, n_filters, 1]
        pooled = [pool.squeeze(2)
                  for pool in pooled]  # [batch_size, n_filters]

        cat = torch.cat(pooled,
                        dim=1)  # [batch_size, n_filters * len(filter_sizes)]
        dp = self.dropout(cat)
        out = self.fc(dp)  # # [batch_size, output_dim]
        out = self.fc_bn(out)
        out = F.softmax(out, dim=1)  # [batch_size, output_dim]

        return out
Пример #30
0
    def __init__(self, config):
        super(XLNetTokenClassificationHead, self).__init__(config)
        self.num_labels = config.num_labels

        self.transformer = XLNetModel(config)
        self.logits_proj = torch.nn.Linear(config.d_model, config.num_labels)
        self.dropout = torch.nn.Dropout(config.dropout)

        self.apply(self.init_weights)