示例#1
0
def main(args):
    is_train = not args['nottrain']
    is_save = args['savemodel']
    model_path = args['modelpath']
    epoch = args['epoch']
    config = load_conf(args)
    config['nepoch'] = epoch
    #Load Dataset
    train_data, test_data = load_tt_datas(args, config)
    # setup randomer
    Randomer.set_stddev(config['stddev'])
    with tf.Graph().as_default():
        # build model
        model = Seq2SeqAttNN(config)
        model.build_model()
        if is_save or not is_train:
            saver = tf.train.Saver(max_to_keep=30)
        else:
            saver = None
        # run
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            if is_train:
                t1 = time.time()
                model.train(sess, train_data, test_data, saver)
                t2 = time.time()
                print('Training+Evaluation Time = ', t2 - t1)
            else:
                t1 = time.time()
                saver.restore(sess, model_path)
                model.test(sess, test_data)
                t2 = time.time()
                print('Testing Time = ', t2 - t1)
示例#2
0
def main(options, modelconf="config/model.conf"):
    '''
    model: 需要加载的模型
    dataset: 需要加载的数据集
    reload: 是否需要重新加载数据,yes or no
    modelconf: model config文件所在的路径
    class_num: 分类的类别
    use_term: 是否是对aspect term 进行分类
    '''
    model = options.model
    dataset = options.dataset
    reload = options.reload
    class_num = options.classnum
    is_train = not options.not_train
    is_save = not options.not_save_model
    model_path = options.model_path
    input_data = options.input_data
    epoch = options.epoch

    module, obj, config = load_conf(model, modelconf)
    config['model'] = model
    print(model)
    config['dataset'] = dataset
    config['class_num'] = class_num
    config['nepoch'] = epoch
    train_data, test_data = load_tt_datas(config, reload)
    module = __import__(module, fromlist=True)

    # setup randomer

    Randomer.set_stddev(config['stddev'])

    with tf.Graph().as_default():
        # build model
        model = getattr(module, obj)(config)
        model.build_model()
        if is_save or not is_train:
            saver = tf.train.Saver(max_to_keep=30)
        else:
            saver = None
        # run
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            if is_train:
                print(dataset)
                if dataset == "cikm16":
                    model.train(sess, train_data, test_data, saver, threshold_acc=config['cikm_threshold_acc'])
                elif dataset == "rsc15":
                    model.train(sess, train_data, test_data, saver, threshold_acc=config['recsys_threshold_acc'])
                else:
                    model.train(sess, train_data, test_data, saver)
                # if dataset == "rsc15":
                #     model.train(sess, train_data, test_data, saver, threshold_acc=config['recsys_threshold_acc'])

            else:
                if input_data is "test":
                    sent_data = test_data
                elif input_data is "train":
                    sent_data = train_data
                else:
                    sent_data = test_data
                saver.restore(sess, model_path)
                model.test(sess, sent_data)