def main(_):
    print('Configurations:')
    print(FLAGS)

    log_dir = FLAGS.model_dir
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    path_prefix = log_dir + "/NP2P.{}".format(FLAGS.suffix)
    log_file_path = path_prefix + ".log"
    print('Log file path: {}'.format(log_file_path))
    log_file = open(log_file_path, 'wt')
    log_file.write("{}\n".format(FLAGS))
    log_file.flush()

    # save configuration
    namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")

    print('Loading training set.')
    trainset, train_ans_len = NP2P_data_stream.read_all_GenerationDatasets(
        FLAGS.train_path, isLower=FLAGS.isLower)
    print('Number of training samples: {}'.format(len(trainset)))

    print('Loading dev set.')
    devset, dev_ans_len = NP2P_data_stream.read_all_GenerationDatasets(
        FLAGS.test_path, isLower=FLAGS.isLower)
    print('Number of dev samples: {}'.format(len(devset)))

    if FLAGS.finetune_path != "":
        print('Loading finetune set.')
        ftset, ft_ans_len = NP2P_data_stream.read_all_GenerationDatasets(
            FLAGS.ft_path, isLower=FLAGS.isLower)
        print('Number of finetune samples: {}'.format(len(ftset)))
    else:
        ftset, ft_ans_len = (None, 0)

    max_actual_len = max(train_ans_len, ft_ans_len, dev_ans_len)
    print('Max answer length: {}, truncated to {}'.format(
        max_actual_len, FLAGS.max_answer_len))

    enc_word_vocab = None
    dec_word_vocab = None
    char_vocab = None
    has_pretrained_model = False
    best_path = path_prefix + ".best.model"
    if os.path.exists(best_path + ".index"):
        has_pretrained_model = True
        print('!!Existing pretrained model. Loading vocabs.')
        if FLAGS.with_word:
            enc_word_vocab = Vocab(FLAGS.enc_word_vec_path, fileformat='txt2')
            dec_word_vocab = Vocab(FLAGS.dec_word_vec_path, fileformat='txt2')
            print('Encoder word vocab: {}'.format(
                enc_word_vocab.word_vecs.shape))
            print('Decoder word vocab: {}'.format(
                dec_word_vocab.word_vecs.shape))
        if FLAGS.with_char:
            char_vocab = Vocab(path_prefix + ".char_vocab", fileformat='txt2')
            print('char_vocab: {}'.format(char_vocab.word_vecs.shape))
    else:
        print('Collecting vocabs.')
        (allWords, allChars) = NP2P_data_stream.collect_vocabs(trainset)
        print('Number of words: {}'.format(len(allWords)))
        print('Number of allChars: {}'.format(len(allChars)))

        if FLAGS.with_word:
            enc_word_vocab = Vocab(FLAGS.enc_word_vec_path, fileformat='txt2')
            dec_word_vocab = Vocab(FLAGS.dec_word_vec_path, fileformat='txt2')
        if FLAGS.with_char:
            char_vocab = Vocab(voc=allChars,
                               dim=FLAGS.char_dim,
                               fileformat='build')
            char_vocab.dump_to_txt2(path_prefix + ".char_vocab")

    print('Encoder word vocab size {}'.format(enc_word_vocab.vocab_size))
    print('Decoder word vocab size {}'.format(dec_word_vocab.vocab_size))
    sys.stdout.flush()

    print('Build DataStream ... ')
    trainDataStream = NP2P_data_stream.DataStream(trainset,
                                                  enc_word_vocab,
                                                  dec_word_vocab,
                                                  char_vocab,
                                                  options=FLAGS,
                                                  isShuffle=True,
                                                  isLoop=True,
                                                  isSort=True)
    devDataStream = NP2P_data_stream.DataStream(devset,
                                                enc_word_vocab,
                                                dec_word_vocab,
                                                char_vocab,
                                                options=FLAGS,
                                                isShuffle=False,
                                                isLoop=False,
                                                isSort=True)
    print('Number of instances in trainDataStream: {}'.format(
        trainDataStream.get_num_instance()))
    print('Number of instances in devDataStream: {}'.format(
        devDataStream.get_num_instance()))
    print('Number of batches in trainDataStream: {}'.format(
        trainDataStream.get_num_batch()))
    print('Number of batches in devDataStream: {}'.format(
        devDataStream.get_num_batch()))
    if ftset != None:
        ftDataStream = NP2P_data_stream.DataStream(ftset,
                                                   enc_word_vocab,
                                                   dec_word_vocab,
                                                   char_vocab,
                                                   options=FLAGS,
                                                   isShuffle=True,
                                                   isLoop=True,
                                                   isSort=True)
        print('Number of instances in ftDataStream: {}'.format(
            ftDataStream.get_num_instance()))
        print('Number of batches in ftDataStream: {}'.format(
            ftDataStream.get_num_batch()))

    sys.stdout.flush()

    init_scale = 0.01
    # initialize the best bleu and accu scores for current training session
    best_accu = FLAGS.best_accu if FLAGS.__dict__.has_key('best_accu') else 0.0
    best_bleu = FLAGS.best_bleu if FLAGS.__dict__.has_key('best_bleu') else 0.0
    if best_accu > 0.0:
        print('With initial dev accuracy {}'.format(best_accu))
    if best_bleu > 0.0:
        print('With initial dev BLEU score {}'.format(best_bleu))

    with tf.Graph().as_default():
        initializer = tf.random_uniform_initializer(-init_scale, init_scale)
        with tf.name_scope("Train"):
            with tf.variable_scope("Model",
                                   reuse=None,
                                   initializer=initializer):
                train_graph = ModelGraph(enc_word_vocab=enc_word_vocab,
                                         dec_word_vocab=dec_word_vocab,
                                         char_vocab=char_vocab,
                                         POS_vocab=None,
                                         NER_vocab=None,
                                         options=FLAGS,
                                         mode=FLAGS.mode)

        assert FLAGS.mode in (
            'ce_train',
            'rl_train',
        )
        valid_mode = 'evaluate' if FLAGS.mode == 'ce_train' else 'evaluate_bleu'

        with tf.name_scope("Valid"):
            with tf.variable_scope("Model",
                                   reuse=True,
                                   initializer=initializer):
                valid_graph = ModelGraph(enc_word_vocab=enc_word_vocab,
                                         dec_word_vocab=dec_word_vocab,
                                         char_vocab=char_vocab,
                                         POS_vocab=None,
                                         NER_vocab=None,
                                         options=FLAGS,
                                         mode=valid_mode)

        initializer = tf.global_variables_initializer()

        vars_ = {}
        for var in tf.all_variables():
            if FLAGS.fix_word_vec and "word_embedding" in var.name: continue
            if not var.name.startswith("Model"): continue
            print(var)
            vars_[var.name.split(":")[0]] = var
        saver = tf.train.Saver(vars_)

        sess = tf.Session()
        sess.run(initializer)
        if has_pretrained_model:
            print("Restoring model from " + best_path)
            saver.restore(sess, best_path)
            print("DONE!")

            if FLAGS.mode == 'rl_train' and abs(best_bleu) < 0.00001:
                print("Getting BLEU score for the model")
                best_bleu = evaluate(sess,
                                     valid_graph,
                                     devDataStream,
                                     options=FLAGS)['dev_bleu']
                FLAGS.best_bleu = best_bleu
                namespace_utils.save_namespace(FLAGS,
                                               path_prefix + ".config.json")
                print('BLEU = %.4f' % best_bleu)
                log_file.write('BLEU = %.4f\n' % best_bleu)
            if FLAGS.mode == 'ce_train' and abs(best_accu) < 0.00001:
                print("Getting ACCU score for the model")
                best_accu = evaluate(sess,
                                     valid_graph,
                                     devDataStream,
                                     options=FLAGS)['dev_accu']
                FLAGS.best_accu = best_accu
                namespace_utils.save_namespace(FLAGS,
                                               path_prefix + ".config.json")
                print('ACCU = %.4f' % best_accu)
                log_file.write('ACCU = %.4f\n' % best_accu)

        print('Start the training loop.')
        train_size = trainDataStream.get_num_batch()
        max_steps = train_size * FLAGS.max_epochs
        total_loss = 0.0
        start_time = time.time()
        for step in xrange(max_steps):
            cur_batch = trainDataStream.nextBatch()
            if FLAGS.mode == 'rl_train':
                loss_value = train_graph.run_rl_training_2(
                    sess, cur_batch, FLAGS)
            elif FLAGS.mode == 'ce_train':
                loss_value = train_graph.run_ce_training(
                    sess, cur_batch, FLAGS)
            total_loss += loss_value

            if step % 100 == 0:
                print('{} '.format(step), end="")
                sys.stdout.flush()

            # Save a checkpoint and evaluate the model periodically.
            if (step + 1) % trainDataStream.get_num_batch() == 0 or (step + 1) == max_steps or \
                     (trainDataStream.get_num_batch() > 10000 and (step + 1) % 2000 == 0):
                print()
                duration = time.time() - start_time
                print('Step %d: loss = %.2f (%.3f sec)' %
                      (step, total_loss, duration))
                log_file.write('Step %d: loss = %.2f (%.3f sec)\n' %
                               (step, total_loss, duration))
                log_file.flush()
                sys.stdout.flush()
                total_loss = 0.0

                if ftset != None:
                    best_accu, best_bleu = fine_tune(sess, saver, FLAGS,
                                                     log_file, ftDataStream,
                                                     devDataStream,
                                                     train_graph, valid_graph,
                                                     path_prefix, best_accu,
                                                     best_bleu)
                else:
                    best_accu, best_bleu = validate_and_save(
                        sess, saver, FLAGS, log_file, devDataStream,
                        valid_graph, path_prefix, best_accu, best_bleu)
                start_time = time.time()

    log_file.close()
