示例#1
0
def train():
    train_data = loader.load_data('./round1_train/data/train.txt'
                                  )  # 第一个维度为所有训练样本中句子个数,第二个维度是每个句子所包含的(实体,类别)数
    valid_data = loader.load_data('./round1_train/data/val.txt')

    global train_generator
    train_generator = generator.Generator(train_data=train_data,
                                          batch_size=batch_size,
                                          tokenizer=tokenizer,
                                          maxlen=maxlen,
                                          label2id=loader.label2id)

    global model
    model = build_transformer_model(
        config_path,
        checkpoint_path,
    )  # 根据bert_model.ckpt和bert_config.json文件构建transformer模型

    output_layer = 'Transformer-%s-FeedForward-Norm' % (bert_layers - 1)
    output = model.get_layer(output_layer).output  # shape=(None, None, 768)
    output = Dense(loader.num_labels)(output)  # 27分类,13类*(B+I)+O

    output = CRF(output)

    model = Model(model.input, output)
    model.summary()

    model.compile(loss=CRF.sparse_loss,
                  optimizer=Adam(learing_rate),
                  metrics=[CRF.sparse_accuracy])

    NER = models.NamedEntityRecognizer(trans=K.eval(CRF.trans),
                                       starts=[0],
                                       ends=[0])
    evaluate = evaluator.Evaluator(valid_data, tokenizer, model, NER, CRF,
                                   loader)

    model.fit_generator(train_generator.forfit(),
                        steps_per_epoch=len(train_generator),
                        epochs=epochs,
                        callbacks=[evaluate])
示例#2
0
def verify():
    NER = models.NamedEntityRecognizer(trans=K.eval(CRF.trans),
                                       starts=[0],
                                       ends=[0])
    # 验证集
    X, Y, Z = 1e-10, 1e-10, 1e-10
    val_data_flist = glob.glob('./round1_train/val_data/*.txt')
    data_dir = './round1_train/val_data/'
    for file in val_data_flist:
        if file.find(".ann") == -1 and file.find(".txt") == -1:
            continue
        file_name = file.split('\\')[-1].split('.')[0]
        r_ann_path = os.path.join(data_dir, "%s.ann" % file_name)
        r_txt_path = os.path.join(data_dir, "%s.txt" % file_name)

        R = []
        with codecs.open(r_txt_path, "r", encoding="utf-8") as f:
            line = f.readlines()
            aa = predict_test(line, NER)
            for line in aa[0]:
                lines = line['label_type'] + " " + str(
                    line['start_pos']) + ' ' + str(
                        line['end_pos']) + "\t" + line['res']
                R.append(lines)
        T = []
        with codecs.open(r_ann_path, "r", encoding="utf-8") as f:
            for line in f:
                lines = line.strip('\n').split('\t')[1] + '\t' + line.strip(
                    '\n').split('\t')[2]
                T.append(lines)
        R = set(R)
        T = set(T)
        X += len(R & T)
        Y += len(R)
        Z += len(T)
    precision, recall = X / Y, X / Z
    f1 = 2 * precision * recall / (precision + recall)
示例#3
0
def tcm_test():
    # ## 测试集
    NER = models.NamedEntityRecognizer(trans=K.eval(CRF.trans),
                                       starts=[0],
                                       ends=[0])

    test_files = os.listdir("./round1_test/chusai_xuanshou/")

    for file in test_files:
        with codecs.open("./round1_test/chusai_xuanshou/" + file,
                         "r",
                         encoding="utf-8") as f:
            line = f.readlines()
            aa = predict_test(line, NER)
        with codecs.open("./round1_test/submission_4/" + file.split('.')[0] +
                         ".ann",
                         "w",
                         encoding="utf-8") as ff:
            for line in aa[0]:
                lines = line['overlap'] + "\t" + line[
                    'label_type'] + " " + str(line['start_pos']) + ' ' + str(
                        line['end_pos']) + "\t" + line['res']
                ff.write(lines + "\n")
            ff.close()