示例#1
0
def validate_and_save(model, devset_batches, log_file, best_accu):
    path_prefix = FLAGS.log_dir + "/MHQA.{}".format(FLAGS.suffix)
    start_time = time.time()
    res_dict = evaluate_dataset(model, devset_batches)
    dev_loss = res_dict['dev_loss']
    dev_accu = res_dict['dev_accu']
    dev_right = int(res_dict['dev_right'])
    dev_total = int(res_dict['dev_total'])
    print('Dev loss = %.4f' % dev_loss)
    log_file.write('Dev loss = %.4f\n' % dev_loss)
    print('Dev accu = %.4f %d/%d' % (dev_accu, dev_right, dev_total))
    log_file.write('Dev accu = %.4f %d/%d\n' %
                   (dev_accu, dev_right, dev_total))
    log_file.flush()
    if best_accu < dev_accu:
        print('Saving weights, ACCU {} (prev_best) < {} (cur)'.format(
            best_accu, dev_accu))
        best_path = path_prefix + '.model.bin'
        torch.save(model.state_dict(), best_path)
        best_accu = dev_accu
        FLAGS.best_accu = dev_accu
        namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")
    duration = time.time() - start_time
    print('Duration %.3f sec' % (duration))
    print('-------------')
    log_file.write('-------------\n')
    return best_accu
示例#2
0
def validate_and_save(sess, saver, FLAGS, log_file, devDataStream, valid_graph,
                      path_prefix, best_accu):
    best_path = path_prefix + ".best.model"
    start_time = time.time()
    print('Validation Data Eval:')
    res_dict = evaluate(sess, valid_graph, devDataStream, options=FLAGS)
    dev_loss = res_dict['dev_loss']
    dev_accu = res_dict['dev_accu']
    dev_right = int(res_dict['dev_right'])
    dev_total = int(res_dict['dev_total'])
    print('Dev loss = %.4f' % dev_loss)
    log_file.write('Dev loss = %.4f\n' % dev_loss)
    print('Dev accu = %.4f %d/%d' % (dev_accu, dev_right, dev_total))
    log_file.write('Dev accu = %.4f %d/%d\n' %
                   (dev_accu, dev_right, dev_total))
    log_file.flush()
    if best_accu < dev_accu:
        print('Saving weights, ACCU {} (prev_best) < {} (cur)'.format(
            best_accu, dev_accu))
        saver.save(sess, best_path)
        best_accu = dev_accu
        FLAGS.best_accu = dev_accu
        namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")
    duration = time.time() - start_time
    print('Duration %.3f sec' % (duration))
    sys.stdout.flush()
    return best_accu
示例#3
0
def main(_):
    print('Configurations:')
    print(FLAGS)  # 打印各个参数

    root_path = FLAGS.root_path
    train_path = root_path + FLAGS.train_path
    dev_path = root_path + FLAGS.dev_path
    test_path = root_path + FLAGS.test_path
    word_vec_path = root_path + FLAGS.word_vec_path
    model_dir = root_path + FLAGS.model_dir

    if tf.gfile.Exists(model_dir + '/mnist_with_summaries'):
        print("delete summaries")
        tf.gfile.DeleteRecursively(model_dir + '/mnist_with_summaries')

    if not os.path.exists(model_dir):
        os.makedirs(model_dir)

    path_prefix = model_dir + "/SentenceMatch.{}".format(FLAGS.suffix)

    namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")  # 保存参数

    best_path = path_prefix + '.best.model'
    label_path = path_prefix + ".label_vocab"
    has_pre_trained_model = False
    ckpt = tf.train.get_checkpoint_state(model_dir)
    if ckpt and ckpt.model_checkpoint_path:
        print("-------has_pre_trained_model--------")
        print(ckpt.model_checkpoint_path)
        has_pre_trained_model = True

    ############# build vocabs#################
    print('Collect words, chars and labels ...')
    (all_words, all_labels) = collect_vocabs(train_path)
    print('Number of words: {}'.format(len(all_words)))
    print('Number of labels: {}'.format(len(all_labels)))

    word_vocab = Vocab(pattern='word')  # 定义一个类
    word_vocab.patternWord(word_vec_path, model_dir)
    label_vocab = Vocab(pattern="label")
    label_vocab.patternLabel(all_labels, label_path)

    print('word_vocab shape is {}'.format(word_vocab.word_vecs.shape))
    print('tag_vocab shape is {}'.format(label_vocab.word_vecs.shape))
    num_classes = len(all_labels)

    if FLAGS.wo_char: char_vocab = None
    #####  Build SentenceMatchDataStream  ################
    print('Build SentenceMatchDataStream ... ')
    trainDataStream = SentenceMatchDataStream(
        train_path,
        word_vocab=word_vocab,
        label_vocab=label_vocab,
        batch_size=FLAGS.batch_size,
        isShuffle=True,
        isLoop=True,
        isSort=False,
        max_sent_length=FLAGS.max_sent_length)

    devDataStream = SentenceMatchDataStream(
        dev_path,
        word_vocab=word_vocab,
        label_vocab=label_vocab,
        batch_size=FLAGS.batch_size,
        isShuffle=False,
        isLoop=True,
        isSort=False,
        max_sent_length=FLAGS.max_sent_length)

    testDataStream = SentenceMatchDataStream(
        test_path,
        word_vocab=word_vocab,
        label_vocab=label_vocab,
        batch_size=FLAGS.batch_size,
        isShuffle=False,
        isLoop=True,
        isSort=False,
        max_sent_length=FLAGS.max_sent_length)

    print('Number of instances in trainDataStream: {}'.format(
        trainDataStream.get_num_instance()))
    print('Number of instances in devDataStream: {}'.format(
        devDataStream.get_num_instance()))
    print('Number of instances in testDataStream: {}'.format(
        testDataStream.get_num_instance()))
    print('Number of batches in trainDataStream: {}'.format(
        trainDataStream.get_num_batch()))
    print('Number of batches in devDataStream: {}'.format(
        devDataStream.get_num_batch()))
    print('Number of batches in testDataStream: {}'.format(
        testDataStream.get_num_batch()))

    sys.stdout.flush()

    best_accuracy = 0.0
    init_scale = 0.01
    g_2 = tf.Graph()
    with g_2.as_default():
        initializer = tf.random_uniform_initializer(-init_scale, init_scale)
        with tf.variable_scope("Model", reuse=None, initializer=initializer):
            train_graph = SentenceMatchModelGraph(
                num_classes,
                word_vocab=word_vocab,
                dropout_rate=FLAGS.dropout_rate,
                learning_rate=FLAGS.learning_rate,
                optimize_type=FLAGS.optimize_type,
                lambda_l2=FLAGS.lambda_l2,
                with_word=True,
                context_lstm_dim=FLAGS.context_lstm_dim,
                aggregation_lstm_dim=FLAGS.aggregation_lstm_dim,
                is_training=True,
                MP_dim=FLAGS.MP_dim,
                context_layer_num=FLAGS.context_layer_num,
                aggregation_layer_num=FLAGS.aggregation_layer_num,
                fix_word_vec=FLAGS.fix_word_vec,
                with_filter_layer=FLAGS.with_filter_layer,
                with_highway=FLAGS.with_highway,
                with_match_highway=FLAGS.with_match_highway,
                with_aggregation_highway=FLAGS.with_aggregation_highway,
                highway_layer_num=FLAGS.highway_layer_num,
                with_lex_decomposition=FLAGS.with_lex_decomposition,
                lex_decompsition_dim=FLAGS.lex_decompsition_dim,
                with_left_match=(not FLAGS.wo_left_match),
                with_right_match=(not FLAGS.wo_right_match),
                with_full_match=(not FLAGS.wo_full_match),
                with_maxpool_match=(not FLAGS.wo_maxpool_match),
                with_attentive_match=(not FLAGS.wo_attentive_match),
                with_max_attentive_match=(not FLAGS.wo_max_attentive_match))
            tf.summary.scalar("Training Loss", train_graph.get_loss())

        with tf.variable_scope("Model", reuse=True, initializer=initializer):
            valid_graph = SentenceMatchModelGraph(
                num_classes,
                word_vocab=word_vocab,
                dropout_rate=FLAGS.dropout_rate,
                learning_rate=FLAGS.learning_rate,
                optimize_type=FLAGS.optimize_type,
                lambda_l2=FLAGS.lambda_l2,
                with_word=True,
                context_lstm_dim=FLAGS.context_lstm_dim,
                aggregation_lstm_dim=FLAGS.aggregation_lstm_dim,
                is_training=False,
                MP_dim=FLAGS.MP_dim,
                context_layer_num=FLAGS.context_layer_num,
                aggregation_layer_num=FLAGS.aggregation_layer_num,
                fix_word_vec=FLAGS.fix_word_vec,
                with_filter_layer=FLAGS.with_filter_layer,
                with_highway=FLAGS.with_highway,
                with_match_highway=FLAGS.with_match_highway,
                with_aggregation_highway=FLAGS.with_aggregation_highway,
                highway_layer_num=FLAGS.highway_layer_num,
                with_lex_decomposition=FLAGS.with_lex_decomposition,
                lex_decompsition_dim=FLAGS.lex_decompsition_dim,
                with_left_match=(not FLAGS.wo_left_match),
                with_right_match=(not FLAGS.wo_right_match),
                with_full_match=(not FLAGS.wo_full_match),
                with_maxpool_match=(not FLAGS.wo_maxpool_match),
                with_attentive_match=(not FLAGS.wo_attentive_match),
                with_max_attentive_match=(not FLAGS.wo_max_attentive_match))

        initializer = tf.global_variables_initializer()
        saver = tf.train.Saver()
        # vars_ = {}
        # for var in tf.global_variables():
        #     if "word_embedding" in var.name: continue
        #     vars_[var.name.split(":")[0]] = var
        # saver = tf.train.Saver(vars_)

        sess = tf.Session()

        merged = tf.summary.merge_all()
        train_writer = tf.summary.FileWriter(
            model_dir + '/mnist_with_summaries/train', sess.graph)
        sess.run(initializer)

        if has_pre_trained_model:
            print("Restoring model from " + best_path)
            saver.restore(sess, best_path)
            print("DONE!")

        print('Start the training loop.')
        train_size = trainDataStream.get_num_batch()
        max_steps = train_size * FLAGS.max_epochs
        total_loss = 0.0
        start_time = time.time()
        for step in iter(range(max_steps)):
            cur_batch = trainDataStream.nextBatch()
            (label_id_batch, word_idx_1_batch, word_idx_2_batch,
             sent1_length_batch, sent2_length_batch) = cur_batch
            feed_dict = {
                train_graph.get_truth(): label_id_batch,
                train_graph.get_question_lengths(): sent1_length_batch,
                train_graph.get_passage_lengths(): sent2_length_batch,
                train_graph.get_in_question_words(): word_idx_1_batch,
                train_graph.get_in_passage_words(): word_idx_2_batch,
            }

            # in_question_repres,in_ques=sess.run([train_graph.in_question_repres,train_graph.in_ques],feed_dict=feed_dict)
            # print(in_question_repres,in_ques)
            # break

            _, loss_value, summary = sess.run(
                [train_graph.get_train_op(),
                 train_graph.get_loss(), merged],
                feed_dict=feed_dict)
            total_loss += loss_value

            if step % 5000 == 0:
                # train_writer.add_summary(summary, step)
                print("step:", step, "loss:", loss_value)

            if (step + 1) % trainDataStream.get_num_batch() == 0 or (
                    step + 1) == max_steps:
                print()
                duration = time.time() - start_time
                start_time = time.time()
                print('Step %d: loss = %.2f (%.3f sec)' %
                      (step, total_loss, duration))
                total_loss = 0.0

                print('Validation Data Eval:')
                accuracy = evaluate(devDataStream, valid_graph, sess)
                print("Current accuracy is %.2f" % accuracy)
                if accuracy >= best_accuracy:
                    print('Saving model since it\'s the best so far')
                    best_accuracy = accuracy
                    saver.save(sess, best_path)
            sys.stdout.flush()

    print("Best accuracy on dev set is %.2f" % best_accuracy)
示例#4
0
def main(_):
    print('Configurations:')
    print(FLAGS)

    train_path = FLAGS.train_path
    dev_path = FLAGS.dev_path
    test_path = FLAGS.test_path
    word_vec_path = FLAGS.word_vec_path
    log_dir = FLAGS.model_dir
    if FLAGS.train == "sick":
        train_path = FLAGS.SICK_train_path
        dev_path = FLAGS.SICK_dev_path
    if FLAGS.test == "sick":
        test_path = FLAGS.SICK_test_path
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    
    path_prefix = log_dir + "/SentenceMatch.{}".format(FLAGS.suffix)
    namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")

    # build vocabs
    parser,image_feats = None, None
    if FLAGS.with_dep:
        parser=Parser('snli')
    if FLAGS.with_image:
        image_feats=ImageFeatures()
    word_vocab = Vocab(word_vec_path, fileformat='txt3', parser=parser, beginning=FLAGS.beginning) #fileformat='txt3'
    best_path = path_prefix + '.best.model'
    char_path = path_prefix + ".char_vocab"
    label_path = path_prefix + ".label_vocab"
    POS_path = path_prefix + ".POS_vocab"
    NER_path = path_prefix + ".NER_vocab"
    DEP_path = path_prefix + ".DEP_vocab"
    has_pre_trained_model = False
    POS_vocab = None
    NER_vocab = None
    DEP_vocab = None
    print('has pretrained model: ', os.path.exists(best_path))
    print('best_path: ' + best_path)
    if os.path.exists(best_path + '.meta'):
        has_pre_trained_model = True
        label_vocab = Vocab(label_path, fileformat='txt2')
        char_vocab = Vocab(char_path, fileformat='txt2')
        if FLAGS.with_POS: POS_vocab = Vocab(POS_path, fileformat='txt2')
        if FLAGS.with_NER: NER_vocab = Vocab(NER_path, fileformat='txt2')
    else:
        print('Collect words, chars and labels ...')
        (all_words, all_chars, all_labels, all_POSs, all_NERs) = collect_vocabs(train_path, with_POS=FLAGS.with_POS, with_NER=FLAGS.with_NER)
        print('Number of words: {}'.format(len(all_words)))
        print('Number of labels: {}'.format(len(all_labels)))
        label_vocab = Vocab(fileformat='voc', voc=all_labels,dim=2)
        label_vocab.dump_to_txt2(label_path)

        print('Number of chars: {}'.format(len(all_chars)))
        char_vocab = Vocab(fileformat='voc', voc=all_chars,dim=FLAGS.char_emb_dim, beginning=FLAGS.beginning)
        char_vocab.dump_to_txt2(char_path)
        
        if FLAGS.with_POS:
            print('Number of POSs: {}'.format(len(all_POSs)))
            POS_vocab = Vocab(fileformat='voc', voc=all_POSs,dim=FLAGS.POS_dim)
            POS_vocab.dump_to_txt2(POS_path)
        if FLAGS.with_NER:
            print('Number of NERs: {}'.format(len(all_NERs)))
            NER_vocab = Vocab(fileformat='voc', voc=all_NERs,dim=FLAGS.NER_dim)
            NER_vocab.dump_to_txt2(NER_path)
            

    print('word_vocab shape is {}'.format(word_vocab.word_vecs.shape))
    print('tag_vocab shape is {}'.format(label_vocab.word_vecs.shape))
    num_classes = label_vocab.size()

    print('Build DataStream ... ')
    print('Reading trainDataStream')
    if not FLAGS.decoding_only:
        trainDataStream = DataStream(train_path, word_vocab=word_vocab, char_vocab=char_vocab, 
                                              POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, 
                                              batch_size=FLAGS.batch_size, isShuffle=True, isLoop=True, isSort=True, 
                                              max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length, 
                                              with_dep=FLAGS.with_dep, with_image=FLAGS.with_image, image_feats=image_feats, sick_data=(FLAGS.test == "sick"))
    
        print('Reading devDataStream')
        devDataStream = DataStream(dev_path, word_vocab=word_vocab, char_vocab=char_vocab,
                                              POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, 
                                              batch_size=FLAGS.batch_size, isShuffle=False, isLoop=True, isSort=True, 
                                              max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length, 
                                              with_dep=FLAGS.with_dep, with_image=FLAGS.with_image, image_feats=image_feats, sick_data=(FLAGS.test == "sick"))

    print('Reading testDataStream')
    testDataStream = DataStream(test_path, word_vocab=word_vocab, char_vocab=char_vocab, 
                                              POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, 
                                              batch_size=FLAGS.batch_size, isShuffle=False, isLoop=True, isSort=True, 
                                                  max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length, 
                                                  with_dep=FLAGS.with_dep, with_image=FLAGS.with_image, image_feats=image_feats, sick_data=(FLAGS.test == "sick"))
    print('save cache file')
    #word_vocab.parser.save_cache()
    #image_feats.save_feat()
    if not FLAGS.decoding_only:
        print('Number of instances in trainDataStream: {}'.format(trainDataStream.get_num_instance()))
        print('Number of instances in devDataStream: {}'.format(devDataStream.get_num_instance()))
        print('Number of batches in trainDataStream: {}'.format(trainDataStream.get_num_batch()))
        print('Number of batches in devDataStream: {}'.format(devDataStream.get_num_batch()))
    print('Number of instances in testDataStream: {}'.format(testDataStream.get_num_instance()))
    print('Number of batches in testDataStream: {}'.format(testDataStream.get_num_batch()))
    
    sys.stdout.flush()
    if FLAGS.wo_char: char_vocab = None

    best_accuracy = 0.0
    init_scale = 0.01
    
    if not FLAGS.decoding_only:
        with tf.Graph().as_default():
            initializer = tf.random_uniform_initializer(-init_scale, init_scale)
            with tf.variable_scope("Model", reuse=None, initializer=initializer):
                train_graph = ModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab,POS_vocab=POS_vocab, NER_vocab=NER_vocab, 
                    dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type,
                    lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim, 
                    aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=True, MP_dim=FLAGS.MP_dim, 
                    context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num, 
                    fix_word_vec=FLAGS.fix_word_vec,with_filter_layer=FLAGS.with_filter_layer, with_highway=FLAGS.with_highway,
                    word_level_MP_dim=FLAGS.word_level_MP_dim,
                    with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway,
                    highway_layer_num=FLAGS.highway_layer_num,with_lex_decomposition=FLAGS.with_lex_decomposition, 
                    lex_decompsition_dim=FLAGS.lex_decompsition_dim,
                    with_left_match=(not FLAGS.wo_left_match), with_right_match=(not FLAGS.wo_right_match),
                    with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match), 
                    with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match), 
                    with_dep=FLAGS.with_dep, with_image=FLAGS.with_image, image_with_hypothesis_only=FLAGS.image_with_hypothesis_only,
                    with_img_full_match=FLAGS.with_img_full_match, with_img_maxpool_match=FLAGS.with_img_full_match, 
                    with_img_attentive_match=FLAGS.with_img_attentive_match, image_context_layer=FLAGS.image_context_layer, 
                    with_img_max_attentive_match=FLAGS.with_img_max_attentive_match, img_dim=FLAGS.img_dim)
                tf.summary.scalar("Training Loss", train_graph.get_loss()) # Add a scalar summary for the snapshot loss.
        
            with tf.variable_scope("Model", reuse=True, initializer=initializer):
                valid_graph = ModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, 
                    dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type,
                    lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim, 
                    aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=False, MP_dim=FLAGS.MP_dim, 
                    context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num, 
                    fix_word_vec=FLAGS.fix_word_vec,with_filter_layer=FLAGS.with_filter_layer, with_highway=FLAGS.with_highway,
                    word_level_MP_dim=FLAGS.word_level_MP_dim,
                    with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway,
                    highway_layer_num=FLAGS.highway_layer_num, with_lex_decomposition=FLAGS.with_lex_decomposition, 
                    lex_decompsition_dim=FLAGS.lex_decompsition_dim,
                    with_left_match=(not FLAGS.wo_left_match), with_right_match=(not FLAGS.wo_right_match),
                    with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match), 
                    with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match), 
                    with_dep=FLAGS.with_dep, with_image=FLAGS.with_image, image_with_hypothesis_only=FLAGS.image_with_hypothesis_only,
                    with_img_attentive_match=FLAGS.with_img_attentive_match, with_img_full_match=FLAGS.with_img_full_match, 
                    with_img_maxpool_match=FLAGS.with_img_full_match, image_context_layer=FLAGS.image_context_layer, 
                    with_img_max_attentive_match=FLAGS.with_img_max_attentive_match, img_dim=FLAGS.img_dim)

                
            initializer = tf.global_variables_initializer()
            vars_ = {}
            for var in tf.all_variables():
                if "word_embedding" in var.name: continue
                vars_[var.name.split(":")[0]] = var
            saver = tf.train.Saver(vars_)
         
            sess = tf.Session()
            sess.run(initializer) #, feed_dict={valid_graph.emb_init: word_vocab.word_vecs, train_graph.emb_init: word_vocab.word_vecs})
            if has_pre_trained_model:
                print("Restoring model from " + best_path)
                saver.restore(sess, best_path)
                print("DONE!")
                #if best_path.startswith('bimpm_baseline'):
                #best_path = best_path + '_sick' 

            print('Start the training loop.')
            train_size = trainDataStream.get_num_batch()
            max_steps = train_size * FLAGS.max_epochs
            total_loss = 0.0
            start_time = time.time()
            for step in xrange(max_steps):
                # read data
                cur_batch = trainDataStream.nextBatch()
                (label_batch, sent1_batch, sent2_batch, label_id_batch, word_idx_1_batch, word_idx_2_batch, 
                                 char_matrix_idx_1_batch, char_matrix_idx_2_batch, sent1_length_batch, sent2_length_batch, 
                                 sent1_char_length_batch, sent2_char_length_batch,
                                 POS_idx_1_batch, POS_idx_2_batch, NER_idx_1_batch, NER_idx_2_batch, 
                                 dependency1_batch, dependency2_batch, dep_con1_batch, dep_con2_batch, img_feats_batch, img_id_batch) = cur_batch
                feed_dict = {
                         train_graph.get_truth(): label_id_batch, 
                         train_graph.get_question_lengths(): sent1_length_batch, 
                         train_graph.get_passage_lengths(): sent2_length_batch, 
                         train_graph.get_in_question_words(): word_idx_1_batch, 
                         train_graph.get_in_passage_words(): word_idx_2_batch,
                         #train_graph.get_emb_init(): word_vocab.word_vecs,
                         #train_graph.get_in_question_dependency(): dependency1_batch,
                         #train_graph.get_in_passage_dependency(): dependency2_batch,
#                          train_graph.get_question_char_lengths(): sent1_char_length_batch, 
#                          train_graph.get_passage_char_lengths(): sent2_char_length_batch, 
#                          train_graph.get_in_question_chars(): char_matrix_idx_1_batch, 
#                          train_graph.get_in_passage_chars(): char_matrix_idx_2_batch, 
                         }
                if FLAGS.with_dep:
                    feed_dict[train_graph.get_in_question_dependency()] = dependency1_batch
                    feed_dict[train_graph.get_in_passage_dependency()] = dependency2_batch
                    feed_dict[train_graph.get_in_question_dep_con()] = dep_con1_batch
                    feed_dict[train_graph.get_in_passage_dep_con()] = dep_con2_batch

                if FLAGS.with_image:
                    feed_dict[train_graph.get_image_feats()] = img_feats_batch

                if char_vocab is not None:
                    feed_dict[train_graph.get_question_char_lengths()] = sent1_char_length_batch
                    feed_dict[train_graph.get_passage_char_lengths()] = sent2_char_length_batch
                    feed_dict[train_graph.get_in_question_chars()] = char_matrix_idx_1_batch
                    feed_dict[train_graph.get_in_passage_chars()] = char_matrix_idx_2_batch

                if POS_vocab is not None:
                    feed_dict[train_graph.get_in_question_poss()] = POS_idx_1_batch
                    feed_dict[train_graph.get_in_passage_poss()] = POS_idx_2_batch

                if NER_vocab is not None:
                    feed_dict[train_graph.get_in_question_ners()] = NER_idx_1_batch
                    feed_dict[train_graph.get_in_passage_ners()] = NER_idx_2_batch

                _, loss_value = sess.run([train_graph.get_train_op(), train_graph.get_loss()], feed_dict=feed_dict)
                total_loss += loss_value
            
                if step % 100==0: 
                    print('{} '.format(step), end="")
                    sys.stdout.flush()

                # Save a checkpoint and evaluate the model periodically.
                if (step + 1) % trainDataStream.get_num_batch() == 0 or (step + 1) == max_steps:
                    print()
                    # Print status to stdout.
                    duration = time.time() - start_time
                    start_time = time.time()
                    print('Step %d: loss = %.2f (%.3f sec)' % (step, total_loss, duration))
                    total_loss = 0.0

                    # Evaluate against the validation set.
                    print('Validation Data Eval:')
                    accuracy = evaluate(devDataStream, valid_graph, sess,char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, word_vocab=word_vocab)
                    print("Current accuracy on dev is %.2f" % accuracy)
                
                    #accuracy_train = evaluate(trainDataStream, valid_graph, sess,char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab)
                    #print("Current accuracy on train is %.2f" % accuracy_train)
                    if accuracy>best_accuracy:
                        best_accuracy = accuracy
                        saver.save(sess, best_path)
        print("Best accuracy on dev set is %.2f" % best_accuracy)
    
    # decoding
    print('Decoding on the test set:')
    init_scale = 0.01
    with tf.Graph().as_default():
        initializer = tf.random_uniform_initializer(-init_scale, init_scale)
        with tf.variable_scope("Model", reuse=False, initializer=initializer):
            valid_graph = ModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, 
                 dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type,
                 lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim, 
                 aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=False, MP_dim=FLAGS.MP_dim, 
                 context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num, 
                 fix_word_vec=FLAGS.fix_word_vec,with_filter_layer=FLAGS.with_filter_layer, with_highway=FLAGS.with_highway,
                 word_level_MP_dim=FLAGS.word_level_MP_dim,
                 with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway,
                 highway_layer_num=FLAGS.highway_layer_num, with_lex_decomposition=FLAGS.with_lex_decomposition, 
                 lex_decompsition_dim=FLAGS.lex_decompsition_dim,
                 with_left_match=(not FLAGS.wo_left_match), with_right_match=(not FLAGS.wo_right_match),
                 with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match), 
                 with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match), 
                 with_dep=FLAGS.with_dep, with_image=FLAGS.with_image, image_with_hypothesis_only=FLAGS.image_with_hypothesis_only,
                 with_img_attentive_match=FLAGS.with_img_attentive_match, with_img_full_match=FLAGS.with_img_full_match, 
                 with_img_maxpool_match=FLAGS.with_img_full_match, image_context_layer=FLAGS.image_context_layer, 
                 with_img_max_attentive_match=FLAGS.with_img_max_attentive_match, img_dim=FLAGS.img_dim)
        vars_ = {}
        for var in tf.all_variables():
            if "word_embedding" in var.name: continue
            if not var.name.startswith("Model"): continue
            vars_[var.name.split(":")[0]] = var
        saver = tf.train.Saver(vars_)
                
        sess = tf.Session()
        sess.run(tf.global_variables_initializer())#, feed_dict={valid_graph.emb_init: word_vocab.word_vecs})
        step = 0
        saver.restore(sess, best_path)

        accuracy = evaluate(testDataStream, valid_graph, sess, outpath=FLAGS.suffix+ FLAGS.train + FLAGS.test + ".result",char_vocab=char_vocab,label_vocab=label_vocab, word_vocab=word_vocab)
        print("Accuracy for test set is %.2f" % accuracy)
        accuracy_train = evaluate(trainDataStream, valid_graph, sess,char_vocab=char_vocab,word_vocab=word_vocab)
        print("Accuracy for train set is %.2f" % accuracy_train)