Example #2
0
def main(_):
    print('Configurations:')
    print(FLAGS)

    log_dir = FLAGS.model_dir
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    path_prefix = log_dir + "/NP2P.{}".format(FLAGS.suffix)
    log_file_path = path_prefix + ".log"
    print('Log file path: {}'.format(log_file_path))
    log_file = open(log_file_path, 'wt')
    log_file.write("{}\n".format(FLAGS))
    log_file.flush()

    # save configuration
    namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")

    print('Loading train set.')
    if FLAGS.infile_format == 'fof':
        trainset, train_ans_len = NP2P_data_stream.read_generation_datasets_from_fof(
            FLAGS.train_path, isLower=FLAGS.isLower)
    elif FLAGS.infile_format == 'plain':
        trainset, train_ans_len = NP2P_data_stream.read_all_GenerationDatasets(
            FLAGS.train_path, isLower=FLAGS.isLower)
    else:
        trainset, train_ans_len = NP2P_data_stream.read_all_GQA_questions(
            FLAGS.train_path, isLower=FLAGS.isLower, switch=FLAGS.switch_qa)
    print('Number of training samples: {}'.format(len(trainset)))

    print('Loading test set.')
    if FLAGS.infile_format == 'fof':
        testset, test_ans_len = NP2P_data_stream.read_generation_datasets_from_fof(
            FLAGS.test_path, isLower=FLAGS.isLower)
    elif FLAGS.infile_format == 'plain':
        testset, test_ans_len = NP2P_data_stream.read_all_GenerationDatasets(
            FLAGS.test_path, isLower=FLAGS.isLower)
    else:
        testset, test_ans_len = NP2P_data_stream.read_all_GQA_questions(
            FLAGS.test_path, isLower=FLAGS.isLower, switch=FLAGS.switch_qa)
    print('Number of test samples: {}'.format(len(testset)))

    max_actual_len = max(train_ans_len, test_ans_len)
    print('Max answer length: {}, truncated to {}'.format(
        max_actual_len, FLAGS.max_answer_len))

    word_vocab = None
    POS_vocab = None
    NER_vocab = None
    char_vocab = None
    has_pretrained_model = False
    best_path = path_prefix + ".best.model"
    if os.path.exists(best_path + ".index"):
        has_pretrained_model = True
        print('!!Existing pretrained model. Loading vocabs.')
        if FLAGS.with_word:
            word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2')
            print('word_vocab: {}'.format(word_vocab.word_vecs.shape))
        if FLAGS.with_char:
            char_vocab = Vocab(path_prefix + ".char_vocab", fileformat='txt2')
            print('char_vocab: {}'.format(char_vocab.word_vecs.shape))
        if FLAGS.with_POS:
            POS_vocab = Vocab(path_prefix + ".POS_vocab", fileformat='txt2')
            print('POS_vocab: {}'.format(POS_vocab.word_vecs.shape))
        if FLAGS.with_NER:
            NER_vocab = Vocab(path_prefix + ".NER_vocab", fileformat='txt2')
            print('NER_vocab: {}'.format(NER_vocab.word_vecs.shape))
    else:
        print('Collecting vocabs.')
        (allWords, allChars, allPOSs,
         allNERs) = NP2P_data_stream.collect_vocabs(trainset)
        print('Number of words: {}'.format(len(allWords)))
        print('Number of allChars: {}'.format(len(allChars)))
        print('Number of allPOSs: {}'.format(len(allPOSs)))
        print('Number of allNERs: {}'.format(len(allNERs)))

        if FLAGS.with_word:
            word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2')
        if FLAGS.with_char:
            char_vocab = Vocab(voc=allChars,
                               dim=FLAGS.char_dim,
                               fileformat='build')
            char_vocab.dump_to_txt2(path_prefix + ".char_vocab")
        if FLAGS.with_POS:
            POS_vocab = Vocab(voc=allPOSs,
                              dim=FLAGS.POS_dim,
                              fileformat='build')
            POS_vocab.dump_to_txt2(path_prefix + ".POS_vocab")
        if FLAGS.with_NER:
            NER_vocab = Vocab(voc=allNERs,
                              dim=FLAGS.NER_dim,
                              fileformat='build')
            NER_vocab.dump_to_txt2(path_prefix + ".NER_vocab")

    print('word vocab size {}'.format(word_vocab.vocab_size))
    sys.stdout.flush()

    print('Build DataStream ... ')
    trainDataStream = NP2P_data_stream.QADataStream(trainset,
                                                    word_vocab,
                                                    char_vocab,
                                                    POS_vocab,
                                                    NER_vocab,
                                                    options=FLAGS,
                                                    isShuffle=True,
                                                    isLoop=True,
                                                    isSort=True)

    devDataStream = NP2P_data_stream.QADataStream(testset,
                                                  word_vocab,
                                                  char_vocab,
                                                  POS_vocab,
                                                  NER_vocab,
                                                  options=FLAGS,
                                                  isShuffle=False,
                                                  isLoop=False,
                                                  isSort=True)
    print('Number of instances in trainDataStream: {}'.format(
        trainDataStream.get_num_instance()))
    print('Number of instances in devDataStream: {}'.format(
        devDataStream.get_num_instance()))
    print('Number of batches in trainDataStream: {}'.format(
        trainDataStream.get_num_batch()))
    print('Number of batches in devDataStream: {}'.format(
        devDataStream.get_num_batch()))
    sys.stdout.flush()

    init_scale = 0.01
    # initialize the best bleu and accu scores for current training session
    best_accu = FLAGS.best_accu if FLAGS.__dict__.has_key('best_accu') else 0.0
    best_bleu = FLAGS.best_bleu if FLAGS.__dict__.has_key('best_bleu') else 0.0
    if best_accu > 0.0:
        print('With initial dev accuracy {}'.format(best_accu))
    if best_bleu > 0.0:
        print('With initial dev BLEU score {}'.format(best_bleu))

    with tf.Graph().as_default():
        initializer = tf.random_uniform_initializer(-init_scale, init_scale)
        with tf.name_scope("Train"):
            with tf.variable_scope("Model",
                                   reuse=None,
                                   initializer=initializer):
                train_graph = ModelGraph(word_vocab=word_vocab,
                                         char_vocab=char_vocab,
                                         POS_vocab=POS_vocab,
                                         NER_vocab=NER_vocab,
                                         options=FLAGS,
                                         mode=FLAGS.mode)

        assert FLAGS.mode in (
            'ce_train',
            'rl_train',
        )
        valid_mode = 'evaluate' if FLAGS.mode == 'ce_train' else 'evaluate_bleu'

        with tf.name_scope("Valid"):
            with tf.variable_scope("Model",
                                   reuse=True,
                                   initializer=initializer):
                valid_graph = ModelGraph(word_vocab=word_vocab,
                                         char_vocab=char_vocab,
                                         POS_vocab=POS_vocab,
                                         NER_vocab=NER_vocab,
                                         options=FLAGS,
                                         mode=valid_mode)

        initializer = tf.global_variables_initializer()

        vars_ = {}
        for var in tf.all_variables():
            if "word_embedding" in var.name: continue
            if not var.name.startswith("Model"): continue
            vars_[var.name.split(":")[0]] = var
        saver = tf.train.Saver(vars_)

        sess = tf.Session()
        sess.run(initializer)
        if has_pretrained_model:
            print("Restoring model from " + best_path)
            saver.restore(sess, best_path)
            print("DONE!")

            if FLAGS.mode == 'rl_train' and abs(best_bleu) < 0.00001:
                print("Getting BLEU score for the model")
                best_bleu = evaluate(sess,
                                     valid_graph,
                                     devDataStream,
                                     options=FLAGS)['dev_bleu']
                FLAGS.best_bleu = best_bleu
                namespace_utils.save_namespace(FLAGS,
                                               path_prefix + ".config.json")
                print('BLEU = %.4f' % best_bleu)
                log_file.write('BLEU = %.4f\n' % best_bleu)
            if FLAGS.mode == 'ce_train' and abs(best_accu) < 0.00001:
                print("Getting ACCU score for the model")
                best_accu = evaluate(sess,
                                     valid_graph,
                                     devDataStream,
                                     options=FLAGS)['dev_accu']
                FLAGS.best_accu = best_accu
                namespace_utils.save_namespace(FLAGS,
                                               path_prefix + ".config.json")
                print('ACCU = %.4f' % best_accu)
                log_file.write('ACCU = %.4f\n' % best_accu)

        print('Start the training loop.')
        train_size = trainDataStream.get_num_batch()
        max_steps = train_size * FLAGS.max_epochs
        total_loss = 0.0
        start_time = time.time()
        for step in xrange(max_steps):
            cur_batch = trainDataStream.nextBatch()
            if FLAGS.mode == 'rl_train':
                loss_value = train_graph.run_rl_training_2(
                    sess, cur_batch, FLAGS)
            elif FLAGS.mode == 'ce_train':
                loss_value = train_graph.run_ce_training(
                    sess, cur_batch, FLAGS)
            total_loss += loss_value

            if step % 100 == 0:
                print('{} '.format(step), end="")
                sys.stdout.flush()

            # Save a checkpoint and evaluate the model periodically.
            if (step + 1) % trainDataStream.get_num_batch() == 0 or (
                    step + 1) == max_steps:
                print()
                duration = time.time() - start_time
                print('Step %d: loss = %.2f (%.3f sec)' %
                      (step, total_loss, duration))
                log_file.write('Step %d: loss = %.2f (%.3f sec)\n' %
                               (step, total_loss, duration))
                log_file.flush()
                sys.stdout.flush()
                total_loss = 0.0

                # Evaluate against the validation set.
                start_time = time.time()
                print('Validation Data Eval:')
                res_dict = evaluate(sess,
                                    valid_graph,
                                    devDataStream,
                                    options=FLAGS,
                                    suffix=str(step))
                if valid_graph.mode == 'evaluate':
                    dev_loss = res_dict['dev_loss']
                    dev_accu = res_dict['dev_accu']
                    dev_right = int(res_dict['dev_right'])
                    dev_total = int(res_dict['dev_total'])
                    print('Dev loss = %.4f' % dev_loss)
                    log_file.write('Dev loss = %.4f\n' % dev_loss)
                    print('Dev accu = %.4f %d/%d' %
                          (dev_accu, dev_right, dev_total))
                    log_file.write('Dev accu = %.4f %d/%d\n' %
                                   (dev_accu, dev_right, dev_total))
                    log_file.flush()
                    if best_accu < dev_accu:
                        print('Saving weights, ACCU {} (prev_best) < {} (cur)'.
                              format(best_accu, dev_accu))
                        saver.save(sess, best_path)
                        best_accu = dev_accu
                        FLAGS.best_accu = dev_accu
                        namespace_utils.save_namespace(
                            FLAGS, path_prefix + ".config.json")
                else:
                    dev_bleu = res_dict['dev_bleu']
                    print('Dev bleu = %.4f' % dev_bleu)
                    log_file.write('Dev bleu = %.4f\n' % dev_bleu)
                    log_file.flush()
                    if best_bleu < dev_bleu:
                        print('Saving weights, BLEU {} (prev_best) < {} (cur)'.
                              format(best_bleu, dev_bleu))
                        saver.save(sess, best_path)
                        best_bleu = dev_bleu
                        FLAGS.best_bleu = dev_bleu
                        namespace_utils.save_namespace(
                            FLAGS, path_prefix + ".config.json")
                duration = time.time() - start_time
                print('Duration %.3f sec' % (duration))
                sys.stdout.flush()

                log_file.write('Duration %.3f sec\n' % (duration))
                log_file.flush()

    log_file.close()
