Ejemplo n.º 1
0
def main(FLAGS):
    # np.random.seed(FLAGS.seed)

    train_path = FLAGS.train_path
    train_feat_path = FLAGS.train_feat_path
    dev_path = FLAGS.dev_path
    dev_feat_path = FLAGS.dev_feat_path
    word_vec_path = FLAGS.word_vec_path
    word_vec_path2 = FLAGS.word_vec_path2
    char_vec_path = FLAGS.char_vec_path
    log_dir = FLAGS.model_dir
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    path_prefix = log_dir + "/SentenceMatch.{}".format(FLAGS.suffix)

    namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")

    # build vocabs
    word_vocab = Vocab(word_vec_path, fileformat='txt3')
    # word_vocab2 = Vocab(word_vec_path2, fileformat='txt3')
    # word_vocab = Vocab(word_vec_path, word_vec_path2, fileformat='txt4')
    char_vocab = None

    best_path = path_prefix + '.best.model'
    char_path = path_prefix + ".char_vocab"
    label_path = path_prefix + ".label_vocab"
    has_pre_trained_model = False

    if os.path.exists(best_path + ".index"):
        has_pre_trained_model = True
        print('Loading vocabs from a pre-trained model ...')
        label_vocab = Vocab(label_path, fileformat='txt2')
        if FLAGS.with_char: char_vocab = Vocab(char_path, fileformat='txt2')
    else:
        print('Collecting words, chars and labels ...')
        (all_words, all_chars, all_labels, all_POSs,
         all_NERs) = collect_vocabs(train_path)
        print('Number of words: {}'.format(len(all_words)))
        label_vocab = Vocab(fileformat='voc', voc=all_labels, dim=2)
        label_vocab.dump_to_txt2(label_path)

        if FLAGS.with_char:
            print('Number of chars: {}'.format(len(all_chars)))
            if char_vec_path == "":
                char_vocab = Vocab(fileformat='voc',
                                   voc=all_chars,
                                   dim=FLAGS.char_emb_dim)
            else:
                char_vocab = Vocab(char_vec_path, fileformat='txt3')
            char_vocab.dump_to_txt2(char_path)

    print('word_vocab shape is {}'.format(word_vocab.word_vecs.shape))
    if FLAGS.with_char:
        print('char_vocab shape is {}'.format(char_vocab.word_vecs.shape))
    num_classes = label_vocab.size()
    print("Number of labels: {}".format(num_classes))
    sys.stdout.flush()

    print('Build SentenceMatchDataStream ... ')
    trainDataStream = SentenceMatchDataStream(train_path,
                                              train_feat_path,
                                              word_vocab=word_vocab,
                                              char_vocab=char_vocab,
                                              label_vocab=label_vocab,
                                              isShuffle=True,
                                              isLoop=True,
                                              isSort=True,
                                              options=FLAGS)
    print('Number of instances in trainDataStream: {}'.format(
        trainDataStream.get_num_instance()))
    print('Number of batches in trainDataStream: {}'.format(
        trainDataStream.get_num_batch()))
    sys.stdout.flush()

    devDataStream = SentenceMatchDataStream(dev_path,
                                            dev_feat_path,
                                            word_vocab=word_vocab,
                                            char_vocab=char_vocab,
                                            label_vocab=label_vocab,
                                            isShuffle=False,
                                            isLoop=True,
                                            isSort=True,
                                            options=FLAGS)
    print('Number of instances in devDataStream: {}'.format(
        devDataStream.get_num_instance()))
    print('Number of batches in devDataStream: {}'.format(
        devDataStream.get_num_batch()))
    sys.stdout.flush()

    init_scale = 0.01
    with tf.Graph().as_default():
        # tf.set_random_seed(FLAGS.seed)

        initializer = tf.random_uniform_initializer(-init_scale, init_scale)
        global_step = tf.train.get_or_create_global_step()
        with tf.variable_scope("Model", reuse=None, initializer=initializer):
            train_graph = SentenceMatchModelGraph(num_classes,
                                                  word_vocab=word_vocab,
                                                  char_vocab=char_vocab,
                                                  is_training=True,
                                                  options=FLAGS,
                                                  global_step=global_step)

        with tf.variable_scope("Model", reuse=True, initializer=initializer):
            valid_graph = SentenceMatchModelGraph(num_classes,
                                                  word_vocab=word_vocab,
                                                  char_vocab=char_vocab,
                                                  is_training=False,
                                                  options=FLAGS)

        initializer = tf.global_variables_initializer()
        initializer_local = tf.local_variables_initializer()
        vars_ = {}
        for var in tf.global_variables():
            if "word_embedding" in var.name: continue
            # if not var.name.startswith("Model"): continue
            vars_[var.name.split(":")[0]] = var
        saver = tf.train.Saver(vars_)

        sess = tf.Session()
        sess.run(initializer,
                 feed_dict={
                     train_graph.w_embedding: word_vocab.word_vecs,
                     train_graph.c_embedding: char_vocab.word_vecs
                 })
        # sess.run(initializer, feed_dict={train_graph.w_embedding: word_vocab.word_vecs, train_graph.w_embedding_trainable: word_vocab2.word_vecs,
        #                                 train_graph.c_embedding: char_vocab.word_vecs})
        sess.run(initializer_local)
        if has_pre_trained_model:
            print("Restoring model from " + best_path)
            saver.restore(sess, best_path)
            print("DONE!")

        # training
        train(sess, saver, train_graph, valid_graph, trainDataStream,
              devDataStream, FLAGS, best_path, label_vocab)
