Ejemplo n.º 1
0
def main(_):
    print('Configurations:')
    print(FLAGS)

    log_dir = FLAGS.model_dir
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    path_prefix = log_dir + "/NP2P.{}".format(FLAGS.suffix)
    log_file_path = path_prefix + ".log"
    print('Log file path: {}'.format(log_file_path))
    log_file = open(log_file_path, 'wt')
    log_file.write("{}\n".format(FLAGS))
    log_file.flush()

    # save configuration
    namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")

    print('Loading train set.')
    if FLAGS.infile_format == 'fof':
        trainset, train_ans_len = NP2P_data_stream.read_generation_datasets_from_fof(
            FLAGS.train_path, isLower=FLAGS.isLower)
    elif FLAGS.infile_format == 'plain':
        trainset, train_ans_len = NP2P_data_stream.read_all_GenerationDatasets(
            FLAGS.train_path, isLower=FLAGS.isLower)
    else:
        trainset, train_ans_len = NP2P_data_stream.read_all_GQA_questions(
            FLAGS.train_path, isLower=FLAGS.isLower, switch=FLAGS.switch_qa)
    print('Number of training samples: {}'.format(len(trainset)))

    print('Loading test set.')
    if FLAGS.infile_format == 'fof':
        testset, test_ans_len = NP2P_data_stream.read_generation_datasets_from_fof(
            FLAGS.test_path, isLower=FLAGS.isLower)
    elif FLAGS.infile_format == 'plain':
        testset, test_ans_len = NP2P_data_stream.read_all_GenerationDatasets(
            FLAGS.test_path, isLower=FLAGS.isLower)
    else:
        testset, test_ans_len = NP2P_data_stream.read_all_GQA_questions(
            FLAGS.test_path, isLower=FLAGS.isLower, switch=FLAGS.switch_qa)
    print('Number of test samples: {}'.format(len(testset)))

    max_actual_len = max(train_ans_len, test_ans_len)
    print('Max answer length: {}, truncated to {}'.format(
        max_actual_len, FLAGS.max_answer_len))

    word_vocab = None
    POS_vocab = None
    NER_vocab = None
    char_vocab = None
    has_pretrained_model = False
    best_path = path_prefix + ".best.model"
    if os.path.exists(best_path + ".index"):
        has_pretrained_model = True
        print('!!Existing pretrained model. Loading vocabs.')
        if FLAGS.with_word:
            word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2')
            print('word_vocab: {}'.format(word_vocab.word_vecs.shape))
        if FLAGS.with_char:
            char_vocab = Vocab(path_prefix + ".char_vocab", fileformat='txt2')
            print('char_vocab: {}'.format(char_vocab.word_vecs.shape))
        if FLAGS.with_POS:
            POS_vocab = Vocab(path_prefix + ".POS_vocab", fileformat='txt2')
            print('POS_vocab: {}'.format(POS_vocab.word_vecs.shape))
        if FLAGS.with_NER:
            NER_vocab = Vocab(path_prefix + ".NER_vocab", fileformat='txt2')
            print('NER_vocab: {}'.format(NER_vocab.word_vecs.shape))
    else:
        print('Collecting vocabs.')
        (allWords, allChars, allPOSs,
         allNERs) = NP2P_data_stream.collect_vocabs(trainset)
        print('Number of words: {}'.format(len(allWords)))
        print('Number of allChars: {}'.format(len(allChars)))
        print('Number of allPOSs: {}'.format(len(allPOSs)))
        print('Number of allNERs: {}'.format(len(allNERs)))

        if FLAGS.with_word:
            word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2')
        if FLAGS.with_char:
            char_vocab = Vocab(voc=allChars,
                               dim=FLAGS.char_dim,
                               fileformat='build')
            char_vocab.dump_to_txt2(path_prefix + ".char_vocab")
        if FLAGS.with_POS:
            POS_vocab = Vocab(voc=allPOSs,
                              dim=FLAGS.POS_dim,
                              fileformat='build')
            POS_vocab.dump_to_txt2(path_prefix + ".POS_vocab")
        if FLAGS.with_NER:
            NER_vocab = Vocab(voc=allNERs,
                              dim=FLAGS.NER_dim,
                              fileformat='build')
            NER_vocab.dump_to_txt2(path_prefix + ".NER_vocab")

    print('word vocab size {}'.format(word_vocab.vocab_size))
    sys.stdout.flush()

    print('Build DataStream ... ')
    trainDataStream = NP2P_data_stream.QADataStream(trainset,
                                                    word_vocab,
                                                    char_vocab,
                                                    POS_vocab,
                                                    NER_vocab,
                                                    options=FLAGS,
                                                    isShuffle=True,
                                                    isLoop=True,
                                                    isSort=True)

    devDataStream = NP2P_data_stream.QADataStream(testset,
                                                  word_vocab,
                                                  char_vocab,
                                                  POS_vocab,
                                                  NER_vocab,
                                                  options=FLAGS,
                                                  isShuffle=False,
                                                  isLoop=False,
                                                  isSort=True)
    print('Number of instances in trainDataStream: {}'.format(
        trainDataStream.get_num_instance()))
    print('Number of instances in devDataStream: {}'.format(
        devDataStream.get_num_instance()))
    print('Number of batches in trainDataStream: {}'.format(
        trainDataStream.get_num_batch()))
    print('Number of batches in devDataStream: {}'.format(
        devDataStream.get_num_batch()))
    sys.stdout.flush()

    init_scale = 0.01
    # initialize the best bleu and accu scores for current training session
    best_accu = FLAGS.best_accu if FLAGS.__dict__.has_key('best_accu') else 0.0
    best_bleu = FLAGS.best_bleu if FLAGS.__dict__.has_key('best_bleu') else 0.0
    if best_accu > 0.0:
        print('With initial dev accuracy {}'.format(best_accu))
    if best_bleu > 0.0:
        print('With initial dev BLEU score {}'.format(best_bleu))

    with tf.Graph().as_default():
        initializer = tf.random_uniform_initializer(-init_scale, init_scale)
        with tf.name_scope("Train"):
            with tf.variable_scope("Model",
                                   reuse=None,
                                   initializer=initializer):
                train_graph = ModelGraph(word_vocab=word_vocab,
                                         char_vocab=char_vocab,
                                         POS_vocab=POS_vocab,
                                         NER_vocab=NER_vocab,
                                         options=FLAGS,
                                         mode=FLAGS.mode)

        assert FLAGS.mode in (
            'ce_train',
            'rl_train',
        )
        valid_mode = 'evaluate' if FLAGS.mode == 'ce_train' else 'evaluate_bleu'

        with tf.name_scope("Valid"):
            with tf.variable_scope("Model",
                                   reuse=True,
                                   initializer=initializer):
                valid_graph = ModelGraph(word_vocab=word_vocab,
                                         char_vocab=char_vocab,
                                         POS_vocab=POS_vocab,
                                         NER_vocab=NER_vocab,
                                         options=FLAGS,
                                         mode=valid_mode)

        initializer = tf.global_variables_initializer()

        vars_ = {}
        for var in tf.all_variables():
            if "word_embedding" in var.name: continue
            if not var.name.startswith("Model"): continue
            vars_[var.name.split(":")[0]] = var
        saver = tf.train.Saver(vars_)

        sess = tf.Session()
        sess.run(initializer)
        if has_pretrained_model:
            print("Restoring model from " + best_path)
            saver.restore(sess, best_path)
            print("DONE!")

            if FLAGS.mode == 'rl_train' and abs(best_bleu) < 0.00001:
                print("Getting BLEU score for the model")
                best_bleu = evaluate(sess,
                                     valid_graph,
                                     devDataStream,
                                     options=FLAGS)['dev_bleu']
                FLAGS.best_bleu = best_bleu
                namespace_utils.save_namespace(FLAGS,
                                               path_prefix + ".config.json")
                print('BLEU = %.4f' % best_bleu)
                log_file.write('BLEU = %.4f\n' % best_bleu)
            if FLAGS.mode == 'ce_train' and abs(best_accu) < 0.00001:
                print("Getting ACCU score for the model")
                best_accu = evaluate(sess,
                                     valid_graph,
                                     devDataStream,
                                     options=FLAGS)['dev_accu']
                FLAGS.best_accu = best_accu
                namespace_utils.save_namespace(FLAGS,
                                               path_prefix + ".config.json")
                print('ACCU = %.4f' % best_accu)
                log_file.write('ACCU = %.4f\n' % best_accu)

        print('Start the training loop.')
        train_size = trainDataStream.get_num_batch()
        max_steps = train_size * FLAGS.max_epochs
        total_loss = 0.0
        start_time = time.time()
        for step in xrange(max_steps):
            cur_batch = trainDataStream.nextBatch()
            if FLAGS.mode == 'rl_train':
                loss_value = train_graph.run_rl_training_2(
                    sess, cur_batch, FLAGS)
            elif FLAGS.mode == 'ce_train':
                loss_value = train_graph.run_ce_training(
                    sess, cur_batch, FLAGS)
            total_loss += loss_value

            if step % 100 == 0:
                print('{} '.format(step), end="")
                sys.stdout.flush()

            # Save a checkpoint and evaluate the model periodically.
            if (step + 1) % trainDataStream.get_num_batch() == 0 or (
                    step + 1) == max_steps:
                print()
                duration = time.time() - start_time
                print('Step %d: loss = %.2f (%.3f sec)' %
                      (step, total_loss, duration))
                log_file.write('Step %d: loss = %.2f (%.3f sec)\n' %
                               (step, total_loss, duration))
                log_file.flush()
                sys.stdout.flush()
                total_loss = 0.0

                # Evaluate against the validation set.
                start_time = time.time()
                print('Validation Data Eval:')
                res_dict = evaluate(sess,
                                    valid_graph,
                                    devDataStream,
                                    options=FLAGS,
                                    suffix=str(step))
                if valid_graph.mode == 'evaluate':
                    dev_loss = res_dict['dev_loss']
                    dev_accu = res_dict['dev_accu']
                    dev_right = int(res_dict['dev_right'])
                    dev_total = int(res_dict['dev_total'])
                    print('Dev loss = %.4f' % dev_loss)
                    log_file.write('Dev loss = %.4f\n' % dev_loss)
                    print('Dev accu = %.4f %d/%d' %
                          (dev_accu, dev_right, dev_total))
                    log_file.write('Dev accu = %.4f %d/%d\n' %
                                   (dev_accu, dev_right, dev_total))
                    log_file.flush()
                    if best_accu < dev_accu:
                        print('Saving weights, ACCU {} (prev_best) < {} (cur)'.
                              format(best_accu, dev_accu))
                        saver.save(sess, best_path)
                        best_accu = dev_accu
                        FLAGS.best_accu = dev_accu
                        namespace_utils.save_namespace(
                            FLAGS, path_prefix + ".config.json")
                else:
                    dev_bleu = res_dict['dev_bleu']
                    print('Dev bleu = %.4f' % dev_bleu)
                    log_file.write('Dev bleu = %.4f\n' % dev_bleu)
                    log_file.flush()
                    if best_bleu < dev_bleu:
                        print('Saving weights, BLEU {} (prev_best) < {} (cur)'.
                              format(best_bleu, dev_bleu))
                        saver.save(sess, best_path)
                        best_bleu = dev_bleu
                        FLAGS.best_bleu = dev_bleu
                        namespace_utils.save_namespace(
                            FLAGS, path_prefix + ".config.json")
                duration = time.time() - start_time
                print('Duration %.3f sec' % (duration))
                sys.stdout.flush()

                log_file.write('Duration %.3f sec\n' % (duration))
                log_file.flush()

    log_file.close()