def main(_):
    print('Configurations:')
    print(FLAGS)

    log_dir = FLAGS.model_dir
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    path_prefix = log_dir + "/NP2P.{}".format(FLAGS.suffix)
    init_model_prefix = FLAGS.init_model  # "/u/zhigwang/zhigwang1/sentence_generation/mscoco/logs/NP2P.phrase_ce_train"
    log_file_path = path_prefix + ".log"
    print('Log file path: {}'.format(log_file_path))
    log_file = open(log_file_path, 'wt')
    log_file.write("{}\n".format(FLAGS))
    log_file.flush()

    # save configuration
    namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")

    print('Loading train set.')
    if FLAGS.infile_format == 'fof':
        trainset, train_ans_len = NP2P_data_stream.read_generation_datasets_from_fof(
            FLAGS.train_path, isLower=FLAGS.isLower)
        if FLAGS.max_answer_len > train_ans_len:
            FLAGS.max_answer_len = train_ans_len
    else:
        trainset, train_ans_len = NP2P_data_stream.read_all_GQA_questions(
            FLAGS.train_path, isLower=FLAGS.isLower)
    print('Number of training samples: {}'.format(len(trainset)))

    print('Loading test set.')
    if FLAGS.infile_format == 'fof':
        testset, test_ans_len = NP2P_data_stream.read_generation_datasets_from_fof(
            FLAGS.test_path, isLower=FLAGS.isLower)
    else:
        testset, test_ans_len = NP2P_data_stream.read_all_GQA_questions(
            FLAGS.test_path, isLower=FLAGS.isLower)
    print('Number of test samples: {}'.format(len(testset)))

    max_actual_len = max(train_ans_len, test_ans_len)
    print('Max answer length: {}, truncated to {}'.format(
        max_actual_len, FLAGS.max_answer_len))

    word_vocab = None
    POS_vocab = None
    NER_vocab = None
    char_vocab = None
    has_pretrained_model = False
    best_path = path_prefix + ".best.model"
    if os.path.exists(init_model_prefix + ".best.model.index"):
        has_pretrained_model = True
        print('!!Existing pretrained model. Loading vocabs.')
        if FLAGS.with_word:
            word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2')
            print('word_vocab: {}'.format(word_vocab.word_vecs.shape))
        if FLAGS.with_char:
            char_vocab = Vocab(init_model_prefix + ".char_vocab",
                               fileformat='txt2')
            print('char_vocab: {}'.format(char_vocab.word_vecs.shape))
        if FLAGS.with_POS:
            POS_vocab = Vocab(init_model_prefix + ".POS_vocab",
                              fileformat='txt2')
            print('POS_vocab: {}'.format(POS_vocab.word_vecs.shape))
        if FLAGS.with_NER:
            NER_vocab = Vocab(init_model_prefix + ".NER_vocab",
                              fileformat='txt2')
            print('NER_vocab: {}'.format(NER_vocab.word_vecs.shape))
    else:
        print('Collecting vocabs.')
        (allWords, allChars, allPOSs,
         allNERs) = NP2P_data_stream.collect_vocabs(trainset)
        print('Number of words: {}'.format(len(allWords)))
        print('Number of allChars: {}'.format(len(allChars)))
        print('Number of allPOSs: {}'.format(len(allPOSs)))
        print('Number of allNERs: {}'.format(len(allNERs)))

        if FLAGS.with_word:
            word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2')
        if FLAGS.with_char:
            char_vocab = Vocab(voc=allChars,
                               dim=FLAGS.char_dim,
                               fileformat='build')
            char_vocab.dump_to_txt2(path_prefix + ".char_vocab")
        if FLAGS.with_POS:
            POS_vocab = Vocab(voc=allPOSs,
                              dim=FLAGS.POS_dim,
                              fileformat='build')
            POS_vocab.dump_to_txt2(path_prefix + ".POS_vocab")
        if FLAGS.with_NER:
            NER_vocab = Vocab(voc=allNERs,
                              dim=FLAGS.NER_dim,
                              fileformat='build')
            NER_vocab.dump_to_txt2(path_prefix + ".NER_vocab")

    print('word vocab size {}'.format(word_vocab.vocab_size))
    sys.stdout.flush()

    print('Build DataStream ... ')
    trainDataStream = NP2P_data_stream.QADataStream(trainset,
                                                    word_vocab,
                                                    char_vocab,
                                                    POS_vocab,
                                                    NER_vocab,
                                                    options=FLAGS,
                                                    isShuffle=True,
                                                    isLoop=True,
                                                    isSort=True)

    devDataStream = NP2P_data_stream.QADataStream(testset,
                                                  word_vocab,
                                                  char_vocab,
                                                  POS_vocab,
                                                  NER_vocab,
                                                  options=FLAGS,
                                                  isShuffle=False,
                                                  isLoop=False,
                                                  isSort=True)
    print('Number of instances in trainDataStream: {}'.format(
        trainDataStream.get_num_instance()))
    print('Number of instances in devDataStream: {}'.format(
        devDataStream.get_num_instance()))
    print('Number of batches in trainDataStream: {}'.format(
        trainDataStream.get_num_batch()))
    print('Number of batches in devDataStream: {}'.format(
        devDataStream.get_num_batch()))
    sys.stdout.flush()

    init_scale = 0.01
    with tf.Graph().as_default():
        initializer = tf.random_uniform_initializer(-init_scale, init_scale)
        with tf.name_scope("Train"):
            with tf.variable_scope("Model",
                                   reuse=None,
                                   initializer=initializer):
                train_graph = ModelGraph(word_vocab=word_vocab,
                                         char_vocab=char_vocab,
                                         POS_vocab=POS_vocab,
                                         NER_vocab=NER_vocab,
                                         options=FLAGS,
                                         mode="rl_train_for_phrase")

        with tf.name_scope("Valid"):
            with tf.variable_scope("Model",
                                   reuse=True,
                                   initializer=initializer):
                valid_graph = ModelGraph(word_vocab=word_vocab,
                                         char_vocab=char_vocab,
                                         POS_vocab=POS_vocab,
                                         NER_vocab=NER_vocab,
                                         options=FLAGS,
                                         mode="decode")

        initializer = tf.global_variables_initializer()

        vars_ = {}
        for var in tf.all_variables():
            if "word_embedding" in var.name: continue
            if not var.name.startswith("Model"): continue
            vars_[var.name.split(":")[0]] = var
        saver = tf.train.Saver(vars_)

        sess = tf.Session()
        sess.run(initializer)
        if has_pretrained_model:
            print("Restoring model from " + init_model_prefix + ".best.model")
            saver.restore(sess, init_model_prefix + ".best.model")
            print("DONE!")
        sys.stdout.flush()

        # for first-time rl training, we get the current BLEU score
        print("First-time rl training, get the current BLEU score on dev")
        sys.stdout.flush()
        best_bleu = evaluate(sess,
                             valid_graph,
                             devDataStream,
                             word_vocab,
                             options=FLAGS)
        print('First-time bleu = %.4f' % best_bleu)
        log_file.write('First-time bleu = %.4f\n' % best_bleu)

        print('Start the training loop.')
        sys.stdout.flush()
        train_size = trainDataStream.get_num_batch()
        max_steps = train_size * FLAGS.max_epochs
        total_loss = 0.0
        start_time = time.time()
        for step in xrange(max_steps):
            cur_batch = trainDataStream.nextBatch()
            if FLAGS.with_baseline:
                # greedy search
                (greedy_sentences, _, _,
                 _) = NP2P_beam_decoder.search(sess,
                                               valid_graph,
                                               word_vocab,
                                               cur_batch,
                                               FLAGS,
                                               decode_mode="greedy")

            if FLAGS.with_target_lattice:
                (sampled_sentences, sampled_prediction_lengths,
                 sampled_generator_input_idx, sampled_generator_output_idx
                 ) = cur_batch.sample_a_partition()
            else:
                # multinomial sampling
                (sampled_sentences, sampled_prediction_lengths,
                 sampled_generator_input_idx,
                 sampled_generator_output_idx) = NP2P_beam_decoder.search(
                     sess,
                     valid_graph,
                     word_vocab,
                     cur_batch,
                     FLAGS,
                     decode_mode="multinomial")
            # calculate rewards
            rewards = []
            for i in xrange(cur_batch.batch_size):
                #                 print(sampled_sentences[i])
                #                 print(sampled_generator_input_idx[i])
                #                 print(sampled_generator_output_idx[i])
                cur_toks = cur_batch.instances[i][1].tokText.split()
                #                 r = sentence_bleu([cur_toks], sampled_sentences[i].split(), smoothing_function=cc.method3)
                r = 1.0
                b = 0.0
                if FLAGS.with_baseline:
                    b = sentence_bleu([cur_toks],
                                      greedy_sentences[i].split(),
                                      smoothing_function=cc.method3)


