コード例 #1
0
def evaluate(eval_file, model_dir, summary_dir, train_steps):
    hp = hparam.create_hparam()

    eval_graph = tf.Graph()
    with eval_graph.as_default():
        input_features = HRAN.create_input_layer(mode=modekeys.EVAL,
                                                 filename=eval_file,
                                                 hp=hp)

        ppl = HRAN.impl(features=input_features, hp=hp, mode=modekeys.EVAL)

        sess = tf.Session()

        saver = tf.train.Saver()
        checkpoint = saver_lib.latest_checkpoint(model_dir)
        saver.restore(sess=sess, save_path=checkpoint)
        sess.run(tf.local_variables_initializer())

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        tf.logging.info('Begin evaluation at model {} on file {}'.format(
            checkpoint, eval_file))

        total_ppl = 0
        eval_step = 0
        try:
            while not coord.should_stop():
                perplexity = sess.run(fetches=ppl)
                total_ppl += perplexity
                eval_step += 1
        except tf.errors.OutOfRangeError:
            tf.logging.info('Finish evaluation')
        finally:
            coord.request_stop()
        coord.join(threads)

        avg_ppl = total_ppl / eval_step
        #write_to_summary(output_dir=summary_dir,summary_tag='eval_bleu_score',summary_value=bleu_score,current_global_step=train_steps)
        write_to_summary(output_dir=summary_dir,
                         summary_tag='eval_ppl',
                         summary_value=avg_ppl,
                         current_global_step=train_steps)
        tf.logging.info('eval ppl is {}'.format(avg_ppl))
        #tf.logging.info('bleu score is {}'.format(bleu_score))
        return avg_ppl
コード例 #2
0
ファイル: predict.py プロジェクト: wangfeng0621/HRAN
def predict(datafile, model_dir):
    hp = hparam.create_hparam()

    eval_graph = tf.Graph()
    with eval_graph.as_default():
        input_features = HRAN.create_input_layer(mode=modekeys.PREDICT,
                                                 filename=datafile,
                                                 hp=hp)
        contexts = input_features['contexts']
        response_out = input_features['response_out']
        context_length = input_features['context_length']
        context_utterance_length = input_features['context_utterance_length']
        sample_ids, final_lengths = HRAN.impl(features=input_features,
                                              hp=hp,
                                              mode=modekeys.PREDICT)
        sess = tf.Session()

        saver = tf.train.Saver()
        checkpoint = saver_lib.latest_checkpoint(model_dir)
        saver.restore(sess=sess, save_path=checkpoint)
        sess.run(tf.local_variables_initializer())

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        tf.logging.info('Begin prediction at model {} on file {}'.format(
            checkpoint, datafile))

        try:
            while not coord.should_stop():
                contexts_ids, gen_ids, gen_lengths, ref_ids, con_len, con_utte_len = sess.run(
                    fetches=[
                        contexts, sample_ids, final_lengths, response_out,
                        context_length, context_utterance_length
                    ])
                tf.logging.info('write prediction to file')
                write_to_file(contexts_ids, ref_ids, gen_ids, './twitter_data',
                              model_dir, gen_lengths, con_len, con_utte_len)
                coord.request_stop()

        except tf.errors.OutOfRangeError:
            tf.logging.info('Finish prediction')
        finally:
            coord.request_stop()
        coord.join(threads)
