コード例 #1
0
def gen_role_class_data():
    """
    generate role mrc data for verify_neg_fold_data_{}
    """
    # bert vocab file path
    vocab_file_path = os.path.join(
        event_config.get("bert_pretrained_model_path"),
        event_config.get("vocab_file"))
    # event role slot list file path
    slot_file = os.path.join(
        event_config.get("slot_list_root_path"),
        event_config.get("bert_slot_complete_file_name_role"))
    # schema file path
    schema_file = os.path.join(event_config.get("data_dir"),
                               event_config.get("event_schema"))
    # query map file path
    query_file = os.path.join(event_config.get("slot_list_root_path"),
                              event_config.get("query_map_file"))
    data_loader = EventRolePrepareMRC(vocab_file_path, 512, slot_file,
                                      schema_file, query_file)
    train_file = os.path.join(event_config.get("data_dir"),
                              event_config.get("event_data_file_train"))
    eval_file = os.path.join(event_config.get("data_dir"),
                             event_config.get("event_data_file_eval"))
    data_loader.k_fold_split_data(train_file, eval_file, True)
コード例 #2
0
 def init_data_loader(self, config, query_map_file):
     vocab_file_path = os.path.join(
         config.get("bert_pretrained_model_path"), config.get("vocab_file"))
     slot_file = os.path.join(
         event_config.get("slot_list_root_path"),
         event_config.get("bert_slot_complete_file_name_role"))
     schema_file = os.path.join(event_config.get("data_dir"),
                                event_config.get("event_schema"))
     # query_map_file = os.path.join(event_config.get(
     #         "slot_list_root_path"), event_config.get("query_map_file"))
     data_loader = EventRolePrepareMRC(vocab_file_path, 512, slot_file,
                                       schema_file, query_map_file)
     return data_loader
コード例 #3
0
from configs.event_config import event_config

if __name__ == "__main__":
    vocab_file_path = os.path.join(event_config.get("bert_pretrained_model_path"), event_config.get("vocab_file"))
    # bert_config_file = os.path.join(event_config.get("bert_pretrained_model_path"), event_config.get("bert_config_path"))
    event_type_file = os.path.join(event_config.get("slot_list_root_path"), event_config.get("event_type_file"))
    # data_loader =EventTypeClassificationPrepare(vocab_file_path,512,event_type_file)
    # train_file = os.path.join(event_config.get("data_dir"),event_config.get("event_data_file_train"))
    # eval_file = os.path.join(event_config.get("data_dir"),event_config.get("event_data_file_eval"))
    # train_data_list,train_label_list,train_token_type_id_list,dev_data_list,dev_label_list,dev_token_type_id_list = data_loader._read_json_file(train_file,eval_file,is_train=True)
    slot_file = os.path.join(event_config.get("slot_list_root_path"),
                             event_config.get("bert_slot_complete_file_name_role"))
    schema_file = os.path.join(event_config.get("data_dir"), event_config.get("event_schema"))
    query_map_file = os.path.join(event_config.get("slot_list_root_path"), event_config.get("query_map_file"))

    data_loader = EventRolePrepareMRC(vocab_file_path, 512, slot_file, schema_file, query_map_file)
    train_file = os.path.join(event_config.get("data_dir"), event_config.get("event_data_file_train"))
    eval_file = os.path.join(event_config.get("data_dir"), event_config.get("event_data_file_eval"))
    # data_list,label_start_list,label_end_list,query_len_list,token_type_id_list
    # train_datas, train_labels_start,train_labels_end,train_query_lens,train_token_type_id_list,dev_datas, dev_labels_start,dev_labels_end,dev_query_lens,dev_token_type_id_list = data_loader._read_json_file(train_file,eval_file,True)
    # dev_datas, dev_labels_start,dev_labels_end,dev_query_lens,dev_token_type_id_list = data_loader._read_json_file(eval_file,None,False)
    # train_datas, train_labels_start,train_labels_end,train_query_lens,train_token_type_id_list,dev_datas, dev_labels_start,dev_labels_end,dev_query_lens,dev_token_type_id_list = data_loader._merge_ee_and_re_datas(train_file,eval_file,"relation_extraction/data/train_data.json","relation_extraction/data/dev_data.json")
    data_loader.k_fold_split_data(train_file, eval_file, True)
    # import numpy as np
    # train_query_lens = np.load("data/fold_data_{}/query_lens_train.npy".format(0),allow_pickle=True)
    # print(train_query_lens[0])

    # re_train_file = "relation_extraction/data/train_data.json"
    # re_dev_file = "relation_extraction/data/dev_data.json"
    # # data_loader.k_fold_split_data(train_file,eval_file,re_train_file,re_dev_file,True,6)
    # train_datas = np.load("data/re15000_neg_fold_data_{}/token_ids_train.npy".format(0),allow_pickle=True)