Ejemplo n.º 2
0
def main(_):
    print('Configurations:')
    print(FLAGS)

    train_path = FLAGS.train_path
    dev_path = FLAGS.dev_path
    test_path = FLAGS.test_path
    word_vec_path = FLAGS.word_vec_path
    log_dir = FLAGS.model_dir
    if FLAGS.train == "sick":
        train_path = FLAGS.SICK_train_path
        dev_path = FLAGS.SICK_dev_path
    if FLAGS.test == "sick":
        test_path = FLAGS.SICK_test_path
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    
    path_prefix = log_dir + "/SentenceMatch.{}".format(FLAGS.suffix)
    namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")

    # build vocabs
    parser,image_feats = None, None
    if FLAGS.with_dep:
        parser=Parser('snli')
    if FLAGS.with_image:
        image_feats=ImageFeatures()
    word_vocab = Vocab(word_vec_path, fileformat='txt3', parser=parser, beginning=FLAGS.beginning) #fileformat='txt3'
    best_path = path_prefix + '.best.model'
    char_path = path_prefix + ".char_vocab"
    label_path = path_prefix + ".label_vocab"
    POS_path = path_prefix + ".POS_vocab"
    NER_path = path_prefix + ".NER_vocab"
    DEP_path = path_prefix + ".DEP_vocab"
    has_pre_trained_model = False
    POS_vocab = None
    NER_vocab = None
    DEP_vocab = None
    print('has pretrained model: ', os.path.exists(best_path))
    print('best_path: ' + best_path)
    if os.path.exists(best_path + '.meta'):
        has_pre_trained_model = True
        label_vocab = Vocab(label_path, fileformat='txt2')
        char_vocab = Vocab(char_path, fileformat='txt2')
        if FLAGS.with_POS: POS_vocab = Vocab(POS_path, fileformat='txt2')
        if FLAGS.with_NER: NER_vocab = Vocab(NER_path, fileformat='txt2')
    else:
        print('Collect words, chars and labels ...')
        (all_words, all_chars, all_labels, all_POSs, all_NERs) = collect_vocabs(train_path, with_POS=FLAGS.with_POS, with_NER=FLAGS.with_NER)
        print('Number of words: {}'.format(len(all_words)))
        print('Number of labels: {}'.format(len(all_labels)))
        label_vocab = Vocab(fileformat='voc', voc=all_labels,dim=2)
        label_vocab.dump_to_txt2(label_path)

        print('Number of chars: {}'.format(len(all_chars)))
        char_vocab = Vocab(fileformat='voc', voc=all_chars,dim=FLAGS.char_emb_dim, beginning=FLAGS.beginning)
        char_vocab.dump_to_txt2(char_path)
        
        if FLAGS.with_POS:
            print('Number of POSs: {}'.format(len(all_POSs)))
            POS_vocab = Vocab(fileformat='voc', voc=all_POSs,dim=FLAGS.POS_dim)
            POS_vocab.dump_to_txt2(POS_path)
        if FLAGS.with_NER:
            print('Number of NERs: {}'.format(len(all_NERs)))
            NER_vocab = Vocab(fileformat='voc', voc=all_NERs,dim=FLAGS.NER_dim)
            NER_vocab.dump_to_txt2(NER_path)
            

    print('word_vocab shape is {}'.format(word_vocab.word_vecs.shape))
    print('tag_vocab shape is {}'.format(label_vocab.word_vecs.shape))
    num_classes = label_vocab.size()

    print('Build DataStream ... ')
    print('Reading trainDataStream')
    if not FLAGS.decoding_only:
        trainDataStream = DataStream(train_path, word_vocab=word_vocab, char_vocab=char_vocab, 
                                              POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, 
                                              batch_size=FLAGS.batch_size, isShuffle=True, isLoop=True, isSort=True, 
                                              max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length, 
                                              with_dep=FLAGS.with_dep, with_image=FLAGS.with_image, image_feats=image_feats, sick_data=(FLAGS.test == "sick"))
    
        print('Reading devDataStream')
        devDataStream = DataStream(dev_path, word_vocab=word_vocab, char_vocab=char_vocab,
                                              POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, 
                                              batch_size=FLAGS.batch_size, isShuffle=False, isLoop=True, isSort=True, 
                                              max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length, 
                                              with_dep=FLAGS.with_dep, with_image=FLAGS.with_image, image_feats=image_feats, sick_data=(FLAGS.test == "sick"))

    print('Reading testDataStream')
    testDataStream = DataStream(test_path, word_vocab=word_vocab, char_vocab=char_vocab, 
                                              POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, 
                                              batch_size=FLAGS.batch_size, isShuffle=False, isLoop=True, isSort=True, 
                                                  max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length, 
                                                  with_dep=FLAGS.with_dep, with_image=FLAGS.with_image, image_feats=image_feats, sick_data=(FLAGS.test == "sick"))
    print('save cache file')
    #word_vocab.parser.save_cache()
    #image_feats.save_feat()
    if not FLAGS.decoding_only:
        print('Number of instances in trainDataStream: {}'.format(trainDataStream.get_num_instance()))
        print('Number of instances in devDataStream: {}'.format(devDataStream.get_num_instance()))
        print('Number of batches in trainDataStream: {}'.format(trainDataStream.get_num_batch()))
        print('Number of batches in devDataStream: {}'.format(devDataStream.get_num_batch()))
    print('Number of instances in testDataStream: {}'.format(testDataStream.get_num_instance()))
    print('Number of batches in testDataStream: {}'.format(testDataStream.get_num_batch()))
    
    sys.stdout.flush()
    if FLAGS.wo_char: char_vocab = None

    best_accuracy = 0.0
    init_scale = 0.01
    
    if not FLAGS.decoding_only:
        with tf.Graph().as_default():
            initializer = tf.random_uniform_initializer(-init_scale, init_scale)
            with tf.variable_scope("Model", reuse=None, initializer=initializer):
                train_graph = ModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab,POS_vocab=POS_vocab, NER_vocab=NER_vocab, 
                    dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type,
                    lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim, 
                    aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=True, MP_dim=FLAGS.MP_dim, 
                    context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num, 
                    fix_word_vec=FLAGS.fix_word_vec,with_filter_layer=FLAGS.with_filter_layer, with_highway=FLAGS.with_highway,
                    word_level_MP_dim=FLAGS.word_level_MP_dim,
                    with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway,
                    highway_layer_num=FLAGS.highway_layer_num,with_lex_decomposition=FLAGS.with_lex_decomposition, 
                    lex_decompsition_dim=FLAGS.lex_decompsition_dim,
                    with_left_match=(not FLAGS.wo_left_match), with_right_match=(not FLAGS.wo_right_match),
                    with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match), 
                    with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match), 
                    with_dep=FLAGS.with_dep, with_image=FLAGS.with_image, image_with_hypothesis_only=FLAGS.image_with_hypothesis_only,
                    with_img_full_match=FLAGS.with_img_full_match, with_img_maxpool_match=FLAGS.with_img_full_match, 
                    with_img_attentive_match=FLAGS.with_img_attentive_match, image_context_layer=FLAGS.image_context_layer, 
                    with_img_max_attentive_match=FLAGS.with_img_max_attentive_match, img_dim=FLAGS.img_dim)
                tf.summary.scalar("Training Loss", train_graph.get_loss()) # Add a scalar summary for the snapshot loss.
        
            with tf.variable_scope("Model", reuse=True, initializer=initializer):
                valid_graph = ModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, 
                    dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type,
                    lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim, 
                    aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=False, MP_dim=FLAGS.MP_dim, 
                    context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num, 
                    fix_word_vec=FLAGS.fix_word_vec,with_filter_layer=FLAGS.with_filter_layer, with_highway=FLAGS.with_highway,
                    word_level_MP_dim=FLAGS.word_level_MP_dim,
                    with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway,
                    highway_layer_num=FLAGS.highway_layer_num, with_lex_decomposition=FLAGS.with_lex_decomposition, 
                    lex_decompsition_dim=FLAGS.lex_decompsition_dim,
                    with_left_match=(not FLAGS.wo_left_match), with_right_match=(not FLAGS.wo_right_match),
                    with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match), 
                    with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match), 
                    with_dep=FLAGS.with_dep, with_image=FLAGS.with_image, image_with_hypothesis_only=FLAGS.image_with_hypothesis_only,
                    with_img_attentive_match=FLAGS.with_img_attentive_match, with_img_full_match=FLAGS.with_img_full_match, 
                    with_img_maxpool_match=FLAGS.with_img_full_match, image_context_layer=FLAGS.image_context_layer, 
                    with_img_max_attentive_match=FLAGS.with_img_max_attentive_match, img_dim=FLAGS.img_dim)

                
            initializer = tf.global_variables_initializer()
            vars_ = {}
            for var in tf.all_variables():
                if "word_embedding" in var.name: continue
                vars_[var.name.split(":")[0]] = var
            saver = tf.train.Saver(vars_)
         
            sess = tf.Session()
            sess.run(initializer) #, feed_dict={valid_graph.emb_init: word_vocab.word_vecs, train_graph.emb_init: word_vocab.word_vecs})
            if has_pre_trained_model:
                print("Restoring model from " + best_path)
                saver.restore(sess, best_path)
                print("DONE!")
                #if best_path.startswith('bimpm_baseline'):
                #best_path = best_path + '_sick' 

            print('Start the training loop.')
            train_size = trainDataStream.get_num_batch()
            max_steps = train_size * FLAGS.max_epochs
            total_loss = 0.0
            start_time = time.time()
            for step in xrange(max_steps):
                # read data
                cur_batch = trainDataStream.nextBatch()
                (label_batch, sent1_batch, sent2_batch, label_id_batch, word_idx_1_batch, word_idx_2_batch, 
                                 char_matrix_idx_1_batch, char_matrix_idx_2_batch, sent1_length_batch, sent2_length_batch, 
                                 sent1_char_length_batch, sent2_char_length_batch,
                                 POS_idx_1_batch, POS_idx_2_batch, NER_idx_1_batch, NER_idx_2_batch, 
                                 dependency1_batch, dependency2_batch, dep_con1_batch, dep_con2_batch, img_feats_batch, img_id_batch) = cur_batch
                feed_dict = {
                         train_graph.get_truth(): label_id_batch, 
                         train_graph.get_question_lengths(): sent1_length_batch, 
                         train_graph.get_passage_lengths(): sent2_length_batch, 
                         train_graph.get_in_question_words(): word_idx_1_batch, 
                         train_graph.get_in_passage_words(): word_idx_2_batch,
                         #train_graph.get_emb_init(): word_vocab.word_vecs,
                         #train_graph.get_in_question_dependency(): dependency1_batch,
                         #train_graph.get_in_passage_dependency(): dependency2_batch,
#                          train_graph.get_question_char_lengths(): sent1_char_length_batch, 
#                          train_graph.get_passage_char_lengths(): sent2_char_length_batch, 
#                          train_graph.get_in_question_chars(): char_matrix_idx_1_batch, 
#                          train_graph.get_in_passage_chars(): char_matrix_idx_2_batch, 
                         }
                if FLAGS.with_dep:
                    feed_dict[train_graph.get_in_question_dependency()] = dependency1_batch
                    feed_dict[train_graph.get_in_passage_dependency()] = dependency2_batch
                    feed_dict[train_graph.get_in_question_dep_con()] = dep_con1_batch
                    feed_dict[train_graph.get_in_passage_dep_con()] = dep_con2_batch

                if FLAGS.with_image:
                    feed_dict[train_graph.get_image_feats()] = img_feats_batch

                if char_vocab is not None:
                    feed_dict[train_graph.get_question_char_lengths()] = sent1_char_length_batch
                    feed_dict[train_graph.get_passage_char_lengths()] = sent2_char_length_batch
                    feed_dict[train_graph.get_in_question_chars()] = char_matrix_idx_1_batch
                    feed_dict[train_graph.get_in_passage_chars()] = char_matrix_idx_2_batch

                if POS_vocab is not None:
                    feed_dict[train_graph.get_in_question_poss()] = POS_idx_1_batch
                    feed_dict[train_graph.get_in_passage_poss()] = POS_idx_2_batch

                if NER_vocab is not None:
                    feed_dict[train_graph.get_in_question_ners()] = NER_idx_1_batch
                    feed_dict[train_graph.get_in_passage_ners()] = NER_idx_2_batch

                _, loss_value = sess.run([train_graph.get_train_op(), train_graph.get_loss()], feed_dict=feed_dict)
                total_loss += loss_value
            
                if step % 100==0: 
                    print('{} '.format(step), end="")
                    sys.stdout.flush()

                # Save a checkpoint and evaluate the model periodically.
                if (step + 1) % trainDataStream.get_num_batch() == 0 or (step + 1) == max_steps:
                    print()
                    # Print status to stdout.
                    duration = time.time() - start_time
                    start_time = time.time()
                    print('Step %d: loss = %.2f (%.3f sec)' % (step, total_loss, duration))
                    total_loss = 0.0

                    # Evaluate against the validation set.
                    print('Validation Data Eval:')
                    accuracy = evaluate(devDataStream, valid_graph, sess,char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, word_vocab=word_vocab)
                    print("Current accuracy on dev is %.2f" % accuracy)
                
                    #accuracy_train = evaluate(trainDataStream, valid_graph, sess,char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab)
                    #print("Current accuracy on train is %.2f" % accuracy_train)
                    if accuracy>best_accuracy:
                        best_accuracy = accuracy
                        saver.save(sess, best_path)
        print("Best accuracy on dev set is %.2f" % best_accuracy)
    
    # decoding
    print('Decoding on the test set:')
    init_scale = 0.01
    with tf.Graph().as_default():
        initializer = tf.random_uniform_initializer(-init_scale, init_scale)
        with tf.variable_scope("Model", reuse=False, initializer=initializer):
            valid_graph = ModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, 
                 dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type,
                 lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim, 
                 aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=False, MP_dim=FLAGS.MP_dim, 
                 context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num, 
                 fix_word_vec=FLAGS.fix_word_vec,with_filter_layer=FLAGS.with_filter_layer, with_highway=FLAGS.with_highway,
                 word_level_MP_dim=FLAGS.word_level_MP_dim,
                 with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway,
                 highway_layer_num=FLAGS.highway_layer_num, with_lex_decomposition=FLAGS.with_lex_decomposition, 
                 lex_decompsition_dim=FLAGS.lex_decompsition_dim,
                 with_left_match=(not FLAGS.wo_left_match), with_right_match=(not FLAGS.wo_right_match),
                 with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match), 
                 with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match), 
                 with_dep=FLAGS.with_dep, with_image=FLAGS.with_image, image_with_hypothesis_only=FLAGS.image_with_hypothesis_only,
                 with_img_attentive_match=FLAGS.with_img_attentive_match, with_img_full_match=FLAGS.with_img_full_match, 
                 with_img_maxpool_match=FLAGS.with_img_full_match, image_context_layer=FLAGS.image_context_layer, 
                 with_img_max_attentive_match=FLAGS.with_img_max_attentive_match, img_dim=FLAGS.img_dim)
        vars_ = {}
        for var in tf.all_variables():
            if "word_embedding" in var.name: continue
            if not var.name.startswith("Model"): continue
            vars_[var.name.split(":")[0]] = var
        saver = tf.train.Saver(vars_)
                
        sess = tf.Session()
        sess.run(tf.global_variables_initializer())#, feed_dict={valid_graph.emb_init: word_vocab.word_vecs})
        step = 0
        saver.restore(sess, best_path)

        accuracy = evaluate(testDataStream, valid_graph, sess, outpath=FLAGS.suffix+ FLAGS.train + FLAGS.test + ".result",char_vocab=char_vocab,label_vocab=label_vocab, word_vocab=word_vocab)
        print("Accuracy for test set is %.2f" % accuracy)
        accuracy_train = evaluate(trainDataStream, valid_graph, sess,char_vocab=char_vocab,word_vocab=word_vocab)
        print("Accuracy for train set is %.2f" % accuracy_train)
    args, unparsed = parser.parse_known_args()

    # load the configuration file
    print('Loading configurations.')
    options = namespace_utils.load_namespace(args.model_prefix +
                                             ".config.json")

    if args.word_vec_path is None: args.word_vec_path = options.word_vec_path

    # load vocabs
    print('Loading vocabs.')
    word_vocab = Vocab(args.word_vec_path, fileformat='txt3')
    label_vocab = Vocab(args.model_prefix + ".label_vocab", fileformat='txt2')
    print('word_vocab: {}'.format(word_vocab.word_vecs.shape))
    print('label_vocab: {}'.format(label_vocab.word_vecs.shape))
    num_classes = label_vocab.size()

    if options.with_char:
        char_vocab = Vocab(args.model_prefix + ".char_vocab",
                           fileformat='txt2')
        print('char_vocab: {}'.format(char_vocab.word_vecs.shape))

    print('Build SentenceMatchDataStream ... ')
    testDataStream = SentenceMatchDataStream(args.in_path,
                                             word_vocab=word_vocab,
                                             char_vocab=char_vocab,
                                             label_vocab=label_vocab,
                                             isShuffle=False,
                                             isLoop=True,
                                             isSort=True,
                                             options=options)
def get_test_result(in_p,root_path):
    print('Loading configurations.')
    model_prefix =root_path+"/stsapp/src/logs/SentenceMatch.snli"
    word_vec_path = root_path+"/stsapp/src/data/snli/wordvec.txt"



    in_path = in_p

    out_path =root_path+"/stsapp/src/result.txt"


    print("access decoder")

    options = namespace_utils.load_namespace(model_prefix + ".config.json")

    if word_vec_path is None: word_vec_path = options.word_vec_path


    # load vocabs
    print('Loading vocabs.')
    word_vocab = Vocab(word_vec_path, fileformat='txt3')
    label_vocab = Vocab(model_prefix + ".label_vocab", fileformat='txt2')
    print('word_vocab: {}'.format(word_vocab.word_vecs.shape))
    print('label_vocab: {}'.format(label_vocab.word_vecs.shape))
    num_classes = label_vocab.size()

    if options.with_char:
        char_vocab = Vocab(model_prefix + ".char_vocab", fileformat='txt2')
        print('char_vocab: {}'.format(char_vocab.word_vecs.shape))
    
    print('Build SentenceMatchDataStream ... ')
    testDataStream = SentenceMatchDataStream(in_path, word_vocab=word_vocab, char_vocab=char_vocab,
                                            label_vocab=label_vocab,
                                            isShuffle=False, isLoop=True, isSort=True, options=options)
    print('Number of instances in devDataStream: {}'.format(testDataStream.get_num_instance()))
    print('Number of batches in devDataStream: {}'.format(testDataStream.get_num_batch()))
    sys.stdout.flush()

    best_path = model_prefix + ".best.model"
    init_scale = 0.01
    with tf.Graph().as_default():
        initializer = tf.random_uniform_initializer(-init_scale, init_scale)
        global_step = tf.train.get_or_create_global_step()
        with tf.variable_scope("Model", reuse=False, initializer=initializer):
            valid_graph = SentenceMatchModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab,
                                                  is_training=False, options=options)

        initializer = tf.global_variables_initializer()
        vars_ = {}
        for var in tf.global_variables():
            if "word_embedding" in var.name: continue
            if not var.name.startswith("Model"): continue
            vars_[var.name.split(":")[0]] = var
        saver = tf.train.Saver(vars_)

        sess = tf.Session()
        sess.run(initializer)
        print("Restoring model from " + best_path)
        saver.restore(sess, best_path)
        print("DONE!")
        acc,result = train.evaluation(sess, valid_graph, testDataStream, outpath=out_path,
                                              label_vocab=label_vocab)

        print("Accuracy for test set is : ",colored(acc, 'green'),"\n")

        # print(result['probs'])

        return acc,result