示例#5
0
def main(_):


    #for x in range (100):
    #    Generate_random_initialization()
    #    print (FLAGS.is_aggregation_lstm, FLAGS.context_lstm_dim, FLAGS.context_layer_num, FLAGS. aggregation_lstm_dim, FLAGS.aggregation_layer_num, FLAGS.max_window_size, FLAGS.MP_dim)

    print('Configurations:')
    #print(FLAGS)


    train_path = FLAGS.train_path
    dev_path = FLAGS.dev_path
    test_path = FLAGS.test_path
    word_vec_path = FLAGS.word_vec_path
    log_dir = FLAGS.model_dir
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    
    path_prefix = log_dir + "/SentenceMatch.{}".format(FLAGS.suffix)

    namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")

    # build vocabs
    word_vocab = Vocab(word_vec_path, fileformat='txt3')
    best_path = path_prefix + '.best.model'
    char_path = path_prefix + ".char_vocab"
    label_path = path_prefix + ".label_vocab"
    POS_path = path_prefix + ".POS_vocab"
    NER_path = path_prefix + ".NER_vocab"
    has_pre_trained_model = False
    POS_vocab = None
    NER_vocab = None
    if os.path.exists(best_path):
        has_pre_trained_model = True
        label_vocab = Vocab(label_path, fileformat='txt2')
        char_vocab = Vocab(char_path, fileformat='txt2')
        if FLAGS.with_POS: POS_vocab = Vocab(POS_path, fileformat='txt2')
        if FLAGS.with_NER: NER_vocab = Vocab(NER_path, fileformat='txt2')
    else:
        print('Collect words, chars and labels ...')
        (all_words, all_chars, all_labels, all_POSs, all_NERs) = collect_vocabs(train_path, with_POS=FLAGS.with_POS, with_NER=FLAGS.with_NER)
        print('Number of words: {}'.format(len(all_words)))
        print('Number of labels: {}'.format(len(all_labels)))
        label_vocab = Vocab(fileformat='voc', voc=all_labels,dim=2)
        label_vocab.dump_to_txt2(label_path)

        print('Number of chars: {}'.format(len(all_chars)))
        char_vocab = Vocab(fileformat='voc', voc=all_chars,dim=FLAGS.char_emb_dim)
        char_vocab.dump_to_txt2(char_path)
        
        if FLAGS.with_POS:
            print('Number of POSs: {}'.format(len(all_POSs)))
            POS_vocab = Vocab(fileformat='voc', voc=all_POSs,dim=FLAGS.POS_dim)
            POS_vocab.dump_to_txt2(POS_path)
        if FLAGS.with_NER:
            print('Number of NERs: {}'.format(len(all_NERs)))
            NER_vocab = Vocab(fileformat='voc', voc=all_NERs,dim=FLAGS.NER_dim)
            NER_vocab.dump_to_txt2(NER_path)
            

    print('word_vocab shape is {}'.format(word_vocab.word_vecs.shape))
    print('tag_vocab shape is {}'.format(label_vocab.word_vecs.shape))
    num_classes = label_vocab.size()

    print('Build SentenceMatchDataStream ... ')
    trainDataStream = SentenceMatchDataStream(train_path, word_vocab=word_vocab, char_vocab=char_vocab, 
                                              POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, 
                                              batch_size=FLAGS.batch_size, isShuffle=True, isLoop=True, isSort=True, 
                                              max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length,
                                              is_as=FLAGS.is_answer_selection)
                                    
    devDataStream = SentenceMatchDataStream(dev_path, word_vocab=word_vocab, char_vocab=char_vocab,
                                              POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, 
                                              batch_size=FLAGS.batch_size, isShuffle=False, isLoop=True, isSort=True,
                                              max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length,
                                              is_as=FLAGS.is_answer_selection)

    testDataStream = SentenceMatchDataStream(test_path, word_vocab=word_vocab, char_vocab=char_vocab, 
                                              POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, 
                                              batch_size=FLAGS.batch_size, isShuffle=False, isLoop=True, isSort=True,
                                              max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length,
                                              is_as=FLAGS.is_answer_selection)

    print('Number of instances in trainDataStream: {}'.format(trainDataStream.get_num_instance()))
    print('Number of instances in devDataStream: {}'.format(devDataStream.get_num_instance()))
    print('Number of instances in testDataStream: {}'.format(testDataStream.get_num_instance()))
    print('Number of batches in trainDataStream: {}'.format(trainDataStream.get_num_batch()))
    print('Number of batches in devDataStream: {}'.format(devDataStream.get_num_batch()))
    print('Number of batches in testDataStream: {}'.format(testDataStream.get_num_batch()))
    
    sys.stdout.flush()
    if FLAGS.wo_char: char_vocab = None
    output_res_index = 1
    while True:
        Generate_random_initialization()
        st_cuda = ''
        if FLAGS.is_server == True:
            st_cuda = str(os.environ['CUDA_VISIBLE_DEVICES']) + '.'
        output_res_file = open('../result/' + st_cuda + str(output_res_index), 'wt')
        output_res_index += 1
        output_res_file.write(str(FLAGS) + '\n\n')
        stt = str (FLAGS)
        best_accuracy = 0.0
        init_scale = 0.01
        with tf.Graph().as_default():
            initializer = tf.random_uniform_initializer(-init_scale, init_scale)
    #         with tf.name_scope("Train"):
            with tf.variable_scope("Model", reuse=None, initializer=initializer):
                train_graph = SentenceMatchModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab,
                                                      dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type,
                                                      lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim,
                                                      aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=True, MP_dim=FLAGS.MP_dim,
                                                      context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num,
                                                      fix_word_vec=FLAGS.fix_word_vec, with_filter_layer=FLAGS.with_filter_layer, with_input_highway=FLAGS.with_highway,
                                                      word_level_MP_dim=FLAGS.word_level_MP_dim,
                                                      with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway,
                                                      highway_layer_num=FLAGS.highway_layer_num, with_lex_decomposition=FLAGS.with_lex_decomposition,
                                                      lex_decompsition_dim=FLAGS.lex_decompsition_dim,
                                                      with_left_match=(not FLAGS.wo_left_match), with_right_match=(not FLAGS.wo_right_match),
                                                      with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match),
                                                      with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match),
                                                      with_bilinear_att=(FLAGS.attention_type)
                                                      , type1=FLAGS.type1, type2 = FLAGS.type2, type3=FLAGS.type3,
                                                      with_aggregation_attention=not FLAGS.wo_agg_self_att,
                                                      is_answer_selection= FLAGS.is_answer_selection,
                                                      is_shared_attention=FLAGS.is_shared_attention,
                                                      modify_loss=FLAGS.modify_loss, is_aggregation_lstm=FLAGS.is_aggregation_lstm
                                                      , max_window_size=FLAGS.max_window_size
                                                      , prediction_mode=FLAGS.prediction_mode,
                                                      context_lstm_dropout=not FLAGS.wo_lstm_drop_out,
                                                      is_aggregation_siamese=FLAGS.is_aggregation_siamese
                                                      , unstack_cnn=FLAGS.unstack_cnn,with_context_self_attention=FLAGS.with_context_self_attention)
                tf.summary.scalar("Training Loss", train_graph.get_loss()) # Add a scalar summary for the snapshot loss.

    #         with tf.name_scope("Valid"):
            with tf.variable_scope("Model", reuse=True, initializer=initializer):
                valid_graph = SentenceMatchModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab,
                                                      dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type,
                                                      lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim,
                                                      aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=False, MP_dim=FLAGS.MP_dim,
                                                      context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num,
                                                      fix_word_vec=FLAGS.fix_word_vec, with_filter_layer=FLAGS.with_filter_layer, with_input_highway=FLAGS.with_highway,
                                                      word_level_MP_dim=FLAGS.word_level_MP_dim,
                                                      with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway,
                                                      highway_layer_num=FLAGS.highway_layer_num, with_lex_decomposition=FLAGS.with_lex_decomposition,
                                                      lex_decompsition_dim=FLAGS.lex_decompsition_dim,
                                                      with_left_match=(not FLAGS.wo_left_match), with_right_match=(not FLAGS.wo_right_match),
                                                      with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match),
                                                      with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match),
                                                      with_bilinear_att=(FLAGS.attention_type)
                                                      , type1=FLAGS.type1, type2 = FLAGS.type2, type3=FLAGS.type3,
                                                      with_aggregation_attention=not FLAGS.wo_agg_self_att,
                                                      is_answer_selection= FLAGS.is_answer_selection,
                                                      is_shared_attention=FLAGS.is_shared_attention,
                                                      modify_loss=FLAGS.modify_loss, is_aggregation_lstm=FLAGS.is_aggregation_lstm,
                                                      max_window_size=FLAGS.max_window_size
                                                      , prediction_mode=FLAGS.prediction_mode,
                                                      context_lstm_dropout=not FLAGS.wo_lstm_drop_out,
                                                      is_aggregation_siamese=FLAGS.is_aggregation_siamese
                                                      , unstack_cnn=FLAGS.unstack_cnn,with_context_self_attention=FLAGS.with_context_self_attention)


            initializer = tf.global_variables_initializer()
            vars_ = {}
            #for var in tf.all_variables():
            for var in tf.global_variables():
                if "word_embedding" in var.name: continue
    #             if not var.name.startswith("Model"): continue
                vars_[var.name.split(":")[0]] = var
            saver = tf.train.Saver(vars_)

            with tf.Session() as sess:
                sess.run(initializer)
                if has_pre_trained_model:
                    print("Restoring model from " + best_path)
                    saver.restore(sess, best_path)
                    print("DONE!")

                print('Start the training loop.')
                train_size = trainDataStream.get_num_batch()
                max_steps = train_size * FLAGS.max_epochs
                total_loss = 0.0
                start_time = time.time()

                for step in xrange(max_steps):
                    # read data
                    cur_batch, batch_index = trainDataStream.nextBatch()
                    (label_batch, sent1_batch, sent2_batch, label_id_batch, word_idx_1_batch, word_idx_2_batch,
                                         char_matrix_idx_1_batch, char_matrix_idx_2_batch, sent1_length_batch, sent2_length_batch,
                                         sent1_char_length_batch, sent2_char_length_batch,
                                         POS_idx_1_batch, POS_idx_2_batch, NER_idx_1_batch, NER_idx_2_batch) = cur_batch
                    feed_dict = {
                                 train_graph.get_truth(): label_id_batch,
                                 train_graph.get_question_lengths(): sent1_length_batch,
                                 train_graph.get_passage_lengths(): sent2_length_batch,
                                 train_graph.get_in_question_words(): word_idx_1_batch,
                                 train_graph.get_in_passage_words(): word_idx_2_batch,
        #                          train_graph.get_question_char_lengths(): sent1_char_length_batch,
        #                          train_graph.get_passage_char_lengths(): sent2_char_length_batch,
        #                          train_graph.get_in_question_chars(): char_matrix_idx_1_batch,
        #                          train_graph.get_in_passage_chars(): char_matrix_idx_2_batch,
                                 }
                    if char_vocab is not None:
                        feed_dict[train_graph.get_question_char_lengths()] = sent1_char_length_batch
                        feed_dict[train_graph.get_passage_char_lengths()] = sent2_char_length_batch
                        feed_dict[train_graph.get_in_question_chars()] = char_matrix_idx_1_batch
                        feed_dict[train_graph.get_in_passage_chars()] = char_matrix_idx_2_batch

                    if POS_vocab is not None:
                        feed_dict[train_graph.get_in_question_poss()] = POS_idx_1_batch
                        feed_dict[train_graph.get_in_passage_poss()] = POS_idx_2_batch

                    if NER_vocab is not None:
                        feed_dict[train_graph.get_in_question_ners()] = NER_idx_1_batch
                        feed_dict[train_graph.get_in_passage_ners()] = NER_idx_2_batch

                    if FLAGS.is_answer_selection == True:
                        feed_dict[train_graph.get_question_count()] = trainDataStream.question_count(batch_index)
                        feed_dict[train_graph.get_answer_count()] = trainDataStream.answer_count(batch_index)

                    _, loss_value = sess.run([train_graph.get_train_op(), train_graph.get_loss()], feed_dict=feed_dict)
                    total_loss += loss_value
                    if FLAGS.is_answer_selection == True and FLAGS.is_server == False:
                        print ("q: {} a: {} loss_value: {}".format(trainDataStream.question_count(batch_index)
                                                   ,trainDataStream.answer_count(batch_index), loss_value))

                    if step % 100==0:
                        print('{} '.format(step), end="")
                        sys.stdout.flush()

                    # Save a checkpoint and evaluate the model periodically.
                    if (step + 1) % trainDataStream.get_num_batch() == 0 or (step + 1) == max_steps:
                        #print(total_loss)
                        # Print status to stdout.
                        duration = time.time() - start_time
                        start_time = time.time()
                        output_res_file.write('Step %d: loss = %.2f (%.3f sec)\n' % (step, total_loss, duration))
                        total_loss = 0.0


                        #Evaluate against the validation set.
                        output_res_file.write('valid- ')
                        my_map, my_mrr = evaluate(devDataStream, valid_graph, sess,char_vocab=char_vocab,
                                            POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab)
                        output_res_file.write("map: '{}', mrr: '{}'\n".format(my_map, my_mrr))
                        #print ("dev map: {}".format(my_map))
                        #print("Current accuracy is %.2f" % accuracy)

                        #accuracy = my_map
                        #if accuracy>best_accuracy:
                        #    best_accuracy = accuracy
                        #    saver.save(sess, best_path)

                        # Evaluate against the test set.
                        output_res_file.write ('test- ')
                        my_map, my_mrr = evaluate(testDataStream, valid_graph, sess, char_vocab=char_vocab,
                                 POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab)
                        output_res_file.write("map: '{}', mrr: '{}\n\n".format(my_map, my_mrr))
                        if FLAGS.is_server == False:
                            print ("test map: {}".format(my_map))

                        #Evaluate against the train set only for final epoch.
                        if (step + 1) == max_steps:
                            output_res_file.write ('train- ')
                            my_map, my_mrr = evaluate(trainDataStream, valid_graph, sess, char_vocab=char_vocab,
                                POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab)
                            output_res_file.write("map: '{}', mrr: '{}'\n".format(my_map, my_mrr))

        # print("Best accuracy on dev set is %.2f" % best_accuracy)
        # # decoding
        # print('Decoding on the test set:')
        # init_scale = 0.01
        # with tf.Graph().as_default():
        #     initializer = tf.random_uniform_initializer(-init_scale, init_scale)
        #     with tf.variable_scope("Model", reuse=False, initializer=initializer):
        #         valid_graph = SentenceMatchModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab,
        #              dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type,
        #              lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim,
        #              aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=False, MP_dim=FLAGS.MP_dim,
        #              context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num,
        #              fix_word_vec=FLAGS.fix_word_vec,with_filter_layer=FLAGS.with_filter_layer, with_highway=FLAGS.with_highway,
        #              word_level_MP_dim=FLAGS.word_level_MP_dim,
        #              with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway,
        #              highway_layer_num=FLAGS.highway_layer_num, with_lex_decomposition=FLAGS.with_lex_decomposition,
        #              lex_decompsition_dim=FLAGS.lex_decompsition_dim,
        #              with_left_match=(not FLAGS.wo_left_match), with_right_match=(not FLAGS.wo_right_match),
        #              with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match),
        #              with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match),
        #                                               with_bilinear_att=(not FLAGS.wo_bilinear_att)
        #                                               , type1=FLAGS.type1, type2 = FLAGS.type2, type3=FLAGS.type3,
        #                                               with_aggregation_attention=not FLAGS.wo_agg_self_att,
        #                                               is_answer_selection= FLAGS.is_answer_selection,
        #                                               is_shared_attention=FLAGS.is_shared_attention,
        #                                               modify_loss=FLAGS.modify_loss,is_aggregation_lstm=FLAGS.is_aggregation_lstm,
        #                                               max_window_size=FLAGS.max_window_size,
        #                                               prediction_mode=FLAGS.prediction_mode,
        #                                               context_lstm_dropout=not FLAGS.wo_lstm_drop_out,
        #                                              is_aggregation_siamese=FLAGS.is_aggregation_siamese)
        #
        #     vars_ = {}
        #     for var in tf.global_variables():
        #         if "word_embedding" in var.name: continue
        #         if not var.name.startswith("Model"): continue
        #         vars_[var.name.split(":")[0]] = var
        #     saver = tf.train.Saver(vars_)
        #
        #     sess = tf.Session()
        #     sess.run(tf.global_variables_initializer())
        #     step = 0
        #     saver.restore(sess, best_path)
        #
        #     accuracy, mrr = evaluate(testDataStream, valid_graph, sess,char_vocab=char_vocab,POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab
        #                         , mode='trec')
        #     output_res_file.write("map for test set is %.2f\n" % accuracy)
        output_res_file.close()