#                 r = metric_utils.evaluate_captions([cur_toks],[sampled_sentences[i]])
#                 b = metric_utils.evaluate_captions([cur_toks],[greedy_sentences[i]])
                rewards.append(1.0 * (r - b))
            rewards = np.array(rewards, dtype=np.float32)
            #             sys.exit(-1)

            # update parameters
            feed_dict = train_graph.run_encoder(sess,
                                                cur_batch,
                                                FLAGS,
                                                only_feed_dict=True)
            feed_dict[train_graph.reward] = rewards
            feed_dict[
                train_graph.gen_input_words] = sampled_generator_input_idx
            feed_dict[
                train_graph.in_answer_words] = sampled_generator_output_idx
            feed_dict[train_graph.answer_lengths] = sampled_prediction_lengths
            (_,
             loss_value) = sess.run([train_graph.train_op, train_graph.loss],
                                    feed_dict)
            total_loss += loss_value

            if step % 100 == 0:
                print('{} '.format(step), end="")
                sys.stdout.flush()

            # Save a checkpoint and evaluate the model periodically.
            if (step + 1) % trainDataStream.get_num_batch() == 0 or (
                    step + 1) == max_steps:
                print()
                duration = time.time() - start_time
                print('Step %d: loss = %.2f (%.3f sec)' %
                      (step, total_loss, duration))
                log_file.write('Step %d: loss = %.2f (%.3f sec)\n' %
                               (step, total_loss, duration))
                log_file.flush()
                sys.stdout.flush()
                total_loss = 0.0

                # Evaluate against the validation set.
                start_time = time.time()
                print('Validation Data Eval:')
                dev_bleu = evaluate(sess,
                                    valid_graph,
                                    devDataStream,
                                    word_vocab,
                                    options=FLAGS)
                print('Dev bleu = %.4f' % dev_bleu)
                log_file.write('Dev bleu = %.4f\n' % dev_bleu)
                log_file.flush()
                if best_bleu < dev_bleu:
                    print('Saving weights, BLEU {} (prev_best) < {} (cur)'.
                          format(best_bleu, dev_bleu))
                    best_bleu = dev_bleu
                    saver.save(sess, best_path)  # TODO: save model
                duration = time.time() - start_time
                print('Duration %.3f sec' % (duration))
                sys.stdout.flush()
                log_file.write('Duration %.3f sec\n' % (duration))
                log_file.flush()

    log_file.close()