Ejemplo n.º 5
0
def main(_):


    #for x in range (100):
    #    Generate_random_initialization()
    #    print (FLAGS.is_aggregation_lstm, FLAGS.context_lstm_dim, FLAGS.context_layer_num, FLAGS. aggregation_lstm_dim, FLAGS.aggregation_layer_num, FLAGS.max_window_size, FLAGS.MP_dim)

    print('Configurations:')
    #print(FLAGS)


    train_path = FLAGS.train_path
    dev_path = FLAGS.dev_path
    test_path = FLAGS.test_path
    word_vec_path = FLAGS.word_vec_path
    log_dir = FLAGS.model_dir
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    
    path_prefix = log_dir + "/SentenceMatch.{}".format(FLAGS.suffix)

    namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")

    # build vocabs
    word_vocab = Vocab(word_vec_path, fileformat='txt3')
    best_path = path_prefix + '.best.model'
    char_path = path_prefix + ".char_vocab"
    label_path = path_prefix + ".label_vocab"
    POS_path = path_prefix + ".POS_vocab"
    NER_path = path_prefix + ".NER_vocab"
    has_pre_trained_model = False
    POS_vocab = None
    NER_vocab = None
    if os.path.exists(best_path):
        has_pre_trained_model = True
        label_vocab = Vocab(label_path, fileformat='txt2')
        char_vocab = Vocab(char_path, fileformat='txt2')
        if FLAGS.with_POS: POS_vocab = Vocab(POS_path, fileformat='txt2')
        if FLAGS.with_NER: NER_vocab = Vocab(NER_path, fileformat='txt2')
    else:
        print('Collect words, chars and labels ...')
        (all_words, all_chars, all_labels, all_POSs, all_NERs) = collect_vocabs(train_path, with_POS=FLAGS.with_POS, with_NER=FLAGS.with_NER)
        print('Number of words: {}'.format(len(all_words)))
        print('Number of labels: {}'.format(len(all_labels)))
        label_vocab = Vocab(fileformat='voc', voc=all_labels,dim=2)
        label_vocab.dump_to_txt2(label_path)

        print('Number of chars: {}'.format(len(all_chars)))
        char_vocab = Vocab(fileformat='voc', voc=all_chars,dim=FLAGS.char_emb_dim)
        char_vocab.dump_to_txt2(char_path)
        
        if FLAGS.with_POS:
            print('Number of POSs: {}'.format(len(all_POSs)))
            POS_vocab = Vocab(fileformat='voc', voc=all_POSs,dim=FLAGS.POS_dim)
            POS_vocab.dump_to_txt2(POS_path)
        if FLAGS.with_NER:
            print('Number of NERs: {}'.format(len(all_NERs)))
            NER_vocab = Vocab(fileformat='voc', voc=all_NERs,dim=FLAGS.NER_dim)
            NER_vocab.dump_to_txt2(NER_path)
            

    print('word_vocab shape is {}'.format(word_vocab.word_vecs.shape))
    print('tag_vocab shape is {}'.format(label_vocab.word_vecs.shape))
    num_classes = label_vocab.size()

    print('Build SentenceMatchDataStream ... ')
    trainDataStream = SentenceMatchDataStream(train_path, word_vocab=word_vocab, char_vocab=char_vocab, 
                                              POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, 
                                              batch_size=FLAGS.batch_size, isShuffle=True, isLoop=True, isSort=True, 
                                              max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length,
                                              is_as=FLAGS.is_answer_selection)
                                    
    devDataStream = SentenceMatchDataStream(dev_path, word_vocab=word_vocab, char_vocab=char_vocab,
                                              POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, 
                                              batch_size=FLAGS.batch_size, isShuffle=False, isLoop=True, isSort=True,
                                              max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length,
                                              is_as=FLAGS.is_answer_selection)

    testDataStream = SentenceMatchDataStream(test_path, word_vocab=word_vocab, char_vocab=char_vocab, 
                                              POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, 
                                              batch_size=FLAGS.batch_size, isShuffle=False, isLoop=True, isSort=True,
                                              max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length,
                                              is_as=FLAGS.is_answer_selection)

    print('Number of instances in trainDataStream: {}'.format(trainDataStream.get_num_instance()))
    print('Number of instances in devDataStream: {}'.format(devDataStream.get_num_instance()))
    print('Number of instances in testDataStream: {}'.format(testDataStream.get_num_instance()))
    print('Number of batches in trainDataStream: {}'.format(trainDataStream.get_num_batch()))
    print('Number of batches in devDataStream: {}'.format(devDataStream.get_num_batch()))
    print('Number of batches in testDataStream: {}'.format(testDataStream.get_num_batch()))
    
    sys.stdout.flush()
    if FLAGS.wo_char: char_vocab = None
    output_res_index = 1
    while True:
        Generate_random_initialization()
        st_cuda = ''
        if FLAGS.is_server == True:
            st_cuda = str(os.environ['CUDA_VISIBLE_DEVICES']) + '.'
        output_res_file = open('../result/' + st_cuda + str(output_res_index), 'wt')
        output_res_index += 1
        output_res_file.write(str(FLAGS) + '\n\n')
        stt = str (FLAGS)
        best_accuracy = 0.0
        init_scale = 0.01
        with tf.Graph().as_default():
            initializer = tf.random_uniform_initializer(-init_scale, init_scale)
    #         with tf.name_scope("Train"):
            with tf.variable_scope("Model", reuse=None, initializer=initializer):
                train_graph = SentenceMatchModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab,
                                                      dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type,
                                                      lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim,
                                                      aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=True, MP_dim=FLAGS.MP_dim,
                                                      context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num,
                                                      fix_word_vec=FLAGS.fix_word_vec, with_filter_layer=FLAGS.with_filter_layer, with_input_highway=FLAGS.with_highway,
                                                      word_level_MP_dim=FLAGS.word_level_MP_dim,
                                                      with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway,
                                                      highway_layer_num=FLAGS.highway_layer_num, with_lex_decomposition=FLAGS.with_lex_decomposition,
                                                      lex_decompsition_dim=FLAGS.lex_decompsition_dim,
                                                      with_left_match=(not FLAGS.wo_left_match), with_right_match=(not FLAGS.wo_right_match),
                                                      with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match),
                                                      with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match),
                                                      with_bilinear_att=(FLAGS.attention_type)
                                                      , type1=FLAGS.type1, type2 = FLAGS.type2, type3=FLAGS.type3,
                                                      with_aggregation_attention=not FLAGS.wo_agg_self_att,
                                                      is_answer_selection= FLAGS.is_answer_selection,
                                                      is_shared_attention=FLAGS.is_shared_attention,
                                                      modify_loss=FLAGS.modify_loss, is_aggregation_lstm=FLAGS.is_aggregation_lstm
                                                      , max_window_size=FLAGS.max_window_size
                                                      , prediction_mode=FLAGS.prediction_mode,
                                                      context_lstm_dropout=not FLAGS.wo_lstm_drop_out,
                                                      is_aggregation_siamese=FLAGS.is_aggregation_siamese
                                                      , unstack_cnn=FLAGS.unstack_cnn,with_context_self_attention=FLAGS.with_context_self_attention)
                tf.summary.scalar("Training Loss", train_graph.get_loss()) # Add a scalar summary for the snapshot loss.

    #         with tf.name_scope("Valid"):
            with tf.variable_scope("Model", reuse=True, initializer=initializer):
                valid_graph = SentenceMatchModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab,
                                                      dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type,
                                                      lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim,
                                                      aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=False, MP_dim=FLAGS.MP_dim,
                                                      context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num,
                                                      fix_word_vec=FLAGS.fix_word_vec, with_filter_layer=FLAGS.with_filter_layer, with_input_highway=FLAGS.with_highway,
                                                      word_level_MP_dim=FLAGS.word_level_MP_dim,
                                                      with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway,
                                                      highway_layer_num=FLAGS.highway_layer_num, with_lex_decomposition=FLAGS.with_lex_decomposition,
                                                      lex_decompsition_dim=FLAGS.lex_decompsition_dim,
                                                      with_left_match=(not FLAGS.wo_left_match), with_right_match=(not FLAGS.wo_right_match),
                                                      with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match),
                                                      with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match),
                                                      with_bilinear_att=(FLAGS.attention_type)
                                                      , type1=FLAGS.type1, type2 = FLAGS.type2, type3=FLAGS.type3,
                                                      with_aggregation_attention=not FLAGS.wo_agg_self_att,
                                                      is_answer_selection= FLAGS.is_answer_selection,
                                                      is_shared_attention=FLAGS.is_shared_attention,
                                                      modify_loss=FLAGS.modify_loss, is_aggregation_lstm=FLAGS.is_aggregation_lstm,
                                                      max_window_size=FLAGS.max_window_size
                                                      , prediction_mode=FLAGS.prediction_mode,
                                                      context_lstm_dropout=not FLAGS.wo_lstm_drop_out,
                                                      is_aggregation_siamese=FLAGS.is_aggregation_siamese
                                                      , unstack_cnn=FLAGS.unstack_cnn,with_context_self_attention=FLAGS.with_context_self_attention)


            initializer = tf.global_variables_initializer()
            vars_ = {}
            #for var in tf.all_variables():
            for var in tf.global_variables():
                if "word_embedding" in var.name: continue
    #             if not var.name.startswith("Model"): continue
                vars_[var.name.split(":")[0]] = var
            saver = tf.train.Saver(vars_)

            with tf.Session() as sess:
                sess.run(initializer)
                if has_pre_trained_model:
                    print("Restoring model from " + best_path)
                    saver.restore(sess, best_path)
                    print("DONE!")

                print('Start the training loop.')
                train_size = trainDataStream.get_num_batch()
                max_steps = train_size * FLAGS.max_epochs
                total_loss = 0.0
                start_time = time.time()

                for step in xrange(max_steps):
                    # read data
                    cur_batch, batch_index = trainDataStream.nextBatch()
                    (label_batch, sent1_batch, sent2_batch, label_id_batch, word_idx_1_batch, word_idx_2_batch,
                                         char_matrix_idx_1_batch, char_matrix_idx_2_batch, sent1_length_batch, sent2_length_batch,
                                         sent1_char_length_batch, sent2_char_length_batch,
                                         POS_idx_1_batch, POS_idx_2_batch, NER_idx_1_batch, NER_idx_2_batch) = cur_batch
                    feed_dict = {
                                 train_graph.get_truth(): label_id_batch,
                                 train_graph.get_question_lengths(): sent1_length_batch,
                                 train_graph.get_passage_lengths(): sent2_length_batch,
                                 train_graph.get_in_question_words(): word_idx_1_batch,
                                 train_graph.get_in_passage_words(): word_idx_2_batch,
        #                          train_graph.get_question_char_lengths(): sent1_char_length_batch,
        #                          train_graph.get_passage_char_lengths(): sent2_char_length_batch,
        #                          train_graph.get_in_question_chars(): char_matrix_idx_1_batch,
        #                          train_graph.get_in_passage_chars(): char_matrix_idx_2_batch,
                                 }
                    if char_vocab is not None:
                        feed_dict[train_graph.get_question_char_lengths()] = sent1_char_length_batch
                        feed_dict[train_graph.get_passage_char_lengths()] = sent2_char_length_batch
                        feed_dict[train_graph.get_in_question_chars()] = char_matrix_idx_1_batch
                        feed_dict[train_graph.get_in_passage_chars()] = char_matrix_idx_2_batch

                    if POS_vocab is not None:
                        feed_dict[train_graph.get_in_question_poss()] = POS_idx_1_batch
                        feed_dict[train_graph.get_in_passage_poss()] = POS_idx_2_batch

                    if NER_vocab is not None:
                        feed_dict[train_graph.get_in_question_ners()] = NER_idx_1_batch
                        feed_dict[train_graph.get_in_passage_ners()] = NER_idx_2_batch

                    if FLAGS.is_answer_selection == True:
                        feed_dict[train_graph.get_question_count()] = trainDataStream.question_count(batch_index)
                        feed_dict[train_graph.get_answer_count()] = trainDataStream.answer_count(batch_index)

                    _, loss_value = sess.run([train_graph.get_train_op(), train_graph.get_loss()], feed_dict=feed_dict)
                    total_loss += loss_value
                    if FLAGS.is_answer_selection == True and FLAGS.is_server == False:
                        print ("q: {} a: {} loss_value: {}".format(trainDataStream.question_count(batch_index)
                                                   ,trainDataStream.answer_count(batch_index), loss_value))

                    if step % 100==0:
                        print('{} '.format(step), end="")
                        sys.stdout.flush()

                    # Save a checkpoint and evaluate the model periodically.
                    if (step + 1) % trainDataStream.get_num_batch() == 0 or (step + 1) == max_steps:
                        #print(total_loss)
                        # Print status to stdout.
                        duration = time.time() - start_time
                        start_time = time.time()
                        output_res_file.write('Step %d: loss = %.2f (%.3f sec)\n' % (step, total_loss, duration))
                        total_loss = 0.0


                        #Evaluate against the validation set.
                        output_res_file.write('valid- ')
                        my_map, my_mrr = evaluate(devDataStream, valid_graph, sess,char_vocab=char_vocab,
                                            POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab)
                        output_res_file.write("map: '{}', mrr: '{}'\n".format(my_map, my_mrr))
                        #print ("dev map: {}".format(my_map))
                        #print("Current accuracy is %.2f" % accuracy)

                        #accuracy = my_map
                        #if accuracy>best_accuracy:
                        #    best_accuracy = accuracy
                        #    saver.save(sess, best_path)

                        # Evaluate against the test set.
                        output_res_file.write ('test- ')
                        my_map, my_mrr = evaluate(testDataStream, valid_graph, sess, char_vocab=char_vocab,
                                 POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab)
                        output_res_file.write("map: '{}', mrr: '{}\n\n".format(my_map, my_mrr))
                        if FLAGS.is_server == False:
                            print ("test map: {}".format(my_map))

                        #Evaluate against the train set only for final epoch.
                        if (step + 1) == max_steps:
                            output_res_file.write ('train- ')
                            my_map, my_mrr = evaluate(trainDataStream, valid_graph, sess, char_vocab=char_vocab,
                                POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab)
                            output_res_file.write("map: '{}', mrr: '{}'\n".format(my_map, my_mrr))

        # print("Best accuracy on dev set is %.2f" % best_accuracy)
        # # decoding
        # print('Decoding on the test set:')
        # init_scale = 0.01
        # with tf.Graph().as_default():
        #     initializer = tf.random_uniform_initializer(-init_scale, init_scale)
        #     with tf.variable_scope("Model", reuse=False, initializer=initializer):
        #         valid_graph = SentenceMatchModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab,
        #              dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type,
        #              lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim,
        #              aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=False, MP_dim=FLAGS.MP_dim,
        #              context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num,
        #              fix_word_vec=FLAGS.fix_word_vec,with_filter_layer=FLAGS.with_filter_layer, with_highway=FLAGS.with_highway,
        #              word_level_MP_dim=FLAGS.word_level_MP_dim,
        #              with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway,
        #              highway_layer_num=FLAGS.highway_layer_num, with_lex_decomposition=FLAGS.with_lex_decomposition,
        #              lex_decompsition_dim=FLAGS.lex_decompsition_dim,
        #              with_left_match=(not FLAGS.wo_left_match), with_right_match=(not FLAGS.wo_right_match),
        #              with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match),
        #              with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match),
        #                                               with_bilinear_att=(not FLAGS.wo_bilinear_att)
        #                                               , type1=FLAGS.type1, type2 = FLAGS.type2, type3=FLAGS.type3,
        #                                               with_aggregation_attention=not FLAGS.wo_agg_self_att,
        #                                               is_answer_selection= FLAGS.is_answer_selection,
        #                                               is_shared_attention=FLAGS.is_shared_attention,
        #                                               modify_loss=FLAGS.modify_loss,is_aggregation_lstm=FLAGS.is_aggregation_lstm,
        #                                               max_window_size=FLAGS.max_window_size,
        #                                               prediction_mode=FLAGS.prediction_mode,
        #                                               context_lstm_dropout=not FLAGS.wo_lstm_drop_out,
        #                                              is_aggregation_siamese=FLAGS.is_aggregation_siamese)
        #
        #     vars_ = {}
        #     for var in tf.global_variables():
        #         if "word_embedding" in var.name: continue
        #         if not var.name.startswith("Model"): continue
        #         vars_[var.name.split(":")[0]] = var
        #     saver = tf.train.Saver(vars_)
        #
        #     sess = tf.Session()
        #     sess.run(tf.global_variables_initializer())
        #     step = 0
        #     saver.restore(sess, best_path)
        #
        #     accuracy, mrr = evaluate(testDataStream, valid_graph, sess,char_vocab=char_vocab,POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab
        #                         , mode='trec')
        #     output_res_file.write("map for test set is %.2f\n" % accuracy)
        output_res_file.close()