示例#6
0
def main(_):
    print('Configurations:')
    print(FLAGS)

    log_dir = FLAGS.model_dir
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    path_prefix = log_dir + "/G2S.{}".format(FLAGS.suffix)
    log_file_path = path_prefix + ".log"
    print('Log file path: {}'.format(log_file_path))
    log_file = open(log_file_path, 'wt')
    log_file.write("{}\n".format(FLAGS))
    log_file.flush()

    # save configuration
    namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")

    print('Loading train set.')
    if FLAGS.infile_format == 'fof':
        trainset, trn_node, trn_in_neigh, trn_out_neigh, trn_sent = G2S_data_stream.read_nary_from_fof(
            FLAGS.train_path, FLAGS)
    else:
        trainset, trn_node, trn_in_neigh, trn_out_neigh, trn_sent = G2S_data_stream.read_nary_file(
            FLAGS.train_path, FLAGS)

    random.shuffle(trainset)
    devset = trainset[:200]
    trainset = trainset[200:]

    print('Number of training samples: {}'.format(len(trainset)))
    print('Number of dev samples: {}'.format(len(devset)))

    max_node = trn_node
    max_in_neigh = trn_in_neigh
    max_out_neigh = trn_out_neigh
    max_sent = trn_sent
    print('Max node number: {}, while max allowed is {}'.format(
        max_node, FLAGS.max_node_num))
    print('Max parent number: {}, truncated to {}'.format(
        max_in_neigh, FLAGS.max_in_neigh_num))
    print('Max children number: {}, truncated to {}'.format(
        max_out_neigh, FLAGS.max_out_neigh_num))
    print('Max entity size: {}, truncated to {}'.format(
        max_sent, FLAGS.max_entity_size))

    word_vocab = None
    char_vocab = None
    edgelabel_vocab = None
    has_pretrained_model = False
    best_path = path_prefix + ".best.model"
    if os.path.exists(best_path + ".index"):
        has_pretrained_model = True
        print('!!Existing pretrained model. Loading vocabs.')
        word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2')
        print('word_vocab: {}'.format(word_vocab.word_vecs.shape))
        char_vocab = None
        if FLAGS.with_char:
            char_vocab = Vocab(path_prefix + ".char_vocab", fileformat='txt2')
            print('char_vocab: {}'.format(char_vocab.word_vecs.shape))
        edgelabel_vocab = Vocab(path_prefix + ".edgelabel_vocab",
                                fileformat='txt2')
    else:
        print('Collecting vocabs.')
        (allWords, allChars,
         allEdgelabels) = G2S_data_stream.collect_vocabs(trainset)
        print('Number of words: {}'.format(len(allWords)))
        print('Number of allChars: {}'.format(len(allChars)))
        print('Number of allEdgelabels: {}'.format(len(allEdgelabels)))

        word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2')
        char_vocab = None
        if FLAGS.with_char:
            char_vocab = Vocab(voc=allChars,
                               dim=FLAGS.char_dim,
                               fileformat='build')
            char_vocab.dump_to_txt2(path_prefix + ".char_vocab")
        edgelabel_vocab = Vocab(voc=allEdgelabels,
                                dim=FLAGS.edgelabel_dim,
                                fileformat='build')
        edgelabel_vocab.dump_to_txt2(path_prefix + ".edgelabel_vocab")

    print('word vocab size {}'.format(word_vocab.vocab_size))
    sys.stdout.flush()

    print('Build DataStream ... ')
    trainDataStream = G2S_data_stream.G2SDataStream(trainset,
                                                    word_vocab,
                                                    char_vocab,
                                                    edgelabel_vocab,
                                                    options=FLAGS,
                                                    isShuffle=True,
                                                    isLoop=True,
                                                    isSort=False)

    devDataStream = G2S_data_stream.G2SDataStream(devset,
                                                  word_vocab,
                                                  char_vocab,
                                                  edgelabel_vocab,
                                                  options=FLAGS,
                                                  isShuffle=False,
                                                  isLoop=False,
                                                  isSort=False)
    print('Number of instances in trainDataStream: {}'.format(
        trainDataStream.get_num_instance()))
    print('Number of instances in devDataStream: {}'.format(
        devDataStream.get_num_instance()))
    print('Number of batches in trainDataStream: {}'.format(
        trainDataStream.get_num_batch()))
    print('Number of batches in devDataStream: {}'.format(
        devDataStream.get_num_batch()))
    sys.stdout.flush()

    # initialize the best bleu and accu scores for current training session
    best_accu = FLAGS.best_accu if FLAGS.__dict__.has_key('best_accu') else 0.0
    if best_accu > 0.0:
        print('With initial dev accuracy {}'.format(best_accu))

    init_scale = 0.01
    with tf.Graph().as_default():
        initializer = tf.random_uniform_initializer(-init_scale, init_scale)
        with tf.name_scope("Train"):
            with tf.variable_scope("Model",
                                   reuse=None,
                                   initializer=initializer):
                train_graph = ModelGraph(word_vocab=word_vocab,
                                         Edgelabel_vocab=edgelabel_vocab,
                                         char_vocab=char_vocab,
                                         options=FLAGS,
                                         mode='train')

        with tf.name_scope("Valid"):
            with tf.variable_scope("Model",
                                   reuse=True,
                                   initializer=initializer):
                valid_graph = ModelGraph(word_vocab=word_vocab,
                                         Edgelabel_vocab=edgelabel_vocab,
                                         char_vocab=char_vocab,
                                         options=FLAGS,
                                         mode='evaluate')

        initializer = tf.global_variables_initializer()

        vars_ = {}
        for var in tf.all_variables():
            if "word_embedding" in var.name: continue
            if not var.name.startswith("Model"): continue
            vars_[var.name.split(":")[0]] = var
        saver = tf.train.Saver(vars_)

        sess = tf.Session()
        sess.run(initializer)
        if has_pretrained_model:
            print("Restoring model from " + best_path)
            saver.restore(sess, best_path)
            print("DONE!")

            if abs(best_accu) < 0.00001:
                print("Getting ACCU score for the model")
                best_accu = evaluate(sess,
                                     valid_graph,
                                     devDataStream,
                                     options=FLAGS)['dev_accu']
                FLAGS.best_accu = best_accu
                namespace_utils.save_namespace(FLAGS,
                                               path_prefix + ".config.json")
                print('ACCU = %.4f' % best_accu)
                log_file.write('ACCU = %.4f\n' % best_accu)

        print('Start the training loop.')
        train_size = trainDataStream.get_num_batch()
        max_steps = train_size * FLAGS.max_epochs
        last_step = 0
        total_loss = 0.0
        start_time = time.time()
        for step in xrange(max_steps):
            cur_batch = trainDataStream.nextBatch()
            _, loss_value, _ = train_graph.execute(sess,
                                                   cur_batch,
                                                   FLAGS,
                                                   is_train=True)
            total_loss += loss_value

            if step % 100 == 0:
                print('{} '.format(step), end="")
                sys.stdout.flush()

            # Save a checkpoint and evaluate the model periodically.
            if (step + 1) % trainDataStream.get_num_batch() == 0 or (
                    step + 1) == max_steps:
                print()
                duration = time.time() - start_time
                print('Step %d: loss = %.2f (%.3f sec)' %
                      (step, total_loss / (step - last_step), duration))
                log_file.write('Step %d: loss = %.2f (%.3f sec)\n' %
                               (step, total_loss /
                                (step - last_step), duration))
                sys.stdout.flush()
                log_file.flush()
                last_step = step
                total_loss = 0.0

                # Evaluate against the validation set.
                start_time = time.time()
                print('Validation Data Eval:')
                res_dict = evaluate(sess,
                                    valid_graph,
                                    devDataStream,
                                    options=FLAGS,
                                    suffix=str(step))
                dev_loss = res_dict['dev_loss']
                dev_accu = res_dict['dev_accu']
                dev_right = int(res_dict['dev_right'])
                dev_total = int(res_dict['dev_total'])
                print('Dev loss = %.4f' % dev_loss)
                log_file.write('Dev loss = %.4f\n' % dev_loss)
                print('Dev accu = %.4f %d/%d' %
                      (dev_accu, dev_right, dev_total))
                log_file.write('Dev accu = %.4f %d/%d\n' %
                               (dev_accu, dev_right, dev_total))
                log_file.flush()
                if best_accu < dev_accu:
                    print('Saving weights, ACCU {} (prev_best) < {} (cur)'.
                          format(best_accu, dev_accu))
                    saver.save(sess, best_path)
                    best_accu = dev_accu
                    FLAGS.best_accu = dev_accu
                    namespace_utils.save_namespace(
                        FLAGS, path_prefix + ".config.json")
                    json.dump(res_dict['data'], open(FLAGS.output_path, 'w'))
                duration = time.time() - start_time
                print('Duration %.3f sec' % (duration))
                sys.stdout.flush()

                log_file.write('Duration %.3f sec\n' % (duration))
                log_file.flush()

    log_file.close()
示例#7
0
def main(FLAGS):
    train_path = FLAGS.train_path
    dev_path = FLAGS.dev_path
    test_path = FLAGS.test_path
    dev_path_target = FLAGS.dev_path_target
    test_path_target = FLAGS.test_path_target
    word_vec_path = FLAGS.word_vec_path
    log_dir = FLAGS.model_dir
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
        os.makedirs(os.path.join(log_dir, '../result_source'))
        os.makedirs(os.path.join(log_dir, '../logits_source'))
        os.makedirs(os.path.join(log_dir, '../result_target'))
        os.makedirs(os.path.join(log_dir, '../logits_target'))

    log_dir_target = FLAGS.model_dir + '_target'
    if not os.path.exists(log_dir_target):
        os.makedirs(log_dir_target)

    path_prefix = log_dir + "/ESIM.{}".format(FLAGS.suffix)
    namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")
    path_prefix_target = log_dir_target + "/ESIM.{}".format(FLAGS.suffix)
    namespace_utils.save_namespace(FLAGS, path_prefix_target + ".config.json")
    # build vocabs
    word_vocab = Vocab(word_vec_path, fileformat='txt3')

    best_path = path_prefix + '.best.model'
    best_path_target = path_prefix_target + '.best.model'
    char_path = path_prefix + ".char_vocab"
    label_path = path_prefix + ".label_vocab"
    has_pre_trained_model = False
    char_vocab = None
    # if os.path.exists(best_path + ".index"):
    print('Collecting words, chars and labels ...')
    (all_words, all_chars, all_labels, all_POSs,
     all_NERs) = collect_vocabs(train_path)
    print('Number of words: {}'.format(len(all_words)))
    label_vocab = Vocab(fileformat='voc', voc=all_labels, dim=2)
    label_vocab.dump_to_txt2(label_path)

    print('word_vocab shape is {}'.format(word_vocab.word_vecs.shape))
    num_classes = label_vocab.size()
    print("Number of labels: {}".format(num_classes))
    sys.stdout.flush()

    print('Build SentenceMatchDataStream ... ')
    trainDataStream = DataStream(train_path,
                                 word_vocab=word_vocab,
                                 label_vocab=None,
                                 isShuffle=True,
                                 isLoop=True,
                                 isSort=True,
                                 options=FLAGS)
    print('Number of instances in trainDataStream: {}'.format(
        trainDataStream.get_num_instance()))
    print('Number of batches in trainDataStream: {}'.format(
        trainDataStream.get_num_batch()))
    sys.stdout.flush()

    devDataStream = DataStream(dev_path,
                               word_vocab=word_vocab,
                               label_vocab=None,
                               isShuffle=True,
                               isLoop=True,
                               isSort=True,
                               options=FLAGS)
    print('Number of instances in devDataStream: {}'.format(
        devDataStream.get_num_instance()))
    print('Number of batches in devDataStream: {}'.format(
        devDataStream.get_num_batch()))
    sys.stdout.flush()

    testDataStream = DataStream(test_path,
                                word_vocab=word_vocab,
                                label_vocab=None,
                                isShuffle=True,
                                isLoop=True,
                                isSort=True,
                                options=FLAGS)

    print('Number of instances in testDataStream: {}'.format(
        testDataStream.get_num_instance()))
    print('Number of batches in testDataStream: {}'.format(
        testDataStream.get_num_batch()))
    sys.stdout.flush()

    devDataStream_target = DataStream(dev_path_target,
                                      word_vocab=word_vocab,
                                      label_vocab=None,
                                      isShuffle=True,
                                      isLoop=True,
                                      isSort=True,
                                      options=FLAGS)
    print('Number of instances in devDataStream: {}'.format(
        devDataStream.get_num_instance()))
    print('Number of batches in devDataStream: {}'.format(
        devDataStream.get_num_batch()))
    sys.stdout.flush()

    testDataStream_target = DataStream(test_path_target,
                                       word_vocab=word_vocab,
                                       label_vocab=None,
                                       isShuffle=True,
                                       isLoop=True,
                                       isSort=True,
                                       options=FLAGS)

    print('Number of instances in testDataStream: {}'.format(
        testDataStream.get_num_instance()))
    print('Number of batches in testDataStream: {}'.format(
        testDataStream.get_num_batch()))
    sys.stdout.flush()

    with tf.Graph().as_default():
        initializer = tf.contrib.layers.xavier_initializer()
        global_step = tf.train.get_or_create_global_step()
        with tf.variable_scope("Model", reuse=None, initializer=initializer):
            train_graph = Model(num_classes,
                                word_vocab=word_vocab,
                                is_training=True,
                                options=FLAGS,
                                global_step=global_step)
        with tf.variable_scope("Model", reuse=True, initializer=initializer):
            valid_graph = Model(num_classes,
                                word_vocab=word_vocab,
                                is_training=False,
                                options=FLAGS)

        initializer = tf.global_variables_initializer()
        vars_ = {}
        for var in tf.global_variables():
            if "word_embedding" in var.name: continue
            #             if not var.name.startswith("Model"): continue
            vars_[var.name.split(":")[0]] = var
        saver = tf.train.Saver(vars_)

        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=1)
        config = tf.ConfigProto(allow_soft_placement=True,
                                gpu_options=gpu_options)

        config.gpu_options.allow_growth = True
        with tf.Session(config=config) as sess:
            sess.run(initializer)

            # training
            train(sess, saver, train_graph, valid_graph, trainDataStream,
                  devDataStream, testDataStream, devDataStream_target,
                  testDataStream_target, FLAGS, best_path, best_path_target)
示例#8
0
def main(_):
    print('Configurations:')
    print(FLAGS)

    train_path = FLAGS.train_path
    dev_path = FLAGS.dev_path
    test_path = FLAGS.test_path
    word_vec_path = FLAGS.word_vec_path
    log_dir = FLAGS.model_dir
    tolower = FLAGS.use_lower_letter
    FLAGS.rl_matches = json.loads(FLAGS.rl_matches)

    # if not os.path.exists(log_dir):
    #     os.makedirs(log_dir)

    path_prefix = log_dir + "/TriMatch.{}".format(FLAGS.suffix)

    namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")

    # build vocabs
    word_vocab = Vocab(word_vec_path, fileformat='txt3', tolower=tolower)
    best_path = path_prefix + '.best.model'
    char_path = path_prefix + ".char_vocab"
    label_path = path_prefix + ".label_vocab"
    POS_path = path_prefix + ".POS_vocab"
    NER_path = path_prefix + ".NER_vocab"
    has_pre_trained_model = False
    POS_vocab = None
    NER_vocab = None

    print('best path:', best_path)
    if os.path.exists(best_path +
                      '.data-00000-of-00001') and not (FLAGS.create_new_model):
        print('Using pretrained model')
        has_pre_trained_model = True
        label_vocab = Vocab(label_path, fileformat='txt2', tolower=tolower)
        char_vocab = Vocab(char_path, fileformat='txt2', tolower=tolower)
        if FLAGS.with_POS:
            POS_vocab = Vocab(POS_path, fileformat='txt2', tolower=tolower)
        if FLAGS.with_NER:
            NER_vocab = Vocab(NER_path, fileformat='txt2', tolower=tolower)
    else:
        print('Creating new model')
        print('Collect words, chars and labels ...')
        (all_words, all_chars, all_labels, all_POSs,
         all_NERs) = collect_vocabs(train_path,
                                    with_POS=FLAGS.with_POS,
                                    with_NER=FLAGS.with_NER,
                                    tolower=tolower)
        if FLAGS.use_options:
            all_labels = ['0', '1']
        print('Number of words: {}'.format(len(all_words)))
        print('Number of labels: {}'.format(len(all_labels)))
        # for word in all_labels:
        #     print('label',word)
        # input('check')

        label_vocab = Vocab(fileformat='voc',
                            voc=all_labels,
                            dim=2,
                            tolower=tolower)
        label_vocab.dump_to_txt2(label_path)

        print('Number of chars: {}'.format(len(all_chars)))
        char_vocab = Vocab(fileformat='voc',
                           voc=all_chars,
                           dim=FLAGS.char_emb_dim,
                           tolower=tolower)
        char_vocab.dump_to_txt2(char_path)

        if FLAGS.with_POS:
            print('Number of POSs: {}'.format(len(all_POSs)))
            POS_vocab = Vocab(fileformat='voc',
                              voc=all_POSs,
                              dim=FLAGS.POS_dim,
                              tolower=tolower)
            POS_vocab.dump_to_txt2(POS_path)
        if FLAGS.with_NER:
            print('Number of NERs: {}'.format(len(all_NERs)))
            NER_vocab = Vocab(fileformat='voc',
                              voc=all_NERs,
                              dim=FLAGS.NER_dim,
                              tolower=tolower)
            NER_vocab.dump_to_txt2(NER_path)

    print('all_labels:', label_vocab)
    print('has pretrained model:', has_pre_trained_model)
    # for word in word_vocab.word_vecs:

    print('word_vocab shape is {}'.format(word_vocab.word_vecs.shape))
    print('tag_vocab shape is {}'.format(label_vocab.word_vecs.shape))
    num_classes = label_vocab.size()

    print('Build TriMatchDataStream ... ')

    gen_concat_mat = False
    gen_split_mat = False
    if FLAGS.matching_option == 7:
        gen_concat_mat = True
        if FLAGS.concat_context:
            gen_split_mat = True
    trainDataStream = TriMatchDataStream(
        train_path,
        word_vocab=word_vocab,
        char_vocab=char_vocab,
        POS_vocab=POS_vocab,
        NER_vocab=NER_vocab,
        label_vocab=label_vocab,
        batch_size=FLAGS.batch_size,
        isShuffle=True,
        isLoop=True,
        isSort=(not FLAGS.wo_sort_instance_based_on_length),
        max_char_per_word=FLAGS.max_char_per_word,
        max_sent_length=FLAGS.max_sent_length,
        max_hyp_length=FLAGS.max_hyp_length,
        max_choice_length=FLAGS.max_choice_length,
        tolower=tolower,
        gen_concat_mat=gen_concat_mat,
        gen_split_mat=gen_split_mat)

    devDataStream = TriMatchDataStream(
        dev_path,
        word_vocab=word_vocab,
        char_vocab=char_vocab,
        POS_vocab=POS_vocab,
        NER_vocab=NER_vocab,
        label_vocab=label_vocab,
        batch_size=FLAGS.batch_size,
        isShuffle=False,
        isLoop=True,
        isSort=(not FLAGS.wo_sort_instance_based_on_length),
        max_char_per_word=FLAGS.max_char_per_word,
        max_sent_length=FLAGS.max_sent_length,
        max_hyp_length=FLAGS.max_hyp_length,
        max_choice_length=FLAGS.max_choice_length,
        tolower=tolower,
        gen_concat_mat=gen_concat_mat,
        gen_split_mat=gen_split_mat)

    testDataStream = TriMatchDataStream(
        test_path,
        word_vocab=word_vocab,
        char_vocab=char_vocab,
        POS_vocab=POS_vocab,
        NER_vocab=NER_vocab,
        label_vocab=label_vocab,
        batch_size=FLAGS.batch_size,
        isShuffle=False,
        isLoop=True,
        isSort=(not FLAGS.wo_sort_instance_based_on_length),
        max_char_per_word=FLAGS.max_char_per_word,
        max_sent_length=FLAGS.max_sent_length,
        max_hyp_length=FLAGS.max_hyp_length,
        max_choice_length=FLAGS.max_choice_length,
        tolower=tolower,
        gen_concat_mat=gen_concat_mat,
        gen_split_mat=gen_split_mat)

    print('Number of instances in trainDataStream: {}'.format(
        trainDataStream.get_num_instance()))
    print('Number of instances in devDataStream: {}'.format(
        devDataStream.get_num_instance()))
    print('Number of instances in testDataStream: {}'.format(
        testDataStream.get_num_instance()))
    print('Number of batches in trainDataStream: {}'.format(
        trainDataStream.get_num_batch()))
    print('Number of batches in devDataStream: {}'.format(
        devDataStream.get_num_batch()))
    print('Number of batches in testDataStream: {}'.format(
        testDataStream.get_num_batch()))

    sys.stdout.flush()
    if FLAGS.wo_char: char_vocab = None

    best_accuracy = 0.0
    init_scale = 0.01
    with tf.Graph().as_default():
        initializer = tf.random_uniform_initializer(-init_scale, init_scale)
        #         with tf.name_scope("Train"):
        with tf.variable_scope("Model", reuse=None, initializer=initializer):
            train_graph = TriMatchModelGraph(
                num_classes,
                word_vocab=word_vocab,
                char_vocab=char_vocab,
                POS_vocab=POS_vocab,
                NER_vocab=NER_vocab,
                dropout_rate=FLAGS.dropout_rate,
                learning_rate=FLAGS.learning_rate,
                optimize_type=FLAGS.optimize_type,
                lambda_l2=FLAGS.lambda_l2,
                char_lstm_dim=FLAGS.char_lstm_dim,
                context_lstm_dim=FLAGS.context_lstm_dim,
                aggregation_lstm_dim=FLAGS.aggregation_lstm_dim,
                is_training=True,
                MP_dim=FLAGS.MP_dim,
                context_layer_num=FLAGS.context_layer_num,
                aggregation_layer_num=FLAGS.aggregation_layer_num,
                fix_word_vec=FLAGS.fix_word_vec,
                with_highway=FLAGS.with_highway,
                word_level_MP_dim=FLAGS.word_level_MP_dim,
                with_match_highway=FLAGS.with_match_highway,
                with_aggregation_highway=FLAGS.with_aggregation_highway,
                highway_layer_num=FLAGS.highway_layer_num,
                match_to_question=FLAGS.match_to_question,
                match_to_passage=FLAGS.match_to_passage,
                match_to_choice=FLAGS.match_to_choice,
                with_full_match=(not FLAGS.wo_full_match),
                with_maxpool_match=(not FLAGS.wo_maxpool_match),
                with_attentive_match=(not FLAGS.wo_attentive_match),
                with_max_attentive_match=(not FLAGS.wo_max_attentive_match),
                use_options=FLAGS.use_options,
                num_options=num_options,
                with_no_match=FLAGS.with_no_match,
                verbose=FLAGS.verbose,
                matching_option=FLAGS.matching_option,
                concat_context=FLAGS.concat_context,
                tied_aggre=FLAGS.tied_aggre,
                rl_training_method=FLAGS.rl_training_method,
                rl_matches=FLAGS.rl_matches)

            tf.summary.scalar("Training Loss", train_graph.get_loss()
                              )  # Add a scalar summary for the snapshot loss.
        if FLAGS.verbose:
            valid_graph = train_graph
        else:
            #         with tf.name_scope("Valid"):
            with tf.variable_scope("Model",
                                   reuse=True,
                                   initializer=initializer):
                valid_graph = TriMatchModelGraph(
                    num_classes,
                    word_vocab=word_vocab,
                    char_vocab=char_vocab,
                    POS_vocab=POS_vocab,
                    NER_vocab=NER_vocab,
                    dropout_rate=FLAGS.dropout_rate,
                    learning_rate=FLAGS.learning_rate,
                    optimize_type=FLAGS.optimize_type,
                    lambda_l2=FLAGS.lambda_l2,
                    char_lstm_dim=FLAGS.char_lstm_dim,
                    context_lstm_dim=FLAGS.context_lstm_dim,
                    aggregation_lstm_dim=FLAGS.aggregation_lstm_dim,
                    is_training=False,
                    MP_dim=FLAGS.MP_dim,
                    context_layer_num=FLAGS.context_layer_num,
                    aggregation_layer_num=FLAGS.aggregation_layer_num,
                    fix_word_vec=FLAGS.fix_word_vec,
                    with_highway=FLAGS.with_highway,
                    word_level_MP_dim=FLAGS.word_level_MP_dim,
                    with_match_highway=FLAGS.with_match_highway,
                    with_aggregation_highway=FLAGS.with_aggregation_highway,
                    highway_layer_num=FLAGS.highway_layer_num,
                    match_to_question=FLAGS.match_to_question,
                    match_to_passage=FLAGS.match_to_passage,
                    match_to_choice=FLAGS.match_to_choice,
                    with_full_match=(not FLAGS.wo_full_match),
                    with_maxpool_match=(not FLAGS.wo_maxpool_match),
                    with_attentive_match=(not FLAGS.wo_attentive_match),
                    with_max_attentive_match=(
                        not FLAGS.wo_max_attentive_match),
                    use_options=FLAGS.use_options,
                    num_options=num_options,
                    with_no_match=FLAGS.with_no_match,
                    matching_option=FLAGS.matching_option,
                    concat_context=FLAGS.concat_context,
                    tied_aggre=FLAGS.tied_aggre,
                    rl_training_method=FLAGS.rl_training_method,
                    rl_matches=FLAGS.rl_matches)

        initializer = tf.global_variables_initializer()
        vars_ = {}
        for var in tf.global_variables():
            # print(var.name,var.get_shape().as_list())
            if "word_embedding" in var.name: continue
            #             if not var.name.startswith("Model"): continue
            vars_[var.name.split(":")[0]] = var
        saver = tf.train.Saver(vars_)
        # input('check')

        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        sess = tf.Session(config=config)
        sess.run(initializer)
        if has_pre_trained_model:
            print("Restoring model from " + best_path)
            saver.restore(sess, best_path)
            print("DONE!")

        print('Start the training loop.')
        train_size = trainDataStream.get_num_batch()
        max_steps = train_size * FLAGS.max_epochs
        total_loss = 0.0
        start_time = time.time()
        sub_loss_counter = 0.0
        for step in range(max_steps):
            # read data
            cur_batch = trainDataStream.nextBatch()
            (label_batch, sent1_batch, sent2_batch, sent3_batch,
             label_id_batch, word_idx_1_batch, word_idx_2_batch,
             word_idx_3_batch, char_matrix_idx_1_batch,
             char_matrix_idx_2_batch, char_matrix_idx_3_batch,
             sent1_length_batch, sent2_length_batch, sent3_length_batch,
             sent1_char_length_batch, sent2_char_length_batch,
             sent3_char_length_batch, POS_idx_1_batch, POS_idx_2_batch,
             NER_idx_1_batch, NER_idx_2_batch, concat_mat_batch,
             split_mat_batch_q, split_mat_batch_c) = cur_batch

            # print(label_id_batch)
            if FLAGS.verbose:
                print(label_id_batch)
                print(sent1_length_batch)
                print(sent2_length_batch)
                print(sent3_length_batch)
                # print(word_idx_1_batch)
                # print(word_idx_2_batch)
                # print(word_idx_3_batch)
                # print(sent1_batch)
                # print(sent2_batch)
                # print(sent3_batch)
                print(concat_mat_batch)
                input('check')
            feed_dict = {
                train_graph.get_truth(): label_id_batch,
                train_graph.get_passage_lengths(): sent1_length_batch,
                train_graph.get_question_lengths(): sent2_length_batch,
                train_graph.get_choice_lengths(): sent3_length_batch,
                train_graph.get_in_passage_words(): word_idx_1_batch,
                train_graph.get_in_question_words(): word_idx_2_batch,
                train_graph.get_in_choice_words(): word_idx_3_batch,
                #                          train_graph.get_question_char_lengths(): sent1_char_length_batch,
                #                          train_graph.get_passage_char_lengths(): sent2_char_length_batch,
                #                          train_graph.get_in_question_chars(): char_matrix_idx_1_batch,
                #                          train_graph.get_in_passage_chars(): char_matrix_idx_2_batch,
            }
            if char_vocab is not None:
                feed_dict[train_graph.get_passage_char_lengths(
                )] = sent1_char_length_batch
                feed_dict[train_graph.get_question_char_lengths(
                )] = sent2_char_length_batch
                feed_dict[train_graph.get_choice_char_lengths(
                )] = sent3_char_length_batch
                feed_dict[train_graph.get_in_passage_chars(
                )] = char_matrix_idx_1_batch
                feed_dict[train_graph.get_in_question_chars(
                )] = char_matrix_idx_2_batch
                feed_dict[train_graph.get_in_choice_chars(
                )] = char_matrix_idx_3_batch

            if POS_vocab is not None:
                feed_dict[train_graph.get_in_passage_poss()] = POS_idx_1_batch
                feed_dict[train_graph.get_in_question_poss()] = POS_idx_2_batch

            if NER_vocab is not None:
                feed_dict[train_graph.get_in_passage_ners()] = NER_idx_1_batch
                feed_dict[train_graph.get_in_question_ners()] = NER_idx_2_batch
            if concat_mat_batch is not None:
                feed_dict[train_graph.concat_idx_mat] = concat_mat_batch
            if split_mat_batch_q is not None:
                feed_dict[train_graph.split_idx_mat_q] = split_mat_batch_q
                feed_dict[train_graph.split_idx_mat_c] = split_mat_batch_c

            if FLAGS.verbose:
                return_list = sess.run([
                    train_graph.get_train_op(),
                    train_graph.get_loss(),
                    train_graph.get_predictions(),
                    train_graph.get_prob(), train_graph.all_probs,
                    train_graph.correct
                ] + train_graph.matching_vectors,
                                       feed_dict=feed_dict)
                print(len(return_list))
                _, loss_value, pred, prob, all_probs, correct = return_list[
                    0:6]
                print('pred=', pred)
                print('prob=', prob)
                print('logits=', all_probs)
                print('correct=', correct)
                for val in return_list[6:]:
                    if isinstance(val, list):
                        print('list len ', len(val))
                        for objj in val:
                            print('this shape=', val.shape)
                    print('this shape=', val.shape)
                    # print(val)
                input('check')
            else:
                _, loss_value = sess.run(
                    [train_graph.get_train_op(),
                     train_graph.get_loss()],
                    feed_dict=feed_dict)
            total_loss += loss_value
            sub_loss_counter += loss_value

            if step % int(FLAGS.display_every) == 0:
                print('{},{} '.format(step, sub_loss_counter), end="")
                sys.stdout.flush()
                sub_loss_counter = 0.0

            # Save a checkpoint and evaluate the model periodically.
            if (step + 1) % trainDataStream.get_num_batch() == 0 or (
                    step + 1) == max_steps:
                print()
                # Print status to stdout.
                duration = time.time() - start_time
                start_time = time.time()
                print('Step %d: loss = %.2f (%.3f sec)' %
                      (step, total_loss, duration))
                total_loss = 0.0

                # Evaluate against the validation set.
                print('Validation Data Eval:')
                if FLAGS.predict_val:
                    outpath = path_prefix + '.iter%d' % (step) + '.probs'
                else:
                    outpath = None
                accuracy = evaluate(devDataStream,
                                    valid_graph,
                                    sess,
                                    char_vocab=char_vocab,
                                    POS_vocab=POS_vocab,
                                    NER_vocab=NER_vocab,
                                    use_options=FLAGS.use_options,
                                    outpath=outpath,
                                    mode='prob')
                print("Current accuracy on dev set is %.2f" % accuracy)
                if accuracy >= best_accuracy:
                    best_accuracy = accuracy
                    saver.save(sess, best_path)
                    print('saving the current model.')
                accuracy = evaluate(testDataStream,
                                    valid_graph,
                                    sess,
                                    char_vocab=char_vocab,
                                    POS_vocab=POS_vocab,
                                    NER_vocab=NER_vocab,
                                    use_options=FLAGS.use_options,
                                    outpath=outpath,
                                    mode='prob')
                print("Current accuracy on test set is %.2f" % accuracy)

    print("Best accuracy on dev set is %.2f" % best_accuracy)
    # decoding
    print('Decoding on the test set:')
    init_scale = 0.01
    with tf.Graph().as_default():
        initializer = tf.random_uniform_initializer(-init_scale, init_scale)
        with tf.variable_scope("Model", reuse=False, initializer=initializer):
            valid_graph = TriMatchModelGraph(
                num_classes,
                word_vocab=word_vocab,
                char_vocab=char_vocab,
                POS_vocab=POS_vocab,
                NER_vocab=NER_vocab,
                dropout_rate=FLAGS.dropout_rate,
                learning_rate=FLAGS.learning_rate,
                optimize_type=FLAGS.optimize_type,
                lambda_l2=FLAGS.lambda_l2,
                char_lstm_dim=FLAGS.char_lstm_dim,
                context_lstm_dim=FLAGS.context_lstm_dim,
                aggregation_lstm_dim=FLAGS.aggregation_lstm_dim,
                is_training=False,
                MP_dim=FLAGS.MP_dim,
                context_layer_num=FLAGS.context_layer_num,
                aggregation_layer_num=FLAGS.aggregation_layer_num,
                fix_word_vec=FLAGS.fix_word_vec,
                with_highway=FLAGS.with_highway,
                word_level_MP_dim=FLAGS.word_level_MP_dim,
                with_match_highway=FLAGS.with_match_highway,
                with_aggregation_highway=FLAGS.with_aggregation_highway,
                highway_layer_num=FLAGS.highway_layer_num,
                match_to_question=FLAGS.match_to_question,
                match_to_passage=FLAGS.match_to_passage,
                match_to_choice=FLAGS.match_to_choice,
                with_full_match=(not FLAGS.wo_full_match),
                with_maxpool_match=(not FLAGS.wo_maxpool_match),
                with_attentive_match=(not FLAGS.wo_attentive_match),
                with_max_attentive_match=(not FLAGS.wo_max_attentive_match),
                use_options=FLAGS.use_options,
                num_options=num_options,
                with_no_match=FLAGS.with_no_match,
                matching_option=FLAGS.matching_option,
                concat_context=FLAGS.concat_context,
                tied_aggre=FLAGS.tied_aggre,
                rl_training_method=FLAGS.rl_training_method,
                rl_matches=FLAGS.rl_matches)
        vars_ = {}
        for var in tf.all_variables():
            if "word_embedding" in var.name: continue
            if not var.name.startswith("Model"): continue
            vars_[var.name.split(":")[0]] = var
        saver = tf.train.Saver(vars_)

        sess = tf.Session()
        sess.run(tf.global_variables_initializer())
        step = 0
        saver.restore(sess, best_path)

        accuracy = evaluate(testDataStream,
                            valid_graph,
                            sess,
                            char_vocab=char_vocab,
                            POS_vocab=POS_vocab,
                            NER_vocab=NER_vocab,
                            use_options=FLAGS.use_options)
        print("Accuracy for test set is %.2f" % accuracy)