Ejemplo n.º 2
0
def main(_):
    print('Configurations:')
    print(FLAGS)

    log_dir = FLAGS.model_dir
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    path_prefix = log_dir + "/NP2P.{}".format(FLAGS.suffix)
    init_model_prefix = FLAGS.init_model  # "/u/zhigwang/zhigwang1/sentence_generation/mscoco/logs/NP2P.phrase_ce_train"
    log_file_path = path_prefix + ".log"
    print('Log file path: {}'.format(log_file_path))
    log_file = open(log_file_path, 'wt')
    log_file.write("{}\n".format(FLAGS))
    log_file.flush()

    # save configuration
    namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")

    print('Loading train set.')
    if FLAGS.infile_format == 'fof':
        trainset, train_ans_len = NP2P_data_stream.read_generation_datasets_from_fof(
            FLAGS.train_path, isLower=FLAGS.isLower)
        if FLAGS.max_answer_len > train_ans_len:
            FLAGS.max_answer_len = train_ans_len
    else:
        trainset, train_ans_len = NP2P_data_stream.read_all_GQA_questions(
            FLAGS.train_path, isLower=FLAGS.isLower)
    print('Number of training samples: {}'.format(len(trainset)))

    print('Loading test set.')
    if FLAGS.infile_format == 'fof':
        testset, test_ans_len = NP2P_data_stream.read_generation_datasets_from_fof(
            FLAGS.test_path, isLower=FLAGS.isLower)
    else:
        testset, test_ans_len = NP2P_data_stream.read_all_GQA_questions(
            FLAGS.test_path, isLower=FLAGS.isLower)
    print('Number of test samples: {}'.format(len(testset)))

    max_actual_len = max(train_ans_len, test_ans_len)
    print('Max answer length: {}, truncated to {}'.format(
        max_actual_len, FLAGS.max_answer_len))

    word_vocab = None
    POS_vocab = None
    NER_vocab = None
    char_vocab = None
    has_pretrained_model = False
    best_path = path_prefix + ".best.model"
    if os.path.exists(init_model_prefix + ".best.model.index"):
        has_pretrained_model = True
        print('!!Existing pretrained model. Loading vocabs.')
        if FLAGS.with_word:
            word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2')
            print('word_vocab: {}'.format(word_vocab.word_vecs.shape))
        if FLAGS.with_char:
            char_vocab = Vocab(init_model_prefix + ".char_vocab",
                               fileformat='txt2')
            print('char_vocab: {}'.format(char_vocab.word_vecs.shape))
        if FLAGS.with_POS:
            POS_vocab = Vocab(init_model_prefix + ".POS_vocab",
                              fileformat='txt2')
            print('POS_vocab: {}'.format(POS_vocab.word_vecs.shape))
        if FLAGS.with_NER:
            NER_vocab = Vocab(init_model_prefix + ".NER_vocab",
                              fileformat='txt2')
            print('NER_vocab: {}'.format(NER_vocab.word_vecs.shape))
    else:
        print('Collecting vocabs.')
        (allWords, allChars, allPOSs,
         allNERs) = NP2P_data_stream.collect_vocabs(trainset)
        print('Number of words: {}'.format(len(allWords)))
        print('Number of allChars: {}'.format(len(allChars)))
        print('Number of allPOSs: {}'.format(len(allPOSs)))
        print('Number of allNERs: {}'.format(len(allNERs)))

        if FLAGS.with_word:
            word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2')
        if FLAGS.with_char:
            char_vocab = Vocab(voc=allChars,
                               dim=FLAGS.char_dim,
                               fileformat='build')
            char_vocab.dump_to_txt2(path_prefix + ".char_vocab")
        if FLAGS.with_POS:
            POS_vocab = Vocab(voc=allPOSs,
                              dim=FLAGS.POS_dim,
                              fileformat='build')
            POS_vocab.dump_to_txt2(path_prefix + ".POS_vocab")
        if FLAGS.with_NER:
            NER_vocab = Vocab(voc=allNERs,
                              dim=FLAGS.NER_dim,
                              fileformat='build')
            NER_vocab.dump_to_txt2(path_prefix + ".NER_vocab")

    print('word vocab size {}'.format(word_vocab.vocab_size))
    sys.stdout.flush()

    print('Build DataStream ... ')
    trainDataStream = NP2P_data_stream.QADataStream(trainset,
                                                    word_vocab,
                                                    char_vocab,
                                                    POS_vocab,
                                                    NER_vocab,
                                                    options=FLAGS,
                                                    isShuffle=True,
                                                    isLoop=True,
                                                    isSort=True)

    devDataStream = NP2P_data_stream.QADataStream(testset,
                                                  word_vocab,
                                                  char_vocab,
                                                  POS_vocab,
                                                  NER_vocab,
                                                  options=FLAGS,
                                                  isShuffle=False,
                                                  isLoop=False,
                                                  isSort=True)
    print('Number of instances in trainDataStream: {}'.format(
        trainDataStream.get_num_instance()))
    print('Number of instances in devDataStream: {}'.format(
        devDataStream.get_num_instance()))
    print('Number of batches in trainDataStream: {}'.format(
        trainDataStream.get_num_batch()))
    print('Number of batches in devDataStream: {}'.format(
        devDataStream.get_num_batch()))
    sys.stdout.flush()

    init_scale = 0.01
    with tf.Graph().as_default():
        initializer = tf.random_uniform_initializer(-init_scale, init_scale)
        with tf.name_scope("Train"):
            with tf.variable_scope("Model",
                                   reuse=None,
                                   initializer=initializer):
                train_graph = ModelGraph(word_vocab=word_vocab,
                                         char_vocab=char_vocab,
                                         POS_vocab=POS_vocab,
                                         NER_vocab=NER_vocab,
                                         options=FLAGS,
                                         mode="rl_train_for_phrase")

        with tf.name_scope("Valid"):
            with tf.variable_scope("Model",
                                   reuse=True,
                                   initializer=initializer):
                valid_graph = ModelGraph(word_vocab=word_vocab,
                                         char_vocab=char_vocab,
                                         POS_vocab=POS_vocab,
                                         NER_vocab=NER_vocab,
                                         options=FLAGS,
                                         mode="decode")

        initializer = tf.global_variables_initializer()

        vars_ = {}
        for var in tf.all_variables():
            if "word_embedding" in var.name: continue
            if not var.name.startswith("Model"): continue
            vars_[var.name.split(":")[0]] = var
        saver = tf.train.Saver(vars_)

        sess = tf.Session()
        sess.run(initializer)
        if has_pretrained_model:
            print("Restoring model from " + init_model_prefix + ".best.model")
            saver.restore(sess, init_model_prefix + ".best.model")
            print("DONE!")
        sys.stdout.flush()

        # for first-time rl training, we get the current BLEU score
        print("First-time rl training, get the current BLEU score on dev")
        sys.stdout.flush()
        best_bleu = evaluate(sess,
                             valid_graph,
                             devDataStream,
                             word_vocab,
                             options=FLAGS)
        print('First-time bleu = %.4f' % best_bleu)
        log_file.write('First-time bleu = %.4f\n' % best_bleu)

        print('Start the training loop.')
        sys.stdout.flush()
        train_size = trainDataStream.get_num_batch()
        max_steps = train_size * FLAGS.max_epochs
        total_loss = 0.0
        start_time = time.time()
        for step in xrange(max_steps):
            cur_batch = trainDataStream.nextBatch()
            if FLAGS.with_baseline:
                # greedy search
                (greedy_sentences, _, _,
                 _) = NP2P_beam_decoder.search(sess,
                                               valid_graph,
                                               word_vocab,
                                               cur_batch,
                                               FLAGS,
                                               decode_mode="greedy")

            if FLAGS.with_target_lattice:
                (sampled_sentences, sampled_prediction_lengths,
                 sampled_generator_input_idx, sampled_generator_output_idx
                 ) = cur_batch.sample_a_partition()
            else:
                # multinomial sampling
                (sampled_sentences, sampled_prediction_lengths,
                 sampled_generator_input_idx,
                 sampled_generator_output_idx) = NP2P_beam_decoder.search(
                     sess,
                     valid_graph,
                     word_vocab,
                     cur_batch,
                     FLAGS,
                     decode_mode="multinomial")
            # calculate rewards
            rewards = []
            for i in xrange(cur_batch.batch_size):
                #                 print(sampled_sentences[i])
                #                 print(sampled_generator_input_idx[i])
                #                 print(sampled_generator_output_idx[i])
                cur_toks = cur_batch.instances[i][1].tokText.split()
                #                 r = sentence_bleu([cur_toks], sampled_sentences[i].split(), smoothing_function=cc.method3)
                r = 1.0
                b = 0.0
                if FLAGS.with_baseline:
                    b = sentence_bleu([cur_toks],
                                      greedy_sentences[i].split(),
                                      smoothing_function=cc.method3)