Ejemplo n.º 6
0
def main(_):
    print('Configurations:')
    print(FLAGS)

    train_path = FLAGS.train_path
    dev_path = FLAGS.dev_path
    test_path = FLAGS.test_path
    word_vec_path = FLAGS.word_vec_path
    log_dir = FLAGS.model_dir
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    
    path_prefix = log_dir + "/SentenceMatch.{}".format(FLAGS.suffix)

    namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")

    # build vocabs
    word_vocab = Vocab(word_vec_path, fileformat='txt3')
    best_path = path_prefix + '.best.model'
    char_path = path_prefix + ".char_vocab"
    label_path = path_prefix + ".label_vocab"
    POS_path = path_prefix + ".POS_vocab"
    NER_path = path_prefix + ".NER_vocab"
    has_pre_trained_model = False
    POS_vocab = None
    NER_vocab = None
    if os.path.exists(best_path):
        has_pre_trained_model = True
        label_vocab = Vocab(label_path, fileformat='txt2')
        char_vocab = Vocab(char_path, fileformat='txt2')
        if FLAGS.with_POS: POS_vocab = Vocab(POS_path, fileformat='txt2')
        if FLAGS.with_NER: NER_vocab = Vocab(NER_path, fileformat='txt2')
    else:
        print('Collect words, chars and labels ...')
        (all_words, all_chars, all_labels, all_POSs, all_NERs) = collect_vocabs(train_path, with_POS=FLAGS.with_POS, with_NER=FLAGS.with_NER)
        print('Number of words: {}'.format(len(all_words)))
        print('Number of labels: {}'.format(len(all_labels)))
        label_vocab = Vocab(fileformat='voc', voc=all_labels,dim=2)
        label_vocab.dump_to_txt2(label_path)

        print('Number of chars: {}'.format(len(all_chars)))
        char_vocab = Vocab(fileformat='voc', voc=all_chars,dim=FLAGS.char_emb_dim)
        char_vocab.dump_to_txt2(char_path)
        
        if FLAGS.with_POS:
            print('Number of POSs: {}'.format(len(all_POSs)))
            POS_vocab = Vocab(fileformat='voc', voc=all_POSs,dim=FLAGS.POS_dim)
            POS_vocab.dump_to_txt2(POS_path)
        if FLAGS.with_NER:
            print('Number of NERs: {}'.format(len(all_NERs)))
            NER_vocab = Vocab(fileformat='voc', voc=all_NERs,dim=FLAGS.NER_dim)
            NER_vocab.dump_to_txt2(NER_path)
            

    print('word_vocab shape is {}'.format(word_vocab.word_vecs.shape))
    print('tag_vocab shape is {}'.format(label_vocab.word_vecs.shape))
    num_classes = label_vocab.size()

    print('Build SentenceMatchDataStream ... ')
    trainDataStream = SentenceMatchDataStream(train_path, word_vocab=word_vocab, char_vocab=char_vocab, 
                                              POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, 
                                              batch_size=FLAGS.batch_size, isShuffle=True, isLoop=True, isSort=True, 
                                              max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length)
                                    
    devDataStream = SentenceMatchDataStream(dev_path, word_vocab=word_vocab, char_vocab=char_vocab,
                                              POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, 
                                              batch_size=FLAGS.batch_size, isShuffle=False, isLoop=True, isSort=True, 
                                              max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length)

    testDataStream = SentenceMatchDataStream(test_path, word_vocab=word_vocab, char_vocab=char_vocab, 
                                              POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, 
                                              batch_size=FLAGS.batch_size, isShuffle=False, isLoop=True, isSort=True, 
                                              max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length)

    print('Number of instances in trainDataStream: {}'.format(trainDataStream.get_num_instance()))
    print('Number of instances in devDataStream: {}'.format(devDataStream.get_num_instance()))
    print('Number of instances in testDataStream: {}'.format(testDataStream.get_num_instance()))
    print('Number of batches in trainDataStream: {}'.format(trainDataStream.get_num_batch()))
    print('Number of batches in devDataStream: {}'.format(devDataStream.get_num_batch()))
    print('Number of batches in testDataStream: {}'.format(testDataStream.get_num_batch()))
    
    sys.stdout.flush()
    if FLAGS.wo_char: char_vocab = None

    best_accuracy = 0.0
    init_scale = 0.01
    with tf.Graph().as_default():
        initializer = tf.random_uniform_initializer(-init_scale, init_scale)
#         with tf.name_scope("Train"):
        with tf.variable_scope("Model", reuse=None, initializer=initializer):
            train_graph = SentenceMatchModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab,POS_vocab=POS_vocab, NER_vocab=NER_vocab, 
                 dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type,
                 lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim, 
                 aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=True, MP_dim=FLAGS.MP_dim, 
                 context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num, 
                 fix_word_vec=FLAGS.fix_word_vec,with_filter_layer=FLAGS.with_filter_layer, with_highway=FLAGS.with_highway,
                 word_level_MP_dim=FLAGS.word_level_MP_dim,
                 with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway,
                 highway_layer_num=FLAGS.highway_layer_num,with_lex_decomposition=FLAGS.with_lex_decomposition, 
                 lex_decompsition_dim=FLAGS.lex_decompsition_dim,
                 with_left_match=(not FLAGS.wo_left_match), with_right_match=(not FLAGS.wo_right_match),
                 with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match), 
                 with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match))
            tf.summary.scalar("Training Loss", train_graph.get_loss()) # Add a scalar summary for the snapshot loss.
        
#         with tf.name_scope("Valid"):
        with tf.variable_scope("Model", reuse=True, initializer=initializer):
            valid_graph = SentenceMatchModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, 
                 dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type,
                 lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim, 
                 aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=False, MP_dim=FLAGS.MP_dim, 
                 context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num, 
                 fix_word_vec=FLAGS.fix_word_vec,with_filter_layer=FLAGS.with_filter_layer, with_highway=FLAGS.with_highway,
                 word_level_MP_dim=FLAGS.word_level_MP_dim,
                 with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway,
                 highway_layer_num=FLAGS.highway_layer_num, with_lex_decomposition=FLAGS.with_lex_decomposition, 
                 lex_decompsition_dim=FLAGS.lex_decompsition_dim,
                 with_left_match=(not FLAGS.wo_left_match), with_right_match=(not FLAGS.wo_right_match),
                 with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match), 
                 with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match))

                
        initializer = tf.global_variables_initializer()
        vars_ = {}
        for var in tf.all_variables():
            if "word_embedding" in var.name: continue
#             if not var.name.startswith("Model"): continue
            vars_[var.name.split(":")[0]] = var
        saver = tf.train.Saver(vars_)
         
        sess = tf.Session()
        sess.run(initializer)
        if has_pre_trained_model:
            print("Restoring model from " + best_path)
            saver.restore(sess, best_path)
            print("DONE!")

        print('Start the training loop.')
        train_size = trainDataStream.get_num_batch()
        max_steps = train_size * FLAGS.max_epochs
        total_loss = 0.0
        start_time = time.time()
        for step in xrange(max_steps):
            # read data
            cur_batch = trainDataStream.nextBatch()
            (label_batch, sent1_batch, sent2_batch, label_id_batch, word_idx_1_batch, word_idx_2_batch, 
                                 char_matrix_idx_1_batch, char_matrix_idx_2_batch, sent1_length_batch, sent2_length_batch, 
                                 sent1_char_length_batch, sent2_char_length_batch,
                                 POS_idx_1_batch, POS_idx_2_batch, NER_idx_1_batch, NER_idx_2_batch) = cur_batch
            feed_dict = {
                         train_graph.get_truth(): label_id_batch, 
                         train_graph.get_question_lengths(): sent1_length_batch, 
                         train_graph.get_passage_lengths(): sent2_length_batch, 
                         train_graph.get_in_question_words(): word_idx_1_batch, 
                         train_graph.get_in_passage_words(): word_idx_2_batch, 
#                          train_graph.get_question_char_lengths(): sent1_char_length_batch, 
#                          train_graph.get_passage_char_lengths(): sent2_char_length_batch, 
#                          train_graph.get_in_question_chars(): char_matrix_idx_1_batch, 
#                          train_graph.get_in_passage_chars(): char_matrix_idx_2_batch, 
                         }
            if char_vocab is not None:
                feed_dict[train_graph.get_question_char_lengths()] = sent1_char_length_batch
                feed_dict[train_graph.get_passage_char_lengths()] = sent2_char_length_batch
                feed_dict[train_graph.get_in_question_chars()] = char_matrix_idx_1_batch
                feed_dict[train_graph.get_in_passage_chars()] = char_matrix_idx_2_batch

            if POS_vocab is not None:
                feed_dict[train_graph.get_in_question_poss()] = POS_idx_1_batch
                feed_dict[train_graph.get_in_passage_poss()] = POS_idx_2_batch

            if NER_vocab is not None:
                feed_dict[train_graph.get_in_question_ners()] = NER_idx_1_batch
                feed_dict[train_graph.get_in_passage_ners()] = NER_idx_2_batch

            _, loss_value = sess.run([train_graph.get_train_op(), train_graph.get_loss()], feed_dict=feed_dict)
            total_loss += loss_value
            
            if step % 100==0: 
                print('{} '.format(step), end="")
                sys.stdout.flush()

            # Save a checkpoint and evaluate the model periodically.
            if (step + 1) % trainDataStream.get_num_batch() == 0 or (step + 1) == max_steps:
                print()
                # Print status to stdout.
                duration = time.time() - start_time
                start_time = time.time()
                print('Step %d: loss = %.2f (%.3f sec)' % (step, total_loss, duration))
                total_loss = 0.0

                # Evaluate against the validation set.
                print('Validation Data Eval:')
                accuracy = evaluate(devDataStream, valid_graph, sess,char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab)
                print("Current accuracy is %.2f" % accuracy)
                if accuracy>best_accuracy:
                    best_accuracy = accuracy
                    saver.save(sess, best_path)

    print("Best accuracy on dev set is %.2f" % best_accuracy)
    # decoding
    print('Decoding on the test set:')
    init_scale = 0.01
    with tf.Graph().as_default():
        initializer = tf.random_uniform_initializer(-init_scale, init_scale)
        with tf.variable_scope("Model", reuse=False, initializer=initializer):
            valid_graph = SentenceMatchModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, 
                 dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type,
                 lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim, 
                 aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=False, MP_dim=FLAGS.MP_dim, 
                 context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num, 
                 fix_word_vec=FLAGS.fix_word_vec,with_filter_layer=FLAGS.with_filter_layer, with_highway=FLAGS.with_highway,
                 word_level_MP_dim=FLAGS.word_level_MP_dim,
                 with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway,
                 highway_layer_num=FLAGS.highway_layer_num, with_lex_decomposition=FLAGS.with_lex_decomposition, 
                 lex_decompsition_dim=FLAGS.lex_decompsition_dim,
                 with_left_match=(not FLAGS.wo_left_match), with_right_match=(not FLAGS.wo_right_match),
                 with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match), 
                 with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match))
        vars_ = {}
        for var in tf.all_variables():
            if "word_embedding" in var.name: continue
            if not var.name.startswith("Model"): continue
            vars_[var.name.split(":")[0]] = var
        saver = tf.train.Saver(vars_)
                
        sess = tf.Session()
        sess.run(tf.global_variables_initializer())
        step = 0
        saver.restore(sess, best_path)

        accuracy = evaluate(testDataStream, valid_graph, sess,char_vocab=char_vocab,POS_vocab=POS_vocab, NER_vocab=NER_vocab)
        print("Accuracy for test set is %.2f" % accuracy)