示例#9
0
def main(_):

    if FLAGS.is_shuffle == 'True':
        FLAGS.is_shuffle = True
    else:
        FLAGS.is_shuffle = False
    print('Configuration')

    log_dir = FLAGS.model_dir
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    path_prefix = log_dir + "/SentenceMatch.{}".format(FLAGS.suffix)

    namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")
    output_res_file = open('../result/' + FLAGS.run_id, 'wt')
    while (Get_Next_box_size(FLAGS.start_batch) == True):
        output_res_file.write('Q' + str(FLAGS) + '\n')
        print('Q' + str(FLAGS))
        train_path = FLAGS.train_path
        dev_path = FLAGS.dev_path
        test_path = FLAGS.test_path
        best_path = path_prefix + '.best.model'

        #zero_pad = True
        zero_pad = False
        if FLAGS.prediction_mode == 'list_wise' and FLAGS.loss_type == 'list_mle':
            zero_pad = True

        trainDataStream = SentenceMatchDataStream(train_path,
                                                  is_training=True,
                                                  isShuffle=FLAGS.is_shuffle,
                                                  isLoop=True,
                                                  isSort=True,
                                                  zero_pad=zero_pad,
                                                  is_ndcg=FLAGS.is_ndcg)
        #isShuggle must be true because it dtermines we are training or not.

        #train_testDataStream = SentenceMatchDataStream(train_path, isShuffle=False, isLoop=True, isSort=True)

        testDataStream = SentenceMatchDataStream(test_path,
                                                 is_training=False,
                                                 isShuffle=False,
                                                 isLoop=True,
                                                 isSort=True,
                                                 is_ndcg=FLAGS.is_ndcg)

        devDataStream = SentenceMatchDataStream(dev_path,
                                                is_training=False,
                                                isShuffle=False,
                                                isLoop=True,
                                                isSort=True,
                                                is_ndcg=FLAGS.is_ndcg)

        print('Number of instances in trainDataStream: {}'.format(
            trainDataStream.get_num_instance()))
        print('Number of instances in devDataStream: {}'.format(
            devDataStream.get_num_instance()))
        print('Number of instances in testDataStream: {}'.format(
            testDataStream.get_num_instance()))
        # print('Number of batches in trainDataStream: {}'.format(trainDataStream.get_num_batch()))
        # print('Number of batches in devDataStream: {}'.format(devDataStream.get_num_batch()))
        # print('Number of batches in testDataStream: {}'.format(testDataStream.get_num_batch()))

        sys.stdout.flush()
        output_res_index = 1
        # best_test_acc = 0
        max_test_ndcg = np.zeros(10)
        max_valid = np.zeros(10)
        max_test = np.zeros(10)
        # max_dev_ndcg = 0
        while output_res_index <= FLAGS.iter_count:
            # st_cuda = ''
            ssst = FLAGS.run_id
            ssst += str(FLAGS.start_batch)
            # output_res_file = open('../result/' + ssst + '.'+ st_cuda + str(output_res_index), 'wt')
            # output_sentence_file = open('../result/' + ssst + '.'+ st_cuda + str(output_res_index) + "S", 'wt')
            # output_train_file = open('../result/' + ssst + '.'+ st_cuda + str(output_res_index) + "T", 'wt')
            # output_sentences = []
            output_res_index += 1
            # output_res_file.write(str(FLAGS) + '\n\n')
            # stt = str (FLAGS)
            # best_dev_acc = 0.0
            init_scale = 0.001
            with tf.Graph().as_default():
                # tf.set_random_seed(0)
                # np.random.seed(123)
                input_dim = 136
                if train_path.find("2008") > 0:
                    input_dim = 46
                initializer = tf.random_uniform_initializer(
                    -init_scale, init_scale)
                with tf.variable_scope("Model",
                                       reuse=None,
                                       initializer=initializer):
                    train_graph = SentenceMatchModelGraph(
                        num_classes=3,
                        is_training=True,
                        learning_rate=FLAGS.learning_rate,
                        lambda_l2=FLAGS.lambda_l2,
                        prediction_mode=FLAGS.prediction_mode,
                        q_count=FLAGS.question_count_per_batch,
                        loss_type=FLAGS.loss_type,
                        pos_avg=FLAGS.pos_avg,
                        input_dim=input_dim)
                    tf.summary.scalar("Training Loss", train_graph.get_loss(
                    ))  # Add a scalar summary for the snapshot loss.

        #         with tf.name_scope("Valid"):
                with tf.variable_scope("Model",
                                       reuse=True,
                                       initializer=initializer):
                    valid_graph = SentenceMatchModelGraph(
                        num_classes=3,
                        is_training=True,
                        learning_rate=FLAGS.learning_rate,
                        lambda_l2=FLAGS.lambda_l2,
                        prediction_mode=FLAGS.prediction_mode,
                        q_count=1,
                        loss_type=FLAGS.loss_type,
                        pos_avg=FLAGS.pos_avg,
                        input_dim=input_dim)

                # tf.set_random_seed(123)
                # np.random.seed(123)
                initializer = tf.global_variables_initializer()
                # tf.set_random_seed(123)
                # np.random.seed(123)

                vars_ = {}
                #for var in tf.all_variables():
                for var in tf.global_variables():
                    vars_[var.name.split(":")[0]] = var
                saver = tf.train.Saver(vars_)

                max_valid_iter = np.zeros(10)
                max_test_ndcg_iter = np.zeros(10)
                with tf.Session() as sess:
                    # tf.set_random_seed(123)
                    # np.random.seed(123)
                    sess.run(initializer)

                    train_size = trainDataStream.get_num_batch()
                    max_steps = (train_size * FLAGS.max_epochs
                                 ) // FLAGS.question_count_per_batch
                    epoch_size = max_steps // (FLAGS.max_epochs) + 1
                    total_loss = 0.0
                    start_time = time.time()
                    for step in range(max_steps):
                        # read data
                        _truth = []
                        _input_vector = []
                        _mask = []
                        for i in range(FLAGS.question_count_per_batch):
                            cur_batch, batch_index = trainDataStream.nextBatch(
                            )
                            (label_id_batch, input_vector_batch,
                             mask_batch) = cur_batch

                            if FLAGS.prediction_mode == 'list_wise' and FLAGS.loss_type == 'list_mle':
                                label_id_batch, input_vector_batch = sort_mle(
                                    label_id_batch, input_vector_batch)
                            _truth.append(label_id_batch)
                            _input_vector.append(input_vector_batch)
                            _mask.append(mask_batch)

                        #print (_truth)
                        feed_dict = {
                            train_graph.get_truth(): tuple(_truth),
                            train_graph.get_input_vector():
                            tuple(_input_vector),
                            train_graph.get_mask(): tuple(_mask)
                        }
                        _, loss_value = sess.run([
                            train_graph.get_train_op(),
                            train_graph.get_loss()
                        ],
                                                 feed_dict=feed_dict)
                        #print (loss_value)
                        #print (sess.run([train_graph.truth, train_graph.soft_truth], feed_dict=feed_dict))
                        #loss_value = sess.run([train_graph.logits1], feed_dict=feed_dict)
                        import math
                        if math.isnan(loss_value):
                            print(step)
                            print(
                                sess.run([
                                    train_graph.truth, train_graph.mask,
                                    train_graph.mask2, train_graph.mask01
                                ],
                                         feed_dict=feed_dict))

                        total_loss += loss_value
                        if (step + 1) % epoch_size == 0 or (step +
                                                            1) == max_steps:
                            if (step + 1) == max_steps:
                                print(total_loss)
                            # duration = time.time() - start_time
                            # start_time = time.time()
                            # total_loss = 0.0

                            for ndcg_ind in range(10):
                                v_map = evaluate(devDataStream,
                                                 valid_graph,
                                                 sess,
                                                 is_ndcg=FLAGS.is_ndcg,
                                                 top_k=ndcg_ind)
                                if v_map > max_valid[ndcg_ind]:
                                    max_valid[ndcg_ind] = v_map
                                flag_valid = False
                                if v_map > max_valid_iter[ndcg_ind]:
                                    max_valid_iter[ndcg_ind] = v_map
                                    flag_valid = True
                                te_map = evaluate(testDataStream,
                                                  valid_graph,
                                                  sess,
                                                  is_ndcg=FLAGS.is_ndcg,
                                                  flag_valid=flag_valid,
                                                  top_k=ndcg_ind)
                                if te_map > max_test[ndcg_ind]:
                                    max_test[ndcg_ind] = te_map
                                if flag_valid == True:
                                    # if te_map > max_test_ndcg[ndcg_ind] and FLAGS.store_best == True:
                                    #     #best_test_acc = my_map
                                    #     saver.save(sess, best_path)
                                    # if te_map > max_test_ndcg[ndcg_ind]:
                                    #     max_test_ndcg[ndcg_ind] = te_map
                                    # if te_map > max_test_ndcg_iter[ndcg_ind]:
                                    #     max_test_ndcg_iter[ndcg_ind] = te_map
                                    max_test_ndcg_iter[ndcg_ind] = te_map
                                #print ("{} - {}".format(v_map, my_map))

                    for ndcg_ind in range(10):
                        if max_test_ndcg_iter[ndcg_ind] > max_test_ndcg[
                                ndcg_ind]:
                            max_test_ndcg[ndcg_ind] = max_test_ndcg_iter[
                                ndcg_ind]

            #print (total_loss)
            print("{}-{}: {}".format(FLAGS.start_batch, output_res_index - 1,
                                     max_test_ndcg_iter))
            output_res_file.write("{}-{}: {}\n".format(FLAGS.start_batch,
                                                       output_res_index - 1,
                                                       max_test_ndcg_iter))

        print("*{}-{}: {}-{}-{}".format(FLAGS.fold, FLAGS.start_batch,
                                        max_valid, max_test, max_test_ndcg))
        output_res_file.write("{}-{}: {}-{}-{}\n".format(
            FLAGS.fold, FLAGS.start_batch, max_valid, max_test, max_test_ndcg))
        FLAGS.start_batch += FLAGS.step_batch

    output_res_file.close()
