示例#1
0
    def __init__(self):
        super(Bert_CRF, self).__init__()
        self.bert = BertForTokenClassification.from_pretrained(
            args.bert_model,
            cache_dir=PYTORCH_PRETRAINED_BERT_CACHE,
            num_labels=len(args.labels))

        self.crf = CRF(len(args.labels))
    def __init__(self, config, num_tag):
        super(Bert_CRF, self).__init__(config)
        self.bert = BertModel(config)
        # for p in self.bert.parameters():
        #     p.requires_grad = False
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, num_tag)
        self.apply(self.init_bert_weights)

        self.crf = CRF(num_tag)
class BILSTM_CRF(nn.Module):
    def __init__(self, vocab_size, word_embedding_dim, word2id, hidden_size,
                 bi_flag, num_layer, input_size, cell_type, dropout, num_tag,
                 checkpoint_dir):
        super(BILSTM_CRF, self).__init__()

        self.embedding = nn.Embedding(vocab_size, word_embedding_dim)
        for p in self.embedding.parameters():
            p.requires_grad = False
        self.embedding.weight.data.copy_(
            torch.from_numpy(
                get_embedding(vocab_size, word_embedding_dim, word2id)))

        self.rnn = RNN(hidden_size, bi_flag, num_layer, input_size, cell_type,
                       dropout, num_tag)

        self.crf = CRF(num_tag=num_tag)

        self.checkpoint_dir = checkpoint_dir

    def forward(self, inputs, length):
        embeddings = self.embedding(inputs)
        rnn_output = self.rnn(embeddings,
                              length)  # (batch_size, time_steps, num_tag+2)
        return rnn_output

    def loss_fn(self, rnn_output, labels, length):
        loss = self.crf.negative_log_loss(inputs=rnn_output,
                                          length=length,
                                          tags=labels)
        return loss

    def predict(self, rnn_output, length):
        best_path = self.crf.get_batch_best_path(rnn_output, length)
        return best_path

    def load(self):
        self.load_state_dict(torch.load(self.checkpoint_dir))

    def save(self):
        torch.save(self.state_dict(), self.checkpoint_dir)

    def evaluate(self, y_pred, y_true):
        y_true = y_true.cpu().numpy()
        y_pred = y_pred.numpy()
        f1 = f1_score(y_true, y_pred, labels=config.labels, average="macro")
        correct = np.sum((y_true == y_pred).astype(int))
        acc = correct / y_pred.shape[0]
        return (acc, f1)

    def class_report(self, y_pred, y_true):
        y_true = y_true.cpu().numpy()
        y_pred = y_pred.numpy()
        classify_report = classification_report(y_true, y_pred)
        print('\n\nclassify_report:\n', classify_report)
class Bert_CRF(BertPreTrainedModel):

    def __init__(self, config, num_tag):
        super(Bert_CRF, self).__init__(config)
        self.bert = BertModel(config)
        # for p in self.bert.parameters():
        #     p.requires_grad = False
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, num_tag)
        self.apply(self.init_bert_weights)

        self.crf = CRF(num_tag)

    def forward(self,
                input_ids,
                token_type_ids,
                attention_mask,
                label_id=None,
                output_all_encoded_layers=False):
        bert_encode, _ = self.bert(input_ids,
                                   token_type_ids,
                                   attention_mask,
                                   output_all_encoded_layers=output_all_encoded_layers)
        output = self.classifier(bert_encode)
        return output

    def loss_fn(self, bert_encode, output_mask, tags):
        loss = self.crf.negative_log_loss(bert_encode, output_mask, tags)
        return loss

    def predict(self, bert_encode, output_mask):
        predicts = self.crf.get_batch_best_path(bert_encode, output_mask)
        predicts = predicts.view(1, -1).squeeze()
        predicts = predicts[predicts != -1]
        return predicts

    def acc_f1(self, y_pred, y_true):
        y_pred = y_pred.numpy()
        y_true = y_true.numpy()
        f1 = f1_score(y_true, y_pred, average="macro")
        correct = np.sum((y_true == y_pred).astype(int))
        acc = correct / y_pred.shape[0]
        return acc, f1

    def class_report(self, y_pred, y_true):
        y_true = y_true.numpy()
        y_pred = y_pred.numpy()
        classify_report = classification_report(y_true, y_pred)
        print('\n\nclassify_report:\n', classify_report)
    def __init__(self, vocab_size, word_embedding_dim, word2id, hidden_size,
                 bi_flag, num_layer, input_size, cell_type, dropout, num_tag,
                 checkpoint_dir):
        super(BILSTM_CRF, self).__init__()

        self.embedding = nn.Embedding(vocab_size, word_embedding_dim)
        for p in self.embedding.parameters():
            p.requires_grad = False
        self.embedding.weight.data.copy_(
            torch.from_numpy(
                get_embedding(vocab_size, word_embedding_dim, word2id)))

        self.rnn = RNN(hidden_size, bi_flag, num_layer, input_size, cell_type,
                       dropout, num_tag)

        self.crf = CRF(num_tag=num_tag)

        self.checkpoint_dir = checkpoint_dir