#                 r = metric_utils.evaluate_captions([cur_toks],[sampled_sentences[i]])
#                 b = metric_utils.evaluate_captions([cur_toks],[greedy_sentences[i]])
                rewards.append(1.0 * (r - b))
            rewards = np.array(rewards, dtype=np.float32)
            #             sys.exit(-1)

            # update parameters
            feed_dict = train_graph.run_encoder(sess,
                                                cur_batch,
                                                FLAGS,
                                                only_feed_dict=True)
            feed_dict[train_graph.reward] = rewards
            feed_dict[
                train_graph.gen_input_words] = sampled_generator_input_idx
            feed_dict[
                train_graph.in_answer_words] = sampled_generator_output_idx
            feed_dict[train_graph.answer_lengths] = sampled_prediction_lengths
            (_,
             loss_value) = sess.run([train_graph.train_op, train_graph.loss],
                                    feed_dict)
            total_loss += loss_value

            if step % 100 == 0:
                print('{} '.format(step), end="")
                sys.stdout.flush()

            # Save a checkpoint and evaluate the model periodically.
            if (step + 1) % trainDataStream.get_num_batch() == 0 or (
                    step + 1) == max_steps:
                print()
                duration = time.time() - start_time
                print('Step %d: loss = %.2f (%.3f sec)' %
                      (step, total_loss, duration))
                log_file.write('Step %d: loss = %.2f (%.3f sec)\n' %
                               (step, total_loss, duration))
                log_file.flush()
                sys.stdout.flush()
                total_loss = 0.0

                # Evaluate against the validation set.
                start_time = time.time()
                print('Validation Data Eval:')
                dev_bleu = evaluate(sess,
                                    valid_graph,
                                    devDataStream,
                                    word_vocab,
                                    options=FLAGS)
                print('Dev bleu = %.4f' % dev_bleu)
                log_file.write('Dev bleu = %.4f\n' % dev_bleu)
                log_file.flush()
                if best_bleu < dev_bleu:
                    print('Saving weights, BLEU {} (prev_best) < {} (cur)'.
                          format(best_bleu, dev_bleu))
                    best_bleu = dev_bleu
                    saver.save(sess, best_path)  # TODO: save model
                duration = time.time() - start_time
                print('Duration %.3f sec' % (duration))
                sys.stdout.flush()
                log_file.write('Duration %.3f sec\n' % (duration))
                log_file.flush()

    log_file.close()