示例#10
0
def main(_):
    print('Configurations:')
    print(FLAGS)

    log_dir = FLAGS.model_dir
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    path_prefix = log_dir + "/G2S.{}".format(FLAGS.suffix)
    log_file_path = path_prefix + ".log"
    print('Log file path: {}'.format(log_file_path))
    log_file = open(log_file_path, 'wt')
    log_file.write("{}\n".format(FLAGS))
    log_file.flush()

    # save configuration
    namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")

    word_vocab_enc = None
    word_vocab_dec = None
    char_vocab = None
    edgelabel_vocab = None
    has_pretrained_model = False
    best_path = path_prefix + ".best.model"
    if os.path.exists(best_path + ".index"):
        has_pretrained_model = True
        print('!!Existing pretrained model. Loading vocabs.')
        word_vocab_enc = Vocab(FLAGS.word_vec_src_path, fileformat='txt2')
        print('word_vocab SRC: {}'.format(word_vocab_enc.word_vecs.shape))
        word_vocab_dec = Vocab(FLAGS.word_vec_tgt_path, fileformat='txt2')
        print('word_vocab TGT: {}'.format(word_vocab_dec.word_vecs.shape))
        if FLAGS.with_char:
            char_vocab = Vocab(path_prefix + ".char_vocab", fileformat='txt2')
            print('char_vocab: {}'.format(char_vocab.word_vecs.shape))
        edgelabel_vocab = Vocab(path_prefix + ".edgelabel_vocab", fileformat='txt2')
    else:
        print('Collecting vocabs.')
        word_vocab_enc = Vocab(FLAGS.word_vec_src_path, fileformat='txt2')
        word_vocab_dec = Vocab(FLAGS.word_vec_tgt_path, fileformat='txt2')
        if FLAGS.with_char:
            char_vocab = Vocab(voc=allChars, dim=FLAGS.char_dim, fileformat='build')
            char_vocab.dump_to_txt2(path_prefix + ".char_vocab")
        allEdgelabels = set([line.strip().split()[0] \
                for line in open(FLAGS.edgelabel_vocab, 'rU')])
        edgelabel_vocab = Vocab(voc=allEdgelabels, dim=FLAGS.edgelabel_dim, fileformat='build')
        edgelabel_vocab.dump_to_txt2(path_prefix + ".edgelabel_vocab")

    print('word vocab SRC size {}'.format(word_vocab_enc.vocab_size))
    print('word vocab TGT size {}'.format(word_vocab_dec.vocab_size))
    sys.stdout.flush()

    print('Loading train set.')
    if FLAGS.infile_format == 'fof':
        trainset, trn_node, trn_in_neigh, trn_out_neigh, trn_sent = G2S_data_stream.read_amr_from_fof(FLAGS.train_path, FLAGS,
                word_vocab_enc, word_vocab_dec, char_vocab, edgelabel_vocab)
    else:
        trainset, trn_node, trn_in_neigh, trn_out_neigh, trn_sent = G2S_data_stream.read_amr_file(FLAGS.train_path, FLAGS,
                word_vocab_enc, word_vocab_dec, char_vocab, edgelabel_vocab)
    print('Number of training samples: {}'.format(len(trainset)))

    print('Loading test set.')
    if FLAGS.infile_format == 'fof':
        testset, tst_node, tst_in_neigh, tst_out_neigh, tst_sent = G2S_data_stream.read_amr_from_fof(FLAGS.test_path, FLAGS,
                word_vocab_enc, word_vocab_dec, char_vocab, edgelabel_vocab)
    else:
        testset, tst_node, tst_in_neigh, tst_out_neigh, tst_sent = G2S_data_stream.read_amr_file(FLAGS.test_path, FLAGS,
                word_vocab_enc, word_vocab_dec, char_vocab, edgelabel_vocab)
    print('Number of test samples: {}'.format(len(testset)))

    max_node = max(trn_node, tst_node)
    max_in_neigh = max(trn_in_neigh, tst_in_neigh)
    max_out_neigh = max(trn_out_neigh, tst_out_neigh)
    max_sent = max(trn_sent, tst_sent)
    print('Max node number: {}, while max allowed is {}'.format(max_node, FLAGS.max_node_num))
    print('Max parent number: {}, truncated to {}'.format(max_in_neigh, FLAGS.max_in_neigh_num))
    print('Max children number: {}, truncated to {}'.format(max_out_neigh, FLAGS.max_out_neigh_num))
    print('Max answer length: {}, truncated to {}'.format(max_sent, FLAGS.max_answer_len))

    print('Build DataStream ... ')
    trainDataStream = G2S_data_stream.G2SDataStream(trainset, word_vocab_enc, word_vocab_dec, char_vocab, edgelabel_vocab,
            options=FLAGS, isShuffle=True, isLoop=True, isSort=True)

    devDataStream = G2S_data_stream.G2SDataStream(testset, word_vocab_enc, word_vocab_dec, char_vocab, edgelabel_vocab,
            options=FLAGS, isShuffle=False, isLoop=False, isSort=True)
    print('Number of instances in trainDataStream: {}'.format(trainDataStream.get_num_instance()))
    print('Number of instances in devDataStream: {}'.format(devDataStream.get_num_instance()))
    print('Number of batches in trainDataStream: {}'.format(trainDataStream.get_num_batch()))
    print('Number of batches in devDataStream: {}'.format(devDataStream.get_num_batch()))
    sys.stdout.flush()

    # initialize the best bleu and accu scores for current training session
    best_accu = FLAGS.best_accu if FLAGS.__dict__.has_key('best_accu') else 0.0
    best_bleu = FLAGS.best_bleu if FLAGS.__dict__.has_key('best_bleu') else 0.0
    if best_accu > 0.0:
        print('With initial dev accuracy {}'.format(best_accu))
    if best_bleu > 0.0:
        print('With initial dev BLEU score {}'.format(best_bleu))

    init_scale = 0.01
    with tf.Graph().as_default():
        initializer = tf.random_uniform_initializer(-init_scale, init_scale)
        with tf.name_scope("Train"):
            with tf.variable_scope("Model", reuse=None, initializer=initializer):
                train_graph = ModelGraph(word_vocab_enc=word_vocab_enc, word_vocab_dec=word_vocab_dec, Edgelabel_vocab=edgelabel_vocab,
                                         char_vocab=char_vocab, options=FLAGS, mode=FLAGS.mode)

        assert FLAGS.mode in ('ce_train', 'rl_train', )
        valid_mode = 'evaluate' if FLAGS.mode == 'ce_train' else 'evaluate_bleu'

        with tf.name_scope("Valid"):
            with tf.variable_scope("Model", reuse=True, initializer=initializer):
                valid_graph = ModelGraph(word_vocab_enc=word_vocab_enc, word_vocab_dec=word_vocab_dec, Edgelabel_vocab=edgelabel_vocab,
                                         char_vocab=char_vocab, options=FLAGS, mode=valid_mode)

        initializer = tf.global_variables_initializer()

        for var in tf.trainable_variables():
            print(var)

        vars_ = {}
        for var in tf.all_variables():
            if FLAGS.fix_word_vec and "word_embedding" in var.name: continue
            if not var.name.startswith("Model"): continue
            vars_[var.name.split(":")[0]] = var
        saver = tf.train.Saver(vars_)

        sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
        sess.run(initializer)
        if has_pretrained_model:
            print("Restoring model from " + best_path)
            saver.restore(sess, best_path)
            print("DONE!")

            if FLAGS.mode == 'rl_train' and abs(best_bleu) < 0.00001:
                print("Getting BLEU score for the model")
                sys.stdout.flush()
                best_bleu = evaluate(sess, valid_graph, devDataStream, options=FLAGS)['dev_bleu']
                FLAGS.best_bleu = best_bleu
                namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")
                print('BLEU = %.4f' % best_bleu)
                sys.stdout.flush()
                log_file.write('BLEU = %.4f\n' % best_bleu)
            if FLAGS.mode == 'ce_train' and abs(best_accu) < 0.00001:
                print("Getting ACCU score for the model")
                best_accu = evaluate(sess, valid_graph, devDataStream, options=FLAGS)['dev_accu']
                FLAGS.best_accu = best_accu
                namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")
                print('ACCU = %.4f' % best_accu)
                log_file.write('ACCU = %.4f\n' % best_accu)

        print('Start the training loop.')
        train_size = trainDataStream.get_num_batch()
        max_steps = train_size * FLAGS.max_epochs
        total_loss = 0.0
        start_time = time.time()
        for step in xrange(max_steps):
            cur_batch = trainDataStream.nextBatch()
            cur_batch = G2S_data_stream.G2SBatchPadd(cur_batch)
            if FLAGS.mode == 'rl_train':
                loss_value = train_graph.run_rl_training_subsample(sess, cur_batch, FLAGS)
            elif FLAGS.mode == 'ce_train':
                loss_value = train_graph.run_ce_training(sess, cur_batch, FLAGS)
            total_loss += loss_value

            if step % 100==0:
                print('{} '.format(step), end="")
                sys.stdout.flush()


            # Save a checkpoint and evaluate the model periodically.
            if (step + 1) % trainDataStream.get_num_batch() == 0 or (step + 1) == max_steps or (step != 0 and step%2000 == 0):
                print()
                duration = time.time() - start_time
                print('Step %d: loss = %.2f (%.3f sec)' % (step, total_loss, duration))
                log_file.write('Step %d: loss = %.2f (%.3f sec)\n' % (step, total_loss, duration))
                log_file.flush()
                sys.stdout.flush()
                total_loss = 0.0

                # Evaluate against the validation set.
                start_time = time.time()
                print('Validation Data Eval:')
                res_dict = evaluate(sess, valid_graph, devDataStream, options=FLAGS, suffix=str(step))
                if valid_graph.mode == 'evaluate':
                    dev_loss = res_dict['dev_loss']
                    dev_accu = res_dict['dev_accu']
                    dev_right = int(res_dict['dev_right'])
                    dev_total = int(res_dict['dev_total'])
                    print('Dev loss = %.4f' % dev_loss)
                    log_file.write('Dev loss = %.4f\n' % dev_loss)
                    print('Dev accu = %.4f %d/%d' % (dev_accu, dev_right, dev_total))
                    log_file.write('Dev accu = %.4f %d/%d\n' % (dev_accu, dev_right, dev_total))
                    log_file.flush()
                    if best_accu < dev_accu:
                        print('Saving weights, ACCU {} (prev_best) < {} (cur)'.format(best_accu, dev_accu))
                        saver.save(sess, best_path)
                        best_accu = dev_accu
                        FLAGS.best_accu = dev_accu
                        namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")
                else:
                    dev_bleu = res_dict['dev_bleu']
                    print('Dev bleu = %.4f' % dev_bleu)
                    log_file.write('Dev bleu = %.4f\n' % dev_bleu)
                    log_file.flush()
                    if best_bleu < dev_bleu:
                        print('Saving weights, BLEU {} (prev_best) < {} (cur)'.format(best_bleu, dev_bleu))
                        saver.save(sess, best_path)
                        best_bleu = dev_bleu
                        FLAGS.best_bleu = dev_bleu
                        namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")
                duration = time.time() - start_time
                print('Duration %.3f sec' % (duration))
                sys.stdout.flush()

                log_file.write('Duration %.3f sec\n' % (duration))
                log_file.flush()

    log_file.close()
示例#11
0
def main(_):
    print('Configurations:')
    print(FLAGS)

    train_path = FLAGS.train_path
    dev_path = FLAGS.dev_path
    test_path = FLAGS.test_path
    word_vec_path = FLAGS.word_vec_path
    log_dir = FLAGS.model_dir
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    
    path_prefix = log_dir + "/SentenceMatch.{}".format(FLAGS.suffix)

    namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")

    # build vocabs
    word_vocab = Vocab(word_vec_path, fileformat='txt3')
    best_path = path_prefix + '.best.model'
    char_path = path_prefix + ".char_vocab"
    label_path = path_prefix + ".label_vocab"
    POS_path = path_prefix + ".POS_vocab"
    NER_path = path_prefix + ".NER_vocab"
    has_pre_trained_model = False
    POS_vocab = None
    NER_vocab = None
    if os.path.exists(best_path):
        has_pre_trained_model = True
        label_vocab = Vocab(label_path, fileformat='txt2')
        char_vocab = Vocab(char_path, fileformat='txt2')
        if FLAGS.with_POS: POS_vocab = Vocab(POS_path, fileformat='txt2')
        if FLAGS.with_NER: NER_vocab = Vocab(NER_path, fileformat='txt2')
    else:
        print('Collect words, chars and labels ...')
        (all_words, all_chars, all_labels, all_POSs, all_NERs) = collect_vocabs(train_path, with_POS=FLAGS.with_POS, with_NER=FLAGS.with_NER)
        print('Number of words: {}'.format(len(all_words)))
        print('Number of labels: {}'.format(len(all_labels)))
        label_vocab = Vocab(fileformat='voc', voc=all_labels,dim=2)
        label_vocab.dump_to_txt2(label_path)

        print('Number of chars: {}'.format(len(all_chars)))
        char_vocab = Vocab(fileformat='voc', voc=all_chars,dim=FLAGS.char_emb_dim)
        char_vocab.dump_to_txt2(char_path)
        
        if FLAGS.with_POS:
            print('Number of POSs: {}'.format(len(all_POSs)))
            POS_vocab = Vocab(fileformat='voc', voc=all_POSs,dim=FLAGS.POS_dim)
            POS_vocab.dump_to_txt2(POS_path)
        if FLAGS.with_NER:
            print('Number of NERs: {}'.format(len(all_NERs)))
            NER_vocab = Vocab(fileformat='voc', voc=all_NERs,dim=FLAGS.NER_dim)
            NER_vocab.dump_to_txt2(NER_path)
            

    print('word_vocab shape is {}'.format(word_vocab.word_vecs.shape))
    print('tag_vocab shape is {}'.format(label_vocab.word_vecs.shape))
    num_classes = label_vocab.size()

    print('Build SentenceMatchDataStream ... ')
    trainDataStream = SentenceMatchDataStream(train_path, word_vocab=word_vocab, char_vocab=char_vocab, 
                                              POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, 
                                              batch_size=FLAGS.batch_size, isShuffle=True, isLoop=True, isSort=True, 
                                              max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length)
                                    
    devDataStream = SentenceMatchDataStream(dev_path, word_vocab=word_vocab, char_vocab=char_vocab,
                                              POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, 
                                              batch_size=FLAGS.batch_size, isShuffle=False, isLoop=True, isSort=True, 
                                              max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length)

    testDataStream = SentenceMatchDataStream(test_path, word_vocab=word_vocab, char_vocab=char_vocab, 
                                              POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, 
                                              batch_size=FLAGS.batch_size, isShuffle=False, isLoop=True, isSort=True, 
                                              max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length)

    print('Number of instances in trainDataStream: {}'.format(trainDataStream.get_num_instance()))
    print('Number of instances in devDataStream: {}'.format(devDataStream.get_num_instance()))
    print('Number of instances in testDataStream: {}'.format(testDataStream.get_num_instance()))
    print('Number of batches in trainDataStream: {}'.format(trainDataStream.get_num_batch()))
    print('Number of batches in devDataStream: {}'.format(devDataStream.get_num_batch()))
    print('Number of batches in testDataStream: {}'.format(testDataStream.get_num_batch()))
    
    sys.stdout.flush()
    if FLAGS.wo_char: char_vocab = None

    best_accuracy = 0.0
    init_scale = 0.01
    with tf.Graph().as_default():
        initializer = tf.random_uniform_initializer(-init_scale, init_scale)
#         with tf.name_scope("Train"):
        with tf.variable_scope("Model", reuse=None, initializer=initializer):
            train_graph = SentenceMatchModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab,POS_vocab=POS_vocab, NER_vocab=NER_vocab, 
                 dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type,
                 lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim, 
                 aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=True, MP_dim=FLAGS.MP_dim, 
                 context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num, 
                 fix_word_vec=FLAGS.fix_word_vec,with_filter_layer=FLAGS.with_filter_layer, with_highway=FLAGS.with_highway,
                 word_level_MP_dim=FLAGS.word_level_MP_dim,
                 with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway,
                 highway_layer_num=FLAGS.highway_layer_num,with_lex_decomposition=FLAGS.with_lex_decomposition, 
                 lex_decompsition_dim=FLAGS.lex_decompsition_dim,
                 with_left_match=(not FLAGS.wo_left_match), with_right_match=(not FLAGS.wo_right_match),
                 with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match), 
                 with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match))
            tf.summary.scalar("Training Loss", train_graph.get_loss()) # Add a scalar summary for the snapshot loss.
        
#         with tf.name_scope("Valid"):
        with tf.variable_scope("Model", reuse=True, initializer=initializer):
            valid_graph = SentenceMatchModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, 
                 dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type,
                 lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim, 
                 aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=False, MP_dim=FLAGS.MP_dim, 
                 context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num, 
                 fix_word_vec=FLAGS.fix_word_vec,with_filter_layer=FLAGS.with_filter_layer, with_highway=FLAGS.with_highway,
                 word_level_MP_dim=FLAGS.word_level_MP_dim,
                 with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway,
                 highway_layer_num=FLAGS.highway_layer_num, with_lex_decomposition=FLAGS.with_lex_decomposition, 
                 lex_decompsition_dim=FLAGS.lex_decompsition_dim,
                 with_left_match=(not FLAGS.wo_left_match), with_right_match=(not FLAGS.wo_right_match),
                 with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match), 
                 with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match))

                
        initializer = tf.global_variables_initializer()
        vars_ = {}
        for var in tf.all_variables():
            if "word_embedding" in var.name: continue
#             if not var.name.startswith("Model"): continue
            vars_[var.name.split(":")[0]] = var
        saver = tf.train.Saver(vars_)
         
        sess = tf.Session()
        sess.run(initializer)
        if has_pre_trained_model:
            print("Restoring model from " + best_path)
            saver.restore(sess, best_path)
            print("DONE!")

        print('Start the training loop.')
        train_size = trainDataStream.get_num_batch()
        max_steps = train_size * FLAGS.max_epochs
        total_loss = 0.0
        start_time = time.time()
        for step in xrange(max_steps):
            # read data
            cur_batch = trainDataStream.nextBatch()
            (label_batch, sent1_batch, sent2_batch, label_id_batch, word_idx_1_batch, word_idx_2_batch, 
                                 char_matrix_idx_1_batch, char_matrix_idx_2_batch, sent1_length_batch, sent2_length_batch, 
                                 sent1_char_length_batch, sent2_char_length_batch,
                                 POS_idx_1_batch, POS_idx_2_batch, NER_idx_1_batch, NER_idx_2_batch) = cur_batch
            feed_dict = {
                         train_graph.get_truth(): label_id_batch, 
                         train_graph.get_question_lengths(): sent1_length_batch, 
                         train_graph.get_passage_lengths(): sent2_length_batch, 
                         train_graph.get_in_question_words(): word_idx_1_batch, 
                         train_graph.get_in_passage_words(): word_idx_2_batch, 
#                          train_graph.get_question_char_lengths(): sent1_char_length_batch, 
#                          train_graph.get_passage_char_lengths(): sent2_char_length_batch, 
#                          train_graph.get_in_question_chars(): char_matrix_idx_1_batch, 
#                          train_graph.get_in_passage_chars(): char_matrix_idx_2_batch, 
                         }
            if char_vocab is not None:
                feed_dict[train_graph.get_question_char_lengths()] = sent1_char_length_batch
                feed_dict[train_graph.get_passage_char_lengths()] = sent2_char_length_batch
                feed_dict[train_graph.get_in_question_chars()] = char_matrix_idx_1_batch
                feed_dict[train_graph.get_in_passage_chars()] = char_matrix_idx_2_batch

            if POS_vocab is not None:
                feed_dict[train_graph.get_in_question_poss()] = POS_idx_1_batch
                feed_dict[train_graph.get_in_passage_poss()] = POS_idx_2_batch

            if NER_vocab is not None:
                feed_dict[train_graph.get_in_question_ners()] = NER_idx_1_batch
                feed_dict[train_graph.get_in_passage_ners()] = NER_idx_2_batch

            _, loss_value = sess.run([train_graph.get_train_op(), train_graph.get_loss()], feed_dict=feed_dict)
            total_loss += loss_value
            
            if step % 100==0: 
                print('{} '.format(step), end="")
                sys.stdout.flush()

            # Save a checkpoint and evaluate the model periodically.
            if (step + 1) % trainDataStream.get_num_batch() == 0 or (step + 1) == max_steps:
                print()
                # Print status to stdout.
                duration = time.time() - start_time
                start_time = time.time()
                print('Step %d: loss = %.2f (%.3f sec)' % (step, total_loss, duration))
                total_loss = 0.0

                # Evaluate against the validation set.
                print('Validation Data Eval:')
                accuracy = evaluate(devDataStream, valid_graph, sess,char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab)
                print("Current accuracy is %.2f" % accuracy)
                if accuracy>best_accuracy:
                    best_accuracy = accuracy
                    saver.save(sess, best_path)

    print("Best accuracy on dev set is %.2f" % best_accuracy)
    # decoding
    print('Decoding on the test set:')
    init_scale = 0.01
    with tf.Graph().as_default():
        initializer = tf.random_uniform_initializer(-init_scale, init_scale)
        with tf.variable_scope("Model", reuse=False, initializer=initializer):
            valid_graph = SentenceMatchModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, 
                 dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type,
                 lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim, 
                 aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=False, MP_dim=FLAGS.MP_dim, 
                 context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num, 
                 fix_word_vec=FLAGS.fix_word_vec,with_filter_layer=FLAGS.with_filter_layer, with_highway=FLAGS.with_highway,
                 word_level_MP_dim=FLAGS.word_level_MP_dim,
                 with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway,
                 highway_layer_num=FLAGS.highway_layer_num, with_lex_decomposition=FLAGS.with_lex_decomposition, 
                 lex_decompsition_dim=FLAGS.lex_decompsition_dim,
                 with_left_match=(not FLAGS.wo_left_match), with_right_match=(not FLAGS.wo_right_match),
                 with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match), 
                 with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match))
        vars_ = {}
        for var in tf.all_variables():
            if "word_embedding" in var.name: continue
            if not var.name.startswith("Model"): continue
            vars_[var.name.split(":")[0]] = var
        saver = tf.train.Saver(vars_)
                
        sess = tf.Session()
        sess.run(tf.global_variables_initializer())
        step = 0
        saver.restore(sess, best_path)

        accuracy = evaluate(testDataStream, valid_graph, sess,char_vocab=char_vocab,POS_vocab=POS_vocab, NER_vocab=NER_vocab)
        print("Accuracy for test set is %.2f" % accuracy)