Ejemplo n.º 7
0
def main(_):
    print('Configurations:')
    print(FLAGS)

    train_path = FLAGS.train_path
    dev_path = FLAGS.dev_path
    test_path = FLAGS.test_path
    word_vec_path = FLAGS.word_vec_path
    log_dir = FLAGS.model_dir
    tolower = FLAGS.use_lower_letter
    FLAGS.rl_matches = json.loads(FLAGS.rl_matches)

    # if not os.path.exists(log_dir):
    #     os.makedirs(log_dir)

    path_prefix = log_dir + "/TriMatch.{}".format(FLAGS.suffix)

    namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")

    # build vocabs
    word_vocab = Vocab(word_vec_path, fileformat='txt3', tolower=tolower)
    best_path = path_prefix + '.best.model'
    char_path = path_prefix + ".char_vocab"
    label_path = path_prefix + ".label_vocab"
    POS_path = path_prefix + ".POS_vocab"
    NER_path = path_prefix + ".NER_vocab"
    has_pre_trained_model = False
    POS_vocab = None
    NER_vocab = None

    print('best path:', best_path)
    if os.path.exists(best_path +
                      '.data-00000-of-00001') and not (FLAGS.create_new_model):
        print('Using pretrained model')
        has_pre_trained_model = True
        label_vocab = Vocab(label_path, fileformat='txt2', tolower=tolower)
        char_vocab = Vocab(char_path, fileformat='txt2', tolower=tolower)
        if FLAGS.with_POS:
            POS_vocab = Vocab(POS_path, fileformat='txt2', tolower=tolower)
        if FLAGS.with_NER:
            NER_vocab = Vocab(NER_path, fileformat='txt2', tolower=tolower)
    else:
        print('Creating new model')
        print('Collect words, chars and labels ...')
        (all_words, all_chars, all_labels, all_POSs,
         all_NERs) = collect_vocabs(train_path,
                                    with_POS=FLAGS.with_POS,
                                    with_NER=FLAGS.with_NER,
                                    tolower=tolower)
        if FLAGS.use_options:
            all_labels = ['0', '1']
        print('Number of words: {}'.format(len(all_words)))
        print('Number of labels: {}'.format(len(all_labels)))
        # for word in all_labels:
        #     print('label',word)
        # input('check')

        label_vocab = Vocab(fileformat='voc',
                            voc=all_labels,
                            dim=2,
                            tolower=tolower)
        label_vocab.dump_to_txt2(label_path)

        print('Number of chars: {}'.format(len(all_chars)))
        char_vocab = Vocab(fileformat='voc',
                           voc=all_chars,
                           dim=FLAGS.char_emb_dim,
                           tolower=tolower)
        char_vocab.dump_to_txt2(char_path)

        if FLAGS.with_POS:
            print('Number of POSs: {}'.format(len(all_POSs)))
            POS_vocab = Vocab(fileformat='voc',
                              voc=all_POSs,
                              dim=FLAGS.POS_dim,
                              tolower=tolower)
            POS_vocab.dump_to_txt2(POS_path)
        if FLAGS.with_NER:
            print('Number of NERs: {}'.format(len(all_NERs)))
            NER_vocab = Vocab(fileformat='voc',
                              voc=all_NERs,
                              dim=FLAGS.NER_dim,
                              tolower=tolower)
            NER_vocab.dump_to_txt2(NER_path)

    print('all_labels:', label_vocab)
    print('has pretrained model:', has_pre_trained_model)
    # for word in word_vocab.word_vecs:

    print('word_vocab shape is {}'.format(word_vocab.word_vecs.shape))
    print('tag_vocab shape is {}'.format(label_vocab.word_vecs.shape))
    num_classes = label_vocab.size()

    print('Build TriMatchDataStream ... ')

    gen_concat_mat = False
    gen_split_mat = False
    if FLAGS.matching_option == 7:
        gen_concat_mat = True
        if FLAGS.concat_context:
            gen_split_mat = True
    trainDataStream = TriMatchDataStream(
        train_path,
        word_vocab=word_vocab,
        char_vocab=char_vocab,
        POS_vocab=POS_vocab,
        NER_vocab=NER_vocab,
        label_vocab=label_vocab,
        batch_size=FLAGS.batch_size,
        isShuffle=True,
        isLoop=True,
        isSort=(not FLAGS.wo_sort_instance_based_on_length),
        max_char_per_word=FLAGS.max_char_per_word,
        max_sent_length=FLAGS.max_sent_length,
        max_hyp_length=FLAGS.max_hyp_length,
        max_choice_length=FLAGS.max_choice_length,
        tolower=tolower,
        gen_concat_mat=gen_concat_mat,
        gen_split_mat=gen_split_mat)

    devDataStream = TriMatchDataStream(
        dev_path,
        word_vocab=word_vocab,
        char_vocab=char_vocab,
        POS_vocab=POS_vocab,
        NER_vocab=NER_vocab,
        label_vocab=label_vocab,
        batch_size=FLAGS.batch_size,
        isShuffle=False,
        isLoop=True,
        isSort=(not FLAGS.wo_sort_instance_based_on_length),
        max_char_per_word=FLAGS.max_char_per_word,
        max_sent_length=FLAGS.max_sent_length,
        max_hyp_length=FLAGS.max_hyp_length,
        max_choice_length=FLAGS.max_choice_length,
        tolower=tolower,
        gen_concat_mat=gen_concat_mat,
        gen_split_mat=gen_split_mat)

    testDataStream = TriMatchDataStream(
        test_path,
        word_vocab=word_vocab,
        char_vocab=char_vocab,
        POS_vocab=POS_vocab,
        NER_vocab=NER_vocab,
        label_vocab=label_vocab,
        batch_size=FLAGS.batch_size,
        isShuffle=False,
        isLoop=True,
        isSort=(not FLAGS.wo_sort_instance_based_on_length),
        max_char_per_word=FLAGS.max_char_per_word,
        max_sent_length=FLAGS.max_sent_length,
        max_hyp_length=FLAGS.max_hyp_length,
        max_choice_length=FLAGS.max_choice_length,
        tolower=tolower,
        gen_concat_mat=gen_concat_mat,
        gen_split_mat=gen_split_mat)

    print('Number of instances in trainDataStream: {}'.format(
        trainDataStream.get_num_instance()))
    print('Number of instances in devDataStream: {}'.format(
        devDataStream.get_num_instance()))
    print('Number of instances in testDataStream: {}'.format(
        testDataStream.get_num_instance()))
    print('Number of batches in trainDataStream: {}'.format(
        trainDataStream.get_num_batch()))
    print('Number of batches in devDataStream: {}'.format(
        devDataStream.get_num_batch()))
    print('Number of batches in testDataStream: {}'.format(
        testDataStream.get_num_batch()))

    sys.stdout.flush()
    if FLAGS.wo_char: char_vocab = None

    best_accuracy = 0.0
    init_scale = 0.01
    with tf.Graph().as_default():
        initializer = tf.random_uniform_initializer(-init_scale, init_scale)
        #         with tf.name_scope("Train"):
        with tf.variable_scope("Model", reuse=None, initializer=initializer):
            train_graph = TriMatchModelGraph(
                num_classes,
                word_vocab=word_vocab,
                char_vocab=char_vocab,
                POS_vocab=POS_vocab,
                NER_vocab=NER_vocab,
                dropout_rate=FLAGS.dropout_rate,
                learning_rate=FLAGS.learning_rate,
                optimize_type=FLAGS.optimize_type,
                lambda_l2=FLAGS.lambda_l2,
                char_lstm_dim=FLAGS.char_lstm_dim,
                context_lstm_dim=FLAGS.context_lstm_dim,
                aggregation_lstm_dim=FLAGS.aggregation_lstm_dim,
                is_training=True,
                MP_dim=FLAGS.MP_dim,
                context_layer_num=FLAGS.context_layer_num,
                aggregation_layer_num=FLAGS.aggregation_layer_num,
                fix_word_vec=FLAGS.fix_word_vec,
                with_highway=FLAGS.with_highway,
                word_level_MP_dim=FLAGS.word_level_MP_dim,
                with_match_highway=FLAGS.with_match_highway,
                with_aggregation_highway=FLAGS.with_aggregation_highway,
                highway_layer_num=FLAGS.highway_layer_num,
                match_to_question=FLAGS.match_to_question,
                match_to_passage=FLAGS.match_to_passage,
                match_to_choice=FLAGS.match_to_choice,
                with_full_match=(not FLAGS.wo_full_match),
                with_maxpool_match=(not FLAGS.wo_maxpool_match),
                with_attentive_match=(not FLAGS.wo_attentive_match),
                with_max_attentive_match=(not FLAGS.wo_max_attentive_match),
                use_options=FLAGS.use_options,
                num_options=num_options,
                with_no_match=FLAGS.with_no_match,
                verbose=FLAGS.verbose,
                matching_option=FLAGS.matching_option,
                concat_context=FLAGS.concat_context,
                tied_aggre=FLAGS.tied_aggre,
                rl_training_method=FLAGS.rl_training_method,
                rl_matches=FLAGS.rl_matches)

            tf.summary.scalar("Training Loss", train_graph.get_loss()
                              )  # Add a scalar summary for the snapshot loss.
        if FLAGS.verbose:
            valid_graph = train_graph
        else:
            #         with tf.name_scope("Valid"):
            with tf.variable_scope("Model",
                                   reuse=True,
                                   initializer=initializer):
                valid_graph = TriMatchModelGraph(
                    num_classes,
                    word_vocab=word_vocab,
                    char_vocab=char_vocab,
                    POS_vocab=POS_vocab,
                    NER_vocab=NER_vocab,
                    dropout_rate=FLAGS.dropout_rate,
                    learning_rate=FLAGS.learning_rate,
                    optimize_type=FLAGS.optimize_type,
                    lambda_l2=FLAGS.lambda_l2,
                    char_lstm_dim=FLAGS.char_lstm_dim,
                    context_lstm_dim=FLAGS.context_lstm_dim,
                    aggregation_lstm_dim=FLAGS.aggregation_lstm_dim,
                    is_training=False,
                    MP_dim=FLAGS.MP_dim,
                    context_layer_num=FLAGS.context_layer_num,
                    aggregation_layer_num=FLAGS.aggregation_layer_num,
                    fix_word_vec=FLAGS.fix_word_vec,
                    with_highway=FLAGS.with_highway,
                    word_level_MP_dim=FLAGS.word_level_MP_dim,
                    with_match_highway=FLAGS.with_match_highway,
                    with_aggregation_highway=FLAGS.with_aggregation_highway,
                    highway_layer_num=FLAGS.highway_layer_num,
                    match_to_question=FLAGS.match_to_question,
                    match_to_passage=FLAGS.match_to_passage,
                    match_to_choice=FLAGS.match_to_choice,
                    with_full_match=(not FLAGS.wo_full_match),
                    with_maxpool_match=(not FLAGS.wo_maxpool_match),
                    with_attentive_match=(not FLAGS.wo_attentive_match),
                    with_max_attentive_match=(
                        not FLAGS.wo_max_attentive_match),
                    use_options=FLAGS.use_options,
                    num_options=num_options,
                    with_no_match=FLAGS.with_no_match,
                    matching_option=FLAGS.matching_option,
                    concat_context=FLAGS.concat_context,
                    tied_aggre=FLAGS.tied_aggre,
                    rl_training_method=FLAGS.rl_training_method,
                    rl_matches=FLAGS.rl_matches)

        initializer = tf.global_variables_initializer()
        vars_ = {}
        for var in tf.global_variables():
            # print(var.name,var.get_shape().as_list())
            if "word_embedding" in var.name: continue
            #             if not var.name.startswith("Model"): continue
            vars_[var.name.split(":")[0]] = var
        saver = tf.train.Saver(vars_)
        # input('check')

        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        sess = tf.Session(config=config)
        sess.run(initializer)
        if has_pre_trained_model:
            print("Restoring model from " + best_path)
            saver.restore(sess, best_path)
            print("DONE!")

        print('Start the training loop.')
        train_size = trainDataStream.get_num_batch()
        max_steps = train_size * FLAGS.max_epochs
        total_loss = 0.0
        start_time = time.time()
        sub_loss_counter = 0.0
        for step in range(max_steps):
            # read data
            cur_batch = trainDataStream.nextBatch()
            (label_batch, sent1_batch, sent2_batch, sent3_batch,
             label_id_batch, word_idx_1_batch, word_idx_2_batch,
             word_idx_3_batch, char_matrix_idx_1_batch,
             char_matrix_idx_2_batch, char_matrix_idx_3_batch,
             sent1_length_batch, sent2_length_batch, sent3_length_batch,
             sent1_char_length_batch, sent2_char_length_batch,
             sent3_char_length_batch, POS_idx_1_batch, POS_idx_2_batch,
             NER_idx_1_batch, NER_idx_2_batch, concat_mat_batch,
             split_mat_batch_q, split_mat_batch_c) = cur_batch

            # print(label_id_batch)
            if FLAGS.verbose:
                print(label_id_batch)
                print(sent1_length_batch)
                print(sent2_length_batch)
                print(sent3_length_batch)
                # print(word_idx_1_batch)
                # print(word_idx_2_batch)
                # print(word_idx_3_batch)
                # print(sent1_batch)
                # print(sent2_batch)
                # print(sent3_batch)
                print(concat_mat_batch)
                input('check')
            feed_dict = {
                train_graph.get_truth(): label_id_batch,
                train_graph.get_passage_lengths(): sent1_length_batch,
                train_graph.get_question_lengths(): sent2_length_batch,
                train_graph.get_choice_lengths(): sent3_length_batch,
                train_graph.get_in_passage_words(): word_idx_1_batch,
                train_graph.get_in_question_words(): word_idx_2_batch,
                train_graph.get_in_choice_words(): word_idx_3_batch,
                #                          train_graph.get_question_char_lengths(): sent1_char_length_batch,
                #                          train_graph.get_passage_char_lengths(): sent2_char_length_batch,
                #                          train_graph.get_in_question_chars(): char_matrix_idx_1_batch,
                #                          train_graph.get_in_passage_chars(): char_matrix_idx_2_batch,
            }
            if char_vocab is not None:
                feed_dict[train_graph.get_passage_char_lengths(
                )] = sent1_char_length_batch
                feed_dict[train_graph.get_question_char_lengths(
                )] = sent2_char_length_batch
                feed_dict[train_graph.get_choice_char_lengths(
                )] = sent3_char_length_batch
                feed_dict[train_graph.get_in_passage_chars(
                )] = char_matrix_idx_1_batch
                feed_dict[train_graph.get_in_question_chars(
                )] = char_matrix_idx_2_batch
                feed_dict[train_graph.get_in_choice_chars(
                )] = char_matrix_idx_3_batch

            if POS_vocab is not None:
                feed_dict[train_graph.get_in_passage_poss()] = POS_idx_1_batch
                feed_dict[train_graph.get_in_question_poss()] = POS_idx_2_batch

            if NER_vocab is not None:
                feed_dict[train_graph.get_in_passage_ners()] = NER_idx_1_batch
                feed_dict[train_graph.get_in_question_ners()] = NER_idx_2_batch
            if concat_mat_batch is not None:
                feed_dict[train_graph.concat_idx_mat] = concat_mat_batch
            if split_mat_batch_q is not None:
                feed_dict[train_graph.split_idx_mat_q] = split_mat_batch_q
                feed_dict[train_graph.split_idx_mat_c] = split_mat_batch_c

            if FLAGS.verbose:
                return_list = sess.run([
                    train_graph.get_train_op(),
                    train_graph.get_loss(),
                    train_graph.get_predictions(),
                    train_graph.get_prob(), train_graph.all_probs,
                    train_graph.correct
                ] + train_graph.matching_vectors,
                                       feed_dict=feed_dict)
                print(len(return_list))
                _, loss_value, pred, prob, all_probs, correct = return_list[
                    0:6]
                print('pred=', pred)
                print('prob=', prob)
                print('logits=', all_probs)
                print('correct=', correct)
                for val in return_list[6:]:
                    if isinstance(val, list):
                        print('list len ', len(val))
                        for objj in val:
                            print('this shape=', val.shape)
                    print('this shape=', val.shape)
                    # print(val)
                input('check')
            else:
                _, loss_value = sess.run(
                    [train_graph.get_train_op(),
                     train_graph.get_loss()],
                    feed_dict=feed_dict)
            total_loss += loss_value
            sub_loss_counter += loss_value

            if step % int(FLAGS.display_every) == 0:
                print('{},{} '.format(step, sub_loss_counter), end="")
                sys.stdout.flush()
                sub_loss_counter = 0.0

            # Save a checkpoint and evaluate the model periodically.
            if (step + 1) % trainDataStream.get_num_batch() == 0 or (
                    step + 1) == max_steps:
                print()
                # Print status to stdout.
                duration = time.time() - start_time
                start_time = time.time()
                print('Step %d: loss = %.2f (%.3f sec)' %
                      (step, total_loss, duration))
                total_loss = 0.0

                # Evaluate against the validation set.
                print('Validation Data Eval:')
                if FLAGS.predict_val:
                    outpath = path_prefix + '.iter%d' % (step) + '.probs'
                else:
                    outpath = None
                accuracy = evaluate(devDataStream,
                                    valid_graph,
                                    sess,
                                    char_vocab=char_vocab,
                                    POS_vocab=POS_vocab,
                                    NER_vocab=NER_vocab,
                                    use_options=FLAGS.use_options,
                                    outpath=outpath,
                                    mode='prob')
                print("Current accuracy on dev set is %.2f" % accuracy)
                if accuracy >= best_accuracy:
                    best_accuracy = accuracy
                    saver.save(sess, best_path)
                    print('saving the current model.')
                accuracy = evaluate(testDataStream,
                                    valid_graph,
                                    sess,
                                    char_vocab=char_vocab,
                                    POS_vocab=POS_vocab,
                                    NER_vocab=NER_vocab,
                                    use_options=FLAGS.use_options,
                                    outpath=outpath,
                                    mode='prob')
                print("Current accuracy on test set is %.2f" % accuracy)

    print("Best accuracy on dev set is %.2f" % best_accuracy)
    # decoding
    print('Decoding on the test set:')
    init_scale = 0.01
    with tf.Graph().as_default():
        initializer = tf.random_uniform_initializer(-init_scale, init_scale)
        with tf.variable_scope("Model", reuse=False, initializer=initializer):
            valid_graph = TriMatchModelGraph(
                num_classes,
                word_vocab=word_vocab,
                char_vocab=char_vocab,
                POS_vocab=POS_vocab,
                NER_vocab=NER_vocab,
                dropout_rate=FLAGS.dropout_rate,
                learning_rate=FLAGS.learning_rate,
                optimize_type=FLAGS.optimize_type,
                lambda_l2=FLAGS.lambda_l2,
                char_lstm_dim=FLAGS.char_lstm_dim,
                context_lstm_dim=FLAGS.context_lstm_dim,
                aggregation_lstm_dim=FLAGS.aggregation_lstm_dim,
                is_training=False,
                MP_dim=FLAGS.MP_dim,
                context_layer_num=FLAGS.context_layer_num,
                aggregation_layer_num=FLAGS.aggregation_layer_num,
                fix_word_vec=FLAGS.fix_word_vec,
                with_highway=FLAGS.with_highway,
                word_level_MP_dim=FLAGS.word_level_MP_dim,
                with_match_highway=FLAGS.with_match_highway,
                with_aggregation_highway=FLAGS.with_aggregation_highway,
                highway_layer_num=FLAGS.highway_layer_num,
                match_to_question=FLAGS.match_to_question,
                match_to_passage=FLAGS.match_to_passage,
                match_to_choice=FLAGS.match_to_choice,
                with_full_match=(not FLAGS.wo_full_match),
                with_maxpool_match=(not FLAGS.wo_maxpool_match),
                with_attentive_match=(not FLAGS.wo_attentive_match),
                with_max_attentive_match=(not FLAGS.wo_max_attentive_match),
                use_options=FLAGS.use_options,
                num_options=num_options,
                with_no_match=FLAGS.with_no_match,
                matching_option=FLAGS.matching_option,
                concat_context=FLAGS.concat_context,
                tied_aggre=FLAGS.tied_aggre,
                rl_training_method=FLAGS.rl_training_method,
                rl_matches=FLAGS.rl_matches)
        vars_ = {}
        for var in tf.all_variables():
            if "word_embedding" in var.name: continue
            if not var.name.startswith("Model"): continue
            vars_[var.name.split(":")[0]] = var
        saver = tf.train.Saver(vars_)

        sess = tf.Session()
        sess.run(tf.global_variables_initializer())
        step = 0
        saver.restore(sess, best_path)

        accuracy = evaluate(testDataStream,
                            valid_graph,
                            sess,
                            char_vocab=char_vocab,
                            POS_vocab=POS_vocab,
                            NER_vocab=NER_vocab,
                            use_options=FLAGS.use_options)
        print("Accuracy for test set is %.2f" % accuracy)
