コード例 #1
0
def main(unused_argv):
    hparams = cn_hparams.create_hparams()
    # model_fun=[2,3,4,5],30
    model_fn = cn_model.create_model_fn(hparams,
                                        model_impl=dual_encoder_model,
                                        model_fun=model.RNN_MaxPooling,
                                        RNNInit=tf.nn.rnn_cell.LSTMCell,
                                        is_bidirection=True)

    estimator = Estimator(model_fn=model_fn,
                          model_dir=MODEL_DIR,
                          config=tf.contrib.learn.RunConfig())
    # tf.contrib.learn.RunConfig()
    input_fn_train = cn_inputs.create_input_fn(
        mode=tf.contrib.learn.ModeKeys.TRAIN,
        input_files=[TRAIN_FILE],
        batch_size=hparams.batch_size,
        num_epochs=FLAGS.num_epochs)

    input_fn_eval = cn_inputs.create_input_fn(
        mode=tf.contrib.learn.ModeKeys.EVAL,
        input_files=[VALIDATION_FILE],
        batch_size=hparams.eval_batch_size,
        num_epochs=1)

    eval_metrics = cn_metrics.create_evaluation_metrics()

    eval_monitor = tf.contrib.learn.monitors.ValidationMonitor(
        input_fn=input_fn_eval,
        every_n_steps=FLAGS.eval_every,
        metrics=eval_metrics)  # 喂数据

    estimator.fit(input_fn=input_fn_train, steps=None, monitors=[eval_monitor])
コード例 #2
0
def main(unused_argv):
    hparams = cn_hparams.create_hparams()
    # model_fun=[2,3,4,5],30
    model_fn = cn_model.create_model_fn(hparams,
                                        model_impl=encoder_model,
                                        model_fun=model,
                                        RNNInit=RNNInit,
                                        is_bidirection=is_bidirection,
                                        input_keep_prob=1.0,
                                        output_keep_prob=1.0)
    estimator = Estimator(model_fn=model_fn,
                          model_dir=MODEL_DIR,
                          config=tf.contrib.learn.RunConfig())

    input_fn_train = cn_inputs.create_input_fn(
        mode=tf.contrib.learn.ModeKeys.TRAIN,
        input_files=[TRAIN_FILE],
        batch_size=hparams.batch_size,
        num_epochs=FLAGS.num_epochs)
    input_fn_eval = cn_inputs.create_input_fn(
        mode=tf.contrib.learn.ModeKeys.EVAL,
        input_files=[VALIDATION_FILE],
        batch_size=hparams.eval_batch_size,
        num_epochs=1)
    # tf.contrib.learn.RunConfig()

    eval_metrics = cn_metrics.create_evaluation_metrics()

    eval_monitor = tf.contrib.learn.monitors.ValidationMonitor(
        input_fn=input_fn_eval,
        every_n_steps=FLAGS.eval_every,
        metrics=eval_metrics,
        early_stopping_metric="recall_at_1",
        early_stopping_metric_minimize=True,
        early_stopping_rounds=10000)  # 喂数据

    estimator.fit(input_fn=input_fn_train,
                  steps=10000,
                  monitors=[eval_monitor])
コード例 #3
0
    # features["context_len"]=tf.constant(context_len, shape=[len(context_len), 1], dtype=tf.int64)
    return features, None


if __name__ == "__main__":

    # get raw data
    dataset = lyx.load_pkl('dataset')
    raw_data = dataset.raw_data

    # restore model & parameters
    hparams = cn_hparams.create_hparams()
    model_fn = cn_model.create_model_fn(hparams,
                                        model_impl=encoder_model,
                                        model_fun=model,
                                        RNNInit=RNNInit,
                                        is_bidirection=is_bidirection,
                                        input_keep_prob=1.0,
                                        output_keep_prob=1.0)
    estimator = tf.contrib.learn.Estimator(model_fn=model_fn,
                                           model_dir=MODEL_DIR)

    while True:
        input_question = input('欢迎咨询医疗问题,请描述您的问题: ')

        # if input the index of question
        if input_question.isdigit():
            input_question_index = int(input_question)
            input_question = raw_data[input_question_index]['question']

        print('您的问题:%s' % input_question)
コード例 #4
0
tf.flags.DEFINE_integer("test_batch_size", 16, "Batch size for testing")
tf.flags.DEFINE_boolean("customized_word_vector", True,
                        "choose random or customized word vectors")
FLAGS = tf.flags.FLAGS

if not FLAGS.model_dir:
    # print("You must specify a model directory")
    sys.exit(1)

tf.logging.set_verbosity(FLAGS.loglevel)

if __name__ == "__main__":
    hparams = cn_hparams.create_hparams()
    model_fn = cn_model.create_model_fn(hparams,
                                        model_impl=encoder_model,
                                        model_fun=model.RNN_CNN_MaxPooling,
                                        RNNInit=tf.nn.rnn_cell.LSTMCell,
                                        is_bidirection=True)
    estimator = tf.contrib.learn.Estimator(model_fn=model_fn,
                                           model_dir=FLAGS.model_dir,
                                           config=tf.contrib.learn.RunConfig())

    input_fn_test = cn_inputs.create_input_fn(
        mode=tf.contrib.learn.ModeKeys.EVAL,
        input_files=[FLAGS.test_file],
        batch_size=FLAGS.test_batch_size,
        num_epochs=1)

    eval_metrics = cn_metrics.create_evaluation_metrics()
    estimator.evaluate(input_fn=input_fn_test,
                       steps=None,