示例#6
0
class Bert_CRF(nn.Module):
    def __init__(self):
        super(Bert_CRF, self).__init__()
        self.bert = BertForTokenClassification.from_pretrained(
            args.bert_model,
            cache_dir=PYTORCH_PRETRAINED_BERT_CACHE,
            num_labels=len(args.labels))

        self.crf = CRF(len(args.labels))

    def forward(self,
                input_ids,
                token_type_ids,
                attention_mask,
                label_id=None,
                output_all_encoded_layers=False):
        logit = self.bert(input_ids, token_type_ids)
        return logit

    def loss_fn(self, bert_encode, output_mask, tags):
        loss = self.crf.negative_log_loss(bert_encode, output_mask, tags)
        return loss

    def predict(self, bert_encode, output_mask):
        predicts = self.crf.get_batch_best_path(bert_encode, output_mask)
        # predicts = predicts.view(1, -1).squeeze()
        # predicts = predicts[predicts != -1]
        return predicts

    def acc_f1(self, y_pred, y_true):
        y_pred = y_pred.numpy()
        y_true = y_true.numpy()
        f1 = f1_score(y_true, y_pred, average="macro")
        correct = np.sum((y_true == y_pred).astype(int))
        acc = correct / y_pred.shape[0]
        return acc, f1

    def class_report(self, y_pred, y_true):
        y_true = y_true.numpy()
        y_pred = y_pred.numpy()
        classify_report = classification_report(y_true, y_pred)
        print('\n\nclassify_report:\n', classify_report)
示例#7
0
class Bert_CRF(BertPreTrainedModel):
    def __init__(self, config, num_tag):
        super(Bert_CRF, self).__init__(config)
        self.bert = BertModel(config)
        if args.do_not_train_ernie:
            for p in self.bert.parameters():
                p.requires_grad = False
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, num_tag)
        self.apply(self.init_bert_weights)
        self.crf = CRF(num_tag)
        self.num_tag = num_tag

    def forward(self,
                input_ids,
                token_type_ids,
                attention_mask,
                label_id=None,
                output_all_encoded_layers=False):
        bert_encode, _ = self.bert(
            input_ids,
            token_type_ids,
            attention_mask,
            output_all_encoded_layers=output_all_encoded_layers)
        output = self.classifier(bert_encode)

        return output

    def loss_fn(self, bert_encode, output_mask, tags):
        if args.do_CRF:
            loss = self.crf.negative_log_loss(bert_encode, output_mask, tags)
        else:
            loss = torch.autograd.Variable(torch.tensor(0.),
                                           requires_grad=True)
            for ix, (features, tag) in enumerate(zip(bert_encode, tags)):
                num_valid = torch.sum(output_mask[ix].detach())
                features = features[output_mask[ix] == 1]
                tag = tag[:num_valid]
                loss_fct = nn.CrossEntropyLoss(ignore_index=0)
                loss = loss + loss_fct(
                    features.view(-1, self.num_tag).cpu(),
                    tag.view(-1).cpu())
        return loss

    def predict(self, bert_encode, output_mask):
        if args.do_CRF:
            predicts = self.crf.get_batch_best_path(bert_encode, output_mask)
            if not args.do_inference:
                predicts = predicts.view(1, -1).squeeze()
                predicts = predicts[predicts != -1]
            else:
                predicts_ = []
                for ix, features, in enumerate(predicts):
                    #features = features[output_mask[ix] == 1]
                    predict = features[features != -1]
                    predicts_.append(predict)
                predicts = predicts_
        else:
            predicts_ = []
            for ix, features, in enumerate(bert_encode):
                features = features[output_mask[ix] == 1]
                predict = F.softmax(features, dim=1)
                predict = torch.argmax(predict, dim=1)
                predicts_.append(predict)
            if not args.do_inference:
                predicts = torch.cat(predicts_, 0)
            else:
                predicts = predicts_
        return predicts

    def acc_f1(self, y_pred, y_true):
        try:
            y_pred = y_pred.numpy()
            y_true = y_true.numpy()
        except:
            pass
        f1 = f1_score(y_true, y_pred, average="macro")
        correct = np.sum((y_true == y_pred).astype(int))
        acc = correct / y_pred.shape[0]
        return acc, f1

    def class_report(self, y_pred, y_true):
        y_true = y_true.numpy()
        y_pred = y_pred.numpy()
        classify_report = classification_report(y_true, y_pred)
        print('\n\nclassify_report:\n', classify_report)