Ejemplo n.º 8
0
def main(_):
    print('Configurations:')
    print(FLAGS)
    train_path = FLAGS.train_path
    dev_path = FLAGS.dev_path
    test_path = FLAGS.test_path
    word_vec_path = FLAGS.word_vec_path
    log_dir = FLAGS.model_dir
    result_dir = '../result'
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    if not os.path.exists(result_dir):
        os.makedirs(result_dir)

    path_prefix = log_dir + "/Han.{}".format(FLAGS.suffix)
    save_namespace(FLAGS, path_prefix + ".config.json")
    word_vocab = Vocab(word_vec_path, fileformat='txt3')
    best_path = path_prefix + '.best.model'
    label_path = path_prefix + ".label_vocab"

    print('Collect words and labels ...')
    (all_words, all_labels) = collect_vocabs(train_path)
    print('Number of words: {}'.format(len(all_words)))
    print('Number of labels: {}'.format(len(all_labels)))
    label_vocab = Vocab(fileformat='voc', voc=all_labels, dim=2)
    label_vocab.dump_to_txt2(label_path)

    print('word_vocab shape is {}'.format(word_vocab.word_vecs.shape))
    print('tag_vocab shape is {}'.format(label_vocab.word_vecs.shape))
    num_classes = label_vocab.size()

    print('Build HanDataStream ... ')
    trainDataStream = HanDataStream(inpath=train_path, word_vocab=word_vocab,
                                              label_vocab=label_vocab,
                                              isShuffle=True, isLoop=True,
                                              max_sent_length=FLAGS.max_sent_length)
    devDataStream = HanDataStream(inpath=dev_path, word_vocab=word_vocab,
                                              label_vocab=label_vocab,
                                              isShuffle=False, isLoop=True,
                                              max_sent_length=FLAGS.max_sent_length)
    testDataStream = HanDataStream(inpath=test_path, word_vocab=word_vocab,
                                              label_vocab=label_vocab,
                                              isShuffle=False, isLoop=True,
                                              max_sent_length=FLAGS.max_sent_length)

    print('Number of instances in trainDataStream: {}'.format(trainDataStream.get_num_instance()))
    print('Number of instances in devDataStream: {}'.format(devDataStream.get_num_instance()))
    print('Number of instances in testDataStream: {}'.format(testDataStream.get_num_instance()))

    with tf.Graph().as_default():
        initializer = tf.contrib.layers.xavier_initializer()
        with tf.variable_scope("Model", reuse=None, initializer=initializer):
            train_graph = HanModelGraph(num_classes=num_classes, word_vocab=word_vocab,
                                                  dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate,
                                                  lambda_l2=FLAGS.lambda_l2,
                                                  context_lstm_dim=FLAGS.context_lstm_dim,
                                                  is_training=True, batch_size = FLAGS.batch_size)
            tf.summary.scalar("Training Loss", train_graph.loss)  # Add a scalar summary for the snapshot loss.
        print("Train Graph Build")
        with tf.variable_scope("Model", reuse=True, initializer=initializer):
            valid_graph = HanModelGraph(num_classes=num_classes, word_vocab=word_vocab,
                                                  dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate,
                                                  lambda_l2=FLAGS.lambda_l2,
                                                  context_lstm_dim=FLAGS.context_lstm_dim,
                                                  is_training=False, batch_size = 1)
        print ("dev Graph Build")
        initializer = tf.global_variables_initializer()
        vars_ = {}
        for var in tf.all_variables():
            if "word_embedding" in var.name: continue
            #             if not var.name.startswith("Model"): continue
            vars_[var.name.split(":")[0]] = var
        saver = tf.train.Saver(vars_)

        output_res_file = open(result_dir + '/' + FLAGS.suffix, 'wt')
        output_res_file.write(str(FLAGS))
        with tf.Session() as sess:
            sess.run(initializer)
            train_size = trainDataStream.get_num_instance()
            max_steps = (train_size * FLAGS.max_epochs) // FLAGS.batch_size
            epoch_size = max_steps // (FLAGS.max_epochs)  # + 1

            total_loss = 0.0
            start_time = time.time()
            best_accuracy = 0
            for step in range(max_steps):
                # read data
                # _truth = []
                # _sents_length = []
                # _in_text_words = []
                # for i in range(FLAGS.batch_size):
                #     cur_instance, instance_index = trainDataStream.nextInstance ()
                #     (label,text,label_id, word_idx, sents_length) = cur_instance
                #
                #     _truth.append(label_id)
                #     _sents_length.append(sents_length)
                #     _in_text_words.append(word_idx)
                #
                # feed_dict = {
                #     train_graph.truth: np.array(_truth),
                #     train_graph.sents_length: tuple(_sents_length),
                #     train_graph.in_text_words: tuple(_in_text_words),
                # }

                feed_dict = get_feed_dict(data_stream=trainDataStream, graph=train_graph, batch_size=FLAGS.batch_size, is_testing=False)

                _, loss_value, _score = sess.run([train_graph.train_op, train_graph.loss
                                                     , train_graph.batch_class_scores],
                                                 feed_dict=feed_dict)
                total_loss += loss_value

                if step % 100 == 0:
                    print('{} '.format(step), end="")
                    sys.stdout.flush()
                if (step + 1) % epoch_size == 0 or (step + 1) == max_steps:
                    # print(total_loss)
                    duration = time.time() - start_time
                    start_time = time.time()
                    print(duration, step, "Loss: ", total_loss)
                    output_res_file.write('\nStep %d: loss = %.2f (%.3f sec)\n' % (step, total_loss, duration))
                    total_loss = 0.0
                    # Evaluate against the validation set.
                    output_res_file.write('valid- ')
                    dev_accuracy = evaluate(devDataStream, valid_graph, sess)
                    output_res_file.write("%.2f\n" % dev_accuracy)
                    print("Current dev accuracy is %.2f" % dev_accuracy)
                    if dev_accuracy > best_accuracy:
                        best_accuracy = dev_accuracy
                        saver.save(sess, best_path)
                    output_res_file.write('test- ')
                    test_accuracy = evaluate(testDataStream, valid_graph, sess)
                    print("Current test accuracy is %.2f" % test_accuracy)
                    output_res_file.write("%.2f\n" % test_accuracy)

    output_res_file.close()
    sys.stdout.flush()