コード例 #4
0
def run_event_verify_role_mrc(args):
    """
    retro reader 第二阶段的精度模块,同时训练两个任务,role抽取和问题是否可以回答
    :param args:
    :return:
    """
    model_base_dir = event_config.get(args.model_checkpoint_dir).format(
        args.fold_index)
    pb_model_dir = event_config.get(args.model_pb_dir).format(args.fold_index)
    vocab_file_path = os.path.join(
        event_config.get("bert_pretrained_model_path"),
        event_config.get("vocab_file"))
    bert_config_file = os.path.join(
        event_config.get("bert_pretrained_model_path"),
        event_config.get("bert_config_path"))
    slot_file = os.path.join(
        event_config.get("slot_list_root_path"),
        event_config.get("bert_slot_complete_file_name_role"))
    schema_file = os.path.join(event_config.get("data_dir"),
                               event_config.get("event_schema"))
    query_map_file = os.path.join(event_config.get("slot_list_root_path"),
                                  event_config.get("query_map_file"))
    data_loader = EventRolePrepareMRC(vocab_file_path, 512, slot_file,
                                      schema_file, query_map_file)
    # train_file = os.path.join(event_config.get("data_dir"), event_config.get("event_data_file_train"))
    # eval_file = os.path.join(event_config.get("data_dir"), event_config.get("event_data_file_eval"))
    # data_list,label_start_list,label_end_list,query_len_list,token_type_id_list
    # train_datas, train_labels_start,train_labels_end,train_query_lens,train_token_type_id_list,dev_datas, dev_labels_start,dev_labels_end,dev_query_lens,dev_token_type_id_list = data_loader._read_json_file(train_file,eval_file,True)
    # dev_datas, dev_labels_start,dev_labels_end,dev_query_lens,dev_token_type_id_list = data_loader._read_json_file(eval_file,None,False)
    # train_datas, train_labels_start,train_labels_end,train_query_lens,train_token_type_id_list,dev_datas, dev_labels_start,dev_labels_end,dev_query_lens,dev_token_type_id_list = data_loader._merge_ee_and_re_datas(train_file,eval_file,"relation_extraction/data/train_data.json","relation_extraction/data/dev_data.json")
    train_has_answer_label_list = []
    dev_has_answer_label_list = []
    train_datas = np.load(
        "data/verify_neg_fold_data_{}/token_ids_train.npy".format(
            args.fold_index),
        allow_pickle=True)
    # train_has_answer_label_list = np.load("data/verify_neg_fold_data_{}/has_answer_train.npy".format(args.fold_index),allow_pickle=True)
    train_token_type_id_list = np.load(
        "data/verify_neg_fold_data_{}/token_type_ids_train.npy".format(
            args.fold_index),
        allow_pickle=True)
    dev_datas = np.load(
        "data/verify_neg_fold_data_{}/token_ids_dev.npy".format(
            args.fold_index),
        allow_pickle=True)
    # dev_has_answer_label_list = np.load("data/verify_neg_fold_data_{}/has_answer_dev.npy".format(args.fold_index),allow_pickle=True)
    dev_token_type_id_list = np.load(
        "data/verify_neg_fold_data_{}/token_type_ids_dev.npy".format(
            args.fold_index),
        allow_pickle=True)
    train_query_lens = np.load(
        "data/verify_neg_fold_data_{}/query_lens_train.npy".format(
            args.fold_index),
        allow_pickle=True)
    dev_query_lens = np.load(
        "data/verify_neg_fold_data_{}/query_lens_dev.npy".format(
            args.fold_index),
        allow_pickle=True)
    train_start_labels = np.load(
        "data/verify_neg_fold_data_{}/labels_start_train.npy".format(
            args.fold_index),
        allow_pickle=True)
    dev_start_labels = np.load(
        "data/verify_neg_fold_data_{}/labels_start_dev.npy".format(
            args.fold_index),
        allow_pickle=True)
    train_end_labels = np.load(
        "data/verify_neg_fold_data_{}/labels_end_train.npy".format(
            args.fold_index),
        allow_pickle=True)
    dev_end_labels = np.load(
        "data/verify_neg_fold_data_{}/labels_end_dev.npy".format(
            args.fold_index),
        allow_pickle=True)
    train_samples_nums = len(train_datas)
    for i in range(train_samples_nums):
        if sum(train_start_labels[i]) == 0:
            train_has_answer_label_list.append(0)
        else:
            train_has_answer_label_list.append(1)

    train_has_answer_label_list = np.array(
        train_has_answer_label_list).reshape((train_samples_nums, 1))
    dev_samples_nums = len(dev_datas)
    for i in range(dev_samples_nums):
        if sum(dev_start_labels[i]) == 0:
            dev_has_answer_label_list.append(0)
        else:
            dev_has_answer_label_list.append(1)
    dev_has_answer_label_list = np.array(dev_has_answer_label_list).reshape(
        (dev_samples_nums, 1))

    if train_samples_nums % args.train_batch_size != 0:
        each_epoch_steps = int(train_samples_nums / args.train_batch_size) + 1
    else:
        each_epoch_steps = int(train_samples_nums / args.train_batch_size)
    # each_epoch_steps = int(data_loader.train_samples_nums/args.train_batch_size)+1
    logger.info('*****train_set sample nums:{}'.format(train_samples_nums))
    logger.info('*****dev_set sample nums:{}'.format(dev_samples_nums))
    logger.info('*****train each epoch steps:{}'.format(each_epoch_steps))
    train_steps_nums = each_epoch_steps * args.epochs
    # train_steps_nums = each_epoch_steps * args.epochs // hvd.size()
    logger.info('*****train_total_steps:{}'.format(train_steps_nums))
    decay_steps = args.decay_epoch * each_epoch_steps
    logger.info('*****train decay steps:{}'.format(decay_steps))
    # dropout_prob是丢弃概率
    params = {
        "dropout_prob": args.dropout_prob,
        "num_labels": 2,
        "rnn_size": args.rnn_units,
        "num_layers": args.num_layers,
        "hidden_units": args.hidden_units,
        "decay_steps": decay_steps,
        "train_steps": train_steps_nums,
        "num_warmup_steps": int(train_steps_nums * 0.1)
    }
    # dist_strategy = tf.contrib.distribute.MirroredStrategy(num_gpus=args.gpu_nums)
    config_tf = tf.ConfigProto()
    config_tf.gpu_options.allow_growth = True
    run_config = tf.estimator.RunConfig(
        model_dir=model_base_dir,
        save_summary_steps=each_epoch_steps,
        save_checkpoints_steps=each_epoch_steps,
        session_config=config_tf,
        keep_checkpoint_max=3,
        # train_distribute=dist_strategy
    )
    bert_init_checkpoints = os.path.join(
        event_config.get("bert_pretrained_model_path"),
        event_config.get("bert_init_checkpoints"))
    # init_checkpoints = "output/model/merge_usingtype_roberta_traindev_event_role_bert_mrc_model_desmodified_lowercase/checkpoint/model.ckpt-1218868"
    model_fn = event_verify_mrc_model_fn_builder(bert_config_file,
                                                 bert_init_checkpoints, args)
    estimator = tf.estimator.Estimator(model_fn,
                                       params=params,
                                       config=run_config)
    if args.do_train:
        train_input_fn = lambda: event_input_verfify_mrc_fn(
            train_datas,
            train_start_labels,
            train_end_labels,
            train_token_type_id_list,
            train_query_lens,
            train_has_answer_label_list,
            is_training=True,
            is_testing=False,
            args=args)
        eval_input_fn = lambda: event_input_verfify_mrc_fn(
            dev_datas,
            dev_start_labels,
            dev_end_labels,
            dev_token_type_id_list,
            dev_query_lens,
            dev_has_answer_label_list,
            is_training=False,
            is_testing=False,
            args=args)
        train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn,
                                            max_steps=train_steps_nums)
        exporter = tf.estimator.BestExporter(
            exports_to_keep=1,
            serving_input_receiver_fn=bert_mrc_serving_input_receiver_fn)
        eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn,
                                          exporters=[exporter],
                                          throttle_secs=0)
        # for _ in range(args.epochs):

        tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
        # "bert_ce_model_pb"
        estimator.export_saved_model(pb_model_dir,
                                     bert_mrc_serving_input_receiver_fn)