示例#12
0
def main(_):
    log_dir = FLAGS.model_dir
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    path_prefix = log_dir + "/G2S.{}".format(FLAGS.suffix)
    log_file_path = path_prefix + ".log"
    print('Log file path: {}'.format(log_file_path))
    log_file = open(log_file_path, 'wt')
    log_file.write("{}\n".format(FLAGS))
    log_file.flush()

    # save configuration
    namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")

    print('Loading data.')
    FLAGS.num_relations = 2
    trainset = G2S_data_stream.read_bionlp_file(FLAGS.train_path, FLAGS.train_dep_path, FLAGS)
    if FLAGS.dev_gen == 'shuffle':
        random.shuffle(trainset)
    elif FLAGS.dev_gen == 'last':
        trainset.reverse()
    N = int(len(trainset)*FLAGS.dev_percent)
    devset = trainset[:N]
    trainset = trainset[N:]

    print('Number of training samples: {}'.format(len(trainset)))
    print('Number of dev samples: {}'.format(len(devset)))
    print('Number of relations: {}'.format(FLAGS.num_relations))

    word_vocab = None
    char_vocab = None
    POS_vocab = None
    edgelabel_vocab = None
    has_pretrained_model = False
    best_path = path_prefix + ".best.model"
    if os.path.exists(best_path + ".index"):
        has_pretrained_model = True
        print('!!Existing pretrained model. Loading vocabs.')
        word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2')
        print('word_vocab: {}'.format(word_vocab.word_vecs.shape))
        if FLAGS.with_char:
            char_vocab = Vocab(path_prefix + ".char_vocab", fileformat='txt2')
            print('char_vocab: {}'.format(char_vocab.word_vecs.shape))
        if FLAGS.with_POS:
            POS_vocab = Vocab(path_prefix + ".POS_vocab", fileformat='txt2')
            print('POS_vocab: {}'.format(POS_vocab.word_vecs.shape))
        edgelabel_vocab = Vocab(path_prefix + ".edgelabel_vocab", fileformat='txt2')
        print('edgelabel_vocab: {}'.format(edgelabel_vocab.word_vecs.shape))
    else:
        print('Collecting vocabs.')
        all_words = set()
        all_chars = set()
        all_poses = set()
        all_edgelabels = set()
        G2S_data_stream.collect_vocabs(trainset, all_words, all_chars, all_poses, all_edgelabels)
        G2S_data_stream.collect_vocabs(devset, all_words, all_chars, all_poses, all_edgelabels)
        print('Number of words: {}'.format(len(all_words)))
        print('Number of chars: {}'.format(len(all_chars)))
        print('Number of poses: {}'.format(len(all_poses)))
        print('Number of edgelabels: {}'.format(len(all_edgelabels)))

        word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2')
        if FLAGS.with_char:
            char_vocab = Vocab(voc=all_chars, dim=FLAGS.char_dim, fileformat='build')
            char_vocab.dump_to_txt2(path_prefix + ".char_vocab")
        if FLAGS.with_POS:
            POS_vocab = Vocab(voc=all_poses, dim=FLAGS.POS_dim, fileformat='build')
            POS_vocab.dump_to_txt2(path_prefix + ".POS_vocab")
        edgelabel_vocab = Vocab(voc=all_edgelabels, dim=FLAGS.edgelabel_dim, fileformat='build')
        edgelabel_vocab.dump_to_txt2(path_prefix + ".edgelabel_vocab")

    print('word vocab size {}'.format(word_vocab.vocab_size))
    sys.stdout.flush()

    print('Build DataStream ... ')
    trainDataStream = G2S_data_stream.G2SDataStream(FLAGS, trainset, word_vocab, char_vocab, POS_vocab, edgelabel_vocab,
            isShuffle=True, isLoop=True, isSort=True, is_training=True)

    devDataStream = G2S_data_stream.G2SDataStream(FLAGS, devset, word_vocab, char_vocab, POS_vocab, edgelabel_vocab,
            isShuffle=False, isLoop=False, isSort=True)
    print('Number of instances in trainDataStream: {}'.format(trainDataStream.get_num_instance()))
    print('Number of instances in devDataStream: {}'.format(devDataStream.get_num_instance()))
    print('Number of batches in trainDataStream: {}'.format(trainDataStream.get_num_batch()))
    print('Number of batches in devDataStream: {}'.format(devDataStream.get_num_batch()))
    sys.stdout.flush()

    FLAGS.trn_bch_num = trainDataStream.get_num_batch()

    # initialize the best bleu and accu scores for current training session
    best_accu = FLAGS.best_accu if FLAGS.__dict__.has_key('best_accu') else 0.0
    if best_accu > 0.0:
        print('With initial dev accuracy {}'.format(best_accu))

    init_scale = 0.01
    with tf.Graph().as_default():
        initializer = tf.random_uniform_initializer(-init_scale, init_scale)
        with tf.name_scope("Train"):
            with tf.variable_scope("Model", reuse=None, initializer=initializer):
                train_graph = ModelGraph(word_vocab, char_vocab, POS_vocab, edgelabel_vocab,
                                         FLAGS, mode='train')

        with tf.name_scope("Valid"):
            with tf.variable_scope("Model", reuse=True, initializer=initializer):
                valid_graph = ModelGraph(word_vocab, char_vocab, POS_vocab, edgelabel_vocab,
                                         FLAGS, mode='evaluate')

        initializer = tf.global_variables_initializer()

        vars_ = {}
        for var in tf.all_variables():
            if FLAGS.fix_word_vec and "word_embedding" in var.name: continue
            if not var.name.startswith("Model"): continue
            vars_[var.name.split(":")[0]] = var
        saver = tf.train.Saver(vars_)

        sess = tf.Session()
        sess.run(initializer)
        if has_pretrained_model:
            print("Restoring model from " + best_path)
            saver.restore(sess, best_path)
            print("DONE!")

            if abs(best_accu) < 1e-5:
                print("Getting ACCU score for the model")
                best_accu = evaluate(sess, valid_graph, devDataStream, FLAGS)['dev_f1']
                FLAGS.best_accu = best_accu
                namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")
                print('ACCU = %.4f' % best_accu)
                log_file.write('ACCU = %.4f\n' % best_accu)

        print('Start the training loop.')
        train_size = trainDataStream.get_num_batch()
        max_steps = train_size * FLAGS.max_epochs
        last_step = 0
        total_loss = 0.0
        start_time = time.time()
        for step in xrange(max_steps):
            cur_batch = trainDataStream.nextBatch()
            _, _, cur_loss, _ = train_graph.execute(sess, cur_batch, FLAGS, is_train=True)
            total_loss += cur_loss

            if step % 100==0:
                print('{} '.format(step), end="")
                sys.stdout.flush()

            # Save a checkpoint and evaluate the model periodically.
            if (step + 1) % trainDataStream.get_num_batch() == 0 or (step + 1) == max_steps:
                print()
                duration = time.time() - start_time
                print('Step %d: loss = %.2f (%.3f sec)' % (step, total_loss/(step-last_step), duration))
                log_file.write('Step %d: loss = %.2f (%.3f sec)\n' % (step, total_loss/(step-last_step), duration))
                sys.stdout.flush()
                log_file.flush()
                last_step = step
                total_loss = 0.0

                # Evaluate against the validation set.
                start_time = time.time()
                print('Validation Data Eval:')
                res_dict = evaluate(sess, valid_graph, devDataStream, FLAGS)
                dev_loss = res_dict['dev_loss']
                dev_accu = res_dict['dev_f1']
                dev_precision = res_dict['dev_precision']
                dev_recall = res_dict['dev_recall']
                print('Dev loss = %.4f' % dev_loss)
                log_file.write('Dev loss = %.4f\n' % dev_loss)
                print('Dev F1 = %.4f, P = %.4f, R = %.4f' % (dev_accu, dev_precision, dev_recall))
                log_file.write('Dev F1 = %.4f, P =  %.4f, R = %.4f\n' % (dev_accu, dev_precision, dev_recall))
                log_file.flush()
                if best_accu < dev_accu:
                    print('Saving weights, ACCU {} (prev_best) < {} (cur)'.format(best_accu, dev_accu))
                    saver.save(sess, best_path)
                    best_accu = dev_accu
                    FLAGS.best_accu = dev_accu
                    namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")
                duration = time.time() - start_time
                print('Duration %.3f sec' % (duration))
                sys.stdout.flush()

                log_file.write('Duration %.3f sec\n' % (duration))
                log_file.flush()
                start_time = time.time()

    log_file.close()
示例#13
0
def main(_):
    print('Configurations:')
    print(FLAGS)
    train_path = FLAGS.train_path
    dev_path = FLAGS.dev_path
    test_path = FLAGS.test_path
    word_vec_path = FLAGS.word_vec_path
    log_dir = FLAGS.model_dir
    result_dir = '../result'
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    if not os.path.exists(result_dir):
        os.makedirs(result_dir)

    path_prefix = log_dir + "/Han.{}".format(FLAGS.suffix)
    save_namespace(FLAGS, path_prefix + ".config.json")
    word_vocab = Vocab(word_vec_path, fileformat='txt3')
    best_path = path_prefix + '.best.model'
    label_path = path_prefix + ".label_vocab"

    print('Collect words and labels ...')
    (all_words, all_labels) = collect_vocabs(train_path)
    print('Number of words: {}'.format(len(all_words)))
    print('Number of labels: {}'.format(len(all_labels)))
    label_vocab = Vocab(fileformat='voc', voc=all_labels, dim=2)
    label_vocab.dump_to_txt2(label_path)

    print('word_vocab shape is {}'.format(word_vocab.word_vecs.shape))
    print('tag_vocab shape is {}'.format(label_vocab.word_vecs.shape))
    num_classes = label_vocab.size()

    print('Build HanDataStream ... ')
    trainDataStream = HanDataStream(inpath=train_path, word_vocab=word_vocab,
                                              label_vocab=label_vocab,
                                              isShuffle=True, isLoop=True,
                                              max_sent_length=FLAGS.max_sent_length)
    devDataStream = HanDataStream(inpath=dev_path, word_vocab=word_vocab,
                                              label_vocab=label_vocab,
                                              isShuffle=False, isLoop=True,
                                              max_sent_length=FLAGS.max_sent_length)
    testDataStream = HanDataStream(inpath=test_path, word_vocab=word_vocab,
                                              label_vocab=label_vocab,
                                              isShuffle=False, isLoop=True,
                                              max_sent_length=FLAGS.max_sent_length)

    print('Number of instances in trainDataStream: {}'.format(trainDataStream.get_num_instance()))
    print('Number of instances in devDataStream: {}'.format(devDataStream.get_num_instance()))
    print('Number of instances in testDataStream: {}'.format(testDataStream.get_num_instance()))

    with tf.Graph().as_default():
        initializer = tf.contrib.layers.xavier_initializer()
        with tf.variable_scope("Model", reuse=None, initializer=initializer):
            train_graph = HanModelGraph(num_classes=num_classes, word_vocab=word_vocab,
                                                  dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate,
                                                  lambda_l2=FLAGS.lambda_l2,
                                                  context_lstm_dim=FLAGS.context_lstm_dim,
                                                  is_training=True, batch_size = FLAGS.batch_size)
            tf.summary.scalar("Training Loss", train_graph.loss)  # Add a scalar summary for the snapshot loss.
        print("Train Graph Build")
        with tf.variable_scope("Model", reuse=True, initializer=initializer):
            valid_graph = HanModelGraph(num_classes=num_classes, word_vocab=word_vocab,
                                                  dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate,
                                                  lambda_l2=FLAGS.lambda_l2,
                                                  context_lstm_dim=FLAGS.context_lstm_dim,
                                                  is_training=False, batch_size = 1)
        print ("dev Graph Build")
        initializer = tf.global_variables_initializer()
        vars_ = {}
        for var in tf.all_variables():
            if "word_embedding" in var.name: continue
            #             if not var.name.startswith("Model"): continue
            vars_[var.name.split(":")[0]] = var
        saver = tf.train.Saver(vars_)

        output_res_file = open(result_dir + '/' + FLAGS.suffix, 'wt')
        output_res_file.write(str(FLAGS))
        with tf.Session() as sess:
            sess.run(initializer)
            train_size = trainDataStream.get_num_instance()
            max_steps = (train_size * FLAGS.max_epochs) // FLAGS.batch_size
            epoch_size = max_steps // (FLAGS.max_epochs)  # + 1

            total_loss = 0.0
            start_time = time.time()
            best_accuracy = 0
            for step in range(max_steps):
                # read data
                # _truth = []
                # _sents_length = []
                # _in_text_words = []
                # for i in range(FLAGS.batch_size):
                #     cur_instance, instance_index = trainDataStream.nextInstance ()
                #     (label,text,label_id, word_idx, sents_length) = cur_instance
                #
                #     _truth.append(label_id)
                #     _sents_length.append(sents_length)
                #     _in_text_words.append(word_idx)
                #
                # feed_dict = {
                #     train_graph.truth: np.array(_truth),
                #     train_graph.sents_length: tuple(_sents_length),
                #     train_graph.in_text_words: tuple(_in_text_words),
                # }

                feed_dict = get_feed_dict(data_stream=trainDataStream, graph=train_graph, batch_size=FLAGS.batch_size, is_testing=False)

                _, loss_value, _score = sess.run([train_graph.train_op, train_graph.loss
                                                     , train_graph.batch_class_scores],
                                                 feed_dict=feed_dict)
                total_loss += loss_value

                if step % 100 == 0:
                    print('{} '.format(step), end="")
                    sys.stdout.flush()
                if (step + 1) % epoch_size == 0 or (step + 1) == max_steps:
                    # print(total_loss)
                    duration = time.time() - start_time
                    start_time = time.time()
                    print(duration, step, "Loss: ", total_loss)
                    output_res_file.write('\nStep %d: loss = %.2f (%.3f sec)\n' % (step, total_loss, duration))
                    total_loss = 0.0
                    # Evaluate against the validation set.
                    output_res_file.write('valid- ')
                    dev_accuracy = evaluate(devDataStream, valid_graph, sess)
                    output_res_file.write("%.2f\n" % dev_accuracy)
                    print("Current dev accuracy is %.2f" % dev_accuracy)
                    if dev_accuracy > best_accuracy:
                        best_accuracy = dev_accuracy
                        saver.save(sess, best_path)
                    output_res_file.write('test- ')
                    test_accuracy = evaluate(testDataStream, valid_graph, sess)
                    print("Current test accuracy is %.2f" % test_accuracy)
                    output_res_file.write("%.2f\n" % test_accuracy)

    output_res_file.close()
    sys.stdout.flush()
def main(_):
    print('Configurations:')
    print(FLAGS)

    log_dir = FLAGS.model_dir
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    path_prefix = log_dir + "/NP2P.{}".format(FLAGS.suffix)
    init_model_prefix = FLAGS.init_model  # "/u/zhigwang/zhigwang1/sentence_generation/mscoco/logs/NP2P.phrase_ce_train"
    log_file_path = path_prefix + ".log"
    print('Log file path: {}'.format(log_file_path))
    log_file = open(log_file_path, 'wt')
    log_file.write("{}\n".format(FLAGS))
    log_file.flush()

    # save configuration
    namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")

    print('Loading train set.')
    if FLAGS.infile_format == 'fof':
        trainset, train_ans_len = NP2P_data_stream.read_generation_datasets_from_fof(
            FLAGS.train_path, isLower=FLAGS.isLower)
        if FLAGS.max_answer_len > train_ans_len:
            FLAGS.max_answer_len = train_ans_len
    else:
        trainset, train_ans_len = NP2P_data_stream.read_all_GQA_questions(
            FLAGS.train_path, isLower=FLAGS.isLower)
    print('Number of training samples: {}'.format(len(trainset)))

    print('Loading test set.')
    if FLAGS.infile_format == 'fof':
        testset, test_ans_len = NP2P_data_stream.read_generation_datasets_from_fof(
            FLAGS.test_path, isLower=FLAGS.isLower)
    else:
        testset, test_ans_len = NP2P_data_stream.read_all_GQA_questions(
            FLAGS.test_path, isLower=FLAGS.isLower)
    print('Number of test samples: {}'.format(len(testset)))

    max_actual_len = max(train_ans_len, test_ans_len)
    print('Max answer length: {}, truncated to {}'.format(
        max_actual_len, FLAGS.max_answer_len))

    word_vocab = None
    POS_vocab = None
    NER_vocab = None
    char_vocab = None
    has_pretrained_model = False
    best_path = path_prefix + ".best.model"
    if os.path.exists(init_model_prefix + ".best.model.index"):
        has_pretrained_model = True
        print('!!Existing pretrained model. Loading vocabs.')
        if FLAGS.with_word:
            word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2')
            print('word_vocab: {}'.format(word_vocab.word_vecs.shape))
        if FLAGS.with_char:
            char_vocab = Vocab(init_model_prefix + ".char_vocab",
                               fileformat='txt2')
            print('char_vocab: {}'.format(char_vocab.word_vecs.shape))
        if FLAGS.with_POS:
            POS_vocab = Vocab(init_model_prefix + ".POS_vocab",
                              fileformat='txt2')
            print('POS_vocab: {}'.format(POS_vocab.word_vecs.shape))
        if FLAGS.with_NER:
            NER_vocab = Vocab(init_model_prefix + ".NER_vocab",
                              fileformat='txt2')
            print('NER_vocab: {}'.format(NER_vocab.word_vecs.shape))
    else:
        print('Collecting vocabs.')
        (allWords, allChars, allPOSs,
         allNERs) = NP2P_data_stream.collect_vocabs(trainset)
        print('Number of words: {}'.format(len(allWords)))
        print('Number of allChars: {}'.format(len(allChars)))
        print('Number of allPOSs: {}'.format(len(allPOSs)))
        print('Number of allNERs: {}'.format(len(allNERs)))

        if FLAGS.with_word:
            word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2')
        if FLAGS.with_char:
            char_vocab = Vocab(voc=allChars,
                               dim=FLAGS.char_dim,
                               fileformat='build')
            char_vocab.dump_to_txt2(path_prefix + ".char_vocab")
        if FLAGS.with_POS:
            POS_vocab = Vocab(voc=allPOSs,
                              dim=FLAGS.POS_dim,
                              fileformat='build')
            POS_vocab.dump_to_txt2(path_prefix + ".POS_vocab")
        if FLAGS.with_NER:
            NER_vocab = Vocab(voc=allNERs,
                              dim=FLAGS.NER_dim,
                              fileformat='build')
            NER_vocab.dump_to_txt2(path_prefix + ".NER_vocab")

    print('word vocab size {}'.format(word_vocab.vocab_size))
    sys.stdout.flush()

    print('Build DataStream ... ')
    trainDataStream = NP2P_data_stream.QADataStream(trainset,
                                                    word_vocab,
                                                    char_vocab,
                                                    POS_vocab,
                                                    NER_vocab,
                                                    options=FLAGS,
                                                    isShuffle=True,
                                                    isLoop=True,
                                                    isSort=True)

    devDataStream = NP2P_data_stream.QADataStream(testset,
                                                  word_vocab,
                                                  char_vocab,
                                                  POS_vocab,
                                                  NER_vocab,
                                                  options=FLAGS,
                                                  isShuffle=False,
                                                  isLoop=False,
                                                  isSort=True)
    print('Number of instances in trainDataStream: {}'.format(
        trainDataStream.get_num_instance()))
    print('Number of instances in devDataStream: {}'.format(
        devDataStream.get_num_instance()))
    print('Number of batches in trainDataStream: {}'.format(
        trainDataStream.get_num_batch()))
    print('Number of batches in devDataStream: {}'.format(
        devDataStream.get_num_batch()))
    sys.stdout.flush()

    init_scale = 0.01
    with tf.Graph().as_default():
        initializer = tf.random_uniform_initializer(-init_scale, init_scale)
        with tf.name_scope("Train"):
            with tf.variable_scope("Model",
                                   reuse=None,
                                   initializer=initializer):
                train_graph = ModelGraph(word_vocab=word_vocab,
                                         char_vocab=char_vocab,
                                         POS_vocab=POS_vocab,
                                         NER_vocab=NER_vocab,
                                         options=FLAGS,
                                         mode="rl_train_for_phrase")

        with tf.name_scope("Valid"):
            with tf.variable_scope("Model",
                                   reuse=True,
                                   initializer=initializer):
                valid_graph = ModelGraph(word_vocab=word_vocab,
                                         char_vocab=char_vocab,
                                         POS_vocab=POS_vocab,
                                         NER_vocab=NER_vocab,
                                         options=FLAGS,
                                         mode="decode")

        initializer = tf.global_variables_initializer()

        vars_ = {}
        for var in tf.all_variables():
            if "word_embedding" in var.name: continue
            if not var.name.startswith("Model"): continue
            vars_[var.name.split(":")[0]] = var
        saver = tf.train.Saver(vars_)

        sess = tf.Session()
        sess.run(initializer)
        if has_pretrained_model:
            print("Restoring model from " + init_model_prefix + ".best.model")
            saver.restore(sess, init_model_prefix + ".best.model")
            print("DONE!")
        sys.stdout.flush()

        # for first-time rl training, we get the current BLEU score
        print("First-time rl training, get the current BLEU score on dev")
        sys.stdout.flush()
        best_bleu = evaluate(sess,
                             valid_graph,
                             devDataStream,
                             word_vocab,
                             options=FLAGS)
        print('First-time bleu = %.4f' % best_bleu)
        log_file.write('First-time bleu = %.4f\n' % best_bleu)

        print('Start the training loop.')
        sys.stdout.flush()
        train_size = trainDataStream.get_num_batch()
        max_steps = train_size * FLAGS.max_epochs
        total_loss = 0.0
        start_time = time.time()
        for step in xrange(max_steps):
            cur_batch = trainDataStream.nextBatch()
            if FLAGS.with_baseline:
                # greedy search
                (greedy_sentences, _, _,
                 _) = NP2P_beam_decoder.search(sess,
                                               valid_graph,
                                               word_vocab,
                                               cur_batch,
                                               FLAGS,
                                               decode_mode="greedy")

            if FLAGS.with_target_lattice:
                (sampled_sentences, sampled_prediction_lengths,
                 sampled_generator_input_idx, sampled_generator_output_idx
                 ) = cur_batch.sample_a_partition()
            else:
                # multinomial sampling
                (sampled_sentences, sampled_prediction_lengths,
                 sampled_generator_input_idx,
                 sampled_generator_output_idx) = NP2P_beam_decoder.search(
                     sess,
                     valid_graph,
                     word_vocab,
                     cur_batch,
                     FLAGS,
                     decode_mode="multinomial")
            # calculate rewards
            rewards = []
            for i in xrange(cur_batch.batch_size):
                #                 print(sampled_sentences[i])
                #                 print(sampled_generator_input_idx[i])
                #                 print(sampled_generator_output_idx[i])
                cur_toks = cur_batch.instances[i][1].tokText.split()
                #                 r = sentence_bleu([cur_toks], sampled_sentences[i].split(), smoothing_function=cc.method3)
                r = 1.0
                b = 0.0
                if FLAGS.with_baseline:
                    b = sentence_bleu([cur_toks],
                                      greedy_sentences[i].split(),
                                      smoothing_function=cc.method3)


#                 r = metric_utils.evaluate_captions([cur_toks],[sampled_sentences[i]])
#                 b = metric_utils.evaluate_captions([cur_toks],[greedy_sentences[i]])
                rewards.append(1.0 * (r - b))
            rewards = np.array(rewards, dtype=np.float32)
            #             sys.exit(-1)

            # update parameters
            feed_dict = train_graph.run_encoder(sess,
                                                cur_batch,
                                                FLAGS,
                                                only_feed_dict=True)
            feed_dict[train_graph.reward] = rewards
            feed_dict[
                train_graph.gen_input_words] = sampled_generator_input_idx
            feed_dict[
                train_graph.in_answer_words] = sampled_generator_output_idx
            feed_dict[train_graph.answer_lengths] = sampled_prediction_lengths
            (_,
             loss_value) = sess.run([train_graph.train_op, train_graph.loss],
                                    feed_dict)
            total_loss += loss_value

            if step % 100 == 0:
                print('{} '.format(step), end="")
                sys.stdout.flush()

            # Save a checkpoint and evaluate the model periodically.
            if (step + 1) % trainDataStream.get_num_batch() == 0 or (
                    step + 1) == max_steps:
                print()
                duration = time.time() - start_time
                print('Step %d: loss = %.2f (%.3f sec)' %
                      (step, total_loss, duration))
                log_file.write('Step %d: loss = %.2f (%.3f sec)\n' %
                               (step, total_loss, duration))
                log_file.flush()
                sys.stdout.flush()
                total_loss = 0.0

                # Evaluate against the validation set.
                start_time = time.time()
                print('Validation Data Eval:')
                dev_bleu = evaluate(sess,
                                    valid_graph,
                                    devDataStream,
                                    word_vocab,
                                    options=FLAGS)
                print('Dev bleu = %.4f' % dev_bleu)
                log_file.write('Dev bleu = %.4f\n' % dev_bleu)
                log_file.flush()
                if best_bleu < dev_bleu:
                    print('Saving weights, BLEU {} (prev_best) < {} (cur)'.
                          format(best_bleu, dev_bleu))
                    best_bleu = dev_bleu
                    saver.save(sess, best_path)  # TODO: save model
                duration = time.time() - start_time
                print('Duration %.3f sec' % (duration))
                sys.stdout.flush()
                log_file.write('Duration %.3f sec\n' % (duration))
                log_file.flush()

    log_file.close()