Ejemplo n.º 9
0
    def decode(self,
               model_prefix,
               in_path,
               out_path,
               word_vec_path,
               mode,
               out_json_path=None,
               dump_prob_path=None):
        #     model_prefix = args.model_prefix
        #     in_path = args.in_path
        #     out_path = args.out_path
        #     word_vec_path = args.word_vec_path
        #     mode = args.mode
        #     out_json_path = None
        #     dump_prob_path = None

        # load the configuration file
        print('Loading configurations.')
        FLAGS = namespace_utils.load_namespace(model_prefix + ".config.json")
        print(FLAGS)

        with_POS = False
        if hasattr(FLAGS, 'with_POS'): with_POS = FLAGS.with_POS
        with_NER = False
        if hasattr(FLAGS, 'with_NER'): with_NER = FLAGS.with_NER
        wo_char = False
        if hasattr(FLAGS, 'wo_char'): wo_char = FLAGS.wo_char

        wo_left_match = False
        if hasattr(FLAGS, 'wo_left_match'): wo_left_match = FLAGS.wo_left_match

        wo_right_match = False
        if hasattr(FLAGS, 'wo_right_match'):
            wo_right_match = FLAGS.wo_right_match

        wo_full_match = False
        if hasattr(FLAGS, 'wo_full_match'): wo_full_match = FLAGS.wo_full_match

        wo_maxpool_match = False
        if hasattr(FLAGS, 'wo_maxpool_match'):
            wo_maxpool_match = FLAGS.wo_maxpool_match

        wo_attentive_match = False
        if hasattr(FLAGS, 'wo_attentive_match'):
            wo_attentive_match = FLAGS.wo_attentive_match

        wo_max_attentive_match = False
        if hasattr(FLAGS, 'wo_max_attentive_match'):
            wo_max_attentive_match = FLAGS.wo_max_attentive_match

        # load vocabs
        print('Loading vocabs.')
        word_vocab = Vocab(word_vec_path, fileformat='txt3')
        label_vocab = Vocab(model_prefix + ".label_vocab", fileformat='txt2')
        print('word_vocab: {}'.format(word_vocab.word_vecs.shape))
        print('label_vocab: {}'.format(label_vocab.word_vecs.shape))
        num_classes = label_vocab.size()

        POS_vocab = None
        NER_vocab = None
        char_vocab = None
        if with_POS:
            POS_vocab = Vocab(model_prefix + ".POS_vocab", fileformat='txt2')
        if with_NER:
            NER_vocab = Vocab(model_prefix + ".NER_vocab", fileformat='txt2')
        char_vocab = Vocab(model_prefix + ".char_vocab", fileformat='txt2')
        print('char_vocab: {}'.format(char_vocab.word_vecs.shape))

        print('Build SentenceMatchDataStream ... ')
        testDataStream = SentenceMatchTrainer.SentenceMatchDataStream(
            in_path,
            word_vocab=word_vocab,
            char_vocab=char_vocab,
            POS_vocab=POS_vocab,
            NER_vocab=NER_vocab,
            label_vocab=label_vocab,
            batch_size=FLAGS.batch_size,
            isShuffle=False,
            isLoop=True,
            isSort=True,
            max_char_per_word=FLAGS.max_char_per_word,
            max_sent_length=FLAGS.max_sent_length)
        print('Number of instances in testDataStream: {}'.format(
            testDataStream.get_num_instance()))
        print('Number of batches in testDataStream: {}'.format(
            testDataStream.get_num_batch()))

        if wo_char: char_vocab = None

        init_scale = 0.01
        best_path = model_prefix + ".best.model"
        print('Decoding on the test set:')
        with tf.Graph().as_default():
            initializer = tf.random_uniform_initializer(
                -init_scale, init_scale)
            with tf.variable_scope("Model",
                                   reuse=False,
                                   initializer=initializer):
                valid_graph = SentenceMatchModelGraph(
                    num_classes,
                    word_vocab=word_vocab,
                    char_vocab=char_vocab,
                    POS_vocab=POS_vocab,
                    NER_vocab=NER_vocab,
                    dropout_rate=FLAGS.dropout_rate,
                    learning_rate=FLAGS.learning_rate,
                    optimize_type=FLAGS.optimize_type,
                    lambda_l2=FLAGS.lambda_l2,
                    char_lstm_dim=FLAGS.char_lstm_dim,
                    context_lstm_dim=FLAGS.context_lstm_dim,
                    aggregation_lstm_dim=FLAGS.aggregation_lstm_dim,
                    is_training=False,
                    MP_dim=FLAGS.MP_dim,
                    context_layer_num=FLAGS.context_layer_num,
                    aggregation_layer_num=FLAGS.aggregation_layer_num,
                    fix_word_vec=FLAGS.fix_word_vec,
                    with_filter_layer=FLAGS.with_filter_layer,
                    with_highway=FLAGS.with_highway,
                    word_level_MP_dim=FLAGS.word_level_MP_dim,
                    with_match_highway=FLAGS.with_match_highway,
                    with_aggregation_highway=FLAGS.with_aggregation_highway,
                    highway_layer_num=FLAGS.highway_layer_num,
                    with_lex_decomposition=FLAGS.with_lex_decomposition,
                    lex_decompsition_dim=FLAGS.lex_decompsition_dim,
                    with_char=(not FLAGS.wo_char),
                    with_left_match=(not FLAGS.wo_left_match),
                    with_right_match=(not FLAGS.wo_right_match),
                    with_full_match=(not FLAGS.wo_full_match),
                    with_maxpool_match=(not FLAGS.wo_maxpool_match),
                    with_attentive_match=(not FLAGS.wo_attentive_match),
                    with_max_attentive_match=(
                        not FLAGS.wo_max_attentive_match))

            # remove word _embedding
            vars_ = {}
            for var in tf.all_variables():
                if "word_embedding" in var.name: continue
                if not var.name.startswith("Model"): continue
                vars_[var.name.split(":")[0]] = var
            saver = tf.train.Saver(vars_)

            sess = tf.Session()
            sess.run(tf.global_variables_initializer())
            step = 0
            best_path = best_path.replace('//', '/')
            saver.restore(sess, best_path)

            accuracy = SentenceMatchTrainer.evaluate(testDataStream,
                                                     valid_graph,
                                                     sess,
                                                     outpath=out_path,
                                                     label_vocab=label_vocab,
                                                     mode=mode,
                                                     char_vocab=char_vocab,
                                                     POS_vocab=POS_vocab,
                                                     NER_vocab=NER_vocab)
Ejemplo n.º 10
0
def main(FLAGS):
    train_path = FLAGS.train_path
    dev_path = FLAGS.dev_path
    word_vec_path = FLAGS.word_vec_path
    log_dir = FLAGS.model_dir
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    path_prefix = log_dir + "/SentenceMatch.{}".format(FLAGS.suffix)

    namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")

    # build vocabs
    word_vocab = Vocab(word_vec_path, fileformat='txt3')

    best_path = path_prefix + '.best.model'
    char_path = path_prefix + ".char_vocab"
    label_path = path_prefix + ".label_vocab"
    has_pre_trained_model = False
    char_vocab = None
    if os.path.exists(best_path + ".index"):
        has_pre_trained_model = True
        logger.info('Loading vocabs from a pre-trained model ...')
        label_vocab = Vocab(label_path, fileformat='txt2')
        if FLAGS.with_char: char_vocab = Vocab(char_path, fileformat='txt2')
    else:
        logger.info('Collecting words, chars and labels ...')
        (all_words, all_chars, all_labels, all_POSs,
         all_NERs) = collect_vocabs(train_path)
        logger.info('Number of words: {}'.format(len(all_words)))
        label_vocab = Vocab(fileformat='voc', voc=all_labels, dim=2)
        label_vocab.dump_to_txt2(label_path)

        if FLAGS.with_char:
            logger.info('Number of chars: {}'.format(len(all_chars)))
            char_vocab = Vocab(fileformat='voc',
                               voc=all_chars,
                               dim=FLAGS.char_emb_dim)
            char_vocab.dump_to_txt2(char_path)

    logger.info('word_vocab shape is {}'.format(word_vocab.word_vecs.shape))
    num_classes = label_vocab.size()
    logger.info("Number of labels: {}".format(num_classes))
    sys.stdout.flush()

    logger.info('Build SentenceMatchDataStream ... ')
    trainDataStream = SentenceMatchDataStream(train_path,
                                              word_vocab=word_vocab,
                                              char_vocab=char_vocab,
                                              label_vocab=label_vocab,
                                              isShuffle=True,
                                              isLoop=True,
                                              isSort=True,
                                              options=FLAGS)
    logger.info('Number of instances in trainDataStream: {}'.format(
        trainDataStream.get_num_instance()))
    logger.info('Number of batches in trainDataStream: {}'.format(
        trainDataStream.get_num_batch()))
    sys.stdout.flush()

    devDataStream = SentenceMatchDataStream(dev_path,
                                            word_vocab=word_vocab,
                                            char_vocab=char_vocab,
                                            label_vocab=label_vocab,
                                            isShuffle=False,
                                            isLoop=True,
                                            isSort=True,
                                            options=FLAGS)
    logger.info('Number of instances in devDataStream: {}'.format(
        devDataStream.get_num_instance()))
    logger.info('Number of batches in devDataStream: {}'.format(
        devDataStream.get_num_batch()))
    sys.stdout.flush()

    init_scale = 0.01
    with tf.Graph().as_default():
        initializer = tf.random_uniform_initializer(-init_scale, init_scale)
        global_step = tf.train.get_or_create_global_step()
        with tf.variable_scope("Model", reuse=None, initializer=initializer):
            train_graph = SentenceMatchModelGraph(num_classes,
                                                  word_vocab=word_vocab,
                                                  char_vocab=char_vocab,
                                                  is_training=True,
                                                  options=FLAGS,
                                                  global_step=global_step)

        with tf.variable_scope("Model", reuse=True, initializer=initializer):
            valid_graph = SentenceMatchModelGraph(num_classes,
                                                  word_vocab=word_vocab,
                                                  char_vocab=char_vocab,
                                                  is_training=False,
                                                  options=FLAGS)

        initializer = tf.global_variables_initializer()
        vars_ = {}
        for var in tf.global_variables():
            if "word_embedding" in var.name:
                continue
            vars_[var.name.split(":")[0]] = var
        saver = tf.train.Saver(vars_)

        sess = tf.Session()
        # 初始化写日志的wirter, 并将当前TensorFlow计算图写入日志
        train_writer = tf.summary.FileWriter(SUMMARY_DIR, sess.graph)
        # valid_writer = tf.summary.FileWriter(SUMMARY_DIR + '/valid')

        sess.run(initializer)
        if has_pre_trained_model:
            logger.info("Restoring model from " + best_path)
            saver.restore(sess, best_path)
            logger.info("DONE!")

        # training
        train(sess, saver, train_graph, valid_graph, trainDataStream,
              devDataStream, FLAGS, best_path, train_writer, label_vocab)

        train_writer.close()