Ejemplo n.º 3
0
    print('Build DataStream ... ')
    batch_size = -1
    if mode not in (
            'pointwise',
            'multinomial',
            'greedy',
            'greedy_evaluate',
    ):
        batch_size = 1
    devDataStream = NP2P_data_stream.QADataStream(testset,
                                                  enc_word_vocab,
                                                  dec_word_vocab,
                                                  char_vocab,
                                                  POS_vocab,
                                                  NER_vocab,
                                                  options=FLAGS,
                                                  isShuffle=False,
                                                  isLoop=False,
                                                  isSort=True,
                                                  batch_size=batch_size)
    print('Number of instances in testDataStream: {}'.format(
        devDataStream.get_num_instance()))
    print('Number of batches in testDataStream: {}'.format(
        devDataStream.get_num_batch()))

    best_path = model_prefix + ".best.model"
    with tf.Graph().as_default():
        initializer = tf.random_uniform_initializer(-0.01, 0.01)
        with tf.name_scope("Valid"):
            with tf.variable_scope("Model",
Ejemplo n.º 4
0
def question_gen_run(argv):
    #parser = argparse.ArgumentParser()
    #parser.add_argument('--model_prefix', type=str, required=True, help='Prefix to the models.')
    #parser.add_argument('--in_path', type=str, required=True, help='The path to the test file.')
    #parser.add_argument('--out_path', type=str, required=True, help='The path to the output file.')
    #parser.add_argument('--mode', type=str, required=True, help='Can be `greedy` or `beam`')

    #args, unparsed = parser.parse_known_args()

    #model_prefix = args.model_prefix
    #in_path = args.in_path
    #out_path = args.out_path
    #mode = args.mode
    print(sys.argv)
    model_prefix = argv[0]
    in_path = argv[1]
    out_path = argv[2]
    mode = argv[3]

    print("CUDA_VISIBLE_DEVICES " + os.environ['CUDA_VISIBLE_DEVICES'])

    # load the configuration file
    print('Loading configurations from ' + model_prefix + ".config.json")
    FLAGS = namespace_utils.load_namespace(model_prefix + ".config.json")
    FLAGS = NP2P_trainer.enrich_options(FLAGS)

    # load vocabs
    print('Loading vocabs.')
    word_vocab = char_vocab = POS_vocab = NER_vocab = None
    if FLAGS.with_word:
        word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2')
        print('word_vocab: {}'.format(word_vocab.word_vecs.shape))
    if FLAGS.with_char:
        char_vocab = Vocab(model_prefix + ".char_vocab", fileformat='txt2')
        print('char_vocab: {}'.format(char_vocab.word_vecs.shape))
    if FLAGS.with_POS:
        POS_vocab = Vocab(model_prefix + ".POS_vocab", fileformat='txt2')
        print('POS_vocab: {}'.format(POS_vocab.word_vecs.shape))
    if FLAGS.with_NER:
        NER_vocab = Vocab(model_prefix + ".NER_vocab", fileformat='txt2')
        print('NER_vocab: {}'.format(NER_vocab.word_vecs.shape))

    print('Loading test set.')
    if FLAGS.infile_format == 'fof':
        testset, _ = NP2P_data_stream.read_generation_datasets_from_fof(
            in_path, isLower=FLAGS.isLower)
    elif FLAGS.infile_format == 'plain':
        testset, _ = NP2P_data_stream.read_all_GenerationDatasets(
            in_path, isLower=FLAGS.isLower)
    else:
        testset, _ = NP2P_data_stream.read_all_GQA_questions(
            in_path, isLower=FLAGS.isLower, switch=FLAGS.switch_qa)
    print('Number of samples: {}'.format(len(testset)))

    print('Build DataStream ... ')
    batch_size = -1
    if mode.find('beam') >= 0: batch_size = 1
    devDataStream = NP2P_data_stream.QADataStream(testset,
                                                  word_vocab,
                                                  char_vocab,
                                                  POS_vocab,
                                                  NER_vocab,
                                                  options=FLAGS,
                                                  isShuffle=False,
                                                  isLoop=False,
                                                  isSort=True,
                                                  batch_size=batch_size)
    print('Number of instances in testDataStream: {}'.format(
        devDataStream.get_num_instance()))
    print('Number of batches in testDataStream: {}'.format(
        devDataStream.get_num_batch()))

    best_path = model_prefix + ".best.model"
    with tf.Graph().as_default():
        initializer = tf.random_uniform_initializer(-0.01, 0.01)
        with tf.name_scope("Valid"):
            with tf.variable_scope("Model",
                                   reuse=False,
                                   initializer=initializer):
                valid_graph = ModelGraph(word_vocab=word_vocab,
                                         char_vocab=char_vocab,
                                         POS_vocab=POS_vocab,
                                         NER_vocab=NER_vocab,
                                         options=FLAGS,
                                         mode="decode")

        ## remove word _embedding
        vars_ = {}
        for var in tf.all_variables():
            if "word_embedding" in var.name: continue
            if not var.name.startswith("Model"): continue
            vars_[var.name.split(":")[0]] = var
        saver = tf.train.Saver(vars_)

        initializer = tf.global_variables_initializer()
        #gpu_fraction = 0.1
        #gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=gpu_fraction)
        #sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
        sess = tf.Session()
        sess.run(initializer)

        saver.restore(sess, best_path)  # restore the model

        total = 0
        correct = 0
        if mode.endswith('evaluate'):
            ref_outfile = open(out_path + ".ref", 'wt')
            pred_outfile = open(out_path + ".pred", 'wt')
        else:
            outfile = open(out_path, 'wt')
        total_num = devDataStream.get_num_batch()
        devDataStream.reset()
        for i in range(total_num):
            cur_batch = devDataStream.get_batch(i)
            if mode == 'pointwise':
                (sentences, prediction_lengths, generator_input_idx,
                 generator_output_idx) = search(sess,
                                                valid_graph,
                                                word_vocab,
                                                cur_batch,
                                                FLAGS,
                                                decode_mode=mode)
                for j in xrange(cur_batch.batch_size):
                    cur_total = cur_batch.answer_lengths[j]
                    cur_correct = 0
                    for k in xrange(cur_total):
                        if generator_output_idx[
                                j, k] == cur_batch.in_answer_words[j, k]:
                            cur_correct += 1.0
                    total += cur_total
                    correct += cur_correct
                    outfile.write(
                        cur_batch.instances[j][1].tokText.encode('utf-8') +
                        "\n")
                    outfile.write(sentences[j].encode('utf-8') + "\n")
                    outfile.write("========\n")
                outfile.flush()
                print('Current dev accuracy is %d/%d=%.2f' %
                      (correct, total, correct / float(total) * 100))
            elif mode in ['greedy', 'multinomial']:
                print('Batch {}'.format(i))
                (sentences, prediction_lengths, generator_input_idx,
                 generator_output_idx) = search(sess,
                                                valid_graph,
                                                word_vocab,
                                                cur_batch,
                                                FLAGS,
                                                decode_mode=mode)
                for j in xrange(cur_batch.batch_size):
                    outfile.write(
                        cur_batch.instances[j][1].ID_num.encode('utf-8') +
                        "\n")
                    outfile.write(
                        cur_batch.instances[j][1].tokText.encode('utf-8') +
                        "\n")
                    outfile.write(sentences[j].encode('utf-8') + "\n")
                    outfile.write("========\n")
                outfile.flush()
            elif mode == 'greedy_evaluate':
                print('Batch {}'.format(i))
                (sentences, prediction_lengths, generator_input_idx,
                 generator_output_idx) = search(sess,
                                                valid_graph,
                                                word_vocab,
                                                cur_batch,
                                                FLAGS,
                                                decode_mode="greedy")
                for j in xrange(cur_batch.batch_size):
                    ref_outfile.write(
                        cur_batch.instances[j][1].tokText.encode('utf-8') +
                        "\n")
                    pred_outfile.write(sentences[j].encode('utf-8') + "\n")
                ref_outfile.flush()
                pred_outfile.flush()
            elif mode == 'beam_evaluate':
                print('Instance {}'.format(i))
                ref_outfile.write(
                    cur_batch.instances[0][1].tokText.encode('utf-8') + "\n")
                ref_outfile.flush()
                hyps = run_beam_search(sess, valid_graph, word_vocab,
                                       cur_batch, FLAGS)
                cur_passage = cur_batch.instances[0][0]
                cur_id2phrase = None
                if FLAGS.with_phrase_projection:
                    (cur_phrase2id, cur_id2phrase) = cur_batch.phrase_vocabs[0]
                cur_sent = hyps[0].idx_seq_to_string(cur_passage,
                                                     cur_id2phrase, word_vocab,
                                                     FLAGS)
                pred_outfile.write(cur_sent.encode('utf-8') + "\n")
                pred_outfile.flush()
            else:  # beam search
                print('Instance {}'.format(i))
                hyps = run_beam_search(sess, valid_graph, word_vocab,
                                       cur_batch, FLAGS)
                outfile.write(
                    "Input: " +
                    cur_batch.instances[0][0].tokText.encode('utf-8') + "\n")
                outfile.write(
                    "Truth: " +
                    cur_batch.instances[0][1].tokText.encode('utf-8') + "\n")
                for j in xrange(len(hyps)):
                    hyp = hyps[j]
                    cur_passage = cur_batch.instances[0][0]
                    cur_id2phrase = None
                    if FLAGS.with_phrase_projection:
                        (cur_phrase2id,
                         cur_id2phrase) = cur_batch.phrase_vocabs[0]
                    cur_sent = hyp.idx_seq_to_string(cur_passage,
                                                     cur_id2phrase, word_vocab,
                                                     FLAGS)
                    outfile.write("Hyp-{}: ".format(j) +
                                  cur_sent.encode('utf-8') +
                                  " {}".format(hyp.avg_log_prob()) + "\n")
                #outfile.write("========\n")
                outfile.flush()
        if mode.endswith('evaluate'):
            ref_outfile.close()
            pred_outfile.close()
        else:
            outfile.close()