示例#15
0
def main(FLAGS):
    train_path = FLAGS.train_path
    dev_path = FLAGS.dev_path
    word_vec_path = FLAGS.word_vec_path
    log_dir = FLAGS.model_dir
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    path_prefix = log_dir + "/SentenceMatch.{}".format(FLAGS.suffix)

    namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")

    # build vocabs
    word_vocab = Vocab(word_vec_path, fileformat='txt3')

    best_path = path_prefix + '.best.model'
    char_path = path_prefix + ".char_vocab"
    label_path = path_prefix + ".label_vocab"
    has_pre_trained_model = False
    char_vocab = None
    if os.path.exists(best_path + ".index"):
        has_pre_trained_model = True
        logger.info('Loading vocabs from a pre-trained model ...')
        label_vocab = Vocab(label_path, fileformat='txt2')
        if FLAGS.with_char: char_vocab = Vocab(char_path, fileformat='txt2')
    else:
        logger.info('Collecting words, chars and labels ...')
        (all_words, all_chars, all_labels, all_POSs,
         all_NERs) = collect_vocabs(train_path)
        logger.info('Number of words: {}'.format(len(all_words)))
        label_vocab = Vocab(fileformat='voc', voc=all_labels, dim=2)
        label_vocab.dump_to_txt2(label_path)

        if FLAGS.with_char:
            logger.info('Number of chars: {}'.format(len(all_chars)))
            char_vocab = Vocab(fileformat='voc',
                               voc=all_chars,
                               dim=FLAGS.char_emb_dim)
            char_vocab.dump_to_txt2(char_path)

    logger.info('word_vocab shape is {}'.format(word_vocab.word_vecs.shape))
    num_classes = label_vocab.size()
    logger.info("Number of labels: {}".format(num_classes))
    sys.stdout.flush()

    logger.info('Build SentenceMatchDataStream ... ')
    trainDataStream = SentenceMatchDataStream(train_path,
                                              word_vocab=word_vocab,
                                              char_vocab=char_vocab,
                                              label_vocab=label_vocab,
                                              isShuffle=True,
                                              isLoop=True,
                                              isSort=True,
                                              options=FLAGS)
    logger.info('Number of instances in trainDataStream: {}'.format(
        trainDataStream.get_num_instance()))
    logger.info('Number of batches in trainDataStream: {}'.format(
        trainDataStream.get_num_batch()))
    sys.stdout.flush()

    devDataStream = SentenceMatchDataStream(dev_path,
                                            word_vocab=word_vocab,
                                            char_vocab=char_vocab,
                                            label_vocab=label_vocab,
                                            isShuffle=False,
                                            isLoop=True,
                                            isSort=True,
                                            options=FLAGS)
    logger.info('Number of instances in devDataStream: {}'.format(
        devDataStream.get_num_instance()))
    logger.info('Number of batches in devDataStream: {}'.format(
        devDataStream.get_num_batch()))
    sys.stdout.flush()

    init_scale = 0.01
    with tf.Graph().as_default():
        initializer = tf.random_uniform_initializer(-init_scale, init_scale)
        global_step = tf.train.get_or_create_global_step()
        with tf.variable_scope("Model", reuse=None, initializer=initializer):
            train_graph = SentenceMatchModelGraph(num_classes,
                                                  word_vocab=word_vocab,
                                                  char_vocab=char_vocab,
                                                  is_training=True,
                                                  options=FLAGS,
                                                  global_step=global_step)

        with tf.variable_scope("Model", reuse=True, initializer=initializer):
            valid_graph = SentenceMatchModelGraph(num_classes,
                                                  word_vocab=word_vocab,
                                                  char_vocab=char_vocab,
                                                  is_training=False,
                                                  options=FLAGS)

        initializer = tf.global_variables_initializer()
        vars_ = {}
        for var in tf.global_variables():
            if "word_embedding" in var.name:
                continue
            vars_[var.name.split(":")[0]] = var
        saver = tf.train.Saver(vars_)

        sess = tf.Session()
        # 初始化写日志的wirter, 并将当前TensorFlow计算图写入日志
        train_writer = tf.summary.FileWriter(SUMMARY_DIR, sess.graph)
        # valid_writer = tf.summary.FileWriter(SUMMARY_DIR + '/valid')

        sess.run(initializer)
        if has_pre_trained_model:
            logger.info("Restoring model from " + best_path)
            saver.restore(sess, best_path)
            logger.info("DONE!")

        # training
        train(sess, saver, train_graph, valid_graph, trainDataStream,
              devDataStream, FLAGS, best_path, train_writer, label_vocab)

        train_writer.close()
示例#16
0
def main(_):
    print('Configurations:')
    print(FLAGS)

    log_dir = FLAGS.model_dir
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    path_prefix = log_dir + "/G2S.{}".format(FLAGS.suffix)
    log_file_path = path_prefix + ".log"
    print('Log file path: {}'.format(log_file_path))
    log_file = open(log_file_path, 'wt')
    log_file.write("{}\n".format(FLAGS))
    log_file.flush()

    # save configuration
    namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")

    print('Loading train set.')
    trainset, trn_node, trn_in_neigh, trn_out_neigh, trn_sent = G2S_data_stream.read_amr_file(
        FLAGS.train_path)
    print('Number of training samples: {}'.format(len(trainset)))

    print('Loading dev set.')
    devset, tst_node, tst_in_neigh, tst_out_neigh, tst_sent = G2S_data_stream.read_amr_file(
        FLAGS.test_path)
    print('Number of dev samples: {}'.format(len(devset)))

    if FLAGS.finetune_path != "":
        print('Loading finetune set.')
        ftset, ft_node, ft_in_neigh, ft_out_neigh, ft_sent = G2S_data_stream.read_amr_file(
            FLAGS.finetune_path)
        print('Number of finetune samples: {}'.format(len(ftset)))
    else:
        ftset, ft_node, ft_in_neigh, ft_out_neigh, ft_sent = (None, 0, 0, 0, 0)

    max_node = max(trn_node, tst_node, ft_node)
    max_in_neigh = max(trn_in_neigh, tst_in_neigh, ft_in_neigh)
    max_out_neigh = max(trn_out_neigh, tst_out_neigh, ft_out_neigh)
    max_sent = max(trn_sent, tst_sent, ft_sent)
    print('Max node number: {}, while max allowed is {}'.format(
        max_node, FLAGS.max_node_num))
    print('Max parent number: {}, truncated to {}'.format(
        max_in_neigh, FLAGS.max_in_neigh_num))
    print('Max children number: {}, truncated to {}'.format(
        max_out_neigh, FLAGS.max_out_neigh_num))
    print('Max answer length: {}, truncated to {}'.format(
        max_sent, FLAGS.max_answer_len))

    word_vocab = None
    char_vocab = None
    edgelabel_vocab = None
    has_pretrained_model = False
    best_path = path_prefix + ".best.model"
    if os.path.exists(best_path + ".index"):
        has_pretrained_model = True
        print('!!Existing pretrained model. Loading vocabs.')
        word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2')
        print('word_vocab: {}'.format(word_vocab.word_vecs.shape))
        char_vocab = None
        if FLAGS.with_char:
            char_vocab = Vocab(path_prefix + ".char_vocab", fileformat='txt2')
            print('char_vocab: {}'.format(char_vocab.word_vecs.shape))
        edgelabel_vocab = Vocab(path_prefix + ".edgelabel_vocab",
                                fileformat='txt2')
    else:
        print('Collecting vocabs.')
        (allWords, allChars,
         allEdgelabels) = G2S_data_stream.collect_vocabs(trainset)
        print('Number of words: {}'.format(len(allWords)))
        print('Number of allChars: {}'.format(len(allChars)))
        print('Number of allEdgelabels: {}'.format(len(allEdgelabels)))

        word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2')
        char_vocab = None
        if FLAGS.with_char:
            char_vocab = Vocab(voc=allChars,
                               dim=FLAGS.char_dim,
                               fileformat='build')
            char_vocab.dump_to_txt2(path_prefix + ".char_vocab")
        edgelabel_vocab = Vocab(voc=allEdgelabels,
                                dim=FLAGS.edgelabel_dim,
                                fileformat='build')
        edgelabel_vocab.dump_to_txt2(path_prefix + ".edgelabel_vocab")

    print('word vocab size {}'.format(word_vocab.vocab_size))
    sys.stdout.flush()

    print('Build DataStream ... ')
    trainDataStream = G2S_data_stream.G2SDataStream(trainset,
                                                    word_vocab,
                                                    char_vocab,
                                                    edgelabel_vocab,
                                                    options=FLAGS,
                                                    isShuffle=True,
                                                    isLoop=True,
                                                    isSort=True)

    devDataStream = G2S_data_stream.G2SDataStream(devset,
                                                  word_vocab,
                                                  char_vocab,
                                                  edgelabel_vocab,
                                                  options=FLAGS,
                                                  isShuffle=False,
                                                  isLoop=False,
                                                  isSort=True)
    print('Number of instances in trainDataStream: {}'.format(
        trainDataStream.get_num_instance()))
    print('Number of instances in devDataStream: {}'.format(
        devDataStream.get_num_instance()))
    print('Number of batches in trainDataStream: {}'.format(
        trainDataStream.get_num_batch()))
    print('Number of batches in devDataStream: {}'.format(
        devDataStream.get_num_batch()))
    if ftset != None:
        ftDataStream = G2S_data_stream.G2SDataStream(ftset,
                                                     word_vocab,
                                                     char_vocab,
                                                     edgelabel_vocab,
                                                     options=FLAGS,
                                                     isShuffle=True,
                                                     isLoop=True,
                                                     isSort=True)
        print('Number of instances in ftDataStream: {}'.format(
            ftDataStream.get_num_instance()))
        print('Number of batches in ftDataStream: {}'.format(
            ftDataStream.get_num_batch()))

    sys.stdout.flush()

    # initialize the best bleu and accu scores for current training session
    best_accu = FLAGS.best_accu if FLAGS.__dict__.has_key('best_accu') else 0.0
    best_bleu = FLAGS.best_bleu if FLAGS.__dict__.has_key('best_bleu') else 0.0
    if best_accu > 0.0:
        print('With initial dev accuracy {}'.format(best_accu))
    if best_bleu > 0.0:
        print('With initial dev BLEU score {}'.format(best_bleu))

    init_scale = 0.01
    with tf.Graph().as_default():
        initializer = tf.random_uniform_initializer(-init_scale, init_scale)
        with tf.name_scope("Train"):
            with tf.variable_scope("Model",
                                   reuse=None,
                                   initializer=initializer):
                train_graph = ModelGraph(word_vocab=word_vocab,
                                         Edgelabel_vocab=edgelabel_vocab,
                                         char_vocab=char_vocab,
                                         options=FLAGS,
                                         mode=FLAGS.mode)

        assert FLAGS.mode in ('ce_train', 'rl_train', 'transformer')
        valid_mode = 'evaluate' if FLAGS.mode in (
            'ce_train', 'transformer') else 'evaluate_bleu'

        with tf.name_scope("Valid"):
            with tf.variable_scope("Model",
                                   reuse=True,
                                   initializer=initializer):
                valid_graph = ModelGraph(word_vocab=word_vocab,
                                         Edgelabel_vocab=edgelabel_vocab,
                                         char_vocab=char_vocab,
                                         options=FLAGS,
                                         mode=valid_mode)

        initializer = tf.global_variables_initializer()

        vars_ = {}
        for var in tf.all_variables():
            if FLAGS.fix_word_vec and "word_embedding" in var.name: continue
            if not var.name.startswith("Model"): continue
            vars_[var.name.split(":")[0]] = var
        saver = tf.train.Saver(vars_)

        sess = tf.Session()
        sess.run(initializer)
        if has_pretrained_model:
            print("Restoring model from " + best_path)
            saver.restore(sess, best_path)
            print("DONE!")

            if FLAGS.mode == 'rl_train' and abs(best_bleu) < 0.00001:
                print("Getting BLEU score for the model")
                sys.stdout.flush()
                best_bleu = evaluate(sess,
                                     valid_graph,
                                     devDataStream,
                                     options=FLAGS)['dev_bleu']
                FLAGS.best_bleu = best_bleu
                namespace_utils.save_namespace(FLAGS,
                                               path_prefix + ".config.json")
                print('BLEU = %.4f' % best_bleu)
                sys.stdout.flush()
                log_file.write('BLEU = %.4f\n' % best_bleu)
            if FLAGS.mode in ('ce_train', 'rl_train',
                              'transformer') and abs(best_accu) < 0.00001:
                print("Getting ACCU score for the model")
                best_accu = evaluate(sess,
                                     valid_graph,
                                     devDataStream,
                                     options=FLAGS)['dev_accu']
                FLAGS.best_accu = best_accu
                namespace_utils.save_namespace(FLAGS,
                                               path_prefix + ".config.json")
                print('ACCU = %.4f' % best_accu)
                log_file.write('ACCU = %.4f\n' % best_accu)

        print('Start the training loop.')
        train_size = trainDataStream.get_num_batch()
        max_steps = train_size * FLAGS.max_epochs
        total_loss = 0.0
        start_time = time.time()
        for step in xrange(max_steps):
            cur_batch = trainDataStream.nextBatch()
            if FLAGS.mode == 'rl_train':
                loss_value = train_graph.run_rl_training_subsample(
                    sess, cur_batch, FLAGS)
            elif FLAGS.mode in ('ce_train', 'rl_train', 'transformer'):
                loss_value = train_graph.run_ce_training(
                    sess, cur_batch, FLAGS)
            total_loss += loss_value

            if step % 100 == 0:
                print('{} '.format(step), end="")
                sys.stdout.flush()

            # Save a checkpoint and evaluate the model periodically.
            if (step + 1) % trainDataStream.get_num_batch() == 0 or (step + 1) == max_steps or \
                    (trainDataStream.get_num_batch() > 10000 and (step+1)%2000 == 0):
                print()
                duration = time.time() - start_time
                print('Step %d: loss = %.2f (%.3f sec)' %
                      (step, total_loss, duration))
                log_file.write('Step %d: loss = %.2f (%.3f sec)\n' %
                               (step, total_loss, duration))
                log_file.flush()
                sys.stdout.flush()
                total_loss = 0.0

                if ftset != None:
                    best_accu, best_bleu = fine_tune(sess, saver, FLAGS,
                                                     log_file, ftDataStream,
                                                     devDataStream,
                                                     train_graph, valid_graph,
                                                     path_prefix, best_accu,
                                                     best_bleu)
                else:
                    best_accu, best_bleu = validate_and_save(
                        sess, saver, FLAGS, log_file, devDataStream,
                        valid_graph, path_prefix, best_accu, best_bleu)
                start_time = time.time()

    log_file.close()
示例#17
0
def main(_):
    print('Configurations:')
    print(FLAGS)

    log_dir = FLAGS.model_dir
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    path_prefix = log_dir + "/MHQA.{}".format(FLAGS.suffix)
    log_file_path = path_prefix + ".log"
    print('Log file path: {}'.format(log_file_path))
    log_file = open(log_file_path, 'wt')
    log_file.write("{}\n".format(FLAGS))
    log_file.flush()

    # save configuration
    namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")

    print('Loading train set.')
    trainset, _, _ = MHQA_data_stream.read_data_file(FLAGS.train_path, FLAGS)
    print('Number of training samples: {}'.format(len(trainset)))

    print('Loading dev set.')
    devset, _, _ = MHQA_data_stream.read_data_file(FLAGS.dev_path, FLAGS)
    print('Number of dev samples: {}'.format(len(devset)))

    word_vocab = None
    char_vocab = None
    has_pretrained_model = False
    best_path = path_prefix + ".best.model"
    if os.path.exists(best_path + ".index"):
        has_pretrained_model = True
        print('!!Existing pretrained model. Loading vocabs.')
        word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2')
        print('word_vocab: {}'.format(word_vocab.word_vecs.shape))
        char_vocab = None
        if FLAGS.with_char:
            char_vocab = Vocab(path_prefix + ".char_vocab", fileformat='txt2')
            print('char_vocab: {}'.format(char_vocab.word_vecs.shape))
    else:
        print('Collecting vocabs.')
        (allWords, allChars) = MHQA_data_stream.collect_vocabs(trainset)
        print('Number of words: {}'.format(len(allWords)))
        print('Number of allChars: {}'.format(len(allChars)))

        word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2')
        char_vocab = None
        if FLAGS.with_char:
            char_vocab = Vocab(voc=allChars,
                               dim=FLAGS.char_dim,
                               fileformat='build')
            char_vocab.dump_to_txt2(path_prefix + ".char_vocab")

    print('word vocab size {}'.format(word_vocab.vocab_size))
    sys.stdout.flush()

    print('Build DataStream ... ')
    trainDataStream = MHQA_data_stream.DataStream(trainset,
                                                  word_vocab,
                                                  char_vocab,
                                                  options=FLAGS,
                                                  isShuffle=True,
                                                  isLoop=True,
                                                  isSort=True,
                                                  has_ref=True)

    devDataStream = MHQA_data_stream.DataStream(devset,
                                                word_vocab,
                                                char_vocab,
                                                options=FLAGS,
                                                isShuffle=False,
                                                isLoop=False,
                                                isSort=True,
                                                has_ref=True)
    print('Number of instances in trainDataStream: {}'.format(
        trainDataStream.get_num_instance()))
    print('Number of instances in devDataStream: {}'.format(
        devDataStream.get_num_instance()))
    print('Number of batches in trainDataStream: {}'.format(
        trainDataStream.get_num_batch()))
    print('Number of batches in devDataStream: {}'.format(
        devDataStream.get_num_batch()))

    sys.stdout.flush()

    # initialize the best bleu and accu scores for current training session
    best_accu = FLAGS.best_accu if FLAGS.__dict__.has_key('best_accu') else 0.0
    if best_accu > 0.0:
        print('With initial dev accuracy {}'.format(best_accu))

    init_scale = 0.01
    with tf.Graph().as_default():
        initializer = tf.random_uniform_initializer(-init_scale, init_scale)
        with tf.name_scope("Train"):
            with tf.variable_scope("Model",
                                   reuse=None,
                                   initializer=initializer):
                train_graph = ModelGraph(word_vocab=word_vocab,
                                         char_vocab=char_vocab,
                                         options=FLAGS,
                                         has_ref=True,
                                         is_training=True)

        with tf.name_scope("Valid"):
            with tf.variable_scope("Model",
                                   reuse=True,
                                   initializer=initializer):
                valid_graph = ModelGraph(word_vocab=word_vocab,
                                         char_vocab=char_vocab,
                                         options=FLAGS,
                                         has_ref=True,
                                         is_training=False)

        initializer = tf.global_variables_initializer()

        vars_ = {}
        for var in tf.all_variables():
            if FLAGS.fix_word_vec and "word_embedding" in var.name: continue
            if not var.name.startswith("Model"): continue
            vars_[var.name.split(":")[0]] = var
            print(var)
        saver = tf.train.Saver(vars_)

        sess = tf.Session()
        sess.run(initializer)
        if has_pretrained_model:
            print("Restoring model from " + best_path)
            saver.restore(sess, best_path)
            print("DONE!")

            if abs(best_accu) < 0.00001:
                print("Getting ACCU score for the model")
                best_accu = evaluate(sess,
                                     valid_graph,
                                     devDataStream,
                                     options=FLAGS)['dev_accu']
                FLAGS.best_accu = best_accu
                namespace_utils.save_namespace(FLAGS,
                                               path_prefix + ".config.json")
                print('ACCU = %.4f' % best_accu)
                log_file.write('ACCU = %.4f\n' % best_accu)

        print('Start the training loop.')
        train_size = trainDataStream.get_num_batch()
        max_steps = train_size * FLAGS.max_epochs
        total_loss = 0.0
        start_time = time.time()
        for step in xrange(max_steps):
            cur_batch = trainDataStream.nextBatch()
            cur_batch = MHQA_data_stream.BatchPadded(cur_batch)
            _, cur_loss, _ = train_graph.execute(sess, cur_batch, FLAGS)
            total_loss += cur_loss

            if step % 100 == 0:
                print('{} '.format(step), end="")
                sys.stdout.flush()

            # Save a checkpoint and evaluate the model periodically.
            if (step + 1) % trainDataStream.get_num_batch() == 0 or (step + 1) == max_steps or \
                    (trainDataStream.get_num_batch() > 10000 and (step+1)%2000 == 0):
                print()
                duration = time.time() - start_time
                print('Step %d: loss = %.2f (%.3f sec)' %
                      (step, total_loss, duration))
                log_file.write('Step %d: loss = %.2f (%.3f sec)\n' %
                               (step, total_loss, duration))
                log_file.flush()
                sys.stdout.flush()
                total_loss = 0.0

                best_accu = validate_and_save(sess, saver, FLAGS, log_file,
                                              devDataStream, valid_graph,
                                              path_prefix, best_accu)
                start_time = time.time()

    log_file.close()
def main_cv(FLAGS):
    # np.random.seed(FLAGS.seed)

    for fold in range(FLAGS.cv_folds):
        print("Start training fold " + str(fold))
        train_path = FLAGS.cv_train_path + str(fold) + '.tsv'
        train_feat_path = FLAGS.cv_train_feat_path + str(fold) + '.tsv'
        dev_path = FLAGS.cv_dev_path + str(fold) + '.tsv'
        dev_feat_path = FLAGS.cv_dev_feat_path + str(fold) + '.tsv'
        word_vec_path = FLAGS.word_vec_path
        char_vec_path = FLAGS.char_vec_path
        log_dir = FLAGS.model_dir + '/cv_fold_' + str(fold)
        if not os.path.exists(log_dir):
            os.makedirs(log_dir)

        path_prefix = log_dir + "/SentenceMatch.{}".format(FLAGS.suffix)

        namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")

        # build vocabs
        word_vocab = Vocab(word_vec_path, fileformat='txt3')
        char_vocab = None

        best_path = path_prefix + '.best.model'
        char_path = path_prefix + ".char_vocab"
        label_path = path_prefix + ".label_vocab"
        has_pre_trained_model = False

        if os.path.exists(best_path + ".index"):
            has_pre_trained_model = True
            print('Loading vocabs from a pre-trained model ...')
            label_vocab = Vocab(label_path, fileformat='txt2')
            if FLAGS.with_char:
                char_vocab = Vocab(char_path, fileformat='txt2')
        else:
            print('Collecting words, chars and labels ...')
            (all_words, all_chars, all_labels, all_POSs,
             all_NERs) = collect_vocabs(train_path)
            print('Number of words: {}'.format(len(all_words)))
            label_vocab = Vocab(fileformat='voc', voc=all_labels, dim=2)
            label_vocab.dump_to_txt2(label_path)

            if FLAGS.with_char:
                print('Number of chars: {}'.format(len(all_chars)))
                if char_vec_path == "":
                    char_vocab = Vocab(fileformat='voc',
                                       voc=all_chars,
                                       dim=FLAGS.char_emb_dim)
                else:
                    char_vocab = Vocab(char_vec_path, fileformat='txt3')
                char_vocab.dump_to_txt2(char_path)

        print('word_vocab shape is {}'.format(word_vocab.word_vecs.shape))
        if FLAGS.with_char:
            print('char_vocab shape is {}'.format(char_vocab.word_vecs.shape))
        num_classes = label_vocab.size()
        print("Number of labels: {}".format(num_classes))
        sys.stdout.flush()

        print('Build SentenceMatchDataStream ... ')
        trainDataStream = SentenceMatchDataStream(train_path,
                                                  train_feat_path,
                                                  word_vocab=word_vocab,
                                                  char_vocab=char_vocab,
                                                  label_vocab=label_vocab,
                                                  isShuffle=True,
                                                  isLoop=True,
                                                  isSort=True,
                                                  options=FLAGS)
        print('Number of instances in trainDataStream: {}'.format(
            trainDataStream.get_num_instance()))
        print('Number of batches in trainDataStream: {}'.format(
            trainDataStream.get_num_batch()))
        sys.stdout.flush()

        devDataStream = SentenceMatchDataStream(dev_path,
                                                dev_feat_path,
                                                word_vocab=word_vocab,
                                                char_vocab=char_vocab,
                                                label_vocab=label_vocab,
                                                isShuffle=False,
                                                isLoop=True,
                                                isSort=True,
                                                options=FLAGS)
        print('Number of instances in devDataStream: {}'.format(
            devDataStream.get_num_instance()))
        print('Number of batches in devDataStream: {}'.format(
            devDataStream.get_num_batch()))
        sys.stdout.flush()

        init_scale = 0.01
        with tf.Graph().as_default():
            # tf.set_random_seed(FLAGS.seed)

            initializer = tf.random_uniform_initializer(
                -init_scale, init_scale)
            global_step = tf.train.get_or_create_global_step()
            with tf.variable_scope("Model",
                                   reuse=None,
                                   initializer=initializer):
                train_graph = SentenceMatchModelGraph(num_classes,
                                                      word_vocab=word_vocab,
                                                      char_vocab=char_vocab,
                                                      is_training=True,
                                                      options=FLAGS,
                                                      global_step=global_step)

            with tf.variable_scope("Model",
                                   reuse=True,
                                   initializer=initializer):
                valid_graph = SentenceMatchModelGraph(num_classes,
                                                      word_vocab=word_vocab,
                                                      char_vocab=char_vocab,
                                                      is_training=False,
                                                      options=FLAGS)

            initializer = tf.global_variables_initializer()
            initializer_local = tf.local_variables_initializer()
            vars_ = {}
            for var in tf.global_variables():
                if "word_embedding" in var.name: continue
                # if not var.name.startswith("Model"): continue
                vars_[var.name.split(":")[0]] = var
            saver = tf.train.Saver(vars_)

            sess = tf.Session()
            sess.run(initializer,
                     feed_dict={
                         train_graph.w_embedding: word_vocab.word_vecs,
                         train_graph.c_embedding: char_vocab.word_vecs
                     })
            sess.run(initializer_local)
            if has_pre_trained_model:
                print("Restoring model from " + best_path)
                saver.restore(sess, best_path)
                print("DONE!")

            # training
            train(sess, saver, train_graph, valid_graph, trainDataStream,
                  devDataStream, FLAGS, best_path, label_vocab)
        print()