コード例 #3
0
ファイル: evaluate.py プロジェクト: luomuqinghan/HRAN-1
def evaluate(eval_file,model_dir,summary_dir,train_steps):
    hp = hparam.create_hparam()

    eval_graph = tf.Graph()
    with eval_graph.as_default():
        input_features = HRAN.create_input_layer(mode=modekeys.EVAL,filename=eval_file,hp=hp)

        ppl  = HRAN.impl(features=input_features,hp=hp,mode=modekeys.EVAL)

        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=1.0)
        sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))

        saver = tf.train.Saver()
        checkpoint = saver_lib.latest_checkpoint(model_dir)
        saver.restore(sess=sess,save_path=checkpoint)
        sess.run(tf.local_variables_initializer())

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess,coord=coord)
        tf.logging.info('Begin evaluation')


        try:
            total_ppl = 0
            eval_step = 0
            while not coord.should_stop():
                perplexity = sess.run(fetches=ppl)
                total_ppl += perplexity
                eval_step += 1
        except tf.errors.OutOfRangeError:
            avg_ppl = total_ppl / eval_step
            tf.logging.info('Finish evaluation. The perplexity is {}'.format(avg_ppl))
            write_to_summary(summary_dir, 'eval_ppl', avg_ppl, train_steps)
        finally:
            coord.request_stop()
        coord.join(threads)

        return avg_ppl
コード例 #4
0
def main(unused_arg):
    hyper_parameters = hparam.create_hparam()

    train_config = tf.contrib.learn.RunConfig(
        gpu_memory_fraction=1,
        save_summary_steps=hyper_parameters.eval_step,
        save_checkpoints_steps=hyper_parameters.eval_step,
        log_step_count_steps=1000)

    estimator = tf.estimator.Estimator(model_fn=model.create_model_fn(),
                                       model_dir=MODEL_DIR,
                                       config=train_config,
                                       params=hyper_parameters)

    monitors_list = []
    if FLAGS.debug:
        debuger = tf_debug.LocalCLIDebugHook()
        monitors_list.append(debuger)

    valid_input_fn = input.create_input_fn(tf.estimator.ModeKeys.EVAL,
                                           [VALID_FILE],
                                           hyper_parameters.eval_batch_size, 1,
                                           False)
    valid_monitor = tf.contrib.learn.monitors.ValidationMonitor(
        input_fn=valid_input_fn, every_n_steps=hyper_parameters.eval_step)
    monitors_list.append(valid_monitor)

    hooks = tf.contrib.learn.monitors.replace_monitors_with_hooks(
        monitors_list, estimator)
    train_input_fn = input.create_input_fn(tf.estimator.ModeKeys.TRAIN,
                                           [TRAIN_FILE],
                                           hyper_parameters.batch_size,
                                           hyper_parameters.num_epochs,
                                           hyper_parameters.shuffle_batch)
    estimator.train(input_fn=train_input_fn, hooks=hooks)

    hparam.write_hparams_to_file(hyper_parameters, MODEL_DIR)
コード例 #5
0
def main(unused_arg):
    eval_model_fn = model.create_model_fn()
    if not FLAGS.model_dir:
        raise KeyError()
    else:
        MODEL_DIR = FLAGS.model_dir

    hyper_parameters = hparam.create_hparam()

    estimator = tf.estimator.Estimator(eval_model_fn, model_dir=MODEL_DIR, params=hyper_parameters)

    EVAL_FILE = './data/validation.tfrecords'
    eval_input_fn = input.create_input_fn(tf.estimator.ModeKeys.EVAL,[EVAL_FILE],hyper_parameters.eval_batch_size,1,False)

    monitors_list = []
    if FLAGS.debug:
        debuger = tf_debug.LocalCLIDebugHook()
        monitors_list.append(debuger)
    hooks = tf.contrib.learn.monitors.replace_monitors_with_hooks(monitors_list, estimator)

    eval_result = estimator.evaluate(input_fn=eval_input_fn,hooks=hooks)

    print('precision: {}'.format(eval_result['precision']))
    print('recall: {}'.format(eval_result['recall']))
