class Bert_CRF(BertPreTrainedModel):
    def __init__(self,
                 config,
                 num_tag,
                 need_birnn=True,
                 hidden_size=768,
                 hidden_dropout_prob=0.5,
                 rnn_dim=768):
        super(Bert_CRF, self).__init__(config)
        self.bert = BertModel(config)
        self.dropout = nn.Dropout(hidden_dropout_prob)
        self.apply(self.init_bert_weights)
        out_dim = hidden_size
        self.need_birnn = need_birnn
        print("need_birnn=", need_birnn)

        # 如果为False,则不要BiLSTM层
        if need_birnn:
            self.bilstm = nn.LSTM(bidirectional=True,
                                  num_layers=2,
                                  input_size=768,
                                  hidden_size=768 // 2,
                                  batch_first=True)
        self.classifier = nn.Linear(out_dim, num_tag)
        self.crf = CRF(num_tag)

    def forward(self,
                input_ids,
                token_type_ids,
                attention_mask,
                label_id=None,
                output_all_encoded_layers=False):
        outputs, _ = self.bert(input_ids,
                               token_type_ids=token_type_ids,
                               attention_mask=attention_mask)
        sequence_output = outputs[0]
        #sequence_output = sequence_output.numpy()
        #print(sequence_output)
        if self.need_birnn:
            sequence_output, _ = self.bilstm(sequence_output)
        sequence_output = self.dropout(sequence_output)
        output = self.classifier(sequence_output)
        return output

    def loss_fn(self, bert_encode, output_mask, tags):
        loss = self.crf.negative_log_loss(bert_encode, output_mask, tags)
        return loss

    def predict(self, bert_encode, output_mask):
        predicts = self.crf.get_batch_best_path(bert_encode, output_mask)
        #print(predicts)
        predicts = predicts.view(1, -1).squeeze()
        predicts = predicts[predicts != -1]
        return predicts

    def result(self, y_pred, y_true):
        y_pred = np.array(y_pred)
        y_true = np.array(y_true)
        f_score = f1_score(y_true, y_pred, average="weighted")
        acc = accuracy_score(y_true, y_pred)
        precision = precision_score(y_true, y_pred, average="weighted")
        recall = recall_score(y_true, y_pred, average="weighted")
        return acc, f_score, precision, recall

    def class_report(self, y_pred, y_true):
        y_true = y_true.numpy()
        y_pred = y_pred.numpy()
        classify_report = classification_report(y_true, y_pred)
        print('\n\nclassify_report:\n', classify_report)
Example #2
0
class NER_Model(nn.Module):
    def __init__(self, bert_path, bert_dim, n_class, drop_p, num_pre):
        super(NER_Model, self).__init__()

        self.bert_model = BertModel.from_pretrained(bert_path)
        self.fc = nn.Linear(bert_dim * 2, n_class)
        self.dropout = nn.Dropout(drop_p)
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')

        # pre embedding
        self.pre_dim = bert_dim
        self.pre_embedding = nn.Embedding(num_pre, self.pre_dim)
        self.pre_embedding.weight.data.copy_(
            torch.from_numpy(
                self.random_embedding_label(num_pre, self.pre_dim, 0.025)))

        # transformer
        self.enc_positional_encoding = positional_encoding(768,
                                                           zeros_pad=True,
                                                           scale=True)
        for i in range(hp.num_blocks):
            self.__setattr__(
                'enc_self_attention_%d' % i,
                multihead_attention(num_units=hp.hidden_units,
                                    num_heads=hp.num_heads,
                                    dropout_rate=hp.dropout_rate,
                                    causality=False))
            self.__setattr__(
                'enc_feed_forward_%d' % i,
                feedforward(hp.hidden_units,
                            [4 * hp.hidden_units, hp.hidden_units]))

        # crf
        self.crf = CRF(n_class,
                       use_cuda=True if torch.cuda.is_available() else False)

    def random_embedding_label(self, vocab_size, embedding_dim, scale):
        pretrain_emb = np.empty([vocab_size, embedding_dim])
        # scale = np.sqrt(3.0 / embedding_dim)
        # scale = 0.025
        for index in range(vocab_size):
            pretrain_emb[index, :] = np.random.uniform(-scale, scale,
                                                       [1, embedding_dim])
        return pretrain_emb

    def encoder(self, embed, device):
        # Dropout
        self.enc = self.dropout(embed)
        # Blocks
        for i in range(hp.num_blocks):
            self.enc = self.__getattr__('enc_self_attention_%d' % i)(self.enc,
                                                                     self.enc,
                                                                     self.enc,
                                                                     device)
            # Feed Forward
            self.enc = self.__getattr__('enc_feed_forward_%d' % i)(
                self.enc)  # q和k一样,所以叫自注意力机制 (b, l, 768)
        return self.enc

    def forward(self, x, seg_id, p, mask, device):
        self.input_ids = x
        x_mask = (x != 0).int()
        out = self.bert_model(x, token_type_ids=seg_id,
                              attention_mask=x_mask)[0]
        # Positional Encoding (b, l, 768)
        pos = self.enc_positional_encoding(self.input_ids)
        out += pos.to(self.device)
        # word embedding + pre_embedding
        pre_embedding = self.pre_embedding(p)
        pre_embedding = pre_embedding * mask.unsqueeze(dim=2).expand(
            mask.size()[0],
            mask.size()[1], self.pre_dim)
        out = torch.cat((out, pre_embedding), 2)
        # transformer output
        encoder_out = self.encoder(out, device)
        out = self.fc(encoder_out)

        return out

    def loss_fn(self, transformer_encode, output_mask, tags):
        loss = self.crf.negative_log_loss(transformer_encode, output_mask,
                                          tags)
        return loss.cpu()

    def predict(self, transformer_encode, output_mask):
        predicts = self.crf.get_batch_best_path(transformer_encode,
                                                output_mask)  # (b,l)
        return predicts