示例#19
0
def main():
    print(FLAGS.__dict__)
    log_dir = FLAGS.log_dir
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    path_prefix = log_dir + "/MHQA.{}".format(FLAGS.suffix)
    log_file_path = path_prefix + ".log"
    print('Log file path: {}'.format(log_file_path))
    log_file = open(log_file_path, 'wt')
    log_file.write("{}\n".format(FLAGS))
    log_file.flush()

    # save configuration
    namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = torch.cuda.device_count()
    print('device: {}, n_gpu: {}, grad_accum_steps: {}'.format(
        device, n_gpu, FLAGS.grad_accum_steps))
    log_file.write('device: {}, n_gpu: {}, grad_accum_steps: {}\n'.format(
        device, n_gpu, FLAGS.grad_accum_steps))

    glove_vocab = None
    glove_embedding = None
    if FLAGS.embedding_model.find('elmo') < 0:
        print('Loading GloVe model from: {}'.format(FLAGS.glove_path))
        glove_vocab, glove_embedding = MHQA_data_stream.load_glove(
            FLAGS.glove_path)

    print('Loading train set.')
    trainset, _ = MHQA_data_stream.read_data_file(FLAGS.train_path, FLAGS)
    trainset_batches = MHQA_data_stream.make_batches(trainset, FLAGS,
                                                     glove_vocab)
    print('Number of training samples: {}'.format(len(trainset)))
    print('Number of training batches: {}'.format(len(trainset_batches)))

    print('Loading dev set.')
    devset, _ = MHQA_data_stream.read_data_file(FLAGS.dev_path, FLAGS)
    devset_batches = MHQA_data_stream.make_batches(devset, FLAGS, glove_vocab)
    print('Number of dev samples: {}'.format(len(devset)))
    print('Number of dev batches: {}'.format(len(devset_batches)))

    # model
    print('Compiling model.')
    model = MHQA_model_graph.ModelGraph(FLAGS, glove_embedding)
    if os.path.exists(path_prefix + ".model.bin"):
        print('!!Existing pretrained model. Loading the model...')
        model.load_state_dict(torch.load(path_prefix + ".model.bin"))
    model.to(device)

    # pretrained performance
    best_accu = 0.0
    if os.path.exists(path_prefix + ".model.bin"):
        best_accu = FLAGS.best_accu if 'best_accu' in FLAGS.__dict__ and abs(FLAGS.best_accu) > 1e-4 \
                else evaluate_dataset(model, devset_batches)
        FLAGS.best_accu = best_accu
        print("!!Accuracy for pretrained model is {}".format(best_accu))

    # optimizer
    train_updates = len(trainset_batches) * FLAGS.num_epochs
    if FLAGS.grad_accum_steps > 1:
        train_updates = train_updates // FLAGS.grad_accum_steps
    if FLAGS.optim == 'bertadam':
        optimizer = BertAdam(model.parameters(),
                             lr=FLAGS.learning_rate,
                             warmup=FLAGS.warmup_proportion,
                             t_total=train_updates)
    elif FLAGS.optim == 'adam':
        optimizer = Adam(model.parameters(),
                         lr=FLAGS.learning_rate,
                         weight_decay=FLAGS.lambda_l2)
    else:
        assert False, 'unsupported optimizer type: {}'.format(FLAGS.optim)

    print('Start the training loop, total *updating* steps = {}'.format(
        train_updates))
    finished_steps, finished_epochs = 0, 0
    train_batch_ids = list(range(0, len(trainset_batches)))
    model.train()
    while finished_epochs < FLAGS.num_epochs:
        epoch_start = time.time()
        epoch_loss = []
        print('Current epoch takes {} steps'.format(len(train_batch_ids)))
        random.shuffle(train_batch_ids)
        start_time = time.time()
        for id in train_batch_ids:
            ori_batch = trainset_batches[id]
            batch = {k: v.to(device) if type(v) == torch.Tensor else v \
                    for k, v in ori_batch.items()}

            outputs = model(batch)
            loss = outputs['loss']
            epoch_loss.append(loss.item())

            if n_gpu > 1:
                loss = loss.mean()
            if FLAGS.grad_accum_steps > 1:
                loss = loss / FLAGS.grad_accum_steps
            loss.backward()  # just calculate gradient

            finished_steps += 1
            if finished_steps % FLAGS.grad_accum_steps == 0:
                optimizer.step()
                optimizer.zero_grad()

            if finished_steps % 100 == 0:
                print('{} '.format(finished_steps), end="")
                sys.stdout.flush()
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()

            # Save a checkpoint and evaluate the model periodically.
            if finished_steps > 0 and finished_steps % 1000 == 0:
                best_accu = validate_and_save(model, devset_batches, log_file,
                                              best_accu)
        duration = time.time() - start_time
        print('Training loss = %.2f (%.3f sec)' %
              (float(sum(epoch_loss)), duration))
        log_file.write('Training loss = %.2f (%.3f sec)\n' %
                       (float(sum(epoch_loss)), duration))
        finished_epochs += 1
        best_accu = validate_and_save(model, devset_batches, log_file,
                                      best_accu)

    log_file.close()
示例#20
0
def main(_):
    print('Configurations:')
    print(FLAGS)

    log_dir = FLAGS.model_dir
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    path_prefix = log_dir + "/NP2P.{}".format(FLAGS.suffix)
    log_file_path = path_prefix + ".log"
    print('Log file path: {}'.format(log_file_path))
    log_file = open(log_file_path, 'wt')
    log_file.write("{}\n".format(FLAGS))
    log_file.flush()

    # save configuration
    namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")

    print('Loading train set.')
    if FLAGS.infile_format == 'fof':
        trainset, train_ans_len = NP2P_data_stream.read_generation_datasets_from_fof(
            FLAGS.train_path, isLower=FLAGS.isLower)
    elif FLAGS.infile_format == 'plain':
        trainset, train_ans_len = NP2P_data_stream.read_all_GenerationDatasets(
            FLAGS.train_path, isLower=FLAGS.isLower)
    else:
        trainset, train_ans_len = NP2P_data_stream.read_all_GQA_questions(
            FLAGS.train_path, isLower=FLAGS.isLower, switch=FLAGS.switch_qa)
    print('Number of training samples: {}'.format(len(trainset)))

    print('Loading test set.')
    if FLAGS.infile_format == 'fof':
        testset, test_ans_len = NP2P_data_stream.read_generation_datasets_from_fof(
            FLAGS.test_path, isLower=FLAGS.isLower)
    elif FLAGS.infile_format == 'plain':
        testset, test_ans_len = NP2P_data_stream.read_all_GenerationDatasets(
            FLAGS.test_path, isLower=FLAGS.isLower)
    else:
        testset, test_ans_len = NP2P_data_stream.read_all_GQA_questions(
            FLAGS.test_path, isLower=FLAGS.isLower, switch=FLAGS.switch_qa)
    print('Number of test samples: {}'.format(len(testset)))

    max_actual_len = max(train_ans_len, test_ans_len)
    print('Max answer length: {}, truncated to {}'.format(
        max_actual_len, FLAGS.max_answer_len))

    word_vocab = None
    POS_vocab = None
    NER_vocab = None
    char_vocab = None
    has_pretrained_model = False
    best_path = path_prefix + ".best.model"
    if os.path.exists(best_path + ".index"):
        has_pretrained_model = True
        print('!!Existing pretrained model. Loading vocabs.')
        if FLAGS.with_word:
            word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2')
            print('word_vocab: {}'.format(word_vocab.word_vecs.shape))
        if FLAGS.with_char:
            char_vocab = Vocab(path_prefix + ".char_vocab", fileformat='txt2')
            print('char_vocab: {}'.format(char_vocab.word_vecs.shape))
        if FLAGS.with_POS:
            POS_vocab = Vocab(path_prefix + ".POS_vocab", fileformat='txt2')
            print('POS_vocab: {}'.format(POS_vocab.word_vecs.shape))
        if FLAGS.with_NER:
            NER_vocab = Vocab(path_prefix + ".NER_vocab", fileformat='txt2')
            print('NER_vocab: {}'.format(NER_vocab.word_vecs.shape))
    else:
        print('Collecting vocabs.')
        (allWords, allChars, allPOSs,
         allNERs) = NP2P_data_stream.collect_vocabs(trainset)
        print('Number of words: {}'.format(len(allWords)))
        print('Number of allChars: {}'.format(len(allChars)))
        print('Number of allPOSs: {}'.format(len(allPOSs)))
        print('Number of allNERs: {}'.format(len(allNERs)))

        if FLAGS.with_word:
            word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2')
        if FLAGS.with_char:
            char_vocab = Vocab(voc=allChars,
                               dim=FLAGS.char_dim,
                               fileformat='build')
            char_vocab.dump_to_txt2(path_prefix + ".char_vocab")
        if FLAGS.with_POS:
            POS_vocab = Vocab(voc=allPOSs,
                              dim=FLAGS.POS_dim,
                              fileformat='build')
            POS_vocab.dump_to_txt2(path_prefix + ".POS_vocab")
        if FLAGS.with_NER:
            NER_vocab = Vocab(voc=allNERs,
                              dim=FLAGS.NER_dim,
                              fileformat='build')
            NER_vocab.dump_to_txt2(path_prefix + ".NER_vocab")

    print('word vocab size {}'.format(word_vocab.vocab_size))
    sys.stdout.flush()

    print('Build DataStream ... ')
    trainDataStream = NP2P_data_stream.QADataStream(trainset,
                                                    word_vocab,
                                                    char_vocab,
                                                    POS_vocab,
                                                    NER_vocab,
                                                    options=FLAGS,
                                                    isShuffle=True,
                                                    isLoop=True,
                                                    isSort=True)

    devDataStream = NP2P_data_stream.QADataStream(testset,
                                                  word_vocab,
                                                  char_vocab,
                                                  POS_vocab,
                                                  NER_vocab,
                                                  options=FLAGS,
                                                  isShuffle=False,
                                                  isLoop=False,
                                                  isSort=True)
    print('Number of instances in trainDataStream: {}'.format(
        trainDataStream.get_num_instance()))
    print('Number of instances in devDataStream: {}'.format(
        devDataStream.get_num_instance()))
    print('Number of batches in trainDataStream: {}'.format(
        trainDataStream.get_num_batch()))
    print('Number of batches in devDataStream: {}'.format(
        devDataStream.get_num_batch()))
    sys.stdout.flush()

    init_scale = 0.01
    # initialize the best bleu and accu scores for current training session
    best_accu = FLAGS.best_accu if FLAGS.__dict__.has_key('best_accu') else 0.0
    best_bleu = FLAGS.best_bleu if FLAGS.__dict__.has_key('best_bleu') else 0.0
    if best_accu > 0.0:
        print('With initial dev accuracy {}'.format(best_accu))
    if best_bleu > 0.0:
        print('With initial dev BLEU score {}'.format(best_bleu))

    with tf.Graph().as_default():
        initializer = tf.random_uniform_initializer(-init_scale, init_scale)
        with tf.name_scope("Train"):
            with tf.variable_scope("Model",
                                   reuse=None,
                                   initializer=initializer):
                train_graph = ModelGraph(word_vocab=word_vocab,
                                         char_vocab=char_vocab,
                                         POS_vocab=POS_vocab,
                                         NER_vocab=NER_vocab,
                                         options=FLAGS,
                                         mode=FLAGS.mode)

        assert FLAGS.mode in (
            'ce_train',
            'rl_train',
        )
        valid_mode = 'evaluate' if FLAGS.mode == 'ce_train' else 'evaluate_bleu'

        with tf.name_scope("Valid"):
            with tf.variable_scope("Model",
                                   reuse=True,
                                   initializer=initializer):
                valid_graph = ModelGraph(word_vocab=word_vocab,
                                         char_vocab=char_vocab,
                                         POS_vocab=POS_vocab,
                                         NER_vocab=NER_vocab,
                                         options=FLAGS,
                                         mode=valid_mode)

        initializer = tf.global_variables_initializer()

        vars_ = {}
        for var in tf.all_variables():
            if "word_embedding" in var.name: continue
            if not var.name.startswith("Model"): continue
            vars_[var.name.split(":")[0]] = var
        saver = tf.train.Saver(vars_)

        sess = tf.Session()
        sess.run(initializer)
        if has_pretrained_model:
            print("Restoring model from " + best_path)
            saver.restore(sess, best_path)
            print("DONE!")

            if FLAGS.mode == 'rl_train' and abs(best_bleu) < 0.00001:
                print("Getting BLEU score for the model")
                best_bleu = evaluate(sess,
                                     valid_graph,
                                     devDataStream,
                                     options=FLAGS)['dev_bleu']
                FLAGS.best_bleu = best_bleu
                namespace_utils.save_namespace(FLAGS,
                                               path_prefix + ".config.json")
                print('BLEU = %.4f' % best_bleu)
                log_file.write('BLEU = %.4f\n' % best_bleu)
            if FLAGS.mode == 'ce_train' and abs(best_accu) < 0.00001:
                print("Getting ACCU score for the model")
                best_accu = evaluate(sess,
                                     valid_graph,
                                     devDataStream,
                                     options=FLAGS)['dev_accu']
                FLAGS.best_accu = best_accu
                namespace_utils.save_namespace(FLAGS,
                                               path_prefix + ".config.json")
                print('ACCU = %.4f' % best_accu)
                log_file.write('ACCU = %.4f\n' % best_accu)

        print('Start the training loop.')
        train_size = trainDataStream.get_num_batch()
        max_steps = train_size * FLAGS.max_epochs
        total_loss = 0.0
        start_time = time.time()
        for step in xrange(max_steps):
            cur_batch = trainDataStream.nextBatch()
            if FLAGS.mode == 'rl_train':
                loss_value = train_graph.run_rl_training_2(
                    sess, cur_batch, FLAGS)
            elif FLAGS.mode == 'ce_train':
                loss_value = train_graph.run_ce_training(
                    sess, cur_batch, FLAGS)
            total_loss += loss_value

            if step % 100 == 0:
                print('{} '.format(step), end="")
                sys.stdout.flush()

            # Save a checkpoint and evaluate the model periodically.
            if (step + 1) % trainDataStream.get_num_batch() == 0 or (
                    step + 1) == max_steps:
                print()
                duration = time.time() - start_time
                print('Step %d: loss = %.2f (%.3f sec)' %
                      (step, total_loss, duration))
                log_file.write('Step %d: loss = %.2f (%.3f sec)\n' %
                               (step, total_loss, duration))
                log_file.flush()
                sys.stdout.flush()
                total_loss = 0.0

                # Evaluate against the validation set.
                start_time = time.time()
                print('Validation Data Eval:')
                res_dict = evaluate(sess,
                                    valid_graph,
                                    devDataStream,
                                    options=FLAGS,
                                    suffix=str(step))
                if valid_graph.mode == 'evaluate':
                    dev_loss = res_dict['dev_loss']
                    dev_accu = res_dict['dev_accu']
                    dev_right = int(res_dict['dev_right'])
                    dev_total = int(res_dict['dev_total'])
                    print('Dev loss = %.4f' % dev_loss)
                    log_file.write('Dev loss = %.4f\n' % dev_loss)
                    print('Dev accu = %.4f %d/%d' %
                          (dev_accu, dev_right, dev_total))
                    log_file.write('Dev accu = %.4f %d/%d\n' %
                                   (dev_accu, dev_right, dev_total))
                    log_file.flush()
                    if best_accu < dev_accu:
                        print('Saving weights, ACCU {} (prev_best) < {} (cur)'.
                              format(best_accu, dev_accu))
                        saver.save(sess, best_path)
                        best_accu = dev_accu
                        FLAGS.best_accu = dev_accu
                        namespace_utils.save_namespace(
                            FLAGS, path_prefix + ".config.json")
                else:
                    dev_bleu = res_dict['dev_bleu']
                    print('Dev bleu = %.4f' % dev_bleu)
                    log_file.write('Dev bleu = %.4f\n' % dev_bleu)
                    log_file.flush()
                    if best_bleu < dev_bleu:
                        print('Saving weights, BLEU {} (prev_best) < {} (cur)'.
                              format(best_bleu, dev_bleu))
                        saver.save(sess, best_path)
                        best_bleu = dev_bleu
                        FLAGS.best_bleu = dev_bleu
                        namespace_utils.save_namespace(
                            FLAGS, path_prefix + ".config.json")
                duration = time.time() - start_time
                print('Duration %.3f sec' % (duration))
                sys.stdout.flush()

                log_file.write('Duration %.3f sec\n' % (duration))
                log_file.flush()

    log_file.close()
示例#21
0
文件: Main.py 项目: qikunxun/KEIM
def main(FLAGS):
    tf.logging.set_verbosity(tf.logging.INFO)
    train_path = FLAGS.train_path
    dev_path = FLAGS.dev_path
    test_path = FLAGS.test_path
    word_vec_path = FLAGS.word_vec_path
    kg_path = FLAGS.kg_path
    wordnet_path = FLAGS.wordnet_path
    lemma_vec_path = FLAGS.lemma_vec_path
    log_dir = FLAGS.model_dir
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
        os.makedirs(os.path.join(log_dir, '../result'))
        os.makedirs(os.path.join(log_dir, '../logits'))

    path_prefix = log_dir + "/KEIM.{}".format(FLAGS.suffix)
    namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")

    # build vocabs
    word_vocab = Vocab(word_vec_path, fileformat='txt3')
    lemma_vocab = Vocab(lemma_vec_path, fileformat='txt3')
    best_path = path_prefix + '.best.model'
    char_path = path_prefix + ".char_vocab"
    label_path = path_prefix + ".label_vocab"
    char_vocab = None

    tf.logging.info('Collecting words, chars and labels ...')
    (all_words, all_chars, all_labels, all_POSs,
     all_NERs) = collect_vocabs(train_path)
    tf.logging.info('Number of words: {}'.format(len(all_words)))
    label_vocab = Vocab(fileformat='voc', voc=all_labels, dim=2)
    label_vocab.dump_to_txt2(label_path)

    if FLAGS.with_char:
        tf.logging.info('Number of chars: {}'.format(len(all_chars)))
        char_vocab = Vocab(fileformat='voc',
                           voc=all_chars,
                           dim=FLAGS.char_emb_dim)
        char_vocab.dump_to_txt2(char_path)

    tf.logging.info('word_vocab shape is {}'.format(
        word_vocab.word_vecs.shape))
    tf.logging.info('lemma_word_vocab shape is {}'.format(
        lemma_vocab.word_vecs.shape))
    num_classes = label_vocab.size()
    tf.logging.info("Number of labels: {}".format(num_classes))
    sys.stdout.flush()

    with open(wordnet_path, 'rb') as f:
        wordnet_vocab = pkl.load(f)
    tf.logging.info('wordnet_vocab shape is {}'.format(len(wordnet_vocab)))
    with open(kg_path, 'rb') as f:
        kg = pkl.load(f)
    tf.logging.info('kg shape is {}'.format(len(kg)))

    tf.logging.info('Build SentenceMatchDataStream ... ')
    trainDataStream = DataStream(train_path,
                                 word_vocab=word_vocab,
                                 char_vocab=char_vocab,
                                 label_vocab=None,
                                 kg=kg,
                                 wordnet_vocab=wordnet_vocab,
                                 lemma_vocab=lemma_vocab,
                                 isShuffle=True,
                                 isLoop=True,
                                 isSort=True,
                                 options=FLAGS)
    tf.logging.info('Number of instances in trainDataStream: {}'.format(
        trainDataStream.get_num_instance()))
    tf.logging.info('Number of batches in trainDataStream: {}'.format(
        trainDataStream.get_num_batch()))
    sys.stdout.flush()

    devDataStream = DataStream(dev_path,
                               word_vocab=word_vocab,
                               char_vocab=char_vocab,
                               label_vocab=None,
                               kg=kg,
                               wordnet_vocab=wordnet_vocab,
                               lemma_vocab=lemma_vocab,
                               isShuffle=True,
                               isLoop=True,
                               isSort=True,
                               options=FLAGS)
    tf.logging.info('Number of instances in devDataStream: {}'.format(
        devDataStream.get_num_instance()))
    tf.logging.info('Number of batches in devDataStream: {}'.format(
        devDataStream.get_num_batch()))
    sys.stdout.flush()

    testDataStream = DataStream(test_path,
                                word_vocab=word_vocab,
                                char_vocab=char_vocab,
                                label_vocab=None,
                                kg=kg,
                                wordnet_vocab=wordnet_vocab,
                                lemma_vocab=lemma_vocab,
                                isShuffle=True,
                                isLoop=True,
                                isSort=True,
                                options=FLAGS)

    tf.logging.info('Number of instances in testDataStream: {}'.format(
        testDataStream.get_num_instance()))
    tf.logging.info('Number of batches in testDataStream: {}'.format(
        testDataStream.get_num_batch()))
    sys.stdout.flush()

    with tf.Graph().as_default():
        initializer = tf.contrib.layers.xavier_initializer()
        # initializer = tf.truncated_normal_initializer(stddev=0.02)
        global_step = tf.train.get_or_create_global_step()
        with tf.variable_scope("Model", reuse=None, initializer=initializer):
            train_graph = Model(num_classes,
                                word_vocab=word_vocab,
                                char_vocab=char_vocab,
                                lemma_vocab=lemma_vocab,
                                is_training=True,
                                options=FLAGS,
                                global_step=global_step)
        with tf.variable_scope("Model", reuse=True, initializer=initializer):
            valid_graph = Model(num_classes,
                                word_vocab=word_vocab,
                                char_vocab=char_vocab,
                                lemma_vocab=lemma_vocab,
                                is_training=False,
                                options=FLAGS)

        initializer = tf.global_variables_initializer()
        vars_ = {}
        for var in tf.global_variables():
            # if "word_embedding" in var.name: continue
            # if not var.name.startswith("Model"): continue
            vars_[var.name.split(":")[0]] = var
        saver = tf.train.Saver(vars_)

        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=1)
        config = tf.ConfigProto(allow_soft_placement=True,
                                gpu_options=gpu_options)

        config.gpu_options.allow_growth = True
        with tf.Session(config=config) as sess:
            sess.run(initializer)
            # training
            train(sess, saver, train_graph, valid_graph, trainDataStream,
                  devDataStream, testDataStream, FLAGS, best_path)