コード例 #6
0
def train():
    hp = hparam.create_hparam()
    train_graph = tf.Graph()
    with train_graph.as_default():
        input_features = HRAN.create_input_layer(mode=modekeys.TRAIN,
                                                 filename=TRAIN_FILE,
                                                 hp=hp)
        loss = HRAN.impl(features=input_features, mode=modekeys.TRAIN, hp=hp)
        global_step_tensor = tf.Variable(initial_value=0,
                                         trainable=False,
                                         collections=[
                                             tf.GraphKeys.GLOBAL_STEP,
                                             tf.GraphKeys.GLOBAL_VARIABLES
                                         ],
                                         name='global_step')
        train_op, lr = create_train_op(loss, hp.learning_rate,
                                       global_step_tensor)

        tf.summary.scalar(name='train_loss', tensor=loss)
        summary_op = tf.summary.merge_all()
        summary_writer = tf.summary.FileWriter(
            logdir=os.path.join(os.path.abspath(MODEL_DIR), 'summary'))

        sess = tf.Session()

        if FLAGS.debug:
            #sess = tf_debug.LocalCLIDebugWrapperSession(sess,thread_name_filter = "MainThread$")
            #sess.add_tensor_filter(tensor_filter=tf_debug.has_inf_or_nan,filter_name='has_inf_or_nan')
            pass

        saver = tf.train.Saver(max_to_keep=5)
        checkpoint = saver_lib.latest_checkpoint(MODEL_DIR)
        tf.logging.info('model dir {}'.format(MODEL_DIR))
        tf.logging.info('check point {}'.format(checkpoint))
        if checkpoint:
            tf.logging.info('Restore parameter from {}'.format(checkpoint))
            saver.restore(sess=sess, save_path=checkpoint)
            sess.run(tf.local_variables_initializer())
        else:
            sess.run([
                tf.global_variables_initializer(),
                tf.local_variables_initializer()
            ])

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        tf.logging.info(msg='Begin training')
        try:
            while not coord.should_stop():
                _, current_loss, summary, global_step, learning_rate = sess.run(
                    fetches=[
                        train_op, loss, summary_op, global_step_tensor, lr
                    ])

                if global_step % 100 == 0:
                    tf.logging.info('global step ' + str(global_step) +
                                    ' loss: ' + str(current_loss))

                if global_step % hp.summary_save_steps == 0:
                    summary_writer.add_summary(summary=summary,
                                               global_step=global_step)
                    tf.logging.info('learning rate {}'.format(learning_rate))

                if global_step % hp.eval_step == 0:
                    tf.logging.info('save model')
                    saver.save(sess=sess,
                               save_path=os.path.join(MODEL_DIR, 'model.ckpt'),
                               global_step=global_step)
                    eval_file = os.path.join(os.path.abspath(FLAGS.data_dir),
                                             'valid.tfrecords')
                    evaluate.evaluate(eval_file, MODEL_DIR,
                                      os.path.join(MODEL_DIR, 'summary/eval'),
                                      global_step)

        except tf.errors.OutOfRangeError:
            tf.logging.info('Finish training -- epoch limit reached')
        finally:
            coord.request_stop()
        coord.join(threads)

        saver.save(sess=sess,
                   save_path=os.path.join(MODEL_DIR, 'model.ckpt'),
                   global_step=tf.train.get_global_step())
