def testone():  ##按默认执行
    # get_data()
    args = parser.parse_args()
    np.random.seed(args.random_seed)
    random.seed(args.random_seed)
    args.do_model = 'mcls'
    schema_labels, predict_data, predict_sents = process_data(args)
    shiyan = """
    mcls
        """
    write_title('./work/log/' + args.do_model + '.txt', args, shiyan)
    args.checkpoint_dir = 'models/' + args.do_model + str(id)
    one(args, schema_labels, predict_data, predict_sents, str(id))
Esempio n. 2
0
def testone(id=1, do_model="mrc_relation"):  ##按默认执行
    # get_data()
    args = parser.parse_args()
    # get_submit_postprocess(args, id)
    np.random.seed(args.random_seed)
    random.seed(args.random_seed)
    args.do_model = do_model
    shiyan = """
    mrc_relation
        """
    write_title('./work/log/' + args.do_model + '.txt', args, shiyan)
    if (args.use_cross_validation):
        id = cross_validation(args, id)
    else:
        schema_labels, predict_data, predict_sents = process_data(args, 4)
        args.checkpoint_dir = 'models/' + args.do_model + str(id)
        one(args, schema_labels, predict_data, predict_sents, str(id))
Esempio n. 3
0
def bzsearch():
    args = parser.parse_args()
    # args.do_model = 'role'
    schema_labels, predict_data, predict_sents = process_data(args)
    # # 创建一个 LogWriter 对象 log_writer
    # log_writer = LogWriter("./log", sync_cycle=10)
    shiyan = """
######################################################################################################################################
                                trigger_batch_size gridsearch
######################################################################################################################################
    """
    write_title('./work/log/' + args.do_model + '.txt', args, shiyan)
    id = 4  # str(datetime.now().strftime('%m%d%H%M'))
    print(id)
    for bz in [32, 16, 8]:
        args.batch_size = bz
        args.checkpoint_dir = 'models/' + args.do_model + str(id)
        one(args, schema_labels, predict_data, predict_sents, str(id))
        id += 1
Esempio n. 4
0
def lrsearch():
    args = parser.parse_args()
    # args.do_model = 'role'
    schema_labels, predict_data, predict_sents = process_data(args)
    # # 创建一个 LogWriter 对象 log_writer
    # log_writer = LogWriter("./log", sync_cycle=10)
    shiyan = """
######################################################################################################################################
                                trigger_lrgridsearch
######################################################################################################################################
    """
    write_title('./work/log/' + args.do_model + '.txt', args, shiyan)
    id = 1  # str(datetime.now().strftime('%m%d%H%M'))
    print(id)
    for lr in [3e-5, 1e-5, 1e-4]:
        args.learning_rate = lr
        args.checkpoint_dir = 'models/' + args.do_model + str(id)
        one(args, schema_labels, predict_data, predict_sents, str(id))
        id += 1
Esempio n. 5
0
def my():
    args = parser.parse_args()
    args.do_model = 'role'
    schema_labels = read_label('{}/entity2id.txt'.format(args.data_dir))

    # # 创建一个 LogWriter 对象 log_writer
    # log_writer = LogWriter("./log", sync_cycle=10)
    shiyan = """
######################################################################################################################################
                                202,不复制,不考虑重叠
######################################################################################################################################
    """
    write_title('./work/log/' + args.do_model + '.txt', args, shiyan)
    id = 5  # str(datetime.now().strftime('%m%d%H%M'))

    args.checkpoint_dir = 'models/' + args.do_model + str(id)
    one(args, schema_labels, str(id))
    schema_labels, predict_data, predict_sents = process_data(args)
    predict_by_model_path(args,args.checkpoint_dir, schema_labels, predict_data, predict_sents, id)
def lrepochsearch():
    args = parser.parse_args()
    args.do_model = 'role'
    schema_labels = read_label('{}/entity2id.txt'.format(args.data_dir))

    # # 创建一个 LogWriter 对象 log_writer
    # log_writer = LogWriter("./log", sync_cycle=10)
    shiyan = """
######################################################################################################################################
                                202,不复制,不考虑重叠,lrepochsearch
######################################################################################################################################
    """
    write_title('./work/log/' + args.do_model + '.txt', args, shiyan)
    id = 6  # str(datetime.now().strftime('%m%d%H%M'))
    for lr in [5e-5, 7e-5, 1e-4, 3e-4]:  #[3e-5,1e-5,5e-6,1e-6]
        args.learning_rate = lr
        args.checkpoint_dir = 'models/' + args.do_model + str(id)
        for epoch in range(1, 4):
            args.num_epoch = epoch
            one(args, schema_labels, str(id))
            predict_by_model_path(args, args.checkpoint_dir, id)
            id += 1
Esempio n. 7
0
    c.to_csv('./work/data.csv', header=None, index=False, sep='\t')


if __name__ == "__main__":
    # get_data()
    args = parser.parse_args()
    np.random.seed(args.random_seed)
    random.seed(args.random_seed)
    args.do_model = 'role'
    schema_labels, predict_data, predict_sents = process_data(args)
    shiyan = """
    ######################################################################################################################################
                                    trigger_batch_size lr gridsearch
    ######################################################################################################################################
        """
    write_title('./work/log/' + args.do_model + '.txt', args, shiyan)
    id = 1
    args.checkpoint_dir = 'models/' + args.do_model + str(id)
    one(args, schema_labels, predict_data, predict_sents, str(id))

    # args = parser.parse_args()

    # args.do_model = 'role'
    # schema_labels, predict_data, predict_sents = process_data(args)
    # # # 创建一个 LogWriter 对象 log_writer
    # # log_writer = LogWriter("./log", sync_cycle=10)
    #
    # id = str(datetime.now().strftime('%m%d%H%M%S'))
    # print(id)
    #
    # args.checkpoint_dir = 'models/trigger' + str(id)