Esempio n. 1
0
class Net(nn.Module):
    def __init__(self, config, bert_state_dict, vocab_len, device = 'cpu'):
        super().__init__()
        self.bert = BertModel(config)
        if bert_state_dict is not None:
            self.bert.load_state_dict(bert_state_dict)
        self.bert.eval()
        self.rnn = nn.LSTM(bidirectional=True, num_layers=2, input_size=768, hidden_size=768//2, batch_first=True)
        self.fc = nn.Linear(768, vocab_len)
        self.device = device

    def forward(self, x, y):
        '''
        x: (N, T). int64
        y: (N, T). int64

        Returns
        enc: (N, T, VOCAB)
        '''
        x = x.to(self.device)
        y = y.to(self.device)

        with torch.no_grad():
            encoded_layers, _ = self.bert(x)
            enc = encoded_layers[-1]
        enc, _ = self.rnn(enc)
        logits = self.fc(enc)
        y_hat = logits.argmax(-1)
        return logits, y, y_hat
Esempio n. 2
0
class Bert_Classification(nn.Module):
    def __init__(self,config,output_size):
        super(Bert_Classification, self).__init__()
        self.word_embeds = BertModel(config)
        self.word_embeds.load_state_dict(torch.load('/home/fayan/lxl/chip_2019/weights/robert/RoBertLarge_weight.bin'))
        #self.pooler = BertPooler(config)
        #self.dropout=nn.Dropout(0.5) 
        self.classification=nn.Linear(config.hidden_size,output_size)
    def forward(self, sentences,attention_mask,flag,labels=None):
        _, pooled_output= self.word_embeds(sentences, attention_mask=attention_mask, output_all_encoded_layers=False)
        #pooled_output = self.pooler(pooled_output)
        #print('model_shape:',_.size())
        #exit(1)
        '''
        if flag=='CLS':
            pooled_output=pooled_output[:,0]
        elif flag=='MAX':
            pooled_output=pooled_output.max(1)[0]
        elif flag=='MEAN':
            pooled_output=pooled_output.mean(1)
        # print(pooled_output.size())
        '''
        #pooled_output=self.dropout(pooled_output)
        logits=self.classification(pooled_output)
        # print('logits:',logits.size())
        # print('labels:', labels.size(),labels)
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(logits, labels)
            return loss, logits
        else:
            return logits
Esempio n. 3
0
class Dense_Classification(nn.Module):
    def __init__(self,config,output_size):
        super(Dense_Classification, self).__init__()
        self.word_embeds = BertModel(config)
        self.word_embeds.load_state_dict(torch.load('/home/fayan/lxl/chip_2019/weights/robert/RoBertLarge_weight.bin'))
        #self.pooler = BertPooler(config)
        self.linear=nn.Linear(config.hidden_size,128)
        #self.dropout=nn.Dropout(0.5) 
        self.classification=nn.Linear(config.hidden_size+128,output_size)
    def forward(self, sentences,attention_mask,flag,labels=None):
        _, pooled_output= self.word_embeds(sentences, attention_mask=attention_mask, output_all_encoded_layers=False)
        #pooled_output = self.pooler(pooled_output)
        linear=self.linear(pooled_output)
        #print(avg_pool.size(),max_pool.size(),hh_gru.size(),pooled_output.size())
        pooled_output=torch.cat((linear,pooled_output),1)
        #print(pooled_output.size())
        logits=self.classification(pooled_output)
        # print('logits:',logits.size())
        # print('labels:', labels.size(),labels)
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(logits, labels)
            return loss, logits
        else:
            return logits
Esempio n. 4
0
class LSTM_Classification(nn.Module):
    def __init__(self,config,output_size):
        super(LSTM_Classification, self).__init__()
        self.word_embeds = BertModel(config)
        self.word_embeds.load_state_dict(torch.load('/home/fayan/lxl/chip_2019/weights/robert/RoBertLarge_weight.bin'))
        #self.pooler = BertPooler(config)
        self.lstm=nn.LSTM(config.hidden_size,128,bidirectional=True,batch_first=True)
        #self.dropout=nn.Dropout(0.5) 
        self.classification=nn.Linear(config.hidden_size+128*6,output_size)
    def forward(self, sentences,attention_mask,flag,labels=None):
        _, pooled_output= self.word_embeds(sentences, attention_mask=attention_mask, output_all_encoded_layers=False)
        #pooled_output = self.pooler(pooled_output)
        h_lstm,(hidden_state,cell_state)=self.lstm(_)
        hh_lstm=torch.cat((hidden_state[0],hidden_state[1]),dim=1)
        avg_pool=torch.mean(h_lstm,1)
        max_pool,_=torch.max(h_lstm,1)
        #print(avg_pool.size(),max_pool.size(),hh_lstm.size(),pooled_output.size())
        pooled_output=torch.cat((avg_pool,max_pool,pooled_output,hh_lstm),1)
        logits=self.classification(pooled_output)
        # print('logits:',logits.size())
        # print('labels:', labels.size(),labels)
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(logits, labels)
            return loss, logits
        else:
            return logits
Esempio n. 5
0
def get_kobert_model(model_file, vocab_file, ctx="cpu"):
    bertmodel = BertModel(config=BertConfig.from_dict(bert_config))
    bertmodel.load_state_dict(torch.load(model_file))
    device = torch.device(ctx)
    bertmodel.to(device)
    bertmodel.eval()
    vocab_b_obj = nlp.vocab.BERTVocab.from_json(open(vocab_file, 'rt').read())
    return bertmodel, vocab_b_obj
Esempio n. 6
0
def get_kobert_model(model_file, vocab_file, ctx="cpu"):
    bertmodel = BertModel(config=BertConfig.from_dict(bert_config))
    bertmodel.load_state_dict(torch.load(model_file))
    device = torch.device(ctx)
    bertmodel.to(device)
    bertmodel.eval()
    vocab_b_obj = nlp.vocab.BERTVocab.from_sentencepiece(vocab_file,
                                                         padding_token='[PAD]')
    return bertmodel, vocab_b_obj
Esempio n. 7
0
def get_kobert_model(ctx="cpu"):
    model_file = './kobert_model/pytorch_kobert_2439f391a6.params'
    vocab_file = './kobert_model/kobertvocab_f38b8a4d6d.json'
    bertmodel = BertModel(config=BertConfig.from_dict(bert_config))
    bertmodel.load_state_dict(torch.load(model_file))
    device = torch.device(ctx)
    bertmodel.to(device)
    bertmodel.eval()
    #print(vocab_file) #./kobertvocab_f38b8a4d6d.json
    vocab_b_obj = nlp.vocab.BERTVocab.from_json(
        open(vocab_file, 'rt').read())
    #print(vocab_b_obj)
    return bertmodel, vocab_b_obj
Esempio n. 8
0
class CNN_Classification(nn.Module):
    def __init__(self,config,output_size):
        super(CNN_Classification, self).__init__()
        self.word_embeds = BertModel(config)
        self.word_embeds.load_state_dict(torch.load('/home/fayan/lxl/chip_2019/weights/robert/RoBertLarge_weight.bin'))
        #self.pooler = BertPooler(config)
        self.embedding_dropout = SpatialDropout1D(config.hidden_dropout_prob)
        filters = [3, 4, 5]
        self.conv_layers = nn.ModuleList()
        for filter_size in filters:
            conv_block = nn.Sequential(
                nn.Conv1d(
                    config.hidden_size,
                    CHANNEL_UNITS,
                    kernel_size=filter_size,
                    padding=1,
                ),
                # nn.BatchNorm1d(CHANNEL_UNITS),
                # nn.ReLU(inplace=True),
            )
            self.conv_layers.append(conv_block)
        #self.dropout=nn.Dropout(0.5) 
        self.classification=nn.Linear(config.hidden_size+CHANNEL_UNITS*6,output_size)
    def forward(self, sentences,attention_mask,flag,labels=None):
        _, pooled_output= self.word_embeds(sentences, attention_mask=attention_mask, output_all_encoded_layers=False)
        #pooled_output = self.pooler(pooled_output)
        h_embedding = _.permute(0, 2, 1)
        feature_maps= []
        for layer in self.conv_layers:
            h_x= layer(h_embedding)
            feature_maps.append(
                F.max_pool1d(h_x, kernel_size=h_x.size(2)).squeeze()
            )
            feature_maps.append(
                F.avg_pool1d(h_x, kernel_size=h_x.size(2)).squeeze()
            )
        conv_features= torch.cat(feature_maps, 1)
        #print(conv_features.size())
        pooled_output= torch.cat((conv_features, pooled_output), 1)
        #print(pooled_output.size())
        logits=self.classification(pooled_output)
        # print('logits:',logits.size())
        # print('labels:', labels.size(),labels)
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(logits, labels)
            return loss, logits
        else:
            return logits
Esempio n. 9
0
class BioBertNER(nn.Module):

  def __init__(self, vocab_len, config, state_dict):
    super().__init__()
    self.bert = BertModel(config)
    self.bert.load_state_dict(state_dict, strict=False)
    self.dropout = nn.Dropout(p=0.3)
    self.output = nn.Linear(self.bert.config.hidden_size, vocab_len)
    self.softmax = nn.Softmax(dim=1)

  def forward(self, input_ids, attention_mask):
    encoded_layer, _ = self.bert(input_ids=input_ids, attention_mask=attention_mask)
    encl = encoded_layer[-1]
    out = self.dropout(encl)
    out = self.output(out)
    return out, out.argmax(-1)
Esempio n. 10
0
class BioBERTModel(nn.Module):
    def __init__(self, args, config, state_dict):

        super().__init__()
        self.args = args
        self.bert = BertModel(config)
        self.bert.load_state_dict(state_dict, strict=False)
        self.hidden_bert = self.bert.config.hidden_size
        self.dropout_bert = self.bert.config.hidden_dropout_prob
        self.dropout = nn.Dropout(p=0.3)
        self.fc = nn.Linear(self.hidden_bert, self.args.num_labels)

    def forward(self, input_ids, token_type_ids, attention_mask):
        encoded_layer, _ = self.bert(input_ids=input_ids,
                                     token_type_ids=token_type_ids,
                                     attention_mask=attention_mask)
        encl = encoded_layer[-1]
        #print (encl.shape)
        currentLevel = self.dropout(encl)
        currentLevel = self.fc(currentLevel)
        return currentLevel
Esempio n. 11
0
class STS_NET(nn.Module):
    def __init__(self, config, bert_state_dict, device=Param.device):
        super().__init__()
        self.bert = BertModel(config)
        #print('bert initialized from config')
        if bert_state_dict is not None:
            self.bert.load_state_dict(bert_state_dict)
        self.bert.eval()
        self.dropout = nn.Dropout(p=Param.p)
        self.rnn = nn.LSTM(bidirectional=True,
                           num_layers=1,
                           input_size=768,
                           hidden_size=768 // 2)
        self.f1 = nn.Linear(768 // 2, 128)
        self.f2 = nn.Linear(128, 32)
        self.out = nn.Linear(32, 1)
        self.device = device

    def init_hidden(self, batch_size):
        return torch.zeros(2, batch_size,
                           768 // 2).to(self.device), torch.zeros(
                               2, batch_size, 768 // 2).to(self.device)

    def forward(self, x_f, x_r):
        batch_size = x_f.size()[0]
        x_f = x_f.to(self.device)
        x_r = x_r.to(self.device)
        xf_encoded_layers, _ = self.bert(x_f)
        enc_f = xf_encoded_layers[-1]
        enc = enc_f.permute(1, 0, 2)
        enc = self.dropout(enc)
        self.hidden = self.init_hidden(batch_size)
        rnn_out, self.hidden = self.rnn(enc, self.hidden)
        last_hidden_state, last_cell_state = self.hidden
        rnn_out = self.dropout(last_hidden_state)
        f1_out = F.relu(self.f1(last_hidden_state[-1]))
        f2_out = F.relu(self.f2(f1_out))
        out = self.out(f2_out)
        return out
Esempio n. 12
0
class BERT_LSTM_CRF(nn.Module):
    """
    bert_lstm_crf model
    bert_model=BertModel(config=BertConfig.from_json_file(args.bert_config_json))
    """
    def __init__(self, args, tagset_size, embedding_dim, hidden_dim, rnn_layers, dropout_ratio, dropout1, use_cuda=False):
        super(BERT_LSTM_CRF, self).__init__()
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.word_embeds = BertModel(config=BertConfig.from_json_file(args.bert_config_json))
        # print(self.word_embeds)
        self.word_embeds.load_state_dict(torch.load('./ckpts/9134_bert_weight.bin'))
        self.lstm = nn.LSTM(embedding_dim, hidden_dim,
                            num_layers=rnn_layers, bidirectional=True, dropout=dropout_ratio, batch_first=True)
        self.rnn_layers = rnn_layers
        self.dropout1 = nn.Dropout(p=dropout1)
        self.crf = CRF(target_size=tagset_size, average_batch=True, use_cuda=use_cuda)
        self.liner = nn.Linear(hidden_dim*2, tagset_size+2)
        self.tagset_size = tagset_size

    def rand_init_hidden(self, batch_size):
        """
        random initialize hidden variable
        """
        return Variable(
            torch.randn(2 * self.rnn_layers, batch_size, self.hidden_dim)), Variable(
            torch.randn(2 * self.rnn_layers, batch_size, self.hidden_dim))

    def forward(self, sentence, attention_mask=None):
        '''
        args:
            sentence (word_seq_len, batch_size) : word-level representation of sentence
            hidden: initial hidden state

        return:
            crf output (word_seq_len, batch_size, tag_size, tag_size), hidden
        '''
        batch_size = sentence.size(0)
        seq_length = sentence.size(1)
        embeds, _ = self.word_embeds(sentence, attention_mask=attention_mask, output_all_encoded_layers=False)
        # print(embeds,_)
        hidden = self.rand_init_hidden(batch_size)
        # if embeds.is_cuda:
        #     hidden = (i.cuda() for i in hidden)
        # embeds=(embeds,dim=0,keepdim=True)
        lstm_out, hidden = self.lstm(embeds)
        lstm_out = lstm_out.contiguous().view(-1, self.hidden_dim*2)
        d_lstm_out = self.dropout1(lstm_out)
        l_out = self.liner(d_lstm_out)
        lstm_feats = l_out.contiguous().view(batch_size, seq_length, -1)
        return lstm_feats

    def loss(self, feats, mask, tags):
        """
        feats: size=(batch_size, seq_len, tag_size)
            mask: size=(batch_size, seq_len)
            tags: size=(batch_size, seq_len)
        :return:
        """
        loss_value = self.crf.neg_log_likelihood_loss(feats, mask, tags)
        batch_size = feats.size(0)
        loss_value /= float(batch_size)
        return loss_value
Esempio n. 13
0
class BertForTokenClassification(nn.Module):
    """BERT model for token-level classification.
    This module is composed of the BERT model with a linear layer on top of
    the full hidden state of the last layer.
    Params:
        `config`: a config.
    Outputs:
        if `labels` is not `None`:
            Outputs the CrossEntropy classification loss of the output with the labels.
        if `labels` is `None`:
            Outputs the classification logits of shape [batch_size, sequence_length, num_labels].
    Example usage:
    ```python
    # Already been converted into WordPiece token ids
    input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
    input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
    token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
    config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
        num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
    num_labels = 2
    model = BertForTokenClassification(config, num_labels)
    logits = model(input_ids, token_type_ids, input_mask)
    ```
    """
    def __init__(self, config, num_labels, bert_state_dict=None):
        super().__init__()
        self.num_labels = num_labels
        self.bert = BertModel(config)
        if bert_state_dict is not None:
            self.bert.load_state_dict(bert_state_dict)
        # we don't fine tune bert, it requires large GPU mem
        #self.bert.eval()
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

        #self.rnn = nn.LSTM(bidirectional=True, num_layers=2, input_size=768, hidden_size=768//2, batch_first=True)
        self.classifier = nn.Linear(config.hidden_size, num_labels)

    def forward(self,
                input_ids,
                token_type_ids=None,
                attention_mask=None,
                labels=None):
        #with torch.no_grad():
        sequence_output, _ = self.bert(input_ids,
                                       token_type_ids,
                                       attention_mask,
                                       output_all_encoded_layers=False)
        sequence_output = self.dropout(sequence_output)
        #sequence_output, _ = self.rnn(sequence_output)
        #sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)

        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()

            # Only keep active parts of the loss
            if attention_mask is not None:
                active_loss = attention_mask.view(-1) == 1
                active_logits = logits.view(-1, self.num_labels)[active_loss]
                active_labels = labels.view(-1)[active_loss]
                loss = loss_fct(active_logits, active_labels)
            else:
                loss = loss_fct(logits.view(-1, self.num_labels),
                                labels.view(-1))
            return loss
        else:
            return logits
Esempio n. 14
0
class BertBiLstmCrf(BaseModel):  # 模型网络定义,共4层:embedding lstm linear crf
    def __init__(self, args):
        super(BertBiLstmCrf, self).__init__(args)

        self.args = args
        self.vector_path = args.vector_path  # 预训练词向量的路径 txt
        self.embedding_dim = args.embedding_dim
        self.hidden_dim = args.hidden_dim  # 隐藏层
        self.tag_num = args.tag_num
        self.batch_size = args.batch_size
        self.bidirectional = True  # BiLstm
        self.num_layers = args.num_layers
        self.pad_index = args.pad_index
        self.dropout = args.dropout  # 训练时网络中连接 dropout 的概率
        self.save_path = args.save_path

        embedding_dimension = args.bert_embedding_dim

        self.embedding = BertModel(
            config=BertConfig.from_json_file(args.bert_config_json)).to(DEVICE)
        self.embedding.load_state_dict(torch.load(args.bert_weight))
        self.drop = nn.Dropout(0.5)

        self.lstm = nn.LSTM(embedding_dimension,
                            self.hidden_dim,
                            bidirectional=self.bidirectional,
                            num_layers=self.num_layers,
                            dropout=self.dropout).to(DEVICE)
        # hidden 除以 2 是因为,双向lstm输出的时候会翻倍,不除以二改成下面的 linear层中 hidden*2 也行
        self.linear1 = nn.Linear(self.hidden_dim * 2,
                                 self.hidden_dim).to(DEVICE)
        self.lin_drop = nn.Dropout(0.5)
        self.linear2 = nn.Linear(self.hidden_dim, self.tag_num + 2).to(
            DEVICE)  # 隐藏层到 label 的线性转换,即维度变换。
        self.crf_layer = CRF(self.tag_num).to(DEVICE)

    # def init_weight(self):  # 对各层的权重矩阵进行初始化
    #     nn.init.xavier_normal_(self.embedding.weight)
    #
    #     for name, param in self.lstm.named_parameters():
    #         if 'weight' in name:
    #             nn.init.xavier_normal_(param)
    #
    #     nn.init.xavier_normal_(self.linear.weight)

    def init_hidden(self, batch_size=None):  # 生成 lstm 层的输入
        if batch_size is None:
            batch_size = self.batch_size
        h0 = torch.zeros(self.num_layers * 2, batch_size,
                         self.hidden_dim).to(DEVICE)
        c0 = torch.zeros(self.num_layers * 2, batch_size,
                         self.hidden_dim).to(DEVICE)
        return h0, c0

    def loss(self, x, sent_lengths, y):
        mask = torch.ne(x, self.pad_index)  # 判断 x 中数字为 1 时(即该位置为pad补丁) mask取0
        emissions = self.lstm_forward(x, sent_lengths, mask)
        emissions = torch.transpose(
            emissions, 1, 0)  # 矩阵转置为:[batch_size*sentence_length*tag_size]
        mask = torch.transpose(mask, 1, 0)
        y = torch.transpose(y, 1, 0)
        loss_function = self.crf_layer.neg_log_likelihood_loss
        return loss_function(emissions, mask, y)
        # return self.crflayer(emissions, y, mask=mask)  # compare to forward, 'decode' was miss

    def forward(self, x,
                sent_lengths):  # 前向传播函数,模型的 input -> forward -> output
        '''
        :param x: (sentence_length, batch_size)
        :param sent_lengths: (batch_size)
        :return tag_list:(sentence_length, batch_size)
        '''
        mask = torch.ne(x, self.pad_index)
        emissions = self.lstm_forward(x, sent_lengths)
        emissions = torch.transpose(
            emissions, 1, 0)  # 矩阵转置为:[batch_size*sentence_length*tag_size]
        mask = torch.transpose(mask, 1, 0)
        path_score, best_paths = self.crf_layer(emissions, mask)
        tag_list = []
        for i in range(best_paths.size(0)):
            tag_list.append(
                best_paths[i].cpu().data.numpy()[:torch.sum(mask[i])])
        return tag_list

    def lstm_forward(self, sentence, sent_lengths, mask):
        # x = self.embedding(sentence.to(DEVICE)).to(DEVICE)  # input embedding, output x
        embeds, _ = self.embedding(sentence,
                                   attention_mask=mask,
                                   output_all_encoded_layers=False)
        embeds = self.drop(embeds)
        embeds = pack_padded_sequence(embeds, sent_lengths)  # 长度对齐
        hidden = self.init_hidden(batch_size=len(sent_lengths))
        lstm_out, _ = self.lstm(
            embeds,
            hidden)  # input lstm, output lstm_out. hidden 要作为一个整体输入(h0,c0)
        lstm_out, new_batch_size = pad_packed_sequence(lstm_out)

        assert torch.equal(sent_lengths, new_batch_size.to(DEVICE))
        y = self.linear1(lstm_out.to(DEVICE))  # input linear,把输出维度变为 标签 个数。
        y = self.lin_drop(y)
        y = self.linear2(y)
        return y.to(DEVICE)
Esempio n. 15
0
class Net(nn.Module):
    def __init__(self, config, bert_state_dict, vocab_len, device='cuda'):
        super().__init__()
        self.bert = BertModel(config)
        self.num_layers = 2
        self.input_size = 768
        self.hidden_size = 768
        self.tagset_size = vocab_len
        # BERT always returns hidden_dim*2 dimensional representations.
        if bert_state_dict is not None:
            self.bert.load_state_dict(bert_state_dict)
        self.bert.eval()
        # Each input has vector size 768, and outpus a vector size of 768//2.
        self.lstm = nn.LSTM(self.input_size,
                            self.hidden_size // 2,
                            self.num_layers,
                            batch_first=True,
                            bidirectional=True)
        self.fc = nn.Linear(self.hidden_size, vocab_len)
        self.device = device

    def init_hidden(self, batch_size):
        ''' Initializes hidden state '''
        # Create two new tensors with sizes n_layers x batch_size x hidden_dim,
        # initialized to zero, for hidden state and cell state of LSTM
        weight = next(self.parameters()).data

        if self.device == 'cuda':
            hidden = (nn.init.xavier_normal_(
                weight.new(self.num_layers * 2, batch_size,
                           self.hidden_size // 2).zero_()).cuda(),
                      nn.init.xavier_normal_(
                          weight.new(self.num_layers * 2, batch_size,
                                     self.hidden_size // 2).zero_()).cuda())
        else:
            hidden = (nn.init.xavier_normal_(
                weight.new(self.num_layers * 2, batch_size,
                           self.hidden_size // 2).zero_()),
                      nn.init.xavier_normal_(
                          weight.new(self.num_layers * 2, batch_size,
                                     self.hidden_size // 2).zero_()))

        return hidden

    def init_eval_hidden(self, batch_size):
        ''' Initializes hidden state '''
        # Create two new tensors with sizes n_layers x batch_size x hidden_dim,
        # initialized to zero, for hidden state and cell state of LSTM
        weight = next(self.parameters()).data

        hidden = (nn.init.xavier_normal_(
            weight.new(self.num_layers * 2, 1, self.hidden_size // 2).zero_()),
                  nn.init.xavier_normal_(
                      weight.new(self.num_layers * 2, 1,
                                 self.hidden_size // 2).zero_()))

        return hidden

    def forward(self, x, hidden):
        with torch.no_grad():
            encoded_layers, _ = self.bert(x)
            enc = encoded_layers[-1]
        out, hidden = self.lstm(enc, hidden)
        logits = self.fc(out)
        # softmax = torch.nn.Softmax(dim=2)
        # logits = softmax(logits)
        y_hat = logits.argmax(-1)
        return logits, hidden, y_hat