コード例 #7
0
def train():
    hp = hparam.create_hparam()
    train_graph = tf.Graph()
    with train_graph.as_default():
        input_features = HRAN.create_input_layer(mode=modekeys.TRAIN,filename=TRAIN_FILE,hp=hp)
        loss,debug_tensors = HRAN.impl(features=input_features,mode=modekeys.TRAIN,hp=hp)
        global_step_tensor = tf.Variable(initial_value=0,
                                         trainable=False,
                                         collections=[tf.GraphKeys.GLOBAL_STEP, tf.GraphKeys.GLOBAL_VARIABLES],
                                         name='global_step')
        train_op, grad_norm = create_train_op(loss, hp.learning_rate, global_step_tensor, hp.clip_norm)
        stop_criteria_tensor = tf.Variable(initial_value=10000, trainable=False, name='stop_criteria', dtype=tf.float32)

        tf.summary.scalar(name='train_loss',tensor=loss)
        tf.summary.scalar(name='train_grad_norm', tensor=grad_norm)
        summary_op = tf.summary.merge_all()
        summary_writer = tf.summary.FileWriter(logdir=os.path.join(os.path.abspath(MODEL_DIR), 'summary'))

        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=1.0)
        sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))

        if FLAGS.debug:
            #sess = tf_debug.LocalCLIDebugWrapperSession(sess,thread_name_filter = "MainThread$")
            #sess.add_tensor_filter(tensor_filter=tf_debug.has_inf_or_nan,filter_name='has_inf_or_nan')
            pass

        saver = tf.train.Saver(max_to_keep=1)
        best_saver = tf.train.Saver(max_to_keep=1)
        checkpoint = saver_lib.latest_checkpoint(MODEL_DIR)
        tf.logging.info('model dir {}'.format(MODEL_DIR))
        tf.logging.info('check point {}'.format(checkpoint))
        if checkpoint:
            tf.logging.info('Restore parameter from {}'.format(checkpoint))
            saver.restore(sess=sess,save_path=checkpoint)
            sess.run(tf.local_variables_initializer())
        else:
            sess.run([tf.global_variables_initializer(),tf.local_variables_initializer()])

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        tf.logging.info(msg='Begin training')
        try:
            stop_count = 10

            while not coord.should_stop():
                _,current_loss,summary,global_step = sess.run(fetches=[train_op,loss,summary_op,global_step_tensor])

                if global_step % 100 == 0:
                    tf.logging.info('global step '+str(global_step)+' loss: ' + str(current_loss))

                if global_step % hp.summary_save_steps == 0:
                    summary_writer.add_summary(summary=summary,global_step=global_step)

                if global_step % hp.eval_step == 0:
                    saver.save(sess=sess, save_path=os.path.join(MODEL_DIR, 'model.ckpt'), global_step=global_step)
                    eval_file = os.path.join(os.path.abspath(FLAGS.data_dir), 'valid.tfrecords')
                    cur_stop_criteria = evaluate.evaluate(eval_file, MODEL_DIR, os.path.join(MODEL_DIR, 'summary/eval'),
                                                      global_step)
                    stop_criteria = sess.run(stop_criteria_tensor)
                    if cur_stop_criteria < stop_criteria:
                        sess.run(stop_criteria_tensor.assign(cur_stop_criteria))
                        best_model_path = os.path.join(os.path.join(MODEL_DIR, 'best_model'), 'model.ckpt')
                        save_path = best_saver.save(sess=sess, save_path=best_model_path,
                                                    global_step=tf.train.get_global_step())
                        tf.logging.info('Save best model to {}'.format(save_path))
                        stop_count = 10
                    else:
                        stop_count -= 1
                        if stop_count == 0:
                            tf.logging.info('Early stop at step {}'.format(global_step))
                            break


        except tf.errors.OutOfRangeError:
            tf.logging.info('Finish training -- epoch limit reached')
        finally:
            tf.logging.info('Best ppl is {}'.format(sess.run(stop_criteria_tensor)))
            coord.request_stop()

        coord.join(threads)