Ejemplo n.º 11
0
def main(FLAGS):
    train_path = FLAGS.train_path
    dev_path = FLAGS.dev_path
    test_path = FLAGS.test_path
    dev_path_target = FLAGS.dev_path_target
    test_path_target = FLAGS.test_path_target
    word_vec_path = FLAGS.word_vec_path
    log_dir = FLAGS.model_dir
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
        os.makedirs(os.path.join(log_dir, '../result_source'))
        os.makedirs(os.path.join(log_dir, '../logits_source'))
        os.makedirs(os.path.join(log_dir, '../result_target'))
        os.makedirs(os.path.join(log_dir, '../logits_target'))

    log_dir_target = FLAGS.model_dir + '_target'
    if not os.path.exists(log_dir_target):
        os.makedirs(log_dir_target)

    path_prefix = log_dir + "/ESIM.{}".format(FLAGS.suffix)
    namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")
    path_prefix_target = log_dir_target + "/ESIM.{}".format(FLAGS.suffix)
    namespace_utils.save_namespace(FLAGS, path_prefix_target + ".config.json")
    # build vocabs
    word_vocab = Vocab(word_vec_path, fileformat='txt3')

    best_path = path_prefix + '.best.model'
    best_path_target = path_prefix_target + '.best.model'
    char_path = path_prefix + ".char_vocab"
    label_path = path_prefix + ".label_vocab"
    has_pre_trained_model = False
    char_vocab = None
    # if os.path.exists(best_path + ".index"):
    print('Collecting words, chars and labels ...')
    (all_words, all_chars, all_labels, all_POSs,
     all_NERs) = collect_vocabs(train_path)
    print('Number of words: {}'.format(len(all_words)))
    label_vocab = Vocab(fileformat='voc', voc=all_labels, dim=2)
    label_vocab.dump_to_txt2(label_path)

    print('word_vocab shape is {}'.format(word_vocab.word_vecs.shape))
    num_classes = label_vocab.size()
    print("Number of labels: {}".format(num_classes))
    sys.stdout.flush()

    print('Build SentenceMatchDataStream ... ')
    trainDataStream = DataStream(train_path,
                                 word_vocab=word_vocab,
                                 label_vocab=None,
                                 isShuffle=True,
                                 isLoop=True,
                                 isSort=True,
                                 options=FLAGS)
    print('Number of instances in trainDataStream: {}'.format(
        trainDataStream.get_num_instance()))
    print('Number of batches in trainDataStream: {}'.format(
        trainDataStream.get_num_batch()))
    sys.stdout.flush()

    devDataStream = DataStream(dev_path,
                               word_vocab=word_vocab,
                               label_vocab=None,
                               isShuffle=True,
                               isLoop=True,
                               isSort=True,
                               options=FLAGS)
    print('Number of instances in devDataStream: {}'.format(
        devDataStream.get_num_instance()))
    print('Number of batches in devDataStream: {}'.format(
        devDataStream.get_num_batch()))
    sys.stdout.flush()

    testDataStream = DataStream(test_path,
                                word_vocab=word_vocab,
                                label_vocab=None,
                                isShuffle=True,
                                isLoop=True,
                                isSort=True,
                                options=FLAGS)

    print('Number of instances in testDataStream: {}'.format(
        testDataStream.get_num_instance()))
    print('Number of batches in testDataStream: {}'.format(
        testDataStream.get_num_batch()))
    sys.stdout.flush()

    devDataStream_target = DataStream(dev_path_target,
                                      word_vocab=word_vocab,
                                      label_vocab=None,
                                      isShuffle=True,
                                      isLoop=True,
                                      isSort=True,
                                      options=FLAGS)
    print('Number of instances in devDataStream: {}'.format(
        devDataStream.get_num_instance()))
    print('Number of batches in devDataStream: {}'.format(
        devDataStream.get_num_batch()))
    sys.stdout.flush()

    testDataStream_target = DataStream(test_path_target,
                                       word_vocab=word_vocab,
                                       label_vocab=None,
                                       isShuffle=True,
                                       isLoop=True,
                                       isSort=True,
                                       options=FLAGS)

    print('Number of instances in testDataStream: {}'.format(
        testDataStream.get_num_instance()))
    print('Number of batches in testDataStream: {}'.format(
        testDataStream.get_num_batch()))
    sys.stdout.flush()

    with tf.Graph().as_default():
        initializer = tf.contrib.layers.xavier_initializer()
        global_step = tf.train.get_or_create_global_step()
        with tf.variable_scope("Model", reuse=None, initializer=initializer):
            train_graph = Model(num_classes,
                                word_vocab=word_vocab,
                                is_training=True,
                                options=FLAGS,
                                global_step=global_step)
        with tf.variable_scope("Model", reuse=True, initializer=initializer):
            valid_graph = Model(num_classes,
                                word_vocab=word_vocab,
                                is_training=False,
                                options=FLAGS)

        initializer = tf.global_variables_initializer()
        vars_ = {}
        for var in tf.global_variables():
            if "word_embedding" in var.name: continue
            #             if not var.name.startswith("Model"): continue
            vars_[var.name.split(":")[0]] = var
        saver = tf.train.Saver(vars_)

        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=1)
        config = tf.ConfigProto(allow_soft_placement=True,
                                gpu_options=gpu_options)

        config.gpu_options.allow_growth = True
        with tf.Session(config=config) as sess:
            sess.run(initializer)

            # training
            train(sess, saver, train_graph, valid_graph, trainDataStream,
                  devDataStream, testDataStream, devDataStream_target,
                  testDataStream_target, FLAGS, best_path, best_path_target)
Ejemplo n.º 12
0
def main(FLAGS):
    tf.logging.set_verbosity(tf.logging.INFO)
    train_path = FLAGS.train_path
    dev_path = FLAGS.dev_path
    test_path = FLAGS.test_path
    word_vec_path = FLAGS.word_vec_path
    kg_path = FLAGS.kg_path
    wordnet_path = FLAGS.wordnet_path
    lemma_vec_path = FLAGS.lemma_vec_path
    log_dir = FLAGS.model_dir
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
        os.makedirs(os.path.join(log_dir, '../result'))
        os.makedirs(os.path.join(log_dir, '../logits'))

    path_prefix = log_dir + "/KEIM.{}".format(FLAGS.suffix)
    namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")

    # build vocabs
    word_vocab = Vocab(word_vec_path, fileformat='txt3')
    lemma_vocab = Vocab(lemma_vec_path, fileformat='txt3')
    best_path = path_prefix + '.best.model'
    char_path = path_prefix + ".char_vocab"
    label_path = path_prefix + ".label_vocab"
    char_vocab = None

    tf.logging.info('Collecting words, chars and labels ...')
    (all_words, all_chars, all_labels, all_POSs,
     all_NERs) = collect_vocabs(train_path)
    tf.logging.info('Number of words: {}'.format(len(all_words)))
    label_vocab = Vocab(fileformat='voc', voc=all_labels, dim=2)
    label_vocab.dump_to_txt2(label_path)

    if FLAGS.with_char:
        tf.logging.info('Number of chars: {}'.format(len(all_chars)))
        char_vocab = Vocab(fileformat='voc',
                           voc=all_chars,
                           dim=FLAGS.char_emb_dim)
        char_vocab.dump_to_txt2(char_path)

    tf.logging.info('word_vocab shape is {}'.format(
        word_vocab.word_vecs.shape))
    tf.logging.info('lemma_word_vocab shape is {}'.format(
        lemma_vocab.word_vecs.shape))
    num_classes = label_vocab.size()
    tf.logging.info("Number of labels: {}".format(num_classes))
    sys.stdout.flush()

    with open(wordnet_path, 'rb') as f:
        wordnet_vocab = pkl.load(f)
    tf.logging.info('wordnet_vocab shape is {}'.format(len(wordnet_vocab)))
    with open(kg_path, 'rb') as f:
        kg = pkl.load(f)
    tf.logging.info('kg shape is {}'.format(len(kg)))

    tf.logging.info('Build SentenceMatchDataStream ... ')
    trainDataStream = DataStream(train_path,
                                 word_vocab=word_vocab,
                                 char_vocab=char_vocab,
                                 label_vocab=None,
                                 kg=kg,
                                 wordnet_vocab=wordnet_vocab,
                                 lemma_vocab=lemma_vocab,
                                 isShuffle=True,
                                 isLoop=True,
                                 isSort=True,
                                 options=FLAGS)
    tf.logging.info('Number of instances in trainDataStream: {}'.format(
        trainDataStream.get_num_instance()))
    tf.logging.info('Number of batches in trainDataStream: {}'.format(
        trainDataStream.get_num_batch()))
    sys.stdout.flush()

    devDataStream = DataStream(dev_path,
                               word_vocab=word_vocab,
                               char_vocab=char_vocab,
                               label_vocab=None,
                               kg=kg,
                               wordnet_vocab=wordnet_vocab,
                               lemma_vocab=lemma_vocab,
                               isShuffle=True,
                               isLoop=True,
                               isSort=True,
                               options=FLAGS)
    tf.logging.info('Number of instances in devDataStream: {}'.format(
        devDataStream.get_num_instance()))
    tf.logging.info('Number of batches in devDataStream: {}'.format(
        devDataStream.get_num_batch()))
    sys.stdout.flush()

    testDataStream = DataStream(test_path,
                                word_vocab=word_vocab,
                                char_vocab=char_vocab,
                                label_vocab=None,
                                kg=kg,
                                wordnet_vocab=wordnet_vocab,
                                lemma_vocab=lemma_vocab,
                                isShuffle=True,
                                isLoop=True,
                                isSort=True,
                                options=FLAGS)

    tf.logging.info('Number of instances in testDataStream: {}'.format(
        testDataStream.get_num_instance()))
    tf.logging.info('Number of batches in testDataStream: {}'.format(
        testDataStream.get_num_batch()))
    sys.stdout.flush()

    with tf.Graph().as_default():
        initializer = tf.contrib.layers.xavier_initializer()
        # initializer = tf.truncated_normal_initializer(stddev=0.02)
        global_step = tf.train.get_or_create_global_step()
        with tf.variable_scope("Model", reuse=None, initializer=initializer):
            train_graph = Model(num_classes,
                                word_vocab=word_vocab,
                                char_vocab=char_vocab,
                                lemma_vocab=lemma_vocab,
                                is_training=True,
                                options=FLAGS,
                                global_step=global_step)
        with tf.variable_scope("Model", reuse=True, initializer=initializer):
            valid_graph = Model(num_classes,
                                word_vocab=word_vocab,
                                char_vocab=char_vocab,
                                lemma_vocab=lemma_vocab,
                                is_training=False,
                                options=FLAGS)

        initializer = tf.global_variables_initializer()
        vars_ = {}
        for var in tf.global_variables():
            # if "word_embedding" in var.name: continue
            # if not var.name.startswith("Model"): continue
            vars_[var.name.split(":")[0]] = var
        saver = tf.train.Saver(vars_)

        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=1)
        config = tf.ConfigProto(allow_soft_placement=True,
                                gpu_options=gpu_options)

        config.gpu_options.allow_growth = True
        with tf.Session(config=config) as sess:
            sess.run(initializer)
            # training
            train(sess, saver, train_graph, valid_graph, trainDataStream,
                  devDataStream, testDataStream, FLAGS, best_path)