コード例 #8
0
ファイル: online_predict.py プロジェクト: luomuqinghan/HRAN-1
def online_prediction():
    hp = hparam.create_hparam()

    vocab_path = hp.vocab_path
    vocab = load_vocabulary(vocab_path)
    reverse_vocab = load_reverse_vocabulary(vocab_path)

    features = model_impl.create_input_layer(filename=None,
                                             hp=hp,
                                             mode=modekeys.PREDICT)
    results = model_impl.impl(features, modekeys.PREDICT, hp)

    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=1.0)
    sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))

    saver = tf.train.Saver()
    checkpoint = saver_lib.latest_checkpoint(MODEL_DIR)
    if checkpoint:
        saver.restore(sess=sess, save_path=checkpoint)
        print('restore from {}'.format(checkpoint))
        sess.run(tf.local_variables_initializer())
    else:
        raise Exception('no check point')

    print('Finish model initializing')
    raw_query = input('Please enter query\n')

    while raw_query != 'q':

        feed_dict = {}
        if tf.flags.FLAGS.dialog_mode == 'single':
            query_id, query_len = preprocess_raw_query(raw_query, vocab,
                                                       hp.max_sentence_length)
            feed_dict = {
                features['utterance']: [query_id],
                features['utterance_length']: [query_len]
            }
            # print(new_query)
            # print(query_id)
        elif tf.flags.FLAGS.dialog_mode == 'multi':
            context_ids, context_len, context_utterance_lens = preprocess_raw_context(
                raw_query, vocab, hp.max_sentence_length,
                hp.max_context_length)
            feed_dict = {
                features['contexts']: [context_ids],
                features['context_utterance_length']: [context_utterance_lens],
                features['context_length']: [context_len]
            }

        fetch_dict = {}
        if tf.flags.FLAGS.dialog_mode == 'single':
            if hp.beam_width == 0:
                fetch_dict['response_ids'] = results['response_ids']
                fetch_dict['response_lens'] = results['response_lens']
                fetch_dict['alignment_history'] = results['alignment_history']
                fetch_dict['keywords_prob'] = results['keywords_prob']
            else:

                fetch_dict['response_ids'] = results['response_ids']
                fetch_dict['response_lens'] = tf.constant(0)
                fetch_dict['alignment_history'] = tf.constant(0)
                fetch_dict['keywords_prob'] = results['keywords_prob']
        elif tf.flags.FLAGS.dialog_mode == 'multi':
            if hp.beam_width == 0:
                fetch_dict['response_ids'] = results['response_ids']
                fetch_dict['response_lens'] = results['response_lens']
            else:
                fetch_dict['response_ids'] = results['response_ids']
                fetch_dict['response_lens'] = tf.constant(0)

        fetches = sess.run(fetches=fetch_dict, feed_dict=feed_dict)

        if tf.flags.FLAGS.dialog_mode == 'single':
            if hp.beam_width > 0:
                gen_responses_ids = fetches['response_ids']
                responses = postprocess_k_generated_response(
                    gen_responses_ids, reverse_vocab)

                response = responses[0]

                for res in response:
                    print('Response: {}'.format(res))
            else:
                # print(gen_responses_ids)
                # print(lens)
                gen_responses_ids = fetches['response_ids']
                lens = fetches['response_lens']
                responses = postprocess_generated_response(
                    gen_responses_ids, lens, reverse_vocab)
                response = responses[0]
                print('Response: {}'.format(response))

                alignment_history = fetches['alignment_history']
                alignment_his = alignment_history[0]

                for i, ali in enumerate(alignment_his[0:lens[0]]):
                    print('word{} {}'.format(i, np.argsort(ali)[::-1][0:5]))
                    print(ali)
                    print('\n')

                key_prob = fetches['keywords_prob']
                print('\n')
                print('keywords prediction')
                print(key_prob[0])
                print(np.argsort(key_prob[0])[::-1][0:5])
        elif tf.flags.FLAGS.dialog_mode == 'multi':
            if hp.beam_width > 0:
                gen_responses_ids = fetches['response_ids']
                responses = postprocess_k_generated_response(
                    gen_responses_ids, reverse_vocab)

                response = responses[0]

                for res in response:
                    print('Response: {}'.format(res))
            else:
                # print(gen_responses_ids)
                # print(lens)
                gen_responses_ids = fetches['response_ids']
                lens = fetches['response_lens']
                responses = postprocess_generated_response(
                    gen_responses_ids, lens, reverse_vocab)
                response = responses[0]
                print('Response: {}'.format(response))

        raw_query = input('Please enter query\n')